Skip to content

Blocks¤

linax.blocks.base.BlockConfig ¤

Configuration for blocks.

build(in_features: int, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray) -> Block ¤

Build block from config.

Parameters:

Name Type Description Default
in_features int

Input features.

required
sequence_mixer SequenceMixer

The sequence mixer instance for this block.

required
channel_mixer ChannelMixer

The channel mixer instance for this block.

required
key PRNGKeyArray

JAX random key for initialization.

required

Returns:

Type Description
Block

The block instance.


linax.blocks.base.Block ¤

Abstract base class for all blocks.

Parameters:

Name Type Description Default
in_features int

Input features.

required
cfg ConfigType

Configuration for the block.

required
sequence_mixer SequenceMixer

The sequence mixer instance for this block.

required
channel_mixer ChannelMixer

The channel mixer instance for this block.

required
key PRNGKeyArray

JAX random key for initialization.

required
__init__(in_features: int, cfg: ConfigType, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray) ¤

Initialize the block.

__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the block.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray

JAX random key for operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.