napistu_torch.lightning.datamodule
Base Lightning DataModule for Napistu networks.
This module provides the abstract base class and shared infrastructure. Concrete implementations should subclass and implement the dataloader methods.
Classes
|
Abstract base class for Napistu Lightning DataModules. |
- class napistu_torch.lightning.datamodule.NapistuDataModule(*args: Any, **kwargs: Any)
Bases:
LightningDataModule,ABCAbstract base class for Napistu Lightning DataModules.
Provides shared infrastructure for all Napistu DataModules: - Artifact loading from configs and store - Setup logic for transductive/inductive splits - Property access (num_node_features, num_edge_features)
Subclasses must implement: - train_dataloader() - val_dataloader() - test_dataloader() - predict_dataloader()
Do not instantiate this class directly. Use concrete implementations: - FullGraphDataModule: Returns full NapistuData objects (full-batch training) - EdgeBatchDataModule: Returns edge indices (mini-batch training for edge prediction)
- Parameters:
config (ExperimentConfig) – Pydantic experiment configuration containing data, model, task, and training configs
napistu_data_name (Optional[str]) – Name of the NapistuData artifact to use for training
other_artifacts (Optional[List[str]]) – List of other artifact names needed for the experiment
napistu_data (Optional[NapistuData]) – Direct NapistuData object for testing/backward compatibility
store (Optional[NapistuDataStore]) – Pre-initialized store
artifact_registry (Optional[Dict[str, ArtifactDefinition]]) – Registry of artifact definitions
overwrite_artifacts (bool, default=False) – If True, recreate artifact even if it exists
- __init__(config: ExperimentConfig, napistu_data_name: str | None = None, other_artifacts: List[str] | None = None, napistu_data: NapistuData | None = None, store: NapistuDataStore | None = None, artifact_registry: Dict[str, ~napistu_torch.load.artifacts.ArtifactDefinition] | None={'comprehensive_pathway_memberships': ArtifactDefinition(name='comprehensive_pathway_memberships', artifact_type='vertex_tensor', creation_func=<function _create_comprehensive_pathway_memberships>, description='VertexTensor containing comprehensive pathway membership features'), 'edge_prediction': ArtifactDefinition(name = 'edge_prediction', artifact_type='napistu_data', creation_func=<function _create_edge_prediction_data>, description='Unlabeled NapistuData with train/test/val edge masking'), 'edge_strata_by_edge_sbo_terms': ArtifactDefinition(name = 'edge_strata_by_edge_sbo_terms', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_edge_sbo_terms>, description='Pandas DataFrame containing edge strata by from-to edge SBO terms'), 'edge_strata_by_node_species_type': ArtifactDefinition(name = 'edge_strata_by_node_species_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_species_type>, description='Pandas DataFrame containing edge strata by node + species type'), 'edge_strata_by_node_type': ArtifactDefinition(name = 'edge_strata_by_node_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_type>, description='Pandas DataFrame containing edge strata by node type'), 'name_to_sid_map': ArtifactDefinition(name = 'name_to_sid_map', artifact_type='pandas_dfs', creation_func=<function _create_name_to_sid_map>, description='Pandas DataFrame containing a map of vertex names to species ids'), 'relation_prediction': ArtifactDefinition(name = 'relation_prediction', artifact_type='napistu_data', creation_func=<function _create_relation_prediction_data>, description='Unlabeled NapistuData with train/test/val with edge masking and realtion-type labels'), 'species_identifiers': ArtifactDefinition(name = 'species_identifiers', artifact_type='pandas_dfs', creation_func=<function _create_species_identifiers>, description='Pandas DataFrame containing species identifiers'), 'species_type_prediction': ArtifactDefinition(name = 'species_type_prediction', artifact_type='napistu_data', creation_func=<function _create_species_type_prediction_data>, description='NapistuData containing species type labels with train/test/val vertex masking'), 'unlabeled': ArtifactDefinition(name = 'unlabeled', artifact_type='napistu_data', creation_func=<function _create_unlabeled_data>, description='Unlabeled NapistuData without masking')}, overwrite_artifacts: bool = False)
- abstractmethod predict_dataloader() torch.utils.data.DataLoader
Create prediction dataloader. Must be implemented by subclasses.
- setup(stage: str | None = None)
Set up NapistuData object(s) from the provided data.
Shared setup logic for all subclasses.
- abstractmethod test_dataloader() torch.utils.data.DataLoader
Create test dataloader. Must be implemented by subclasses.
- abstractmethod train_dataloader() torch.utils.data.DataLoader
Create training dataloader. Must be implemented by subclasses.
- abstractmethod val_dataloader() torch.utils.data.DataLoader
Create validation dataloader. Must be implemented by subclasses.
- _abc_impl = <_abc._abc_data object>
- property num_edge_features: int
Get the number of edge features from the data.
- property num_node_features: int
Get the number of node features from the data.