Layers API Reference

All layers are importable directly from idx_flow. Internally they live in separate modules (conv, mlp, norm, regularization, attention, vit, pooling, functional).

Convolution Layers

SpatialConv

class SpatialConv[source]

Bases: Module

Spatial Convolution layer for downsampling on HEALPix grids.

This layer performs convolution on spherical data discretized using the HEALPix tessellation scheme. It uses precomputed connection indices to gather features from neighboring pixels and applies learnable kernels for spatial feature transformation.

The operation follows:
  1. Gather: Collect features from k neighbors for each output point

  2. Transform: Apply learnable kernel weights

  3. Aggregate: Sum weighted contributions with bias

Mathematically:

Y[b,p,f] = sum_k sum_c X[b, idx[p,k], c] * W[k,c,f] + bias[f]

Parameters:
  • output_points – Number of spatial points in the output tensor.

  • connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels that connect to each output pixel.

  • kernel_weights – Optional float array of shape [output_points, kernel_size] containing distance-based weights for each connection. If provided, neighbor features are scaled by these weights before convolution.

  • filters – Number of output feature channels (filters).

  • bias – Whether to include a bias term. Default is True.

  • weight_init – Weight initialization method. Default is “xavier_uniform”.

  • weight_init_gain – Gain for xavier/orthogonal initialization.

  • bias_init – Bias initialization value. Default is 0.0.

kernel

Learnable weight tensor of shape [kernel_size, in_channels, filters].

bias_param

Learnable bias tensor of shape [filters] if bias=True.

Shape:
  • Input: [B, N_in, C_in] where B is batch size, N_in is input points, C_in is input channels.

  • Output: [B, N_out, filters] where N_out is output_points.

Example

>>> from idx_flow.utils import compute_connection_indices
>>> indices, distances = compute_connection_indices(
...     nside_in=64, nside_out=32, k=4
... )
>>> conv = SpatialConv(
...     output_points=12 * 32**2,
...     connection_indices=indices,
...     filters=64,
...     weight_init="kaiming_normal"
... )
>>> x = torch.randn(8, 12 * 64**2, 32)
>>> y = conv(x)
>>> print(y.shape)  # torch.Size([8, 12288, 64])

Notes

  • Connection indices must be precomputed using hp_distance or similar.

  • The layer maintains O(N) complexity per forward pass.

  • Input channels are inferred from the first forward pass.

__init__(output_points, connection_indices, kernel_weights=None, filters=32, bias=True, weight_init='xavier_uniform', weight_init_gain=1.0, bias_init=0.0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • output_points (int)

  • connection_indices (ndarray[tuple[int, ...], dtype[int64]])

  • kernel_weights (ndarray[tuple[int, ...], dtype[float64]] | None)

  • filters (int)

  • bias (bool)

  • weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])

  • weight_init_gain (float)

  • bias_init (float)

Return type:

None

forward(x)[source]

Forward pass of the spatial convolution.

Parameters:

x (Tensor) – Input tensor of shape [B, N_in, C_in].

Returns:

Output tensor of shape [B, N_out, filters].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

SpatialTransposeConv

class SpatialTransposeConv[source]

Bases: Module

Spatial Transpose Convolution layer for upsampling on HEALPix grids.

This layer performs transposed (deconvolution) operations for upsampling spatial resolution on HEALPix grids. It maps features from a lower resolution grid to a higher resolution grid using precomputed connection indices.

Parameters:
  • output_points – Number of spatial points in the output (higher resolution).

  • connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output pixel.

  • kernel_weights – Optional float array of shape [output_points, kernel_size] containing distance-based weights for each connection.

  • filters – Number of output feature channels.

  • bias – Whether to include a bias term. Default is True.

  • weight_init – Weight initialization method. Default is “xavier_uniform”.

  • weight_init_gain – Gain for xavier/orthogonal initialization.

  • bias_init – Bias initialization value. Default is 0.0.

Shape:
  • Input: [B, N_in, C_in] where N_in is the lower resolution.

  • Output: [B, N_out, filters] where N_out is output_points (higher res).

Example

