Source code for zero.module

__all__ = ['call', 'ecall', 'evaluation']

import contextlib
import typing as ty

import torch
import torch.nn as nn

T = ty.TypeVar('T')

def _to_device(value: T, device: torch.device) -> T:
    # TODO: support dataclasses
    if isinstance(value, torch.Tensor):
        return  # type: ignore
    elif isinstance(value, (tuple, list)):
        cls = type(value)
        data = (_to_device(x, device) for x in value)
        is_namedtuple = all(
            hasattr(value, x) for x in ['_fields', '_replace', '_asdict']
        return cls(*data) if is_namedtuple else cls(data)  # type: ignore
    elif isinstance(value, dict):
        return type(value)((k, _to_device(v, device)) for k, v in value.items())  # type: ignore
        return value

[docs]def call(module: ty.Union[nn.Module, nn.DataParallel], *args, **kwargs): """Move arguments to module's device and call the module with these arguments. With this function you don't have to do the following anymore: - pass the model's device everywhere as an additional argument along with the model - infer the model's device by hand - move model's arguments to the correct device before the call Args: module: the module. If an instance of `torch.nn.DataParallel`, then arguments are passed to the module as is. *args: **kwargs: Returns: result: the output of :code:`module(*args, **kwargs)` after the arguments are moved to the module's device. Note: The module's device is inferred as the device of its one randomly selected parameter. So, the function works only for cases when all the parameters of the module are located on the same device. Note: The transfer happens only to tensors and simple containers containing tensors (tuples, lists, named tuples and dictionaries). Other values are not changed anyhow. For example, the following call is successfully handled by ``:: # all tensor_X variables will be moved to the module's device model, tensor_0, 1, 'hello', [tensor_1, True, (tensor_2, False)], {'world': [[[tensor_3]]]}, abc=my_namedtuple_with_tensors ) However, if the arguments include instances of custom classes and some of their fields are tensors that need to be moved than this is not the case for ``. See also: - `ecall` Examples: .. testcode:: model = torch.nn.Linear(3, 5) ... # model is moved to some device here x = torch.randn(4, 3), x) # no need to move `x` to the model's device """ if isinstance(module, nn.DataParallel): return module(*args, **kwargs) device = next(module.parameters()).device return module(*_to_device(args, device), **_to_device(kwargs, device))
[docs]@contextlib.contextmanager def evaluation(*modules: nn.Module): """Context-manager for models evaluation. Warning: The function must be used only as a context manager as shown below in the examples. The behaviour for call without the `with` keyword is unspecified. This code...:: model.eval() with torch.no_grad(): ... equivalent to :: with evaluation(model): ... Args: modules See also: - `ecall` Examples: .. testcode:: a = torch.nn.Linear(1, 1) b = torch.nn.Linear(2, 2) with evaluation(a): ... with evaluation(a, b): ... .. testcode:: model = torch.nn.Linear(1, 1) for grad in False, True: for train in False, True: torch.set_grad_enabled(grad) model.train(train) with evaluation(model): assert not assert not torch.is_grad_enabled() ... assert torch.is_grad_enabled() == grad_before_context # is unspecified here """ assert modules for x in modules: x.eval() no_grad_context = torch.no_grad() no_grad_context.__enter__() try: yield finally: no_grad_context.__exit__(None, None, None)
[docs]def ecall(module: ty.Union[nn.Module, nn.DataParallel], *args, **kwargs): """Call the module (torch.no_grad() + module.eval() + The function: 1. switches the module to the evaluation mode 2. turns off gradients 3. moves the arguments to the module's device 4. calls the module and returns the result In fact, the function is just a shortcut for the combination of `evaluation` and `call` (hence, all constraints of the `call` function are inherited, see the its documentation for details), i.e.:: result = ecall(model, x) # is equivalent to: with evaluation(model): result = call(model, x) Args: module: args: kwargs: Returns: result: See also: - `evaluation` - `call` """ with evaluation(module): return call(module, *args, **kwargs)