napistu_torch.napistu_data

NapistuData - A PyTorch Geometric Data subclass for Napistu networks.

This module provides a PyTorch Geometric Data subclass with Napistu-specific functionality including safe save/load methods and additional utilities.

Classes

NapistuData

A PyTorch Geometric Data subclass for Napistu biological networks.

Classes

NapistuData(*args, **kwargs)

A PyTorch Geometric Data subclass for Napistu biological networks.

class napistu_torch.napistu_data.NapistuData(*args: Any, **kwargs: Any)

Bases: Data

A PyTorch Geometric Data subclass for Napistu biological networks.

This class extends PyG’s Data class with Napistu-specific functionality including safe save/load methods and additional utilities for working with biological network data.

Parameters:
  • x (torch.Tensor) – Node feature matrix with shape [num_nodes, num_node_features]

  • edge_index (torch.Tensor) – Graph connectivity in COO format with shape [2, num_edges]

  • edge_attr (torch.Tensor) – Edge feature matrix with shape [num_edges, num_edge_features]

  • name (str = NAPISTU_DATA_DEFAULT_NAME,) – Name of the NapistuData object. Used for summaries and for organizing objects in the NapistuDataStore.

  • edge_weight (torch.Tensor, optional) – Edge weights tensor with shape [num_edges]

  • y (torch.Tensor, optional) – Node labels tensor with shape [num_nodes] for supervised learning tasks

  • vertex_feature_names (List[str], optional) – Names of vertex features for interpretability

  • edge_feature_names (List[str], optional) – Names of edge features for interpretability

  • vertex_feature_name_aliases (Dict[str, str], optional) – Mapping from vertex feature names to their canonical names (for deduplicated features)

  • edge_feature_name_aliases (Dict[str, str], optional) – Mapping from edge feature names to their canonical names (for deduplicated features)

  • ng_vertex_names (pd.Series, optional) – Minimal vertex names from the original NapistuGraph. Series aligned with the vertex tensor (x) - each element corresponds to a vertex in the same order as the tensor rows. Used for debugging and validation of tensor alignment.

  • ng_edge_names (pd.DataFrame, optional) – Minimal edge names from the original NapistuGraph. DataFrame with ‘from’ and ‘to’ columns aligned with the edge tensor (edge_index, edge_attr) - each row corresponds to an edge in the same order as the tensor columns. Used for debugging and validation.

  • splitting_strategy (Optional[str] = None,) – Strategy used to split the data into train/test/val sets. This occurs upstream but the approach is tracked as a reference here.

  • labeling_manager (Optional[LabelingManager] = None,) – Labeling manager used to encode the labels. This is used to decode the labels back to the original values for validation purposes.

  • **kwargs – Additional attributes to store in the data object

  • Methods (Private)

  • --------------

  • copy() – Create a deep copy of the NapistuData object

  • estimate_memory_footprint() – Estimate memory footprint of the NapistuData object

  • get_edge_feature_names() – Get the names of edge features

  • get_edge_indices(df – Get edge index tensor from a DataFrame with vertex names

  • from_col – Get edge index tensor from a DataFrame with vertex names

  • to_col) – Get edge index tensor from a DataFrame with vertex names

  • get_edge_names() – Get the edge names from the original NapistuGraph

  • get_edge_weights() – Get edge weights as a 1D tensor

  • get_feature_by_name(feature_name) – Get a feature by name from the NapistuData object

  • get_summary(summary_type="basic") – Get a summary of the NapistuData object

  • get_symmetrical_relation_indices() – Get the indices of symmetric relation types

  • get_vertex_feature_names() – Get the names of vertex features

  • get_vertex_indices(vertex_names) – Get the indices of vertices by their names

  • get_vertex_names() – Get the vertex names from the original NapistuGraph

  • has_edges(edge_indices) – Check which edges in edge_indices are present in this NapistuData

  • load(filepath – Load a NapistuData object from disk

  • map_location="cpu") – Load a NapistuData object from disk

  • save(filepath) – Save the NapistuData object to disk

  • show_memory_footprint() – Display memory footprint of the NapistuData object

  • show_summary() – Display a summary of the NapistuData object

  • trim(keep_edge_attr=True – Trim the NapistuData object to keep only the specified attributes

  • keep_labels=True – Trim the NapistuData object to keep only the specified attributes

  • keep_masks=True – Trim the NapistuData object to keep only the specified attributes

  • inplace=False) – Trim the NapistuData object to keep only the specified attributes

  • unencode_features(napistu_graph – Unencode features from the NapistuData object

  • attribute_type – Unencode features from the NapistuData object

  • attribute – Unencode features from the NapistuData object

  • encoding_manager=None) – Unencode features from the NapistuData object

  • validate_graph_alignment(napistu_graph) – Validate the alignment of the NapistuData object with the NapistuGraph

  • Methods

  • ---------------

  • _validate_edge_encoding(napistu_graph – Validate the edge encoding of the NapistuData object

  • edge_attribute – Validate the edge encoding of the NapistuData object

  • encoding_manager=None) – Validate the edge encoding of the NapistuData object

  • _validate_labels(napistu_graph – Validate the labels of the NapistuData object

  • labeling_manager) – Validate the labels of the NapistuData object

  • _validate_vertex_encoding(napistu_graph – Validate the vertex encoding of the NapistuData object

  • vertex_attribute – Validate the vertex encoding of the NapistuData object

  • encoding_manager=None) – Validate the vertex encoding of the NapistuData object

