napistu_torch.models.message_passing_encoder

Graph Neural Network models for Napistu-Torch.

This module provides a unified Graph Neural Network encoder supporting multiple architectures (GCN, GAT, SAGE, GraphConv) with consistent behavior and configuration.

Classes

MessagePassingEncoder

Unified Graph Neural Network encoder supporting multiple architectures.

Classes

MessagePassingEncoder(*args, **kwargs)

Unified Graph Neural Network encoder supporting multiple architectures.

class napistu_torch.models.message_passing_encoder.MessagePassingEncoder(*args: Any, **kwargs: Any)

Bases: Module

Unified Graph Neural Network encoder supporting multiple architectures.

This class eliminates boilerplate by providing a single interface for SAGE, GCN, and GAT models with consistent behavior and configuration.

Edge Weight Support

  • GCN: ✅ Supports edge_weight parameter

  • GraphConv: ✅ Supports edge_weight parameter (SAGE-like with edge weights)

  • SAGE: ❌ Does not support edge_weight (gracefully ignored)

  • GAT: ❌ Uses learned attention (edge_weight not needed)

param in_channels:

Number of input node features

type in_channels:

int

param hidden_channels:

Number of hidden channels in each layer

type hidden_channels:

int

param num_layers:

Number of GNN layers

type num_layers:

int

param dropout:

Dropout probability, by default 0.0

type dropout:

float, optional

param encoder_type:

Type of encoder (‘sage’, ‘gcn’, ‘gat’), by default ‘sage’

type encoder_type:

str, optional

param sage_aggregator:

Aggregation method for SAGE (‘mean’, ‘max’, ‘lstm’), by default ‘mean’

type sage_aggregator:

str, optional

param gat_heads:

Number of attention heads for GAT, by default 1

type gat_heads:

int, optional

param gat_concat:

Whether to concatenate attention heads in GAT, by default True

type gat_concat:

bool, optional

param graph_conv_aggregator:

Aggregation method for GraphConv, by default ‘add’

type graph_conv_aggregator:

str, optional

param Public Methods:

param ————–:

param config(self) -> Dict[str:

Get the configuration dictionary for this encoder.

param Any]:

Get the configuration dictionary for this encoder.

