cat#
- delu.cat(data: List[T], /, dim: int = 0) T [source]#
Concatenate a sequence of collections of tensors (like
torch.cat
, but for collections of tensors).While
torch.cat
concatenates a sequence of tensors,delu.to
concatenates a sequence of collections of tensors (tuples, named tuples, dictionaries, dataclasses and nested combinations thereof; nested lists are not allowed).Usage
Common setup:
>>> # First batch. >>> x1 = torch.randn(64, 10) >>> y1 = torch.randn(64) >>> # Second batch. >>> x2 = torch.randn(64, 10) >>> y2 = torch.randn(64) >>> # The last (incomplete) batch. >>> x3 = torch.randn(7, 10) >>> y3 = torch.randn(7) >>> # Total size. >>> len(x1) + len(x2) + len(x3) 135
delu.cat
can be applied to tuples:>>> batches = [(x1, y1), (x2, y2), (x3, y3)] >>> X, Y = delu.cat(batches) >>> print(len(X), len(Y)) 135 135
delu.cat
can be applied to dictionaries:>>> batches = [ ... {'x': x1, 'y': y1}, ... {'x': x2, 'y': y2}, ... {'x': x3, 'y': y3}, ... ] >>> result = delu.cat(batches) >>> print(isinstance(result, dict), len(result['x']), len(result['y'])) True 135 135
delu.cat
can be applied to named tuples:>>> from typing import NamedTuple >>> class Data(NamedTuple): ... x: torch.Tensor ... y: torch.Tensor >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> print(isinstance(result, Data), len(result.x), len(result.y)) True 135 135
delu.cat
can be applied to dataclasses:>>> from dataclasses import dataclass >>> @dataclass ... class Data: ... x: torch.Tensor ... y: torch.Tensor >>> batches = [Data(x1, y1), Data(x2, y2), Data(x3, y3)] >>> result = delu.cat(batches) >>> print(isinstance(result, Data), len(result.x), len(result.y)) True 135 135
delu.cat
can be applied to nested collections:>>> batches = [ ... (x1, {'a': {'b': y1}}), ... (x2, {'a': {'b': y2}}), ... (x3, {'a': {'b': y3}}), ... ] >>> X, Y_nested = delu.cat(batches) >>> print(len(X), len(Y_nested['a']['b'])) 135 135
`delu.cat` cannot be applied to lists:
>>> # This does not work. Instead, use tuples. >>> # batches = [[x1, y1], [x2, y2], [x3, y3]] >>> # delu.cat(batches) # Error
- Parameters:
data – the list of collections of tensors. All items of the list must be of the same type, structure and layout, only the
dim
dimension can vary (same as fortorch.cat
). All the “leaf” values must be of the typetorch.Tensor
.dim – the dimension along which the tensors are concatenated.
- Returns:
The concatenated items of the list.