Linear Encoder¤
linax.encoder.linear.LinearEncoderConfig
¤
Configuration for the linear encoder.
Attributes:
| Name | Type | Description |
|---|---|---|
in_features |
Input dimensionality (number of input features). |
|
out_features |
Output dimensionality (hidden dimension). |
|
use_bias |
Whether to use bias in the linear layer. |
build(key: PRNGKeyArray) -> LinearEncoder
¤
Build encoder from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
LinearEncoder
|
The encoder instance. |
linax.encoder.linear.LinearEncoder
¤
Linear encoder.
This encoder takes an input of shape (timesteps, in_features) and outputs a hidden representation of shape (timesteps, hidden_dim).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
out_features
|
int
|
Output dimensionality. |
required |
cfg
|
ConfigType
|
Configuration for the linear encoder. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
linear |
MLP instance with multiple hidden layers and a last linear layer. |
__init__(in_features: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray)
¤
Initialize the linear encoder.
__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]
¤
Forward pass of the linear encoder.
This forward pass applies the linear layer to the input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |