Skip to content

Blocks¤

discretax.blocks.base.AbstractBlock ¤

Abstract base class for all blocks.

Parameters:

Name Type Description Default
in_features int

Input features.

required
sequence_mixer AbstractSequenceMixer

The sequence mixer instance for this block.

required
channel_mixer AbstractChannelMixer

The channel mixer instance for this block.

required
key PRNGKeyArray

JAX random key for initialization.

required
*args

Additional positional arguments (ignored).

required
**kwargs

Additional keyword arguments (ignored).

required
__init__(in_features: int, key: PRNGKeyArray, *args, sequence_mixer: AbstractSequenceMixer, channel_mixer: AbstractChannelMixer, **kwargs) ¤

Initialize the block.

__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State] ¤

Forward pass of the block.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
state State

Current state for stateful layers.

required
key PRNGKeyArray

JAX random key for operations.

required

Returns:

Type Description
tuple[Array, State]

Tuple containing the output tensor and updated state.