S4D Sequence Mixer¤
linax.sequence_mixers.s4d.S4DSequenceMixerConfig
¤
Configuration for the S4D sequence mixer.
This configuration class defines the hyperparameters for the S4D sequence mixer. S4D uses diagonal structured state space models with efficient FFT-based convolutions.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
Dimensionality of the state space. |
|
transposed |
Whether input is in transposed format (H, L) vs (L, H). |
|
dt_min |
Minimum discretization step size. |
|
dt_max |
Maximum discretization step size. |
build(in_features: int, key: PRNGKeyArray) -> S4DSequenceMixer
¤
Build sequence mixer from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
S4DSequenceMixer
|
The sequence mixer instance. |
linax.sequence_mixers.s4d.S4DSequenceMixer
¤
S4D sequence mixer layer.
This layer implements the Structured State Space - Diagonal (S4D) sequence mixer, which uses diagonal parameterization of state space models for efficient sequence modeling via FFT-based convolutions.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
Input dimensionality. |
|
state_dim |
State space dimensionality. |
|
transposed |
Whether input is in transposed format. |
|
kernel |
The S4D kernel for generating convolution kernels. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
cfg
|
ConfigType
|
Configuration for the S4D sequence mixer. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray)
¤
Initialize the S4D sequence mixer layer.
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the S4D sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of features with shape (L, H) where L is sequence length and H is the number of hidden features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the S4D sequence mixer with shape (L, H). |