Lambda#
- class delu.nn.Lambda[source]#
Bases:
Module
A wrapper for functions from
torch
and methods oftorch.Tensor
.An important “feature” of this module is that it is intentionally limited:
Only the functions from the
torch
module and the methods oftorch.Tensor
are allowed.The passed callable must accept a single
torch.Tensor
and return a singletorch.Tensor
.The allowed keyword arguments must be of simple types (see the docstring).
Usage
>>> m = delu.nn.Lambda(torch.squeeze, dim=1) >>> m(torch.randn(2, 1, 3, 1)).shape torch.Size([2, 3, 1]) >>> m = delu.nn.Lambda(torch.Tensor.abs_) >>> m(torch.tensor(-1.0)) tensor(1.)
Custom functions are not allowed:
>>> m = delu.nn.Lambda(lambda x: torch.abs(x)) Traceback (most recent call last): ... ValueError: fn must be a function from `torch` or a method of `torch.Tensor`, but ...
Non-trivial keyword arguments are not allowed:
>>> m = delu.nn.Lambda(torch.mul, other=torch.tensor(2.0)) Traceback (most recent call last): ... ValueError: For kwargs, the allowed value types include: ...