GLU Channel Mixer¤
linax.channel_mixers.glu.GLUConfig
¤
Configuration for the GLU channel mixer.
Attributes:
| Name | Type | Description |
|---|---|---|
use_bias |
Whether to include a bias term in the linear layers. |
build(in_features: int, out_features: int | None, key: PRNGKeyArray) -> GLU
¤
Build GLU from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
out_features
|
int | None
|
Optional output dimensionality. If None, defaults to in_features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
GLU
|
The GLU instance. |
linax.channel_mixers.glu.GLU
¤
Gated Linear Unit (GLU) layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Dimensionality of the input features. |
required |
out_features
|
int | None
|
Optional dimensionality of the output features (defaults to in_features). |
None
|
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
w1 |
First linear layer. |
|
w2 |
Second linear layer. |
Source
https://arxiv.org/pdf/2002.05202
__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray, *, out_features: int | None = None)
¤
Initialize the GLU layer.
__call__(x: Array) -> Array
¤
Forward pass of the GLU layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor after applying gated linear transformation. |