napistu_torch.tasks.negative_sampler
Classes
|
Efficient negative edge sampler using vectorized collision detection. |
- class napistu_torch.tasks.negative_sampler.NegativeSampler(edge_index: torch.Tensor, edge_strata: torch.Tensor, edge_attr: torch.Tensor | None = None, relation_type: torch.Tensor | None = None, sampling_strategy: Literal['uniform', 'degree_weighted'] = 'uniform', oversample_ratio: float = 1.2, max_oversample_ratio: float = 2.0)
Bases:
objectEfficient negative edge sampler using vectorized collision detection.
Uses strata-constrained sampling with fast np.isin() for collision detection. Inspired by PyTorch Geometric’s negative_sampling implementation.
- __init__(edge_index: torch.Tensor, edge_strata: torch.Tensor, edge_attr: torch.Tensor | None = None, relation_type: torch.Tensor | None = None, sampling_strategy: Literal['uniform', 'degree_weighted'] = 'uniform', oversample_ratio: float = 1.2, max_oversample_ratio: float = 2.0)
Initialize sampler with vectorized collision detection.
- Parameters:
edge_index (torch.Tensor) – Training edges [2, num_edges]
edge_strata (torch.Tensor) – Strata label for each edge [num_edges]
edge_attr (torch.Tensor, optional) – Edge attributes [num_edges, num_features] This is unnecessary when message passing is just on positive edges but may be useful for other tasks.
relation_type (torch.Tensor, optional) – Relation type for each edge [num_edges] Used for relation-aware heads to sample relation types for negative edges.
sampling_strategy ({'uniform', 'degree_weighted'}) – How to sample nodes within each strata. Either: - ‘uniform’: Sample nodes uniformly within each strata - ‘degree_weighted’: Sample nodes according to their out- and in-degree within each strata
oversample_ratio (float) – Initial over-sampling factor (1.2 = 20% extra to account for collisions). Will be adaptively increased if needed and maintained across calls.
max_oversample_ratio (float) – Maximum over-sampling factor (caps adaptive increases)
- _build_degree_distributions(edge_index)
Build degree-weighted sampling distributions per strata.
- _build_edge_hash(edge_index)
Build sorted edge index array for vectorized collision detection.
- _build_strata_structure(edge_index, edge_strata)
Extract valid (from_nodes, to_nodes) pairs for each strata.
- _check_collisions_vectorized(src: torch.Tensor, dst: torch.Tensor) torch.Tensor
Fast vectorized collision detection using np.isin.
- _generate_edge_attributes(sampled_strata: torch.Tensor) torch.Tensor
Generate plausible edge attributes for negative samples.
For each negative edge, sample attributes from a real edge in the same strata.
- Parameters:
sampled_strata (torch.Tensor) – Strata assignment for each sampled negative edge [num_neg]
- Returns:
Edge attributes [num_neg, num_features]
- Return type:
torch.Tensor
- _generate_relations(sampled_strata: torch.Tensor) torch.Tensor
Generate relation types for negative samples.
For each negative edge, sample a relation type from a real edge in the same strata.
- Parameters:
sampled_strata (torch.Tensor) – Strata assignment for each sampled negative edge [num_neg]
- Returns:
Relation types [num_neg]
- Return type:
torch.Tensor
- _sample_and_filter_batch(num_needed: int) tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Sample a batch of candidates and filter for valid negatives.
Adaptively increases oversample ratio if collision rate is high. The increased ratio is maintained across future sample() calls.
- Parameters:
num_needed (int) – Number of valid negatives still needed
- Returns:
Valid source and destination nodes
- Return type:
tuple[torch.Tensor, torch.Tensor]
- _sample_candidates(batch_size: int) tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Sample candidate edges respecting strata structure.
- _sample_from_strata(sampled_strata: torch.Tensor, source_data: torch.Tensor) torch.Tensor
Sample data from source_data for each negative edge based on its strata.
For each negative edge, samples from a real training edge in the same strata. This is a utility method that factors out common logic for sampling edge attributes and relations.
- Parameters:
sampled_strata (torch.Tensor) – Strata assignment for each sampled negative edge [num_neg]
source_data (torch.Tensor) – Source data to sample from [num_edges, …] (can be 1D or 2D)
- Returns:
Sampled data [num_neg, …] with same shape as source_data except first dim
- Return type:
torch.Tensor
- sample(num_neg: int, device: str | None = None, return_edge_attr: bool = False, return_relations: bool = False) Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]
Sample negative edges with fast vectorized collision detection.
- Parameters:
num_neg (int) – Number of negative edges to sample
device (str or torch.device, optional) – Device to return results on. If None, returns on CPU.
return_edge_attr (bool) – Whether to return edge attributes for the sampled edges. Default is False.
return_relations (bool) – Whether to return relation types for the sampled edges. Default is False.
- Returns:
tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
Tuple containing
- Negative edges [2, num_neg]
- Relation type [num_neg] if return_relations is True, otherwise None
- Edge attributes [num_neg, num_features] if return_edge_attr is True, otherwise None