zero.Stream

class zero.Stream(loader)[source]

Smart wrapper for data loaders and iterables.

Stream simplifies managing loops, especially in typical Deep Learning scenarios. Stream:

  • manages the “epoch” and “iteration” variables

  • allows to dump and restore loop’s state: epoch, iteration, etc.

  • allows to customize the size of epoch

  • allows to change the underlying data loader on the fly

  • enables useful patterns

Tutorial

Let’s start with the most common training loop:

loader = DataLoader(...)
iteration = 0
for epoch in range(max_epoch):
    for batch in loader:
        iteration += 1
        print('Epoch:', epoch, 'Iteration:', iteration)
        ...

Let’s enhance the loop using Stream:

stream = Stream(DataLoader(...))  # (A)
for epoch in stream.epochs(max_epoch):  # (B)
    for batch in epoch:  # (C)
        print('Epoch:', stream.epoch, 'Iteration:', stream.iteration)  # (D)
        ...

Some comments for the above code:

  • (A) Stream is created by passing a dataloader as a single argument (in fact, you can pass any iterable object); the dataloader is accessible via Stream.loader

  • (B) epoch is an iterator over batches for one epoch

  • (C) a progress bar for batches is displayed (for the whole training loop, not just for one epoch)

  • (D) Stream.epoch and Stream.iteration are managed automatically

Saving the loop’s state and resuming the loop is possible with the methods Stream.state_dict, Stream.load_state_dict. In practice, it can look like this:

model = ...
optimizer = ...
stream = Stream(DataLoader(...))
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    stream.load_state_dict(checkpoint['stream'])
...
for epoch in stream.epochs(...):
    for batch in epoch:
        ...
    torch.save(
        {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'stream': stream.state_dict(),
        },
        f'checkpoint_{stream.epoch}.pt'
    )

Note

Stream’s state does not include the loader’s state. See Stream.state_dict and Stream.load_state_dict for details.

In order to customize the epoch size, pass the size as the second argument:

for epoch in stream.epochs(max_epoch, custom_epoch_size):
    for batch in epoch:
        ...

Changing the underlying loader on the fly is possible at any moment (even in the middle of epoch) via Stream.set_loader. For example:

for epoch in stream.epochs(max_epoch, custom_epoch_size):
    for batch in epoch:
        ...
        if need_new_data():
            stream.set_loader(new_loader)

If the method Stream.epochs does not fit your workflow and you want more control over the loop, there are more “low-level” methods (in fact, Stream.epochs is just a thin wrapper around them):

For example, the most common training loop can be implemented as follows:

while stream.epoch < max_epoch:
    stream.increment_epoch()
    for batch in stream.data():
        ...

Or even like this:

while stream.epoch < max_epoch:
    stream.increment_epoch()
    for _ in range(len(stream.loader)):
        batch = stream.next()  # stream.iteration is incremented automatically
        ...

Note

For better technical understanding, keep in mind that Stream simply encapsulates an “infinite iterator” that is constantly moving forward. The behavior is absolutely the same for both finite and infinite iterables and can be expressed with the following loop:

while True:
    for item in loader:  # the loader which is passed to the constructor
        ...

Documentation for Stream.next and Stream.data provide helpful examples.

Attributes

epoch

Current epoch.

iteration

Current iteration.

loader

The underlying loader.

Methods

__init__

Initialize self.

data

Iterate over the loader.

epochs

Iterate over data epochs.

increment_epoch

Increment Stream.epoch.

load_state_dict

Load state dictionary.

next

Get the next item and increment iteration.

reload_iterator

Set the underlying iterator to iter(self.loader).

set_loader

Set new loader.

state_dict

Get the stream’s state.