Skip to content

LinOSS Model¤

discretax.models.linoss.LinOSS ¤

LinOSS model.

This model implements stacked blocks with LinOSS sequence mixers and GLU channel mixers. Use with eqx.nn.Sequential to compose with encoder and head.

Attributes:

Name Type Description
blocks

List of standard blocks with LinOSS sequence mixers.

Example
import equinox as eqx
import jax.random as jr
from discretax.encoder import LinearEncoder
from discretax.heads import ClassificationHead
from discretax.models import LinOSS

key = jr.PRNGKey(0)
keys = jr.split(key, 3)

encoder = LinearEncoder(in_features=784, out_features=64, key=keys[0])
model = LinOSS(hidden_dim=64, num_blocks=4, key=keys[1])
head = ClassificationHead(in_features=64, out_features=10, key=keys[2])

# Compose with Sequential
full_model = eqx.nn.Sequential([encoder, model, head])
Reference

LinOSS: https://openreview.net/pdf?id=GRMfXcAAFh

__init__(key: PRNGKeyArray, *args, hidden_dim: int, num_blocks: int = 4, state_dim: int = 64, discretization: typing.Literal['IM', 'IMEX'] = 'IMEX', damping: bool = True, r_min: float = 0.9, theta_max: float = 3.141592653589793, drop_rate: float = 0.1, prenorm: bool = True, use_bias: bool = True, **kwargs) ¤

Initialize the LinOSS model.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required
hidden_dim int

hidden dimension for the model.

required
num_blocks int

number of LinOSS blocks to stack.

4
state_dim int

state space dimension for LinOSS sequence mixers.

64
discretization Literal['IM', 'IMEX']

discretization method ("IM" or "IMEX").

'IMEX'
damping bool

whether to use damping in LinOSS.

True
r_min float

minimum value for the radius in LinOSS.

0.9
theta_max float

maximum value for theta parameter in LinOSS.

3.141592653589793
drop_rate float

dropout rate for blocks.

0.1
prenorm bool

whether to apply prenorm in blocks.

True
use_bias bool

whether to use bias in GLU channel mixers.

True
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: State, key: PRNGKeyArray) -> tuple[Array, equinox.nn._stateful.State] ¤

Forward pass through the LinOSS blocks.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray

JAX random key for operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.