Layers API Reference
All layers are importable directly from idx_flow. Internally they live in
separate modules (conv, mlp, norm, regularization, attention,
vit, pooling, functional).
Convolution Layers
SpatialConv
- class SpatialConv[source]
Bases:
ModuleSpatial Convolution layer for downsampling on HEALPix grids.
This layer performs convolution on spherical data discretized using the HEALPix tessellation scheme. It uses precomputed connection indices to gather features from neighboring pixels and applies learnable kernels for spatial feature transformation.
- The operation follows:
Gather: Collect features from k neighbors for each output point
Transform: Apply learnable kernel weights
Aggregate: Sum weighted contributions with bias
- Mathematically:
Y[b,p,f] = sum_k sum_c X[b, idx[p,k], c] * W[k,c,f] + bias[f]
- Parameters:
output_points – Number of spatial points in the output tensor.
connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels that connect to each output pixel.
kernel_weights – Optional float array of shape [output_points, kernel_size] containing distance-based weights for each connection. If provided, neighbor features are scaled by these weights before convolution.
filters – Number of output feature channels (filters).
bias – Whether to include a bias term. Default is True.
weight_init – Weight initialization method. Default is “xavier_uniform”.
weight_init_gain – Gain for xavier/orthogonal initialization.
bias_init – Bias initialization value. Default is 0.0.
- kernel
Learnable weight tensor of shape [kernel_size, in_channels, filters].
- bias_param
Learnable bias tensor of shape [filters] if bias=True.
- Shape:
Input: [B, N_in, C_in] where B is batch size, N_in is input points, C_in is input channels.
Output: [B, N_out, filters] where N_out is output_points.
Example
>>> from idx_flow.utils import compute_connection_indices >>> indices, distances = compute_connection_indices( ... nside_in=64, nside_out=32, k=4 ... ) >>> conv = SpatialConv( ... output_points=12 * 32**2, ... connection_indices=indices, ... filters=64, ... weight_init="kaiming_normal" ... ) >>> x = torch.randn(8, 12 * 64**2, 32) >>> y = conv(x) >>> print(y.shape) # torch.Size([8, 12288, 64])
Notes
Connection indices must be precomputed using hp_distance or similar.
The layer maintains O(N) complexity per forward pass.
Input channels are inferred from the first forward pass.
- __init__(output_points, connection_indices, kernel_weights=None, filters=32, bias=True, weight_init='xavier_uniform', weight_init_gain=1.0, bias_init=0.0)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
- Return type:
None
SpatialTransposeConv
- class SpatialTransposeConv[source]
Bases:
ModuleSpatial Transpose Convolution layer for upsampling on HEALPix grids.
This layer performs transposed (deconvolution) operations for upsampling spatial resolution on HEALPix grids. It maps features from a lower resolution grid to a higher resolution grid using precomputed connection indices.
- Parameters:
output_points – Number of spatial points in the output (higher resolution).
connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output pixel.
kernel_weights – Optional float array of shape [output_points, kernel_size] containing distance-based weights for each connection.
filters – Number of output feature channels.
bias – Whether to include a bias term. Default is True.
weight_init – Weight initialization method. Default is “xavier_uniform”.
weight_init_gain – Gain for xavier/orthogonal initialization.
bias_init – Bias initialization value. Default is 0.0.
- Shape:
Input: [B, N_in, C_in] where N_in is the lower resolution.
Output: [B, N_out, filters] where N_out is output_points (higher res).
Example
>>> from idx_flow.utils import compute_connection_indices >>> indices, distances, weights = compute_connection_indices( ... nside_in=32, nside_out=64, k=4, return_weights=True ... ) >>> transpose_conv = SpatialTransposeConv( ... output_points=12 * 64**2, ... connection_indices=indices, ... kernel_weights=weights, ... filters=32, ... weight_init="orthogonal" ... ) >>> x = torch.randn(8, 12 * 32**2, 64) >>> y = transpose_conv(x) >>> print(y.shape) # torch.Size([8, 49152, 32])
- __init__(output_points, connection_indices, kernel_weights=None, filters=32, bias=True, weight_init='xavier_uniform', weight_init_gain=1.0, bias_init=0.0)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
- Return type:
None
SpatialUpsampling
- class SpatialUpsampling[source]
Bases:
ModuleSpatial Upsampling layer using distance-based interpolation.
This layer performs upsampling on HEALPix grids using precomputed interpolation weights based on geodesic distances. Unlike SpatialTransposeConv, this layer does not have learnable parameters and performs pure distance-weighted interpolation.
- Interpolation methods:
“linear”: Weight = max(0, 1 - distance / kernel_radius)
“idw”: Inverse distance weighting, Weight = 1 / (distance^2 + eps)
“gaussian”: Weight = exp(-0.5 * (distance / kernel_radius)^2)
- Parameters:
output_points – Number of spatial points in the output (higher resolution).
connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output pixel.
distances – Float array of shape [output_points, kernel_size] containing geodesic distances to each neighbor.
interpolation – Interpolation method. One of “linear”, “idw”, “gaussian”.
kernel_radius – Radius for interpolation kernel. If None, uses the maximum distance. Only used for “linear” and “gaussian” methods.
- Shape:
Input: [B, N_in, C_in]
Output: [B, N_out, C_in] (channels preserved)
Example
>>> from idx_flow.utils import compute_connection_indices >>> indices, distances = compute_connection_indices( ... nside_in=32, nside_out=64, k=4 ... ) >>> upsample = SpatialUpsampling( ... output_points=12 * 64**2, ... connection_indices=indices, ... distances=distances, ... interpolation="idw" ... ) >>> x = torch.randn(8, 12 * 32**2, 32) >>> y = upsample(x) >>> print(y.shape) # torch.Size([8, 49152, 32])
Notes
This layer has no learnable parameters.
Output channels equal input channels (no feature transformation).
Weights are precomputed and stored as buffers for efficiency.
- __init__(output_points, connection_indices, distances, interpolation='linear', kernel_radius=None)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
MLP Layers
SpatialMLP
- class SpatialMLP[source]
Bases:
ModuleSpatial Multi-Layer Perceptron for local non-linear processing on HEALPix grids.
This layer replaces the standard linear convolution kernel (matrix multiplication) with a shared Dense Neural Network (MLP) applied to the flattened neighborhood vector. This architectural choice is inspired by principles from Geometric Deep Learning, which extends neural network operations to non-Euclidean domains such as manifolds and graphs.
- Architectural Difference from SpatialConv:
SpatialConv: Applies a linear transformation to neighborhood features. The output is a weighted sum: Y = sum_k(W_k * X_k) + b, which can only capture linear relationships between neighbors.
SpatialMLP: Concatenates neighborhood features into a single vector and processes it through a multi-layer perceptron with non-linear activations: Y = MLP([X_1 || X_2 || … || X_k]). This allows the model to learn complex, non-linear interactions between neighbors.
- Trade-offs:
- Benefits:
Significantly higher representation capacity and expressivity
Can approximate arbitrary non-linear functions over local patches
Better suited for learning complex spatial patterns on manifolds
Supports dropout, batch normalization, and residual connections
- Costs:
Higher computational cost (FLOPs) due to dense layer operations
Increased memory usage for storing MLP parameters
May require more training data to avoid overfitting
- The operation:
Gather k neighbor features for each output point: [B, N_out, k, C_in]
Flatten the neighbor features: [B, N_out, k * C_in]
Process through shared MLP layers with specified activations
Output: [B, N_out, hidden_units[-1]]
- Literature Context:
This approach draws from Geometric Deep Learning research on extending neural networks to non-Euclidean domains:
Bronstein, M. M., Bruna, J., LeCun, Y., Szlam, A., & Vandergheynst, P. (2017). “Geometric deep learning: Going beyond Euclidean data.” IEEE Signal Processing Magazine, 34(4), 18-42. DOI: 10.1109/MSP.2017.2693418
Masci, J., Boscaini, D., Bronstein, M. M., & Vandergheynst, P. (2015). “Geodesic convolutional neural networks on Riemannian manifolds.” In Proceedings of the IEEE ICCV Workshops, pp. 37-45. DOI: 10.1109/ICCVW.2015.112
- Parameters:
output_points – Number of spatial points in the output.
connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output pixel.
hidden_units – List of hidden layer dimensions. The last value determines the output feature dimension.
activations – List of activation function names, one per hidden layer. Options: “relu”, “selu”, “leaky_relu”, “gelu”, “elu”, “tanh”, “sigmoid”, “swish”, “mish”, “linear”.
dropout – Dropout probability applied after each layer (except last). Default is 0.0 (no dropout).
use_batch_norm – Whether to apply batch normalization after each layer. Default is False.
residual – Whether to add residual connection if input/output dims match. Default is False.
weight_init – Weight initialization method. Default is “xavier_uniform”.
- Shape:
Input: [B, N_in, C_in]
Output: [B, N_out, hidden_units[-1]]
Example
>>> from idx_flow.utils import compute_connection_indices >>> indices, _ = compute_connection_indices( ... nside_in=64, nside_out=32, k=4 ... ) >>> mlp = SpatialMLP( ... output_points=12 * 32**2, ... connection_indices=indices, ... hidden_units=[64, 64, 32], ... activations=["gelu", "gelu", "linear"], ... dropout=0.1, ... use_batch_norm=True ... ) >>> x = torch.randn(8, 12 * 64**2, 16) >>> y = mlp(x) >>> print(y.shape) # torch.Size([8, 12288, 32])
See also
SpatialConv: Linear convolution with lower computational costGlobalMLP: Channel-wise MLP without spatial neighbor gathering
- __init__(output_points, connection_indices, hidden_units=(32, 32, 32), activations=('linear', 'linear', 'linear'), dropout=0.0, use_batch_norm=False, residual=False, weight_init='xavier_uniform')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
output_points (int)
activations (Sequence[Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'] | None])
dropout (float)
use_batch_norm (bool)
residual (bool)
weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])
- Return type:
None
GlobalMLP
- class GlobalMLP[source]
Bases:
ModuleGlobal MLP for channel-wise transformations on spatial data.
This layer applies a shared MLP to each spatial point independently, transforming the feature channels without spatial mixing. Useful for pointwise feature transformations in encoder-decoder architectures.
- The operation applies the same MLP to each of the N spatial points:
Y[b,n,:] = MLP(X[b,n,:]) for all n
- Parameters:
hidden_units – List of hidden layer dimensions. The last value determines the output feature dimension.
activations – List of activation function names, one per hidden layer.
dropout – Dropout probability applied after each layer (except last).
use_batch_norm – Whether to apply batch normalization.
residual – Whether to add residual connection if input/output dims match.
weight_init – Weight initialization method.
- Shape:
Input: [B, N, C_in]
Output: [B, N, hidden_units[-1]]
Example
>>> mlp = GlobalMLP( ... hidden_units=[64, 128, 64], ... activations=["gelu", "gelu", "linear"], ... dropout=0.1, ... residual=True ... ) >>> x = torch.randn(8, 12288, 32) >>> # First call initializes based on input channels >>> y = mlp(x) >>> print(y.shape) # torch.Size([8, 12288, 64])
- __init__(hidden_units=(64, 64), activations=('gelu', 'linear'), dropout=0.0, use_batch_norm=False, residual=False, weight_init='xavier_uniform')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
activations (Sequence[Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'] | None])
dropout (float)
use_batch_norm (bool)
residual (bool)
weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])
- Return type:
None
Normalization Layers
SpatialBatchNorm
- class SpatialBatchNorm[source]
Bases:
ModuleBatch Normalization for spatial data on HEALPix grids.
Applies batch normalization across the spatial and batch dimensions, normalizing per channel.
- Parameters:
num_features – Number of feature channels.
eps – Small constant for numerical stability. Default: 1e-5.
momentum – Momentum for running statistics. Default: 0.1.
affine – Whether to include learnable affine parameters. Default: True.
- Shape:
Input: [B, N, C]
Output: [B, N, C]
Example
>>> bn = SpatialBatchNorm(num_features=64) >>> x = torch.randn(8, 12288, 64) >>> y = bn(x) >>> print(y.shape) # torch.Size([8, 12288, 64])
SpatialLayerNorm
- class SpatialLayerNorm[source]
Bases:
ModuleLayer Normalization for spatial data on HEALPix grids.
Applies layer normalization across the feature dimension for each spatial point independently. Unlike BatchNorm, this normalizes across features rather than across the batch.
- Parameters:
num_features – Number of feature channels.
eps – Small constant for numerical stability. Default: 1e-6.
elementwise_affine – Whether to include learnable affine parameters.
- Shape:
Input: [B, N, C]
Output: [B, N, C]
Example
>>> ln = SpatialLayerNorm(num_features=64) >>> x = torch.randn(8, 12288, 64) >>> y = ln(x) >>> print(y.shape) # torch.Size([8, 12288, 64])
SpatialInstanceNorm
- class SpatialInstanceNorm[source]
Bases:
ModuleInstance Normalization for spatial data on HEALPix grids.
Applies instance normalization across the spatial dimension for each channel independently. Useful for style transfer and generative models.
- Parameters:
num_features – Number of feature channels.
eps – Small constant for numerical stability. Default: 1e-5.
momentum – Momentum for running statistics. Default: 0.1.
affine – Whether to include learnable affine parameters. Default: False.
- Shape:
Input: [B, N, C]
Output: [B, N, C]
Example
>>> instnorm = SpatialInstanceNorm(num_features=64, affine=True) >>> x = torch.randn(8, 12288, 64) >>> y = instnorm(x) >>> print(y.shape) # torch.Size([8, 12288, 64])
SpatialGroupNorm
- class SpatialGroupNorm[source]
Bases:
ModuleGroup Normalization for spatial data on HEALPix grids.
Divides channels into groups and normalizes within each group. Provides a middle ground between LayerNorm and InstanceNorm.
- Parameters:
num_groups – Number of groups to divide channels into.
num_channels – Number of feature channels (must be divisible by num_groups).
eps – Small constant for numerical stability. Default: 1e-5.
affine – Whether to include learnable affine parameters. Default: True.
- Shape:
Input: [B, N, C]
Output: [B, N, C]
Example
>>> gn = SpatialGroupNorm(num_groups=8, num_channels=64) >>> x = torch.randn(8, 12288, 64) >>> y = gn(x) >>> print(y.shape) # torch.Size([8, 12288, 64])
Regularization Layers
SpatialDropout
- class SpatialDropout[source]
Bases:
ModuleSpatial Dropout for HEALPix grid data.
Drops entire spatial locations (all channels for selected points) during training. This encourages the model to learn spatially robust features.
- Parameters:
p – Probability of dropping a spatial location. Default: 0.1.
- Shape:
Input: [B, N, C]
Output: [B, N, C]
Example
>>> dropout = SpatialDropout(p=0.2) >>> x = torch.randn(8, 12288, 64) >>> y = dropout(x) # During training, some spatial points are zeroed >>> print(y.shape) # torch.Size([8, 12288, 64])
- __init__(p=0.1)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
p (float)
- Return type:
None
ChannelDropout
- class ChannelDropout[source]
Bases:
ModuleChannel Dropout for HEALPix grid data.
Drops entire channels (all spatial points for selected channels) during training. This encourages the model to learn channel-robust features.
- Parameters:
p – Probability of dropping a channel. Default: 0.1.
- Shape:
Input: [B, N, C]
Output: [B, N, C]
Example
>>> dropout = ChannelDropout(p=0.2) >>> x = torch.randn(8, 12288, 64) >>> y = dropout(x) # During training, some channels are zeroed >>> print(y.shape) # torch.Size([8, 12288, 64])
- __init__(p=0.1)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
p (float)
- Return type:
None
Attention Layers
SpatialSelfAttention
- class SpatialSelfAttention[source]
Bases:
ModuleSelf-Attention layer for spatial data on HEALPix grids.
Applies multi-head self-attention across the spatial dimension, allowing each spatial point to attend to all other points.
When
attn_backend="auto"(default) and PyTorch >= 2.0, attention is computed viaF.scaled_dot_product_attention, which automatically selects the fastest available kernel:FlashAttention 2 on Ampere+ GPUs with float16/bfloat16
Memory-efficient attention via xFormers-style backend
Math fallback on CPU or unsupported dtypes
Set
attn_backend="manual"to force the explicit matmul-softmax-matmul path (always available).Note: Complexity is O(N^2) in the number of spatial points regardless of backend. FlashAttention 2 reduces the constant factor and memory usage but does not change asymptotic complexity.
- Parameters:
embed_dim – Total dimension of the model (must be divisible by num_heads).
num_heads – Number of attention heads.
dropout – Dropout probability on attention weights. Default: 0.0.
bias – Whether to include bias in projections. Default: True.
attn_backend – Attention computation backend. Default: “auto”. -
"auto": Use SDPA when available (PyTorch >= 2.0), else manual. -"sdpa": Force SDPA (raises if unavailable). -"manual": Force explicit matmul-softmax-matmul.
- Shape:
Input: [B, N, embed_dim]
Output: [B, N, embed_dim]
Example
>>> attn = SpatialSelfAttention(embed_dim=64, num_heads=8) >>> x = torch.randn(4, 768, 64) >>> y = attn(x) >>> print(y.shape) # torch.Size([4, 768, 64])
- __init__(embed_dim, num_heads, dropout=0.0, bias=True, attn_backend='auto')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Vision Transformer Layers
SpatialPatchEmbedding
- class SpatialPatchEmbedding[source]
Bases:
ModuleSpatial Patch Embedding layer for HEALPix grids.
This layer creates patch embeddings from local neighborhoods on the spherical grid using precomputed connection indices. It gathers k neighbor features for each output point, flattens the neighborhood into a single vector, and projects it to the embedding dimension via a linear layer.
This is the spatial analog of the patch embedding in Vision Transformers (ViT), adapted for non-Euclidean domains. Instead of extracting fixed-size 2D image patches, it uses precomputed topology (connection indices) to define local patches on the sphere.
- The operation:
Gather: Collect k neighbor features for each output point: [B, N_out, k, C_in]
Flatten: Reshape neighborhood to vector: [B, N_out, k * C_in]
Project: Linear projection to embedding dimension: [B, N_out, embed_dim]
- Mathematically:
E[b,p,:] = W_proj * [X[b, idx[p,0], :] || … || X[b, idx[p,k-1], :]] + b_proj
- Literature Context:
Adapted from the Vision Transformer (ViT) architecture:
Dosovitskiy, A., Beyer, L., Kolesnikov, A., et al. (2021). “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” In ICLR 2021. arXiv: 2010.11929
- Parameters:
output_points – Number of spatial points in the output (number of patches).
connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output patch.
embed_dim – Dimension of the patch embedding vectors.
bias – Whether to include a bias term in the projection. Default is True.
weight_init – Weight initialization method. Default is “xavier_uniform”.
weight_init_gain – Gain for xavier/orthogonal initialization.
- projection
Learnable linear projection from flattened patch to embed_dim.
- Shape:
Input: [B, N_in, C_in] where B is batch size, N_in is input points, C_in is input channels.
Output: [B, N_out, embed_dim] where N_out is output_points.
Example
>>> from idx_flow.utils import compute_connection_indices >>> indices, _ = compute_connection_indices( ... nside_in=64, nside_out=32, k=9 ... ) >>> patch_embed = SpatialPatchEmbedding( ... output_points=12 * 32**2, ... connection_indices=indices, ... embed_dim=128, ... ) >>> x = torch.randn(8, 12 * 64**2, 16) >>> embeddings = patch_embed(x) >>> print(embeddings.shape) # torch.Size([8, 12288, 128])
See also
SpatialConv: Linear convolution (einsum-based kernel)SpatialMLP: MLP kernel for non-linear local processingSpatialViT: Full Vision Transformer using this embedding
- __init__(output_points, connection_indices, embed_dim=128, bias=True, weight_init='xavier_uniform', weight_init_gain=1.0)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
- Return type:
None
SpatialTransformerBlock
- class SpatialTransformerBlock[source]
Bases:
ModuleTransformer Encoder Block for spatial data on HEALPix grids.
Implements a standard pre-norm Transformer encoder block with multi-head self-attention and a position-wise feed-forward network (FFN). Uses residual connections and layer normalization following the pre-norm convention (LN before attention/FFN rather than after).
- The operation:
Pre-norm: LayerNorm -> Multi-Head Self-Attention -> Residual
Pre-norm: LayerNorm -> Feed-Forward Network -> Residual
- Mathematically:
h = x + MHSA(LN(x)) y = h + FFN(LN(h))
- where:
MHSA: Multi-Head Self-Attention with scaled dot-product
FFN: Two-layer MLP with activation (Linear -> Act -> Dropout -> Linear -> Dropout)
LN: Layer Normalization per spatial point
Note: The self-attention has O(N^2) complexity in the number of spatial points. For large grids, consider using this after spatial downsampling.
- Parameters:
embed_dim – Embedding dimension (must be divisible by num_heads).
num_heads – Number of attention heads.
mlp_ratio – Ratio of FFN hidden dimension to embed_dim. Default is 4.0.
dropout – Dropout probability in attention and FFN. Default is 0.0.
activation – Activation function for the FFN. Default is “gelu”.
bias – Whether to include bias in linear projections. Default is True.
norm_eps – Epsilon for layer normalization. Default is 1e-6.
attn_backend – Attention backend.
"auto"uses SDPA/FlashAttention 2 when available,"manual"forces explicit matmul path. Default is"auto".
- Shape:
Input: [B, N, embed_dim]
Output: [B, N, embed_dim]
Example
>>> block = SpatialTransformerBlock( ... embed_dim=128, ... num_heads=8, ... mlp_ratio=4.0, ... dropout=0.1, ... ) >>> x = torch.randn(8, 768, 128) >>> y = block(x) >>> print(y.shape) # torch.Size([8, 768, 128])
See also
SpatialSelfAttention: Standalone multi-head self-attentionSpatialViT: Full Vision Transformer using this block
- __init__(embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0, activation='gelu', bias=True, norm_eps=1e-06, attn_backend='auto')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
SpatialViT
- class SpatialViT[source]
Bases:
ModuleVision Transformer (ViT) for spatial data on HEALPix grids.
This layer implements a complete Vision Transformer adapted for spherical data discretized on HEALPix grids. It combines index-based patch embedding with Transformer encoder blocks, bridging the structure-compilation philosophy of this library with the global attention mechanism of ViT.
- The architecture:
Patch Embedding: Uses precomputed connection indices to gather local neighborhoods and project them to embedding vectors.
Positional Encoding: Learnable positional embeddings added to each spatial point to encode location on the sphere.
Transformer Encoder: Stack of N Transformer blocks, each with multi-head self-attention and feed-forward network.
Output Projection: Optional linear projection to desired output dimension.
- The operation:
Embed: E = PatchEmbed(X) + PosEmbed [B, N_out, embed_dim]
Encode: Z = TransformerBlock_N(…TransformerBlock_1(E))
Project: Y = Linear(LN(Z)) [B, N_out, output_dim]
- Complexity:
Patch embedding: O(N_out * k * C_in) – linear in output points
Self-attention per block: O(N_out^2 * embed_dim) – quadratic in points
Total: O(depth * N_out^2 * embed_dim)
For large grids, use spatial downsampling (via connection indices from higher to lower resolution) to reduce N_out before the transformer.
- Literature Context:
Adapted from:
Dosovitskiy, A., Beyer, L., Kolesnikov, A., et al. (2021). “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.” In ICLR 2021. arXiv: 2010.11929
- Parameters:
output_points – Number of spatial points (patches) in the output.
connection_indices – Integer array of shape [output_points, kernel_size] containing indices of input pixels for each output patch.
embed_dim – Dimension of the patch embeddings. Default is 128.
depth – Number of Transformer encoder blocks. Default is 4.
num_heads – Number of attention heads per block. Default is 8.
mlp_ratio – Ratio of FFN hidden dim to embed_dim. Default is 4.0.
output_dim – Output feature dimension. If None, equals embed_dim.
dropout – Dropout probability for attention and FFN. Default is 0.0.
activation – Activation function for FFN layers. Default is “gelu”.
bias – Whether to include bias in linear projections. Default is True.
weight_init – Weight initialization method for patch embedding. Default is “xavier_uniform”.
norm_eps – Epsilon for layer normalization. Default is 1e-6.
attn_backend – Attention backend for all transformer blocks.
"auto"uses SDPA/FlashAttention 2 when available,"manual"forces explicit matmul path. Default is"auto".
- Shape:
Input: [B, N_in, C_in] where B is batch size, N_in is input points, C_in is input channels.
Output: [B, N_out, output_dim] where N_out is output_points.
Example
>>> from idx_flow.utils import compute_connection_indices >>> indices, _ = compute_connection_indices( ... nside_in=16, nside_out=8, k=9 ... ) >>> vit = SpatialViT( ... output_points=12 * 8**2, ... connection_indices=indices, ... embed_dim=64, ... depth=4, ... num_heads=4, ... output_dim=32, ... dropout=0.1, ... ) >>> x = torch.randn(4, 12 * 16**2, 8) >>> y = vit(x) >>> print(y.shape) # torch.Size([4, 768, 32])
See also
SpatialPatchEmbedding: Patch embedding using connection indicesSpatialTransformerBlock: Single transformer encoder blockSpatialSelfAttention: Standalone multi-head self-attentionSpatialConv: Linear convolution with O(N) complexitySpatialMLP: MLP kernel for local non-linear processing
- __init__(output_points, connection_indices, embed_dim=128, depth=4, num_heads=8, mlp_ratio=4.0, output_dim=None, dropout=0.0, activation='gelu', bias=True, weight_init='xavier_uniform', norm_eps=1e-06, attn_backend='auto')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
output_points (int)
embed_dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
output_dim (int | None)
dropout (float)
activation (Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'])
bias (bool)
weight_init (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros'])
norm_eps (float)
attn_backend (Literal['auto', 'sdpa', 'manual'])
- Return type:
None
Pooling and Utility Layers
SpatialPooling
- class SpatialPooling[source]
Bases:
ModuleSpatial Pooling layer for HEALPix grids.
Performs pooling operations (mean, max, or sum) over local neighborhoods on the spherical grid. This is a non-learnable layer useful for downsampling with simple aggregation.
- Parameters:
output_points – Number of spatial points in the output.
connection_indices – Integer array of shape [output_points, kernel_size].
pool_type – Type of pooling operation. One of “mean”, “max”, “sum”.
- Shape:
Input: [B, N_in, C_in]
Output: [B, N_out, C_in] (channels preserved)
Example
>>> from idx_flow.utils import compute_connection_indices >>> indices, _ = compute_connection_indices( ... nside_in=64, nside_out=32, k=4 ... ) >>> pool = SpatialPooling( ... output_points=12 * 32**2, ... connection_indices=indices, ... pool_type="mean" ... ) >>> x = torch.randn(8, 12 * 64**2, 32) >>> y = pool(x) >>> print(y.shape) # torch.Size([8, 12288, 32])
- __init__(output_points, connection_indices, pool_type='mean')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Squeeze
- class Squeeze[source]
Bases:
ModuleSqueeze layer that reduces spatial dimension to a single vector.
Performs global aggregation over all spatial points using mean, max, or sum pooling.
- Parameters:
reduction – Reduction method. One of “mean”, “max”, “sum”.
- Shape:
Input: [B, N, C]
Output: [B, C]
Example
>>> squeeze = Squeeze(reduction="mean") >>> x = torch.randn(8, 12288, 64) >>> y = squeeze(x) >>> print(y.shape) # torch.Size([8, 64])
- __init__(reduction='mean')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
reduction (Literal['mean', 'max', 'sum'])
- Return type:
None
Unsqueeze
- class Unsqueeze[source]
Bases:
ModuleUnsqueeze layer that broadcasts a vector to all spatial points.
Takes a feature vector and replicates it across the spatial dimension.
- Parameters:
num_points – Number of spatial points to broadcast to.
- Shape:
Input: [B, C]
Output: [B, num_points, C]
Example
>>> unsqueeze = Unsqueeze(num_points=12288) >>> x = torch.randn(8, 64) >>> y = unsqueeze(x) >>> print(y.shape) # torch.Size([8, 12288, 64])
- __init__(num_points)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
num_points (int)
- Return type:
None
Functional Utilities
get_initializer
- get_initializer(method, gain=1.0, nonlinearity='leaky_relu', mean=0.0, std=0.02, a=0.0, b=1.0)[source]
Get weight initialization function.
- Parameters:
method (Literal['xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'orthogonal', 'normal', 'uniform', 'zeros']) – Initialization method name.
gain (float) – Gain factor for xavier/orthogonal initialization.
nonlinearity (str) – Nonlinearity for kaiming initialization.
mean (float) – Mean for normal initialization.
std (float) – Standard deviation for normal initialization.
a (float) – Lower bound for uniform initialization.
b (float) – Upper bound for uniform initialization.
- Returns:
Initialization function that takes a tensor and initializes it in-place.
- Raises:
ValueError – If method is not recognized.
- Return type:
get_activation
- get_activation(name)[source]
Get activation module by name.
- Parameters:
name (Literal['relu', 'selu', 'leaky_relu', 'gelu', 'elu', 'tanh', 'sigmoid', 'swish', 'mish', 'linear'] | None) – Activation function name. If None, returns Identity.
- Returns:
PyTorch activation module.
- Raises:
ValueError – If activation name is not recognized.
- Return type:
Type Aliases
- idx_flow.InitMethod
Weight initialization methods:
"xavier_uniform","xavier_normal","kaiming_uniform","kaiming_normal","orthogonal","normal","uniform","zeros"
- idx_flow.ActivationType
Activation functions:
"relu","selu","leaky_relu","gelu","elu","tanh","sigmoid","swish","mish","linear"
- idx_flow.InterpolationMethod
Interpolation methods:
"linear","idw","gaussian"
- idx_flow.PoolingMethod
Pooling methods:
"mean","max","sum"
- idx_flow.AttnBackend
Attention computation backends:
"auto"(SDPA when available, else manual),"sdpa"(force scaled_dot_product_attention, requires PyTorch >= 2.0),"manual"(explicit matmul-softmax-matmul)