cat#
- delu.cat(data: List[T], /, dim: int = 0) T[source]#
Concatenate a sequence of collections of tensors.
delu.catis a generalized version oftorch.cat. A typical use case is concatenating a sequence of batches:(think of
xiandyias of batches)torch.cat: [x1, x2, ..., xN] -> xdelu.cat: [x1, x2, ..., xN] -> xdelu.cat: [(x1, y1), (x2, y2), ..., (xN, yN)] -> (x, y)delu.cat: [{'x': x1, 'y': y1}, ..., {'x': xN, 'y': yN}] -> {'x': x, 'y': y}Same for named tuples.
Same for dataclasses.
Nested collections are supported.
In other words, while
torch.catconcatenates a sequence of tensors,delu.toconcatenates a sequence of collections of tensors.Usage
Common setup:
>>> # First batch. >>> x1 = torch.randn(64, 10) >>> y1 = torch.randn(64) >>> # Second batch. >>> x2 = torch.randn(64, 10) >>> y2 = torch.randn(64) >>> # The last (incomplete) batch. >>> x3 = torch.randn(7, 10) >>> y3 = torch.randn(7) >>> # Total size. >>> len(x1) + len(x2) + len(x3) 135
delu.catcan be applied to tuples:>>> batches = [(x1, y1), (x2, y2), (x3, y3)] >>> X, Y = delu.cat(batches) >>> print(len(X), len(Y)) 135 135
delu.catcan be applied to dictionaries:>>> batches = [ ... {'x': x1, 'y': y1}, ... {'x': x2, 'y': y2}, ... {'x': x3, 'y': y3}, ... ] >>> result = delu.cat(batches) >>> print(isinstance(result, dict), len(result['x']), len(result['y'])) True 135 135
delu.catcan be applied to named tuples:>>> from typing import NamedTuple >>> class Data(NamedTuple): ... x: torch.Tensor ... y: torch.Tensor ... >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> print(isinstance(result, Data), len(result.x), len(result.y)) True 135 135
delu.catcan be applied to dataclasses:>>> from dataclasses import dataclass >>> @dataclass ... class Data: ... x: torch.Tensor ... y: torch.Tensor ... >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> print(isinstance(result, Data), len(result.x), len(result.y)) True 135 135
delu.catcan be applied to nested collections:>>> batches = [ ... (x1, {'a': {'b': y1}}), ... (x2, {'a': {'b': y2}}), ... (x3, {'a': {'b': y3}}), ... ] >>> X, Y_nested = delu.cat(batches) >>> print(len(X), len(Y_nested['a']['b'])) 135 135
`delu.cat` cannot be applied to lists:
>>> # This does not work. Instead, use tuples. >>> # batches = [[x1, y1], [x2, y2], [x3, y3]] >>> # delu.cat(batches) # Error
- Parameters:
data – the list of collections of tensors. All items of the list must be of the same type, structure and layout, only the
dimdimension can vary (same as fortorch.cat). All the “leaf” values must be of the typetorch.Tensor.dim – the dimension along which the tensors are concatenated.
- Returns:
The concatenated items of the list.