to#
- delu.to(obj: T, /, *args, **kwargs) T[source]#
Change devices and data types of tensors and modules in an arbitrary Python object.
This function is like
torch.Tensor.to/torch.nn.Module.to, but applicable to any Python object.Note
Non-trivial (nested) objects such as user-defined classes or PyTorch modules are modified in-place. See the note below for more details.
The two primary use cases for this function are changing the device and data types of tensors and modules that are a part of:
a complex Python object (e.g. a training state, checkpoint, etc.)
an object of an unknown type (when implementing generic pipelines)
Usage
A technical example:
>>> from dataclasses import dataclass >>> >>> class UserClass: ... def __init__(self): ... self.a = torch.randn(5) ... self.b = ('Hello, world!', torch.randn(10)) ... self.c = nn.Linear(4, 7) ... >>> @dataclass >>> class UserDataclass: ... d: List[UserClass] ... >>> data = ( ... torch.rand(3), ... [{(False, 1): torch.tensor(1.0)}, 2.0], ... UserDataclass([UserClass(), UserClass()]), ... ) >>> data = delu.to(data, device='cpu', dtype=torch.float16)
Note
Technically, the function traverses the input
dataas follows:for tensors/modules,
torch.Tensor.to/torch.nn.Module.tois applied with the provided*argsand**kwargs; in particular, it means that tensors will be replaced with new tensors (in terms of Pythonid), but modules will be modified in-place;for tuples, named tuples, lists, other sequences (see
typing.Sequence), dictionaries and other mappings (seetyping.Mapping), a new collection of the same type is returned, wheredelu.tois recursively applied to all values of the original collection;in all other cases, the original object is modified in-place, and the same object in terms of Python
idis returned. If the object has attributes (defined in__dict__or__slots__), thendelu.tois recursively applied to all the attributes.
- Parameters:
obj – the input object.
args – the positional arguments for
torch.Tensor.to/torch.nn.Module.to.kwargs – the keyword arguments for
torch.Tensor.to/torch.nn.Module.to.
- Returns:
the transformed object.