cat#

delu.cat(data: List[T], /, dim: int = 0) T[source]#

Concatenate a sequence of collections of tensors (like torch.cat, but for collections of tensors).

While torch.cat concatenates a sequence of tensors, delu.to concatenates a sequence of collections of tensors (tuples, named tuples, dictionaries, dataclasses and nested combinations thereof; nested lists are not allowed).

Usage

Common setup:

>>> # First batch.
>>> x1 = torch.randn(64, 10)
>>> y1 = torch.randn(64)
>>> # Second batch.
>>> x2 = torch.randn(64, 10)
>>> y2 = torch.randn(64)
>>> # The last (incomplete) batch.
>>> x3 = torch.randn(7, 10)
>>> y3 = torch.randn(7)
>>> # Total size.
>>> len(x1) + len(x2) + len(x3)
135

delu.cat can be applied to tuples:

>>> batches = [(x1, y1), (x2, y2), (x3, y3)]
>>> X, Y = delu.cat(batches)
>>> print(len(X), len(Y))
135 135

delu.cat can be applied to dictionaries:

>>> batches = [
...     {'x': x1, 'y': y1},
...     {'x': x2, 'y': y2},
...     {'x': x3, 'y': y3},
... ]
>>> result = delu.cat(batches)
>>> print(isinstance(result, dict), len(result['x']), len(result['y']))
True 135 135

delu.cat can be applied to named tuples:

>>> from typing import NamedTuple
>>> class Data(NamedTuple):
...     x: torch.Tensor
...     y: torch.Tensor
>>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)]
>>> result = delu.cat(batches)
>>> print(isinstance(result, Data), len(result.x), len(result.y))
True 135 135

delu.cat can be applied to dataclasses:

>>> from dataclasses import dataclass
>>> @dataclass
... class Data:
...     x: torch.Tensor
...     y: torch.Tensor
>>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)]
>>> result = delu.cat(batches)
>>> print(isinstance(result, Data), len(result.x), len(result.y))
True 135 135

delu.cat can be applied to nested collections:

>>> batches = [
...     (x1, {'a': {'b': y1}}),
...     (x2, {'a': {'b': y2}}),
...     (x3, {'a': {'b': y3}}),
... ]
>>> X, Y_nested = delu.cat(batches)
>>> print(len(X), len(Y_nested['a']['b']))
135 135

`delu.cat` cannot be applied to lists:

>>> # This does not work. Instead, use tuples.
>>> # batches = [[x1, y1], [x2, y2], [x3, y3]]
>>> # delu.cat(batches)  # Error
Parameters:
  • data – the list of collections of tensors. All items of the list must be of the same type, structure and layout, only the dim dimension can vary (same as for torch.cat). All the “leaf” values must be of the type torch.Tensor.

  • dim – the dimension along which the tensors are concatenated.

Returns:

The concatenated items of the list.