comfy-env 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,176 @@
1
+ """
2
+ GPU detection for automatic CUDA version selection.
3
+
4
+ Detects Blackwell GPUs (RTX 50xx, B100, B200) which require CUDA 12.8,
5
+ vs older GPUs which use CUDA 12.4.
6
+
7
+ This runs BEFORE PyTorch is installed, so we use nvidia-smi directly.
8
+ """
9
+
10
+ import subprocess
11
+ from typing import List, Dict, Optional
12
+
13
+
14
+ def detect_gpu_info() -> List[Dict[str, str]]:
15
+ """
16
+ Detect GPU name and compute capability using nvidia-smi.
17
+
18
+ Returns:
19
+ List of dicts with 'name' and 'compute_cap' keys.
20
+ Empty list if detection fails.
21
+ """
22
+ try:
23
+ result = subprocess.run(
24
+ ["nvidia-smi", "--query-gpu=name,compute_cap", "--format=csv,noheader"],
25
+ capture_output=True,
26
+ text=True,
27
+ timeout=10
28
+ )
29
+ if result.returncode == 0:
30
+ gpus = []
31
+ for line in result.stdout.strip().split('\n'):
32
+ if not line.strip():
33
+ continue
34
+ parts = [p.strip() for p in line.split(',')]
35
+ name = parts[0] if parts else "Unknown"
36
+ cc = parts[1] if len(parts) > 1 else "0.0"
37
+ gpus.append({"name": name, "compute_cap": cc})
38
+ return gpus
39
+ except FileNotFoundError:
40
+ # nvidia-smi not found - no NVIDIA GPU or driver not installed
41
+ pass
42
+ except subprocess.TimeoutExpired:
43
+ pass
44
+ except Exception:
45
+ pass
46
+ return []
47
+
48
+
49
+ def is_blackwell_gpu(name: str, compute_cap: str) -> bool:
50
+ """
51
+ Check if a GPU is Blackwell architecture.
52
+
53
+ Args:
54
+ name: GPU name from nvidia-smi
55
+ compute_cap: Compute capability string (e.g., "8.9", "12.0")
56
+
57
+ Returns:
58
+ True if Blackwell (requires CUDA 12.8)
59
+ """
60
+ name_upper = name.upper()
61
+
62
+ # Check by name patterns
63
+ blackwell_patterns = [
64
+ "RTX 50", # RTX 5090, 5080, 5070, etc.
65
+ "RTX50", # Without space
66
+ "B100", # Datacenter Blackwell
67
+ "B200", # Datacenter Blackwell
68
+ "GB202", # Blackwell die
69
+ "GB203",
70
+ "GB205",
71
+ "GB206",
72
+ "GB207",
73
+ ]
74
+
75
+ if any(pattern in name_upper for pattern in blackwell_patterns):
76
+ return True
77
+
78
+ # Check by compute capability (10.0+ = Blackwell)
79
+ try:
80
+ cc = float(compute_cap)
81
+ if cc >= 10.0:
82
+ return True
83
+ except (ValueError, TypeError):
84
+ pass
85
+
86
+ return False
87
+
88
+
89
+ def needs_cuda_128() -> bool:
90
+ """
91
+ Check if any detected GPU requires CUDA 12.8.
92
+
93
+ Returns:
94
+ True if Blackwell GPU detected, False otherwise.
95
+ """
96
+ gpus = detect_gpu_info()
97
+
98
+ for gpu in gpus:
99
+ if is_blackwell_gpu(gpu["name"], gpu["compute_cap"]):
100
+ return True
101
+
102
+ return False
103
+
104
+
105
+ def is_legacy_gpu(compute_cap: str) -> bool:
106
+ """
107
+ Check if GPU is Pascal or older (requires legacy CUDA/PyTorch).
108
+
109
+ Args:
110
+ compute_cap: Compute capability string (e.g., "6.1", "7.5")
111
+
112
+ Returns:
113
+ True if Pascal or older (compute < 7.5)
114
+ """
115
+ try:
116
+ cc = float(compute_cap)
117
+ return cc < 7.5 # Turing starts at 7.5
118
+ except (ValueError, TypeError):
119
+ return False
120
+
121
+
122
+ def detect_cuda_version() -> Optional[str]:
123
+ """
124
+ Get recommended CUDA version based on detected GPU.
125
+
126
+ Returns:
127
+ "12.4" for Pascal or older (compute < 7.5),
128
+ "12.8" for Turing or newer (compute >= 7.5),
129
+ None if no GPU detected.
130
+
131
+ GPU Architecture Reference:
132
+ - Pascal (GTX 10xx, P100): compute 6.0-6.1 → CUDA 12.4
133
+ - Turing (RTX 20xx, T4): compute 7.5 → CUDA 12.8
134
+ - Ampere (RTX 30xx, A100): compute 8.0-8.6 → CUDA 12.8
135
+ - Ada (RTX 40xx, L40): compute 8.9 → CUDA 12.8
136
+ - Hopper (H100): compute 9.0 → CUDA 12.8
137
+ - Blackwell (RTX 50xx, B100/B200): compute 10.0+ → CUDA 12.8
138
+ """
139
+ gpus = detect_gpu_info()
140
+ if not gpus:
141
+ return None
142
+
143
+ # Check if any GPU is legacy (Pascal or older)
144
+ for gpu in gpus:
145
+ if is_legacy_gpu(gpu.get("compute_cap", "0.0")):
146
+ return "12.4"
147
+
148
+ # Turing or newer - use modern stack
149
+ return "12.8"
150
+
151
+
152
+ def get_gpu_summary() -> str:
153
+ """
154
+ Get a human-readable summary of detected GPUs.
155
+
156
+ Returns:
157
+ Summary string for logging.
158
+ """
159
+ gpus = detect_gpu_info()
160
+
161
+ if not gpus:
162
+ return "No NVIDIA GPU detected"
163
+
164
+ lines = []
165
+ for i, gpu in enumerate(gpus):
166
+ cc = gpu.get("compute_cap", "0.0")
167
+ is_legacy = is_legacy_gpu(cc)
168
+ if is_legacy:
169
+ tag = " [Pascal - CUDA 12.4]"
170
+ elif is_blackwell_gpu(gpu["name"], cc):
171
+ tag = " [Blackwell - CUDA 12.8]"
172
+ else:
173
+ tag = " [CUDA 12.8]"
174
+ lines.append(f" GPU {i}: {gpu['name']} (sm_{cc.replace('.', '')}){tag}")
175
+
176
+ return "\n".join(lines)