napistu_torch.utils.torch_utils
Utility functions for managing torch devices and memory.
Public Functions
- cleanup_tensors(tensors)
Clean up tensors and empty cache for all unique devices they were on.
- delete_tensors(tensors)
Delete one or more tensors without emptying cache.
- empty_cache(device)
Empty the cache for a given device.
- ensure_device(device, allow_autoselect=False, mps_valid=True)
Ensure the device is a torch.device.
- memory_manager(device)
Context manager for general memory management.
- select_device(mps_valid=True)
Select the device to use for the model.
Functions
|
Clean up tensors and empty cache for all unique devices they were on. |
|
Delete one or more tensors without emptying cache. |
|
Empty the cache for a given device. |
|
Ensure the device is a torch.device. |
|
Log the memory usage of the device. |
|
Context manager for general memory management. |
|
Selects the device to use for the model. |
- napistu_torch.utils.torch_utils.cleanup_tensors(*tensors) None
Clean up tensors and empty cache for all unique devices they were on.
Deletes the provided tensors and then calls empty_cache() for each unique device that the tensors were on. This ensures GPU/MPS memory is freed immediately. Non-tensor objects (e.g., DataFrames) are simply deleted without cache clearing.
- Parameters:
*tensors (torch.Tensor or any object) – One or more tensors to clean up. Devices are automatically detected from the tensors themselves. Non-tensor objects are deleted but don’t trigger cache clearing.
Examples
>>> # Clean up tensors on GPU >>> cleanup_tensors(attention, rank_tensor, edge_attentions) >>> # Clean up tensors on different devices >>> cleanup_tensors(tensor1, tensor2) # Automatically handles both devices >>> # Non-tensors are handled gracefully >>> cleanup_tensors(tensor, df) # DataFrame is deleted but doesn't affect cache
- napistu_torch.utils.torch_utils.delete_tensors(*tensors) None
Delete one or more tensors without emptying cache.
- Parameters:
*tensors (torch.Tensor) – One or more tensors to delete
- napistu_torch.utils.torch_utils.empty_cache(device: str | torch.device) None
Empty the cache for a given device. If the device is not MPS or GPU, do nothing.
- Parameters:
device (str or torch.device) – The device to empty the cache for. Can be a string like ‘cuda:0’ or ‘mps’, or a torch.device object.
- napistu_torch.utils.torch_utils.ensure_device(device: str | torch.device | None, allow_autoselect: bool = False, mps_valid: bool = True) torch.device
Ensure the device is a torch.device.
- Parameters:
device (Union[str, torch.device]) – The device to ensure
allow_autoselect (bool) – Whether to allow automatic selection of the device if the device is not specified
mps_valid (bool) – Whether to use MPS if available.
- napistu_torch.utils.torch_utils.log_memory_usage(label: str, device: str | torch.device | None = None) None
Log the memory usage of the device.
- Parameters:
label (str) – The label to log the memory usage for
device (Optional[Union[str, torch_device]]) – The device to log the memory usage for
- napistu_torch.utils.torch_utils.memory_manager(device: torch.device = torch.device)
Context manager for general memory management.
This context manager ensures proper cleanup by: 1. Clearing device cache before and after operations 2. Forcing garbage collection
- Parameters:
device (torch.device) – The device to manage memory for
Usage –
- with memory_manager(device):
# Your operations here pass
- napistu_torch.utils.torch_utils.select_device(mps_valid: bool = True) torch.device
Selects the device to use for the model. If MPS is available and mps_valid is True, use MPS. If CUDA is available, use CUDA. Otherwise, use CPU.
- Parameters:
mps_valid (bool) – Whether to use MPS if available.
- Returns:
device – The device to use for the model.
- Return type:
torch.device