S5 Model¤
linax.models.s5.S5Config
¤
Configuration for S5 models.
This is a modular configuration that allows building an S5 model with different components.
Attributes:
| Name | Type | Description |
|---|---|---|
num_blocks |
Number of S5 blocks to stack. |
|
encoder_config |
Configuration for the encoder. |
|
head_config |
Configuration for the output head. |
|
sequence_mixer_config |
Optional S5 sequence mixer config that will be replicated for each block. If not provided, defaults to S5SequenceMixerConfig(). |
|
block_config |
Optional S5 block config that will be replicated for each block. If not provided, defaults to StandardBlockConfig. |
Example
# With explicit configs
config = S5Config(
num_blocks=4,
encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
sequence_mixer_config=S5SequenceMixerConfig(
state_dim=64,
ssm_blocks=1,
conj_sym=True,
clip_eigs=True,
discretization="zoh",
),
block_config=StandardBlockConfig(drop_rate=0.05),
head_config=ClassificationHeadConfig(out_features=10),
)
# With defaults (simpler)
config = S5Config(
num_blocks=4,
encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
head_config=ClassificationHeadConfig(out_features=10),
)
model = config.build(key=key)
Reference
S5: https://openreview.net/pdf?id=Ai8Hw3AXqks
__post_init__()
¤
Replicates configs for each block and validates.
build(key: PRNGKeyArray | None = None) -> SSM
¤
Build an SSM model from this configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray | None
|
JAX random key for parameter initialization. |
None
|
Returns:
| Type | Description |
|---|---|
SSM
|
Instantiated SSM model. |
Example
config = SSMConfig(...)
model = config.build(key=jr.PRNGKey(0))