Skip to content

LinOSS Model¤

linax.models.linoss.LinOSSConfig ¤

Configuration for LinOSS models.

This is a modular configuration that allows building a LinOSS model with different components.

Attributes:

Name Type Description
num_blocks

Number of LinOSS blocks to stack.

encoder_config

Configuration for the encoder (contains in_features and out_features).

head_config

Configuration for the output head (contains out_features).

sequence_mixer_config

Optional linoss sequence mixer config that will be replicated for each block. If not provided, defaults to LinOSSSequenceMixerConfig().

block_config

Optional linoss block config that will be replicated for each block. If not provided, defaults to StandardBlockConfig.

Example
# With explicit configs
config = LinOSSConfig(
    num_blocks=4,
    encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
    sequence_mixer_config=LinOSSSequenceMixerConfig(state_dim=64),
    block_config=StandardBlockConfig(drop_rate=0.1),
    head_config=ClassificationHeadConfig(out_features=10),
)

# With defaults (simpler)
config = LinOSSConfig(
    num_blocks=4,
    encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
    head_config=ClassificationHeadConfig(out_features=10),
)
model = config.build(key=key)
Reference

LinOSS: https://openreview.net/pdf?id=GRMfXcAAFh

__post_init__() ¤

Replicates configs for each block and validates.

build(key: PRNGKeyArray | None = None) -> SSM ¤

Build an SSM model from this configuration.

Parameters:

Name Type Description Default
key PRNGKeyArray | None

JAX random key for parameter initialization.

None

Returns:

Type Description
SSM

Instantiated SSM model.

Example
config = SSMConfig(...)
model = config.build(key=jr.PRNGKey(0))