napistu_torch.utils.torch_utils

Utility functions for managing torch devices and memory.

Public Functions

cleanup_tensors(tensors)

Clean up tensors and empty cache for all unique devices they were on.

delete_tensors(tensors)

Delete one or more tensors without emptying cache.

empty_cache(device)

Empty the cache for a given device.

ensure_device(device, allow_autoselect=False, mps_valid=True)

Ensure the device is a torch.device.

memory_manager(device)

Context manager for general memory management.

select_device(mps_valid=True)

Select the device to use for the model.

Functions

cleanup_tensors(*tensors)

Clean up tensors and empty cache for all unique devices they were on.

delete_tensors(*tensors)

Delete one or more tensors without emptying cache.

empty_cache(device)

Empty the cache for a given device.

ensure_device(device[, allow_autoselect, ...])

Ensure the device is a torch.device.

log_memory_usage(label[, device])

Log the memory usage of the device.

memory_manager([device])

Context manager for general memory management.

select_device([mps_valid])

Selects the device to use for the model.

napistu_torch.utils.torch_utils.cleanup_tensors(*tensors) None

Clean up tensors and empty cache for all unique devices they were on.

Deletes the provided tensors and then calls empty_cache() for each unique device that the tensors were on. This ensures GPU/MPS memory is freed immediately. Non-tensor objects (e.g., DataFrames) are simply deleted without cache clearing.

Parameters:

*tensors (torch.Tensor or any object) – One or more tensors to clean up. Devices are automatically detected from the tensors themselves. Non-tensor objects are deleted but don’t trigger cache clearing.

Examples

>>> # Clean up tensors on GPU
>>> cleanup_tensors(attention, rank_tensor, edge_attentions)
>>> # Clean up tensors on different devices
>>> cleanup_tensors(tensor1, tensor2)  # Automatically handles both devices
>>> # Non-tensors are handled gracefully
>>> cleanup_tensors(tensor, df)  # DataFrame is deleted but doesn't affect cache
napistu_torch.utils.torch_utils.delete_tensors(*tensors) None

Delete one or more tensors without emptying cache.

Parameters:

*tensors (torch.Tensor) – One or more tensors to delete

napistu_torch.utils.torch_utils.empty_cache(device: str | torch.device) None

Empty the cache for a given device. If the device is not MPS or GPU, do nothing.

Parameters:

device (str or torch.device) – The device to empty the cache for. Can be a string like ‘cuda:0’ or ‘mps’, or a torch.device object.

napistu_torch.utils.torch_utils.ensure_device(device: str | torch.device | None, allow_autoselect: bool = False, mps_valid: bool = True) torch.device

Ensure the device is a torch.device.

Parameters:
  • device (Union[str, torch.device]) – The device to ensure

  • allow_autoselect (bool) – Whether to allow automatic selection of the device if the device is not specified

  • mps_valid (bool) – Whether to use MPS if available.

napistu_torch.utils.torch_utils.log_memory_usage(label: str, device: str | torch.device | None = None) None

Log the memory usage of the device.

Parameters:
  • label (str) – The label to log the memory usage for

  • device (Optional[Union[str, torch_device]]) – The device to log the memory usage for

napistu_torch.utils.torch_utils.memory_manager(device: torch.device = torch.device)

Context manager for general memory management.

This context manager ensures proper cleanup by: 1. Clearing device cache before and after operations 2. Forcing garbage collection

Parameters:
  • device (torch.device) – The device to manage memory for

  • Usage

    with memory_manager(device):

    # Your operations here pass

napistu_torch.utils.torch_utils.select_device(mps_valid: bool = True) torch.device

Selects the device to use for the model. If MPS is available and mps_valid is True, use MPS. If CUDA is available, use CUDA. Otherwise, use CPU.

Parameters:

mps_valid (bool) – Whether to use MPS if available.

Returns:

device – The device to use for the model.

Return type:

torch.device