Skip to content

Embedding Encoder¤

discretax.encoder.embedding.EmbeddingEncoder ¤

Embedding encoder.

This encoder takes an input of shape (timesteps,) and outputs a hidden representation of shape (timesteps, out_features).

Attributes:

Name Type Description
embedding

Embedding layer.

__init__(key: PRNGKeyArray, *args, out_features: int, num_classes: int, **kwargs) ¤

Initialize the embedding encoder.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required
out_features int

output dimensionality (embedding dimension).

required
num_classes int

number of classes (vocabulary size).

required
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: eqx.nn.State, *, key: PRNGKeyArray | None = None) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the embedding encoder.

This forward pass applies the embedding layer to the input.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray | None

JAX random key for stochastic operations (unused).

None

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.