iter_batches¶
- delu.iter_batches(
- data: T,
- /,
- batch_size: int,
- *,
- shuffle: bool = False,
- generator: Generator | None = None,
- drop_last: bool = False,
Iterate over a tensor or a collection of tensors by (random) batches.
The function makes batches along the first dimension of the tensors in
data
.TL;DR (assuming that
X
andY
denote full tensors andxi
andyi
denote batches):delu.iter_batches: X -> [x1, x2, ..., xN]
delu.iter_batches: (X, Y) -> [(x1, y1), (x2, y2), ..., (xN, yN)]
delu.iter_batches: {'x': X, 'y': Y} -> [{'x': x1, 'y': y1}, ...]
Same for named tuples.
Same for dataclasses.
Note
delu.iter_batches
is significantly faster for in-memory tensors thantorch.utils.data.DataLoader
, because, when building batches, it uses batched indexing instead of one-by-one indexing.Usage
>>> X = torch.randn(12, 32) >>> Y = torch.randn(12)
delu.iter_batches
can be applied to tensors:>>> for x in delu.iter_batches(X, batch_size=5): ... print(len(x)) 5 5 2
delu.iter_batches
can be applied to tuples:>>> # shuffle=True can be useful for training. >>> dataset = (X, Y) >>> for x, y in delu.iter_batches(dataset, batch_size=5, shuffle=True): ... print(len(x), len(y)) 5 5 5 5 2 2 >>> # Drop the last incomplete batch. >>> for x, y in delu.iter_batches( ... dataset, batch_size=5, shuffle=True, drop_last=True ... ): ... print(len(x), len(y)) 5 5 5 5 >>> # The last batch is complete, so drop_last=True does not have any effect. >>> batches = [] >>> for x, y in delu.iter_batches(dataset, batch_size=6, drop_last=True): ... print(len(x), len(y)) ... batches.append((x, y)) 6 6 6 6
By default,
shuffle
is set toFalse
, i.e. the order of items is preserved:>>> X2, Y2 = delu.cat(list(delu.iter_batches((X, Y), batch_size=5))) >>> print((X == X2).all().item(), (Y == Y2).all().item()) True True
delu.iter_batches
can be applied to dictionaries:>>> dataset = {'x': X, 'y': Y} >>> for batch in delu.iter_batches(dataset, batch_size=5, shuffle=True): ... print(isinstance(batch, dict), len(batch['x']), len(batch['y'])) True 5 5 True 5 5 True 2 2
delu.iter_batches
can be applied to named tuples:>>> from typing import NamedTuple >>> class Data(NamedTuple): ... x: torch.Tensor ... y: torch.Tensor >>> dataset = Data(X, Y) >>> for batch in delu.iter_batches(dataset, batch_size=5, shuffle=True): ... print(isinstance(batch, Data), len(batch.x), len(batch.y)) True 5 5 True 5 5 True 2 2
delu.iter_batches
can be applied to dataclasses:>>> from dataclasses import dataclass >>> @dataclass ... class Data: ... x: torch.Tensor ... y: torch.Tensor >>> dataset = Data(X, Y) >>> for batch in delu.iter_batches(dataset, batch_size=5, shuffle=True): ... print(isinstance(batch, Data), len(batch.x), len(batch.y)) True 5 5 True 5 5 True 2 2
- Parameters:
data – the tensor or the non-empty collection of tensors. If data is a collection, then the tensors must be of the same size along the first dimension.
batch_size – the batch size. If
drop_last
is False, then the last batch can be smaller thanbatch_size
.shuffle – if True, iterate over random batches (without replacement), not sequentially.
generator – when
shuffle
is True, passinggenerator
makes the function reproducible.drop_last – when
True
and the last batch is smaller thenbatch_size
, then this last batch is not returned (in other words, same as thedrop_last
argument fortorch.utils.data.DataLoader
).
- Returns:
the iterator over batches.