alloc 0.0.5__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alloc-0.0.5 → alloc-0.0.7}/PKG-INFO +1 -1
- {alloc-0.0.5 → alloc-0.0.7}/pyproject.toml +1 -1
- alloc-0.0.7/src/alloc/__init__.py +18 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/cli.py +68 -9
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/extractor_runner.py +3 -1
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/model_extractor.py +27 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/probe.py +156 -8
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/PKG-INFO +1 -1
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/SOURCES.txt +2 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_auth.py +12 -1
- alloc-0.0.7/tests/test_ghost_degradation.py +145 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_probe_multi.py +79 -1
- alloc-0.0.7/tests/test_scan_auth.py +142 -0
- alloc-0.0.5/src/alloc/__init__.py +0 -11
- {alloc-0.0.5 → alloc-0.0.7}/README.md +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/setup.cfg +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/artifact_loader.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/artifact_writer.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/browser_auth.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/callbacks.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/catalog/__init__.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/catalog/default_rate_card.json +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/catalog/gpus.v1.json +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/code_analyzer.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/config.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/context.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/diagnosis_display.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/diagnosis_engine.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/diagnosis_rules.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/display.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/ghost.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/model_registry.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/stability.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/upload.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc/yaml_config.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/dependency_links.txt +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/entry_points.txt +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/requires.txt +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/top_level.txt +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_artifact.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_artifact_loader.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_callbacks.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_catalog.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_cli.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_code_analyzer.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_context.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_diagnose_cli.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_diagnosis_engine.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_diagnosis_rules.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_extractor_activation.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_ghost.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_init_from_org.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_interconnect.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_model_extractor.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_probe_hw.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_stability.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_upload.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_verdict.py +0 -0
- {alloc-0.0.5 → alloc-0.0.7}/tests/test_yaml_config.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alloc
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
|
|
5
5
|
Author-email: Alloc Labs <hello@alloclabs.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "alloc"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.7"
|
|
8
8
|
description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Alloc — GPU intelligence for ML training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import warnings as _warnings
|
|
6
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
7
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
8
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
9
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
10
|
+
del _warnings
|
|
11
|
+
|
|
12
|
+
__version__ = "0.0.7"
|
|
13
|
+
|
|
14
|
+
from alloc.ghost import ghost, GhostReport
|
|
15
|
+
from alloc.callbacks import AllocCallback as HuggingFaceCallback
|
|
16
|
+
from alloc.callbacks import AllocLightningCallback as LightningCallback
|
|
17
|
+
|
|
18
|
+
__all__ = ["ghost", "GhostReport", "HuggingFaceCallback", "LightningCallback", "__version__"]
|
|
@@ -16,8 +16,17 @@ from __future__ import annotations
|
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
18
|
import sys
|
|
19
|
+
import warnings
|
|
19
20
|
from typing import List, Optional
|
|
20
21
|
|
|
22
|
+
# Suppress noisy third-party warnings globally — pynvml deprecation (emitted
|
|
23
|
+
# from torch.cuda.__init__) and urllib3 LibreSSL warnings clutter CLI output.
|
|
24
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
25
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
26
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
27
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
28
|
+
warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
|
|
29
|
+
|
|
21
30
|
import typer
|
|
22
31
|
from rich.console import Console
|
|
23
32
|
|
|
@@ -68,6 +77,19 @@ def ghost(
|
|
|
68
77
|
console.print(f"[dim]Tip: alloc ghost {script} --param-count-b 7.0[/dim]")
|
|
69
78
|
raise typer.Exit(1)
|
|
70
79
|
|
|
80
|
+
if info.extraction_error:
|
|
81
|
+
if json_output:
|
|
82
|
+
_print_json({
|
|
83
|
+
"error": info.extraction_error,
|
|
84
|
+
"detail": info.extraction_detail,
|
|
85
|
+
"supported": False,
|
|
86
|
+
})
|
|
87
|
+
else:
|
|
88
|
+
console.print(f"[yellow]{info.extraction_detail}[/yellow]")
|
|
89
|
+
if info.extraction_error == "distributed_entrypoint":
|
|
90
|
+
console.print("[dim]Tip: alloc ghost model.py (point to the file that defines your model)[/dim]")
|
|
91
|
+
raise typer.Exit(1)
|
|
92
|
+
|
|
71
93
|
# Use dtype from execution if available, otherwise CLI flag
|
|
72
94
|
resolved_dtype = info.dtype if info.method == "execution" else dtype
|
|
73
95
|
|
|
@@ -2092,12 +2114,32 @@ def scan(
|
|
|
2092
2114
|
|
|
2093
2115
|
try:
|
|
2094
2116
|
headers = {"Content-Type": "application/json"}
|
|
2117
|
+
used_auth = bool(token)
|
|
2118
|
+
|
|
2095
2119
|
if token:
|
|
2096
2120
|
headers["Authorization"] = f"Bearer {token}"
|
|
2121
|
+
endpoint = "/scans"
|
|
2122
|
+
else:
|
|
2123
|
+
endpoint = "/scans/cli"
|
|
2097
2124
|
|
|
2098
|
-
endpoint = "/scans" if token else "/scans/cli"
|
|
2099
2125
|
with httpx.Client(timeout=30) as client:
|
|
2100
2126
|
resp = client.post(f"{api_url}{endpoint}", json=payload, headers=headers)
|
|
2127
|
+
|
|
2128
|
+
# On 401 with a saved token: try refresh, then fall back to public endpoint
|
|
2129
|
+
if resp.status_code == 401 and used_auth:
|
|
2130
|
+
new_token = try_refresh_access_token()
|
|
2131
|
+
if new_token:
|
|
2132
|
+
headers["Authorization"] = f"Bearer {new_token}"
|
|
2133
|
+
resp = client.post(f"{api_url}/scans", json=payload, headers=headers)
|
|
2134
|
+
else:
|
|
2135
|
+
# Token refresh failed — fall back to unauthenticated scan
|
|
2136
|
+
console.print(
|
|
2137
|
+
"[yellow]Session expired — falling back to public scan "
|
|
2138
|
+
"(org fleet context unavailable). Run `alloc login` to restore.[/yellow]",
|
|
2139
|
+
)
|
|
2140
|
+
del headers["Authorization"]
|
|
2141
|
+
resp = client.post(f"{api_url}/scans/cli", json=payload, headers=headers)
|
|
2142
|
+
|
|
2101
2143
|
resp.raise_for_status()
|
|
2102
2144
|
result = resp.json()
|
|
2103
2145
|
|
|
@@ -2107,7 +2149,12 @@ def scan(
|
|
|
2107
2149
|
_print_scan_result(result, gpu, strategy)
|
|
2108
2150
|
except httpx.HTTPStatusError as e:
|
|
2109
2151
|
if json_output:
|
|
2110
|
-
|
|
2152
|
+
detail = ""
|
|
2153
|
+
try:
|
|
2154
|
+
detail = e.response.json().get("detail", "")
|
|
2155
|
+
except Exception:
|
|
2156
|
+
pass
|
|
2157
|
+
_print_json({"error": f"API error {e.response.status_code}", "detail": detail})
|
|
2111
2158
|
elif e.response.status_code == 403:
|
|
2112
2159
|
console.print("[yellow]AI analysis requires a Pro or Enterprise plan.[/yellow]")
|
|
2113
2160
|
console.print("[dim]The scan still works — just without AI-powered analysis.[/dim]")
|
|
@@ -2147,18 +2194,30 @@ def login(
|
|
|
2147
2194
|
),
|
|
2148
2195
|
):
|
|
2149
2196
|
"""Authenticate with Alloc dashboard."""
|
|
2150
|
-
# Suppress noisy third-party warnings (urllib3 LibreSSL, pynvml deprecation)
|
|
2151
|
-
# that clutter the auth flow output.
|
|
2152
|
-
import warnings
|
|
2153
|
-
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
2154
|
-
warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
|
|
2155
|
-
warnings.filterwarnings("ignore", message=".*pynvml.*", category=FutureWarning)
|
|
2156
|
-
|
|
2157
2197
|
import httpx
|
|
2158
2198
|
from alloc.config import get_supabase_url, get_supabase_anon_key, load_config, save_config
|
|
2159
2199
|
|
|
2160
2200
|
# --- Browser OAuth flow ---
|
|
2161
2201
|
if browser:
|
|
2202
|
+
# Detect headless/SSH environments where browser login won't work
|
|
2203
|
+
is_headless = (
|
|
2204
|
+
not os.environ.get("DISPLAY")
|
|
2205
|
+
and not os.environ.get("WAYLAND_DISPLAY")
|
|
2206
|
+
and sys.platform != "darwin"
|
|
2207
|
+
and sys.platform != "win32"
|
|
2208
|
+
)
|
|
2209
|
+
is_ssh = bool(os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"))
|
|
2210
|
+
if is_headless or is_ssh:
|
|
2211
|
+
console.print("[yellow]Headless/SSH environment detected — browser login won't work here.[/yellow]")
|
|
2212
|
+
console.print()
|
|
2213
|
+
console.print("Use token login instead:")
|
|
2214
|
+
console.print(" 1. Log in at [cyan]https://alloclabs.com[/cyan] in your local browser")
|
|
2215
|
+
console.print(" 2. Open DevTools → Application → Local Storage → copy your access_token")
|
|
2216
|
+
console.print(" 3. Run: [green]alloc login --method token --token <paste-token>[/green]")
|
|
2217
|
+
console.print()
|
|
2218
|
+
console.print("[dim]Or set ALLOC_TOKEN=<token> in your environment.[/dim]")
|
|
2219
|
+
raise typer.Exit(1)
|
|
2220
|
+
|
|
2162
2221
|
provider = (provider or "").strip().lower()
|
|
2163
2222
|
if provider not in ("google", "azure"):
|
|
2164
2223
|
console.print("[red]Invalid --provider. Use: google or azure[/red]")
|
|
@@ -217,7 +217,9 @@ def main():
|
|
|
217
217
|
try:
|
|
218
218
|
obj = getattr(module, attr_name)
|
|
219
219
|
if isinstance(obj, nn.Module):
|
|
220
|
-
|
|
220
|
+
# Unwrap DDP/FSDP wrappers to get the underlying model
|
|
221
|
+
unwrapped = getattr(obj, "module", obj)
|
|
222
|
+
count, dtype_str = _count_params(unwrapped)
|
|
221
223
|
if count > 0:
|
|
222
224
|
models.append((count, dtype_str, attr_name))
|
|
223
225
|
except Exception:
|
|
@@ -33,6 +33,8 @@ class ModelInfo:
|
|
|
33
33
|
seq_length: Optional[int] = None
|
|
34
34
|
activation_memory_bytes: Optional[int] = None
|
|
35
35
|
activation_method: Optional[str] = None # "traced" | None
|
|
36
|
+
extraction_error: Optional[str] = None # "distributed_entrypoint" | None
|
|
37
|
+
extraction_detail: Optional[str] = None # human-readable explanation
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
def extract_model_info(
|
|
@@ -99,6 +101,13 @@ def _extract_via_subprocess(
|
|
|
99
101
|
|
|
100
102
|
env = os.environ.copy()
|
|
101
103
|
env["CUDA_VISIBLE_DEVICES"] = "" # prevent GPU allocation
|
|
104
|
+
# Set distributed env vars so DDP scripts don't crash on
|
|
105
|
+
# torch.distributed.init_process_group() during model extraction.
|
|
106
|
+
env.setdefault("RANK", "0")
|
|
107
|
+
env.setdefault("LOCAL_RANK", "0")
|
|
108
|
+
env.setdefault("WORLD_SIZE", "1")
|
|
109
|
+
env.setdefault("MASTER_ADDR", "127.0.0.1")
|
|
110
|
+
env.setdefault("MASTER_PORT", "29500")
|
|
102
111
|
|
|
103
112
|
subprocess.run(
|
|
104
113
|
[sys.executable, "-m", "alloc.extractor_runner", sidecar_path, script_abs],
|
|
@@ -127,6 +136,24 @@ def _extract_via_subprocess(
|
|
|
127
136
|
activation_method=data.get("activation_method"),
|
|
128
137
|
)
|
|
129
138
|
|
|
139
|
+
# Structured degradation for distributed scripts
|
|
140
|
+
if data.get("status") == "error":
|
|
141
|
+
error_msg = data.get("error", "")
|
|
142
|
+
_dist_keywords = ("init_process_group", "NCCL", "gloo", "distributed",
|
|
143
|
+
"MASTER_ADDR", "MASTER_PORT", "RendezvousError")
|
|
144
|
+
if any(kw.lower() in error_msg.lower() for kw in _dist_keywords):
|
|
145
|
+
return ModelInfo(
|
|
146
|
+
param_count=0,
|
|
147
|
+
dtype="float16",
|
|
148
|
+
model_name=None,
|
|
149
|
+
method="execution",
|
|
150
|
+
extraction_error="distributed_entrypoint",
|
|
151
|
+
extraction_detail=(
|
|
152
|
+
"Script requires a distributed runtime (e.g. torchrun). "
|
|
153
|
+
"Run ghost on the model definition file instead of the launcher script."
|
|
154
|
+
),
|
|
155
|
+
)
|
|
156
|
+
|
|
130
157
|
return None
|
|
131
158
|
|
|
132
159
|
except subprocess.TimeoutExpired:
|
|
@@ -18,6 +18,13 @@ from dataclasses import dataclass, field
|
|
|
18
18
|
from enum import Enum
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
|
|
21
|
+
import warnings as _warnings
|
|
22
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
23
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
24
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
25
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
26
|
+
del _warnings
|
|
27
|
+
|
|
21
28
|
|
|
22
29
|
class StopReason(str, Enum):
|
|
23
30
|
STABLE = "stable"
|
|
@@ -83,6 +90,85 @@ def _try_import_pynvml():
|
|
|
83
90
|
|
|
84
91
|
|
|
85
92
|
|
|
93
|
+
def _parse_launcher_gpu_count(command):
|
|
94
|
+
# type: (list) -> Optional[int]
|
|
95
|
+
"""Extract expected GPU count from launcher command line.
|
|
96
|
+
|
|
97
|
+
Recognizes: torchrun, torch.distributed.launch, accelerate launch,
|
|
98
|
+
deepspeed, mpirun/mpiexec.
|
|
99
|
+
"""
|
|
100
|
+
args = [str(a) for a in command]
|
|
101
|
+
cmd_str = " ".join(args)
|
|
102
|
+
|
|
103
|
+
# torchrun --nproc_per_node=N or --nproc-per-node=N
|
|
104
|
+
for i, a in enumerate(args):
|
|
105
|
+
for flag in ("--nproc_per_node", "--nproc-per-node"):
|
|
106
|
+
if a.startswith(flag + "="):
|
|
107
|
+
try:
|
|
108
|
+
return int(a.split("=", 1)[1])
|
|
109
|
+
except ValueError:
|
|
110
|
+
pass
|
|
111
|
+
elif a == flag and i + 1 < len(args):
|
|
112
|
+
try:
|
|
113
|
+
return int(args[i + 1])
|
|
114
|
+
except ValueError:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
# accelerate launch --num_processes=N or --num-processes=N
|
|
118
|
+
for i, a in enumerate(args):
|
|
119
|
+
for flag in ("--num_processes", "--num-processes"):
|
|
120
|
+
if a.startswith(flag + "="):
|
|
121
|
+
try:
|
|
122
|
+
return int(a.split("=", 1)[1])
|
|
123
|
+
except ValueError:
|
|
124
|
+
pass
|
|
125
|
+
elif a == flag and i + 1 < len(args):
|
|
126
|
+
try:
|
|
127
|
+
return int(args[i + 1])
|
|
128
|
+
except ValueError:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
# deepspeed --num_gpus=N or --num-gpus=N
|
|
132
|
+
for i, a in enumerate(args):
|
|
133
|
+
for flag in ("--num_gpus", "--num-gpus"):
|
|
134
|
+
if a.startswith(flag + "="):
|
|
135
|
+
try:
|
|
136
|
+
return int(a.split("=", 1)[1])
|
|
137
|
+
except ValueError:
|
|
138
|
+
pass
|
|
139
|
+
elif a == flag and i + 1 < len(args):
|
|
140
|
+
try:
|
|
141
|
+
return int(args[i + 1])
|
|
142
|
+
except ValueError:
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
# mpirun/mpiexec -np N or -n N
|
|
146
|
+
for i, a in enumerate(args):
|
|
147
|
+
if a in ("-np", "-n") and i + 1 < len(args):
|
|
148
|
+
if any(x in cmd_str for x in ("mpirun", "mpiexec")):
|
|
149
|
+
try:
|
|
150
|
+
return int(args[i + 1])
|
|
151
|
+
except ValueError:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _read_child_env(pid, var_name):
|
|
158
|
+
# type: (int, str) -> Optional[str]
|
|
159
|
+
"""Read an environment variable from a child process via /proc (Linux only)."""
|
|
160
|
+
try:
|
|
161
|
+
env_path = "/proc/{}/environ".format(pid)
|
|
162
|
+
with open(env_path, "rb") as f:
|
|
163
|
+
env_data = f.read()
|
|
164
|
+
for entry in env_data.split(b"\x00"):
|
|
165
|
+
if entry.startswith(var_name.encode() + b"="):
|
|
166
|
+
return entry.split(b"=", 1)[1].decode("utf-8")
|
|
167
|
+
except Exception:
|
|
168
|
+
pass
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
|
|
86
172
|
def _get_child_pids(pid):
|
|
87
173
|
# type: (int) -> List[int]
|
|
88
174
|
"""Get child PIDs of a process. Returns empty list on failure."""
|
|
@@ -101,11 +187,15 @@ def _get_child_pids(pid):
|
|
|
101
187
|
return []
|
|
102
188
|
|
|
103
189
|
|
|
104
|
-
def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0):
|
|
105
|
-
# type: (int, ..., int) -> List[int]
|
|
190
|
+
def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0, expected_gpus=None):
|
|
191
|
+
# type: (int, ..., int, Optional[int]) -> List[int]
|
|
106
192
|
"""Discover which GPUs a process (and its children) are using.
|
|
107
193
|
|
|
108
|
-
|
|
194
|
+
Two strategies:
|
|
195
|
+
1. PID-matching: walks process tree, matches PIDs against NVML compute processes.
|
|
196
|
+
2. Active-GPU counting: counts GPUs with ANY compute processes. Used when PID
|
|
197
|
+
matching finds fewer GPUs than expected (common for DDP launchers).
|
|
198
|
+
|
|
109
199
|
Falls back to [fallback_index] if discovery fails or finds nothing.
|
|
110
200
|
"""
|
|
111
201
|
try:
|
|
@@ -143,11 +233,48 @@ def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0):
|
|
|
143
233
|
for ggchild in _get_child_pids(grandchild):
|
|
144
234
|
target_pids.add(ggchild)
|
|
145
235
|
|
|
236
|
+
# Also try reading WORLD_SIZE from child process environments (Linux)
|
|
237
|
+
child_world_size = None
|
|
238
|
+
for child in _get_child_pids(proc_pid):
|
|
239
|
+
ws = _read_child_env(child, "WORLD_SIZE")
|
|
240
|
+
if ws:
|
|
241
|
+
try:
|
|
242
|
+
child_world_size = int(ws)
|
|
243
|
+
except ValueError:
|
|
244
|
+
pass
|
|
245
|
+
break
|
|
246
|
+
for grandchild in _get_child_pids(child):
|
|
247
|
+
ws = _read_child_env(grandchild, "WORLD_SIZE")
|
|
248
|
+
if ws:
|
|
249
|
+
try:
|
|
250
|
+
child_world_size = int(ws)
|
|
251
|
+
except ValueError:
|
|
252
|
+
pass
|
|
253
|
+
break
|
|
254
|
+
for ggchild in _get_child_pids(grandchild):
|
|
255
|
+
ws = _read_child_env(ggchild, "WORLD_SIZE")
|
|
256
|
+
if ws:
|
|
257
|
+
try:
|
|
258
|
+
child_world_size = int(ws)
|
|
259
|
+
except ValueError:
|
|
260
|
+
pass
|
|
261
|
+
break
|
|
262
|
+
|
|
263
|
+
# Determine how many GPUs we expect
|
|
264
|
+
effective_expected = expected_gpus
|
|
265
|
+
if child_world_size is not None:
|
|
266
|
+
if effective_expected is None or child_world_size > effective_expected:
|
|
267
|
+
effective_expected = child_world_size
|
|
268
|
+
|
|
269
|
+
# Strategy 1: PID-based matching
|
|
146
270
|
found_indices = []
|
|
271
|
+
active_indices = [] # GPUs with ANY compute processes
|
|
147
272
|
for idx in search_indices:
|
|
148
273
|
try:
|
|
149
274
|
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
|
|
150
275
|
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
|
276
|
+
if procs:
|
|
277
|
+
active_indices.append(idx)
|
|
151
278
|
for p in procs:
|
|
152
279
|
if p.pid in target_pids:
|
|
153
280
|
found_indices.append(idx)
|
|
@@ -155,6 +282,16 @@ def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0):
|
|
|
155
282
|
except Exception:
|
|
156
283
|
continue
|
|
157
284
|
|
|
285
|
+
# Strategy 2: If PID matching found fewer than expected, use active GPU count.
|
|
286
|
+
# This handles the common case where DDP workers are running but their PIDs
|
|
287
|
+
# weren't in our process tree (e.g., process group isolation, timing issues).
|
|
288
|
+
if effective_expected is not None and len(found_indices) < effective_expected:
|
|
289
|
+
if len(active_indices) >= effective_expected:
|
|
290
|
+
return active_indices[:effective_expected]
|
|
291
|
+
# Even if active < expected, active is better than found
|
|
292
|
+
if len(active_indices) > len(found_indices):
|
|
293
|
+
return active_indices
|
|
294
|
+
|
|
158
295
|
return found_indices if found_indices else [fallback_index]
|
|
159
296
|
|
|
160
297
|
|
|
@@ -310,12 +447,17 @@ def probe_command(
|
|
|
310
447
|
discovery_attempts = 0
|
|
311
448
|
max_discovery_attempts = 3 # Retry at samples 5, 15, 30
|
|
312
449
|
|
|
313
|
-
# Determine expected GPU count from
|
|
450
|
+
# Determine expected GPU count from command line + environment
|
|
314
451
|
expected_gpus = 1
|
|
452
|
+
launcher_gpus = _parse_launcher_gpu_count(command)
|
|
453
|
+
if launcher_gpus is not None and launcher_gpus > 1:
|
|
454
|
+
expected_gpus = launcher_gpus
|
|
315
455
|
ws = os.environ.get("WORLD_SIZE", "").strip()
|
|
316
456
|
if ws:
|
|
317
457
|
try:
|
|
318
|
-
|
|
458
|
+
ws_int = max(1, int(ws))
|
|
459
|
+
if ws_int > expected_gpus:
|
|
460
|
+
expected_gpus = ws_int
|
|
319
461
|
except ValueError:
|
|
320
462
|
pass
|
|
321
463
|
|
|
@@ -329,7 +471,10 @@ def probe_command(
|
|
|
329
471
|
and proc.pid):
|
|
330
472
|
discovery_attempts += 1
|
|
331
473
|
try:
|
|
332
|
-
discovered = _discover_gpu_indices(
|
|
474
|
+
discovered = _discover_gpu_indices(
|
|
475
|
+
proc.pid, pynvml, fallback_index=gpu_index,
|
|
476
|
+
expected_gpus=expected_gpus if expected_gpus > 1 else None,
|
|
477
|
+
)
|
|
333
478
|
if len(discovered) > 1:
|
|
334
479
|
handles = []
|
|
335
480
|
pmap = []
|
|
@@ -456,9 +601,12 @@ def probe_command(
|
|
|
456
601
|
if calibration_time_ref[0] is not None:
|
|
457
602
|
cal_duration = round(calibration_time_ref[0] - start_time, 2)
|
|
458
603
|
|
|
459
|
-
#
|
|
460
|
-
#
|
|
604
|
+
# Final fallback: if NVML discovery found fewer GPUs than expected,
|
|
605
|
+
# trust the command-line / environment signals. The probe may miss GPUs
|
|
461
606
|
# due to DDP per-rank CVD isolation or timing races.
|
|
607
|
+
launcher_count = _parse_launcher_gpu_count(command)
|
|
608
|
+
if launcher_count is not None and launcher_count > num_gpus_ref[0]:
|
|
609
|
+
num_gpus_ref[0] = launcher_count
|
|
462
610
|
env_world = os.environ.get("WORLD_SIZE", "").strip()
|
|
463
611
|
if env_world:
|
|
464
612
|
try:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alloc
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
|
|
5
5
|
Author-email: Alloc Labs <hello@alloclabs.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -43,11 +43,13 @@ tests/test_diagnosis_engine.py
|
|
|
43
43
|
tests/test_diagnosis_rules.py
|
|
44
44
|
tests/test_extractor_activation.py
|
|
45
45
|
tests/test_ghost.py
|
|
46
|
+
tests/test_ghost_degradation.py
|
|
46
47
|
tests/test_init_from_org.py
|
|
47
48
|
tests/test_interconnect.py
|
|
48
49
|
tests/test_model_extractor.py
|
|
49
50
|
tests/test_probe_hw.py
|
|
50
51
|
tests/test_probe_multi.py
|
|
52
|
+
tests/test_scan_auth.py
|
|
51
53
|
tests/test_stability.py
|
|
52
54
|
tests/test_upload.py
|
|
53
55
|
tests/test_verdict.py
|
|
@@ -223,6 +223,7 @@ def test_browser_login_saves_tokens(tmp_path: Path):
|
|
|
223
223
|
env = {
|
|
224
224
|
"HOME": str(tmp_path),
|
|
225
225
|
"ALLOC_API_URL": "https://api.example.com",
|
|
226
|
+
"DISPLAY": ":0", # prevent headless detection in CI
|
|
226
227
|
}
|
|
227
228
|
|
|
228
229
|
with patch("alloc.browser_auth.browser_login", return_value=mock_result):
|
|
@@ -248,6 +249,7 @@ def test_browser_login_with_azure_provider(tmp_path: Path):
|
|
|
248
249
|
env = {
|
|
249
250
|
"HOME": str(tmp_path),
|
|
250
251
|
"ALLOC_API_URL": "https://api.example.com",
|
|
252
|
+
"DISPLAY": ":0",
|
|
251
253
|
}
|
|
252
254
|
|
|
253
255
|
with patch("alloc.browser_auth.browser_login", return_value=mock_result) as mock_bl:
|
|
@@ -263,7 +265,7 @@ def test_browser_login_with_azure_provider(tmp_path: Path):
|
|
|
263
265
|
|
|
264
266
|
|
|
265
267
|
def test_browser_login_invalid_provider(tmp_path: Path):
|
|
266
|
-
env = {"HOME": str(tmp_path)}
|
|
268
|
+
env = {"HOME": str(tmp_path), "DISPLAY": ":0"}
|
|
267
269
|
result = runner.invoke(
|
|
268
270
|
app, ["login", "--browser", "--provider", "facebook"], env=env
|
|
269
271
|
)
|
|
@@ -271,10 +273,19 @@ def test_browser_login_invalid_provider(tmp_path: Path):
|
|
|
271
273
|
assert "Invalid --provider" in result.output
|
|
272
274
|
|
|
273
275
|
|
|
276
|
+
def test_browser_login_headless_detection(tmp_path: Path):
|
|
277
|
+
"""In headless/SSH environment, browser login should fail with guidance."""
|
|
278
|
+
env = {"HOME": str(tmp_path), "SSH_CLIENT": "1.2.3.4 1234 22", "DISPLAY": ":0"}
|
|
279
|
+
result = runner.invoke(app, ["login", "--browser"], env=env)
|
|
280
|
+
assert result.exit_code != 0
|
|
281
|
+
assert "Headless" in result.output or "token" in result.output
|
|
282
|
+
|
|
283
|
+
|
|
274
284
|
def test_browser_login_timeout(tmp_path: Path):
|
|
275
285
|
env = {
|
|
276
286
|
"HOME": str(tmp_path),
|
|
277
287
|
"ALLOC_API_URL": "https://api.example.com",
|
|
288
|
+
"DISPLAY": ":0",
|
|
278
289
|
}
|
|
279
290
|
|
|
280
291
|
with patch(
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Tests for ghost structured degradation on distributed scripts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import tempfile
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from unittest.mock import patch
|
|
10
|
+
|
|
11
|
+
from typer.testing import CliRunner
|
|
12
|
+
|
|
13
|
+
from alloc.cli import app
|
|
14
|
+
from alloc.model_extractor import ModelInfo, extract_model_info
|
|
15
|
+
|
|
16
|
+
runner = CliRunner()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_distributed_error_returns_structured_modelinfo():
|
|
20
|
+
"""When extractor subprocess fails with a distributed keyword, return structured ModelInfo."""
|
|
21
|
+
sidecar_data = json.dumps({
|
|
22
|
+
"status": "error",
|
|
23
|
+
"error": "RuntimeError: torch.distributed.init_process_group requires MASTER_ADDR",
|
|
24
|
+
})
|
|
25
|
+
|
|
26
|
+
def _fake_subprocess_run(*args, **kwargs):
|
|
27
|
+
# Write the sidecar file
|
|
28
|
+
sidecar_path = args[0][3] # [python, -m, alloc.extractor_runner, sidecar_path, script_path]
|
|
29
|
+
with open(sidecar_path, "w") as f:
|
|
30
|
+
f.write(sidecar_data)
|
|
31
|
+
|
|
32
|
+
# Create a dummy script
|
|
33
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_dist_")
|
|
34
|
+
os.write(fd, b"import torch\ntorch.distributed.init_process_group('nccl')\n")
|
|
35
|
+
os.close(fd)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
with patch("subprocess.run", side_effect=_fake_subprocess_run):
|
|
39
|
+
info = extract_model_info(script_path)
|
|
40
|
+
|
|
41
|
+
assert info is not None
|
|
42
|
+
assert info.extraction_error == "distributed_entrypoint"
|
|
43
|
+
assert "distributed runtime" in info.extraction_detail
|
|
44
|
+
assert info.param_count == 0
|
|
45
|
+
finally:
|
|
46
|
+
os.unlink(script_path)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_distributed_error_nccl_keyword():
|
|
50
|
+
"""NCCL errors should be caught as distributed failures."""
|
|
51
|
+
sidecar_data = json.dumps({
|
|
52
|
+
"status": "error",
|
|
53
|
+
"error": "NCCL error: unhandled system error",
|
|
54
|
+
})
|
|
55
|
+
|
|
56
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_nccl_")
|
|
57
|
+
os.write(fd, b"pass\n")
|
|
58
|
+
os.close(fd)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
def _fake_run(*args, **kwargs):
|
|
62
|
+
sidecar_path = args[0][3]
|
|
63
|
+
with open(sidecar_path, "w") as f:
|
|
64
|
+
f.write(sidecar_data)
|
|
65
|
+
|
|
66
|
+
with patch("subprocess.run", side_effect=_fake_run):
|
|
67
|
+
info = extract_model_info(script_path)
|
|
68
|
+
|
|
69
|
+
assert info is not None
|
|
70
|
+
assert info.extraction_error == "distributed_entrypoint"
|
|
71
|
+
finally:
|
|
72
|
+
os.unlink(script_path)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_non_distributed_error_returns_none():
|
|
76
|
+
"""Non-distributed errors should still return None (fall through to AST)."""
|
|
77
|
+
sidecar_data = json.dumps({
|
|
78
|
+
"status": "error",
|
|
79
|
+
"error": "ImportError: No module named 'custom_lib'",
|
|
80
|
+
})
|
|
81
|
+
|
|
82
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_other_")
|
|
83
|
+
# Script with no from_pretrained so AST also returns None
|
|
84
|
+
os.write(fd, b"import custom_lib\n")
|
|
85
|
+
os.close(fd)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
def _fake_run(*args, **kwargs):
|
|
89
|
+
sidecar_path = args[0][3]
|
|
90
|
+
with open(sidecar_path, "w") as f:
|
|
91
|
+
f.write(sidecar_data)
|
|
92
|
+
|
|
93
|
+
with patch("subprocess.run", side_effect=_fake_run):
|
|
94
|
+
info = extract_model_info(script_path)
|
|
95
|
+
|
|
96
|
+
# Should be None because error is not distributed and AST won't find a model either
|
|
97
|
+
assert info is None
|
|
98
|
+
finally:
|
|
99
|
+
os.unlink(script_path)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_ghost_cli_distributed_error_json(tmp_path: Path):
|
|
103
|
+
"""ghost --json shows structured error for distributed scripts."""
|
|
104
|
+
script_path = tmp_path / "train_ddp.py"
|
|
105
|
+
script_path.write_text("import torch\ntorch.distributed.init_process_group('nccl')\n")
|
|
106
|
+
|
|
107
|
+
dist_info = ModelInfo(
|
|
108
|
+
param_count=0,
|
|
109
|
+
dtype="float16",
|
|
110
|
+
model_name=None,
|
|
111
|
+
method="execution",
|
|
112
|
+
extraction_error="distributed_entrypoint",
|
|
113
|
+
extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
|
|
117
|
+
result = runner.invoke(app, ["ghost", str(script_path), "--json"])
|
|
118
|
+
|
|
119
|
+
assert result.exit_code != 0
|
|
120
|
+
data = json.loads(result.output)
|
|
121
|
+
assert data["error"] == "distributed_entrypoint"
|
|
122
|
+
assert data["supported"] is False
|
|
123
|
+
assert "distributed runtime" in data["detail"]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_ghost_cli_distributed_error_human(tmp_path: Path):
|
|
127
|
+
"""ghost shows human-readable message with tip for distributed scripts."""
|
|
128
|
+
script_path = tmp_path / "train_ddp.py"
|
|
129
|
+
script_path.write_text("import torch\n")
|
|
130
|
+
|
|
131
|
+
dist_info = ModelInfo(
|
|
132
|
+
param_count=0,
|
|
133
|
+
dtype="float16",
|
|
134
|
+
model_name=None,
|
|
135
|
+
method="execution",
|
|
136
|
+
extraction_error="distributed_entrypoint",
|
|
137
|
+
extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
|
|
141
|
+
result = runner.invoke(app, ["ghost", str(script_path)])
|
|
142
|
+
|
|
143
|
+
assert result.exit_code != 0
|
|
144
|
+
assert "distributed runtime" in result.output
|
|
145
|
+
assert "model.py" in result.output # tip about pointing to model file
|
|
@@ -4,7 +4,12 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from unittest.mock import MagicMock, patch
|
|
6
6
|
|
|
7
|
-
from alloc.probe import
|
|
7
|
+
from alloc.probe import (
|
|
8
|
+
_discover_gpu_indices,
|
|
9
|
+
_get_child_pids,
|
|
10
|
+
_parse_launcher_gpu_count,
|
|
11
|
+
ProbeResult,
|
|
12
|
+
)
|
|
8
13
|
|
|
9
14
|
|
|
10
15
|
def _mock_pynvml_multi_gpu(proc_pid, gpu_process_map):
|
|
@@ -112,3 +117,76 @@ def test_probe_result_defaults_single_gpu():
|
|
|
112
117
|
result = ProbeResult()
|
|
113
118
|
assert result.num_gpus_detected == 1
|
|
114
119
|
assert result.process_map is None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# ── Launcher command-line parsing ──
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_parse_torchrun_equals():
|
|
126
|
+
assert _parse_launcher_gpu_count(["torchrun", "--nproc_per_node=2", "train.py"]) == 2
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def test_parse_torchrun_space():
|
|
130
|
+
assert _parse_launcher_gpu_count(["torchrun", "--nproc_per_node", "4", "train.py"]) == 4
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_parse_torchrun_hyphen():
|
|
134
|
+
assert _parse_launcher_gpu_count(["torchrun", "--nproc-per-node=8", "train.py"]) == 8
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def test_parse_accelerate_equals():
|
|
138
|
+
assert _parse_launcher_gpu_count(["accelerate", "launch", "--num_processes=4", "train.py"]) == 4
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_parse_accelerate_hyphen_space():
|
|
142
|
+
assert _parse_launcher_gpu_count(["accelerate", "launch", "--num-processes", "8", "train.py"]) == 8
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_parse_deepspeed_equals():
|
|
146
|
+
assert _parse_launcher_gpu_count(["deepspeed", "--num_gpus=4", "train.py"]) == 4
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def test_parse_deepspeed_hyphen_space():
|
|
150
|
+
assert _parse_launcher_gpu_count(["deepspeed", "--num-gpus", "2", "train.py"]) == 2
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_parse_mpirun():
|
|
154
|
+
assert _parse_launcher_gpu_count(["mpirun", "-np", "8", "python", "train.py"]) == 8
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_parse_plain_python():
|
|
158
|
+
assert _parse_launcher_gpu_count(["python", "train.py"]) is None
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def test_parse_torch_distributed_launch():
|
|
162
|
+
assert _parse_launcher_gpu_count([
|
|
163
|
+
"python", "-m", "torch.distributed.launch", "--nproc_per_node=2", "train.py"
|
|
164
|
+
]) == 2
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# ── Active-GPU fallback discovery ──
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_active_gpu_fallback_when_pid_mismatch():
|
|
171
|
+
"""When PID matching fails but GPUs have active compute processes,
|
|
172
|
+
and expected_gpus matches, use active GPUs."""
|
|
173
|
+
mock = _mock_pynvml_multi_gpu(
|
|
174
|
+
proc_pid=1000,
|
|
175
|
+
gpu_process_map={0: [9999], 1: [8888]}, # PIDs don't match 1000
|
|
176
|
+
)
|
|
177
|
+
with patch("alloc.probe._get_child_pids", return_value=[]):
|
|
178
|
+
with patch("alloc.probe._read_child_env", return_value=None):
|
|
179
|
+
result = _discover_gpu_indices(1000, mock, fallback_index=0, expected_gpus=2)
|
|
180
|
+
assert result == [0, 1]
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_active_gpu_fallback_not_used_without_expected():
|
|
184
|
+
"""Without expected_gpus, active GPU fallback is not used."""
|
|
185
|
+
mock = _mock_pynvml_multi_gpu(
|
|
186
|
+
proc_pid=1000,
|
|
187
|
+
gpu_process_map={0: [9999], 1: [8888]},
|
|
188
|
+
)
|
|
189
|
+
with patch("alloc.probe._get_child_pids", return_value=[]):
|
|
190
|
+
with patch("alloc.probe._read_child_env", return_value=None):
|
|
191
|
+
result = _discover_gpu_indices(1000, mock, fallback_index=0)
|
|
192
|
+
assert result == [0] # Falls back to default
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Tests for scan command 401 retry + /scans/cli fallback."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from unittest.mock import MagicMock, patch
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
from typer.testing import CliRunner
|
|
11
|
+
|
|
12
|
+
from alloc.cli import app
|
|
13
|
+
|
|
14
|
+
runner = CliRunner()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _make_resp(status_code: int, body: dict, url: str = "https://api.example.com/scans"):
|
|
18
|
+
req = httpx.Request("POST", url)
|
|
19
|
+
return httpx.Response(
|
|
20
|
+
status_code,
|
|
21
|
+
request=req,
|
|
22
|
+
content=json.dumps(body).encode(),
|
|
23
|
+
headers={"content-type": "application/json"},
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_scan_401_refresh_retry(tmp_path: Path):
|
|
28
|
+
"""On 401, refresh token and retry on /scans."""
|
|
29
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
30
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
31
|
+
|
|
32
|
+
mock_client = MagicMock()
|
|
33
|
+
mock_client.__enter__.return_value = mock_client
|
|
34
|
+
mock_client.__exit__.return_value = False
|
|
35
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
36
|
+
|
|
37
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
38
|
+
cfg_file.parent.mkdir(parents=True)
|
|
39
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
40
|
+
|
|
41
|
+
env = {
|
|
42
|
+
"HOME": str(tmp_path),
|
|
43
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
with (
|
|
47
|
+
patch("httpx.Client", return_value=mock_client),
|
|
48
|
+
patch("alloc.cli.try_refresh_access_token", return_value="new-tok"),
|
|
49
|
+
):
|
|
50
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
51
|
+
|
|
52
|
+
assert result.exit_code == 0
|
|
53
|
+
assert mock_client.post.call_count == 2
|
|
54
|
+
# Second call should use refreshed token
|
|
55
|
+
second_call = mock_client.post.call_args_list[1]
|
|
56
|
+
assert "Bearer new-tok" in str(second_call)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_scan_401_refresh_fails_fallback_public(tmp_path: Path):
|
|
60
|
+
"""On 401 + refresh failure, fall back to /scans/cli with warning."""
|
|
61
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
62
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []},
|
|
63
|
+
url="https://api.example.com/scans/cli")
|
|
64
|
+
|
|
65
|
+
mock_client = MagicMock()
|
|
66
|
+
mock_client.__enter__.return_value = mock_client
|
|
67
|
+
mock_client.__exit__.return_value = False
|
|
68
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
69
|
+
|
|
70
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
71
|
+
cfg_file.parent.mkdir(parents=True)
|
|
72
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
73
|
+
|
|
74
|
+
env = {
|
|
75
|
+
"HOME": str(tmp_path),
|
|
76
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
with (
|
|
80
|
+
patch("httpx.Client", return_value=mock_client),
|
|
81
|
+
patch("alloc.cli.try_refresh_access_token", return_value=None),
|
|
82
|
+
):
|
|
83
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
84
|
+
|
|
85
|
+
assert result.exit_code == 0
|
|
86
|
+
assert mock_client.post.call_count == 2
|
|
87
|
+
# Second call should hit /scans/cli
|
|
88
|
+
second_url = str(mock_client.post.call_args_list[1])
|
|
89
|
+
assert "/scans/cli" in second_url
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_scan_401_fallback_warns_about_dropped_features(tmp_path: Path):
|
|
93
|
+
"""Fallback to public scan warns user about lost org context."""
|
|
94
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
95
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
96
|
+
|
|
97
|
+
mock_client = MagicMock()
|
|
98
|
+
mock_client.__enter__.return_value = mock_client
|
|
99
|
+
mock_client.__exit__.return_value = False
|
|
100
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
101
|
+
|
|
102
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
103
|
+
cfg_file.parent.mkdir(parents=True)
|
|
104
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
105
|
+
|
|
106
|
+
env = {
|
|
107
|
+
"HOME": str(tmp_path),
|
|
108
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
with (
|
|
112
|
+
patch("httpx.Client", return_value=mock_client),
|
|
113
|
+
patch("alloc.cli.try_refresh_access_token", return_value=None),
|
|
114
|
+
):
|
|
115
|
+
# Non-JSON mode to see the warning message
|
|
116
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b"], env=env)
|
|
117
|
+
|
|
118
|
+
assert result.exit_code == 0
|
|
119
|
+
assert "expired" in result.output.lower() or "falling back" in result.output.lower()
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_scan_no_token_uses_public_directly(tmp_path: Path):
|
|
123
|
+
"""Without a token, scan goes directly to /scans/cli."""
|
|
124
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
125
|
+
|
|
126
|
+
mock_client = MagicMock()
|
|
127
|
+
mock_client.__enter__.return_value = mock_client
|
|
128
|
+
mock_client.__exit__.return_value = False
|
|
129
|
+
mock_client.post.return_value = resp_ok
|
|
130
|
+
|
|
131
|
+
env = {
|
|
132
|
+
"HOME": str(tmp_path),
|
|
133
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
with patch("httpx.Client", return_value=mock_client):
|
|
137
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
138
|
+
|
|
139
|
+
assert result.exit_code == 0
|
|
140
|
+
assert mock_client.post.call_count == 1
|
|
141
|
+
call_url = str(mock_client.post.call_args_list[0])
|
|
142
|
+
assert "/scans/cli" in call_url
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
"""Alloc — GPU intelligence for ML training."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
__version__ = "0.0.5"
|
|
6
|
-
|
|
7
|
-
from alloc.ghost import ghost, GhostReport
|
|
8
|
-
from alloc.callbacks import AllocCallback as HuggingFaceCallback
|
|
9
|
-
from alloc.callbacks import AllocLightningCallback as LightningCallback
|
|
10
|
-
|
|
11
|
-
__all__ = ["ghost", "GhostReport", "HuggingFaceCallback", "LightningCallback", "__version__"]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|