S4D Sequence Mixer¤
discretax.sequence_mixers.s4d.S4DSequenceMixer
¤
S4D sequence mixer layer.
This layer implements the Structured State Space - Diagonal (S4D) sequence mixer, which uses diagonal parameterization of state space models for efficient sequence modeling via FFT-based convolutions.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
Input dimensionality. |
|
state_dim |
State space dimensionality. |
|
transposed |
Whether input is in transposed format. |
|
kernel |
The S4D kernel for generating convolution kernels. |
__init__(in_features: int, key: PRNGKeyArray, *args, state_dim: int = 64, transposed: bool = False, dt_min: float = 0.001, dt_max: float = 0.1, **kwargs)
¤
Initialize the S4D sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
dimension of the input features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
state_dim
|
int
|
dimension of the state space. |
64
|
transposed
|
bool
|
whether input is in transposed format (H, L) vs (L, H). |
False
|
dt_min
|
float
|
minimum discretization step size. |
0.001
|
dt_max
|
float
|
maximum discretization step size. |
0.1
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the S4D sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of features with shape (L, H) where L is sequence length and H is the number of hidden features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the S4D sequence mixer with shape (L, H). |