napistu_torch.models.head_utils

Utility functions supporting a subset of heads.

This module provides utility functions for computing distances and probabilities used by various prediction heads, particularly relation-aware heads like RotatE.

Public Functions

compute_rotate_distance(head_embeddings, tail_embeddings, relation_phase, eps=1e-10)

Compute RotatE distance in complex space.

normalized_distances_to_probs(scores)

Convert distances between softmax-normalized vectors to probabilities.

validate_symmetric_relation_indices(symmetric_relation_indices, num_relations)

Validate that symmetric relation indices are properly configured.

Functions

compute_rotate_distance(head_embeddings, ...)

Compute RotatE distance in complex space.

normalized_distances_to_probs(scores)

Convert distances between softmax-normalized vectors to probabilities.

validate_symmetric_relation_indices(...)

Validate symmetric relation indices.

napistu_torch.models.head_utils.compute_rotate_distance(head_embeddings: torch.Tensor, tail_embeddings: torch.Tensor, relation_phase: torch.Tensor, eps: float = 1e-10) torch.Tensor

Compute RotatE distance in complex space.

Models relations as rotations: h ⊙ r ≈ t Distance measures how well the rotation transforms h to t.

Parameters:
  • head_embeddings (torch.Tensor) – Source node embeddings [num_edges, embedding_dim] Must be normalized and have even dimension

  • tail_embeddings (torch.Tensor) – Target node embeddings [num_edges, embedding_dim] Must be normalized and have even dimension

  • relation_phase (torch.Tensor) – Rotation phase angles [num_edges, embedding_dim/2] Angles in radians for complex rotation

  • eps (float, optional) – Small constant for numerical stability, by default 1e-10

Returns:

Distance in [0, 2] for normalized embeddings [num_edges]

Return type:

torch.Tensor

Notes

The computation follows RotatE (Sun et al. 2019): 1. Split embeddings into real/imaginary parts 2. Convert phase to complex rotation: r = cos(θ) + i*sin(θ) 3. Complex multiply: h ⊙ r 4. Compute L2 distance: ||h ⊙ r - t||

References

Sun et al. “RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space” ICLR 2019.

napistu_torch.models.head_utils.normalized_distances_to_probs(scores: torch.Tensor) torch.Tensor

Convert distances between softmax-normalized vectors to probabilities.

Parameters:

scores (torch.Tensor) – Raw RotatE scores (negative distances [-2, 0])

Returns:

Probabilities in [0, 1]

Return type:

torch.Tensor

napistu_torch.models.head_utils.validate_symmetric_relation_indices(symmetric_relation_indices: List[int] | tuple | range, num_relations: int) None

Validate symmetric relation indices.

Parameters:
  • symmetric_relation_indices (List[int]) – Indices to validate

  • num_relations (int) – Total number of relations

Raises:

ValueError – If indices are invalid (duplicates, out of range, or all/none symmetric)