comfy-env 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comfy_env/__init__.py +116 -41
- comfy_env/cli.py +89 -317
- comfy_env/config/__init__.py +18 -6
- comfy_env/config/parser.py +22 -76
- comfy_env/config/types.py +37 -0
- comfy_env/detection/__init__.py +77 -0
- comfy_env/detection/cuda.py +61 -0
- comfy_env/detection/gpu.py +230 -0
- comfy_env/detection/platform.py +70 -0
- comfy_env/detection/runtime.py +103 -0
- comfy_env/environment/__init__.py +53 -0
- comfy_env/environment/cache.py +141 -0
- comfy_env/environment/libomp.py +41 -0
- comfy_env/environment/paths.py +38 -0
- comfy_env/environment/setup.py +88 -0
- comfy_env/install.py +127 -329
- comfy_env/isolation/__init__.py +32 -2
- comfy_env/isolation/tensor_utils.py +83 -0
- comfy_env/isolation/workers/__init__.py +16 -0
- comfy_env/{workers → isolation/workers}/mp.py +1 -1
- comfy_env/{workers → isolation/workers}/subprocess.py +1 -1
- comfy_env/isolation/wrap.py +128 -509
- comfy_env/packages/__init__.py +60 -0
- comfy_env/packages/apt.py +36 -0
- comfy_env/packages/cuda_wheels.py +97 -0
- comfy_env/packages/node_dependencies.py +77 -0
- comfy_env/packages/pixi.py +85 -0
- comfy_env/packages/toml_generator.py +88 -0
- comfy_env-0.1.16.dist-info/METADATA +279 -0
- comfy_env-0.1.16.dist-info/RECORD +36 -0
- comfy_env/cache.py +0 -203
- comfy_env/nodes.py +0 -187
- comfy_env/pixi/__init__.py +0 -48
- comfy_env/pixi/core.py +0 -587
- comfy_env/pixi/cuda_detection.py +0 -303
- comfy_env/pixi/platform/__init__.py +0 -21
- comfy_env/pixi/platform/base.py +0 -96
- comfy_env/pixi/platform/darwin.py +0 -53
- comfy_env/pixi/platform/linux.py +0 -68
- comfy_env/pixi/platform/windows.py +0 -284
- comfy_env/pixi/resolver.py +0 -198
- comfy_env/prestartup.py +0 -208
- comfy_env/workers/__init__.py +0 -38
- comfy_env/workers/tensor_utils.py +0 -188
- comfy_env-0.1.15.dist-info/METADATA +0 -291
- comfy_env-0.1.15.dist-info/RECORD +0 -31
- /comfy_env/{workers → isolation/workers}/base.py +0 -0
- {comfy_env-0.1.15.dist-info → comfy_env-0.1.16.dist-info}/WHEEL +0 -0
- {comfy_env-0.1.15.dist-info → comfy_env-0.1.16.dist-info}/entry_points.txt +0 -0
- {comfy_env-0.1.15.dist-info → comfy_env-0.1.16.dist-info}/licenses/LICENSE +0 -0
comfy_env/config/parser.py
CHANGED
|
@@ -1,88 +1,42 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Configuration parsing for comfy-env.
|
|
3
|
-
|
|
4
|
-
Loads comfy-env.toml (a superset of pixi.toml) and provides typed config objects.
|
|
5
|
-
"""
|
|
1
|
+
"""Configuration parsing for comfy-env."""
|
|
6
2
|
|
|
7
3
|
import copy
|
|
8
|
-
import sys
|
|
9
|
-
from dataclasses import dataclass, field
|
|
10
4
|
from pathlib import Path
|
|
11
|
-
from typing import
|
|
12
|
-
import tomli
|
|
13
|
-
|
|
14
|
-
# --- Types&Constants ---
|
|
15
|
-
CONFIG_FILE_NAME = "comfy-env.toml"
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
16
6
|
|
|
17
|
-
|
|
18
|
-
class NodeReq:
|
|
19
|
-
"""A node dependency (another ComfyUI custom node)."""
|
|
20
|
-
name: str
|
|
21
|
-
repo: str # GitHub repo, e.g., "owner/repo"
|
|
7
|
+
import tomli
|
|
22
8
|
|
|
23
|
-
|
|
24
|
-
class ComfyEnvConfig:
|
|
25
|
-
"""Configuration from comfy-env.toml."""
|
|
26
|
-
python: Optional[str] = None
|
|
27
|
-
cuda_packages: List[str] = field(default_factory=list)
|
|
28
|
-
apt_packages: List[str] = field(default_factory=list)
|
|
29
|
-
env_vars: Dict[str, str] = field(default_factory=dict)
|
|
30
|
-
node_reqs: List[NodeReq] = field(default_factory=list)
|
|
31
|
-
pixi_passthrough: Dict[str, Any] = field(default_factory=dict)
|
|
9
|
+
from .types import ComfyEnvConfig, NodeDependency
|
|
32
10
|
|
|
33
|
-
|
|
34
|
-
def has_cuda(self) -> bool:
|
|
35
|
-
return bool(self.cuda_packages)
|
|
36
|
-
# --- Types&Constants ---
|
|
11
|
+
CONFIG_FILE_NAME = "comfy-env.toml"
|
|
37
12
|
|
|
38
13
|
|
|
39
14
|
def load_config(path: Path) -> ComfyEnvConfig:
|
|
40
|
-
"""Load
|
|
15
|
+
"""Load and parse comfy-env.toml."""
|
|
41
16
|
path = Path(path)
|
|
42
17
|
if not path.exists():
|
|
43
18
|
raise FileNotFoundError(f"Config file not found: {path}")
|
|
44
19
|
with open(path, "rb") as f:
|
|
45
|
-
|
|
46
|
-
return _parse_config(data)
|
|
20
|
+
return parse_config(tomli.load(f))
|
|
47
21
|
|
|
48
22
|
|
|
49
23
|
def discover_config(node_dir: Path) -> Optional[ComfyEnvConfig]:
|
|
50
|
-
"""Find and load comfy-env.toml from
|
|
24
|
+
"""Find and load comfy-env.toml from directory."""
|
|
51
25
|
config_path = Path(node_dir) / CONFIG_FILE_NAME
|
|
52
|
-
if config_path.exists()
|
|
53
|
-
return load_config(config_path)
|
|
26
|
+
return load_config(config_path) if config_path.exists() else None
|
|
54
27
|
|
|
55
|
-
return None
|
|
56
28
|
|
|
57
|
-
|
|
58
|
-
def _parse_config(data: Dict[str, Any]) -> ComfyEnvConfig:
|
|
29
|
+
def parse_config(data: Dict[str, Any]) -> ComfyEnvConfig:
|
|
59
30
|
"""Parse TOML data into ComfyEnvConfig."""
|
|
60
|
-
# Make a copy so we can pop our custom sections
|
|
61
31
|
data = copy.deepcopy(data)
|
|
62
32
|
|
|
63
|
-
# Extract python version (top-level key)
|
|
64
33
|
python_version = data.pop("python", None)
|
|
65
|
-
if python_version
|
|
66
|
-
python_version = str(python_version)
|
|
67
|
-
|
|
68
|
-
# Extract [cuda] section
|
|
69
|
-
cuda_data = data.pop("cuda", {})
|
|
70
|
-
cuda_packages = _ensure_list(cuda_data.get("packages", []))
|
|
71
|
-
|
|
72
|
-
# Extract [apt] section
|
|
73
|
-
apt_data = data.pop("apt", {})
|
|
74
|
-
apt_packages = _ensure_list(apt_data.get("packages", []))
|
|
75
|
-
|
|
76
|
-
# Extract [env_vars] section
|
|
77
|
-
env_vars_data = data.pop("env_vars", {})
|
|
78
|
-
env_vars = {str(k): str(v) for k, v in env_vars_data.items()}
|
|
79
|
-
|
|
80
|
-
# Extract [node_reqs] section
|
|
81
|
-
node_reqs_data = data.pop("node_reqs", {})
|
|
82
|
-
node_reqs = _parse_node_reqs(node_reqs_data)
|
|
34
|
+
python_version = str(python_version) if python_version else None
|
|
83
35
|
|
|
84
|
-
|
|
85
|
-
|
|
36
|
+
cuda_packages = _ensure_list(data.pop("cuda", {}).get("packages", []))
|
|
37
|
+
apt_packages = _ensure_list(data.pop("apt", {}).get("packages", []))
|
|
38
|
+
env_vars = {str(k): str(v) for k, v in data.pop("env_vars", {}).items()}
|
|
39
|
+
node_reqs = _parse_node_reqs(data.pop("node_reqs", {}))
|
|
86
40
|
|
|
87
41
|
return ComfyEnvConfig(
|
|
88
42
|
python=python_version,
|
|
@@ -90,25 +44,17 @@ def _parse_config(data: Dict[str, Any]) -> ComfyEnvConfig:
|
|
|
90
44
|
apt_packages=apt_packages,
|
|
91
45
|
env_vars=env_vars,
|
|
92
46
|
node_reqs=node_reqs,
|
|
93
|
-
pixi_passthrough=
|
|
47
|
+
pixi_passthrough=data,
|
|
94
48
|
)
|
|
95
49
|
|
|
96
50
|
|
|
97
|
-
def _parse_node_reqs(data: Dict[str, Any]) -> List[
|
|
51
|
+
def _parse_node_reqs(data: Dict[str, Any]) -> List[NodeDependency]:
|
|
98
52
|
"""Parse [node_reqs] section."""
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
elif isinstance(value, dict):
|
|
104
|
-
node_reqs.append(NodeReq(name=name, repo=value.get("repo", "")))
|
|
105
|
-
return node_reqs
|
|
53
|
+
return [
|
|
54
|
+
NodeDependency(name=name, repo=value if isinstance(value, str) else value.get("repo", ""))
|
|
55
|
+
for name, value in data.items()
|
|
56
|
+
]
|
|
106
57
|
|
|
107
58
|
|
|
108
59
|
def _ensure_list(value) -> List:
|
|
109
|
-
|
|
110
|
-
if isinstance(value, list):
|
|
111
|
-
return value
|
|
112
|
-
if value:
|
|
113
|
-
return [value]
|
|
114
|
-
return []
|
|
60
|
+
return value if isinstance(value, list) else ([value] if value else [])
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Configuration types for comfy-env."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class NodeDependency:
|
|
9
|
+
"""A ComfyUI custom node dependency."""
|
|
10
|
+
name: str
|
|
11
|
+
repo: str # "owner/repo" or full URL
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
NodeReq = NodeDependency # Backwards compat
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ComfyEnvConfig:
|
|
19
|
+
"""Parsed comfy-env.toml configuration."""
|
|
20
|
+
python: Optional[str] = None
|
|
21
|
+
cuda_packages: List[str] = field(default_factory=list)
|
|
22
|
+
apt_packages: List[str] = field(default_factory=list)
|
|
23
|
+
env_vars: Dict[str, str] = field(default_factory=dict)
|
|
24
|
+
node_reqs: List[NodeDependency] = field(default_factory=list)
|
|
25
|
+
pixi_passthrough: Dict[str, Any] = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def has_cuda(self) -> bool:
|
|
29
|
+
return bool(self.cuda_packages)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def has_dependencies(self) -> bool:
|
|
33
|
+
return bool(
|
|
34
|
+
self.cuda_packages or self.apt_packages or self.node_reqs
|
|
35
|
+
or self.pixi_passthrough.get("dependencies")
|
|
36
|
+
or self.pixi_passthrough.get("pypi-dependencies")
|
|
37
|
+
)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Detection layer - Pure functions for system detection.
|
|
3
|
+
|
|
4
|
+
No side effects. These functions gather information about the runtime environment.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .cuda import (
|
|
8
|
+
CUDA_VERSION_ENV_VAR,
|
|
9
|
+
detect_cuda_version,
|
|
10
|
+
get_cuda_from_torch,
|
|
11
|
+
get_cuda_from_nvml,
|
|
12
|
+
get_cuda_from_nvcc,
|
|
13
|
+
get_cuda_from_env,
|
|
14
|
+
)
|
|
15
|
+
from .gpu import (
|
|
16
|
+
GPUInfo,
|
|
17
|
+
CUDAEnvironment,
|
|
18
|
+
COMPUTE_TO_ARCH,
|
|
19
|
+
detect_gpu,
|
|
20
|
+
detect_gpus,
|
|
21
|
+
detect_cuda_environment,
|
|
22
|
+
get_compute_capability,
|
|
23
|
+
compute_capability_to_architecture,
|
|
24
|
+
get_recommended_cuda_version,
|
|
25
|
+
get_gpu_summary,
|
|
26
|
+
)
|
|
27
|
+
from .platform import (
|
|
28
|
+
PlatformInfo,
|
|
29
|
+
detect_platform,
|
|
30
|
+
get_platform_tag,
|
|
31
|
+
get_pixi_platform,
|
|
32
|
+
get_library_extension,
|
|
33
|
+
get_executable_suffix,
|
|
34
|
+
is_linux,
|
|
35
|
+
is_windows,
|
|
36
|
+
is_macos,
|
|
37
|
+
)
|
|
38
|
+
from .runtime import (
|
|
39
|
+
RuntimeEnv,
|
|
40
|
+
detect_runtime,
|
|
41
|
+
parse_wheel_requirement,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
# CUDA detection
|
|
46
|
+
"CUDA_VERSION_ENV_VAR",
|
|
47
|
+
"detect_cuda_version",
|
|
48
|
+
"get_cuda_from_torch",
|
|
49
|
+
"get_cuda_from_nvml",
|
|
50
|
+
"get_cuda_from_nvcc",
|
|
51
|
+
"get_cuda_from_env",
|
|
52
|
+
# GPU detection
|
|
53
|
+
"GPUInfo",
|
|
54
|
+
"CUDAEnvironment",
|
|
55
|
+
"COMPUTE_TO_ARCH",
|
|
56
|
+
"detect_gpu",
|
|
57
|
+
"detect_gpus",
|
|
58
|
+
"detect_cuda_environment",
|
|
59
|
+
"get_compute_capability",
|
|
60
|
+
"compute_capability_to_architecture",
|
|
61
|
+
"get_recommended_cuda_version",
|
|
62
|
+
"get_gpu_summary",
|
|
63
|
+
# Platform detection
|
|
64
|
+
"PlatformInfo",
|
|
65
|
+
"detect_platform",
|
|
66
|
+
"get_platform_tag",
|
|
67
|
+
"get_pixi_platform",
|
|
68
|
+
"get_library_extension",
|
|
69
|
+
"get_executable_suffix",
|
|
70
|
+
"is_linux",
|
|
71
|
+
"is_windows",
|
|
72
|
+
"is_macos",
|
|
73
|
+
# Runtime detection
|
|
74
|
+
"RuntimeEnv",
|
|
75
|
+
"detect_runtime",
|
|
76
|
+
"parse_wheel_requirement",
|
|
77
|
+
]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""CUDA version detection. Priority: env -> torch -> nvcc"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import subprocess
|
|
8
|
+
|
|
9
|
+
CUDA_VERSION_ENV_VAR = "COMFY_ENV_CUDA_VERSION"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def detect_cuda_version() -> str | None:
|
|
13
|
+
"""Detect CUDA version from available sources."""
|
|
14
|
+
return get_cuda_from_env() or get_cuda_from_torch() or get_cuda_from_nvcc()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_cuda_from_env() -> str | None:
|
|
18
|
+
"""Get CUDA version from environment variable override."""
|
|
19
|
+
override = os.environ.get(CUDA_VERSION_ENV_VAR, "").strip()
|
|
20
|
+
if not override:
|
|
21
|
+
return None
|
|
22
|
+
if "." not in override and len(override) >= 2:
|
|
23
|
+
return f"{override[:-1]}.{override[-1]}"
|
|
24
|
+
return override
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_cuda_from_torch() -> str | None:
|
|
28
|
+
"""Get CUDA version from PyTorch."""
|
|
29
|
+
try:
|
|
30
|
+
import torch
|
|
31
|
+
if torch.cuda.is_available() and torch.version.cuda:
|
|
32
|
+
return torch.version.cuda
|
|
33
|
+
except Exception:
|
|
34
|
+
pass
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_cuda_from_nvml() -> str | None:
|
|
39
|
+
"""Get CUDA version from NVML."""
|
|
40
|
+
try:
|
|
41
|
+
import pynvml
|
|
42
|
+
pynvml.nvmlInit()
|
|
43
|
+
try:
|
|
44
|
+
cuda_version = pynvml.nvmlSystemGetCudaDriverVersion_v2()
|
|
45
|
+
return f"{cuda_version // 1000}.{(cuda_version % 1000) // 10}"
|
|
46
|
+
finally:
|
|
47
|
+
pynvml.nvmlShutdown()
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_cuda_from_nvcc() -> str | None:
|
|
54
|
+
"""Get CUDA version from nvcc compiler."""
|
|
55
|
+
try:
|
|
56
|
+
r = subprocess.run(["nvcc", "--version"], capture_output=True, text=True, timeout=5)
|
|
57
|
+
if r.returncode == 0 and (m := re.search(r"release (\d+\.\d+)", r.stdout)):
|
|
58
|
+
return m.group(1)
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
return None
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""GPU detection. Methods: NVML -> PyTorch -> nvidia-smi -> sysfs"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import subprocess
|
|
7
|
+
import time
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from .cuda import CUDA_VERSION_ENV_VAR
|
|
12
|
+
|
|
13
|
+
COMPUTE_TO_ARCH = {
|
|
14
|
+
(5, 0): "Maxwell", (5, 2): "Maxwell", (5, 3): "Maxwell",
|
|
15
|
+
(6, 0): "Pascal", (6, 1): "Pascal", (6, 2): "Pascal",
|
|
16
|
+
(7, 0): "Volta", (7, 2): "Volta", (7, 5): "Turing",
|
|
17
|
+
(8, 0): "Ampere", (8, 6): "Ampere", (8, 7): "Ampere", (8, 9): "Ada",
|
|
18
|
+
(9, 0): "Hopper",
|
|
19
|
+
(10, 0): "Blackwell", (10, 1): "Blackwell", (10, 2): "Blackwell",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
_cache: tuple[float, "CUDAEnvironment | None"] = (0, None)
|
|
23
|
+
CACHE_TTL = 60
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class GPUInfo:
|
|
28
|
+
index: int
|
|
29
|
+
name: str
|
|
30
|
+
compute_capability: tuple[int, int]
|
|
31
|
+
architecture: str
|
|
32
|
+
vram_total_mb: int = 0
|
|
33
|
+
vram_free_mb: int = 0
|
|
34
|
+
uuid: str = ""
|
|
35
|
+
pci_bus_id: str = ""
|
|
36
|
+
driver_version: str = ""
|
|
37
|
+
|
|
38
|
+
def cc_str(self) -> str: return f"{self.compute_capability[0]}.{self.compute_capability[1]}"
|
|
39
|
+
def sm_version(self) -> str: return f"sm_{self.compute_capability[0]}{self.compute_capability[1]}"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class CUDAEnvironment:
|
|
44
|
+
gpus: list[GPUInfo] = field(default_factory=list)
|
|
45
|
+
driver_version: str = ""
|
|
46
|
+
cuda_runtime_version: str = ""
|
|
47
|
+
recommended_cuda: str = ""
|
|
48
|
+
detection_method: str = ""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def compute_capability_to_architecture(major: int, minor: int) -> str:
|
|
52
|
+
if arch := COMPUTE_TO_ARCH.get((major, minor)):
|
|
53
|
+
return arch
|
|
54
|
+
if major >= 10: return "Blackwell"
|
|
55
|
+
if major == 9: return "Hopper"
|
|
56
|
+
if major == 8: return "Ada" if minor >= 9 else "Ampere"
|
|
57
|
+
if major == 7: return "Turing" if minor >= 5 else "Volta"
|
|
58
|
+
if major == 6: return "Pascal"
|
|
59
|
+
return "Maxwell" if major == 5 else "Unknown"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_compute_capability(gpu_index: int = 0) -> tuple[int, int] | None:
|
|
63
|
+
env = detect_cuda_environment()
|
|
64
|
+
return env.gpus[gpu_index].compute_capability if env.gpus and gpu_index < len(env.gpus) else None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_recommended_cuda_version(gpus: list[GPUInfo] | None = None) -> str:
|
|
68
|
+
"""Blackwell: 12.8, Pascal: 12.4, others: 12.8"""
|
|
69
|
+
if override := os.environ.get(CUDA_VERSION_ENV_VAR, "").strip():
|
|
70
|
+
return f"{override[:-1]}.{override[-1]}" if "." not in override and len(override) >= 2 else override
|
|
71
|
+
|
|
72
|
+
gpus = gpus if gpus is not None else detect_cuda_environment().gpus
|
|
73
|
+
if not gpus: return ""
|
|
74
|
+
|
|
75
|
+
for gpu in gpus:
|
|
76
|
+
if gpu.compute_capability[0] >= 10: return "12.8"
|
|
77
|
+
for gpu in gpus:
|
|
78
|
+
cc = gpu.compute_capability
|
|
79
|
+
if cc[0] < 7 or (cc[0] == 7 and cc[1] < 5): return "12.4"
|
|
80
|
+
return "12.8"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def detect_gpu() -> GPUInfo | None:
|
|
84
|
+
env = detect_cuda_environment()
|
|
85
|
+
return env.gpus[0] if env.gpus else None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def detect_gpus() -> list[GPUInfo]:
|
|
89
|
+
return detect_cuda_environment().gpus
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def detect_cuda_environment(force_refresh: bool = False) -> CUDAEnvironment:
|
|
93
|
+
global _cache
|
|
94
|
+
if not force_refresh and _cache[1] and time.time() - _cache[0] < CACHE_TTL:
|
|
95
|
+
return _cache[1]
|
|
96
|
+
|
|
97
|
+
gpus, method = None, "none"
|
|
98
|
+
for name, fn in [("nvml", _detect_nvml), ("torch", _detect_torch), ("smi", _detect_smi), ("sysfs", _detect_sysfs)]:
|
|
99
|
+
if result := fn():
|
|
100
|
+
gpus, method = result, name
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
env = CUDAEnvironment(
|
|
104
|
+
gpus=gpus or [], driver_version=_get_driver_version(),
|
|
105
|
+
cuda_runtime_version=_get_cuda_version(),
|
|
106
|
+
recommended_cuda=get_recommended_cuda_version(gpus or []), detection_method=method,
|
|
107
|
+
)
|
|
108
|
+
_cache = (time.time(), env)
|
|
109
|
+
return env
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_gpu_summary() -> str:
|
|
113
|
+
env = detect_cuda_environment()
|
|
114
|
+
if not env.gpus:
|
|
115
|
+
override = os.environ.get(CUDA_VERSION_ENV_VAR)
|
|
116
|
+
return f"No GPU detected (using {CUDA_VERSION_ENV_VAR}={override})" if override else f"No GPU (set {CUDA_VERSION_ENV_VAR} to override)"
|
|
117
|
+
|
|
118
|
+
lines = [f"Detection: {env.detection_method}"]
|
|
119
|
+
if env.driver_version: lines.append(f"Driver: {env.driver_version}")
|
|
120
|
+
if env.cuda_runtime_version: lines.append(f"CUDA: {env.cuda_runtime_version}")
|
|
121
|
+
lines.append(f"Recommended: CUDA {env.recommended_cuda}")
|
|
122
|
+
lines.append("")
|
|
123
|
+
for gpu in env.gpus:
|
|
124
|
+
vram = f"{gpu.vram_total_mb}MB" if gpu.vram_total_mb else "?"
|
|
125
|
+
lines.append(f" GPU {gpu.index}: {gpu.name} ({gpu.sm_version()}) [{gpu.architecture}] {vram}")
|
|
126
|
+
return "\n".join(lines)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _parse_cc(s: str) -> tuple[int, int]:
|
|
130
|
+
try:
|
|
131
|
+
if "." in s: p = s.split("."); return (int(p[0]), int(p[1]))
|
|
132
|
+
if len(s) >= 2: return (int(s[:-1]), int(s[-1]))
|
|
133
|
+
except (ValueError, IndexError): pass
|
|
134
|
+
return (0, 0)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _detect_nvml() -> list[GPUInfo] | None:
|
|
138
|
+
try:
|
|
139
|
+
import pynvml
|
|
140
|
+
pynvml.nvmlInit()
|
|
141
|
+
try:
|
|
142
|
+
count = pynvml.nvmlDeviceGetCount()
|
|
143
|
+
if not count: return None
|
|
144
|
+
gpus = []
|
|
145
|
+
for i in range(count):
|
|
146
|
+
h = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
147
|
+
name = pynvml.nvmlDeviceGetName(h)
|
|
148
|
+
if isinstance(name, bytes): name = name.decode()
|
|
149
|
+
cc = pynvml.nvmlDeviceGetCudaComputeCapability(h)
|
|
150
|
+
mem = pynvml.nvmlDeviceGetMemoryInfo(h)
|
|
151
|
+
gpus.append(GPUInfo(i, name, cc, compute_capability_to_architecture(*cc),
|
|
152
|
+
mem.total // (1024 * 1024), mem.free // (1024 * 1024)))
|
|
153
|
+
return gpus
|
|
154
|
+
finally:
|
|
155
|
+
pynvml.nvmlShutdown()
|
|
156
|
+
except Exception: return None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _detect_torch() -> list[GPUInfo] | None:
|
|
160
|
+
try:
|
|
161
|
+
import torch
|
|
162
|
+
if not torch.cuda.is_available(): return None
|
|
163
|
+
gpus = []
|
|
164
|
+
for i in range(torch.cuda.device_count()):
|
|
165
|
+
p = torch.cuda.get_device_properties(i)
|
|
166
|
+
gpus.append(GPUInfo(i, p.name, (p.major, p.minor),
|
|
167
|
+
compute_capability_to_architecture(p.major, p.minor),
|
|
168
|
+
p.total_memory // (1024 * 1024)))
|
|
169
|
+
return gpus or None
|
|
170
|
+
except Exception: return None
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _detect_smi() -> list[GPUInfo] | None:
|
|
174
|
+
try:
|
|
175
|
+
r = subprocess.run(
|
|
176
|
+
["nvidia-smi", "--query-gpu=index,name,uuid,pci.bus_id,compute_cap,memory.total,memory.free,driver_version",
|
|
177
|
+
"--format=csv,noheader,nounits"], capture_output=True, text=True, timeout=10)
|
|
178
|
+
if r.returncode != 0: return None
|
|
179
|
+
gpus = []
|
|
180
|
+
for line in r.stdout.strip().split("\n"):
|
|
181
|
+
if not line.strip(): continue
|
|
182
|
+
p = [x.strip() for x in line.split(",")]
|
|
183
|
+
if len(p) < 5: continue
|
|
184
|
+
cc = _parse_cc(p[4])
|
|
185
|
+
gpus.append(GPUInfo(
|
|
186
|
+
int(p[0]) if p[0].isdigit() else len(gpus), p[1], cc, compute_capability_to_architecture(*cc),
|
|
187
|
+
int(p[5]) if len(p) > 5 and p[5].isdigit() else 0,
|
|
188
|
+
int(p[6]) if len(p) > 6 and p[6].isdigit() else 0,
|
|
189
|
+
p[2] if len(p) > 2 else "", p[3] if len(p) > 3 else "", p[7] if len(p) > 7 else ""))
|
|
190
|
+
return gpus or None
|
|
191
|
+
except Exception: return None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _detect_sysfs() -> list[GPUInfo] | None:
|
|
195
|
+
try:
|
|
196
|
+
pci_path = Path("/sys/bus/pci/devices")
|
|
197
|
+
if not pci_path.exists(): return None
|
|
198
|
+
gpus = []
|
|
199
|
+
for d in sorted(pci_path.iterdir()):
|
|
200
|
+
vendor = (d / "vendor").read_text().strip().lower() if (d / "vendor").exists() else ""
|
|
201
|
+
if "10de" not in vendor: continue
|
|
202
|
+
cls = (d / "class").read_text().strip() if (d / "class").exists() else ""
|
|
203
|
+
if not (cls.startswith("0x0300") or cls.startswith("0x0302")): continue
|
|
204
|
+
gpus.append(GPUInfo(len(gpus), "NVIDIA GPU", (0, 0), "Unknown", pci_bus_id=d.name))
|
|
205
|
+
return gpus or None
|
|
206
|
+
except Exception: return None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _get_driver_version() -> str:
|
|
210
|
+
try:
|
|
211
|
+
import pynvml
|
|
212
|
+
pynvml.nvmlInit()
|
|
213
|
+
v = pynvml.nvmlSystemGetDriverVersion()
|
|
214
|
+
pynvml.nvmlShutdown()
|
|
215
|
+
return v.decode() if isinstance(v, bytes) else v
|
|
216
|
+
except Exception: pass
|
|
217
|
+
try:
|
|
218
|
+
r = subprocess.run(["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
|
219
|
+
capture_output=True, text=True, timeout=5)
|
|
220
|
+
if r.returncode == 0: return r.stdout.strip().split("\n")[0]
|
|
221
|
+
except Exception: pass
|
|
222
|
+
return ""
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _get_cuda_version() -> str:
|
|
226
|
+
try:
|
|
227
|
+
import torch
|
|
228
|
+
if torch.cuda.is_available() and torch.version.cuda: return torch.version.cuda
|
|
229
|
+
except Exception: pass
|
|
230
|
+
return ""
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Platform detection - OS, architecture, platform tags."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import platform as platform_module
|
|
6
|
+
import sys
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class PlatformInfo:
|
|
12
|
+
os_name: str # linux, windows, darwin
|
|
13
|
+
arch: str # x86_64, aarch64, arm64
|
|
14
|
+
platform_tag: str # linux_x86_64, win_amd64, macosx_11_0_arm64
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def detect_platform() -> PlatformInfo:
|
|
18
|
+
return PlatformInfo(
|
|
19
|
+
os_name=_get_os_name(),
|
|
20
|
+
arch=platform_module.machine().lower(),
|
|
21
|
+
platform_tag=get_platform_tag(),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_os_name() -> str:
|
|
26
|
+
if sys.platform.startswith('linux'): return 'linux'
|
|
27
|
+
if sys.platform == 'win32': return 'windows'
|
|
28
|
+
if sys.platform == 'darwin': return 'darwin'
|
|
29
|
+
return sys.platform
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
_PLATFORM_TAGS = {
|
|
33
|
+
('linux', 'x86_64'): 'linux_x86_64', ('linux', 'amd64'): 'linux_x86_64',
|
|
34
|
+
('linux', 'aarch64'): 'linux_aarch64',
|
|
35
|
+
('win32', 'amd64'): 'win_amd64', ('win32', 'x86_64'): 'win_amd64',
|
|
36
|
+
('darwin', 'arm64'): 'macosx_11_0_arm64',
|
|
37
|
+
('darwin', 'x86_64'): 'macosx_10_9_x86_64',
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_platform_tag() -> str:
|
|
42
|
+
key = (sys.platform if sys.platform != 'linux' else 'linux', platform_module.machine().lower())
|
|
43
|
+
return _PLATFORM_TAGS.get(key, f'{sys.platform}_{platform_module.machine().lower()}')
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
_PIXI_PLATFORMS = {
|
|
47
|
+
('linux', 'x86_64'): 'linux-64', ('linux', 'amd64'): 'linux-64',
|
|
48
|
+
('linux', 'aarch64'): 'linux-aarch64',
|
|
49
|
+
('windows', 'amd64'): 'win-64', ('windows', 'x86_64'): 'win-64',
|
|
50
|
+
('darwin', 'arm64'): 'osx-arm64',
|
|
51
|
+
('darwin', 'x86_64'): 'osx-64',
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_pixi_platform() -> str:
|
|
56
|
+
key = (_get_os_name(), platform_module.machine().lower())
|
|
57
|
+
return _PIXI_PLATFORMS.get(key, f'{key[0]}-{key[1]}')
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_library_extension() -> str:
|
|
61
|
+
return {'.dll': 'windows', '.dylib': 'darwin'}.get(_get_os_name(), '.so')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_executable_suffix() -> str:
|
|
65
|
+
return '.exe' if _get_os_name() == 'windows' else ''
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def is_linux() -> bool: return _get_os_name() == 'linux'
|
|
69
|
+
def is_windows() -> bool: return _get_os_name() == 'windows'
|
|
70
|
+
def is_macos() -> bool: return _get_os_name() == 'darwin'
|