MLP Channel Mixer¤
discretax.channel_mixers.mlp.MLPChannelMixer
¤
MLP channel mixer.
This channel mixer applies a multi-layer perceptron (MLP) to the input tensor.
Attributes:
| Name | Type | Description |
|---|---|---|
linear |
Linear layer applied to the input. |
|
non_linearity |
The non-linearity function used after the linear layer. |
__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int | None = None, non_linearity: activation = 'gelu', use_bias: bool = False, **kwargs)
¤
Initialize the MLP channel mixer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
the input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
out_features
|
int | None
|
optional output dimensionality. If None, defaults to in_features. |
None
|
non_linearity
|
Literal
|
name of the activation function to apply after the linear layer. |
'gelu'
|
use_bias
|
bool
|
whether to include a bias term in the linear layer. |
False
|
*args
|
Additional positional arguments (ignored). |
required | |
**kwargs
|
Additional keyword arguments (ignored). |
required |
__call__(x: Array) -> Array
¤
Forward pass of the MLP channel mixer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor. |