SwiGLU Channel Mixer¤
linax.channel_mixers.swi_glu.SwiGLUConfig
¤
Configuration for the SwiGLU channel mixer.
Attributes:
| Name | Type | Description |
|---|---|---|
use_bias |
Whether to include a bias term in the linear layers. |
|
hidden_ratio |
Ratio to scale hidden dimension for intermediate size calculation. |
|
intermediate_dim |
Optional explicit intermediate size. |
build(in_features: int, out_features: int | None, key: PRNGKeyArray) -> SwiGLU
¤
Build SwiGLU from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
out_features
|
int | None
|
Optional output dimensionality. If None, defaults to in_features. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
SwiGLU
|
The SwiGLU instance. |
linax.channel_mixers.swi_glu.SwiGLU
¤
Swish Gated Linear Unit (SwiGLU) layer.
Adapted from https://huggingface.co/blog/sachithgunasekara/nanojaxgpt .
The architecture consists of three linear projections: - gate_proj: Projects input to intermediate dimension - up_proj: Projects input to intermediate dimension - down_proj: Projects intermediate dimension back to hidden dimension The computation is: down_proj(swish(gate_proj(x)) * up_proj(x))
Attributes:
| Name | Type | Description |
|---|---|---|
gate_proj |
Linear layer for the gate projection. |
|
up_proj |
Linear layer for the up projection. |
|
down_proj |
Linear layer for the down projection. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
The input dimensionality. |
required |
cfg
|
ConfigType
|
Configuration for the SwiGLU channel mixer. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
out_features
|
int | None
|
Optional output dimensionality. If None, defaults to in_features. |
None
|
__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray, *, out_features: int | None = None) -> None
¤
__call__(x: Array) -> Array
¤
Forward pass of the SwiGLU layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of after applying the SwiGLU transformation. |