napistu_torch.models.heads

Prediction heads for Napistu-Torch.

This module provides implementations of different prediction heads for various tasks like edge prediction, node classification, etc. All heads follow a consistent interface.

Classes

AttentionHead

Lightweight attention head for edge prediction.

ConditionalRotatEHead

Conditional RotatE head for relation-aware edge prediction.

DistMultHead

DistMult head for relation-aware edge prediction.

DotProductHead

Simple dot product head for edge prediction.

EdgeMLPHead

MLP-based head for edge prediction.

NodeClassificationHead

Head for node classification tasks.

RelationAttentionHead

Relation-aware attention head for edge prediction.

RelationGatedMLPHead

Relation-aware gated MLP head for edge prediction.

RelationAttentionMLPHead

Relation-aware attention-MLP hybrid head for edge prediction.

RotatEHead

RotatE head for relation-aware edge prediction.

TransEHead

TransE head for relation-aware edge prediction.

Decoder

Decoder combining encoder and head for complete model architecture.

Classes

AttentionHead(*args, **kwargs)

Lightweight attention head for edge prediction.

ConditionalRotatEHead(*args, **kwargs)

Conditional decoder: DotProduct for symmetric relations, RotatE for asymmetric.

Decoder(*args, **kwargs)

Unified head decoder that can create different types of prediction heads.

DistMultHead(*args, **kwargs)

DistMult-style relation scoring for graph neural networks.

DotProductHead(*args, **kwargs)

Dot product head for edge prediction.

EdgeMLPHead(*args, **kwargs)

Multi-layer perceptron head for edge prediction.

NodeClassificationHead(*args, **kwargs)

Simple linear head for node classification tasks.

RelationAttentionHead(*args, **kwargs)

Lightweight relation-aware multi-head attention for edge prediction.

RelationAttentionMLPHead(*args, **kwargs)

Relation-attention MLP head for edge prediction.

RelationGatedMLPHead(*args, **kwargs)

Relation-gated MLP head for edge prediction.

RotatEHead(*args, **kwargs)

RotatE decoder for relation-aware edge prediction.

TransEHead(*args, **kwargs)

TransE decoder for relation-aware edge prediction.

class napistu_torch.models.heads.AttentionHead(*args: Any, **kwargs: Any)

Bases: Module

Lightweight attention head for edge prediction.

Projects nodes to query/key spaces and computes scaled dot-product attention. More expressive than dot product but much lighter than full MLP with attention.

Architecture: 1. Project source nodes → query space 2. Project target nodes → key space 3. Compute scaled dot product: (W_q @ src)^T @ (W_k @ tgt) / sqrt(d)

Parameters:
  • embedding_dim (int) – Dimension of input node embeddings

  • attention_dim (int, optional) – Dimension of attention space (lower = more compression), by default 64

  • init_as_identity (bool, optional) – Initialize projections to approximate identity (dot product), by default False

Notes

  • Projects to lower dimension for efficiency and regularization

  • Learns separate query/key transformations (more flexible than dot product)

  • Scaled dot product prevents gradient vanishing in high dimensions

  • Normalizes embeddings for numerical stability

  • ~2 * embedding_dim * attention_dim parameters (e.g., ~16K for 128→64)

Comparison to other heads: - vs DotProduct: Learns transformations (more expressive) - vs MLP: Much fewer parameters, easier to interpret - vs RelationAttention: No relation-specific modulation

Examples

>>> head = AttentionHead(embedding_dim=128, attention_dim=64)
>>> scores = head(node_embeddings, edge_index)
>>> # scores ∈ ℝ^{num_edges}, approximately normalized by scaling
__init__(embedding_dim: int, attention_dim: int = 64, init_as_identity: bool = False)
_initialize_weights(init_as_identity: bool)

Initialize projection weights.

forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor

Compute attention-based edge scores.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings from encoder [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

Returns:

Edge scores [num_edges]

Return type:

torch.Tensor

Notes

Scores are computed as:

score = (W_q @ normalize(src))^T @ (W_k @ normalize(tgt)) / sqrt(d_attn)

Normalization ensures embeddings have bounded norms, preventing score explosion with pretrained encoders.

loss_type = 'bce'
class napistu_torch.models.heads.ConditionalRotatEHead(*args: Any, **kwargs: Any)

Bases: Module

Conditional decoder: DotProduct for symmetric relations, RotatE for asymmetric.

Automatically routes different relation types to appropriate scoring functions: - Symmetric relations (e.g., “protein->protein”): DotProduct distance - Asymmetric relations (e.g., “catalyst->modified”): RotatE rotation distance

Both heads produce distance-based scores in [-2, 0] for margin loss.

Parameters:
  • embedding_dim (int) – Dimension of node embeddings (must be even for RotatE complex embeddings)

  • num_relations (int) – Total number of relation types

  • symmetric_relation_indices (List[int]) – Indices of relations that should use dot product (symmetric). All other relations use RotatE (asymmetric). Obtained from NapistuData.analyze_relation_symmetry()

  • init_asymmetric_as_identity (bool, optional) – Initialize RotatE phases to 0 (identity rotation), by default False

  • margin (float, optional) – Margin for ranking loss (applied to both heads), by default 9.0

Notes

Score Ranges (with normalized embeddings): Both heads produce scores in [-2, 0]: - DotProduct: distance = 1 - similarity, score = -distance - RotatE: distance = ||h⊙r - t||, score = -distance

When to Use: - Graph has mix of symmetric (A↔B) and asymmetric (A→B) relations - Example: protein-protein interactions (symmetric) + catalysis (asymmetric)

When NOT to Use: - All relations symmetric → Use DotProduct or DistMult instead - All relations asymmetric → Use RotatE or TransE instead

__init__(embedding_dim: int, num_relations: int, symmetric_relation_indices: List[int], init_asymmetric_as_identity: bool = False, margin: float = 0.1)
_compute_dot_scores(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor

Compute dot product-based distance scores.

_compute_rotate_scores(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute RotatE rotation-based distance scores using shared utility.

forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute edge scores using conditional head selection.

Routes edges to DotProduct (symmetric) or RotatE (asymmetric) based on relation type, then returns unified distance-based scores.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • relation_type (torch.Tensor) – Relation type for each edge [num_edges]

Returns:

Edge scores in [-2, 0] for margin loss [num_edges] Higher score = more likely edge (closer to 0)

Return type:

torch.Tensor

scores_to_probs(scores: torch.Tensor) torch.Tensor
loss_type = 'margin'
class napistu_torch.models.heads.Decoder(*args: Any, **kwargs: Any)

Bases: Module

Unified head decoder that can create different types of prediction heads.

This class provides a single interface for creating various head types (e.g., dot product, MLP, attention, node classification) with a from_config classmethod for easy integration with configuration systems.

Parameters:
  • hidden_channels (int) – Dimension of input node embeddings (should match GNN encoder output)

  • head_type (str) – Type of head to create (dot_product, mlp, attention, node_classification)

  • num_relations (int, optional) – Number of relation types (required for relation-aware heads)

  • symmetric_relation_indices (List[int], optional) – List of relation type indices that are symmetric. This is required for heads that support special symmetry handling.

  • num_classes (int, optional) – Number of output classes for node classification head

  • init_head_as_identity (bool, optional) – Whether to initialize the head to approximate an identity transformation, by default False

  • mlp_hidden_dim (int, optional) – Hidden layer dimension for MLP head, by default 64

  • mlp_num_layers (int, optional) – Number of hidden layers for MLP head, by default 2

  • mlp_dropout (float, optional) – Dropout probability for MLP head, by default 0.1

  • nc_dropout (float, optional) – Dropout probability for node classification head, by default 0.1

  • rotate_margin (float, optional) – Margin for RotatE head, by default 9.0

  • transe_margin (float, optional) – Margin for TransE head, by default 1.0

  • relation_emb_dim (int,) – Dimension of relation embeddings for relation-aware MLP heads, by default 64

  • relation_attention_heads (int,) – Number of attention heads for RelationAttentionMLP, by default 4

  • Methods (Public)

  • --------------

  • Dict[str (config(self) ->) – Get the configuration dictionary for this decoder.

  • Any] – Get the configuration dictionary for this decoder.

  • from_config(config (ModelConfig, num_relations: Optional[int] = None, num_classes: Optional[int] = None) -> Decoder:) – Create a Decoder from a ModelConfig instance.

  • forward(node_embeddings (torch.Tensor, edge_index: Optional[torch.Tensor] = None, relation_type: Optional[torch.Tensor] = None) -> torch.Tensor:) – Forward pass through the head.

  • bool (supports_relations(self) ->) – Check if this decoder supports relation-aware heads.

classmethod from_config(config: ModelConfig, num_relations: int | None = None, num_classes: int | None = None, symmetric_relation_indices: List[int] | None = None)

Create a Decoder from a configuration object.

Parameters:
  • config (ModelConfig) – Configuration object containing head parameters

  • num_relations (int, optional) – Number of relation types (required for relation-aware heads). This should be inferred from edge_strata.

  • num_classes (int, optional) – Number of output classes for node classification head (required for node classification head). This should be inferred from the data.

  • symmetric_relation_indices (List[int], optional) – List of relation type indices that are symmetric. This is required for heads that support special symmetry handling.

Returns:

Configured head decoder

Return type:

Decoder

__init__(hidden_channels: int, head_type: str = 'dot_product', num_relations: int | None = None, symmetric_relation_indices: List[int] | None = None, num_classes: int | None = None, init_head_as_identity: bool = False, mlp_hidden_dim: int = 128, mlp_num_layers: int = 2, mlp_dropout: float = 0.1, nc_dropout: float = 0.1, rotate_margin: float = 0.1, transe_margin: float = 0.1, relation_emb_dim: int = 64, relation_attention_heads: int = 4)
forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor | None = None, relation_type: torch.Tensor | None = None) torch.Tensor

Forward pass through the head.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]

  • edge_index (torch.Tensor, optional) – Edge connectivity [2, num_edges] (required for edge prediction heads)

  • relation_type (torch.Tensor, optional) – Relation type for each edge [num_edges] (required for relation-aware heads)

Returns:

Head output (edge scores or node predictions)

Return type:

torch.Tensor

get_summary() Dict[str, Any]

Get decoder metadata summary for checkpointing.

Returns essential metadata needed to reconstruct the decoder from a checkpoint, including ALL parameters that were used.

Returns:

Dictionary containing all initialization parameters, with None values filtered out for head-type-specific params

Return type:

Dict[str, Any]

property config: Dict[str, Any]

Get the configuration dictionary for this decoder.

Returns a dict containing all initialization parameters needed to reconstruct this decoder instance.

Returns:

Configuration dictionary with all __init__ parameters

Return type:

Dict[str, Any]

property loss_type: str

Get the loss type required by the underlying head.

Returns:

Loss type (e.g., LOSSES.BCE, LOSSES.MARGIN)

Return type:

str

property margin: float

Get the margin value for heads that support margin loss (RotatE, TransE).

Returns:

Margin value for ranking loss

Return type:

float

Raises:

AttributeError – If the underlying head does not have a margin attribute

property supports_relations: bool

Check if this decoder supports relation-aware heads.

Returns:

True if the head type is in RELATION_AWARE_HEADS, False otherwise

Return type:

bool

class napistu_torch.models.heads.DistMultHead(*args: Any, **kwargs: Any)

Bases: Module

DistMult-style relation scoring for graph neural networks.

Adapted from knowledge graph DistMult (Yang et al. 2015) to GNN setting where nodes share embedding space instead of having separate entity embeddings.

Score = mean(h ⊙ r ⊙ t) where h,t ∈ same embedding space

Parameters:
  • embedding_dim (int) – Dimension of node embeddings from GNN

  • num_relations (int) – Number of distinct relation types

Notes

Key Difference from Original DistMult: - Original: Separate embeddings per entity (h_aspirin, t_headache) - This version: Shared node embedding space (all nodes use same encoder)

Symmetry Warning: Like original DistMult, this is symmetric: score(h,r,t) = score(t,r,h) Cannot distinguish directed relations without combining with asymmetric encoder or relation-specific directionality in the GNN.

References

Yang et al. “Embedding Entities and Relations for Learning and Inference in Knowledge Bases” ICLR 2015.

Examples

>>> # Use only if relations are symmetric
>>> head = DistMultHead(embedding_dim=256, num_relations=4)
>>> scores = head(z, edge_index, relation_type)
__init__(embedding_dim: int, num_relations: int, init_as_identity: bool = False)
forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute edge scores using DistMult.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • relation_type (torch.Tensor) – Relation type for each edge [num_edges]

Returns:

Edge scores [num_edges]

Return type:

torch.Tensor

loss_type = 'bce'
class napistu_torch.models.heads.DotProductHead(*args: Any, **kwargs: Any)

Bases: Module

Dot product head for edge prediction.

Computes edge scores as the dot product of source and target node embeddings. This is the simplest and most efficient head for edge prediction tasks.

__init__()
forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor

Compute edge scores using dot product.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

Returns:

Edge scores [num_edges]

Return type:

torch.Tensor

loss_type = 'bce'
class napistu_torch.models.heads.EdgeMLPHead(*args: Any, **kwargs: Any)

Bases: Module

Multi-layer perceptron head for edge prediction.

Uses an MLP to predict edge scores from concatenated source and target embeddings. More expressive than dot product but requires more parameters.

Parameters:
  • embedding_dim (int) – Dimension of input node embeddings

  • hidden_dim (int, optional) – Hidden layer dimension, by default 64

  • num_layers (int, optional) – Number of hidden layers, by default 2

  • dropout (float, optional) – Dropout probability, by default 0.1

__init__(embedding_dim: int, hidden_dim: int = 64, num_layers: int = 2, dropout: float = 0.1)
forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor) torch.Tensor

