Skip to content

S5 Model¤

linax.models.s5.S5Config ¤

Configuration for S5 models.

This is a modular configuration that allows building an S5 model with different components.

Attributes:

Name Type Description
num_blocks

Number of S5 blocks to stack.

encoder_config

Configuration for the encoder.

head_config

Configuration for the output head.

sequence_mixer_config

Optional S5 sequence mixer config that will be replicated for each block. If not provided, defaults to S5SequenceMixerConfig().

block_config

Optional S5 block config that will be replicated for each block. If not provided, defaults to StandardBlockConfig.

Example
# With explicit configs
config = S5Config(
    num_blocks=4,
    encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
    sequence_mixer_config=S5SequenceMixerConfig(
        state_dim=64,
        ssm_blocks=1,
        conj_sym=True,
        clip_eigs=True,
        discretization="zoh",
    ),
    block_config=StandardBlockConfig(drop_rate=0.05),
    head_config=ClassificationHeadConfig(out_features=10),
)

# With defaults (simpler)
config = S5Config(
    num_blocks=4,
    encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
    head_config=ClassificationHeadConfig(out_features=10),
)
model = config.build(key=key)
Reference

S5: https://openreview.net/pdf?id=Ai8Hw3AXqks

__post_init__() ¤

Replicates configs for each block and validates.

build(key: PRNGKeyArray | None = None) -> SSM ¤

Build an SSM model from this configuration.

Parameters:

Name Type Description Default
key PRNGKeyArray | None

JAX random key for parameter initialization.

None

Returns:

Type Description
SSM

Instantiated SSM model.

Example
config = SSMConfig(...)
model = config.build(key=jr.PRNGKey(0))