delu.to#

delu.to(data, *args, **kwargs)[source]#

Like torch.Tensor.to, but for collections of tensors.

The function allows changing devices and data types for (nested) collections of tensors similarly to how torch.Tensor.to does this for a single tensor.

Note

Technically, the function simply traverses the input and applies torch.Tensor.to to tensors (non-tensor values are not allowed).

Parameters:
  • data (T) – the tensor or the (nested) collection of tensors. Allowed collections include: (named)tuples, lists, dictionaries and dataclasses. For dataclasses, all their fields must be tensors.

  • args – the positional arguments for torch.Tensor.to

  • kwargs – the key-word arguments for torch.Tensor.to

Returns:

transformed data.

Return type:

T

Examples

# in practice, this can be 'cuda' or any other device
device = torch.device('cpu')
tensor = torch.tensor

x = tensor(0)
x = delu.to(x, dtype=torch.float, device=device)

batch = {
    'x': tensor([0.0, 1.0]),
    'y': tensor([0, 1]),
}
batch = delu.to(batch, device)

x = [
    tensor(0.0),
    {'a': tensor(1.0), 'b': tensor(2.0)},
    (tensor(3.0), tensor(4.0))
]
x = delu.to(x, torch.half)