Lambda#

class delu.nn.Lambda[source]#

Bases: Module

A wrapper for functions from torch and methods of torch.Tensor.

An important “feature” of this module is that it is intentionally limited:

  • Only the functions from the torch module and the methods of torch.Tensor are allowed.

  • The passed callable must accept a single torch.Tensor and return a single torch.Tensor.

  • The allowed keyword arguments must be of simple types (see the docstring).

Usage

>>> m = delu.nn.Lambda(torch.squeeze, dim=1)
>>> m(torch.randn(2, 1, 3, 1)).shape
torch.Size([2, 3, 1])
>>> m = delu.nn.Lambda(torch.Tensor.abs_)
>>> m(torch.tensor(-1.0))
tensor(1.)

Custom functions are not allowed:

>>> m = delu.nn.Lambda(lambda x: torch.abs(x))
Traceback (most recent call last):
    ...
ValueError: fn must be a function from `torch` or a method of `torch.Tensor`, but ...

Non-trivial keyword arguments are not allowed:

>>> m = delu.nn.Lambda(torch.mul, other=torch.tensor(2.0))
Traceback (most recent call last):
    ...
ValueError: For kwargs, the allowed value types include: ...
__init__(fn: Callable[[...], Tensor], /, **kwargs) None[source]#
Parameters:
  • fn – the callable.

  • kwargs – the keyword arguments for fn. The allowed values types include: None, bool, int, float, bytes, str and (nested) tuples of these simple types.

forward(x: Tensor) Tensor[source]#

Do the forward pass.