Skip to content

Standard Block¤

linax.blocks.standard.StandardBlockConfig ¤

Configuration for the Standard block.

Attributes:

Name Type Description
drop_rate

Dropout rate for the channel mixer.

prenorm

Whether to apply the normalization at the beginning or the end of the block.

build(in_features: int, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray) -> StandardBlock ¤

Build block from config.

Parameters:

Name Type Description Default
in_features int

Input features.

required
sequence_mixer SequenceMixer

The sequence mixer instance for this block.

required
channel_mixer ChannelMixer

The channel mixer instance for this block.

required
key PRNGKeyArray

JAX random key for initialization of layers.

required

Returns:

Type Description
StandardBlock

The Standard block instance.


linax.blocks.standard.StandardBlock ¤

A single block in the Standard backbone.

This block implements a sequence mixer, BatchNorm normalization, and a channel mixer.

Warning

This block uses BatchNorm for normalization. When training with vmap, ensure you name the batch axis as "batch" for compatibility. Example:

# Correct usage with axis naming
jax.vmap(model, axis_name="batch")
# or
jax.vmap(model, in_axes=(0, None, 0), axis_name="batch")

This ensures BatchNorm can properly compute batch statistics across the named axis.

Attributes:

Name Type Description
norm

BatchNorm layer applied after the sequence mixer.

sequence_mixer

The sequence mixing mechanism for sequence processing.

channel_mixer

The channel mixing mechanism for channel processing.

drop

Dropout layer applied after the channel mixer.

prenorm

Whether to apply the normalization at the beginning or the end of the block.

Parameters:

Name Type Description Default
in_features int

Input features.

required
cfg ConfigType

Configuration for the Standard block.

required
sequence_mixer SequenceMixer

The sequence mixer instance for this block.

required
channel_mixer ChannelMixer

The channel mixer instance for this block.

required
key PRNGKeyArray

JAX random key for initialization of layers.

required
__init__(in_features: int, cfg: ConfigType, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray) ¤

Initialize the Standard block.

__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State] ¤

Apply the Standard block to the input sequence.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (timesteps, hidden_dim).

required
state State

Current state for stateful normalization layers.

required
key PRNGKeyArray

JAX random key for dropout operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.