Compute edge scores using MLP.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

Returns:

Edge scores [num_edges]

Return type:

torch.Tensor

loss_type = 'bce'
class napistu_torch.models.heads.NodeClassificationHead(*args: Any, **kwargs: Any)

Bases: Module

Simple linear head for node classification tasks.

Parameters:
  • embedding_dim (int) – Dimension of input node embeddings

  • num_classes (int) – Number of output classes

  • dropout (float, optional) – Dropout probability, by default 0.1

__init__(embedding_dim: int, num_classes: int, dropout: float = 0.1)
forward(node_embeddings: torch.Tensor) torch.Tensor

Compute node class predictions.

Parameters:

node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]

Returns:

Node class logits [num_nodes, num_classes]

Return type:

torch.Tensor

class napistu_torch.models.heads.RelationAttentionHead(*args: Any, **kwargs: Any)

Bases: Module

Lightweight relation-aware multi-head attention for edge prediction.

Simplified version of RelationAttentionMLPHead: - Directly projects nodes instead of edge MLP - Multi-head attention with relation queries - No residual connection or output MLP (lighter) - Still captures relation-specific feature selection

Architecture: 1. Project nodes → attention space (replaces edge MLP) 2. Relation embeddings → queries (multi-head) 3. Node projections → keys, values (multi-head) 4. Attention → weighted combination → score

