cat#

delu.cat(data: List[T], /, dim: int = 0) T[source]#

Concatenate a sequence of collections of tensors.

delu.cat is a generalized version of torch.cat for concatenating not only tensors, but also (nested) collections of tensors.

Usage

Let’s see how a sequence of model outputs for batches can be concatenated into a output tuple for the whole dataset:

>>> from torch.utils.data import DataLoader, TensorDataset
>>> dataset = TensorDataset(torch.randn(320, 24))
>>> batch_size = 32
>>>
>>> # The model returns not only predictions, but also embeddings.
>>> def model(x_batch):
...     # A dummy forward pass.
...     embeddings_batch = torch.randn(batch_size, 16)
...     y_pred_batch = torch.randn(batch_size)
...     return (y_pred_batch, embeddings_batch)
...
>>> y_pred, embeddings = delu.cat(
...     [model(batch) for batch in DataLoader(dataset, batch_size, shuffle=True)]
... )
>>> len(y_pred) == len(dataset)
True
>>> len(embeddings) == len(dataset)
True

The same works for dictionaries:

>>> def model(x_batch):
...     return {
...         'y_pred': torch.randn(batch_size),
...         'embeddings': torch.randn(batch_size, 16)
...     }
...
>>> outputs = delu.cat(
...     [model(batch) for batch in DataLoader(dataset, batch_size, shuffle=True)]
... )
>>> len(outputs['y_pred']) == len(dataset)
True
>>> len(outputs['embeddings']) == len(dataset)
True

The same works for sequences of named tuples, dataclasses, tensors and nested combinations of all mentioned collection types.

Below, additinal technical examples are provided.

The common setup:

>>> # First batch.
>>> x1 = torch.randn(64, 10)
>>> y1 = torch.randn(64)
>>> # Second batch.
>>> x2 = torch.randn(64, 10)
>>> y2 = torch.randn(64)
>>> # The last (incomplete) batch.
>>> x3 = torch.randn(7, 10)
>>> y3 = torch.randn(7)
>>> total_size = len(x1) + len(x2) + len(x3)

delu.cat can be applied to tuples:

>>> batches = [(x1, y1), (x2, y2), (x3, y3)]
>>> X, Y = delu.cat(batches)
>>> len(X) == total_size and len(Y) == total_size
True

delu.cat can be applied to dictionaries:

>>> batches = [
...     {'x': x1, 'y': y1},
...     {'x': x2, 'y': y2},
...     {'x': x3, 'y': y3},
... ]
>>> result = delu.cat(batches)
>>> isinstance(result, dict)
True
>>> len(result['x']) == total_size and len(result['y']) == total_size
True

delu.cat can be applied to named tuples:

>>> from typing import NamedTuple
>>> class Data(NamedTuple):
...     x: torch.Tensor
...     y: torch.Tensor
...
>>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)]
>>> result = delu.cat(batches)
>>> isinstance(result, Data)
True
>>> len(result.x) == total_size and len(result.y) == total_size
True

delu.cat can be applied to dataclasses:

>>> from dataclasses import dataclass
>>> @dataclass
... class Data:
...     x: torch.Tensor
...     y: torch.Tensor
...
>>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)]
>>> result = delu.cat(batches)
>>> isinstance(result, Data)
True
>>> len(result.x) == total_size and len(result.y) == total_size
True

delu.cat can be applied to nested collections:

>>> batches = [
...     (x1, {'a': {'b': y1}}),
...     (x2, {'a': {'b': y2}}),
...     (x3, {'a': {'b': y3}}),
... ]
>>> X, Y_nested = delu.cat(batches)
>>> len(X) == total_size and len(Y_nested['a']['b']) == total_size
True

Lists are not supported:

>>> # This does not work. Instead, use tuples.
>>> # batches = [[x1, y1], [x2, y2], [x3, y3]]
>>> # delu.cat(batches)  # Error
Parameters:
  • data – the list of collections of tensors. All items of the list must be of the same type, structure and layout, only the dim dimension can vary (same as for torch.cat). All the “leaf” values must be of the type torch.Tensor.

  • dim – the dimension along which the tensors are concatenated.

Returns:

The concatenated items of the list.