napistu_torch.models.heads
Prediction heads for Napistu-Torch.
This module provides implementations of different prediction heads for various tasks like edge prediction, node classification, etc. All heads follow a consistent interface.
Classes
- AttentionHead
Lightweight attention head for edge prediction.
- ConditionalRotatEHead
Conditional RotatE head for relation-aware edge prediction.
- DistMultHead
DistMult head for relation-aware edge prediction.
- DotProductHead
Simple dot product head for edge prediction.
- EdgeMLPHead
MLP-based head for edge prediction.
- NodeClassificationHead
Head for node classification tasks.
- RelationAttentionHead
Relation-aware attention head for edge prediction.
- RelationGatedMLPHead
Relation-aware gated MLP head for edge prediction.
- RelationAttentionMLPHead
Relation-aware attention-MLP hybrid head for edge prediction.
- RotatEHead
RotatE head for relation-aware edge prediction.
- TransEHead
TransE head for relation-aware edge prediction.
- Decoder
Decoder combining encoder and head for complete model architecture.
Classes
|
Lightweight attention head for edge prediction. |
|
Conditional decoder: DotProduct for symmetric relations, RotatE for asymmetric. |
|
Unified head decoder that can create different types of prediction heads. |
|
DistMult-style relation scoring for graph neural networks. |
|
Dot product head for edge prediction. |
|
Multi-layer perceptron head for edge prediction. |
|
Simple linear head for node classification tasks. |
|
Lightweight relation-aware multi-head attention for edge prediction. |
|
Relation-attention MLP head for edge prediction. |
|
Relation-gated MLP head for edge prediction. |
|
RotatE decoder for relation-aware edge prediction. |
|
TransE decoder for relation-aware edge prediction. |
- class napistu_torch.models.heads.AttentionHead(*args: Any, **kwargs: Any)
Bases:
ModuleLightweight attention head for edge prediction.
Projects nodes to query/key spaces and computes scaled dot-product attention. More expressive than dot product but much lighter than full MLP with attention.
Architecture: 1. Project source nodes → query space 2. Project target nodes → key space 3. Compute scaled dot product: (W_q @ src)^T @ (W_k @ tgt) / sqrt(d)
- Parameters:
embedding_dim (int) – Dimension of input node embeddings
attention_dim (int, optional) – Dimension of attention space (lower = more compression), by default 64
init_as_identity (bool, optional) – Initialize projections to approximate identity (dot product), by default False
Notes
Projects to lower dimension for efficiency and regularization
Learns separate query/key transformations (more flexible than dot product)
Scaled dot product prevents gradient vanishing in high dimensions
Normalizes embeddings for numerical stability
~2 * embedding_dim * attention_dim parameters (e.g., ~16K for 128→64)
Comparison to other heads: - vs DotProduct: Learns transformations (more expressive) - vs MLP: Much fewer parameters, easier to interpret - vs RelationAttention: No relation-specific modulation
Examples
>>> head = AttentionHead(embedding_dim=128, attention_dim=64) >>> scores = head(node_embeddings, edge_index) >>> # scores ∈ ℝ^{num_edges}, approximately normalized by scaling
- __init__(embedding_dim: int, attention_dim: int = 64, init_as_identity: bool = False)
- _initialize_weights(init_as_identity: bool)
Initialize projection weights.
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor
Compute attention-based edge scores.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings from encoder [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
- Returns:
Edge scores [num_edges]
- Return type:
torch.Tensor
Notes
- Scores are computed as:
score = (W_q @ normalize(src))^T @ (W_k @ normalize(tgt)) / sqrt(d_attn)
Normalization ensures embeddings have bounded norms, preventing score explosion with pretrained encoders.
- loss_type = 'bce'
- class napistu_torch.models.heads.ConditionalRotatEHead(*args: Any, **kwargs: Any)
Bases:
ModuleConditional decoder: DotProduct for symmetric relations, RotatE for asymmetric.
Automatically routes different relation types to appropriate scoring functions: - Symmetric relations (e.g., “protein->protein”): DotProduct distance - Asymmetric relations (e.g., “catalyst->modified”): RotatE rotation distance
Both heads produce distance-based scores in [-2, 0] for margin loss.
- Parameters:
embedding_dim (int) – Dimension of node embeddings (must be even for RotatE complex embeddings)
num_relations (int) – Total number of relation types
symmetric_relation_indices (List[int]) – Indices of relations that should use dot product (symmetric). All other relations use RotatE (asymmetric). Obtained from NapistuData.analyze_relation_symmetry()
init_asymmetric_as_identity (bool, optional) – Initialize RotatE phases to 0 (identity rotation), by default False
margin (float, optional) – Margin for ranking loss (applied to both heads), by default 9.0
Notes
Score Ranges (with normalized embeddings): Both heads produce scores in [-2, 0]: - DotProduct: distance = 1 - similarity, score = -distance - RotatE: distance = ||h⊙r - t||, score = -distance
When to Use: - Graph has mix of symmetric (A↔B) and asymmetric (A→B) relations - Example: protein-protein interactions (symmetric) + catalysis (asymmetric)
When NOT to Use: - All relations symmetric → Use DotProduct or DistMult instead - All relations asymmetric → Use RotatE or TransE instead
- __init__(embedding_dim: int, num_relations: int, symmetric_relation_indices: List[int], init_asymmetric_as_identity: bool = False, margin: float = 0.1)
- _compute_dot_scores(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor
Compute dot product-based distance scores.
- _compute_rotate_scores(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute RotatE rotation-based distance scores using shared utility.
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute edge scores using conditional head selection.
Routes edges to DotProduct (symmetric) or RotatE (asymmetric) based on relation type, then returns unified distance-based scores.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
relation_type (torch.Tensor) – Relation type for each edge [num_edges]
- Returns:
Edge scores in [-2, 0] for margin loss [num_edges] Higher score = more likely edge (closer to 0)
- Return type:
torch.Tensor
- scores_to_probs(scores: torch.Tensor) torch.Tensor
- loss_type = 'margin'
- class napistu_torch.models.heads.Decoder(*args: Any, **kwargs: Any)
Bases:
ModuleUnified head decoder that can create different types of prediction heads.
This class provides a single interface for creating various head types (e.g., dot product, MLP, attention, node classification) with a from_config classmethod for easy integration with configuration systems.
- Parameters:
hidden_channels (int) – Dimension of input node embeddings (should match GNN encoder output)
head_type (str) – Type of head to create (dot_product, mlp, attention, node_classification)
num_relations (int, optional) – Number of relation types (required for relation-aware heads)
symmetric_relation_indices (List[int], optional) – List of relation type indices that are symmetric. This is required for heads that support special symmetry handling.
num_classes (int, optional) – Number of output classes for node classification head
init_head_as_identity (bool, optional) – Whether to initialize the head to approximate an identity transformation, by default False
mlp_hidden_dim (int, optional) – Hidden layer dimension for MLP head, by default 64
mlp_num_layers (int, optional) – Number of hidden layers for MLP head, by default 2
mlp_dropout (float, optional) – Dropout probability for MLP head, by default 0.1
nc_dropout (float, optional) – Dropout probability for node classification head, by default 0.1
rotate_margin (float, optional) – Margin for RotatE head, by default 9.0
transe_margin (float, optional) – Margin for TransE head, by default 1.0
relation_emb_dim (int,) – Dimension of relation embeddings for relation-aware MLP heads, by default 64
relation_attention_heads (int,) – Number of attention heads for RelationAttentionMLP, by default 4
Methods (Public)
--------------
Dict[str (config(self) ->) – Get the configuration dictionary for this decoder.
Any] – Get the configuration dictionary for this decoder.
from_config(config (ModelConfig, num_relations: Optional[int] = None, num_classes: Optional[int] = None) -> Decoder:) – Create a Decoder from a ModelConfig instance.
forward(node_embeddings (torch.Tensor, edge_index: Optional[torch.Tensor] = None, relation_type: Optional[torch.Tensor] = None) -> torch.Tensor:) – Forward pass through the head.
bool (supports_relations(self) ->) – Check if this decoder supports relation-aware heads.
- classmethod from_config(config: ModelConfig, num_relations: int | None = None, num_classes: int | None = None, symmetric_relation_indices: List[int] | None = None)
Create a Decoder from a configuration object.
- Parameters:
config (ModelConfig) – Configuration object containing head parameters
num_relations (int, optional) – Number of relation types (required for relation-aware heads). This should be inferred from edge_strata.
num_classes (int, optional) – Number of output classes for node classification head (required for node classification head). This should be inferred from the data.
symmetric_relation_indices (List[int], optional) – List of relation type indices that are symmetric. This is required for heads that support special symmetry handling.
- Returns:
Configured head decoder
- Return type:
- __init__(hidden_channels: int, head_type: str = 'dot_product', num_relations: int | None = None, symmetric_relation_indices: List[int] | None = None, num_classes: int | None = None, init_head_as_identity: bool = False, mlp_hidden_dim: int = 128, mlp_num_layers: int = 2, mlp_dropout: float = 0.1, nc_dropout: float = 0.1, rotate_margin: float = 0.1, transe_margin: float = 0.1, relation_emb_dim: int = 64, relation_attention_heads: int = 4)
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor | None = None, relation_type: torch.Tensor | None = None) torch.Tensor
Forward pass through the head.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]
edge_index (torch.Tensor, optional) – Edge connectivity [2, num_edges] (required for edge prediction heads)
relation_type (torch.Tensor, optional) – Relation type for each edge [num_edges] (required for relation-aware heads)
- Returns:
Head output (edge scores or node predictions)
- Return type:
torch.Tensor
- get_summary() Dict[str, Any]
Get decoder metadata summary for checkpointing.
Returns essential metadata needed to reconstruct the decoder from a checkpoint, including ALL parameters that were used.
- Returns:
Dictionary containing all initialization parameters, with None values filtered out for head-type-specific params
- Return type:
Dict[str, Any]
- property config: Dict[str, Any]
Get the configuration dictionary for this decoder.
Returns a dict containing all initialization parameters needed to reconstruct this decoder instance.
- Returns:
Configuration dictionary with all __init__ parameters
- Return type:
Dict[str, Any]
- property loss_type: str
Get the loss type required by the underlying head.
- Returns:
Loss type (e.g., LOSSES.BCE, LOSSES.MARGIN)
- Return type:
str
- property margin: float
Get the margin value for heads that support margin loss (RotatE, TransE).
- Returns:
Margin value for ranking loss
- Return type:
float
- Raises:
AttributeError – If the underlying head does not have a margin attribute
- property supports_relations: bool
Check if this decoder supports relation-aware heads.
- Returns:
True if the head type is in RELATION_AWARE_HEADS, False otherwise
- Return type:
bool
- class napistu_torch.models.heads.DistMultHead(*args: Any, **kwargs: Any)
Bases:
ModuleDistMult-style relation scoring for graph neural networks.
Adapted from knowledge graph DistMult (Yang et al. 2015) to GNN setting where nodes share embedding space instead of having separate entity embeddings.
Score = mean(h ⊙ r ⊙ t) where h,t ∈ same embedding space
- Parameters:
embedding_dim (int) – Dimension of node embeddings from GNN
num_relations (int) – Number of distinct relation types
Notes
Key Difference from Original DistMult: - Original: Separate embeddings per entity (h_aspirin, t_headache) - This version: Shared node embedding space (all nodes use same encoder)
Symmetry Warning: Like original DistMult, this is symmetric: score(h,r,t) = score(t,r,h) Cannot distinguish directed relations without combining with asymmetric encoder or relation-specific directionality in the GNN.
References
Yang et al. “Embedding Entities and Relations for Learning and Inference in Knowledge Bases” ICLR 2015.
Examples
>>> # Use only if relations are symmetric >>> head = DistMultHead(embedding_dim=256, num_relations=4) >>> scores = head(z, edge_index, relation_type)
- __init__(embedding_dim: int, num_relations: int, init_as_identity: bool = False)
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute edge scores using DistMult.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
relation_type (torch.Tensor) – Relation type for each edge [num_edges]
- Returns:
Edge scores [num_edges]
- Return type:
torch.Tensor
- loss_type = 'bce'
- class napistu_torch.models.heads.DotProductHead(*args: Any, **kwargs: Any)
Bases:
ModuleDot product head for edge prediction.
Computes edge scores as the dot product of source and target node embeddings. This is the simplest and most efficient head for edge prediction tasks.
- __init__()
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor
Compute edge scores using dot product.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
- Returns:
Edge scores [num_edges]
- Return type:
torch.Tensor
- loss_type = 'bce'
- class napistu_torch.models.heads.EdgeMLPHead(*args: Any, **kwargs: Any)
Bases:
ModuleMulti-layer perceptron head for edge prediction.
Uses an MLP to predict edge scores from concatenated source and target embeddings. More expressive than dot product but requires more parameters.
- Parameters:
embedding_dim (int) – Dimension of input node embeddings
hidden_dim (int, optional) – Hidden layer dimension, by default 64
num_layers (int, optional) – Number of hidden layers, by default 2
dropout (float, optional) – Dropout probability, by default 0.1
- __init__(embedding_dim: int, hidden_dim: int = 64, num_layers: int = 2, dropout: float = 0.1)
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor
Compute edge scores using MLP.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
- Returns:
Edge scores [num_edges]
- Return type:
torch.Tensor
- loss_type = 'bce'
- class napistu_torch.models.heads.NodeClassificationHead(*args: Any, **kwargs: Any)
Bases:
ModuleSimple linear head for node classification tasks.
- Parameters:
embedding_dim (int) – Dimension of input node embeddings
num_classes (int) – Number of output classes
dropout (float, optional) – Dropout probability, by default 0.1
- __init__(embedding_dim: int, num_classes: int, dropout: float = 0.1)
- forward(node_embeddings: torch.Tensor) torch.Tensor
Compute node class predictions.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]
- Returns:
Node class logits [num_nodes, num_classes]
- Return type:
torch.Tensor
- class napistu_torch.models.heads.RelationAttentionHead(*args: Any, **kwargs: Any)
Bases:
ModuleLightweight relation-aware multi-head attention for edge prediction.
Simplified version of RelationAttentionMLPHead: - Directly projects nodes instead of edge MLP - Multi-head attention with relation queries - No residual connection or output MLP (lighter) - Still captures relation-specific feature selection
Architecture: 1. Project nodes → attention space (replaces edge MLP) 2. Relation embeddings → queries (multi-head) 3. Node projections → keys, values (multi-head) 4. Attention → weighted combination → score
- Parameters:
embedding_dim (int) – Dimension of input node embeddings
num_relations (int) – Number of distinct relation types
relation_emb_dim (int, optional) – Dimension of relation embeddings, by default 64
hidden_dim (int, optional) – Hidden dimension (must be divisible by num_attention_heads), by default 128
num_attention_heads (int, optional) – Number of attention heads, by default 4
Notes
Comparison to RelationAttentionMLPHead: - Much lighter: no edge MLP, no output MLP, no residual - Same core idea: relation queries edge features via attention - Parameters: ~50K vs ~100K (half the size) - More interpretable: fewer non-linearities
Comparison to AttentionHead: - Adds relation-specific attention (like RelationAttentionMLPHead) - Multi-head for richer feature selection - More parameters but more expressive
- __init__(embedding_dim: int, num_relations: int, relation_emb_dim: int = 64, hidden_dim: int = 128, num_attention_heads: int = 4)
- _initialize_weights()
Initialize weights with sensible defaults.
Strategy: - Relation embeddings: Small random (std=0.1) to allow learning - Edge projection: Xavier with moderate gain to preserve info - Attention Q/K/V: Standard Xavier to balance stability/expressiveness - Output: Small gain to avoid initial saturation
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute relation-aware multi-head attention scores.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
relation_type (torch.Tensor) – Relation type for each edge [num_edges]
- Returns:
Edge scores [num_edges]
- Return type:
torch.Tensor
- loss_type = 'bce'
- class napistu_torch.models.heads.RelationAttentionMLPHead(*args: Any, **kwargs: Any)
Bases:
ModuleRelation-attention MLP head for edge prediction.
Uses relation embeddings to query edge features via multi-head attention. The relation type determines WHICH aspects of edge features to attend to.
Architecture: 1. Process [src || tgt] through edge MLP → hidden features 2. Relation embedding → Query 3. Edge features → Key, Value 4. Multi-head attention: relation queries edge features 5. Residual connection + output MLP → edge score
- Parameters:
embedding_dim (int) – Dimension of input node embeddings
num_relations (int) – Number of distinct relation types
relation_emb_dim (int, optional) – Dimension of relation embeddings, by default 64
hidden_dim (int, optional) – Hidden layer dimension (must be divisible by num_attention_heads), by default 128
num_layers (int, optional) – Number of layers in output MLP, by default 2
dropout (float, optional) – Dropout probability, by default 0.1
num_attention_heads (int, optional) – Number of attention heads, by default 4
Notes
More expressive than gating (can learn complex feature selection)
Different heads can specialize (e.g., one for catalysis, one for inhibition)
Attention is per-edge over hidden dimensions (not graph-level like GAT)
More parameters than gating but potentially better for diverse relation semantics
Reuses MLP hyperparameters from EdgeMLPHead for consistency
Examples
>>> head = RelationAttentionMLPHead( ... embedding_dim=256, ... num_relations=10, ... relation_emb_dim=64, ... hidden_dim=128, ... num_attention_heads=4 ... ) >>> scores = head(node_embeddings, edge_index, relation_type)
- __init__(embedding_dim: int, num_relations: int, relation_emb_dim: int = 64, hidden_dim: int = 128, num_layers: int = 2, dropout: float = 0.1, num_attention_heads: int = 4)
- _initialize_weights()
Initialize weights, optionally starting near identity.
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute edge scores using relation-attention MLP.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
relation_type (torch.Tensor) – Relation type for each edge [num_edges]
- Returns:
Edge scores [num_edges]
- Return type:
torch.Tensor
- loss_type = 'bce'
- class napistu_torch.models.heads.RelationGatedMLPHead(*args: Any, **kwargs: Any)
Bases:
ModuleRelation-gated MLP head for edge prediction.
Uses relation embeddings to modulate edge features via element-wise gating. The relation type controls HOW the MLP processes each edge pair.
Architecture: 1. Process [src || tgt] through edge MLP → hidden features 2. Relation embedding → gate values (via Tanh) 3. Modulate: gated_features = edge_features * relation_gates 4. Final MLP → edge score
- Parameters:
embedding_dim (int) – Dimension of input node embeddings
num_relations (int) – Number of distinct relation types
relation_emb_dim (int, optional) – Dimension of relation embeddings, by default 64
hidden_dim (int, optional) – Hidden layer dimension, by default 128
num_layers (int, optional) – Number of layers in output MLP, by default 2
dropout (float, optional) – Dropout probability, by default 0.1
Notes
Handles imbalanced relation frequencies well (rare relations share parameters)
More parameter-efficient than separate MLPs per relation
Tanh gating allows both suppression (negative) and amplification (positive)
Reuses MLP hyperparameters from EdgeMLPHead for consistency
Examples
>>> head = RelationGatedMLPHead( ... embedding_dim=256, ... num_relations=10, ... relation_emb_dim=64, ... hidden_dim=128 ... ) >>> scores = head(node_embeddings, edge_index, relation_type)
- __init__(embedding_dim: int, num_relations: int, relation_emb_dim: int = 64, hidden_dim: int = 128, num_layers: int = 2, dropout: float = 0.1)
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute edge scores using relation-gated MLP.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
relation_type (torch.Tensor) – Relation type for each edge [num_edges]
- Returns:
Edge scores [num_edges]
- Return type:
torch.Tensor
- loss_type = 'bce'
- class napistu_torch.models.heads.RotatEHead(*args: Any, **kwargs: Any)
Bases:
ModuleRotatE decoder for relation-aware edge prediction.
Models relations as rotations in complex space: h ⊙ r ≈ t where ⊙ is complex multiplication (Hadamard product in re/im components).
Scoring function: score = -||h ⊙ r - t||
- Parameters:
embedding_dim (int) – Dimension of node embeddings from GNN (must be even for complex embeddings)
num_relations (int) – Number of distinct relation types
margin (float, optional) – Margin for ranking loss, by default 1.0
init_as_identity (bool, optional) – Initialize relations as identity rotations (angle=0), by default False
Notes
Embeddings are split into real/imaginary parts: [embedding_dim/2, embedding_dim/2]
Relations are phase angles that rotate head embeddings
Handles symmetric relations (r₁ = -r₂) and composition (r₃ = r₁ + r₂)
Requires normalized embeddings (||h|| = ||t|| = 1) for bounded distances
Distance range with unit norm: [0, 2]
Score range: [-2, 0]
References
Sun et al. “RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space” ICLR 2019.
Examples
>>> head = RotatEHead(embedding_dim=256, num_relations=4) >>> scores = head(z, edge_index, relation_type)
- __init__(embedding_dim: int, num_relations: int, margin: float = 0.1, init_as_identity: bool = False)
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute edge scores using RotatE.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
relation_type (torch.Tensor) – Relation type for each edge [num_edges]
- Returns:
Edge scores [num_edges] (higher = more likely, range [-2, 0])
- Return type:
torch.Tensor
- scores_to_probs(scores: torch.Tensor) torch.Tensor
- loss_type = 'margin'
- class napistu_torch.models.heads.TransEHead(*args: Any, **kwargs: Any)
Bases:
ModuleTransE decoder for relation-aware edge prediction.
Models relations as translations in embedding space: h + r ≈ t Simpler than RotatE and often easier to interpret.
Scoring function: score = -||h + r - t||
- Parameters:
embedding_dim (int) – Dimension of node embeddings from GNN
num_relations (int) – Number of distinct relation types
margin (float, optional) – Margin for ranking loss, by default 1.0
norm (int, optional) – Norm to use for distance (1 or 2), by default 2
Notes
Simpler than RotatE (fewer parameters, easier optimization)
Naturally handles asymmetric relations: h+r₁ vs h+r₂
May struggle with 1-to-N relations (e.g., one reaction → many products)
Good baseline before trying more complex heads
Requires normalized embeddings (||h|| = ||t|| = 1) for bounded distances
References
Bordes et al. “Translating Embeddings for Modeling Multi-relational Data” NeurIPS 2013.
Examples
>>> head = TransEHead(embedding_dim=256, num_relations=4) >>> scores = head(z, edge_index, relation_type)
- __init__(embedding_dim: int, num_relations: int, margin: float = 0.1, norm: int = 2, init_as_identity: bool = False)
- forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor
Compute edge scores using TransE.
- Parameters:
node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
relation_type (torch.Tensor) – Relation type for each edge [num_edges]
- Returns:
Edge scores [num_edges] (higher = more likely)
- Return type:
torch.Tensor
- scores_to_probs(scores: torch.Tensor) torch.Tensor
- loss_type = 'margin'