Enumerate#

class delu.utils.data.Enumerate[source]#

Bases: Dataset

Make a PyTorch dataset return indices in addition to items (like enumerate, but for datasets).

Usage

Creating the initial non-enumerated dataset:

>>> from torch.utils.data import DataLoader, TensorDataset
>>>
>>> X = torch.arange(10).float().view(5, 2)
>>> X
tensor([[0., 1.],
        [2., 3.],
        [4., 5.],
        [6., 7.],
        [8., 9.]])
>>> Y = -10 * torch.arange(5)
>>> Y
tensor([  0, -10, -20, -30, -40])
>>>
>>> dataset = TensorDataset(X, Y)
>>> dataset[2]
(tensor([4., 5.]), tensor(-20))

The enumerated dataset returns indices in addition to items:

>>> enumerated_dataset = delu.utils.data.Enumerate(dataset)
>>> enumerated_dataset[2]
(2, (tensor([4., 5.]), tensor(-20)))
>>>
>>> for x_batch, y_batch in DataLoader(
...     dataset, batch_size=2
... ):
...     ...
...
>>> for batch_idx, (x_batch, y_batch) in DataLoader(
...     enumerated_dataset, batch_size=2
... ):
...     print(batch_idx)
tensor([0, 1])
tensor([2, 3])
tensor([4])

The original dataset and its size remain accessible:

>>> enumerated_dataset.dataset is dataset
True
>>> len(enumerated_dataset) == len(dataset)
True
__getitem__(index) Tuple[Any, Any][source]#

Return index and the corresponding item from the original dataset.

Parameters:

index – the index.

Returns:

(index, item)

__init__(dataset: Dataset, /) None[source]#
Parameters:

dataset – the original dataset.

__len__() int[source]#

Get the length of the original dataset.

property dataset: Dataset#

The original dataset.