delu.iter_batches#
- delu.iter_batches(data, batch_size, shuffle=False, *, generator=None, drop_last=False)[source]#
Iterate over tensor or collection of tensors by (random) batches.
The function makes batches over the first dimension of the tensors in
data
and returns an iterator over collections of the same type as the input. A simple example (see below for more examples):n_objects = 100 n_features = 4 X = torch.randn(n_objects, n_features) y = torch.randn(n_objects) for batch_x, batch_y in delu.iter_batches( (X, y), batch_size=12, shuffle=True ): ... # train(batch_x, batch_y)
- Parameters:
data (T) – the tensor or the collection ((named)tuple/dict/dataclass) of tensors. If data is a collection, then the tensors must have the same first dimension. If data is a dataclass, then all its fields must be tensors.
batch_size (int) – the batch size. If
drop_last
is False, then the last batch can be smaller thanbatch_size
.shuffle (bool) – if True, iterate over random batches (without replacement), not sequentially.
generator (Generator | None) – the argument for
torch.randperm
whenshuffle
is True.drop_last (bool) – same as the
drop_last
argument fortorch.utils.data.DataLoader
. When True and the last batch is smaller thenbatch_size
, then this last batch is not returned.
- Returns:
Iterator over batches.
- Raises:
ValueError – if the data is empty.
- Return type:
Iterator[T]
Note
The function lazily indexes to the provided input with batches of indices. This works faster than iterating over the tensors in
data
withtorch.utils.data.DataLoader
.See also
Examples
for epoch in range(n_epochs): for batch in delu.iter_batches(data, batch_size, shuffle=True)): ...
a = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]) b = torch.tensor([[0], [10], [20], [30], [40]]) batch_size = 2 for batch in delu.iter_batches(a, batch_size): assert isinstance(batch, torch.Tensor) for batch in delu.iter_batches((a, b), batch_size): assert isinstance(batch, tuple) and len(batch) == 2 for batch in delu.iter_batches({'a': a, 'b': b}, batch_size): assert isinstance(batch, dict) and set(batch) == {'a', 'b'} from dataclasses import dataclass @dataclass class Data: a: torch.Tensor b: torch.Tensor for batch in delu.iter_batches(Data(a, b), batch_size): assert isinstance(batch, Data) ab = delu.cat(list(delu.iter_batches((a, b), batch_size))) assert torch.equal(ab[0], a) assert torch.equal(ab[1], b) n_batches = len(list(delu.iter_batches((a, b), batch_size))) assert n_batches == 3 n_batches = len(list(delu.iter_batches((a, b), batch_size, drop_last=True))) assert n_batches == 2