Skip to content

DeltaNet Sequence Mixer¤

discretax.sequence_mixers.deltanet.DeltaNetSequenceMixer ¤

DeltaNet sequence mixer layer.

Implements multi-head linear attention with delta rule updates. Input is projected to queries, keys, values and a scalar gate (beta), the chunked delta rule recurrence is applied per head, and the result is projected back to the input dimension.

Attributes:

Name Type Description
q_proj

Query projection.

k_proj

Key projection.

v_proj

Value projection.

beta_proj

Scalar gate projection (sigmoid-activated at call time).

out_proj

Output projection back to in_features.

n_heads

Number of attention heads.

head_dim

Dimensionality of each head.

chunk_size

Number of timesteps per chunk for the chunked delta rule.

__init__(in_features: int, key: PRNGKeyArray, *args, n_heads: int = 4, head_dim: int | None = None, chunk_size: int = 64, **kwargs) -> None ¤

Initialize the DeltaNet sequence mixer.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
key PRNGKeyArray

JAX random key for initialization.

required
n_heads int

Number of attention heads.

4
head_dim int | None

Dimensionality per head. Defaults to in_features // n_heads.

None
chunk_size int

Timesteps per chunk for the chunked delta rule. Must evenly divide the sequence length at call time.

64
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, key: PRNGKeyArray) -> Array ¤

Forward pass of the DeltaNet sequence mixer.

Parameters:

Name Type Description Default
x Array

Input sequence of shape (timesteps, in_features). The sequence length must be divisible by chunk_size.

required
key PRNGKeyArray

JAX random key (unused, kept for interface compatibility).

required

Returns:

Type Description
Array

Output sequence of shape (timesteps, in_features).