Skip to content

Encoder¤

linax.encoder.base.EncoderConfig ¤

Configuration for encoders.

Attributes:

Name Type Description
out_features

Output dimensionality (hidden dimension).

build(key: PRNGKeyArray) -> Encoder ¤

Build encoder from config.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
Encoder

The encoder instance.


linax.encoder.base.Encoder ¤

Abstract base class for all encoders.

This is the base class for all encoders.

Parameters:

Name Type Description Default
out_features int

Output dimensionality.

required
cfg ConfigType

Configuration for the encoder.

required
key PRNGKeyArray

JAX random key for initialization.

required
__init__(out_features: int, cfg: ConfigType, key: PRNGKeyArray) ¤

Initialize the encoder.

__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the encoder.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.