napistu_torch.lightning.full_graph_datamodule
DataModule for full-batch training on a single graph.
Returns complete NapistuData objects - the entire graph is processed at once. This is the traditional approach and matches the original NapistuDataModule behavior.
Classes
|
DataModule for full-batch training on a single graph. |
- class napistu_torch.lightning.full_graph_datamodule.FullGraphDataModule(*args: Any, **kwargs: Any)
Bases:
NapistuDataModuleDataModule for full-batch training on a single graph.
- __init__(config: ExperimentConfig, napistu_data_name: str | None = None, other_artifacts: List[str] | None = None, napistu_data: NapistuData | None = None, store: NapistuDataStore | None = None, artifact_registry: Dict[str, ~napistu_torch.load.artifacts.ArtifactDefinition] | None={'comprehensive_pathway_memberships': ArtifactDefinition(name='comprehensive_pathway_memberships', artifact_type='vertex_tensor', creation_func=<function _create_comprehensive_pathway_memberships>, description='VertexTensor containing comprehensive pathway membership features'), 'edge_prediction': ArtifactDefinition(name = 'edge_prediction', artifact_type='napistu_data', creation_func=<function _create_edge_prediction_data>, description='Unlabeled NapistuData with train/test/val edge masking'), 'edge_strata_by_edge_sbo_terms': ArtifactDefinition(name = 'edge_strata_by_edge_sbo_terms', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_edge_sbo_terms>, description='Pandas DataFrame containing edge strata by from-to edge SBO terms'), 'edge_strata_by_node_species_type': ArtifactDefinition(name = 'edge_strata_by_node_species_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_species_type>, description='Pandas DataFrame containing edge strata by node + species type'), 'edge_strata_by_node_type': ArtifactDefinition(name = 'edge_strata_by_node_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_type>, description='Pandas DataFrame containing edge strata by node type'), 'name_to_sid_map': ArtifactDefinition(name = 'name_to_sid_map', artifact_type='pandas_dfs', creation_func=<function _create_name_to_sid_map>, description='Pandas DataFrame containing a map of vertex names to species ids'), 'relation_prediction': ArtifactDefinition(name = 'relation_prediction', artifact_type='napistu_data', creation_func=<function _create_relation_prediction_data>, description='Unlabeled NapistuData with train/test/val with edge masking and realtion-type labels'), 'species_identifiers': ArtifactDefinition(name = 'species_identifiers', artifact_type='pandas_dfs', creation_func=<function _create_species_identifiers>, description='Pandas DataFrame containing species identifiers'), 'species_type_prediction': ArtifactDefinition(name = 'species_type_prediction', artifact_type='napistu_data', creation_func=<function _create_species_type_prediction_data>, description='NapistuData containing species type labels with train/test/val vertex masking'), 'unlabeled': ArtifactDefinition(name = 'unlabeled', artifact_type='napistu_data', creation_func=<function _create_unlabeled_data>, description='Unlabeled NapistuData without masking')}, overwrite_artifacts: bool = False)
Initialize FullGraphDataModule.
- Parameters:
config (ExperimentConfig) – Pydantic experiment configuration
napistu_data_name (Optional[str]) – Name of the NapistuData artifact to use
other_artifacts (Optional[List[str]]) – List of other artifact names needed
napistu_data (Optional[NapistuData]) – Direct NapistuData object for testing
store (Optional[NapistuDataStore]) – Pre-initialized store
artifact_registry (Optional[Dict[str, ArtifactDefinition]]) – Registry of artifact definitions
overwrite_artifacts (bool) – If True, recreate artifacts even if they exist
- predict_dataloader() torch.utils.data.DataLoader
Return prediction dataloader.
- test_dataloader() torch.utils.data.DataLoader
Return test dataloader.
- train_dataloader() torch.utils.data.DataLoader
Return training dataloader.
- Returns:
DataLoader that yields complete NapistuData object.
- Return type:
DataLoader
- val_dataloader() torch.utils.data.DataLoader
Return validation dataloader.
- _abc_impl = <_abc._abc_data object>