LRU Sequence Mixer¤
linax.sequence_mixers.lru.LRUSequenceMixerConfig
¤
Configuration for the LRU sequence mixer.
This configuration class defines the hyperparameters for the LRU sequence mixer. LRU uses complex-valued diagonal state matrices for efficient sequence modeling.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
Dimensionality of the state space. |
|
r_min |
Minimum radius for the complex-valued eigenvalues. |
|
r_max |
Maximum radius for the complex-valued eigenvalues. |
|
max_phase |
Maximum phase angle for the complex-valued eigenvalues. |
build(in_features: int, key: PRNGKeyArray) -> LRUSequenceMixer
¤
Build sequence mixer from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
LRUSequenceMixer
|
The sequence mixer instance. |
linax.sequence_mixers.lru.LRUSequenceMixer
¤
LRU sequence mixer layer.
This layer implements the Linear Recurrent Unit (LRU) sequence mixer using complex-valued diagonal state matrices for efficient and expressive sequence modeling.
Attributes:
| Name | Type | Description |
|---|---|---|
nu_log |
Log of nu parameter (controls eigenvalue magnitudes). |
|
theta_log |
Log of theta parameter (controls eigenvalue phases). |
|
B_re |
Real part of input projection matrix. |
|
B_im |
Imaginary part of input projection matrix. |
|
C_re |
Real part of output projection matrix. |
|
C_im |
Imaginary part of output projection matrix. |
|
D |
Skip connection weights. |
|
gamma_log |
Log of normalization factor. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
cfg
|
ConfigType
|
Configuration for the LRU sequence mixer. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray)
¤
Initialize the LRU sequence mixer layer.
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the LRU sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of features. |
required |
key
|
PRNGKeyArray
|
JAX random key (unused, for compatibility). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the LRU sequence mixer. |