napistu_torch.models.message_passing_encoder
Graph Neural Network models for Napistu-Torch.
This module provides a unified Graph Neural Network encoder supporting multiple architectures (GCN, GAT, SAGE, GraphConv) with consistent behavior and configuration.
Classes
- MessagePassingEncoder
Unified Graph Neural Network encoder supporting multiple architectures.
Classes
|
Unified Graph Neural Network encoder supporting multiple architectures. |
- class napistu_torch.models.message_passing_encoder.MessagePassingEncoder(*args: Any, **kwargs: Any)
Bases:
ModuleUnified Graph Neural Network encoder supporting multiple architectures.
This class eliminates boilerplate by providing a single interface for SAGE, GCN, and GAT models with consistent behavior and configuration.
Edge Weight Support
GCN: ✅ Supports edge_weight parameter
GraphConv: ✅ Supports edge_weight parameter (SAGE-like with edge weights)
SAGE: ❌ Does not support edge_weight (gracefully ignored)
GAT: ❌ Uses learned attention (edge_weight not needed)
- param in_channels:
Number of input node features
- type in_channels:
int
- param hidden_channels:
Number of hidden channels in each layer
- type hidden_channels:
int
- param num_layers:
Number of GNN layers
- type num_layers:
int
- param dropout:
Dropout probability, by default 0.0
- type dropout:
float, optional
- param encoder_type:
Type of encoder (‘sage’, ‘gcn’, ‘gat’), by default ‘sage’
- type encoder_type:
str, optional
- param sage_aggregator:
Aggregation method for SAGE (‘mean’, ‘max’, ‘lstm’), by default ‘mean’
- type sage_aggregator:
str, optional
- param gat_heads:
Number of attention heads for GAT, by default 1
- type gat_heads:
int, optional
- param gat_concat:
Whether to concatenate attention heads in GAT, by default True
- type gat_concat:
bool, optional
- param graph_conv_aggregator:
Aggregation method for GraphConv, by default ‘add’
- type graph_conv_aggregator:
str, optional
- param Public Methods:
- param ————–:
- param config(self) -> Dict[str:
Get the configuration dictionary for this encoder.
- param Any]:
Get the configuration dictionary for this encoder.
- param encode(x:
Alias for forward method for consistency with other models.
- type encode(x:
torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:
- param from_config(config:
Create a MessagePassingEncoder from a ModelConfig instance.
- type from_config(config:
ModelConfig, in_channels: int, edge_in_channels: Optional[int] = None) -> “MessagePassingEncoder”:
- param forward(x:
Forward pass through the encoder.
- type forward(x:
torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:
- param get_summary(self) -> Dict[str:
Get encoder metadata summary for checkpointing.
- param Any]:
Get encoder metadata summary for checkpointing.
- param Private Methods:
- param —————:
- param _parse_edge_weighting(weight_edges_by:
Parse weight_edges_by parameter into type indicator and value.
- type _parse_edge_weighting(weight_edges_by:
Optional[Union[torch.Tensor, nn.Module]], supports_edge_weight: bool, encoder_type: str) -> tuple[str, Optional[Union[torch.Tensor, nn.Module]]]:
Notes
This encoder does NOT use edge weights for message passing. If you need weighted message passing, you would need to: 1. Use GCNConv (only encoder that natively supports edge weights) 2. Implement custom message passing with edge attributes
Edge weights and attributes in your NapistuData are still available for supervision and evaluation - they just aren’t used during encoding.
Examples
>>> # Direct instantiation >>> encoder = MessagePassingEncoder(128, 256, 3, encoder_type='sage', sage_aggregator='mean') >>> >>> # From config >>> config = ModelConfig(encoder='sage', hidden_channels=256, num_layers=3) >>> encoder = MessagePassingEncoder.from_config(config, in_channels=128)
- classmethod from_config(config: ModelConfig, in_channels: int, edge_in_channels: int | None = None) MessagePassingEncoder
Create MessagePassingEncoder from ModelConfig.
- Parameters:
config (ModelConfig) – Model configuration containing encoder, hidden_channels, etc.
in_channels (int) – Number of input node features (not in config as it depends on data)
edge_in_channels (int, optional) – Number of input edge features. Required if use_edge_encoder=True.
- Returns:
Configured encoder instance
- Return type:
Examples
>>> config = ModelConfig(encoder='sage', hidden_channels=256, num_layers=3) >>> encoder = MessagePassingEncoder.from_config(config, in_channels=128) >>> >>> # With edge encoder >>> config = ModelConfig(encoder='gcn', use_edge_encoder=True, edge_encoder_dim=32) >>> encoder = MessagePassingEncoder.from_config(config, in_channels=128, edge_in_channels=10)
- __init__(in_channels: int, hidden_channels: int, num_layers: int, dropout: float = 0.0, encoder_type: str = 'sage', weight_edges_by: torch.Tensor | torch.nn.Module | None = None, gat_heads: int = 1, gat_concat: bool = True, graph_conv_aggregator: str = 'mean', sage_aggregator: str = 'mean')
- encode(x: torch.Tensor, edge_index: torch.Tensor, edge_data: torch.Tensor | None = None) torch.Tensor
Alias for forward method for consistency with other models.
- Parameters:
x (torch.Tensor) – Node feature matrix [num_nodes, in_channels]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
edge_data (Optional[torch.Tensor]) – Edge data [num_edges, edge_dim] for edge weighting. Can be edge attributes (for learnable encoder) or edge weights (for static weighting).
- Returns:
Node embeddings [num_nodes, hidden_channels]
- Return type:
torch.Tensor
- forward(x: torch.Tensor, edge_index: torch.Tensor, edge_data: torch.Tensor | None = None) torch.Tensor
Forward pass through the GNN encoder.
- Parameters:
x (torch.Tensor) – Node feature matrix [num_nodes, in_channels]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
edge_data (torch.Tensor, optional) – Edge data [num_edges, edge_dim] for edge weighting. Can be edge attributes (for learnable encoder) or edge weights (for static weighting).
- Returns:
Node embeddings [num_nodes, hidden_channels]
- Return type:
torch.Tensor
Notes
Edge weighting is handled based on the edge_weighting_type attribute: - EDGE_WEIGHTING_TYPE.NONE: No edge weighting (uniform message passing) - EDGE_WEIGHTING_TYPE.STATIC_WEIGHTS: Static edge weights (edge_data contains pre-computed weights) - EDGE_WEIGHTING_TYPE.LEARNED_ENCODER: Learnable edge encoder (edge_data contains edge attributes)
- get_summary() Dict[str, Any]
Get encoder metadata summary for checkpointing.
Returns essential metadata needed to reconstruct the encoder.
- property config: Dict[str, Any]
Get the configuration dictionary for this encoder.
Returns a dict containing all initialization parameters needed to reconstruct this encoder instance.
- Returns:
Configuration dictionary with all __init__ parameters
- Return type:
Dict[str, Any]
- napistu_torch.models.message_passing_encoder._parse_edge_weighting(weight_edges_by: torch.Tensor | torch.nn.Module | None, supports_edge_weight: bool, encoder_type: str) tuple[str, torch.Tensor | torch.nn.Module | None]
Parse weight_edges_by parameter into type indicator and value.
This utility function handles the polyschematicity of edge weighting options by explicitly separating the type indicator from the value.
- Parameters:
weight_edges_by (Optional[Union[torch.Tensor, nn.Module]]) – Edge weighting specification: - None: No edge weighting - torch.Tensor: Static edge weights - nn.Module: Learnable edge encoder
supports_edge_weight (bool) – Whether the encoder type supports edge weighting
encoder_type (str) – Name of encoder type (for logging)
- Returns:
Tuple of (edge_weighting_type, edge_weighting_value) - edge_weighting_type: String constant from EDGE_WEIGHTING_TYPE indicating the type of weighting - edge_weighting_value: The actual value (None, Tensor, or Module)
- Return type:
tuple[str, Optional[Union[torch.Tensor, nn.Module]]]
- Raises:
ValueError – If weight_edges_by is not None, Tensor, or Module