Skip to content

MLP Channel Mixer¤

linax.channel_mixers.mlp.MLPChannelMixerConfig ¤

Configuration for the MLP channel mixer.

Attributes:

Name Type Description
non_linearity

Name of the activation function to apply after the linear layer.

use_bias

Whether to include a bias term in the linear layer.

build(in_features: int, out_features: int | None, key: PRNGKeyArray) -> MLPChannelMixer ¤

Build MLPChannelMixer from config.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
out_features int | None

Optional output dimensionality. If None, defaults to in_features.

required
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
MLPChannelMixer

The MLPChannelMixer instance.


linax.channel_mixers.mlp.MLPChannelMixer ¤

MLP channel mixer.

This channel mixer applies a multi-layer perceptron (MLP) to the input tensor.

Parameters:

Name Type Description Default
in_features int

The input dimensionality.

required
cfg ConfigType

Configuration for the MLP channel mixer.

required
key PRNGKeyArray

JAX random key for initialization.

required
out_features int | None

Optional output dimensionality. If None, defaults to in_features.

None

Attributes:

Name Type Description
linear

Linear layer applied to the input.

non_linearity

The non-linearity function used after the linear layer.

__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray, *, out_features: int | None = None) ¤

Initialize the MLP channel mixer.

__call__(x: Array) -> Array ¤

Forward pass of the MLP channel mixer.

Parameters:

Name Type Description Default
x Array

Input tensor.

required

Returns:

Type Description
Array

Output tensor.