Enumerate#
- class delu.utils.data.Enumerate[source]#
Bases:
Dataset
Make a PyTorch dataset return indices in addition to items (like
enumerate
, but for datasets).TL;DR:
dataset[i] -> value
enumerated_dataset[i] -> (i, value)
Usage
Creating the initial non-enumerated
dataset
:>>> from torch.utils.data import DataLoader, TensorDataset >>> >>> X = torch.arange(10).float().view(5, 2) >>> X tensor([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]) >>> Y = -10 * torch.arange(5) >>> Y tensor([ 0, -10, -20, -30, -40]) >>> >>> dataset = TensorDataset(X, Y) >>> dataset[2] (tensor([4., 5.]), tensor(-20))
The enumerated dataset returns indices in addition to items:
>>> enumerated_dataset = delu.utils.data.Enumerate(dataset) >>> enumerated_dataset[2] (2, (tensor([4., 5.]), tensor(-20))) >>> >>> for x_batch, y_batch in DataLoader( ... dataset, batch_size=2 ... ): ... ... ... >>> for batch_idx, (x_batch, y_batch) in DataLoader( ... enumerated_dataset, batch_size=2 ... ): ... print(batch_idx) tensor([0, 1]) tensor([2, 3]) tensor([4])
The original dataset and its size remain accessible:
>>> enumerated_dataset.dataset is dataset True >>> len(enumerated_dataset) == len(dataset) True