"""An extension to `torch.nn`."""
import inspect
import warnings
from collections import OrderedDict
from typing import Callable, Tuple, Union
import torch.nn
import torch.nn as nn
from torch.nn.parameter import Parameter
from ._utils import deprecated
__all__ = ['Lambda', 'NLinear', 'named_sequential']
[docs]class Lambda(torch.nn.Module):
"""A wrapper for functions from `torch` and methods of `torch.Tensor`.
An important "feature" of this module is that it is intentionally limited:
- Only the functions from the `torch` module and the methods of `torch.Tensor`
are allowed.
- The passed callable must accept a single `torch.Tensor`
and return a single `torch.Tensor`.
- The allowed keyword arguments must be of simple types (see the docstring).
**Usage**
>>> m = delu.nn.Lambda(torch.squeeze)
>>> m(torch.randn(2, 1, 3, 1)).shape
torch.Size([2, 3])
>>> m = delu.nn.Lambda(torch.squeeze, dim=1)
>>> m(torch.randn(2, 1, 3, 1)).shape
torch.Size([2, 3, 1])
>>> m = delu.nn.Lambda(torch.Tensor.abs_)
>>> m(torch.tensor(-1.0))
tensor(1.)
Custom functions are not allowed
(technically, they are **temporarily** allowed,
but this functionality is deprecated and will be removed in future releases):
>>> # xdoctest: +SKIP
>>> m = delu.nn.Lambda(lambda x: torch.abs(x))
Traceback (most recent call last):
...
ValueError: fn must be a function from `torch` or a method of `torch.Tensor`, but ...
Non-trivial keyword arguments are not allowed:
>>> m = delu.nn.Lambda(torch.mul, other=torch.tensor(2.0))
Traceback (most recent call last):
...
ValueError: For kwargs, the allowed value types include: ...
""" # noqa: E501
[docs] def __init__(self, fn: Callable[..., torch.Tensor], /, **kwargs) -> None:
"""
Args:
fn: the callable.
kwargs: the keyword arguments for ``fn``. The allowed values types include:
None, bool, int, float, bytes, str
and (nested) tuples of these simple types.
"""
super().__init__()
if not callable(fn) or (
fn not in vars(torch).values()
and (
fn not in (member for _, member in inspect.getmembers(torch.Tensor))
or inspect.ismethod(fn) # Check if fn is a @classmethod
)
):
warnings.warn(
'Passing custom functions to delu.nn.Lambda is deprecated'
' and will be removed in future releases.'
' Only functions from the `torch` module and methods of `torch.Tensor`'
' are allowed',
DeprecationWarning,
)
# NOTE: in future releases, replace the above warning with this exception:
# raise ValueError(
# 'fn must be a function from `torch` or a method of `torch.Tensor`,'
# f' but this is not true for the passed {fn=}'
# )
def is_valid_value(x):
return (
x is None
or isinstance(x, (bool, int, float, bytes, str))
or isinstance(x, tuple)
and all(map(is_valid_value, x))
)
for k, v in kwargs.items():
if not is_valid_value(v):
raise ValueError(
'For kwargs, the allowed value types include:'
' None, bool, int, float, bytes, str and (nested) tuples containing'
' values of these simple types. This is not true for the passed'
f' argument {k} with the value {v}'
)
self._function = fn
self._function_kwargs = kwargs
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Do the forward pass."""
return self._function(x, **self._function_kwargs)
[docs]class NLinear(nn.Module):
"""N *separate* linear layers for N embeddings: ``(*, *N, D1) -> (*, *N, D2)``.
Usage examples covered below:
- (NLP) Training a *separate* linear layer for each token embedding in a sequence.
By contrast, using `torch.nn.Linear` would mean applying the same linear layer
to all tokens.
- (CV) Training a *separate* linear layer for each patch embedding in an image.
By contrast, using `torch.nn.Linear` would mean applying the same linear layer
to all tokens.
Technically, ``NLinear(N, D1, D2)`` is just a layout of ``N``
linear layers ``torch.nn.Linear(D1, D2)``.
**Shape**
- Input: ``(*, *n, in_features)``, where ``*`` are batch dimensions.
- Output: ``(*, *n, out_features)``.
**Usage**
(NLP)
Training a separate linear layer for each of the token embeddings in a sequence:
>>> batch_size = 2
>>> sequence_length = 4
>>> d_embedding_in = 6
>>> d_embedding_out = 7
>>> x = torch.randn(batch_size, sequence_length, d_embedding_in)
>>> x.shape
torch.Size([2, 4, 6])
>>> m = NLinear(sequence_length, d_embedding_in, d_embedding_out)
>>> m(x).shape
torch.Size([2, 4, 7])
(CV)
Training a separate linear layer for each of the patch embeddings in an image:
>>> # Batch dimensions can also be arbitrarily complex.
>>> batch_size = (2, 3)
>>> width = 4
>>> height = 5
>>> in_channels = 6
>>> out_channels = 7
>>> x = torch.randn(*batch_size, width, height, in_channels)
>>> x.shape
torch.Size([2, 3, 4, 5, 6])
>>> # N == width * heght == 4 * 5 == 20
>>> m = NLinear((width, height), in_channels, out_channels)
>>> m(x).shape
torch.Size([2, 3, 4, 5, 7])
"""
[docs] def __init__(
self,
n: Union[int, Tuple[int, ...]],
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
"""
All arguments are the same as in `torch.nn.Linear` except for ``n``,
which is the expected layout of the input (see the examples in `NLinear`).
"""
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
n_tuple = (n,) if isinstance(n, int) else n
if not n_tuple or any(x <= 0 for x in n_tuple):
raise ValueError(
'n must be a positive integer or a non-empty tuple'
f' of positive integers. The provided value: {n=}'
)
self.weight = Parameter(
torch.empty(*n_tuple, in_features, out_features, **factory_kwargs)
)
self.bias = (
nn.parameter.Parameter(
torch.empty(*n_tuple, out_features, **factory_kwargs)
)
if bias
else None
)
self.reset_parameters()
[docs] def reset_parameters(self):
"""Reset all parameters."""
# The same as in torch.nn.Linear.
d_in_rsqrt = self.weight.shape[-2] ** -0.5
nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt)
if self.bias is not None:
nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Do the forward pass."""
if x.ndim < self.weight.ndim - 1:
raise ValueError(
f'The input must have at least {self.weight.ndim - 1} dimentions,'
f' but {x.ndim=}'
)
# The non-batch dimensions corresponding to n and in_features must be
# exactly equal, it would be incorrect to rely on broadcasting over them.
if x.shape[-(self.weight.ndim - 1) :] != self.weight.shape[:-1]:
raise ValueError(
'The input must have a shape like'
' `(*batch_dimensions, *n, in_features)`, where n and in_features '
f'are the values passed to the constructor of {type(self).__name__}.'
f' However: {x.shape=}, n={self.weight.shape[:-2]},'
f' in_features={self.weight.shape[-2]}'
)
x = (x[..., None, :] @ self.weight).squeeze(-2)
if self.bias is not None:
x = x + self.bias
return x
[docs]@deprecated('')
def named_sequential(*names_and_modules: Tuple[str, nn.Module]) -> nn.Sequential:
"""A shortcut for creating `torch.nn.Sequential` with named modules without using `collections.OrderedDict`.
<DEPRECATION MESSAGE>
The sole purpose of this function is to improve the ergonomics and readability
of the common construction.
**Usage**
This ...
>>> # xdoctest: +SKIP
>>> m = delu.nn.named_sequential(
... ('linear1', nn.Linear(10, 20)),
... ('activation', nn.ReLU()),
... ('linear2', nn.Linear(20, 1)),
... )
... is equivalent to this:
>>> # xdoctest: +SKIP
>>> from collections import OrderedDict
>>> m = torch.nn.Sequential(
... OrderedDict(
... [
... ('linear1', nn.Linear(10, 20)),
... ('activation', nn.ReLU()),
... ('linear2', nn.Linear(20, 1)),
... ]
... )
... )
Args:
names_and_modules: the names and the modules.
""" # noqa: E501
return nn.Sequential(OrderedDict(names_and_modules))