Embedding Encoder¤
linax.encoder.embedding.EmbeddingEncoderConfig
¤
Configuration for the embedding encoder.
Attributes:
| Name | Type | Description |
|---|---|---|
num_classes |
Number of classes (vocabulary size). |
|
out_features |
Output dimensionality (embedding dimension). |
build(key: PRNGKeyArray) -> EmbeddingEncoder
¤
Build encoder from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
EmbeddingEncoder
|
The encoder instance. |
linax.encoder.embedding.EmbeddingEncoder
¤
Embedding encoder.
This encoder takes an input of shape (timesteps,) and outputs a hidden representation of shape (timesteps, out_features).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of classes (vocabulary size). |
required |
out_features
|
int
|
Output dimensionality (embedding dimension). |
required |
cfg
|
ConfigType
|
Configuration for the embedding encoder. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
embedding |
Embedding layer. |
__init__(num_classes: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray)
¤
Initialize the embedding encoder.
__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]
¤
Forward pass of the embedding encoder.
This forward pass applies the embedding layer to the input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |