napistu_torch.data.dataset

Datasets for edge prediction training.

This module provides PyTorch Dataset classes for working with NapistuData objects in training pipelines.

Classes

SingleGraphDataset

Wrapper to make a single NapistuData object work with DataLoader.

EdgeBatchDataset

Dataset that splits edge indices into mini-batches for training.

Classes

EdgeBatchDataset(*args, **kwargs)

Dataset that splits edge indices into mini-batches for training.

SingleGraphDataset(*args, **kwargs)

Wrapper to make a single NapistuData object work with DataLoader.

class napistu_torch.data.dataset.EdgeBatchDataset(*args: Any, **kwargs: Any)

Bases: Dataset

Dataset that splits edge indices into mini-batches for training.

Unlike sharding (parallel processing), batches are processed sequentially with weight updates between them.

Parameters:
  • edge_indices (torch.Tensor) – All edge indices to batch [num_edges]

  • batches_per_epoch (int) – Number of mini-batches per epoch

Examples

>>> # 80M training edges split into 10 mini-batches
>>> train_indices = torch.where(data.train_mask)[0]
>>> dataset = EdgeBatchDataset(train_indices, batches_per_epoch=10)
>>> len(dataset)  # 10
>>> dataset[0].shape  # torch.Size([8000000]) - first mini-batch
__init__(edge_indices: torch.Tensor, batches_per_epoch: int)
class napistu_torch.data.dataset.SingleGraphDataset(*args: Any, **kwargs: Any)

Bases: Dataset

Wrapper to make a single NapistuData object work with DataLoader.

This is necessary because DataLoader expects a Dataset interface. For full-batch training on a single graph, this just returns the same graph every time (batch_size should be 1).

__init__(data: NapistuData)