SwiGLU Channel Mixer¤
discretax.channel_mixers.swi_glu.SwiGLU
¤
Swish Gated Linear Unit (SwiGLU) layer.
Adapted from https://huggingface.co/blog/sachithgunasekara/nanojaxgpt .
The architecture consists of three linear projections: - gate_proj: Projects input to intermediate dimension - up_proj: Projects input to intermediate dimension - down_proj: Projects intermediate dimension back to hidden dimension The computation is: down_proj(swish(gate_proj(x)) * up_proj(x))
Attributes:
| Name | Type | Description |
|---|---|---|
gate_proj |
Linear layer for the gate projection. |
|
up_proj |
Linear layer for the up projection. |
|
down_proj |
Linear layer for the down projection. |
__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int | None = None, intermediate_dim: int | None = None, use_bias: bool = False, hidden_ratio: int | float = 4, **kwargs) -> None
¤
Initialize the SwiGLU layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
the input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
out_features
|
int | None
|
optional output dimensionality (unused, kept for compatibility). |
None
|
hidden_ratio
|
int | float
|
FFN expansion ratio used to compute the intermediate dimension as
|
4
|
intermediate_dim
|
int | None
|
optional explicit intermediate size. When set, |
None
|
use_bias
|
bool
|
whether to include a bias term in the linear layers. |
False
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array) -> Array
¤
Forward pass of the SwiGLU layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of after applying the SwiGLU transformation. |