Channel Mixer¤
discretax.channel_mixers.base.AbstractChannelMixer
¤
Abstract base class for all channel mixers.
This class defines the interface for all channel mixers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimensionality. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
out_features
|
int | None
|
Optional output dimensionality. If None, defaults to in_features. |
None
|
*args
|
Additional arguments for the channel mixer. |
required | |
**kwargs
|
Additional keyword arguments for the channel mixer. |
required |
__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int | None = None, **kwargs)
¤
Initialize the channel mixer.
__call__(x: Array) -> Array
¤
Forward pass of the channel mixer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input tensor to the channel mixer. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The output of the channel mixer. |