Source code for zero.random

"""Random sampling utilities."""

__all__ = ['get_random_state', 'set_random_state', 'set_randomness']

import random
import secrets
from typing import Any, Dict, Optional

import numpy as np
import torch


[docs]def set_randomness( seed: Optional[int], cudnn_deterministic: bool = True, cudnn_benchmark: bool = False, ) -> int: """Set seeds and settings in `random`, `numpy` and `torch`. Sets random seed in `random`, `numpy.random`, `torch`, `torch.cuda` and sets settings in :code:`torch.backends.cudnn`. Args: seed: the seed for all mentioned libraries. If `None`, a **high-quality** seed is generated and used instead. cudnn_deterministic: value for :code:`torch.backends.cudnn.deterministic` cudnn_benchmark: value for :code:`torch.backends.cudnn.benchmark` Returns: seed: if :code:`seed` is set to `None`, the generated seed is returned; otherwise the :code:`seed` argument is returned as is Note: If you don't want to set the seed by hand, but still want to have a chance to reproduce things, you can use the following pattern:: print('Seed:', set_randomness(None)) Examples: .. testcode:: assert set_randomness(0) == 0 assert set_randomness(), '0 was generated as the seed, which is almost impossible!' """ torch.backends.cudnn.deterministic = cudnn_deterministic # type: ignore torch.backends.cudnn.benchmark = cudnn_benchmark # type: ignore raw_seed = seed if raw_seed is None: # See https://numpy.org/doc/1.18/reference/random/bit_generators/index.html#seeding-and-entropy # noqa raw_seed = secrets.randbits(128) seed = raw_seed % (2 ** 32 - 1) torch.manual_seed(seed) # mypy doesn't know about the following functions torch.cuda.manual_seed(seed) # type: ignore torch.cuda.manual_seed_all(seed) # type: ignore np.random.seed(seed) random.seed(seed) return seed
[docs]def get_random_state() -> Dict[str, Any]: """Aggregate global random states from `random`, `numpy` and `torch`. The function is useful for creating checkpoints that allow to resume data streams or other activities dependent on **global** random number generator (see the note below ). The result of this function can be passed to `set_random_state`. Returns: state Note: The most reliable way to guarantee reproducibility and to make your data streams resumable is to create separate random number generators and manage them manually (for example, `torch.utils.data.DataLoader` accepts the argument :code:`generator` for that purposes). However, if you rely on the global random state, this function along with `set_random_state` does everything just right. See also: `set_random_state` Examples: .. testcode:: model = torch.nn.Linear(1, 1) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) ... checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'random_state': get_random_state(), } # later # torch.save(checkpoint, 'checkpoint.pt') # ... # set_random_state(torch.load('checkpoint.pt')['random_state']) """ return { 'random': random.getstate(), 'numpy.random': np.random.get_state(), 'torch.random': torch.random.get_rng_state(), 'torch.cuda': torch.cuda.get_rng_state_all(), # type: ignore }
[docs]def set_random_state(state: Dict[str, Any]) -> None: """Set global random states from `random`, `numpy` and `torch`. The argument must be produced by `get_random_state`. Note: The size of list :code:`state['torch.cuda']` must be equal to the number of available cuda devices. If random state of cuda devices is not important, remove the entry 'torch.cuda' from the state beforehand, or, **at your own risk** adjust its value. Raises: AssertionError: if :code:`torch.cuda.device_count() != len(state['torch.cuda'])` """ random.setstate(state['random']) np.random.set_state(state['numpy.random']) torch.random.set_rng_state(state['torch.random']) assert torch.cuda.device_count() == len(state['torch.cuda']) torch.cuda.set_rng_state_all(state['torch.cuda']) # type: ignore