napistu_torch.data.dataset
Datasets for edge prediction training.
This module provides PyTorch Dataset classes for working with NapistuData objects in training pipelines.
Classes
- SingleGraphDataset
Wrapper to make a single NapistuData object work with DataLoader.
- EdgeBatchDataset
Dataset that splits edge indices into mini-batches for training.
Classes
|
Dataset that splits edge indices into mini-batches for training. |
|
Wrapper to make a single NapistuData object work with DataLoader. |
- class napistu_torch.data.dataset.EdgeBatchDataset(*args: Any, **kwargs: Any)
Bases:
DatasetDataset that splits edge indices into mini-batches for training.
Unlike sharding (parallel processing), batches are processed sequentially with weight updates between them.
- Parameters:
edge_indices (torch.Tensor) – All edge indices to batch [num_edges]
batches_per_epoch (int) – Number of mini-batches per epoch
Examples
>>> # 80M training edges split into 10 mini-batches >>> train_indices = torch.where(data.train_mask)[0] >>> dataset = EdgeBatchDataset(train_indices, batches_per_epoch=10) >>> len(dataset) # 10 >>> dataset[0].shape # torch.Size([8000000]) - first mini-batch
- __init__(edge_indices: torch.Tensor, batches_per_epoch: int)
- class napistu_torch.data.dataset.SingleGraphDataset(*args: Any, **kwargs: Any)
Bases:
DatasetWrapper to make a single NapistuData object work with DataLoader.
This is necessary because DataLoader expects a Dataset interface. For full-batch training on a single graph, this just returns the same graph every time (batch_size should be 1).
- __init__(data: NapistuData)