alloc 0.0.6__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {alloc-0.0.6 → alloc-0.0.8}/PKG-INFO +1 -1
  2. {alloc-0.0.6 → alloc-0.0.8}/pyproject.toml +1 -1
  3. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/__init__.py +3 -1
  4. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/cli.py +60 -5
  5. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/extractor_runner.py +5 -1
  6. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/model_extractor.py +29 -0
  7. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/probe.py +8 -1
  8. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/upload.py +1 -0
  9. {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/PKG-INFO +1 -1
  10. {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/SOURCES.txt +3 -0
  11. {alloc-0.0.6 → alloc-0.0.8}/tests/test_cli.py +1 -1
  12. alloc-0.0.8/tests/test_ghost_degradation.py +145 -0
  13. alloc-0.0.8/tests/test_scan_auth.py +142 -0
  14. alloc-0.0.8/tests/test_topology_strategy.py +87 -0
  15. {alloc-0.0.6 → alloc-0.0.8}/README.md +0 -0
  16. {alloc-0.0.6 → alloc-0.0.8}/setup.cfg +0 -0
  17. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/artifact_loader.py +0 -0
  18. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/artifact_writer.py +0 -0
  19. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/browser_auth.py +0 -0
  20. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/callbacks.py +0 -0
  21. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/catalog/__init__.py +0 -0
  22. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/catalog/default_rate_card.json +0 -0
  23. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/catalog/gpus.v1.json +0 -0
  24. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/code_analyzer.py +0 -0
  25. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/config.py +0 -0
  26. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/context.py +0 -0
  27. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/diagnosis_display.py +0 -0
  28. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/diagnosis_engine.py +0 -0
  29. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/diagnosis_rules.py +0 -0
  30. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/display.py +0 -0
  31. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/ghost.py +0 -0
  32. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/model_registry.py +0 -0
  33. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/stability.py +0 -0
  34. {alloc-0.0.6 → alloc-0.0.8}/src/alloc/yaml_config.py +0 -0
  35. {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/dependency_links.txt +0 -0
  36. {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/entry_points.txt +0 -0
  37. {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/requires.txt +0 -0
  38. {alloc-0.0.6 → alloc-0.0.8}/src/alloc.egg-info/top_level.txt +0 -0
  39. {alloc-0.0.6 → alloc-0.0.8}/tests/test_artifact.py +0 -0
  40. {alloc-0.0.6 → alloc-0.0.8}/tests/test_artifact_loader.py +0 -0
  41. {alloc-0.0.6 → alloc-0.0.8}/tests/test_auth.py +0 -0
  42. {alloc-0.0.6 → alloc-0.0.8}/tests/test_callbacks.py +0 -0
  43. {alloc-0.0.6 → alloc-0.0.8}/tests/test_catalog.py +0 -0
  44. {alloc-0.0.6 → alloc-0.0.8}/tests/test_code_analyzer.py +0 -0
  45. {alloc-0.0.6 → alloc-0.0.8}/tests/test_context.py +0 -0
  46. {alloc-0.0.6 → alloc-0.0.8}/tests/test_diagnose_cli.py +0 -0
  47. {alloc-0.0.6 → alloc-0.0.8}/tests/test_diagnosis_engine.py +0 -0
  48. {alloc-0.0.6 → alloc-0.0.8}/tests/test_diagnosis_rules.py +0 -0
  49. {alloc-0.0.6 → alloc-0.0.8}/tests/test_extractor_activation.py +0 -0
  50. {alloc-0.0.6 → alloc-0.0.8}/tests/test_ghost.py +0 -0
  51. {alloc-0.0.6 → alloc-0.0.8}/tests/test_init_from_org.py +0 -0
  52. {alloc-0.0.6 → alloc-0.0.8}/tests/test_interconnect.py +0 -0
  53. {alloc-0.0.6 → alloc-0.0.8}/tests/test_model_extractor.py +0 -0
  54. {alloc-0.0.6 → alloc-0.0.8}/tests/test_probe_hw.py +0 -0
  55. {alloc-0.0.6 → alloc-0.0.8}/tests/test_probe_multi.py +0 -0
  56. {alloc-0.0.6 → alloc-0.0.8}/tests/test_stability.py +0 -0
  57. {alloc-0.0.6 → alloc-0.0.8}/tests/test_upload.py +0 -0
  58. {alloc-0.0.6 → alloc-0.0.8}/tests/test_verdict.py +0 -0
  59. {alloc-0.0.6 → alloc-0.0.8}/tests/test_yaml_config.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "alloc"
7
- version = "0.0.6"
7
+ version = "0.0.8"
8
8
  description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -5,9 +5,11 @@ from __future__ import annotations
5
5
  import warnings as _warnings
6
6
  _warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
7
7
  _warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
8
+ _warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
9
+ _warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
8
10
  del _warnings
9
11
 
10
- __version__ = "0.0.6"
12
+ __version__ = "0.0.8"
11
13
 
12
14
  from alloc.ghost import ghost, GhostReport
13
15
  from alloc.callbacks import AllocCallback as HuggingFaceCallback
@@ -19,10 +19,12 @@ import sys
19
19
  import warnings
20
20
  from typing import List, Optional
21
21
 
22
- # Suppress noisy third-party warnings globally — pynvml deprecation and
23
- # urllib3 LibreSSL warnings clutter every CLI command on affected systems.
22
+ # Suppress noisy third-party warnings globally — pynvml deprecation (emitted
23
+ # from torch.cuda.__init__) and urllib3 LibreSSL warnings clutter CLI output.
24
24
  warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
25
25
  warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
26
+ warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
27
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
26
28
  warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
27
29
 
28
30
  import typer
@@ -75,6 +77,19 @@ def ghost(
75
77
  console.print(f"[dim]Tip: alloc ghost {script} --param-count-b 7.0[/dim]")
76
78
  raise typer.Exit(1)
77
79
 
80
+ if info.extraction_error:
81
+ if json_output:
82
+ _print_json({
83
+ "error": info.extraction_error,
84
+ "detail": info.extraction_detail,
85
+ "supported": False,
86
+ })
87
+ else:
88
+ console.print(f"[yellow]{info.extraction_detail}[/yellow]")
89
+ if info.extraction_error == "distributed_entrypoint":
90
+ console.print("[dim]Tip: alloc ghost model.py (point to the file that defines your model)[/dim]")
91
+ raise typer.Exit(1)
92
+
78
93
  # Use dtype from execution if available, otherwise CLI flag
79
94
  resolved_dtype = info.dtype if info.method == "execution" else dtype
80
95
 
@@ -440,7 +455,9 @@ def run(
440
455
  "tp_degree": topology.get("tp_degree"),
441
456
  "pp_degree": topology.get("pp_degree"),
442
457
  "dp_degree": topology.get("dp_degree"),
458
+ "strategy": topology.get("strategy"),
443
459
  "interconnect_type": topology.get("interconnect_type"),
460
+ "process_map": result.process_map,
444
461
  "objective": objective,
445
462
  "max_budget_hourly": max_budget_hourly,
446
463
  "command": " ".join(command),
@@ -2099,12 +2116,32 @@ def scan(
2099
2116
 
2100
2117
  try:
2101
2118
  headers = {"Content-Type": "application/json"}
2119
+ used_auth = bool(token)
2120
+
2102
2121
  if token:
2103
2122
  headers["Authorization"] = f"Bearer {token}"
2123
+ endpoint = "/scans"
2124
+ else:
2125
+ endpoint = "/scans/cli"
2104
2126
 
2105
- endpoint = "/scans" if token else "/scans/cli"
2106
2127
  with httpx.Client(timeout=30) as client:
2107
2128
  resp = client.post(f"{api_url}{endpoint}", json=payload, headers=headers)
2129
+
2130
+ # On 401 with a saved token: try refresh, then fall back to public endpoint
2131
+ if resp.status_code == 401 and used_auth:
2132
+ new_token = try_refresh_access_token()
2133
+ if new_token:
2134
+ headers["Authorization"] = f"Bearer {new_token}"
2135
+ resp = client.post(f"{api_url}/scans", json=payload, headers=headers)
2136
+ else:
2137
+ # Token refresh failed — fall back to unauthenticated scan
2138
+ console.print(
2139
+ "[yellow]Session expired — falling back to public scan "
2140
+ "(org fleet context unavailable). Run `alloc login` to restore.[/yellow]",
2141
+ )
2142
+ del headers["Authorization"]
2143
+ resp = client.post(f"{api_url}/scans/cli", json=payload, headers=headers)
2144
+
2108
2145
  resp.raise_for_status()
2109
2146
  result = resp.json()
2110
2147
 
@@ -2333,7 +2370,7 @@ def whoami(
2333
2370
 
2334
2371
  out = {
2335
2372
  "api_url": api_url,
2336
- "logged_in": bool(token),
2373
+ "logged_in": False,
2337
2374
  "token_source": token_source if token else None,
2338
2375
  }
2339
2376
 
@@ -2378,6 +2415,9 @@ def whoami(
2378
2415
  console.print(f"[red]Cannot connect to {api_url}[/red]")
2379
2416
  raise typer.Exit(1)
2380
2417
 
2418
+ # API validated the token — now we know login is real
2419
+ out["logged_in"] = True
2420
+
2381
2421
  gpus = fleet.get("gpus") or []
2382
2422
  fleet_count = len([g for g in gpus if g.get("fleet_status") == "in_fleet"])
2383
2423
  explore_count = len([g for g in gpus if g.get("fleet_status") == "explore"])
@@ -3052,7 +3092,7 @@ def status(
3052
3092
 
3053
3093
  out = {
3054
3094
  "version": __version__,
3055
- "logged_in": bool(token),
3095
+ "has_token": bool(token),
3056
3096
  "api_url": api_url,
3057
3097
  "artifact": None,
3058
3098
  "dashboard_url": None,
@@ -3513,6 +3553,20 @@ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_intercon
3513
3553
  if interconnect not in ("pcie", "nvlink", "nvlink_switch", "nvlink_p2p", "infiniband", "unknown"):
3514
3554
  interconnect = "unknown"
3515
3555
 
3556
+ # Infer strategy from degrees — only when evidence exists
3557
+ strategy = None
3558
+ has_tp = tp is not None and tp > 1
3559
+ has_pp = pp is not None and pp > 1
3560
+ if has_tp and has_pp:
3561
+ strategy = "tp+pp+dp"
3562
+ elif has_tp:
3563
+ strategy = "tp+dp" if (dp is not None and dp > 1) else "tp"
3564
+ elif has_pp:
3565
+ strategy = "pp+dp" if (dp is not None and dp > 1) else "pp"
3566
+ elif dp is not None and dp > 1:
3567
+ strategy = "ddp"
3568
+ # If none of the above matched, strategy stays None (unknown)
3569
+
3516
3570
  return {
3517
3571
  "num_nodes": nnodes or 1,
3518
3572
  "gpus_per_node": gpn,
@@ -3520,6 +3574,7 @@ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_intercon
3520
3574
  "pp_degree": pp,
3521
3575
  "dp_degree": dp,
3522
3576
  "interconnect_type": interconnect,
3577
+ "strategy": strategy,
3523
3578
  }
3524
3579
 
3525
3580
 
@@ -206,7 +206,11 @@ def main():
206
206
  except SystemExit:
207
207
  pass # catch real SystemExit too
208
208
  except Exception as e:
209
- result = {"status": "error", "error": str(e)[:200]}
209
+ error_msg = str(e)[:500]
210
+ _dist_keywords = ("init_process_group", "nccl", "gloo", "distributed",
211
+ "master_addr", "master_port", "rendezvouserror")
212
+ status = "error_distributed" if any(kw in error_msg.lower() for kw in _dist_keywords) else "error"
213
+ result = {"status": status, "error": error_msg}
210
214
  with open(sidecar_path, "w") as f:
211
215
  json.dump(result, f)
212
216
  return
@@ -33,6 +33,8 @@ class ModelInfo:
33
33
  seq_length: Optional[int] = None
34
34
  activation_memory_bytes: Optional[int] = None
35
35
  activation_method: Optional[str] = None # "traced" | None
36
+ extraction_error: Optional[str] = None # "distributed_entrypoint" | None
37
+ extraction_detail: Optional[str] = None # human-readable explanation
36
38
 
37
39
 
38
40
  def extract_model_info(
@@ -106,6 +108,10 @@ def _extract_via_subprocess(
106
108
  env.setdefault("WORLD_SIZE", "1")
107
109
  env.setdefault("MASTER_ADDR", "127.0.0.1")
108
110
  env.setdefault("MASTER_PORT", "29500")
111
+ # Suppress pynvml/torch.cuda deprecation warnings in subprocess
112
+ existing = env.get("PYTHONWARNINGS", "")
113
+ filters = "ignore::FutureWarning,ignore::DeprecationWarning"
114
+ env["PYTHONWARNINGS"] = f"{existing},{filters}" if existing else filters
109
115
 
110
116
  subprocess.run(
111
117
  [sys.executable, "-m", "alloc.extractor_runner", sidecar_path, script_abs],
@@ -134,6 +140,29 @@ def _extract_via_subprocess(
134
140
  activation_method=data.get("activation_method"),
135
141
  )
136
142
 
143
+ # Structured degradation for distributed scripts
144
+ status = data.get("status", "")
145
+ if status in ("error", "error_distributed"):
146
+ is_distributed = status == "error_distributed"
147
+ if not is_distributed:
148
+ # Fallback keyword match for older sidecar format
149
+ error_msg = data.get("error", "")
150
+ _dist_keywords = ("init_process_group", "NCCL", "gloo", "distributed",
151
+ "MASTER_ADDR", "MASTER_PORT", "RendezvousError")
152
+ is_distributed = any(kw.lower() in error_msg.lower() for kw in _dist_keywords)
153
+ if is_distributed:
154
+ return ModelInfo(
155
+ param_count=0,
156
+ dtype="float16",
157
+ model_name=None,
158
+ method="execution",
159
+ extraction_error="distributed_entrypoint",
160
+ extraction_detail=(
161
+ "Script requires a distributed runtime (e.g. torchrun). "
162
+ "Run ghost on the model definition file instead of the launcher script."
163
+ ),
164
+ )
165
+
137
166
  return None
138
167
 
139
168
  except subprocess.TimeoutExpired:
@@ -18,6 +18,13 @@ from dataclasses import dataclass, field
18
18
  from enum import Enum
19
19
  from typing import List, Optional
20
20
 
21
+ import warnings as _warnings
22
+ _warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
23
+ _warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
24
+ _warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
25
+ _warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
26
+ del _warnings
27
+
21
28
 
22
29
  class StopReason(str, Enum):
23
30
  STABLE = "stable"
@@ -356,7 +363,7 @@ def probe_command(
356
363
  """
357
364
  pynvml = _try_import_pynvml()
358
365
 
359
- # Launch the subprocess
366
+ # Launch the user's training subprocess — do NOT modify env (their warnings matter)
360
367
  try:
361
368
  proc = subprocess.Popen(
362
369
  command,
@@ -123,6 +123,7 @@ def upload_artifact(artifact_path: str, api_url: str, token: str) -> dict:
123
123
  "dataloader_wait_pct": probe.get("dataloader_wait_pct"),
124
124
  "comm_overhead_pct": probe.get("comm_overhead_pct"),
125
125
  "per_rank_peak_vram_mb": probe.get("per_rank_peak_vram_mb"),
126
+ "process_map": probe.get("process_map"),
126
127
  # Architecture fields: probe (callbacks) takes priority over ghost defaults
127
128
  "batch_size": probe.get("batch_size") or (ghost.get("batch_size") if ghost else None),
128
129
  "seq_length": ghost.get("seq_length") if ghost else None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -43,12 +43,15 @@ tests/test_diagnosis_engine.py
43
43
  tests/test_diagnosis_rules.py
44
44
  tests/test_extractor_activation.py
45
45
  tests/test_ghost.py
46
+ tests/test_ghost_degradation.py
46
47
  tests/test_init_from_org.py
47
48
  tests/test_interconnect.py
48
49
  tests/test_model_extractor.py
49
50
  tests/test_probe_hw.py
50
51
  tests/test_probe_multi.py
52
+ tests/test_scan_auth.py
51
53
  tests/test_stability.py
54
+ tests/test_topology_strategy.py
52
55
  tests/test_upload.py
53
56
  tests/test_verdict.py
54
57
  tests/test_yaml_config.py
@@ -239,7 +239,7 @@ def test_status_json_no_artifact(tmp_path, monkeypatch):
239
239
  assert result.exit_code == 0
240
240
  data = json.loads(result.output.strip())
241
241
  assert data["artifact"] is None
242
- assert data["logged_in"] is False
242
+ assert data["has_token"] is False
243
243
  assert "version" in data
244
244
 
245
245
 
@@ -0,0 +1,145 @@
1
+ """Tests for ghost structured degradation on distributed scripts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import tempfile
8
+ from pathlib import Path
9
+ from unittest.mock import patch
10
+
11
+ from typer.testing import CliRunner
12
+
13
+ from alloc.cli import app
14
+ from alloc.model_extractor import ModelInfo, extract_model_info
15
+
16
+ runner = CliRunner()
17
+
18
+
19
+ def test_distributed_error_returns_structured_modelinfo():
20
+ """When extractor subprocess fails with a distributed keyword, return structured ModelInfo."""
21
+ sidecar_data = json.dumps({
22
+ "status": "error",
23
+ "error": "RuntimeError: torch.distributed.init_process_group requires MASTER_ADDR",
24
+ })
25
+
26
+ def _fake_subprocess_run(*args, **kwargs):
27
+ # Write the sidecar file
28
+ sidecar_path = args[0][3] # [python, -m, alloc.extractor_runner, sidecar_path, script_path]
29
+ with open(sidecar_path, "w") as f:
30
+ f.write(sidecar_data)
31
+
32
+ # Create a dummy script
33
+ fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_dist_")
34
+ os.write(fd, b"import torch\ntorch.distributed.init_process_group('nccl')\n")
35
+ os.close(fd)
36
+
37
+ try:
38
+ with patch("subprocess.run", side_effect=_fake_subprocess_run):
39
+ info = extract_model_info(script_path)
40
+
41
+ assert info is not None
42
+ assert info.extraction_error == "distributed_entrypoint"
43
+ assert "distributed runtime" in info.extraction_detail
44
+ assert info.param_count == 0
45
+ finally:
46
+ os.unlink(script_path)
47
+
48
+
49
+ def test_distributed_error_nccl_keyword():
50
+ """NCCL errors should be caught as distributed failures."""
51
+ sidecar_data = json.dumps({
52
+ "status": "error",
53
+ "error": "NCCL error: unhandled system error",
54
+ })
55
+
56
+ fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_nccl_")
57
+ os.write(fd, b"pass\n")
58
+ os.close(fd)
59
+
60
+ try:
61
+ def _fake_run(*args, **kwargs):
62
+ sidecar_path = args[0][3]
63
+ with open(sidecar_path, "w") as f:
64
+ f.write(sidecar_data)
65
+
66
+ with patch("subprocess.run", side_effect=_fake_run):
67
+ info = extract_model_info(script_path)
68
+
69
+ assert info is not None
70
+ assert info.extraction_error == "distributed_entrypoint"
71
+ finally:
72
+ os.unlink(script_path)
73
+
74
+
75
+ def test_non_distributed_error_returns_none():
76
+ """Non-distributed errors should still return None (fall through to AST)."""
77
+ sidecar_data = json.dumps({
78
+ "status": "error",
79
+ "error": "ImportError: No module named 'custom_lib'",
80
+ })
81
+
82
+ fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_other_")
83
+ # Script with no from_pretrained so AST also returns None
84
+ os.write(fd, b"import custom_lib\n")
85
+ os.close(fd)
86
+
87
+ try:
88
+ def _fake_run(*args, **kwargs):
89
+ sidecar_path = args[0][3]
90
+ with open(sidecar_path, "w") as f:
91
+ f.write(sidecar_data)
92
+
93
+ with patch("subprocess.run", side_effect=_fake_run):
94
+ info = extract_model_info(script_path)
95
+
96
+ # Should be None because error is not distributed and AST won't find a model either
97
+ assert info is None
98
+ finally:
99
+ os.unlink(script_path)
100
+
101
+
102
+ def test_ghost_cli_distributed_error_json(tmp_path: Path):
103
+ """ghost --json shows structured error for distributed scripts."""
104
+ script_path = tmp_path / "train_ddp.py"
105
+ script_path.write_text("import torch\ntorch.distributed.init_process_group('nccl')\n")
106
+
107
+ dist_info = ModelInfo(
108
+ param_count=0,
109
+ dtype="float16",
110
+ model_name=None,
111
+ method="execution",
112
+ extraction_error="distributed_entrypoint",
113
+ extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
114
+ )
115
+
116
+ with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
117
+ result = runner.invoke(app, ["ghost", str(script_path), "--json"])
118
+
119
+ assert result.exit_code != 0
120
+ data = json.loads(result.output)
121
+ assert data["error"] == "distributed_entrypoint"
122
+ assert data["supported"] is False
123
+ assert "distributed runtime" in data["detail"]
124
+
125
+
126
+ def test_ghost_cli_distributed_error_human(tmp_path: Path):
127
+ """ghost shows human-readable message with tip for distributed scripts."""
128
+ script_path = tmp_path / "train_ddp.py"
129
+ script_path.write_text("import torch\n")
130
+
131
+ dist_info = ModelInfo(
132
+ param_count=0,
133
+ dtype="float16",
134
+ model_name=None,
135
+ method="execution",
136
+ extraction_error="distributed_entrypoint",
137
+ extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
138
+ )
139
+
140
+ with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
141
+ result = runner.invoke(app, ["ghost", str(script_path)])
142
+
143
+ assert result.exit_code != 0
144
+ assert "distributed runtime" in result.output
145
+ assert "model.py" in result.output # tip about pointing to model file
@@ -0,0 +1,142 @@
1
+ """Tests for scan command 401 retry + /scans/cli fallback."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from unittest.mock import MagicMock, patch
8
+
9
+ import httpx
10
+ from typer.testing import CliRunner
11
+
12
+ from alloc.cli import app
13
+
14
+ runner = CliRunner()
15
+
16
+
17
+ def _make_resp(status_code: int, body: dict, url: str = "https://api.example.com/scans"):
18
+ req = httpx.Request("POST", url)
19
+ return httpx.Response(
20
+ status_code,
21
+ request=req,
22
+ content=json.dumps(body).encode(),
23
+ headers={"content-type": "application/json"},
24
+ )
25
+
26
+
27
+ def test_scan_401_refresh_retry(tmp_path: Path):
28
+ """On 401, refresh token and retry on /scans."""
29
+ resp_401 = _make_resp(401, {"detail": "unauthorized"})
30
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
31
+
32
+ mock_client = MagicMock()
33
+ mock_client.__enter__.return_value = mock_client
34
+ mock_client.__exit__.return_value = False
35
+ mock_client.post.side_effect = [resp_401, resp_ok]
36
+
37
+ cfg_file = tmp_path / ".alloc" / "config.json"
38
+ cfg_file.parent.mkdir(parents=True)
39
+ cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
40
+
41
+ env = {
42
+ "HOME": str(tmp_path),
43
+ "ALLOC_API_URL": "https://api.example.com",
44
+ }
45
+
46
+ with (
47
+ patch("httpx.Client", return_value=mock_client),
48
+ patch("alloc.cli.try_refresh_access_token", return_value="new-tok"),
49
+ ):
50
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
51
+
52
+ assert result.exit_code == 0
53
+ assert mock_client.post.call_count == 2
54
+ # Second call should use refreshed token
55
+ second_call = mock_client.post.call_args_list[1]
56
+ assert "Bearer new-tok" in str(second_call)
57
+
58
+
59
+ def test_scan_401_refresh_fails_fallback_public(tmp_path: Path):
60
+ """On 401 + refresh failure, fall back to /scans/cli with warning."""
61
+ resp_401 = _make_resp(401, {"detail": "unauthorized"})
62
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []},
63
+ url="https://api.example.com/scans/cli")
64
+
65
+ mock_client = MagicMock()
66
+ mock_client.__enter__.return_value = mock_client
67
+ mock_client.__exit__.return_value = False
68
+ mock_client.post.side_effect = [resp_401, resp_ok]
69
+
70
+ cfg_file = tmp_path / ".alloc" / "config.json"
71
+ cfg_file.parent.mkdir(parents=True)
72
+ cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
73
+
74
+ env = {
75
+ "HOME": str(tmp_path),
76
+ "ALLOC_API_URL": "https://api.example.com",
77
+ }
78
+
79
+ with (
80
+ patch("httpx.Client", return_value=mock_client),
81
+ patch("alloc.cli.try_refresh_access_token", return_value=None),
82
+ ):
83
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
84
+
85
+ assert result.exit_code == 0
86
+ assert mock_client.post.call_count == 2
87
+ # Second call should hit /scans/cli
88
+ second_url = str(mock_client.post.call_args_list[1])
89
+ assert "/scans/cli" in second_url
90
+
91
+
92
+ def test_scan_401_fallback_warns_about_dropped_features(tmp_path: Path):
93
+ """Fallback to public scan warns user about lost org context."""
94
+ resp_401 = _make_resp(401, {"detail": "unauthorized"})
95
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
96
+
97
+ mock_client = MagicMock()
98
+ mock_client.__enter__.return_value = mock_client
99
+ mock_client.__exit__.return_value = False
100
+ mock_client.post.side_effect = [resp_401, resp_ok]
101
+
102
+ cfg_file = tmp_path / ".alloc" / "config.json"
103
+ cfg_file.parent.mkdir(parents=True)
104
+ cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
105
+
106
+ env = {
107
+ "HOME": str(tmp_path),
108
+ "ALLOC_API_URL": "https://api.example.com",
109
+ }
110
+
111
+ with (
112
+ patch("httpx.Client", return_value=mock_client),
113
+ patch("alloc.cli.try_refresh_access_token", return_value=None),
114
+ ):
115
+ # Non-JSON mode to see the warning message
116
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b"], env=env)
117
+
118
+ assert result.exit_code == 0
119
+ assert "expired" in result.output.lower() or "falling back" in result.output.lower()
120
+
121
+
122
+ def test_scan_no_token_uses_public_directly(tmp_path: Path):
123
+ """Without a token, scan goes directly to /scans/cli."""
124
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
125
+
126
+ mock_client = MagicMock()
127
+ mock_client.__enter__.return_value = mock_client
128
+ mock_client.__exit__.return_value = False
129
+ mock_client.post.return_value = resp_ok
130
+
131
+ env = {
132
+ "HOME": str(tmp_path),
133
+ "ALLOC_API_URL": "https://api.example.com",
134
+ }
135
+
136
+ with patch("httpx.Client", return_value=mock_client):
137
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
138
+
139
+ assert result.exit_code == 0
140
+ assert mock_client.post.call_count == 1
141
+ call_url = str(mock_client.post.call_args_list[0])
142
+ assert "/scans/cli" in call_url
@@ -0,0 +1,87 @@
1
+ """Tests for strategy inference from topology degrees (P0-B)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from unittest.mock import patch
7
+
8
+ from alloc.cli import _infer_parallel_topology_from_env
9
+
10
+
11
+ class TestStrategyInference:
12
+ """Strategy should be inferred from TP/PP/DP degrees when present."""
13
+
14
+ def _topo(self, env=None, num_gpus=4):
15
+ env = env or {}
16
+ with patch.dict(os.environ, env, clear=False):
17
+ return _infer_parallel_topology_from_env(
18
+ num_gpus_detected=num_gpus,
19
+ )
20
+
21
+ def test_no_degrees_strategy_none(self):
22
+ """When no degree env vars set, strategy should be None."""
23
+ result = self._topo({})
24
+ assert result["strategy"] is None
25
+
26
+ def test_dp_only_is_ddp(self):
27
+ """WORLD_SIZE=4 with no TP/PP → dp inferred → strategy=ddp."""
28
+ result = self._topo({"WORLD_SIZE": "4"})
29
+ assert result["strategy"] == "ddp"
30
+ assert result["dp_degree"] == 4
31
+
32
+ def test_tp_only(self):
33
+ """TP_SIZE=4 alone → strategy=tp."""
34
+ result = self._topo({"TP_SIZE": "4"})
35
+ assert result["strategy"] == "tp"
36
+
37
+ def test_pp_only(self):
38
+ """PP_SIZE=4 alone → strategy=pp."""
39
+ result = self._topo({"PP_SIZE": "4"})
40
+ assert result["strategy"] == "pp"
41
+
42
+ def test_tp_dp(self):
43
+ """TP_SIZE=2 with DP_SIZE=2 → strategy=tp+dp."""
44
+ result = self._topo({"TP_SIZE": "2", "DP_SIZE": "2"})
45
+ assert result["strategy"] == "tp+dp"
46
+
47
+ def test_pp_dp(self):
48
+ """PP_SIZE=2 with DP_SIZE=2 → strategy=pp+dp."""
49
+ result = self._topo({"PP_SIZE": "2", "DP_SIZE": "2"})
50
+ assert result["strategy"] == "pp+dp"
51
+
52
+ def test_tp_pp_dp(self):
53
+ """All three degrees → strategy=tp+pp+dp."""
54
+ result = self._topo({"TP_SIZE": "2", "PP_SIZE": "2", "DP_SIZE": "2"})
55
+ assert result["strategy"] == "tp+pp+dp"
56
+
57
+ def test_tp_pp_no_dp(self):
58
+ """TP+PP without explicit DP → strategy=tp+pp+dp."""
59
+ result = self._topo({"TP_SIZE": "2", "PP_SIZE": "2"})
60
+ assert result["strategy"] == "tp+pp+dp"
61
+
62
+ def test_tp_size_1_not_counted(self):
63
+ """TP_SIZE=1 should not count as tensor parallelism."""
64
+ result = self._topo({"TP_SIZE": "1", "DP_SIZE": "4"})
65
+ assert result["strategy"] == "ddp"
66
+
67
+ def test_pp_size_1_not_counted(self):
68
+ """PP_SIZE=1 should not count as pipeline parallelism."""
69
+ result = self._topo({"PP_SIZE": "1", "DP_SIZE": "4"})
70
+ assert result["strategy"] == "ddp"
71
+
72
+ def test_dp_inferred_from_world_size(self):
73
+ """DP inferred from WORLD_SIZE / (TP * PP) → strategy includes dp."""
74
+ result = self._topo({"WORLD_SIZE": "8", "TP_SIZE": "2"})
75
+ assert result["dp_degree"] == 4
76
+ assert result["strategy"] == "tp+dp"
77
+
78
+
79
+ class TestProcessMapInProbeDictAssembly:
80
+ """process_map should reach probe_dict from ProbeResult."""
81
+
82
+ def test_process_map_present_in_topology_return(self):
83
+ """Topology dict now includes strategy field."""
84
+ with patch.dict(os.environ, {"WORLD_SIZE": "4"}, clear=False):
85
+ topo = _infer_parallel_topology_from_env(num_gpus_detected=4)
86
+ assert "strategy" in topo
87
+ assert topo["strategy"] == "ddp"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes