zero.random.get_state

zero.random.get_state()[source]

Aggregate global random states from random, numpy and torch.

The function is useful for creating checkpoints that allow to resume data streams or other activities dependent on global random number generator (see the note below ). The result of this function can be passed to set_state.

Returns

state

Return type

Dict[str, Any]

Note

The most reliable way to guarantee reproducibility and to make your data streams resumable is to create separate random number generators and manage them manually (for example, torch.utils.data.DataLoader accepts the argument generator for that purposes). However, if you rely on the global random state, this function along with set_state does everything just right.

See also

set_state

Examples

model = torch.nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
...
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'random_state': zero.random.get_state(),
}
# later
# torch.save(checkpoint, 'checkpoint.pt')
# ...
# zero.random.set_state(torch.load('checkpoint.pt')['random_state'])