Skip to content

S5 Sequence Mixer¤

discretax.sequence_mixers.s5.S5SequenceMixer ¤

S5 sequence mixer layer.

This layer implements the Simplified State Space Layers (S5) sequence mixer, which uses structured state space models with HiPPO initialization and efficient parallel scan operations.

Attributes:

Name Type Description
Lambda_re

Real part of diagonal state matrix eigenvalues.

Lambda_im

Imaginary part of diagonal state matrix eigenvalues.

B

Input projection matrix (parameterized as V^{-1}B).

C

Output projection matrix (parameterized as CV).

D

Skip connection weights.

log_step

Log of discretization step sizes.

in_features

Number of hidden channels (input features).

P

Effective state dimensionality.

conj_sym

Whether conjugate symmetry is enforced.

clip_eigs

Whether to clip eigenvalues for stability.

discretization

Discretization method being used.

step_rescale

Rescaling factor for step sizes.

__init__(in_features: int, key: PRNGKeyArray, *args, state_dim: int = 64, ssm_blocks: int = 1, C_init: Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal'] = 'lecun_normal', conj_sym: bool = True, clip_eigs: bool = True, discretization: Literal['zoh', 'bilinear'] = 'zoh', dt_min: float = 0.001, dt_max: float = 1.0, step_rescale: float = 1.0, **kwargs) ¤

Initialize the S5 sequence mixer layer.

Parameters:

Name Type Description Default
in_features int

dimension of the input features.

required
key PRNGKeyArray

JAX random key for initialization.

required
state_dim int

dimension of the state space (total SSM size).

64
ssm_blocks int

number of blocks in the block-diagonal HiPPO initialization.

1
C_init Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal']

initialization method for output matrix C.

'lecun_normal'
conj_sym bool

whether to enforce conjugate symmetry (reduces parameters by half).

True
clip_eigs bool

whether to clip eigenvalues to ensure stability.

True
discretization Literal['zoh', 'bilinear']

discretization method to use.

'zoh'
dt_min float

minimum discretization step size.

0.001
dt_max float

maximum discretization step size.

1.0
step_rescale float

rescaling factor for the discretization step.

1.0
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, key: PRNGKeyArray) -> Array ¤

Forward pass of the S5 sequence mixer layer.

Parameters:

Name Type Description Default
x Array

Input sequence of features.

required
key PRNGKeyArray

JAX random key (unused, for compatibility).

required

Returns:

Type Description
Array

The output of the S5 sequence mixer.