>>> from idx_flow.utils import compute_connection_indices
>>> indices, distances, weights = compute_connection_indices(
...     nside_in=32, nside_out=64, k=4, return_weights=True
... )
>>> transpose_conv = SpatialTransposeConv(
...     output_points=12 * 64**2,
...     connection_indices=indices,
...     kernel_weights=weights,
...     filters=32,
...     weight_init="orthogonal"
... )
>>> x = torch.randn(8, 12 * 32**2, 64)
>>> y = transpose_conv(x)
>>> print(y.shape)  # torch.Size([8, 49152, 32])
__init__(output_points, connection_indices, kernel_weights=None, filters=32, bias=True, weight_init='xavier_uniform', weight_init_gain=1.0, bias_init=0.0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • output_points (int)

  • connection_indices (ndarray[tuple[int, ...], dtype[int64]])

  • kernel_weights (ndarray[tuple[int, ...], dtype[float64]] | None)

  • filters (int)

  • bias (bool)

  • weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])

  • weight_init_gain (float)

  • bias_init (float)

Return type:

None

forward(x)[source]

Forward pass of the spatial transpose convolution.

Parameters:

x (Tensor) – Input tensor of shape [B, N_in, C_in].

Returns:

Output tensor of shape [B, N_out, filters].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

SpatialUpsampling

class SpatialUpsampling[source]

Bases: Module

Spatial Upsampling layer using distance-based interpolation.

This layer performs upsampling on HEALPix grids using precomputed interpolation weights based on geodesic distances. Unlike SpatialTransposeConv, this layer does not have learnable parameters and performs pure distance-weighted interpolation.

Interpolation methods:
  • “linear”: Weight = max(0, 1 - distance / kernel_radius)

  • “idw”: Inverse distance weighting, Weight = 1 / (distance^2 + eps)

  • “gaussian”: Weight = exp(-0.5 * (distance / kernel_radius)^2)

Parameters:
  • output_points – Number of spatial points in the output (higher resolution).

  • connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output pixel.

  • distances – Float array of shape [output_points, kernel_size] containing geodesic distances to each neighbor.

  • interpolation – Interpolation method. One of “linear”, “idw”, “gaussian”.

  • kernel_radius – Radius for interpolation kernel. If None, uses the maximum distance. Only used for “linear” and “gaussian” methods.

Shape:
  • Input: [B, N_in, C_in]

  • Output: [B, N_out, C_in] (channels preserved)

Example

>>> from idx_flow.utils import compute_connection_indices
>>> indices, distances = compute_connection_indices(
...     nside_in=32, nside_out=64, k=4
... )
>>> upsample = SpatialUpsampling(
...     output_points=12 * 64**2,
...     connection_indices=indices,
...     distances=distances,
...     interpolation="idw"
... )
>>> x = torch.randn(8, 12 * 32**2, 32)
>>> y = upsample(x)
>>> print(y.shape)  # torch.Size([8, 49152, 32])

Notes

  • This layer has no learnable parameters.

  • Output channels equal input channels (no feature transformation).

  • Weights are precomputed and stored as buffers for efficiency.

__init__(output_points, connection_indices, distances, interpolation='linear', kernel_radius=None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]

Forward pass of spatial upsampling.

Parameters:

x (Tensor) – Input tensor of shape [B, N_in, C_in].

Returns:

Output tensor of shape [B, N_out, C_in].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

MLP Layers

SpatialMLP

class SpatialMLP[source]

Bases: Module

Spatial Multi-Layer Perceptron for local non-linear processing on HEALPix grids.

This layer replaces the standard linear convolution kernel (matrix multiplication) with a shared Dense Neural Network (MLP) applied to the flattened neighborhood vector. This architectural choice is inspired by principles from Geometric Deep Learning, which extends neural network operations to non-Euclidean domains such as manifolds and graphs.

Architectural Difference from SpatialConv:
  • SpatialConv: Applies a linear transformation to neighborhood features. The output is a weighted sum: Y = sum_k(W_k * X_k) + b, which can only capture linear relationships between neighbors.

  • SpatialMLP: Concatenates neighborhood features into a single vector and processes it through a multi-layer perceptron with non-linear activations: Y = MLP([X_1 || X_2 || … || X_k]). This allows the model to learn complex, non-linear interactions between neighbors.

Trade-offs:
Benefits:
  • Significantly higher representation capacity and expressivity

  • Can approximate arbitrary non-linear functions over local patches

  • Better suited for learning complex spatial patterns on manifolds

  • Supports dropout, batch normalization, and residual connections

Costs:
  • Higher computational cost (FLOPs) due to dense layer operations

  • Increased memory usage for storing MLP parameters

  • May require more training data to avoid overfitting

The operation:
  1. Gather k neighbor features for each output point: [B, N_out, k, C_in]

  2. Flatten the neighbor features: [B, N_out, k * C_in]

  3. Process through shared MLP layers with specified activations

  4. Output: [B, N_out, hidden_units[-1]]

