napistu_torch.load.checkpoints

Checkpoint loading and validation utilities.

This module provides utilities for loading and validating pretrained Napistu-Torch models.

Classes

Checkpoint

Manager for PyTorch Lightning checkpoint loading and validation.

DataMetadata

Metadata about the NapistuData used during training.

EdgeEncoderMetadata

Metadata about the edge encoder component.

EncoderMetadata

Metadata about the encoder component.

HeadMetadata

Metadata about the head component.

ModelMetadata

Metadata about the complete model architecture.

CheckpointHyperparameters

Hyperparameters stored in the checkpoint.

CheckpointStructure

Structure definition for checkpoint validation.

Classes

Checkpoint(checkpoint_dict)

Manager for PyTorch Lightning checkpoint loading and validation.

CheckpointHyperparameters(*, config, model, data)

Validated hyperparameters structure from Lightning checkpoint.

CheckpointStructure(*, state_dict, ...[, ...])

Validated structure of a Lightning checkpoint dictionary.

DataMetadata(*, name, num_nodes, num_edges, ...)

Validated metadata about the training data.

EdgeEncoderMetadata(*, edge_in_channels, ...)

Validated metadata about the edge encoder.

EncoderMetadata(*, encoder, in_channels, ...)

Validated metadata about the encoder.

HeadMetadata(*, head, hidden_channels[, ...])

Validated metadata about the head/decoder.

ModelMetadata(*, encoder, head[, edge_encoder])

Validated metadata about the complete model.

class napistu_torch.load.checkpoints.Checkpoint(checkpoint_dict: Dict[str, Any])

Bases: object

Manager for PyTorch Lightning checkpoint loading and validation.

This class handles loading checkpoints, extracting metadata, validating compatibility with current data, and reconstructing model components.

Parameters:
  • checkpoint_dict (Dict[str, Any]) – PyTorch Lightning checkpoint dictionary (validated via Pydantic)

  • Methods (Private)

  • --------------

  • assert_same_napistu_data(napistu_data (NapistuData) -> None:) – Validate that current NapistuData is compatible with checkpoint.

  • load(checkpoint_path (Union[str, Path], map_location: str = DEVICE.CPU) -> "Checkpoint":) – Load and validate a checkpoint from a local file.

  • Dict[str (get_data_summary() ->) – Get encoder configuration as dictionary.

  • Any] – Get encoder configuration as dictionary.

  • Dict[str – Get head configuration as dictionary.

  • Any] – Get head configuration as dictionary.

  • Dict[str – Get data summary as dictionary.

  • Any] – Get data summary as dictionary.

  • update_model_config(model_config (ModelConfig, inplace: bool = True) -> ModelConfig:) – Update a ModelConfig instance with settings from the checkpoint.

  • Methods

  • ---------------

  • _update_model_config_with_encoder(model_config (ModelConfig, inplace: bool = True) -> Optional[ModelConfig]:) – Update a ModelConfig instance with encoder configuration from checkpoint.

  • _update_model_config_with_head(model_config (ModelConfig, inplace: bool = True) -> Optional[ModelConfig]:) – Update a ModelConfig instance with head configuration from checkpoint.

Examples

>>> # Load from local file (automatically validated)
>>> checkpoint = Checkpoint.load("path/to/checkpoint.ckpt")
>>>
>>> # Validate compatibility with current data
>>> checkpoint.assert_same_napistu_data(current_data)
>>>
>>> # Access validated configurations
>>> encoder_config = checkpoint.encoder_metadata
>>> head_config = checkpoint.head_metadata
>>> data_config = checkpoint.data_metadata
classmethod load(checkpoint_path: str | Path, map_location: str = 'cpu') Checkpoint

Load and validate a checkpoint from a local file.

Parameters:
  • checkpoint_path (Union[str, Path]) – Path to the checkpoint file (.ckpt)

  • map_location (str, optional) – Device to load tensors to, by default ‘cpu’

Returns:

Loaded and validated checkpoint object

Return type:

Checkpoint

Raises:
  • FileNotFoundError – If checkpoint file doesn’t exist

  • RuntimeError – If checkpoint loading fails

  • ValidationError – If checkpoint structure is invalid

Examples

