S5 Model¤
discretax.models.s5.S5
¤
S5 model.
This model implements stacked blocks with S5 sequence mixers and GLU channel mixers. Use with eqx.nn.Sequential to compose with encoder and head.
Attributes:
| Name | Type | Description |
|---|---|---|
blocks |
List of standard blocks with S5 sequence mixers. |
Example
import equinox as eqx
import jax.random as jr
from discretax.encoder import LinearEncoder
from discretax.heads import ClassificationHead
from discretax.models import S5
key = jr.PRNGKey(0)
keys = jr.split(key, 3)
encoder = LinearEncoder(in_features=784, out_features=64, key=keys[0])
model = S5(hidden_dim=64, num_blocks=4, key=keys[1])
head = ClassificationHead(in_features=64, out_features=10, key=keys[2])
# Compose with Sequential
full_model = eqx.nn.Sequential([encoder, model, head])
Reference
S5: https://openreview.net/pdf?id=Ai8Hw3AXqks
__init__(key: PRNGKeyArray, *args, hidden_dim: int, num_blocks: int = 4, state_dim: int = 64, ssm_blocks: int = 1, C_init: typing.Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal'] = 'lecun_normal', conj_sym: bool = True, clip_eigs: bool = True, discretization: typing.Literal['zoh', 'bilinear'] = 'zoh', dt_min: float = 0.001, dt_max: float = 1.0, step_rescale: float = 1.0, drop_rate: float = 0.1, prenorm: bool = True, use_bias: bool = True, **kwargs)
¤
Initialize the S5 model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
hidden_dim
|
int
|
hidden dimension for the model. |
required |
num_blocks
|
int
|
number of S5 blocks to stack. |
4
|
state_dim
|
int
|
state space dimension for S5 sequence mixers. |
64
|
ssm_blocks
|
int
|
number of SSM blocks (for block-diagonal structure). |
1
|
C_init
|
Literal['trunc_standard_normal', 'lecun_normal', 'complex_normal']
|
initialization method for output matrix C. |
'lecun_normal'
|
conj_sym
|
bool
|
whether to enforce conjugate symmetry. |
True
|
clip_eigs
|
bool
|
whether to clip eigenvalues to ensure stability. |
True
|
discretization
|
Literal['zoh', 'bilinear']
|
discretization method to use. |
'zoh'
|
dt_min
|
float
|
minimum discretization step size. |
0.001
|
dt_max
|
float
|
maximum discretization step size. |
1.0
|
step_rescale
|
float
|
rescaling factor for the discretization step. |
1.0
|
drop_rate
|
float
|
dropout rate for blocks. |
0.1
|
prenorm
|
bool
|
whether to apply prenorm in blocks. |
True
|
use_bias
|
bool
|
whether to use bias in GLU channel mixers. |
True
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, state: State, key: PRNGKeyArray) -> tuple[Array, equinox.nn._stateful.State]
¤
Forward pass through the S5 blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
key
|
PRNGKeyArray
|
JAX random key for operations. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |