napistu_torch.tasks.node_classification
Classes
|
Node classification task. |
- class napistu_torch.tasks.node_classification.NodeClassificationTask(*args: Any, **kwargs: Any)
Bases:
BaseTaskNode classification task.
Predicts node labels using node features and graph structure.
This class is Lightning-free - pure PyTorch logic.
- __init__(encoder: torch.nn.Module, head: torch.nn.Module, num_classes: int, metrics: List[str] = None)
- _predict_impl(data: NapistuData) torch.Tensor
Predict class labels for all nodes.
- compute_loss(batch: Dict[str, torch.Tensor]) torch.Tensor
Compute cross-entropy loss for node classification.
- compute_metrics(data: NapistuData, split: str = 'validation') Dict[str, float]
Compute classification metrics.
- prepare_batch(data: NapistuData, split: str = 'train') Dict[str, torch.Tensor]
Prepare batch for node classification.
For transductive learning, returns full graph with mask.
- _abc_impl = <_abc._abc_data object>