napistu_torch.ml.metrics

Custom metrics for model evaluation.

Classes

RelationWeightedAUC(loss_weights, ...[, ...])

Compute per-relation AUC and weighted average AUC.

class napistu_torch.ml.metrics.RelationWeightedAUC(loss_weights: torch.Tensor, loss_weight_alpha: float, relation_manager: LabelingManager | None = None)

Bases: object

Compute per-relation AUC and weighted average AUC.

Computes: 1. Overall AUC (all samples pooled, unweighted) 2. Per-relation AUCs (one for each relation type) 3. Relation-weighted AUC (weighted by loss_weight × validation_count)

Parameters:
  • loss_weights (torch.Tensor) – Pre-computed loss weights from training [num_relations]

  • loss_weight_alpha (float) – Alpha parameter used for loss weighting (for logging/reference)

  • relation_manager (LabelingManager, optional) – Manager for decoding relation type indices to human-readable names

  • Methods (Public)

  • --------------

  • compute(y_true – Compute overall, per-relation, and weighted AUCs.

  • y_pred – Compute overall, per-relation, and weighted AUCs.

  • relation_type) – Compute overall, per-relation, and weighted AUCs.

Examples

>>> # During validation
>>> rw_auc = RelationWeightedAUC(
...     loss_weights=task._relation_weights,
...     loss_weight_alpha=task.loss_weight_alpha,
...     relation_manager=data.relation_manager
... )
>>> metrics = rw_auc.compute(y_true, y_pred, relation_type)
>>> print(metrics)
{
    'auc': 0.85,
    'auc_relation_weighted': 0.82,
    'auc_catalysis': 0.78,
    'auc_interaction': 0.88,
    'auc_inhibition': 0.81
}
__init__(loss_weights: torch.Tensor, loss_weight_alpha: float, relation_manager: LabelingManager | None = None)
compute(y_true: ndarray, y_pred: ndarray, relation_type: torch.Tensor) Dict[str, float]

Compute overall, per-relation, and weighted AUCs.

Parameters:
  • y_true (np.ndarray) – True binary labels [num_samples]

  • y_pred (np.ndarray) – Predicted probabilities [num_samples]

  • relation_type (torch.Tensor) – Relation type index for each sample [num_samples]

Returns:

Dictionary containing: - ‘auc’: Overall AUC (all samples pooled) - ‘auc_relation_weighted’: Weighted average of per-relation AUCs - ‘auc_{relation_name}’: Per-relation AUC for each relation type

Return type:

Dict[str, float]

napistu_torch.ml.metrics._compute_per_relation_aucs(y_true: ndarray, y_pred: ndarray, relation_type: torch.Tensor, unique_relations: ndarray, relation_manager: LabelingManager | None = None) Tuple[List[float], List[int], Dict[str, float]]

Compute per-relation AUCs for each relation type.

Assumes all relations have been validated to have both classes present.

Parameters:
  • y_true (np.ndarray) – True binary labels [num_samples]

  • y_pred (np.ndarray) – Predicted probabilities [num_samples]

  • relation_type (torch.Tensor) – Relation type index for each sample [num_samples]

  • unique_relations (np.ndarray) – Unique relation type indices (in sorted order)

  • relation_manager (LabelingManager, optional) – Manager for decoding relation type indices to human-readable names

Returns:

  • per_relation_aucs (List[float]) – AUC for each relation type (in order of unique_relations)

  • per_relation_counts (List[int]) – Number of samples for each relation type (in order of unique_relations)

  • per_relation_results (Dict[str, float]) – Dictionary mapping ‘auc_{relation_name}’ to AUC value for each relation

napistu_torch.ml.metrics._log_pathological_labels(y_true: ndarray, relation_type_np: ndarray, unique_relations: ndarray, relation_manager: LabelingManager | None = None) None

Check for relations missing expected classes and raise informative error.

Parameters:
  • y_true (np.ndarray) – True binary labels [num_samples]

  • relation_type_np (np.ndarray) – Relation type indices [num_samples]

  • unique_relations (np.ndarray) – Unique relation type indices

  • relation_manager (LabelingManager, optional) – Manager for decoding relation type indices to human-readable names

Raises:

ValueError – If any relation type is missing both positive and negative samples