delu.iter_batches#
- delu.iter_batches(data, batch_size, shuffle=False, *, generator=None, drop_last=False)[source]#
Iterate over tensor or collection of tensors by (random) batches.
The function makes batches over the first dimension of the tensors in
dataand returns an iterator over collections of the same type as the input. A simple example (see below for more examples):n_objects = 100 n_features = 4 X = torch.randn(n_objects, n_features) y = torch.randn(n_objects) for batch_x, batch_y in delu.iter_batches( (X, y), batch_size=12, shuffle=True ): ... # train(batch_x, batch_y)
- Parameters:
data (T) – the tensor or the collection ((named)tuple/dict/dataclass) of tensors. If data is a collection, then the tensors must have the same first dimension. If data is a dataclass, then all its fields must be tensors.
batch_size (int) – the batch size. If
drop_lastis False, then the last batch can be smaller thanbatch_size.shuffle (bool) – if True, iterate over random batches (without replacement), not sequentially.
generator (Generator | None) – the argument for
torch.randpermwhenshuffleis True.drop_last (bool) – same as the
drop_lastargument fortorch.utils.data.DataLoader. When True and the last batch is smaller thenbatch_size, then this last batch is not returned.
- Returns:
Iterator over batches.
- Raises:
ValueError – if the data is empty.
- Return type:
Iterator[T]
Note
The function lazily indexes to the provided input with batches of indices. This works faster than iterating over the tensors in
datawithtorch.utils.data.DataLoader.See also
Examples
for epoch in range(n_epochs): for batch in delu.iter_batches(data, batch_size, shuffle=True)): ...
a = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]) b = torch.tensor([[0], [10], [20], [30], [40]]) batch_size = 2 for batch in delu.iter_batches(a, batch_size): assert isinstance(batch, torch.Tensor) for batch in delu.iter_batches((a, b), batch_size): assert isinstance(batch, tuple) and len(batch) == 2 for batch in delu.iter_batches({'a': a, 'b': b}, batch_size): assert isinstance(batch, dict) and set(batch) == {'a', 'b'} from dataclasses import dataclass @dataclass class Data: a: torch.Tensor b: torch.Tensor for batch in delu.iter_batches(Data(a, b), batch_size): assert isinstance(batch, Data) ab = delu.cat(list(delu.iter_batches((a, b), batch_size))) assert torch.equal(ab[0], a) assert torch.equal(ab[1], b) n_batches = len(list(delu.iter_batches((a, b), batch_size))) assert n_batches == 3 n_batches = len(list(delu.iter_batches((a, b), batch_size, drop_last=True))) assert n_batches == 2