delu.nn.Lambda#

class delu.nn.Lambda(fn)[source]#

A parameter-free module for wrapping callables.

Examples

module = delu.nn.Lambda(torch.square)
assert torch.equal(module(torch.tensor(3)), torch.tensor(9))

# Any callable can be wrapped in Lambda:
module = delu.nn.Lambda(lambda x, y, z: x + y + z)
assert module(1, 2, z=3) == 6
Parameters:

fn (Callable) –

__init__(fn)[source]#

Initialize self.

Parameters:

fn (Callable) –

forward(*args, **kwargs)[source]#

Perform the forward pass.