zero.stream

Smart Python loops.

Stream

class zero.stream.Stream(loader)[source]

Smart wrapper for iterables.

Stream simplifies managing loops, especially in typical deep learning scenarios (it is usually used to wrap train_dataloader or any other data source).

Stream:

  • simplifies management of the “epoch” and “iteration” variables

  • allows to dump and restore loop’s state: epoch, iteration, etc.

  • allows to customize the size of epoch

  • allows to change the underlying data loader on the fly

  • enables useful patterns

Parameters

loader – any kind of iterable (DataLoader, list, iterator, generator, …)

Raises

AssertionError – if loader is not an iterator and is empty

Examples

stream = Stream([0, 1, 2, 3])
stream = Stream(range(10))
import itertools
stream = Stream(itertools.repeat(0))

from torch.utils.data import DataLoader, TensorDataset
dataset = TensorDataset(torch.randn(10, 2))
stream = Stream(DataLoader(dataset, batch_size=3, shuffle=True))

Tutorial

Let’s revise the conventional approach without Stream:

loader = DataLoader(...)
iteration = 0
for epoch in range(n_epochs):
    for x in loader:
        iteration += 1
        print('Epoch:', epoch, 'Iteration:', iteration)
        ...

There are several ways how you can use Stream to enhance this loop. Let’s start with creating a stream:

stream = Stream(DataLoader(...))

The dataloader is accessible via Stream.loader. Now, let’s reproduce the loop above:

for epoch in stream.epochs(n_epochs):
    for x in epoch:
        print('Epoch:', stream.epoch, 'Iteration:', stream.iteration)

We see that Stream.epoch and Stream.iteration are managed automatically. Additionally, a progress bar is displayed while the loop is running.

Saving the loop’s state and resuming the loop is possible with the methods Stream.state_dict, Stream.load_state_dict. In practice, it may look like this:

model = ...
optimizer = ...
stream = Stream(DataLoader(...))
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    ...
    stream.load_state_dict(checkpoint['stream'])
...
for epoch in stream.epochs(...):
    for batch in epoch:
        ...
    torch.save(
        {
            'model': model.state_dict(),
            'optimizer': model.state_dict(),
            'stream': stream.state_dict(),
        },
        f'checkpoint_{stream.epoch}.pt'
    )

Note

Stream’s state does not include the loader’s state. See Stream.state_dict and Stream.load_state_dict for details.

In order to customize the epoch size, pass the size as the second argument:

for epoch in stream.epochs(n_epochs, custom_epoch_size):
    for x in epoch:
        ...

Changing the underlying loader on the fly is possible at any moment (even in the middle of epoch) via Stream.set_loader. For example:

for epoch in stream.epochs(n_epochs, custom_epoch_size):
    for x in epoch:
        ...
        if need_new_data():
            stream.set_loader(new_loader)

If the method Stream.epochs does not fit your workflow and you want more control over the loop, there are more “low-level” methods (in fact, Stream.epochs is just a thin wrapper around them):

Note

For better technical understanding, keep in mind that Stream simply encapsulates an “infinite iterator” that is constantly moving forward. The behavior is absolutely the same for both finite and infinite iterables and can be expressed with the following loop:

while True:
    for item in loader:  # loader which is passed in the constructor
        ...

Documentation for Stream.next and Stream.data provide helpful examples.

Stream.iteration

Current iteration.

Stream.epoch

Current epoch.

Stream.increment_epoch()

Increment Stream.epoch.

Stream.loader

The underlying loader.

Stream.set_loader(loader)

Set new loader.

Stream.reload_iterator()

Set the underlying iterator to iter(self.loader).

Stream.next()

Get the next item and increment iteration.

Stream.data([n_items])

Iterate over the loader.

Stream.epochs(n_epochs[, epoch_size, …])

Iterate over data epochs.

Stream.state_dict()

Get the stream’s state.

Stream.load_state_dict(state_dict)

Load state dictionary.