napistu_torch.visualization.basic_metrics
Basic metrics like train loss and test/val AUC.
Functions
|
Plot only test/val AUC comparison. |
|
Create comparison plots for model training loss and test/val AUC. |
- napistu_torch.visualization.basic_metrics._extract_metric(summaries: Dict[str, Dict[str, Any]], metric_key: str) List[float | None]
Extract a metric from all model summaries.
- napistu_torch.visualization.basic_metrics._plot_test_val_auc(ax: Axes, display_names: List[str], test_aucs: List[float], val_aucs: List[float], ylim: Tuple[float, float] | None = None, bar_width: float = 0.35, title: str | None = None, horizontal: bool = False) None
Plot test and validation AUC as grouped bar chart.
- Parameters:
ax (plt.Axes) – Matplotlib axes object to plot on
display_names (List[str]) – Model names for axis labels
test_aucs (List[float]) – Test AUC values for each model
val_aucs (List[float]) – Validation AUC values for each model
ylim (Optional[Tuple[float, float]]) – Axis limits for AUC values as (min, max). If None, calculated automatically
bar_width (float) – Width of each bar in the grouped bar chart
title (Optional[str]) – Custom title for the plot. If None, uses default title
horizontal (bool) – If True, create horizontal bars. If False (default), create vertical bars
- napistu_torch.visualization.basic_metrics._plot_train_loss(ax: Axes, display_names: List[str], train_losses: List[float], ylim: Tuple[float, float] | None = None) None
Plot training loss as a bar chart.
- Parameters:
ax (plt.Axes) – Matplotlib axes object to plot on
display_names (List[str]) – Model names for x-axis labels
train_losses (List[float]) – Training loss values for each model
ylim (Optional[Tuple[float, float]]) – Y-axis limits as (min, max). If None, calculated automatically
- napistu_torch.visualization.basic_metrics.plot_auc_only(summaries: Dict[str, Dict[str, Any]], display_names: List[str], figsize: Tuple[int, int] = (10, 6), test_auc_attribute: str = 'test_auc', val_auc_attribute: str = 'val_auc', title: str | None = None, horizontal: bool = True, ax: Axes | None = None, **kwargs) Tuple[Figure, Axes]
Plot only test/val AUC comparison.
- Parameters:
summaries (Dict[str, Dict[str, Any]]) – Dictionary mapping model names to their summary metrics. Each summary must contain ‘test_auc’ and ‘val_auc’.
display_names (List[str]) – Clean display names for models (must match order of summaries.keys())
figsize (Tuple[int, int]) – Figure size as (width, height)
test_auc_attribute (str) – Attribute name for test AUC in summaries
val_auc_attribute (str) – Attribute name for validation AUC in summaries
title (Optional[str]) – Custom title for the plot. If None, uses default title
horizontal (bool) – If True, create horizontal bars. If False (default), create vertical bars
ax (Optional[plt.Axes]) – Matplotlib axes object to plot on. If None, creates a new figure and axes.
**kwargs (dict) – Additional keyword arguments to pass to _plot_test_val_auc
- Returns:
Figure and axis objects
- Return type:
Tuple[plt.Figure, plt.Axes]
Examples
>>> summaries = { ... 'model1': {'test_auc': 0.75, 'val_auc': 0.74}, ... 'model2': {'test_auc': 0.78, 'val_auc': 0.77} ... } >>> display_names = ['Model 1', 'Model 2'] >>> fig, ax = plot_auc_only(summaries, display_names) >>> plt.show()
- napistu_torch.visualization.basic_metrics.plot_model_comparison(summaries: Dict[str, Dict[str, Any]], display_names: List[str], figsize: Tuple[int, int] = (16, 6), train_loss_attribute: str = 'train_loss', test_auc_attribute: str = 'test_auc', val_auc_attribute: str = 'val_auc') Tuple[Figure, Tuple[Axes, Axes]]
Create comparison plots for model training loss and test/val AUC.
- Parameters:
summaries (Dict[str, Dict[str, Any]]) – Dictionary mapping model names to their summary metrics. Each summary must contain ‘train_loss’, ‘test_auc’, and ‘val_auc’.
display_names (List[str]) – Clean display names for models (must match order of summaries.keys())
figsize (Tuple[int, int]) – Figure size as (width, height)
train_loss_attribute (str) – Attribute name for train loss in summaries
test_auc_attribute (str) – Attribute name for test AUC in summaries
val_auc_attribute (str) – Attribute name for validation AUC in summaries
- Returns:
Figure and axes tuple (ax1 for train loss, ax2 for AUC)
- Return type:
Tuple[plt.Figure, Tuple[plt.Axes, plt.Axes]]
Examples
>>> summaries = { ... 'model1': {'train_loss': 1.5, 'test_auc': 0.75, 'val_auc': 0.74}, ... 'model2': {'train_loss': 1.2, 'test_auc': 0.78, 'val_auc': 0.77} ... } >>> display_names = ['Model 1', 'Model 2'] >>> fig, (ax1, ax2) = plot_model_comparison(summaries, display_names) >>> plt.show()