alloc 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {alloc-0.0.5 → alloc-0.0.7}/PKG-INFO +1 -1
  2. {alloc-0.0.5 → alloc-0.0.7}/pyproject.toml +1 -1
  3. alloc-0.0.7/src/alloc/__init__.py +18 -0
  4. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/cli.py +68 -9
  5. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/extractor_runner.py +3 -1
  6. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/model_extractor.py +27 -0
  7. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/probe.py +156 -8
  8. {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/PKG-INFO +1 -1
  9. {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/SOURCES.txt +2 -0
  10. {alloc-0.0.5 → alloc-0.0.7}/tests/test_auth.py +12 -1
  11. alloc-0.0.7/tests/test_ghost_degradation.py +145 -0
  12. {alloc-0.0.5 → alloc-0.0.7}/tests/test_probe_multi.py +79 -1
  13. alloc-0.0.7/tests/test_scan_auth.py +142 -0
  14. alloc-0.0.5/src/alloc/__init__.py +0 -11
  15. {alloc-0.0.5 → alloc-0.0.7}/README.md +0 -0
  16. {alloc-0.0.5 → alloc-0.0.7}/setup.cfg +0 -0
  17. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/artifact_loader.py +0 -0
  18. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/artifact_writer.py +0 -0
  19. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/browser_auth.py +0 -0
  20. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/callbacks.py +0 -0
  21. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/catalog/__init__.py +0 -0
  22. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/catalog/default_rate_card.json +0 -0
  23. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/catalog/gpus.v1.json +0 -0
  24. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/code_analyzer.py +0 -0
  25. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/config.py +0 -0
  26. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/context.py +0 -0
  27. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/diagnosis_display.py +0 -0
  28. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/diagnosis_engine.py +0 -0
  29. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/diagnosis_rules.py +0 -0
  30. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/display.py +0 -0
  31. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/ghost.py +0 -0
  32. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/model_registry.py +0 -0
  33. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/stability.py +0 -0
  34. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/upload.py +0 -0
  35. {alloc-0.0.5 → alloc-0.0.7}/src/alloc/yaml_config.py +0 -0
  36. {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/dependency_links.txt +0 -0
  37. {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/entry_points.txt +0 -0
  38. {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/requires.txt +0 -0
  39. {alloc-0.0.5 → alloc-0.0.7}/src/alloc.egg-info/top_level.txt +0 -0
  40. {alloc-0.0.5 → alloc-0.0.7}/tests/test_artifact.py +0 -0
  41. {alloc-0.0.5 → alloc-0.0.7}/tests/test_artifact_loader.py +0 -0
  42. {alloc-0.0.5 → alloc-0.0.7}/tests/test_callbacks.py +0 -0
  43. {alloc-0.0.5 → alloc-0.0.7}/tests/test_catalog.py +0 -0
  44. {alloc-0.0.5 → alloc-0.0.7}/tests/test_cli.py +0 -0
  45. {alloc-0.0.5 → alloc-0.0.7}/tests/test_code_analyzer.py +0 -0
  46. {alloc-0.0.5 → alloc-0.0.7}/tests/test_context.py +0 -0
  47. {alloc-0.0.5 → alloc-0.0.7}/tests/test_diagnose_cli.py +0 -0
  48. {alloc-0.0.5 → alloc-0.0.7}/tests/test_diagnosis_engine.py +0 -0
  49. {alloc-0.0.5 → alloc-0.0.7}/tests/test_diagnosis_rules.py +0 -0
  50. {alloc-0.0.5 → alloc-0.0.7}/tests/test_extractor_activation.py +0 -0
  51. {alloc-0.0.5 → alloc-0.0.7}/tests/test_ghost.py +0 -0
  52. {alloc-0.0.5 → alloc-0.0.7}/tests/test_init_from_org.py +0 -0
  53. {alloc-0.0.5 → alloc-0.0.7}/tests/test_interconnect.py +0 -0
  54. {alloc-0.0.5 → alloc-0.0.7}/tests/test_model_extractor.py +0 -0
  55. {alloc-0.0.5 → alloc-0.0.7}/tests/test_probe_hw.py +0 -0
  56. {alloc-0.0.5 → alloc-0.0.7}/tests/test_stability.py +0 -0
  57. {alloc-0.0.5 → alloc-0.0.7}/tests/test_upload.py +0 -0
  58. {alloc-0.0.5 → alloc-0.0.7}/tests/test_verdict.py +0 -0
  59. {alloc-0.0.5 → alloc-0.0.7}/tests/test_yaml_config.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "alloc"
7
- version = "0.0.5"
7
+ version = "0.0.7"
8
8
  description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -0,0 +1,18 @@
1
+ """Alloc — GPU intelligence for ML training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import warnings as _warnings
6
+ _warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
7
+ _warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
8
+ _warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
9
+ _warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
10
+ del _warnings
11
+
12
+ __version__ = "0.0.7"
13
+
14
+ from alloc.ghost import ghost, GhostReport
15
+ from alloc.callbacks import AllocCallback as HuggingFaceCallback
16
+ from alloc.callbacks import AllocLightningCallback as LightningCallback
17
+
18
+ __all__ = ["ghost", "GhostReport", "HuggingFaceCallback", "LightningCallback", "__version__"]
@@ -16,8 +16,17 @@ from __future__ import annotations
16
16
 
17
17
  import os
18
18
  import sys
19
+ import warnings
19
20
  from typing import List, Optional
20
21
 
22
+ # Suppress noisy third-party warnings globally — pynvml deprecation (emitted
23
+ # from torch.cuda.__init__) and urllib3 LibreSSL warnings clutter CLI output.
24
+ warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
25
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
26
+ warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
27
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
28
+ warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
29
+
21
30
  import typer
22
31
  from rich.console import Console
23
32
 
@@ -68,6 +77,19 @@ def ghost(
68
77
  console.print(f"[dim]Tip: alloc ghost {script} --param-count-b 7.0[/dim]")
69
78
  raise typer.Exit(1)
70
79
 
80
+ if info.extraction_error:
81
+ if json_output:
82
+ _print_json({
83
+ "error": info.extraction_error,
84
+ "detail": info.extraction_detail,
85
+ "supported": False,
86
+ })
87
+ else:
88
+ console.print(f"[yellow]{info.extraction_detail}[/yellow]")
89
+ if info.extraction_error == "distributed_entrypoint":
90
+ console.print("[dim]Tip: alloc ghost model.py (point to the file that defines your model)[/dim]")
91
+ raise typer.Exit(1)
92
+
71
93
  # Use dtype from execution if available, otherwise CLI flag
72
94
  resolved_dtype = info.dtype if info.method == "execution" else dtype
73
95
 
@@ -2092,12 +2114,32 @@ def scan(
2092
2114
 
2093
2115
  try:
2094
2116
  headers = {"Content-Type": "application/json"}
2117
+ used_auth = bool(token)
2118
+
2095
2119
  if token:
2096
2120
  headers["Authorization"] = f"Bearer {token}"
2121
+ endpoint = "/scans"
2122
+ else:
2123
+ endpoint = "/scans/cli"
2097
2124
 
2098
- endpoint = "/scans" if token else "/scans/cli"
2099
2125
  with httpx.Client(timeout=30) as client:
2100
2126
  resp = client.post(f"{api_url}{endpoint}", json=payload, headers=headers)
2127
+
2128
+ # On 401 with a saved token: try refresh, then fall back to public endpoint
2129
+ if resp.status_code == 401 and used_auth:
2130
+ new_token = try_refresh_access_token()
2131
+ if new_token:
2132
+ headers["Authorization"] = f"Bearer {new_token}"
2133
+ resp = client.post(f"{api_url}/scans", json=payload, headers=headers)
2134
+ else:
2135
+ # Token refresh failed — fall back to unauthenticated scan
2136
+ console.print(
2137
+ "[yellow]Session expired — falling back to public scan "
2138
+ "(org fleet context unavailable). Run `alloc login` to restore.[/yellow]",
2139
+ )
2140
+ del headers["Authorization"]
2141
+ resp = client.post(f"{api_url}/scans/cli", json=payload, headers=headers)
2142
+
2101
2143
  resp.raise_for_status()
2102
2144
  result = resp.json()
2103
2145
 
@@ -2107,7 +2149,12 @@ def scan(
2107
2149
  _print_scan_result(result, gpu, strategy)
2108
2150
  except httpx.HTTPStatusError as e:
2109
2151
  if json_output:
2110
- _print_json({"error": f"API error {e.response.status_code}"})
2152
+ detail = ""
2153
+ try:
2154
+ detail = e.response.json().get("detail", "")
2155
+ except Exception:
2156
+ pass
2157
+ _print_json({"error": f"API error {e.response.status_code}", "detail": detail})
2111
2158
  elif e.response.status_code == 403:
2112
2159
  console.print("[yellow]AI analysis requires a Pro or Enterprise plan.[/yellow]")
2113
2160
  console.print("[dim]The scan still works — just without AI-powered analysis.[/dim]")
@@ -2147,18 +2194,30 @@ def login(
2147
2194
  ),
2148
2195
  ):
2149
2196
  """Authenticate with Alloc dashboard."""
2150
- # Suppress noisy third-party warnings (urllib3 LibreSSL, pynvml deprecation)
2151
- # that clutter the auth flow output.
2152
- import warnings
2153
- warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
2154
- warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
2155
- warnings.filterwarnings("ignore", message=".*pynvml.*", category=FutureWarning)
2156
-
2157
2197
  import httpx
2158
2198
  from alloc.config import get_supabase_url, get_supabase_anon_key, load_config, save_config
2159
2199
 
2160
2200
  # --- Browser OAuth flow ---
2161
2201
  if browser:
2202
+ # Detect headless/SSH environments where browser login won't work
2203
+ is_headless = (
2204
+ not os.environ.get("DISPLAY")
2205
+ and not os.environ.get("WAYLAND_DISPLAY")
2206
+ and sys.platform != "darwin"
2207
+ and sys.platform != "win32"
2208
+ )
2209
+ is_ssh = bool(os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"))
2210
+ if is_headless or is_ssh:
2211
+ console.print("[yellow]Headless/SSH environment detected — browser login won't work here.[/yellow]")
2212
+ console.print()
2213
+ console.print("Use token login instead:")
2214
+ console.print(" 1. Log in at [cyan]https://alloclabs.com[/cyan] in your local browser")
2215
+ console.print(" 2. Open DevTools → Application → Local Storage → copy your access_token")
2216
+ console.print(" 3. Run: [green]alloc login --method token --token <paste-token>[/green]")
2217
+ console.print()
2218
+ console.print("[dim]Or set ALLOC_TOKEN=<token> in your environment.[/dim]")
2219
+ raise typer.Exit(1)
2220
+
2162
2221
  provider = (provider or "").strip().lower()
2163
2222
  if provider not in ("google", "azure"):
2164
2223
  console.print("[red]Invalid --provider. Use: google or azure[/red]")
@@ -217,7 +217,9 @@ def main():
217
217
  try:
218
218
  obj = getattr(module, attr_name)
219
219
  if isinstance(obj, nn.Module):
220
- count, dtype_str = _count_params(obj)
220
+ # Unwrap DDP/FSDP wrappers to get the underlying model
221
+ unwrapped = getattr(obj, "module", obj)
222
+ count, dtype_str = _count_params(unwrapped)
221
223
  if count > 0:
222
224
  models.append((count, dtype_str, attr_name))
223
225
  except Exception:
@@ -33,6 +33,8 @@ class ModelInfo:
33
33
  seq_length: Optional[int] = None
34
34
  activation_memory_bytes: Optional[int] = None
35
35
  activation_method: Optional[str] = None # "traced" | None
36
+ extraction_error: Optional[str] = None # "distributed_entrypoint" | None
37
+ extraction_detail: Optional[str] = None # human-readable explanation
36
38
 
37
39
 
38
40
  def extract_model_info(
@@ -99,6 +101,13 @@ def _extract_via_subprocess(
99
101
 
100
102
  env = os.environ.copy()
101
103
  env["CUDA_VISIBLE_DEVICES"] = "" # prevent GPU allocation
104
+ # Set distributed env vars so DDP scripts don't crash on
105
+ # torch.distributed.init_process_group() during model extraction.
106
+ env.setdefault("RANK", "0")
107
+ env.setdefault("LOCAL_RANK", "0")
108
+ env.setdefault("WORLD_SIZE", "1")
109
+ env.setdefault("MASTER_ADDR", "127.0.0.1")
110
+ env.setdefault("MASTER_PORT", "29500")
102
111
 
103
112
  subprocess.run(
104
113
  [sys.executable, "-m", "alloc.extractor_runner", sidecar_path, script_abs],
@@ -127,6 +136,24 @@ def _extract_via_subprocess(
127
136
  activation_method=data.get("activation_method"),
128
137
  )
129
138
 
139
+ # Structured degradation for distributed scripts
140
+ if data.get("status") == "error":
141
+ error_msg = data.get("error", "")
142
+ _dist_keywords = ("init_process_group", "NCCL", "gloo", "distributed",
143
+ "MASTER_ADDR", "MASTER_PORT", "RendezvousError")
144
+ if any(kw.lower() in error_msg.lower() for kw in _dist_keywords):
145
+ return ModelInfo(
146
+ param_count=0,
147
+ dtype="float16",
148
+ model_name=None,
149
+ method="execution",
150
+ extraction_error="distributed_entrypoint",
151
+ extraction_detail=(
152
+ "Script requires a distributed runtime (e.g. torchrun). "
153
+ "Run ghost on the model definition file instead of the launcher script."
154
+ ),
155
+ )
156
+
130
157
  return None
131
158
 
132
159
  except subprocess.TimeoutExpired:
@@ -18,6 +18,13 @@ from dataclasses import dataclass, field
18
18
  from enum import Enum
19
19
  from typing import List, Optional
20
20
 
21
+ import warnings as _warnings
22
+ _warnings.filterwarnings("ignore", category=FutureWarning, module="pynvml")
23
+ _warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
24
+ _warnings.filterwarnings("ignore", category=FutureWarning, module=r"torch\.cuda")
25
+ _warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"torch\.cuda")
26
+ del _warnings
27
+
21
28
 
22
29
  class StopReason(str, Enum):
23
30
  STABLE = "stable"
@@ -83,6 +90,85 @@ def _try_import_pynvml():
83
90
 
84
91
 
85
92
 
93
+ def _parse_launcher_gpu_count(command):
94
+ # type: (list) -> Optional[int]
95
+ """Extract expected GPU count from launcher command line.
96
+
97
+ Recognizes: torchrun, torch.distributed.launch, accelerate launch,
98
+ deepspeed, mpirun/mpiexec.
99
+ """
100
+ args = [str(a) for a in command]
101
+ cmd_str = " ".join(args)
102
+
103
+ # torchrun --nproc_per_node=N or --nproc-per-node=N
104
+ for i, a in enumerate(args):
105
+ for flag in ("--nproc_per_node", "--nproc-per-node"):
106
+ if a.startswith(flag + "="):
107
+ try:
108
+ return int(a.split("=", 1)[1])
109
+ except ValueError:
110
+ pass
111
+ elif a == flag and i + 1 < len(args):
112
+ try:
113
+ return int(args[i + 1])
114
+ except ValueError:
115
+ pass
116
+
117
+ # accelerate launch --num_processes=N or --num-processes=N
118
+ for i, a in enumerate(args):
119
+ for flag in ("--num_processes", "--num-processes"):
120
+ if a.startswith(flag + "="):
121
+ try:
122
+ return int(a.split("=", 1)[1])
123
+ except ValueError:
124
+ pass
125
+ elif a == flag and i + 1 < len(args):
126
+ try:
127
+ return int(args[i + 1])
128
+ except ValueError:
129
+ pass
130
+
131
+ # deepspeed --num_gpus=N or --num-gpus=N
132
+ for i, a in enumerate(args):
133
+ for flag in ("--num_gpus", "--num-gpus"):
134
+ if a.startswith(flag + "="):
135
+ try:
136
+ return int(a.split("=", 1)[1])
137
+ except ValueError:
138
+ pass
139
+ elif a == flag and i + 1 < len(args):
140
+ try:
141
+ return int(args[i + 1])
142
+ except ValueError:
143
+ pass
144
+
145
+ # mpirun/mpiexec -np N or -n N
146
+ for i, a in enumerate(args):
147
+ if a in ("-np", "-n") and i + 1 < len(args):
148
+ if any(x in cmd_str for x in ("mpirun", "mpiexec")):
149
+ try:
150
+ return int(args[i + 1])
151
+ except ValueError:
152
+ pass
153
+
154
+ return None
155
+
156
+
157
+ def _read_child_env(pid, var_name):
158
+ # type: (int, str) -> Optional[str]
159
+ """Read an environment variable from a child process via /proc (Linux only)."""
160
+ try:
161
+ env_path = "/proc/{}/environ".format(pid)
162
+ with open(env_path, "rb") as f:
163
+ env_data = f.read()
164
+ for entry in env_data.split(b"\x00"):
165
+ if entry.startswith(var_name.encode() + b"="):
166
+ return entry.split(b"=", 1)[1].decode("utf-8")
167
+ except Exception:
168
+ pass
169
+ return None
170
+
171
+
86
172
  def _get_child_pids(pid):
87
173
  # type: (int) -> List[int]
88
174
  """Get child PIDs of a process. Returns empty list on failure."""
@@ -101,11 +187,15 @@ def _get_child_pids(pid):
101
187
  return []
102
188
 
103
189
 
104
- def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0):
105
- # type: (int, ..., int) -> List[int]
190
+ def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0, expected_gpus=None):
191
+ # type: (int, ..., int, Optional[int]) -> List[int]
106
192
  """Discover which GPUs a process (and its children) are using.
107
193
 
108
- Iterates all GPU devices and checks running compute processes.
194
+ Two strategies:
195
+ 1. PID-matching: walks process tree, matches PIDs against NVML compute processes.
196
+ 2. Active-GPU counting: counts GPUs with ANY compute processes. Used when PID
197
+ matching finds fewer GPUs than expected (common for DDP launchers).
198
+
109
199
  Falls back to [fallback_index] if discovery fails or finds nothing.
110
200
  """
111
201
  try:
@@ -143,11 +233,48 @@ def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0):
143
233
  for ggchild in _get_child_pids(grandchild):
144
234
  target_pids.add(ggchild)
145
235
 
236
+ # Also try reading WORLD_SIZE from child process environments (Linux)
237
+ child_world_size = None
238
+ for child in _get_child_pids(proc_pid):
239
+ ws = _read_child_env(child, "WORLD_SIZE")
240
+ if ws:
241
+ try:
242
+ child_world_size = int(ws)
243
+ except ValueError:
244
+ pass
245
+ break
246
+ for grandchild in _get_child_pids(child):
247
+ ws = _read_child_env(grandchild, "WORLD_SIZE")
248
+ if ws:
249
+ try:
250
+ child_world_size = int(ws)
251
+ except ValueError:
252
+ pass
253
+ break
254
+ for ggchild in _get_child_pids(grandchild):
255
+ ws = _read_child_env(ggchild, "WORLD_SIZE")
256
+ if ws:
257
+ try:
258
+ child_world_size = int(ws)
259
+ except ValueError:
260
+ pass
261
+ break
262
+
263
+ # Determine how many GPUs we expect
264
+ effective_expected = expected_gpus
265
+ if child_world_size is not None:
266
+ if effective_expected is None or child_world_size > effective_expected:
267
+ effective_expected = child_world_size
268
+
269
+ # Strategy 1: PID-based matching
146
270
  found_indices = []
271
+ active_indices = [] # GPUs with ANY compute processes
147
272
  for idx in search_indices:
148
273
  try:
149
274
  handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
150
275
  procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
276
+ if procs:
277
+ active_indices.append(idx)
151
278
  for p in procs:
152
279
  if p.pid in target_pids:
153
280
  found_indices.append(idx)
@@ -155,6 +282,16 @@ def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0):
155
282
  except Exception:
156
283
  continue
157
284
 
285
+ # Strategy 2: If PID matching found fewer than expected, use active GPU count.
286
+ # This handles the common case where DDP workers are running but their PIDs
287
+ # weren't in our process tree (e.g., process group isolation, timing issues).
288
+ if effective_expected is not None and len(found_indices) < effective_expected:
289
+ if len(active_indices) >= effective_expected:
290
+ return active_indices[:effective_expected]
291
+ # Even if active < expected, active is better than found
292
+ if len(active_indices) > len(found_indices):
293
+ return active_indices
294
+
158
295
  return found_indices if found_indices else [fallback_index]
159
296
 
160
297
 
@@ -310,12 +447,17 @@ def probe_command(
310
447
  discovery_attempts = 0
311
448
  max_discovery_attempts = 3 # Retry at samples 5, 15, 30
312
449
 
313
- # Determine expected GPU count from environment for retry logic
450
+ # Determine expected GPU count from command line + environment
314
451
  expected_gpus = 1
452
+ launcher_gpus = _parse_launcher_gpu_count(command)
453
+ if launcher_gpus is not None and launcher_gpus > 1:
454
+ expected_gpus = launcher_gpus
315
455
  ws = os.environ.get("WORLD_SIZE", "").strip()
316
456
  if ws:
317
457
  try:
318
- expected_gpus = max(1, int(ws))
458
+ ws_int = max(1, int(ws))
459
+ if ws_int > expected_gpus:
460
+ expected_gpus = ws_int
319
461
  except ValueError:
320
462
  pass
321
463
 
@@ -329,7 +471,10 @@ def probe_command(
329
471
  and proc.pid):
330
472
  discovery_attempts += 1
331
473
  try:
332
- discovered = _discover_gpu_indices(proc.pid, pynvml, fallback_index=gpu_index)
474
+ discovered = _discover_gpu_indices(
475
+ proc.pid, pynvml, fallback_index=gpu_index,
476
+ expected_gpus=expected_gpus if expected_gpus > 1 else None,
477
+ )
333
478
  if len(discovered) > 1:
334
479
  handles = []
335
480
  pmap = []
@@ -456,9 +601,12 @@ def probe_command(
456
601
  if calibration_time_ref[0] is not None:
457
602
  cal_duration = round(calibration_time_ref[0] - start_time, 2)
458
603
 
459
- # Environment-based fallback: if NVML discovery found fewer GPUs than
460
- # WORLD_SIZE indicates, trust the environment. The probe may miss GPUs
604
+ # Final fallback: if NVML discovery found fewer GPUs than expected,
605
+ # trust the command-line / environment signals. The probe may miss GPUs
461
606
  # due to DDP per-rank CVD isolation or timing races.
607
+ launcher_count = _parse_launcher_gpu_count(command)
608
+ if launcher_count is not None and launcher_count > num_gpus_ref[0]:
609
+ num_gpus_ref[0] = launcher_count
462
610
  env_world = os.environ.get("WORLD_SIZE", "").strip()
463
611
  if env_world:
464
612
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -43,11 +43,13 @@ tests/test_diagnosis_engine.py
43
43
  tests/test_diagnosis_rules.py
44
44
  tests/test_extractor_activation.py
45
45
  tests/test_ghost.py
46
+ tests/test_ghost_degradation.py
46
47
  tests/test_init_from_org.py
47
48
  tests/test_interconnect.py
48
49
  tests/test_model_extractor.py
49
50
  tests/test_probe_hw.py
50
51
  tests/test_probe_multi.py
52
+ tests/test_scan_auth.py
51
53
  tests/test_stability.py
52
54
  tests/test_upload.py
53
55
  tests/test_verdict.py
@@ -223,6 +223,7 @@ def test_browser_login_saves_tokens(tmp_path: Path):
223
223
  env = {
224
224
  "HOME": str(tmp_path),
225
225
  "ALLOC_API_URL": "https://api.example.com",
226
+ "DISPLAY": ":0", # prevent headless detection in CI
226
227
  }
227
228
 
228
229
  with patch("alloc.browser_auth.browser_login", return_value=mock_result):
@@ -248,6 +249,7 @@ def test_browser_login_with_azure_provider(tmp_path: Path):
248
249
  env = {
249
250
  "HOME": str(tmp_path),
250
251
  "ALLOC_API_URL": "https://api.example.com",
252
+ "DISPLAY": ":0",
251
253
  }
252
254
 
253
255
  with patch("alloc.browser_auth.browser_login", return_value=mock_result) as mock_bl:
@@ -263,7 +265,7 @@ def test_browser_login_with_azure_provider(tmp_path: Path):
263
265
 
264
266
 
265
267
  def test_browser_login_invalid_provider(tmp_path: Path):
266
- env = {"HOME": str(tmp_path)}
268
+ env = {"HOME": str(tmp_path), "DISPLAY": ":0"}
267
269
  result = runner.invoke(
268
270
  app, ["login", "--browser", "--provider", "facebook"], env=env
269
271
  )
@@ -271,10 +273,19 @@ def test_browser_login_invalid_provider(tmp_path: Path):
271
273
  assert "Invalid --provider" in result.output
272
274
 
273
275
 
276
+ def test_browser_login_headless_detection(tmp_path: Path):
277
+ """In headless/SSH environment, browser login should fail with guidance."""
278
+ env = {"HOME": str(tmp_path), "SSH_CLIENT": "1.2.3.4 1234 22", "DISPLAY": ":0"}
279
+ result = runner.invoke(app, ["login", "--browser"], env=env)
280
+ assert result.exit_code != 0
281
+ assert "Headless" in result.output or "token" in result.output
282
+
283
+
274
284
  def test_browser_login_timeout(tmp_path: Path):
275
285
  env = {
276
286
  "HOME": str(tmp_path),
277
287
  "ALLOC_API_URL": "https://api.example.com",
288
+ "DISPLAY": ":0",
278
289
  }
279
290
 
280
291
  with patch(
@@ -0,0 +1,145 @@
1
+ """Tests for ghost structured degradation on distributed scripts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import tempfile
8
+ from pathlib import Path
9
+ from unittest.mock import patch
10
+
11
+ from typer.testing import CliRunner
12
+
13
+ from alloc.cli import app
14
+ from alloc.model_extractor import ModelInfo, extract_model_info
15
+
16
+ runner = CliRunner()
17
+
18
+
19
+ def test_distributed_error_returns_structured_modelinfo():
20
+ """When extractor subprocess fails with a distributed keyword, return structured ModelInfo."""
21
+ sidecar_data = json.dumps({
22
+ "status": "error",
23
+ "error": "RuntimeError: torch.distributed.init_process_group requires MASTER_ADDR",
24
+ })
25
+
26
+ def _fake_subprocess_run(*args, **kwargs):
27
+ # Write the sidecar file
28
+ sidecar_path = args[0][3] # [python, -m, alloc.extractor_runner, sidecar_path, script_path]
29
+ with open(sidecar_path, "w") as f:
30
+ f.write(sidecar_data)
31
+
32
+ # Create a dummy script
33
+ fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_dist_")
34
+ os.write(fd, b"import torch\ntorch.distributed.init_process_group('nccl')\n")
35
+ os.close(fd)
36
+
37
+ try:
38
+ with patch("subprocess.run", side_effect=_fake_subprocess_run):
39
+ info = extract_model_info(script_path)
40
+
41
+ assert info is not None
42
+ assert info.extraction_error == "distributed_entrypoint"
43
+ assert "distributed runtime" in info.extraction_detail
44
+ assert info.param_count == 0
45
+ finally:
46
+ os.unlink(script_path)
47
+
48
+
49
+ def test_distributed_error_nccl_keyword():
50
+ """NCCL errors should be caught as distributed failures."""
51
+ sidecar_data = json.dumps({
52
+ "status": "error",
53
+ "error": "NCCL error: unhandled system error",
54
+ })
55
+
56
+ fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_nccl_")
57
+ os.write(fd, b"pass\n")
58
+ os.close(fd)
59
+
60
+ try:
61
+ def _fake_run(*args, **kwargs):
62
+ sidecar_path = args[0][3]
63
+ with open(sidecar_path, "w") as f:
64
+ f.write(sidecar_data)
65
+
66
+ with patch("subprocess.run", side_effect=_fake_run):
67
+ info = extract_model_info(script_path)
68
+
69
+ assert info is not None
70
+ assert info.extraction_error == "distributed_entrypoint"
71
+ finally:
72
+ os.unlink(script_path)
73
+
74
+
75
+ def test_non_distributed_error_returns_none():
76
+ """Non-distributed errors should still return None (fall through to AST)."""
77
+ sidecar_data = json.dumps({
78
+ "status": "error",
79
+ "error": "ImportError: No module named 'custom_lib'",
80
+ })
81
+
82
+ fd, script_path = tempfile.mkstemp(suffix=".py", prefix="alloc_test_other_")
83
+ # Script with no from_pretrained so AST also returns None
84
+ os.write(fd, b"import custom_lib\n")
85
+ os.close(fd)
86
+
87
+ try:
88
+ def _fake_run(*args, **kwargs):
89
+ sidecar_path = args[0][3]
90
+ with open(sidecar_path, "w") as f:
91
+ f.write(sidecar_data)
92
+
93
+ with patch("subprocess.run", side_effect=_fake_run):
94
+ info = extract_model_info(script_path)
95
+
96
+ # Should be None because error is not distributed and AST won't find a model either
97
+ assert info is None
98
+ finally:
99
+ os.unlink(script_path)
100
+
101
+
102
+ def test_ghost_cli_distributed_error_json(tmp_path: Path):
103
+ """ghost --json shows structured error for distributed scripts."""
104
+ script_path = tmp_path / "train_ddp.py"
105
+ script_path.write_text("import torch\ntorch.distributed.init_process_group('nccl')\n")
106
+
107
+ dist_info = ModelInfo(
108
+ param_count=0,
109
+ dtype="float16",
110
+ model_name=None,
111
+ method="execution",
112
+ extraction_error="distributed_entrypoint",
113
+ extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
114
+ )
115
+
116
+ with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
117
+ result = runner.invoke(app, ["ghost", str(script_path), "--json"])
118
+
119
+ assert result.exit_code != 0
120
+ data = json.loads(result.output)
121
+ assert data["error"] == "distributed_entrypoint"
122
+ assert data["supported"] is False
123
+ assert "distributed runtime" in data["detail"]
124
+
125
+
126
+ def test_ghost_cli_distributed_error_human(tmp_path: Path):
127
+ """ghost shows human-readable message with tip for distributed scripts."""
128
+ script_path = tmp_path / "train_ddp.py"
129
+ script_path.write_text("import torch\n")
130
+
131
+ dist_info = ModelInfo(
132
+ param_count=0,
133
+ dtype="float16",
134
+ model_name=None,
135
+ method="execution",
136
+ extraction_error="distributed_entrypoint",
137
+ extraction_detail="Script requires a distributed runtime (e.g. torchrun). Run ghost on the model definition file instead of the launcher script.",
138
+ )
139
+
140
+ with patch("alloc.model_extractor.extract_model_info", return_value=dist_info):
141
+ result = runner.invoke(app, ["ghost", str(script_path)])
142
+
143
+ assert result.exit_code != 0
144
+ assert "distributed runtime" in result.output
145
+ assert "model.py" in result.output # tip about pointing to model file
@@ -4,7 +4,12 @@ from __future__ import annotations
4
4
 
5
5
  from unittest.mock import MagicMock, patch
6
6
 
7
- from alloc.probe import _discover_gpu_indices, _get_child_pids, ProbeResult
7
+ from alloc.probe import (
8
+ _discover_gpu_indices,
9
+ _get_child_pids,
10
+ _parse_launcher_gpu_count,
11
+ ProbeResult,
12
+ )
8
13
 
9
14
 
10
15
  def _mock_pynvml_multi_gpu(proc_pid, gpu_process_map):
@@ -112,3 +117,76 @@ def test_probe_result_defaults_single_gpu():
112
117
  result = ProbeResult()
113
118
  assert result.num_gpus_detected == 1
114
119
  assert result.process_map is None
120
+
121
+
122
+ # ── Launcher command-line parsing ──
123
+
124
+
125
+ def test_parse_torchrun_equals():
126
+ assert _parse_launcher_gpu_count(["torchrun", "--nproc_per_node=2", "train.py"]) == 2
127
+
128
+
129
+ def test_parse_torchrun_space():
130
+ assert _parse_launcher_gpu_count(["torchrun", "--nproc_per_node", "4", "train.py"]) == 4
131
+
132
+
133
+ def test_parse_torchrun_hyphen():
134
+ assert _parse_launcher_gpu_count(["torchrun", "--nproc-per-node=8", "train.py"]) == 8
135
+
136
+
137
+ def test_parse_accelerate_equals():
138
+ assert _parse_launcher_gpu_count(["accelerate", "launch", "--num_processes=4", "train.py"]) == 4
139
+
140
+
141
+ def test_parse_accelerate_hyphen_space():
142
+ assert _parse_launcher_gpu_count(["accelerate", "launch", "--num-processes", "8", "train.py"]) == 8
143
+
144
+
145
+ def test_parse_deepspeed_equals():
146
+ assert _parse_launcher_gpu_count(["deepspeed", "--num_gpus=4", "train.py"]) == 4
147
+
148
+
149
+ def test_parse_deepspeed_hyphen_space():
150
+ assert _parse_launcher_gpu_count(["deepspeed", "--num-gpus", "2", "train.py"]) == 2
151
+
152
+
153
+ def test_parse_mpirun():
154
+ assert _parse_launcher_gpu_count(["mpirun", "-np", "8", "python", "train.py"]) == 8
155
+
156
+
157
+ def test_parse_plain_python():
158
+ assert _parse_launcher_gpu_count(["python", "train.py"]) is None
159
+
160
+
161
+ def test_parse_torch_distributed_launch():
162
+ assert _parse_launcher_gpu_count([
163
+ "python", "-m", "torch.distributed.launch", "--nproc_per_node=2", "train.py"
164
+ ]) == 2
165
+
166
+
167
+ # ── Active-GPU fallback discovery ──
168
+
169
+
170
+ def test_active_gpu_fallback_when_pid_mismatch():
171
+ """When PID matching fails but GPUs have active compute processes,
172
+ and expected_gpus matches, use active GPUs."""
173
+ mock = _mock_pynvml_multi_gpu(
174
+ proc_pid=1000,
175
+ gpu_process_map={0: [9999], 1: [8888]}, # PIDs don't match 1000
176
+ )
177
+ with patch("alloc.probe._get_child_pids", return_value=[]):
178
+ with patch("alloc.probe._read_child_env", return_value=None):
179
+ result = _discover_gpu_indices(1000, mock, fallback_index=0, expected_gpus=2)
180
+ assert result == [0, 1]
181
+
182
+
183
+ def test_active_gpu_fallback_not_used_without_expected():
184
+ """Without expected_gpus, active GPU fallback is not used."""
185
+ mock = _mock_pynvml_multi_gpu(
186
+ proc_pid=1000,
187
+ gpu_process_map={0: [9999], 1: [8888]},
188
+ )
189
+ with patch("alloc.probe._get_child_pids", return_value=[]):
190
+ with patch("alloc.probe._read_child_env", return_value=None):
191
+ result = _discover_gpu_indices(1000, mock, fallback_index=0)
192
+ assert result == [0] # Falls back to default
@@ -0,0 +1,142 @@
1
+ """Tests for scan command 401 retry + /scans/cli fallback."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from unittest.mock import MagicMock, patch
8
+
9
+ import httpx
10
+ from typer.testing import CliRunner
11
+
12
+ from alloc.cli import app
13
+
14
+ runner = CliRunner()
15
+
16
+
17
+ def _make_resp(status_code: int, body: dict, url: str = "https://api.example.com/scans"):
18
+ req = httpx.Request("POST", url)
19
+ return httpx.Response(
20
+ status_code,
21
+ request=req,
22
+ content=json.dumps(body).encode(),
23
+ headers={"content-type": "application/json"},
24
+ )
25
+
26
+
27
+ def test_scan_401_refresh_retry(tmp_path: Path):
28
+ """On 401, refresh token and retry on /scans."""
29
+ resp_401 = _make_resp(401, {"detail": "unauthorized"})
30
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
31
+
32
+ mock_client = MagicMock()
33
+ mock_client.__enter__.return_value = mock_client
34
+ mock_client.__exit__.return_value = False
35
+ mock_client.post.side_effect = [resp_401, resp_ok]
36
+
37
+ cfg_file = tmp_path / ".alloc" / "config.json"
38
+ cfg_file.parent.mkdir(parents=True)
39
+ cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
40
+
41
+ env = {
42
+ "HOME": str(tmp_path),
43
+ "ALLOC_API_URL": "https://api.example.com",
44
+ }
45
+
46
+ with (
47
+ patch("httpx.Client", return_value=mock_client),
48
+ patch("alloc.cli.try_refresh_access_token", return_value="new-tok"),
49
+ ):
50
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
51
+
52
+ assert result.exit_code == 0
53
+ assert mock_client.post.call_count == 2
54
+ # Second call should use refreshed token
55
+ second_call = mock_client.post.call_args_list[1]
56
+ assert "Bearer new-tok" in str(second_call)
57
+
58
+
59
+ def test_scan_401_refresh_fails_fallback_public(tmp_path: Path):
60
+ """On 401 + refresh failure, fall back to /scans/cli with warning."""
61
+ resp_401 = _make_resp(401, {"detail": "unauthorized"})
62
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []},
63
+ url="https://api.example.com/scans/cli")
64
+
65
+ mock_client = MagicMock()
66
+ mock_client.__enter__.return_value = mock_client
67
+ mock_client.__exit__.return_value = False
68
+ mock_client.post.side_effect = [resp_401, resp_ok]
69
+
70
+ cfg_file = tmp_path / ".alloc" / "config.json"
71
+ cfg_file.parent.mkdir(parents=True)
72
+ cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
73
+
74
+ env = {
75
+ "HOME": str(tmp_path),
76
+ "ALLOC_API_URL": "https://api.example.com",
77
+ }
78
+
79
+ with (
80
+ patch("httpx.Client", return_value=mock_client),
81
+ patch("alloc.cli.try_refresh_access_token", return_value=None),
82
+ ):
83
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
84
+
85
+ assert result.exit_code == 0
86
+ assert mock_client.post.call_count == 2
87
+ # Second call should hit /scans/cli
88
+ second_url = str(mock_client.post.call_args_list[1])
89
+ assert "/scans/cli" in second_url
90
+
91
+
92
+ def test_scan_401_fallback_warns_about_dropped_features(tmp_path: Path):
93
+ """Fallback to public scan warns user about lost org context."""
94
+ resp_401 = _make_resp(401, {"detail": "unauthorized"})
95
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
96
+
97
+ mock_client = MagicMock()
98
+ mock_client.__enter__.return_value = mock_client
99
+ mock_client.__exit__.return_value = False
100
+ mock_client.post.side_effect = [resp_401, resp_ok]
101
+
102
+ cfg_file = tmp_path / ".alloc" / "config.json"
103
+ cfg_file.parent.mkdir(parents=True)
104
+ cfg_file.write_text(json.dumps({"token": "old-tok", "refresh_token": "rt"}))
105
+
106
+ env = {
107
+ "HOME": str(tmp_path),
108
+ "ALLOC_API_URL": "https://api.example.com",
109
+ }
110
+
111
+ with (
112
+ patch("httpx.Client", return_value=mock_client),
113
+ patch("alloc.cli.try_refresh_access_token", return_value=None),
114
+ ):
115
+ # Non-JSON mode to see the warning message
116
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b"], env=env)
117
+
118
+ assert result.exit_code == 0
119
+ assert "expired" in result.output.lower() or "falling back" in result.output.lower()
120
+
121
+
122
+ def test_scan_no_token_uses_public_directly(tmp_path: Path):
123
+ """Without a token, scan goes directly to /scans/cli."""
124
+ resp_ok = _make_resp(200, {"vram_gb": 16.0, "configs": []})
125
+
126
+ mock_client = MagicMock()
127
+ mock_client.__enter__.return_value = mock_client
128
+ mock_client.__exit__.return_value = False
129
+ mock_client.post.return_value = resp_ok
130
+
131
+ env = {
132
+ "HOME": str(tmp_path),
133
+ "ALLOC_API_URL": "https://api.example.com",
134
+ }
135
+
136
+ with patch("httpx.Client", return_value=mock_client):
137
+ result = runner.invoke(app, ["scan", "--model", "llama-3-8b", "--json"], env=env)
138
+
139
+ assert result.exit_code == 0
140
+ assert mock_client.post.call_count == 1
141
+ call_url = str(mock_client.post.call_args_list[0])
142
+ assert "/scans/cli" in call_url
@@ -1,11 +0,0 @@
1
- """Alloc — GPU intelligence for ML training."""
2
-
3
- from __future__ import annotations
4
-
5
- __version__ = "0.0.5"
6
-
7
- from alloc.ghost import ghost, GhostReport
8
- from alloc.callbacks import AllocCallback as HuggingFaceCallback
9
- from alloc.callbacks import AllocLightningCallback as LightningCallback
10
-
11
- __all__ = ["ghost", "GhostReport", "HuggingFaceCallback", "LightningCallback", "__version__"]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes