comfy-env 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comfy_env/__init__.py +115 -62
- comfy_env/cli.py +89 -319
- comfy_env/config/__init__.py +18 -8
- comfy_env/config/parser.py +21 -122
- comfy_env/config/types.py +37 -70
- comfy_env/detection/__init__.py +77 -0
- comfy_env/detection/cuda.py +61 -0
- comfy_env/detection/gpu.py +230 -0
- comfy_env/detection/platform.py +70 -0
- comfy_env/detection/runtime.py +103 -0
- comfy_env/environment/__init__.py +53 -0
- comfy_env/environment/cache.py +141 -0
- comfy_env/environment/libomp.py +41 -0
- comfy_env/environment/paths.py +38 -0
- comfy_env/environment/setup.py +88 -0
- comfy_env/install.py +163 -249
- comfy_env/isolation/__init__.py +33 -2
- comfy_env/isolation/tensor_utils.py +83 -0
- comfy_env/isolation/workers/__init__.py +16 -0
- comfy_env/{workers → isolation/workers}/mp.py +1 -1
- comfy_env/{workers → isolation/workers}/subprocess.py +2 -2
- comfy_env/isolation/wrap.py +149 -409
- comfy_env/packages/__init__.py +60 -0
- comfy_env/packages/apt.py +36 -0
- comfy_env/packages/cuda_wheels.py +97 -0
- comfy_env/packages/node_dependencies.py +77 -0
- comfy_env/packages/pixi.py +85 -0
- comfy_env/packages/toml_generator.py +88 -0
- comfy_env-0.1.16.dist-info/METADATA +279 -0
- comfy_env-0.1.16.dist-info/RECORD +36 -0
- comfy_env/cache.py +0 -331
- comfy_env/errors.py +0 -293
- comfy_env/nodes.py +0 -187
- comfy_env/pixi/__init__.py +0 -48
- comfy_env/pixi/core.py +0 -588
- comfy_env/pixi/cuda_detection.py +0 -303
- comfy_env/pixi/platform/__init__.py +0 -21
- comfy_env/pixi/platform/base.py +0 -96
- comfy_env/pixi/platform/darwin.py +0 -53
- comfy_env/pixi/platform/linux.py +0 -68
- comfy_env/pixi/platform/windows.py +0 -284
- comfy_env/pixi/resolver.py +0 -198
- comfy_env/prestartup.py +0 -192
- comfy_env/workers/__init__.py +0 -38
- comfy_env/workers/tensor_utils.py +0 -188
- comfy_env-0.1.14.dist-info/METADATA +0 -291
- comfy_env-0.1.14.dist-info/RECORD +0 -33
- /comfy_env/{workers → isolation/workers}/base.py +0 -0
- {comfy_env-0.1.14.dist-info → comfy_env-0.1.16.dist-info}/WHEEL +0 -0
- {comfy_env-0.1.14.dist-info → comfy_env-0.1.16.dist-info}/entry_points.txt +0 -0
- {comfy_env-0.1.14.dist-info → comfy_env-0.1.16.dist-info}/licenses/LICENSE +0 -0
comfy_env/config/parser.py
CHANGED
|
@@ -1,135 +1,42 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
comfy-env.toml is a superset of pixi.toml. Custom sections we handle:
|
|
4
|
-
- python = "3.11" - Python version for isolated envs
|
|
5
|
-
- [cuda] packages = [...] - CUDA packages (triggers find-links + PyTorch detection)
|
|
6
|
-
- [node_reqs] - Other ComfyUI nodes to clone
|
|
7
|
-
|
|
8
|
-
Everything else passes through to pixi.toml directly.
|
|
9
|
-
|
|
10
|
-
Example config:
|
|
11
|
-
|
|
12
|
-
python = "3.11"
|
|
13
|
-
|
|
14
|
-
[cuda]
|
|
15
|
-
packages = ["cumesh"]
|
|
16
|
-
|
|
17
|
-
[dependencies]
|
|
18
|
-
mesalib = "*"
|
|
19
|
-
cgal = "*"
|
|
20
|
-
|
|
21
|
-
[pypi-dependencies]
|
|
22
|
-
numpy = ">=1.21.0,<2"
|
|
23
|
-
trimesh = { version = ">=4.0.0", extras = ["easy"] }
|
|
24
|
-
|
|
25
|
-
[target.linux-64.pypi-dependencies]
|
|
26
|
-
embreex = "*"
|
|
27
|
-
|
|
28
|
-
[node_reqs]
|
|
29
|
-
SomeNode = "owner/repo"
|
|
30
|
-
"""
|
|
1
|
+
"""Configuration parsing for comfy-env."""
|
|
31
2
|
|
|
32
3
|
import copy
|
|
33
|
-
import sys
|
|
34
4
|
from pathlib import Path
|
|
35
|
-
from typing import
|
|
36
|
-
|
|
37
|
-
# Use built-in tomllib (Python 3.11+) or tomli fallback
|
|
38
|
-
if sys.version_info >= (3, 11):
|
|
39
|
-
import tomllib
|
|
40
|
-
else:
|
|
41
|
-
try:
|
|
42
|
-
import tomli as tomllib
|
|
43
|
-
except ImportError:
|
|
44
|
-
tomllib = None # type: ignore
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
45
6
|
|
|
46
|
-
|
|
7
|
+
import tomli
|
|
47
8
|
|
|
9
|
+
from .types import ComfyEnvConfig, NodeDependency
|
|
48
10
|
|
|
49
11
|
CONFIG_FILE_NAME = "comfy-env.toml"
|
|
50
12
|
|
|
51
|
-
# Sections we handle specially (not passed through to pixi.toml)
|
|
52
|
-
CUSTOM_SECTIONS = {"python", "cuda", "node_reqs", "apt", "env_vars"}
|
|
53
|
-
|
|
54
13
|
|
|
55
14
|
def load_config(path: Path) -> ComfyEnvConfig:
|
|
56
|
-
"""
|
|
57
|
-
Load configuration from a TOML file.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
path: Path to comfy-env.toml
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
ComfyEnvConfig instance
|
|
64
|
-
|
|
65
|
-
Raises:
|
|
66
|
-
FileNotFoundError: If config file doesn't exist
|
|
67
|
-
ImportError: If tomli not installed (Python < 3.11)
|
|
68
|
-
"""
|
|
69
|
-
if tomllib is None:
|
|
70
|
-
raise ImportError(
|
|
71
|
-
"TOML parsing requires tomli for Python < 3.11. "
|
|
72
|
-
"Install with: pip install tomli"
|
|
73
|
-
)
|
|
74
|
-
|
|
15
|
+
"""Load and parse comfy-env.toml."""
|
|
75
16
|
path = Path(path)
|
|
76
17
|
if not path.exists():
|
|
77
18
|
raise FileNotFoundError(f"Config file not found: {path}")
|
|
78
|
-
|
|
79
19
|
with open(path, "rb") as f:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
return _parse_config(data)
|
|
20
|
+
return parse_config(tomli.load(f))
|
|
83
21
|
|
|
84
22
|
|
|
85
23
|
def discover_config(node_dir: Path) -> Optional[ComfyEnvConfig]:
|
|
86
|
-
"""
|
|
87
|
-
Find and load comfy-env.toml from a directory.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
node_dir: Directory to search
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
ComfyEnvConfig if found, None otherwise
|
|
94
|
-
"""
|
|
95
|
-
if tomllib is None:
|
|
96
|
-
return None
|
|
97
|
-
|
|
24
|
+
"""Find and load comfy-env.toml from directory."""
|
|
98
25
|
config_path = Path(node_dir) / CONFIG_FILE_NAME
|
|
99
|
-
if config_path.exists()
|
|
100
|
-
return load_config(config_path)
|
|
101
|
-
|
|
102
|
-
return None
|
|
26
|
+
return load_config(config_path) if config_path.exists() else None
|
|
103
27
|
|
|
104
28
|
|
|
105
|
-
def
|
|
29
|
+
def parse_config(data: Dict[str, Any]) -> ComfyEnvConfig:
|
|
106
30
|
"""Parse TOML data into ComfyEnvConfig."""
|
|
107
|
-
# Make a copy so we can pop our custom sections
|
|
108
31
|
data = copy.deepcopy(data)
|
|
109
32
|
|
|
110
|
-
# Extract python version (top-level key)
|
|
111
33
|
python_version = data.pop("python", None)
|
|
112
|
-
if python_version
|
|
113
|
-
python_version = str(python_version)
|
|
114
|
-
|
|
115
|
-
# Extract [cuda] section
|
|
116
|
-
cuda_data = data.pop("cuda", {})
|
|
117
|
-
cuda_packages = _ensure_list(cuda_data.get("packages", []))
|
|
118
|
-
|
|
119
|
-
# Extract [apt] section
|
|
120
|
-
apt_data = data.pop("apt", {})
|
|
121
|
-
apt_packages = _ensure_list(apt_data.get("packages", []))
|
|
122
|
-
|
|
123
|
-
# Extract [env_vars] section
|
|
124
|
-
env_vars_data = data.pop("env_vars", {})
|
|
125
|
-
env_vars = {str(k): str(v) for k, v in env_vars_data.items()}
|
|
126
|
-
|
|
127
|
-
# Extract [node_reqs] section
|
|
128
|
-
node_reqs_data = data.pop("node_reqs", {})
|
|
129
|
-
node_reqs = _parse_node_reqs(node_reqs_data)
|
|
34
|
+
python_version = str(python_version) if python_version else None
|
|
130
35
|
|
|
131
|
-
|
|
132
|
-
|
|
36
|
+
cuda_packages = _ensure_list(data.pop("cuda", {}).get("packages", []))
|
|
37
|
+
apt_packages = _ensure_list(data.pop("apt", {}).get("packages", []))
|
|
38
|
+
env_vars = {str(k): str(v) for k, v in data.pop("env_vars", {}).items()}
|
|
39
|
+
node_reqs = _parse_node_reqs(data.pop("node_reqs", {}))
|
|
133
40
|
|
|
134
41
|
return ComfyEnvConfig(
|
|
135
42
|
python=python_version,
|
|
@@ -137,25 +44,17 @@ def _parse_config(data: Dict[str, Any]) -> ComfyEnvConfig:
|
|
|
137
44
|
apt_packages=apt_packages,
|
|
138
45
|
env_vars=env_vars,
|
|
139
46
|
node_reqs=node_reqs,
|
|
140
|
-
pixi_passthrough=
|
|
47
|
+
pixi_passthrough=data,
|
|
141
48
|
)
|
|
142
49
|
|
|
143
50
|
|
|
144
|
-
def _parse_node_reqs(data: Dict[str, Any]) -> List[
|
|
51
|
+
def _parse_node_reqs(data: Dict[str, Any]) -> List[NodeDependency]:
|
|
145
52
|
"""Parse [node_reqs] section."""
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
elif isinstance(value, dict):
|
|
151
|
-
node_reqs.append(NodeReq(name=name, repo=value.get("repo", "")))
|
|
152
|
-
return node_reqs
|
|
53
|
+
return [
|
|
54
|
+
NodeDependency(name=name, repo=value if isinstance(value, str) else value.get("repo", ""))
|
|
55
|
+
for name, value in data.items()
|
|
56
|
+
]
|
|
153
57
|
|
|
154
58
|
|
|
155
59
|
def _ensure_list(value) -> List:
|
|
156
|
-
|
|
157
|
-
if isinstance(value, list):
|
|
158
|
-
return value
|
|
159
|
-
if value:
|
|
160
|
-
return [value]
|
|
161
|
-
return []
|
|
60
|
+
return value if isinstance(value, list) else ([value] if value else [])
|
comfy_env/config/types.py
CHANGED
|
@@ -1,70 +1,37 @@
|
|
|
1
|
-
"""Configuration types for comfy-env."""
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
from typing import Any, Dict, List, Optional
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
@dataclass
|
|
8
|
-
class
|
|
9
|
-
"""A
|
|
10
|
-
name: str
|
|
11
|
-
repo: str #
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
comfy-env.toml
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
cgal = "*"
|
|
39
|
-
|
|
40
|
-
[pypi-dependencies]
|
|
41
|
-
numpy = ">=1.21.0,<2"
|
|
42
|
-
trimesh = { version = ">=4.0.0", extras = ["easy"] }
|
|
43
|
-
|
|
44
|
-
[target.linux-64.pypi-dependencies]
|
|
45
|
-
embreex = "*"
|
|
46
|
-
|
|
47
|
-
[node_reqs]
|
|
48
|
-
SomeNode = "owner/repo"
|
|
49
|
-
"""
|
|
50
|
-
# python = "3.11" - Python version (for isolated envs)
|
|
51
|
-
python: Optional[str] = None
|
|
52
|
-
|
|
53
|
-
# [cuda] - CUDA packages (installed via find-links index)
|
|
54
|
-
cuda_packages: List[str] = field(default_factory=list)
|
|
55
|
-
|
|
56
|
-
# [apt] - System packages to install via apt (Linux only)
|
|
57
|
-
apt_packages: List[str] = field(default_factory=list)
|
|
58
|
-
|
|
59
|
-
# [env_vars] - Environment variables to set early (in prestartup)
|
|
60
|
-
env_vars: Dict[str, str] = field(default_factory=dict)
|
|
61
|
-
|
|
62
|
-
# [node_reqs] - other ComfyUI nodes to clone
|
|
63
|
-
node_reqs: List[NodeReq] = field(default_factory=list)
|
|
64
|
-
|
|
65
|
-
# Everything else from comfy-env.toml passes through to pixi.toml
|
|
66
|
-
pixi_passthrough: Dict[str, Any] = field(default_factory=dict)
|
|
67
|
-
|
|
68
|
-
@property
|
|
69
|
-
def has_cuda(self) -> bool:
|
|
70
|
-
return bool(self.cuda_packages)
|
|
1
|
+
"""Configuration types for comfy-env."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class NodeDependency:
|
|
9
|
+
"""A ComfyUI custom node dependency."""
|
|
10
|
+
name: str
|
|
11
|
+
repo: str # "owner/repo" or full URL
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
NodeReq = NodeDependency # Backwards compat
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ComfyEnvConfig:
|
|
19
|
+
"""Parsed comfy-env.toml configuration."""
|
|
20
|
+
python: Optional[str] = None
|
|
21
|
+
cuda_packages: List[str] = field(default_factory=list)
|
|
22
|
+
apt_packages: List[str] = field(default_factory=list)
|
|
23
|
+
env_vars: Dict[str, str] = field(default_factory=dict)
|
|
24
|
+
node_reqs: List[NodeDependency] = field(default_factory=list)
|
|
25
|
+
pixi_passthrough: Dict[str, Any] = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def has_cuda(self) -> bool:
|
|
29
|
+
return bool(self.cuda_packages)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def has_dependencies(self) -> bool:
|
|
33
|
+
return bool(
|
|
34
|
+
self.cuda_packages or self.apt_packages or self.node_reqs
|
|
35
|
+
or self.pixi_passthrough.get("dependencies")
|
|
36
|
+
or self.pixi_passthrough.get("pypi-dependencies")
|
|
37
|
+
)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Detection layer - Pure functions for system detection.
|
|
3
|
+
|
|
4
|
+
No side effects. These functions gather information about the runtime environment.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .cuda import (
|
|
8
|
+
CUDA_VERSION_ENV_VAR,
|
|
9
|
+
detect_cuda_version,
|
|
10
|
+
get_cuda_from_torch,
|
|
11
|
+
get_cuda_from_nvml,
|
|
12
|
+
get_cuda_from_nvcc,
|
|
13
|
+
get_cuda_from_env,
|
|
14
|
+
)
|
|
15
|
+
from .gpu import (
|
|
16
|
+
GPUInfo,
|
|
17
|
+
CUDAEnvironment,
|
|
18
|
+
COMPUTE_TO_ARCH,
|
|
19
|
+
detect_gpu,
|
|
20
|
+
detect_gpus,
|
|
21
|
+
detect_cuda_environment,
|
|
22
|
+
get_compute_capability,
|
|
23
|
+
compute_capability_to_architecture,
|
|
24
|
+
get_recommended_cuda_version,
|
|
25
|
+
get_gpu_summary,
|
|
26
|
+
)
|
|
27
|
+
from .platform import (
|
|
28
|
+
PlatformInfo,
|
|
29
|
+
detect_platform,
|
|
30
|
+
get_platform_tag,
|
|
31
|
+
get_pixi_platform,
|
|
32
|
+
get_library_extension,
|
|
33
|
+
get_executable_suffix,
|
|
34
|
+
is_linux,
|
|
35
|
+
is_windows,
|
|
36
|
+
is_macos,
|
|
37
|
+
)
|
|
38
|
+
from .runtime import (
|
|
39
|
+
RuntimeEnv,
|
|
40
|
+
detect_runtime,
|
|
41
|
+
parse_wheel_requirement,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
# CUDA detection
|
|
46
|
+
"CUDA_VERSION_ENV_VAR",
|
|
47
|
+
"detect_cuda_version",
|
|
48
|
+
"get_cuda_from_torch",
|
|
49
|
+
"get_cuda_from_nvml",
|
|
50
|
+
"get_cuda_from_nvcc",
|
|
51
|
+
"get_cuda_from_env",
|
|
52
|
+
# GPU detection
|
|
53
|
+
"GPUInfo",
|
|
54
|
+
"CUDAEnvironment",
|
|
55
|
+
"COMPUTE_TO_ARCH",
|
|
56
|
+
"detect_gpu",
|
|
57
|
+
"detect_gpus",
|
|
58
|
+
"detect_cuda_environment",
|
|
59
|
+
"get_compute_capability",
|
|
60
|
+
"compute_capability_to_architecture",
|
|
61
|
+
"get_recommended_cuda_version",
|
|
62
|
+
"get_gpu_summary",
|
|
63
|
+
# Platform detection
|
|
64
|
+
"PlatformInfo",
|
|
65
|
+
"detect_platform",
|
|
66
|
+
"get_platform_tag",
|
|
67
|
+
"get_pixi_platform",
|
|
68
|
+
"get_library_extension",
|
|
69
|
+
"get_executable_suffix",
|
|
70
|
+
"is_linux",
|
|
71
|
+
"is_windows",
|
|
72
|
+
"is_macos",
|
|
73
|
+
# Runtime detection
|
|
74
|
+
"RuntimeEnv",
|
|
75
|
+
"detect_runtime",
|
|
76
|
+
"parse_wheel_requirement",
|
|
77
|
+
]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""CUDA version detection. Priority: env -> torch -> nvcc"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import subprocess
|
|
8
|
+
|
|
9
|
+
CUDA_VERSION_ENV_VAR = "COMFY_ENV_CUDA_VERSION"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def detect_cuda_version() -> str | None:
|
|
13
|
+
"""Detect CUDA version from available sources."""
|
|
14
|
+
return get_cuda_from_env() or get_cuda_from_torch() or get_cuda_from_nvcc()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_cuda_from_env() -> str | None:
|
|
18
|
+
"""Get CUDA version from environment variable override."""
|
|
19
|
+
override = os.environ.get(CUDA_VERSION_ENV_VAR, "").strip()
|
|
20
|
+
if not override:
|
|
21
|
+
return None
|
|
22
|
+
if "." not in override and len(override) >= 2:
|
|
23
|
+
return f"{override[:-1]}.{override[-1]}"
|
|
24
|
+
return override
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_cuda_from_torch() -> str | None:
|
|
28
|
+
"""Get CUDA version from PyTorch."""
|
|
29
|
+
try:
|
|
30
|
+
import torch
|
|
31
|
+
if torch.cuda.is_available() and torch.version.cuda:
|
|
32
|
+
return torch.version.cuda
|
|
33
|
+
except Exception:
|
|
34
|
+
pass
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_cuda_from_nvml() -> str | None:
|
|
39
|
+
"""Get CUDA version from NVML."""
|
|
40
|
+
try:
|
|
41
|
+
import pynvml
|
|
42
|
+
pynvml.nvmlInit()
|
|
43
|
+
try:
|
|
44
|
+
cuda_version = pynvml.nvmlSystemGetCudaDriverVersion_v2()
|
|
45
|
+
return f"{cuda_version // 1000}.{(cuda_version % 1000) // 10}"
|
|
46
|
+
finally:
|
|
47
|
+
pynvml.nvmlShutdown()
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_cuda_from_nvcc() -> str | None:
|
|
54
|
+
"""Get CUDA version from nvcc compiler."""
|
|
55
|
+
try:
|
|
56
|
+
r = subprocess.run(["nvcc", "--version"], capture_output=True, text=True, timeout=5)
|
|
57
|
+
if r.returncode == 0 and (m := re.search(r"release (\d+\.\d+)", r.stdout)):
|
|
58
|
+
return m.group(1)
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
return None
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""GPU detection. Methods: NVML -> PyTorch -> nvidia-smi -> sysfs"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import subprocess
|
|
7
|
+
import time
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from .cuda import CUDA_VERSION_ENV_VAR
|
|
12
|
+
|
|
13
|
+
COMPUTE_TO_ARCH = {
|
|
14
|
+
(5, 0): "Maxwell", (5, 2): "Maxwell", (5, 3): "Maxwell",
|
|
15
|
+
(6, 0): "Pascal", (6, 1): "Pascal", (6, 2): "Pascal",
|
|
16
|
+
(7, 0): "Volta", (7, 2): "Volta", (7, 5): "Turing",
|
|
17
|
+
(8, 0): "Ampere", (8, 6): "Ampere", (8, 7): "Ampere", (8, 9): "Ada",
|
|
18
|
+
(9, 0): "Hopper",
|
|
19
|
+
(10, 0): "Blackwell", (10, 1): "Blackwell", (10, 2): "Blackwell",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
_cache: tuple[float, "CUDAEnvironment | None"] = (0, None)
|
|
23
|
+
CACHE_TTL = 60
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class GPUInfo:
|
|
28
|
+
index: int
|
|
29
|
+
name: str
|
|
30
|
+
compute_capability: tuple[int, int]
|
|
31
|
+
architecture: str
|
|
32
|
+
vram_total_mb: int = 0
|
|
33
|
+
vram_free_mb: int = 0
|
|
34
|
+
uuid: str = ""
|
|
35
|
+
pci_bus_id: str = ""
|
|
36
|
+
driver_version: str = ""
|
|
37
|
+
|
|
38
|
+
def cc_str(self) -> str: return f"{self.compute_capability[0]}.{self.compute_capability[1]}"
|
|
39
|
+
def sm_version(self) -> str: return f"sm_{self.compute_capability[0]}{self.compute_capability[1]}"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class CUDAEnvironment:
|
|
44
|
+
gpus: list[GPUInfo] = field(default_factory=list)
|
|
45
|
+
driver_version: str = ""
|
|
46
|
+
cuda_runtime_version: str = ""
|
|
47
|
+
recommended_cuda: str = ""
|
|
48
|
+
detection_method: str = ""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def compute_capability_to_architecture(major: int, minor: int) -> str:
|
|
52
|
+
if arch := COMPUTE_TO_ARCH.get((major, minor)):
|
|
53
|
+
return arch
|
|
54
|
+
if major >= 10: return "Blackwell"
|
|
55
|
+
if major == 9: return "Hopper"
|
|
56
|
+
if major == 8: return "Ada" if minor >= 9 else "Ampere"
|
|
57
|
+
if major == 7: return "Turing" if minor >= 5 else "Volta"
|
|
58
|
+
if major == 6: return "Pascal"
|
|
59
|
+
return "Maxwell" if major == 5 else "Unknown"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_compute_capability(gpu_index: int = 0) -> tuple[int, int] | None:
|
|
63
|
+
env = detect_cuda_environment()
|
|
64
|
+
return env.gpus[gpu_index].compute_capability if env.gpus and gpu_index < len(env.gpus) else None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_recommended_cuda_version(gpus: list[GPUInfo] | None = None) -> str:
|
|
68
|
+
"""Blackwell: 12.8, Pascal: 12.4, others: 12.8"""
|
|
69
|
+
if override := os.environ.get(CUDA_VERSION_ENV_VAR, "").strip():
|
|
70
|
+
return f"{override[:-1]}.{override[-1]}" if "." not in override and len(override) >= 2 else override
|
|
71
|
+
|
|
72
|
+
gpus = gpus if gpus is not None else detect_cuda_environment().gpus
|
|
73
|
+
if not gpus: return ""
|
|
74
|
+
|
|
75
|
+
for gpu in gpus:
|
|
76
|
+
if gpu.compute_capability[0] >= 10: return "12.8"
|
|
77
|
+
for gpu in gpus:
|
|
78
|
+
cc = gpu.compute_capability
|
|
79
|
+
if cc[0] < 7 or (cc[0] == 7 and cc[1] < 5): return "12.4"
|
|
80
|
+
return "12.8"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def detect_gpu() -> GPUInfo | None:
|
|
84
|
+
env = detect_cuda_environment()
|
|
85
|
+
return env.gpus[0] if env.gpus else None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def detect_gpus() -> list[GPUInfo]:
|
|
89
|
+
return detect_cuda_environment().gpus
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def detect_cuda_environment(force_refresh: bool = False) -> CUDAEnvironment:
|
|
93
|
+
global _cache
|
|
94
|
+
if not force_refresh and _cache[1] and time.time() - _cache[0] < CACHE_TTL:
|
|
95
|
+
return _cache[1]
|
|
96
|
+
|
|
97
|
+
gpus, method = None, "none"
|
|
98
|
+
for name, fn in [("nvml", _detect_nvml), ("torch", _detect_torch), ("smi", _detect_smi), ("sysfs", _detect_sysfs)]:
|
|
99
|
+
if result := fn():
|
|
100
|
+
gpus, method = result, name
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
env = CUDAEnvironment(
|
|
104
|
+
gpus=gpus or [], driver_version=_get_driver_version(),
|
|
105
|
+
cuda_runtime_version=_get_cuda_version(),
|
|
106
|
+
recommended_cuda=get_recommended_cuda_version(gpus or []), detection_method=method,
|
|
107
|
+
)
|
|
108
|
+
_cache = (time.time(), env)
|
|
109
|
+
return env
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_gpu_summary() -> str:
|
|
113
|
+
env = detect_cuda_environment()
|
|
114
|
+
if not env.gpus:
|
|
115
|
+
override = os.environ.get(CUDA_VERSION_ENV_VAR)
|
|
116
|
+
return f"No GPU detected (using {CUDA_VERSION_ENV_VAR}={override})" if override else f"No GPU (set {CUDA_VERSION_ENV_VAR} to override)"
|
|
117
|
+
|
|
118
|
+
lines = [f"Detection: {env.detection_method}"]
|
|
119
|
+
if env.driver_version: lines.append(f"Driver: {env.driver_version}")
|
|
120
|
+
if env.cuda_runtime_version: lines.append(f"CUDA: {env.cuda_runtime_version}")
|
|
121
|
+
lines.append(f"Recommended: CUDA {env.recommended_cuda}")
|
|
122
|
+
lines.append("")
|
|
123
|
+
for gpu in env.gpus:
|
|
124
|
+
vram = f"{gpu.vram_total_mb}MB" if gpu.vram_total_mb else "?"
|
|
125
|
+
lines.append(f" GPU {gpu.index}: {gpu.name} ({gpu.sm_version()}) [{gpu.architecture}] {vram}")
|
|
126
|
+
return "\n".join(lines)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _parse_cc(s: str) -> tuple[int, int]:
|
|
130
|
+
try:
|
|
131
|
+
if "." in s: p = s.split("."); return (int(p[0]), int(p[1]))
|
|
132
|
+
if len(s) >= 2: return (int(s[:-1]), int(s[-1]))
|
|
133
|
+
except (ValueError, IndexError): pass
|
|
134
|
+
return (0, 0)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _detect_nvml() -> list[GPUInfo] | None:
|
|
138
|
+
try:
|
|
139
|
+
import pynvml
|
|
140
|
+
pynvml.nvmlInit()
|
|
141
|
+
try:
|
|
142
|
+
count = pynvml.nvmlDeviceGetCount()
|
|
143
|
+
if not count: return None
|
|
144
|
+
gpus = []
|
|
145
|
+
for i in range(count):
|
|
146
|
+
h = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
147
|
+
name = pynvml.nvmlDeviceGetName(h)
|
|
148
|
+
if isinstance(name, bytes): name = name.decode()
|
|
149
|
+
cc = pynvml.nvmlDeviceGetCudaComputeCapability(h)
|
|
150
|
+
mem = pynvml.nvmlDeviceGetMemoryInfo(h)
|
|
151
|
+
gpus.append(GPUInfo(i, name, cc, compute_capability_to_architecture(*cc),
|
|
152
|
+
mem.total // (1024 * 1024), mem.free // (1024 * 1024)))
|
|
153
|
+
return gpus
|
|
154
|
+
finally:
|
|
155
|
+
pynvml.nvmlShutdown()
|
|
156
|
+
except Exception: return None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _detect_torch() -> list[GPUInfo] | None:
|
|
160
|
+
try:
|
|
161
|
+
import torch
|
|
162
|
+
if not torch.cuda.is_available(): return None
|
|
163
|
+
gpus = []
|
|
164
|
+
for i in range(torch.cuda.device_count()):
|
|
165
|
+
p = torch.cuda.get_device_properties(i)
|
|
166
|
+
gpus.append(GPUInfo(i, p.name, (p.major, p.minor),
|
|
167
|
+
compute_capability_to_architecture(p.major, p.minor),
|
|
168
|
+
p.total_memory // (1024 * 1024)))
|
|
169
|
+
return gpus or None
|
|
170
|
+
except Exception: return None
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _detect_smi() -> list[GPUInfo] | None:
|
|
174
|
+
try:
|
|
175
|
+
r = subprocess.run(
|
|
176
|
+
["nvidia-smi", "--query-gpu=index,name,uuid,pci.bus_id,compute_cap,memory.total,memory.free,driver_version",
|
|
177
|
+
"--format=csv,noheader,nounits"], capture_output=True, text=True, timeout=10)
|
|
178
|
+
if r.returncode != 0: return None
|
|
179
|
+
gpus = []
|
|
180
|
+
for line in r.stdout.strip().split("\n"):
|
|
181
|
+
if not line.strip(): continue
|
|
182
|
+
p = [x.strip() for x in line.split(",")]
|
|
183
|
+
if len(p) < 5: continue
|
|
184
|
+
cc = _parse_cc(p[4])
|
|
185
|
+
gpus.append(GPUInfo(
|
|
186
|
+
int(p[0]) if p[0].isdigit() else len(gpus), p[1], cc, compute_capability_to_architecture(*cc),
|
|
187
|
+
int(p[5]) if len(p) > 5 and p[5].isdigit() else 0,
|
|
188
|
+
int(p[6]) if len(p) > 6 and p[6].isdigit() else 0,
|
|
189
|
+
p[2] if len(p) > 2 else "", p[3] if len(p) > 3 else "", p[7] if len(p) > 7 else ""))
|
|
190
|
+
return gpus or None
|
|
191
|
+
except Exception: return None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _detect_sysfs() -> list[GPUInfo] | None:
|
|
195
|
+
try:
|
|
196
|
+
pci_path = Path("/sys/bus/pci/devices")
|
|
197
|
+
if not pci_path.exists(): return None
|
|
198
|
+
gpus = []
|
|
199
|
+
for d in sorted(pci_path.iterdir()):
|
|
200
|
+
vendor = (d / "vendor").read_text().strip().lower() if (d / "vendor").exists() else ""
|
|
201
|
+
if "10de" not in vendor: continue
|
|
202
|
+
cls = (d / "class").read_text().strip() if (d / "class").exists() else ""
|
|
203
|
+
if not (cls.startswith("0x0300") or cls.startswith("0x0302")): continue
|
|
204
|
+
gpus.append(GPUInfo(len(gpus), "NVIDIA GPU", (0, 0), "Unknown", pci_bus_id=d.name))
|
|
205
|
+
return gpus or None
|
|
206
|
+
except Exception: return None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _get_driver_version() -> str:
|
|
210
|
+
try:
|
|
211
|
+
import pynvml
|
|
212
|
+
pynvml.nvmlInit()
|
|
213
|
+
v = pynvml.nvmlSystemGetDriverVersion()
|
|
214
|
+
pynvml.nvmlShutdown()
|
|
215
|
+
return v.decode() if isinstance(v, bytes) else v
|
|
216
|
+
except Exception: pass
|
|
217
|
+
try:
|
|
218
|
+
r = subprocess.run(["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
|
219
|
+
capture_output=True, text=True, timeout=5)
|
|
220
|
+
if r.returncode == 0: return r.stdout.strip().split("\n")[0]
|
|
221
|
+
except Exception: pass
|
|
222
|
+
return ""
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _get_cuda_version() -> str:
|
|
226
|
+
try:
|
|
227
|
+
import torch
|
|
228
|
+
if torch.cuda.is_available() and torch.version.cuda: return torch.version.cuda
|
|
229
|
+
except Exception: pass
|
|
230
|
+
return ""
|