Examples

>>> # Create a NapistuData object (x, edge_index, and edge_attr are required)
>>> data = NapistuData(
...     x=torch.randn(100, 10),                    # Required: node features
...     edge_index=torch.randint(0, 100, (2, 200)), # Required: graph connectivity
...     edge_attr=torch.randn(200, 5),             # Required: edge features
...     y=torch.randint(0, 3, (100,)),             # Optional: node labels
...     vertex_feature_names=['feature_1', 'feature_2', ...],  # Optional
...     edge_feature_names=['weight', 'direction', ...],       # Optional
...     ng_vertex_names=vertex_names_series,        # Optional: minimal vertex names
...     ng_edge_names=edge_names_df,                # Optional: minimal edge names
... )
>>>
>>> # Save and load
>>> data.save('my_network.pt')
>>> loaded_data = NapistuData.load('my_network.pt')
classmethod load(filepath: str | Path, map_location: str = 'cpu') NapistuData

Load a NapistuData object from disk.

This method automatically uses weights_only=False to ensure compatibility with PyG Data objects, which contain custom classes that aren’t allowed with the default weights_only=True setting in PyTorch 2.6+.

Parameters:
  • filepath (Union[str, Path]) – Path to the saved data object

  • map_location (str, default='cpu') – Device to map tensors to (e.g., ‘cpu’, ‘cuda:0’). Defaults to ‘cpu’ for universal compatibility.

Returns:

The loaded NapistuData object

Return type:

NapistuData

Raises:
  • FileNotFoundError – If the file doesn’t exist

  • RuntimeError – If loading fails

  • TypeError – If the loaded object is not a NapistuData or Data object

Examples

>>> data = NapistuData.load('my_network.pt')  # Loads to CPU by default
>>> data = NapistuData.load('my_network.pt', map_location='cuda:0')  # Load to GPU

Notes

This method uses weights_only=False by default because PyG Data objects contain custom classes that aren’t allowed with weights_only=True. Only use this with trusted files, as it can result in arbitrary code execution.

__init__(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, name: str = 'default', edge_weight: torch.Tensor | None = None, y: torch.Tensor | None = None, vertex_feature_names: List[str] | None = None, edge_feature_names: List[str] | None = None, vertex_feature_name_aliases: Dict[str, str] | None = None, edge_feature_name_aliases: Dict[str, str] | None = None, ng_vertex_names: Series | None = None, ng_edge_names: DataFrame | None = None, splitting_strategy: str | None = None, labeling_manager: LabelingManager | None = None, relation_type: torch.Tensor | None = None, relation_manager: LabelingManager | None = None, **kwargs)
_validate_edge_encoding(napistu_graph: NapistuGraph, edge_attribute: str, encoding_manager: EncodingManager | None = None) bool

Validate consistency between encoded values and original NapistuGraph edge values.

This method compares the edge values recovered from encoding in the NapistuData object with the original edge values stored in the NapistuGraph object to ensure data consistency.

Parameters:
  • napistu_graph (NapistuGraph) – The NapistuGraph object containing the original edge values

  • edge_attribute (str) – The name of the edge attribute to validate (e.g., ‘r_irreversible’)

  • encoding_manager (Optional[EncodingManager]) – The encoding manager to use to unencode the features. If this is not provided then the default encoding managers will be used.

Returns:

True if the encoding is consistent, False otherwise

Return type:

bool

Raises:

ValueError – If the edge attribute is not found in the NapistuGraph, if edge names don’t match between NapistuData and NapistuGraph, or if there are encoding inconsistencies.

Examples

>>> # Validate r_irreversible encoding consistency
>>> is_consistent = napistu_data._validate_edge_encoding(napistu_graph, 'r_irreversible')
>>> print(f"Encoding is consistent: {is_consistent}")
True
>>> # Validate a different edge attribute
>>> is_consistent = napistu_data._validate_edge_encoding(napistu_graph, 'weight')
>>> print(f"Weight encoding is consistent: {is_consistent}")
True
_validate_labels(napistu_graph: NapistuGraph, labeling_manager: LabelingManager) bool

Validate consistency between encoded labels and original NapistuGraph vertex labels.

This method compares the labels recovered from encoding in the NapistuData object with the original labels stored in the NapistuGraph object to ensure data consistency.

Parameters:
  • napistu_graph (NapistuGraph) – The NapistuGraph object containing the original vertex labels

  • labeling_manager (LabelingManager) – The labeling manager used to decode the encoded labels

Returns:

True if the label encoding is consistent, False otherwise

Return type:

bool

Raises:

ValueError – If the NapistuData object doesn’t have encoded labels (y attribute), if vertex names don’t match between NapistuData and NapistuGraph, or if there are label encoding inconsistencies.

Examples

>>> # Validate label encoding consistency
>>> is_consistent = napistu_data._validate_labels(napistu_graph, labeling_manager)
>>> print(f"Label encoding is consistent: {is_consistent}")
True
_validate_vertex_encoding(napistu_graph: NapistuGraph, vertex_attribute: str, encoding_manager: EncodingManager | None = None) bool

Validate consistency between encoded values and original NapistuGraph vertex values.

This method compares the vertex values recovered from encoding in the NapistuData object with the original vertex values stored in the NapistuGraph object to ensure data consistency.

Parameters:
  • napistu_graph (NapistuGraph) – The NapistuGraph object containing the original categorical values

  • categorical_vertex_attribute (str) – The name of the categorical vertex attribute to validate (e.g., ‘node_type’)

Returns:

True if the encoding is consistent, False otherwise

Return type:

bool

Raises:

ValueError – If the categorical attribute is not found in the NapistuGraph, if vertex names don’t match between NapistuData and NapistuGraph, or if there are encoding inconsistencies.

Examples

>>> # Validate node_type encoding consistency
>>> is_consistent = napistu_data._validate_vertex_encoding(napistu_graph, 'node_type')
>>> print(f"Encoding is consistent: {is_consistent}")
True
>>> # Validate a different categorical attribute
>>> is_consistent = napistu_data._validate_vertex_encoding(napistu_graph, 'species_type')
>>> print(f"Species type encoding is consistent: {is_consistent}")
True
copy() NapistuData

Create a deep copy of the NapistuData object.

estimate_memory_footprint() Dict[str, int | None]

Estimate memory footprint of the NapistuData object.

Calculates the memory usage (in bytes) for each major component of the data object, including node features, edge index, edge attributes, and training/validation/test masks.

Returns:

Dictionary containing memory usage in bytes for each component: - “node_features”: Memory used by node features (x) - “edge_index”: Memory used by edge index - “edge_attr”: Memory used by edge attributes - “train_mask”: Memory used by train mask (None if not present) - “val_mask”: Memory used by validation mask (None if not present) - “test_mask”: Memory used by test mask (None if not present) - “total”: Total memory usage in bytes

Return type:

Dict[str, Optional[int]]

Examples

>>> footprint = data.estimate_memory_footprint()
>>> print(f"Total memory: {footprint['total'] / 1e9:.2f} GB")
>>> print(f"Node features: {footprint['node_features'] / 1e9:.2f} GB")
get_edge_feature_names() List[str] | None

Get the names of edge features.

Returns:

List of edge feature names, or None if not available

Return type:

Optional[List[str]]

get_edge_indices(df: DataFrame, from_col: str, to_col: str) torch.Tensor

Get edge index tensor from a DataFrame with vertex names.

Extracts vertex names from specified columns in a DataFrame, converts them to indices using get_vertex_indices, and returns a tensor of shape (2, num_edges) suitable for use as edge_index in PyTorch Geometric.