Parameters:
  • embedding_dim (int) – Dimension of input node embeddings

  • num_relations (int) – Number of distinct relation types

  • relation_emb_dim (int, optional) – Dimension of relation embeddings, by default 64

  • hidden_dim (int, optional) – Hidden dimension (must be divisible by num_attention_heads), by default 128

  • num_attention_heads (int, optional) – Number of attention heads, by default 4

Notes

Comparison to RelationAttentionMLPHead: - Much lighter: no edge MLP, no output MLP, no residual - Same core idea: relation queries edge features via attention - Parameters: ~50K vs ~100K (half the size) - More interpretable: fewer non-linearities

Comparison to AttentionHead: - Adds relation-specific attention (like RelationAttentionMLPHead) - Multi-head for richer feature selection - More parameters but more expressive

__init__(embedding_dim: int, num_relations: int, relation_emb_dim: int = 64, hidden_dim: int = 128, num_attention_heads: int = 4)
_initialize_weights()

Initialize weights with sensible defaults.

Strategy: - Relation embeddings: Small random (std=0.1) to allow learning - Edge projection: Xavier with moderate gain to preserve info - Attention Q/K/V: Standard Xavier to balance stability/expressiveness - Output: Small gain to avoid initial saturation

forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute relation-aware multi-head attention scores.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • relation_type (torch.Tensor) – Relation type for each edge [num_edges]

