IndexDataset#

class delu.utils.data.IndexDataset[source]#

Bases: Dataset

A trivial dataset that yields indices back to user (useful for DistributedDataParallel (DDP)).

This simple dataset is useful when both conditions are true:

  1. A dataloader that yields batches of indices instead of objects is needed

  2. The Distributed Data Parallel setup is used.

Note

If only the first condition is true, consider using the combinatation of torch.randperm and torch.Tensor.split instead.

Usage

>>> 
>>> from torch.utils.data import DataLoader
>>> from torch.utils.data.distributed import DistributedSampler
>>>
>>> train_size = 1000
>>> batch_size = 64
>>> dataset = delu.data.IndexDataset(train_size)
>>> # The dataset is really *that* trivial:
>>> for i in range(train_size):
...     assert dataset[i] == i
>>> dataloader = DataLoader(
...     dataset,
...     batch_size,
...     sampler=DistributedSampler(dataset),
... )
>>> for epoch in range(n_epochs):
...     for batch_indices in dataloader:
...         ...
__getitem__(index: int) int[source]#

Get the same index back.

The index must be an integer from range(len(self)).

__init__(size: int) None[source]#
Parameters:

size – the dataset size.

__len__() int[source]#

Get the dataset size.