delu.cat#

delu.cat(data, dim=0)[source]#

Like torch.cat, but for collections of tensors.

A typical use case is concatenating a model/function’s outputs for batches into a single output for the whole dataset:

class Model(nn.Module):
    def forward(self, ...) -> tuple[Tensor, Tensor]:
        ...
        return y_pred, embeddings

# Concatenate a sequence of tuples (batch_y_pred, batch_embeddings) into a single tuple.
y_pred, embeddings = delu.cat([model(batch) for batch in dataloader])

The function operates recursively, so nested structures are supported as well (e.g. tuple[Tensor, dict[str, tuple[Tensor, Tensor]]]). See other examples below.

Note

Technically, roughly speaking, the function “transposes” the list of collections to a collection of lists and applies torch.cat to those lists.

Parameters:
  • data (List[T]) – the list of (nested) (named)tuples/dictionaries/dataclasses of tensors. All items of the list must be of the same type and have the same structure (tuples must be of the same length, dictionaries must have the same keys, dataclasses must have the same fields, etc.). All the “leaf” values must be of the type torch.Tensor.

  • dim (int) – the dimension over which the tensors are concatenated.

Returns:

Concatenated items of the list.

Raises:

ValueError – if data is empty or contains unsupported collections.

Return type:

T

Examples

Below, only one-dimensional data and dim=0 are considered for simplicity.

tensor = torch.tensor

batches = [
    # (batch_x, batch_y)
    (tensor([0.0, 1.0]), tensor([[0], [1]])),
    (tensor([2.0, 3.0]), tensor([[2], [3]])),
]
# result = (x, y)
result = delu.cat(batches)
assert isinstance(result, tuple) and len(result) == 2
assert torch.equal(result[0], tensor([0.0, 1.0, 2.0, 3.0]))
assert torch.equal(result[1], tensor([[0], [1], [2], [3]]))

batches = [
    # {'x': batch_x, 'y': batch_y}
    {'x': tensor([0.0, 1.0]), 'y': tensor([[0], [1]])},
    {'x': tensor([2.0, 3.0]), 'y': tensor([[2], [3]])},
]
result = delu.cat(batches)
assert isinstance(result, dict) and set(result) == {'x', 'y'}
assert torch.equal(result['x'], tensor([0.0, 1.0, 2.0, 3.0]))
assert torch.equal(result['y'], tensor([[0], [1], [2], [3]]))

from dataclasses import dataclass
@dataclass
class Data:
    # all fields must be tensors
    x: torch.Tensor
    y: torch.Tensor

batches = [
    Data(tensor([0.0, 1.0]), tensor([[0], [1]])),
    Data(tensor([2.0, 3.0]), tensor([[2], [3]])),
]
result = delu.cat(batches)
assert isinstance(result, Data)
assert torch.equal(result.x, tensor([0.0, 1.0, 2.0, 3.0]))
assert torch.equal(result.y, tensor([[0], [1], [2], [3]]))

batches = [
    {
        'x': tensor([0.0, 1.0]),
        'y': (tensor([[0], [1]]), tensor([[10], [20]]))
    },
    {
        'x': tensor([2.0, 3.0]),
        'y': (tensor([[2], [3]]), tensor([[30], [40]]))
    },
]
result = delu.cat(batches)
assert isinstance(result, dict) and set(result) == {'x', 'y'}
assert torch.equal(result['x'], tensor([0.0, 1.0, 2.0, 3.0]))
assert torch.equal(result['y'][0], tensor([[0], [1], [2], [3]]))
assert torch.equal(result['y'][1], tensor([[10], [20], [30], [40]]))

x = tensor([0.0, 1.0, 2.0, 3.0, 4.0])
y = tensor([[0], [10], [20], [30], [40]])
batch_size = 2
ab = delu.cat(list(delu.iter_batches((x, y), batch_size)))
assert torch.equal(ab[0], x)
assert torch.equal(ab[1], y)