IndexDataset#
- class delu.utils.data.IndexDataset[source]#
Bases:
Dataset
A trivial dataset that yields indices back to user (useful for DistributedDataParallel (DDP)).
This simple dataset is useful when both conditions are true:
A dataloader that yields batches of indices instead of objects is needed
The Distributed Data Parallel setup is used.
Note
If only the first condition is true, consider using the combinatation of
torch.randperm
andtorch.Tensor.split
instead.Usage
>>> >>> from torch.utils.data import DataLoader >>> from torch.utils.data.distributed import DistributedSampler >>> >>> train_size = 1000 >>> batch_size = 64 >>> dataset = delu.data.IndexDataset(train_size) >>> # The dataset is really *that* trivial: >>> for i in range(train_size): ... assert dataset[i] == i >>> dataloader = DataLoader( ... dataset, ... batch_size, ... sampler=DistributedSampler(dataset), ... ) >>> for epoch in range(n_epochs): ... for batch_indices in dataloader: ... ...