Standard Block¤
linax.blocks.standard.StandardBlockConfig
¤
Configuration for the Standard block.
Attributes:
| Name | Type | Description |
|---|---|---|
drop_rate |
Dropout rate for the channel mixer. |
|
prenorm |
Whether to apply the normalization at the beginning or the end of the block. |
build(in_features: int, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray) -> StandardBlock
¤
Build block from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input features. |
required |
sequence_mixer
|
SequenceMixer
|
The sequence mixer instance for this block. |
required |
channel_mixer
|
ChannelMixer
|
The channel mixer instance for this block. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization of layers. |
required |
Returns:
| Type | Description |
|---|---|
StandardBlock
|
The Standard block instance. |
linax.blocks.standard.StandardBlock
¤
A single block in the Standard backbone.
This block implements a sequence mixer, BatchNorm normalization, and a channel mixer.
Warning
This block uses BatchNorm for normalization. When training with vmap, ensure you name the batch axis as "batch" for compatibility. Example:
# Correct usage with axis naming
jax.vmap(model, axis_name="batch")
# or
jax.vmap(model, in_axes=(0, None, 0), axis_name="batch")
This ensures BatchNorm can properly compute batch statistics across the named axis.
Attributes:
| Name | Type | Description |
|---|---|---|
norm |
BatchNorm layer applied after the sequence mixer. |
|
sequence_mixer |
The sequence mixing mechanism for sequence processing. |
|
channel_mixer |
The channel mixing mechanism for channel processing. |
|
drop |
Dropout layer applied after the channel mixer. |
|
prenorm |
Whether to apply the normalization at the beginning or the end of the block. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input features. |
required |
cfg
|
ConfigType
|
Configuration for the Standard block. |
required |
sequence_mixer
|
SequenceMixer
|
The sequence mixer instance for this block. |
required |
channel_mixer
|
ChannelMixer
|
The channel mixer instance for this block. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization of layers. |
required |
__init__(in_features: int, cfg: ConfigType, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray)
¤
Initialize the Standard block.
__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State]
¤
Apply the Standard block to the input sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (timesteps, hidden_dim). |
required |
state
|
State
|
Current state for stateful normalization layers. |
required |
key
|
PRNGKeyArray
|
JAX random key for dropout operations. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |