to#

delu.to(obj: T, /, *args, **kwargs) T[source]#

Change devices and data types of tensors and modules in an arbitrary Python object (like torch.Tensor.to / torch.nn.Module.to, but for any Python object).

The two primary use cases for this function are changing the device and data types of tensors and modules that are a part of:

  • a complex Python object (e.g. a training state, checkpoint, etc.)

  • an object of an unknown type (when implementing generic pipelines)

Usage

>>> from dataclasses import dataclass
>>>
>>> class UserClass:
...     def __init__(self):
...         self.a = torch.randn(5)
...         self.b = ('Hello, world!', torch.randn(10))
...         self.c = nn.Linear(4, 7)
...
>>> @dataclass
>>> class UserDataclass:
...     d: List[UserClass]
...
>>> data = (
...     torch.rand(3),
...     [{(False, 1): torch.tensor(1.0)}, 2.0],
...     UserDataclass([UserClass(), UserClass()]),
... )
>>> data = delu.to(data, device='cpu', dtype=torch.float16)

Note

Technically, the function traverses the input data as follows:

  • for tensors/modules, torch.Tensor.to/torch.nn.Module.to is applied with the provided *args and **kwargs; in particular, it means that tensors will be replaced with new tensors (in terms of Python id), but modules will be modified inplace;

  • for tuples, named tuples, lists, other sequences (see typing.Sequence), dictionaries and other mappings (see typing.Mapping), a new collection of the same type is returned, where delu.to is recursively applied to all values of the original collection;

  • in all other cases, the original object in terms of Python id is returned. If the object has attributes (defined in __dict__ or __slots__), then delu.to is recursively applied to all the attributes.

Parameters:
Returns:

the transformed object.