napistu_torch.tasks.node_classification

Classes

NodeClassificationTask(*args, **kwargs)

Node classification task.

class napistu_torch.tasks.node_classification.NodeClassificationTask(*args: Any, **kwargs: Any)

Bases: BaseTask

Node classification task.

Predicts node labels using node features and graph structure.

This class is Lightning-free - pure PyTorch logic.

__init__(encoder: torch.nn.Module, head: torch.nn.Module, num_classes: int, metrics: List[str] = None)
_predict_impl(data: NapistuData) torch.Tensor

Predict class labels for all nodes.

compute_loss(batch: Dict[str, torch.Tensor]) torch.Tensor

Compute cross-entropy loss for node classification.

compute_metrics(data: NapistuData, split: str = 'validation') Dict[str, float]

Compute classification metrics.

prepare_batch(data: NapistuData, split: str = 'train') Dict[str, torch.Tensor]

Prepare batch for node classification.

For transductive learning, returns full graph with mask.

_abc_impl = <_abc._abc_data object>