Returns:

Edge scores [num_edges]

Return type:

torch.Tensor

loss_type = 'bce'
class napistu_torch.models.heads.RelationAttentionMLPHead(*args: Any, **kwargs: Any)

Bases: Module

Relation-attention MLP head for edge prediction.

Uses relation embeddings to query edge features via multi-head attention. The relation type determines WHICH aspects of edge features to attend to.

Architecture: 1. Process [src || tgt] through edge MLP → hidden features 2. Relation embedding → Query 3. Edge features → Key, Value 4. Multi-head attention: relation queries edge features 5. Residual connection + output MLP → edge score

Parameters:
  • embedding_dim (int) – Dimension of input node embeddings

  • num_relations (int) – Number of distinct relation types

  • relation_emb_dim (int, optional) – Dimension of relation embeddings, by default 64

  • hidden_dim (int, optional) – Hidden layer dimension (must be divisible by num_attention_heads), by default 128

  • num_layers (int, optional) – Number of layers in output MLP, by default 2

  • dropout (float, optional) – Dropout probability, by default 0.1

  • num_attention_heads (int, optional) – Number of attention heads, by default 4

Notes

  • More expressive than gating (can learn complex feature selection)

  • Different heads can specialize (e.g., one for catalysis, one for inhibition)

  • Attention is per-edge over hidden dimensions (not graph-level like GAT)

  • More parameters than gating but potentially better for diverse relation semantics

  • Reuses MLP hyperparameters from EdgeMLPHead for consistency

Examples

>>> head = RelationAttentionMLPHead(
...     embedding_dim=256,
...     num_relations=10,
...     relation_emb_dim=64,
...     hidden_dim=128,
...     num_attention_heads=4
... )
>>> scores = head(node_embeddings, edge_index, relation_type)
__init__(embedding_dim: int, num_relations: int, relation_emb_dim: int = 64, hidden_dim: int = 128, num_layers: int = 2, dropout: float = 0.1, num_attention_heads: int = 4)
_initialize_weights()

Initialize weights, optionally starting near identity.

forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute edge scores using relation-attention MLP.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • relation_type (torch.Tensor) – Relation type for each edge [num_edges]

Returns:

Edge scores [num_edges]

Return type:

torch.Tensor

loss_type = 'bce'
class napistu_torch.models.heads.RelationGatedMLPHead(*args: Any, **kwargs: Any)

Bases: Module

Relation-gated MLP head for edge prediction.

Uses relation embeddings to modulate edge features via element-wise gating. The relation type controls HOW the MLP processes each edge pair.

Architecture: 1. Process [src || tgt] through edge MLP → hidden features 2. Relation embedding → gate values (via Tanh) 3. Modulate: gated_features = edge_features * relation_gates 4. Final MLP → edge score

Parameters:
  • embedding_dim (int) – Dimension of input node embeddings

  • num_relations (int) – Number of distinct relation types

  • relation_emb_dim (int, optional) – Dimension of relation embeddings, by default 64

  • hidden_dim (int, optional) – Hidden layer dimension, by default 128

  • num_layers (int, optional) – Number of layers in output MLP, by default 2

  • dropout (float, optional) – Dropout probability, by default 0.1

Notes

  • Handles imbalanced relation frequencies well (rare relations share parameters)

  • More parameter-efficient than separate MLPs per relation

  • Tanh gating allows both suppression (negative) and amplification (positive)

  • Reuses MLP hyperparameters from EdgeMLPHead for consistency

Examples

