LinOSS Sequence Mixer¤
linax.sequence_mixers.linoss.LinOSSSequenceMixerConfig
¤
Configuration for the LinOSS sequence mixer.
This configuration class defines the hyperparameters for the LinOSS sequence mixer. It includes options for the model's architecture, training parameters, and behavior.
Attributes:
| Name | Type | Description |
|---|---|---|
state_dim |
Dimensionality of the state space. |
|
discretization |
Discretization method to use. |
|
damping |
Whether to use damping. |
|
r_min |
Minimum value for the radius. |
|
theta_max |
Maximum value for the theta parameter. |
build(in_features: int, key: PRNGKeyArray) -> LinOSSSequenceMixer
¤
Build sequence mixer from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
LinOSSSequenceMixer
|
The sequence mixer instance. |
linax.sequence_mixers.linoss.LinOSSSequenceMixer
¤
LinOSS sequence mixer layer.
This layer implements the LinOSS sequence mixer.
Attributes:
| Name | Type | Description |
|---|---|---|
A_diag |
Diagonal state matrix. |
|
G_diag |
Diagonal damping matrix. |
|
B |
Input matrix. |
|
C |
Output matrix. |
|
D |
Output matrix. |
|
steps |
Step sizes for the sequence mixer. |
|
discretization |
Discretization method to use. |
|
damping |
Whether to use damping. |
__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray)
¤
Initialize the LinOSS sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
cfg
|
ConfigType
|
Configuration for the LinOSS sequence mixer. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
__call__(x: Array, key: PRNGKeyArray) -> Array
¤
Forward pass of the LinOSS sequence mixer layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sequence of features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the LinOSS sequence mixer. |