Parameters:
  • df (pd.DataFrame) – DataFrame containing edge information with vertex names.

  • from_col (str) – Name of the column containing source vertex names.

  • to_col (str) – Name of the column containing target vertex names.

Returns:

Tensor of shape (2, num_edges) with dtype torch.long, where: - Row 0 contains source vertex indices - Row 1 contains target vertex indices

Return type:

torch.Tensor

Raises:
  • KeyError – If from_col or to_col are not in the DataFrame.

  • ValueError – If any vertex names in the columns are not found in NapistuData.

get_edge_names() Index | None

Get the edge names as a pandas Index.

Returns:

Pandas Index of edge names, or None if not available

Return type:

Optional[pd.Index]

get_edge_weights() torch.Tensor | None

Get edge weights as a 1D tensor.

This method provides access to the original edge weights stored in the edge_weight attribute, which is the standard PyG convention for scalar edge weights.

Returns:

1D tensor of edge weights, or None if not available

Return type:

Optional[torch.Tensor]

Examples

>>> weights = data.get_edge_weights()
>>> if weights is not None:
...     print(f"Edge weights shape: {weights.shape}")
...     print(f"Mean weight: {weights.mean():.3f}")
get_feature_by_name(feature_name: str) torch.Tensor

Get a feature by name from the NapistuData object.

Parameters:

feature_name (str) – The name of the feature to get

Returns:

The feature tensor

Return type:

torch.Tensor

get_features_by_regex(regex: str, return_suffixes: bool = False) Tuple[torch.Tensor, List[str]]

Get features by regex from the NapistuData object.

Parameters:
  • regex (str) – The regex to search for in the vertex feature names

  • return_suffixes (bool) – If True, return the substring following the regex as feature_names

  • Returns

  • Tuple[torch.Tensor

    The features tensor and feature names

    features: torch.Tensor

    The features tensor

    feature_names: List[str]

    The feature names

  • List[str]]

    The features tensor and feature names

    features: torch.Tensor

    The features tensor

    feature_names: List[str]

    The feature names

Examples

>>> features, feature_names = napistu_data.get_features_by_regex("__source")
>>> print(features.shape)
>>> print(feature_names)
(100, 5)
['source_1', 'source_2', 'source_3', 'source_4', 'source_5']
get_num_relations() int | None

Get the number of relations from relation_type tensor.

Computes the number of unique relation types and validates that they are consecutive integers starting from 0 (0, 1, 2, …, N-1).

Returns:

Number of unique relation types

Return type:

int

Raises:

ValueError – If relation_type is missing or contains non-consecutive integers

get_summary(summary_type: str = 'basic') Dict[str, Any]

Get a summary of the NapistuData object.

Parameters:

summary_type (str, default="basic") – Type of summary to return: - “basic”: Core structural attributes only (num_nodes, num_edges, etc.) - “validation”: Basic + feature metadata for compatibility validation (includes feature names, aliases, relation labels, mask hashes) - “detailed”: Validation + boolean flags for attribute presence (for backward compatibility and debugging) - “all”: All available information

Returns:

Dictionary containing summary information about the data object

Return type:

Dict[str, Any]

get_symmetrical_relation_indices(treat_asymmetrically: Set[str] = {'other relation'}) List[int]

Analyze relation type names to detect symmetric ones.

Parses relation names in the format “{source_type} -> {target_type}” (spaces around the arrow are optional) and categorizes them based on whether source_type == target_type.

Parameters:

treat_asymmetrically (Set[str], optional) – Set of relation names to treat as asymmetric even if they don’t match the standard pattern. Defaults to {MERGE_RARE_STRATA_DEFS.OTHER_RELATION}.

Returns:

List of relation type indices that are symmetric

Return type:

List[int]

Raises:

ValueError – If relation_manager is missing or all relations are same type

get_vertex_feature_names() List[str] | None

Get the names of vertex features.

Returns:

List of vertex feature names, or None if not available

Return type:

Optional[List[str]]

get_vertex_indices(vertex_names: List[str] | Series) List[int]

Get the indices of vertices by their names.

Parameters:

vertex_names (List[str] or pd.Series) – List or Series of vertex names to look up. If Series, uses the values.

Returns:

List of integer indices corresponding to the vertex names. Indices are aligned with the vertex tensor (x) rows.

Return type:

List[int]

