lyceum-cli 1.0.28__py3-none-any.whl → 1.0.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1074 @@
1
+ """GPU selection execution commands"""
2
+
3
+ import json
4
+ import time
5
+
6
+ import httpx
7
+ import typer
8
+ from rich.console import Console
9
+ from rich.panel import Panel
10
+ from rich.table import Table
11
+
12
+ from ....shared.config import config
13
+ from ....shared.streaming import StatusLine
14
+ from .python import (
15
+ inject_script_args,
16
+ load_workspace_config,
17
+ read_code_from_source,
18
+ resolve_import_files,
19
+ resolve_requirements,
20
+ )
21
+
22
+ console = Console()
23
+
24
+ gpu_selection_app = typer.Typer(name="gpu-selection", help="GPU selection and profiling commands")
25
+
26
+ POLL_INTERVAL = 2.0
27
+ MAX_POLL_TIME = 3600 # 1 hour - A100/H100 initialization can take up to 30 min
28
+
29
+ # Cache for GPU pricing
30
+ _pricing_cache: dict[str, float] | None = None
31
+
32
+ # Mapping from API profile names to display names
33
+ GPU_DISPLAY_NAMES = {
34
+ "gpu": "T4",
35
+ "gpu.t4": "T4",
36
+ "gpu.t4.64gb": "T4",
37
+ "gpu.a100": "A100",
38
+ "gpu.a100.40gb": "A100 (40GB)",
39
+ "gpu.a100.80gb": "A100 (80GB)",
40
+ "gpu.h100": "H100",
41
+ "gpu.h200": "H200",
42
+ "gpu.l40s": "L40S",
43
+ "gpu.b200": "B200",
44
+ "gpu.rtx6000pro": "RTX 6000 Pro",
45
+ }
46
+
47
+ # VRAM in GB for each GPU profile (for showing excluded GPUs)
48
+ GPU_VRAM_GB = {
49
+ "gpu": 16,
50
+ "gpu.t4": 16,
51
+ "gpu.t4.64gb": 16,
52
+ "gpu.a100": 40,
53
+ "gpu.a100.40gb": 40,
54
+ "gpu.a100.80gb": 80,
55
+ "gpu.h100": 80,
56
+ "gpu.h200": 141,
57
+ "gpu.l40s": 48,
58
+ "gpu.b200": 180,
59
+ "gpu.rtx6000pro": 48,
60
+ }
61
+
62
+
63
+ def format_gpu_name(profile: str) -> str:
64
+ """Format GPU profile name for display."""
65
+ if profile in GPU_DISPLAY_NAMES:
66
+ return GPU_DISPLAY_NAMES[profile]
67
+ # Fallback: strip "gpu." prefix and uppercase
68
+ return profile.replace("gpu.", "").upper()
69
+
70
+
71
+ def fetch_gpu_pricing() -> dict[str, float]:
72
+ """Fetch GPU pricing from API. Returns dict of hardware_profile -> price_per_hour."""
73
+ global _pricing_cache
74
+ if _pricing_cache is not None:
75
+ return _pricing_cache
76
+
77
+ try:
78
+ response = httpx.get(
79
+ f"{config.base_url}/api/v2/external/compute/machine-types",
80
+ headers={"Authorization": f"Bearer {config.api_key}"},
81
+ timeout=10.0,
82
+ )
83
+ if response.status_code == 200:
84
+ data = response.json()
85
+ _pricing_cache = {
86
+ m["hardware_profile"]: m.get("price_per_hour", 0) or 0
87
+ for m in data.get("machine_types", [])
88
+ }
89
+ return _pricing_cache
90
+ except Exception:
91
+ pass
92
+ return {}
93
+
94
+
95
+ def calculate_cost(execution_time_s: float, hardware_profile: str, pricing: dict[str, float]) -> float | None:
96
+ """Calculate cost based on execution time and GPU pricing."""
97
+ price_per_hour = pricing.get(hardware_profile)
98
+ if price_per_hour is None or price_per_hour == 0:
99
+ return None
100
+ return execution_time_s * (price_per_hour / 3600)
101
+
102
+
103
+ def submit_gpu_selection(payload: dict, status: StatusLine = None) -> str:
104
+ """Submit GPU selection request to API and return the execution_id."""
105
+ if status:
106
+ status.update("Submitting GPU selection job...")
107
+
108
+ response = httpx.post(
109
+ f"{config.base_url}/api/v2/external/execution/gpu_selection/start",
110
+ headers={"Authorization": f"Bearer {config.api_key}"},
111
+ json=payload,
112
+ timeout=30.0,
113
+ )
114
+
115
+ if response.status_code != 200:
116
+ if status:
117
+ status.stop()
118
+ console.print(f"[red]Error: HTTP {response.status_code}[/red]")
119
+ if response.status_code == 401:
120
+ console.print("[red]Authentication failed. Your session may have expired.[/red]")
121
+ console.print("[yellow]Run 'lyceum auth login' to re-authenticate.[/yellow]")
122
+ elif response.status_code == 402:
123
+ console.print("[red]Insufficient credits. Please purchase more credits to continue.[/red]")
124
+ elif response.status_code == 403:
125
+ console.print("[red]You do not have access to GPU instances.[/red]")
126
+ else:
127
+ console.print(f"[red]{response.content.decode()}[/red]")
128
+ raise typer.Exit(1)
129
+
130
+ data = response.json()
131
+ return data["execution_id"]
132
+
133
+
134
+ def poll_gpu_selection(execution_id: str, status: StatusLine = None) -> dict:
135
+ """Poll GPU selection status until terminal state."""
136
+ elapsed = 0.0
137
+
138
+ while elapsed < MAX_POLL_TIME:
139
+ try:
140
+ response = httpx.get(
141
+ f"{config.base_url}/api/v2/external/execution/gpu_selection/{execution_id}/status",
142
+ headers={"Authorization": f"Bearer {config.api_key}"},
143
+ timeout=10.0,
144
+ )
145
+
146
+ if response.status_code != 200:
147
+ if status:
148
+ status.update(f"Waiting for results (status check returned {response.status_code})...")
149
+ time.sleep(POLL_INTERVAL)
150
+ elapsed += POLL_INTERVAL
151
+ continue
152
+
153
+ data = response.json()
154
+ current_status = data.get("status", "unknown")
155
+
156
+ if current_status in ("completed", "failed", "aborted", "system_failure"):
157
+ return data
158
+
159
+ if status:
160
+ status.update(f"Status: {current_status}...")
161
+
162
+ except httpx.RequestError:
163
+ if status:
164
+ status.update("Reconnecting...")
165
+
166
+ time.sleep(POLL_INTERVAL)
167
+ elapsed += POLL_INTERVAL
168
+
169
+ if status:
170
+ status.stop()
171
+ console.print("[yellow]Timed out waiting for GPU selection results.[/yellow]")
172
+ console.print(f"[dim]Check later: lyceum predict status {execution_id}[/dim]")
173
+ raise typer.Exit(1)
174
+
175
+
176
+ ERROR_SUGGESTIONS = {
177
+ "No PyTorch or Hugging Face ecosystem detected": [
178
+ "Add [cyan]import torch[/cyan] and use PyTorch modules",
179
+ "Or use HuggingFace [cyan]transformers[/cyan] library",
180
+ ],
181
+ "GPU requirement cannot be determined or is CPU-only": [
182
+ "Move model to GPU: [cyan]model.to('cuda')[/cyan]",
183
+ "Move tensors to GPU: [cyan]tensor.to('cuda')[/cyan]",
184
+ "Or use [cyan]device = torch.device('cuda')[/cyan]",
185
+ ],
186
+ "No model found": [
187
+ "Define a class that inherits from [cyan]nn.Module[/cyan]",
188
+ "Or use a pretrained model from [cyan]transformers[/cyan]",
189
+ ],
190
+ "No training loop detected": [
191
+ "Add a training loop with [cyan]loss.backward()[/cyan]",
192
+ "And [cyan]optimizer.step()[/cyan]",
193
+ ],
194
+ }
195
+
196
+
197
+ def get_suggestions(error: str) -> list[str]:
198
+ """Get suggestions for a given error message."""
199
+ for key, suggestions in ERROR_SUGGESTIONS.items():
200
+ if key.lower() in error.lower():
201
+ return suggestions
202
+ return []
203
+
204
+
205
+ def display_results(data: dict, file_path: str | None = None) -> None:
206
+ """Display GPU selection results."""
207
+ if data is None:
208
+ console.print()
209
+ console.print(Panel(
210
+ "[red]✗[/red] No response data received",
211
+ title="[red]GPU Selection Failed[/red]",
212
+ border_style="red",
213
+ padding=(1, 2),
214
+ ))
215
+ return
216
+
217
+ status = data.get("status", "unknown")
218
+
219
+ # Parse metadata if it's a string
220
+ metadata = data.get("metadata")
221
+ if isinstance(metadata, str):
222
+ try:
223
+ metadata = json.loads(metadata)
224
+ except (json.JSONDecodeError, TypeError):
225
+ metadata = {}
226
+ metadata = metadata or {}
227
+
228
+ if status != "completed":
229
+ errors = data.get("system_errors") or []
230
+
231
+ # Build error content
232
+ error_lines = []
233
+ all_suggestions = []
234
+
235
+ for err in errors:
236
+ error_lines.append(f"[red]✗[/red] {err}")
237
+ all_suggestions.extend(get_suggestions(err))
238
+
239
+ if not error_lines:
240
+ error_lines.append(f"[red]✗[/red] Status: {status}")
241
+
242
+ error_content = "\n".join(error_lines)
243
+
244
+ # Add suggestions if available
245
+ if all_suggestions:
246
+ error_content += "\n\n[dim]Suggestions:[/dim]"
247
+ for suggestion in all_suggestions:
248
+ error_content += f"\n → {suggestion}"
249
+
250
+ console.print()
251
+ console.print(Panel(
252
+ error_content,
253
+ title="[red]GPU Selection Failed[/red]",
254
+ border_style="red",
255
+ padding=(1, 2),
256
+ ))
257
+ return
258
+
259
+ profiling = metadata.get("profiling_results", [])
260
+ extraction = metadata.get("extraction_result", {})
261
+
262
+ # Get memory info for summary
263
+ mem_config = extraction.get("memory_config", {})
264
+ minimal_configs = mem_config.get("minimal_configs", [])
265
+
266
+ # Find the smallest/cheapest GPU option (lowest VRAM that works)
267
+ best_gpu = None
268
+ if minimal_configs:
269
+ # Sort by VRAM to find smallest viable option
270
+ sorted_configs = sorted(minimal_configs, key=lambda x: x.get("per_gpu_vram_gb", 999))
271
+ best_gpu = sorted_configs[0] if sorted_configs else None
272
+
273
+ # Summary panel
274
+ console.print()
275
+ summary_lines = ["[green]✓[/green] Analysis complete"]
276
+
277
+ if best_gpu:
278
+ gpu_name = format_gpu_name(best_gpu.get("gpu_type", "unknown"))
279
+ vram = best_gpu.get("per_gpu_vram_gb", "?")
280
+ count = best_gpu.get("min_gpu_count", 1)
281
+ gpu_str = f"{count}x " if count > 1 else ""
282
+ summary_lines.append("")
283
+ summary_lines.append(f"[bold]Recommended:[/bold] [cyan]{gpu_str}{gpu_name}[/cyan] ({vram} GB VRAM)")
284
+
285
+ # Add runtime info if available
286
+ if profiling:
287
+ completed = [p for p in profiling if p.get("status") == "completed"]
288
+ if completed:
289
+ fastest = min(completed, key=lambda x: x.get("execution_time", 999))
290
+ time_s = fastest.get("execution_time")
291
+ if time_s is not None:
292
+ summary_lines.append(f"[bold]Est. runtime:[/bold] {time_s:.2f}s")
293
+
294
+ report = fastest.get("runtime_report", {})
295
+ iters = report.get("train_iteration", {}).get("train_iterations_per_second")
296
+ if iters:
297
+ summary_lines.append(f"[bold]Throughput:[/bold] {iters:.0f} iters/sec")
298
+
299
+ console.print(Panel(
300
+ "\n".join(summary_lines),
301
+ title="[green]GPU Selection Results[/green]",
302
+ border_style="green",
303
+ padding=(1, 2),
304
+ ))
305
+
306
+ # Profiling results table with cost
307
+ if profiling:
308
+ pricing = fetch_gpu_pricing()
309
+ console.print()
310
+ prof_table = Table(title="Profiling Results", show_header=True, header_style="bold")
311
+ prof_table.add_column("GPU", style="cyan")
312
+ prof_table.add_column("Status")
313
+ prof_table.add_column("Time", justify="right")
314
+ prof_table.add_column("Cost", justify="right")
315
+ prof_table.add_column("Iters/sec", justify="right")
316
+ prof_table.add_column("Peak VRAM", justify="right")
317
+
318
+ # Sort by execution time
319
+ sorted_profiling = sorted(profiling, key=lambda x: x.get("execution_time") or 999)
320
+
321
+ for result in sorted_profiling:
322
+ profile = result.get("profile", "?")
323
+ rst = result.get("status", "unknown")
324
+ style = "green" if rst in ("completed", "success") else "red"
325
+
326
+ report = result.get("runtime_report") or {}
327
+ train_iter = report.get("train_iteration") or {}
328
+
329
+ time_s = result.get("execution_time")
330
+ # Calculate cost from pricing
331
+ cost = result.get("cost")
332
+ if cost is None and time_s is not None:
333
+ cost = calculate_cost(time_s, profile, pricing)
334
+
335
+ iters = train_iter.get("train_iterations_per_second")
336
+ vram = report.get("Peak VRAM Allocated (MB)")
337
+
338
+ prof_table.add_row(
339
+ format_gpu_name(profile),
340
+ f"[{style}]{rst}[/{style}]",
341
+ f"{time_s:.2f}s" if time_s is not None else "-",
342
+ f"${cost:.6f}" if cost is not None else "-",
343
+ f"{iters:.0f}" if iters else "-",
344
+ f"{vram:.1f} MB" if vram else "-",
345
+ )
346
+
347
+ console.print(prof_table)
348
+
349
+ # Compatible GPU configurations table
350
+ if minimal_configs:
351
+ console.print()
352
+ gpu_table = Table(title="Compatible GPUs", show_header=True, header_style="bold")
353
+ gpu_table.add_column("GPU", style="cyan")
354
+ gpu_table.add_column("VRAM", justify="right")
355
+ gpu_table.add_column("GPUs Needed", justify="right")
356
+ gpu_table.add_column("Utilization", justify="right")
357
+
358
+ # Sort by VRAM size for better display
359
+ sorted_configs = sorted(minimal_configs, key=lambda x: x.get("per_gpu_vram_gb", 0))
360
+ compatible_gpu_types = {cfg.get("gpu_type") for cfg in minimal_configs}
361
+
362
+ for i, cfg in enumerate(sorted_configs):
363
+ gpu_type = format_gpu_name(cfg.get("gpu_type", "?"))
364
+ vram = cfg.get("per_gpu_vram_gb", 0)
365
+ count = cfg.get("min_gpu_count", 1)
366
+ util = cfg.get("vram_utilization_percent", 0)
367
+
368
+ # Highlight the recommended (smallest) option
369
+ if i == 0:
370
+ gpu_type = f"[green]{gpu_type}[/green] ✓"
371
+
372
+ gpu_table.add_row(
373
+ gpu_type,
374
+ f"{vram} GB",
375
+ str(count),
376
+ f"{util}%",
377
+ )
378
+
379
+ console.print(gpu_table)
380
+
381
+ # Show excluded GPUs (those not in minimal_configs)
382
+ # Get total memory required from extraction result
383
+ mem_reqs = mem_config.get("memory_requirements", {})
384
+ total_mem_gb = 0
385
+ if mem_reqs:
386
+ # Sum up memory components
387
+ total_mem_gb = (
388
+ mem_reqs.get("model_weights", 0) +
389
+ mem_reqs.get("gradients", 0) +
390
+ mem_reqs.get("optimizer_states", 0) +
391
+ mem_reqs.get("activations", 0)
392
+ )
393
+
394
+ if not total_mem_gb and sorted_configs:
395
+ # Estimate from minimal config utilization
396
+ first_cfg = sorted_configs[0]
397
+ vram = first_cfg.get("per_gpu_vram_gb", 0)
398
+ util = first_cfg.get("vram_utilization_percent", 0)
399
+ if vram and util:
400
+ total_mem_gb = vram * (util / 100)
401
+
402
+ # Find GPUs that were excluded
403
+ excluded_gpus = []
404
+ for profile, vram in GPU_VRAM_GB.items():
405
+ if profile not in compatible_gpu_types and profile in ("gpu", "gpu.t4"):
406
+ # Only show common GPUs that users expect to see
407
+ excluded_gpus.append((profile, vram))
408
+
409
+ if excluded_gpus and total_mem_gb:
410
+ console.print()
411
+ console.print("[dim]Excluded GPUs (insufficient VRAM):[/dim]")
412
+ for profile, vram in sorted(excluded_gpus, key=lambda x: x[1]):
413
+ gpu_name = format_gpu_name(profile)
414
+ console.print(f"[dim] • {gpu_name}: {vram} GB available, ~{total_mem_gb:.1f} GB required[/dim]")
415
+
416
+ # Show run command hint if we have a best GPU and file path
417
+ if best_gpu and file_path:
418
+ machine_flag = best_gpu.get("gpu_type", "gpu")
419
+ console.print()
420
+ console.print(f"[dim]To run on optimal machine: lyceum python run {file_path} -m {machine_flag}[/dim]")
421
+
422
+
423
+ @gpu_selection_app.command("run", context_settings={"allow_extra_args": True, "allow_interspersed_args": True})
424
+ def run_gpu_selection(
425
+ ctx: typer.Context,
426
+ code_or_file: str = typer.Argument(..., help="Python code to execute or path to Python file"),
427
+ file_name: str | None = typer.Option(None, "--file-name", "-f", help="Name for the execution"),
428
+ timeout: int = typer.Option(60, "--timeout", "-t", help="Timeout per sub-job in seconds (1-600)"),
429
+ requirements: str | None = typer.Option(
430
+ None, "--requirements", "-r", help="Requirements file path or pip requirements string"
431
+ ),
432
+ imports: list[str] | None = typer.Option(
433
+ None, "--import", help="Pre-import modules (can be used multiple times)"
434
+ ),
435
+ use_config: bool = typer.Option(
436
+ True, "--use-config/--no-config",
437
+ help="Use workspace config from .lyceum/config.json if available"
438
+ ),
439
+ debug: bool = typer.Option(
440
+ False, "--debug", "-d",
441
+ help="Show detailed debug information about config, requirements, and payload"
442
+ ),
443
+ ):
444
+ """Run code on multiple GPUs and select the optimal hardware.
445
+
446
+ Submits the code to run on all GPU profiles available to your account,
447
+ then returns which GPU performed best.
448
+
449
+ Script arguments can be passed after the file path:
450
+
451
+ lyceum predict run train.py -- --epochs 10 --lr 0.001
452
+ """
453
+ status = StatusLine()
454
+
455
+ try:
456
+ config.get_client()
457
+ status.start()
458
+
459
+ script_args = [arg for arg in (ctx.args or []) if arg != "--"]
460
+
461
+ code, file_path, detected_file_name = read_code_from_source(code_or_file, status)
462
+ if not file_name:
463
+ file_name = detected_file_name
464
+
465
+ code = inject_script_args(code, script_args, file_name)
466
+
467
+ workspace_config = None
468
+ if use_config:
469
+ status.update("Loading workspace config...")
470
+ workspace_config = load_workspace_config(file_path)
471
+ if workspace_config and debug:
472
+ status.stop()
473
+ console.print(f"[cyan]DEBUG: Config keys: {list(workspace_config.keys())}[/cyan]")
474
+ status.start()
475
+
476
+ requirements_content = resolve_requirements(requirements, workspace_config, debug, status)
477
+ import_files = resolve_import_files(file_path, workspace_config, debug, status)
478
+
479
+ # Build payload matching GPUSelectionRequest schema
480
+ payload = {
481
+ "code": code,
482
+ "nbcode": 0,
483
+ "timeout": timeout,
484
+ }
485
+ if file_name:
486
+ payload["file_name"] = file_name
487
+ if requirements_content:
488
+ payload["requirements_content"] = requirements_content
489
+ if imports:
490
+ payload["prior_imports"] = imports
491
+ if import_files:
492
+ payload["import_files"] = import_files
493
+
494
+ if debug:
495
+ status.stop()
496
+ console.print("[cyan]DEBUG: Payload summary:[/cyan]")
497
+ console.print(f"[cyan] - timeout: {timeout}[/cyan]")
498
+ console.print(f"[cyan] - code length: {len(code)} chars[/cyan]")
499
+ console.print(f"[cyan] - requirements_content: {len(requirements_content or '')} chars[/cyan]")
500
+ console.print(f"[cyan] - import_files: {len(import_files or '')} chars[/cyan]")
501
+ status.start()
502
+
503
+ execution_id = submit_gpu_selection(payload, status)
504
+ console.print(f"[dim]Execution ID: {execution_id}[/dim]")
505
+
506
+ status.update("Waiting for GPU selection results...")
507
+ data = poll_gpu_selection(execution_id, status)
508
+ status.stop()
509
+
510
+ display_results(data, file_path=code_or_file)
511
+
512
+ if data.get("status") != "completed":
513
+ raise typer.Exit(1)
514
+
515
+ except typer.Exit:
516
+ status.stop()
517
+ raise
518
+ except Exception as e:
519
+ status.stop()
520
+ console.print(f"[red]Error: {e}[/red]")
521
+ raise typer.Exit(1)
522
+
523
+
524
+ @gpu_selection_app.command("status")
525
+ def predict_status(
526
+ execution_id: str = typer.Argument(..., help="Execution ID to check"),
527
+ ):
528
+ """Check the status of a GPU selection execution."""
529
+ try:
530
+ config.get_client()
531
+
532
+ response = httpx.get(
533
+ f"{config.base_url}/api/v2/external/execution/gpu_selection/{execution_id}/status",
534
+ headers={"Authorization": f"Bearer {config.api_key}"},
535
+ timeout=10.0,
536
+ )
537
+
538
+ if response.status_code == 404:
539
+ console.print("[red]Execution not found.[/red]")
540
+ raise typer.Exit(1)
541
+
542
+ if response.status_code != 200:
543
+ console.print(f"[red]Error: HTTP {response.status_code}[/red]")
544
+ console.print(f"[red]{response.content.decode()}[/red]")
545
+ raise typer.Exit(1)
546
+
547
+ data = response.json()
548
+
549
+ # Parse metadata if it's a string
550
+ if isinstance(data.get("metadata"), str):
551
+ try:
552
+ data["metadata"] = json.loads(data["metadata"])
553
+ except (json.JSONDecodeError, TypeError):
554
+ pass
555
+
556
+ current_status = data.get("status", "unknown")
557
+ console.print(f"Status: [bold]{current_status}[/bold]")
558
+
559
+ if current_status in ("completed", "failed", "aborted", "system_failure"):
560
+ display_results(data)
561
+ else:
562
+ console.print("[dim]Job is still running. Check again later.[/dim]")
563
+
564
+ except typer.Exit:
565
+ raise
566
+ except Exception as e:
567
+ console.print(f"[red]Error: {e}[/red]")
568
+ raise typer.Exit(1)
569
+
570
+
571
+ def display_memory_results(data: dict, file_path: str | None = None) -> None:
572
+ """Display memory analysis results."""
573
+ if data is None:
574
+ console.print("[red]No data received[/red]")
575
+ return
576
+
577
+ metadata = data.get("metadata")
578
+ if isinstance(metadata, str):
579
+ try:
580
+ metadata = json.loads(metadata)
581
+ except (json.JSONDecodeError, TypeError):
582
+ metadata = {}
583
+ metadata = metadata or {}
584
+
585
+ extraction = metadata.get("extraction_result", {})
586
+ mem_config = extraction.get("memory_config", {})
587
+ mem_reqs = mem_config.get("memory_requirements", {})
588
+ minimal_configs = mem_config.get("minimal_configs", [])
589
+
590
+ if not mem_reqs and not minimal_configs:
591
+ console.print("[yellow]No memory analysis data available.[/yellow]")
592
+ return
593
+
594
+ # Memory requirements breakdown
595
+ if mem_reqs:
596
+ console.print()
597
+ mem_table = Table(title="Memory Requirements", show_header=True, header_style="bold")
598
+ mem_table.add_column("Component", style="cyan")
599
+ mem_table.add_column("Size", justify="right")
600
+
601
+ def format_gb(val: float) -> str:
602
+ if val < 0.001:
603
+ return f"{val * 1024:.2f} MB"
604
+ return f"{val:.3f} GB"
605
+
606
+ components = [
607
+ ("Model Weights", mem_reqs.get("model_weights", 0)),
608
+ ("Gradients", mem_reqs.get("gradients", 0)),
609
+ ("Optimizer States", mem_reqs.get("optimizer_states", 0)),
610
+ ("Activations", mem_reqs.get("activations", 0)),
611
+ ("Largest Layer", mem_reqs.get("largest_layer", 0)),
612
+ ]
613
+
614
+ total = sum(v for _, v in components if v)
615
+ for name, val in components:
616
+ if val:
617
+ mem_table.add_row(name, format_gb(val))
618
+
619
+ mem_table.add_row("─" * 20, "─" * 10)
620
+ mem_table.add_row("[bold]Total[/bold]", f"[bold]{format_gb(total)}[/bold]")
621
+
622
+ param_count = mem_reqs.get("parameter_count", 0)
623
+ if param_count:
624
+ mem_table.add_row("", "")
625
+ mem_table.add_row("Parameter Count", f"{param_count:.2e}")
626
+
627
+ console.print(mem_table)
628
+
629
+ # Compatible GPUs
630
+ if minimal_configs:
631
+ console.print()
632
+ gpu_table = Table(title="Compatible GPUs", show_header=True, header_style="bold")
633
+ gpu_table.add_column("GPU", style="cyan")
634
+ gpu_table.add_column("VRAM", justify="right")
635
+ gpu_table.add_column("GPUs Needed", justify="right")
636
+ gpu_table.add_column("Utilization", justify="right")
637
+
638
+ sorted_configs = sorted(minimal_configs, key=lambda x: x.get("per_gpu_vram_gb", 0))
639
+
640
+ for i, cfg in enumerate(sorted_configs):
641
+ gpu_type = format_gpu_name(cfg.get("gpu_type", "?"))
642
+ vram = cfg.get("per_gpu_vram_gb", 0)
643
+ count = cfg.get("min_gpu_count", 1)
644
+ util = cfg.get("vram_utilization_percent", 0)
645
+
646
+ if i == 0:
647
+ gpu_type = f"[green]{gpu_type}[/green] ✓"
648
+
649
+ gpu_table.add_row(gpu_type, f"{vram} GB", str(count), f"{util}%")
650
+
651
+ console.print(gpu_table)
652
+
653
+ # Show run command hint
654
+ if file_path and sorted_configs:
655
+ best = sorted_configs[0]
656
+ machine_flag = best.get("gpu_type", "gpu")
657
+ console.print()
658
+ console.print(f"[dim]To run on optimal machine: lyceum python run {file_path} -m {machine_flag}[/dim]")
659
+
660
+
661
+ @gpu_selection_app.command("memory", context_settings={"allow_extra_args": True, "allow_interspersed_args": True})
662
+ def predict_memory(
663
+ ctx: typer.Context,
664
+ code_or_file: str = typer.Argument(..., help="Python code or path to Python file"),
665
+ file_name: str | None = typer.Option(None, "--file-name", "-f", help="Name for the execution"),
666
+ requirements: str | None = typer.Option(
667
+ None, "--requirements", "-r", help="Requirements file path or pip requirements string"
668
+ ),
669
+ imports: list[str] | None = typer.Option(
670
+ None, "--import", help="Pre-import modules (can be used multiple times)"
671
+ ),
672
+ mixed_precision: str | None = typer.Option(
673
+ None, "--mixed-precision", "-mp",
674
+ help="Mixed precision dtype (fp16, bf16)"
675
+ ),
676
+ strategy: str | None = typer.Option(
677
+ None, "--strategy", "-s",
678
+ help="Parallelization strategy (ddp, fsdp, zero1, zero2, zero3)"
679
+ ),
680
+ use_config: bool = typer.Option(
681
+ True, "--use-config/--no-config",
682
+ help="Use workspace config from .lyceum/config.json if available"
683
+ ),
684
+ ):
685
+ """Estimate memory requirements for your training script.
686
+
687
+ Analyzes model architecture to predict VRAM usage without running
688
+ full GPU profiling. Faster than 'predict run'.
689
+
690
+ Examples:
691
+ lyceum predict memory train.py
692
+ lyceum predict memory train.py --mixed-precision fp16
693
+ lyceum predict memory train.py --strategy fsdp
694
+ """
695
+ status = StatusLine()
696
+
697
+ try:
698
+ config.get_client()
699
+ status.start()
700
+
701
+ script_args = [arg for arg in (ctx.args or []) if arg != "--"]
702
+
703
+ code, file_path, detected_file_name = read_code_from_source(code_or_file, status)
704
+ if not file_name:
705
+ file_name = detected_file_name
706
+
707
+ code = inject_script_args(code, script_args, file_name)
708
+
709
+ workspace_config = None
710
+ if use_config:
711
+ status.update("Loading workspace config...")
712
+ workspace_config = load_workspace_config(file_path)
713
+
714
+ requirements_content = resolve_requirements(requirements, workspace_config, False, status)
715
+ import_files = resolve_import_files(file_path, workspace_config, False, status)
716
+
717
+ # Build payload - same as gpu_selection but we'll only show memory results
718
+ payload = {
719
+ "code": code,
720
+ "nbcode": 0,
721
+ "timeout": 60, # Memory analysis is quick
722
+ }
723
+ if file_name:
724
+ payload["file_name"] = file_name
725
+ if requirements_content:
726
+ payload["requirements_content"] = requirements_content
727
+ if imports:
728
+ payload["prior_imports"] = imports
729
+ if import_files:
730
+ payload["import_files"] = import_files
731
+
732
+ # TODO: When backend supports it, add mixed_precision and strategy to payload
733
+ if mixed_precision:
734
+ console.print(f"[dim]Note: --mixed-precision {mixed_precision} (backend support coming soon)[/dim]")
735
+ if strategy:
736
+ console.print(f"[dim]Note: --strategy {strategy} (backend support coming soon)[/dim]")
737
+
738
+ execution_id = submit_gpu_selection(payload, status)
739
+ console.print(f"[dim]Execution ID: {execution_id}[/dim]")
740
+
741
+ status.update("Analyzing memory requirements...")
742
+ data = poll_gpu_selection(execution_id, status)
743
+ status.stop()
744
+
745
+ if data.get("status") != "completed":
746
+ display_results(data) # Show error with suggestions
747
+ raise typer.Exit(1)
748
+
749
+ display_memory_results(data, file_path=code_or_file)
750
+
751
+ except typer.Exit:
752
+ status.stop()
753
+ raise
754
+ except Exception as e:
755
+ status.stop()
756
+ console.print(f"[red]Error: {e}[/red]")
757
+ raise typer.Exit(1)
758
+
759
+
760
+ @gpu_selection_app.command("recommend-gpus", context_settings={"allow_extra_args": True, "allow_interspersed_args": True})
761
+ def recommend_gpus(
762
+ ctx: typer.Context,
763
+ code_or_file: str = typer.Argument(..., help="Python code or path to Python file"),
764
+ file_name: str | None = typer.Option(None, "--file-name", "-f", help="Name for the execution"),
765
+ requirements: str | None = typer.Option(
766
+ None, "--requirements", "-r", help="Requirements file path or pip requirements string"
767
+ ),
768
+ use_config: bool = typer.Option(
769
+ True, "--use-config/--no-config",
770
+ help="Use workspace config from .lyceum/config.json if available"
771
+ ),
772
+ top: int = typer.Option(3, "--top", "-n", help="Number of recommendations to show"),
773
+ ):
774
+ """Quick GPU recommendations based on memory analysis.
775
+
776
+ Analyzes your model and recommends the best GPU configurations
777
+ sorted by cost-efficiency.
778
+
779
+ Examples:
780
+ lyceum predict recommend-gpus train.py
781
+ lyceum predict recommend-gpus train.py --top 5
782
+ """
783
+ status = StatusLine()
784
+
785
+ try:
786
+ config.get_client()
787
+ status.start()
788
+
789
+ script_args = [arg for arg in (ctx.args or []) if arg != "--"]
790
+
791
+ code, file_path, detected_file_name = read_code_from_source(code_or_file, status)
792
+ if not file_name:
793
+ file_name = detected_file_name
794
+
795
+ code = inject_script_args(code, script_args, file_name)
796
+
797
+ workspace_config = None
798
+ if use_config:
799
+ status.update("Loading workspace config...")
800
+ workspace_config = load_workspace_config(file_path)
801
+
802
+ requirements_content = resolve_requirements(requirements, workspace_config, False, status)
803
+ import_files = resolve_import_files(file_path, workspace_config, False, status)
804
+
805
+ payload = {
806
+ "code": code,
807
+ "nbcode": 0,
808
+ "timeout": 60,
809
+ }
810
+ if file_name:
811
+ payload["file_name"] = file_name
812
+ if requirements_content:
813
+ payload["requirements_content"] = requirements_content
814
+ if import_files:
815
+ payload["import_files"] = import_files
816
+
817
+ execution_id = submit_gpu_selection(payload, status)
818
+ console.print(f"[dim]Execution ID: {execution_id}[/dim]")
819
+
820
+ status.update("Analyzing model and generating recommendations...")
821
+ data = poll_gpu_selection(execution_id, status)
822
+ status.stop()
823
+
824
+ if data.get("status") != "completed":
825
+ display_results(data)
826
+ raise typer.Exit(1)
827
+
828
+ # Get pricing for cost estimates
829
+ pricing = fetch_gpu_pricing()
830
+
831
+ metadata = data.get("metadata")
832
+ if isinstance(metadata, str):
833
+ try:
834
+ metadata = json.loads(metadata)
835
+ except (json.JSONDecodeError, TypeError):
836
+ metadata = {}
837
+ metadata = metadata or {}
838
+
839
+ extraction = metadata.get("extraction_result", {})
840
+ mem_config = extraction.get("memory_config", {})
841
+ minimal_configs = mem_config.get("minimal_configs", [])
842
+
843
+ if not minimal_configs:
844
+ console.print("[yellow]No GPU recommendations available.[/yellow]")
845
+ raise typer.Exit(1)
846
+
847
+ # Sort by VRAM (smaller = likely cheaper)
848
+ sorted_configs = sorted(minimal_configs, key=lambda x: x.get("per_gpu_vram_gb", 0))[:top]
849
+
850
+ console.print()
851
+ console.print(Panel(
852
+ f"[green]✓[/green] Found {len(minimal_configs)} compatible GPU configurations",
853
+ title="[green]GPU Recommendations[/green]",
854
+ border_style="green",
855
+ ))
856
+
857
+ console.print()
858
+ rec_table = Table(title=f"Top {min(top, len(sorted_configs))} Recommendations", show_header=True, header_style="bold")
859
+ rec_table.add_column("#", style="dim", width=3)
860
+ rec_table.add_column("GPU", style="cyan")
861
+ rec_table.add_column("VRAM", justify="right")
862
+ rec_table.add_column("GPUs", justify="right")
863
+ rec_table.add_column("$/hour", justify="right")
864
+ rec_table.add_column("Utilization", justify="right")
865
+
866
+ for i, cfg in enumerate(sorted_configs):
867
+ gpu_type = cfg.get("gpu_type", "?")
868
+ gpu_display = format_gpu_name(gpu_type)
869
+ vram = cfg.get("per_gpu_vram_gb", 0)
870
+ count = cfg.get("min_gpu_count", 1)
871
+ util = cfg.get("vram_utilization_percent", 0)
872
+
873
+ price = pricing.get(gpu_type, 0)
874
+ price_str = f"${price:.2f}" if price else "-"
875
+
876
+ rank = f"[green]{i + 1}[/green]" if i == 0 else str(i + 1)
877
+
878
+ rec_table.add_row(
879
+ rank,
880
+ gpu_display,
881
+ f"{vram} GB",
882
+ str(count),
883
+ price_str,
884
+ f"{util}%",
885
+ )
886
+
887
+ console.print(rec_table)
888
+
889
+ # Show command hint
890
+ best = sorted_configs[0]
891
+ machine_flag = best.get("gpu_type", "gpu")
892
+ console.print()
893
+ console.print(f"[dim]Run with: lyceum python run {code_or_file} -m {machine_flag}[/dim]")
894
+
895
+ except typer.Exit:
896
+ status.stop()
897
+ raise
898
+ except Exception as e:
899
+ status.stop()
900
+ console.print(f"[red]Error: {e}[/red]")
901
+ raise typer.Exit(1)
902
+
903
+
904
+ def display_runtime_results(data: dict, file_path: str | None = None) -> None:
905
+ """Display runtime profiling results."""
906
+ if data is None:
907
+ console.print("[red]No data received[/red]")
908
+ return
909
+
910
+ metadata = data.get("metadata")
911
+ if isinstance(metadata, str):
912
+ try:
913
+ metadata = json.loads(metadata)
914
+ except (json.JSONDecodeError, TypeError):
915
+ metadata = {}
916
+ metadata = metadata or {}
917
+
918
+ profiling = metadata.get("profiling_results", [])
919
+
920
+ if not profiling:
921
+ console.print("[yellow]No runtime profiling data available.[/yellow]")
922
+ return
923
+
924
+ pricing = fetch_gpu_pricing()
925
+
926
+ # Find best performers
927
+ completed = [p for p in profiling if p.get("status") in ("completed", "success")]
928
+
929
+ if completed:
930
+ fastest = min(completed, key=lambda x: x.get("execution_time") or 999)
931
+ fastest_profile = format_gpu_name(fastest.get("profile", "?"))
932
+ fastest_time = fastest.get("execution_time", 0)
933
+
934
+ console.print()
935
+ console.print(Panel(
936
+ f"[green]✓[/green] Profiling complete\n\n"
937
+ f"[bold]Fastest:[/bold] [cyan]{fastest_profile}[/cyan] ({fastest_time:.2f}s)",
938
+ title="[green]Runtime Analysis[/green]",
939
+ border_style="green",
940
+ ))
941
+
942
+ # Detailed results table
943
+ console.print()
944
+ prof_table = Table(title="Runtime Results by GPU", show_header=True, header_style="bold")
945
+ prof_table.add_column("GPU", style="cyan")
946
+ prof_table.add_column("Status")
947
+ prof_table.add_column("Time", justify="right")
948
+ prof_table.add_column("Cost", justify="right")
949
+ prof_table.add_column("Iters/sec", justify="right")
950
+ prof_table.add_column("Avg Batch (ms)", justify="right")
951
+ prof_table.add_column("Peak VRAM", justify="right")
952
+
953
+ sorted_profiling = sorted(profiling, key=lambda x: x.get("execution_time") or 999)
954
+
955
+ for result in sorted_profiling:
956
+ profile = result.get("profile", "?")
957
+ rst = result.get("status", "unknown")
958
+ style = "green" if rst in ("completed", "success") else "red"
959
+
960
+ report = result.get("runtime_report") or {}
961
+ train = report.get("training") or {}
962
+ train_iter = report.get("train_iteration") or {}
963
+
964
+ time_s = result.get("execution_time")
965
+ cost = result.get("cost")
966
+ if cost is None and time_s is not None:
967
+ cost = calculate_cost(time_s, profile, pricing)
968
+
969
+ iters = train_iter.get("train_iterations_per_second")
970
+ avg_batch = train.get("avg_train_time_ms")
971
+ vram = report.get("Peak VRAM Allocated (MB)")
972
+
973
+ prof_table.add_row(
974
+ format_gpu_name(profile),
975
+ f"[{style}]{rst}[/{style}]",
976
+ f"{time_s:.2f}s" if time_s is not None else "-",
977
+ f"${cost:.6f}" if cost is not None else "-",
978
+ f"{iters:.0f}" if iters else "-",
979
+ f"{avg_batch:.2f}" if avg_batch else "-",
980
+ f"{vram:.1f} MB" if vram else "-",
981
+ )
982
+
983
+ console.print(prof_table)
984
+
985
+ # Show run command hint based on fastest GPU
986
+ if file_path and completed:
987
+ fastest = min(completed, key=lambda x: x.get("execution_time") or 999)
988
+ machine_flag = fastest.get("profile", "gpu")
989
+ console.print()
990
+ console.print(f"[dim]To run on fastest machine: lyceum python run {file_path} -m {machine_flag}[/dim]")
991
+
992
+
993
+ @gpu_selection_app.command("runtime", context_settings={"allow_extra_args": True, "allow_interspersed_args": True})
994
+ def predict_runtime(
995
+ ctx: typer.Context,
996
+ code_or_file: str = typer.Argument(..., help="Python code or path to Python file"),
997
+ file_name: str | None = typer.Option(None, "--file-name", "-f", help="Name for the execution"),
998
+ timeout: int = typer.Option(120, "--timeout", "-t", help="Timeout per GPU in seconds (1-600)"),
999
+ requirements: str | None = typer.Option(
1000
+ None, "--requirements", "-r", help="Requirements file path or pip requirements string"
1001
+ ),
1002
+ imports: list[str] | None = typer.Option(
1003
+ None, "--import", help="Pre-import modules (can be used multiple times)"
1004
+ ),
1005
+ use_config: bool = typer.Option(
1006
+ True, "--use-config/--no-config",
1007
+ help="Use workspace config from .lyceum/config.json if available"
1008
+ ),
1009
+ ):
1010
+ """Profile runtime performance across different GPUs.
1011
+
1012
+ Runs your training script on available GPUs and measures actual
1013
+ execution time, throughput, and VRAM usage.
1014
+
1015
+ Examples:
1016
+ lyceum predict runtime train.py
1017
+ lyceum predict runtime train.py --timeout 180
1018
+ """
1019
+ status = StatusLine()
1020
+
1021
+ try:
1022
+ config.get_client()
1023
+ status.start()
1024
+
1025
+ script_args = [arg for arg in (ctx.args or []) if arg != "--"]
1026
+
1027
+ code, file_path, detected_file_name = read_code_from_source(code_or_file, status)
1028
+ if not file_name:
1029
+ file_name = detected_file_name
1030
+
1031
+ code = inject_script_args(code, script_args, file_name)
1032
+
1033
+ workspace_config = None
1034
+ if use_config:
1035
+ status.update("Loading workspace config...")
1036
+ workspace_config = load_workspace_config(file_path)
1037
+
1038
+ requirements_content = resolve_requirements(requirements, workspace_config, False, status)
1039
+ import_files = resolve_import_files(file_path, workspace_config, False, status)
1040
+
1041
+ payload = {
1042
+ "code": code,
1043
+ "nbcode": 0,
1044
+ "timeout": timeout,
1045
+ }
1046
+ if file_name:
1047
+ payload["file_name"] = file_name
1048
+ if requirements_content:
1049
+ payload["requirements_content"] = requirements_content
1050
+ if imports:
1051
+ payload["prior_imports"] = imports
1052
+ if import_files:
1053
+ payload["import_files"] = import_files
1054
+
1055
+ execution_id = submit_gpu_selection(payload, status)
1056
+ console.print(f"[dim]Execution ID: {execution_id}[/dim]")
1057
+
1058
+ status.update("Profiling runtime across GPUs...")
1059
+ data = poll_gpu_selection(execution_id, status)
1060
+ status.stop()
1061
+
1062
+ if data.get("status") != "completed":
1063
+ display_results(data)
1064
+ raise typer.Exit(1)
1065
+
1066
+ display_runtime_results(data, file_path=code_or_file)
1067
+
1068
+ except typer.Exit:
1069
+ status.stop()
1070
+ raise
1071
+ except Exception as e:
1072
+ status.stop()
1073
+ console.print(f"[red]Error: {e}[/red]")
1074
+ raise typer.Exit(1)