IndexDataset#
- class delu.utils.data.IndexDataset[source]#
Bases:
DatasetA trivial dataset that yields indices back to user (useful for DistributedDataParallel (DDP)).
This simple dataset is useful when both conditions are true:
A dataloader that yields batches of indices instead of objects is needed
The Distributed Data Parallel setup is used.
Note
If only the first condition is true, consider using the combinatation of
torch.randpermandtorch.Tensor.splitinstead.Usage
>>> >>> from torch.utils.data import DataLoader >>> from torch.utils.data.distributed import DistributedSampler >>> >>> train_size = 1000 >>> batch_size = 64 >>> dataset = delu.data.IndexDataset(train_size) >>> # The dataset is really *that* trivial: >>> for i in range(train_size): ... assert dataset[i] == i >>> dataloader = DataLoader( ... dataset, ... batch_size, ... sampler=DistributedSampler(dataset), ... ) >>> for epoch in range(n_epochs): ... for batch_indices in dataloader: ... ...