Source code for delu.utils.data

"""An extension to `torch.utils.data`."""

from typing import Any, Tuple, TypeVar

from torch.utils.data import Dataset

from .._utils import deprecated

T = TypeVar('T')


__all__ = ['Enumerate', 'IndexDataset']


[docs]@deprecated('') class Enumerate(Dataset): """Make a PyTorch dataset return indices in addition to items (like `enumerate`, but for datasets). <DEPRECATION MESSAGE> TL;DR: - ``dataset[i] -> value`` - ``enumerated_dataset[i] -> (i, value)`` **Usage** Creating the initial non-enumerated ``dataset``: >>> from torch.utils.data import DataLoader, TensorDataset >>> >>> X = torch.arange(10).float().view(5, 2) >>> X tensor([[0., 1.], [2., 3.], [4., 5.], [6., 7.], [8., 9.]]) >>> Y = -10 * torch.arange(5) >>> Y tensor([ 0, -10, -20, -30, -40]) >>> >>> dataset = TensorDataset(X, Y) >>> dataset[2] (tensor([4., 5.]), tensor(-20)) The enumerated dataset returns indices in addition to items: >>> enumerated_dataset = delu.utils.data.Enumerate(dataset) >>> enumerated_dataset[2] (2, (tensor([4., 5.]), tensor(-20))) >>> >>> for x_batch, y_batch in DataLoader( ... dataset, batch_size=2 ... ): ... ... ... >>> for batch_idx, (x_batch, y_batch) in DataLoader( ... enumerated_dataset, batch_size=2 ... ): ... print(batch_idx) tensor([0, 1]) tensor([2, 3]) tensor([4]) The original dataset and its size remain accessible: >>> enumerated_dataset.dataset is dataset True >>> len(enumerated_dataset) == len(dataset) True """ # noqa: E501
[docs] def __init__(self, dataset: Dataset, /) -> None: """ Args: dataset: the original dataset. """ self._dataset = dataset
@property def dataset(self) -> Dataset: """The original dataset.""" return self._dataset
[docs] def __len__(self) -> int: """Get the length of the original dataset.""" return len(self._dataset) # type: ignore
[docs] def __getitem__(self, index) -> Tuple[Any, Any]: """Return index and the corresponding item from the original dataset. Args: index: the index. Returns: (index, item) """ return index, self._dataset[index]
[docs]@deprecated('') class IndexDataset(Dataset): """A trivial dataset that yields indices back to user (useful for DistributedDataParallel (DDP)). <DEPRECATION MESSAGE> This simple dataset is useful when *both* conditions are true: 1. A dataloader that yields batches of *indices* instead of *objects* is needed 2. The `Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ setup is used. .. note:: If only the first condition is true, consider using the combinatation of `torch.randperm` and `torch.Tensor.split` instead. **Usage** >>> # doctest: +SKIP >>> from torch.utils.data import DataLoader >>> from torch.utils.data.distributed import DistributedSampler >>> >>> train_size = 1000 >>> batch_size = 64 >>> dataset = delu.data.IndexDataset(train_size) >>> # The dataset is really *that* trivial: >>> for i in range(train_size): ... assert dataset[i] == i >>> dataloader = DataLoader( ... dataset, ... batch_size, ... sampler=DistributedSampler(dataset), ... ) >>> for epoch in range(n_epochs): ... for batch_indices in dataloader: ... ... """ # noqa: E501
[docs] def __init__(self, size: int) -> None: """ Args: size: the dataset size. """ if size < 1: raise ValueError('size must be positive') self.size = size
[docs] def __len__(self) -> int: """Get the dataset size.""" return self.size
[docs] def __getitem__(self, index: int) -> int: """Get the same index back. The index must be an integer from ``range(len(self))``. """ # Some datasets support non-integer indices. if not isinstance(index, int): raise ValueError('index must be an integer') if index < 0 or index >= self.size: raise IndexError( f'The index {index} is out of range (the dataset size is {self.size})' ) return index