Skip to content

Classification Head¤

linax.heads.classification.ClassificationHeadConfig ¤

Configuration for the classification head.

Attributes:

Name Type Description
out_features

Output dimensionality (number of classes).

build(in_features: int, key: PRNGKeyArray) -> ClassificationHead ¤

Build head from config.

Parameters:

Name Type Description Default
in_features int

Input dimensionality (hidden dimension).

required
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
ClassificationHead

The classification head instance.


linax.heads.classification.ClassificationHead ¤

Classification head.

This classification head takes an input of shape (timesteps, in_features) and outputs a logits of shape (out_features).

Parameters:

Name Type Description Default
in_features int

Input features.

required
out_features int

Output features.

required
cfg ConfigType

Configuration for the classification head.

required
key PRNGKeyArray

JAX random key for initialization.

required

Attributes:

Name Type Description
linear

Linear layer.

__init__(in_features: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray) ¤

Initialize the classification head.

__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the classification head.

This forward pass applies the linear layer to the input and returns the logits of the output.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required

Returns:

Type Description
Array

Tuple containing the output tensor and updated state. If reduce is True,

State

the output tensor is of shape (out_features). If reduce is False,

tuple[Array, State]

the output tensor is of shape (timesteps, out_features).