"""An extension to `torch.random`."""
import random
import secrets
from contextlib import contextmanager
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.cuda
__all__ = ['seed', 'get_state', 'set_state', 'preserve_state']
_2_pow_64 = 1 << 64
[docs]def seed(base_seed: Optional[int], /, *, one_cuda_seed: bool = False) -> int:
"""Set *diverse* global random seeds in `random`, `numpy` and `torch`.
.. note::
For all libraries, *different deterministically computed*
(based on the ``base_seed`` argument) seeds are set to ensure that
different libraries and (by default) devices generate diverse random numbers.
**Usage**
>>> import random
>>> import numpy as np
>>>
>>> def f():
... return (
... random.randint(0, 10 ** 9),
... np.random.rand(10).tolist(),
... torch.randn(20).tolist(),
... )
...
>>> # Numbers sampled under the same random seed are equal.
>>> delu.random.seed(0)
>>> a = f()
>>> delu.random.seed(0)
>>> b = f()
>>> a == b
True
Pass `None` to set a truly random seed generated by the OS:
>>> # Save the generated `seed` for future reproducibility:
>>> seed = delu.random.seed(None)
>>> a = f()
>>> # Reproduce the results:
>>> delu.random.seed(seed)
>>> b = f()
>>> a == b
True
Args:
base_seed: an integer from `[0, 2**64)` used to compute diverse seeds
for all libraries. If `None`,
then an unpredictable seed generated by OS is used and returned.
one_cuda_seed: if `True`, then the same seed will be set for all CUDA devices,
otherwise, different seeds will be set for all CUDA devices.
Returns:
the provided ``base_seed`` or the generated one if ``base_seed=None``.
"""
# The implementation is based on:
# - https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
# - https://github.com/Lightning-AI/lightning/pull/6960#issuecomment-819672341
if base_seed is None:
base_seed = secrets.randbelow(_2_pow_64)
if not (0 <= base_seed < _2_pow_64):
raise ValueError(
'base_seed must be a non-negative integer from [0, 2**64).'
f' The provided value: {base_seed=}'
)
sequence = np.random.SeedSequence(base_seed)
def generate_state(*args, **kwargs) -> np.ndarray:
new_sequence = sequence.spawn(1)[0]
return new_sequence.generate_state(*args, **kwargs)
# To generate a 128-bit seed for the standard library,
# two uint64 numbers are generated and concatenated (literally).
state_std = generate_state(2, dtype=np.uint64).tolist()
random.seed(state_std[0] * _2_pow_64 + state_std[1])
del state_std
np.random.seed(generate_state(4))
torch.manual_seed(int(generate_state(1, dtype=np.uint64)[0]))
if not torch.cuda._is_in_bad_fork():
if one_cuda_seed:
torch.cuda.manual_seed_all(int(generate_state(1, dtype=np.uint64)[0]))
else:
if torch.cuda.is_available():
torch.cuda.init()
for i in range(torch.cuda.device_count()):
torch.cuda.default_generators[i].manual_seed(
int(generate_state(1, dtype=np.uint64)[0])
)
return base_seed
[docs]def get_state() -> Dict[str, Any]:
"""Aggregate the global RNG states from `random`, `numpy` and `torch`.
The result of this function can be passed to `delu.random.set_state`.
An important use case is saving the random states to a checkpoint.
**Usage**
First, a technical example:
>>> import random
>>> import numpy as np
>>>
>>> def f():
... return random.random(), np.random.rand(), torch.rand(1).item()
...
>>> # Save the state before the first call
>>> state = delu.random.get_state()
>>> a1, b1, c1 = f()
>>> # As expected, the second call produces different results:
>>> a2, b2, c2 = f()
>>> print(a1 == a2, b1 == b2, c1 == c2)
False False False
>>>
>>> # Restore the state that was before the first call:
>>> delu.random.set_state(state)
>>> a3, b3, c3 = f()
>>> print(a1 == a3, b1 == b3, c1 == c3)
True True True
An example pseudocode
for saving/loading the global state to/from a checkpoint::
# Resuming from a checkpoint:
checkpoint_path = ...
if checkpoint_path.exists():
checkpoint = torch.load(checkpoint_path)
delu.random.set_state(checkpoint['random_state'])
...
# Training:
for batch in batches:
...
if step % checkpoint_frequency == 0:
torch.save(
{
'model': ...,
'optimizer': ...,
'random_state': delu.random.get_state(),
},
checkpoint_path,
)
Returns:
The aggregated random states.
"""
return {
'random': random.getstate(),
'numpy.random': np.random.get_state(),
'torch.random': torch.random.get_rng_state(),
'torch.cuda': torch.cuda.get_rng_state_all(), # type: ignore
}
[docs]def set_state(state: Dict[str, Any], /, cuda: bool = True) -> None:
"""Set the global RNG states in `random`, `numpy` and `torch`.
**Usage**
See `delu.random.get_state` for usage examples.
Args:
state: the dict with the states produced by `delu.random.get_state`.
If the `'torch.cuda'` key is presented in ``state``, then it must be
a list of the size equal to the device count as reported by
`torch.cuda.device_count`.
If `'torch.cuda'` is not presented in ``state``, the state of the CUDA
global RNGs is not set.
"""
random.setstate(state['random'])
np.random.set_state(state['numpy.random'])
torch.random.set_rng_state(state['torch.random'])
if 'torch.cuda' in state:
torch_cuda_state = state['torch.cuda']
if cuda and torch.cuda.device_count() != len(torch_cuda_state):
raise RuntimeError(
'The provided state of the global CUDA RNGs is not compatible'
f' with the current hardware, because {torch.cuda.device_count()=}'
f' is not equal to {len(torch_cuda_state)=}'
)
torch.cuda.set_rng_state_all(torch_cuda_state)
[docs]@contextmanager
def preserve_state():
"""Save the global RNG states before entering a context/function and restore it on exit.
The function saves the global RNG states in `random`, `numpy` and `torch`
when entering a context/function and restores it on exit/return.
.. note::
Within a context or a function call, random sampling works as usual,
i.e. it continues to be "random".
**Usage**
As a context manager
(the state after the context is the same as before the context):
>>> import random
>>> import numpy as np
>>>
>>> def f():
... return random.random(), np.random.rand(), torch.rand(1).item()
...
>>> with delu.random.preserve_state():
... a = f()
... # Within the context, random sampling continues to be random:
... b = f()
... assert a != b
...
>>> # However, now, the state is reset to what it was before the context.
>>> c = f()
>>> a == c
True
As a decorator
(the state after the call `g()` is the same as before the call):
>>> @delu.random.preserve_state()
... def g():
... return f()
...
>>> a = g()
>>> b = f()
>>> a == b
True
""" # noqa: E501
state = get_state()
try:
yield
finally:
set_state(state)