S5 Sequence Mixer¤
discretax.sequence_mixers.s5.S5SequenceMixer
¤
S5 sequence mixer layer.
This layer implements the Simplified State Space Layers (S5) sequence mixer, which uses structured state space models with HiPPO initialization and efficient parallel scan operations.
Attributes:
| Name | Type | Description |
|---|---|---|
Lambda_re |
Real part of diagonal state matrix eigenvalues. |
|
Lambda_im |
Imaginary part of diagonal state matrix eigenvalues. |
|
B |
Input projection matrix (parameterized as V^{-1}B). |
|
C |
Output projection matrix (parameterized as CV). |
|
D |
Skip connection weights. |
|
log_step |
Log of discretization step sizes. |
|
in_features |
Number of hidden channels (input features). |
|
P |
Effective state dimensionality. |
|
conj_sym |
Whether conjugate symmetry is enforced. |
|
clip_eigs |
Whether to clip eigenvalues for stability. |
|
discretization |
Discretization method being used. |
|
step_rescale |
Rescaling factor for step sizes. |
__init__(in_features: int, key: PRNGKeyArray, *args, state_dim: int = 64, ssm_blocks: int = 1, C_init: Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal'] = 'lecun_normal', conj_sym: bool = True, clip_eigs: bool = True, discretization: Literal['zoh', 'bilinear'] = 'zoh', dt_min: float = 0.001, dt_max: float = 1.0, step_rescale: float = 1.0, **kwargs)
¤
Initialize the S5 sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
dimension of the input features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
state_dim
|
int
|
dimension of the state space (total SSM size). |
64
|
ssm_blocks
|
int
|
number of blocks in the block-diagonal HiPPO initialization. |
1
|
C_init
|
Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal']
|
initialization method for output matrix C. |
'lecun_normal'
|
conj_sym
|
bool
|
whether to enforce conjugate symmetry (reduces parameters by half). |
True
|
clip_eigs
|
bool
|
whether to clip eigenvalues to ensure stability. |
True
|
discretization
|
Literal['zoh', 'bilinear']
|
discretization method to use. |
'zoh'
|
dt_min
|
float
|
minimum discretization step size. |
0.001
|
dt_max
|
float
|
maximum discretization step size. |
1.0
|
step_rescale
|
float
|
rescaling factor for the discretization step. |
1.0
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the S5 sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of features. |
required |
key
|
PRNGKeyArray
|
JAX random key (unused, for compatibility). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the S5 sequence mixer. |