Source code for zero.random
"""Random sampling utilities."""
import random
from contextlib import contextmanager
from typing import Any, Dict
import numpy as np
import torch
[docs]def seed(seed: int) -> None:
    """Set seeds in `random`, `numpy` and `torch`.
    Args:
        seed: the seed for all mentioned libraries. Must be a non-negative number less
            than :code:`2 ** 32 - 1`.
    Raises:
        AssertionError: if the seed is not within the required interval
    Examples:
        .. testcode::
            zero.random.seed(0)
    """
    assert 0 <= seed < 2 ** 32 - 1
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # mypy doesn't know about the following functions
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.cuda.manual_seed_all(seed)  # type: ignore 
[docs]def get_state() -> Dict[str, Any]:
    """Aggregate global random states from `random`, `numpy` and `torch`.
    The function is useful for creating checkpoints that allow to resume data streams or
    other activities dependent on **global** random number generator (see the note below
    ). The result of this function can be passed to `set_state`.
    Returns:
        state
    See also:
        `set_state`
    Examples:
        .. testcode::
            model = torch.nn.Linear(1, 1)
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
            ...
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'random_state': zero.random.get_state(),
            }
            # later
            # torch.save(checkpoint, 'checkpoint.pt')
            # ...
            # zero.random.set_state(torch.load('checkpoint.pt')['random_state'])
    """
    return {
        'random': random.getstate(),
        'numpy.random': np.random.get_state(),
        'torch.random': torch.random.get_rng_state(),
        'torch.cuda': torch.cuda.get_rng_state_all(),  # type: ignore
    } 
[docs]def set_state(state: Dict[str, Any]) -> None:
    """Set global random states in `random`, `numpy` and `torch`.
    Args:
        state: global RNG states. Must be produced by `get_state`. The size of the list
            :code:`state['torch.cuda']` must be equal to the number of available cuda
            devices.
    See also:
        `get_state`
    Raises:
        AssertionError: if :code:`torch.cuda.device_count() != len(state['torch.cuda'])`
    """
    random.setstate(state['random'])
    np.random.set_state(state['numpy.random'])
    torch.random.set_rng_state(state['torch.random'])
    assert torch.cuda.device_count() == len(state['torch.cuda'])
    torch.cuda.set_rng_state_all(state['torch.cuda'])  # type: ignore 
[docs]@contextmanager
def preserve_state():
    """Decorator and a context manager for preserving global random state.
    Examples:
        .. testcode::
            import random
            f = lambda: (
                random.randint(0, 10),
                np.random.randint(10),
                torch.randint(10, (1,)).item()
            )
            with preserve_state():
                a = f()
            b = f()
            assert a == b
            @preserve_state()
            def g():
                return f()
            with preserve_state():
                a = g()
            b = g()
            assert a == b
    """
    state = get_state()
    try:
        yield
    finally:
        set_state(state)