napistu_torch.lightning.datamodule

Base Lightning DataModule for Napistu networks.

This module provides the abstract base class and shared infrastructure. Concrete implementations should subclass and implement the dataloader methods.

Classes

NapistuDataModule(*args, **kwargs)

Abstract base class for Napistu Lightning DataModules.

class napistu_torch.lightning.datamodule.NapistuDataModule(*args: Any, **kwargs: Any)

Bases: LightningDataModule, ABC

Abstract base class for Napistu Lightning DataModules.

Provides shared infrastructure for all Napistu DataModules: - Artifact loading from configs and store - Setup logic for transductive/inductive splits - Property access (num_node_features, num_edge_features)

Subclasses must implement: - train_dataloader() - val_dataloader() - test_dataloader() - predict_dataloader()

Do not instantiate this class directly. Use concrete implementations: - FullGraphDataModule: Returns full NapistuData objects (full-batch training) - EdgeBatchDataModule: Returns edge indices (mini-batch training for edge prediction)

Parameters:
  • config (ExperimentConfig) – Pydantic experiment configuration containing data, model, task, and training configs

  • napistu_data_name (Optional[str]) – Name of the NapistuData artifact to use for training

  • other_artifacts (Optional[List[str]]) – List of other artifact names needed for the experiment

  • napistu_data (Optional[NapistuData]) – Direct NapistuData object for testing/backward compatibility

  • store (Optional[NapistuDataStore]) – Pre-initialized store

  • artifact_registry (Optional[Dict[str, ArtifactDefinition]]) – Registry of artifact definitions

  • overwrite_artifacts (bool, default=False) – If True, recreate artifact even if it exists

__init__(config: ExperimentConfig, napistu_data_name: str | None = None, other_artifacts: List[str] | None = None, napistu_data: NapistuData | None = None, store: NapistuDataStore | None = None, artifact_registry: Dict[str, ~napistu_torch.load.artifacts.ArtifactDefinition] | None={'comprehensive_pathway_memberships': ArtifactDefinition(name='comprehensive_pathway_memberships', artifact_type='vertex_tensor', creation_func=<function _create_comprehensive_pathway_memberships>, description='VertexTensor containing comprehensive pathway membership features'), 'edge_prediction': ArtifactDefinition(name = 'edge_prediction', artifact_type='napistu_data', creation_func=<function _create_edge_prediction_data>, description='Unlabeled NapistuData with train/test/val edge masking'), 'edge_strata_by_edge_sbo_terms': ArtifactDefinition(name = 'edge_strata_by_edge_sbo_terms', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_edge_sbo_terms>, description='Pandas DataFrame containing edge strata by from-to edge SBO terms'), 'edge_strata_by_node_species_type': ArtifactDefinition(name = 'edge_strata_by_node_species_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_species_type>, description='Pandas DataFrame containing edge strata by node + species type'), 'edge_strata_by_node_type': ArtifactDefinition(name = 'edge_strata_by_node_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_type>, description='Pandas DataFrame containing edge strata by node type'), 'name_to_sid_map': ArtifactDefinition(name = 'name_to_sid_map', artifact_type='pandas_dfs', creation_func=<function _create_name_to_sid_map>, description='Pandas DataFrame containing a map of vertex names to species ids'), 'relation_prediction': ArtifactDefinition(name = 'relation_prediction', artifact_type='napistu_data', creation_func=<function _create_relation_prediction_data>, description='Unlabeled NapistuData with train/test/val with edge masking and realtion-type labels'), 'species_identifiers': ArtifactDefinition(name = 'species_identifiers', artifact_type='pandas_dfs', creation_func=<function _create_species_identifiers>, description='Pandas DataFrame containing species identifiers'), 'species_type_prediction': ArtifactDefinition(name = 'species_type_prediction', artifact_type='napistu_data', creation_func=<function _create_species_type_prediction_data>, description='NapistuData containing species type labels with train/test/val vertex masking'), 'unlabeled': ArtifactDefinition(name = 'unlabeled', artifact_type='napistu_data', creation_func=<function _create_unlabeled_data>, description='Unlabeled NapistuData without masking')}, overwrite_artifacts: bool = False)
abstractmethod predict_dataloader() torch.utils.data.DataLoader

Create prediction dataloader. Must be implemented by subclasses.

setup(stage: str | None = None)

Set up NapistuData object(s) from the provided data.

Shared setup logic for all subclasses.

abstractmethod test_dataloader() torch.utils.data.DataLoader

Create test dataloader. Must be implemented by subclasses.

abstractmethod train_dataloader() torch.utils.data.DataLoader

Create training dataloader. Must be implemented by subclasses.

abstractmethod val_dataloader() torch.utils.data.DataLoader

Create validation dataloader. Must be implemented by subclasses.

_abc_impl = <_abc._abc_data object>
property num_edge_features: int

Get the number of edge features from the data.

property num_node_features: int

Get the number of node features from the data.