Skip to content

LRU Model¤

discretax.models.lru.LRU ¤

LRU model.

This model implements stacked blocks with LRU sequence mixers and GLU channel mixers. Use with eqx.nn.Sequential to compose with encoder and head.

Attributes:

Name Type Description
blocks

List of standard blocks with LRU sequence mixers.

Example
import equinox as eqx
import jax.random as jr
from discretax.encoder import LinearEncoder
from discretax.heads import ClassificationHead
from discretax.models import LRU

key = jr.PRNGKey(0)
keys = jr.split(key, 3)

encoder = LinearEncoder(in_features=784, out_features=64, key=keys[0])
model = LRU(hidden_dim=64, num_blocks=4, key=keys[1])
head = ClassificationHead(in_features=64, out_features=10, key=keys[2])

# Compose with Sequential
full_model = eqx.nn.Sequential([encoder, model, head])
Reference

LRU: https://proceedings.mlr.press/v202/orvieto23a/orvieto23a.pdf

__init__(key: PRNGKeyArray, *args, hidden_dim: int, num_blocks: int = 4, state_dim: int = 64, r_min: float = 0.0, r_max: float = 1.0, max_phase: float = 6.283185307179586, drop_rate: float = 0.1, prenorm: bool = True, use_bias: bool = True, **kwargs) ¤

Initialize the LRU model.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required
hidden_dim int

hidden dimension for the model.

required
num_blocks int

number of LRU blocks to stack.

4
state_dim int

state space dimension for LRU sequence mixers.

64
r_min float

minimum radius for complex-valued eigenvalues.

0.0
r_max float

maximum radius for complex-valued eigenvalues.

1.0
max_phase float

maximum phase angle for complex-valued eigenvalues.

6.283185307179586
drop_rate float

dropout rate for blocks.

0.1
prenorm bool

whether to apply prenorm in blocks.

True
use_bias bool

whether to use bias in GLU channel mixers.

True
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: State, key: PRNGKeyArray) -> tuple[Array, equinox.nn._stateful.State] ¤

Forward pass through the LRU blocks.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray

JAX random key for operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.