zero.data.Enumerate

class zero.data.Enumerate(dataset)[source]

Make dataset return both indices and items.

Tutorial

from torch.utils.data import DataLoader, TensorDataset
X, y = torch.randn(9, 2), torch.randn(9)
dataset = TensorDataset(X, y)
for batch_idx, batch in DataLoader(Enumerate(dataset), batch_size=3):
    print(batch_idx)
tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])

Attributes

dataset

Access the underlying dataset.

Methods

__init__

Initialize self.

__len__

Get the length of the underlying dataset.

__getitem__

Return index and the corresponding item from the underlying dataset.