DeltaNet Model¤
discretax.models.deltanet.DeltaNet
¤
DeltaNet model.
This model implements stacked blocks with DeltaNet sequence mixers and GLU channel mixers. Use with eqx.nn.Sequential to compose with encoder and head.
Attributes:
| Name | Type | Description |
|---|---|---|
blocks |
List of standard blocks with DeltaNet sequence mixers. |
Example
import equinox as eqx
import jax.random as jr
from discretax.encoder import LinearEncoder
from discretax.heads import ClassificationHead
from discretax.models import DeltaNet
key = jr.PRNGKey(0)
keys = jr.split(key, 3)
encoder = LinearEncoder(in_features=784, out_features=64, key=keys[0])
model = DeltaNet(hidden_dim=64, num_blocks=4, key=keys[1])
head = ClassificationHead(in_features=64, out_features=10, key=keys[2])
# Compose with Sequential
full_model = eqx.nn.Sequential([encoder, model, head])
Reference
DeltaNet: https://arxiv.org/abs/2406.06484
__init__(key: PRNGKeyArray, *args, hidden_dim: int, num_blocks: int = 4, n_heads: int = 4, head_dim: int | None = None, chunk_size: int = 64, drop_rate: float = 0.1, prenorm: bool = True, **kwargs)
¤
Initialize the DeltaNet model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
hidden_dim
|
int
|
Hidden dimension for the model. |
required |
num_blocks
|
int
|
Number of DeltaNet blocks to stack. |
4
|
n_heads
|
int
|
Number of attention heads in each DeltaNet sequence mixer. |
4
|
head_dim
|
int | None
|
Dimensionality per attention head. Defaults to
|
None
|
chunk_size
|
int
|
Timesteps per chunk for the chunked delta rule. Must evenly divide the sequence length at inference time. |
64
|
drop_rate
|
float
|
Dropout rate for blocks. |
0.1
|
prenorm
|
bool
|
Whether to apply prenorm in blocks. |
True
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, state: State, key: PRNGKeyArray) -> tuple[Array, equinox.nn._stateful.State]
¤
Forward pass through the DeltaNet blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
key
|
PRNGKeyArray
|
JAX random key for operations. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |