alloc 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {alloc-0.0.4 → alloc-0.0.5}/PKG-INFO +1 -1
  2. {alloc-0.0.4 → alloc-0.0.5}/pyproject.toml +1 -1
  3. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/__init__.py +1 -1
  4. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/browser_auth.py +28 -6
  5. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/callbacks.py +22 -5
  6. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/cli.py +37 -0
  7. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/code_analyzer.py +54 -0
  8. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/config.py +3 -2
  9. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/diagnosis_display.py +75 -28
  10. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/diagnosis_engine.py +2 -2
  11. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/diagnosis_rules.py +5 -3
  12. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/probe.py +60 -6
  13. {alloc-0.0.4 → alloc-0.0.5}/src/alloc.egg-info/PKG-INFO +1 -1
  14. {alloc-0.0.4 → alloc-0.0.5}/tests/test_cli.py +2 -0
  15. {alloc-0.0.4 → alloc-0.0.5}/README.md +0 -0
  16. {alloc-0.0.4 → alloc-0.0.5}/setup.cfg +0 -0
  17. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/artifact_loader.py +0 -0
  18. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/artifact_writer.py +0 -0
  19. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/catalog/__init__.py +0 -0
  20. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/catalog/default_rate_card.json +0 -0
  21. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/catalog/gpus.v1.json +0 -0
  22. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/context.py +0 -0
  23. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/display.py +0 -0
  24. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/extractor_runner.py +0 -0
  25. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/ghost.py +0 -0
  26. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/model_extractor.py +0 -0
  27. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/model_registry.py +0 -0
  28. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/stability.py +0 -0
  29. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/upload.py +0 -0
  30. {alloc-0.0.4 → alloc-0.0.5}/src/alloc/yaml_config.py +0 -0
  31. {alloc-0.0.4 → alloc-0.0.5}/src/alloc.egg-info/SOURCES.txt +0 -0
  32. {alloc-0.0.4 → alloc-0.0.5}/src/alloc.egg-info/dependency_links.txt +0 -0
  33. {alloc-0.0.4 → alloc-0.0.5}/src/alloc.egg-info/entry_points.txt +0 -0
  34. {alloc-0.0.4 → alloc-0.0.5}/src/alloc.egg-info/requires.txt +0 -0
  35. {alloc-0.0.4 → alloc-0.0.5}/src/alloc.egg-info/top_level.txt +0 -0
  36. {alloc-0.0.4 → alloc-0.0.5}/tests/test_artifact.py +0 -0
  37. {alloc-0.0.4 → alloc-0.0.5}/tests/test_artifact_loader.py +0 -0
  38. {alloc-0.0.4 → alloc-0.0.5}/tests/test_auth.py +0 -0
  39. {alloc-0.0.4 → alloc-0.0.5}/tests/test_callbacks.py +0 -0
  40. {alloc-0.0.4 → alloc-0.0.5}/tests/test_catalog.py +0 -0
  41. {alloc-0.0.4 → alloc-0.0.5}/tests/test_code_analyzer.py +0 -0
  42. {alloc-0.0.4 → alloc-0.0.5}/tests/test_context.py +0 -0
  43. {alloc-0.0.4 → alloc-0.0.5}/tests/test_diagnose_cli.py +0 -0
  44. {alloc-0.0.4 → alloc-0.0.5}/tests/test_diagnosis_engine.py +0 -0
  45. {alloc-0.0.4 → alloc-0.0.5}/tests/test_diagnosis_rules.py +0 -0
  46. {alloc-0.0.4 → alloc-0.0.5}/tests/test_extractor_activation.py +0 -0
  47. {alloc-0.0.4 → alloc-0.0.5}/tests/test_ghost.py +0 -0
  48. {alloc-0.0.4 → alloc-0.0.5}/tests/test_init_from_org.py +0 -0
  49. {alloc-0.0.4 → alloc-0.0.5}/tests/test_interconnect.py +0 -0
  50. {alloc-0.0.4 → alloc-0.0.5}/tests/test_model_extractor.py +0 -0
  51. {alloc-0.0.4 → alloc-0.0.5}/tests/test_probe_hw.py +0 -0
  52. {alloc-0.0.4 → alloc-0.0.5}/tests/test_probe_multi.py +0 -0
  53. {alloc-0.0.4 → alloc-0.0.5}/tests/test_stability.py +0 -0
  54. {alloc-0.0.4 → alloc-0.0.5}/tests/test_upload.py +0 -0
  55. {alloc-0.0.4 → alloc-0.0.5}/tests/test_verdict.py +0 -0
  56. {alloc-0.0.4 → alloc-0.0.5}/tests/test_yaml_config.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "alloc"
7
- version = "0.0.4"
7
+ version = "0.0.5"
8
8
  description = "Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints."
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "0.0.4"
5
+ __version__ = "0.0.5"
6
6
 
7
7
  from alloc.ghost import ghost, GhostReport
8
8
  from alloc.callbacks import AllocCallback as HuggingFaceCallback
@@ -11,6 +11,7 @@ from __future__ import annotations
11
11
 
12
12
  import base64
13
13
  import hashlib
14
+ import html
14
15
  import secrets
15
16
  import socket
16
17
  import threading
@@ -40,7 +41,8 @@ def _find_open_port(start=17256, attempts=20):
40
41
  for port in range(start, start + attempts):
41
42
  try:
42
43
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
43
- s.bind(("127.0.0.1", port))
44
+ # Bind to all interfaces so both localhost and 127.0.0.1 work
45
+ s.bind(("0.0.0.0", port))
44
46
  return port
45
47
  except OSError:
46
48
  continue
@@ -69,7 +71,7 @@ class _CallbackHandler(BaseHTTPRequestHandler):
69
71
  self._respond(
70
72
  400,
71
73
  "<html><body style='font-family:system-ui;text-align:center;padding:60px'>"
72
- f"<h2>Login failed</h2><p>{error_desc}</p>"
74
+ f"<h2>Login failed</h2><p>{html.escape(error_desc)}</p>"
73
75
  "</body></html>",
74
76
  )
75
77
  else:
@@ -108,7 +110,8 @@ def browser_login(
108
110
  verifier, challenge = _generate_pkce_pair()
109
111
  port = _find_open_port()
110
112
 
111
- redirect_uri = f"http://localhost:{port}/callback"
113
+ # Use 127.0.0.1 (not localhost) — more reliable, avoids IPv6 resolution issues.
114
+ redirect_uri = f"http://127.0.0.1:{port}/callback"
112
115
 
113
116
  authorize_params = urlencode({
114
117
  "provider": provider,
@@ -118,7 +121,8 @@ def browser_login(
118
121
  })
119
122
  authorize_url = f"{supabase_url}/auth/v1/authorize?{authorize_params}"
120
123
 
121
- server = HTTPServer(("127.0.0.1", port), _CallbackHandler)
124
+ # Bind to 0.0.0.0 so both localhost and 127.0.0.1 reach the server.
125
+ server = HTTPServer(("0.0.0.0", port), _CallbackHandler)
122
126
  server.auth_code = None # type: ignore[attr-defined]
123
127
  server.auth_error = None # type: ignore[attr-defined]
124
128
  server.timeout = 1 # poll interval for handle_request()
@@ -128,16 +132,21 @@ def browser_login(
128
132
  server_thread.daemon = True
129
133
  server_thread.start()
130
134
 
135
+ import sys
136
+
131
137
  # Open the browser (or print URL as fallback).
132
138
  try:
133
139
  opened = webbrowser.open(authorize_url)
134
140
  except Exception:
135
141
  opened = False
136
142
 
143
+ print(f"\nCallback server listening on http://127.0.0.1:{port}/callback", file=sys.stderr)
137
144
  if not opened:
138
145
  print(f"\nOpen this URL in your browser to log in:\n\n {authorize_url}\n")
139
146
  else:
140
- print("Opened browser for login. Waiting for callback...")
147
+ print("Opened browser for login. Waiting for callback...", file=sys.stderr)
148
+ print(f"If login completes but the terminal stays stuck, your Supabase", file=sys.stderr)
149
+ print(f"redirect allowlist may not include http://127.0.0.1:{port}/callback", file=sys.stderr)
141
150
 
142
151
  server_thread.join(timeout=timeout_seconds + 5)
143
152
  server.server_close()
@@ -146,7 +155,20 @@ def browser_login(
146
155
  raise RuntimeError(f"OAuth error: {server.auth_error}")
147
156
 
148
157
  if not server.auth_code:
149
- raise RuntimeError("Login timed out — no callback received within 120 seconds.")
158
+ raise RuntimeError(
159
+ f"Login timed out — no callback received within {timeout_seconds} seconds.\n"
160
+ f"\n"
161
+ f" The browser never reached http://127.0.0.1:{port}/callback.\n"
162
+ f"\n"
163
+ f" Common causes:\n"
164
+ f" 1. Supabase redirect allowlist does not include http://127.0.0.1:{port}/**\n"
165
+ f" (Check: Supabase Dashboard → Authentication → URL Configuration → Redirect URLs)\n"
166
+ f" 2. Browser redirected to your site URL instead of localhost\n"
167
+ f" 3. Firewall or antivirus blocked the local callback server\n"
168
+ f"\n"
169
+ f" Workaround: alloc login --method token --token <paste-access-token>\n"
170
+ f" (Copy token from browser DevTools → Application → Local Storage → sb-*-auth-token)"
171
+ )
150
172
 
151
173
  # Exchange auth code + verifier for tokens.
152
174
  with httpx.Client(timeout=15) as client:
@@ -138,7 +138,8 @@ def _detect_architecture(model, optimizer=None, training_args=None):
138
138
  "mistral", "qwen2", "phi", "gemma", "falcon",
139
139
  "bert", "roberta", "t5", "bart", "mbart",
140
140
  "whisper", "wav2vec2", "vit", "deit", "beit",
141
- "swin", "clip", "dinov2"},
141
+ "swin", "clip", "dinov2", "deepseek",
142
+ "starcoder2", "cohere", "mamba"},
142
143
  "moe": {"mixtral", "switch_transformers"},
143
144
  "diffusion": {"unet_2d_condition"},
144
145
  }
@@ -389,9 +390,19 @@ class _NvmlMonitor:
389
390
  if 0 <= idx < physical_count:
390
391
  visible_indices.append(idx)
391
392
  except ValueError:
392
- # UUID-style device identifiers — fall back to physical count
393
- visible_indices = list(range(physical_count))
394
- break
393
+ # UUID-style device identifiers — try NVML UUID matching
394
+ try:
395
+ for phys_idx in range(physical_count):
396
+ handle = self._pynvml.nvmlDeviceGetHandleByIndex(phys_idx)
397
+ uuid = self._pynvml.nvmlDeviceGetUUID(handle)
398
+ if isinstance(uuid, bytes):
399
+ uuid = uuid.decode("utf-8", errors="replace")
400
+ if d in uuid:
401
+ visible_indices.append(phys_idx)
402
+ break
403
+ except Exception:
404
+ visible_indices = list(range(physical_count))
405
+ break
395
406
  gpu_indices = visible_indices if visible_indices else list(range(physical_count))
396
407
  else:
397
408
  gpu_indices = list(range(physical_count))
@@ -687,8 +698,14 @@ def _write_full_artifact(monitor, sidecar_data, step_times_raw=None):
687
698
  if sidecar_data.get("is_distributed"):
688
699
  probe_dict["is_distributed"] = True
689
700
  rank = sidecar_data.get("rank", 0)
701
+ world_size = sidecar_data.get("world_size", 1)
690
702
  probe_dict["rank"] = rank
691
- probe_dict["world_size"] = sidecar_data.get("world_size", 1)
703
+ probe_dict["world_size"] = world_size
704
+ # Set num_gpus_detected to world_size so the artifact reflects
705
+ # the full distributed topology, not just the local GPU count.
706
+ probe_dict["num_gpus_detected"] = max(
707
+ probe_dict.get("num_gpus_detected", 1), world_size
708
+ )
692
709
  if rank > 0:
693
710
  output_path = "alloc_artifact_rank{}.json.gz".format(rank)
694
711
 
@@ -370,6 +370,36 @@ def run(
370
370
  callback_data = _read_callback_data()
371
371
  step_count = callback_data.get("step_count") if callback_data else None
372
372
 
373
+ # Auto-merge per-rank callback artifacts for distributed runs.
374
+ # When DDP callbacks write alloc_artifact_rank{N}.json.gz alongside the
375
+ # main artifact, merge them to get per-rank peaks and straggler data.
376
+ try:
377
+ from alloc.artifact_loader import find_rank_artifacts, merge_artifacts, load_artifact
378
+ rank_files = find_rank_artifacts(".")
379
+ if rank_files:
380
+ # Include rank 0 artifact if it exists
381
+ main_artifact_path = os.path.join(".", "alloc_artifact.json.gz")
382
+ all_paths = ([main_artifact_path] if os.path.exists(main_artifact_path) else []) + rank_files
383
+ if len(all_paths) > 1:
384
+ merged = merge_artifacts(all_paths)
385
+ if merged is not None:
386
+ # Enrich probe result with merged multi-GPU data
387
+ result.num_gpus_detected = max(result.num_gpus_detected, merged.gpu_count or len(all_paths))
388
+ if merged.per_rank_peak_vram_mb:
389
+ result.per_gpu_peak_vram_mb = merged.per_rank_peak_vram_mb
390
+ # Use merged step timing if probe didn't capture it
391
+ if callback_data is None:
392
+ callback_data = {}
393
+ if merged.step_time_p50_ms and not callback_data.get("step_time_ms_p50"):
394
+ callback_data["step_time_ms_p50"] = merged.step_time_p50_ms
395
+ if merged.step_time_p90_ms and not callback_data.get("step_time_ms_p90"):
396
+ callback_data["step_time_ms_p90"] = merged.step_time_p90_ms
397
+ if merged.throughput_samples_per_sec and not callback_data.get("samples_per_sec"):
398
+ callback_data["samples_per_sec"] = merged.throughput_samples_per_sec
399
+ step_count = step_count or callback_data.get("step_count")
400
+ except Exception:
401
+ pass # Never crash on merge failure
402
+
373
403
  # Discover environment context (git, container, Ray)
374
404
  from alloc.context import discover_context
375
405
  env_context = discover_context()
@@ -2117,6 +2147,13 @@ def login(
2117
2147
  ),
2118
2148
  ):
2119
2149
  """Authenticate with Alloc dashboard."""
2150
+ # Suppress noisy third-party warnings (urllib3 LibreSSL, pynvml deprecation)
2151
+ # that clutter the auth flow output.
2152
+ import warnings
2153
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module="pynvml")
2154
+ warnings.filterwarnings("ignore", message=".*LibreSSL.*", module="urllib3")
2155
+ warnings.filterwarnings("ignore", message=".*pynvml.*", category=FutureWarning)
2156
+
2120
2157
  import httpx
2121
2158
  from alloc.config import get_supabase_url, get_supabase_anon_key, load_config, save_config
2122
2159
 
@@ -576,6 +576,56 @@ def _find_distributed(
576
576
  backend=None,
577
577
  ))
578
578
 
579
+ # Lightning: pl.Trainer(...), Trainer from pytorch_lightning/lightning.pytorch
580
+ _lightning_prefixes = ("pytorch_lightning", "lightning.pytorch", "lightning")
581
+ for node in ast.walk(tree):
582
+ if not isinstance(node, ast.Call):
583
+ continue
584
+ fqn = _resolve_call_name(node, imports)
585
+ if fqn is None:
586
+ continue
587
+ # Direct match: pytorch_lightning.Trainer(...) or lightning.pytorch.Trainer(...)
588
+ if any(fqn.startswith(pfx) for pfx in _lightning_prefixes) and "Trainer" in fqn:
589
+ if not any(d.kind == "lightning" for d in results):
590
+ results.append(DistributedFinding(
591
+ location=_loc(script_path, node, lines),
592
+ kind="lightning",
593
+ backend=None,
594
+ ))
595
+ break
596
+ # Import-resolved: Trainer imported from lightning
597
+ if fqn == "Trainer" or fqn.endswith(".Trainer"):
598
+ src = imports.get("Trainer", "")
599
+ if any(src.startswith(pfx) for pfx in _lightning_prefixes):
600
+ if not any(d.kind == "lightning" for d in results):
601
+ results.append(DistributedFinding(
602
+ location=_loc(script_path, node, lines),
603
+ kind="lightning",
604
+ backend=None,
605
+ ))
606
+ break
607
+
608
+ # LightningModule subclass detection
609
+ for node in ast.walk(tree):
610
+ if not isinstance(node, ast.ClassDef):
611
+ continue
612
+ for base in node.bases:
613
+ base_name = None
614
+ if isinstance(base, ast.Name):
615
+ base_name = base.id
616
+ elif isinstance(base, ast.Attribute):
617
+ base_name = base.attr
618
+ if base_name == "LightningModule":
619
+ src = imports.get("LightningModule", "")
620
+ if not src or any(src.startswith(pfx) for pfx in _lightning_prefixes):
621
+ if not any(d.kind == "lightning" for d in results):
622
+ results.append(DistributedFinding(
623
+ location=_loc(script_path, node, lines),
624
+ kind="lightning",
625
+ backend=None,
626
+ ))
627
+ break
628
+
579
629
  return results
580
630
 
581
631
 
@@ -973,6 +1023,10 @@ def _merge_imported_findings(
973
1023
  for opt in _find_optimizers(tree, imports, lines, imported_path):
974
1024
  main_findings.optimizers.append(opt)
975
1025
 
1026
+ # Merge fine-tuning findings
1027
+ for ft in _find_fine_tuning(tree, imports, lines, imported_path):
1028
+ main_findings.fine_tuning.append(ft)
1029
+
976
1030
  # Extract TrainingArguments from imported file
977
1031
  sub_findings = CodeFindings(script_path=imported_path)
978
1032
  sub_findings.imports = imports
@@ -40,8 +40,9 @@ def save_config(data: dict) -> None:
40
40
  cfg_file = _config_file()
41
41
  cfg_file.write_text(json.dumps(data, indent=2) + "\n")
42
42
  os.chmod(cfg_file, 0o600)
43
- except Exception:
44
- pass
43
+ except Exception as e:
44
+ import sys
45
+ print(f"Warning: could not secure config file permissions: {e}", file=sys.stderr)
45
46
 
46
47
 
47
48
  def get_token() -> str:
@@ -443,23 +443,46 @@ def print_diagnose_efficiency(result: DiagnoseResult) -> None:
443
443
  console.print(f" Step time (p50): {p50:.1f} ms")
444
444
  console.print()
445
445
 
446
- # Visual bar
447
- compute_w = int(eff["compute_pct"] / 100 * 48)
448
- data_w = int(eff["data_loading_pct"] / 100 * 48)
449
- other_w = max(0, 48 - compute_w - data_w)
450
-
451
- bar = (
452
- "[green]" + "" * compute_w + "[/green]"
453
- + "[yellow]" + "█" * data_w + "[/yellow]"
454
- + "[dim]" + "░" * other_w + "[/dim]"
455
- )
456
- console.print(f" {bar}")
457
- label = f" [green]Compute: {eff['compute_pct']:.0f}%[/green]"
458
- if eff["data_loading_pct"] > 0:
459
- label += f" [yellow]Data: {eff['data_loading_pct']:.0f}%[/yellow]"
460
- if eff["other_pct"] > 0:
461
- label += f" [dim]Other: {eff['other_pct']:.0f}%[/dim]"
462
- console.print(label)
446
+ # Visual bar — layout depends on source (cuda_events vs wall_clock)
447
+ is_cuda = eff.get("source") == "cuda_events"
448
+
449
+ if is_cuda:
450
+ fwd_w = int(eff["forward_pct"] / 100 * 48)
451
+ bwd_w = int(eff["backward_pct"] / 100 * 48)
452
+ opt_w = int(eff["optimizer_pct"] / 100 * 48)
453
+ dl_w = max(0, 48 - fwd_w - bwd_w - opt_w)
454
+
455
+ bar = (
456
+ "[green]" + "█" * fwd_w + "[/green]"
457
+ + "[cyan]" + "█" * bwd_w + "[/cyan]"
458
+ + "[magenta]" + "█" * opt_w + "[/magenta]"
459
+ + "[yellow]" + "█" * dl_w + "[/yellow]"
460
+ )
461
+ console.print(f" {bar}")
462
+ label = f" [green]Forward: {eff['forward_pct']:.0f}%[/green]"
463
+ label += f" [cyan]Backward: {eff['backward_pct']:.0f}%[/cyan]"
464
+ if eff["optimizer_pct"] > 0:
465
+ label += f" [magenta]Optimizer: {eff['optimizer_pct']:.0f}%[/magenta]"
466
+ if eff["data_loading_pct"] > 0:
467
+ label += f" [yellow]Data: {eff['data_loading_pct']:.0f}%[/yellow]"
468
+ console.print(label)
469
+ else:
470
+ compute_w = int(eff["compute_pct"] / 100 * 48)
471
+ data_w = int(eff["data_loading_pct"] / 100 * 48)
472
+ other_w = max(0, 48 - compute_w - data_w)
473
+
474
+ bar = (
475
+ "[green]" + "█" * compute_w + "[/green]"
476
+ + "[yellow]" + "█" * data_w + "[/yellow]"
477
+ + "[dim]" + "░" * other_w + "[/dim]"
478
+ )
479
+ console.print(f" {bar}")
480
+ label = f" [green]Compute: {eff['compute_pct']:.0f}%[/green]"
481
+ if eff["data_loading_pct"] > 0:
482
+ label += f" [yellow]Data: {eff['data_loading_pct']:.0f}%[/yellow]"
483
+ if eff["other_pct"] > 0:
484
+ label += f" [dim]Other: {eff['other_pct']:.0f}%[/dim]"
485
+ console.print(label)
463
486
  console.print()
464
487
 
465
488
  # Component table
@@ -468,14 +491,28 @@ def print_diagnose_efficiency(result: DiagnoseResult) -> None:
468
491
  table.add_column("Time (est.)", justify="right", style="bold")
469
492
  table.add_column("Notes", style="dim")
470
493
 
471
- table.add_row("GPU compute", f"{eff['compute_ms']:.1f} ms", f"{eff['compute_pct']:.0f}% of step")
472
- if eff["data_loading_pct"] > 0:
473
- dl_note = f"{eff['data_loading_pct']:.0f}%"
474
- if eff["data_loading_pct"] > 20:
475
- dl_note += " — bottleneck candidate"
476
- table.add_row("Data loading", f"{eff['data_loading_ms']:.1f} ms", dl_note)
477
- if eff["other_pct"] > 0:
478
- table.add_row("Other/overhead", f"{eff['other_ms']:.1f} ms", f"{eff['other_pct']:.0f}%")
494
+ if is_cuda:
495
+ table.add_row("Forward", f"{eff['forward_ms']:.1f} ms", f"{eff['forward_pct']:.0f}% of step")
496
+ table.add_row("Backward", f"{eff['backward_ms']:.1f} ms", f"{eff['backward_pct']:.0f}% of step")
497
+ if eff["optimizer_pct"] > 0:
498
+ opt_note = f"{eff['optimizer_pct']:.0f}%"
499
+ if eff["optimizer_pct"] > 30:
500
+ opt_note += " bottleneck candidate"
501
+ table.add_row("Optimizer", f"{eff['optimizer_ms']:.1f} ms", opt_note)
502
+ if eff["data_loading_pct"] > 0:
503
+ dl_note = f"{eff['data_loading_pct']:.0f}%"
504
+ if eff["data_loading_pct"] > 30:
505
+ dl_note += " — bottleneck candidate"
506
+ table.add_row("Data loading", f"{eff['data_loading_ms']:.1f} ms", dl_note)
507
+ else:
508
+ table.add_row("GPU compute", f"{eff['compute_ms']:.1f} ms", f"{eff['compute_pct']:.0f}% of step")
509
+ if eff["data_loading_pct"] > 0:
510
+ dl_note = f"{eff['data_loading_pct']:.0f}%"
511
+ if eff["data_loading_pct"] > 20:
512
+ dl_note += " — bottleneck candidate"
513
+ table.add_row("Data loading", f"{eff['data_loading_ms']:.1f} ms", dl_note)
514
+ if eff["other_pct"] > 0:
515
+ table.add_row("Other/overhead", f"{eff['other_ms']:.1f} ms", f"{eff['other_pct']:.0f}%")
479
516
 
480
517
  console.print(table)
481
518
  console.print()
@@ -500,9 +537,19 @@ def _print_efficiency_plain(result: DiagnoseResult) -> None:
500
537
  print(f"\n Efficiency breakdown (estimated)")
501
538
  print(f" Step time (p50): {eff['step_time_p50_ms']:.1f} ms")
502
539
  print(f" {'─' * 40}")
503
- print(f" GPU compute: {eff['compute_ms']:>8.1f} ms ({eff['compute_pct']:.0f}%)")
504
- if eff["data_loading_pct"] > 0:
505
- print(f" Data loading: {eff['data_loading_ms']:>8.1f} ms ({eff['data_loading_pct']:.0f}%)")
540
+
541
+ if eff.get("source") == "cuda_events":
542
+ print(f" Forward: {eff['forward_ms']:>8.1f} ms ({eff['forward_pct']:.0f}%)")
543
+ print(f" Backward: {eff['backward_ms']:>8.1f} ms ({eff['backward_pct']:.0f}%)")
544
+ if eff["optimizer_pct"] > 0:
545
+ print(f" Optimizer: {eff['optimizer_ms']:>8.1f} ms ({eff['optimizer_pct']:.0f}%)")
546
+ if eff["data_loading_pct"] > 0:
547
+ print(f" Data loading: {eff['data_loading_ms']:>8.1f} ms ({eff['data_loading_pct']:.0f}%)")
548
+ else:
549
+ print(f" GPU compute: {eff['compute_ms']:>8.1f} ms ({eff['compute_pct']:.0f}%)")
550
+ if eff["data_loading_pct"] > 0:
551
+ print(f" Data loading: {eff['data_loading_ms']:>8.1f} ms ({eff['data_loading_pct']:.0f}%)")
552
+
506
553
  bn = eff.get("bottleneck")
507
554
  if bn:
508
555
  print(f"\n Bottleneck: {bn}")
@@ -185,8 +185,8 @@ def _build_comparison(current: ArtifactData, previous: ArtifactData) -> Dict:
185
185
  })
186
186
 
187
187
  # Peak VRAM
188
- cur_peak = max(current.per_gpu_vram_used_mb) if current.per_gpu_vram_used_mb else current.peak_vram_mb
189
- prev_peak = max(previous.per_gpu_vram_used_mb) if previous.per_gpu_vram_used_mb else previous.peak_vram_mb
188
+ cur_peak = max(current.per_gpu_vram_used_mb) if current.per_gpu_vram_used_mb and len(current.per_gpu_vram_used_mb) > 0 else current.peak_vram_mb
189
+ prev_peak = max(previous.per_gpu_vram_used_mb) if previous.per_gpu_vram_used_mb and len(previous.per_gpu_vram_used_mb) > 0 else previous.peak_vram_mb
190
190
  _add("Peak VRAM", cur_peak, prev_peak, "MB", higher_is_worse=True)
191
191
 
192
192
  # Step time
@@ -289,7 +289,9 @@ def rule_dl005_main_thread(
289
289
  """
290
290
  results = []
291
291
  gpu_count = (hw or {}).get("gpu_count", 1) or 1
292
- recommended = max(4, gpu_count * 2)
292
+ cpu_cores = os.cpu_count() or 4
293
+ per_gpu_cores = max(1, cpu_cores // max(gpu_count, 1))
294
+ recommended = max(4, min(gpu_count * 2, per_gpu_cores))
293
295
 
294
296
  for dl in findings.dataloaders:
295
297
  if dl.num_workers != 0:
@@ -428,7 +430,7 @@ def rule_mem005_no_torch_compile(
428
430
  return [Diagnosis(
429
431
  rule_id="MEM005",
430
432
  severity="info",
431
- category="throughput",
433
+ category="memory",
432
434
  title="torch.compile not used",
433
435
  file_path=findings.script_path,
434
436
  line_number=0,
@@ -446,7 +448,7 @@ def rule_mem005_no_torch_compile(
446
448
  return [Diagnosis(
447
449
  rule_id="MEM005",
448
450
  severity="info",
449
- category="throughput",
451
+ category="memory",
450
452
  title="torch.compile not used",
451
453
  file_path=findings.script_path,
452
454
  line_number=0,
@@ -8,6 +8,7 @@ Graceful no-op if pynvml is not installed or no GPU is available.
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ import os
11
12
  import signal
12
13
  import subprocess
13
14
  import sys
@@ -112,16 +113,38 @@ def _discover_gpu_indices(proc_pid, pynvml, fallback_index=0):
112
113
  except Exception:
113
114
  return [fallback_index]
114
115
 
115
- # Collect target PIDs: the main process + its children
116
+ # Respect CUDA_VISIBLE_DEVICES only search visible GPUs
117
+ cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
118
+ if cvd:
119
+ visible_physical = []
120
+ for d in cvd.split(","):
121
+ d = d.strip()
122
+ if d:
123
+ try:
124
+ idx = int(d)
125
+ if 0 <= idx < device_count:
126
+ visible_physical.append(idx)
127
+ except ValueError:
128
+ visible_physical = list(range(device_count))
129
+ break
130
+ search_indices = visible_physical if visible_physical else list(range(device_count))
131
+ else:
132
+ search_indices = list(range(device_count))
133
+
134
+ # Collect target PIDs: the main process + descendants (3 levels deep).
135
+ # torchrun uses: torchrun → elastic_agent → worker processes,
136
+ # so we need at least 3 levels to find DDP worker GPUs.
116
137
  target_pids = {proc_pid}
117
138
  for child in _get_child_pids(proc_pid):
118
139
  target_pids.add(child)
119
- # Also check grandchildren (common with torchrun/accelerate)
120
140
  for grandchild in _get_child_pids(child):
121
141
  target_pids.add(grandchild)
142
+ # Great-grandchildren: covers torchrun elastic launch wrapper
143
+ for ggchild in _get_child_pids(grandchild):
144
+ target_pids.add(ggchild)
122
145
 
123
146
  found_indices = []
124
- for idx in range(device_count):
147
+ for idx in search_indices:
125
148
  try:
126
149
  handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
127
150
  procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
@@ -284,10 +307,27 @@ def probe_command(
284
307
 
285
308
  handles = [handle]
286
309
  discovery_done = False
310
+ discovery_attempts = 0
311
+ max_discovery_attempts = 3 # Retry at samples 5, 15, 30
312
+
313
+ # Determine expected GPU count from environment for retry logic
314
+ expected_gpus = 1
315
+ ws = os.environ.get("WORLD_SIZE", "").strip()
316
+ if ws:
317
+ try:
318
+ expected_gpus = max(1, int(ws))
319
+ except ValueError:
320
+ pass
287
321
 
288
322
  while not stop_event.is_set():
289
- # After 5 samples, try to discover all GPUs used by the process
290
- if not discovery_done and len(samples) >= 5 and proc.pid:
323
+ # Retry GPU discovery: at samples 5, 15, 30
324
+ # Keep retrying if we haven't found all expected GPUs yet
325
+ discovery_thresholds = [5, 15, 30]
326
+ if (not discovery_done
327
+ and discovery_attempts < max_discovery_attempts
328
+ and len(samples) >= discovery_thresholds[discovery_attempts]
329
+ and proc.pid):
330
+ discovery_attempts += 1
291
331
  try:
292
332
  discovered = _discover_gpu_indices(proc.pid, pynvml, fallback_index=gpu_index)
293
333
  if len(discovered) > 1:
@@ -303,7 +343,9 @@ def probe_command(
303
343
  pass
304
344
  # Detect interconnect type between discovered GPUs
305
345
  detected_ic_ref[0] = _detect_interconnect(handles, pynvml)
306
- discovery_done = True
346
+ # Stop retrying if we found expected count or exhausted attempts
347
+ if num_gpus_ref[0] >= expected_gpus or discovery_attempts >= max_discovery_attempts:
348
+ discovery_done = True
307
349
 
308
350
  # Sample from all monitored GPUs — aggregate: peak vram = max, util/power = mean
309
351
  try:
@@ -414,6 +456,18 @@ def probe_command(
414
456
  if calibration_time_ref[0] is not None:
415
457
  cal_duration = round(calibration_time_ref[0] - start_time, 2)
416
458
 
459
+ # Environment-based fallback: if NVML discovery found fewer GPUs than
460
+ # WORLD_SIZE indicates, trust the environment. The probe may miss GPUs
461
+ # due to DDP per-rank CVD isolation or timing races.
462
+ env_world = os.environ.get("WORLD_SIZE", "").strip()
463
+ if env_world:
464
+ try:
465
+ ws = int(env_world)
466
+ if ws > num_gpus_ref[0]:
467
+ num_gpus_ref[0] = ws
468
+ except ValueError:
469
+ pass
470
+
417
471
  if not samples:
418
472
  return ProbeResult(
419
473
  duration_seconds=round(duration, 2),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alloc
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Engineer-first training calibration: estimate VRAM fit, profile short runs, and pick GPU configs under real budget constraints.
5
5
  Author-email: Alloc Labs <hello@alloclabs.com>
6
6
  License-Expression: Apache-2.0
@@ -234,6 +234,7 @@ def test_status_json_no_artifact(tmp_path, monkeypatch):
234
234
  import json
235
235
  monkeypatch.chdir(tmp_path)
236
236
  monkeypatch.delenv("ALLOC_TOKEN", raising=False)
237
+ monkeypatch.setenv("HOME", str(tmp_path)) # isolate from real ~/.alloc/config.json
237
238
  result = runner.invoke(app, ["status", "--json"])
238
239
  assert result.exit_code == 0
239
240
  data = json.loads(result.output.strip())
@@ -325,6 +326,7 @@ def test_status_not_logged_in(tmp_path, monkeypatch):
325
326
  """alloc status without token shows not-logged-in state."""
326
327
  monkeypatch.chdir(tmp_path)
327
328
  monkeypatch.delenv("ALLOC_TOKEN", raising=False)
329
+ monkeypatch.setenv("HOME", str(tmp_path)) # isolate from real ~/.alloc/config.json
328
330
  result = runner.invoke(app, ["status"])
329
331
  assert result.exit_code == 0
330
332
  out = _plain(result.output)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes