Skip to content

DeltaNet Model¤

discretax.models.deltanet.DeltaNet ¤

DeltaNet model.

This model implements stacked blocks with DeltaNet sequence mixers and GLU channel mixers. Use with eqx.nn.Sequential to compose with encoder and head.

Attributes:

Name Type Description
blocks

List of standard blocks with DeltaNet sequence mixers.

Example
import equinox as eqx
import jax.random as jr
from discretax.encoder import LinearEncoder
from discretax.heads import ClassificationHead
from discretax.models import DeltaNet

key = jr.PRNGKey(0)
keys = jr.split(key, 3)

encoder = LinearEncoder(in_features=784, out_features=64, key=keys[0])
model = DeltaNet(hidden_dim=64, num_blocks=4, key=keys[1])
head = ClassificationHead(in_features=64, out_features=10, key=keys[2])

# Compose with Sequential
full_model = eqx.nn.Sequential([encoder, model, head])
Reference

DeltaNet: https://arxiv.org/abs/2406.06484

__init__(key: PRNGKeyArray, *args, hidden_dim: int, num_blocks: int = 4, n_heads: int = 4, head_dim: int | None = None, chunk_size: int = 64, drop_rate: float = 0.1, prenorm: bool = True, **kwargs) ¤

Initialize the DeltaNet model.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required
hidden_dim int

Hidden dimension for the model.

required
num_blocks int

Number of DeltaNet blocks to stack.

4
n_heads int

Number of attention heads in each DeltaNet sequence mixer.

4
head_dim int | None

Dimensionality per attention head. Defaults to hidden_dim // n_heads.

None
chunk_size int

Timesteps per chunk for the chunked delta rule. Must evenly divide the sequence length at inference time.

64
drop_rate float

Dropout rate for blocks.

0.1
prenorm bool

Whether to apply prenorm in blocks.

True
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: State, key: PRNGKeyArray) -> tuple[Array, equinox.nn._stateful.State] ¤

Forward pass through the DeltaNet blocks.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray

JAX random key for operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.