Skip to content

Heads¤

linax.heads.base.HeadConfig ¤

Configuration for heads.

Attributes:

Name Type Description
out_features

Output dimensionality (e.g., number of classes).

reduce

Whether to reduce the time dimension by averaging.

build(in_features: int, key: PRNGKeyArray) -> Head ¤

Build head from config.

Parameters:

Name Type Description Default
in_features int

Input dimensionality (from encoder's hidden dimension).

required
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
Head

The head instance.


linax.heads.base.Head ¤

Abstract base class for all heads.

This is the base class for all heads in Linax.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
cfg ConfigType

Configuration for the head.

required
key PRNGKeyArray

JAX random key for initialization.

required
__init__(in_features: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray) ¤

Initialize the head.

__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the head.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.