Skip to content

Regression Head¤

linax.heads.regression.RegressionHeadConfig ¤

Configuration for the regression head.

Attributes:

Name Type Description
out_features

Output dimensionality (prediction dimension).

build(in_features: int, key: PRNGKeyArray) -> RegressionHead ¤

Build head from config.

Parameters:

Name Type Description Default
in_features int

Input dimensionality (hidden dimension).

required
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
RegressionHead

The regression head instance.


linax.heads.regression.RegressionHead ¤

Regression head.

This regression head takes an input of shape (timesteps, in_features) and outputs a regression of shape (out_features).

Parameters:

Name Type Description Default
in_features int

Input features.

required
out_features int

Output features.

required
cfg ConfigType

Configuration for the regression head.

required
key PRNGKeyArray

JAX random key for initialization.

required

Attributes:

Name Type Description
linear

Linear layer.

__init__(in_features: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray) ¤

Initialize the regression head.

__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the regression head.

This forward pass applies the linear layer to the input and returns the mean of the output.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required

Returns:

Type Description
Array

Tuple containing the output tensor and updated state. If reduce is True,

State

the output tensor is of shape (out_features). If reduce is False,

tuple[Array, State]

the output tensor is of shape (timesteps, out_features).