to#

delu.to(obj: T, /, *args, **kwargs) T[source]#

Change devices and data types of tensors and modules in an arbitrary Python object.

The two primary use cases for this function are changing the device and data types of tensors and modules that are a part of:

  • a complex Python object (e.g. a training state, checkpoint, etc.)

  • an object of an unknown type (when implementing generic pipelines)

type(obj)

What delu.to(obj, *args, **kwargs) returns

bool int float str bytes

obj is returned as-is.

tuple list set frozenset and all other collections falling under typing.Sequence typing.Set typing.FrozenSet

A new collection of the same type where delu.to is recursively applied to all items.

dict or any other typing.Mapping

A new collection of the same type where delu.to is recursively applied to all keys and values.

torch.Tensor

obj.to(*args, **kwargs)

torch.nn.Module

(obj is modified in-place) obj.to(*args, **kwargs)

Any other type (custom classes are allowed)

(obj is modified in-place) obj itself with all attributes recursively updated with delu.to.

Usage

Trivial immutable objects are returned as-is:

>>> kwargs = {'device': 'cpu', 'dtype': torch.half}
>>>
>>> x = 0
>>> x_new = delu.to(x, **kwargs)
>>> x_new is x
True

If a collection is passed, a new one is created. The behavior for the nested values depends on their types:

>>> x = {
...     # The "unchanged" tensor will not be changed,
...     # because it already has the requested dtype and device.
...     'unchanged': torch.tensor(0, **kwargs),
...     'changed': torch.tensor(1),
...     'module': nn.Linear(2, 3),
...     'other': [False, 1, 2.0, 'hello', b'world']
... }
>>> x_new = delu.to(x, **kwargs)
>>> # The collection itself is a new object:
>>> x_new is x
False
>>> # Tensors change according to `torch.Tensor.to`:
>>> x_new['unchanged'] is x['unchanged']
True
>>> x_new['changed'] is x['changed']
False
>>> # Modules are modified in-place:
>>> x_new['module'] is x['module']
True

Complex user-defined types are also allowed:

>>> from dataclasses import dataclass
>>>
>>> class A:
...     def __init__(self):
...         self.a = torch.randn(5)
...         self.b = ('Hello, world!', torch.randn(10))
...         self.c = nn.Linear(4, 7)
...
>>> @dataclass
>>> class B:
...     d: List[A]
...
>>> x = B([A(), A()])
>>> x_new = delu.to(x, **kwargs)
>>> # The object is the same in terms of Python `id`,
>>> # however, some of its nested attributes changed.
>>> x_new is x
True
Parameters:
Returns:

the transformed object.