Raises:
  • TypeError – If vertex_names is not a list or pd.Series.

  • ValueError – If vertex names are not available in this NapistuData.

  • ValueError – If any vertex names are not found (results in -1 indices).

get_vertex_names() Index | None

Get the vertex names as a pandas Index.

Returns:

Pandas Index of vertex names, or None if not available

Return type:

Optional[pd.Index]

has_edges(edge_indices: torch.Tensor) torch.Tensor

Check which edges in edge_indices are present in this NapistuData.

Uses efficient set-based lookup for fast checking of many edges. Suitable for looking up large numbers of edges (e.g., 30K+).

Parameters:

edge_indices (torch.Tensor) – Edge indices tensor of shape (2, num_edges) to check. Row 0 should contain source vertex indices. Row 1 should contain target vertex indices.

Returns:

Boolean tensor of shape (num_edges,) where True indicates the edge exists in this NapistuData.

Return type:

torch.Tensor

Examples

>>> # Get edge indices from a DataFrame
>>> query_edges = napistu_data.get_edge_indices(df, from_col='from', to_col='to')
>>> # Check which edges exist
>>> matches = napistu_data.has_edges(query_edges)
>>> # Filter to only existing edges
>>> existing_edges = query_edges[:, matches]
reverse_edges(inplace: bool = True) NapistuData | None

Reverse all edges by swapping source and target in edge_index.

Only the edge indices are swapped. Edge attributes (edge_attr) are not modified. If direction-dependent edge attributes need to be swapped as well, reverse the NapistuGraph prior to NapistuData construction.

Parameters:

inplace (bool, default=True) – If True, modify in place. If False, return a new NapistuData.

Returns:

If inplace=False, returns a new NapistuData with reversed edges.

Return type:

NapistuData or None

save(filepath: str | Path) None

Save the NapistuData object to disk.

This method provides a safe way to save NapistuData objects, ensuring compatibility with PyTorch’s security features.

Parameters:

filepath (Union[str, Path]) – Path where to save the data object

Examples

>>> data.save('my_network.pt')
show_memory_footprint() None

Display memory footprint of the NapistuData object.

Prints a formatted breakdown of memory usage for each component of the data object in gigabytes (GB), showing node features, edge index, edge attributes, and training/validation/test masks.

Examples

>>> data.show_memory_footprint()
Node features: 0.05 GB
Edge index: 0.01 GB
Edge attributes: 0.20 GB
train_mask: 0.00 GB
val_mask: 0.00 GB
test_mask: 0.00 GB

Total data: 0.26 GB

show_summary() None

Display a summary of the NapistuData object.

trim(keep_edge_attr: bool = True, keep_labels: bool = True, keep_masks: bool = True, keep_relation_type: bool = True, inplace: bool = False) NapistuData

Create a memory-optimized copy with only essential training attributes.

This method creates a new NapistuData object with only the core attributes needed for training, stripping away all metadata and debugging information.

What’s Always Kept: - x (node features) - edge_index (graph structure) - edge_weight (if present)

What’s Always Removed: - ng_vertex_names, ng_edge_names (pandas objects) - vertex_feature_names, edge_feature_names (metadata) - name, splitting_strategy (metadata)

Conditionally Kept: - labeling_manager: Kept if keep_labels=True (needed for label metadata) - relation_manager: Kept if keep_relation_type=True (needed for relation metadata)

Parameters:
  • keep_edge_attr (bool, default=True) – Whether to keep edge_attr. Set False if not using edge features (e.g., no edge encoder). Major memory savings for large graphs.

  • keep_labels (bool, default=True) – Whether to keep y (node labels). Set False for unlabeled tasks.

  • keep_masks (bool, default=True) – Whether to keep train_mask, val_mask, test_mask. Set False if using custom splitting.

  • keep_relation_type (bool, default=True) – Whether to keep relation_type. Set False if not using relation-aware heads.

  • inplace (bool, default=False) – Whether to modify the current object in place or return a new object.

Returns:

New trimmed NapistuData object with minimal attributes

Return type:

NapistuData

Examples

>>> # Default - keep everything except metadata
>>> trimmed = data.trim()
>>>
>>> # No edge features needed (biggest memory savings)
>>> trimmed = data.trim(keep_edge_attr=False)
>>>
>>> # Unlabeled learning
>>> trimmed = data.trim(keep_labels=False)
>>>
>>> # Check memory savings
>>> print(f"Before: {data.estimate_memory():.2f} GB")
>>> print(f"After: {trimmed.estimate_memory():.2f} GB")
>>>
>>> # Minimal for inference (no edge features, labels, or masks)
>>> trimmed = data.trim(
...     keep_edge_attr=False,
...     keep_labels=False,
...     keep_masks=False
... )

Notes

Memory Impact (10M edges example): - Removing edge_attr: saves ~4 GB (100 features) to ~0.4 GB (10 features) - Removing pandas names: saves ~10-100 MB - Removing labels/masks: saves ~10-50 MB

unencode_features(napistu_graph: NapistuGraph, attribute_type: str, attribute: str, encoding_manager: EncodingManager | None = None) Series

Unencode features from the NapistuData object back to the original values.

This only categorical and passthrough encoding and is useful for validation purposes to ensure that encoded features are proprely aligned with their values in their original NapistuGraph.

Parameters:
  • napistu_graph (NapistuGraph) – The NapistuGraph object containing the original values

  • attribute_type (str) – The type of attribute to unencode (“vertices” or “edges”)

  • attribute (str) – An attribute to unencode (e.g., “node_type” or “species_type”)

  • encoding_manager (Optional[EncodingManager]) – The encoding manager to use to unencode the features. If this is not provided then the default encoding managers will be used.

Returns:

A DataFrame with the unencoded features

Return type:

pd.DataFrame

validate_graph_alignment(napistu_graph: NapistuGraph) None

Validate structural alignment between this NapistuData and a NapistuGraph.

Checks that vertex and edge counts match, and that the stored ng_vertex_names and ng_edge_names in this NapistuData are consistent with the NapistuGraph’s vertex and edge ordering.

Parameters:

napistu_graph (NapistuGraph) – The NapistuGraph this data was built from (or should align with).

Raises:

ValueError – If vertex counts, edge counts, or name orderings don’t match.

napistu_torch.napistu_data._apply_optional_nd_args(params: Dict[str, Any], x: torch.Tensor, edge_attr: torch.Tensor, y: torch.Tensor | None, vertex_feature_names: List[str] | None, edge_feature_names: List[str] | None, vertex_feature_name_aliases: Dict[str, str] | None, edge_feature_name_aliases: Dict[str, str] | None, ng_vertex_names: Series | None, ng_edge_names: DataFrame | None, splitting_strategy: str | None, labeling_manager: LabelingManager | None, relation_type: torch.Tensor | None, relation_manager: LabelingManager | None) None

Apply and validate optional NapistuData arguments.

Parameters:
  • params (Dict[str, Any]) – Dictionary to update with validated optional parameters

  • x (torch.Tensor) – Node feature matrix (for validation of vertex_feature_names length)

  • edge_attr (torch.Tensor) – Edge feature matrix (for validation of edge_feature_names length)

  • y (Optional[torch.Tensor]) – Node labels tensor

  • vertex_feature_names (Optional[List[str]]) – Names of vertex features

  • edge_feature_names (Optional[List[str]]) – Names of edge features

  • vertex_feature_name_aliases (Optional[Dict[str, str]]) – Mapping from vertex feature names to their canonical names

  • edge_feature_name_aliases (Optional[Dict[str, str]]) – Mapping from edge feature names to their canonical names

  • ng_vertex_names (Optional[pd.Series]) – Minimal vertex names from the original NapistuGraph

  • ng_edge_names (Optional[pd.DataFrame]) – Minimal edge names from the original NapistuGraph

  • splitting_strategy (Optional[str]) – Strategy used to split the data into train/test/val sets

  • labeling_manager (Optional[LabelingManager]) – Labeling manager used to encode the labels

  • relation_type (Optional[torch.Tensor]) – Relation type tensor

  • relation_manager (Optional[LabelingManager]) – Relation manager used to encode relation types

Returns:

Modifies params dict in place

Return type:

None

napistu_torch.napistu_data._validate_required_nd_args(x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, name: str) None

Validate required NapistuData arguments for correct types.

Parameters:
  • x (torch.Tensor) – Node feature matrix

  • edge_index (torch.Tensor) – Graph connectivity tensor

  • edge_attr (torch.Tensor) – Edge feature matrix

  • name (str) – Name of the NapistuData object

Raises:
  • TypeError – If any argument is not of the expected type

  • ValueError – If name is empty