alloc 0.0.6__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alloc-0.0.6 → alloc-0.0.8}/PKG-INFO +1 -1
- {alloc-0.0.6 → alloc-0.0.8}/pyproject.toml +1 -1
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/__init__.py +3 -1
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/cli.py +60 -5
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/extractor_runner.py +5 -1
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/model_extractor.py +29 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/probe.py +8 -1
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/upload.py +1 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/PKG-INFO +1 -1
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/SOURCES.txt +3 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_cli.py +1 -1
- alloc-0.0.8/tests/test_ghost_degradation.py +145 -0
- alloc-0.0.8/tests/test_scan_auth.py +142 -0
- alloc-0.0.8/tests/test_topology_strategy.py +87 -0
- {alloc-0.0.6 → alloc-0.0.8}/README.md +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/setup.cfg +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/artifact_loader.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/artifact_writer.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/browser_auth.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/callbacks.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/catalog/__init__.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/catalog/default_rate_card.json +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/catalog/gpus.v1.json +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/code_analyzer.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/config.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/context.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/diagnosis_display.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/diagnosis_engine.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/diagnosis_rules.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/display.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/ghost.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/model_registry.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/stability.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc/yaml_config.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/dependency_links.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/entry_points.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/requires.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/top_level.txt +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_artifact.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_artifact_loader.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_auth.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_callbacks.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_catalog.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_code_analyzer.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_context.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_diagnose_cli.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_diagnosis_engine.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_diagnosis_rules.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_extractor_activation.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_ghost.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_init_from_org.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_interconnect.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_model_extractor.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_probe_hw.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_probe_multi.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_stability.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_upload.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_verdict.py +0 -0
- {alloc-0.0.6 → alloc-0.0.8}/tests/test_yaml_config.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alloc
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
|
|
5
5
|
Author-email: Alloc Labs <hello@alloclabs.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "alloc"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.8"
|
|
8
8
|
description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -5,9 +5,11 @@ from __future__ import annotations
|
|
|
5
5
|
import warnings as _warnings
|
|
6
6
|
_warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
7
7
|
_warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
8
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
9
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
8
10
|
del _warnings
|
|
9
11
|
|
|
10
|
-
__version__ = "0.0.
|
|
12
|
+
__version__ = "0.0.8"
|
|
11
13
|
|
|
12
14
|
from alloc.ghost import ghost, GhostReport
|
|
13
15
|
from alloc.callbacks import AllocCallback as HuggingFaceCallback
|
|
@@ -19,10 +19,12 @@ import sys
|
|
|
19
19
|
import warnings
|
|
20
20
|
from typing import List, Optional
|
|
21
21
|
|
|
22
|
-
# Suppress noisy third-party warnings globally — pynvml deprecation
|
|
23
|
-
# urllib3 LibreSSL warnings clutter
|
|
22
|
+
# Suppress noisy third-party warnings globally — pynvml deprecation (emitted
|
|
23
|
+
# from torch.cuda.__init__) and urllib3 LibreSSL warnings clutter CLI output.
|
|
24
24
|
warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
25
25
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
26
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
27
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
26
28
|
warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
|
|
27
29
|
|
|
28
30
|
import typer
|
|
@@ -75,6 +77,19 @@ def ghost(
|
|
|
75
77
|
console.print(f"[dim]Tip: alloc ghost {script} --param-count-b 7.0[/dim]")
|
|
76
78
|
raise typer.Exit(1)
|
|
77
79
|
|
|
80
|
+
if info.extraction_error:
|
|
81
|
+
if json_output:
|
|
82
|
+
_print_json({
|
|
83
|
+
"error": info.extraction_error,
|
|
84
|
+
"detail": info.extraction_detail,
|
|
85
|
+
"supported": False,
|
|
86
|
+
})
|
|
87
|
+
else:
|
|
88
|
+
console.print(f"[yellow]{info.extraction_detail}[/yellow]")
|
|
89
|
+
if info.extraction_error == "distributed_entrypoint":
|
|
90
|
+
console.print("[dim]Tip: alloc ghost model.py (point to the file that defines your model)[/dim]")
|
|
91
|
+
raise typer.Exit(1)
|
|
92
|
+
|
|
78
93
|
# Use dtype from execution if available, otherwise CLI flag
|
|
79
94
|
resolved_dtype = info.dtype if info.method == "execution" else dtype
|
|
80
95
|
|
|
@@ -440,7 +455,9 @@ def run(
|
|
|
440
455
|
"tp_degree": topology.get("tp_degree"),
|
|
441
456
|
"pp_degree": topology.get("pp_degree"),
|
|
442
457
|
"dp_degree": topology.get("dp_degree"),
|
|
458
|
+
"strategy": topology.get("strategy"),
|
|
443
459
|
"interconnect_type": topology.get("interconnect_type"),
|
|
460
|
+
"process_map": result.process_map,
|
|
444
461
|
"objective": objective,
|
|
445
462
|
"max_budget_hourly": max_budget_hourly,
|
|
446
463
|
"command": " ".join(command),
|
|
@@ -2099,12 +2116,32 @@ def scan(
|
|
|
2099
2116
|
|
|
2100
2117
|
try:
|
|
2101
2118
|
headers = {"Content-Type": "application/json"}
|
|
2119
|
+
used_auth = bool(token)
|
|
2120
|
+
|
|
2102
2121
|
if token:
|
|
2103
2122
|
headers["Authorization"] = f"Bearer {token}"
|
|
2123
|
+
endpoint = "/scans"
|
|
2124
|
+
else:
|
|
2125
|
+
endpoint = "/scans/cli"
|
|
2104
2126
|
|
|
2105
|
-
endpoint = "/scans" if token else "/scans/cli"
|
|
2106
2127
|
with httpx.Client(timeout=30) as client:
|
|
2107
2128
|
resp = client.post(f"{api_url}{endpoint}", json=payload, headers=headers)
|
|
2129
|
+
|
|
2130
|
+
# On 401 with a saved token: try refresh, then fall back to public endpoint
|
|
2131
|
+
if resp.status_code == 401 and used_auth:
|
|
2132
|
+
new_token = try_refresh_access_token()
|
|
2133
|
+
if new_token:
|
|
2134
|
+
headers["Authorization"] = f"Bearer {new_token}"
|
|
2135
|
+
resp = client.post(f"{api_url}/scans", json=payload, headers=headers)
|
|
2136
|
+
else:
|
|
2137
|
+
# Token refresh failed — fall back to unauthenticated scan
|
|
2138
|
+
console.print(
|
|
2139
|
+
"[yellow]Session expired — falling back to public scan "
|
|
2140
|
+
"(org fleet context unavailable). Run `alloc login` to restore.[/yellow]",
|
|
2141
|
+
)
|
|
2142
|
+
del headers["Authorization"]
|
|
2143
|
+
resp = client.post(f"{api_url}/scans/cli", json=payload, headers=headers)
|
|
2144
|
+
|
|
2108
2145
|
resp.raise_for_status()
|
|
2109
2146
|
result = resp.json()
|
|
2110
2147
|
|
|
@@ -2333,7 +2370,7 @@ def whoami(
|
|
|
2333
2370
|
|
|
2334
2371
|
out = {
|
|
2335
2372
|
"api_url": api_url,
|
|
2336
|
-
"logged_in":
|
|
2373
|
+
"logged_in": False,
|
|
2337
2374
|
"token_source": token_source if token else None,
|
|
2338
2375
|
}
|
|
2339
2376
|
|
|
@@ -2378,6 +2415,9 @@ def whoami(
|
|
|
2378
2415
|
console.print(f"[red]Cannot connect to {api_url}[/red]")
|
|
2379
2416
|
raise typer.Exit(1)
|
|
2380
2417
|
|
|
2418
|
+
# API validated the token — now we know login is real
|
|
2419
|
+
out["logged_in"] = True
|
|
2420
|
+
|
|
2381
2421
|
gpus = fleet.get("gpus") or []
|
|
2382
2422
|
fleet_count = len([g for g in gpus if g.get("fleet_status") == "in_fleet"])
|
|
2383
2423
|
explore_count = len([g for g in gpus if g.get("fleet_status") == "explore"])
|
|
@@ -3052,7 +3092,7 @@ def status(
|
|
|
3052
3092
|
|
|
3053
3093
|
out = {
|
|
3054
3094
|
"version": __version__,
|
|
3055
|
-
"
|
|
3095
|
+
"has_token": bool(token),
|
|
3056
3096
|
"api_url": api_url,
|
|
3057
3097
|
"artifact": None,
|
|
3058
3098
|
"dashboard_url": None,
|
|
@@ -3513,6 +3553,20 @@ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_intercon
|
|
|
3513
3553
|
if interconnect not in ("pcie", "nvlink", "nvlink_switch", "nvlink_p2p", "infiniband", "unknown"):
|
|
3514
3554
|
interconnect = "unknown"
|
|
3515
3555
|
|
|
3556
|
+
# Infer strategy from degrees — only when evidence exists
|
|
3557
|
+
strategy = None
|
|
3558
|
+
has_tp = tp is not None and tp > 1
|
|
3559
|
+
has_pp = pp is not None and pp > 1
|
|
3560
|
+
if has_tp and has_pp:
|
|
3561
|
+
strategy = "tp+pp+dp"
|
|
3562
|
+
elif has_tp:
|
|
3563
|
+
strategy = "tp+dp" if (dp is not None and dp > 1) else "tp"
|
|
3564
|
+
elif has_pp:
|
|
3565
|
+
strategy = "pp+dp" if (dp is not None and dp > 1) else "pp"
|
|
3566
|
+
elif dp is not None and dp > 1:
|
|
3567
|
+
strategy = "ddp"
|
|
3568
|
+
# If none of the above matched, strategy stays None (unknown)
|
|
3569
|
+
|
|
3516
3570
|
return {
|
|
3517
3571
|
"num_nodes": nnodes or 1,
|
|
3518
3572
|
"gpus_per_node": gpn,
|
|
@@ -3520,6 +3574,7 @@ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_intercon
|
|
|
3520
3574
|
"pp_degree": pp,
|
|
3521
3575
|
"dp_degree": dp,
|
|
3522
3576
|
"interconnect_type": interconnect,
|
|
3577
|
+
"strategy": strategy,
|
|
3523
3578
|
}
|
|
3524
3579
|
|
|
3525
3580
|
|
|
@@ -206,7 +206,11 @@ def main():
|
|
|
206
206
|
except SystemExit:
|
|
207
207
|
pass # catch real SystemExit too
|
|
208
208
|
except Exception as e:
|
|
209
|
-
|
|
209
|
+
error_msg = str(e)[:500]
|
|
210
|
+
_dist_keywords = ("init_process_group", "nccl", "gloo", "distributed",
|
|
211
|
+
"master_addr", "master_port", "rendezvouserror")
|
|
212
|
+
status = "error_distributed" if any(kw in error_msg.lower() for kw in _dist_keywords) else "error"
|
|
213
|
+
result = {"status": status, "error": error_msg}
|
|
210
214
|
with open(sidecar_path, "w") as f:
|
|
211
215
|
json.dump(result, f)
|
|
212
216
|
return
|
|
@@ -33,6 +33,8 @@ class ModelInfo:
|
|
|
33
33
|
seq_length: Optional[int] = None
|
|
34
34
|
activation_memory_bytes: Optional[int] = None
|
|
35
35
|
activation_method: Optional[str] = None # "traced" | None
|
|
36
|
+
extraction_error: Optional[str] = None # "distributed_entrypoint" | None
|
|
37
|
+
extraction_detail: Optional[str] = None # human-readable explanation
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
def extract_model_info(
|
|
@@ -106,6 +108,10 @@ def _extract_via_subprocess(
|
|
|
106
108
|
env.setdefault("WORLD_SIZE", "1")
|
|
107
109
|
env.setdefault("MASTER_ADDR", "127.0.0.1")
|
|
108
110
|
env.setdefault("MASTER_PORT", "29500")
|
|
111
|
+
# Suppress pynvml/torch.cuda deprecation warnings in subprocess
|
|
112
|
+
existing = env.get("PYTHONWARNINGS", "")
|
|
113
|
+
filters = "ignore::FutureWarning,ignore::DeprecationWarning"
|
|
114
|
+
env["PYTHONWARNINGS"] = f"{existing},{filters}" if existing else filters
|
|
109
115
|
|
|
110
116
|
subprocess.run(
|
|
111
117
|
[sys.executable, "-m", "alloc.extractor_runner", sidecar_path, script_abs],
|
|
@@ -134,6 +140,29 @@ def _extract_via_subprocess(
|
|
|
134
140
|
activation_method=data.get("activation_method"),
|
|
135
141
|
)
|
|
136
142
|
|
|
143
|
+
# Structured degradation for distributed scripts
|
|
144
|
+
status = data.get("status", "")
|
|
145
|
+
if status in ("error", "error_distributed"):
|
|
146
|
+
is_distributed = status == "error_distributed"
|
|
147
|
+
if not is_distributed:
|
|
148
|
+
# Fallback keyword match for older sidecar format
|
|
149
|
+
error_msg = data.get("error", "")
|
|
150
|
+
_dist_keywords = ("init_process_group", "NCCL", "gloo", "distributed",
|
|
151
|
+
"MASTER_ADDR", "MASTER_PORT", "RendezvousError")
|
|
152
|
+
is_distributed = any(kw.lower() in error_msg.lower() for kw in _dist_keywords)
|
|
153
|
+
if is_distributed:
|
|
154
|
+
return ModelInfo(
|
|
155
|
+
param_count=0,
|
|
156
|
+
dtype="float16",
|
|
157
|
+
model_name=None,
|
|
158
|
+
method="execution",
|
|
159
|
+
extraction_error="distributed_entrypoint",
|
|
160
|
+
extraction_detail=(
|
|
161
|
+
"Script requires a distributed runtime (e.g. torchrun). "
|
|
162
|
+
"Run ghost on the model definition file instead of the launcher script."
|
|
163
|
+
),
|
|
164
|
+
)
|
|
165
|
+
|
|
137
166
|
return None
|
|
138
167
|
|
|
139
168
|
except subprocess.TimeoutExpired:
|
|
@@ -18,6 +18,13 @@ from dataclasses import dataclass, field
|
|
|
18
18
|
from enum import Enum
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
|
|
21
|
+
import warnings as _warnings
|
|
22
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
|
|
23
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
|
|
24
|
+
_warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
|
|
25
|
+
_warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
|
|
26
|
+
del _warnings
|
|
27
|
+
|
|
21
28
|
|
|
22
29
|
class StopReason(str, Enum):
|
|
23
30
|
STABLE = "stable"
|
|
@@ -356,7 +363,7 @@ def probe_command(
|
|
|
356
363
|
"""
|
|
357
364
|
pynvml = _try_import_pynvml()
|
|
358
365
|
|
|
359
|
-
# Launch the subprocess
|
|
366
|
+
# Launch the user's training subprocess — do NOT modify env (their warnings matter)
|
|
360
367
|
try:
|
|
361
368
|
proc = subprocess.Popen(
|
|
362
369
|
command,
|
|
@@ -123,6 +123,7 @@ def upload_artifact(artifact_path: str, api_url: str, token: str) -> dict:
|
|
|
123
123
|
"dataloader_wait_pct": probe.get("dataloader_wait_pct"),
|
|
124
124
|
"comm_overhead_pct": probe.get("comm_overhead_pct"),
|
|
125
125
|
"per_rank_peak_vram_mb": probe.get("per_rank_peak_vram_mb"),
|
|
126
|
+
"process_map": probe.get("process_map"),
|
|
126
127
|
# Architecture fields: probe (callbacks) takes priority over ghost defaults
|
|
127
128
|
"batch_size": probe.get("batch_size") or (ghost.get("batch_size") if ghost else None),
|
|
128
129
|
"seq_length": ghost.get("seq_length") if ghost else None,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: alloc
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
|
|
5
5
|
Author-email: Alloc Labs <hello@alloclabs.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -43,12 +43,15 @@ tests/test_diagnosis_engine.py
|
|
|
43
43
|
tests/test_diagnosis_rules.py
|
|
44
44
|
tests/test_extractor_activation.py
|
|
45
45
|
tests/test_ghost.py
|
|
46
|
+
tests/test_ghost_degradation.py
|
|
46
47
|
tests/test_init_from_org.py
|
|
47
48
|
tests/test_interconnect.py
|
|
48
49
|
tests/test_model_extractor.py
|
|
49
50
|
tests/test_probe_hw.py
|
|
50
51
|
tests/test_probe_multi.py
|
|
52
|
+
tests/test_scan_auth.py
|
|
51
53
|
tests/test_stability.py
|
|
54
|
+
tests/test_topology_strategy.py
|
|
52
55
|
tests/test_upload.py
|
|
53
56
|
tests/test_verdict.py
|
|
54
57
|
tests/test_yaml_config.py
|
|
@@ -239,7 +239,7 @@ def test_status_json_no_artifact(tmp_path, monkeypatch):
|
|
|
239
239
|
assert result.exit_code == 0
|
|
240
240
|
data = json.loads(result.output.strip())
|
|
241
241
|
assert data["artifact"] is None
|
|
242
|
-
assert data["
|
|
242
|
+
assert data["has_token"] is False
|
|
243
243
|
assert "version" in data
|
|
244
244
|
|
|
245
245
|
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Tests for ghost structured degradation on distributed scripts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import tempfile
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from unittest.mock import patch
|
|
10
|
+
|
|
11
|
+
from typer.testing import CliRunner
|
|
12
|
+
|
|
13
|
+
from alloc.cli import app
|
|
14
|
+
from alloc.model_extractor import ModelInfo, extract_model_info
|
|
15
|
+
|
|
16
|
+
runner = CliRunner()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_distributed_error_returns_structured_modelinfo():
|
|
20
|
+
"""When extractor subprocess fails with a distributed keyword, return structured ModelInfo."""
|
|
21
|
+
sidecar_data = json.dumps({
|
|
22
|
+
"status": "error",
|
|
23
|
+
"error": "RuntimeError: torch.distributed.init_process_group requires MASTER_ADDR",
|
|
24
|
+
})
|
|
25
|
+
|
|
26
|
+
def _fake_subprocess_run(*args, **kwargs):
|
|
27
|
+
# Write the sidecar file
|
|
28
|
+
sidecar_path = args[0][3] # [python, -m, alloc.extractor_runner, sidecar_path, script_path]
|
|
29
|
+
with open(sidecar_path, "w") as f:
|
|
30
|
+
f.write(sidecar_data)
|
|
31
|
+
|
|
32
|
+
# Create a dummy script
|
|
33
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_dist_")
|
|
34
|
+
os.write(fd, b"import torch\ntorch.distributed.init_process_group('nccl')\n")
|
|
35
|
+
os.close(fd)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
with patch("subprocess.run", side_effect=_fake_subprocess_run):
|
|
39
|
+
info = extract_model_info(script_path)
|
|
40
|
+
|
|
41
|
+
assert info is not None
|
|
42
|
+
assert info.extraction_error == "distributed_entrypoint"
|
|
43
|
+
assert "distributed runtime" in info.extraction_detail
|
|
44
|
+
assert info.param_count == 0
|
|
45
|
+
finally:
|
|
46
|
+
os.unlink(script_path)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_distributed_error_nccl_keyword():
|
|
50
|
+
"""NCCL errors should be caught as distributed failures."""
|
|
51
|
+
sidecar_data = json.dumps({
|
|
52
|
+
"status": "error",
|
|
53
|
+
"error": "NCCL error: unhandled system error",
|
|
54
|
+
})
|
|
55
|
+
|
|
56
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_nccl_")
|
|
57
|
+
os.write(fd, b"pass\n")
|
|
58
|
+
os.close(fd)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
def _fake_run(*args, **kwargs):
|
|
62
|
+
sidecar_path = args[0][3]
|
|
63
|
+
with open(sidecar_path, "w") as f:
|
|
64
|
+
f.write(sidecar_data)
|
|
65
|
+
|
|
66
|
+
with patch("subprocess.run", side_effect=_fake_run):
|
|
67
|
+
info = extract_model_info(script_path)
|
|
68
|
+
|
|
69
|
+
assert info is not None
|
|
70
|
+
assert info.extraction_error == "distributed_entrypoint"
|
|
71
|
+
finally:
|
|
72
|
+
os.unlink(script_path)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_non_distributed_error_returns_none():
|
|
76
|
+
"""Non-distributed errors should still return None (fall through to AST)."""
|
|
77
|
+
sidecar_data = json.dumps({
|
|
78
|
+
"status": "error",
|
|
79
|
+
"error": "ImportError: No module named 'custom_lib'",
|
|
80
|
+
})
|
|
81
|
+
|
|
82
|
+
fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_other_")
|
|
83
|
+
# Script with no from_pretrained so AST also returns None
|
|
84
|
+
os.write(fd, b"import custom_lib\n")
|
|
85
|
+
os.close(fd)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
def _fake_run(*args, **kwargs):
|
|
89
|
+
sidecar_path = args[0][3]
|
|
90
|
+
with open(sidecar_path, "w") as f:
|
|
91
|
+
f.write(sidecar_data)
|
|
92
|
+
|
|
93
|
+
with patch("subprocess.run", side_effect=_fake_run):
|
|
94
|
+
info = extract_model_info(script_path)
|
|
95
|
+
|
|
96
|
+
# Should be None because error is not distributed and AST won't find a model either
|
|
97
|
+
assert info is None
|
|
98
|
+
finally:
|
|
99
|
+
os.unlink(script_path)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_ghost_cli_distributed_error_json(tmp_path: Path):
|
|
103
|
+
"""ghost --json shows structured error for distributed scripts."""
|
|
104
|
+
script_path = tmp_path / "train_ddp.py"
|
|
105
|
+
script_path.write_text("import torch\ntorch.distributed.init_process_group('nccl')\n")
|
|
106
|
+
|
|
107
|
+
dist_info = ModelInfo(
|
|
108
|
+
param_count=0,
|
|
109
|
+
dtype="float16",
|
|
110
|
+
model_name=None,
|
|
111
|
+
method="execution",
|
|
112
|
+
extraction_error="distributed_entrypoint",
|
|
113
|
+
extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
|
|
117
|
+
result = runner.invoke(app, ["ghost", str(script_path), "--json"])
|
|
118
|
+
|
|
119
|
+
assert result.exit_code != 0
|
|
120
|
+
data = json.loads(result.output)
|
|
121
|
+
assert data["error"] == "distributed_entrypoint"
|
|
122
|
+
assert data["supported"] is False
|
|
123
|
+
assert "distributed runtime" in data["detail"]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_ghost_cli_distributed_error_human(tmp_path: Path):
|
|
127
|
+
"""ghost shows human-readable message with tip for distributed scripts."""
|
|
128
|
+
script_path = tmp_path / "train_ddp.py"
|
|
129
|
+
script_path.write_text("import torch\n")
|
|
130
|
+
|
|
131
|
+
dist_info = ModelInfo(
|
|
132
|
+
param_count=0,
|
|
133
|
+
dtype="float16",
|
|
134
|
+
model_name=None,
|
|
135
|
+
method="execution",
|
|
136
|
+
extraction_error="distributed_entrypoint",
|
|
137
|
+
extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
|
|
141
|
+
result = runner.invoke(app, ["ghost", str(script_path)])
|
|
142
|
+
|
|
143
|
+
assert result.exit_code != 0
|
|
144
|
+
assert "distributed runtime" in result.output
|
|
145
|
+
assert "model.py" in result.output # tip about pointing to model file
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Tests for scan command 401 retry + /scans/cli fallback."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from unittest.mock import MagicMock, patch
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
from typer.testing import CliRunner
|
|
11
|
+
|
|
12
|
+
from alloc.cli import app
|
|
13
|
+
|
|
14
|
+
runner = CliRunner()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _make_resp(status_code: int, body: dict, url: str = "https://api.example.com/scans"):
|
|
18
|
+
req = httpx.Request("POST", url)
|
|
19
|
+
return httpx.Response(
|
|
20
|
+
status_code,
|
|
21
|
+
request=req,
|
|
22
|
+
content=json.dumps(body).encode(),
|
|
23
|
+
headers={"content-type": "application/json"},
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_scan_401_refresh_retry(tmp_path: Path):
|
|
28
|
+
"""On 401, refresh token and retry on /scans."""
|
|
29
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
30
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
31
|
+
|
|
32
|
+
mock_client = MagicMock()
|
|
33
|
+
mock_client.__enter__.return_value = mock_client
|
|
34
|
+
mock_client.__exit__.return_value = False
|
|
35
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
36
|
+
|
|
37
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
38
|
+
cfg_file.parent.mkdir(parents=True)
|
|
39
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
40
|
+
|
|
41
|
+
env = {
|
|
42
|
+
"HOME": str(tmp_path),
|
|
43
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
with (
|
|
47
|
+
patch("httpx.Client", return_value=mock_client),
|
|
48
|
+
patch("alloc.cli.try_refresh_access_token", return_value="new-tok"),
|
|
49
|
+
):
|
|
50
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
51
|
+
|
|
52
|
+
assert result.exit_code == 0
|
|
53
|
+
assert mock_client.post.call_count == 2
|
|
54
|
+
# Second call should use refreshed token
|
|
55
|
+
second_call = mock_client.post.call_args_list[1]
|
|
56
|
+
assert "Bearer new-tok" in str(second_call)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_scan_401_refresh_fails_fallback_public(tmp_path: Path):
|
|
60
|
+
"""On 401 + refresh failure, fall back to /scans/cli with warning."""
|
|
61
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
62
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []},
|
|
63
|
+
url="https://api.example.com/scans/cli")
|
|
64
|
+
|
|
65
|
+
mock_client = MagicMock()
|
|
66
|
+
mock_client.__enter__.return_value = mock_client
|
|
67
|
+
mock_client.__exit__.return_value = False
|
|
68
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
69
|
+
|
|
70
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
71
|
+
cfg_file.parent.mkdir(parents=True)
|
|
72
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
73
|
+
|
|
74
|
+
env = {
|
|
75
|
+
"HOME": str(tmp_path),
|
|
76
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
with (
|
|
80
|
+
patch("httpx.Client", return_value=mock_client),
|
|
81
|
+
patch("alloc.cli.try_refresh_access_token", return_value=None),
|
|
82
|
+
):
|
|
83
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
84
|
+
|
|
85
|
+
assert result.exit_code == 0
|
|
86
|
+
assert mock_client.post.call_count == 2
|
|
87
|
+
# Second call should hit /scans/cli
|
|
88
|
+
second_url = str(mock_client.post.call_args_list[1])
|
|
89
|
+
assert "/scans/cli" in second_url
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_scan_401_fallback_warns_about_dropped_features(tmp_path: Path):
|
|
93
|
+
"""Fallback to public scan warns user about lost org context."""
|
|
94
|
+
resp_401 = _make_resp(401, {"detail": "unauthorized"})
|
|
95
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
96
|
+
|
|
97
|
+
mock_client = MagicMock()
|
|
98
|
+
mock_client.__enter__.return_value = mock_client
|
|
99
|
+
mock_client.__exit__.return_value = False
|
|
100
|
+
mock_client.post.side_effect = [resp_401, resp_ok]
|
|
101
|
+
|
|
102
|
+
cfg_file = tmp_path / ".alloc" / "config.json"
|
|
103
|
+
cfg_file.parent.mkdir(parents=True)
|
|
104
|
+
cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
|
|
105
|
+
|
|
106
|
+
env = {
|
|
107
|
+
"HOME": str(tmp_path),
|
|
108
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
with (
|
|
112
|
+
patch("httpx.Client", return_value=mock_client),
|
|
113
|
+
patch("alloc.cli.try_refresh_access_token", return_value=None),
|
|
114
|
+
):
|
|
115
|
+
# Non-JSON mode to see the warning message
|
|
116
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b"], env=env)
|
|
117
|
+
|
|
118
|
+
assert result.exit_code == 0
|
|
119
|
+
assert "expired" in result.output.lower() or "falling back" in result.output.lower()
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_scan_no_token_uses_public_directly(tmp_path: Path):
|
|
123
|
+
"""Without a token, scan goes directly to /scans/cli."""
|
|
124
|
+
resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
|
|
125
|
+
|
|
126
|
+
mock_client = MagicMock()
|
|
127
|
+
mock_client.__enter__.return_value = mock_client
|
|
128
|
+
mock_client.__exit__.return_value = False
|
|
129
|
+
mock_client.post.return_value = resp_ok
|
|
130
|
+
|
|
131
|
+
env = {
|
|
132
|
+
"HOME": str(tmp_path),
|
|
133
|
+
"ALLOC_API_URL": "https://api.example.com",
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
with patch("httpx.Client", return_value=mock_client):
|
|
137
|
+
result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
|
|
138
|
+
|
|
139
|
+
assert result.exit_code == 0
|
|
140
|
+
assert mock_client.post.call_count == 1
|
|
141
|
+
call_url = str(mock_client.post.call_args_list[0])
|
|
142
|
+
assert "/scans/cli" in call_url
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Tests for strategy inference from topology degrees (P0-B)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from unittest.mock import patch
|
|
7
|
+
|
|
8
|
+
from alloc.cli import _infer_parallel_topology_from_env
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TestStrategyInference:
|
|
12
|
+
"""Strategy should be inferred from TP/PP/DP degrees when present."""
|
|
13
|
+
|
|
14
|
+
def _topo(self, env=None, num_gpus=4):
|
|
15
|
+
env = env or {}
|
|
16
|
+
with patch.dict(os.environ, env, clear=False):
|
|
17
|
+
return _infer_parallel_topology_from_env(
|
|
18
|
+
num_gpus_detected=num_gpus,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def test_no_degrees_strategy_none(self):
|
|
22
|
+
"""When no degree env vars set, strategy should be None."""
|
|
23
|
+
result = self._topo({})
|
|
24
|
+
assert result["strategy"] is None
|
|
25
|
+
|
|
26
|
+
def test_dp_only_is_ddp(self):
|
|
27
|
+
"""WORLD_SIZE=4 with no TP/PP → dp inferred → strategy=ddp."""
|
|
28
|
+
result = self._topo({"WORLD_SIZE": "4"})
|
|
29
|
+
assert result["strategy"] == "ddp"
|
|
30
|
+
assert result["dp_degree"] == 4
|
|
31
|
+
|
|
32
|
+
def test_tp_only(self):
|
|
33
|
+
"""TP_SIZE=4 alone → strategy=tp."""
|
|
34
|
+
result = self._topo({"TP_SIZE": "4"})
|
|
35
|
+
assert result["strategy"] == "tp"
|
|
36
|
+
|
|
37
|
+
def test_pp_only(self):
|
|
38
|
+
"""PP_SIZE=4 alone → strategy=pp."""
|
|
39
|
+
result = self._topo({"PP_SIZE": "4"})
|
|
40
|
+
assert result["strategy"] == "pp"
|
|
41
|
+
|
|
42
|
+
def test_tp_dp(self):
|
|
43
|
+
"""TP_SIZE=2 with DP_SIZE=2 → strategy=tp+dp."""
|
|
44
|
+
result = self._topo({"TP_SIZE": "2", "DP_SIZE": "2"})
|
|
45
|
+
assert result["strategy"] == "tp+dp"
|
|
46
|
+
|
|
47
|
+
def test_pp_dp(self):
|
|
48
|
+
"""PP_SIZE=2 with DP_SIZE=2 → strategy=pp+dp."""
|
|
49
|
+
result = self._topo({"PP_SIZE": "2", "DP_SIZE": "2"})
|
|
50
|
+
assert result["strategy"] == "pp+dp"
|
|
51
|
+
|
|
52
|
+
def test_tp_pp_dp(self):
|
|
53
|
+
"""All three degrees → strategy=tp+pp+dp."""
|
|
54
|
+
result = self._topo({"TP_SIZE": "2", "PP_SIZE": "2", "DP_SIZE": "2"})
|
|
55
|
+
assert result["strategy"] == "tp+pp+dp"
|
|
56
|
+
|
|
57
|
+
def test_tp_pp_no_dp(self):
|
|
58
|
+
"""TP+PP without explicit DP → strategy=tp+pp+dp."""
|
|
59
|
+
result = self._topo({"TP_SIZE": "2", "PP_SIZE": "2"})
|
|
60
|
+
assert result["strategy"] == "tp+pp+dp"
|
|
61
|
+
|
|
62
|
+
def test_tp_size_1_not_counted(self):
|
|
63
|
+
"""TP_SIZE=1 should not count as tensor parallelism."""
|
|
64
|
+
result = self._topo({"TP_SIZE": "1", "DP_SIZE": "4"})
|
|
65
|
+
assert result["strategy"] == "ddp"
|
|
66
|
+
|
|
67
|
+
def test_pp_size_1_not_counted(self):
|
|
68
|
+
"""PP_SIZE=1 should not count as pipeline parallelism."""
|
|
69
|
+
result = self._topo({"PP_SIZE": "1", "DP_SIZE": "4"})
|
|
70
|
+
assert result["strategy"] == "ddp"
|
|
71
|
+
|
|
72
|
+
def test_dp_inferred_from_world_size(self):
|
|
73
|
+
"""DP inferred from WORLD_SIZE / (TP * PP) → strategy includes dp."""
|
|
74
|
+
result = self._topo({"WORLD_SIZE": "8", "TP_SIZE": "2"})
|
|
75
|
+
assert result["dp_degree"] == 4
|
|
76
|
+
assert result["strategy"] == "tp+dp"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TestProcessMapInProbeDictAssembly:
|
|
80
|
+
"""process_map should reach probe_dict from ProbeResult."""
|
|
81
|
+
|
|
82
|
+
def test_process_map_present_in_topology_return(self):
|
|
83
|
+
"""Topology dict now includes strategy field."""
|
|
84
|
+
with patch.dict(os.environ, {"WORLD_SIZE": "4"}, clear=False):
|
|
85
|
+
topo = _infer_parallel_topology_from_env(num_gpus_detected=4)
|
|
86
|
+
assert "strategy" in topo
|
|
87
|
+
assert topo["strategy"] == "ddp"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|