napistu_torch.models.edge_encoder
Edge encoder for Napistu-Torch.
This module provides a simple MLP-based edge encoder for learning edge importance weights.
Classes
- EdgeEncoder
Learns edge importance weights from edge features.
Classes
|
Learns edge importance weights from edge features. |
- class napistu_torch.models.edge_encoder.EdgeEncoder(*args: Any, **kwargs: Any)
Bases:
ModuleLearns edge importance weights from edge features.
This is a standalone module that composes with GNNEncoder to provide learned edge weights for message passing.
Architecture
edge_features → MLP → sigmoid → edge_weights [0, 1]
The output edge weights scale message contributions during GNN aggregation, effectively learning to filter out noisy edges.
- param edge_dim:
Dimensionality of input edge features
- type edge_dim:
int
- param hidden_dim:
Hidden layer size. Keep small to avoid overfitting.
- type hidden_dim:
int, default=32
- param dropout:
Dropout probability for regularization
- type dropout:
float, default=0.1
- param init_bias:
Initial bias for output layer. Controls starting edge weights: - 0.0 → sigmoid(0) = 0.5 (neutral, equal weighting) - 1.4 → sigmoid(1.4) ≈ 0.8 (optimistic, most edges good) - -1.4 → sigmoid(-1.4) ≈ 0.2 (pessimistic, most edges bad)
- type init_bias:
float, default=0.0
- param Public Methods:
- param ————–:
- param config(self) -> Dict[str:
Get the configuration dictionary for this edge encoder.
- param Any]:
Get the configuration dictionary for this edge encoder.
- param forward(self:
Compute edge importance weights from edge features.
- type forward(self:
torch.Tensor) -> torch.Tensor:
- param edge_attr:
Compute edge importance weights from edge features.
- type edge_attr:
torch.Tensor) -> torch.Tensor:
- param get_summary(self:
Get the summary dictionary for this edge encoder.
- type get_summary(self:
bool = False) -> Dict[str, Any]:
- param to_model_config_names:
Get the summary dictionary for this edge encoder.
- type to_model_config_names:
bool = False) -> Dict[str, Any]:
Examples
>>> # Create edge encoder >>> edge_encoder = EdgeEncoder(edge_dim=10, hidden_dim=32) >>> >>> # Use with GNNEncoder >>> edge_weights = edge_encoder(edge_attr) # [num_edges, 10] -> [num_edges] >>> z = gnn_encoder(x, edge_index, edge_weight=edge_weights)
Notes
Output is in [0, 1] via sigmoid
Very lightweight: ~edge_dim * hidden_dim parameters
Learns end-to-end with the main task
Can be initialized to approximate existing heuristics
- __init__(edge_dim: int, hidden_dim: int = 32, dropout: float = 0.1, init_bias: float = 0.0)
- forward(edge_attr: torch.Tensor) torch.Tensor
Compute edge importance weights from edge features.
- Parameters:
edge_attr (torch.Tensor) – Edge features [num_edges, edge_dim]
- Returns:
edge_weight – Learned edge importance weights [num_edges] Values in range [0, 1] where higher = more important
- Return type:
torch.Tensor
- get_summary(to_model_config_names: bool = False) Dict[str, Any]
Get the summary dictionary for this edge encoder.
Returns a dict containing all initialization parameters needed to reconstruct this edge encoder instance.
- property config: Dict[str, Any]
Get the configuration dictionary for this edge encoder.