Source code for zero.nn
"""Missing batteries from `torch.nn`."""
from typing import Callable
import torch.nn
[docs]class Lambda(torch.nn.Module):
"""A parameters-free module for wrapping callables.
Examples:
.. testcode::
assert zero.nn.Lambda(lambda: 0)() == 0
assert zero.nn.Lambda(lambda x: x)(1) == 1
assert zero.nn.Lambda(lambda x, y, z: x + y + z)(1, 2, z=3) == 6
"""
[docs] def __init__(self, fn: Callable):
"""Initialize self."""
super().__init__()
self.fn = fn
[docs] def forward(self, *args, **kwargs):
"""Perform the forward pass."""
return self.fn(*args, **kwargs)