iter_batches#

delu.iter_batches(
data: T,
/,
batch_size: int,
*,
shuffle: bool = False,
generator: Generator | None = None,
drop_last: bool = False,
) Iterator[T][source]#

Iterate over a tensor or a collection of tensors by (random) batches.

The function makes batches along the first dimension of the tensors in data.

Note

delu.iter_batches is significantly faster for in-memory tensors than torch.utils.data.DataLoader, because, when building batches, it uses batched indexing instead of one-by-one indexing.

Usage

>>> X = torch.randn(12, 32)
>>> Y = torch.randn(12)

delu.iter_batches can be applied to tensors:

>>> for x in delu.iter_batches(X, batch_size=5):
...     print(len(x))
5
5
2

delu.iter_batches can be applied to tuples:

>>> # shuffle=True can be useful for training.
>>> dataset = (X, Y)
>>> for x, y in delu.iter_batches(dataset, batch_size=5, shuffle=True):
...     print(len(x), len(y))
5 5
5 5
2 2
>>> # Drop the last incomplete batch.
>>> for x, y in delu.iter_batches(
...     dataset, batch_size=5, shuffle=True, drop_last=True
... ):
...     print(len(x), len(y))
5 5
5 5
>>> # The last batch is complete, so drop_last=True does not have any effect.
>>> batches = []
>>> for x, y in delu.iter_batches(dataset, batch_size=6, drop_last=True):
...     print(len(x), len(y))
...     batches.append((x, y))
6 6
6 6

By default, shuffle is set to False, i.e. the order of items is preserved:

>>> X2, Y2 = delu.cat(list(delu.iter_batches((X, Y), batch_size=5)))
>>> print((X == X2).all().item(), (Y == Y2).all().item())
True True

delu.iter_batches can be applied to dictionaries:

>>> dataset = {'x': X, 'y': Y}
>>> for batch in delu.iter_batches(dataset, batch_size=5, shuffle=True):
...     print(isinstance(batch, dict), len(batch['x']), len(batch['y']))
True 5 5
True 5 5
True 2 2

delu.iter_batches can be applied to named tuples:

>>> from typing import NamedTuple
>>> class Data(NamedTuple):
...     x: torch.Tensor
...     y: torch.Tensor
>>> dataset = Data(X, Y)
>>> for batch in delu.iter_batches(dataset, batch_size=5, shuffle=True):
...     print(isinstance(batch, Data), len(batch.x), len(batch.y))
True 5 5
True 5 5
True 2 2

delu.iter_batches can be applied to dataclasses:

>>> from dataclasses import dataclass
>>> @dataclass
... class Data:
...     x: torch.Tensor
...     y: torch.Tensor
>>> dataset = Data(X, Y)
>>> for batch in delu.iter_batches(dataset, batch_size=5, shuffle=True):
...     print(isinstance(batch, Data), len(batch.x), len(batch.y))
True 5 5
True 5 5
True 2 2
Parameters:
  • data – the tensor or the non-empty collection of tensors. If data is a collection, then the tensors must be of the same size along the first dimension.

  • batch_size – the batch size. If drop_last is False, then the last batch can be smaller than batch_size.

  • shuffle – if True, iterate over random batches (without replacement), not sequentially.

  • generator – when shuffle is True, passing generator makes the function reproducible.

  • drop_last – when True and the last batch is smaller then batch_size, then this last batch is not returned (in other words, same as the drop_last argument for torch.utils.data.DataLoader).

Returns:

the iterator over batches.