napistu_torch.ml.splitting

Functions

create_split_masks(df, splits_dict)

Create train/test/val masks from split DataFrames.

train_test_val_split(df[, train_size, ...])

Split DataFrame into train, test, and validation sets.

napistu_torch.ml.splitting.create_split_masks(df: DataFrame, splits_dict: Dict[str, DataFrame]) Dict[str, torch.Tensor]

Create train/test/val masks from split DataFrames.

Parameters:
  • df (pd.DataFrame) – Full DataFrame (before splitting)

  • splits_dict (Dict[str, pd.DataFrame]) – Dictionary with ‘train’, ‘test’, ‘validation’ keys containing split DataFrames

Returns:

Dictionary with ‘train_mask’, ‘test_mask’, ‘validation_mask’ boolean tensors

Return type:

Dict[str, torch.Tensor]

napistu_torch.ml.splitting.train_test_val_split(df: DataFrame, train_size: float = 0.7, test_size: float = 0.15, val_size: float = 0.15, random_state: int = 42, shuffle: bool = True, stratify: Series | None = None, return_dict: bool = False) DataFrame | tuple[DataFrame, DataFrame, DataFrame] | dict[str, DataFrame]

Split DataFrame into train, test, and validation sets.

This is an extension of sklearn’s train_test_split for three-way splits.

Parameters:
  • df (pd.DataFrame) – DataFrame to split

  • train_size (float, default=0.7) – Proportion of data for training (0.0 to 1.0)

  • test_size (float, default=0.15) – Proportion of data for testing (0.0 to 1.0)

  • val_size (float, default=0.15) – Proportion of data for validation (0.0 to 1.0)

  • random_state (int, default=42) – Random seed for reproducibility

  • shuffle (bool, default=True) – Whether to shuffle the data before splitting

  • stratify (array-like, optional) – If not None, data is split in a stratified fashion using this as class labels

  • return_dict (bool, default=False) – If True, return a dictionary with keys for each split

Returns:

  • If return_dict is False

  • train_df (pd.DataFrame) – Training data

  • test_df (pd.DataFrame) – Test data

  • val_df (pd.DataFrame) – Validation data

  • If return_dict is True

    dict

    A dictionary with keys for each split - train : pd.DataFrame

    Training data

    • testpd.DataFrame

      Test data

    • valpd.DataFrame

      Validation data

Examples

>>> # Basic usage
>>> train, test, val = train_test_val_split(df)
>>>
>>> # Custom split ratios
>>> train, test, val = train_test_val_split(df, train_size=0.8, test_size=0.1, val_size=0.1)
>>>
>>> # Stratified split by edge type
>>> train, test, val = train_test_val_split(df, stratify=df['edge_type'])
>>>
>>> # No shuffling (preserve order)
>>> train, test, val = train_test_val_split(df, shuffle=False)