napistu_torch.data.data_utils
Utility functions for data loading and batching.
Contains shared helpers for DataLoaders and collate functions.
Public Functions
- identity_collate(batch)
Custom collate function that returns NapistuData unchanged.
- create_single_graph_dataloader(data, batch_size=1, shuffle=False, **kwargs)
Create a DataLoader for a single NapistuData object.
Functions
|
Create a DataLoader for single graph datasets. |
|
Custom collate function that returns NapistuData unchanged. |
- napistu_torch.data.data_utils.create_single_graph_dataloader(dataset: SingleGraphDataset, shuffle: bool = False, num_workers: int = 0) torch.utils.data.DataLoader
Create a DataLoader for single graph datasets.
Returns a DataLoader configured for single-graph training that: - Uses batch_size=1 (only one graph per batch) - Uses identity_collate to avoid PyG batching - Returns NapistuData objects directly
- Parameters:
dataset (SingleGraphDataset) – Dataset wrapping a single NapistuData object
shuffle (bool, optional) – Whether to shuffle (not useful for single graph). Default False.
num_workers (int, optional) – Number of worker processes. Default 0 (single graph doesn’t benefit from workers).
- Returns:
Configured DataLoader that yields NapistuData objects
- Return type:
DataLoader
Examples
>>> dataset = SingleGraphDataset(data) >>> loader = create_single_graph_dataloader(dataset) >>> batch = next(iter(loader)) >>> assert isinstance(batch, NapistuData)
- napistu_torch.data.data_utils.identity_collate(batch: List[NapistuData]) NapistuData
Custom collate function that returns NapistuData unchanged.
For single-graph training, we don’t want PyG’s batching behavior. This function extracts the single NapistuData object from the batch list.
- Parameters:
batch (List[NapistuData]) – Batch from DataLoader (should contain exactly 1 NapistuData object)
- Returns:
The single NapistuData object
- Return type:
- Raises:
AssertionError – If batch is not a list with exactly 1 NapistuData object
Examples
>>> dataset = SingleGraphDataset(data) >>> loader = DataLoader(dataset, batch_size=1, collate_fn=identity_collate) >>> for batch in loader: ... assert isinstance(batch, NapistuData)