to#
- delu.to(obj: T, /, *args, **kwargs) T [source]#
Change devices and data types of tensors and modules in an arbitrary Python object.
The two primary use cases for this function are changing the device and data types of tensors and modules that are a part of:
a complex Python object (e.g. a training state, checkpoint, etc.)
an object of an unknown type (when implementing generic pipelines)
type(obj)
What
delu.to(obj, *args, **kwargs)
returnsobj
is returned as-is.tuple
list
set
frozenset
and all other collections falling undertyping.Sequence
typing.Set
typing.FrozenSet
A new collection of the same type where
delu.to
is recursively applied to all items.dict
or any othertyping.Mapping
A new collection of the same type where
delu.to
is recursively applied to all keys and values.obj.to(*args, **kwargs)
(
obj
is modified in-place)obj.to(*args, **kwargs)
Any other type (custom classes are allowed)
(
obj
is modified in-place)obj
itself with all attributes recursively updated withdelu.to
.Usage
Trivial immutable objects are returned as-is:
>>> kwargs = {'device': 'cpu', 'dtype': torch.half} >>> >>> x = 0 >>> x_new = delu.to(x, **kwargs) >>> x_new is x True
If a collection is passed, a new one is created. The behavior for the nested values depends on their types:
>>> x = { ... # The "unchanged" tensor will not be changed, ... # because it already has the requested dtype and device. ... 'unchanged': torch.tensor(0, **kwargs), ... 'changed': torch.tensor(1), ... 'module': nn.Linear(2, 3), ... 'other': [False, 1, 2.0, 'hello', b'world'] ... } >>> x_new = delu.to(x, **kwargs) >>> # The collection itself is a new object: >>> x_new is x False >>> # Tensors change according to `torch.Tensor.to`: >>> x_new['unchanged'] is x['unchanged'] True >>> x_new['changed'] is x['changed'] False >>> # Modules are modified in-place: >>> x_new['module'] is x['module'] True
Complex user-defined types are also allowed:
>>> from dataclasses import dataclass >>> >>> class A: ... def __init__(self): ... self.a = torch.randn(5) ... self.b = ('Hello, world!', torch.randn(10)) ... self.c = nn.Linear(4, 7) ... >>> @dataclass >>> class B: ... d: List[A] ... >>> x = B([A(), A()]) >>> x_new = delu.to(x, **kwargs) >>> # The object is the same in terms of Python `id`, >>> # however, some of its nested attributes changed. >>> x_new is x True
- Parameters:
obj – the input object.
args – the positional arguments for
torch.Tensor.to
/torch.nn.Module.to
.kwargs – the keyword arguments for
torch.Tensor.to
/torch.nn.Module.to
.
- Returns:
the transformed object.