napistu_torch.ml.metrics
Custom metrics for model evaluation.
Classes
|
Compute per-relation AUC and weighted average AUC. |
- class napistu_torch.ml.metrics.RelationWeightedAUC(loss_weights: torch.Tensor, loss_weight_alpha: float, relation_manager: LabelingManager | None = None)
Bases:
objectCompute per-relation AUC and weighted average AUC.
Computes: 1. Overall AUC (all samples pooled, unweighted) 2. Per-relation AUCs (one for each relation type) 3. Relation-weighted AUC (weighted by loss_weight × validation_count)
- Parameters:
loss_weights (torch.Tensor) – Pre-computed loss weights from training [num_relations]
loss_weight_alpha (float) – Alpha parameter used for loss weighting (for logging/reference)
relation_manager (LabelingManager, optional) – Manager for decoding relation type indices to human-readable names
Methods (Public)
--------------
compute(y_true – Compute overall, per-relation, and weighted AUCs.
y_pred – Compute overall, per-relation, and weighted AUCs.
relation_type) – Compute overall, per-relation, and weighted AUCs.
Examples
>>> # During validation >>> rw_auc = RelationWeightedAUC( ... loss_weights=task._relation_weights, ... loss_weight_alpha=task.loss_weight_alpha, ... relation_manager=data.relation_manager ... ) >>> metrics = rw_auc.compute(y_true, y_pred, relation_type) >>> print(metrics) { 'auc': 0.85, 'auc_relation_weighted': 0.82, 'auc_catalysis': 0.78, 'auc_interaction': 0.88, 'auc_inhibition': 0.81 }
- __init__(loss_weights: torch.Tensor, loss_weight_alpha: float, relation_manager: LabelingManager | None = None)
- compute(y_true: ndarray, y_pred: ndarray, relation_type: torch.Tensor) Dict[str, float]
Compute overall, per-relation, and weighted AUCs.
- Parameters:
y_true (np.ndarray) – True binary labels [num_samples]
y_pred (np.ndarray) – Predicted probabilities [num_samples]
relation_type (torch.Tensor) – Relation type index for each sample [num_samples]
- Returns:
Dictionary containing: - ‘auc’: Overall AUC (all samples pooled) - ‘auc_relation_weighted’: Weighted average of per-relation AUCs - ‘auc_{relation_name}’: Per-relation AUC for each relation type
- Return type:
Dict[str, float]
- napistu_torch.ml.metrics._compute_per_relation_aucs(y_true: ndarray, y_pred: ndarray, relation_type: torch.Tensor, unique_relations: ndarray, relation_manager: LabelingManager | None = None) Tuple[List[float], List[int], Dict[str, float]]
Compute per-relation AUCs for each relation type.
Assumes all relations have been validated to have both classes present.
- Parameters:
y_true (np.ndarray) – True binary labels [num_samples]
y_pred (np.ndarray) – Predicted probabilities [num_samples]
relation_type (torch.Tensor) – Relation type index for each sample [num_samples]
unique_relations (np.ndarray) – Unique relation type indices (in sorted order)
relation_manager (LabelingManager, optional) – Manager for decoding relation type indices to human-readable names
- Returns:
per_relation_aucs (List[float]) – AUC for each relation type (in order of unique_relations)
per_relation_counts (List[int]) – Number of samples for each relation type (in order of unique_relations)
per_relation_results (Dict[str, float]) – Dictionary mapping ‘auc_{relation_name}’ to AUC value for each relation
- napistu_torch.ml.metrics._log_pathological_labels(y_true: ndarray, relation_type_np: ndarray, unique_relations: ndarray, relation_manager: LabelingManager | None = None) None
Check for relations missing expected classes and raise informative error.
- Parameters:
y_true (np.ndarray) – True binary labels [num_samples]
relation_type_np (np.ndarray) – Relation type indices [num_samples]
unique_relations (np.ndarray) – Unique relation type indices
relation_manager (LabelingManager, optional) – Manager for decoding relation type indices to human-readable names
- Raises:
ValueError – If any relation type is missing both positive and negative samples