LRU Model¤
linax.models.lru.LRUConfig
¤
Configuration for LRU models.
This is a modular configuration that allows building an LRU model with different components.
Attributes:
| Name | Type | Description |
|---|---|---|
num_blocks |
Number of LRU blocks to stack. |
|
encoder_config |
Configuration for the encoder. |
|
head_config |
Configuration for the output head. |
|
sequence_mixer_config |
Optional LRU sequence mixer config that will be replicated for each block. If not provided, defaults to LRUSequenceMixerConfig(). |
|
block_config |
Optional LRU block config that will be replicated for each block. If not provided, defaults to StandardBlockConfig. |
Example
# With explicit configs
config = LRUConfig(
num_blocks=4,
encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
sequence_mixer_config=LRUSequenceMixerConfig(
state_dim=64,
r_min=0.0,
r_max=1.0,
max_phase=6.28,
),
block_config=StandardBlockConfig(drop_rate=0.1),
head_config=ClassificationHeadConfig(out_features=10),
)
# With defaults (simpler)
config = LRUConfig(
num_blocks=4,
encoder_config=LinearEncoderConfig(in_features=784, out_features=64),
head_config=ClassificationHeadConfig(out_features=10),
)
model = config.build(key=key)
Reference
LRU: https://proceedings.mlr.press/v202/orvieto23a/orvieto23a.pdf
__post_init__()
¤
Replicates configs for each block and validates.
build(key: PRNGKeyArray | None = None) -> SSM
¤
Build an SSM model from this configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray | None
|
JAX random key for parameter initialization. |
None
|
Returns:
| Type | Description |
|---|---|
SSM
|
Instantiated SSM model. |
Example
config = SSMConfig(...)
model = config.build(key=jr.PRNGKey(0))