get_state#

delu.random.get_state() Dict[str, Any][source]#

Aggregate the global RNG states from random, numpy and torch.

The result of this function can be passed to delu.random.set_state. An important use case is saving the random states to a checkpoint.

Usage

First, a technical example:

>>> import random
>>> import numpy as np
>>>
>>> def f():
...     return random.random(), np.random.rand(), torch.rand(1).item()
...
>>> # Save the state before the first call
>>> state = delu.random.get_state()
>>> a1, b1, c1 = f()
>>> # As expected, the second call produces different results:
>>> a2, b2, c2 = f()
>>> print(a1 == a2, b1 == b2, c1 == c2)
False False False
>>>
>>> # Restore the state that was before the first call:
>>> delu.random.set_state(state)
>>> a3, b3, c3 = f()
>>> print(a1 == a3, b1 == b3, c1 == c3)
True True True

An example pseudocode for saving/loading the global state to/from a checkpoint:

# Resuming from a checkpoint:
checkpoint_path = ...
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path)
    delu.random.set_state(checkpoint['random_state'])
...
# Training:
for batch in batches:
    ...
    if step % checkpoint_frequency == 0:
        torch.save(
            {
                'model': ...,
                'optimizer': ...,
                'random_state': delu.random.get_state(),
            },
            checkpoint_path,
        )
Returns:

The aggregated random states.