napistu_torch.configs

Configuration classes for Napistu-Torch experiments.

This module provides Pydantic-based configuration classes for defining experiments, data loading, model architecture, tasks, training, and Weights & Biases integration.

Classes

DataConfig

Data loading and splitting configuration.

ModelConfig

Model architecture and component configuration.

TaskConfig

Task-specific configuration.

TrainingConfig

Training hyperparameters and settings.

WandBConfig

Weights & Biases integration configuration.

ExperimentConfig

Complete experiment configuration combining all component configs.

RunManifest

Manifest tracking experiment run metadata and artifacts.

Functions

config_to_data_trimming_spec(config)

Based on the config, return a dictionary of booleans indicating whether each attribute should be kept.

create_template_yaml(output_path[, ...])

Create a minimal YAML template file for experiment configuration.

task_config_to_artifact_names(task_config)

Convert a TaskConfig to a list of artifact names required by the task.

Classes

DataConfig(*, store_dir, sbml_dfs_path, ...)

Data loading and splitting configuration.

ExperimentConfig(*, name, seed, ...)

Top-level experiment configuration.

ModelConfig(*[, encoder, hidden_channels, ...])

Model architecture configuration.

RunManifest(*, created_at, wandb_run_id, ...)

Manifest file containing all information about a training run.

TaskConfig(*, task, metrics, ...)

Task-specific configuration

TrainingConfig(*[, lr, weight_decay, ...])

Training hyperparameters.

WandBConfig(*, project, entity, group, tags, ...)

Weights & Biases configuration

class napistu_torch.configs.DataConfig(*, store_dir: Path = PosixPath('.store'), sbml_dfs_path: Path | None = None, napistu_graph_path: Path | None = None, copy_to_store: bool = False, hf_repo_id: str | None = None, hf_revision: str | None = None, napistu_data_name: str = 'edge_prediction', other_artifacts: List[str] = <factory>)

Bases: BaseModel

Data loading and splitting configuration. These parameters are used to setup the NapistuDataStore object and construct the NapistuData object.

classmethod remove_deprecated_fields(data)

Remove deprecated fields from data before validation and warn about them.

validate_paths()

Validate that both paths are either both None or both defined.

_abc_impl = <_abc._abc_data object>
copy_to_store: bool
hf_repo_id: str | None
hf_revision: str | None
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

napistu_data_name: str
napistu_graph_path: Path | None
other_artifacts: List[str]
sbml_dfs_path: Path | None
store_dir: Path
class napistu_torch.configs.ExperimentConfig(*, name: str | None = None, seed: int = 42, deterministic: bool = True, output_dir: Path = PosixPath('output'), model: ModelConfig = <factory>, data: DataConfig = <factory>, task: TaskConfig = <factory>, training: TrainingConfig = <factory>, wandb: WandBConfig = <factory>, fast_dev_run: bool = False, limit_train_batches: float = 1.0, limit_val_batches: float = 1.0)

Bases: BaseModel

Top-level experiment configuration.

Public methods

anonymize(inplace: bool = False, placeholder: str = ANONYMIZATION_PLACEHOLDER_DEFAULT) -> “ExperimentConfig”:

Create an anonymized copy of the config with all Path-like values masked.

from_json(filepath: Path) -> “ExperimentConfig”:

Load from JSON file.

from_yaml(filepath: Path) -> “ExperimentConfig”:

Load from YAML file.

get_experiment_name() -> str:

Generate a descriptive experiment name based on model and task configs.

to_dict() -> dict:

Export to plain dictionary.

to_json(filepath: Path) -> None:

Save to JSON file.

to_yaml(filepath: Path) -> None:

Save to YAML file.

classmethod from_json(filepath: Path)

Load from JSON

classmethod from_yaml(filepath: Path)

Load from YAML

anonymize(inplace: bool = False, placeholder: str = '[REDACTED]') ExperimentConfig

Create an anonymized copy of the config with all Path-like values masked.

Replaces all Path objects and absolute path strings with a placeholder string. Useful for sharing configs without exposing local file paths.

Parameters:
  • inplace (bool, default=False) – If True, modifies the config in place. If False, returns a new config.

  • placeholder (str, default="[REDACTED]") – String to use as placeholder for masked paths.

Returns:

Anonymized config (new instance if inplace=False, self if inplace=True)

Return type:

ExperimentConfig

Examples

>>> config = ExperimentConfig(
...     output_dir=Path("/Users/me/experiments/run1"),
...     data=DataConfig(
...         sbml_dfs_path=Path("/Users/me/data/sbml.pkl"),
...         napistu_graph_path=Path("/Users/me/data/graph.pkl")
...     )
... )
>>> anonymized = config.anonymize()
>>> str(anonymized.output_dir)
'[REDACTED]'
>>> str(anonymized.data.sbml_dfs_path)
'[REDACTED]'
get_experiment_name() str

Generate a descriptive experiment name based on model and task configs

to_dict()

Export to plain dict

to_json(filepath: Path)

Save to JSON

to_yaml(filepath: Path)

Save to YAML

_abc_impl = <_abc._abc_data object>
data: DataConfig
deterministic: bool
fast_dev_run: bool
limit_train_batches: float
limit_val_batches: float
model: ModelConfig
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

name: str | None
output_dir: Path
seed: int
task: TaskConfig
training: TrainingConfig
wandb: WandBConfig
class napistu_torch.configs.ModelConfig(*, encoder: str = 'sage', hidden_channels: Annotated[int, Gt(gt=0)] = 128, num_layers: Annotated[int, Ge(ge=1), Le(le=10)] = 3, dropout: Annotated[float, Ge(ge=0.0), Lt(lt=1.0)] = 0.1, head: str = 'dot_product', init_head_as_identity: bool | None = False, gat_heads: Annotated[int | None, Gt(gt=0)] = 4, gat_concat: bool | None = True, graph_conv_aggregator: str | None = 'mean', sage_aggregator: str | None = 'mean', mlp_hidden_dim: Annotated[int | None, Gt(gt=0)] = 128, mlp_num_layers: Annotated[int | None, Ge(ge=1)] = 2, mlp_dropout: Annotated[float | None, Ge(ge=0.0), Lt(lt=1.0)] = 0.1, nc_dropout: Annotated[float | None, Ge(ge=0.0), Lt(lt=1.0)] = 0.1, rotate_margin: Annotated[float | None, Gt(gt=0.0)] = 0.1, transe_margin: Annotated[float | None, Gt(gt=0.0)] = 0.1, relation_emb_dim: Annotated[int | None, Gt(gt=0)] = 64, relation_attention_heads: Annotated[int | None, Gt(gt=0)] = 4, use_edge_encoder: bool | None = False, edge_encoder_dim: Annotated[int | None, Gt(gt=0)] = 32, edge_encoder_dropout: Annotated[float | None, Ge(ge=0.0), Lt(lt=1.0)] = 0.1, edge_encoder_init_bias: float | None = None, use_pretrained_model: bool | None = False, pretrained_model_source: str | None = None, pretrained_model_path: str | None = None, pretrained_model_revision: str | None = None, pretrained_model_load_head: bool | None = True, pretrained_model_freeze_encoder_weights: bool | None = False, pretrained_model_freeze_head_weights: bool | None = False)

Bases: BaseModel

Model architecture configuration.

Public methods

get_architecture_string() -> str:

Get a string representation of the model architecture.

__repr__() -> str:

Return a formatted string representation of the model architecture.

classmethod remove_deprecated_fields(data)

Remove deprecated fields from data before validation and warn about them.

classmethod validate_encoder(v, info)
classmethod validate_head(v, info)
classmethod validate_power_of_2(v)

Optionally enforce power of 2 for efficiency

