Skip to content

MLP Channel Mixer¤

discretax.channel_mixers.mlp.MLPChannelMixer ¤

MLP channel mixer.

This channel mixer applies a multi-layer perceptron (MLP) to the input tensor.

Attributes:

Name Type Description
linear

Linear layer applied to the input.

non_linearity

The non-linearity function used after the linear layer.

__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int | None = None, non_linearity: activation = 'gelu', use_bias: bool = False, **kwargs) ¤

Initialize the MLP channel mixer.

Parameters:

Name Type Description Default
in_features int

the input dimensionality.

required
key PRNGKeyArray

JAX random key for initialization.

required
out_features int | None

optional output dimensionality. If None, defaults to in_features.

None
non_linearity Literal

name of the activation function to apply after the linear layer.

'gelu'
use_bias bool

whether to include a bias term in the linear layer.

False
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array) -> Array ¤

Forward pass of the MLP channel mixer.

Parameters:

Name Type Description Default
x Array

Input tensor.

required

Returns:

Type Description
Array

Output tensor.