napistu_torch.tasks.base

Base class for all Napistu learning tasks.

Classes

BaseTask(*args, **kwargs)

Base class for all Napistu learning tasks.

class napistu_torch.tasks.base.BaseTask(*args: Any, **kwargs: Any)

Bases: ABC, Module

Base class for all Napistu learning tasks.

This defines the interface that all tasks must implement. No Lightning dependency - pure PyTorch.

Tasks handle: - Data preparation (e.g., negative sampling) - Loss computation - Evaluation metrics

Training infrastructure (optimizers, schedulers, logging) is handled by the Lightning adapter in napistu_torch.lightning

Parameters:
  • encoder (nn.Module) – The encoder model.

  • head (nn.Module) – The head model.

  • Methods (Lightning)

  • --------------

  • compute_loss(self (Dict[str, torch.Tensor]) -> torch.Tensor:) – Compute task-specific loss.

  • batch (Dict[str, torch.Tensor]) -> torch.Tensor:) – Compute task-specific loss.

  • compute_metrics(self (NapistuData, split: str = TRAINING.VALIDATION) -> Dict[str, float]:) – Compute evaluation metrics.

  • data (NapistuData) -> Dict[str, float]:) – Compute evaluation metrics.

  • forward(self – Standard forward pass - encode nodes.

  • x (torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:) – Standard forward pass - encode nodes.

  • edge_index – Standard forward pass - encode nodes.

  • edge_weight=None – Standard forward pass - encode nodes.

  • edge_attr=None) – Standard forward pass - encode nodes.

  • get_embeddings(self (torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:) – Get node embeddings from the encoder.

  • x – Get node embeddings from the encoder.

  • get_learned_edge_weights(self (torch.Tensor) -> torch.Tensor:) – Compute learned edge weights using the encoder’s edge encoder.

  • edge_attr (torch.Tensor) -> torch.Tensor:) – Compute learned edge weights using the encoder’s edge encoder.

  • Dict[str (get_summary(self) ->) – Get the complete summary dictionary for this task.

  • Any] – Get the complete summary dictionary for this task.

  • predict(self (NapistuData) -> torch.Tensor:) – Make predictions (inference mode).

  • data – Make predictions (inference mode).

  • prepare_batch(self (NapistuData, split: str = TRAINING.TRAIN) -> Dict[str, torch.Tensor]:) – Prepare data batch for this task.

  • data – Prepare data batch for this task.

  • Methods

  • --------------

  • _predict_impl(self (NapistuData) -> torch.Tensor:) – Implementation of prediction logic.

  • data – Implementation of prediction logic.

  • Methods

  • -------------

  • training_step(self (NapistuData) -> torch.Tensor:) – Training step - called by Lightning adapter.

  • data – Training step - called by Lightning adapter.

  • validation_step(self (NapistuData) -> Dict[str, float]:) – Validation step - called by Lightning adapter.

  • data – Validation step - called by Lightning adapter.

  • test_step(self (NapistuData) -> Dict[str, float]:)

  • data

__init__(encoder: torch.nn.Module, head: torch.nn.Module)
abstractmethod _predict_impl(data: NapistuData) torch.Tensor

Implementation of prediction logic.

abstractmethod compute_loss(batch: Dict[str, torch.Tensor]) torch.Tensor

Compute task-specific loss.

abstractmethod compute_metrics(data: NapistuData, split: str = 'validation') Dict[str, float]

Compute evaluation metrics.

Returns dictionary of metric_name -> value.

forward(x, edge_index, edge_weight=None, edge_attr=None)

Standard forward pass - encode nodes.

get_embeddings(x: torch.Tensor, edge_index: torch.Tensor, edge_data: torch.Tensor | None = None) torch.Tensor

Get node embeddings from the encoder.

Parameters:
  • x (torch.Tensor) – Node features [num_nodes, num_features]

  • edge_index (torch.Tensor) – Edge connectivity [2, num_edges]

  • edge_data (torch.Tensor, optional) – Edge data for weighting (attributes or weights). When using a learned edge encoder, pass edge attributes; when using static weights, pass the weight tensor.

Returns:

Node embeddings [num_nodes, hidden_channels]

Return type:

torch.Tensor

get_learned_edge_weights(edge_attr: torch.Tensor) torch.Tensor

Compute learned edge weights using the encoder’s edge encoder.

Parameters:

edge_attr (torch.Tensor) – Edge attributes used by the learned edge encoder. Shape [num_edges, edge_dim].

Returns:

Learned edge weights in the range [0, 1] with shape [num_edges].

Return type:

torch.Tensor

Raises:

ValueError – If the encoder does not provide a learned edge encoder, if it is missing, or if edge_attr is not provided.

get_summary() Dict[str, Any]

Get the complete summary dictionary for this task.

Collects all hyperparameters from encoder, edge encoder (if present), and head so they can be embedded in the model and used to reconstruct the architecture.

Returns:

Summary dictionary containing: - Encoder configuration (from encoder.config()) - Edge encoder configuration (from edge_encoder.config(to_model_config_names=True), if present) - Head configuration (from head.config)

Return type:

Dict[str, Any]

predict(data: NapistuData) torch.Tensor

Make predictions (inference mode).

This can be used WITHOUT Lightning for production/inference.

abstractmethod prepare_batch(data: NapistuData, split: str = 'train') Dict[str, torch.Tensor]

Prepare data batch for this task.

Task-specific data transformations (e.g., negative sampling for edge prediction, masking for node classification).

test_step(data: NapistuData) Dict[str, float]

Test step - called by Lightning adapter.

training_step(data: NapistuData) torch.Tensor

Training step - called by Lightning adapter.

This is the interface Lightning expects.

validation_step(data: NapistuData) Dict[str, float]

Validation step - called by Lightning adapter.

_abc_impl = <_abc._abc_data object>