wafer-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py ADDED
@@ -0,0 +1,1536 @@
1
+ """Wafer CLI - Run commands on remote GPUs in Docker containers."""
2
+
3
+ import json
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import trio
8
+ import typer
9
+
10
+ from .config import WaferConfig, WaferEnvironment
11
+ from .inference import infer_upload_files, resolve_environment
12
+
13
+ app = typer.Typer(help="Run commands on remote GPUs in Docker containers")
14
+
15
+ # Subcommand group for target management (local TOML-based)
16
+ targets_app = typer.Typer(help="Manage local GPU targets (TOML files)")
17
+ app.add_typer(targets_app, name="targets")
18
+
19
+
20
+ @app.command()
21
+ def run(
22
+ command: str = typer.Argument(..., help="Command to run in Docker container"),
23
+ env: str | None = typer.Option(None, "--env", "-e", help="Environment name from config"),
24
+ upload: list[str] | None = typer.Option( # noqa: B008
25
+ None, "--upload", "-u", help="Files to upload (default: auto-infer)"
26
+ ),
27
+ target: str | None = typer.Option(None, "--target", "-t", help="Override target from config"),
28
+ follow: bool = typer.Option(True, "--follow/--no-follow", help="Stream output in real-time"),
29
+ detach: bool = typer.Option(False, "--detach", "-d", help="Run in background, return job ID"),
30
+ ) -> None:
31
+ """Run command on remote GPU in Docker container.
32
+
33
+ Examples:
34
+ # Run with auto-inferred files and default environment
35
+ wafer run "make && ./kernel_test"
36
+
37
+ # Specify environment
38
+ wafer run "python train.py" --env pytorch
39
+
40
+ # Override target
41
+ wafer run "nvcc kernel.cu -o kernel && ./kernel" --target root@other-node:22
42
+
43
+ # Upload specific files
44
+ wafer run "make" --upload kernel.cu --upload Makefile
45
+
46
+ # Run in background
47
+ wafer run "python train.py --epochs 100" --detach
48
+ """
49
+ # Load config
50
+ config_path = Path.home() / ".wafer" / "config.toml"
51
+ if not config_path.exists():
52
+ typer.echo(f"Error: Config not found: {config_path}", err=True)
53
+ typer.echo(
54
+ "Create ~/.wafer/config.toml with your settings. See documentation for format.",
55
+ err=True,
56
+ )
57
+ raise typer.Exit(1)
58
+
59
+ try:
60
+ config = WaferConfig.from_toml(config_path)
61
+ except (AssertionError, ValueError, KeyError) as e:
62
+ typer.echo(f"Error: Invalid config: {e}", err=True)
63
+ raise typer.Exit(1) from None
64
+
65
+ # Resolve environment
66
+ try:
67
+ environment = resolve_environment(config, env)
68
+ except (ValueError, AssertionError) as e:
69
+ typer.echo(f"Error: {e}", err=True)
70
+ raise typer.Exit(1) from None
71
+
72
+ # Determine files to upload
73
+ cwd = Path.cwd()
74
+ if upload:
75
+ files_to_upload = [cwd / f for f in upload]
76
+ # Validate files exist
77
+ for f in files_to_upload:
78
+ if not f.exists():
79
+ typer.echo(f"Error: File not found: {f}", err=True)
80
+ raise typer.Exit(1)
81
+ if not f.is_file():
82
+ typer.echo(f"Error: Not a file: {f}", err=True)
83
+ raise typer.Exit(1)
84
+ else:
85
+ try:
86
+ files_to_upload = infer_upload_files(command, cwd)
87
+ except (AssertionError, ValueError) as e:
88
+ typer.echo(f"Error: Failed to infer files: {e}", err=True)
89
+ raise typer.Exit(1) from None
90
+
91
+ # Use target override if provided
92
+ effective_target = target or config.target
93
+
94
+ # Run async implementation
95
+ try:
96
+ trio.run(
97
+ _run_async,
98
+ effective_target,
99
+ config.ssh_key,
100
+ environment,
101
+ command,
102
+ files_to_upload,
103
+ follow,
104
+ detach,
105
+ )
106
+ except KeyboardInterrupt:
107
+ typer.echo("\nInterrupted by user", err=True)
108
+ raise typer.Exit(130) from None
109
+ except Exception as e:
110
+ typer.echo(f"Error: {e}", err=True)
111
+ raise typer.Exit(1) from None
112
+
113
+
114
+ async def _run_async(
115
+ target: str,
116
+ ssh_key: str,
117
+ environment: WaferEnvironment,
118
+ command: str,
119
+ files_to_upload: list[Path],
120
+ follow: bool,
121
+ detach: bool,
122
+ ) -> None:
123
+ """Async wrapper for run command (runs sync SSH client in thread).
124
+
125
+ Args:
126
+ target: SSH target string (user@host:port)
127
+ ssh_key: Path to SSH key
128
+ environment: Environment configuration with Docker image
129
+ command: Command to execute
130
+ files_to_upload: List of files to upload
131
+ follow: Whether to stream output
132
+ detach: Whether to run in background
133
+
134
+ Raises:
135
+ Exception: On any execution failure
136
+ """
137
+ import trio
138
+
139
+ await trio.to_thread.run_sync(
140
+ lambda: _run_sync(target, ssh_key, environment, command, files_to_upload, follow, detach)
141
+ )
142
+
143
+
144
+ def _run_sync(
145
+ target: str,
146
+ ssh_key: str,
147
+ environment: WaferEnvironment,
148
+ command: str,
149
+ files_to_upload: list[Path],
150
+ follow: bool,
151
+ detach: bool,
152
+ ) -> None:
153
+ """Sync implementation of run command using internal SSHClient.
154
+
155
+ Args:
156
+ target: SSH target string (user@host:port)
157
+ ssh_key: Path to SSH key
158
+ environment: Environment configuration with Docker image
159
+ command: Command to execute
160
+ files_to_upload: List of files to upload
161
+ follow: Whether to stream output
162
+ detach: Whether to run in background
163
+
164
+ Raises:
165
+ Exception: On any execution failure
166
+ """
167
+
168
+ from wafer_core.ssh import SSHClient
169
+
170
+ workspace_name = Path.cwd().name
171
+ remote_workspace = f"~/.wafer/workspaces/{workspace_name}"
172
+
173
+ client = SSHClient(target, ssh_key)
174
+
175
+ # Ensure workspace directory exists
176
+ print(f"Setting up workspace: {remote_workspace}")
177
+ client.exec(f"mkdir -p {remote_workspace}")
178
+
179
+ # Upload files
180
+ if files_to_upload:
181
+ print(f"Uploading {len(files_to_upload)} files...")
182
+ for f in files_to_upload:
183
+ remote_path = f"{remote_workspace}/{f.name}"
184
+ client.upload_files(str(f), remote_path)
185
+ print(f" āœ“ {f.name}")
186
+ else:
187
+ print("No files to upload (use --upload to specify)")
188
+
189
+ # Expand workspace path for volume mount
190
+ expanded_workspace = client.expand_path(remote_workspace)
191
+
192
+ print(f"\nEnvironment: {environment.docker}")
193
+ if environment.description:
194
+ print(f"Description: {environment.description}")
195
+ print(f"Command: {command}")
196
+ print("-" * 60)
197
+
198
+ # Check if Docker is available
199
+ docker_check = client.exec("which docker")
200
+ if docker_check.exit_code != 0:
201
+ raise RuntimeError(
202
+ "Docker not found on remote machine. Please install Docker with GPU support."
203
+ )
204
+
205
+ # Build docker command
206
+ docker_cmd = _build_docker_run_cmd(
207
+ image=environment.docker,
208
+ inner_cmd=command,
209
+ volumes={expanded_workspace: "/workspace"},
210
+ working_dir="/workspace",
211
+ )
212
+
213
+ # Execute
214
+ if follow and not detach:
215
+ # Stream output in real-time
216
+ try:
217
+ for line in client.exec_stream(docker_cmd):
218
+ print(line)
219
+ except Exception as e:
220
+ print(f"\nExecution failed: {e}", file=sys.stderr)
221
+ raise
222
+ else:
223
+ # Non-streaming execution
224
+ result = client.exec(docker_cmd)
225
+
226
+ # Print output
227
+ if result.stdout:
228
+ print(result.stdout)
229
+ if result.stderr:
230
+ print(result.stderr, file=sys.stderr)
231
+
232
+ # Check exit code
233
+ if result.exit_code != 0:
234
+ print(
235
+ f"\nCommand exited with code {result.exit_code}",
236
+ file=sys.stderr,
237
+ )
238
+ raise typer.Exit(result.exit_code)
239
+
240
+
241
+ def _build_docker_run_cmd(
242
+ image: str,
243
+ inner_cmd: str,
244
+ volumes: dict[str, str],
245
+ working_dir: str,
246
+ gpu_id: int = 0,
247
+ ) -> str:
248
+ """Build docker run command string."""
249
+ import shlex
250
+
251
+ parts = ["docker", "run", "--rm"]
252
+ parts.extend(["--gpus", f"'device={gpu_id}'"])
253
+
254
+ for host_path, container_path in volumes.items():
255
+ parts.extend(["-v", f"{host_path}:{container_path}"])
256
+
257
+ parts.extend(["-w", working_dir])
258
+ parts.append(image)
259
+ parts.append(f"bash -c {shlex.quote(inner_cmd)}")
260
+
261
+ return " ".join(parts)
262
+
263
+
264
+ @app.command()
265
+ def status(job_id: str = typer.Argument(..., help="Job ID to check")) -> None:
266
+ """Get status of a running job."""
267
+ # TODO: Implement in Phase 3
268
+ typer.echo(f"Status for job {job_id}: not yet implemented")
269
+ typer.echo("Job persistence will be added in Phase 3")
270
+
271
+
272
+ @app.command()
273
+ def logs(
274
+ job_id: str = typer.Argument(..., help="Job ID"),
275
+ follow_logs: bool = typer.Option(False, "--follow", "-f", help="Follow log output"),
276
+ ) -> None:
277
+ """Get logs from a job."""
278
+ # TODO: Implement in Phase 3
279
+ typer.echo(f"Logs for job {job_id}: not yet implemented")
280
+ typer.echo("Job persistence will be added in Phase 3")
281
+
282
+
283
+ @app.command()
284
+ def kill(job_id: str = typer.Argument(..., help="Job ID to kill")) -> None:
285
+ """Kill a running job."""
286
+ # TODO: Implement in Phase 3
287
+ typer.echo(f"Kill job {job_id}: not yet implemented")
288
+ typer.echo("Job persistence will be added in Phase 3")
289
+
290
+
291
+ @app.command()
292
+ def config_show() -> None:
293
+ """Show current configuration."""
294
+ config_path = Path.home() / ".wafer" / "config.toml"
295
+ if not config_path.exists():
296
+ typer.echo(f"Error: Config not found: {config_path}", err=True)
297
+ raise typer.Exit(1)
298
+
299
+ try:
300
+ config = WaferConfig.from_toml(config_path)
301
+ typer.echo(f"Target: {config.target}")
302
+ typer.echo(f"SSH Key: {config.ssh_key}")
303
+ typer.echo(f"Default Environment: {config.default_environment or '(none)'}")
304
+ typer.echo("\nEnvironments:")
305
+ for name, env in config.environments.items():
306
+ typer.echo(f" {name}:")
307
+ typer.echo(f" Docker: {env.docker}")
308
+ if env.description:
309
+ typer.echo(f" Description: {env.description}")
310
+ except Exception as e:
311
+ typer.echo(f"Error reading config: {e}", err=True)
312
+ raise typer.Exit(1) from None
313
+
314
+
315
+ @app.command()
316
+ def wevin(
317
+ prompt: str | None = typer.Argument(
318
+ None,
319
+ help="Prompt to send (reads from stdin if not provided and not interactive)",
320
+ ),
321
+ interactive: bool = typer.Option(
322
+ False,
323
+ "--interactive",
324
+ "-i",
325
+ help="Launch interactive TUI mode",
326
+ ),
327
+ resume: str | None = typer.Option(
328
+ None,
329
+ "--resume",
330
+ "-r",
331
+ help="Resume session by ID (or 'last' for most recent)",
332
+ ),
333
+ from_turn: int | None = typer.Option(
334
+ None,
335
+ "--from-turn",
336
+ help="Branch from specific turn (default: resume from end)",
337
+ ),
338
+ list_sessions: bool = typer.Option(
339
+ False,
340
+ "--list-sessions",
341
+ help="List recent sessions and exit",
342
+ ),
343
+ tools: str | None = typer.Option(
344
+ None,
345
+ "--tools",
346
+ help="Comma-separated list of tools to enable (default: all)",
347
+ ),
348
+ allow_spawn: bool = typer.Option(
349
+ False,
350
+ "--allow-spawn",
351
+ help="Allow wafer tool to spawn sub-wevin agents (blocked by default)",
352
+ ),
353
+ max_tool_fails: int | None = typer.Option(
354
+ None,
355
+ "--max-tool-fails",
356
+ help="Exit after N consecutive tool failures",
357
+ ),
358
+ max_turns: int | None = typer.Option(
359
+ None,
360
+ "--max-turns",
361
+ help="Max conversation turns (default: 10)",
362
+ ),
363
+ model: str | None = typer.Option(
364
+ None,
365
+ "--model",
366
+ "-m",
367
+ help="Model override (default: claude-sonnet-4-5)",
368
+ ),
369
+ json_output: bool = typer.Option(
370
+ False,
371
+ "--json",
372
+ help="Output in JSON format (stream-json style)",
373
+ ),
374
+ # Legacy kernel optimization options (hidden, for backwards compat)
375
+ problem: Path | None = typer.Option(
376
+ None,
377
+ "--problem",
378
+ hidden=True,
379
+ help="[Legacy] Path to problem YAML config file",
380
+ ),
381
+ reference: Path | None = typer.Option(
382
+ None,
383
+ "--reference",
384
+ "--ref",
385
+ hidden=True,
386
+ help="[Legacy] Path to reference kernel file",
387
+ ),
388
+ description: str | None = typer.Option(
389
+ None,
390
+ "--description",
391
+ "--desc",
392
+ hidden=True,
393
+ help="[Legacy] Problem description",
394
+ ),
395
+ test: list[str] | None = typer.Option(
396
+ None,
397
+ "--test",
398
+ "-t",
399
+ hidden=True,
400
+ help="[Legacy] Test case",
401
+ ),
402
+ benchmark: list[str] | None = typer.Option(
403
+ None,
404
+ "--benchmark",
405
+ "-b",
406
+ hidden=True,
407
+ help="[Legacy] Benchmark case",
408
+ ),
409
+ speedup_target: float | None = typer.Option(
410
+ None,
411
+ "--speedup",
412
+ hidden=True,
413
+ help="[Legacy] Speedup target",
414
+ ),
415
+ ) -> None:
416
+ """Wevin - GPU programming assistant.
417
+
418
+ By default, runs in one-shot mode: answer the prompt and exit.
419
+ Use -i/--interactive for full TUI mode.
420
+
421
+ Examples:
422
+ # One-shot query
423
+ wafer wevin "What is TMEM in CuTeDSL?"
424
+
425
+ # Pipe from stdin
426
+ cat kernel.py | wafer wevin "optimize this kernel"
427
+
428
+ # JSON output for scripting
429
+ wafer wevin "explain shared memory" --json
430
+
431
+ # Interactive TUI
432
+ wafer wevin -i
433
+
434
+ # Resume a session
435
+ wafer wevin --resume last "follow up question"
436
+
437
+ # Limit tools
438
+ wafer wevin --tools read,wafer "What is TMEM?"
439
+
440
+ # Legacy kernel optimization mode
441
+ wafer wevin --problem my_kernel.yaml
442
+ """
443
+ from wafer.wevin_cli import main as wevin_main
444
+
445
+ # Read from stdin if no prompt and not interactive
446
+ actual_prompt = prompt
447
+ if not actual_prompt and not interactive and not sys.stdin.isatty():
448
+ actual_prompt = sys.stdin.read().strip()
449
+
450
+ wevin_main(
451
+ prompt=actual_prompt,
452
+ interactive=interactive,
453
+ problem=str(problem) if problem else None,
454
+ reference=str(reference) if reference else None,
455
+ description=description,
456
+ tests=list(test) if test else None,
457
+ benchmarks=list(benchmark) if benchmark else None,
458
+ model=model,
459
+ max_turns=max_turns,
460
+ speedup_target=speedup_target,
461
+ resume=resume,
462
+ from_turn=from_turn,
463
+ list_sessions=list_sessions,
464
+ tools=tools.split(",") if tools else None,
465
+ allow_spawn=allow_spawn,
466
+ max_tool_fails=max_tool_fails,
467
+ json_output=json_output,
468
+ )
469
+
470
+
471
+ # =============================================================================
472
+ # Evaluate command
473
+ # =============================================================================
474
+
475
+
476
+ @app.command()
477
+ def evaluate(
478
+ implementation: Path = typer.Option(
479
+ ..., "--implementation", "--impl", help="Path to implementation kernel file"
480
+ ),
481
+ reference: Path = typer.Option(..., "--reference", help="Path to reference kernel file"),
482
+ test_cases: Path = typer.Option(..., "--test-cases", help="Path to test cases JSON file"),
483
+ target: str | None = typer.Option(None, "--target", "-t", help="Target name (or uses default)"),
484
+ benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
485
+ profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
486
+ sync_artifacts: bool = typer.Option(
487
+ True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
488
+ ),
489
+ gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
490
+ ) -> None:
491
+ """Run kernel evaluation on a remote GPU target.
492
+
493
+ Same interface as evaluate.py, but runs remotely.
494
+
495
+ Examples:
496
+ wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json
497
+
498
+ wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
499
+ --target vultr-b200 --benchmark
500
+ """
501
+ from .evaluate import EvaluateArgs, run_evaluate
502
+
503
+ args = EvaluateArgs(
504
+ implementation=implementation,
505
+ reference=reference,
506
+ test_cases=test_cases,
507
+ target_name=target or "",
508
+ benchmark=benchmark,
509
+ profile=profile,
510
+ sync_artifacts=sync_artifacts,
511
+ gpu_id=gpu_id,
512
+ )
513
+
514
+ try:
515
+ # Use trio_asyncio to run async code that uses both trio and asyncio
516
+ # (AsyncSSHClient uses asyncssh which is asyncio-based, bridged via trio_asyncio)
517
+ import trio_asyncio
518
+
519
+ result = trio_asyncio.run(run_evaluate, args)
520
+ except KeyboardInterrupt:
521
+ typer.echo("\nInterrupted by user", err=True)
522
+ raise typer.Exit(130) from None
523
+ except Exception as e:
524
+ typer.echo(f"Error: {e}", err=True)
525
+ raise typer.Exit(1) from None
526
+
527
+ # Print results
528
+ if result.success:
529
+ typer.echo("")
530
+ typer.echo("=" * 60)
531
+ status = "PASS" if result.all_correct else "FAIL"
532
+ typer.echo(f"Result: {status}")
533
+ score_pct = f"{result.correctness_score:.1%}"
534
+ typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
535
+ if result.geomean_speedup > 0:
536
+ typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
537
+ if result.artifact_path:
538
+ typer.echo(f"Artifacts: {result.artifact_path}")
539
+ typer.echo("=" * 60)
540
+
541
+ if not result.all_correct:
542
+ raise typer.Exit(1)
543
+ else:
544
+ typer.echo(f"Error: {result.error_message}", err=True)
545
+ raise typer.Exit(1)
546
+
547
+
548
+ # =============================================================================
549
+ # Push and Remote-Run commands
550
+ # =============================================================================
551
+
552
+
553
+ @app.command("push")
554
+ def push(
555
+ local_path: Path = typer.Argument(..., help="Local directory to upload"),
556
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace name override"),
557
+ direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
558
+ target_name: str | None = typer.Option(None, "--target", "-t", help="Target name (for --direct mode)"),
559
+ ) -> None:
560
+ """Push directory to remote GPU.
561
+
562
+ By default, uses wafer-api. Use --direct for direct SSH mode.
563
+
564
+ Examples:
565
+ wafer push ./my_project
566
+ wafer push . --workspace my-kernel
567
+ wafer push ./my_project --direct --target vultr-b200
568
+ """
569
+ # Validate path
570
+ if not local_path.exists():
571
+ typer.echo(f"Error: Path not found: {local_path}", err=True)
572
+ raise typer.Exit(1)
573
+
574
+ if not local_path.is_dir():
575
+ typer.echo(f"Error: Not a directory: {local_path}", err=True)
576
+ raise typer.Exit(1)
577
+
578
+ # Resolve to absolute path
579
+ local_path = local_path.resolve()
580
+
581
+ if direct:
582
+ # Direct SSH mode (requires target)
583
+ if not target_name:
584
+ typer.echo("Error: --target required for --direct mode", err=True)
585
+ raise typer.Exit(1)
586
+
587
+ from .gpu_run import push_directory as push_direct
588
+ from .targets import load_target
589
+
590
+ try:
591
+ target = load_target(target_name)
592
+ except FileNotFoundError:
593
+ typer.echo(f"Error: Target not found: {target_name}", err=True)
594
+ typer.echo("List targets with: wafer targets list", err=True)
595
+ raise typer.Exit(1) from None
596
+
597
+ typer.echo(f"Connecting to {target.ssh_target}...")
598
+ try:
599
+ result = push_direct(local_path, target)
600
+ except Exception as e:
601
+ typer.echo(f"Error: {e}", err=True)
602
+ raise typer.Exit(1) from None
603
+
604
+ typer.echo(f"Uploading {len(result.files_uploaded)} files to {result.workspace_path}")
605
+ for f in result.files_uploaded:
606
+ typer.echo(f" āœ“ {f}")
607
+ typer.echo(f"Pushed to: {result.workspace_path}")
608
+ else:
609
+ # API mode (default)
610
+ from .api_client import push_directory as push_api
611
+
612
+ workspace_name = workspace or local_path.name
613
+ typer.echo(f"Pushing {local_path.name} to wafer-api...")
614
+
615
+ try:
616
+ result = push_api(local_path, workspace_name)
617
+ except Exception as e:
618
+ typer.echo(f"Error: {e}", err=True)
619
+ raise typer.Exit(1) from None
620
+
621
+ typer.echo(f"Uploaded {len(result.files_uploaded)} files")
622
+ for f in result.files_uploaded:
623
+ typer.echo(f" āœ“ {f}")
624
+ typer.echo(f"Workspace ID: {result.workspace_id}")
625
+
626
+
627
+ def _run_direct_mode(
628
+ cmd_str: str,
629
+ target_name: str,
630
+ upload_dir: Path | None,
631
+ workspace_id: str | None,
632
+ gpu_id: int | None,
633
+ ) -> int:
634
+ """Run command via direct SSH mode. Returns exit code."""
635
+ from .gpu_run import push_directory as push_direct
636
+ from .gpu_run import run_command as run_direct
637
+ from .targets import load_target
638
+
639
+ try:
640
+ target = load_target(target_name)
641
+ except FileNotFoundError:
642
+ typer.echo(f"Error: Target not found: {target_name}", err=True)
643
+ typer.echo("List targets with: wafer targets list", err=True)
644
+ raise typer.Exit(1) from None
645
+
646
+ if not target.docker_image:
647
+ typer.echo(f"Error: Target '{target_name}' has no docker_image configured", err=True)
648
+ raise typer.Exit(1)
649
+
650
+ # If upload_dir provided, push first
651
+ workspace_name = workspace_id
652
+ if upload_dir:
653
+ typer.echo(f"Uploading {upload_dir.name}...")
654
+ try:
655
+ push_result = push_direct(upload_dir, target)
656
+ workspace_name = push_result.workspace_path
657
+ typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
658
+ except Exception as e:
659
+ typer.echo(f"Error uploading: {e}", err=True)
660
+ raise typer.Exit(1) from None
661
+ elif not workspace_name:
662
+ workspace_name = "tmp"
663
+
664
+ effective_gpu = gpu_id if gpu_id is not None else target.gpu_ids[0]
665
+ typer.echo(f"Target: {target_name} (docker: {target.docker_image})")
666
+ typer.echo(f"Workspace: {workspace_name}")
667
+ typer.echo(f"GPU: {effective_gpu}")
668
+ typer.echo(f"Command: {cmd_str}")
669
+ typer.echo("-" * 60)
670
+
671
+ try:
672
+ return run_direct(cmd_str, workspace_name, target, gpu_id)
673
+ except KeyboardInterrupt:
674
+ typer.echo("\nInterrupted by user", err=True)
675
+ raise typer.Exit(130) from None
676
+ except Exception as e:
677
+ typer.echo(f"Error: {e}", err=True)
678
+ raise typer.Exit(1) from None
679
+
680
+
681
+ def _run_api_mode(
682
+ cmd_str: str,
683
+ upload_dir: Path | None,
684
+ workspace_id: str | None,
685
+ gpu_id: int | None,
686
+ docker_image: str | None,
687
+ docker_entrypoint: str | None,
688
+ pull_image: bool,
689
+ require_hwc: bool,
690
+ ) -> int:
691
+ """Run command via wafer-api. Returns exit code."""
692
+ from .api_client import run_command_stream
693
+
694
+ if upload_dir:
695
+ typer.echo(f"Uploading: {upload_dir}")
696
+ elif workspace_id:
697
+ typer.echo(f"Workspace: {workspace_id}")
698
+ if gpu_id is not None:
699
+ typer.echo(f"GPU: {gpu_id}")
700
+ if docker_image:
701
+ typer.echo(f"Image: {docker_image}")
702
+ if docker_entrypoint:
703
+ typer.echo(f"Entrypoint: {docker_entrypoint}")
704
+ if pull_image:
705
+ typer.echo("Pull image: yes")
706
+ typer.echo(f"Command: {cmd_str}")
707
+ if require_hwc:
708
+ typer.echo("Hardware counters: required (baremetal)")
709
+ typer.echo("-" * 60)
710
+
711
+ try:
712
+ return run_command_stream(
713
+ command=cmd_str,
714
+ upload_dir=upload_dir,
715
+ workspace_id=workspace_id,
716
+ gpu_id=gpu_id,
717
+ docker_image=docker_image,
718
+ docker_entrypoint=docker_entrypoint,
719
+ pull_image=pull_image,
720
+ require_hardware_counters=require_hwc,
721
+ )
722
+ except KeyboardInterrupt:
723
+ typer.echo("\nInterrupted by user", err=True)
724
+ raise typer.Exit(130) from None
725
+ except Exception as e:
726
+ typer.echo(f"Error: {e}", err=True)
727
+ raise typer.Exit(1) from None
728
+
729
+
730
+ @app.command("remote-run")
731
+ def remote_run( # noqa: PLR0913
732
+ command: list[str] = typer.Argument(..., help="Command to run"),
733
+ upload_dir: Path | None = typer.Option(None, "--upload-dir", "-u", help="Directory to upload (stateless mode)"),
734
+ workspace_id: str | None = typer.Option(None, "--workspace-id", "-w", help="Workspace ID (from wafer push)"),
735
+ gpu_id: int | None = typer.Option(None, "--gpu", "-g", help="GPU ID"),
736
+ docker_image: str | None = typer.Option(None, "--image", "-i", help="Docker image override"),
737
+ docker_entrypoint: str | None = typer.Option(None, "--docker-entrypoint", help="Override Docker entrypoint (e.g., 'bash')"),
738
+ pull_image: bool = typer.Option(False, "--pull-image", help="Pull image if not available on target"),
739
+ require_hwc: bool = typer.Option(False, "--require-hwc", help="Require hardware counters (baremetal)"),
740
+ direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
741
+ target_name: str | None = typer.Option(None, "--target", "-t", help="Target name (for --direct mode)"),
742
+ ) -> None:
743
+ """Run command on remote GPU in Docker.
744
+
745
+ Two modes:
746
+ - High-level (stateless): --upload-dir uploads files and runs command
747
+ - Low-level: --workspace-id uses existing workspace from 'wafer push'
748
+
749
+ By default, uses wafer-api. Use --direct for direct SSH mode.
750
+
751
+ Examples:
752
+ # Stateless: upload and run
753
+ wafer remote-run --upload-dir ./my_project -- python train.py
754
+
755
+ # Run without files
756
+ wafer remote-run -- nvidia-smi
757
+
758
+ # Low-level: use existing workspace
759
+ wafer remote-run --workspace-id ws_abc123 -- python train.py
760
+
761
+ # Direct SSH mode
762
+ wafer remote-run --upload-dir ./my_project --direct --target vultr-b200 -- python train.py
763
+ """
764
+ cmd_str = " ".join(command)
765
+ if not cmd_str.strip():
766
+ typer.echo("Error: Empty command", err=True)
767
+ raise typer.Exit(1)
768
+
769
+ if upload_dir and workspace_id:
770
+ typer.echo("Error: --upload-dir and --workspace-id are mutually exclusive", err=True)
771
+ raise typer.Exit(1)
772
+
773
+ if upload_dir:
774
+ if not upload_dir.exists():
775
+ typer.echo(f"Error: Directory not found: {upload_dir}", err=True)
776
+ raise typer.Exit(1)
777
+ if not upload_dir.is_dir():
778
+ typer.echo(f"Error: Not a directory: {upload_dir}", err=True)
779
+ raise typer.Exit(1)
780
+ upload_dir = upload_dir.resolve()
781
+
782
+ if direct:
783
+ if not target_name:
784
+ typer.echo("Error: --target required for --direct mode", err=True)
785
+ raise typer.Exit(1)
786
+ exit_code = _run_direct_mode(cmd_str, target_name, upload_dir, workspace_id, gpu_id)
787
+ else:
788
+ exit_code = _run_api_mode(cmd_str, upload_dir, workspace_id, gpu_id, docker_image, docker_entrypoint, pull_image, require_hwc)
789
+
790
+ raise typer.Exit(exit_code)
791
+
792
+
793
+ # =============================================================================
794
+ # Authentication commands
795
+ # =============================================================================
796
+
797
+
798
+ @app.command("login")
799
+ def login(
800
+ token: str | None = typer.Option(None, "--token", "-t", help="Access token (skip browser OAuth)"),
801
+ ) -> None:
802
+ """Authenticate CLI with wafer-api via GitHub OAuth.
803
+
804
+ Opens browser for GitHub authentication. Use --token to skip browser.
805
+
806
+ Examples:
807
+ wafer login # opens browser for GitHub OAuth
808
+ wafer login --token xyz # use existing token
809
+ """
810
+ import httpx
811
+
812
+ from .auth import browser_login, save_credentials, verify_token
813
+
814
+ # Browser OAuth if no token provided
815
+ if token is None:
816
+ try:
817
+ token = browser_login()
818
+ except TimeoutError as e:
819
+ typer.echo(f"Error: {e}", err=True)
820
+ raise typer.Exit(1) from None
821
+ except RuntimeError as e:
822
+ typer.echo(f"Error: {e}", err=True)
823
+ raise typer.Exit(1) from None
824
+ except KeyboardInterrupt:
825
+ typer.echo("\nCancelled", err=True)
826
+ raise typer.Exit(1) from None
827
+
828
+ if not token.strip():
829
+ typer.echo("Error: Token cannot be empty", err=True)
830
+ raise typer.Exit(1)
831
+
832
+ token = token.strip()
833
+
834
+ # Verify token with API
835
+ typer.echo("Verifying token...")
836
+ try:
837
+ user_info = verify_token(token)
838
+ except httpx.HTTPStatusError as e:
839
+ if e.response.status_code == 401:
840
+ typer.echo("Error: Invalid token", err=True)
841
+ else:
842
+ typer.echo(f"Error: API returned {e.response.status_code}", err=True)
843
+ raise typer.Exit(1) from None
844
+ except httpx.RequestError as e:
845
+ typer.echo(f"Error: Could not reach API: {e}", err=True)
846
+ raise typer.Exit(1) from None
847
+
848
+ # Save credentials
849
+ save_credentials(token, user_info.email)
850
+
851
+ if user_info.email:
852
+ typer.echo(f"Logged in as {user_info.email}")
853
+ else:
854
+ typer.echo(f"Logged in (user_id: {user_info.user_id})")
855
+ typer.echo("Token saved to ~/.wafer/credentials.json")
856
+
857
+
858
+ @app.command("logout")
859
+ def logout() -> None:
860
+ """Remove stored credentials."""
861
+ from .auth import clear_credentials
862
+
863
+ if clear_credentials():
864
+ typer.echo("Logged out. Credentials removed.")
865
+ else:
866
+ typer.echo("Not logged in (no credentials found).")
867
+
868
+
869
+ @app.command("whoami")
870
+ def whoami() -> None:
871
+ """Show current authenticated user."""
872
+ from .auth import load_credentials
873
+
874
+ creds = load_credentials()
875
+ if creds is None:
876
+ typer.echo("Not logged in. Run: wafer login")
877
+ raise typer.Exit(1)
878
+
879
+ if creds.email:
880
+ typer.echo(creds.email)
881
+ else:
882
+ typer.echo("Logged in (email not available)")
883
+
884
+
885
+ # =============================================================================
886
+ # Ask-docs command
887
+ # =============================================================================
888
+
889
+
890
+ @app.command("ask-docs")
891
+ def ask_docs(
892
+ query: str = typer.Argument(..., help="Question about GPU programming/documentation"),
893
+ session_id: str | None = typer.Option(
894
+ None,
895
+ "--session-id",
896
+ "-s",
897
+ help="Session ID for follow-up questions (returned from previous query)",
898
+ ),
899
+ docs_url: str = typer.Option(
900
+ "https://www.api.wafer.ai",
901
+ "--docs-url",
902
+ envvar="WAFER_DOCS_URL",
903
+ help="URL of docs-tool service",
904
+ ),
905
+ json_output: bool = typer.Option(
906
+ False,
907
+ "--json",
908
+ help="Output raw JSON events instead of streaming text",
909
+ ),
910
+ ) -> None:
911
+ """Query GPU documentation using the docs-tool service.
912
+
913
+ NOTE: Requires docs-tool service to be running locally:
914
+ cd services/docs-tool && uv run uvicorn src.main:app --reload
915
+
916
+ Examples:
917
+ # Ask a question (requires local docs-tool)
918
+ wafer ask-docs "What is TMEM in CuTeDSL?"
919
+
920
+ # Follow-up question using session ID
921
+ wafer ask-docs "How do I use it for fp4?" --session-id abc123
922
+ """
923
+ import httpx
924
+
925
+ url = f"{docs_url.rstrip('/')}/v1/docs/rag/stream"
926
+ payload = {"query": query}
927
+ # Note: session_id not yet supported by /v1/docs/rag/stream endpoint
928
+ # TODO: Add conversation_history support for follow-up questions
929
+
930
+ try:
931
+ with httpx.Client(timeout=120.0) as client:
932
+ with client.stream("POST", url, json=payload) as response:
933
+ if response.status_code != 200:
934
+ # Read the error response body
935
+ error_body = response.read().decode("utf-8", errors="replace")
936
+ typer.echo(f"Error: {response.status_code} - {error_body}", err=True)
937
+ raise typer.Exit(1)
938
+
939
+ current_session_id = None
940
+ for line in response.iter_lines():
941
+ if not line.startswith("data: "):
942
+ continue
943
+
944
+ try:
945
+ event = json.loads(line[6:]) # Skip "data: " prefix
946
+ except json.JSONDecodeError:
947
+ continue
948
+
949
+ if json_output:
950
+ typer.echo(json.dumps(event))
951
+ continue
952
+
953
+ # Handle different event types
954
+ event_type = event.get("type", "")
955
+
956
+ if event_type == "sources":
957
+ # Optionally show sources
958
+ pass
959
+ elif event_type == "chunk":
960
+ # Stream text to stdout
961
+ text = event.get("text", "")
962
+ if text:
963
+ sys.stdout.write(text)
964
+ sys.stdout.flush()
965
+ elif event_type == "done":
966
+ # Final chunk - get session ID for follow-ups
967
+ text = event.get("text", "")
968
+ if text:
969
+ sys.stdout.write(text)
970
+ sys.stdout.flush()
971
+ current_session_id = event.get("session_id")
972
+
973
+ # Print newline and session info
974
+ if not json_output:
975
+ print() # Final newline
976
+ if current_session_id:
977
+ typer.echo(
978
+ f"\n[Session: {current_session_id}] "
979
+ f"Use --session-id {current_session_id} for follow-up questions",
980
+ err=True,
981
+ )
982
+
983
+ except httpx.ConnectError:
984
+ typer.echo(
985
+ f"Error: Could not connect to docs-tool at {docs_url}\n"
986
+ "Make sure the docs-tool service is running:\n"
987
+ " cd services/docs-tool && uv run uvicorn src.main:app --reload",
988
+ err=True,
989
+ )
990
+ raise typer.Exit(1)
991
+ except httpx.TimeoutException:
992
+ typer.echo("Error: Request timed out", err=True)
993
+ raise typer.Exit(1)
994
+
995
+
996
+ # =============================================================================
997
+ # Targets subcommands
998
+ # =============================================================================
999
+
1000
+
1001
+ @targets_app.command("add")
1002
+ def targets_add(
1003
+ file_path: Path = typer.Argument(..., help="Path to target TOML file"),
1004
+ ) -> None:
1005
+ """Add a target from a TOML config file.
1006
+
1007
+ Example:
1008
+ wafer targets add ~/configs/modal-b200.toml
1009
+ """
1010
+ from .targets import add_target_from_file, get_target_info
1011
+
1012
+ try:
1013
+ target = add_target_from_file(file_path)
1014
+ typer.echo(f"Added target: {target.name}")
1015
+ info = get_target_info(target)
1016
+ for key, value in info.items():
1017
+ typer.echo(f" {key}: {value}")
1018
+ except FileNotFoundError as e:
1019
+ typer.echo(f"Error: {e}", err=True)
1020
+ raise typer.Exit(1) from None
1021
+ except (ValueError, AssertionError) as e:
1022
+ typer.echo(f"Error: Invalid target config: {e}", err=True)
1023
+ raise typer.Exit(1) from None
1024
+
1025
+
1026
+ @targets_app.command("list")
1027
+ def targets_list() -> None:
1028
+ """List all configured targets.
1029
+
1030
+ Example:
1031
+ wafer targets list
1032
+ """
1033
+ from .targets import get_default_target, list_targets
1034
+
1035
+ targets = list_targets()
1036
+ default = get_default_target()
1037
+
1038
+ if not targets:
1039
+ typer.echo("No targets configured.")
1040
+ typer.echo("Add one with: wafer targets add <path/to/target.toml>")
1041
+ return
1042
+
1043
+ typer.echo("Configured targets:")
1044
+ for name in targets:
1045
+ marker = " (default)" if name == default else ""
1046
+ typer.echo(f" {name}{marker}")
1047
+
1048
+
1049
+ @targets_app.command("show")
1050
+ def targets_show(
1051
+ name: str = typer.Argument(..., help="Target name"),
1052
+ ) -> None:
1053
+ """Show details for a target.
1054
+
1055
+ Example:
1056
+ wafer targets show modal-b200
1057
+ """
1058
+ from .targets import get_target_info, load_target
1059
+
1060
+ try:
1061
+ target = load_target(name)
1062
+ typer.echo(f"Target: {name}")
1063
+ info = get_target_info(target)
1064
+ for key, value in info.items():
1065
+ typer.echo(f" {key}: {value}")
1066
+ except FileNotFoundError as e:
1067
+ typer.echo(f"Error: {e}", err=True)
1068
+ raise typer.Exit(1) from None
1069
+
1070
+
1071
+ @targets_app.command("remove")
1072
+ def targets_remove(
1073
+ name: str = typer.Argument(..., help="Target name"),
1074
+ force: bool = typer.Option(False, "--force", "-f", help="Skip confirmation"),
1075
+ ) -> None:
1076
+ """Remove a target.
1077
+
1078
+ Example:
1079
+ wafer targets remove modal-b200
1080
+ """
1081
+ from .targets import remove_target
1082
+
1083
+ if not force:
1084
+ confirm = typer.confirm(f"Remove target '{name}'?")
1085
+ if not confirm:
1086
+ typer.echo("Cancelled.")
1087
+ raise typer.Exit(0)
1088
+
1089
+ try:
1090
+ remove_target(name)
1091
+ typer.echo(f"Removed target: {name}")
1092
+ except FileNotFoundError as e:
1093
+ typer.echo(f"Error: {e}", err=True)
1094
+ raise typer.Exit(1) from None
1095
+
1096
+
1097
+ @targets_app.command("default")
1098
+ def targets_default(
1099
+ name: str = typer.Argument(..., help="Target name to set as default"),
1100
+ ) -> None:
1101
+ """Set the default target.
1102
+
1103
+ Example:
1104
+ wafer targets default modal-b200
1105
+ """
1106
+ from .targets import set_default_target
1107
+
1108
+ try:
1109
+ set_default_target(name)
1110
+ typer.echo(f"Default target set to: {name}")
1111
+ except FileNotFoundError as e:
1112
+ typer.echo(f"Error: {e}", err=True)
1113
+ raise typer.Exit(1) from None
1114
+
1115
+
1116
+ # =============================================================================
1117
+ # NCU Analyze command
1118
+ # =============================================================================
1119
+
1120
+
1121
+ @app.command("ncu-analyze")
1122
+ def ncu_analyze(
1123
+ filepath: Path = typer.Argument(..., help="Path to .ncu-rep profile file"),
1124
+ output_dir: Path | None = typer.Option(
1125
+ None, "--output-dir", "-o", help="Output directory for analysis files"
1126
+ ),
1127
+ json_output: bool = typer.Option(False, "--json", help="Output raw JSON instead of formatted text"),
1128
+ remote: bool | None = typer.Option(
1129
+ None, "--remote/--local",
1130
+ help="Force remote (via API) or local analysis. Default: auto-detect (remote if NCU not installed locally)"
1131
+ ),
1132
+ target: str | None = typer.Option(
1133
+ None, "--target", "-t",
1134
+ help="Target name for direct SSH mode (e.g., 'vultr-b200'). Bypasses API."
1135
+ ),
1136
+ ) -> None:
1137
+ """Analyze an NVIDIA Nsight Compute profile (.ncu-rep file).
1138
+
1139
+ Returns kernel performance metrics including duration, occupancy,
1140
+ compute/memory throughput, and optimization recommendations.
1141
+
1142
+ By default, uses local NCU if available, otherwise runs analysis
1143
+ remotely via wafer-api (requires authentication: wafer login).
1144
+
1145
+ Use --target for direct SSH mode (like wafer remote-run --direct).
1146
+
1147
+ Examples:
1148
+ wafer ncu-analyze profile.ncu-rep
1149
+ wafer ncu-analyze profile.ncu-rep --json
1150
+ wafer ncu-analyze profile.ncu-rep --output-dir ./analysis
1151
+ wafer ncu-analyze profile.ncu-rep --remote # Force remote via API
1152
+ wafer ncu-analyze profile.ncu-rep --target vultr-b200 # Direct SSH
1153
+ """
1154
+ from .ncu_analyze import analyze_ncu_profile
1155
+
1156
+ if not filepath.exists():
1157
+ typer.echo(f"Error: File not found: {filepath}", err=True)
1158
+ raise typer.Exit(1)
1159
+
1160
+ if filepath.suffix != ".ncu-rep":
1161
+ typer.echo(f"Error: Expected .ncu-rep file, got: {filepath.suffix}", err=True)
1162
+ raise typer.Exit(1)
1163
+
1164
+ try:
1165
+ result = analyze_ncu_profile(
1166
+ filepath,
1167
+ output_dir=output_dir,
1168
+ json_output=json_output,
1169
+ remote=remote,
1170
+ target=target,
1171
+ )
1172
+ typer.echo(result)
1173
+ except FileNotFoundError as e:
1174
+ typer.echo(f"Error: {e}", err=True)
1175
+ raise typer.Exit(1) from None
1176
+ except RuntimeError as e:
1177
+ typer.echo(f"Error: {e}", err=True)
1178
+ raise typer.Exit(1) from None
1179
+
1180
+
1181
+ # =============================================================================
1182
+ # Compiler Analyze command
1183
+ # =============================================================================
1184
+
1185
+
1186
+ @app.command("compiler-analyze")
1187
+ def compiler_analyze(
1188
+ mlir_file: Path | None = typer.Option(None, "--mlir", help="Path to MLIR file"),
1189
+ ptx_file: Path | None = typer.Option(None, "--ptx", help="Path to PTX file"),
1190
+ sass_file: Path | None = typer.Option(None, "--sass", help="Path to SASS file"),
1191
+ source_file: Path | None = typer.Option(None, "--source", help="Path to source file"),
1192
+ mlir_text: str | None = typer.Option(None, "--mlir-text", help="MLIR text content"),
1193
+ ptx_text: str | None = typer.Option(None, "--ptx-text", help="PTX text content"),
1194
+ sass_text: str | None = typer.Option(None, "--sass-text", help="SASS text content"),
1195
+ source_text: str | None = typer.Option(None, "--source-text", help="Source code text"),
1196
+ kernel_name: str | None = typer.Option(None, "--kernel-name", help="Kernel name"),
1197
+ json_output: bool = typer.Option(True, "--json/--no-json", help="Output JSON"),
1198
+ ) -> None:
1199
+ """Analyze compiler kernel (MLIR/PTX/SASS).
1200
+
1201
+ Examples:
1202
+ wafer compiler-analyze --mlir-text "..." --ptx-text "..." --sass-text "..."
1203
+ wafer compiler-analyze --mlir file.mlir --ptx file.ptx --sass file.sass
1204
+ """
1205
+ import sys
1206
+
1207
+ from .compiler_analyze import analyze_compiler_kernel
1208
+
1209
+ try:
1210
+ result = analyze_compiler_kernel(
1211
+ mlir_file=mlir_file,
1212
+ ptx_file=ptx_file,
1213
+ sass_file=sass_file,
1214
+ source_file=source_file,
1215
+ mlir_text=mlir_text,
1216
+ ptx_text=ptx_text,
1217
+ sass_text=sass_text,
1218
+ source_text=source_text,
1219
+ kernel_name=kernel_name,
1220
+ json_output=json_output,
1221
+ )
1222
+ if json_output:
1223
+ print(result)
1224
+ else:
1225
+ typer.echo(result)
1226
+ except ValueError as e:
1227
+ if json_output:
1228
+ import json
1229
+ error_json = json.dumps({"success": False, "error": str(e)}, indent=2)
1230
+ print(error_json, file=sys.stderr)
1231
+ else:
1232
+ typer.echo(f"Error: {e}", err=True)
1233
+ raise typer.Exit(1) from None
1234
+ except Exception as e:
1235
+ if json_output:
1236
+ import json
1237
+ error_json = json.dumps({"success": False, "error": str(e)}, indent=2)
1238
+ print(error_json, file=sys.stderr)
1239
+ else:
1240
+ typer.echo(f"Error: {e}", err=True)
1241
+ raise typer.Exit(1) from None
1242
+
1243
+
1244
+ @app.command("capture")
1245
+ def capture_command(
1246
+ label: str = typer.Argument(..., help="Label for this capture (e.g., 'baseline', 'optimized-v2')"),
1247
+ command: str = typer.Argument(..., help="Command to execute and capture"),
1248
+ variant: str | None = typer.Option(None, "--variant", "-v", help="Variant identifier for grouping related captures"),
1249
+ tags: list[str] | None = typer.Option(None, "--tag", "-t", help="Tags for categorization (can be specified multiple times)"), # noqa: B008
1250
+ working_dir: Path | None = typer.Option(None, "--dir", "-d", help="Working directory (default: current directory)"),
1251
+ sweep: list[str] | None = typer.Option(None, "--sweep", "-s", help="Parameter sweep (format: VAR=val1,val2,val3)"), # noqa: B008
1252
+ code_denylist: list[str] | None = typer.Option(None, "--code-denylist", help="Patterns to exclude from code files (e.g., '*.log', '**/test/**')"), # noqa: B008
1253
+ artifact_denylist: list[str] | None = typer.Option(None, "--artifact-denylist", help="Patterns to exclude from artifacts (e.g., '*.tmp', '*.o')"), # noqa: B008
1254
+ ) -> None:
1255
+ """Capture a complete execution snapshot for reproducibility.
1256
+
1257
+ Captures everything needed to reproduce a benchmark run:
1258
+ - Command output (stdout/stderr), exit code, duration
1259
+ - Generated artifacts (outputs, profiles, logs)
1260
+ - Code files used in execution
1261
+ - Git context (repo, commit, branch, dirty status)
1262
+ - System context (GPU model, CUDA version, hostname)
1263
+ - Metrics extracted from stdout (latency, throughput, etc.)
1264
+
1265
+ All data is uploaded to Supabase for later analysis and comparison.
1266
+
1267
+ Denylist Configuration (precedence: CLI > Project > Global > Defaults):
1268
+ 1. CLI flags: --code-denylist and --artifact-denylist
1269
+ 2. Project config: .wafer-capture.toml in working directory
1270
+ 3. Global config: ~/.wafer/capture.toml
1271
+ 4. Built-in defaults (excludes common binaries, dependencies, etc.)
1272
+
1273
+ Examples:
1274
+ # Basic capture
1275
+ wafer capture baseline "python benchmark.py"
1276
+
1277
+ # With variant for A/B testing
1278
+ wafer capture optimized "python benchmark.py" --variant v2
1279
+
1280
+ # With tags
1281
+ wafer capture test-run "make && ./kernel" --tag cuda --tag fp16
1282
+
1283
+ # Different working directory
1284
+ wafer capture training "python train.py" --dir ./experiments/run1
1285
+
1286
+ # Custom denylists via CLI
1287
+ wafer capture test "make" --code-denylist "*.log" --code-denylist "**/test/**"
1288
+
1289
+ # Parameter sweep (runs multiple captures with different values)
1290
+ wafer capture batch-sizes "python train.py --batch-size {BATCH}" --sweep "BATCH=16,32,64,128"
1291
+
1292
+ # Multiple variable sweep (cartesian product)
1293
+ wafer capture grid-search "python train.py --lr {LR} --bs {BS}" --sweep "LR=0.001,0.01,0.1" --sweep "BS=16,32"
1294
+ """
1295
+ import itertools
1296
+ import tomllib
1297
+
1298
+ import trio
1299
+ from wafer_core.capture.core import capture
1300
+ from wafer_core.capture.dtypes import CaptureConfig
1301
+ from wafer_core.capture.executor import execute_command
1302
+
1303
+ # Resolve working directory
1304
+ work_dir = working_dir.resolve() if working_dir else Path.cwd()
1305
+
1306
+ # Load denylists from config files (precedence: project > global > defaults)
1307
+ config_code_denylist = None
1308
+ config_artifact_denylist = None
1309
+
1310
+ # 1. Try global config (~/.wafer/capture.toml)
1311
+ global_config_path = Path.home() / ".wafer" / "capture.toml"
1312
+ if global_config_path.exists():
1313
+ try:
1314
+ with open(global_config_path, "rb") as f:
1315
+ capture_config_data = tomllib.load(f)
1316
+ config_code_denylist = capture_config_data.get("code_denylist")
1317
+ config_artifact_denylist = capture_config_data.get("artifact_denylist")
1318
+ except Exception as e:
1319
+ typer.echo(f"āš ļø Warning: Failed to load {global_config_path}: {e}", err=True)
1320
+
1321
+ # 2. Try project-specific config (.wafer-capture.toml in working dir)
1322
+ project_config_path = work_dir / ".wafer-capture.toml"
1323
+ if project_config_path.exists():
1324
+ try:
1325
+ with open(project_config_path, "rb") as f:
1326
+ project_config_data = tomllib.load(f)
1327
+ # Project config overrides global config
1328
+ if "code_denylist" in project_config_data:
1329
+ config_code_denylist = project_config_data["code_denylist"]
1330
+ if "artifact_denylist" in project_config_data:
1331
+ config_artifact_denylist = project_config_data["artifact_denylist"]
1332
+ except Exception as e:
1333
+ typer.echo(f"āš ļø Warning: Failed to load {project_config_path}: {e}", err=True)
1334
+
1335
+ # Parse sweep parameters (format: "VAR=val1,val2,val3")
1336
+ sweep_vars: dict[str, list[str]] = {}
1337
+ if sweep:
1338
+ for sweep_spec in sweep:
1339
+ if "=" not in sweep_spec:
1340
+ typer.echo(f"āŒ Invalid sweep format: {sweep_spec}", err=True)
1341
+ typer.echo(" Expected format: VAR=val1,val2,val3", err=True)
1342
+ raise typer.Exit(1)
1343
+
1344
+ var_name, values_str = sweep_spec.split("=", 1)
1345
+ values = [v.strip() for v in values_str.split(",")]
1346
+ sweep_vars[var_name] = values
1347
+
1348
+ # Generate all combinations (cartesian product) of sweep variables
1349
+ if sweep_vars:
1350
+ var_names = list(sweep_vars.keys())
1351
+ var_values = [sweep_vars[name] for name in var_names]
1352
+ combinations = list(itertools.product(*var_values))
1353
+
1354
+ typer.echo(f"šŸ”¬ Running sweep: {label}")
1355
+ typer.echo(f" Variables: {', '.join(var_names)}")
1356
+ typer.echo(f" Total runs: {len(combinations)}")
1357
+ typer.echo()
1358
+ else:
1359
+ # Single run (no sweep)
1360
+ combinations = [()]
1361
+ var_names = []
1362
+
1363
+ # Progress callback
1364
+ def progress(msg: str) -> None:
1365
+ typer.echo(f" {msg}")
1366
+
1367
+ async def run_capture_sweep() -> None:
1368
+ successful = 0
1369
+ failed = 0
1370
+
1371
+ for idx, combo in enumerate(combinations, 1):
1372
+ # Substitute variables in command
1373
+ substituted_cmd = command
1374
+ sweep_info = {}
1375
+ for var_name, value in zip(var_names, combo, strict=True):
1376
+ substituted_cmd = substituted_cmd.replace(f"{{{var_name}}}", value)
1377
+ sweep_info[var_name] = value
1378
+
1379
+ # Create variant name for sweep runs
1380
+ if sweep_vars:
1381
+ variant_parts = [f"{k}={v}" for k, v in sweep_info.items()]
1382
+ run_variant = "_".join(variant_parts)
1383
+ if variant:
1384
+ run_variant = f"{variant}_{run_variant}"
1385
+ else:
1386
+ run_variant = variant
1387
+
1388
+ # Create config for this run
1389
+ # Build denylist kwargs with precedence: CLI > Config File > Defaults
1390
+ denylist_kwargs = {}
1391
+
1392
+ # Code denylist: CLI flag takes precedence over config file
1393
+ if code_denylist:
1394
+ denylist_kwargs['code_denylist'] = code_denylist
1395
+ elif config_code_denylist:
1396
+ denylist_kwargs['code_denylist'] = config_code_denylist
1397
+ # Otherwise use CaptureConfig defaults
1398
+
1399
+ # Artifact denylist: CLI flag takes precedence over config file
1400
+ if artifact_denylist:
1401
+ denylist_kwargs['artifact_denylist'] = artifact_denylist
1402
+ elif config_artifact_denylist:
1403
+ denylist_kwargs['artifact_denylist'] = config_artifact_denylist
1404
+ # Otherwise use CaptureConfig defaults
1405
+
1406
+ config = CaptureConfig(
1407
+ label=label,
1408
+ command=substituted_cmd,
1409
+ working_dir=work_dir,
1410
+ variant=run_variant,
1411
+ tags=tags or [],
1412
+ **denylist_kwargs,
1413
+ )
1414
+
1415
+ try:
1416
+ if sweep_vars:
1417
+ typer.echo(f"[{idx}/{len(combinations)}] {', '.join(f'{k}={v}' for k, v in sweep_info.items())}")
1418
+ else:
1419
+ typer.echo(f"šŸ”¬ Capturing: {label}")
1420
+
1421
+ typer.echo(f" Command: {substituted_cmd}")
1422
+ typer.echo(f" Working dir: {work_dir}")
1423
+ typer.echo()
1424
+
1425
+ result = await capture(
1426
+ config=config,
1427
+ runner=execute_command,
1428
+ progress_callback=progress
1429
+ )
1430
+
1431
+ typer.echo()
1432
+ typer.echo("āœ… Capture complete!")
1433
+ typer.echo(f" ID: {result.id}")
1434
+ typer.echo(f" Exit code: {result.exit_code}")
1435
+ typer.echo(f" Duration: {result.duration_seconds:.2f}s")
1436
+ typer.echo(f" Code files: {len(result.code_files)}")
1437
+ typer.echo(f" Artifacts: {len(result.artifacts)}")
1438
+ if result.metrics.stdout_metrics:
1439
+ typer.echo(f" Metrics: {len(result.metrics.stdout_metrics)}")
1440
+ typer.echo()
1441
+
1442
+ successful += 1
1443
+
1444
+ except Exception as e:
1445
+ typer.echo(f"\nāŒ Capture failed: {e}", err=True)
1446
+ typer.echo()
1447
+ failed += 1
1448
+
1449
+ # Summary for sweep runs
1450
+ if sweep_vars and len(combinations) > 1:
1451
+ typer.echo("=" * 60)
1452
+ typer.echo(f"Sweep complete: {successful} successful, {failed} failed")
1453
+
1454
+ if failed > 0:
1455
+ raise typer.Exit(1)
1456
+
1457
+ trio.run(run_capture_sweep)
1458
+
1459
+
1460
+ @app.command("capture-list")
1461
+ def capture_list_command(
1462
+ label: str | None = typer.Option(None, "--label", "-l", help="Filter by label"),
1463
+ limit: int = typer.Option(100, "--limit", "-n", help="Maximum number of results"),
1464
+ offset: int = typer.Option(0, "--offset", "-o", help="Offset for pagination"),
1465
+ json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
1466
+ ) -> None:
1467
+ """List captured executions.
1468
+
1469
+ Query captures from the backend with optional filtering and pagination.
1470
+ Output can be formatted as a table (default) or JSON.
1471
+
1472
+ Examples:
1473
+ # List all captures
1474
+ wafer capture-list
1475
+
1476
+ # Filter by label
1477
+ wafer capture-list --label baseline
1478
+
1479
+ # Get JSON output for scripting
1480
+ wafer capture-list --json --limit 10
1481
+
1482
+ # Pagination
1483
+ wafer capture-list --limit 20 --offset 20
1484
+ """
1485
+
1486
+ import trio
1487
+ from wafer_core.tools.backend import list_captures
1488
+
1489
+ async def run_list() -> None:
1490
+ try:
1491
+ captures = await list_captures(label=label, limit=limit, offset=offset)
1492
+
1493
+ if json_output:
1494
+ # JSON output for machine consumption
1495
+ typer.echo(json.dumps(captures, indent=2))
1496
+ else:
1497
+ # Human-readable table output
1498
+ if not captures:
1499
+ typer.echo("No captures found.")
1500
+ return
1501
+
1502
+ typer.echo(f"Found {len(captures)} captures:\n")
1503
+
1504
+ # Print table header
1505
+ typer.echo(
1506
+ f"{'ID':<36} {'Label':<20} {'Variant':<20} {'Exit':<4} {'Duration':<8} {'Created'}"
1507
+ )
1508
+ typer.echo("-" * 120)
1509
+
1510
+ # Print each capture
1511
+ for cap in captures:
1512
+ cap_id = cap.get("id", "")[:36]
1513
+ cap_label = cap.get("label", "")[:20]
1514
+ cap_variant = (cap.get("variant") or "")[:20]
1515
+ exit_code = cap.get("exit_code", "")
1516
+ duration = f"{cap.get('duration_seconds', 0):.2f}s"
1517
+ created = cap.get("created_at", "")[:19] # Strip microseconds
1518
+
1519
+ typer.echo(
1520
+ f"{cap_id:<36} {cap_label:<20} {cap_variant:<20} {exit_code:<4} {duration:<8} {created}"
1521
+ )
1522
+
1523
+ except Exception as e:
1524
+ typer.echo(f"āŒ Failed to list captures: {e}", err=True)
1525
+ raise typer.Exit(1) from None
1526
+
1527
+ trio.run(run_list)
1528
+
1529
+
1530
+ def main() -> None:
1531
+ """Entry point for wafer CLI."""
1532
+ app()
1533
+
1534
+
1535
+ if __name__ == "__main__":
1536
+ main()