napistu_torch.tasks.base
Base class for all Napistu learning tasks.
Classes
|
Base class for all Napistu learning tasks. |
- class napistu_torch.tasks.base.BaseTask(*args: Any, **kwargs: Any)
Bases:
ABC,ModuleBase class for all Napistu learning tasks.
This defines the interface that all tasks must implement. No Lightning dependency - pure PyTorch.
Tasks handle: - Data preparation (e.g., negative sampling) - Loss computation - Evaluation metrics
Training infrastructure (optimizers, schedulers, logging) is handled by the Lightning adapter in napistu_torch.lightning
- Parameters:
encoder (nn.Module) – The encoder model.
head (nn.Module) – The head model.
Methods (Lightning)
--------------
compute_loss(self (Dict[str, torch.Tensor]) -> torch.Tensor:) – Compute task-specific loss.
batch (Dict[str, torch.Tensor]) -> torch.Tensor:) – Compute task-specific loss.
compute_metrics(self (NapistuData, split: str = TRAINING.VALIDATION) -> Dict[str, float]:) – Compute evaluation metrics.
data (NapistuData) -> Dict[str, float]:) – Compute evaluation metrics.
forward(self – Standard forward pass - encode nodes.
x (torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:) – Standard forward pass - encode nodes.
edge_index – Standard forward pass - encode nodes.
edge_weight=None – Standard forward pass - encode nodes.
edge_attr=None) – Standard forward pass - encode nodes.
get_embeddings(self (torch.Tensor, edge_index: torch.Tensor, edge_data: Optional[torch.Tensor] = None) -> torch.Tensor:) – Get node embeddings from the encoder.
x – Get node embeddings from the encoder.
get_learned_edge_weights(self (torch.Tensor) -> torch.Tensor:) – Compute learned edge weights using the encoder’s edge encoder.
edge_attr (torch.Tensor) -> torch.Tensor:) – Compute learned edge weights using the encoder’s edge encoder.
Dict[str (get_summary(self) ->) – Get the complete summary dictionary for this task.
Any] – Get the complete summary dictionary for this task.
predict(self (NapistuData) -> torch.Tensor:) – Make predictions (inference mode).
data – Make predictions (inference mode).
prepare_batch(self (NapistuData, split: str = TRAINING.TRAIN) -> Dict[str, torch.Tensor]:) – Prepare data batch for this task.
data – Prepare data batch for this task.
Methods
--------------
_predict_impl(self (NapistuData) -> torch.Tensor:) – Implementation of prediction logic.
data – Implementation of prediction logic.
Methods
-------------
training_step(self (NapistuData) -> torch.Tensor:) – Training step - called by Lightning adapter.
data – Training step - called by Lightning adapter.
validation_step(self (NapistuData) -> Dict[str, float]:) – Validation step - called by Lightning adapter.
data – Validation step - called by Lightning adapter.
test_step(self (NapistuData) -> Dict[str, float]:)
data
- __init__(encoder: torch.nn.Module, head: torch.nn.Module)
- abstractmethod _predict_impl(data: NapistuData) torch.Tensor
Implementation of prediction logic.
- abstractmethod compute_loss(batch: Dict[str, torch.Tensor]) torch.Tensor
Compute task-specific loss.
- abstractmethod compute_metrics(data: NapistuData, split: str = 'validation') Dict[str, float]
Compute evaluation metrics.
Returns dictionary of metric_name -> value.
- forward(x, edge_index, edge_weight=None, edge_attr=None)
Standard forward pass - encode nodes.
- get_embeddings(x: torch.Tensor, edge_index: torch.Tensor, edge_data: torch.Tensor | None = None) torch.Tensor
Get node embeddings from the encoder.
- Parameters:
x (torch.Tensor) – Node features [num_nodes, num_features]
edge_index (torch.Tensor) – Edge connectivity [2, num_edges]
edge_data (torch.Tensor, optional) – Edge data for weighting (attributes or weights). When using a learned edge encoder, pass edge attributes; when using static weights, pass the weight tensor.
- Returns:
Node embeddings [num_nodes, hidden_channels]
- Return type:
torch.Tensor
- get_learned_edge_weights(edge_attr: torch.Tensor) torch.Tensor
Compute learned edge weights using the encoder’s edge encoder.
- Parameters:
edge_attr (torch.Tensor) – Edge attributes used by the learned edge encoder. Shape [num_edges, edge_dim].
- Returns:
Learned edge weights in the range [0, 1] with shape [num_edges].
- Return type:
torch.Tensor
- Raises:
ValueError – If the encoder does not provide a learned edge encoder, if it is missing, or if edge_attr is not provided.
- get_summary() Dict[str, Any]
Get the complete summary dictionary for this task.
Collects all hyperparameters from encoder, edge encoder (if present), and head so they can be embedded in the model and used to reconstruct the architecture.
- Returns:
Summary dictionary containing: - Encoder configuration (from encoder.config()) - Edge encoder configuration (from edge_encoder.config(to_model_config_names=True), if present) - Head configuration (from head.config)
- Return type:
Dict[str, Any]
- predict(data: NapistuData) torch.Tensor
Make predictions (inference mode).
This can be used WITHOUT Lightning for production/inference.
- abstractmethod prepare_batch(data: NapistuData, split: str = 'train') Dict[str, torch.Tensor]
Prepare data batch for this task.
Task-specific data transformations (e.g., negative sampling for edge prediction, masking for node classification).
- test_step(data: NapistuData) Dict[str, float]
Test step - called by Lightning adapter.
- training_step(data: NapistuData) torch.Tensor
Training step - called by Lightning adapter.
This is the interface Lightning expects.
- validation_step(data: NapistuData) Dict[str, float]
Validation step - called by Lightning adapter.
- _abc_impl = <_abc._abc_data object>