Skip to content

Classification Head¤

discretax.heads.classification.ClassificationHead ¤

Classification head.

This classification head takes an input of shape (timesteps, in_features) and outputs a logits of shape (out_features).

Attributes:

Name Type Description
linear

Linear layer.

reduce

Whether to reduce the time dimension by averaging.

__init__(in_features: int, out_features: int, key: PRNGKeyArray, *args, reduce: bool = True, **kwargs) ¤

Initialize the classification head.

Parameters:

Name Type Description Default
in_features int

input features.

required
out_features int

output features (number of classes).

required
key PRNGKeyArray

JAX random key for initialization.

required
reduce bool

whether to reduce the time dimension by averaging.

True
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array, state: eqx.nn.State, *, key: PRNGKeyArray | None = None) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the classification head.

This forward pass applies the linear layer to the input and returns the logits of the output.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray | None

JAX random key for stochastic operations (unused).

None

Returns:

Type Description
Array

Tuple containing the output tensor and updated state. If reduce is True,

State

the output tensor is of shape (out_features). If reduce is False,

tuple[Array, State]

the output tensor is of shape (timesteps, out_features).