Skip to content

SwiGLU Channel Mixer¤

discretax.channel_mixers.swi_glu.SwiGLU ¤

Swish Gated Linear Unit (SwiGLU) layer.

Adapted from https://huggingface.co/blog/sachithgunasekara/nanojaxgpt .

The architecture consists of three linear projections: - gate_proj: Projects input to intermediate dimension - up_proj: Projects input to intermediate dimension - down_proj: Projects intermediate dimension back to hidden dimension The computation is: down_proj(swish(gate_proj(x)) * up_proj(x))

Attributes:

Name Type Description
gate_proj

Linear layer for the gate projection.

up_proj

Linear layer for the up projection.

down_proj

Linear layer for the down projection.

__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int | None = None, intermediate_dim: int | None = None, use_bias: bool = False, hidden_ratio: int | float = 4, **kwargs) -> None ¤

Initialize the SwiGLU layer.

Parameters:

Name Type Description Default
in_features int

the input dimensionality.

required
key PRNGKeyArray

JAX random key for initialization.

required
out_features int | None

optional output dimensionality (unused, kept for compatibility).

None
hidden_ratio int | float

FFN expansion ratio used to compute the intermediate dimension as in_features * hidden_ratio * 2/3, rounded up to a multiple of 256. Ignored when intermediate_dim is set explicitly.

4
intermediate_dim int | None

optional explicit intermediate size. When set, hidden_ratio is ignored.

None
use_bias bool

whether to include a bias term in the linear layers.

False
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__call__(x: Array) -> Array ¤

Forward pass of the SwiGLU layer.

Parameters:

Name Type Description Default
x Array

Input tensor.

required

Returns:

Type Description
Array

Output tensor of after applying the SwiGLU transformation.