zero.Stream¶
-
class
zero.
Stream
(loader)[source]¶ Smart wrapper for data loaders and iterables.
Stream
simplifies managing loops, especially in typical Deep Learning scenarios.Stream
:manages the “epoch” and “iteration” variables
allows to dump and restore loop’s state: epoch, iteration, etc.
allows to customize the size of epoch
allows to change the underlying data loader on the fly
enables useful patterns
Tutorial
Let’s start with the most common training loop:
loader = DataLoader(...) iteration = 0 for epoch in range(max_epoch): for batch in loader: iteration += 1 print('Epoch:', epoch, 'Iteration:', iteration) ...
Let’s enhance the loop using
Stream
:stream = Stream(DataLoader(...)) # (A) for epoch in stream.epochs(max_epoch): # (B) for batch in epoch: # (C) print('Epoch:', stream.epoch, 'Iteration:', stream.iteration) # (D) ...
Some comments for the above code:
(A)
Stream
is created by passing a dataloader as a single argument (in fact, you can pass any iterable object); the dataloader is accessible viaStream.loader
(B)
epoch
is an iterator over batches for one epoch(C)
a progress bar for batches is displayed (for the whole training loop, not just for one epoch)(D)
Stream.epoch
andStream.iteration
are managed automatically
Saving the loop’s state and resuming the loop is possible with the methods
Stream.state_dict
,Stream.load_state_dict
. In practice, it can look like this:model = ... optimizer = ... stream = Stream(DataLoader(...)) if load_from_checkpoint: checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) stream.load_state_dict(checkpoint['stream']) ... for epoch in stream.epochs(...): for batch in epoch: ... torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'stream': stream.state_dict(), }, f'checkpoint_{stream.epoch}.pt' )
Note
Stream’s state does not include the loader’s state. See
Stream.state_dict
andStream.load_state_dict
for details.In order to customize the epoch size, pass the size as the second argument:
for epoch in stream.epochs(max_epoch, custom_epoch_size): for batch in epoch: ...
Changing the underlying loader on the fly is possible at any moment (even in the middle of epoch) via
Stream.set_loader
. For example:for epoch in stream.epochs(max_epoch, custom_epoch_size): for batch in epoch: ... if need_new_data(): stream.set_loader(new_loader)
If the method
Stream.epochs
does not fit your workflow and you want more control over the loop, there are more “low-level” methods (in fact,Stream.epochs
is just a thin wrapper around them):For example, the most common training loop can be implemented as follows:
# A while stream.epoch < max_epoch: stream.increment_epoch() for batch in stream.data(): ... # B while stream.epoch < max_epoch: stream.increment_epoch() for _ in range(len(stream.loader)): batch = stream.next() # stream.iteration is incremented automatically ...
The “infinite” stream of data can be implemented as follows:
for item in stream.data(float('inf')): ... if condition: # for example: `if stream.iteration % frequency == 0` ...
Note
For better technical understanding, keep in mind that
Stream
simply encapsulates an “infinite iterator” that is constantly moving forward. The behavior is absolutely the same for both finite and infinite iterables and can be expressed with the following loop:while True: for item in loader: # the loader which is passed to the constructor ...
Documentation for
Stream.next
andStream.data
provide helpful examples.Attributes
Current epoch.
Current iteration.
The underlying loader.
Methods
Initialize self.
Iterate over the loader.
Iterate over data epochs.
Increment
Stream.epoch
.Load state dictionary.
Get the next item and increment iteration.
Set the underlying iterator to
iter(self.loader)
.Set new loader.
Get the stream’s state.