to#
- delu.to(obj: T, /, *args, **kwargs) T[source]#
Change devices and data types of tensors and modules in an arbitrary Python object.
The two primary use cases for this function are changing the device and data types of tensors and modules that are a part of:
a complex Python object (e.g. a training state, checkpoint, etc.)
an object of an unknown type (when implementing generic pipelines)
type(obj)What
delu.to(obj, *args, **kwargs)returnsobjis returned as-is.tuplelistsetfrozensetand all other collections falling undertyping.Sequencetyping.Settyping.FrozenSetA new collection of the same type where
delu.tois recursively applied to all items.dictor any othertyping.MappingA new collection of the same type where
delu.tois recursively applied to all keys and values.obj.to(*args, **kwargs)(
objis modified in-place)obj.to(*args, **kwargs)Any other type (custom classes are allowed)
(
objis modified in-place)objitself with all attributes recursively updated withdelu.to.Usage
Trivial immutable objects are returned as-is:
>>> kwargs = {'device': 'cpu', 'dtype': torch.half} >>> >>> x = 0 >>> x_new = delu.to(x, **kwargs) >>> x_new is x True
If a collection is passed, a new one is created. The behavior for the nested values depends on their types:
>>> x = { ... # The "unchanged" tensor will not be changed, ... # because it already has the requested dtype and device. ... 'unchanged': torch.tensor(0, **kwargs), ... 'changed': torch.tensor(1), ... 'module': nn.Linear(2, 3), ... 'other': [False, 1, 2.0, 'hello', b'world'] ... } >>> x_new = delu.to(x, **kwargs) >>> # The collection itself is a new object: >>> x_new is x False >>> # Tensors change according to `torch.Tensor.to`: >>> x_new['unchanged'] is x['unchanged'] True >>> x_new['changed'] is x['changed'] False >>> # Modules are modified in-place: >>> x_new['module'] is x['module'] True
Complex user-defined types are also allowed:
>>> from dataclasses import dataclass >>> >>> class A: ... def __init__(self): ... self.a = torch.randn(5) ... self.b = ('Hello, world!', torch.randn(10)) ... self.c = nn.Linear(4, 7) ... >>> @dataclass >>> class B: ... d: List[A] ... >>> x = B([A(), A()]) >>> x_new = delu.to(x, **kwargs) >>> # The object is the same in terms of Python `id`, >>> # however, some of its nested attributes changed. >>> x_new is x True
- Parameters:
obj – the input object.
args – the positional arguments for
torch.Tensor.to/torch.nn.Module.to.kwargs – the keyword arguments for
torch.Tensor.to/torch.nn.Module.to.
- Returns:
the transformed object.