"""Smart Python loops."""
__all__ = ['Stream']
import math
from typing import Any, Dict, Iterable, Iterator, Optional, Sized, Union
from tqdm import tqdm
def _try_len(x):
try:
return len(x)
except (TypeError, NotImplementedError):
return None
[docs]class Stream:
"""Smart wrapper for iterables.
`Stream` simplifies managing loops, especially in typical deep learning scenarios
(it is usually used to wrap :code:`train_dataloader` or any other data source).
`Stream`:
- simplifies management of the "epoch" and "iteration" variables
- allows to dump and restore loop's state: epoch, iteration, etc.
- allows to customize the size of epoch
- allows to change the underlying data loader on the fly
- enables useful patterns
Args:
loader: any kind of iterable (DataLoader, list, iterator, generator, ...)
Raises:
AssertionError: if :code:`loader` is not an iterator and is empty
Examples:
.. testcode::
stream = Stream([0, 1, 2, 3])
stream = Stream(range(10))
import itertools
stream = Stream(itertools.repeat(0))
from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(torch.randn(10, 2))
stream = Stream(DataLoader(dataset, batch_size=3, shuffle=True))
.. rubric:: Tutorial
Let's revise the conventional approach without `Stream`:
.. code-block::
loader = DataLoader(...)
iteration = 0
for epoch in range(n_epochs):
for x in loader:
iteration += 1
print('Epoch:', epoch, 'Iteration:', iteration)
...
There are several ways how you can use `Stream` to enhance this loop. Let's start
with creating a stream::
stream = Stream(DataLoader(...))
The dataloader is accessible via `Stream.loader`. Now, let's reproduce the loop
above::
for epoch in stream.epochs(n_epochs):
for x in epoch:
print('Epoch:', stream.epoch, 'Iteration:', stream.iteration)
We see that `Stream.epoch` and `Stream.iteration` are managed automatically.
Additionally, a progress bar is displayed while the loop is running.
Saving the loop's state and resuming the loop is possible with the methods
`Stream.state_dict`, `Stream.load_state_dict`. In practice, it may look like this::
model = ...
optimizer = ...
stream = Stream(DataLoader(...))
if load_from_checkpoint:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
...
stream.load_state_dict(checkpoint['stream'])
...
for epoch in stream.epochs(...):
for batch in epoch:
...
torch.save(
{
'model': model.state_dict(),
'optimizer': model.state_dict(),
'stream': stream.state_dict(),
},
f'checkpoint_{stream.epoch}.pt'
)
Note:
Stream's state does not include the loader's state. See `Stream.state_dict` and
`Stream.load_state_dict` for details.
In order to customize the epoch size, pass the size as the second argument::
for epoch in stream.epochs(n_epochs, custom_epoch_size):
for x in epoch:
...
Changing the underlying loader on the fly is possible at *any* moment (even in the
middle of epoch) via `Stream.set_loader`. For example::
for epoch in stream.epochs(n_epochs, custom_epoch_size):
for x in epoch:
...
if need_new_data():
stream.set_loader(new_loader)
If the method `Stream.epochs` does not fit your workflow and you want more control
over the loop, there are more "low-level" methods (in fact, `Stream.epochs` is just
a thin wrapper around them):
- `Stream.increment_epoch`
- `Stream.data`
- `Stream.next`
Note:
For better technical understanding, keep in mind that `Stream` simply
encapsulates an "infinite iterator" that is constantly moving forward. The
behavior is absolutely the same for both finite and infinite iterables and can
be expressed with the following loop::
while True:
for item in loader: # loader which is passed in the constructor
...
Documentation for `Stream.next` and `Stream.data` provide helpful examples.
"""
class _EpochData:
def __init__(self, stream, size):
self._stream = stream
self._size = size
self._start = self._stream.iteration
def __iter__(self):
return self
def __next__(self):
if (
self._size is not None
and self._stream.iteration - self._start >= self._size
):
raise StopIteration()
return self._stream.next()
def __init__(self, loader: Iterable) -> None:
assert _try_len(loader) != 0
self._iteration = 0
self._epoch = 0
self._loader = loader
self._iter: Optional[Iterator] = None
self._pbar: Optional[tqdm] = None
@property
def iteration(self) -> int:
"""Current iteration.
Technically, the number of `Stream.next` calls.
"""
return self._iteration
@property
def epoch(self) -> int:
"""Current epoch.
Technically, the number of "succeeded" `Stream.increment_epoch` calls.
"""
return self._epoch
@property
def loader(self) -> Iterable:
"""The underlying loader."""
return self._loader
[docs] def set_loader(self, loader: Iterable) -> None:
"""Set new loader.
Args:
loader:
Raises:
AssertionError: if :code:`loader` is not an iterator and is empty.
Examples:
.. testcode::
from itertools import repeat
stream = Stream(repeat(0))
for x in stream.data(5):
print(stream.iteration, x)
if stream.iteration == 2:
stream.set_loader(repeat(1))
.. testoutput::
1 0
2 0
3 1
4 1
5 1
"""
assert _try_len(loader) != 0
self._loader = loader
if self._iter is not None:
self._iter = iter(loader)
def _increment_iteration(self):
self._iteration += 1
[docs] def increment_epoch(self) -> None:
"""Increment `Stream.epoch`.
Examples:
.. testcode::
stream = Stream(range(5))
assert stream.epoch == 0
stream.increment_epoch()
assert stream.epoch == 1
stream.increment_epoch()
assert stream.epoch == 2
"""
self._epoch += 1
[docs] def reload_iterator(self) -> None:
"""Set the underlying iterator to `iter(self.loader)`.
If the underlying loader is a finite iterable, the method can be used to
interrupt and skip the current epoch (i.e. skip its data). If the loader is an
iterator, the method does nothing.
Examples:
.. testcode::
stream = Stream(range(5))
assert stream.next() == 0
assert stream.next() == 1
stream.reload_iterator()
assert stream.next() == 0
stream = Stream(iter(range(5)))
assert stream.next() == 0
assert stream.next() == 1
stream.reload_iterator()
assert stream.next() == 2
"""
self._iter = iter(self.loader)
[docs] def next(self) -> Any:
"""Get the next item and increment iteration.
Returns:
The next item.
Raises:
StopIteration: if :code:`loader` is a finite iterator and the data is over
Examples:
.. testcode::
stream = Stream(range(3))
assert stream.iteration == 0
assert stream.next() == 0
assert stream.iteration == 1
assert stream.next() == 1
assert stream.next() == 2
assert stream.next() == 0
assert stream.iteration == 4
.. code-block::
while True:
x = stream.next()
...
if stream.iteration % frequency:
...
"""
if self._iter is None:
self._iter = iter(self._loader)
try:
value = next(self._iter)
except StopIteration:
self.reload_iterator()
# If the following line raises StopIteration too, then the data is over
# and the exception should be just propagated.
value = next(self._iter)
self._increment_iteration()
if self._pbar is not None:
self._pbar.update()
return value
[docs] def data(self, n_items: Optional[Union[int, float]] = None) -> Iterator:
"""Iterate over the loader.
Under the hood, `Stream.next` is called, hence, `Stream.iteration` changes
during iterations.
Args:
n_items: how many items to produce. If `None`, interpreted as
:code:`len(self.loader)`. If `float`, must be `math.inf`.
Raises:
AssertionError: if :code:`n_items` is float, but not `math.inf`
ValueError: if :code:`loader` is an iterator and :code:`n_items` is
`None`
Examples:
.. testcode::
stream = Stream(range(5))
assert list(stream.data()) == [0, 1, 2, 3, 4]
assert list(stream.data(3)) == [0, 1, 2]
# stream doesn't "start over"!
assert list(stream.data(3)) == [3, 4, 0]
assert list(stream.data(1)) == [1]
assert list(stream.data(2)) == [2, 3]
.. code-block::
for x in stream.data(math.inf):
...
if stream.iteration % frequency:
...
"""
if isinstance(n_items, float):
assert math.isinf(n_items)
if n_items is None:
if not isinstance(self.loader, Sized):
raise ValueError()
n_items = len(self.loader)
return Stream._EpochData(self, n_items)
[docs] def epochs(
self,
n_epochs: Union[int, float],
epoch_size: Optional[Union[int, float]] = None,
progress_bar: bool = True,
) -> Iterator[Iterator[Any]]:
"""Iterate over data epochs.
A shortcut for what is probably the most popular form of a training loop in Deep
Learning (plus a progress bar)::
for epoch in stream.epochs(n_epochs, epoch_size):
for x in epoch:
...
# is equivalent to:
while stream.epoch < n_epochs:
stream.increment_epoch()
for x in stream.data(epoch_size):
...
Args:
n_epochs: the number of epochs. If `float`, must be `math.inf`.
epoch_size: the number of data items in one epoch (is forwarded to
`Stream.data`)
progress_bar: show the progress bar for iterations. The initial value is set
to `Stream.iteration`. See also the note below.
Returns:
Iterator over iterators over data from `Stream.loader`.
Raises:
AssertionError: if :code:`n_epochs` if `float`, but not `math.inf`.
Note:
If :code:`progress_bar` is True, *the progress bar is updated on yielding
every item* which means that the progress bar should be interpreted as "what
iteration is in progress" instead of "how many iterations are done". The
percentage will be displayed only if the total number of planned iterations
can be inferred from the arguments and/or from `Stream.loader`.
Examples:
.. testcode::
stream = Stream(range(3))
for epoch in stream.epochs(2):
for x in epoch:
print(x)
print('-')
.. testoutput::
0
1
2
-
0
1
2
-
.. testcode::
stream = Stream(range(3))
for epoch in stream.epochs(3, 2):
for x in epoch:
print(x)
print('-')
.. testoutput::
0
1
-
2
0
-
1
2
-
"""
if isinstance(n_epochs, float):
assert math.isinf(n_epochs)
if progress_bar:
pbar_epoch_size = (
_try_len(self.loader) if epoch_size is None else epoch_size
)
self._pbar = tqdm(
initial=self.iteration,
total=None
if (pbar_epoch_size is None or math.isinf(n_epochs))
else n_epochs * pbar_epoch_size,
)
while self.epoch < n_epochs:
self.increment_epoch()
yield self.data(epoch_size)
[docs] def state_dict(self) -> Dict[str, Any]:
"""Get the stream's state.
The result can be passed to `Stream.load_state_dict`. The result includes:
- epoch
- iteration
Note:
Fields related to data (loader, iterator etc.) are **NOT** included in the
state. If you want to save the "state of data stream" then you have to save
the state of corresponding random number generators separately.
Returns:
state
See also:
`Stream.load_state_dict`
Examples:
.. testcode::
stream = Stream(range(10))
assert stream.state_dict() == {'epoch': 0, 'iteration': 0}
stream.next()
stream.next()
stream.increment_epoch()
assert stream.state_dict() == {'epoch': 1, 'iteration': 2}
"""
return {'iteration': self.iteration, 'epoch': self.epoch}
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Load state dictionary.
Args:
state_dict: state. Must be produced by `Stream.state_dict`.
Note:
The method does not affect data that is produced by `Stream.epochs`,
`Stream.data`, `Stream.next` (see the examples below), i.e. the method
only sets some "metadata" such as epoch, iteration etc. If you want to
load the "state of data stream", you have to load the state of corresponding
random number generators separately.
See also:
`Stream.state_dict`
Examples:
.. testcode::
stream = Stream(range(10))
stream.next()
stream.increment_epoch()
assert stream.state_dict() == {'epoch': 1, 'iteration': 1}
new_stream = Stream(range(10))
new_stream.load_state_dict(stream.state_dict())
assert new_stream.state_dict() == {'epoch': 1, 'iteration': 1}
assert new_stream.next() == 0
assert new_stream.state_dict() == {'epoch': 1, 'iteration': 2}
"""
self._iteration = state_dict['iteration']
self._epoch = state_dict['epoch']