zero.random.get_state

zero.random.get_state()[source]

Aggregate global random states from random, numpy and torch.

The function is useful for creating checkpoints that allow to resume data streams or other activities dependent on global random number generator (see the note below ). The result of this function can be passed to set_state.

Returns

state

Return type

Dict[str, Any]

See also

set_state

Examples

model = torch.nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
...
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'random_state': zero.random.get_state(),
}
# later
# torch.save(checkpoint, 'checkpoint.pt')
# ...
# zero.random.set_state(torch.load('checkpoint.pt')['random_state'])