Skip to content

Linear Encoder¤

linax.encoder.linear.LinearEncoderConfig ¤

Configuration for the linear encoder.

Attributes:

Name Type Description
in_features

Input dimensionality (number of input features).

out_features

Output dimensionality (hidden dimension).

use_bias

Whether to use bias in the linear layer.

build(key: PRNGKeyArray) -> LinearEncoder ¤

Build encoder from config.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
LinearEncoder

The encoder instance.


linax.encoder.linear.LinearEncoder ¤

Linear encoder.

This encoder takes an input of shape (timesteps, in_features) and outputs a hidden representation of shape (timesteps, hidden_dim).

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
out_features int

Output dimensionality.

required
cfg ConfigType

Configuration for the linear encoder.

required
key PRNGKeyArray

JAX random key for initialization.

required

Attributes:

Name Type Description
linear

MLP instance with multiple hidden layers and a last linear layer.

__init__(in_features: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray) ¤

Initialize the linear encoder.

__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the linear encoder.

This forward pass applies the linear layer to the input.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.