napistu_torch.lightning.full_graph_datamodule

DataModule for full-batch training on a single graph.

Returns complete NapistuData objects - the entire graph is processed at once. This is the traditional approach and matches the original NapistuDataModule behavior.

Classes

FullGraphDataModule(*args, **kwargs)

DataModule for full-batch training on a single graph.

class napistu_torch.lightning.full_graph_datamodule.FullGraphDataModule(*args: Any, **kwargs: Any)

Bases: NapistuDataModule

DataModule for full-batch training on a single graph.

__init__(config: ExperimentConfig, napistu_data_name: str | None = None, other_artifacts: List[str] | None = None, napistu_data: NapistuData | None = None, store: NapistuDataStore | None = None, artifact_registry: Dict[str, ~napistu_torch.load.artifacts.ArtifactDefinition] | None={'comprehensive_pathway_memberships': ArtifactDefinition(name='comprehensive_pathway_memberships', artifact_type='vertex_tensor', creation_func=<function _create_comprehensive_pathway_memberships>, description='VertexTensor containing comprehensive pathway membership features'), 'edge_prediction': ArtifactDefinition(name = 'edge_prediction', artifact_type='napistu_data', creation_func=<function _create_edge_prediction_data>, description='Unlabeled NapistuData with train/test/val edge masking'), 'edge_strata_by_edge_sbo_terms': ArtifactDefinition(name = 'edge_strata_by_edge_sbo_terms', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_edge_sbo_terms>, description='Pandas DataFrame containing edge strata by from-to edge SBO terms'), 'edge_strata_by_node_species_type': ArtifactDefinition(name = 'edge_strata_by_node_species_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_species_type>, description='Pandas DataFrame containing edge strata by node + species type'), 'edge_strata_by_node_type': ArtifactDefinition(name = 'edge_strata_by_node_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_type>, description='Pandas DataFrame containing edge strata by node type'), 'name_to_sid_map': ArtifactDefinition(name = 'name_to_sid_map', artifact_type='pandas_dfs', creation_func=<function _create_name_to_sid_map>, description='Pandas DataFrame containing a map of vertex names to species ids'), 'relation_prediction': ArtifactDefinition(name = 'relation_prediction', artifact_type='napistu_data', creation_func=<function _create_relation_prediction_data>, description='Unlabeled NapistuData with train/test/val with edge masking and realtion-type labels'), 'species_identifiers': ArtifactDefinition(name = 'species_identifiers', artifact_type='pandas_dfs', creation_func=<function _create_species_identifiers>, description='Pandas DataFrame containing species identifiers'), 'species_type_prediction': ArtifactDefinition(name = 'species_type_prediction', artifact_type='napistu_data', creation_func=<function _create_species_type_prediction_data>, description='NapistuData containing species type labels with train/test/val vertex masking'), 'unlabeled': ArtifactDefinition(name = 'unlabeled', artifact_type='napistu_data', creation_func=<function _create_unlabeled_data>, description='Unlabeled NapistuData without masking')}, overwrite_artifacts: bool = False)

Initialize FullGraphDataModule.

Parameters:
  • config (ExperimentConfig) – Pydantic experiment configuration

  • napistu_data_name (Optional[str]) – Name of the NapistuData artifact to use

  • other_artifacts (Optional[List[str]]) – List of other artifact names needed

  • napistu_data (Optional[NapistuData]) – Direct NapistuData object for testing

  • store (Optional[NapistuDataStore]) – Pre-initialized store

  • artifact_registry (Optional[Dict[str, ArtifactDefinition]]) – Registry of artifact definitions

  • overwrite_artifacts (bool) – If True, recreate artifacts even if they exist

predict_dataloader() torch.utils.data.DataLoader

Return prediction dataloader.

test_dataloader() torch.utils.data.DataLoader

Return test dataloader.

train_dataloader() torch.utils.data.DataLoader

Return training dataloader.

Returns:

DataLoader that yields complete NapistuData object.

Return type:

DataLoader

val_dataloader() torch.utils.data.DataLoader

Return validation dataloader.

_abc_impl = <_abc._abc_data object>