cat#
- delu.cat(data: List[T], /, dim: int = 0) T [source]#
Concatenate a sequence of collections of tensors.
delu.cat
is a generalized version oftorch.cat
for concatenating not only tensors, but also (nested) collections of tensors.Usage
Let’s see how a sequence of model outputs for batches can be concatenated into a output tuple for the whole dataset:
>>> from torch.utils.data import DataLoader, TensorDataset >>> dataset = TensorDataset(torch.randn(320, 24)) >>> batch_size = 32 >>> >>> # The model returns not only predictions, but also embeddings. >>> def model(x_batch): ... # A dummy forward pass. ... embeddings_batch = torch.randn(batch_size, 16) ... y_pred_batch = torch.randn(batch_size) ... return (y_pred_batch, embeddings_batch) ... >>> y_pred, embeddings = delu.cat( ... [model(batch) for batch in DataLoader(dataset, batch_size, shuffle=True)] ... ) >>> len(y_pred) == len(dataset) True >>> len(embeddings) == len(dataset) True
The same works for dictionaries:
>>> def model(x_batch): ... return { ... 'y_pred': torch.randn(batch_size), ... 'embeddings': torch.randn(batch_size, 16) ... } ... >>> outputs = delu.cat( ... [model(batch) for batch in DataLoader(dataset, batch_size, shuffle=True)] ... ) >>> len(outputs['y_pred']) == len(dataset) True >>> len(outputs['embeddings']) == len(dataset) True
The same works for sequences of named tuples, dataclasses, tensors and nested combinations of all mentioned collection types.
Below, additinal technical examples are provided.
The common setup:
>>> # First batch. >>> x1 = torch.randn(64, 10) >>> y1 = torch.randn(64) >>> # Second batch. >>> x2 = torch.randn(64, 10) >>> y2 = torch.randn(64) >>> # The last (incomplete) batch. >>> x3 = torch.randn(7, 10) >>> y3 = torch.randn(7) >>> total_size = len(x1) + len(x2) + len(x3)
delu.cat
can be applied to tuples:>>> batches = [(x1, y1), (x2, y2), (x3, y3)] >>> X, Y = delu.cat(batches) >>> len(X) == total_size and len(Y) == total_size True
delu.cat
can be applied to dictionaries:>>> batches = [ ... {'x': x1, 'y': y1}, ... {'x': x2, 'y': y2}, ... {'x': x3, 'y': y3}, ... ] >>> result = delu.cat(batches) >>> isinstance(result, dict) True >>> len(result['x']) == total_size and len(result['y']) == total_size True
delu.cat
can be applied to named tuples:>>> from typing import NamedTuple >>> class Data(NamedTuple): ... x: torch.Tensor ... y: torch.Tensor ... >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> isinstance(result, Data) True >>> len(result.x) == total_size and len(result.y) == total_size True
delu.cat
can be applied to dataclasses:>>> from dataclasses import dataclass >>> @dataclass ... class Data: ... x: torch.Tensor ... y: torch.Tensor ... >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> isinstance(result, Data) True >>> len(result.x) == total_size and len(result.y) == total_size True
delu.cat
can be applied to nested collections:>>> batches = [ ... (x1, {'a': {'b': y1}}), ... (x2, {'a': {'b': y2}}), ... (x3, {'a': {'b': y3}}), ... ] >>> X, Y_nested = delu.cat(batches) >>> len(X) == total_size and len(Y_nested['a']['b']) == total_size True
Lists are not supported:
>>> # This does not work. Instead, use tuples. >>> # batches = [[x1, y1], [x2, y2], [x3, y3]] >>> # delu.cat(batches) # Error
- Parameters:
data – the list of collections of tensors. All items of the list must be of the same type, structure and layout, only the
dim
dimension can vary (same as fortorch.cat
). All the “leaf” values must be of the typetorch.Tensor
.dim – the dimension along which the tensors are concatenated.
- Returns:
The concatenated items of the list.