b10-transfer 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,179 @@
1
+ """Cooperative cleanup utilities for b10-tcache.
2
+
3
+ This module provides cooperative cleanup functionality where each pod/replica
4
+ helps maintain the health of shared resources in b10fs by removing stale
5
+ lock files and incomplete cache files.
6
+ """
7
+
8
+ import fnmatch
9
+ import time
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import List, Tuple
13
+
14
+ from .constants import (
15
+ B10FS_CACHE_DIR,
16
+ CACHE_INCOMPLETE_SUFFIX,
17
+ CLEANUP_LOCK_TIMEOUT_SECONDS,
18
+ CLEANUP_INCOMPLETE_TIMEOUT_SECONDS,
19
+ )
20
+ from .utils import safe_execute, safe_unlink
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @safe_execute("Failed to find stale files", [])
26
+ def _find_stale_files(
27
+ directory: Path, pattern: str, timeout_seconds: int
28
+ ) -> List[Path]:
29
+ """Find files matching pattern that are older than timeout_seconds.
30
+
31
+ Args:
32
+ directory: Directory to search in
33
+ pattern: Glob pattern to match files (only searches current directory, not subdirs)
34
+ timeout_seconds: Age threshold in seconds
35
+
36
+ Returns:
37
+ List of Path objects for stale files
38
+ """
39
+ if not directory.exists():
40
+ return []
41
+
42
+ current_time = time.time()
43
+ stale_files = []
44
+
45
+ # Use iterdir() + fnmatch for explicit file-only matching in current directory
46
+
47
+ for file_path in directory.iterdir():
48
+ # Skip directories - we only want files
49
+ if not file_path.is_file():
50
+ logger.warning(
51
+ f"Found non-file in b10fs cache directory: {file_path}, skipping consideration for deletion in cleanup phase."
52
+ )
53
+ continue
54
+
55
+ # Check if filename matches pattern for the type of file we're looking for
56
+ if not fnmatch.fnmatch(file_path.name, pattern):
57
+ logger.warning(
58
+ f"Found non-matching file in b10fs cache directory: {file_path}, skipping consideration for deletion in cleanup phase."
59
+ )
60
+ continue
61
+
62
+ try:
63
+ file_age = current_time - file_path.stat().st_mtime
64
+ if file_age > timeout_seconds:
65
+ stale_files.append(file_path)
66
+ except OSError:
67
+ # File might have been deleted already
68
+ continue
69
+
70
+ return stale_files
71
+
72
+
73
+ @safe_execute("Failed to cleanup files", 0)
74
+ def _cleanup_files(files: List[Path], file_type: str) -> int:
75
+ """Clean up a list of files and return count of successfully cleaned files.
76
+
77
+ Args:
78
+ files: List of file paths to clean up
79
+ file_type: Description of file type for logging (e.g., "lock", "incomplete")
80
+
81
+ Returns:
82
+ Number of files successfully cleaned up
83
+ """
84
+ cleaned_count = 0
85
+
86
+ for file_path in files:
87
+ try:
88
+ file_age = time.time() - file_path.stat().st_mtime
89
+ safe_unlink(
90
+ file_path, f"Failed to clean stale {file_type} file: {file_path}"
91
+ )
92
+ cleaned_count += 1
93
+ logger.debug(
94
+ f"Cleaned stale {file_type} file: {file_path.name} (age: {file_age:.1f}s)"
95
+ )
96
+ except OSError:
97
+ # File might have been deleted by another pod
98
+ continue
99
+
100
+ return cleaned_count
101
+
102
+
103
+ @safe_execute("Cooperative cleanup failed", None)
104
+ def cooperative_cleanup_b10fs() -> None:
105
+ """Clean up stale shared resources in b10fs cooperatively.
106
+
107
+ Each pod/replica calls this function to help maintain system health by
108
+ removing files that are likely orphaned due to pod crashes or failures.
109
+
110
+ Removes:
111
+ - Lock files older than CLEANUP_LOCK_TIMEOUT_SECONDS (*.lock)
112
+ - Incomplete cache files older than CLEANUP_INCOMPLETE_TIMEOUT_SECONDS (*.incomplete*)
113
+
114
+ Does NOT remove:
115
+ - Final cache files (*.latest.tar.gz) - these are the actual cached results
116
+ - Files newer than the configured thresholds (may be from active operations)
117
+
118
+ This function is safe to run concurrently from multiple pods as file
119
+ deletion operations are atomic and missing files are handled gracefully.
120
+ """
121
+ b10fs_dir = Path(B10FS_CACHE_DIR)
122
+ if not b10fs_dir.exists():
123
+ logger.debug("b10fs cache directory doesn't exist, skipping cleanup")
124
+ return
125
+
126
+ # Find and clean stale lock files
127
+ stale_locks = _find_stale_files(b10fs_dir, "*.lock", CLEANUP_LOCK_TIMEOUT_SECONDS)
128
+ cleaned_locks = _cleanup_files(stale_locks, "lock")
129
+
130
+ # Find and clean stale incomplete cache files
131
+ incomplete_pattern = f"*{CACHE_INCOMPLETE_SUFFIX}*"
132
+ stale_incomplete = _find_stale_files(
133
+ b10fs_dir, incomplete_pattern, CLEANUP_INCOMPLETE_TIMEOUT_SECONDS
134
+ )
135
+ cleaned_incomplete = _cleanup_files(stale_incomplete, "incomplete cache")
136
+
137
+ # Log summary
138
+ total_cleaned = cleaned_locks + cleaned_incomplete
139
+ if total_cleaned > 0:
140
+ logger.info(
141
+ f"Cooperative cleanup: removed {cleaned_locks} stale locks, "
142
+ f"{cleaned_incomplete} incomplete files"
143
+ )
144
+ else:
145
+ logger.debug("Cooperative cleanup: no stale files found")
146
+
147
+
148
+ def get_cleanup_info() -> dict:
149
+ """Get information about cleanup configuration and current state.
150
+
151
+ Returns:
152
+ dict: Dictionary containing cleanup configuration and statistics:
153
+ - lock_timeout_seconds: Current lock file cleanup threshold
154
+ - incomplete_timeout_seconds: Current incomplete file cleanup threshold
155
+ - b10fs_cache_dir: Path to b10fs cache directory
156
+ - b10fs_exists: Whether b10fs cache directory exists
157
+ - stale_locks_count: Number of lock files that would be cleaned
158
+ - stale_incomplete_count: Number of incomplete files that would be cleaned
159
+ """
160
+ b10fs_dir = Path(B10FS_CACHE_DIR)
161
+
162
+ info = {
163
+ "lock_timeout_seconds": CLEANUP_LOCK_TIMEOUT_SECONDS,
164
+ "incomplete_timeout_seconds": CLEANUP_INCOMPLETE_TIMEOUT_SECONDS,
165
+ "b10fs_cache_dir": str(b10fs_dir),
166
+ "b10fs_exists": b10fs_dir.exists(),
167
+ "stale_locks_count": len(
168
+ _find_stale_files(b10fs_dir, "*.lock", CLEANUP_LOCK_TIMEOUT_SECONDS)
169
+ ),
170
+ "stale_incomplete_count": len(
171
+ _find_stale_files(
172
+ b10fs_dir,
173
+ f"*{CACHE_INCOMPLETE_SUFFIX}*",
174
+ CLEANUP_INCOMPLETE_TIMEOUT_SECONDS,
175
+ )
176
+ ),
177
+ }
178
+
179
+ return info
@@ -0,0 +1,149 @@
1
+ """Configuration constants for b10-tcache.
2
+
3
+ This module defines configuration constants for the PyTorch compilation cache system.
4
+ Some values can be overridden by environment variables, but security caps are enforced
5
+ to prevent malicious or accidental misuse in production environments.
6
+ """
7
+
8
+ import os
9
+ from enum import Enum, auto
10
+
11
+ # Import helper functions from utils to avoid duplication
12
+ from .utils import (
13
+ get_current_username,
14
+ validate_path_security,
15
+ validate_boolean_env,
16
+ apply_cap,
17
+ )
18
+
19
+ # Cache directories with security validation
20
+
21
+ # Validate TORCH_CACHE_DIR - allow /tmp and /cache paths
22
+ # TORCHINDUCTOR_CACHE_DIR is what torch uses by default. If it is not set, we use a different value.
23
+ _torch_cache_dir = os.getenv(
24
+ "TORCHINDUCTOR_CACHE_DIR", f"/tmp/torchinductor_{get_current_username()}"
25
+ )
26
+ TORCH_CACHE_DIR = validate_path_security(
27
+ _torch_cache_dir, ["/tmp/", "/cache/"], "TORCHINDUCTOR_CACHE_DIR"
28
+ )
29
+
30
+ # B10FS cache directory validation
31
+ _REQUIRED_TORCH_CACHE_DIR_PREFIX = "/cache/model"
32
+ _b10fs_cache_dir = os.getenv(
33
+ "B10FS_CACHE_DIR", f"{_REQUIRED_TORCH_CACHE_DIR_PREFIX}/compile_cache"
34
+ )
35
+ B10FS_CACHE_DIR = validate_path_security(
36
+ _b10fs_cache_dir, [_REQUIRED_TORCH_CACHE_DIR_PREFIX], "B10FS_CACHE_DIR"
37
+ )
38
+
39
+ # Validate LOCAL_WORK_DIR - allow /app, /tmp, and /cache paths.
40
+ # This is like a "scratch" directory where you can do work (like compression/archival for example)
41
+ _local_work_dir = os.getenv("LOCAL_WORK_DIR", "/app")
42
+ LOCAL_WORK_DIR = validate_path_security(
43
+ _local_work_dir, ["/app/", "/tmp/", "/cache/"], "LOCAL_WORK_DIR"
44
+ )
45
+
46
+ # Security caps to prevent resource exhaustion
47
+ _MAX_CACHE_SIZE_CAP_MB = 1 * 1024 # 1GB hard limit per cache archive
48
+ _MAX_CONCURRENT_SAVES_CAP = 100 # Maximum concurrent save operations (only used as estimate for b10fs space requirements/thresholding)
49
+
50
+
51
+ # Cache limits (capped for security)
52
+ _user_max_cache_size = int(os.getenv("MAX_CACHE_SIZE_MB", "1024"))
53
+ MAX_CACHE_SIZE_MB = apply_cap(
54
+ _user_max_cache_size, _MAX_CACHE_SIZE_CAP_MB, "MAX_CACHE_SIZE_MB"
55
+ )
56
+
57
+ _user_max_concurrent_saves = int(os.getenv("MAX_CONCURRENT_SAVES", "50"))
58
+ MAX_CONCURRENT_SAVES = apply_cap(
59
+ _user_max_concurrent_saves, _MAX_CONCURRENT_SAVES_CAP, "MAX_CONCURRENT_SAVES"
60
+ )
61
+
62
+ # Space requirements
63
+ MIN_LOCAL_SPACE_MB = 50 * 1024 # 50GB minimum space on local machine
64
+ REQUIRED_B10FS_SPACE_MB = max(MAX_CONCURRENT_SAVES * MAX_CACHE_SIZE_MB, 100_000)
65
+
66
+ # B10FS configuration
67
+ # The default is "0" (disabled) to prevent accidental enabling.
68
+ # But this does limit the ability to enable b10fs for debugging purposes.
69
+ # Probably should use B10FS_ENABLED instead for that.
70
+ _baseten_fs_enabled = os.getenv("BASETEN_FS_ENABLED", "0")
71
+ BASETEN_FS_ENABLED = validate_boolean_env(_baseten_fs_enabled, "BASETEN_FS_ENABLED")
72
+
73
+ # File naming patterns
74
+ CACHE_FILE_EXTENSION = ".tar.gz"
75
+ CACHE_LATEST_SUFFIX = ".latest"
76
+ CACHE_INCOMPLETE_SUFFIX = ".incomplete"
77
+ CACHE_PREFIX = "cache_"
78
+
79
+
80
+ # Space monitoring settings
81
+ SPACE_MONITOR_CHECK_INTERVAL_SECONDS = (
82
+ 0.5 # How often to check disk space during operations
83
+ )
84
+
85
+ # Cooperative cleanup settings
86
+ # Cache operations (load/save) should complete within ~15 seconds under normal conditions
87
+ _LOCK_TIMEOUT_CAP_SECONDS = 3600 # 1 hour hard limit
88
+ _INCOMPLETE_TIMEOUT_CAP_SECONDS = 7200 # 2 hours hard limit
89
+
90
+ # Lock file cleanup timeout (default: 2x expected operation time)
91
+ _user_lock_timeout = int(
92
+ os.getenv("CLEANUP_LOCK_TIMEOUT_SECONDS", "30")
93
+ ) # 30 seconds default
94
+ CLEANUP_LOCK_TIMEOUT_SECONDS = apply_cap(
95
+ _user_lock_timeout, _LOCK_TIMEOUT_CAP_SECONDS, "CLEANUP_LOCK_TIMEOUT_SECONDS"
96
+ )
97
+
98
+ # Incomplete file cleanup timeout (default: 3x expected operation time)
99
+ _user_incomplete_timeout = int(
100
+ os.getenv("CLEANUP_INCOMPLETE_TIMEOUT_SECONDS", "60")
101
+ ) # 1 minute default
102
+ CLEANUP_INCOMPLETE_TIMEOUT_SECONDS = apply_cap(
103
+ _user_incomplete_timeout,
104
+ _INCOMPLETE_TIMEOUT_CAP_SECONDS,
105
+ "CLEANUP_INCOMPLETE_TIMEOUT_SECONDS",
106
+ )
107
+
108
+
109
+ # Worker process result status enum
110
+ class WorkerStatus(Enum):
111
+ """Status values for worker process results."""
112
+
113
+ SUCCESS = auto()
114
+ ERROR = auto()
115
+ CANCELLED = auto()
116
+
117
+
118
+ class LoadStatus(Enum):
119
+ """Status values for cache loading operations."""
120
+
121
+ SUCCESS = auto()
122
+ ERROR = auto()
123
+ DOES_NOT_EXIST = auto()
124
+ SKIPPED = auto()
125
+
126
+
127
+ class SaveStatus(Enum):
128
+ """Status values for cache saving operations."""
129
+
130
+ SUCCESS = auto()
131
+ ERROR = auto()
132
+ SKIPPED = auto()
133
+
134
+
135
+ class TransferStatus(Enum):
136
+ """Status values for generic transfer operations."""
137
+
138
+ SUCCESS = auto()
139
+ ERROR = auto()
140
+ INTERRUPTED = auto()
141
+
142
+
143
+ class AsyncTransferStatus(Enum):
144
+ NOT_STARTED = auto()
145
+ IN_PROGRESS = auto()
146
+ SUCCESS = auto()
147
+ ERROR = auto()
148
+ INTERRUPTED = auto()
149
+ CANCELLED = auto()
b10_transfer/core.py ADDED
@@ -0,0 +1,160 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from .cleanup import cooperative_cleanup_b10fs
5
+ from .utils import (
6
+ timed_fn,
7
+ safe_execute,
8
+ cache_operation,
9
+ )
10
+ from .space_monitor import (
11
+ check_sufficient_disk_space,
12
+ CacheSpaceMonitor,
13
+ CacheOperationInterrupted,
14
+ run_monitored_process,
15
+ )
16
+ from .constants import (
17
+ B10FS_CACHE_DIR,
18
+ LOCAL_WORK_DIR,
19
+ REQUIRED_B10FS_SPACE_MB,
20
+ MIN_LOCAL_SPACE_MB,
21
+ TransferStatus,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @timed_fn(logger=logger, name="Generic transfer operation")
28
+ @safe_execute("Transfer failed", TransferStatus.ERROR)
29
+ def transfer(
30
+ source: Path,
31
+ dest: Path,
32
+ callback: callable,
33
+ *callback_args,
34
+ monitor_local: bool = True,
35
+ monitor_b10fs: bool = True,
36
+ **callback_kwargs,
37
+ ) -> TransferStatus:
38
+ """Generic transfer function with space monitoring and atomic operations.
39
+
40
+ The actual transfer logic is provided via callback.
41
+
42
+ The function handles:
43
+ - Cooperative cleanup of stale shared resources
44
+ - Space monitoring during operations (optional for local and b10fs)
45
+ - Atomic operations using temp files and rename
46
+ - Automatic cleanup on interruption or failure
47
+ - Lock management for b10fs operations
48
+
49
+ Args:
50
+ source: Source path for the transfer operation
51
+ dest: Destination path for the transfer operation
52
+ callback: Function to perform the actual transfer work
53
+ *callback_args: Positional arguments to pass to callback
54
+ monitor_local: Whether to monitor local disk space (default: True)
55
+ monitor_b10fs: Whether to monitor b10fs disk space (default: True)
56
+ **callback_kwargs: Keyword arguments to pass to callback
57
+
58
+ Returns:
59
+ TransferStatus:
60
+ TransferStatus.SUCCESS if transfer completed successfully
61
+ TransferStatus.ERROR if transfer failed
62
+ TransferStatus.INTERRUPTED if transfer was interrupted due to insufficient disk space
63
+
64
+ Raises:
65
+ CacheValidationError: If b10fs is not enabled (caught and returns TransferStatus.ERROR).
66
+ CacheOperationInterrupted: If operations interrupted due to insufficient
67
+ disk space (caught and returns TransferStatus.INTERRUPTED).
68
+ Exception: Any other errors during transfer (caught and returns TransferStatus.ERROR).
69
+ """
70
+ with cache_operation("Transfer"):
71
+ # Cooperative cleanup of stale shared resources
72
+ cooperative_cleanup_b10fs()
73
+
74
+ b10fs_dir = Path(B10FS_CACHE_DIR)
75
+ work_dir = Path(LOCAL_WORK_DIR)
76
+
77
+ # Determine which paths to monitor based on source/dest
78
+ local_path = None
79
+ b10fs_path = None
80
+
81
+ if str(source).startswith(str(b10fs_dir)) or str(dest).startswith(
82
+ str(b10fs_dir)
83
+ ):
84
+ b10fs_path = b10fs_dir
85
+
86
+ if (
87
+ str(source).startswith(str(work_dir))
88
+ or str(dest).startswith(str(work_dir))
89
+ or not str(source).startswith(str(b10fs_dir))
90
+ or not str(dest).startswith(str(b10fs_dir))
91
+ ):
92
+ local_path = work_dir
93
+
94
+ # Initial disk space checks
95
+ if monitor_local and local_path:
96
+ check_sufficient_disk_space(
97
+ local_path, MIN_LOCAL_SPACE_MB, "local transfer operations"
98
+ )
99
+ logger.debug(
100
+ f"Initial local space check passed: {MIN_LOCAL_SPACE_MB:.1f}MB required"
101
+ )
102
+
103
+ if monitor_b10fs and b10fs_path:
104
+ check_sufficient_disk_space(
105
+ b10fs_path, REQUIRED_B10FS_SPACE_MB, "b10fs transfer operations"
106
+ )
107
+ logger.debug(
108
+ f"Initial b10fs space check passed: {REQUIRED_B10FS_SPACE_MB:.1f}MB required"
109
+ )
110
+
111
+ # Determine primary space monitor (prioritize b10fs if both are monitored)
112
+ primary_monitor = None
113
+ if monitor_b10fs and b10fs_path:
114
+ primary_monitor = CacheSpaceMonitor(REQUIRED_B10FS_SPACE_MB, b10fs_path)
115
+ elif monitor_local and local_path:
116
+ primary_monitor = CacheSpaceMonitor(MIN_LOCAL_SPACE_MB, local_path)
117
+
118
+ if primary_monitor is None:
119
+ # No monitoring requested, execute callback directly
120
+ logger.info(f"Starting transfer (no monitoring): {source} -> {dest}")
121
+ callback(source, dest, *callback_args, **callback_kwargs)
122
+ logger.info("Transfer complete")
123
+ return TransferStatus.SUCCESS
124
+
125
+ # Start the primary space monitor
126
+ primary_monitor.start()
127
+
128
+ try:
129
+ # Execute the callback using monitored process for continuous space monitoring
130
+ logger.info(f"Starting monitored transfer: {source} -> {dest}")
131
+
132
+ # Try direct callback with run_monitored_process first
133
+ try:
134
+ run_monitored_process(
135
+ callback,
136
+ (source, dest, *callback_args),
137
+ primary_monitor,
138
+ "transfer callback",
139
+ )
140
+ logger.info("Transfer complete (monitored)")
141
+ return TransferStatus.SUCCESS
142
+
143
+ except (TypeError, AttributeError, ImportError, OSError) as e:
144
+ # Callback not pickleable or other serialization issue
145
+ logger.warning(
146
+ f"Callback not suitable for process isolation, running without monitoring: {e}"
147
+ )
148
+
149
+ # Fallback to direct execution without process isolation
150
+ callback(source, dest, *callback_args, **callback_kwargs)
151
+ logger.info("Transfer complete (unmonitored)")
152
+ return TransferStatus.SUCCESS
153
+
154
+ except CacheOperationInterrupted as e:
155
+ logger.warning(f"Transfer interrupted: {e}")
156
+ return TransferStatus.INTERRUPTED
157
+
158
+ finally:
159
+ # Stop space monitor
160
+ primary_monitor.stop()
@@ -0,0 +1,134 @@
1
+ """Environment detection utilities for GPU cache management.
2
+
3
+ This module provides functions to generate unique environment keys based on
4
+ GPU hardware and driver information for cache compatibility.
5
+ """
6
+
7
+ import hashlib
8
+ import json
9
+ import logging
10
+ import os
11
+
12
+ # Optional imports - may not be available in all environments
13
+ try:
14
+ import torch
15
+
16
+ TORCH_AVAILABLE = True
17
+ except ImportError:
18
+ torch = None
19
+ TORCH_AVAILABLE = False
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ KEY_LENGTH = 16
24
+ UNKNOWN_HOSTNAME = "unknown-host"
25
+
26
+
27
+ def get_cache_filename() -> str:
28
+ """Get the cache filename prefix for the current environment.
29
+
30
+ This function generates a cache filename prefix that includes the
31
+ environment key and hostname to ensure cache files are environment-specific
32
+ and unique per machine.
33
+
34
+ Returns:
35
+ str: Cache filename prefix in format "cache_{environment_key}.{hostname}".
36
+ """
37
+ env_key = get_environment_key()
38
+ hostname = os.uname().nodename or os.getenv("HOSTNAME", UNKNOWN_HOSTNAME)
39
+ return f"cache_{env_key}.{hostname}"
40
+
41
+
42
+ def get_environment_key() -> str:
43
+ """Generate unique environment key based on PyTorch/CUDA/GPU configuration.
44
+
45
+ This function creates a deterministic hash key based only on node-specific
46
+ hardware and driver information to ensure cache compatibility across
47
+ different environments with identical GPU configurations.
48
+
49
+ Returns:
50
+ str: A 16-character hex hash uniquely identifying the environment.
51
+
52
+ Raises:
53
+ RuntimeError: If PyTorch/CUDA are unavailable or environment key
54
+ generation fails for any reason.
55
+
56
+ Note:
57
+ Includes all GPU properties that affect Triton kernel generation.
58
+ References from PyTorch repository:
59
+ - Device name: GPU model identification (codecache.py:199)
60
+ - CUDA version: Driver compatibility (codecache.py:200)
61
+
62
+ These next four are bit more embedded in the codebase and not obviously used in the torchinductor_root cache check.
63
+ They are commented out because:
64
+ 1) They are not explicitly used in the torchinductor_root cache check.
65
+ 2) It's not clear but likely that any violation of these properties will cause local re-compilation when the torch guards activate, not full recompilation.
66
+ 3) We don't want to over-estimate the number of unique environments since that'll cause more cache misses overall.
67
+ We can add them back if we need to.
68
+
69
+ - Compute capability: Available GPU instructions/features (scheduler.py:4286, triton_heuristics.py:480)
70
+ - Multi-processor count: Affects occupancy and grid sizing (choices.py:210, triton_heuristics.py:539)
71
+ - Warp size: Thread grouping (triton_heuristics.py:487, triton.py:2763)
72
+ - Register limits: Affects kernel optimization strategies (triton_heuristics.py:518,536)
73
+
74
+ We're also not including the torch and triton versions in the hash, despite the torch compilation cache dependent on these two things.
75
+ This is because we are saving the cache to the `/cache/model` directory, which is already deployment-specific where the torch/triton versions are constant.
76
+ """
77
+ try:
78
+ _validate_cuda_environment()
79
+
80
+ device_properties = torch.cuda.get_device_properties(
81
+ torch.cuda.current_device()
82
+ )
83
+ node_data = _extract_gpu_properties(device_properties, torch.version.cuda)
84
+
85
+ node_json = json.dumps(node_data, sort_keys=True)
86
+ return hashlib.sha256(node_json.encode("utf-8")).hexdigest()[:KEY_LENGTH]
87
+
88
+ except (ImportError, RuntimeError, AssertionError) as e:
89
+ logger.error(f"GPU environment unavailable: {e}")
90
+ raise RuntimeError(f"Cannot generate environment key: {e}") from e
91
+ except Exception as e:
92
+ logger.error(f"Unexpected error during environment key generation: {e}")
93
+ raise RuntimeError(f"Environment key generation failed: {e}") from e
94
+
95
+
96
+ def _validate_cuda_environment() -> None:
97
+ """Validate that PyTorch and CUDA are available and properly configured.
98
+
99
+ Raises:
100
+ ImportError: If PyTorch is not available
101
+ RuntimeError: If CUDA is not available or version is missing
102
+ """
103
+ if not TORCH_AVAILABLE:
104
+ raise ImportError("PyTorch not available")
105
+
106
+ if not torch.cuda.is_available():
107
+ raise RuntimeError("CUDA must be available - AMD/HIP not supported")
108
+
109
+ if torch.version.cuda is None:
110
+ raise RuntimeError("CUDA version must be available")
111
+
112
+
113
+ def _extract_gpu_properties(
114
+ device_properties: any, cuda_version: str
115
+ ) -> dict[str, any]:
116
+ """Extract relevant GPU properties for environment key generation.
117
+
118
+ Args:
119
+ device_properties: CUDA device properties object
120
+ cuda_version: CUDA version string
121
+ SEE docstring of get_environment_key() for more details and why certain properties are excluded.
122
+
123
+ Returns:
124
+ Dict containing GPU properties that affect kernel generation
125
+ """
126
+ return {
127
+ "device_name": device_properties.name, # GPU model
128
+ "cuda_version": cuda_version, # Driver version
129
+ # "compute_capability": (device_properties.major, device_properties.minor), # GPU features
130
+ # "multi_processor_count": device_properties.multi_processor_count, # SM count for occupancy
131
+ # "warp_size": device_properties.warp_size, # Thread grouping size
132
+ # "regs_per_multiprocessor": getattr(device_properties, "regs_per_multiprocessor", None), # Register limits
133
+ # "max_threads_per_multi_processor": getattr(device_properties, "max_threads_per_multi_processor", None), # Thread limits
134
+ }