LRU Model¤
discretax.models.lru.LRU
¤
LRU model.
This model implements stacked blocks with LRU sequence mixers and GLU channel mixers. Use with eqx.nn.Sequential to compose with encoder and head.
Attributes:
| Name | Type | Description |
|---|---|---|
blocks |
List of standard blocks with LRU sequence mixers. |
Example
import equinox as eqx
import jax.random as jr
from discretax.encoder import LinearEncoder
from discretax.heads import ClassificationHead
from discretax.models import LRU
key = jr.PRNGKey(0)
keys = jr.split(key, 3)
encoder = LinearEncoder(in_features=784, out_features=64, key=keys[0])
model = LRU(hidden_dim=64, num_blocks=4, key=keys[1])
head = ClassificationHead(in_features=64, out_features=10, key=keys[2])
# Compose with Sequential
full_model = eqx.nn.Sequential([encoder, model, head])
Reference
LRU: https://proceedings.mlr.press/v202/orvieto23a/orvieto23a.pdf
__init__(key: PRNGKeyArray, *args, hidden_dim: int, num_blocks: int = 4, state_dim: int = 64, r_min: float = 0.0, r_max: float = 1.0, max_phase: float = 6.283185307179586, drop_rate: float = 0.1, prenorm: bool = True, use_bias: bool = True, **kwargs)
¤
Initialize the LRU model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
hidden_dim
|
int
|
hidden dimension for the model. |
required |
num_blocks
|
int
|
number of LRU blocks to stack. |
4
|
state_dim
|
int
|
state space dimension for LRU sequence mixers. |
64
|
r_min
|
float
|
minimum radius for complex-valued eigenvalues. |
0.0
|
r_max
|
float
|
maximum radius for complex-valued eigenvalues. |
1.0
|
max_phase
|
float
|
maximum phase angle for complex-valued eigenvalues. |
6.283185307179586
|
drop_rate
|
float
|
dropout rate for blocks. |
0.1
|
prenorm
|
bool
|
whether to apply prenorm in blocks. |
True
|
use_bias
|
bool
|
whether to use bias in GLU channel mixers. |
True
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, state: State, key: PRNGKeyArray) -> tuple[Array, equinox.nn._stateful.State]
¤
Forward pass through the LRU blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
key
|
PRNGKeyArray
|
JAX random key for operations. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |