Skip to content

Encoder¤

discretax.encoder.base.AbstractEncoder ¤

Abstract base class for all encoders.

This is the base class for all encoders.

Parameters:

Name Type Description Default
key PRNGKeyArray

JAX random key for initialization.

required
*args

Additional arguments for the encoder.

required
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__init__(key: PRNGKeyArray, *args, out_features: int, **kwargs) ¤

Initialize the encoder.

__call__(x: Array, state: eqx.nn.State, *, key: PRNGKeyArray | None = None) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the encoder.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray | None

Optional JAX random key (unused by encoders, for Sequential compatibility).

None

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.