comfy-env 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comfy_env/__init__.py +161 -0
- comfy_env/cli.py +388 -0
- comfy_env/decorator.py +422 -0
- comfy_env/env/__init__.py +30 -0
- comfy_env/env/config.py +144 -0
- comfy_env/env/config_file.py +592 -0
- comfy_env/env/detection.py +176 -0
- comfy_env/env/manager.py +657 -0
- comfy_env/env/platform/__init__.py +21 -0
- comfy_env/env/platform/base.py +96 -0
- comfy_env/env/platform/darwin.py +53 -0
- comfy_env/env/platform/linux.py +68 -0
- comfy_env/env/platform/windows.py +377 -0
- comfy_env/env/security.py +267 -0
- comfy_env/errors.py +325 -0
- comfy_env/install.py +539 -0
- comfy_env/ipc/__init__.py +55 -0
- comfy_env/ipc/bridge.py +512 -0
- comfy_env/ipc/protocol.py +265 -0
- comfy_env/ipc/tensor.py +371 -0
- comfy_env/ipc/torch_bridge.py +401 -0
- comfy_env/ipc/transport.py +318 -0
- comfy_env/ipc/worker.py +221 -0
- comfy_env/registry.py +252 -0
- comfy_env/resolver.py +399 -0
- comfy_env/runner.py +273 -0
- comfy_env/stubs/__init__.py +1 -0
- comfy_env/stubs/folder_paths.py +57 -0
- comfy_env/workers/__init__.py +49 -0
- comfy_env/workers/base.py +82 -0
- comfy_env/workers/pool.py +241 -0
- comfy_env/workers/tensor_utils.py +188 -0
- comfy_env/workers/torch_mp.py +375 -0
- comfy_env/workers/venv.py +903 -0
- comfy_env-0.0.8.dist-info/METADATA +228 -0
- comfy_env-0.0.8.dist-info/RECORD +39 -0
- comfy_env-0.0.8.dist-info/WHEEL +4 -0
- comfy_env-0.0.8.dist-info/entry_points.txt +2 -0
- comfy_env-0.0.8.dist-info/licenses/LICENSE +21 -0
comfy_env/registry.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Built-in registry of CUDA packages and their wheel sources.
|
|
2
|
+
|
|
3
|
+
This module provides a mapping of well-known CUDA packages to their
|
|
4
|
+
installation sources, eliminating the need for users to specify
|
|
5
|
+
wheel_sources in their comfyui_env.toml.
|
|
6
|
+
|
|
7
|
+
Install method types:
|
|
8
|
+
- "index": Use pip --extra-index-url (PEP 503 simple repository)
|
|
9
|
+
- "github_index": GitHub Pages index (--find-links)
|
|
10
|
+
- "find_links": Use pip --find-links (for PyG, etc.)
|
|
11
|
+
- "pypi_variant": Package name varies by CUDA version (e.g., spconv-cu124)
|
|
12
|
+
- "github_release": Direct wheel URL from GitHub releases with fallback sources
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Dict, Any, Optional
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_cuda_short2(cuda_version: str) -> str:
|
|
19
|
+
"""Convert CUDA version to 2-3 digit format for spconv.
|
|
20
|
+
|
|
21
|
+
spconv uses "cu124" not "cu1240" for CUDA 12.4.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
cuda_version: CUDA version string (e.g., "12.4", "12.8")
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Short format string (e.g., "124", "128")
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
>>> get_cuda_short2("12.4")
|
|
31
|
+
'124'
|
|
32
|
+
>>> get_cuda_short2("12.8")
|
|
33
|
+
'128'
|
|
34
|
+
>>> get_cuda_short2("11.8")
|
|
35
|
+
'118'
|
|
36
|
+
"""
|
|
37
|
+
parts = cuda_version.split(".")
|
|
38
|
+
major = parts[0]
|
|
39
|
+
minor = parts[1] if len(parts) > 1 else "0"
|
|
40
|
+
return f"{major}{minor}"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# =============================================================================
|
|
44
|
+
# Package Registry
|
|
45
|
+
# =============================================================================
|
|
46
|
+
# Maps package names to their installation configuration.
|
|
47
|
+
#
|
|
48
|
+
# Template variables available:
|
|
49
|
+
# {cuda_version} - Full CUDA version (e.g., "12.8")
|
|
50
|
+
# {cuda_short} - CUDA without dot (e.g., "128")
|
|
51
|
+
# {cuda_short2} - CUDA short for spconv (e.g., "124" not "1240")
|
|
52
|
+
# {torch_version} - Full PyTorch version (e.g., "2.8.0")
|
|
53
|
+
# {torch_short} - PyTorch without dots (e.g., "280")
|
|
54
|
+
# {torch_mm} - PyTorch major.minor (e.g., "28")
|
|
55
|
+
# {py_version} - Python version (e.g., "3.10")
|
|
56
|
+
# {py_short} - Python without dot (e.g., "310")
|
|
57
|
+
# {py_minor} - Python minor version only (e.g., "10")
|
|
58
|
+
# {platform} - Platform tag (e.g., "linux_x86_64")
|
|
59
|
+
# =============================================================================
|
|
60
|
+
|
|
61
|
+
PACKAGE_REGISTRY: Dict[str, Dict[str, Any]] = {
|
|
62
|
+
# =========================================================================
|
|
63
|
+
# PyTorch Geometric (PyG) packages - official index
|
|
64
|
+
# https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
|
|
65
|
+
# Uses --find-links (not --extra-index-url) for proper wheel discovery
|
|
66
|
+
# =========================================================================
|
|
67
|
+
"torch-scatter": {
|
|
68
|
+
"method": "find_links",
|
|
69
|
+
"index_url": "https://data.pyg.org/whl/torch-{torch_version}+cu{cuda_short}.html",
|
|
70
|
+
"description": "Scatter operations for PyTorch",
|
|
71
|
+
},
|
|
72
|
+
"torch-cluster": {
|
|
73
|
+
"method": "find_links",
|
|
74
|
+
"index_url": "https://data.pyg.org/whl/torch-{torch_version}+cu{cuda_short}.html",
|
|
75
|
+
"description": "Clustering algorithms for PyTorch",
|
|
76
|
+
},
|
|
77
|
+
"torch-sparse": {
|
|
78
|
+
"method": "find_links",
|
|
79
|
+
"index_url": "https://data.pyg.org/whl/torch-{torch_version}+cu{cuda_short}.html",
|
|
80
|
+
"description": "Sparse tensor operations for PyTorch",
|
|
81
|
+
},
|
|
82
|
+
"torch-spline-conv": {
|
|
83
|
+
"method": "find_links",
|
|
84
|
+
"index_url": "https://data.pyg.org/whl/torch-{torch_version}+cu{cuda_short}.html",
|
|
85
|
+
"description": "Spline convolutions for PyTorch",
|
|
86
|
+
},
|
|
87
|
+
|
|
88
|
+
# =========================================================================
|
|
89
|
+
# pytorch3d - Facebook's official wheels
|
|
90
|
+
# https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md
|
|
91
|
+
# =========================================================================
|
|
92
|
+
"pytorch3d": {
|
|
93
|
+
"method": "index",
|
|
94
|
+
"index_url": "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py3{py_minor}_cu{cuda_short}_pyt{torch_short}/download.html",
|
|
95
|
+
"description": "PyTorch3D - 3D deep learning library",
|
|
96
|
+
},
|
|
97
|
+
|
|
98
|
+
# =========================================================================
|
|
99
|
+
# PozzettiAndrea wheel repos (GitHub Pages indexes)
|
|
100
|
+
# =========================================================================
|
|
101
|
+
# nvdiffrast - wheels are now at cu{cuda}-torch{torch_short} releases
|
|
102
|
+
"nvdiffrast": {
|
|
103
|
+
"method": "github_index",
|
|
104
|
+
"index_url": "https://pozzettiandrea.github.io/nvdiffrast-full-wheels/cu{cuda_short}-torch{torch_short}/",
|
|
105
|
+
"description": "NVIDIA differentiable rasterizer",
|
|
106
|
+
},
|
|
107
|
+
# cumesh, o_voxel, flex_gemm, nvdiffrec_render use torch_short (3 digits: 280)
|
|
108
|
+
"cumesh": {
|
|
109
|
+
"method": "github_index",
|
|
110
|
+
"index_url": "https://pozzettiandrea.github.io/cumesh-wheels/cu{cuda_short}-torch{torch_short}/",
|
|
111
|
+
"description": "CUDA-accelerated mesh utilities",
|
|
112
|
+
},
|
|
113
|
+
"o_voxel": {
|
|
114
|
+
"method": "github_index",
|
|
115
|
+
"index_url": "https://pozzettiandrea.github.io/ovoxel-wheels/cu{cuda_short}-torch{torch_short}/",
|
|
116
|
+
"description": "O-Voxel CUDA extension for TRELLIS",
|
|
117
|
+
},
|
|
118
|
+
"flex_gemm": {
|
|
119
|
+
"method": "github_index",
|
|
120
|
+
"index_url": "https://pozzettiandrea.github.io/flexgemm-wheels/cu{cuda_short}-torch{torch_short}/",
|
|
121
|
+
"description": "Flexible GEMM operations",
|
|
122
|
+
},
|
|
123
|
+
"nvdiffrec_render": {
|
|
124
|
+
"method": "github_release",
|
|
125
|
+
"sources": [
|
|
126
|
+
{
|
|
127
|
+
"name": "PozzettiAndrea",
|
|
128
|
+
"url_template": "https://github.com/PozzettiAndrea/nvdiffrec_render-wheels/releases/download/cu{cuda_short}-torch{torch_short}/nvdiffrec_render-{version}%2Bcu{cuda_short}torch{torch_mm}-{py_tag}-{py_tag}-linux_x86_64.whl",
|
|
129
|
+
"platforms": ["linux_x86_64"],
|
|
130
|
+
},
|
|
131
|
+
{
|
|
132
|
+
"name": "PozzettiAndrea-windows",
|
|
133
|
+
"url_template": "https://github.com/PozzettiAndrea/nvdiffrec_render-wheels/releases/download/cu{cuda_short}-torch{torch_short}/nvdiffrec_render-{version}%2Bcu{cuda_short}torch{torch_mm}-{py_tag}-{py_tag}-win_amd64.whl",
|
|
134
|
+
"platforms": ["win_amd64", "windows_amd64"],
|
|
135
|
+
},
|
|
136
|
+
],
|
|
137
|
+
"description": "NVDiffRec rendering utilities",
|
|
138
|
+
},
|
|
139
|
+
|
|
140
|
+
# =========================================================================
|
|
141
|
+
# spconv - PyPI with CUDA-versioned package names
|
|
142
|
+
# Package names: spconv-cu118, spconv-cu120, spconv-cu121, spconv-cu124, spconv-cu126
|
|
143
|
+
# Note: Max available is cu126 as of Jan 2026, use explicit spconv-cu126 in config
|
|
144
|
+
# =========================================================================
|
|
145
|
+
"spconv": {
|
|
146
|
+
"method": "pypi_variant",
|
|
147
|
+
"package_template": "spconv-cu{cuda_short2}",
|
|
148
|
+
"description": "Sparse convolution library (use spconv-cu126 for CUDA 12.6+)",
|
|
149
|
+
},
|
|
150
|
+
|
|
151
|
+
# =========================================================================
|
|
152
|
+
# sageattention - Fast quantized attention (2-5x faster than FlashAttention)
|
|
153
|
+
# Linux: Prebuilt wheels from Kijai/PrecompiledWheels (v2.2.0, cp312)
|
|
154
|
+
# Windows: Prebuilt wheels from woct0rdho (v2.2.0, cp39-abi3)
|
|
155
|
+
# =========================================================================
|
|
156
|
+
"sageattention": {
|
|
157
|
+
"method": "github_release",
|
|
158
|
+
"sources": [
|
|
159
|
+
# Linux: Kijai's precompiled wheels on HuggingFace (Python 3.12)
|
|
160
|
+
{
|
|
161
|
+
"name": "kijai-hf",
|
|
162
|
+
"url_template": "https://huggingface.co/Kijai/PrecompiledWheels/resolve/main/sageattention-{version}-cp312-cp312-linux_x86_64.whl",
|
|
163
|
+
"platforms": ["linux_x86_64"],
|
|
164
|
+
},
|
|
165
|
+
# Windows: woct0rdho prebuilt wheels (ABI3: Python >= 3.9)
|
|
166
|
+
# Format: sageattention-2.2.0+cu128torch2.8.0.post3-cp39-abi3-win_amd64.whl
|
|
167
|
+
{
|
|
168
|
+
"name": "woct0rdho",
|
|
169
|
+
"url_template": "https://github.com/woct0rdho/SageAttention/releases/download/v2.2.0-windows.post3/sageattention-2.2.0%2Bcu{cuda_short}torch{torch_version}.post3-cp39-abi3-win_amd64.whl",
|
|
170
|
+
"platforms": ["win_amd64"],
|
|
171
|
+
},
|
|
172
|
+
],
|
|
173
|
+
"description": "SageAttention - 2-5x faster than FlashAttention with quantized kernels",
|
|
174
|
+
},
|
|
175
|
+
|
|
176
|
+
# =========================================================================
|
|
177
|
+
# triton - Required for sageattention on Linux (usually bundled with PyTorch)
|
|
178
|
+
# =========================================================================
|
|
179
|
+
"triton": {
|
|
180
|
+
"method": "pypi",
|
|
181
|
+
"description": "Triton compiler for custom CUDA kernels (required by sageattention)",
|
|
182
|
+
},
|
|
183
|
+
|
|
184
|
+
# =========================================================================
|
|
185
|
+
# flash-attn - Multi-source prebuilt wheels
|
|
186
|
+
# Required for UniRig and other transformer-based models
|
|
187
|
+
# Sources: Dao-AILab (official), mjun0812 (Linux), bdashore3 (Windows)
|
|
188
|
+
# =========================================================================
|
|
189
|
+
"flash-attn": {
|
|
190
|
+
"method": "github_release",
|
|
191
|
+
"sources": [
|
|
192
|
+
# Linux: Dao-AILab official wheels (CUDA 12.x, PyTorch 2.4-2.8)
|
|
193
|
+
# Format: flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
|
|
194
|
+
{
|
|
195
|
+
"name": "Dao-AILab",
|
|
196
|
+
"url_template": "https://github.com/Dao-AILab/flash-attention/releases/download/v{version}/flash_attn-{version}%2Bcu{cuda_major}torch{torch_dotted_mm}cxx11abiTRUE-{py_tag}-{py_tag}-linux_x86_64.whl",
|
|
197
|
+
"platforms": ["linux_x86_64"],
|
|
198
|
+
},
|
|
199
|
+
# Linux: mjun0812 prebuilt wheels (CUDA 12.4-13.0, PyTorch 2.5-2.9)
|
|
200
|
+
# Format: flash_attn-2.8.3+cu128torch2.8-cp310-cp310-linux_x86_64.whl
|
|
201
|
+
# Note: Release v0.7.2 contains multiple flash_attn versions
|
|
202
|
+
{
|
|
203
|
+
"name": "mjun0812",
|
|
204
|
+
"url_template": "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.2/flash_attn-{version}%2Bcu{cuda_short}torch{torch_dotted_mm}-{py_tag}-{py_tag}-linux_x86_64.whl",
|
|
205
|
+
"platforms": ["linux_x86_64"],
|
|
206
|
+
},
|
|
207
|
+
# Windows: bdashore3 prebuilt wheels (CUDA 12.4/12.8, PyTorch 2.6-2.8)
|
|
208
|
+
{
|
|
209
|
+
"name": "bdashore3",
|
|
210
|
+
"url_template": "https://github.com/bdashore3/flash-attention/releases/download/v{version}/flash_attn-{version}%2Bcu{cuda_short}torch{torch_version}cxx11abiFALSE-{py_tag}-{py_tag}-win_amd64.whl",
|
|
211
|
+
"platforms": ["win_amd64"],
|
|
212
|
+
},
|
|
213
|
+
],
|
|
214
|
+
"description": "Flash Attention for fast transformer inference",
|
|
215
|
+
},
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def get_package_info(package: str) -> Optional[Dict[str, Any]]:
|
|
220
|
+
"""Get registry info for a package.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
package: Package name (case-insensitive)
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Registry entry dict or None if not found
|
|
227
|
+
"""
|
|
228
|
+
return PACKAGE_REGISTRY.get(package.lower())
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def list_packages() -> Dict[str, str]:
|
|
232
|
+
"""List all registered packages with their descriptions.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Dict mapping package name to description
|
|
236
|
+
"""
|
|
237
|
+
return {
|
|
238
|
+
name: info.get("description", "No description")
|
|
239
|
+
for name, info in PACKAGE_REGISTRY.items()
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def is_registered(package: str) -> bool:
|
|
244
|
+
"""Check if a package is in the registry.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
package: Package name (case-insensitive)
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
True if package is registered
|
|
251
|
+
"""
|
|
252
|
+
return package.lower() in PACKAGE_REGISTRY
|
comfy_env/resolver.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Wheel URL resolver for CUDA-compiled packages.
|
|
3
|
+
|
|
4
|
+
This module provides deterministic wheel URL construction based on the runtime
|
|
5
|
+
environment (CUDA version, PyTorch version, Python version, platform).
|
|
6
|
+
|
|
7
|
+
Unlike pip's constraint solver, this module constructs exact URLs from templates
|
|
8
|
+
and validates that they exist. If a wheel doesn't exist, it fails fast with
|
|
9
|
+
a clear error message.
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
from comfy_env.resolver import WheelResolver, RuntimeEnv
|
|
13
|
+
|
|
14
|
+
env = RuntimeEnv.detect()
|
|
15
|
+
resolver = WheelResolver()
|
|
16
|
+
|
|
17
|
+
url = resolver.resolve("nvdiffrast", version="0.4.0", env=env)
|
|
18
|
+
# Returns: https://github.com/.../nvdiffrast-0.4.0+cu128torch28-cp310-...-linux_x86_64.whl
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import platform
|
|
22
|
+
import re
|
|
23
|
+
import sys
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Dict, List, Optional, Tuple
|
|
27
|
+
from urllib.parse import urlparse
|
|
28
|
+
|
|
29
|
+
from .env.detection import detect_cuda_version, detect_gpu_info
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class RuntimeEnv:
|
|
34
|
+
"""
|
|
35
|
+
Detected runtime environment for wheel resolution.
|
|
36
|
+
|
|
37
|
+
Contains all variables needed for wheel URL template expansion.
|
|
38
|
+
"""
|
|
39
|
+
# OS/Platform
|
|
40
|
+
os_name: str # linux, windows, darwin
|
|
41
|
+
platform_tag: str # linux_x86_64, win_amd64, macosx_...
|
|
42
|
+
|
|
43
|
+
# Python
|
|
44
|
+
python_version: str # 3.10, 3.11, 3.12
|
|
45
|
+
python_short: str # 310, 311, 312
|
|
46
|
+
|
|
47
|
+
# CUDA
|
|
48
|
+
cuda_version: Optional[str] # 12.8, 12.4, None
|
|
49
|
+
cuda_short: Optional[str] # 128, 124, None
|
|
50
|
+
|
|
51
|
+
# PyTorch (detected or configured)
|
|
52
|
+
torch_version: Optional[str] # 2.8.0, 2.5.1
|
|
53
|
+
torch_short: Optional[str] # 280, 251
|
|
54
|
+
torch_mm: Optional[str] # 28, 25 (major.minor without dot)
|
|
55
|
+
|
|
56
|
+
# GPU info
|
|
57
|
+
gpu_name: Optional[str] = None
|
|
58
|
+
gpu_compute: Optional[str] = None # sm_89, sm_100
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def detect(cls, torch_version: Optional[str] = None) -> "RuntimeEnv":
|
|
62
|
+
"""
|
|
63
|
+
Detect runtime environment from current system.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
torch_version: Optional PyTorch version override. If not provided,
|
|
67
|
+
attempts to detect from installed torch.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
RuntimeEnv with detected values.
|
|
71
|
+
"""
|
|
72
|
+
# OS detection
|
|
73
|
+
os_name = sys.platform
|
|
74
|
+
if os_name.startswith('linux'):
|
|
75
|
+
os_name = 'linux'
|
|
76
|
+
elif os_name == 'win32':
|
|
77
|
+
os_name = 'windows'
|
|
78
|
+
elif os_name == 'darwin':
|
|
79
|
+
os_name = 'darwin'
|
|
80
|
+
|
|
81
|
+
# Platform tag
|
|
82
|
+
platform_tag = _get_platform_tag()
|
|
83
|
+
|
|
84
|
+
# Python version
|
|
85
|
+
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
86
|
+
py_short = f"{sys.version_info.major}{sys.version_info.minor}"
|
|
87
|
+
|
|
88
|
+
# CUDA version
|
|
89
|
+
cuda_version = detect_cuda_version()
|
|
90
|
+
cuda_short = cuda_version.replace(".", "") if cuda_version else None
|
|
91
|
+
|
|
92
|
+
# PyTorch version
|
|
93
|
+
if torch_version is None:
|
|
94
|
+
torch_version = _detect_torch_version()
|
|
95
|
+
|
|
96
|
+
torch_short = None
|
|
97
|
+
torch_mm = None
|
|
98
|
+
if torch_version:
|
|
99
|
+
torch_short = torch_version.replace(".", "")
|
|
100
|
+
parts = torch_version.split(".")[:2]
|
|
101
|
+
torch_mm = "".join(parts)
|
|
102
|
+
|
|
103
|
+
# GPU info
|
|
104
|
+
gpu_name = None
|
|
105
|
+
gpu_compute = None
|
|
106
|
+
try:
|
|
107
|
+
gpu_info = detect_gpu_info()
|
|
108
|
+
if gpu_info:
|
|
109
|
+
gpu_name = gpu_info.get("name")
|
|
110
|
+
gpu_compute = gpu_info.get("compute_capability")
|
|
111
|
+
except Exception:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
return cls(
|
|
115
|
+
os_name=os_name,
|
|
116
|
+
platform_tag=platform_tag,
|
|
117
|
+
python_version=py_version,
|
|
118
|
+
python_short=py_short,
|
|
119
|
+
cuda_version=cuda_version,
|
|
120
|
+
cuda_short=cuda_short,
|
|
121
|
+
torch_version=torch_version,
|
|
122
|
+
torch_short=torch_short,
|
|
123
|
+
torch_mm=torch_mm,
|
|
124
|
+
gpu_name=gpu_name,
|
|
125
|
+
gpu_compute=gpu_compute,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def as_dict(self) -> Dict[str, str]:
|
|
129
|
+
"""Convert to dict for template substitution."""
|
|
130
|
+
# Extract py_minor from python_version (e.g., "3.10" -> "10")
|
|
131
|
+
py_minor = self.python_version.split(".")[-1] if self.python_version else ""
|
|
132
|
+
|
|
133
|
+
result = {
|
|
134
|
+
"os": self.os_name,
|
|
135
|
+
"platform": self.platform_tag,
|
|
136
|
+
"python_version": self.python_version,
|
|
137
|
+
"py_version": self.python_version,
|
|
138
|
+
"py_short": self.python_short,
|
|
139
|
+
"py_minor": py_minor,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if self.cuda_version:
|
|
143
|
+
result["cuda_version"] = self.cuda_version
|
|
144
|
+
result["cuda_short"] = self.cuda_short
|
|
145
|
+
|
|
146
|
+
if self.torch_version:
|
|
147
|
+
result["torch_version"] = self.torch_version
|
|
148
|
+
result["torch_short"] = self.torch_short
|
|
149
|
+
result["torch_mm"] = self.torch_mm
|
|
150
|
+
# torch_dotted_mm: "2.8" format (major.minor with dot) for flash-attn URLs
|
|
151
|
+
parts = self.torch_version.split(".")[:2]
|
|
152
|
+
result["torch_dotted_mm"] = ".".join(parts)
|
|
153
|
+
|
|
154
|
+
return result
|
|
155
|
+
|
|
156
|
+
def __str__(self) -> str:
|
|
157
|
+
parts = [
|
|
158
|
+
f"Python {self.python_version}",
|
|
159
|
+
f"CUDA {self.cuda_version}" if self.cuda_version else "CPU",
|
|
160
|
+
]
|
|
161
|
+
if self.torch_version:
|
|
162
|
+
parts.append(f"PyTorch {self.torch_version}")
|
|
163
|
+
if self.gpu_name:
|
|
164
|
+
parts.append(f"GPU: {self.gpu_name}")
|
|
165
|
+
return ", ".join(parts)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _get_platform_tag() -> str:
|
|
169
|
+
"""Get wheel platform tag for current system."""
|
|
170
|
+
machine = platform.machine().lower()
|
|
171
|
+
|
|
172
|
+
if sys.platform.startswith('linux'):
|
|
173
|
+
# Use manylinux tag
|
|
174
|
+
if machine in ('x86_64', 'amd64'):
|
|
175
|
+
return 'linux_x86_64'
|
|
176
|
+
elif machine == 'aarch64':
|
|
177
|
+
return 'linux_aarch64'
|
|
178
|
+
return f'linux_{machine}'
|
|
179
|
+
|
|
180
|
+
elif sys.platform == 'win32':
|
|
181
|
+
if machine in ('amd64', 'x86_64'):
|
|
182
|
+
return 'win_amd64'
|
|
183
|
+
return 'win32'
|
|
184
|
+
|
|
185
|
+
elif sys.platform == 'darwin':
|
|
186
|
+
# macOS - use generic tag
|
|
187
|
+
if machine == 'arm64':
|
|
188
|
+
return 'macosx_11_0_arm64'
|
|
189
|
+
return 'macosx_10_9_x86_64'
|
|
190
|
+
|
|
191
|
+
return f'{sys.platform}_{machine}'
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _detect_torch_version() -> Optional[str]:
|
|
195
|
+
"""Detect installed PyTorch version."""
|
|
196
|
+
try:
|
|
197
|
+
import torch
|
|
198
|
+
version = torch.__version__
|
|
199
|
+
# Strip CUDA suffix (e.g., "2.8.0+cu128" -> "2.8.0")
|
|
200
|
+
if '+' in version:
|
|
201
|
+
version = version.split('+')[0]
|
|
202
|
+
return version
|
|
203
|
+
except ImportError:
|
|
204
|
+
return None
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@dataclass
|
|
208
|
+
class WheelSource:
|
|
209
|
+
"""Configuration for a wheel source (GitHub releases, custom index, etc.)."""
|
|
210
|
+
name: str
|
|
211
|
+
url_template: str
|
|
212
|
+
packages: List[str] = field(default_factory=list) # Empty = all packages
|
|
213
|
+
|
|
214
|
+
def supports(self, package: str) -> bool:
|
|
215
|
+
"""Check if this source provides the given package."""
|
|
216
|
+
if not self.packages:
|
|
217
|
+
return True # Empty list = supports all
|
|
218
|
+
return package.lower() in [p.lower() for p in self.packages]
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# Default wheel sources for common CUDA packages
|
|
222
|
+
DEFAULT_WHEEL_SOURCES = [
|
|
223
|
+
WheelSource(
|
|
224
|
+
name="nvdiffrast-wheels",
|
|
225
|
+
url_template="https://github.com/PozzettiAndrea/nvdiffrast-full-wheels/releases/download/v{version}/nvdiffrast-{version}%2Bcu{cuda_short}torch{torch_mm}-cp{py_short}-cp{py_short}-{platform}.whl",
|
|
226
|
+
packages=["nvdiffrast"],
|
|
227
|
+
),
|
|
228
|
+
WheelSource(
|
|
229
|
+
name="cumesh-wheels",
|
|
230
|
+
url_template="https://github.com/PozzettiAndrea/cumesh-wheels/releases/download/v{version}/{package}-{version}%2Bcu{cuda_short}torch{torch_mm}-cp{py_short}-cp{py_short}-{platform}.whl",
|
|
231
|
+
packages=["pytorch3d", "torch-cluster", "torch-scatter", "torch-sparse"],
|
|
232
|
+
),
|
|
233
|
+
]
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class WheelResolver:
|
|
237
|
+
"""
|
|
238
|
+
Resolves CUDA wheel URLs from package name and runtime environment.
|
|
239
|
+
|
|
240
|
+
Resolution strategy:
|
|
241
|
+
1. Check explicit overrides in config
|
|
242
|
+
2. Try configured wheel sources in order
|
|
243
|
+
3. Fail with actionable error message
|
|
244
|
+
|
|
245
|
+
This is NOT a constraint solver. It constructs deterministic URLs
|
|
246
|
+
based on exact version matches.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def __init__(
|
|
250
|
+
self,
|
|
251
|
+
sources: Optional[List[WheelSource]] = None,
|
|
252
|
+
overrides: Optional[Dict[str, str]] = None,
|
|
253
|
+
):
|
|
254
|
+
"""
|
|
255
|
+
Initialize resolver.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
sources: List of WheelSource configurations.
|
|
259
|
+
overrides: Package-specific URL overrides (package -> template).
|
|
260
|
+
"""
|
|
261
|
+
self.sources = sources or DEFAULT_WHEEL_SOURCES
|
|
262
|
+
self.overrides = overrides or {}
|
|
263
|
+
|
|
264
|
+
def resolve(
|
|
265
|
+
self,
|
|
266
|
+
package: str,
|
|
267
|
+
version: str,
|
|
268
|
+
env: RuntimeEnv,
|
|
269
|
+
verify: bool = False,
|
|
270
|
+
) -> str:
|
|
271
|
+
"""
|
|
272
|
+
Resolve wheel URL for a package.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
package: Package name (e.g., "nvdiffrast").
|
|
276
|
+
version: Package version (e.g., "0.4.0").
|
|
277
|
+
env: Runtime environment for template expansion.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Fully resolved wheel URL.
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
WheelNotFoundError: If no wheel URL could be constructed or verified.
|
|
284
|
+
"""
|
|
285
|
+
from .errors import WheelNotFoundError
|
|
286
|
+
|
|
287
|
+
# Prepare template variables
|
|
288
|
+
variables = env.as_dict()
|
|
289
|
+
variables["package"] = package
|
|
290
|
+
variables["version"] = version
|
|
291
|
+
|
|
292
|
+
# 1. Check explicit override
|
|
293
|
+
if package.lower() in self.overrides:
|
|
294
|
+
url = self._substitute(self.overrides[package.lower()], variables)
|
|
295
|
+
if verify and not self._url_exists(url):
|
|
296
|
+
raise WheelNotFoundError(
|
|
297
|
+
package=package,
|
|
298
|
+
version=version,
|
|
299
|
+
env=env,
|
|
300
|
+
tried_urls=[url],
|
|
301
|
+
reason="Override URL returned 404",
|
|
302
|
+
)
|
|
303
|
+
return url
|
|
304
|
+
|
|
305
|
+
# 2. Try wheel sources
|
|
306
|
+
tried_urls = []
|
|
307
|
+
for source in self.sources:
|
|
308
|
+
if not source.supports(package):
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
url = self._substitute(source.url_template, variables)
|
|
312
|
+
tried_urls.append(url)
|
|
313
|
+
|
|
314
|
+
if verify:
|
|
315
|
+
if self._url_exists(url):
|
|
316
|
+
return url
|
|
317
|
+
else:
|
|
318
|
+
return url
|
|
319
|
+
|
|
320
|
+
# 3. Fail with helpful error
|
|
321
|
+
raise WheelNotFoundError(
|
|
322
|
+
package=package,
|
|
323
|
+
version=version,
|
|
324
|
+
env=env,
|
|
325
|
+
tried_urls=tried_urls,
|
|
326
|
+
reason="No wheel source found for package",
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def resolve_all(
|
|
330
|
+
self,
|
|
331
|
+
packages: Dict[str, str],
|
|
332
|
+
env: RuntimeEnv,
|
|
333
|
+
verify: bool = False,
|
|
334
|
+
) -> Dict[str, str]:
|
|
335
|
+
"""
|
|
336
|
+
Resolve URLs for multiple packages.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
packages: Dict of package -> version.
|
|
340
|
+
env: Runtime environment.
|
|
341
|
+
verify: Whether to verify URLs exist.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Dict of package -> resolved URL.
|
|
345
|
+
|
|
346
|
+
Raises:
|
|
347
|
+
WheelNotFoundError: If any package cannot be resolved.
|
|
348
|
+
"""
|
|
349
|
+
results = {}
|
|
350
|
+
for package, version in packages.items():
|
|
351
|
+
results[package] = self.resolve(package, version, env, verify=verify)
|
|
352
|
+
return results
|
|
353
|
+
|
|
354
|
+
def _substitute(self, template: str, variables: Dict[str, str]) -> str:
|
|
355
|
+
"""
|
|
356
|
+
Substitute variables into URL template.
|
|
357
|
+
|
|
358
|
+
Handles both {var} and {var_name} style placeholders.
|
|
359
|
+
Missing variables are left as-is (caller should validate).
|
|
360
|
+
"""
|
|
361
|
+
result = template
|
|
362
|
+
for key, value in variables.items():
|
|
363
|
+
result = result.replace(f"{{{key}}}", str(value))
|
|
364
|
+
return result
|
|
365
|
+
|
|
366
|
+
def _url_exists(self, url: str, timeout: float = 10.0) -> bool:
|
|
367
|
+
"""
|
|
368
|
+
Check if a URL exists using HTTP HEAD request.
|
|
369
|
+
|
|
370
|
+
Returns True if URL returns 200 OK.
|
|
371
|
+
"""
|
|
372
|
+
try:
|
|
373
|
+
import urllib.request
|
|
374
|
+
request = urllib.request.Request(url, method='HEAD')
|
|
375
|
+
with urllib.request.urlopen(request, timeout=timeout) as response:
|
|
376
|
+
return response.status == 200
|
|
377
|
+
except Exception:
|
|
378
|
+
return False
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def parse_wheel_requirement(req: str) -> Tuple[str, Optional[str]]:
|
|
382
|
+
"""
|
|
383
|
+
Parse a wheel requirement string.
|
|
384
|
+
|
|
385
|
+
Examples:
|
|
386
|
+
"nvdiffrast==0.4.0" -> ("nvdiffrast", "0.4.0")
|
|
387
|
+
"pytorch3d>=0.7.8" -> ("pytorch3d", "0.7.8")
|
|
388
|
+
"torch-cluster" -> ("torch-cluster", None)
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Tuple of (package_name, version_or_None).
|
|
392
|
+
"""
|
|
393
|
+
# Handle version specifiers
|
|
394
|
+
for op in ['==', '>=', '<=', '~=', '!=', '>', '<']:
|
|
395
|
+
if op in req:
|
|
396
|
+
parts = req.split(op, 1)
|
|
397
|
+
return (parts[0].strip(), parts[1].strip())
|
|
398
|
+
|
|
399
|
+
return (req.strip(), None)
|