Skip to content

Embedding Encoder¤

linax.encoder.embedding.EmbeddingEncoderConfig ¤

Configuration for the embedding encoder.

Attributes:

Name Type Description
num_classes

Number of classes (vocabulary size).

out_features

Output dimensionality (embedding dimension).

build(key: PRNGKeyArray) -> EmbeddingEncoder ¤

Build encoder from config.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
EmbeddingEncoder

The encoder instance.


linax.encoder.embedding.EmbeddingEncoder ¤

Embedding encoder.

This encoder takes an input of shape (timesteps,) and outputs a hidden representation of shape (timesteps, out_features).

Parameters:

Name Type Description Default
num_classes int

Number of classes (vocabulary size).

required
out_features int

Output dimensionality (embedding dimension).

required
cfg ConfigType

Configuration for the embedding encoder.

required
key PRNGKeyArray

JAX random key for initialization.

required

Attributes:

Name Type Description
embedding

Embedding layer.

__init__(num_classes: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray) ¤

Initialize the embedding encoder.

__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the embedding encoder.

This forward pass applies the embedding layer to the input.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.