Source code for delu.nn

"""An extension to `torch.nn`."""

import inspect
import warnings
from collections import OrderedDict
from typing import Callable, Tuple, Union

import torch.nn
import torch.nn as nn
from torch.nn.parameter import Parameter

from ._utils import deprecated

__all__ = ['Lambda', 'NLinear', 'named_sequential']


[docs]class Lambda(torch.nn.Module): """A wrapper for functions from `torch` and methods of `torch.Tensor`. An important "feature" of this module is that it is intentionally limited: - Only the functions from the `torch` module and the methods of `torch.Tensor` are allowed. - The passed callable must accept a single `torch.Tensor` and return a single `torch.Tensor`. - The allowed keyword arguments must be of simple types (see the docstring). **Usage** >>> m = delu.nn.Lambda(torch.squeeze) >>> m(torch.randn(2, 1, 3, 1)).shape torch.Size([2, 3]) >>> m = delu.nn.Lambda(torch.squeeze, dim=1) >>> m(torch.randn(2, 1, 3, 1)).shape torch.Size([2, 3, 1]) >>> m = delu.nn.Lambda(torch.Tensor.abs_) >>> m(torch.tensor(-1.0)) tensor(1.) Custom functions are not allowed (technically, they are **temporarily** allowed, but this functionality is deprecated and will be removed in future releases): >>> # xdoctest: +SKIP >>> m = delu.nn.Lambda(lambda x: torch.abs(x)) Traceback (most recent call last): ... ValueError: fn must be a function from `torch` or a method of `torch.Tensor`, but ... Non-trivial keyword arguments are not allowed: >>> m = delu.nn.Lambda(torch.mul, other=torch.tensor(2.0)) Traceback (most recent call last): ... ValueError: For kwargs, the allowed value types include: ... """ # noqa: E501
[docs] def __init__(self, fn: Callable[..., torch.Tensor], /, **kwargs) -> None: """ Args: fn: the callable. kwargs: the keyword arguments for ``fn``. The allowed values types include: None, bool, int, float, bytes, str and (nested) tuples of these simple types. """ super().__init__() if not callable(fn) or ( fn not in vars(torch).values() and ( fn not in (member for _, member in inspect.getmembers(torch.Tensor)) or inspect.ismethod(fn) # Check if fn is a @classmethod ) ): warnings.warn( 'Passing custom functions to delu.nn.Lambda is deprecated' ' and will be removed in future releases.' ' Only functions from the `torch` module and methods of `torch.Tensor`' ' are allowed', DeprecationWarning, ) # NOTE: in future releases, replace the above warning with this exception: # raise ValueError( # 'fn must be a function from `torch` or a method of `torch.Tensor`,' # f' but this is not true for the passed {fn=}' # ) def is_valid_value(x): return ( x is None or isinstance(x, (bool, int, float, bytes, str)) or isinstance(x, tuple) and all(map(is_valid_value, x)) ) for k, v in kwargs.items(): if not is_valid_value(v): raise ValueError( 'For kwargs, the allowed value types include:' ' None, bool, int, float, bytes, str and (nested) tuples containing' ' values of these simple types. This is not true for the passed' f' argument {k} with the value {v}' ) self._function = fn self._function_kwargs = kwargs
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Do the forward pass.""" return self._function(x, **self._function_kwargs)
[docs]class NLinear(nn.Module): """N *separate* linear layers for N embeddings: ``(*, *N, D1) -> (*, *N, D2)``. Usage examples covered below: - (NLP) Training a *separate* linear layer for each token embedding in a sequence. By contrast, using `torch.nn.Linear` would mean applying the same linear layer to all tokens. - (CV) Training a *separate* linear layer for each patch embedding in an image. By contrast, using `torch.nn.Linear` would mean applying the same linear layer to all tokens. Technically, ``NLinear(N, D1, D2)`` is just a layout of ``N`` linear layers ``torch.nn.Linear(D1, D2)``. **Shape** - Input: ``(*, *n, in_features)``, where ``*`` are batch dimensions. - Output: ``(*, *n, out_features)``. **Usage** (NLP) Training a separate linear layer for each of the token embeddings in a sequence: >>> batch_size = 2 >>> sequence_length = 4 >>> d_embedding_in = 6 >>> d_embedding_out = 7 >>> x = torch.randn(batch_size, sequence_length, d_embedding_in) >>> x.shape torch.Size([2, 4, 6]) >>> m = NLinear(sequence_length, d_embedding_in, d_embedding_out) >>> m(x).shape torch.Size([2, 4, 7]) (CV) Training a separate linear layer for each of the patch embeddings in an image: >>> # Batch dimensions can also be arbitrarily complex. >>> batch_size = (2, 3) >>> width = 4 >>> height = 5 >>> in_channels = 6 >>> out_channels = 7 >>> x = torch.randn(*batch_size, width, height, in_channels) >>> x.shape torch.Size([2, 3, 4, 5, 6]) >>> # N == width * heght == 4 * 5 == 20 >>> m = NLinear((width, height), in_channels, out_channels) >>> m(x).shape torch.Size([2, 3, 4, 5, 7]) """
[docs] def __init__( self, n: Union[int, Tuple[int, ...]], in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, ) -> None: """ All arguments are the same as in `torch.nn.Linear` except for ``n``, which is the expected layout of the input (see the examples in `NLinear`). """ super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} n_tuple = (n,) if isinstance(n, int) else n if not n_tuple or any(x <= 0 for x in n_tuple): raise ValueError( 'n must be a positive integer or a non-empty tuple' f' of positive integers. The provided value: {n=}' ) self.weight = Parameter( torch.empty(*n_tuple, in_features, out_features, **factory_kwargs) ) self.bias = ( nn.parameter.Parameter( torch.empty(*n_tuple, out_features, **factory_kwargs) ) if bias else None ) self.reset_parameters()
[docs] def reset_parameters(self): """Reset all parameters.""" # The same as in torch.nn.Linear. d_in_rsqrt = self.weight.shape[-2] ** -0.5 nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt) if self.bias is not None: nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Do the forward pass.""" if x.ndim < self.weight.ndim - 1: raise ValueError( f'The input must have at least {self.weight.ndim - 1} dimentions,' f' but {x.ndim=}' ) # The non-batch dimensions corresponding to n and in_features must be # exactly equal, it would be incorrect to rely on broadcasting over them. if x.shape[-(self.weight.ndim - 1) :] != self.weight.shape[:-1]: raise ValueError( 'The input must have a shape like' ' `(*batch_dimensions, *n, in_features)`, where n and in_features ' f'are the values passed to the constructor of {type(self).__name__}.' f' However: {x.shape=}, n={self.weight.shape[:-2]},' f' in_features={self.weight.shape[-2]}' ) x = (x[..., None, :] @ self.weight).squeeze(-2) if self.bias is not None: x = x + self.bias return x
[docs]@deprecated('') def named_sequential(*names_and_modules: Tuple[str, nn.Module]) -> nn.Sequential: """A shortcut for creating `torch.nn.Sequential` with named modules without using `collections.OrderedDict`. <DEPRECATION MESSAGE> The sole purpose of this function is to improve the ergonomics and readability of the common construction. **Usage** This ... >>> # xdoctest: +SKIP >>> m = delu.nn.named_sequential( ... ('linear1', nn.Linear(10, 20)), ... ('activation', nn.ReLU()), ... ('linear2', nn.Linear(20, 1)), ... ) ... is equivalent to this: >>> # xdoctest: +SKIP >>> from collections import OrderedDict >>> m = torch.nn.Sequential( ... OrderedDict( ... [ ... ('linear1', nn.Linear(10, 20)), ... ('activation', nn.ReLU()), ... ('linear2', nn.Linear(20, 1)), ... ] ... ) ... ) Args: names_and_modules: the names and the modules. """ # noqa: E501 return nn.Sequential(OrderedDict(names_and_modules))