Regression Head¤
linax.heads.regression.RegressionHeadConfig
¤
Configuration for the regression head.
Attributes:
| Name | Type | Description |
|---|---|---|
out_features |
Output dimensionality (prediction dimension). |
build(in_features: int, key: PRNGKeyArray) -> RegressionHead
¤
Build head from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality (hidden dimension). |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
RegressionHead
|
The regression head instance. |
linax.heads.regression.RegressionHead
¤
Regression head.
This regression head takes an input of shape (timesteps, in_features) and outputs a regression of shape (out_features).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input features. |
required |
out_features
|
int
|
Output features. |
required |
cfg
|
ConfigType
|
Configuration for the regression head. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
linear |
Linear layer. |
__init__(in_features: int, out_features: int, cfg: ConfigType, key: PRNGKeyArray)
¤
Initialize the regression head.
__call__(x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]
¤
Forward pass of the regression head.
This forward pass applies the linear layer to the input and returns the mean of the output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Tuple containing the output tensor and updated state. If reduce is True, |
State
|
the output tensor is of shape (out_features). If reduce is False, |
tuple[Array, State]
|
the output tensor is of shape (timesteps, out_features). |