zero.iter_batches

zero.iter_batches(data, *args, **kwargs)[source]

Efficiently iterate over data (tensor, tuple of tensors, dict of tensors etc.) in a batchwise manner.

The function is useful when you want to efficiently iterate once over tensor-based data in a batchwise manner. See examples below for typical use cases.

The function is a more efficient alternative to torch.utils.data.DataLoader when it comes to in-memory data, because it uses batch-based indexing instead of item-based indexing (DataLoader’s behavior). The shuffling logic is delegated to the native PyTorch DataLoader, i.e. no custom logic is performed under the hood.

Parameters
  • data (Union[torch.Tensor, Tuple[torch.Tensor, ..], Dict[Any, torch.Tensor], torch.utils.data.dataset.TensorDataset]) –

  • *args – positional arguments for IndexLoader

  • **kwargs – keyword arguments for IndexLoader

Returns

Iterator over batches.

Return type

Iterator

Warning

Numpy-arrays are not supported because of how they behave when indexed by a torch tensor of the size 1. For details, see the issue

Note

If you want to infititely iterate over batches, wrap the function in while True:.

Examples

Besides loops over batches, the function can be used in combination with concat:

result = concat(map(fn, iter_batches(dataset_or_tensors_or_whatever, ...)))

The function can also be used for training:

for epoch in epochs:
    for batch in iter_batches(data, batch_size, shuffle=True)):
        ...