Literature Context:

This approach draws from Geometric Deep Learning research on extending neural networks to non-Euclidean domains:

  • Bronstein, M. M., Bruna, J., LeCun, Y., Szlam, A., & Vandergheynst, P. (2017). “Geometric deep learning: Going beyond Euclidean data.” IEEE Signal Processing Magazine, 34(4), 18-42. DOI: 10.1109/MSP.2017.2693418

  • Masci, J., Boscaini, D., Bronstein, M. M., & Vandergheynst, P. (2015). “Geodesic convolutional neural networks on Riemannian manifolds.” In Proceedings of the IEEE ICCV Workshops, pp. 37-45. DOI: 10.1109/ICCVW.2015.112

Parameters:
  • output_points – Number of spatial points in the output.

  • connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output pixel.

  • hidden_units – List of hidden layer dimensions. The last value determines the output feature dimension.

  • activations – List of activation function names, one per hidden layer. Options: “relu”, “selu”, “leaky_relu”, “gelu”, “elu”, “tanh”, “sigmoid”, “swish”, “mish”, “linear”.

  • dropout – Dropout probability applied after each layer (except last). Default is 0.0 (no dropout).

  • use_batch_norm – Whether to apply batch normalization after each layer. Default is False.

  • residual – Whether to add residual connection if input/output dims match. Default is False.

  • weight_init – Weight initialization method. Default is “xavier_uniform”.

Shape:
  • Input: [B, N_in, C_in]

  • Output: [B, N_out, hidden_units[-1]]

Example

>>> from idx_flow.utils import compute_connection_indices
>>> indices, _ = compute_connection_indices(
...     nside_in=64, nside_out=32, k=4
... )
>>> mlp = SpatialMLP(
...     output_points=12 * 32**2,
...     connection_indices=indices,
...     hidden_units=[64, 64, 32],
...     activations=["gelu", "gelu", "linear"],
...     dropout=0.1,
...     use_batch_norm=True
... )
>>> x = torch.randn(8, 12 * 64**2, 16)
>>> y = mlp(x)
>>> print(y.shape)  # torch.Size([8, 12288, 32])

See also

  • SpatialConv: Linear convolution with lower computational cost

  • GlobalMLP: Channel-wise MLP without spatial neighbor gathering

__init__(output_points, connection_indices, hidden_units=(32, 32, 32), activations=('linear', 'linear', 'linear'), dropout=0.0, use_batch_norm=False, residual=False, weight_init='xavier_uniform')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • output_points (int)

  • connection_indices (ndarray[tuple[int, ...], dtype[int64]])

  • hidden_units (Sequence[int])

  • activations (Sequence[Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'] | None])

  • dropout (float)

  • use_batch_norm (bool)

  • residual (bool)

  • weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])

Return type:

None

forward(x)[source]

Forward pass of the spatial MLP.

Parameters:

x (Tensor) – Input tensor of shape [B, N_in, C_in].

Returns:

Output tensor of shape [B, N_out, hidden_units[-1]].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

GlobalMLP

class GlobalMLP[source]

Bases: Module

Global MLP for channel-wise transformations on spatial data.

This layer applies a shared MLP to each spatial point independently, transforming the feature channels without spatial mixing. Useful for pointwise feature transformations in encoder-decoder architectures.

The operation applies the same MLP to each of the N spatial points:

Y[b,n,:] = MLP(X[b,n,:]) for all n

Parameters:
  • hidden_units – List of hidden layer dimensions. The last value determines the output feature dimension.

  • activations – List of activation function names, one per hidden layer.

  • dropout – Dropout probability applied after each layer (except last).

  • use_batch_norm – Whether to apply batch normalization.

  • residual – Whether to add residual connection if input/output dims match.

  • weight_init – Weight initialization method.

Shape:
  • Input: [B, N, C_in]

  • Output: [B, N, hidden_units[-1]]

Example

>>> mlp = GlobalMLP(
...     hidden_units=[64, 128, 64],
...     activations=["gelu", "gelu", "linear"],
...     dropout=0.1,
...     residual=True
... )
>>> x = torch.randn(8, 12288, 32)
>>> # First call initializes based on input channels
>>> y = mlp(x)
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(hidden_units=(64, 64), activations=('gelu', 'linear'), dropout=0.0, use_batch_norm=False, residual=False, weight_init='xavier_uniform')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • hidden_units (Sequence[int])

  • activations (Sequence[Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'] | None])

  • dropout (float)

  • use_batch_norm (bool)

  • residual (bool)

  • weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])