>>> checkpoint = Checkpoint.load("model.ckpt")
>>> checkpoint = Checkpoint.load("model.ckpt", map_location="cuda:0")
__init__(checkpoint_dict: Dict[str, Any])

Initialize Checkpoint from a checkpoint dictionary.

Parameters:

checkpoint_dict (Dict[str, Any]) – PyTorch Lightning checkpoint dictionary

Raises:

ValidationError – If checkpoint structure is invalid

_update_model_config_with_encoder(model_config: ModelConfig, inplace: bool = True) ModelConfig | None

Update a ModelConfig instance with encoder configuration from checkpoint.

Updates encoder-related fields in the ModelConfig based on the checkpoint’s encoder metadata. This is useful when reconstructing a model from a checkpoint.

Parameters:
  • model_config (ModelConfig) – ModelConfig instance to update

  • inplace (bool, optional) – If True, modify the ModelConfig in place. If False, create a copy and return it. Default is True.

Returns:

The updated ModelConfig instance. If inplace=True, returns the same object. If inplace=False, returns a new ModelConfig instance.

Return type:

ModelConfig

Examples

>>> checkpoint = Checkpoint.load("model.ckpt")
>>> model_config = ModelConfig()
>>> checkpoint._update_model_config_with_encoder(model_config)
>>> # model_config now has encoder settings from checkpoint
>>>
>>> # Create a copy instead
>>> updated_config = checkpoint._update_model_config_with_encoder(model_config, inplace=False)
_update_model_config_with_head(model_config: ModelConfig, inplace: bool = True) ModelConfig | None

Update a ModelConfig instance with head configuration from checkpoint.

Updates head-related fields in the ModelConfig based on the checkpoint’s head metadata. This is useful when reconstructing a model from a checkpoint.

Parameters:
  • model_config (ModelConfig) – ModelConfig instance to update

  • inplace (bool, optional) – If True, modify the ModelConfig in place. If False, create a copy and return it. Default is True.

Returns:

The updated ModelConfig instance. If inplace=True, returns the same object. If inplace=False, returns a new ModelConfig instance.

Return type:

ModelConfig

Examples

>>> checkpoint = Checkpoint.load("model.ckpt")
>>> model_config = ModelConfig()
>>> checkpoint._update_model_config_with_head(model_config)
>>> # model_config now has head settings from checkpoint
>>>
>>> # Create a copy instead
>>> updated_config = checkpoint._update_model_config_with_head(model_config, inplace=False)
assert_same_napistu_data(napistu_data: NapistuData) None

Validate that current NapistuData is compatible with checkpoint.

Compares the data summary from the checkpoint with a summary generated from the provided NapistuData object.

Parameters:

napistu_data (NapistuData) – Current NapistuData object to validate against checkpoint

Raises:

ValueError – If data summaries don’t match

Examples

>>> checkpoint = Checkpoint.load("model.ckpt")
>>> checkpoint.assert_same_napistu_data(current_data)
get_data_summary() Dict[str, Any]

Get data summary as dictionary.

Returns:

Data summary dictionary

Return type:

Dict[str, Any]

get_edge_encoder_config() Dict[str, Any]

Get edge encoder configuration as dictionary.

Returns:

Edge encoder configuration dictionary

Return type:

Dict[str, Any]

get_encoder_config() Dict[str, Any]

Get encoder configuration as dictionary.

Returns:

Encoder configuration dictionary

Return type:

Dict[str, Any]

get_environment_info() Dict[str, Any]

Get environment information as dictionary.

Returns:

Environment information dictionary

Return type:

Dict[str, Any]

get_head_config() Dict[str, Any]

Get head configuration as dictionary.

Returns:

Head configuration dictionary

Return type:

Dict[str, Any]

update_model_config(model_config: ModelConfig, inplace: bool = True) ModelConfig

Update a ModelConfig instance with settings from the checkpoint.

Updates encoder configuration and optionally head configuration from the checkpoint metadata. This is useful when reconstructing a model from a checkpoint or when loading a pretrained model.

The head is only updated if pretrained_model_load_head is True in the model_config.

Parameters:
  • model_config (ModelConfig) – ModelConfig instance to update

  • inplace (bool, optional) – If True, modify the ModelConfig in place. If False, create a copy and return it. Default is True.

Returns:

