Skip to content

Channel Mixer¤

discretax.channel_mixers.base.AbstractChannelMixer ¤

Abstract base class for all channel mixers.

This class defines the interface for all channel mixers.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
key PRNGKeyArray

JAX random key for initialization.

required
out_features int | None

Optional output dimensionality. If None, defaults to in_features.

None
*args

Additional arguments for the channel mixer.

required
**kwargs

Additional keyword arguments for the channel mixer.

required
__init__(in_features: int, key: PRNGKeyArray, *args, out_features: int | None = None, **kwargs) ¤

Initialize the channel mixer.

__call__(x: Array) -> Array ¤

Forward pass of the channel mixer.

Parameters:

Name Type Description Default
x Array

The input tensor to the channel mixer.

required

Returns:

Type Description
Array

The output of the channel mixer.