Return type:

None

forward(x)[source]

Forward pass of the global MLP.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C_in].

Returns:

Output tensor of shape [B, N, hidden_units[-1]].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

Normalization Layers

SpatialBatchNorm

class SpatialBatchNorm[source]

Bases: Module

Batch Normalization for spatial data on HEALPix grids.

Applies batch normalization across the spatial and batch dimensions, normalizing per channel.

Parameters:
  • num_features – Number of feature channels.

  • eps – Small constant for numerical stability. Default: 1e-5.

  • momentum – Momentum for running statistics. Default: 0.1.

  • affine – Whether to include learnable affine parameters. Default: True.

Shape:
  • Input: [B, N, C]

  • Output: [B, N, C]

Example

>>> bn = SpatialBatchNorm(num_features=64)
>>> x = torch.randn(8, 12288, 64)
>>> y = bn(x)
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(num_features, eps=1e-05, momentum=0.1, affine=True)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]

Forward pass of spatial batch normalization.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C].

Returns:

Output tensor of shape [B, N, C].

Return type:

Tensor

SpatialLayerNorm

class SpatialLayerNorm[source]

Bases: Module

Layer Normalization for spatial data on HEALPix grids.

Applies layer normalization across the feature dimension for each spatial point independently. Unlike BatchNorm, this normalizes across features rather than across the batch.

Parameters:
  • num_features – Number of feature channels.

  • eps – Small constant for numerical stability. Default: 1e-6.

  • elementwise_affine – Whether to include learnable affine parameters.

Shape:
  • Input: [B, N, C]

  • Output: [B, N, C]

Example

>>> ln = SpatialLayerNorm(num_features=64)
>>> x = torch.randn(8, 12288, 64)
>>> y = ln(x)
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(num_features, eps=1e-06, elementwise_affine=True)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]

Forward pass of spatial layer normalization.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C].

Returns:

Output tensor of shape [B, N, C].

Return type:

Tensor

SpatialInstanceNorm

class SpatialInstanceNorm[source]

Bases: Module

Instance Normalization for spatial data on HEALPix grids.

Applies instance normalization across the spatial dimension for each channel independently. Useful for style transfer and generative models.

Parameters:
  • num_features – Number of feature channels.

  • eps – Small constant for numerical stability. Default: 1e-5.

  • momentum – Momentum for running statistics. Default: 0.1.

  • affine – Whether to include learnable affine parameters. Default: False.

Shape:
  • Input: [B, N, C]

  • Output: [B, N, C]

Example

>>> instnorm = SpatialInstanceNorm(num_features=64, affine=True)
>>> x = torch.randn(8, 12288, 64)
>>> y = instnorm(x)
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(num_features, eps=1e-05, momentum=0.1, affine=False)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]

Forward pass of spatial instance normalization.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C].

Returns:

Output tensor of shape [B, N, C].

Return type:

Tensor

SpatialGroupNorm

class SpatialGroupNorm[source]

Bases: Module

Group Normalization for spatial data on HEALPix grids.

Divides channels into groups and normalizes within each group. Provides a middle ground between LayerNorm and InstanceNorm.

Parameters:
  • num_groups – Number of groups to divide channels into.

  • num_channels – Number of feature channels (must be divisible by num_groups).

  • eps – Small constant for numerical stability. Default: 1e-5.

  • affine – Whether to include learnable affine parameters. Default: True.

Shape:
  • Input: [B, N, C]

  • Output: [B, N, C]

Example

>>> gn = SpatialGroupNorm(num_groups=8, num_channels=64)
>>> x = torch.randn(8, 12288, 64)
>>> y = gn(x)
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(num_groups, num_channels, eps=1e-05, affine=True)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]

Forward pass of spatial group normalization.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C].

Returns:

Output tensor of shape [B, N, C].

Return type:

Tensor

Regularization Layers

SpatialDropout

class SpatialDropout[source]

Bases: Module

Spatial Dropout for HEALPix grid data.

Drops entire spatial locations (all channels for selected points) during training. This encourages the model to learn spatially robust features.

Parameters:

p – Probability of dropping a spatial location. Default: 0.1.

Shape:
  • Input: [B, N, C]

  • Output: [B, N, C]

Example

>>> dropout = SpatialDropout(p=0.2)
>>> x = torch.randn(8, 12288, 64)
>>> y = dropout(x)  # During training, some spatial points are zeroed
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(p=0.1)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

p (float)

Return type:

None

forward(x)[source]