The updated ModelConfig instance. If inplace=True, returns the same object (modified in place). If inplace=False, returns a new ModelConfig instance with updated settings.

Return type:

ModelConfig

Examples

>>> checkpoint = Checkpoint.load("model.ckpt")
>>> model_config = ModelConfig()
>>> checkpoint.update_model_config(model_config)
>>> # model_config now has encoder and head settings from checkpoint
>>>
>>> # Create a copy instead
>>> updated_config = checkpoint.update_model_config(model_config, inplace=False)
class napistu_torch.load.checkpoints.CheckpointHyperparameters(*, config: Dict[str, Any], model: ModelMetadata, data: DataMetadata, environment: EnvironmentInfo | None = None, **extra_data: Any)

Bases: BaseModel

Validated hyperparameters structure from Lightning checkpoint.

This validates the checkpoint[‘hyper_parameters’] structure.

classmethod from_task_and_data(task: Any, napistu_data: NapistuData, training_config: TrainingConfig | None = None, capture_environment: bool = True, extra_packages: list[str] | None = None) Dict[str, Any]

Create hyperparameters dict from task and data, with validation.

This is used by SetHyperparameters to build the hyperparameters dict that will be saved to the checkpoint. The dict is validated against the CheckpointHyperparameters schema before returning.

Parameters:
  • task (Any) – Task object with get_summary() method that returns model metadata

  • napistu_data (NapistuData) – NapistuData object to extract data metadata from

  • training_config (Optional[TrainingConfig], optional) – Training configuration object (usually set by Lightning’s save_hyperparameters()), by default None

  • capture_environment (bool, optional) – Whether to capture Python environment info, by default True

  • extra_packages (Optional[list[str]], optional) – Additional package names to capture versions for, by default None

Returns:

Validated hyperparameters dictionary ready to be saved to checkpoint

Return type:

Dict[str, Any]

Raises:
  • ValidationError – If the constructed hyperparameters don’t match the expected schema

  • AttributeError – If task doesn’t have get_summary() method

Examples

