Skip to content

LRU Sequence Mixer¤

discretax.sequence_mixers.lru.LRUSequenceMixer ¤

LRU sequence mixer layer.

This layer implements the Linear Recurrent Unit (LRU) sequence mixer using complex-valued diagonal state matrices for efficient and expressive sequence modeling.

Attributes:

Name Type Description
nu_log

Log of nu parameter (controls eigenvalue magnitudes).

theta_log

Log of theta parameter (controls eigenvalue phases).

B_re

Real part of input projection matrix.

B_im

Imaginary part of input projection matrix.

C_re

Real part of output projection matrix.

C_im

Imaginary part of output projection matrix.

D

Skip connection weights.

gamma_log

Log of normalization factor.

__init__(in_features: int, key: PRNGKeyArray, *args, state_dim: int = 64, r_min: float = 0.0, r_max: float = 1.0, max_phase: float = 6.28, **kwargs) ¤

Initialize the LRU sequence mixer layer.

Parameters:

Name Type Description Default
in_features int

dimension of the input features.

required
key PRNGKeyArray

JAX random key for initialization.

required
state_dim int

dimension of the state space.

64
r_min float

minimum radius for the complex-valued eigenvalues.

0.0
r_max float

maximum radius for the complex-valued eigenvalues.

1.0
max_phase float

maximum phase angle for the complex-valued eigenvalues.

6.28
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, key: PRNGKeyArray) -> Array ¤

Forward pass of the LRU sequence mixer layer.

Parameters:

Name Type Description Default
x Array

Input sequence of features.

required
key PRNGKeyArray

JAX random key (unused, for compatibility).

required

Returns:

Type Description
Array

The output of the LRU sequence mixer.