Forward pass of spatial dropout.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C].

Returns:

Output tensor of shape [B, N, C].

Return type:

Tensor

extra_repr()[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

ChannelDropout

class ChannelDropout[source]

Bases: Module

Channel Dropout for HEALPix grid data.

Drops entire channels (all spatial points for selected channels) during training. This encourages the model to learn channel-robust features.

Parameters:

p – Probability of dropping a channel. Default: 0.1.

Shape:
  • Input: [B, N, C]

  • Output: [B, N, C]

Example

>>> dropout = ChannelDropout(p=0.2)
>>> x = torch.randn(8, 12288, 64)
>>> y = dropout(x)  # During training, some channels are zeroed
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(p=0.1)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

p (float)

Return type:

None

forward(x)[source]

Forward pass of channel dropout.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C].

Returns:

Output tensor of shape [B, N, C].

Return type:

Tensor

extra_repr()[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

Attention Layers

SpatialSelfAttention

class SpatialSelfAttention[source]

Bases: Module

Self-Attention layer for spatial data on HEALPix grids.

Applies multi-head self-attention across the spatial dimension, allowing each spatial point to attend to all other points.

When attn_backend="auto" (default) and PyTorch >= 2.0, attention is computed via F.scaled_dot_product_attention, which automatically selects the fastest available kernel:

  • FlashAttention 2 on Ampere+ GPUs with float16/bfloat16

  • Memory-efficient attention via xFormers-style backend

  • Math fallback on CPU or unsupported dtypes

Set attn_backend="manual" to force the explicit matmul-softmax-matmul path (always available).

Note: Complexity is O(N^2) in the number of spatial points regardless of backend. FlashAttention 2 reduces the constant factor and memory usage but does not change asymptotic complexity.

Parameters:
  • embed_dim – Total dimension of the model (must be divisible by num_heads).

  • num_heads – Number of attention heads.

  • dropout – Dropout probability on attention weights. Default: 0.0.

  • bias – Whether to include bias in projections. Default: True.

  • attn_backend – Attention computation backend. Default: “auto”. - "auto": Use SDPA when available (PyTorch >= 2.0), else manual. - "sdpa": Force SDPA (raises if unavailable). - "manual": Force explicit matmul-softmax-matmul.

Shape:
  • Input: [B, N, embed_dim]

  • Output: [B, N, embed_dim]

Example

>>> attn = SpatialSelfAttention(embed_dim=64, num_heads=8)
>>> x = torch.randn(4, 768, 64)
>>> y = attn(x)
>>> print(y.shape)  # torch.Size([4, 768, 64])
__init__(embed_dim, num_heads, dropout=0.0, bias=True, attn_backend='auto')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • embed_dim (int)

  • num_heads (int)

  • dropout (float)

  • bias (bool)

  • attn_backend (Literal['auto', 'sdpa', 'manual'])

Return type:

None

forward(x)[source]

Forward pass of spatial self-attention.

Parameters:

x (Tensor) – Input tensor of shape [B, N, embed_dim].

Returns:

Output tensor of shape [B, N, embed_dim].

Return type:

Tensor

extra_repr()[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

Vision Transformer Layers

SpatialPatchEmbedding

class SpatialPatchEmbedding[source]

Bases: Module

Spatial Patch Embedding layer for HEALPix grids.

This layer creates patch embeddings from local neighborhoods on the spherical grid using precomputed connection indices. It gathers k neighbor features for each output point, flattens the neighborhood into a single vector, and projects it to the embedding dimension via a linear layer.

This is the spatial analog of the patch embedding in Vision Transformers (ViT), adapted for non-Euclidean domains. Instead of extracting fixed-size 2D image patches, it uses precomputed topology (connection indices) to define local patches on the sphere.

The operation:
  1. Gather: Collect k neighbor features for each output point: [B, N_out, k, C_in]

  2. Flatten: Reshape neighborhood to vector: [B, N_out, k * C_in]

  3. Project: Linear projection to embedding dimension: [B, N_out, embed_dim]

Mathematically:

E[b,p,:] = W_proj * [X[b, idx[p,0], :] || … || X[b, idx[p,k-1], :]] + b_proj

Literature Context:

Adapted from the Vision Transformer (ViT) architecture:

  • Dosovitskiy, A., Beyer, L., Kolesnikov, A., et al. (2021). “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” In ICLR 2021. arXiv: 2010.11929

Parameters:
  • output_points – Number of spatial points in the output (number of patches).

  • connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output patch.

  • embed_dim – Dimension of the patch embedding vectors.

  • bias – Whether to include a bias term in the projection. Default is True.

  • weight_init – Weight initialization method. Default is “xavier_uniform”.

  • weight_init_gain – Gain for xavier/orthogonal initialization.

projection

Learnable linear projection from flattened patch to embed_dim.

Shape:
  • Input: [B, N_in, C_in] where B is batch size, N_in is input points, C_in is input channels.

  • Output: [B, N_out, embed_dim] where N_out is output_points.

Example

>>> from idx_flow.utils import compute_connection_indices
>>> indices, _ = compute_connection_indices(
...     nside_in=64, nside_out=32, k=9
... )
>>> patch_embed = SpatialPatchEmbedding(
...     output_points=12 * 32**2,
...     connection_indices=indices,
...     embed_dim=128,
... )
>>> x = torch.randn(8, 12 * 64**2, 16)
>>> embeddings = patch_embed(x)
>>> print(embeddings.shape)  # torch.Size([8, 12288, 128])

See also

  • SpatialConv: Linear convolution (einsum-based kernel)

  • SpatialMLP: MLP kernel for non-linear local processing

  • SpatialViT: Full Vision Transformer using this embedding

__init__(output_points, connection_indices, embed_dim=128, bias=True, weight_init='xavier_uniform', weight_init_gain=1.0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • output_points (int)

  • connection_indices (ndarray[tuple[int, ...], dtype[int64]])

  • embed_dim (int)

  • bias (bool)

  • weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])

  • weight_init_gain (float)

Return type:

None

forward(x)[source]

Forward pass of spatial patch embedding.

Parameters:

x (Tensor) – Input tensor of shape [B, N_in, C_in].

Returns:

Output tensor of shape [B, N_out, embed_dim].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

SpatialTransformerBlock

class SpatialTransformerBlock[source]

Bases: Module

Transformer Encoder Block for spatial data on HEALPix grids.

Implements a standard pre-norm Transformer encoder block with multi-head self-attention and a position-wise feed-forward network (FFN). Uses residual connections and layer normalization following the pre-norm convention (LN before attention/FFN rather than after).

The operation:
  1. Pre-norm: LayerNorm -> Multi-Head Self-Attention -> Residual

  2. Pre-norm: LayerNorm -> Feed-Forward Network -> Residual

Mathematically:

h = x + MHSA(LN(x)) y = h + FFN(LN(h))

where:
  • MHSA: Multi-Head Self-Attention with scaled dot-product

  • FFN: Two-layer MLP with activation (Linear -> Act -> Dropout -> Linear -> Dropout)

  • LN: Layer Normalization per spatial point

Note: The self-attention has O(N^2) complexity in the number of spatial points. For large grids, consider using this after spatial downsampling.

Parameters:
  • embed_dim – Embedding dimension (must be divisible by num_heads).

  • num_heads – Number of attention heads.

  • mlp_ratio – Ratio of FFN hidden dimension to embed_dim. Default is 4.0.

  • dropout – Dropout probability in attention and FFN. Default is 0.0.

  • activation – Activation function for the FFN. Default is “gelu”.

  • bias – Whether to include bias in linear projections. Default is True.

  • norm_eps – Epsilon for layer normalization. Default is 1e-6.

  • attn_backend – Attention backend. "auto" uses SDPA/FlashAttention 2 when available, "manual" forces explicit matmul path. Default is "auto".

Shape:
  • Input: [B, N, embed_dim]

  • Output: [B, N, embed_dim]

Example

>>> block = SpatialTransformerBlock(
...     embed_dim=128,
...     num_heads=8,
...     mlp_ratio=4.0,
...     dropout=0.1,
... )
>>> x = torch.randn(8, 768, 128)
>>> y = block(x)
>>> print(y.shape)  # torch.Size([8, 768, 128])

See also

__init__(embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0, activation='gelu', bias=True, norm_eps=1e-06, attn_backend='auto')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • embed_dim (int)

  • num_heads (int)

  • mlp_ratio (float)

  • dropout (float)

  • activation (Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'])

  • bias (bool)

  • norm_eps (float)

  • attn_backend (Literal['auto', 'sdpa', 'manual'])

Return type:

None

forward(x)[source]

Forward pass of the transformer block.

Parameters:

x (Tensor) – Input tensor of shape [B, N, embed_dim].

Returns:

Output tensor of shape [B, N, embed_dim].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

SpatialViT

class SpatialViT[source]

Bases: Module

Vision Transformer (ViT) for spatial data on HEALPix grids.

This layer implements a complete Vision Transformer adapted for spherical data discretized on HEALPix grids. It combines index-based patch embedding with Transformer encoder blocks, bridging the structure-compilation philosophy of this library with the global attention mechanism of ViT.

The architecture:
  1. Patch Embedding: Uses precomputed connection indices to gather local neighborhoods and project them to embedding vectors.

  2. Positional Encoding: Learnable positional embeddings added to each spatial point to encode location on the sphere.

  3. Transformer Encoder: Stack of N Transformer blocks, each with multi-head self-attention and feed-forward network.

  4. Output Projection: Optional linear projection to desired output dimension.

The operation:
  1. Embed: E = PatchEmbed(X) + PosEmbed [B, N_out, embed_dim]

  2. Encode: Z = TransformerBlock_N(…TransformerBlock_1(E))

  3. Project: Y = Linear(LN(Z)) [B, N_out, output_dim]

Complexity:
  • Patch embedding: O(N_out * k * C_in) – linear in output points

  • Self-attention per block: O(N_out^2 * embed_dim) – quadratic in points

  • Total: O(depth * N_out^2 * embed_dim)

For large grids, use spatial downsampling (via connection indices from higher to lower resolution) to reduce N_out before the transformer.

Literature Context:

Adapted from:

  • Dosovitskiy, A., Beyer, L., Kolesnikov, A., et al. (2021). “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” In ICLR 2021. arXiv: 2010.11929

Parameters:
  • output_points – Number of spatial points (patches) in the output.

  • connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output patch.

  • embed_dim – Dimension of the patch embeddings. Default is 128.

  • depth – Number of Transformer encoder blocks. Default is 4.

  • num_heads – Number of attention heads per block. Default is 8.

  • mlp_ratio – Ratio of FFN hidden dim to embed_dim. Default is 4.0.

  • output_dim – Output feature dimension. If None, equals embed_dim.

  • dropout – Dropout probability for attention and FFN. Default is 0.0.

  • activation – Activation function for FFN layers. Default is “gelu”.

  • bias – Whether to include bias in linear projections. Default is True.

  • weight_init – Weight initialization method for patch embedding. Default is “xavier_uniform”.

  • norm_eps – Epsilon for layer normalization. Default is 1e-6.

  • attn_backend – Attention backend for all transformer blocks. "auto" uses SDPA/FlashAttention 2 when available, "manual" forces explicit matmul path. Default is "auto".

Shape:
  • Input: [B, N_in, C_in] where B is batch size, N_in is input points, C_in is input channels.

  • Output: [B, N_out, output_dim] where N_out is output_points.

Example

>>> from idx_flow.utils import compute_connection_indices
>>> indices, _ = compute_connection_indices(
...     nside_in=16, nside_out=8, k=9
... )
>>> vit = SpatialViT(
...     output_points=12 * 8**2,
...     connection_indices=indices,
...     embed_dim=64,
...     depth=4,
...     num_heads=4,
...     output_dim=32,
...     dropout=0.1,
... )
>>> x = torch.randn(4, 12 * 16**2, 8)
>>> y = vit(x)
>>> print(y.shape)  # torch.Size([4, 768, 32])

See also

__init__(output_points, connection_indices, embed_dim=128, depth=4, num_heads=8, mlp_ratio=4.0, output_dim=None, dropout=0.0, activation='gelu', bias=True, weight_init='xavier_uniform', norm_eps=1e-06, attn_backend='auto')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • output_points (int)

  • connection_indices (ndarray[tuple[int, ...], dtype[int64]])

  • embed_dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • output_dim (int | None)

  • dropout (float)

  • activation (Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'])

  • bias (bool)

  • weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])

  • norm_eps (float)

  • attn_backend (Literal['auto', 'sdpa', 'manual'])

Return type:

None

forward(x)[source]

Forward pass of the Spatial Vision Transformer.

Parameters:

x (Tensor) – Input tensor of shape [B, N_in, C_in].

Returns:

Output tensor of shape [B, N_out, output_dim].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

Pooling and Utility Layers

SpatialPooling

class SpatialPooling[source]

Bases: Module

Spatial Pooling layer for HEALPix grids.

Performs pooling operations (mean, max, or sum) over local neighborhoods on the spherical grid. This is a non-learnable layer useful for downsampling with simple aggregation.

Parameters:
  • output_points – Number of spatial points in the output.

  • connection_indices – Integer array of shape [output_points, kernel_size].

  • pool_type – Type of pooling operation. One of “mean”, “max”, “sum”.

Shape:
  • Input: [B, N_in, C_in]

  • Output: [B, N_out, C_in] (channels preserved)

Example

>>> from idx_flow.utils import compute_connection_indices
>>> indices, _ = compute_connection_indices(
...     nside_in=64, nside_out=32, k=4
... )
>>> pool = SpatialPooling(
...     output_points=12 * 32**2,
...     connection_indices=indices,
...     pool_type="mean"
... )
>>> x = torch.randn(8, 12 * 64**2, 32)
>>> y = pool(x)
>>> print(y.shape)  # torch.Size([8, 12288, 32])
__init__(output_points, connection_indices, pool_type='mean')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
Return type:

None

forward(x)[source]

Forward pass of spatial pooling.

Parameters:

x (Tensor) – Input tensor of shape [B, N_in, C_in].

Returns:

Output tensor of shape [B, N_out, C_in].

Return type:

Tensor

extra_repr()[source]

Return a string representation of layer parameters.

Return type:

str

Squeeze

class Squeeze[source]

Bases: Module

Squeeze layer that reduces spatial dimension to a single vector.

Performs global aggregation over all spatial points using mean, max, or sum pooling.

Parameters:

reduction – Reduction method. One of “mean”, “max”, “sum”.

Shape:
  • Input: [B, N, C]

  • Output: [B, C]

Example

>>> squeeze = Squeeze(reduction="mean")
>>> x = torch.randn(8, 12288, 64)
>>> y = squeeze(x)
>>> print(y.shape)  # torch.Size([8, 64])
__init__(reduction='mean')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

reduction (Literal['mean', 'max', 'sum'])

Return type:

None

forward(x)[source]

Forward pass of squeeze.

Parameters:

x (Tensor) – Input tensor of shape [B, N, C].

Returns:

Output tensor of shape [B, C].

Return type:

Tensor

extra_repr()[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

Unsqueeze

class Unsqueeze[source]

Bases: Module

Unsqueeze layer that broadcasts a vector to all spatial points.

Takes a feature vector and replicates it across the spatial dimension.

Parameters:

num_points – Number of spatial points to broadcast to.

Shape:
  • Input: [B, C]

  • Output: [B, num_points, C]

Example

>>> unsqueeze = Unsqueeze(num_points=12288)
>>> x = torch.randn(8, 64)
>>> y = unsqueeze(x)
>>> print(y.shape)  # torch.Size([8, 12288, 64])
__init__(num_points)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

num_points (int)

Return type:

None

forward(x)[source]

Forward pass of unsqueeze.

Parameters:

x (Tensor) – Input tensor of shape [B, C].

Returns:

Output tensor of shape [B, num_points, C].

Return type:

Tensor

extra_repr()[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str

Functional Utilities

get_initializer

get_initializer(method, gain=1.0, nonlinearity='leaky_relu', mean=0.0, std=0.02, a=0.0, b=1.0)[source]

Get weight initialization function.

Parameters:
  • method (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros']) – Initialization method name.

  • gain (float) – Gain factor for xavier/orthogonal initialization.

  • nonlinearity (str) – Nonlinearity for kaiming initialization.

  • mean (float) – Mean for normal initialization.

  • std (float) – Standard deviation for normal initialization.

  • a (float) – Lower bound for uniform initialization.

  • b (float) – Upper bound for uniform initialization.

Returns:

Initialization function that takes a tensor and initializes it in-place.

Raises:

ValueError – If method is not recognized.

Return type:

Callable[[Tensor], Tensor]

get_activation

get_activation(name)[source]

Get activation module by name.

Parameters:

name (Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'] | None) – Activation function name. If None, returns Identity.

Returns:

PyTorch activation module.

Raises:

ValueError – If activation name is not recognized.

Return type:

Module

Type Aliases

idx_flow.InitMethod

Weight initialization methods: "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", "orthogonal", "normal", "uniform", "zeros"

idx_flow.ActivationType

Activation functions: "relu", "selu", "leaky_relu", "gelu", "elu", "tanh", "sigmoid", "swish", "mish", "linear"

idx_flow.InterpolationMethod

Interpolation methods: "linear", "idw", "gaussian"

idx_flow.PoolingMethod

Pooling methods: "mean", "max", "sum"

idx_flow.AttnBackend

Attention computation backends: "auto" (SDPA when available, else manual), "sdpa" (force scaled_dot_product_attention, requires PyTorch >= 2.0), "manual" (explicit matmul-softmax-matmul)