Skip to content

S5 Model¤

discretax.models.s5.S5 ¤

S5 model.

This model implements stacked blocks with S5 sequence mixers and GLU channel mixers. Use with eqx.nn.Sequential to compose with encoder and head.

Attributes:

Name Type Description
blocks

List of standard blocks with S5 sequence mixers.

Example
import equinox as eqx
import jax.random as jr
from discretax.encoder import LinearEncoder
from discretax.heads import ClassificationHead
from discretax.models import S5

key = jr.PRNGKey(0)
keys = jr.split(key, 3)

encoder = LinearEncoder(in_features=784, out_features=64, key=keys[0])
model = S5(hidden_dim=64, num_blocks=4, key=keys[1])
head = ClassificationHead(in_features=64, out_features=10, key=keys[2])

# Compose with Sequential
full_model = eqx.nn.Sequential([encoder, model, head])
Reference

S5: https://openreview.net/pdf?id=Ai8Hw3AXqks

__init__(key: PRNGKeyArray, *args, hidden_dim: int, num_blocks: int = 4, state_dim: int = 64, ssm_blocks: int = 1, C_init: typing.Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal'] = 'lecun_normal', conj_sym: bool = True, clip_eigs: bool = True, discretization: typing.Literal['zoh', 'bilinear'] = 'zoh', dt_min: float = 0.001, dt_max: float = 1.0, step_rescale: float = 1.0, drop_rate: float = 0.1, prenorm: bool = True, use_bias: bool = True, **kwargs) ¤

Initialize the S5 model.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required
hidden_dim int

hidden dimension for the model.

required
num_blocks int

number of S5 blocks to stack.

4
state_dim int

state space dimension for S5 sequence mixers.

64
ssm_blocks int

number of SSM blocks (for block-diagonal structure).

1
C_init Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal']

initialization method for output matrix C.

'lecun_normal'
conj_sym bool

whether to enforce conjugate symmetry.

True
clip_eigs bool

whether to clip eigenvalues to ensure stability.

True
discretization Literal['zoh', 'bilinear']

discretization method to use.

'zoh'
dt_min float

minimum discretization step size.

0.001
dt_max float

maximum discretization step size.

1.0
step_rescale float

rescaling factor for the discretization step.

1.0
drop_rate float

dropout rate for blocks.

0.1
prenorm bool

whether to apply prenorm in blocks.

True
use_bias bool

whether to use bias in GLU channel mixers.

True
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: State, key: PRNGKeyArray) -> tuple[Array, equinox.nn._stateful.State] ¤

Forward pass through the S5 blocks.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray

JAX random key for operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.