Source code for delu.data

"""An addition to `torch.utils.data`."""

from typing import Any, Callable, Iterable, Optional, TypeVar, Union

import torch
from torch.utils.data import DataLoader, Dataset

from ._stream import Stream  # noqa: F401
from ._utils import deprecated
from .utils import data as utils_data

T = TypeVar('T')


[docs]@deprecated('Renamed to `delu.utils.data.Enumerate`') class Enumerate(utils_data.Enumerate): """ <DEPRECATION MESSAGE> """ pass
[docs]@deprecated('Instead, use `delu.utils.data.IndexDataset`') class IndexDataset(utils_data.IndexDataset): """ <DEPRECATION MESSAGE> """ pass
[docs]@deprecated('') class FnDataset(Dataset): """Create simple PyTorch datasets without classes and inheritance. <DEPRECATION MESSAGE> `FnDataset` allows avoiding implementing `~torch.utils.data.Dataset` classes in simple cases. .. rubric:: Tutorial First, a quick example. Without `FnDataset`:: from PIL import Image class ImagesList(Dataset): def __init__(self, filenames, transform): self.filenames = filenames self.transform = transform def __len__(self): return len(self.filenames) def __getitem__(self, index): return self.transform(Image.open(self.filenames[index])) dataset = ImagesList(filenames, transform) With `FnDataset`:: dataset = delu.data.FnDataset(Image.open, filenames, transform) # Cache images after the first load: from functools import lru_cache dataset = delu.data.FnDataset(lru_cache(None)(Image.open), filenames) In other words, with the vanilla PyTorch, in order to create a dataset, you have to inherit from `torch.utils.data.Dataset` and implement three methods: - ``__init__`` - ``__len__`` - ``__getitem__`` With `FnDataset` the only thing you *may* need to implement is the ``fn`` argument that will power ``__getitem__``. The easiest way to learn `FnDataset` is to go through the examples below. A list of images:: dataset = delu.data.FnDataset(Image.open, filenames) # dataset[i] returns Image.open(filenames[i]) A list of images that are cached after the first load:: from functools import lru_cache dataset = delu.data.FnDataset(lru_cache(None)(Image.open), filenames) `pathlib.Path` is handy for creating datasets that read from files:: images_dir = Path(...) dataset = delu.data.FnDataset(Image.open, images_dir.iterdir()) If you only need files with specific extensions:: dataset = delu.data.FnDataset(Image.open, images_dir.glob('*.png')) If you only need files with specific extensions located in all subfolders:: dataset = delu.data.FnDataset( Image.open, (x for x in images_dir.rglob('**/*.png') if condition(x)) ) A segmentation dataset:: image_filenames = ... gt_filenames = ... def get(i): return Image.open(image_filenames[i]), Image.open(gt_filenames[i]) dataset = delu.data.FnDataset(get, len(image_filenames)) A dummy dataset that demonstrates that `FnDataset` is a very general thing: .. testcode:: def f(x): return x * 10 def g(x): return x * 2 dataset = delu.data.FnDataset(f, 3, g) # dataset[i] returns g(f(i)) assert len(dataset) == 3 assert dataset[0] == 0 assert dataset[1] == 20 assert dataset[2] == 40 """
[docs] def __init__( self, fn: Callable[..., T], args: Union[int, Iterable], transform: Optional[Callable[[T], Any]] = None, ) -> None: """ Args: fn: the function that produces values based on arguments from ``args`` args: arguments for ``fn``. If an iterable, but not a list, then is casted to a list. If an integer, then the behavior is the same as for ``list(range(args))``. The size of ``args`` defines the return value for `FnDataset.__len__`. transform: if presented, is applied to the return value of `fn` in `FnDataset.__getitem__` Examples: .. code-block:: import PIL.Image as Image import torchvision.transforms as T dataset = delu.data.FnDataset(Image.open, filenames, T.ToTensor()) """ self._fn = fn if isinstance(args, Iterable): if not isinstance(args, list): args = list(args) self._args = args self._transform = transform
[docs] def __len__(self) -> int: """Get the dataset size. See `FnDataset` for details. Returns: size """ return len(self._args) if isinstance(self._args, list) else self._args
[docs] def __getitem__(self, index: int) -> Any: """Get value by index. See `FnDataset` for details. Args: index Returns: value """ if isinstance(self._args, list): x = self._args[index] elif index < self._args: x = index else: raise IndexError(f'Index {index} is out of range') x = self._fn(x) return x if self._transform is None else self._transform(x)
[docs]@deprecated('Instead, use `delu.data.IndexDataset` and `~torch.utils.data.DataLoader`') def make_index_dataloader(size: int, *args, **kwargs) -> DataLoader: """Make `~torch.utils.data.DataLoader` over indices instead of data. <DEPRECATION MESSAGE> This is just a shortcut for ``torch.utils.data.DataLoader(delu.data.IndexDataset(...), ...)``. Args: size: the dataset size *args: positional arguments for `torch.utils.data.DataLoader` **kwargs: keyword arguments for `torch.utils.data.DataLoader` Examples: Usage for training: .. code-block:: train_loader = delu.data.make_index_dataloader( len(train_dataset), batch_size, shuffle=True ) for epoch in range(n_epochs): for i_batch in train_loader: x_batch = X[i_batch] y_batch = Y[i_batch] ... Other examples: .. testcode:: dataset_size = 10 # len(dataset) for batch_idx in delu.data.make_index_dataloader( dataset_size, batch_size=3 ): print(batch_idx) .. testoutput:: tensor([0, 1, 2]) tensor([3, 4, 5]) tensor([6, 7, 8]) tensor([9]) .. testcode:: dataset_size = 10 # len(dataset) for batch_idx in delu.data.make_index_dataloader( dataset_size, 3, drop_last=True ): print(batch_idx) .. testoutput:: tensor([0, 1, 2]) tensor([3, 4, 5]) tensor([6, 7, 8]) """ return DataLoader(IndexDataset(size), *args, **kwargs)
[docs]@deprecated('Instead, use `delu.data.IndexDataset` and `~torch.utils.data.DataLoader`') class IndexLoader: """Like `~torch.utils.data.DataLoader`, but over indices instead of data. <DEPRECATION MESSAGE> **The shuffling logic is delegated to the native PyTorch DataLoader**, i.e. no custom logic is performed under the hood. The data loader which actually generates indices is available as `IndexLoader.loader`. Examples: Usage for training: .. code-block:: train_loader = delu.data.IndexLoader( len(train_dataset), batch_size, shuffle=True ) for epoch in range(n_epochs): for batch_idx in train_loader: ... Other examples: .. testcode:: dataset_size = 10 # len(dataset) for batch_idx in delu.data.IndexLoader(dataset_size, batch_size=3): print(batch_idx) .. testoutput:: tensor([0, 1, 2]) tensor([3, 4, 5]) tensor([6, 7, 8]) tensor([9]) .. testcode:: dataset_size = 10 # len(dataset) for batch_idx in delu.data.IndexLoader(dataset_size, 3, drop_last=True): print(batch_idx) .. testoutput:: tensor([0, 1, 2]) tensor([3, 4, 5]) tensor([6, 7, 8]) """
[docs] def __init__( self, size: int, *args, device: Union[int, str, torch.device] = 'cpu', **kwargs ) -> None: """ Args: size: the number of items (for example, :code:`len(dataset)`) args: positional arguments for `torch.utils.data.DataLoader` device: if not CPU, then all indices are materialized and moved to the device at the beginning of every loop. It can be useful when the indices are applied to non-CPU data (e.g. CUDA-tensors) and moving data between devices takes non-negligible time (which can happen in the case of simple and fast models like MLPs). kwargs: keyword arguments for `torch.utils.data.DataLoader` """ assert size > 0 self._batch_size = args[0] if args else kwargs.get('batch_size', 1) self._loader = DataLoader(IndexDataset(size), *args, **kwargs) if isinstance(device, (int, str)): device = torch.device(device) self._device = device
@property def loader(self) -> DataLoader: """The original DataLoader.""" return self._loader
[docs] def __len__(self) -> int: """Get the size of the original DataLoader.""" return len(self.loader)
def __iter__(self): return iter( self._loader if self._device.type == 'cpu' else torch.cat(list(self.loader)).to(self._device).split(self._batch_size) )
[docs]@deprecated('Instead, use `torch.utils.data.default_collate`') def collate(iterable: Iterable) -> Any: """Almost an alias for :code:`torch.utils.data.dataloader.default_collate`. <DEPRECATION MESSAGE> Namely, the input is allowed to be any kind of iterable, not only a list. Firstly, if it is not a list, it is transformed to a list. Then, the list is passed to the original function and the result is returned as is. """ if not isinstance(iterable, list): iterable = list(iterable) # > Module has no attribute "default_collate" return torch.utils.data.dataloader.default_collate(iterable) # type: ignore