param encode(x:

Alias for forward method for consistency with other models.

type encode(x:

torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:

param from_config(config:

Create a MessagePassingEncoder from a ModelConfig instance.

type from_config(config:

ModelConfig, in_channels: int, edge_in_channels: Optional[int] = None) -> “MessagePassingEncoder”:

param forward(x:

Forward pass through the encoder.

type forward(x:

torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:

param get_summary(self) -> Dict[str:

Get encoder metadata summary for checkpointing.

param Any]:

Get encoder metadata summary for checkpointing.

param Private Methods:

param —————:

param _parse_edge_weighting(weight_edges_by:

Parse weight_edges_by parameter into type indicator and value.

type _parse_edge_weighting(weight_edges_by:

Optional[Union[torch.Tensor, nn.Module]], supports_edge_weight: bool, encoder_type: str) -> tuple[str, Optional[Union[torch.Tensor, nn.Module]]]:

Notes

This encoder does NOT use edge weights for message passing. If you need weighted message passing, you would need to: 1. Use GCNConv (only encoder that natively supports edge weights) 2. Implement custom message passing with edge attributes

Edge weights and attributes in your NapistuData are still available for supervision and evaluation - they just aren’t used during encoding.

Examples

>>> # Direct instantiation
>>> encoder = MessagePassingEncoder(128, 256, 3, encoder_type='sage', sage_aggregator='mean')
>>>
>>> # From config
>>> config = ModelConfig(encoder='sage', hidden_channels=256, num_layers=3)
>>> encoder = MessagePassingEncoder.from_config(config, in_channels=128)
classmethod from_config(config: ModelConfig, in_channels: int, edge_in_channels: int | None = None) MessagePassingEncoder

Create MessagePassingEncoder from ModelConfig.

Parameters:
  • config (ModelConfig) – Model configuration containing encoder, hidden_channels, etc.

  • in_channels (int) – Number of input node features (not in config as it depends on data)

  • edge_in_channels (int, optional) – Number of input edge features. Required if use_edge_encoder=True.

Returns:

Configured encoder instance

Return type:

MessagePassingEncoder

Examples

>>> config = ModelConfig(encoder='sage', hidden_channels=256, num_layers=3)
>>> encoder = MessagePassingEncoder.from_config(config, in_channels=128)
>>>
>>> # With edge encoder
>>> config = ModelConfig(encoder='gcn', use_edge_encoder=True, edge_encoder_dim=32)
>>> encoder = MessagePassingEncoder.from_config(config, in_channels=128, edge_in_channels=10)
__init__(in_channels: int, hidden_channels: int, num_layers: int, dropout: float = 0.0, encoder_type: str = 'sage', weight_edges_by: torch.Tensor | torch.nn.Module | None = None, gat_heads: int = 1, gat_concat: bool = True, graph_conv_aggregator: str = 'mean', sage_aggregator: str = 'mean')
encode(x: torch.Tensor, edge_index: torch.Tensor, edge_data: torch.Tensor | None = None) torch.Tensor

Alias for forward method for consistency with other models.

Parameters:
  • x (torch.Tensor) – Node feature matrix [num_nodes, in_channels]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • edge_data (Optional[torch.Tensor]) – Edge data [num_edges, edge_dim] for edge weighting. Can be edge attributes (for learnable encoder) or edge weights (for static weighting).

Returns:

Node embeddings [num_nodes, hidden_channels]

Return type:

torch.Tensor

forward(x: torch.Tensor, edge_index: torch.Tensor, edge_data: torch.Tensor | None = None) torch.Tensor

Forward pass through the GNN encoder.

Parameters:
  • x (torch.Tensor) – Node feature matrix [num_nodes, in_channels]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • edge_data (torch.Tensor, optional) – Edge data [num_edges, edge_dim] for edge weighting. Can be edge attributes (for learnable encoder) or edge weights (for static weighting).

Returns:

Node embeddings [num_nodes, hidden_channels]

Return type:

torch.Tensor

Notes

Edge weighting is handled based on the edge_weighting_type attribute: - EDGE_WEIGHTING_TYPE.NONE: No edge weighting (uniform message passing) - EDGE_WEIGHTING_TYPE.STATIC_WEIGHTS: Static edge weights (edge_data contains pre-computed weights) - EDGE_WEIGHTING_TYPE.LEARNED_ENCODER: Learnable edge encoder (edge_data contains edge attributes)

get_summary() Dict[str, Any]

Get encoder metadata summary for checkpointing.

Returns essential metadata needed to reconstruct the encoder.

property config: Dict[str, Any]

Get the configuration dictionary for this encoder.

Returns a dict containing all initialization parameters needed to reconstruct this encoder instance.

Returns:

Configuration dictionary with all __init__ parameters

Return type:

Dict[str, Any]

napistu_torch.models.message_passing_encoder._parse_edge_weighting(weight_edges_by: torch.Tensor | torch.nn.Module | None, supports_edge_weight: bool, encoder_type: str) tuple[str, torch.Tensor | torch.nn.Module | None]

Parse weight_edges_by parameter into type indicator and value.

This utility function handles the polyschematicity of edge weighting options by explicitly separating the type indicator from the value.

Parameters:
  • weight_edges_by (Optional[Union[torch.Tensor, nn.Module]]) – Edge weighting specification: - None: No edge weighting - torch.Tensor: Static edge weights - nn.Module: Learnable edge encoder

  • supports_edge_weight (bool) – Whether the encoder type supports edge weighting

  • encoder_type (str) – Name of encoder type (for logging)

Returns:

Tuple of (edge_weighting_type, edge_weighting_value) - edge_weighting_type: String constant from EDGE_WEIGHTING_TYPE indicating the type of weighting - edge_weighting_value: The actual value (None, Tensor, or Module)

Return type:

tuple[str, Optional[Union[torch.Tensor, nn.Module]]]

Raises:

ValueError – If weight_edges_by is not None, Tensor, or Module