delu.cat#

delu.cat(iterable, dim=0)[source]#

Like torch.cat, but for collections of tensors.

The function is especially useful for concatenating outputs of a function or a model that returns not a single tensor, but a (named)tuple/dictionary/dataclass of tensors. For example:

class Model(nn.Module):
    ...
    def forward(...) -> tuple[Tensor, Tensor]:
        ...
        return y_pred, embeddings

model = Model(...)
dataset = Dataset(...)
dataloader = DataLoader(...)

# prediction
model.eval()
with torch.inference_mode():
    # Concatenate a sequence of tuples into a single tuple.
    y_pred, embeddings = delu.cat(model(batch) for batch in dataloader)
assert isinstance(y_pred, torch.Tensor) and len(y_pred) == len(dataset)
assert isinstance(embeddings, torch.Tensor) and len(embeddings) == len(dataset)

See other examples below.

Note

Under the hood, roughly speaking, the function “transposes” the sequence of collections to a collection of sequences and applies torch.cat to those sequencies.

Parameters:
  • iterable (Iterable[T]) – the iterable of (named)tuples/dictionaries/dataclasses of tensors. All items of the iterable must be of the same type and have the same structure (tuples must be of the same length, dictionaries must have the same keys, dataclasses must have the same fields). Dataclasses must have only tensor-valued fields.

  • dim (int) – the argument for torch.cat.

Returns:

Concatenated items of the iterable.

Raises:

ValueError – if the iterable is empty or contains unsupported collections.

Return type:

T

Examples

Below, only one-dimensional data and dim=0 are considered for simplicity.

tensor = torch.tensor

batches = [
    # (batch0_x,         batch0_y)
    (tensor([0.0, 1.0]), tensor([[0], [1]])),

    # (batch1_x,         batch1_y)
    (tensor([2.0, 3.0]), tensor([[2], [3]])),

    # ...
]
# result = (x, y)
result = delu.cat(batches)
assert isinstance(result, tuple) and len(result) == 2
assert torch.equal(result[0], tensor([0.0, 1.0, 2.0, 3.0]))
assert torch.equal(result[1], tensor([[0], [1], [2], [3]]))

batches = [
    {'x': tensor([0.0, 1.0]), 'y': tensor([[0], [1]])},
    {'x': tensor([2.0, 3.0]), 'y': tensor([[2], [3]])},
]
result = delu.cat(batches)
assert isinstance(result, dict) and set(result) == {'x', 'y'}
assert torch.equal(result['x'], tensor([0.0, 1.0, 2.0, 3.0]))
assert torch.equal(result['y'], tensor([[0], [1], [2], [3]]))

from dataclasses import dataclass
@dataclass
class Data:
    # all fields must be tensors
    x: torch.Tensor
    y: torch.Tensor

batches = [
    Data(tensor([0.0, 1.0]), tensor([[0], [1]])),
    Data(tensor([2.0, 3.0]), tensor([[2], [3]])),
]
result = delu.cat(batches)
assert isinstance(result, Data)
assert torch.equal(result.x, tensor([0.0, 1.0, 2.0, 3.0]))
assert torch.equal(result.y, tensor([[0], [1], [2], [3]]))

x = tensor([0.0, 1.0, 2.0, 3.0, 4.0])
y = tensor([[0], [10], [20], [30], [40]])
batch_size = 2
ab = delu.cat(delu.iter_batches((x, y), batch_size))
assert torch.equal(ab[0], x)
assert torch.equal(ab[1], y)