DeltaNet Sequence Mixer¤
discretax.sequence_mixers.deltanet.DeltaNetSequenceMixer
¤
DeltaNet sequence mixer layer.
Implements multi-head linear attention with delta rule updates. Input is projected to queries, keys, values and a scalar gate (beta), the chunked delta rule recurrence is applied per head, and the result is projected back to the input dimension.
Attributes:
| Name | Type | Description |
|---|---|---|
q_proj |
Query projection. |
|
k_proj |
Key projection. |
|
v_proj |
Value projection. |
|
beta_proj |
Scalar gate projection (sigmoid-activated at call time). |
|
out_proj |
Output projection back to in_features. |
|
n_heads |
Number of attention heads. |
|
head_dim |
Dimensionality of each head. |
|
chunk_size |
Number of timesteps per chunk for the chunked delta rule. |
__init__(in_features: int, key: PRNGKeyArray, *args, n_heads: int = 4, head_dim: int | None = None, chunk_size: int = 64, **kwargs) -> None
¤
Initialize the DeltaNet sequence mixer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
n_heads
|
int
|
Number of attention heads. |
4
|
head_dim
|
int | None
|
Dimensionality per head. Defaults to |
None
|
chunk_size
|
int
|
Timesteps per chunk for the chunked delta rule. Must evenly divide the sequence length at call time. |
64
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the DeltaNet sequence mixer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of shape (timesteps, in_features). The sequence length
must be divisible by |
required |
key
|
PRNGKeyArray
|
JAX random key (unused, kept for interface compatibility). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output sequence of shape (timesteps, in_features). |