General SSM¤
linax.models.ssm.SSMConfig
¤
Low-level configuration for State Space Models.
This is a fully modular, component-based configuration that provides fine-grained control over the SSM architecture. Each component config contains its own dimension parameters, making the configuration self-contained and composable.
Use this when: - Building custom SSM architectures - Mixing different component types - Needing full control over each component's configuration
For pre-configured architectures (e.g., LinOSS), use high-level configs like
LinOSSConfig which automatically compose the appropriate components.
Attributes:
| Name | Type | Description |
|---|---|---|
encoder_config |
Configuration for the encoder that processes input data. Must specify in_features and out_features (hidden_dim). |
|
sequence_mixer_configs |
List of configurations for sequence mixers, one per block. Must be compatible with encoder's out_features (hidden_dim). |
|
block_configs |
List of configurations for blocks, one per sequence mixer. |
|
head_config |
Configuration for the output head. Must specify out_features. The in_features will be automatically set to match the encoder's out_features. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the number of sequence_mixer_configs and block_configs differ. |
Example
config = SSMConfig(
encoder_config=LinearEncoderConfig(in_features=784, out_features=128),
sequence_mixer_configs=[LinOSSSequenceMixerConfig(state_dim=128)] * 4,
block_configs=[StandardBlockConfig(drop_rate=0.1)] * 4,
channel_mixer_configs=[GLUConfig()] * 4,
head_config=ClassificationHeadConfig(out_features=10),
)
model = config.build(key=key)
build(key: PRNGKeyArray | None = None) -> SSM
¤
Build an SSM model from this configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray | None
|
JAX random key for parameter initialization. |
None
|
Returns:
| Type | Description |
|---|---|
SSM
|
Instantiated SSM model. |
Example
config = SSMConfig(...)
model = config.build(key=jr.PRNGKey(0))
linax.models.ssm.SSM
¤
General State Space Model (SSM) implementation.
This is a flexible, composable SSM architecture that can be configured with different encoders, sequence mixers, blocks, and heads. It serves as the base implementation for all SSM variants in linax.
The model applies components in the following order: 1. Encoder: Transforms input to hidden dimension 2. Blocks: Stack of (sequence mixer + channel mixer) layers 3. Head: Produces final output (classification, regression, etc.)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
ConfigType
|
Low-level configuration specifying all components (see |
required |
key
|
PRNGKeyArray
|
JAX random key for parameter initialization. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
encoder |
The encoder instance that processes raw inputs. |
|
blocks |
List of block instances, each containing a sequence mixer and channel mixer. |
|
head |
The output head instance that produces final predictions. |
__init__(cfg: ConfigType, key: PRNGKeyArray)
¤
__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State]
¤
Forward pass of the SSM model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
key
|
PRNGKeyArray
|
JAX random key for operations. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |