LinOSS Sequence Mixer¤
discretax.sequence_mixers.linoss.LinOSSSequenceMixer
¤
LinOSS sequence mixer layer.
This layer implements the LinOSS sequence mixer.
Attributes:
| Name | Type | Description |
|---|---|---|
A_diag |
Diagonal state matrix. |
|
G_diag |
Diagonal damping matrix. |
|
B |
Input matrix. |
|
C |
Output matrix. |
|
D |
Output matrix. |
|
steps |
Learnable step sizes for the sequence mixer (parameterized via sigmoid). |
|
discretization |
Discretization method to use. |
|
damping |
Whether to use damping. |
__init__(in_features: int, key: PRNGKeyArray, *args, state_dim: int = 64, discretization: Literal[IM, IMEX] = 'IMEX', damping: bool = True, r_min: float = 0.9, theta_max: float = 3.141592653589793, **kwargs)
¤
Initialize the LinOSS sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
dimension of the input features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
state_dim
|
int
|
dimension of the state space. |
64
|
discretization
|
Literal[IM, IMEX]
|
discretization method to use. |
'IMEX'
|
damping
|
bool
|
whether to use damping. |
True
|
r_min
|
float
|
minimum value for the radius. |
0.9
|
theta_max
|
float
|
maximum value for the theta parameter. |
3.141592653589793
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the LinOSS sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the LinOSS sequence mixer. |