delu.cat#
- delu.cat(data, dim=0)[source]#
Like
torch.cat
, but for collections of tensors.A typical use case is concatenating a model/function’s outputs for batches into a single output for the whole dataset:
class Model(nn.Module): def forward(self, ...) -> tuple[Tensor, Tensor]: ... return y_pred, embeddings # Concatenate a sequence of tuples (batch_y_pred, batch_embeddings) into a single tuple. y_pred, embeddings = delu.cat([model(batch) for batch in dataloader])
The function operates recursively, so nested structures are supported as well (e.g.
tuple[Tensor, dict[str, tuple[Tensor, Tensor]]]
). See other examples below.Note
Technically, roughly speaking, the function “transposes” the list of collections to a collection of lists and applies
torch.cat
to those lists.- Parameters:
data (List[T]) – the list of (nested) (named)tuples/dictionaries/dataclasses of tensors. All items of the list must be of the same type and have the same structure (tuples must be of the same length, dictionaries must have the same keys, dataclasses must have the same fields, etc.). All the “leaf” values must be of the type
torch.Tensor
.dim (int) – the dimension over which the tensors are concatenated.
- Returns:
Concatenated items of the list.
- Raises:
ValueError – if
data
is empty or contains unsupported collections.- Return type:
T
See also
Examples
Below, only one-dimensional data and dim=0 are considered for simplicity.
tensor = torch.tensor batches = [ # (batch_x, batch_y) (tensor([0.0, 1.0]), tensor([[0], [1]])), (tensor([2.0, 3.0]), tensor([[2], [3]])), ] # result = (x, y) result = delu.cat(batches) assert isinstance(result, tuple) and len(result) == 2 assert torch.equal(result[0], tensor([0.0, 1.0, 2.0, 3.0])) assert torch.equal(result[1], tensor([[0], [1], [2], [3]])) batches = [ # {'x': batch_x, 'y': batch_y} {'x': tensor([0.0, 1.0]), 'y': tensor([[0], [1]])}, {'x': tensor([2.0, 3.0]), 'y': tensor([[2], [3]])}, ] result = delu.cat(batches) assert isinstance(result, dict) and set(result) == {'x', 'y'} assert torch.equal(result['x'], tensor([0.0, 1.0, 2.0, 3.0])) assert torch.equal(result['y'], tensor([[0], [1], [2], [3]])) from dataclasses import dataclass @dataclass class Data: # all fields must be tensors x: torch.Tensor y: torch.Tensor batches = [ Data(tensor([0.0, 1.0]), tensor([[0], [1]])), Data(tensor([2.0, 3.0]), tensor([[2], [3]])), ] result = delu.cat(batches) assert isinstance(result, Data) assert torch.equal(result.x, tensor([0.0, 1.0, 2.0, 3.0])) assert torch.equal(result.y, tensor([[0], [1], [2], [3]])) batches = [ { 'x': tensor([0.0, 1.0]), 'y': (tensor([[0], [1]]), tensor([[10], [20]])) }, { 'x': tensor([2.0, 3.0]), 'y': (tensor([[2], [3]]), tensor([[30], [40]])) }, ] result = delu.cat(batches) assert isinstance(result, dict) and set(result) == {'x', 'y'} assert torch.equal(result['x'], tensor([0.0, 1.0, 2.0, 3.0])) assert torch.equal(result['y'][0], tensor([[0], [1], [2], [3]])) assert torch.equal(result['y'][1], tensor([[10], [20], [30], [40]])) x = tensor([0.0, 1.0, 2.0, 3.0, 4.0]) y = tensor([[0], [10], [20], [30], [40]]) batch_size = 2 ab = delu.cat(list(delu.iter_batches((x, y), batch_size))) assert torch.equal(ab[0], x) assert torch.equal(ab[1], y)