to#
- delu.to(obj: T, /, *args, **kwargs) T [source]#
Change devices and data types of tensors and modules in an arbitrary Python object (like
torch.Tensor.to
/torch.nn.Module.to
, but for any Python object).The two primary use cases for this function are changing the device and data types of tensors and modules that are a part of:
a complex Python object (e.g. a training state, checkpoint, etc.)
an object of an unknown type (when implementing generic pipelines)
Usage
>>> from dataclasses import dataclass >>> >>> class UserClass: ... def __init__(self): ... self.a = torch.randn(5) ... self.b = ('Hello, world!', torch.randn(10)) ... self.c = nn.Linear(4, 7) ... >>> @dataclass >>> class UserDataclass: ... d: List[UserClass] ... >>> data = ( ... torch.rand(3), ... [{(False, 1): torch.tensor(1.0)}, 2.0], ... UserDataclass([UserClass(), UserClass()]), ... ) >>> data = delu.to(data, device='cpu', dtype=torch.float16)
Note
Technically, the function traverses the input
data
as follows:for tensors/modules,
torch.Tensor.to
/torch.nn.Module.to
is applied with the provided*args
and**kwargs
; in particular, it means that tensors will be replaced with new tensors (in terms of Pythonid
), but modules will be modified inplace;for tuples, named tuples, lists, other sequences (see
typing.Sequence
), dictionaries and other mappings (seetyping.Mapping
), a new collection of the same type is returned, wheredelu.to
is recursively applied to all values of the original collection;in all other cases, the original object in terms of Python
id
is returned. If the object has attributes (defined in__dict__
or__slots__
), thendelu.to
is recursively applied to all the attributes.
- Parameters:
obj – the input object.
args – the positional arguments for
torch.Tensor.to
/torch.nn.Module.to
.kwargs – the keyword arguments for
torch.Tensor.to
/torch.nn.Module.to
.
- Returns:
the transformed object.