>>> head = RelationGatedMLPHead(
...     embedding_dim=256,
...     num_relations=10,
...     relation_emb_dim=64,
...     hidden_dim=128
... )
>>> scores = head(node_embeddings, edge_index, relation_type)
__init__(embedding_dim: int, num_relations: int, relation_emb_dim: int = 64, hidden_dim: int = 128, num_layers: int = 2, dropout: float = 0.1)
forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute edge scores using relation-gated MLP.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • relation_type (torch.Tensor) – Relation type for each edge [num_edges]

Returns:

Edge scores [num_edges]

Return type:

torch.Tensor

loss_type = 'bce'
class napistu_torch.models.heads.RotatEHead(*args: Any, **kwargs: Any)

Bases: Module

RotatE decoder for relation-aware edge prediction.

Models relations as rotations in complex space: h ⊙ r ≈ t where ⊙ is complex multiplication (Hadamard product in re/im components).

Scoring function: score = -||h ⊙ r - t||

Parameters:
  • embedding_dim (int) – Dimension of node embeddings from GNN (must be even for complex embeddings)

  • num_relations (int) – Number of distinct relation types

  • margin (float, optional) – Margin for ranking loss, by default 1.0

  • init_as_identity (bool, optional) – Initialize relations as identity rotations (angle=0), by default False

Notes

  • Embeddings are split into real/imaginary parts: [embedding_dim/2, embedding_dim/2]

  • Relations are phase angles that rotate head embeddings

  • Handles symmetric relations (r₁ = -r₂) and composition (r₃ = r₁ + r₂)

  • Requires normalized embeddings (||h|| = ||t|| = 1) for bounded distances

  • Distance range with unit norm: [0, 2]

  • Score range: [-2, 0]

References

Sun et al. “RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space” ICLR 2019.

Examples

>>> head = RotatEHead(embedding_dim=256, num_relations=4)
>>> scores = head(z, edge_index, relation_type)
__init__(embedding_dim: int, num_relations: int, margin: float = 0.1, init_as_identity: bool = False)
forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute edge scores using RotatE.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • relation_type (torch.Tensor) – Relation type for each edge [num_edges]

Returns:

Edge scores [num_edges] (higher = more likely, range [-2, 0])

Return type:

torch.Tensor

scores_to_probs(scores: torch.Tensor) torch.Tensor
loss_type = 'margin'
class napistu_torch.models.heads.TransEHead(*args: Any, **kwargs: Any)

Bases: Module

TransE decoder for relation-aware edge prediction.

Models relations as translations in embedding space: h + r ≈ t Simpler than RotatE and often easier to interpret.

Scoring function: score = -||h + r - t||

Parameters:
  • embedding_dim (int) – Dimension of node embeddings from GNN

  • num_relations (int) – Number of distinct relation types

  • margin (float, optional) – Margin for ranking loss, by default 1.0

  • norm (int, optional) – Norm to use for distance (1 or 2), by default 2

Notes

  • Simpler than RotatE (fewer parameters, easier optimization)

  • Naturally handles asymmetric relations: h+r₁ vs h+r₂

  • May struggle with 1-to-N relations (e.g., one reaction → many products)

  • Good baseline before trying more complex heads

  • Requires normalized embeddings (||h|| = ||t|| = 1) for bounded distances

References

Bordes et al. “Translating Embeddings for Modeling Multi-relational Data” NeurIPS 2013.

Examples

>>> head = TransEHead(embedding_dim=256, num_relations=4)
>>> scores = head(z, edge_index, relation_type)
__init__(embedding_dim: int, num_relations: int, margin: float = 0.1, norm: int = 2, init_as_identity: bool = False)
forward(node_embeddings: torch.Tensor, edge_index: torch.Tensor, relation_type: torch.Tensor) torch.Tensor

Compute edge scores using TransE.

Parameters:
  • node_embeddings (torch.Tensor) – Node embeddings from GNN [num_nodes, embedding_dim]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • relation_type (torch.Tensor) – Relation type for each edge [num_edges]

Returns:

Edge scores [num_edges] (higher = more likely)

Return type:

torch.Tensor

scores_to_probs(scores: torch.Tensor) torch.Tensor
loss_type = 'margin'