get_state¶
- delu.random.get_state() Dict[str, Any] [source]¶
Aggregate the global RNG states from
random
,numpy
andtorch
.The result of this function can be passed to
delu.random.set_state
. An important use case is saving the random states to a checkpoint.Usage
First, a technical example:
>>> import random >>> import numpy as np >>> >>> def f(): ... return random.random(), np.random.rand(), torch.rand(1).item() ... >>> # Save the state before the first call >>> state = delu.random.get_state() >>> a1, b1, c1 = f() >>> # As expected, the second call produces different results: >>> a2, b2, c2 = f() >>> print(a1 == a2, b1 == b2, c1 == c2) False False False >>> >>> # Restore the state that was before the first call: >>> delu.random.set_state(state) >>> a3, b3, c3 = f() >>> print(a1 == a3, b1 == b3, c1 == c3) True True True
An example pseudocode for saving/loading the global state to/from a checkpoint:
# Resuming from a checkpoint: checkpoint_path = ... if checkpoint_path.exists(): checkpoint = torch.load(checkpoint_path) delu.random.set_state(checkpoint['random_state']) ... # Training: for batch in batches: ... if step % checkpoint_frequency == 0: torch.save( { 'model': ..., 'optimizer': ..., 'random_state': delu.random.get_state(), }, checkpoint_path, )
- Returns:
The aggregated random states.