Skip to content

Channel Mixer¤

linax.channel_mixers.base.ChannelMixerConfig ¤

Configuration for channel mixers.

build(in_features: int, out_features: int | None, key: PRNGKeyArray) -> ChannelMixer ¤

Build channel mixer from config.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
out_features int | None

Optional output dimensionality. If None, defaults to in_features.

required
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
ChannelMixer

The channel mixer instance.


linax.channel_mixers.base.ChannelMixer ¤

Abstract base class for all channel mixers.

This class defines the interface for all channel mixers.

Parameters:

Name Type Description Default
in_features int

Input dimensionality.

required
cfg ConfigType

Configuration for the channel mixer.

required
key PRNGKeyArray

JAX random key for initialization.

required
__init__(in_features: int, cfg: ConfigType, key: PRNGKeyArray, *, out_features: int | None = None) ¤

Initialize the channel mixer.

__call__(x: Array) -> Array ¤

Forward pass of the channel mixer.

Parameters:

Name Type Description Default
x Array

The input tensor to the channel mixer.

required

Returns:

Type Description
Array

The output of the channel mixer.