delu.cat#
- delu.cat(iterable, dim=0)[source]#
Like
torch.cat
, but for collections of tensors.The function is especially useful for concatenating outputs of a function or a model that returns not a single tensor, but a (named)tuple/dictionary/dataclass of tensors. For example:
class Model(nn.Module): ... def forward(...) -> tuple[Tensor, Tensor]: ... return y_pred, embeddings model = Model(...) dataset = Dataset(...) dataloader = DataLoader(...) # prediction model.eval() with torch.inference_mode(): # Concatenate a sequence of tuples into a single tuple. y_pred, embeddings = delu.cat(model(batch) for batch in dataloader) assert isinstance(y_pred, torch.Tensor) and len(y_pred) == len(dataset) assert isinstance(embeddings, torch.Tensor) and len(embeddings) == len(dataset)
See other examples below.
Note
Under the hood, roughly speaking, the function “transposes” the sequence of collections to a collection of sequences and applies
torch.cat
to those sequencies.- Parameters:
iterable (Iterable[T]) – the iterable of (named)tuples/dictionaries/dataclasses of tensors. All items of the iterable must be of the same type and have the same structure (tuples must be of the same length, dictionaries must have the same keys, dataclasses must have the same fields). Dataclasses must have only tensor-valued fields.
- Returns:
Concatenated items of the iterable.
- Raises:
ValueError – if the iterable is empty or contains unsupported collections.
- Return type:
T
See also
Examples
Below, only one-dimensional data and dim=0 are considered for simplicity.
tensor = torch.tensor batches = [ # (batch0_x, batch0_y) (tensor([0.0, 1.0]), tensor([[0], [1]])), # (batch1_x, batch1_y) (tensor([2.0, 3.0]), tensor([[2], [3]])), # ... ] # result = (x, y) result = delu.cat(batches) assert isinstance(result, tuple) and len(result) == 2 assert torch.equal(result[0], tensor([0.0, 1.0, 2.0, 3.0])) assert torch.equal(result[1], tensor([[0], [1], [2], [3]])) batches = [ {'x': tensor([0.0, 1.0]), 'y': tensor([[0], [1]])}, {'x': tensor([2.0, 3.0]), 'y': tensor([[2], [3]])}, ] result = delu.cat(batches) assert isinstance(result, dict) and set(result) == {'x', 'y'} assert torch.equal(result['x'], tensor([0.0, 1.0, 2.0, 3.0])) assert torch.equal(result['y'], tensor([[0], [1], [2], [3]])) from dataclasses import dataclass @dataclass class Data: # all fields must be tensors x: torch.Tensor y: torch.Tensor batches = [ Data(tensor([0.0, 1.0]), tensor([[0], [1]])), Data(tensor([2.0, 3.0]), tensor([[2], [3]])), ] result = delu.cat(batches) assert isinstance(result, Data) assert torch.equal(result.x, tensor([0.0, 1.0, 2.0, 3.0])) assert torch.equal(result.y, tensor([[0], [1], [2], [3]])) x = tensor([0.0, 1.0, 2.0, 3.0, 4.0]) y = tensor([[0], [10], [20], [30], [40]]) batch_size = 2 ab = delu.cat(delu.iter_batches((x, y), batch_size)) assert torch.equal(ab[0], x) assert torch.equal(ab[1], y)