Blocks¤
linax.blocks.base.BlockConfig
¤
Configuration for blocks.
build(in_features: int, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray) -> Block
¤
Build block from config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input features. |
required |
sequence_mixer
|
SequenceMixer
|
The sequence mixer instance for this block. |
required |
channel_mixer
|
ChannelMixer
|
The channel mixer instance for this block. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
Returns:
| Type | Description |
|---|---|
Block
|
The block instance. |
linax.blocks.base.Block
¤
Abstract base class for all blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input features. |
required |
cfg
|
ConfigType
|
Configuration for the block. |
required |
sequence_mixer
|
SequenceMixer
|
The sequence mixer instance for this block. |
required |
channel_mixer
|
ChannelMixer
|
The channel mixer instance for this block. |
required |
key
|
PRNGKeyArray
|
JAX random key for initialization. |
required |
__init__(in_features: int, cfg: ConfigType, sequence_mixer: SequenceMixer, channel_mixer: ChannelMixer, key: PRNGKeyArray)
¤
Initialize the block.
__call__(x: Array, state: eqx.nn.State, key: PRNGKeyArray) -> tuple[Array, eqx.nn.State]
¤
Forward pass of the block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
state
|
State
|
Current state for stateful layers. |
required |
key
|
PRNGKeyArray
|
JAX random key for operations. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, State]
|
Tuple containing the output tensor and updated state. |