delu.iter_batches#

delu.iter_batches(data, batch_size, shuffle=False, *, generator=None, drop_last=False)[source]#

Iterate over tensor or collection of tensors by (random) batches.

The function makes batches over the first dimension of the tensors in data and returns an iterator over collections of the same type as the input.

Parameters:
  • data (T) – the tensor or the collection ((named)tuple/dict/dataclass) of tensors. If data is a collection, then the tensors must have the same first dimension. If data is a dataclass, then all its fields must be tensors.

  • batch_size (int) – the batch size. If drop_last is False, then the last batch can be smaller than batch_size.

  • shuffle (bool) – if True, iterate over random batches (without replacement), not sequentially.

  • generator (Generator | None) – the argument for torch.randperm when shuffle is True.

  • drop_last (bool) – same as the drop_last argument for torch.utils.data.DataLoader. When True and the last batch is smaller then batch_size, then this last batch is not returned.

Returns:

Iterator over batches.

Raises:

ValueError – if the data is empty.

Return type:

Iterator[T]

Note

The function lazily indexes to the provided input with batches of indices. This works faster than iterating over the tensors in data with torch.utils.data.DataLoader.

See also

Examples

for epoch in range(n_epochs):
    for batch in delu.iter_batches(data, batch_size, shuffle=True)):
        ...
a = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0])
b = torch.tensor([[0], [10], [20], [30], [40]])
batch_size = 2

for batch in delu.iter_batches(a, batch_size):
    assert isinstance(batch, torch.Tensor)
for batch in delu.iter_batches((a, b), batch_size):
    assert isinstance(batch, tuple) and len(batch) == 2
for batch in delu.iter_batches({'a': a, 'b': b}, batch_size):
    assert isinstance(batch, dict) and set(batch) == {'a', 'b'}

from dataclasses import dataclass
@dataclass
class Data:
    a: torch.Tensor
    b: torch.Tensor

for batch in delu.iter_batches(Data(a, b), batch_size):
    assert isinstance(batch, Data)

ab = delu.cat(delu.iter_batches((a, b), batch_size))
assert torch.equal(ab[0], a)
assert torch.equal(ab[1], b)

n_batches = len(list(delu.iter_batches((a, b), batch_size)))
assert n_batches == 3
n_batches = len(list(delu.iter_batches((a, b), batch_size, drop_last=True)))
assert n_batches == 2