NLinear

class delu.nn.NLinear[source]

Bases: Module

N separate linear layers for N embeddings: (*, *N, D1) -> (*, *N, D2).

Usage examples covered below:

  • (NLP) Training a separate linear layer for each token embedding in a sequence. By contrast, using torch.nn.Linear would mean applying the same linear layer to all tokens.

  • (CV) Training a separate linear layer for each patch embedding in an image. By contrast, using torch.nn.Linear would mean applying the same linear layer to all tokens.

Technically, NLinear(N, D1, D2) is just a layout of N linear layers torch.nn.Linear(D1, D2).

Shape

  • Input: (*, *n, in_features), where the leading * is batch dimensions.

    Similarly to torch.nn.Linear, the batch dimensions can be arbitrarily complex.

  • Output: (*, *n, out_features).

Usage

(NLP) Training a separate linear layer for each of the token embeddings in a sequence:

>>> batch_size = 2
>>> sequence_length = 4
>>> d_embedding_in = 6
>>> d_embedding_out = 7
>>> x = torch.randn(batch_size, sequence_length, d_embedding_in)
>>> x.shape
torch.Size([2, 4, 6])
>>> m = NLinear(sequence_length, d_embedding_in, d_embedding_out)
>>> m(x).shape
torch.Size([2, 4, 7])

(CV) Training a separate linear layer (i.e. Conv1x1) for each of the spatial embeddings (patch embeddings, pixel embeddings, etc.) of an image (total count: width x height):

>>> # Batch dimensions can be arbitrarily complex (same as for torch.nn.Linear).
>>> batch_size = 2
>>> width = 4
>>> height = 5
>>> in_channels = 6
>>> out_channels = 7
>>> x = torch.randn(batch_size, width, height, in_channels)
>>> x.shape
torch.Size([2, 4, 5, 6])
>>> m = NLinear((width, height), in_channels, out_channels)
>>> m(x).shape
torch.Size([2, 4, 5, 7])
__init__(
n: int | tuple[int, ...],
in_features: int,
out_features: int,
bias: bool = True,
*,
device=None,
dtype=None,
) None[source]

n is the expected layout of the input (see the examples in NLinear). All other arguments are the same as in torch.nn.Linear.

forward(x: Tensor) Tensor[source]

Do the forward pass.

reset_parameters()[source]

Reset all parameters.