napistu_torch.data.data_utils

Utility functions for data loading and batching.

Contains shared helpers for DataLoaders and collate functions.

Public Functions

identity_collate(batch)

Custom collate function that returns NapistuData unchanged.

create_single_graph_dataloader(data, batch_size=1, shuffle=False, **kwargs)

Create a DataLoader for a single NapistuData object.

Functions

create_single_graph_dataloader(dataset[, ...])

Create a DataLoader for single graph datasets.

identity_collate(batch)

Custom collate function that returns NapistuData unchanged.

napistu_torch.data.data_utils.create_single_graph_dataloader(dataset: SingleGraphDataset, shuffle: bool = False, num_workers: int = 0) torch.utils.data.DataLoader

Create a DataLoader for single graph datasets.

Returns a DataLoader configured for single-graph training that: - Uses batch_size=1 (only one graph per batch) - Uses identity_collate to avoid PyG batching - Returns NapistuData objects directly

Parameters:
  • dataset (SingleGraphDataset) – Dataset wrapping a single NapistuData object

  • shuffle (bool, optional) – Whether to shuffle (not useful for single graph). Default False.

  • num_workers (int, optional) – Number of worker processes. Default 0 (single graph doesn’t benefit from workers).

Returns:

Configured DataLoader that yields NapistuData objects

Return type:

DataLoader

Examples

>>> dataset = SingleGraphDataset(data)
>>> loader = create_single_graph_dataloader(dataset)
>>> batch = next(iter(loader))
>>> assert isinstance(batch, NapistuData)
napistu_torch.data.data_utils.identity_collate(batch: List[NapistuData]) NapistuData

Custom collate function that returns NapistuData unchanged.

For single-graph training, we don’t want PyG’s batching behavior. This function extracts the single NapistuData object from the batch list.

Parameters:

batch (List[NapistuData]) – Batch from DataLoader (should contain exactly 1 NapistuData object)

Returns:

The single NapistuData object

Return type:

NapistuData

Raises:

AssertionError – If batch is not a list with exactly 1 NapistuData object

Examples

>>> dataset = SingleGraphDataset(data)
>>> loader = DataLoader(dataset, batch_size=1, collate_fn=identity_collate)
>>> for batch in loader:
...     assert isinstance(batch, NapistuData)