Skip to content

General SSM¤

linax.models.ssm.SSMConfig ¤

Low-level configuration for State Space Models.

This is a fully modular, component-based configuration that provides fine-grained control over the SSM architecture. Each component config contains its own dimension parameters, making the configuration self-contained and composable.

Use this when: - Building custom SSM architectures - Mixing different component types - Needing full control over each component's configuration

For pre-configured architectures (e.g., LinOSS), use high-level configs like LinOSSConfig which automatically compose the appropriate components.

Attributes:

Name Type Description
encoder_config

Configuration for the encoder that processes input data. Must specify in_features and out_features (hidden_dim).

sequence_mixer_configs

List of configurations for sequence mixers, one per block. Must be compatible with encoder's out_features (hidden_dim).

block_configs

List of configurations for blocks, one per sequence mixer.

head_config

Configuration for the output head. Must specify out_features. The in_features will be automatically set to match the encoder's out_features.

Raises:

Type Description
ValueError

If the number of sequence_mixer_configs and block_configs differ.

Example
config = SSMConfig(
    encoder_config=LinearEncoderConfig(in_features=784, out_features=128),
    sequence_mixer_configs=[LinOSSSequenceMixerConfig(state_dim=128)] * 4,
    block_configs=[StandardBlockConfig(drop_rate=0.1)] * 4,
    channel_mixer_configs=[GLUConfig()] * 4,
    head_config=ClassificationHeadConfig(out_features=10),
)
model = config.build(key=key)
build(key: PRNGKeyArray | None = None) -> SSM ¤

Build an SSM model from this configuration.

Parameters:

Name Type Description Default
key PRNGKeyArray | None

JAX random key for parameter initialization.

None

Returns:

Type Description
SSM

Instantiated SSM model.

Example
config = SSMConfig(...)
model = config.build(key=jr.PRNGKey(0))

linax.models.ssm.SSM ¤

General State Space Model (SSM) implementation.

This is a flexible, composable SSM architecture that can be configured with different encoders, sequence mixers, blocks, and heads. It serves as the base implementation for all SSM variants in linax.

The model applies components in the following order: 1. Encoder: Transforms input to hidden dimension 2. Blocks: Stack of (sequence mixer + channel mixer) layers 3. Head: Produces final output (classification, regression, etc.)

Parameters:

Name Type Description Default
cfg ConfigType

Low-level configuration specifying all components (see SSMConfig).

required
key PRNGKeyArray

JAX random key for parameter initialization.

required

Attributes:

Name Type Description
encoder

The encoder instance that processes raw inputs.

blocks

List of block instances, each containing a sequence mixer and channel mixer.

head

The output head instance that produces final predictions.

__init__(cfg: ConfigType, key: PRNGKeyArray) ¤
__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the SSM model.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray

JAX random key for operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.