zero.Stream.state_dict

Stream.state_dict()[source]

Get the stream’s state.

The result can be passed to Stream.load_state_dict. The result includes:

  • epoch

  • iteration

Note

Fields related to data (loader, iterator etc.) are NOT included in the state. If you want to save the “state of data stream” then you have to save the state of corresponding random number generators separately.

Returns

state

Return type

Dict[str, Any]

Examples

stream = Stream(range(10))
assert stream.state_dict() == {'epoch': 0, 'iteration': 0}
stream.next()
stream.next()
stream.increment_epoch()
assert stream.state_dict() == {'epoch': 1, 'iteration': 2}