Skip to content

Standard Block¤

discretax.blocks.standard.StandardBlock ¤

A single block in the Standard backbone.

This block implements a sequence mixer, BatchNorm normalization, and a channel mixer.

Warning

This block uses BatchNorm for normalization. When training with vmap, ensure you name the batch axis as "batch" for compatibility. Example:

# Correct usage with axis naming
jax.vmap(model, axis_name="batch")
# or
jax.vmap(model, in_axes=(0, None, 0), axis_name="batch")

This ensures BatchNorm can properly compute batch statistics across the named axis.

Attributes:

Name Type Description
norm

BatchNorm layer applied after the sequence mixer.

sequence_mixer

The sequence mixing mechanism for sequence processing.

channel_mixer

The channel mixing mechanism for channel processing.

drop

Dropout layer applied after the channel mixer.

prenorm

Whether to apply the normalization at the beginning or the end of the block.

__init__(in_features: int, key: PRNGKeyArray, *args, sequence_mixer: Resolvable[AbstractSequenceMixer], channel_mixer: Resolvable[AbstractChannelMixer], drop_rate: float = 0.1, prenorm: bool = True, **kwargs) ¤

Initialize the Standard block.

Parameters:

Name Type Description Default
in_features int

input features.

required
key PRNGKeyArray

JAX random key for initialization of layers.

required
sequence_mixer Union[AbstractSequenceMixer]

the sequence mixer instance for this block.

required
channel_mixer Union[AbstractChannelMixer]

the channel mixer instance for this block.

required
drop_rate float

dropout rate for the channel mixer.

0.1
prenorm bool

whether to apply the normalization at the beginning or the end of the block.

True
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State] ¤

Apply the Standard block to the input sequence.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (timesteps, hidden_dim).

required
state State

Current state for stateful normalization layers.

required
key PRNGKeyArray

JAX random key for dropout operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.