zero.stream.Stream.load_state_dict

Stream.load_state_dict(state_dict)[source]

Load state dictionary.

Parameters

state_dict (Dict[str, Any]) – state. Must be produced by Stream.state_dict.

Return type

None

Note

The method does not affect data that is produced by Stream.epochs, Stream.data, Stream.next (see the examples below), i.e. the method only sets some “metadata” such as epoch, iteration etc. If you want to load the “state of data stream”, you have to load the state of corresponding random number generators separately.

Examples

stream = Stream(range(10))
stream.next()
stream.increment_epoch()
assert stream.state_dict() == {'epoch': 1, 'iteration': 1}

new_stream = Stream(range(10))
new_stream.load_state_dict(stream.state_dict())
assert new_stream.state_dict() == {'epoch': 1, 'iteration': 1}
assert new_stream.next() == 0
assert new_stream.state_dict() == {'epoch': 1, 'iteration': 2}