>>> # In SetHyperparameters
>>> hparams_dict = CheckpointHyperparameters.from_task_and_data(
...     task=pl_module.task,
...     napistu_data=napistu_data,
...     training_config=pl_module.hparams.get('config'),
...     extra_packages=['wandb', 'numpy']
... )
>>> pl_module.hparams.update(hparams_dict)
_abc_impl = <_abc._abc_data object>
config: Dict[str, Any]
data: DataMetadata
environment: EnvironmentInfo | None
model: ModelMetadata
model_config = {'extra': 'allow'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class napistu_torch.load.checkpoints.CheckpointStructure(*, state_dict: Dict[str, Any], hyper_parameters: CheckpointHyperparameters, epoch: Annotated[int | None, Ge(ge=0)] = None, global_step: Annotated[int | None, Ge(ge=0)] = None, pytorch_lightning_version: str | None = None, **extra_data: Any)

Bases: BaseModel

Validated structure of a Lightning checkpoint dictionary.

This ensures the checkpoint has all required fields with correct types.

classmethod validate_state_dict_not_empty(v)

Ensure state_dict is not empty.

_abc_impl = <_abc._abc_data object>
epoch: int | None
global_step: int | None
hyper_parameters: CheckpointHyperparameters
model_config = {'extra': 'allow'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

pytorch_lightning_version: str | None
state_dict: Dict[str, Any]
class napistu_torch.load.checkpoints.DataMetadata(*, name: str, num_nodes: Annotated[int, Ge(ge=0)], num_edges: Annotated[int, Ge(ge=0)], num_node_features: Annotated[int, Ge(ge=0)], num_edge_features: Annotated[int, Ge(ge=0)], splitting_strategy: str | None = None, num_unique_relations: Annotated[int | None, Ge(ge=0)] = None, num_train_edges: Annotated[int | None, Ge(ge=0)] = None, num_val_edges: Annotated[int | None, Ge(ge=0)] = None, num_test_edges: Annotated[int | None, Ge(ge=0)] = None, vertex_feature_names: List[str] | None = None, edge_feature_names: List[str] | None = None, vertex_feature_name_aliases: Dict[str, str] | None = None, edge_feature_name_aliases: Dict[str, str] | None = None, relation_type_labels: List[str] | None = None, train_mask_hash: str | None = None, val_mask_hash: str | None = None, test_mask_hash: str | None = None)

Bases: BaseModel

Validated metadata about the training data.

This matches the structure saved by SetHyperparameters.

_abc_impl = <_abc._abc_data object>
edge_feature_name_aliases: Dict[str, str] | None
edge_feature_names: List[str] | None
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

name: str
num_edge_features: int
num_edges: int
num_node_features: int
num_nodes: int
num_test_edges: int | None
num_train_edges: int | None
num_unique_relations: int | None
num_val_edges: int | None
relation_type_labels: List[str] | None
splitting_strategy: str | None
test_mask_hash: str | None
train_mask_hash: str | None
val_mask_hash: str | None
vertex_feature_name_aliases: Dict[str, str] | None
vertex_feature_names: List[str] | None
class napistu_torch.load.checkpoints.EdgeEncoderMetadata(*, edge_in_channels: Annotated[int, Ge(ge=1)], edge_encoder_dim: Annotated[int, Ge(ge=1)], edge_encoder_dropout: Annotated[float, Ge(ge=0.0), Le(le=1.0)], edge_encoder_init_bias: float | None = None)

Bases: BaseModel

Validated metadata about the edge encoder.

This matches the structure from EdgeEncoder.get_summary() with to_model_config_names=True.

_abc_impl = <_abc._abc_data object>
edge_encoder_dim: int
edge_encoder_dropout: float
edge_encoder_init_bias: float | None
edge_in_channels: int
model_config = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class napistu_torch.load.checkpoints.EncoderMetadata(*, encoder: str, in_channels: Annotated[int, Ge(ge=1)], hidden_channels: Annotated[int, Ge(ge=1)], num_layers: Annotated[int, Ge(ge=1)], edge_in_channels: Annotated[int | None, Ge(ge=0)] = None, dropout: Annotated[float | None, Ge(ge=0.0), Le(le=1.0)] = None, sage_aggregator: str | None = None, graph_conv_aggregator: str | None = None, gat_heads: Annotated[int | None, Ge(ge=1)] = None, gat_concat: bool | None = None)

Bases: BaseModel

Validated metadata about the encoder.

This matches the structure from MessagePassingEncoder.get_summary().

_abc_impl = <_abc._abc_data object>
dropout: float | None
edge_in_channels: int | None
encoder: str
gat_concat: bool | None
gat_heads: int | None
graph_conv_aggregator: str | None
hidden_channels: int
in_channels: int
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

num_layers: int
sage_aggregator: str | None
class napistu_torch.load.checkpoints.HeadMetadata(*, head: str, hidden_channels: Annotated[int, Ge(ge=1)], num_relations: Annotated[int | None, Ge(ge=1)] = None, num_classes: Annotated[int | None, Ge(ge=2)] = None, mlp_hidden_dim: Annotated[int | None, Ge(ge=1)] = None, mlp_num_layers: Annotated[int | None, Ge(ge=1)] = None, mlp_dropout: Annotated[float | None, Ge(ge=0.0), Le(le=1.0)] = None, nc_dropout: Annotated[float | None, Ge(ge=0.0), Le(le=1.0)] = None, rotate_margin: Annotated[float | None, Gt(gt=0.0)] = None, transe_margin: Annotated[float | None, Gt(gt=0.0)] = None, **extra_data: Any)

Bases: BaseModel

Validated metadata about the head/decoder.

This matches the structure from Decoder.get_summary().

_abc_impl = <_abc._abc_data object>
head: str
hidden_channels: int
mlp_dropout: float | None
mlp_hidden_dim: int | None
mlp_num_layers: int | None
model_config = {'extra': 'allow'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

nc_dropout: float | None
num_classes: int | None
num_relations: int | None
rotate_margin: float | None
transe_margin: float | None
class napistu_torch.load.checkpoints.ModelMetadata(*, encoder: EncoderMetadata, head: HeadMetadata, edge_encoder: EdgeEncoderMetadata | None = None)

Bases: BaseModel

Validated metadata about the complete model.

This matches the structure saved by ModelMetadataCallback under checkpoint[‘hyper_parameters’][‘model’].

_abc_impl = <_abc._abc_data object>
edge_encoder: EdgeEncoderMetadata | None
encoder: EncoderMetadata
head: HeadMetadata
model_config = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].