napistu_torch.lightning.edge_batch_datamodule

DataModule for mini-batch training on edge prediction tasks.

Returns edge indices instead of full graphs, enabling multiple gradient updates per epoch for better optimization on large graphs.

Classes

EdgeBatchDataModule(*args, **kwargs)

DataModule for edge prediction with mini-batch training.

class napistu_torch.lightning.edge_batch_datamodule.EdgeBatchDataModule(*args: Any, **kwargs: Any)

Bases: NapistuDataModule

DataModule for edge prediction with mini-batch training.

__init__(config: ExperimentConfig, batches_per_epoch: int | None = None, shuffle: bool = True, napistu_data_name: str | None = None, other_artifacts: List[str] | None = None, napistu_data: NapistuData | None = None, store: NapistuDataStore | None = None, artifact_registry: Dict[str, ~napistu_torch.load.artifacts.ArtifactDefinition] | None={'comprehensive_pathway_memberships': ArtifactDefinition(name='comprehensive_pathway_memberships', artifact_type='vertex_tensor', creation_func=<function _create_comprehensive_pathway_memberships>, description='VertexTensor containing comprehensive pathway membership features'), 'edge_prediction': ArtifactDefinition(name = 'edge_prediction', artifact_type='napistu_data', creation_func=<function _create_edge_prediction_data>, description='Unlabeled NapistuData with train/test/val edge masking'), 'edge_strata_by_edge_sbo_terms': ArtifactDefinition(name = 'edge_strata_by_edge_sbo_terms', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_edge_sbo_terms>, description='Pandas DataFrame containing edge strata by from-to edge SBO terms'), 'edge_strata_by_node_species_type': ArtifactDefinition(name = 'edge_strata_by_node_species_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_species_type>, description='Pandas DataFrame containing edge strata by node + species type'), 'edge_strata_by_node_type': ArtifactDefinition(name = 'edge_strata_by_node_type', artifact_type='pandas_dfs', creation_func=<function _create_edge_strata_by_node_type>, description='Pandas DataFrame containing edge strata by node type'), 'name_to_sid_map': ArtifactDefinition(name = 'name_to_sid_map', artifact_type='pandas_dfs', creation_func=<function _create_name_to_sid_map>, description='Pandas DataFrame containing a map of vertex names to species ids'), 'relation_prediction': ArtifactDefinition(name = 'relation_prediction', artifact_type='napistu_data', creation_func=<function _create_relation_prediction_data>, description='Unlabeled NapistuData with train/test/val with edge masking and realtion-type labels'), 'species_identifiers': ArtifactDefinition(name = 'species_identifiers', artifact_type='pandas_dfs', creation_func=<function _create_species_identifiers>, description='Pandas DataFrame containing species identifiers'), 'species_type_prediction': ArtifactDefinition(name = 'species_type_prediction', artifact_type='napistu_data', creation_func=<function _create_species_type_prediction_data>, description='NapistuData containing species type labels with train/test/val vertex masking'), 'unlabeled': ArtifactDefinition(name = 'unlabeled', artifact_type='napistu_data', creation_func=<function _create_unlabeled_data>, description='Unlabeled NapistuData without masking')}, overwrite_artifacts: bool = False)

Initialize EdgeBatchDataModule.

Parameters:
  • config (ExperimentConfig) – Pydantic experiment configuration

  • batches_per_epoch (Optional[int]) – Number of mini-batches per epoch. If None, uses config.training.batches_per_epoch

  • shuffle (bool) – Whether to shuffle mini-batch order. Default True.

  • napistu_data_name (Optional[str]) – Name of the NapistuData artifact to use

  • other_artifacts (Optional[List[str]]) – List of other artifact names needed

  • napistu_data (Optional[NapistuData]) – Direct NapistuData object for testing

  • store (Optional[NapistuDataStore]) – Pre-initialized store

  • artifact_registry (Optional[Dict[str, ArtifactDefinition]]) – Registry of artifact definitions

  • overwrite_artifacts (bool) – If True, recreate artifacts even if they exist

predict_dataloader() torch.utils.data.DataLoader

Prediction uses full graph (same as FullGraphDataModule).

Returns:

DataLoader that yields complete NapistuData object.

Return type:

DataLoader

test_dataloader() torch.utils.data.DataLoader

Test uses full graph (same as FullGraphDataModule).

Returns:

DataLoader that yields complete NapistuData object.

Return type:

DataLoader

train_dataloader() torch.utils.data.DataLoader

Create training dataloader with edge mini-batching.

Returns:

DataLoader that yields edge indices tensors for each mini-batch.

Return type:

DataLoader

val_dataloader() torch.utils.data.DataLoader

Validation uses full graph (same as FullGraphDataModule).

Returns:

DataLoader that yields complete NapistuData object.

Return type:

DataLoader

_abc_impl = <_abc._abc_data object>