NLinear#

class delu.nn.NLinear[source]#

Bases: Module

N linear layers for N inputs: (*, *N, D1) -> (*, *N, D2)

For a tensor x of the shape (*B, *N, D1), where *B are batch dimensions, *N are object dimensions (e.g. a sequence size in NLP, or width & height in computer vision) and D1 is the current embedding size (e.g. the number of features/channels):

  • applying torch.nn.Linear(D1, D2) to x means applying the same linear transformation to each of the math.prod(N) embeddings.

  • applying NLinear(N, D1, D2) to x means applying a separate linear transformation to each of the math.prod(N) embeddings.

In other words, NLinear(N, D1, D2) is a collection of math.prod(N) non-shared torch.nn.Linear(D1, D2) layers.

Shape

  • Input: (*, *n, in_features), where * are batch dimensions.

  • Output: (*, *n, out_features).

Usage

Let’s consider a Transformer-like model that outputs tensors of the shape (batch_size, n_tokens, d_embedding) (in terms of NLP, n_tokens is the sequence length). The following example demonstrates how to train a separate linear transformation for each of the n_tokens embeddings using NLinear.

>>> batch_size = 2
>>> n_tokens = 3
>>> d_embedding_in = 4
>>> d_embedding_out = 5
>>> x = torch.randn(batch_size, n_tokens, d_embedding_in)
>>> x.shape
torch.Size([2, 3, 4])
>>> m = NLinear(n_tokens, d_embedding_in, d_embedding_out)
>>> m(x).shape
torch.Size([2, 3, 5])

Similarly to torch.nn.Linear, the input can have any number of batch dimensions. The number of layers n, in turn, can be also be arbitrary.

>>> # Computer vision.
>>> batch_size = (2, 3)
>>> width = 4
>>> height = 5
>>> in_channels = 6
>>> out_channels = 7
>>> x = torch.randn(*batch_size, width, height, in_channels)
>>> x.shape
torch.Size([2, 3, 4, 5, 6])
>>> # The number of layers: width * heght = 4 * 5 = 20
>>> m = NLinear((width, height), in_channels, out_channels)
>>> m(x).shape
torch.Size([2, 3, 4, 5, 7])
__init__(
n: int | Tuple[int, ...],
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) None[source]#

All arguments are the same as in torch.nn.Linear except for n, which is the expected layout of the input (see the examples in NLinear).

forward(x: Tensor) Tensor[source]#

Do the forward pass.

reset_parameters()[source]#

Reset all parameters.