Linear Encoder¤
discretax.encoder.linear.LinearEncoder
¤
Linear encoder.
This encoder takes an input of shape (timesteps, in_features) and outputs a hidden representation of shape (timesteps, hidden_dim).
Attributes:
| Name | Type | Description |
|---|---|---|
linear |
Linear layer. |
__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int, use_bias: bool = False, **kwargs)
¤
Initialize the linear encoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
input dimensionality (number of input features). |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
out_features
|
int
|
output dimensionality (hidden dimension). |
required |
use_bias
|
bool
|
whether to use bias in the linear layer. |
False
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array, state: eqx.nn.State, *, key: PRNGKeyArray | None = None) -> tuple[Array, eqx.nn.State]
¤
Forward pass of the linear encoder.
This forward pass applies the linear layer to the input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
key
|
PRNGKeyArray | None
|
JAX random key for stochastic operations (unused). |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |