alloc 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
alloc/cli.py ADDED
@@ -0,0 +1,1341 @@
1
+ """Alloc CLI — GPU intelligence for ML training.
2
+
3
+ Commands:
4
+ alloc ghost <script.py> Ghost scan — static VRAM analysis without executing the model
5
+ alloc run <command...> Wrap training with probe monitoring, write artifact
6
+ alloc scan --model <name> Remote ghost scan via API — no GPU needed
7
+ alloc login Authenticate with Alloc dashboard
8
+ alloc upload <artifact> Upload an artifact to the Alloc dashboard
9
+ alloc version Show version
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import sys
16
+ from typing import Optional
17
+
18
+ import typer
19
+ from rich.console import Console
20
+
21
+ import json as json_mod
22
+
23
+ from alloc import __version__
24
+ from alloc.config import get_api_url, get_token, should_upload, try_refresh_access_token
25
+
26
+ app = typer.Typer(
27
+ name="alloc",
28
+ help="GPU intelligence for ML training. Right-size before you launch.",
29
+ no_args_is_help=True,
30
+ add_completion=False,
31
+ )
32
+ console = Console()
33
+
34
+
35
+ @app.command()
36
+ def ghost(
37
+ script: str = typer.Argument(..., help="Python script to analyze (e.g. train.py)"),
38
+ dtype: str = typer.Option("fp16", help="Data type: fp16, bf16, fp32"),
39
+ batch_size: int = typer.Option(32, help="Training batch size"),
40
+ seq_length: int = typer.Option(2048, help="Sequence length"),
41
+ hidden_dim: int = typer.Option(4096, help="Hidden dimension"),
42
+ json_output: bool = typer.Option(False, "--json", help="Output machine-readable JSON"),
43
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed VRAM formula breakdown"),
44
+ param_count_b: Optional[float] = typer.Option(None, "--param-count-b", "-p", help="Param count in billions (skip script analysis)"),
45
+ timeout: int = typer.Option(60, "--timeout", help="Max seconds for model extraction"),
46
+ no_config: bool = typer.Option(False, "--no-config", help="Skip .alloc.yaml (use catalog defaults)"),
47
+ ):
48
+ """Ghost scan — static VRAM analysis without executing the model."""
49
+ from alloc.ghost import ghost as ghost_fn
50
+ from alloc.display import print_ghost_report
51
+ from alloc.model_extractor import extract_model_info
52
+
53
+ # Load GPU context from .alloc.yaml
54
+ gpu_context = _load_gpu_context(no_config)
55
+
56
+ info = extract_model_info(script, timeout=timeout, param_count_b=param_count_b)
57
+
58
+ if info is None:
59
+ if json_output:
60
+ _print_json({"error": f"Could not extract model from {script}"})
61
+ else:
62
+ console.print(f"[yellow]Could not extract model from {script}.[/yellow]")
63
+ console.print("[dim]Supported: PyTorch nn.Module, HuggingFace AutoModel, Lightning modules.[/dim]")
64
+ console.print(f"[dim]Tip: alloc ghost {script} --param-count-b 7.0[/dim]")
65
+ raise typer.Exit(1)
66
+
67
+ # Use dtype from execution if available, otherwise CLI flag
68
+ resolved_dtype = info.dtype if info.method == "execution" else dtype
69
+
70
+ report = ghost_fn(
71
+ param_count=info.param_count,
72
+ dtype=resolved_dtype,
73
+ batch_size=batch_size,
74
+ seq_length=seq_length,
75
+ hidden_dim=hidden_dim,
76
+ )
77
+ report.extraction_method = info.method
78
+
79
+ if json_output:
80
+ data = report.to_dict()
81
+ if gpu_context:
82
+ data["gpu_context"] = gpu_context
83
+ _print_json(data)
84
+ else:
85
+ print_ghost_report(report)
86
+ if gpu_context and not verbose:
87
+ _print_gpu_context_summary(gpu_context)
88
+ if verbose:
89
+ from alloc.display import print_verbose_ghost
90
+ print_verbose_ghost(report)
91
+ if gpu_context:
92
+ _print_gpu_context_detail(gpu_context)
93
+
94
+
95
+ @app.command()
96
+ def run(
97
+ command: list[str] = typer.Argument(..., help="Command to run (e.g. python train.py)"),
98
+ timeout: int = typer.Option(120, help="Max calibration time in seconds"),
99
+ gpu: int = typer.Option(0, help="GPU index to monitor"),
100
+ save: bool = typer.Option(True, help="Save artifact to disk"),
101
+ out: Optional[str] = typer.Option(None, "--out", help="Output path for artifact"),
102
+ upload: bool = typer.Option(False, "--upload", help="Upload artifact to Alloc dashboard after run"),
103
+ full: bool = typer.Option(False, "--full", help="Monitor full training run instead of calibrating"),
104
+ json_output: bool = typer.Option(False, "--json", help="Output machine-readable JSON"),
105
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show hardware context, sample dump, recommendation reasoning"),
106
+ no_config: bool = typer.Option(False, "--no-config", help="Skip .alloc.yaml (use catalog defaults)"),
107
+ ):
108
+ """Run a training command with GPU monitoring."""
109
+ from alloc.probe import probe_command
110
+ from alloc.display import print_probe_result, print_verdict
111
+ from alloc.artifact_writer import write_report
112
+
113
+ # Load GPU context from .alloc.yaml
114
+ gpu_context = _load_gpu_context(no_config)
115
+
116
+ if not command:
117
+ console.print("[red]No command provided.[/red]")
118
+ console.print("Usage: alloc run python train.py")
119
+ raise typer.Exit(1)
120
+
121
+ # ALLOC_POLICY: "warn" or "enforce" forces full monitoring
122
+ alloc_policy = os.environ.get("ALLOC_POLICY", "").lower().strip()
123
+ if alloc_policy and alloc_policy not in ("warn", "enforce"):
124
+ console.print(f"[yellow]Unknown ALLOC_POLICY='{alloc_policy}', ignoring.[/yellow]")
125
+ alloc_policy = ""
126
+ if alloc_policy:
127
+ full = True
128
+
129
+ # Determine mode
130
+ calibrate = not full
131
+
132
+ # Effective timeout: unlimited for --full unless user explicitly set it
133
+ effective_timeout = timeout
134
+ if full and timeout == 120: # User didn't override default
135
+ effective_timeout = 0 # Unlimited for full mode
136
+
137
+ # Mode label
138
+ if calibrate:
139
+ mode_label = "Calibrate"
140
+ elif full:
141
+ mode_label = "Full monitoring"
142
+ else:
143
+ mode_label = "Calibrate"
144
+
145
+ if not json_output:
146
+ console.print(f"[green]alloc[/green] [dim]v{__version__}[/dim] — {mode_label}")
147
+ console.print(f"[dim]Command: {' '.join(command)}[/dim]")
148
+ if calibrate:
149
+ console.print(f"[dim]Auto-stop when metrics stabilize (timeout: {timeout}s)[/dim]")
150
+ console.print()
151
+
152
+ result = probe_command(
153
+ command,
154
+ timeout_seconds=effective_timeout,
155
+ gpu_index=gpu,
156
+ calibrate=calibrate,
157
+ )
158
+
159
+ if not json_output:
160
+ if result.error and "pynvml" in result.error:
161
+ console.print(f"[yellow]{result.error}[/yellow]")
162
+ console.print("[dim]Process ran without GPU monitoring.[/dim]")
163
+ elif result.error:
164
+ console.print(f"[red]Error: {result.error}[/red]")
165
+
166
+ # Read callback data from sidecar (written by framework callbacks)
167
+ callback_data = _read_callback_data()
168
+ step_count = callback_data.get("step_count") if callback_data else None
169
+
170
+ # Discover environment context (git, container, Ray)
171
+ from alloc.context import discover_context
172
+ env_context = discover_context()
173
+ topology = _infer_parallel_topology_from_env(
174
+ num_gpus_detected=result.num_gpus_detected,
175
+ config_interconnect=gpu_context.get("interconnect") if gpu_context else None,
176
+ detected_interconnect=result.detected_interconnect,
177
+ )
178
+ objective = os.environ.get("ALLOC_OBJECTIVE", "").strip().lower() or _objective_from_context(gpu_context)
179
+ max_budget_hourly = _max_budget_hourly_from_context(gpu_context)
180
+
181
+ # Build artifact dict with new fields
182
+ artifact_path = ""
183
+ if save:
184
+ probe_dict = {
185
+ "peak_vram_mb": result.peak_vram_mb,
186
+ "avg_gpu_util": result.avg_gpu_util,
187
+ "avg_power_watts": result.avg_power_watts,
188
+ "duration_seconds": result.duration_seconds,
189
+ "samples": result.samples,
190
+ "exit_code": result.exit_code,
191
+ "probe_mode": result.probe_mode,
192
+ "steps_profiled": result.steps_profiled,
193
+ "stop_reason": result.stop_reason,
194
+ "gpu_name": result.gpu_name,
195
+ "gpu_total_vram_mb": result.gpu_total_vram_mb,
196
+ "calibration_duration_s": result.calibration_duration_s,
197
+ "step_count": step_count,
198
+ "num_nodes": topology.get("num_nodes"),
199
+ "gpus_per_node": topology.get("gpus_per_node"),
200
+ "tp_degree": topology.get("tp_degree"),
201
+ "pp_degree": topology.get("pp_degree"),
202
+ "dp_degree": topology.get("dp_degree"),
203
+ "interconnect_type": topology.get("interconnect_type"),
204
+ "objective": objective,
205
+ "max_budget_hourly": max_budget_hourly,
206
+ }
207
+ # Merge timing fields from callback sidecar
208
+ if callback_data:
209
+ for key in ("step_time_ms_p50", "step_time_ms_p90", "samples_per_sec",
210
+ "dataloader_wait_pct", "comm_overhead_pct"):
211
+ val = callback_data.get(key)
212
+ if val is not None:
213
+ probe_dict[key] = val
214
+ # Merge per-GPU peak VRAM from probe (maps to per_rank_peak_vram_mb)
215
+ if result.per_gpu_peak_vram_mb:
216
+ probe_dict["per_rank_peak_vram_mb"] = result.per_gpu_peak_vram_mb
217
+ hw_context = {
218
+ "gpu_name": result.gpu_name,
219
+ "gpu_total_vram_mb": result.gpu_total_vram_mb,
220
+ "driver_version": result.driver_version,
221
+ "cuda_version": result.cuda_version,
222
+ "sm_version": result.sm_version,
223
+ "num_gpus_detected": result.num_gpus_detected,
224
+ }
225
+ artifact_path = write_report(
226
+ probe_result=probe_dict,
227
+ output_path=out,
228
+ hardware_context=hw_context,
229
+ context=env_context if env_context else None,
230
+ )
231
+
232
+ # Build budget context for verdict display
233
+ budget_ctx = _build_budget_context(gpu_context, result)
234
+
235
+ if json_output:
236
+ from alloc.display import build_verdict_dict
237
+ data = build_verdict_dict(result, artifact_path=artifact_path, step_count=step_count, callback_data=callback_data, budget_context=budget_ctx)
238
+ if result.error:
239
+ data["error"] = result.error
240
+ _print_json(data)
241
+ else:
242
+ # Display verdict for all modes when we have GPU data
243
+ if result.peak_vram_mb > 0:
244
+ print_verdict(result, artifact_path=artifact_path, step_count=step_count, callback_data=callback_data, budget_context=budget_ctx)
245
+ if verbose:
246
+ from alloc.display import print_verbose_run
247
+ print_verbose_run(result, step_count=step_count)
248
+ elif artifact_path:
249
+ console.print(f"[dim]Artifact saved: {artifact_path}[/dim]")
250
+
251
+ # Next action hint
252
+ if artifact_path and not (upload or should_upload()):
253
+ pass # print_verdict already shows "Next: alloc upload ..."
254
+
255
+ # Upload if --upload flag or ALLOC_UPLOAD env var
256
+ if artifact_path and (upload or should_upload()):
257
+ _try_upload(artifact_path)
258
+
259
+ if result.exit_code and result.exit_code != 0:
260
+ raise typer.Exit(result.exit_code)
261
+
262
+
263
+ @app.command()
264
+ def scan(
265
+ model: str = typer.Option(..., "--model", "-m", help="Model name (e.g. llama-3-70b)"),
266
+ gpu: str = typer.Option("A100-80GB", "--gpu", "-g", help="Target GPU type"),
267
+ dtype: str = typer.Option("fp16", help="Data type: fp16, bf16, fp32"),
268
+ strategy: str = typer.Option(
269
+ "ddp",
270
+ help=(
271
+ "Strategy: ddp, fsdp, tp, pp, tp+dp, pp+dp, tp+pp+dp, tp+pp+fsdp"
272
+ ),
273
+ ),
274
+ num_gpus: int = typer.Option(1, help="Number of GPUs"),
275
+ objective: str = typer.Option(
276
+ "best_value",
277
+ help="Objective: cheapest, fastest, fastest_within_budget, best_value",
278
+ ),
279
+ max_budget_hourly: Optional[float] = typer.Option(
280
+ None,
281
+ "--max-budget-hourly",
282
+ help="Optional max $/hr budget for this planning run",
283
+ ),
284
+ num_nodes: Optional[int] = typer.Option(None, "--num-nodes", help="Cluster node count"),
285
+ gpus_per_node: Optional[int] = typer.Option(None, "--gpus-per-node", help="GPUs per node"),
286
+ tp_degree: Optional[int] = typer.Option(None, "--tp-degree", help="Tensor parallel degree"),
287
+ pp_degree: Optional[int] = typer.Option(None, "--pp-degree", help="Pipeline parallel degree"),
288
+ dp_degree: Optional[int] = typer.Option(None, "--dp-degree", help="Data parallel degree"),
289
+ interconnect: str = typer.Option(
290
+ "unknown",
291
+ "--interconnect",
292
+ help="Interconnect: pcie, nvlink, infiniband, unknown",
293
+ ),
294
+ param_count_b: Optional[float] = typer.Option(None, "--param-count-b", "-p", help="Param count in billions (overrides model lookup)"),
295
+ batch_size: int = typer.Option(32, help="Batch size"),
296
+ seq_length: int = typer.Option(2048, help="Sequence length"),
297
+ hidden_dim: int = typer.Option(4096, help="Hidden dimension"),
298
+ json_output: bool = typer.Option(False, "--json", help="Output machine-readable JSON"),
299
+ ):
300
+ """Remote ghost scan via Alloc API — no GPU needed."""
301
+ import httpx
302
+
303
+ # Resolve param count from model name or explicit flag
304
+ resolved_param_count = param_count_b or _model_to_params(model)
305
+ if resolved_param_count is None:
306
+ if json_output:
307
+ _print_json({"error": f"Unknown model: {model}"})
308
+ else:
309
+ console.print(f"[yellow]Unknown model: {model}[/yellow]")
310
+ console.print("[dim]Use --param-count-b to specify directly.[/dim]")
311
+ raise typer.Exit(1)
312
+
313
+ api_url = get_api_url()
314
+ token = get_token()
315
+
316
+ payload = {
317
+ "entrypoint": f"{model}.py",
318
+ "param_count_b": resolved_param_count,
319
+ "dtype": dtype,
320
+ "strategy": strategy,
321
+ "gpu_type": gpu,
322
+ "num_gpus": num_gpus,
323
+ "objective": objective,
324
+ "max_budget_hourly": max_budget_hourly,
325
+ "num_nodes": num_nodes,
326
+ "gpus_per_node": gpus_per_node,
327
+ "tp_degree": tp_degree,
328
+ "pp_degree": pp_degree,
329
+ "dp_degree": dp_degree,
330
+ "interconnect_type": interconnect,
331
+ "batch_size": batch_size,
332
+ "seq_length": seq_length,
333
+ "hidden_dim": hidden_dim,
334
+ }
335
+
336
+ if not json_output:
337
+ console.print(f"[green]alloc[/green] [dim]v{__version__}[/dim] — Remote Ghost Scan")
338
+ console.print(f"[dim]Model: {model} ({resolved_param_count}B) → {gpu} x{num_gpus}[/dim]")
339
+ console.print()
340
+
341
+ try:
342
+ headers = {"Content-Type": "application/json"}
343
+ if token:
344
+ headers["Authorization"] = f"Bearer {token}"
345
+
346
+ endpoint = "/scans" if token else "/scans/cli"
347
+ with httpx.Client(timeout=30) as client:
348
+ resp = client.post(f"{api_url}{endpoint}", json=payload, headers=headers)
349
+ resp.raise_for_status()
350
+ result = resp.json()
351
+
352
+ if json_output:
353
+ _print_json(result)
354
+ else:
355
+ _print_scan_result(result, gpu, strategy)
356
+ except httpx.HTTPStatusError as e:
357
+ if json_output:
358
+ _print_json({"error": f"API error {e.response.status_code}"})
359
+ elif e.response.status_code == 403:
360
+ console.print("[yellow]AI analysis requires a Pro or Enterprise plan.[/yellow]")
361
+ console.print("[dim]The scan still works — just without Euler AI analysis.[/dim]")
362
+ else:
363
+ console.print(f"[red]API error {e.response.status_code}[/red]")
364
+ raise typer.Exit(1)
365
+ except httpx.ConnectError:
366
+ if json_output:
367
+ _print_json({"error": f"Cannot connect to {api_url}"})
368
+ else:
369
+ console.print(f"[red]Cannot connect to {api_url}[/red]")
370
+ console.print("[dim]Check ALLOC_API_URL or try: alloc ghost <script.py> for local ghost scan[/dim]")
371
+ raise typer.Exit(1)
372
+
373
+
374
+ @app.command()
375
+ def login(
376
+ method: str = typer.Option(
377
+ "password",
378
+ "--method",
379
+ help="Auth method: password (Supabase) or token (paste an access token)",
380
+ ),
381
+ token: Optional[str] = typer.Option(
382
+ None,
383
+ "--token",
384
+ help="Access token (implies token login). If omitted, you'll be prompted.",
385
+ ),
386
+ ):
387
+ """Authenticate with Alloc dashboard."""
388
+ import httpx
389
+ from alloc.config import get_supabase_url, get_supabase_anon_key, load_config, save_config
390
+
391
+ method = (method or "").strip().lower()
392
+ if token is not None and method == "password":
393
+ # UX: allow `alloc login --token <...>` without requiring `--method token`.
394
+ method = "token"
395
+ if method not in ("password", "token"):
396
+ console.print("[red]Invalid --method. Use: password or token[/red]")
397
+ raise typer.Exit(2)
398
+
399
+ if method == "token":
400
+ access_token = token or typer.prompt("Access token", hide_input=True)
401
+ if not access_token.strip():
402
+ console.print("[red]Login failed: empty token.[/red]")
403
+ raise typer.Exit(1)
404
+
405
+ cfg = load_config()
406
+ cfg["token"] = access_token.strip()
407
+ cfg.pop("refresh_token", None) # token paste mode cannot refresh automatically
408
+ cfg.pop("email", None)
409
+ cfg["api_url"] = get_api_url()
410
+ save_config(cfg)
411
+ console.print("[green]Saved token.[/green]")
412
+ return
413
+
414
+ email = typer.prompt("Email")
415
+ password = typer.prompt("Password", hide_input=True)
416
+
417
+ supabase_url = get_supabase_url()
418
+ anon_key = get_supabase_anon_key()
419
+
420
+ try:
421
+ with httpx.Client(timeout=15) as client:
422
+ resp = client.post(
423
+ f"{supabase_url}/auth/v1/token?grant_type=password",
424
+ json={"email": email, "password": password},
425
+ headers={
426
+ "apikey": anon_key,
427
+ "Content-Type": "application/json",
428
+ },
429
+ )
430
+ resp.raise_for_status()
431
+ data = resp.json()
432
+
433
+ access_token = data.get("access_token", "")
434
+ refresh = data.get("refresh_token", "")
435
+ if not access_token:
436
+ console.print("[red]Login failed: no access token received.[/red]")
437
+ raise typer.Exit(1)
438
+
439
+ cfg = load_config()
440
+ cfg["token"] = access_token
441
+ cfg["refresh_token"] = refresh
442
+ cfg["email"] = email
443
+ cfg["api_url"] = get_api_url()
444
+ save_config(cfg)
445
+
446
+ console.print(f"[green]Logged in as {email}[/green]")
447
+ except httpx.HTTPStatusError as e:
448
+ detail = "authentication error"
449
+ try:
450
+ body = e.response.json()
451
+ desc = body.get("error_description", "")
452
+ if desc in ("Invalid login credentials", "Email not confirmed"):
453
+ detail = desc.lower()
454
+ except Exception:
455
+ pass
456
+ console.print(f"[red]Login failed: {detail}[/red]")
457
+ raise typer.Exit(1)
458
+ except httpx.ConnectError:
459
+ console.print(f"[red]Cannot connect to {supabase_url}[/red]")
460
+ raise typer.Exit(1)
461
+ except Exception as e:
462
+ console.print(f"[red]Login failed: {e}[/red]")
463
+ raise typer.Exit(1)
464
+
465
+
466
+ @app.command()
467
+ def logout():
468
+ """Log out (clear saved token from local config)."""
469
+ from alloc.config import load_config, save_config
470
+
471
+ cfg = load_config()
472
+ had = bool(cfg.get("token") or cfg.get("refresh_token") or cfg.get("email"))
473
+
474
+ cfg.pop("token", None)
475
+ cfg.pop("refresh_token", None)
476
+ cfg.pop("email", None)
477
+ save_config(cfg)
478
+
479
+ if had:
480
+ console.print("[green]Logged out.[/green]")
481
+ else:
482
+ console.print("[dim]No saved session found.[/dim]")
483
+
484
+ if os.environ.get("ALLOC_TOKEN"):
485
+ console.print("[dim]Note: ALLOC_TOKEN is set in your environment and still overrides local config.[/dim]")
486
+
487
+
488
+ @app.command()
489
+ def whoami(
490
+ json_output: bool = typer.Option(False, "--json", help="Output machine-readable JSON"),
491
+ ):
492
+ """Show current CLI auth + org context."""
493
+ import httpx
494
+ from alloc.config import get_api_url, get_token
495
+
496
+ api_url = get_api_url()
497
+ token = get_token()
498
+ token_source = "env" if os.environ.get("ALLOC_TOKEN") else "config"
499
+
500
+ out = {
501
+ "api_url": api_url,
502
+ "logged_in": bool(token),
503
+ "token_source": token_source if token else None,
504
+ }
505
+
506
+ if not token:
507
+ if json_output:
508
+ _print_json(out)
509
+ else:
510
+ console.print("[yellow]Not logged in.[/yellow]")
511
+ console.print("[dim]Run: alloc login[/dim]")
512
+ return
513
+
514
+ def _get(path: str) -> dict:
515
+ headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
516
+ with httpx.Client(timeout=15) as client:
517
+ resp = client.get(f"{api_url}{path}", headers=headers)
518
+ resp.raise_for_status()
519
+ return resp.json()
520
+
521
+ try:
522
+ profile = _get("/profile")
523
+ fleet = _get("/gpu-fleet")
524
+ except httpx.HTTPStatusError as e:
525
+ new_token = try_refresh_access_token() if e.response.status_code == 401 else None
526
+ if new_token:
527
+ # Retry once with refreshed token
528
+ token = new_token
529
+ profile = _get("/profile")
530
+ fleet = _get("/gpu-fleet")
531
+ else:
532
+ if json_output:
533
+ out["error"] = f"API error {e.response.status_code}"
534
+ _print_json(out)
535
+ else:
536
+ console.print(f"[red]API error {e.response.status_code}[/red]")
537
+ console.print("[dim]Run: alloc login[/dim]")
538
+ raise typer.Exit(1)
539
+ except httpx.ConnectError:
540
+ if json_output:
541
+ out["error"] = f"Cannot connect to {api_url}"
542
+ _print_json(out)
543
+ else:
544
+ console.print(f"[red]Cannot connect to {api_url}[/red]")
545
+ raise typer.Exit(1)
546
+
547
+ gpus = fleet.get("gpus") or []
548
+ fleet_count = len([g for g in gpus if g.get("fleet_status") == "in_fleet"])
549
+ explore_count = len([g for g in gpus if g.get("fleet_status") == "explore"])
550
+
551
+ out.update({
552
+ "email": profile.get("email"),
553
+ "user_id": profile.get("user_id"),
554
+ "onboarding_complete": profile.get("onboarding_complete"),
555
+ "objective": fleet.get("objective"),
556
+ "priority_cost": fleet.get("priority_cost"),
557
+ "budget_monthly_usd": fleet.get("budget_monthly"),
558
+ "effective_budget_monthly_usd": fleet.get("effective_budget_monthly"),
559
+ "budget_cap_applied": fleet.get("budget_cap_applied"),
560
+ "fleet_count": fleet_count,
561
+ "explore_count": explore_count,
562
+ "org_budget": fleet.get("org_budget"),
563
+ })
564
+
565
+ if json_output:
566
+ _print_json(out)
567
+ return
568
+
569
+ email = out.get("email") or "(unknown)"
570
+ console.print(f"[green]Logged in[/green] as {email}")
571
+ console.print(f"[dim]API: {api_url}[/dim]")
572
+ if out.get("objective"):
573
+ console.print(f"[dim]Objective: {out['objective']}[/dim]")
574
+ if out.get("effective_budget_monthly_usd") is not None:
575
+ cap = float(out["effective_budget_monthly_usd"])
576
+ line = f"[dim]Effective budget cap: ${cap:.2f}/mo[/dim]"
577
+ if out.get("budget_cap_applied"):
578
+ line += " [yellow](org cap applied)[/yellow]"
579
+ console.print(line)
580
+ console.print(f"[dim]Fleet GPUs: {fleet_count}, Explore GPUs: {explore_count}[/dim]")
581
+ console.print()
582
+
583
+
584
+ @app.command()
585
+ def upload(
586
+ artifact: str = typer.Argument(..., help="Path to alloc artifact (.json.gz)"),
587
+ ):
588
+ """Upload an artifact to the Alloc dashboard."""
589
+ if not os.path.isfile(artifact):
590
+ console.print(f"[red]File not found: {artifact}[/red]")
591
+ raise typer.Exit(1)
592
+
593
+ if not artifact.endswith(".json.gz"):
594
+ console.print("[red]Expected a .json.gz artifact file.[/red]")
595
+ raise typer.Exit(1)
596
+
597
+ _try_upload(artifact)
598
+
599
+
600
+ @app.command()
601
+ def init(
602
+ yes: bool = typer.Option(False, "--yes", "-y", help="Non-interactive: full catalog, 50/50 priority, no budget"),
603
+ from_org: bool = typer.Option(False, "--from-org", help="Pull fleet/budget from your org (requires alloc login)"),
604
+ ):
605
+ """Create .alloc.yaml with GPU fleet, explore, budget, and priority."""
606
+ from alloc.yaml_config import AllocConfig, FleetEntry, write_alloc_config, validate_config
607
+ from alloc.catalog import list_gpus
608
+
609
+ if os.path.isfile(".alloc.yaml"):
610
+ if not yes:
611
+ overwrite = typer.confirm(".alloc.yaml already exists. Overwrite?", default=False)
612
+ if not overwrite:
613
+ raise typer.Exit(0)
614
+
615
+ if from_org:
616
+ import httpx
617
+
618
+ token = get_token()
619
+ if not token:
620
+ console.print("[yellow]Not logged in. Run `alloc login` first.[/yellow]")
621
+ raise typer.Exit(1)
622
+
623
+ api_url = get_api_url()
624
+ headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
625
+
626
+ def _fetch() -> dict:
627
+ with httpx.Client(timeout=15) as client:
628
+ resp = client.get(f"{api_url}/gpu-fleet", headers=headers)
629
+ resp.raise_for_status()
630
+ return resp.json()
631
+
632
+ try:
633
+ payload = _fetch()
634
+ except httpx.HTTPStatusError as e:
635
+ if e.response.status_code == 401:
636
+ new_token = try_refresh_access_token()
637
+ if not new_token:
638
+ console.print("[yellow]Session expired. Run `alloc login` again.[/yellow]")
639
+ raise typer.Exit(1)
640
+ headers["Authorization"] = f"Bearer {new_token}"
641
+ payload = _fetch()
642
+ else:
643
+ console.print(f"[red]API error {e.response.status_code}[/red]")
644
+ console.print(f"[dim]{e.response.text[:200]}[/dim]")
645
+ raise typer.Exit(1)
646
+ except httpx.ConnectError:
647
+ console.print(f"[red]Cannot connect to {api_url}[/red]")
648
+ raise typer.Exit(1)
649
+
650
+ catalog = list_gpus()
651
+ display_to_id = {g["display_name"]: g["id"] for g in catalog}
652
+
653
+ unknown = [] # type: list[str]
654
+
655
+ def _resolve_gpu_id(display_name: str) -> str:
656
+ resolved = display_to_id.get(display_name)
657
+ if resolved:
658
+ return resolved
659
+ unknown.append(display_name)
660
+ return display_name
661
+
662
+ fleet_entries = []
663
+ explore_entries = []
664
+ for g in payload.get("gpus") or []:
665
+ status = (g.get("fleet_status") or "").strip().lower()
666
+ if status not in ("in_fleet", "explore"):
667
+ continue
668
+
669
+ display = (g.get("display_name") or g.get("gpu_id") or "").strip()
670
+ if not display:
671
+ continue
672
+
673
+ gpu_id = _resolve_gpu_id(display)
674
+
675
+ count = None
676
+ if g.get("max_count") is not None:
677
+ try:
678
+ parsed = int(g["max_count"])
679
+ count = parsed if parsed > 0 else None
680
+ except Exception:
681
+ count = None
682
+
683
+ rate = None
684
+ if g.get("rate") is not None and g.get("rate_source") in ("user", "org"):
685
+ try:
686
+ parsed = float(g["rate"])
687
+ rate = parsed if parsed >= 0 else None
688
+ except Exception:
689
+ rate = None
690
+
691
+ entry = FleetEntry(
692
+ gpu=gpu_id,
693
+ cloud=g.get("cloud"),
694
+ count=count,
695
+ rate=rate,
696
+ explore=(status == "explore"),
697
+ )
698
+ if status == "explore":
699
+ explore_entries.append(entry)
700
+ else:
701
+ fleet_entries.append(entry)
702
+
703
+ budget_monthly = payload.get("effective_budget_monthly")
704
+ if budget_monthly is not None:
705
+ try:
706
+ budget_monthly = float(budget_monthly)
707
+ except Exception:
708
+ budget_monthly = None
709
+ budget_hourly = None
710
+ if budget_monthly is not None and budget_monthly > 0:
711
+ # Match API behavior: budgets are enforced hourly using monthly/730.
712
+ budget_hourly = round(float(budget_monthly) / 730.0, 4)
713
+
714
+ # Preserve org ceiling so CLI can show "(org cap applied)"
715
+ org_budget_raw = payload.get("org_budget") or {}
716
+ org_ceiling = org_budget_raw.get("budget_monthly_usd") if isinstance(org_budget_raw, dict) else None
717
+ if org_ceiling is not None:
718
+ try:
719
+ org_ceiling = float(org_ceiling)
720
+ except Exception:
721
+ org_ceiling = None
722
+
723
+ # Resolve org interconnect preference
724
+ org_interconnect = None
725
+ raw_ic = payload.get("interconnect")
726
+ if isinstance(raw_ic, str) and raw_ic.strip().lower() in ("pcie", "nvlink", "infiniband"):
727
+ org_interconnect = raw_ic.strip().lower()
728
+
729
+ config = AllocConfig(
730
+ fleet=fleet_entries,
731
+ explore=explore_entries,
732
+ objective=payload.get("objective"),
733
+ priority_cost=int(payload.get("priority_cost") or 50),
734
+ budget_monthly=budget_monthly,
735
+ budget_hourly=budget_hourly,
736
+ org_budget_monthly=org_ceiling,
737
+ interconnect=org_interconnect,
738
+ )
739
+
740
+ errors = validate_config(config)
741
+ if errors:
742
+ for err in errors:
743
+ console.print(f"[red]{err}[/red]")
744
+ raise typer.Exit(1)
745
+
746
+ path = write_alloc_config(config)
747
+ console.print(f"\n[green]Created {path}[/green] [dim](from org)[/dim]")
748
+ console.print(f" Fleet: {len(config.fleet)} GPUs, Explore: {len(config.explore)} GPUs")
749
+ if config.objective:
750
+ console.print(f" Objective: {config.objective}")
751
+ console.print(f" Priority: cost={config.priority_cost}, latency={config.priority_latency}")
752
+ if config.budget_monthly:
753
+ console.print(f" Budget: ${config.budget_monthly:.0f}/mo")
754
+ if config.budget_hourly:
755
+ console.print(f" Budget cap: ${config.budget_hourly:.4f}/hr [dim](monthly/730)[/dim]")
756
+ if config.interconnect:
757
+ console.print(f" Interconnect: {config.interconnect}")
758
+ if payload.get("budget_cap_applied"):
759
+ console.print(" [yellow]Note: org cap applied to effective budget[/yellow]")
760
+ if unknown:
761
+ uniq = sorted(set(unknown))
762
+ console.print(" [yellow]Warning:[/yellow] some GPUs were not found in the local catalog:")
763
+ for name in uniq[:10]:
764
+ console.print(f" [dim]- {name}[/dim]")
765
+ if len(uniq) > 10:
766
+ console.print(f" [dim]... (+{len(uniq) - 10} more)[/dim]")
767
+ console.print()
768
+ return
769
+
770
+ gpus = list_gpus()
771
+
772
+ if yes:
773
+ # Non-interactive: all GPUs as fleet, balanced priority
774
+ config = AllocConfig(
775
+ fleet=[FleetEntry(gpu=g["id"]) for g in gpus],
776
+ priority_cost=50,
777
+ )
778
+ else:
779
+ # Interactive wizard
780
+ console.print(f"\n[green]alloc init[/green] — GPU fleet configuration\n")
781
+
782
+ # 1. Select fleet GPUs
783
+ console.print("[bold]Available GPUs:[/bold]")
784
+ for i, g in enumerate(gpus):
785
+ price_str = ""
786
+ aws = g["pricing"].get("aws")
787
+ if aws:
788
+ price_str = f" ${aws:.2f}/hr"
789
+ console.print(f" [dim]{i + 1:>2}.[/dim] {g['display_name']:>20} {g['vram_gb']:.0f} GB{price_str}")
790
+
791
+ console.print()
792
+ fleet_input = typer.prompt(
793
+ "Fleet GPUs (comma-separated numbers, or 'all')",
794
+ default="all",
795
+ )
796
+
797
+ if fleet_input.strip().lower() == "all":
798
+ fleet_entries = [FleetEntry(gpu=g["id"]) for g in gpus]
799
+ else:
800
+ fleet_entries = []
801
+ for part in fleet_input.split(","):
802
+ part = part.strip()
803
+ if part.isdigit():
804
+ idx = int(part) - 1
805
+ if 0 <= idx < len(gpus):
806
+ fleet_entries.append(FleetEntry(gpu=gpus[idx]["id"]))
807
+
808
+ if not fleet_entries:
809
+ console.print("[yellow]No GPUs selected, using full catalog.[/yellow]")
810
+ fleet_entries = [FleetEntry(gpu=g["id"]) for g in gpus]
811
+
812
+ # 2. Explore GPUs (from remaining catalog)
813
+ fleet_ids = {e.gpu for e in fleet_entries}
814
+ remaining = [g for g in gpus if g["id"] not in fleet_ids]
815
+ explore_entries = [] # type: list
816
+ if remaining:
817
+ add_explore = typer.confirm("Add explore GPUs (evaluate GPUs you don't have)?", default=False)
818
+ if add_explore:
819
+ console.print("\n[bold]GPUs not in fleet:[/bold]")
820
+ for i, g in enumerate(remaining):
821
+ console.print(f" [dim]{i + 1:>2}.[/dim] {g['display_name']:>20} {g['vram_gb']:.0f} GB")
822
+ explore_input = typer.prompt("Explore GPUs (comma-separated numbers)", default="")
823
+ for part in explore_input.split(","):
824
+ part = part.strip()
825
+ if part.isdigit():
826
+ idx = int(part) - 1
827
+ if 0 <= idx < len(remaining):
828
+ explore_entries.append(FleetEntry(gpu=remaining[idx]["id"], explore=True))
829
+
830
+ # 3. Priority
831
+ console.print("\n[bold]Priority[/bold] — how to rank GPU recommendations:")
832
+ console.print(" 1. Minimize Cost (cost=80, latency=20)")
833
+ console.print(" 2. Balanced (cost=50, latency=50)")
834
+ console.print(" 3. Minimize Time (cost=20, latency=80)")
835
+ priority_choice = typer.prompt("Choose", default="2")
836
+ priority_map = {"1": 80, "2": 50, "3": 20}
837
+ priority_cost = priority_map.get(priority_choice.strip(), 50)
838
+
839
+ # 4. Budget
840
+ budget_monthly = None # type: Optional[float]
841
+ set_budget = typer.confirm("\nSet a monthly GPU budget?", default=False)
842
+ if set_budget:
843
+ budget_str = typer.prompt("Monthly budget (USD)", default="0")
844
+ try:
845
+ budget_monthly = float(budget_str)
846
+ if budget_monthly <= 0:
847
+ budget_monthly = None
848
+ except ValueError:
849
+ budget_monthly = None
850
+
851
+ # 5. Interconnect
852
+ console.print("\n[bold]Interconnect[/bold] — GPU-to-GPU communication fabric:")
853
+ console.print(" 1. Unknown (auto-detect at runtime)")
854
+ console.print(" 2. PCIe")
855
+ console.print(" 3. NVLink")
856
+ console.print(" 4. InfiniBand")
857
+ ic_choice = typer.prompt("Choose", default="1")
858
+ ic_map = {"1": None, "2": "pcie", "3": "nvlink", "4": "infiniband"}
859
+ interconnect_val = ic_map.get(ic_choice.strip())
860
+
861
+ config = AllocConfig(
862
+ fleet=fleet_entries,
863
+ explore=explore_entries,
864
+ priority_cost=priority_cost,
865
+ budget_monthly=budget_monthly,
866
+ interconnect=interconnect_val,
867
+ )
868
+
869
+ errors = validate_config(config)
870
+ if errors:
871
+ for err in errors:
872
+ console.print(f"[red]{err}[/red]")
873
+ raise typer.Exit(1)
874
+
875
+ path = write_alloc_config(config)
876
+ fleet_count = len(config.fleet)
877
+ explore_count = len(config.explore)
878
+ console.print(f"\n[green]Created {path}[/green]")
879
+ console.print(f" Fleet: {fleet_count} GPUs, Explore: {explore_count} GPUs")
880
+ console.print(f" Priority: cost={config.priority_cost}, latency={config.priority_latency}")
881
+ if config.budget_monthly:
882
+ console.print(f" Budget: ${config.budget_monthly:.0f}/mo")
883
+ if config.interconnect:
884
+ console.print(f" Interconnect: {config.interconnect}")
885
+ console.print()
886
+
887
+
888
+ @app.command()
889
+ def version():
890
+ """Show alloc version."""
891
+ console.print(f"alloc v{__version__}")
892
+
893
+
894
+ # ---------------------------------------------------------------------------
895
+ # Catalog subcommands
896
+ # ---------------------------------------------------------------------------
897
+ catalog_app = typer.Typer(
898
+ name="catalog",
899
+ help="Browse the GPU hardware catalog.",
900
+ no_args_is_help=True,
901
+ )
902
+ app.add_typer(catalog_app, name="catalog")
903
+
904
+
905
+ @catalog_app.command("list")
906
+ def catalog_list(
907
+ sort: str = typer.Option("vram", help="Sort by: vram, cost, tflops, name"),
908
+ ):
909
+ """List all GPUs in the catalog."""
910
+ from rich.table import Table
911
+ from alloc.catalog import list_gpus
912
+
913
+ gpus = list_gpus()
914
+
915
+ if sort == "cost":
916
+ gpus.sort(key=lambda g: next(iter(g["pricing"].values()), 999))
917
+ elif sort == "tflops":
918
+ gpus.sort(key=lambda g: g["bf16_tflops"], reverse=True)
919
+ elif sort == "name":
920
+ gpus.sort(key=lambda g: g["display_name"])
921
+ # default: vram (already sorted)
922
+
923
+ table = Table(show_header=True, header_style="bold cyan", box=None, padding=(0, 2))
924
+ table.add_column("GPU", style="bold", no_wrap=True)
925
+ table.add_column("VRAM", justify="right")
926
+ table.add_column("BF16 TFLOPS", justify="right")
927
+ table.add_column("BW (GB/s)", justify="right")
928
+ table.add_column("TDP", justify="right")
929
+ table.add_column("Arch", style="dim", no_wrap=True)
930
+ table.add_column("$/hr (AWS)", justify="right")
931
+
932
+ for g in gpus:
933
+ aws_price = g["pricing"].get("aws")
934
+ price_str = f"${aws_price:.2f}" if aws_price else "\u2014"
935
+ table.add_row(
936
+ g["display_name"],
937
+ f"{g['vram_gb']:.0f} GB",
938
+ f"{g['bf16_tflops']:.0f}" if g["bf16_tflops"] else "\u2014",
939
+ f"{g['bandwidth_gbps']:.0f}",
940
+ f"{g['tdp_watts']}W",
941
+ g["architecture"],
942
+ price_str,
943
+ )
944
+
945
+ console.print(table)
946
+ console.print(f"\n[dim]{len(gpus)} GPUs in catalog[/dim]")
947
+
948
+
949
+ @catalog_app.command("show")
950
+ def catalog_show(
951
+ gpu_id: str = typer.Argument(..., help="GPU ID or alias (e.g. nvidia-h100-sxm-80gb or H100)"),
952
+ ):
953
+ """Show detailed specs for a single GPU."""
954
+ from rich.panel import Panel
955
+ from rich.table import Table
956
+ from alloc.catalog import get_gpu, _ALIASES
957
+
958
+ gpu = get_gpu(gpu_id)
959
+ if not gpu:
960
+ console.print(f"[red]Unknown GPU: {gpu_id}[/red]")
961
+ console.print(f"[dim]Available aliases: {', '.join(sorted(_ALIASES.keys()))}[/dim]")
962
+ raise typer.Exit(1)
963
+
964
+ # Specs table
965
+ table = Table(show_header=False, box=None, padding=(0, 2))
966
+ table.add_column("Field", style="dim")
967
+ table.add_column("Value", style="bold")
968
+
969
+ table.add_row("ID", gpu["id"])
970
+ table.add_row("Display Name", gpu["display_name"])
971
+ table.add_row("Vendor", gpu["vendor"])
972
+ table.add_row("Architecture", gpu["architecture"])
973
+ table.add_row("VRAM", f"{gpu['vram_gb']:.0f} GB")
974
+ table.add_row("Memory BW", f"{gpu['bandwidth_gbps']:.0f} GB/s")
975
+ table.add_row("BF16 TFLOPS", f"{gpu['bf16_tflops']:.0f}" if gpu["bf16_tflops"] else "\u2014")
976
+ table.add_row("FP16 TFLOPS", f"{gpu['fp16_tflops']:.0f}" if gpu["fp16_tflops"] else "\u2014")
977
+ table.add_row("FP32 TFLOPS", f"{gpu['fp32_tflops']:.1f}")
978
+ table.add_row("TF32 TFLOPS", f"{gpu['tf32_tflops']:.0f}" if gpu["tf32_tflops"] else "\u2014")
979
+ table.add_row("TDP", f"{gpu['tdp_watts']} W")
980
+
981
+ # Interconnect
982
+ ic = gpu.get("interconnect")
983
+ if ic:
984
+ if ic.get("nvlink_gen"):
985
+ table.add_row("NVLink", f"Gen {ic['nvlink_gen']} ({ic.get('nvlink_bw_gbps', '?')} GB/s)")
986
+ if ic.get("pcie_gen"):
987
+ table.add_row("PCIe", f"Gen {ic['pcie_gen']}")
988
+
989
+ console.print(Panel(table, title=f"[bold]{gpu['display_name']}[/bold]", border_style="green"))
990
+
991
+ # Pricing
992
+ if gpu["pricing"]:
993
+ console.print("\n [bold]Pricing[/bold]")
994
+ for cloud, price in sorted(gpu["pricing"].items()):
995
+ console.print(f" {cloud:>12} ${price:.2f}/hr")
996
+ else:
997
+ console.print("\n [dim]No pricing data available[/dim]")
998
+
999
+ console.print()
1000
+
1001
+
1002
+ def _print_json(data: dict) -> None:
1003
+ """Print a dict as formatted JSON to stdout."""
1004
+ print(json_mod.dumps(data, indent=2, default=str))
1005
+
1006
+
1007
+ def _try_upload(artifact_path: str) -> None:
1008
+ """Attempt to upload an artifact. Prints status, never raises."""
1009
+ try:
1010
+ import httpx
1011
+ from alloc.upload import upload_artifact, UploadLimitError
1012
+
1013
+ token = get_token()
1014
+ if not token:
1015
+ console.print("[yellow]Not logged in. Run `alloc login` first.[/yellow]")
1016
+ return
1017
+
1018
+ api_url = get_api_url()
1019
+ console.print(f"[dim]Uploading to {api_url}...[/dim]")
1020
+ try:
1021
+ result = upload_artifact(artifact_path, api_url, token)
1022
+ except httpx.HTTPStatusError as e:
1023
+ new_token = try_refresh_access_token() if e.response.status_code == 401 else None
1024
+ if new_token:
1025
+ console.print("[dim]Session expired; refreshed token. Retrying...[/dim]")
1026
+ result = upload_artifact(artifact_path, api_url, new_token)
1027
+ else:
1028
+ raise
1029
+
1030
+ run_id = result.get("run_id", "unknown")
1031
+ console.print(f"[green]Uploaded.[/green] Run ID: {run_id}")
1032
+ budget_warning = result.get("budget_warning")
1033
+ if budget_warning:
1034
+ console.print(f"[yellow]{budget_warning}[/yellow]")
1035
+ except UploadLimitError as e:
1036
+ detail = e.detail
1037
+ used = detail.get("used", "?")
1038
+ limit = detail.get("limit", "?")
1039
+ console.print(f"[yellow]Upload limit reached ({used}/{limit} this month).[/yellow]")
1040
+ console.print("[dim]Upgrade to Pro for more uploads. Artifact is saved locally.[/dim]")
1041
+ console.print(f"[dim] {artifact_path}[/dim]")
1042
+ except httpx.HTTPStatusError as e:
1043
+ status = e.response.status_code
1044
+ if status == 401:
1045
+ console.print("[yellow]Session expired. Run `alloc login` again.[/yellow]")
1046
+ else:
1047
+ console.print(f"[yellow]Upload failed: API error {status}[/yellow]")
1048
+ console.print(f"[dim]{e.response.text[:200]}[/dim]")
1049
+ console.print(f"[dim]You can retry later: alloc upload {artifact_path}[/dim]")
1050
+ except Exception as e:
1051
+ console.print(f"[yellow]Upload failed: {e}[/yellow]")
1052
+ console.print(f"[dim]You can retry later: alloc upload {artifact_path}[/dim]")
1053
+
1054
+
1055
+ def _read_callback_data():
1056
+ # type: () -> Optional[dict]
1057
+ """Read callback data from .alloc_callback.json."""
1058
+ try:
1059
+ callback_path = os.path.join(os.getcwd(), ".alloc_callback.json")
1060
+ if os.path.isfile(callback_path):
1061
+ with open(callback_path, "r") as f:
1062
+ return json_mod.load(f)
1063
+ except Exception:
1064
+ pass
1065
+ return None
1066
+
1067
+
1068
+ def _print_scan_result(result: dict, gpu: str, strategy: str) -> None:
1069
+ """Print remote scan result."""
1070
+ from rich.table import Table
1071
+ from rich.panel import Panel
1072
+
1073
+ vram = result.get("vram_breakdown", {})
1074
+ verdict = result.get("strategy_verdict", {})
1075
+
1076
+ table = Table(show_header=True, header_style="bold cyan", box=None, padding=(0, 2))
1077
+ table.add_column("Component", style="dim")
1078
+ table.add_column("Size", justify="right", style="bold")
1079
+
1080
+ table.add_row("Model weights", f"{vram.get('weights_gb', 0):.2f} GB")
1081
+ table.add_row("Optimizer (Adam)", f"{vram.get('optimizer_gb', 0):.2f} GB")
1082
+ table.add_row("Activations (est.)", f"{vram.get('activations_gb', 0):.2f} GB")
1083
+ table.add_row("Buffer (10%)", f"{vram.get('buffer_gb', 0):.2f} GB")
1084
+ table.add_row("", "")
1085
+ table.add_row("[bold]Total VRAM[/bold]", f"[bold]{vram.get('total_gb', 0):.2f} GB[/bold]")
1086
+
1087
+ console.print(Panel(table, title="VRAM Breakdown", border_style="green", padding=(1, 2)))
1088
+
1089
+ feasible = verdict.get("feasible", False)
1090
+ status = "[green]FEASIBLE[/green]" if feasible else "[red]INFEASIBLE[/red]"
1091
+ console.print(f" Strategy: {strategy.upper()} on {gpu} — {status}")
1092
+ if verdict.get("strategy_topology"):
1093
+ console.print(f" [dim]Topology: {verdict['strategy_topology']}[/dim]")
1094
+ if verdict.get("objective"):
1095
+ console.print(f" [dim]Objective: {verdict['objective']}[/dim]")
1096
+ if verdict.get("effective_max_budget_hourly") is not None:
1097
+ if verdict.get("budget_cap_applied"):
1098
+ user_cap = verdict.get("user_max_budget_hourly")
1099
+ org_cap = verdict.get("org_max_budget_hourly")
1100
+ eff_cap = float(verdict["effective_max_budget_hourly"])
1101
+ console.print(" [dim]Budget:[/dim]")
1102
+ if user_cap is not None:
1103
+ console.print(f" [dim]Your cap: ${float(user_cap):.4f}/hr[/dim]")
1104
+ if org_cap is not None:
1105
+ console.print(f" [dim]Org ceiling: ${float(org_cap):.4f}/hr[/dim]")
1106
+ console.print(f" [dim]Effective cap: ${eff_cap:.4f}/hr (org ceiling applied)[/dim]")
1107
+ else:
1108
+ console.print(
1109
+ f" [dim]Effective budget cap: ${float(verdict['effective_max_budget_hourly']):.4f}/hr[/dim]"
1110
+ )
1111
+
1112
+ if not feasible and verdict.get("recommendation"):
1113
+ rec = verdict.get("best_recommendation") or {}
1114
+ if rec.get("strategy_topology"):
1115
+ console.print(
1116
+ f" [yellow]Suggestion: {rec.get('strategy', '').upper()} "
1117
+ f"({rec.get('strategy_topology')})[/yellow]"
1118
+ )
1119
+ else:
1120
+ console.print(f" [yellow]Suggestion: switch to {str(verdict['recommendation']).upper()}[/yellow]")
1121
+
1122
+ if verdict.get("reason"):
1123
+ console.print(f" [dim]{verdict['reason']}[/dim]")
1124
+
1125
+ # Cost estimate if present
1126
+ cost = result.get("est_cost_per_hour")
1127
+ if cost is not None:
1128
+ console.print(f" [dim]Est. cost: ~${cost:.2f}/hr[/dim]")
1129
+
1130
+ # Euler analysis if present
1131
+ euler = result.get("euler_analysis")
1132
+ if euler and euler.get("summary"):
1133
+ console.print()
1134
+ console.print(f" [bold cyan]Euler Analysis[/bold cyan]")
1135
+ console.print(f" {euler['summary']}")
1136
+ for rec in euler.get("recommendations", []):
1137
+ console.print(f" [dim]• {rec}[/dim]")
1138
+
1139
+ console.print()
1140
+
1141
+
1142
+ def _model_to_params(model: str) -> Optional[float]:
1143
+ """Look up model param count by name."""
1144
+ from alloc.model_registry import lookup_model_params
1145
+ return lookup_model_params(model)
1146
+
1147
+
1148
+ # ---------------------------------------------------------------------------
1149
+ # GPU context helpers
1150
+ # ---------------------------------------------------------------------------
1151
+
1152
+ def _load_gpu_context(no_config: bool) -> Optional[dict]:
1153
+ """Load GPU context from .alloc.yaml. Returns None if disabled or not found."""
1154
+ if no_config:
1155
+ return None
1156
+ try:
1157
+ from alloc.yaml_config import load_alloc_config
1158
+ config = load_alloc_config()
1159
+ if config is None:
1160
+ return None
1161
+ return {
1162
+ "fleet": config.fleet_gpu_ids,
1163
+ "explore": config.explore_gpu_ids,
1164
+ "objective": config.objective,
1165
+ "priority_cost": config.priority_cost,
1166
+ "priority_latency": config.priority_latency,
1167
+ "budget_monthly": config.budget_monthly,
1168
+ "budget_hourly": config.budget_hourly,
1169
+ "rate_overrides": config.rate_overrides,
1170
+ "org_budget_monthly": config.org_budget_monthly,
1171
+ "interconnect": config.interconnect,
1172
+ }
1173
+ except Exception:
1174
+ return None
1175
+
1176
+
1177
+ def _print_gpu_context_summary(ctx: dict) -> None:
1178
+ """Print a one-line GPU context summary."""
1179
+ fleet_count = len(ctx.get("fleet", []))
1180
+ explore_count = len(ctx.get("explore", []))
1181
+ parts = []
1182
+ if fleet_count:
1183
+ parts.append(f"{fleet_count} fleet")
1184
+ if explore_count:
1185
+ parts.append(f"{explore_count} explore")
1186
+ priority = f"cost={ctx.get('priority_cost', 50)}"
1187
+ parts.append(priority)
1188
+ budget = ctx.get("budget_monthly")
1189
+ if budget:
1190
+ parts.append(f"${budget:.0f}/mo")
1191
+ console.print(f" [dim]GPU context: {', '.join(parts)} (from .alloc.yaml)[/dim]")
1192
+ console.print()
1193
+
1194
+
1195
+ def _print_gpu_context_detail(ctx: dict) -> None:
1196
+ """Print detailed GPU context for verbose mode."""
1197
+ from rich.panel import Panel
1198
+ lines = []
1199
+ fleet = ctx.get("fleet", [])
1200
+ if fleet:
1201
+ lines.append(f" Fleet GPUs {', '.join(fleet)}")
1202
+ explore = ctx.get("explore", [])
1203
+ if explore:
1204
+ lines.append(f" Explore GPUs {', '.join(explore)}")
1205
+ lines.append(f" Priority cost={ctx.get('priority_cost', 50)}, latency={ctx.get('priority_latency', 50)}")
1206
+ budget_m = ctx.get("budget_monthly")
1207
+ if budget_m:
1208
+ lines.append(f" Budget ${budget_m:.0f}/mo")
1209
+ overrides = ctx.get("rate_overrides", {})
1210
+ if overrides:
1211
+ for gpu, rate in overrides.items():
1212
+ lines.append(f" Rate override {gpu}: ${rate:.2f}/hr")
1213
+ console.print(Panel("\n".join(lines), title="GPU Context (.alloc.yaml)", border_style="cyan", padding=(1, 0)))
1214
+
1215
+
1216
+ def _infer_parallel_topology_from_env(*, num_gpus_detected: int, config_interconnect: Optional[str] = None, detected_interconnect: Optional[str] = None) -> dict:
1217
+ """Infer distributed topology hints from common launcher env vars."""
1218
+
1219
+ def _get_int(name: str) -> Optional[int]:
1220
+ val = os.environ.get(name)
1221
+ if val is None:
1222
+ return None
1223
+ try:
1224
+ parsed = int(val)
1225
+ return parsed if parsed > 0 else None
1226
+ except Exception:
1227
+ return None
1228
+
1229
+ world_size = _get_int("WORLD_SIZE")
1230
+ local_world = _get_int("LOCAL_WORLD_SIZE")
1231
+ nnodes = _get_int("NNODES")
1232
+ gpn = local_world or num_gpus_detected or 1
1233
+
1234
+ if nnodes is None and world_size is not None and gpn > 0 and world_size % gpn == 0:
1235
+ nnodes = max(1, world_size // gpn)
1236
+
1237
+ tp = _get_int("TP_SIZE") or _get_int("TENSOR_PARALLEL_SIZE")
1238
+ pp = _get_int("PP_SIZE") or _get_int("PIPELINE_PARALLEL_SIZE")
1239
+ dp = _get_int("DP_SIZE") or _get_int("DATA_PARALLEL_SIZE")
1240
+
1241
+ if dp is None and world_size is not None:
1242
+ denom = (tp or 1) * (pp or 1)
1243
+ if denom > 0 and world_size % denom == 0:
1244
+ dp = max(1, world_size // denom)
1245
+
1246
+ # Priority: env var > NVML probe detection > .alloc.yaml config > "unknown"
1247
+ interconnect = os.environ.get("ALLOC_INTERCONNECT", "").strip().lower()
1248
+ if not interconnect and detected_interconnect:
1249
+ interconnect = detected_interconnect
1250
+ if not interconnect and config_interconnect:
1251
+ interconnect = config_interconnect
1252
+ if not interconnect:
1253
+ interconnect = "unknown"
1254
+ if interconnect not in ("pcie", "nvlink", "infiniband", "unknown"):
1255
+ interconnect = "unknown"
1256
+
1257
+ return {
1258
+ "num_nodes": nnodes or 1,
1259
+ "gpus_per_node": gpn,
1260
+ "tp_degree": tp,
1261
+ "pp_degree": pp,
1262
+ "dp_degree": dp,
1263
+ "interconnect_type": interconnect,
1264
+ }
1265
+
1266
+
1267
+ def _objective_from_context(ctx: Optional[dict]) -> str:
1268
+ """Resolve ranking objective from .alloc.yaml."""
1269
+ if not ctx:
1270
+ return "best_value"
1271
+ explicit = (ctx.get("objective") or "").strip().lower()
1272
+ if explicit in ("cheapest", "fastest", "fastest_within_budget", "best_value"):
1273
+ return explicit
1274
+ priority_cost = int(ctx.get("priority_cost", 50))
1275
+ if priority_cost >= 70:
1276
+ return "cheapest"
1277
+ if priority_cost <= 30:
1278
+ return "fastest"
1279
+ return "best_value"
1280
+
1281
+
1282
+ def _build_budget_context(gpu_context, probe_result):
1283
+ # type: (Optional[dict], Any) -> Optional[dict]
1284
+ """Build budget context dict for verdict display from gpu_context + probe result."""
1285
+ if not gpu_context:
1286
+ return None
1287
+ budget_monthly = gpu_context.get("budget_monthly")
1288
+ rate_overrides = gpu_context.get("rate_overrides") or {}
1289
+
1290
+ # Look up cost_per_hour: first try rate overrides from .alloc.yaml, then catalog
1291
+ gpu_name = getattr(probe_result, "gpu_name", None) or ""
1292
+ cost_per_hour = None
1293
+
1294
+ # Check rate overrides (user's .alloc.yaml rates)
1295
+ for gpu_id, rate in rate_overrides.items():
1296
+ if gpu_id.lower() in gpu_name.lower() or gpu_name.lower() in gpu_id.lower():
1297
+ cost_per_hour = rate
1298
+ break
1299
+
1300
+ # Fallback: look up from catalog
1301
+ if cost_per_hour is None and gpu_name:
1302
+ try:
1303
+ from alloc.catalog import get_default_rate
1304
+ cost_per_hour = get_default_rate(gpu_name)
1305
+ except Exception:
1306
+ pass
1307
+
1308
+ if cost_per_hour is None and budget_monthly is None:
1309
+ return None
1310
+
1311
+ num_gpus = getattr(probe_result, "num_gpus_detected", 1) or 1
1312
+ if cost_per_hour is not None:
1313
+ cost_per_hour = cost_per_hour * num_gpus
1314
+
1315
+ org_budget = gpu_context.get("org_budget_monthly")
1316
+ budget_cap_applied = (
1317
+ budget_monthly is not None
1318
+ and org_budget is not None
1319
+ and budget_monthly >= org_budget
1320
+ )
1321
+
1322
+ return {
1323
+ "cost_per_hour": cost_per_hour,
1324
+ "budget_monthly": budget_monthly,
1325
+ "org_budget_monthly": org_budget,
1326
+ "budget_cap_applied": budget_cap_applied,
1327
+ }
1328
+
1329
+
1330
+ def _max_budget_hourly_from_context(ctx: Optional[dict]) -> Optional[float]:
1331
+ """Resolve hourly budget from .alloc.yaml context."""
1332
+ if not ctx:
1333
+ return None
1334
+ val = ctx.get("budget_hourly")
1335
+ if val is None:
1336
+ return None
1337
+ try:
1338
+ parsed = float(val)
1339
+ return parsed if parsed > 0 else None
1340
+ except Exception:
1341
+ return None