napistu_torch.load.checkpoints
Checkpoint loading and validation utilities.
This module provides utilities for loading and validating pretrained Napistu-Torch models.
Classes
- Checkpoint
Manager for PyTorch Lightning checkpoint loading and validation.
- DataMetadata
Metadata about the NapistuData used during training.
- EdgeEncoderMetadata
Metadata about the edge encoder component.
- EncoderMetadata
Metadata about the encoder component.
- HeadMetadata
Metadata about the head component.
- ModelMetadata
Metadata about the complete model architecture.
- CheckpointHyperparameters
Hyperparameters stored in the checkpoint.
- CheckpointStructure
Structure definition for checkpoint validation.
Classes
|
Manager for PyTorch Lightning checkpoint loading and validation. |
|
Validated hyperparameters structure from Lightning checkpoint. |
|
Validated structure of a Lightning checkpoint dictionary. |
|
Validated metadata about the training data. |
|
Validated metadata about the edge encoder. |
|
Validated metadata about the encoder. |
|
Validated metadata about the head/decoder. |
|
Validated metadata about the complete model. |
- class napistu_torch.load.checkpoints.Checkpoint(checkpoint_dict: Dict[str, Any])
Bases:
objectManager for PyTorch Lightning checkpoint loading and validation.
This class handles loading checkpoints, extracting metadata, validating compatibility with current data, and reconstructing model components.
- Parameters:
checkpoint_dict (Dict[str, Any]) – PyTorch Lightning checkpoint dictionary (validated via Pydantic)
Methods (Private)
--------------
assert_same_napistu_data(napistu_data (NapistuData) -> None:) – Validate that current NapistuData is compatible with checkpoint.
load(checkpoint_path (Union[str, Path], map_location: str = DEVICE.CPU) -> "Checkpoint":) – Load and validate a checkpoint from a local file.
Dict[str (get_data_summary() ->) – Get encoder configuration as dictionary.
Any] – Get encoder configuration as dictionary.
Dict[str – Get head configuration as dictionary.
Any] – Get head configuration as dictionary.
Dict[str – Get data summary as dictionary.
Any] – Get data summary as dictionary.
update_model_config(model_config (ModelConfig, inplace: bool = True) -> ModelConfig:) – Update a ModelConfig instance with settings from the checkpoint.
Methods
---------------
_update_model_config_with_encoder(model_config (ModelConfig, inplace: bool = True) -> Optional[ModelConfig]:) – Update a ModelConfig instance with encoder configuration from checkpoint.
_update_model_config_with_head(model_config (ModelConfig, inplace: bool = True) -> Optional[ModelConfig]:) – Update a ModelConfig instance with head configuration from checkpoint.
Examples
>>> # Load from local file (automatically validated) >>> checkpoint = Checkpoint.load("path/to/checkpoint.ckpt") >>> >>> # Validate compatibility with current data >>> checkpoint.assert_same_napistu_data(current_data) >>> >>> # Access validated configurations >>> encoder_config = checkpoint.encoder_metadata >>> head_config = checkpoint.head_metadata >>> data_config = checkpoint.data_metadata
- classmethod load(checkpoint_path: str | Path, map_location: str = 'cpu') Checkpoint
Load and validate a checkpoint from a local file.
- Parameters:
checkpoint_path (Union[str, Path]) – Path to the checkpoint file (.ckpt)
map_location (str, optional) – Device to load tensors to, by default ‘cpu’
- Returns:
Loaded and validated checkpoint object
- Return type:
- Raises:
FileNotFoundError – If checkpoint file doesn’t exist
RuntimeError – If checkpoint loading fails
ValidationError – If checkpoint structure is invalid
Examples
>>> checkpoint = Checkpoint.load("model.ckpt") >>> checkpoint = Checkpoint.load("model.ckpt", map_location="cuda:0")
- __init__(checkpoint_dict: Dict[str, Any])
Initialize Checkpoint from a checkpoint dictionary.
- Parameters:
checkpoint_dict (Dict[str, Any]) – PyTorch Lightning checkpoint dictionary
- Raises:
ValidationError – If checkpoint structure is invalid
- _update_model_config_with_encoder(model_config: ModelConfig, inplace: bool = True) ModelConfig | None
Update a ModelConfig instance with encoder configuration from checkpoint.
Updates encoder-related fields in the ModelConfig based on the checkpoint’s encoder metadata. This is useful when reconstructing a model from a checkpoint.
- Parameters:
model_config (ModelConfig) – ModelConfig instance to update
inplace (bool, optional) – If True, modify the ModelConfig in place. If False, create a copy and return it. Default is True.
- Returns:
The updated ModelConfig instance. If inplace=True, returns the same object. If inplace=False, returns a new ModelConfig instance.
- Return type:
Examples
>>> checkpoint = Checkpoint.load("model.ckpt") >>> model_config = ModelConfig() >>> checkpoint._update_model_config_with_encoder(model_config) >>> # model_config now has encoder settings from checkpoint >>> >>> # Create a copy instead >>> updated_config = checkpoint._update_model_config_with_encoder(model_config, inplace=False)
- _update_model_config_with_head(model_config: ModelConfig, inplace: bool = True) ModelConfig | None
Update a ModelConfig instance with head configuration from checkpoint.
Updates head-related fields in the ModelConfig based on the checkpoint’s head metadata. This is useful when reconstructing a model from a checkpoint.
- Parameters:
model_config (ModelConfig) – ModelConfig instance to update
inplace (bool, optional) – If True, modify the ModelConfig in place. If False, create a copy and return it. Default is True.
- Returns:
The updated ModelConfig instance. If inplace=True, returns the same object. If inplace=False, returns a new ModelConfig instance.
- Return type:
Examples
>>> checkpoint = Checkpoint.load("model.ckpt") >>> model_config = ModelConfig() >>> checkpoint._update_model_config_with_head(model_config) >>> # model_config now has head settings from checkpoint >>> >>> # Create a copy instead >>> updated_config = checkpoint._update_model_config_with_head(model_config, inplace=False)
- assert_same_napistu_data(napistu_data: NapistuData) None
Validate that current NapistuData is compatible with checkpoint.
Compares the data summary from the checkpoint with a summary generated from the provided NapistuData object.
- Parameters:
napistu_data (NapistuData) – Current NapistuData object to validate against checkpoint
- Raises:
ValueError – If data summaries don’t match
Examples
>>> checkpoint = Checkpoint.load("model.ckpt") >>> checkpoint.assert_same_napistu_data(current_data)
- get_data_summary() Dict[str, Any]
Get data summary as dictionary.
- Returns:
Data summary dictionary
- Return type:
Dict[str, Any]
- get_edge_encoder_config() Dict[str, Any]
Get edge encoder configuration as dictionary.
- Returns:
Edge encoder configuration dictionary
- Return type:
Dict[str, Any]
- get_encoder_config() Dict[str, Any]
Get encoder configuration as dictionary.
- Returns:
Encoder configuration dictionary
- Return type:
Dict[str, Any]
- get_environment_info() Dict[str, Any]
Get environment information as dictionary.
- Returns:
Environment information dictionary
- Return type:
Dict[str, Any]
- get_head_config() Dict[str, Any]
Get head configuration as dictionary.
- Returns:
Head configuration dictionary
- Return type:
Dict[str, Any]
- update_model_config(model_config: ModelConfig, inplace: bool = True) ModelConfig
Update a ModelConfig instance with settings from the checkpoint.
Updates encoder configuration and optionally head configuration from the checkpoint metadata. This is useful when reconstructing a model from a checkpoint or when loading a pretrained model.
The head is only updated if pretrained_model_load_head is True in the model_config.
- Parameters:
model_config (ModelConfig) – ModelConfig instance to update
inplace (bool, optional) – If True, modify the ModelConfig in place. If False, create a copy and return it. Default is True.
- Returns:
The updated ModelConfig instance. If inplace=True, returns the same object (modified in place). If inplace=False, returns a new ModelConfig instance with updated settings.
- Return type:
Examples
>>> checkpoint = Checkpoint.load("model.ckpt") >>> model_config = ModelConfig() >>> checkpoint.update_model_config(model_config) >>> # model_config now has encoder and head settings from checkpoint >>> >>> # Create a copy instead >>> updated_config = checkpoint.update_model_config(model_config, inplace=False)
- class napistu_torch.load.checkpoints.CheckpointHyperparameters(*, config: Dict[str, Any], model: ModelMetadata, data: DataMetadata, environment: EnvironmentInfo | None = None, **extra_data: Any)
Bases:
BaseModelValidated hyperparameters structure from Lightning checkpoint.
This validates the checkpoint[‘hyper_parameters’] structure.
- classmethod from_task_and_data(task: Any, napistu_data: NapistuData, training_config: TrainingConfig | None = None, capture_environment: bool = True, extra_packages: list[str] | None = None) Dict[str, Any]
Create hyperparameters dict from task and data, with validation.
This is used by SetHyperparameters to build the hyperparameters dict that will be saved to the checkpoint. The dict is validated against the CheckpointHyperparameters schema before returning.
- Parameters:
task (Any) – Task object with get_summary() method that returns model metadata
napistu_data (NapistuData) – NapistuData object to extract data metadata from
training_config (Optional[TrainingConfig], optional) – Training configuration object (usually set by Lightning’s save_hyperparameters()), by default None
capture_environment (bool, optional) – Whether to capture Python environment info, by default True
extra_packages (Optional[list[str]], optional) – Additional package names to capture versions for, by default None
- Returns:
Validated hyperparameters dictionary ready to be saved to checkpoint
- Return type:
Dict[str, Any]
- Raises:
ValidationError – If the constructed hyperparameters don’t match the expected schema
AttributeError – If task doesn’t have get_summary() method
Examples
>>> # In SetHyperparameters >>> hparams_dict = CheckpointHyperparameters.from_task_and_data( ... task=pl_module.task, ... napistu_data=napistu_data, ... training_config=pl_module.hparams.get('config'), ... extra_packages=['wandb', 'numpy'] ... ) >>> pl_module.hparams.update(hparams_dict)
- _abc_impl = <_abc._abc_data object>
- config: Dict[str, Any]
- data: DataMetadata
- environment: EnvironmentInfo | None
- model: ModelMetadata
- model_config = {'extra': 'allow'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class napistu_torch.load.checkpoints.CheckpointStructure(*, state_dict: Dict[str, Any], hyper_parameters: CheckpointHyperparameters, epoch: Annotated[int | None, Ge(ge=0)] = None, global_step: Annotated[int | None, Ge(ge=0)] = None, pytorch_lightning_version: str | None = None, **extra_data: Any)
Bases:
BaseModelValidated structure of a Lightning checkpoint dictionary.
This ensures the checkpoint has all required fields with correct types.
- classmethod validate_state_dict_not_empty(v)
Ensure state_dict is not empty.
- _abc_impl = <_abc._abc_data object>
- epoch: int | None
- global_step: int | None
- hyper_parameters: CheckpointHyperparameters
- model_config = {'extra': 'allow'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- pytorch_lightning_version: str | None
- state_dict: Dict[str, Any]
- class napistu_torch.load.checkpoints.DataMetadata(*, name: str, num_nodes: Annotated[int, Ge(ge=0)], num_edges: Annotated[int, Ge(ge=0)], num_node_features: Annotated[int, Ge(ge=0)], num_edge_features: Annotated[int, Ge(ge=0)], splitting_strategy: str | None = None, num_unique_relations: Annotated[int | None, Ge(ge=0)] = None, num_train_edges: Annotated[int | None, Ge(ge=0)] = None, num_val_edges: Annotated[int | None, Ge(ge=0)] = None, num_test_edges: Annotated[int | None, Ge(ge=0)] = None, vertex_feature_names: List[str] | None = None, edge_feature_names: List[str] | None = None, vertex_feature_name_aliases: Dict[str, str] | None = None, edge_feature_name_aliases: Dict[str, str] | None = None, relation_type_labels: List[str] | None = None, train_mask_hash: str | None = None, val_mask_hash: str | None = None, test_mask_hash: str | None = None)
Bases:
BaseModelValidated metadata about the training data.
This matches the structure saved by SetHyperparameters.
- _abc_impl = <_abc._abc_data object>
- edge_feature_name_aliases: Dict[str, str] | None
- edge_feature_names: List[str] | None
- model_config = {'extra': 'forbid'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- name: str
- num_edge_features: int
- num_edges: int
- num_node_features: int
- num_nodes: int
- num_test_edges: int | None
- num_train_edges: int | None
- num_unique_relations: int | None
- num_val_edges: int | None
- relation_type_labels: List[str] | None
- splitting_strategy: str | None
- test_mask_hash: str | None
- train_mask_hash: str | None
- val_mask_hash: str | None
- vertex_feature_name_aliases: Dict[str, str] | None
- vertex_feature_names: List[str] | None
- class napistu_torch.load.checkpoints.EdgeEncoderMetadata(*, edge_in_channels: Annotated[int, Ge(ge=1)], edge_encoder_dim: Annotated[int, Ge(ge=1)], edge_encoder_dropout: Annotated[float, Ge(ge=0.0), Le(le=1.0)], edge_encoder_init_bias: float | None = None)
Bases:
BaseModelValidated metadata about the edge encoder.
This matches the structure from EdgeEncoder.get_summary() with to_model_config_names=True.
- _abc_impl = <_abc._abc_data object>
- edge_encoder_dim: int
- edge_encoder_dropout: float
- edge_encoder_init_bias: float | None
- edge_in_channels: int
- model_config = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class napistu_torch.load.checkpoints.EncoderMetadata(*, encoder: str, in_channels: Annotated[int, Ge(ge=1)], hidden_channels: Annotated[int, Ge(ge=1)], num_layers: Annotated[int, Ge(ge=1)], edge_in_channels: Annotated[int | None, Ge(ge=0)] = None, dropout: Annotated[float | None, Ge(ge=0.0), Le(le=1.0)] = None, sage_aggregator: str | None = None, graph_conv_aggregator: str | None = None, gat_heads: Annotated[int | None, Ge(ge=1)] = None, gat_concat: bool | None = None)
Bases:
BaseModelValidated metadata about the encoder.
This matches the structure from MessagePassingEncoder.get_summary().
- _abc_impl = <_abc._abc_data object>
- dropout: float | None
- edge_in_channels: int | None
- encoder: str
- gat_concat: bool | None
- gat_heads: int | None
- graph_conv_aggregator: str | None
- in_channels: int
- model_config = {'extra': 'forbid'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- num_layers: int
- sage_aggregator: str | None
- class napistu_torch.load.checkpoints.HeadMetadata(*, head: str, hidden_channels: Annotated[int, Ge(ge=1)], num_relations: Annotated[int | None, Ge(ge=1)] = None, num_classes: Annotated[int | None, Ge(ge=2)] = None, mlp_hidden_dim: Annotated[int | None, Ge(ge=1)] = None, mlp_num_layers: Annotated[int | None, Ge(ge=1)] = None, mlp_dropout: Annotated[float | None, Ge(ge=0.0), Le(le=1.0)] = None, nc_dropout: Annotated[float | None, Ge(ge=0.0), Le(le=1.0)] = None, rotate_margin: Annotated[float | None, Gt(gt=0.0)] = None, transe_margin: Annotated[float | None, Gt(gt=0.0)] = None, **extra_data: Any)
Bases:
BaseModelValidated metadata about the head/decoder.
This matches the structure from Decoder.get_summary().
- _abc_impl = <_abc._abc_data object>
- head: str
- mlp_dropout: float | None
- mlp_num_layers: int | None
- model_config = {'extra': 'allow'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- nc_dropout: float | None
- num_classes: int | None
- num_relations: int | None
- rotate_margin: float | None
- transe_margin: float | None
- class napistu_torch.load.checkpoints.ModelMetadata(*, encoder: EncoderMetadata, head: HeadMetadata, edge_encoder: EdgeEncoderMetadata | None = None)
Bases:
BaseModelValidated metadata about the complete model.
This matches the structure saved by ModelMetadataCallback under checkpoint[‘hyper_parameters’][‘model’].
- _abc_impl = <_abc._abc_data object>
- edge_encoder: EdgeEncoderMetadata | None
- encoder: EncoderMetadata
- head: HeadMetadata
- model_config = {'extra': 'forbid'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].