Skip to content

Linear Encoder¤

discretax.encoder.linear.LinearEncoder ¤

Linear encoder.

This encoder takes an input of shape (timesteps, in_features) and outputs a hidden representation of shape (timesteps, hidden_dim).

Attributes:

Name Type Description
linear

Linear layer.

__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int, use_bias: bool = False, **kwargs) ¤

Initialize the linear encoder.

Parameters:

Name Type Description Default
in_features int

input dimensionality (number of input features).

required
key PRNGKeyArray

JAX random key for initialization.

required
out_features int

output dimensionality (hidden dimension).

required
use_bias bool

whether to use bias in the linear layer.

False
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: eqx.nn.State, *, key: PRNGKeyArray | None = None) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the linear encoder.

This forward pass applies the linear layer to the input.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray | None

JAX random key for stochastic operations (unused).

None

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.