get_architecture_string() str

Generate a string representation of the model architecture.

Returns the encoder, head, hidden channels, and number of layers in the format “encoder-head_h{hidden_channels}_l{num_layers}”.

Returns:

Architecture string like “sage-dot_product_h128_l3” or “graph_conv-mlp_h64_l2”

Return type:

str

Examples

>>> config = ModelConfig(encoder="sage", head="dot_product", hidden_channels=128, num_layers=3)
>>> config.get_architecture_string()
'sage-dot_product_h128_l3'
>>> config = ModelConfig(encoder="graph_conv", head="mlp", hidden_channels=64, num_layers=2)
>>> config.get_architecture_string()
'graph_conv-mlp_h64_l2'
validate_pretrained_model()

Validate that pretrained model settings are provided when use_pretrained_model=True.

_abc_impl = <_abc._abc_data object>
dropout: float
edge_encoder_dim: int | None
edge_encoder_dropout: float | None
edge_encoder_init_bias: float | None
encoder: str
gat_concat: bool | None
gat_heads: int | None
graph_conv_aggregator: str | None
head: str
hidden_channels: int
init_head_as_identity: bool | None
mlp_dropout: float | None
mlp_hidden_dim: int | None
mlp_num_layers: int | None
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

nc_dropout: float | None
num_layers: int
pretrained_model_freeze_encoder_weights: bool | None
pretrained_model_freeze_head_weights: bool | None
pretrained_model_load_head: bool | None
pretrained_model_path: str | None
pretrained_model_revision: str | None
pretrained_model_source: str | None
relation_attention_heads: int | None
relation_emb_dim: int | None
rotate_margin: float | None
sage_aggregator: str | None
transe_margin: float | None
use_edge_encoder: bool | None
use_pretrained_model: bool | None
class napistu_torch.configs.RunManifest(*, created_at: datetime = <factory>, wandb_run_id: str | None = None, wandb_run_url: str | None = None, wandb_project: str | None = None, wandb_entity: str | None = None, experiment_name: str | None = None, experiment_config: ExperimentConfig)

Bases: BaseModel

Manifest file containing all information about a training run.

This is a wrapper around the ExperimentConfig that includes WandB information and a timestamp.

Public methods

from_yaml(filepath: Path) -> “RunManifest”:

Load manifest from YAML file.

get_run_summary() -> dict:

Get summary metrics from WandB for this experiment.

to_yaml(filepath: Path) -> None:

Save manifest to YAML file.

classmethod from_huggingface(model_loader: HFModelLoader, repo_id: str) RunManifest

Reconstruct RunManifest from HuggingFace artifacts.

Loads experiment_config from config.json and WandB metadata from wandb_run_info.yaml (if available).

Parameters:
  • model_loader (HFModelLoader) – Loader instance with downloaded artifacts

  • repo_id (str) – HuggingFace repository ID (for fallback experiment name)

Returns:

Reconstructed manifest

Return type:

RunManifest

Examples

>>> from napistu_torch.ml.hugging_face import HFModelLoader
>>> loader = HFModelLoader("username/model-name")
>>> manifest = RunManifest.from_huggingface(loader, "username/model-name")
classmethod from_yaml(filepath: Path) RunManifest

Load manifest from YAML file.

Parameters:

filepath (Path) – Path to the YAML file

Returns:

Loaded manifest object with experiment_config as ExperimentConfig instance

Return type:

RunManifest

get_run_summary() dict

Get summary metrics from WandB for this experiment.

Retrieves the summary metrics (final values) from the WandB run associated with this experiment.

Returns:

Dictionary containing summary metrics from WandB (e.g., final validation AUC, training loss, etc.)

Return type:

dict

Raises:
  • ValueError – If WandB run ID is not available

  • RuntimeError – If WandB API access fails

Examples

>>> manifest = RunManifest.from_yaml("run_manifest.yaml")
>>> summary = manifest.get_run_summary()
>>> print(summary["val_auc"])  # Final validation AUC
to_yaml(filepath: Path) None

Save manifest to YAML file.

Parameters:

filepath (Path) – Path where the YAML file will be written

_abc_impl = <_abc._abc_data object>
created_at: datetime
experiment_config: ExperimentConfig
experiment_name: str | None
model_config = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

wandb_entity: str | None
wandb_project: str | None
wandb_run_id: str | None
wandb_run_url: str | None
class napistu_torch.configs.TaskConfig(*, task: str = 'edge_prediction', metrics: List[str] = <factory>, edge_prediction_neg_sampling_ratio: Annotated[float, ~annotated_types.Gt(gt=0)] = 1.0, edge_prediction_neg_sampling_stratify_by: str = 'node_type', edge_prediction_neg_sampling_strategy: str = 'degree_weighted', weight_loss_by_relation_frequency: bool = False, loss_weight_alpha: Annotated[float, ~annotated_types.Ge(ge=0.0), ~annotated_types.Le(le=1.0)] = 0.5)

Bases: BaseModel

Task-specific configuration

classmethod validate_task(v)
_abc_impl = <_abc._abc_data object>
edge_prediction_neg_sampling_ratio: float
edge_prediction_neg_sampling_strategy: str
edge_prediction_neg_sampling_stratify_by: str
loss_weight_alpha: float
metrics: List[str]
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

task: str
weight_loss_by_relation_frequency: bool
class napistu_torch.configs.TrainingConfig(*, lr: Annotated[float, Gt(gt=0)] = 0.001, weight_decay: Annotated[float, Ge(ge=0)] = 0.0, optimizer: str = 'adam', scheduler: str | None = None, gradient_clip_val: Annotated[float | None, Ge(ge=0.0)] = None, epochs: Annotated[int, Gt(gt=0)] = 200, batches_per_epoch: Annotated[int, Gt(gt=0)] = 1, accelerator: str = 'auto', devices: int = 1, precision: Literal[16, 32, '16-mixed', '32-true'] = 32, early_stopping: bool = True, early_stopping_patience: Annotated[int, Ge(ge=1)] = 20, early_stopping_metric: str = 'val_auc', save_checkpoints: bool = True, checkpoint_subdir: str = 'checkpoints', checkpoint_metric: str = 'val_auc', score_distribution_monitoring: bool = False, score_distribution_monitoring_log_every_n_epochs: Annotated[int, Ge(ge=1)] = 10, embedding_norm_monitoring: bool = False, embedding_norm_monitoring_log_every_n_epochs: Annotated[int, Ge(ge=1)] = 10, weight_monitoring: bool = True)

Bases: BaseModel

Training hyperparameters.

Public methods

get_checkpoint_dir(output_dir: Path) -> Path:

Get absolute checkpoint directory.

classmethod validate_optimizer(v)
classmethod validate_scheduler(v)
get_checkpoint_dir(output_dir: Path) Path

Get absolute checkpoint directory

_abc_impl = <_abc._abc_data object>
accelerator: str
batches_per_epoch: int
checkpoint_metric: str
checkpoint_subdir: str
devices: int
early_stopping: bool
early_stopping_metric: str
early_stopping_patience: int
embedding_norm_monitoring: bool
embedding_norm_monitoring_log_every_n_epochs: int
epochs: int
gradient_clip_val: float | None
lr: float
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

optimizer: str
precision: Literal[16, 32, '16-mixed', '32-true']
save_checkpoints: bool
scheduler: str | None
score_distribution_monitoring: bool
score_distribution_monitoring_log_every_n_epochs: int
weight_decay: float
weight_monitoring: bool
class napistu_torch.configs.WandBConfig(*, project: str = 'napistu-experiments', entity: str | None = 'napistu', group: str | None = 'baseline', tags: List[str] = <factory>, log_model: bool = False, mode: str = 'online', wandb_subdir: str = 'logs')

Bases: BaseModel

Weights & Biases configuration

Public methods

get_enhanced_tags(model_config: ModelConfig, task_config: TaskConfig) -> List[str]:

Get tags with model and task-specific additions.

get_save_dir(output_dir: Path) -> Path:

Get absolute wandb save directory.

classmethod validate_mode(v)
get_enhanced_tags(model_config: ModelConfig, task_config: TaskConfig) List[str]

Get tags with model and task-specific additions

get_save_dir(output_dir: Path) Path

Get absolute wandb save directory

_abc_impl = <_abc._abc_data object>
entity: str | None
group: str | None
log_model: bool
mode: str
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

project: str
tags: List[str]
wandb_subdir: str
napistu_torch.configs._remove_deprecated_fields(data, deprecated_fields: dict[str, str], config_class_name: str)

Remove deprecated fields from data before validation and warn about them.

This allows old configs to load while maintaining strict validation (extra=”forbid”).

Parameters:
  • data (dict or Any) – The data dictionary to process

  • deprecated_fields (dict[str, str]) – Dictionary mapping deprecated field names to their deprecation messages

  • config_class_name (str) – Name of the config class (for warning messages)

Returns:

The data with deprecated fields removed

Return type:

dict or Any

napistu_torch.configs._task_config_to_artifact_names_edge_prediction(task_config: TaskConfig) List[str]

Convert a TaskConfig to a list of artifact names for edge prediction.

napistu_torch.configs.config_to_data_trimming_spec(config: ExperimentConfig) Dict[str, bool]

Based on the config, return a dictionary of booleans indicating whether each attribute should be kept.

Parameters:

config (ExperimentConfig) – The experiment configuration

Returns:

A dictionary with keys “keep_edge_attr”, “keep_labels”, “keep_masks”, “keep_relation_type” and values indicating whether each attribute should be kept. These match the arguments to NapistuData.trim().

Return type:

Dict[str, bool]

napistu_torch.configs.create_template_yaml(output_path: Path, sbml_dfs_path: Path | None = None, napistu_graph_path: Path | None = None, name: str | None = None) None

Create a minimal YAML template file for experiment configuration.

This creates a clean, minimal YAML file with only: - Required data paths (sbml_dfs_path, napistu_graph_path) - Experiment metadata (name) - Common configuration options (without default values)

Users can then customize this template without all the default values cluttering the file.

Parameters:
  • output_path (Path) – Path where the YAML template file will be written

  • sbml_dfs_path (Optional[Path], default=None) – Path to the SBML_dfs pickle file. If None, uses a placeholder.

  • napistu_graph_path (Optional[Path], default=None) – Path to the NapistuGraph pickle file. If None, uses a placeholder.

  • name (Optional[str], default=None) – Experiment name. If None, omits the name field.

Examples

>>> from pathlib import Path
>>> from napistu_torch.configs import create_template_yaml
>>>
>>> # Create template with placeholder paths
>>> create_template_yaml(
...     output_path=Path("config.yaml"),
...     sbml_dfs_path=Path("data/sbml_dfs.pkl"),
...     napistu_graph_path=Path("data/graph.pkl"),
...     name="my_experiment"
... )
napistu_torch.configs.task_config_to_artifact_names(task_config: TaskConfig) List[str]

Convert a TaskConfig to a list of artifact names required by the task.

Parameters:

task_config (TaskConfig) – Task configuration object

Returns:

List of artifact names required by the task

Return type:

List[str]

Examples

>>> from napistu_torch.configs import TaskConfig, task_config_to_artifact_names
>>> task_config = TaskConfig(
...     task="edge_prediction",
...     edge_prediction_neg_sampling_stratify_by="edge_strata_by_node_type"
... )
>>> artifacts = task_config_to_artifact_names(task_config)
>>> print(artifacts)
['edge_strata_by_node_type']