Standard Block¤
discretax.blocks.standard.StandardBlock
¤
A single block in the Standard backbone.
This block implements a sequence mixer, BatchNorm normalization, and a channel mixer.
Warning
This block uses BatchNorm for normalization. When training with vmap, ensure you name the batch axis as "batch" for compatibility. Example:
# Correct usage with axis naming
jax.vmap(model, axis_name="batch")
# or
jax.vmap(model, in_axes=(0, None, 0), axis_name="batch")
This ensures BatchNorm can properly compute batch statistics across the named axis.
Attributes:
| Name | Type | Description |
|---|---|---|
norm |
BatchNorm layer applied after the sequence mixer. |
|
sequence_mixer |
The sequence mixing mechanism for sequence processing. |
|
channel_mixer |
The channel mixing mechanism for channel processing. |
|
drop |
Dropout layer applied after the channel mixer. |
|
prenorm |
Whether to apply the normalization at the beginning or the end of the block. |
__init__(in_features: int, key: PRNGKeyArray, *args, sequence_mixer: Resolvable[AbstractSequenceMixer], channel_mixer: Resolvable[AbstractChannelMixer], drop_rate: float = 0.1, prenorm: bool = True, **kwargs)
¤
Initialize the Standard block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
input features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization of layers. |
required |
sequence_mixer
|
Union[AbstractSequenceMixer]
|
the sequence mixer instance for this block. |
required |
channel_mixer
|
Union[AbstractChannelMixer]
|
the channel mixer instance for this block. |
required |
drop_rate
|
float
|
dropout rate for the channel mixer. |
0.1
|
prenorm
|
bool
|
whether to apply the normalization at the beginning or the end of the block. |
True
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State]
¤
Apply the Standard block to the input sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (timesteps, hidden_dim). |
required |
state
|
State
|
Current state for stateful normalization layers. |
required |
key
|
PRNGKeyArray
|
JAX random key for dropout operations. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |