Lambda#
- class delu.nn.Lambda[source]#
Bases:
ModuleA wrapper for functions from
torchand methods oftorch.Tensor.An important “feature” of this module is that it is intentionally limited:
Only the functions from the
torchmodule and the methods oftorch.Tensorare allowed.The passed callable must accept a single
torch.Tensorand return a singletorch.Tensor.The allowed keyword arguments must be of simple types (see the docstring).
Usage
>>> m = delu.nn.Lambda(torch.squeeze, dim=1) >>> m(torch.randn(2, 1, 3, 1)).shape torch.Size([2, 3, 1]) >>> m = delu.nn.Lambda(torch.Tensor.abs_) >>> m(torch.tensor(-1.0)) tensor(1.)
Custom functions are not allowed:
>>> m = delu.nn.Lambda(lambda x: torch.abs(x)) Traceback (most recent call last): ... ValueError: fn must be a function from `torch` or a method of `torch.Tensor`, but ...
Non-trivial keyword arguments are not allowed:
>>> m = delu.nn.Lambda(torch.mul, other=torch.tensor(2.0)) Traceback (most recent call last): ... ValueError: For kwargs, the allowed value types include: ...