cat#
- delu.cat(data: List[T], /, dim: int = 0) T [source]#
Concatenate a sequence of collections of tensors.
delu.cat
is a generalized version oftorch.cat
. A typical use case is concatenating a sequence of batches:(think of
xi
andyi
as of batches)torch.cat: [x1, x2, ..., xN] -> x
delu.cat: [x1, x2, ..., xN] -> x
delu.cat: [(x1, y1), (x2, y2), ..., (xN, yN)] -> (x, y)
delu.cat: [{'x': x1, 'y': y1}, ..., {'x': xN, 'y': yN}] -> {'x': x, 'y': y}
Same for named tuples.
Same for dataclasses.
Nested collections are supported.
In other words, while
torch.cat
concatenates a sequence of tensors,delu.to
concatenates a sequence of collections of tensors.Usage
Common setup:
>>> # First batch. >>> x1 = torch.randn(64, 10) >>> y1 = torch.randn(64) >>> # Second batch. >>> x2 = torch.randn(64, 10) >>> y2 = torch.randn(64) >>> # The last (incomplete) batch. >>> x3 = torch.randn(7, 10) >>> y3 = torch.randn(7) >>> # Total size. >>> len(x1) + len(x2) + len(x3) 135
delu.cat
can be applied to tuples:>>> batches = [(x1, y1), (x2, y2), (x3, y3)] >>> X, Y = delu.cat(batches) >>> print(len(X), len(Y)) 135 135
delu.cat
can be applied to dictionaries:>>> batches = [ ... {'x': x1, 'y': y1}, ... {'x': x2, 'y': y2}, ... {'x': x3, 'y': y3}, ... ] >>> result = delu.cat(batches) >>> print(isinstance(result, dict), len(result['x']), len(result['y'])) True 135 135
delu.cat
can be applied to named tuples:>>> from typing import NamedTuple >>> class Data(NamedTuple): ... x: torch.Tensor ... y: torch.Tensor ... >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> print(isinstance(result, Data), len(result.x), len(result.y)) True 135 135
delu.cat
can be applied to dataclasses:>>> from dataclasses import dataclass >>> @dataclass ... class Data: ... x: torch.Tensor ... y: torch.Tensor ... >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> print(isinstance(result, Data), len(result.x), len(result.y)) True 135 135
delu.cat
can be applied to nested collections:>>> batches = [ ... (x1, {'a': {'b': y1}}), ... (x2, {'a': {'b': y2}}), ... (x3, {'a': {'b': y3}}), ... ] >>> X, Y_nested = delu.cat(batches) >>> print(len(X), len(Y_nested['a']['b'])) 135 135
`delu.cat` cannot be applied to lists:
>>> # This does not work. Instead, use tuples. >>> # batches = [[x1, y1], [x2, y2], [x3, y3]] >>> # delu.cat(batches) # Error
- Parameters:
data – the list of collections of tensors. All items of the list must be of the same type, structure and layout, only the
dim
dimension can vary (same as fortorch.cat
). All the “leaf” values must be of the typetorch.Tensor
.dim – the dimension along which the tensors are concatenated.
- Returns:
The concatenated items of the list.