Source code for delu.random

"""An extension to `torch.random`."""

import random
import secrets
from contextlib import contextmanager
from typing import Any, Dict, Optional

import numpy as np
import torch
import torch.cuda

__all__ = ['seed', 'get_state', 'set_state', 'preserve_state']

_2_pow_64 = 1 << 64


[docs]def seed(base_seed: Optional[int], /, *, one_cuda_seed: bool = False) -> int: """Set *diverse* global random seeds in `random`, `numpy` and `torch`. .. note:: For all libraries, *different deterministically computed* (based on the ``base_seed`` argument) seeds are set to ensure that different libraries and (by default) devices generate diverse random numbers. **Usage** >>> import random >>> import numpy as np >>> >>> def f(): ... return ( ... random.randint(0, 10 ** 9), ... np.random.rand(10).tolist(), ... torch.randn(20).tolist(), ... ) ... >>> # Numbers sampled under the same random seed are equal. >>> delu.random.seed(0) >>> a = f() >>> delu.random.seed(0) >>> b = f() >>> a == b True Pass `None` to set a truly random seed generated by the OS: >>> # Save the generated `seed` for future reproducibility: >>> seed = delu.random.seed(None) >>> a = f() >>> # Reproduce the results: >>> delu.random.seed(seed) >>> b = f() >>> a == b True Args: base_seed: an integer from `[0, 2**64)` used to compute diverse seeds for all libraries. If `None`, then an unpredictable seed generated by OS is used and returned. one_cuda_seed: if `True`, then the same seed will be set for all CUDA devices, otherwise, different seeds will be set for all CUDA devices. Returns: the provided ``base_seed`` or the generated one if ``base_seed=None``. """ # The implementation is based on: # - https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 # - https://github.com/Lightning-AI/lightning/pull/6960#issuecomment-819672341 if base_seed is None: base_seed = secrets.randbelow(_2_pow_64) if not (0 <= base_seed < _2_pow_64): raise ValueError( 'base_seed must be a non-negative integer from [0, 2**64).' f' The provided value: {base_seed=}' ) sequence = np.random.SeedSequence(base_seed) def generate_state(*args, **kwargs) -> np.ndarray: new_sequence = sequence.spawn(1)[0] return new_sequence.generate_state(*args, **kwargs) # To generate a 128-bit seed for the standard library, # two uint64 numbers are generated and concatenated (literally). state_std = generate_state(2, dtype=np.uint64).tolist() random.seed(state_std[0] * _2_pow_64 + state_std[1]) del state_std np.random.seed(generate_state(4)) torch.manual_seed(int(generate_state(1, dtype=np.uint64)[0])) if not torch.cuda._is_in_bad_fork(): if one_cuda_seed: torch.cuda.manual_seed_all(int(generate_state(1, dtype=np.uint64)[0])) else: if torch.cuda.is_available(): torch.cuda.init() for i in range(torch.cuda.device_count()): torch.cuda.default_generators[i].manual_seed( int(generate_state(1, dtype=np.uint64)[0]) ) return base_seed
[docs]def get_state() -> Dict[str, Any]: """Aggregate the global RNG states from `random`, `numpy` and `torch`. The result of this function can be passed to `delu.random.set_state`. An important use case is saving the random states to a checkpoint. **Usage** First, a technical example: >>> import random >>> import numpy as np >>> >>> def f(): ... return random.random(), np.random.rand(), torch.rand(1).item() ... >>> # Save the state before the first call >>> state = delu.random.get_state() >>> a1, b1, c1 = f() >>> # As expected, the second call produces different results: >>> a2, b2, c2 = f() >>> print(a1 == a2, b1 == b2, c1 == c2) False False False >>> >>> # Restore the state that was before the first call: >>> delu.random.set_state(state) >>> a3, b3, c3 = f() >>> print(a1 == a3, b1 == b3, c1 == c3) True True True An example pseudocode for saving/loading the global state to/from a checkpoint:: # Resuming from a checkpoint: checkpoint_path = ... if checkpoint_path.exists(): checkpoint = torch.load(checkpoint_path) delu.random.set_state(checkpoint['random_state']) ... # Training: for batch in batches: ... if step % checkpoint_frequency == 0: torch.save( { 'model': ..., 'optimizer': ..., 'random_state': delu.random.get_state(), }, checkpoint_path, ) Returns: The aggregated random states. """ return { 'random': random.getstate(), 'numpy.random': np.random.get_state(), 'torch.random': torch.random.get_rng_state(), 'torch.cuda': torch.cuda.get_rng_state_all(), # type: ignore }
[docs]def set_state(state: Dict[str, Any], /, cuda: bool = True) -> None: """Set the global RNG states in `random`, `numpy` and `torch`. **Usage** See `delu.random.get_state` for usage examples. Args: state: the dict with the states produced by `delu.random.get_state`. If the `'torch.cuda'` key is presented in ``state``, then it must be a list of the size equal to the device count as reported by `torch.cuda.device_count`. If `'torch.cuda'` is not presented in ``state``, the state of the CUDA global RNGs is not set. """ random.setstate(state['random']) np.random.set_state(state['numpy.random']) torch.random.set_rng_state(state['torch.random']) if 'torch.cuda' in state: torch_cuda_state = state['torch.cuda'] if cuda and torch.cuda.device_count() != len(torch_cuda_state): raise RuntimeError( 'The provided state of the global CUDA RNGs is not compatible' f' with the current hardware, because {torch.cuda.device_count()=}' f' is not equal to {len(torch_cuda_state)=}' ) torch.cuda.set_rng_state_all(torch_cuda_state)
[docs]@contextmanager def preserve_state(): """Save the global RNG states before entering a context/function and restore it on exit. The function saves the global RNG states in `random`, `numpy` and `torch` when entering a context/function and restores it on exit/return. .. note:: Within a context or a function call, random sampling works as usual, i.e. it continues to be "random". **Usage** As a context manager (the state after the context is the same as before the context): >>> import random >>> import numpy as np >>> >>> def f(): ... return random.random(), np.random.rand(), torch.rand(1).item() ... >>> with delu.random.preserve_state(): ... a = f() ... # Within the context, random sampling continues to be random: ... b = f() ... assert a != b ... >>> # However, now, the state is reset to what it was before the context. >>> c = f() >>> a == c True As a decorator (the state after the call `g()` is the same as before the call): >>> @delu.random.preserve_state() ... def g(): ... return f() ... >>> a = g() >>> b = f() >>> a == b True """ # noqa: E501 state = get_state() try: yield finally: set_state(state)