Skip to content

Heads¤

discretax.heads.base.AbstractHead ¤

Abstract base class for all heads.

This is the base class for all heads in Discretax.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
out_features int

Output dimensionality.

required
key PRNGKeyArray

JAX random key for initialization.

required
*args

Additional arguments for the head.

required
**kwargs

Additional keyword arguments for the head.

required
__init__(in_features: int, out_features: int, key: PRNGKeyArray, *args, **kwargs) ¤

Initialize the head.

__call__(x: Array, state: eqx.nn.State, *, key: PRNGKeyArray | None = None) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the head.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray | None

Optional JAX random key (unused by heads, for Sequential compatibility).

None

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.