S5 Sequence Mixer¤
linax.sequence_mixers.s5.S5SequenceMixerConfig
¤
Configuration for the S5 sequence mixer.
This configuration class defines the hyperparameters for the S5 sequence mixer. S5 uses structured state space models with HiPPO initialization for efficient sequence modeling.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
Dimensionality of the state space (total SSM size). |
|
ssm_blocks |
Number of SSM blocks (for block-diagonal structure). |
|
C_init |
Initialization method for output matrix C. |
|
conj_sym |
Whether to enforce conjugate symmetry (reduces parameters by half). |
|
clip_eigs |
Whether to clip eigenvalues to ensure stability. |
|
discretization |
Discretization method to use. |
|
dt_min |
Minimum discretization step size. |
|
dt_max |
Maximum discretization step size. |
|
step_rescale |
Rescaling factor for the discretization step. |
build(in_features: int, key: PRNGKeyArray) -> S5SequenceMixer
¤
Build sequence mixer from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
S5SequenceMixer
|
The sequence mixer instance. |
linax.sequence_mixers.s5.S5SequenceMixer
¤
S5 sequence mixer layer.
This layer implements the Simplified State Space Layers (S5) sequence mixer, which uses structured state space models with HiPPO initialization and efficient parallel scan operations.
Attributes:
| Name | Type | Description |
|---|---|---|
Lambda_re |
Real part of diagonal state matrix eigenvalues. |
|
Lambda_im |
Imaginary part of diagonal state matrix eigenvalues. |
|
B |
Input projection matrix (parameterized as V^{-1}B). |
|
C |
Output projection matrix (parameterized as CV). |
|
D |
Skip connection weights. |
|
log_step |
Log of discretization step sizes. |
|
H |
Number of hidden channels (input features). |
|
P |
Effective state dimensionality. |
|
conj_sym |
Whether conjugate symmetry is enforced. |
|
clip_eigs |
Whether to clip eigenvalues for stability. |
|
discretization |
Discretization method being used. |
|
step_rescale |
Rescaling factor for step sizes. |
__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray)
¤
Initialize the S5 sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
cfg
|
ConfigType
|
Configuration for the S5 sequence mixer. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the S5 sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of features. |
required |
key
|
PRNGKeyArray
|
JAX random key (unused, for compatibility). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the S5 sequence mixer. |