LinOSS Model¤
linax.models.linoss.LinOSSConfig
¤
Configuration for LinOSS models.
This is a modular configuration that allows building a LinOSS model with different components.
Attributes:
| Name | Type | Description |
|---|---|---|
num_blocks |
Number of LinOSS blocks to stack. |
|
encoder_config |
Configuration for the encoder (contains in_features and out_features). |
|
head_config |
Configuration for the output head (contains out_features). |
|
sequence_mixer_config |
Optional linoss sequence mixer config that will be replicated for each block. If not provided, defaults to LinOSSSequenceMixerConfig(). |
|
block_config |
Optional linoss block config that will be replicated for each block. If not provided, defaults to StandardBlockConfig. |
Example
# With explicit configs
config = LinOSSConfig(
num_blocks=4,
encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
sequence_mixer_config=LinOSSSequenceMixerConfig(state_dim=64),
block_config=StandardBlockConfig(drop_rate=0.1),
head_config=ClassificationHeadConfig(out_features=10),
)
# With defaults (simpler)
config = LinOSSConfig(
num_blocks=4,
encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
head_config=ClassificationHeadConfig(out_features=10),
)
model = config.build(key=key)
Reference
LinOSS: https://openreview.net/pdf?id=GRMfXcAAFh
__post_init__()
¤
Replicates configs for each block and validates.
build(key: PRNGKeyArray | None = None) -> SSM
¤
Build an SSM model from this configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray | None
|
JAX random key for parameter initialization. |
None
|
Returns:
| Type | Description |
|---|---|
SSM
|
Instantiated SSM model. |
Example
config = SSMConfig(...)
model = config.build(key=jr.PRNGKey(0))