wafer-cli 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -30,6 +30,14 @@ import typer
30
30
 
31
31
  from .config import WaferConfig, WaferEnvironment
32
32
  from .inference import infer_upload_files, resolve_environment
33
+ from .problems import (
34
+ download_problems,
35
+ get_problem_path,
36
+ get_problems_path,
37
+ )
38
+ from .problems import (
39
+ list_problems as list_problems_fn,
40
+ )
33
41
 
34
42
  app = typer.Typer(
35
43
  help="GPU development toolkit for LLM coding agents",
@@ -91,11 +99,15 @@ def main_callback(ctx: typer.Context) -> None:
91
99
  # Install exception hook to catch SystemExit and mark failures
92
100
  original_excepthook = sys.excepthook
93
101
 
94
- def custom_excepthook(exc_type, exc_value, exc_traceback):
102
+ def custom_excepthook(
103
+ exc_type: type[BaseException],
104
+ exc_value: BaseException,
105
+ exc_traceback: object,
106
+ ) -> None:
95
107
  global _command_outcome
96
108
  # Mark as failure if SystemExit with non-zero code, or any other exception
97
109
  if exc_type is SystemExit:
98
- exit_code = exc_value.code if hasattr(exc_value, 'code') else 1
110
+ exit_code = exc_value.code if hasattr(exc_value, "code") else 1
99
111
  if exit_code != 0 and exit_code is not None:
100
112
  _command_outcome = "failure"
101
113
  else:
@@ -170,7 +182,12 @@ workspaces_app = typer.Typer(
170
182
 
171
183
  Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
172
184
 
173
- wafer workspaces create dev --gpu H100 # Create workspace
185
+ Available GPUs:
186
+ MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
187
+ B200 NVIDIA Blackwell B200 (180GB HBM3e, CUDA)
188
+
189
+ Commands:
190
+ wafer workspaces create dev --gpu B200 # Create workspace
174
191
  wafer workspaces exec dev -- python x.py # Run commands
175
192
  wafer workspaces ssh dev # Interactive SSH
176
193
  wafer workspaces sync dev ./project # Sync files
@@ -178,6 +195,36 @@ Workspaces are on-demand cloud GPU environments. Requires authentication (wafer
178
195
  )
179
196
  app.add_typer(workspaces_app, name="workspaces")
180
197
 
198
+ # SSH Key management (BYOK - Bring Your Own Key)
199
+ ssh_keys_app = typer.Typer(
200
+ help="""Manage SSH public keys for workspace access.
201
+
202
+ Register your SSH public keys here. These keys are installed in all workspaces
203
+ you provision, enabling SSH access from any machine with your private key.
204
+
205
+ wafer ssh-keys list # List registered keys
206
+ wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
207
+ wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
208
+ wafer ssh-keys remove <key-id> # Remove a key"""
209
+ )
210
+ app.add_typer(ssh_keys_app, name="ssh-keys")
211
+
212
+ # Target operations (exec/ssh/sync on configured targets)
213
+ targets_ops_app = typer.Typer(
214
+ help="""Execute commands on configured GPU targets.
215
+
216
+ Run commands, SSH, or sync files to targets without going through evaluate.
217
+ Useful for exploratory work, debugging, or custom scripts.
218
+
219
+ wafer targets exec my-target -- python test.py # Run command
220
+ wafer targets ssh my-target # Interactive SSH
221
+ wafer targets sync my-target ./local_dir # Sync files
222
+
223
+ Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
224
+ Configure targets with: wafer config targets init ..."""
225
+ )
226
+ app.add_typer(targets_ops_app, name="targets")
227
+
181
228
  # Billing management
182
229
  billing_app = typer.Typer(help="Manage billing, credits, and subscription")
183
230
  app.add_typer(billing_app, name="billing")
@@ -200,6 +247,13 @@ kernelbench_app = typer.Typer(
200
247
  )
201
248
  evaluate_app.add_typer(kernelbench_app, name="kernelbench")
202
249
 
250
+ # Nested subcommand for gpumode format
251
+ gpumode_app = typer.Typer(
252
+ help="Evaluate kernels in GPUMode format (custom_kernel/ref_kernel functions)",
253
+ invoke_without_command=True,
254
+ )
255
+ evaluate_app.add_typer(gpumode_app, name="gpumode")
256
+
203
257
  # =============================================================================
204
258
  # Dev commands (internal, used by web app proxy)
205
259
  # =============================================================================
@@ -238,10 +292,101 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
238
292
  amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
239
293
  app.add_typer(amd_app, name="amd")
240
294
 
241
- # ISA analysis - under amd
242
- isa_app = typer.Typer(help="ISA analysis for AMD GPU code objects (.co files)")
295
+ # Unified ISA Analyzer - supports both .co files and Triton artifacts
296
+ isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
243
297
  amd_app.add_typer(isa_app, name="isa")
244
298
 
299
+ # =============================================================================
300
+ # Roofline analysis (wafer roofline)
301
+ # =============================================================================
302
+
303
+
304
+ @app.command("roofline")
305
+ def roofline_cmd(
306
+ gpu: str | None = typer.Option(
307
+ None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
308
+ ),
309
+ bytes_moved: float | None = typer.Option(
310
+ None, "--bytes", "-b", help="Theoretical minimum bytes moved"
311
+ ),
312
+ flops: float | None = typer.Option(None, "--flops", "-f", help="Theoretical minimum FLOPs"),
313
+ time_ms: float | None = typer.Option(
314
+ None, "--time-ms", "-t", help="Actual kernel time in milliseconds"
315
+ ),
316
+ dtype: str = typer.Option(
317
+ "fp16", "--dtype", "-d", help="Data type for compute ceiling (fp16, fp32, bf16, fp8, int8)"
318
+ ),
319
+ list_gpus: bool = typer.Option(False, "--list-gpus", help="List available GPU specs and exit"),
320
+ ) -> None:
321
+ """Analyze kernel performance against roofline model.
322
+
323
+ The roofline model shows the theoretical speed-of-light (SOL) for your kernel
324
+ based on whether it's memory-bound or compute-bound.
325
+
326
+ You need to provide:
327
+ - The GPU you ran on
328
+ - Theoretical minimum bytes moved (not actual - what the algorithm requires)
329
+ - Theoretical minimum FLOPs
330
+ - Actual measured kernel time
331
+
332
+ Example:
333
+ # Analyze a matmul kernel (4096x4096x4096, FP16)
334
+ # Theoretical: 2*M*N*K FLOPs = 137.4 TFLOP
335
+ # Theoretical bytes: (M*K + K*N + M*N) * 2 = 100.7 MB
336
+ wafer roofline --gpu H100 --bytes 100.7e6 --flops 137.4e12 --time-ms 85
337
+
338
+ # Analyze a memory-bound elementwise add (1B elements FP32)
339
+ # Reads 2 tensors, writes 1 = 12 GB total
340
+ # 1B adds = 1 GFLOP
341
+ wafer roofline --gpu H100 --bytes 12e9 --flops 1e9 --time-ms 4 --dtype fp32
342
+
343
+ # List available GPUs
344
+ wafer roofline --list-gpus
345
+ """
346
+ from wafer_core.roofline import get_gpu_spec, roofline_analysis
347
+ from wafer_core.roofline import list_gpus as get_all_gpus
348
+
349
+ if list_gpus:
350
+ typer.echo("Available GPUs:")
351
+ for name in get_all_gpus():
352
+ spec = get_gpu_spec(name)
353
+ typer.echo(
354
+ f" {name}: {spec.peak_bandwidth_gbps:.0f} GB/s, {spec.peak_tflops_fp16:.0f} TFLOPS FP16"
355
+ )
356
+ return
357
+
358
+ # Validate required args for analysis
359
+ missing = []
360
+ if gpu is None:
361
+ missing.append("--gpu")
362
+ if bytes_moved is None:
363
+ missing.append("--bytes")
364
+ if flops is None:
365
+ missing.append("--flops")
366
+ if time_ms is None:
367
+ missing.append("--time-ms")
368
+
369
+ if missing:
370
+ typer.echo(f"Error: Missing required options: {', '.join(missing)}", err=True)
371
+ typer.echo("", err=True)
372
+ typer.echo("Run 'wafer roofline --help' for usage.", err=True)
373
+ raise typer.Exit(1)
374
+
375
+ try:
376
+ result = roofline_analysis(
377
+ gpu=gpu,
378
+ dtype=dtype,
379
+ bytes_moved=bytes_moved,
380
+ flops=flops,
381
+ time_ms=time_ms,
382
+ )
383
+ except ValueError as e:
384
+ typer.echo(f"Error: {e}", err=True)
385
+ raise typer.Exit(1) from None
386
+
387
+ typer.echo(result.format_report())
388
+
389
+
245
390
  # =============================================================================
246
391
  # Skill management (wafer skill ...)
247
392
  # =============================================================================
@@ -256,21 +401,22 @@ def skill_install(
256
401
  "all",
257
402
  "--target",
258
403
  "-t",
259
- help="Target tool: claude, codex, or all",
404
+ help="Target tool: claude, codex, cursor, or all",
260
405
  ),
261
406
  force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing skill"),
262
407
  ) -> None:
263
408
  """Install the wafer-guide skill for AI coding assistants.
264
409
 
265
410
  Installs the bundled skill to make wafer commands discoverable by
266
- Claude Code and/or OpenAI Codex CLI.
411
+ Claude Code, OpenAI Codex CLI, and/or Cursor.
267
412
 
268
413
  Skills follow the open agent skills specification (agentskills.io).
269
414
 
270
415
  Examples:
271
- wafer skill install # Install for both Claude and Codex
416
+ wafer skill install # Install for all tools
272
417
  wafer skill install -t claude # Install for Claude Code only
273
418
  wafer skill install -t codex # Install for Codex CLI only
419
+ wafer skill install -t cursor # Install for Cursor only
274
420
  wafer skill install --force # Overwrite existing installation
275
421
  """
276
422
  # Locate bundled skill
@@ -288,9 +434,13 @@ def skill_install(
288
434
  ))
289
435
  if target in ("all", "codex"):
290
436
  targets_to_install.append(("Codex CLI", Path.home() / ".codex" / "skills" / "wafer-guide"))
437
+ if target in ("all", "cursor"):
438
+ targets_to_install.append(("Cursor", Path.home() / ".cursor" / "skills" / "wafer-guide"))
291
439
 
292
440
  if not targets_to_install:
293
- typer.echo(f"Error: Unknown target '{target}'. Use: claude, codex, or all", err=True)
441
+ typer.echo(
442
+ f"Error: Unknown target '{target}'. Use: claude, codex, cursor, or all", err=True
443
+ )
294
444
  raise typer.Exit(1)
295
445
 
296
446
  for tool_name, dest_path in targets_to_install:
@@ -325,14 +475,15 @@ def skill_uninstall(
325
475
  "all",
326
476
  "--target",
327
477
  "-t",
328
- help="Target tool: claude, codex, or all",
478
+ help="Target tool: claude, codex, cursor, or all",
329
479
  ),
330
480
  ) -> None:
331
481
  """Uninstall the wafer-guide skill.
332
482
 
333
483
  Examples:
334
- wafer skill uninstall # Uninstall from both
484
+ wafer skill uninstall # Uninstall from all tools
335
485
  wafer skill uninstall -t claude # Uninstall from Claude Code only
486
+ wafer skill uninstall -t cursor # Uninstall from Cursor only
336
487
  """
337
488
  targets_to_uninstall: list[tuple[str, Path]] = []
338
489
 
@@ -346,9 +497,16 @@ def skill_uninstall(
346
497
  "Codex CLI",
347
498
  Path.home() / ".codex" / "skills" / "wafer-guide",
348
499
  ))
500
+ if target in ("all", "cursor"):
501
+ targets_to_uninstall.append((
502
+ "Cursor",
503
+ Path.home() / ".cursor" / "skills" / "wafer-guide",
504
+ ))
349
505
 
350
506
  if not targets_to_uninstall:
351
- typer.echo(f"Error: Unknown target '{target}'. Use: claude, codex, or all", err=True)
507
+ typer.echo(
508
+ f"Error: Unknown target '{target}'. Use: claude, codex, cursor, or all", err=True
509
+ )
352
510
  raise typer.Exit(1)
353
511
 
354
512
  for tool_name, dest_path in targets_to_uninstall:
@@ -383,6 +541,7 @@ def skill_status() -> None:
383
541
  installations = [
384
542
  ("Claude Code", Path.home() / ".claude" / "skills" / "wafer-guide"),
385
543
  ("Codex CLI", Path.home() / ".codex" / "skills" / "wafer-guide"),
544
+ ("Cursor", Path.home() / ".cursor" / "skills" / "wafer-guide"),
386
545
  ]
387
546
 
388
547
  for tool_name, path in installations:
@@ -396,6 +555,122 @@ def skill_status() -> None:
396
555
  typer.echo(f"{tool_name}: Not installed")
397
556
 
398
557
 
558
+ # =============================================================================
559
+ # Provider auth management (wafer auth ...)
560
+ # =============================================================================
561
+
562
+ provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
563
+ app.add_typer(provider_auth_app, name="auth")
564
+
565
+
566
+ @provider_auth_app.command("login")
567
+ def provider_auth_login(
568
+ provider: str = typer.Argument(
569
+ ...,
570
+ help="Provider name: runpod, digitalocean, or modal",
571
+ ),
572
+ api_key: str | None = typer.Option(
573
+ None,
574
+ "--api-key",
575
+ "-k",
576
+ help="API key (if not provided, reads from stdin)",
577
+ ),
578
+ ) -> None:
579
+ """Save API key for a cloud GPU provider.
580
+
581
+ Stores the key in ~/.wafer/auth.json. Environment variables
582
+ (e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
583
+
584
+ Examples:
585
+ wafer auth login runpod --api-key rp_xxx
586
+ wafer auth login digitalocean --api-key dop_v1_xxx
587
+ echo $API_KEY | wafer auth login runpod
588
+ """
589
+ import sys
590
+
591
+ from wafer_core.auth import PROVIDERS, save_api_key
592
+
593
+ # Validate provider
594
+ if provider not in PROVIDERS:
595
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
596
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
597
+ raise typer.Exit(1)
598
+
599
+ # Get API key from option or stdin
600
+ if api_key is None:
601
+ if sys.stdin.isatty():
602
+ typer.echo(f"Enter API key for {PROVIDERS[provider]['display_name']}:")
603
+ api_key = typer.prompt("API key", hide_input=True)
604
+ else:
605
+ api_key = sys.stdin.read().strip()
606
+
607
+ if not api_key:
608
+ typer.echo("Error: No API key provided", err=True)
609
+ raise typer.Exit(1)
610
+
611
+ # Save the key
612
+ save_api_key(provider, api_key)
613
+ typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
614
+ typer.echo("Stored in: ~/.wafer/auth.json")
615
+
616
+
617
+ @provider_auth_app.command("logout")
618
+ def provider_auth_logout(
619
+ provider: str = typer.Argument(
620
+ ...,
621
+ help="Provider name: runpod, digitalocean, or modal",
622
+ ),
623
+ ) -> None:
624
+ """Remove stored API key for a cloud GPU provider.
625
+
626
+ Examples:
627
+ wafer auth logout runpod
628
+ wafer auth logout digitalocean
629
+ """
630
+ from wafer_core.auth import PROVIDERS, remove_api_key
631
+
632
+ # Validate provider
633
+ if provider not in PROVIDERS:
634
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
635
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
636
+ raise typer.Exit(1)
637
+
638
+ if remove_api_key(provider):
639
+ typer.echo(f"API key removed for {PROVIDERS[provider]['display_name']}")
640
+ else:
641
+ typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
642
+
643
+
644
+ @provider_auth_app.command("status")
645
+ def provider_auth_status() -> None:
646
+ """Show authentication status for all cloud GPU providers.
647
+
648
+ Displays which providers have API keys configured and where
649
+ the keys are coming from (environment variable or auth.json).
650
+
651
+ Example:
652
+ wafer auth status
653
+ """
654
+ from wafer_core.auth import get_all_auth_status
655
+
656
+ statuses = get_all_auth_status()
657
+
658
+ typer.echo("Cloud GPU Provider Authentication Status")
659
+ typer.echo("=" * 45)
660
+
661
+ for status in statuses:
662
+ if status.is_authenticated:
663
+ source_str = f"({status.source})" if status.source else ""
664
+ typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
665
+ else:
666
+ typer.echo(f" {status.display_name}: ✗ Not configured")
667
+ typer.echo(f" Run: wafer auth login {status.provider}")
668
+ typer.echo(f" Or set: {status.key_url}")
669
+
670
+ typer.echo("")
671
+ typer.echo("Note: Environment variables take precedence over stored keys.")
672
+
673
+
399
674
  @app.command(hidden=True)
400
675
  def run(
401
676
  command: str = typer.Argument(..., help="Command to run in Docker container"),
@@ -975,6 +1250,11 @@ def agent( # noqa: PLR0913
975
1250
  "--list-sessions",
976
1251
  help="List recent sessions and exit",
977
1252
  ),
1253
+ get_session: str | None = typer.Option(
1254
+ None,
1255
+ "--get-session",
1256
+ help="Get session by ID and print messages (use with --json)",
1257
+ ),
978
1258
  tools: str | None = typer.Option(
979
1259
  None,
980
1260
  "--tools",
@@ -1021,47 +1301,7 @@ def agent( # noqa: PLR0913
1021
1301
  None,
1022
1302
  "--corpus",
1023
1303
  "-c",
1024
- help="Documentation corpus to use (cuda, cutlass, hip). Must be downloaded first.",
1025
- ),
1026
- # Legacy kernel optimization options (hidden, for backwards compat)
1027
- problem: Path | None = typer.Option(
1028
- None,
1029
- "--problem",
1030
- hidden=True,
1031
- help="[Legacy] Path to problem YAML config file",
1032
- ),
1033
- reference: Path | None = typer.Option(
1034
- None,
1035
- "--reference",
1036
- "--ref",
1037
- hidden=True,
1038
- help="[Legacy] Path to reference kernel file",
1039
- ),
1040
- description: str | None = typer.Option(
1041
- None,
1042
- "--description",
1043
- "--desc",
1044
- hidden=True,
1045
- help="[Legacy] Problem description",
1046
- ),
1047
- test: list[str] | None = typer.Option(
1048
- None,
1049
- "--test",
1050
- hidden=True,
1051
- help="[Legacy] Test case",
1052
- ),
1053
- benchmark: list[str] | None = typer.Option(
1054
- None,
1055
- "--benchmark",
1056
- "-b",
1057
- hidden=True,
1058
- help="[Legacy] Benchmark case",
1059
- ),
1060
- speedup_target: float | None = typer.Option(
1061
- None,
1062
- "--speedup",
1063
- hidden=True,
1064
- help="[Legacy] Speedup target",
1304
+ help="Documentation corpus to use (cuda, cutlass, hip, amd). Must be downloaded first.",
1065
1305
  ),
1066
1306
  ) -> None:
1067
1307
  """AI assistant for GPU kernel development.
@@ -1148,20 +1388,15 @@ def agent( # noqa: PLR0913
1148
1388
  prompt=actual_prompt,
1149
1389
  interactive=use_tui,
1150
1390
  single_turn=single_turn,
1151
- problem=str(problem) if problem else None,
1152
- reference=str(reference) if reference else None,
1153
- description=description,
1154
- tests=list(test) if test else None,
1155
- benchmarks=list(benchmark) if benchmark else None,
1156
1391
  model=model,
1157
- max_turns=max_turns,
1158
- speedup_target=speedup_target,
1159
1392
  resume=resume,
1160
1393
  from_turn=from_turn,
1161
1394
  list_sessions=list_sessions,
1395
+ get_session=get_session,
1162
1396
  tools=tools.split(",") if tools else None,
1163
1397
  allow_spawn=allow_spawn,
1164
1398
  max_tool_fails=max_tool_fails,
1399
+ max_turns=max_turns,
1165
1400
  json_output=json_output,
1166
1401
  template=template,
1167
1402
  template_args=parsed_template_args,
@@ -1171,7 +1406,7 @@ def agent( # noqa: PLR0913
1171
1406
 
1172
1407
  # =============================================================================
1173
1408
  # Evaluate command
1174
- # Hidden aliases for backwards compatibility
1409
+ # Hidden aliases for agent command
1175
1410
  def _make_agent_alias(name: str, doc: str) -> None:
1176
1411
  """Create a hidden alias that delegates to agent()."""
1177
1412
 
@@ -1186,6 +1421,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
1186
1421
  resume: str | None = typer.Option(None, "--resume", "-r"),
1187
1422
  from_turn: int | None = typer.Option(None, "--from-turn"),
1188
1423
  list_sessions: bool = typer.Option(False, "--list-sessions"),
1424
+ get_session: str | None = typer.Option(None, "--get-session"),
1189
1425
  tools: str | None = typer.Option(None, "--tools"),
1190
1426
  allow_spawn: bool = typer.Option(False, "--allow-spawn"),
1191
1427
  max_tool_fails: int | None = typer.Option(None, "--max-tool-fails"),
@@ -1195,12 +1431,6 @@ def _make_agent_alias(name: str, doc: str) -> None:
1195
1431
  template: str | None = typer.Option(None, "--template", "-t"),
1196
1432
  template_args: list[str] | None = typer.Option(None, "--args"),
1197
1433
  corpus: str | None = typer.Option(None, "--corpus"),
1198
- problem: Path | None = typer.Option(None, "--problem", hidden=True),
1199
- reference: Path | None = typer.Option(None, "--reference", hidden=True),
1200
- description: str | None = typer.Option(None, "--description", hidden=True),
1201
- test: list[Path] | None = typer.Option(None, "--test", hidden=True),
1202
- benchmark: list[Path] | None = typer.Option(None, "--benchmark", hidden=True),
1203
- speedup_target: float | None = typer.Option(None, "--speedup-target", hidden=True),
1204
1434
  ) -> None:
1205
1435
  agent(
1206
1436
  prompt=prompt,
@@ -1210,6 +1440,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
1210
1440
  resume=resume,
1211
1441
  from_turn=from_turn,
1212
1442
  list_sessions=list_sessions,
1443
+ get_session=get_session,
1213
1444
  tools=tools,
1214
1445
  allow_spawn=allow_spawn,
1215
1446
  max_tool_fails=max_tool_fails,
@@ -1219,12 +1450,6 @@ def _make_agent_alias(name: str, doc: str) -> None:
1219
1450
  template=template,
1220
1451
  template_args=template_args,
1221
1452
  corpus=corpus,
1222
- problem=problem,
1223
- reference=reference,
1224
- description=description,
1225
- test=test,
1226
- benchmark=benchmark,
1227
- speedup_target=speedup_target,
1228
1453
  )
1229
1454
 
1230
1455
  alias_cmd.__doc__ = doc
@@ -1289,86 +1514,37 @@ def evaluate( # noqa: PLR0913
1289
1514
  --benchmark --defensive
1290
1515
 
1291
1516
  Subcommands:
1292
- make-template Generate template files for this format
1517
+ gpumode Use GPUMode format (functional) - RECOMMENDED
1293
1518
  kernelbench Use KernelBench format (ModelNew class)
1519
+ make-template Generate template files for this format (deprecated)
1294
1520
  """
1295
1521
  # If a subcommand is being invoked, skip the main evaluation logic
1296
1522
  if ctx.invoked_subcommand is not None:
1297
1523
  return
1298
1524
 
1299
- # Validate required args when running evaluation (not subcommands)
1300
- missing_args = []
1301
- if implementation is None:
1302
- missing_args.append("--impl/-i")
1303
- if reference is None:
1304
- missing_args.append("--reference")
1305
- if test_cases is None:
1306
- missing_args.append("--test-cases")
1307
-
1308
- if missing_args:
1309
- typer.echo("Error: Missing required arguments", err=True)
1310
- typer.echo(f" Required: {', '.join(missing_args)}", err=True)
1311
- typer.echo("", err=True)
1312
- typer.echo(
1313
- "Usage: wafer evaluate --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
1314
- err=True,
1315
- )
1316
- typer.echo("", err=True)
1317
- typer.echo("Run 'wafer evaluate --help' for full options.", err=True)
1318
- typer.echo("Run 'wafer evaluate make-template DIR' to generate starter files.", err=True)
1319
- raise typer.Exit(1)
1320
-
1321
- from .evaluate import EvaluateArgs, run_evaluate
1322
-
1323
- args = EvaluateArgs(
1324
- implementation=implementation,
1325
- reference=reference,
1326
- test_cases=test_cases,
1327
- target_name=target or "",
1328
- benchmark=benchmark,
1329
- profile=profile,
1330
- defensive=defensive,
1331
- sync_artifacts=sync_artifacts,
1332
- gpu_id=gpu_id,
1525
+ # Bare 'wafer evaluate' is no longer supported - must use subcommand
1526
+ typer.echo("Error: 'wafer evaluate' requires a subcommand.", err=True)
1527
+ typer.echo("", err=True)
1528
+ typer.echo("Available subcommands:", err=True)
1529
+ typer.echo(
1530
+ " gpumode Evaluate GPUMode format (custom_kernel/ref_kernel functions)", err=True
1333
1531
  )
1334
-
1335
- try:
1336
- # Use trio_asyncio to run async code that uses both trio and asyncio
1337
- # (AsyncSSHClient uses asyncssh which is asyncio-based, bridged via trio_asyncio)
1338
- import trio_asyncio
1339
-
1340
- result = trio_asyncio.run(run_evaluate, args)
1341
- except KeyboardInterrupt:
1342
- typer.echo("\nInterrupted by user", err=True)
1343
- raise typer.Exit(130) from None
1344
- except Exception as e:
1345
- # Unwrap ExceptionGroup (from Trio nurseries) to show actual error
1346
- if hasattr(e, "exceptions") and e.exceptions:
1347
- for exc in e.exceptions:
1348
- typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
1349
- else:
1350
- typer.echo(f"Error: {e}", err=True)
1351
- raise typer.Exit(1) from None
1352
-
1353
- # Print results
1354
- if result.success:
1355
- typer.echo("")
1356
- typer.echo("=" * 60)
1357
- status = "PASS" if result.all_correct else "FAIL"
1358
- typer.echo(f"Result: {status}")
1359
- score_pct = f"{result.correctness_score:.1%}"
1360
- typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
1361
- if result.geomean_speedup > 0:
1362
- typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
1363
- if result.artifact_path:
1364
- typer.echo(f"Artifacts: {result.artifact_path}")
1365
- typer.echo("=" * 60)
1366
-
1367
- if not result.all_correct:
1368
- raise typer.Exit(1)
1369
- else:
1370
- typer.echo(f"Error: {result.error_message}", err=True)
1371
- raise typer.Exit(1)
1532
+ typer.echo(" kernelbench Evaluate KernelBench format (ModelNew class)", err=True)
1533
+ typer.echo("", err=True)
1534
+ typer.echo("Examples:", err=True)
1535
+ typer.echo(
1536
+ " wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json",
1537
+ err=True,
1538
+ )
1539
+ typer.echo(
1540
+ " wafer evaluate kernelbench --impl impl.py --reference ref.py --benchmark", err=True
1541
+ )
1542
+ typer.echo("", err=True)
1543
+ typer.echo(
1544
+ "Run 'wafer evaluate gpumode --help' or 'wafer evaluate kernelbench --help' for options.",
1545
+ err=True,
1546
+ )
1547
+ raise typer.Exit(1)
1372
1548
 
1373
1549
 
1374
1550
  TEMPLATE_KERNEL = '''\
@@ -1503,12 +1679,63 @@ def evaluate_make_template(
1503
1679
  # KernelBench format evaluation
1504
1680
  # =============================================================================
1505
1681
 
1506
- # Path to KernelBench problems (relative to wafer root)
1507
- KERNELBENCH_ROOT = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench"
1682
+
1683
+ def _get_kernelbench_root() -> Path | None:
1684
+ """Get KernelBench problems root, preferring downloaded location."""
1685
+ # First check downloaded location
1686
+ downloaded = get_problems_path("kernelbench")
1687
+ if downloaded is not None:
1688
+ kb_root = downloaded / "KernelBench"
1689
+ if kb_root.exists():
1690
+ return kb_root
1691
+ return downloaded
1692
+
1693
+ # Fall back to legacy location (for development)
1694
+ legacy = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench" / "KernelBench"
1695
+ if legacy.exists():
1696
+ return legacy
1697
+
1698
+ return None
1699
+
1700
+
1701
+ @kernelbench_app.command("download")
1702
+ def kernelbench_download(
1703
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
1704
+ ) -> None:
1705
+ """Download KernelBench problems from GitHub.
1706
+
1707
+ Downloads the problem set to ~/.cache/wafer/problems/kernelbench/
1708
+
1709
+ Examples:
1710
+ wafer evaluate kernelbench download
1711
+ wafer evaluate kernelbench download --force # Re-download
1712
+ """
1713
+ try:
1714
+ path = download_problems("kernelbench", force=force, verbose=True)
1715
+ typer.echo("")
1716
+ typer.echo(f"Problems available at: {path}")
1717
+ typer.echo("Run 'wafer evaluate kernelbench list-problems' to see available problems.")
1718
+ except Exception as e:
1719
+ typer.echo(f"Error downloading problems: {e}", err=True)
1720
+ raise typer.Exit(1) from None
1721
+
1722
+
1723
+ @kernelbench_app.command("list-problems")
1724
+ def kernelbench_list_problems() -> None:
1725
+ """List available KernelBench problems.
1726
+
1727
+ Examples:
1728
+ wafer evaluate kernelbench list-problems
1729
+ """
1730
+ try:
1731
+ list_problems_fn("kernelbench", verbose=True)
1732
+ except ValueError as e:
1733
+ typer.echo(str(e), err=True)
1734
+ raise typer.Exit(1) from None
1508
1735
 
1509
1736
 
1510
1737
  @kernelbench_app.callback(invoke_without_command=True)
1511
- def kernelbench_evaluate( # noqa: PLR0913
1738
+ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
1512
1739
  ctx: typer.Context,
1513
1740
  implementation: Path | None = typer.Option(
1514
1741
  None,
@@ -1528,17 +1755,38 @@ def kernelbench_evaluate( # noqa: PLR0913
1528
1755
  help="GPU target name. See 'wafer config targets list' for available targets.",
1529
1756
  autocompletion=complete_target_name,
1530
1757
  ),
1758
+ pool: str | None = typer.Option(
1759
+ None,
1760
+ "--pool",
1761
+ "-p",
1762
+ help="Target pool name. Acquires first available target from the pool. "
1763
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
1764
+ ),
1531
1765
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
1532
1766
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
1533
- inputs: Path | None = typer.Option(None, "--inputs", help="Custom inputs file to override get_inputs()"),
1767
+ inputs: Path | None = typer.Option(
1768
+ None, "--inputs", help="Custom inputs file to override get_inputs()"
1769
+ ),
1534
1770
  seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
1535
1771
  defensive: bool = typer.Option(
1536
1772
  False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
1537
1773
  ),
1774
+ backend: str | None = typer.Option(
1775
+ None,
1776
+ "--backend",
1777
+ help="Kernel backend for static validation (hip, cuda, triton, cute, tilelang, thunderkittens). "
1778
+ "When specified, validates that the implementation uses the correct backend primitives.",
1779
+ ),
1538
1780
  sync_artifacts: bool = typer.Option(
1539
1781
  True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
1540
1782
  ),
1541
1783
  gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
1784
+ json_output: bool = typer.Option(
1785
+ False, "--json", help="Output as single JSON object (machine-readable)"
1786
+ ),
1787
+ jsonl_output: bool = typer.Option(
1788
+ False, "--jsonl", help="Output as streaming JSON Lines (one object per event)"
1789
+ ),
1542
1790
  ) -> None:
1543
1791
  """Run kernel evaluation in KernelBench format (ModelNew class).
1544
1792
 
@@ -1588,48 +1836,106 @@ def kernelbench_evaluate( # noqa: PLR0913
1588
1836
  )
1589
1837
  raise typer.Exit(1)
1590
1838
 
1839
+ # Validate --target and --pool are mutually exclusive
1840
+ if target and pool:
1841
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
1842
+ raise typer.Exit(1)
1843
+
1591
1844
  from .evaluate import KernelBenchEvaluateArgs, run_evaluate_kernelbench
1845
+ from .output import OutputCollector, format_evaluate_result, get_output_format
1846
+
1847
+ output_format = get_output_format(json_output, jsonl_output)
1848
+ collector = OutputCollector(format=output_format)
1849
+
1850
+ # If pool specified, acquire a target from the pool
1851
+ resolved_target = target or ""
1852
+ pool_lock_context = None
1853
+
1854
+ if pool:
1855
+ from .target_lock import acquire_from_pool
1856
+ from .targets import filter_pool_by_auth, get_pool
1857
+
1858
+ try:
1859
+ pool_targets = get_pool(pool)
1860
+ except FileNotFoundError as e:
1861
+ collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
1862
+ collector.finalize()
1863
+ raise typer.Exit(1) from None
1864
+
1865
+ # Filter to only targets with valid auth
1866
+ usable_targets, skipped = filter_pool_by_auth(pool_targets)
1867
+ if skipped:
1868
+ collector.emit("pool_auth_skip", targets=skipped)
1869
+
1870
+ if not usable_targets:
1871
+ collector.set_error("pool", "NoUsableTargets", pool=pool)
1872
+ collector.finalize()
1873
+ raise typer.Exit(1) from None
1874
+
1875
+ collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
1876
+ pool_lock_context = acquire_from_pool(usable_targets)
1877
+ acquired_target = pool_lock_context.__enter__()
1878
+
1879
+ if acquired_target is None:
1880
+ # Exit context manager before raising to avoid resource leak
1881
+ pool_lock_context.__exit__(None, None, None)
1882
+ collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
1883
+ collector.finalize()
1884
+ raise typer.Exit(1)
1885
+
1886
+ collector.emit("pool_acquired", target=acquired_target)
1887
+ resolved_target = acquired_target
1888
+
1889
+ collector.target = resolved_target
1592
1890
 
1593
1891
  args = KernelBenchEvaluateArgs(
1594
1892
  implementation=implementation,
1595
1893
  reference=reference,
1596
- target_name=target or "",
1894
+ target_name=resolved_target,
1597
1895
  benchmark=benchmark,
1598
1896
  profile=profile,
1599
1897
  inputs=inputs,
1600
1898
  seed=seed,
1601
1899
  defensive=defensive,
1900
+ backend=backend,
1602
1901
  sync_artifacts=sync_artifacts,
1603
1902
  gpu_id=gpu_id,
1604
1903
  )
1605
1904
 
1905
+ collector.emit("started", target=resolved_target)
1906
+
1606
1907
  try:
1607
1908
  import trio_asyncio
1608
1909
 
1910
+ collector.emit("evaluation", status="running")
1609
1911
  result = trio_asyncio.run(run_evaluate_kernelbench, args)
1610
1912
  except KeyboardInterrupt:
1611
- typer.echo("\nInterrupted by user", err=True)
1913
+ collector.set_error("evaluation", "Interrupted", message="Interrupted by user")
1914
+ collector.finalize()
1612
1915
  raise typer.Exit(130) from None
1613
1916
  except Exception as e:
1614
- typer.echo(f"Error: {e}", err=True)
1917
+ collector.set_error("evaluation", "Exception", message=str(e))
1918
+ collector.finalize()
1615
1919
  raise typer.Exit(1) from None
1920
+ finally:
1921
+ # Release pool lock if we acquired one
1922
+ if pool_lock_context is not None:
1923
+ pool_lock_context.__exit__(None, None, None)
1616
1924
 
1617
- # Print results
1925
+ # Build structured output
1926
+ eval_output = format_evaluate_result(result, target=resolved_target)
1927
+ collector._result = eval_output
1928
+
1929
+ # Print results based on output format
1618
1930
  if result.success:
1619
- typer.echo("")
1620
- typer.echo("=" * 60)
1621
- status = "PASS" if result.all_correct else "FAIL"
1622
- typer.echo(f"Result: {status}")
1623
- score_pct = f"{result.correctness_score:.1%}"
1624
- typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
1625
- if result.geomean_speedup > 0:
1626
- typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
1627
- typer.echo("=" * 60)
1931
+ collector.output_text_result(result)
1932
+ collector.finalize()
1628
1933
 
1629
1934
  if not result.all_correct:
1630
1935
  raise typer.Exit(1)
1631
1936
  else:
1632
- typer.echo(f"Error: {result.error_message}", err=True)
1937
+ collector.output_text_error(result.error_message or "Unknown error")
1938
+ collector.finalize()
1633
1939
  raise typer.Exit(1)
1634
1940
 
1635
1941
 
@@ -1659,7 +1965,14 @@ def kernelbench_make_template(
1659
1965
  # Overwrite existing
1660
1966
  wafer evaluate kernelbench make-template level1/1 --force
1661
1967
  """
1662
- # Parse problem ID
1968
+ # Get problems root (downloaded or legacy)
1969
+ kb_root = _get_kernelbench_root()
1970
+ if kb_root is None:
1971
+ typer.echo("Error: KernelBench problems not found.", err=True)
1972
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1973
+ raise typer.Exit(1)
1974
+
1975
+ # Parse problem ID
1663
1976
  parts = problem.split("/")
1664
1977
  if len(parts) != 2:
1665
1978
  typer.echo(f"Error: Invalid problem ID '{problem}'. Expected format: level1/1", err=True)
@@ -1670,10 +1983,10 @@ def kernelbench_make_template(
1670
1983
  level_str = f"level{level_str}"
1671
1984
 
1672
1985
  # Find the problem file
1673
- problem_dir = KERNELBENCH_ROOT / "KernelBench" / level_str
1986
+ problem_dir = kb_root / level_str
1674
1987
  if not problem_dir.exists():
1675
1988
  typer.echo(f"Error: KernelBench level directory not found: {problem_dir}", err=True)
1676
- typer.echo(f"Make sure KernelBench is at: {KERNELBENCH_ROOT}", err=True)
1989
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1677
1990
  raise typer.Exit(1)
1678
1991
 
1679
1992
  # Find matching problem file
@@ -1740,6 +2053,306 @@ def kernelbench_make_template(
1740
2053
  typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
1741
2054
 
1742
2055
 
2056
+ # =============================================================================
2057
+ # GPUMode format evaluation
2058
+ # =============================================================================
2059
+
2060
+
2061
+ @gpumode_app.command("download")
2062
+ def gpumode_download(
2063
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
2064
+ ) -> None:
2065
+ """Download GPUMode reference kernels from GitHub.
2066
+
2067
+ Downloads the problem set to ~/.cache/wafer/problems/gpumode/
2068
+
2069
+ Examples:
2070
+ wafer evaluate gpumode download
2071
+ wafer evaluate gpumode download --force # Re-download
2072
+ """
2073
+ try:
2074
+ path = download_problems("gpumode", force=force, verbose=True)
2075
+ typer.echo("")
2076
+ typer.echo(f"Problems available at: {path}")
2077
+ typer.echo("Run 'wafer evaluate gpumode list-problems' to see available problems.")
2078
+ except Exception as e:
2079
+ typer.echo(f"Error downloading problems: {e}", err=True)
2080
+ raise typer.Exit(1) from None
2081
+
2082
+
2083
+ @gpumode_app.command("list-problems")
2084
+ def gpumode_list_problems() -> None:
2085
+ """List available GPUMode problems.
2086
+
2087
+ Examples:
2088
+ wafer evaluate gpumode list-problems
2089
+ """
2090
+ try:
2091
+ list_problems_fn("gpumode", verbose=True)
2092
+ except ValueError as e:
2093
+ typer.echo(str(e), err=True)
2094
+ raise typer.Exit(1) from None
2095
+
2096
+
2097
+ @gpumode_app.command("make-template")
2098
+ def gpumode_make_template(
2099
+ problem: str = typer.Option(
2100
+ ...,
2101
+ "--problem",
2102
+ "-p",
2103
+ help="Problem ID (e.g., 'pmpp/vectoradd_py' or 'amd/fp8-mm')",
2104
+ ),
2105
+ output: Path = typer.Option(
2106
+ None, "--output", "-o", help="Output directory (default: ./<problem_name>/)"
2107
+ ),
2108
+ force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
2109
+ ) -> None:
2110
+ """Extract a GPUMode problem as template files.
2111
+
2112
+ Creates a directory with reference.py, task.yml, and other problem files.
2113
+ You then create kernel.py with your custom_kernel implementation.
2114
+
2115
+ Examples:
2116
+ # Extract pmpp vectoradd problem
2117
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py
2118
+
2119
+ # Extract to specific directory
2120
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py --output ./my-kernel/
2121
+ """
2122
+ import shutil
2123
+
2124
+ # Get problem path
2125
+ problem_path = get_problem_path("gpumode", problem)
2126
+ if problem_path is None:
2127
+ # Check if problems are downloaded
2128
+ if get_problems_path("gpumode") is None:
2129
+ typer.echo("Error: GPUMode problems not downloaded.", err=True)
2130
+ typer.echo("Run 'wafer evaluate gpumode download' first.", err=True)
2131
+ else:
2132
+ typer.echo(f"Error: Problem '{problem}' not found.", err=True)
2133
+ typer.echo(
2134
+ "Run 'wafer evaluate gpumode list-problems' to see available problems.", err=True
2135
+ )
2136
+ raise typer.Exit(1)
2137
+
2138
+ # Determine output path
2139
+ if output is None:
2140
+ output = Path.cwd() / problem.replace("/", "_")
2141
+
2142
+ output = output.resolve()
2143
+
2144
+ # Check if exists
2145
+ if output.exists() and not force:
2146
+ typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
2147
+ raise typer.Exit(1)
2148
+
2149
+ # Copy the problem directory
2150
+ if output.exists():
2151
+ shutil.rmtree(output)
2152
+ shutil.copytree(problem_path, output)
2153
+
2154
+ typer.echo(f"Created {output}/")
2155
+ typer.echo("")
2156
+ typer.echo("Contents:")
2157
+ for f in sorted(output.iterdir()):
2158
+ if not f.name.startswith("."):
2159
+ typer.echo(f" {f.name}")
2160
+ typer.echo("")
2161
+ typer.echo("Next steps:")
2162
+ typer.echo(" 1. Read reference.py to understand the kernel interface")
2163
+ typer.echo(" 2. Create kernel.py with your custom_kernel implementation:")
2164
+ typer.echo("")
2165
+ typer.echo(" def custom_kernel(data):")
2166
+ typer.echo(" # Your optimized implementation")
2167
+ typer.echo(" ...")
2168
+ typer.echo("")
2169
+ typer.echo(" 3. Run evaluation:")
2170
+ typer.echo(
2171
+ f" wafer evaluate gpumode --impl {output}/kernel.py --reference {output}/reference.py \\"
2172
+ )
2173
+ typer.echo(f" --test-cases {output}/test_cases.json --target <target>")
2174
+
2175
+
2176
+ @gpumode_app.callback(invoke_without_command=True)
2177
+ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2178
+ ctx: typer.Context,
2179
+ implementation: Path | None = typer.Option(
2180
+ None, "--impl", "-i", help="Path to implementation kernel file"
2181
+ ),
2182
+ reference: Path | None = typer.Option(
2183
+ None, "--reference", help="Path to reference kernel file"
2184
+ ),
2185
+ test_cases: Path | None = typer.Option(
2186
+ None, "--test-cases", help="Path to test cases JSON file"
2187
+ ),
2188
+ target: str | None = typer.Option(
2189
+ None,
2190
+ "--target",
2191
+ "-t",
2192
+ help="GPU target name. See 'wafer config targets list' for available targets.",
2193
+ autocompletion=complete_target_name,
2194
+ ),
2195
+ pool: str | None = typer.Option(
2196
+ None,
2197
+ "--pool",
2198
+ "-p",
2199
+ help="Target pool name. Acquires first available target from the pool. "
2200
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
2201
+ ),
2202
+ benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
2203
+ profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
2204
+ defensive: bool = typer.Option(
2205
+ False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
2206
+ ),
2207
+ sync_artifacts: bool = typer.Option(
2208
+ True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
2209
+ ),
2210
+ gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
2211
+ ) -> None:
2212
+ """Run kernel evaluation in GPUMode format (functional).
2213
+
2214
+ This format expects:
2215
+ - Implementation: Python file with `custom_kernel(inputs)` function
2216
+ - Reference: Python file with `ref_kernel(inputs)` and `generate_input(**kwargs)` functions
2217
+ - Test cases: JSON file with test parameters
2218
+
2219
+ Examples:
2220
+ # Basic correctness check
2221
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
2222
+
2223
+ # With benchmarking
2224
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
2225
+ --target vultr-b200 --benchmark
2226
+
2227
+ Subcommands:
2228
+ download Download GPUMode problems from GitHub
2229
+ list-problems List available problems
2230
+ make-template Extract a problem as template files
2231
+ """
2232
+ # If a subcommand is being invoked, skip the main evaluation logic
2233
+ if ctx.invoked_subcommand is not None:
2234
+ return
2235
+
2236
+ # Validate required args when running evaluation (not subcommands)
2237
+ missing_args = []
2238
+ if implementation is None:
2239
+ missing_args.append("--impl/-i")
2240
+ if reference is None:
2241
+ missing_args.append("--reference")
2242
+ if test_cases is None:
2243
+ missing_args.append("--test-cases")
2244
+
2245
+ if missing_args:
2246
+ typer.echo("Error: Missing required arguments", err=True)
2247
+ typer.echo(f" Required: {', '.join(missing_args)}", err=True)
2248
+ typer.echo("", err=True)
2249
+ typer.echo(
2250
+ "Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
2251
+ err=True,
2252
+ )
2253
+ typer.echo("", err=True)
2254
+ typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
2255
+ typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
2256
+ raise typer.Exit(1)
2257
+
2258
+ # Validate --target and --pool are mutually exclusive
2259
+ if target and pool:
2260
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
2261
+ raise typer.Exit(1)
2262
+
2263
+ from .evaluate import EvaluateArgs, run_evaluate
2264
+
2265
+ # If pool specified, acquire a target from the pool
2266
+ resolved_target = target or ""
2267
+ pool_lock_context = None
2268
+
2269
+ if pool:
2270
+ from .target_lock import acquire_from_pool
2271
+ from .targets import filter_pool_by_auth, get_pool
2272
+
2273
+ try:
2274
+ pool_targets = get_pool(pool)
2275
+ except FileNotFoundError as e:
2276
+ typer.echo(f"Error: {e}", err=True)
2277
+ raise typer.Exit(1) from None
2278
+
2279
+ # Filter to only targets with valid auth
2280
+ usable_targets, skipped = filter_pool_by_auth(pool_targets)
2281
+ if skipped:
2282
+ typer.echo(f"Skipping targets without auth: {', '.join(skipped)}", err=True)
2283
+
2284
+ if not usable_targets:
2285
+ typer.echo(f"Error: No usable targets in pool '{pool}'", err=True)
2286
+ typer.echo(" All targets require authentication that is not configured.", err=True)
2287
+ typer.echo(" Run 'wafer auth status' to see which providers need setup.", err=True)
2288
+ raise typer.Exit(1) from None
2289
+
2290
+ typer.echo(f"Acquiring target from pool '{pool}' ({len(usable_targets)} targets)...")
2291
+ pool_lock_context = acquire_from_pool(usable_targets)
2292
+ acquired_target = pool_lock_context.__enter__()
2293
+
2294
+ if acquired_target is None:
2295
+ # Exit context manager before raising to avoid resource leak
2296
+ pool_lock_context.__exit__(None, None, None)
2297
+ typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
2298
+ typer.echo(f" Targets: {', '.join(usable_targets)}", err=True)
2299
+ raise typer.Exit(1)
2300
+
2301
+ typer.echo(f"Acquired target: {acquired_target}")
2302
+ resolved_target = acquired_target
2303
+
2304
+ args = EvaluateArgs(
2305
+ implementation=implementation,
2306
+ reference=reference,
2307
+ test_cases=test_cases,
2308
+ target_name=resolved_target,
2309
+ benchmark=benchmark,
2310
+ profile=profile,
2311
+ defensive=defensive,
2312
+ sync_artifacts=sync_artifacts,
2313
+ gpu_id=gpu_id,
2314
+ )
2315
+
2316
+ try:
2317
+ import trio_asyncio
2318
+
2319
+ result = trio_asyncio.run(run_evaluate, args)
2320
+ except KeyboardInterrupt:
2321
+ typer.echo("\nInterrupted by user", err=True)
2322
+ raise typer.Exit(130) from None
2323
+ except Exception as e:
2324
+ if hasattr(e, "exceptions") and e.exceptions:
2325
+ for exc in e.exceptions:
2326
+ typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
2327
+ else:
2328
+ typer.echo(f"Error: {e}", err=True)
2329
+ raise typer.Exit(1) from None
2330
+ finally:
2331
+ # Release pool lock if we acquired one
2332
+ if pool_lock_context is not None:
2333
+ pool_lock_context.__exit__(None, None, None)
2334
+
2335
+ # Print results
2336
+ if result.success:
2337
+ typer.echo("")
2338
+ typer.echo("=" * 60)
2339
+ status = "PASS" if result.all_correct else "FAIL"
2340
+ typer.echo(f"Result: {status}")
2341
+ score_pct = f"{result.correctness_score:.1%}"
2342
+ typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
2343
+ if result.geomean_speedup > 0:
2344
+ typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
2345
+ if result.artifact_path:
2346
+ typer.echo(f"Artifacts: {result.artifact_path}")
2347
+ typer.echo("=" * 60)
2348
+
2349
+ if not result.all_correct:
2350
+ raise typer.Exit(1)
2351
+ else:
2352
+ typer.echo(f"Error: {result.error_message}", err=True)
2353
+ raise typer.Exit(1)
2354
+
2355
+
1743
2356
  # =============================================================================
1744
2357
  # Push and Remote-Run commands
1745
2358
  # =============================================================================
@@ -1871,7 +2484,7 @@ def _run_direct_mode(
1871
2484
  typer.echo(f"Uploading {upload_dir.name}...")
1872
2485
  try:
1873
2486
  push_result = push_direct(upload_dir, target)
1874
- workspace_name = push_result.workspace_path
2487
+ workspace_name = push_result.workspace_name
1875
2488
  typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
1876
2489
  except Exception as e:
1877
2490
  typer.echo(f"Error uploading: {e}", err=True)
@@ -1901,6 +2514,7 @@ def _run_api_mode( # noqa: PLR0913
1901
2514
  upload_dir: Path | None,
1902
2515
  workspace_id: str | None,
1903
2516
  gpu_id: int | None,
2517
+ gpu_count: int,
1904
2518
  docker_image: str | None,
1905
2519
  docker_entrypoint: str | None,
1906
2520
  pull_image: bool,
@@ -1915,6 +2529,8 @@ def _run_api_mode( # noqa: PLR0913
1915
2529
  typer.echo(f"Workspace: {workspace_id}")
1916
2530
  if gpu_id is not None:
1917
2531
  typer.echo(f"GPU: {gpu_id}")
2532
+ if gpu_count > 1:
2533
+ typer.echo(f"GPU count: {gpu_count}")
1918
2534
  if docker_image:
1919
2535
  typer.echo(f"Image: {docker_image}")
1920
2536
  if docker_entrypoint:
@@ -1932,6 +2548,7 @@ def _run_api_mode( # noqa: PLR0913
1932
2548
  upload_dir=upload_dir,
1933
2549
  workspace_id=workspace_id,
1934
2550
  gpu_id=gpu_id,
2551
+ gpu_count=gpu_count,
1935
2552
  docker_image=docker_image,
1936
2553
  docker_entrypoint=docker_entrypoint,
1937
2554
  pull_image=pull_image,
@@ -1955,6 +2572,7 @@ def remote_run( # noqa: PLR0913
1955
2572
  None, "--workspace-id", "-w", help="Workspace ID (from wafer push)"
1956
2573
  ),
1957
2574
  gpu_id: int | None = typer.Option(None, "--gpu", "-g", help="GPU ID"),
2575
+ gpu_count: int = typer.Option(1, "--gpu-count", "-n", help="Number of GPUs (1-8)"),
1958
2576
  docker_image: str | None = typer.Option(None, "--image", "-i", help="Docker image override"),
1959
2577
  docker_entrypoint: str | None = typer.Option(
1960
2578
  None, "--docker-entrypoint", help="Override Docker entrypoint (e.g., 'bash')"
@@ -2024,6 +2642,7 @@ def remote_run( # noqa: PLR0913
2024
2642
  upload_dir,
2025
2643
  workspace_id,
2026
2644
  gpu_id,
2645
+ gpu_count,
2027
2646
  docker_image,
2028
2647
  docker_entrypoint,
2029
2648
  pull_image,
@@ -2044,27 +2663,41 @@ def login(
2044
2663
  None, "--token", "-t", help="Access token (skip browser OAuth)"
2045
2664
  ),
2046
2665
  port: int | None = typer.Option(
2047
- None, "--port", "-p", help="Port for OAuth callback server (default: 8765 for SSH, random for local)"
2666
+ None,
2667
+ "--port",
2668
+ "-p",
2669
+ help="Port for OAuth callback server (local only, ignored for SSH)",
2670
+ ),
2671
+ no_device_code: bool = typer.Option(
2672
+ False,
2673
+ "--no-device-code",
2674
+ help="Force browser OAuth even on SSH (requires port forwarding)",
2048
2675
  ),
2049
2676
  ) -> None:
2050
2677
  """Authenticate CLI with wafer-api via GitHub OAuth.
2051
2678
 
2052
- Opens browser for GitHub authentication. Use --token to skip browser.
2679
+ Local: Opens browser for GitHub authentication.
2680
+ SSH: Uses device code flow (no port forwarding needed).
2681
+
2053
2682
  Uses the API environment from config (see 'wafer config show').
2054
2683
 
2055
- SSH Users:
2056
- - Automatically uses port 8765 (just set up port forwarding once)
2057
- - On local machine: ssh -L 8765:localhost:8765 user@host
2058
- - On remote machine: wafer login
2059
- - Browser opens locally, redirect works through tunnel
2684
+ SSH Users (Easiest):
2685
+ - Just run: wafer login
2686
+ - Visit the URL and enter the code shown
2687
+ - No port forwarding needed!
2688
+
2689
+ SSH with browser (Advanced):
2690
+ - Use --no-device-code to force browser flow
2691
+ - Requires: ssh -L 8765:localhost:8765 user@host
2060
2692
 
2061
2693
  Manual token option:
2062
2694
  - Visit auth.wafer.ai, authenticate, copy token from URL
2063
2695
  - Run: wafer login --token <paste-token>
2064
2696
 
2065
2697
  Examples:
2066
- wafer login # auto-detects SSH, uses appropriate port
2067
- wafer login --port 9000 # override port
2698
+ wafer login # device code on SSH, browser on local
2699
+ wafer login --no-device-code # force browser (needs port forwarding on SSH)
2700
+ wafer login --port 9000 # custom port for browser flow
2068
2701
  wafer login --token xyz # manual token (no browser)
2069
2702
 
2070
2703
  # Change environment:
@@ -2073,7 +2706,7 @@ def login(
2073
2706
  """
2074
2707
  import httpx
2075
2708
 
2076
- from .auth import browser_login, save_credentials, verify_token
2709
+ from .auth import browser_login, device_code_login, save_credentials, verify_token
2077
2710
  from .global_config import get_api_url, get_supabase_url, load_global_config
2078
2711
 
2079
2712
  # Show which environment we're logging into
@@ -2083,21 +2716,31 @@ def login(
2083
2716
  typer.echo(f"Auth: {get_supabase_url()}")
2084
2717
  typer.echo("")
2085
2718
 
2086
- # Auto-detect SSH and use fixed port
2087
- if port is None:
2088
- is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
2089
- if is_ssh:
2090
- port = 8765
2091
- typer.echo("🔒 SSH session detected - using port 8765 for OAuth callback")
2092
- typer.echo(" Make sure you have port forwarding set up:")
2093
- typer.echo(" ssh -L 8765:localhost:8765 user@host")
2094
- typer.echo("")
2719
+ # Auto-detect SSH
2720
+ is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
2095
2721
 
2096
- # Browser OAuth if no token provided
2722
+ # Choose auth method
2097
2723
  refresh_token = None
2098
2724
  if token is None:
2099
2725
  try:
2100
- token, refresh_token = browser_login(port=port)
2726
+ if is_ssh and not no_device_code:
2727
+ # Use device code flow for SSH (no port forwarding needed)
2728
+ typer.echo("🔒 SSH session detected - using device code authentication")
2729
+ typer.echo(" (No port forwarding required!)")
2730
+ typer.echo("")
2731
+ token, refresh_token = device_code_login()
2732
+ else:
2733
+ # Use browser OAuth for local or if explicitly requested
2734
+ if is_ssh:
2735
+ typer.echo("🔒 SSH session detected - using browser authentication")
2736
+ typer.echo(" Make sure you have port forwarding set up:")
2737
+ if port is None:
2738
+ port = 8765
2739
+ typer.echo(f" ssh -L {port}:localhost:{port} user@host")
2740
+ else:
2741
+ typer.echo(f" ssh -L {port}:localhost:{port} user@host")
2742
+ typer.echo("")
2743
+ token, refresh_token = browser_login(port=port)
2101
2744
  except TimeoutError as e:
2102
2745
  typer.echo(f"Error: {e}", err=True)
2103
2746
  raise typer.Exit(1) from None
@@ -2146,9 +2789,8 @@ def login(
2146
2789
  @app.command("logout")
2147
2790
  def logout() -> None:
2148
2791
  """Remove stored credentials."""
2149
- from .auth import clear_credentials
2150
-
2151
2792
  from . import analytics
2793
+ from .auth import clear_credentials
2152
2794
 
2153
2795
  # Track logout event first (while credentials still exist for user identification)
2154
2796
  # Note: track_logout() handles the case where user is not logged in
@@ -2625,6 +3267,7 @@ init_app = typer.Typer(
2625
3267
 
2626
3268
  Choose based on your GPU access:
2627
3269
 
3270
+ local GPU on current machine (no SSH)
2628
3271
  ssh Your own hardware via SSH
2629
3272
  runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
2630
3273
  digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
@@ -2632,57 +3275,143 @@ Choose based on your GPU access:
2632
3275
  targets_app.add_typer(init_app, name="init")
2633
3276
 
2634
3277
 
2635
- @init_app.command("runpod")
2636
- def init_runpod(
2637
- name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
2638
- gpu_type: str = typer.Option("MI300X", "--gpu", "-g", help="GPU type (MI300X, H100, A100)"),
2639
- ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
2640
- keep_alive: bool = typer.Option(
2641
- True, "--keep-alive/--no-keep-alive", help="Keep pod running after eval"
2642
- ),
3278
+ @init_app.command("local")
3279
+ def init_local(
3280
+ name: str = typer.Option("local", "--name", "-n", help="Target name"),
3281
+ gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
2643
3282
  ) -> None:
2644
- """Initialize a RunPod target.
3283
+ """Initialize a local target for GPU on current machine.
2645
3284
 
2646
- Creates a target config for auto-provisioned RunPod GPUs.
2647
- Requires WAFER_RUNPOD_API_KEY environment variable.
3285
+ Detects your local GPU and configures a target for direct execution
3286
+ (no SSH). Use this when running wafer on the same machine as the GPU.
2648
3287
 
2649
3288
  Examples:
2650
- wafer config targets init runpod
2651
- wafer config targets init runpod --name my-runpod --gpu H100
3289
+ wafer config targets init local
3290
+ wafer config targets init local --name my-5090 --gpu-ids 0,1
2652
3291
  """
2653
- import os
2654
-
2655
3292
  from .targets import save_target
2656
3293
 
2657
- # Check for API key
2658
- api_key = os.environ.get("WAFER_RUNPOD_API_KEY", "")
2659
- if not api_key:
2660
- typer.echo("Error: WAFER_RUNPOD_API_KEY environment variable not set.", err=True)
2661
- typer.echo("", err=True)
2662
- typer.echo("Get your API key from: https://runpod.io/console/user/settings", err=True)
2663
- typer.echo("Then run: export WAFER_RUNPOD_API_KEY=your_key_here", err=True)
2664
- raise typer.Exit(1)
3294
+ # Parse GPU IDs
3295
+ try:
3296
+ parsed_gpu_ids = [int(g.strip()) for g in gpu_ids.split(",")]
3297
+ except ValueError:
3298
+ typer.echo(f"Error: Invalid GPU IDs '{gpu_ids}'. Use comma-separated integers.", err=True)
3299
+ raise typer.Exit(1) from None
2665
3300
 
2666
- # GPU type mappings
2667
- gpu_configs = {
2668
- "MI300X": {
2669
- "gpu_type_id": "AMD Instinct MI300X OAM",
2670
- "image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
2671
- "compute_capability": "9.4",
2672
- },
2673
- "H100": {
2674
- "gpu_type_id": "NVIDIA H100 80GB HBM3",
2675
- "image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
2676
- "compute_capability": "9.0",
2677
- },
2678
- "A100": {
2679
- "gpu_type_id": "NVIDIA A100 80GB PCIe",
2680
- "image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
2681
- "compute_capability": "8.0",
2682
- },
2683
- }
3301
+ typer.echo("Detecting local GPU...")
2684
3302
 
2685
- if gpu_type not in gpu_configs:
3303
+ try:
3304
+ from wafer_core.gpu_detect import (
3305
+ detect_local_gpu,
3306
+ get_compute_capability,
3307
+ get_torch_requirements,
3308
+ )
3309
+
3310
+ detected_gpu = detect_local_gpu()
3311
+
3312
+ if detected_gpu:
3313
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3314
+ if detected_gpu.vendor == "nvidia":
3315
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3316
+ else:
3317
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3318
+ typer.echo(f" GPU count: {detected_gpu.gpu_count}")
3319
+
3320
+ # Get torch requirements and compute capability
3321
+ torch_reqs = get_torch_requirements(detected_gpu)
3322
+ compute_capability = get_compute_capability(detected_gpu)
3323
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3324
+
3325
+ typer.echo(f" PyTorch: {torch_reqs.packages[0]}")
3326
+ else:
3327
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)", err=True)
3328
+ raise typer.Exit(1)
3329
+
3330
+ except ImportError as e:
3331
+ typer.echo(f"Error: Missing dependency: {e}", err=True)
3332
+ raise typer.Exit(1) from None
3333
+
3334
+ # Build target data
3335
+ target_data = {
3336
+ "name": name,
3337
+ "type": "local",
3338
+ "gpu_ids": parsed_gpu_ids,
3339
+ "gpu_type": gpu_type,
3340
+ "compute_capability": compute_capability,
3341
+ "torch_package": torch_reqs.packages[0],
3342
+ "torch_index_url": torch_reqs.index_url,
3343
+ "vendor": detected_gpu.vendor,
3344
+ "driver_version": detected_gpu.driver_version,
3345
+ }
3346
+
3347
+ try:
3348
+ target = save_target(target_data)
3349
+ typer.echo(f"✓ Created target: {target.name}")
3350
+ typer.echo(" Type: Local (no SSH)")
3351
+ typer.echo(f" GPU IDs: {parsed_gpu_ids}")
3352
+ typer.echo(f" GPU Type: {gpu_type}")
3353
+ typer.echo(f" Compute: {compute_capability}")
3354
+ typer.echo(f" Torch: {torch_reqs.packages[0]}")
3355
+ typer.echo("")
3356
+ typer.echo(
3357
+ f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
3358
+ )
3359
+ except (ValueError, AssertionError) as e:
3360
+ typer.echo(f"Error: {e}", err=True)
3361
+ raise typer.Exit(1) from None
3362
+
3363
+
3364
+ @init_app.command("runpod")
3365
+ def init_runpod(
3366
+ name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
3367
+ gpu_type: str = typer.Option("MI300X", "--gpu", "-g", help="GPU type (MI300X, H100, A100)"),
3368
+ ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
3369
+ keep_alive: bool = typer.Option(
3370
+ True, "--keep-alive/--no-keep-alive", help="Keep pod running after eval"
3371
+ ),
3372
+ ) -> None:
3373
+ """Initialize a RunPod target.
3374
+
3375
+ Creates a target config for auto-provisioned RunPod GPUs.
3376
+ Requires WAFER_RUNPOD_API_KEY environment variable.
3377
+
3378
+ Examples:
3379
+ wafer config targets init runpod
3380
+ wafer config targets init runpod --name my-runpod --gpu H100
3381
+ """
3382
+ import os
3383
+
3384
+ from .targets import save_target
3385
+
3386
+ # Check for API key
3387
+ api_key = os.environ.get("WAFER_RUNPOD_API_KEY", "")
3388
+ if not api_key:
3389
+ typer.echo("Error: WAFER_RUNPOD_API_KEY environment variable not set.", err=True)
3390
+ typer.echo("", err=True)
3391
+ typer.echo("Get your API key from: https://runpod.io/console/user/settings", err=True)
3392
+ typer.echo("Then run: export WAFER_RUNPOD_API_KEY=your_key_here", err=True)
3393
+ raise typer.Exit(1)
3394
+
3395
+ # GPU type mappings
3396
+ gpu_configs = {
3397
+ "MI300X": {
3398
+ "gpu_type_id": "AMD Instinct MI300X OAM",
3399
+ "image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
3400
+ "compute_capability": "9.4",
3401
+ },
3402
+ "H100": {
3403
+ "gpu_type_id": "NVIDIA H100 80GB HBM3",
3404
+ "image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
3405
+ "compute_capability": "9.0",
3406
+ },
3407
+ "A100": {
3408
+ "gpu_type_id": "NVIDIA A100 80GB PCIe",
3409
+ "image": "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04",
3410
+ "compute_capability": "8.0",
3411
+ },
3412
+ }
3413
+
3414
+ if gpu_type not in gpu_configs:
2686
3415
  typer.echo(
2687
3416
  f"Error: Unknown GPU type '{gpu_type}'. Available: {', '.join(gpu_configs.keys())}",
2688
3417
  err=True,
@@ -2795,23 +3524,29 @@ def init_ssh(
2795
3524
  host: str = typer.Option(..., "--host", "-H", help="SSH host (user@hostname:port)"),
2796
3525
  ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
2797
3526
  gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
2798
- gpu_type: str = typer.Option(
2799
- "H100", "--gpu-type", help="GPU type (H100, A100, B200, MI300X, etc.)"
3527
+ gpu_type: str | None = typer.Option(
3528
+ None, "--gpu-type", help="GPU type (auto-detected if not specified)"
2800
3529
  ),
2801
3530
  docker_image: str | None = typer.Option(
2802
3531
  None, "--docker-image", "-d", help="Docker image (optional)"
2803
3532
  ),
2804
3533
  ncu: bool = typer.Option(False, "--ncu/--no-ncu", help="NCU profiling available"),
3534
+ no_detect: bool = typer.Option(False, "--no-detect", help="Skip GPU auto-detection"),
2805
3535
  ) -> None:
2806
3536
  """Initialize an SSH target for your own GPU hardware.
2807
3537
 
2808
3538
  Creates a target config for direct SSH access to a GPU machine.
2809
- Use for baremetal servers, VMs, or any machine you have SSH access to.
3539
+ Automatically detects GPU type and selects compatible PyTorch version.
2810
3540
 
2811
3541
  Examples:
3542
+ # Auto-detect GPU (recommended)
2812
3543
  wafer config targets init ssh --name my-gpu --host user@192.168.1.100:22
3544
+
3545
+ # Multiple GPUs with NCU profiling
2813
3546
  wafer config targets init ssh --name lab-h100 --host ubuntu@gpu.lab.com:22 --gpu-ids 0,1 --ncu
2814
- wafer config targets init ssh --name docker-gpu --host user@host:22 --docker-image nvcr.io/nvidia/pytorch:24.01-py3
3547
+
3548
+ # Skip detection, specify manually
3549
+ wafer config targets init ssh --name my-gpu --host user@host:22 --gpu-type H100 --no-detect
2815
3550
  """
2816
3551
  from .targets import save_target
2817
3552
 
@@ -2828,17 +3563,86 @@ def init_ssh(
2828
3563
  typer.echo("Example: user@192.168.1.100:22", err=True)
2829
3564
  raise typer.Exit(1)
2830
3565
 
3566
+ # Auto-detect GPU if not specified
3567
+ detected_gpu = None
3568
+ torch_package = None
3569
+ torch_index_url = None
3570
+
3571
+ if not no_detect:
3572
+ typer.echo(f"Connecting to {host}...")
3573
+ try:
3574
+ import trio
3575
+ import trio_asyncio
3576
+ from wafer_core.async_ssh import AsyncSSHClient
3577
+ from wafer_core.gpu_detect import (
3578
+ detect_remote_gpu,
3579
+ get_compute_capability,
3580
+ get_torch_requirements,
3581
+ )
3582
+
3583
+ expanded_key = str(Path(ssh_key).expanduser())
3584
+
3585
+ async def _detect() -> None:
3586
+ nonlocal detected_gpu, torch_package, torch_index_url
3587
+ # Need trio_asyncio.open_loop() for asyncssh bridge
3588
+ async with trio_asyncio.open_loop():
3589
+ async with AsyncSSHClient(host, expanded_key) as client:
3590
+ detected_gpu = await detect_remote_gpu(client)
3591
+
3592
+ trio.run(_detect)
3593
+
3594
+ if detected_gpu:
3595
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3596
+ if detected_gpu.vendor == "nvidia":
3597
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3598
+ else:
3599
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3600
+
3601
+ # Get torch requirements
3602
+ torch_reqs = get_torch_requirements(detected_gpu)
3603
+ torch_package = torch_reqs.packages[0] # Just torch, not all packages
3604
+ torch_index_url = torch_reqs.index_url
3605
+ typer.echo(f" PyTorch: {torch_package}")
3606
+
3607
+ # Use detected GPU type if not specified
3608
+ if not gpu_type:
3609
+ # Extract GPU name (e.g., "H100" from "NVIDIA H100 80GB HBM3")
3610
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3611
+ else:
3612
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)")
3613
+ if not gpu_type:
3614
+ gpu_type = "H100" # Default fallback
3615
+ typer.echo(f" Using default: {gpu_type}")
3616
+
3617
+ except Exception as e:
3618
+ typer.echo(f" Detection failed: {e}", err=True)
3619
+ if not gpu_type:
3620
+ gpu_type = "H100"
3621
+ typer.echo(f" Using default: {gpu_type}")
3622
+
3623
+ # Fallback if no detection
3624
+ if not gpu_type:
3625
+ gpu_type = "H100"
3626
+
2831
3627
  # Compute capability mappings
2832
- compute_caps = {
2833
- "B200": "10.0",
2834
- "H100": "9.0",
2835
- "A100": "8.0",
2836
- "A10": "8.6",
2837
- "V100": "7.0",
2838
- "MI300X": "9.4",
2839
- "MI250X": "9.0",
2840
- }
2841
- compute_capability = compute_caps.get(gpu_type, "8.0")
3628
+ if detected_gpu:
3629
+ from wafer_core.gpu_detect import get_compute_capability
3630
+
3631
+ compute_capability = get_compute_capability(detected_gpu)
3632
+ else:
3633
+ compute_caps = {
3634
+ "B200": "10.0",
3635
+ "H100": "9.0",
3636
+ "A100": "8.0",
3637
+ "A10": "8.6",
3638
+ "V100": "7.0",
3639
+ "MI300X": "9.4",
3640
+ "MI250X": "9.0",
3641
+ "RTX 5090": "10.0",
3642
+ "RTX 4090": "8.9",
3643
+ "RTX 3090": "8.6",
3644
+ }
3645
+ compute_capability = compute_caps.get(gpu_type, "8.0")
2842
3646
 
2843
3647
  # Build target data
2844
3648
  target_data = {
@@ -2855,6 +3659,12 @@ def init_ssh(
2855
3659
  if docker_image:
2856
3660
  target_data["docker_image"] = docker_image
2857
3661
 
3662
+ # Add torch requirements if detected
3663
+ if torch_package:
3664
+ target_data["torch_package"] = torch_package
3665
+ if torch_index_url:
3666
+ target_data["torch_index_url"] = torch_index_url
3667
+
2858
3668
  try:
2859
3669
  target = save_target(target_data)
2860
3670
  typer.echo(f"✓ Created target: {target.name}")
@@ -2862,9 +3672,12 @@ def init_ssh(
2862
3672
  typer.echo(f" Host: {host}")
2863
3673
  typer.echo(f" GPU IDs: {parsed_gpu_ids}")
2864
3674
  typer.echo(f" GPU Type: {gpu_type}")
3675
+ typer.echo(f" Compute: {compute_capability}")
2865
3676
  typer.echo(f" NCU: {'Yes' if ncu else 'No'}")
2866
3677
  if docker_image:
2867
3678
  typer.echo(f" Docker: {docker_image}")
3679
+ if torch_package:
3680
+ typer.echo(f" Torch: {torch_package}")
2868
3681
  typer.echo("")
2869
3682
  typer.echo(
2870
3683
  f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
@@ -2874,6 +3687,44 @@ def init_ssh(
2874
3687
  raise typer.Exit(1) from None
2875
3688
 
2876
3689
 
3690
+ def _extract_gpu_type(gpu_name: str) -> str:
3691
+ """Extract GPU type from full GPU name.
3692
+
3693
+ Examples:
3694
+ "NVIDIA H100 80GB HBM3" -> "H100"
3695
+ "NVIDIA GeForce RTX 4090" -> "RTX 4090"
3696
+ "AMD Instinct MI300X OAM" -> "MI300X"
3697
+ """
3698
+ gpu_name_upper = gpu_name.upper()
3699
+
3700
+ # Check for known GPU types
3701
+ known_types = [
3702
+ "B200",
3703
+ "B100",
3704
+ "H200",
3705
+ "H100",
3706
+ "A100",
3707
+ "A10",
3708
+ "V100",
3709
+ "RTX 5090",
3710
+ "RTX 5080",
3711
+ "RTX 4090",
3712
+ "RTX 4080",
3713
+ "RTX 3090",
3714
+ "RTX 3080",
3715
+ "MI300X",
3716
+ "MI250X",
3717
+ "MI100",
3718
+ ]
3719
+
3720
+ for gpu_type in known_types:
3721
+ if gpu_type in gpu_name_upper:
3722
+ return gpu_type
3723
+
3724
+ # Fallback: return cleaned name
3725
+ return gpu_name.replace("NVIDIA ", "").replace("AMD ", "").strip()
3726
+
3727
+
2877
3728
  @targets_app.command("add")
2878
3729
  def targets_add(
2879
3730
  file_path: Path = typer.Argument(..., help="Path to target TOML file"),
@@ -2956,6 +3807,93 @@ def targets_show(
2956
3807
  raise typer.Exit(1) from None
2957
3808
 
2958
3809
 
3810
+ @targets_app.command("probe")
3811
+ def targets_probe(
3812
+ name: str = typer.Argument(..., help="Target name"),
3813
+ ) -> None:
3814
+ """Probe a target to discover available compilation backends.
3815
+
3816
+ Connects to the target and checks what's available:
3817
+ - Triton
3818
+ - torch.compile/inductor
3819
+ - HIP/hipcc or CUDA/nvcc
3820
+ - ROCm or CUDA version
3821
+ - Python packages (torch, triton, etc.)
3822
+
3823
+ Example:
3824
+ wafer config targets probe runpod-mi300x
3825
+ """
3826
+ import trio
3827
+
3828
+ from .targets import ProbeError, load_target, probe_target_capabilities
3829
+
3830
+ try:
3831
+ target = load_target(name)
3832
+ except FileNotFoundError as e:
3833
+ typer.echo(f"Error: {e}", err=True)
3834
+ raise typer.Exit(1) from None
3835
+
3836
+ typer.echo(f"Probing target: {name}...")
3837
+
3838
+ try:
3839
+ capabilities = trio.run(probe_target_capabilities, target)
3840
+ except ProbeError as e:
3841
+ # ProbeError already has actionable context
3842
+ typer.echo(f"\nError: {e}", err=True)
3843
+ raise typer.Exit(1) from None
3844
+ except Exception as e:
3845
+ # Unexpected errors - include type for debugging
3846
+ typer.echo(f"\nUnexpected error probing target: {type(e).__name__}: {e}", err=True)
3847
+ raise typer.Exit(1) from None
3848
+
3849
+ # Display results
3850
+ typer.echo(f"\nTarget: {name}")
3851
+
3852
+ if capabilities.get("gpu_name"):
3853
+ typer.echo(f" GPU: {capabilities['gpu_name']}")
3854
+ if capabilities.get("compute_capability"):
3855
+ typer.echo(f" Compute: {capabilities['compute_capability']}")
3856
+
3857
+ typer.echo("\n Compilation Backends:")
3858
+ backends = capabilities.get("backends", {})
3859
+
3860
+ # Triton
3861
+ triton_ver = backends.get("triton")
3862
+ if triton_ver:
3863
+ typer.echo(f" ✓ Triton: {triton_ver}")
3864
+ else:
3865
+ typer.echo(" ✗ Triton: not installed")
3866
+
3867
+ # torch.compile
3868
+ if triton_ver and backends.get("torch"):
3869
+ typer.echo(" ✓ torch.compile/inductor: available")
3870
+ else:
3871
+ typer.echo(" ✗ torch.compile/inductor: requires Triton")
3872
+
3873
+ # HIP/CUDA compiler
3874
+ if backends.get("hipcc"):
3875
+ typer.echo(f" ✓ HIP/hipcc: {backends['hipcc']}")
3876
+ elif backends.get("nvcc"):
3877
+ typer.echo(f" ✓ CUDA/nvcc: {backends['nvcc']}")
3878
+ else:
3879
+ typer.echo(" ✗ No GPU compiler found")
3880
+
3881
+ # ROCm/CUDA version
3882
+ if capabilities.get("rocm_version"):
3883
+ typer.echo(f" ROCm: {capabilities['rocm_version']}")
3884
+ if capabilities.get("cuda_version"):
3885
+ typer.echo(f" CUDA: {capabilities['cuda_version']}")
3886
+
3887
+ typer.echo("\n Python Environment:")
3888
+ typer.echo(f" Python: {capabilities.get('python_version', 'unknown')}")
3889
+
3890
+ packages = capabilities.get("packages", {})
3891
+ if packages.get("torch"):
3892
+ typer.echo(f" PyTorch: {packages['torch']}")
3893
+ if triton_ver:
3894
+ typer.echo(f" Triton: {triton_ver}")
3895
+
3896
+
2959
3897
  @targets_app.command("remove")
2960
3898
  def targets_remove(
2961
3899
  name: str = typer.Argument(..., help="Target name"),
@@ -3086,6 +4024,92 @@ def targets_pods() -> None:
3086
4024
  typer.echo()
3087
4025
 
3088
4026
 
4027
+ # ── Pool commands ───────────────────────────────────────────────────────────
4028
+
4029
+
4030
+ @targets_app.command("pool-list")
4031
+ def targets_pool_list() -> None:
4032
+ """List all configured target pools.
4033
+
4034
+ Example:
4035
+ wafer config targets pool-list
4036
+ """
4037
+ from .targets import get_pool, list_pools
4038
+
4039
+ pools = list_pools()
4040
+
4041
+ if not pools:
4042
+ typer.echo("No pools configured")
4043
+ typer.echo("")
4044
+ typer.echo("Define pools in ~/.wafer/config.toml:")
4045
+ typer.echo(" [pools.my-pool]")
4046
+ typer.echo(' targets = ["target-1", "target-2"]')
4047
+ return
4048
+
4049
+ typer.echo("Configured pools:\n")
4050
+ for pool_name in pools:
4051
+ try:
4052
+ targets = get_pool(pool_name)
4053
+ typer.echo(f" {pool_name}: {', '.join(targets)}")
4054
+ except Exception as e:
4055
+ typer.echo(f" {pool_name}: (error: {e})")
4056
+
4057
+
4058
+ @targets_app.command("pool-create")
4059
+ def targets_pool_create(
4060
+ name: str = typer.Argument(..., help="Pool name"),
4061
+ targets: list[str] = typer.Argument(..., help="Target names to include in pool"),
4062
+ ) -> None:
4063
+ """Create or update a target pool.
4064
+
4065
+ Example:
4066
+ wafer config targets pool-create mi300x-pool mi300x-1 mi300x-2 mi300x-3
4067
+ """
4068
+ from .targets import save_pool
4069
+
4070
+ try:
4071
+ save_pool(name, targets)
4072
+ typer.echo(f"Pool '{name}' created with {len(targets)} targets")
4073
+ except FileNotFoundError as e:
4074
+ typer.echo(f"Error: {e}", err=True)
4075
+ raise typer.Exit(1) from None
4076
+
4077
+
4078
+ @targets_app.command("pool-status")
4079
+ def targets_pool_status(
4080
+ name: str = typer.Argument(..., help="Pool name"),
4081
+ ) -> None:
4082
+ """Show status of targets in a pool (locked/available).
4083
+
4084
+ Example:
4085
+ wafer config targets pool-status mi300x-pool
4086
+ """
4087
+ from .target_lock import get_lock_holder, is_target_locked
4088
+ from .targets import get_pool
4089
+
4090
+ try:
4091
+ targets = get_pool(name)
4092
+ except FileNotFoundError as e:
4093
+ typer.echo(f"Error: {e}", err=True)
4094
+ raise typer.Exit(1) from None
4095
+
4096
+ typer.echo(f"Pool '{name}' ({len(targets)} targets):\n")
4097
+
4098
+ available = 0
4099
+ for target_name in targets:
4100
+ locked = is_target_locked(target_name)
4101
+ if locked:
4102
+ pid = get_lock_holder(target_name)
4103
+ pid_str = f" (pid {pid})" if pid else ""
4104
+ typer.echo(f" [busy] {target_name}{pid_str}")
4105
+ else:
4106
+ typer.echo(f" [free] {target_name}")
4107
+ available += 1
4108
+
4109
+ typer.echo("")
4110
+ typer.echo(f"Available: {available}/{len(targets)}")
4111
+
4112
+
3089
4113
  # =============================================================================
3090
4114
  # Billing commands
3091
4115
  # =============================================================================
@@ -3119,7 +4143,9 @@ def billing_usage(
3119
4143
  @billing_app.command("topup")
3120
4144
  def billing_topup(
3121
4145
  amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
3122
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
4146
+ no_browser: bool = typer.Option(
4147
+ False, "--no-browser", help="Print URL instead of opening browser"
4148
+ ),
3123
4149
  ) -> None:
3124
4150
  """Add credits to your account.
3125
4151
 
@@ -3165,7 +4191,9 @@ def billing_topup(
3165
4191
 
3166
4192
  @billing_app.command("portal")
3167
4193
  def billing_portal(
3168
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
4194
+ no_browser: bool = typer.Option(
4195
+ False, "--no-browser", help="Print URL instead of opening browser"
4196
+ ),
3169
4197
  ) -> None:
3170
4198
  """Open Stripe billing portal.
3171
4199
 
@@ -3198,6 +4226,81 @@ def billing_portal(
3198
4226
  raise typer.Exit(1) from None
3199
4227
 
3200
4228
 
4229
+ # =============================================================================
4230
+ # SSH Keys commands (BYOK - Bring Your Own Key)
4231
+ # =============================================================================
4232
+
4233
+
4234
+ @ssh_keys_app.command("list")
4235
+ def ssh_keys_list(
4236
+ json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
4237
+ ) -> None:
4238
+ """List all registered SSH public keys.
4239
+
4240
+ Example:
4241
+ wafer ssh-keys list
4242
+ wafer ssh-keys list --json
4243
+ """
4244
+ from .ssh_keys import list_ssh_keys
4245
+
4246
+ try:
4247
+ result = list_ssh_keys(json_output=json_output)
4248
+ typer.echo(result)
4249
+ except RuntimeError as e:
4250
+ typer.echo(f"Error: {e}", err=True)
4251
+ raise typer.Exit(1) from e
4252
+
4253
+
4254
+ @ssh_keys_app.command("add")
4255
+ def ssh_keys_add(
4256
+ pubkey_path: Path | None = typer.Argument(
4257
+ None, help="Path to public key file (auto-detects ~/.ssh/id_ed25519.pub if not specified)"
4258
+ ),
4259
+ name: str | None = typer.Option(None, "--name", "-n", help="Friendly name for the key"),
4260
+ json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
4261
+ ) -> None:
4262
+ """Add an SSH public key.
4263
+
4264
+ If no path is specified, auto-detects keys from ~/.ssh/ in preference order:
4265
+ id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
4266
+
4267
+ Example:
4268
+ wafer ssh-keys add # Auto-detect
4269
+ wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
4270
+ wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
4271
+ """
4272
+ from .ssh_keys import add_ssh_key
4273
+
4274
+ try:
4275
+ result = add_ssh_key(pubkey_path=pubkey_path, name=name, json_output=json_output)
4276
+ typer.echo(result)
4277
+ except RuntimeError as e:
4278
+ typer.echo(f"Error: {e}", err=True)
4279
+ raise typer.Exit(1) from e
4280
+
4281
+
4282
+ @ssh_keys_app.command("remove")
4283
+ def ssh_keys_remove(
4284
+ key_id: str = typer.Argument(..., help="UUID of the SSH key to remove"),
4285
+ json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
4286
+ ) -> None:
4287
+ """Remove an SSH public key.
4288
+
4289
+ Get the key ID from 'wafer ssh-keys list'.
4290
+
4291
+ Example:
4292
+ wafer ssh-keys remove abc123-def456-...
4293
+ """
4294
+ from .ssh_keys import remove_ssh_key
4295
+
4296
+ try:
4297
+ result = remove_ssh_key(key_id=key_id, json_output=json_output)
4298
+ typer.echo(result)
4299
+ except RuntimeError as e:
4300
+ typer.echo(f"Error: {e}", err=True)
4301
+ raise typer.Exit(1) from e
4302
+
4303
+
3201
4304
  # =============================================================================
3202
4305
  # Workspaces commands
3203
4306
  # =============================================================================
@@ -3226,21 +4329,34 @@ def workspaces_list(
3226
4329
  @workspaces_app.command("create")
3227
4330
  def workspaces_create(
3228
4331
  name: str = typer.Argument(..., help="Workspace name"),
3229
- gpu_type: str = typer.Option("B200", "--gpu", "-g", help="GPU type (default: B200)"),
4332
+ gpu_type: str = typer.Option("B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"),
3230
4333
  image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
4334
+ wait: bool = typer.Option(False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"),
3231
4335
  json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
3232
4336
  ) -> None:
3233
4337
  """Create a new workspace.
3234
4338
 
4339
+ Available GPUs:
4340
+ MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
4341
+ B200 NVIDIA Blackwell B200 (180GB HBM3e, CUDA)
4342
+
3235
4343
  Example:
3236
- wafer workspaces create my-kernel
3237
- wafer workspaces create my-kernel --gpu H100
4344
+ wafer workspaces create my-kernel # B200 (default)
4345
+ wafer workspaces create my-kernel --gpu MI300X # AMD MI300X
4346
+ wafer workspaces create my-kernel --gpu B200 # NVIDIA B200
3238
4347
  wafer workspaces create my-kernel --image pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
4348
+ wafer workspaces create my-kernel --wait
3239
4349
  """
3240
4350
  from .workspaces import create_workspace
3241
4351
 
3242
4352
  try:
3243
- result = create_workspace(name, gpu_type=gpu_type, image=image, json_output=json_output)
4353
+ result = create_workspace(
4354
+ name,
4355
+ gpu_type=gpu_type,
4356
+ image=image,
4357
+ wait=wait,
4358
+ json_output=json_output,
4359
+ )
3244
4360
  typer.echo(result)
3245
4361
  except RuntimeError as e:
3246
4362
  typer.echo(f"Error: {e}", err=True)
@@ -3250,16 +4366,23 @@ def workspaces_create(
3250
4366
  @workspaces_app.command("delete")
3251
4367
  def workspaces_delete(
3252
4368
  workspace_id: str = typer.Argument(..., help="Workspace ID to delete"),
4369
+ yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
3253
4370
  json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
3254
4371
  ) -> None:
3255
4372
  """Delete a workspace.
3256
4373
 
3257
4374
  Example:
3258
4375
  wafer workspaces delete ws_abc123
4376
+ wafer workspaces delete ws_abc123 -y
3259
4377
  """
3260
4378
  from .workspaces import delete_workspace
3261
4379
 
3262
4380
  try:
4381
+ if not yes:
4382
+ confirm = typer.confirm(f"Delete workspace '{workspace_id}'?")
4383
+ if not confirm:
4384
+ typer.echo("Cancelled.")
4385
+ raise typer.Exit(0)
3263
4386
  result = delete_workspace(workspace_id, json_output=json_output)
3264
4387
  typer.echo(result)
3265
4388
  except RuntimeError as e:
@@ -3267,32 +4390,6 @@ def workspaces_delete(
3267
4390
  raise typer.Exit(1) from None
3268
4391
 
3269
4392
 
3270
- @workspaces_app.command("attach")
3271
- def workspaces_attach(
3272
- workspace_id: str = typer.Argument(..., help="Workspace ID to attach to"),
3273
- json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
3274
- ) -> None:
3275
- """Attach to a workspace (get SSH credentials).
3276
-
3277
- This will:
3278
- 1. Start the workspace if needed
3279
- 2. Return SSH connection details
3280
- 3. Save the private key to ~/.wafer/keys/
3281
-
3282
- Example:
3283
- wafer workspaces attach ws_abc123
3284
- wafer workspaces attach ws_abc123 --json
3285
- """
3286
- from .workspaces import attach_workspace
3287
-
3288
- try:
3289
- result = attach_workspace(workspace_id, json_output=json_output)
3290
- typer.echo(result)
3291
- except RuntimeError as e:
3292
- typer.echo(f"Error: {e}", err=True)
3293
- raise typer.Exit(1) from None
3294
-
3295
-
3296
4393
  @workspaces_app.command("show")
3297
4394
  def workspaces_show(
3298
4395
  workspace_id: str = typer.Argument(..., help="Workspace ID to show"),
@@ -3314,12 +4411,19 @@ def workspaces_show(
3314
4411
  raise typer.Exit(1) from None
3315
4412
 
3316
4413
 
3317
- @workspaces_app.command("exec", context_settings={"allow_interspersed_args": False})
4414
+ @workspaces_app.command(
4415
+ "exec",
4416
+ context_settings={
4417
+ "allow_interspersed_args": False,
4418
+ "ignore_unknown_options": True,
4419
+ "allow_extra_args": True,
4420
+ },
4421
+ )
3318
4422
  def workspaces_exec(
4423
+ ctx: typer.Context,
3319
4424
  workspace: str | None = typer.Argument(
3320
4425
  None, help="Workspace name or ID (optional if default set)"
3321
4426
  ),
3322
- command: list[str] = typer.Argument(..., help="Command to execute on GPU"),
3323
4427
  timeout: int | None = typer.Option(
3324
4428
  None,
3325
4429
  "--timeout",
@@ -3332,17 +4436,30 @@ def workspaces_exec(
3332
4436
  "-s",
3333
4437
  help="Sync local directory to workspace before executing",
3334
4438
  ),
4439
+ gpu: bool = typer.Option(False, "--gpu", help="Force GPU routing (default behavior)"),
4440
+ cpu: bool = typer.Option(False, "--cpu", help="Run in workspace container (no GPU)"),
4441
+ baremetal: bool = typer.Option(
4442
+ False, "--baremetal", help="Force baremetal target (for hardware counters like ncu/nsys)"
4443
+ ),
4444
+ pull_image: bool = typer.Option(False, "--pull-image", help="Pull image on target if missing"),
3335
4445
  verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
3336
4446
  quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
3337
4447
  ) -> None:
3338
- """Execute a command in workspace with GPU routing.
4448
+ """Execute a command in workspace.
4449
+
4450
+ By default, auto-detects whether to route to GPU based on the command.
4451
+ Use --gpu, --cpu, or --baremetal to override.
3339
4452
 
3340
- Runs the command on the workspace's configured GPU target (Modal, baremetal, etc.)
3341
- and streams output back. No SSH or zsh plugin required.
4453
+ Routing options:
4454
+ --gpu Force GPU container (Modal or baremetal with GPU)
4455
+ --cpu Run in workspace container directly (no GPU)
4456
+ --baremetal Force baremetal target (for ncu, nsys, hardware counters)
3342
4457
 
3343
4458
  If workspace is not specified, uses the default workspace from config,
3344
4459
  or the only workspace if you have exactly one.
3345
4460
 
4461
+ IMPORTANT: Options must come before the workspace name.
4462
+
3346
4463
  Examples:
3347
4464
  wafer workspaces exec dev -- python train.py
3348
4465
  wafer workspaces exec dev -- python -c "import torch; print(torch.cuda.is_available())"
@@ -3353,6 +4470,49 @@ def workspaces_exec(
3353
4470
  from .global_config import get_defaults, get_preferences
3354
4471
  from .workspaces import exec_command, resolve_workspace, sync_files
3355
4472
 
4473
+ # Enforce option ordering to avoid treating CLI flags as remote commands
4474
+ known_options = {
4475
+ "--timeout",
4476
+ "-t",
4477
+ "--sync",
4478
+ "-s",
4479
+ "--gpu",
4480
+ "--cpu",
4481
+ "--baremetal",
4482
+ "--pull-image",
4483
+ "--verbose",
4484
+ "-v",
4485
+ "--quiet",
4486
+ "-q",
4487
+ "--help",
4488
+ "-h",
4489
+ }
4490
+ for arg in ctx.args:
4491
+ if arg == "--":
4492
+ break
4493
+ if arg in known_options:
4494
+ typer.echo(
4495
+ "Error: options must come before the workspace name. "
4496
+ "Example: wafer workspaces exec --pull-image dev -- python -V",
4497
+ err=True,
4498
+ )
4499
+ raise typer.Exit(1)
4500
+
4501
+ # Validate mutually exclusive routing flags
4502
+ routing_flags = sum([gpu, cpu, baremetal])
4503
+ if routing_flags > 1:
4504
+ typer.echo("Error: --gpu, --cpu, and --baremetal are mutually exclusive", err=True)
4505
+ raise typer.Exit(1)
4506
+
4507
+ # Determine routing (None = auto-detect)
4508
+ routing: str | None = None
4509
+ if gpu:
4510
+ routing = "gpu"
4511
+ elif cpu:
4512
+ routing = "cpu"
4513
+ elif baremetal:
4514
+ routing = "baremetal"
4515
+
3356
4516
  # Resolve workspace (specified, config default, or single workspace)
3357
4517
  try:
3358
4518
  resolved_workspace = resolve_workspace(workspace)
@@ -3377,7 +4537,8 @@ def workspaces_exec(
3377
4537
  show_status = prefs.mode == "explicit"
3378
4538
 
3379
4539
  if show_status:
3380
- typer.echo(f"[wafer] Workspace: {resolved_workspace}", err=True)
4540
+ routing_label = routing or "auto"
4541
+ typer.echo(f"[wafer] Workspace: {resolved_workspace} (routing: {routing_label})", err=True)
3381
4542
 
3382
4543
  # Sync files if requested
3383
4544
  if sync is not None:
@@ -3403,114 +4564,617 @@ def workspaces_exec(
3403
4564
  typer.echo(f"Error: {e}", err=True)
3404
4565
  raise typer.Exit(1) from None
3405
4566
 
4567
+ # Get command from context args (passthrough after --)
4568
+ import shlex
4569
+
4570
+ command = list(ctx.args)
4571
+ if command and command[0] == "--":
4572
+ command = command[1:]
4573
+
4574
+ if not command:
4575
+ typer.echo("Error: No command specified", err=True)
4576
+ raise typer.Exit(1)
4577
+
3406
4578
  if show_status:
3407
4579
  typer.echo(f"[wafer] Executing (timeout: {effective_timeout}s)...", err=True)
3408
4580
 
3409
- # Join command list into shell command string, stripping leading "--" separator
4581
+ # Build command string
4582
+ # Handle two cases:
4583
+ # 1. Single element: user quoted the whole command (e.g., "echo hello world")
4584
+ # -> use directly, don't re-quote
4585
+ # 2. Multiple elements: user passed separate args (e.g., -- python -c "print(1)")
4586
+ # -> use shlex.join to properly quote args with spaces
4587
+ if len(command) == 1:
4588
+ command_str = command[0]
4589
+ else:
4590
+ command_str = shlex.join(command)
4591
+
4592
+ try:
4593
+ exit_code = exec_command(
4594
+ workspace_id=resolved_workspace,
4595
+ command=command_str,
4596
+ timeout_seconds=effective_timeout,
4597
+ routing=routing,
4598
+ pull_image=pull_image,
4599
+ )
4600
+ except RuntimeError as e:
4601
+ typer.echo(f"Error: {e}", err=True)
4602
+ raise typer.Exit(1) from None
4603
+
4604
+ if show_status:
4605
+ typer.echo(f"[wafer] Exit code: {exit_code}", err=True)
4606
+
4607
+ raise typer.Exit(exit_code)
4608
+
4609
+
4610
+ @workspaces_app.command("ssh")
4611
+ def workspaces_ssh(
4612
+ workspace: str | None = typer.Argument(
4613
+ None, help="Workspace name or ID (optional if default set)"
4614
+ ),
4615
+ ) -> None:
4616
+ """SSH into a workspace.
4617
+
4618
+ Uses workspace SSH credentials once the workspace is running.
4619
+ If workspace is not specified, uses the default workspace.
4620
+
4621
+ Examples:
4622
+ wafer workspaces ssh dev
4623
+ wafer workspaces ssh # uses default workspace
4624
+ """
4625
+ import os
4626
+
4627
+ from .workspaces import get_workspace_raw, resolve_workspace
4628
+
4629
+ # Resolve workspace
4630
+ try:
4631
+ resolved_workspace = resolve_workspace(workspace)
4632
+ except RuntimeError as e:
4633
+ typer.echo(f"Error: {e}", err=True)
4634
+ raise typer.Exit(1) from None
4635
+
4636
+ typer.echo(f"Connecting to workspace: {resolved_workspace}...", err=True)
4637
+
4638
+ # Get SSH credentials from workspace
4639
+ try:
4640
+ ws = get_workspace_raw(resolved_workspace)
4641
+ except RuntimeError as e:
4642
+ typer.echo(f"Error: {e}", err=True)
4643
+ raise typer.Exit(1) from None
4644
+
4645
+ from .workspaces import VALID_STATUSES
4646
+
4647
+ workspace_status = ws.get("status")
4648
+ assert workspace_status in VALID_STATUSES, (
4649
+ f"Workspace {resolved_workspace} has invalid status '{workspace_status}'. "
4650
+ f"Valid statuses: {VALID_STATUSES}"
4651
+ )
4652
+
4653
+ if workspace_status != "running":
4654
+ typer.echo(f"Error: Workspace is {workspace_status}. Wait for it to be running.", err=True)
4655
+ raise typer.Exit(1)
4656
+ if not ws.get("ssh_host") or not ws.get("ssh_port") or not ws.get("ssh_user"):
4657
+ typer.echo("Error: SSH credentials not available yet.", err=True)
4658
+ raise typer.Exit(1)
4659
+
4660
+ # Build SSH args - key_path is None for BYOK model (uses default SSH key)
4661
+ ssh_args = ["ssh"]
4662
+ ssh_args.extend([
4663
+ "-p",
4664
+ str(ws.get("ssh_port")),
4665
+ "-o",
4666
+ "StrictHostKeyChecking=no",
4667
+ "-o",
4668
+ "UserKnownHostsFile=/dev/null",
4669
+ f"{ws.get('ssh_user')}@{ws.get('ssh_host')}",
4670
+ ])
4671
+
4672
+ # Replace current process with SSH
4673
+ os.execvp("ssh", ssh_args)
4674
+
4675
+
4676
+ @workspaces_app.command("sync")
4677
+ def workspaces_sync(
4678
+ workspace: str | None = typer.Argument(
4679
+ None, help="Workspace name or ID (optional if default set)"
4680
+ ),
4681
+ path: Path = typer.Argument(..., help="Local file or directory to sync"),
4682
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
4683
+ quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
4684
+ ) -> None:
4685
+ """Sync local files to workspace.
4686
+
4687
+ Uses rsync over SSH to sync files to the workspace's /workspace directory.
4688
+ If workspace is not specified, uses the default workspace.
4689
+
4690
+ Examples:
4691
+ wafer workspaces sync dev ./my-project
4692
+ wafer workspaces sync ./my-project # uses default workspace
4693
+ wafer workspaces sync dev . # sync current directory
4694
+ wafer workspaces sync dev ./script.py # sync single file
4695
+ """
4696
+ from .global_config import get_preferences
4697
+ from .workspaces import resolve_workspace, sync_files
4698
+
4699
+ # Determine verbosity based on mode
4700
+ prefs = get_preferences()
4701
+ if quiet:
4702
+ show_status = False
4703
+ elif verbose:
4704
+ show_status = True
4705
+ else:
4706
+ show_status = prefs.mode == "explicit"
4707
+
4708
+ # Validate path
4709
+ if not path.exists():
4710
+ typer.echo(f"Error: Path not found: {path}", err=True)
4711
+ raise typer.Exit(1)
4712
+
4713
+ # Resolve workspace
4714
+ try:
4715
+ resolved_workspace = resolve_workspace(workspace)
4716
+ except RuntimeError as e:
4717
+ typer.echo(f"Error: {e}", err=True)
4718
+ raise typer.Exit(1) from None
4719
+
4720
+ if show_status:
4721
+ typer.echo(f"[wafer] Syncing {path} to workspace {resolved_workspace}...", err=True)
4722
+
4723
+ def on_progress(msg: str) -> None:
4724
+ if show_status:
4725
+ typer.echo(f"[wafer] {msg}", err=True)
4726
+
4727
+ try:
4728
+ file_count, warning = sync_files(
4729
+ resolved_workspace, path.resolve(), on_progress=on_progress
4730
+ )
4731
+ except RuntimeError as e:
4732
+ typer.echo(f"Error: {e}", err=True)
4733
+ raise typer.Exit(1) from None
4734
+
4735
+
4736
+ # =============================================================================
4737
+ # Target operations commands (exec/ssh/sync)
4738
+ # =============================================================================
4739
+
4740
+
4741
+ @targets_ops_app.command("exec", context_settings={"allow_interspersed_args": False})
4742
+ def targets_exec(
4743
+ target: str = typer.Argument(
4744
+ ...,
4745
+ help="Target name",
4746
+ autocompletion=complete_target_name,
4747
+ ),
4748
+ command: list[str] = typer.Argument(..., help="Command to execute"),
4749
+ timeout: int | None = typer.Option(
4750
+ None,
4751
+ "--timeout",
4752
+ "-t",
4753
+ help="Execution timeout in seconds (default: 300)",
4754
+ ),
4755
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
4756
+ quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
4757
+ ) -> None:
4758
+ """Execute a command on a configured target.
4759
+
4760
+ Provisions the target if needed (RunPod, DigitalOcean), then runs the command via SSH.
4761
+ For cloud targets, the instance is kept alive after execution - use
4762
+ 'wafer config targets cleanup <name>' to terminate.
4763
+
4764
+ Supported targets: RunPod, DigitalOcean, SSH (baremetal/vm).
4765
+ Not supported: Modal (serverless), Local (no SSH), Workspace (use 'wafer workspaces exec').
4766
+
4767
+ Examples:
4768
+ wafer targets exec runpod-mi300x -- python -c "import torch; print(torch.cuda.is_available())"
4769
+ wafer targets exec runpod-mi300x -- rocm-smi
4770
+ wafer targets exec my-ssh-server -- nvidia-smi
4771
+ wafer targets exec runpod-mi300x "echo hello && ls -la" --timeout 60
4772
+ """
4773
+ from .global_config import get_preferences
4774
+ from .targets import load_target
4775
+ from .targets_ops import TargetExecError, exec_on_target_sync, get_target_ssh_info
4776
+
4777
+ # Determine verbosity
4778
+ prefs = get_preferences()
4779
+ if quiet:
4780
+ show_status = False
4781
+ elif verbose:
4782
+ show_status = True
4783
+ else:
4784
+ show_status = prefs.mode == "explicit"
4785
+
4786
+ # Load target
4787
+ try:
4788
+ target_config = load_target(target)
4789
+ except FileNotFoundError as e:
4790
+ typer.echo(f"Error: {e}", err=True)
4791
+ typer.echo("List available targets with: wafer config targets list", err=True)
4792
+ raise typer.Exit(1) from None
4793
+ except ValueError as e:
4794
+ typer.echo(f"Error loading target config: {e}", err=True)
4795
+ raise typer.Exit(1) from None
4796
+
4797
+ if show_status:
4798
+ typer.echo(f"[wafer] Target: {target} ({type(target_config).__name__})", err=True)
4799
+
4800
+ # Get SSH info (may provision)
4801
+ if show_status:
4802
+ typer.echo("[wafer] Connecting to target...", err=True)
4803
+
4804
+ try:
4805
+ ssh_info = trio.run(get_target_ssh_info, target_config)
4806
+ except TargetExecError as e:
4807
+ typer.echo(f"Error: {e}", err=True)
4808
+ raise typer.Exit(1) from None
4809
+
4810
+ if show_status:
4811
+ typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
4812
+
4813
+ # Build command string
3410
4814
  if isinstance(command, list):
3411
4815
  import shlex
3412
4816
 
3413
- # Remove leading "--" if present (typer passes it through with allow_interspersed_args=False)
4817
+ # Remove leading "--" if present
3414
4818
  if command and command[0] == "--":
3415
4819
  command = command[1:]
3416
- # Use shlex.join to properly quote args containing spaces/special chars
3417
- command_str = shlex.join(command)
4820
+
4821
+ if not command:
4822
+ typer.echo("Error: No command specified", err=True)
4823
+ raise typer.Exit(1)
4824
+
4825
+ if len(command) == 1:
4826
+ command_str = command[0]
4827
+ else:
4828
+ command_str = shlex.join(command)
3418
4829
  else:
3419
4830
  command_str = command
3420
4831
 
4832
+ # Default timeout
4833
+ effective_timeout = timeout if timeout is not None else 300
4834
+
4835
+ if show_status:
4836
+ typer.echo(f"[wafer] Executing (timeout: {effective_timeout}s)...", err=True)
4837
+
4838
+ # Execute
4839
+ try:
4840
+ exit_code = exec_on_target_sync(ssh_info, command_str, effective_timeout)
4841
+ except TargetExecError as e:
4842
+ typer.echo(f"Error: {e}", err=True)
4843
+ raise typer.Exit(1) from None
4844
+
4845
+ if show_status:
4846
+ typer.echo(f"[wafer] Exit code: {exit_code}", err=True)
4847
+
4848
+ raise typer.Exit(exit_code)
4849
+
4850
+
4851
+ @targets_ops_app.command("ssh")
4852
+ def targets_ssh(
4853
+ target: str = typer.Argument(
4854
+ ...,
4855
+ help="Target name",
4856
+ autocompletion=complete_target_name,
4857
+ ),
4858
+ ) -> None:
4859
+ """SSH into a configured target.
4860
+
4861
+ Provisions the target if needed (RunPod, DigitalOcean), then starts an interactive SSH session.
4862
+ For cloud targets, the instance is kept alive - use 'wafer config targets cleanup <name>' to terminate.
4863
+
4864
+ Examples:
4865
+ wafer targets ssh runpod-mi300x
4866
+ wafer targets ssh my-baremetal-server
4867
+ """
4868
+ from .targets import load_target
4869
+ from .targets_ops import TargetExecError, get_target_ssh_info
4870
+
4871
+ # Load target
4872
+ try:
4873
+ target_config = load_target(target)
4874
+ except FileNotFoundError as e:
4875
+ typer.echo(f"Error: {e}", err=True)
4876
+ typer.echo("List available targets with: wafer config targets list", err=True)
4877
+ raise typer.Exit(1) from None
4878
+ except ValueError as e:
4879
+ typer.echo(f"Error loading target config: {e}", err=True)
4880
+ raise typer.Exit(1) from None
4881
+
4882
+ typer.echo(f"Connecting to target: {target}...", err=True)
4883
+
4884
+ # Get SSH info (may provision)
4885
+ try:
4886
+ ssh_info = trio.run(get_target_ssh_info, target_config)
4887
+ except TargetExecError as e:
4888
+ typer.echo(f"Error: {e}", err=True)
4889
+ raise typer.Exit(1) from None
4890
+
4891
+ # Build SSH command
4892
+ ssh_args = [
4893
+ "ssh",
4894
+ "-i",
4895
+ str(ssh_info.key_path),
4896
+ "-p",
4897
+ str(ssh_info.port),
4898
+ "-o",
4899
+ "StrictHostKeyChecking=no",
4900
+ "-o",
4901
+ "UserKnownHostsFile=/dev/null",
4902
+ f"{ssh_info.user}@{ssh_info.host}",
4903
+ ]
4904
+
4905
+ # Replace current process with SSH
4906
+ os.execvp("ssh", ssh_args)
4907
+
4908
+
4909
+ @targets_ops_app.command("sync")
4910
+ def targets_sync(
4911
+ target: str = typer.Argument(
4912
+ ...,
4913
+ help="Target name",
4914
+ autocompletion=complete_target_name,
4915
+ ),
4916
+ path: Path = typer.Argument(..., help="Local file or directory to sync"),
4917
+ dest: str | None = typer.Option(
4918
+ None,
4919
+ "--dest",
4920
+ "-d",
4921
+ help="Remote destination path (default: /tmp/<basename>)",
4922
+ ),
4923
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
4924
+ quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
4925
+ ) -> None:
4926
+ """Sync local files to a configured target.
4927
+
4928
+ Uses rsync over SSH to copy files to the target. Provisions the target if needed.
4929
+
4930
+ Examples:
4931
+ wafer targets sync runpod-mi300x ./my-project
4932
+ wafer targets sync runpod-mi300x ./script.py --dest /workspace/script.py
4933
+ wafer targets sync my-server ./kernels --dest /tmp/kernels
4934
+ """
4935
+ from .global_config import get_preferences
4936
+ from .targets import load_target
4937
+ from .targets_ops import TargetExecError, get_target_ssh_info, sync_to_target
4938
+
4939
+ # Determine verbosity
4940
+ prefs = get_preferences()
4941
+ if quiet:
4942
+ show_status = False
4943
+ elif verbose:
4944
+ show_status = True
4945
+ else:
4946
+ show_status = prefs.mode == "explicit"
4947
+
4948
+ # Validate path
4949
+ if not path.exists():
4950
+ typer.echo(f"Error: Path not found: {path}", err=True)
4951
+ raise typer.Exit(1)
4952
+
4953
+ # Load target
4954
+ try:
4955
+ target_config = load_target(target)
4956
+ except FileNotFoundError as e:
4957
+ typer.echo(f"Error: {e}", err=True)
4958
+ typer.echo("List available targets with: wafer config targets list", err=True)
4959
+ raise typer.Exit(1) from None
4960
+ except ValueError as e:
4961
+ typer.echo(f"Error loading target config: {e}", err=True)
4962
+ raise typer.Exit(1) from None
4963
+
4964
+ if show_status:
4965
+ typer.echo(f"[wafer] Target: {target} ({type(target_config).__name__})", err=True)
4966
+
4967
+ # Get SSH info (may provision)
4968
+ if show_status:
4969
+ typer.echo("[wafer] Connecting to target...", err=True)
4970
+
4971
+ try:
4972
+ ssh_info = trio.run(get_target_ssh_info, target_config)
4973
+ except TargetExecError as e:
4974
+ typer.echo(f"Error: {e}", err=True)
4975
+ raise typer.Exit(1) from None
4976
+
4977
+ if show_status:
4978
+ typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
4979
+
4980
+ # Sync
4981
+ def on_progress(msg: str) -> None:
4982
+ if show_status:
4983
+ typer.echo(f"[wafer] {msg}", err=True)
4984
+
3421
4985
  try:
3422
- exit_code = exec_command(
3423
- workspace_id=resolved_workspace,
3424
- command=command_str,
3425
- timeout_seconds=effective_timeout,
3426
- )
3427
- except RuntimeError as e:
4986
+ file_count = sync_to_target(ssh_info, path.resolve(), dest, on_progress)
4987
+ except TargetExecError as e:
3428
4988
  typer.echo(f"Error: {e}", err=True)
3429
4989
  raise typer.Exit(1) from None
3430
4990
 
3431
4991
  if show_status:
3432
- typer.echo(f"[wafer] Exit code: {exit_code}", err=True)
3433
-
3434
- raise typer.Exit(exit_code)
4992
+ typer.echo(f"[wafer] Done. Synced {file_count} files.", err=True)
3435
4993
 
3436
4994
 
3437
- @workspaces_app.command("ssh")
3438
- def workspaces_ssh(
3439
- workspace: str | None = typer.Argument(
3440
- None, help="Workspace name or ID (optional if default set)"
3441
- ),
4995
+ @targets_ops_app.command("scp")
4996
+ def targets_scp(
4997
+ source: str = typer.Argument(..., help="Source path (prefix with target: for remote)"),
4998
+ dest: str = typer.Argument(..., help="Destination path (prefix with target: for remote)"),
4999
+ recursive: bool = typer.Option(False, "-r", "--recursive", help="Copy directories recursively"),
5000
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
5001
+ quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
3442
5002
  ) -> None:
3443
- """SSH into a workspace.
5003
+ """Copy files to/from a target using scp-style syntax.
3444
5004
 
3445
- Gets SSH credentials via attach, then execs into SSH.
3446
- If workspace is not specified, uses the default workspace.
5005
+ Use target: prefix to indicate remote paths. Exactly one of source or dest
5006
+ must be remote.
3447
5007
 
3448
5008
  Examples:
3449
- wafer workspaces ssh dev
3450
- wafer workspaces ssh # uses default workspace
5009
+ wafer targets scp runpod-mi300x:/tmp/trace.json ./trace.json # download
5010
+ wafer targets scp ./script.py runpod-mi300x:/tmp/script.py # upload
5011
+ wafer targets scp -r ./kernels runpod-mi300x:/tmp/kernels # upload dir
5012
+ wafer targets scp -r runpod-mi300x:/tmp/results ./results # download dir
3451
5013
  """
3452
- import os
5014
+ from .global_config import get_preferences
5015
+ from .targets import load_target
5016
+ from .targets_ops import TargetExecError, get_target_ssh_info, parse_scp_path, scp_transfer
5017
+
5018
+ # Determine verbosity
5019
+ prefs = get_preferences()
5020
+ if quiet:
5021
+ show_status = False
5022
+ elif verbose:
5023
+ show_status = True
5024
+ else:
5025
+ show_status = prefs.mode == "explicit"
3453
5026
 
3454
- from .workspaces import get_ssh_credentials, resolve_workspace
5027
+ # Parse source and dest
5028
+ source_target, source_path = parse_scp_path(source)
5029
+ dest_target, dest_path = parse_scp_path(dest)
3455
5030
 
3456
- # Resolve workspace
5031
+ # Validate: exactly one must be remote
5032
+ if source_target and dest_target:
5033
+ typer.echo("Error: Both paths are remote. Use ssh to transfer between remotes.", err=True)
5034
+ raise typer.Exit(1)
5035
+
5036
+ if not source_target and not dest_target:
5037
+ typer.echo("Error: Both paths are local. Use regular cp command.", err=True)
5038
+ raise typer.Exit(1)
5039
+
5040
+ # Determine direction and target
5041
+ is_download = source_target is not None
5042
+ target_name = source_target if is_download else dest_target
5043
+
5044
+ # Load target
3457
5045
  try:
3458
- resolved_workspace = resolve_workspace(workspace)
3459
- except RuntimeError as e:
3460
- typer.echo(f"Error: {e}", err=True)
5046
+ target_config = load_target(target_name)
5047
+ except FileNotFoundError:
5048
+ typer.echo(f"Error: Target '{target_name}' not found.", err=True)
5049
+ typer.echo("Run 'wafer config targets list' to see available targets.", err=True)
5050
+ raise typer.Exit(1) from None
5051
+ except ValueError as e:
5052
+ typer.echo(f"Error loading target config: {e}", err=True)
3461
5053
  raise typer.Exit(1) from None
3462
5054
 
3463
- typer.echo(f"Connecting to workspace: {resolved_workspace}...", err=True)
5055
+ # Validate local path exists (for upload)
5056
+ if not is_download:
5057
+ local_path = Path(source_path)
5058
+ if not local_path.exists():
5059
+ typer.echo(f"Error: Local path '{source_path}' does not exist.", err=True)
5060
+ raise typer.Exit(1)
5061
+ if local_path.is_dir() and not recursive:
5062
+ typer.echo(
5063
+ f"Error: '{source_path}' is a directory. Use -r flag for recursive copy.", err=True
5064
+ )
5065
+ raise typer.Exit(1)
3464
5066
 
3465
- # Get SSH credentials (this calls attach)
5067
+ if show_status:
5068
+ typer.echo(f"[wafer] Target: {target_name} ({type(target_config).__name__})", err=True)
5069
+ typer.echo("[wafer] Connecting to target...", err=True)
5070
+
5071
+ # Get SSH info (may provision)
3466
5072
  try:
3467
- creds = get_ssh_credentials(resolved_workspace)
3468
- except RuntimeError as e:
5073
+ ssh_info = trio.run(get_target_ssh_info, target_config)
5074
+ except TargetExecError as e:
3469
5075
  typer.echo(f"Error: {e}", err=True)
3470
5076
  raise typer.Exit(1) from None
3471
5077
 
3472
- # Exec into SSH - replaces this process
3473
- ssh_args = [
3474
- "ssh",
3475
- "-i",
3476
- str(creds.key_path),
3477
- "-p",
3478
- str(creds.port),
3479
- "-o",
3480
- "StrictHostKeyChecking=no",
3481
- "-o",
3482
- "UserKnownHostsFile=/dev/null",
3483
- f"{creds.user}@{creds.host}",
3484
- ]
5078
+ if show_status:
5079
+ typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
5080
+ direction = "Downloading" if is_download else "Uploading"
5081
+ typer.echo(f"[wafer] {direction}...", err=True)
3485
5082
 
3486
- # Replace current process with SSH
3487
- os.execvp("ssh", ssh_args)
5083
+ # Transfer
5084
+ try:
5085
+ if is_download:
5086
+ scp_transfer(ssh_info, source_path, dest_path, is_download=True, recursive=recursive)
5087
+ else:
5088
+ scp_transfer(ssh_info, source_path, dest_path, is_download=False, recursive=recursive)
5089
+ except TargetExecError as e:
5090
+ typer.echo(f"Error: {e}", err=True)
5091
+ raise typer.Exit(1) from None
3488
5092
 
5093
+ if show_status:
5094
+ typer.echo("[wafer] Done.", err=True)
3489
5095
 
3490
- @workspaces_app.command("sync")
3491
- def workspaces_sync(
3492
- workspace: str | None = typer.Argument(
3493
- None, help="Workspace name or ID (optional if default set)"
5096
+
5097
+ @targets_ops_app.command("ensure")
5098
+ def targets_ensure( # noqa: PLR0915
5099
+ target: str = typer.Argument(
5100
+ None,
5101
+ help="Target name",
5102
+ autocompletion=complete_target_name,
3494
5103
  ),
3495
- path: Path = typer.Argument(..., help="Local file or directory to sync"),
5104
+ tool: str = typer.Argument(None, help="Tool to ensure is installed"),
5105
+ check_only: bool = typer.Option(False, "--check-only", "-c", help="Only check, don't install"),
5106
+ force: bool = typer.Option(False, "--force", "-f", help="Reinstall even if present"),
5107
+ list_tools: bool = typer.Option(False, "--list", "-l", help="List available tools"),
5108
+ timeout: int = typer.Option(300, "--timeout", "-t", help="Installation timeout in seconds"),
3496
5109
  verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
3497
5110
  quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
3498
5111
  ) -> None:
3499
- """Sync local files to workspace.
5112
+ """Ensure a tool is installed on a target.
3500
5113
 
3501
- Uses rsync over SSH to sync files to the workspace's /workspace directory.
3502
- If workspace is not specified, uses the default workspace.
5114
+ Checks if a tool exists on the target and installs it if missing.
5115
+ Useful for profiling tools like rocprof-compute that aren't pre-installed.
3503
5116
 
3504
5117
  Examples:
3505
- wafer workspaces sync dev ./my-project
3506
- wafer workspaces sync ./my-project # uses default workspace
3507
- wafer workspaces sync dev . # sync current directory
3508
- wafer workspaces sync dev ./script.py # sync single file
5118
+ wafer targets ensure runpod-mi300x rocprof-compute
5119
+ wafer targets ensure runpod-mi300x rocprof-compute --check-only
5120
+ wafer targets ensure runpod-mi300x rocprof-compute --force
5121
+ wafer targets ensure --list
3509
5122
  """
3510
5123
  from .global_config import get_preferences
3511
- from .workspaces import resolve_workspace, sync_files
5124
+ from .targets import load_target
5125
+ from .targets_ops import (
5126
+ TOOL_REGISTRY,
5127
+ TargetExecError,
5128
+ ensure_tool,
5129
+ get_target_platform,
5130
+ get_target_ssh_info,
5131
+ )
3512
5132
 
3513
- # Determine verbosity based on mode
5133
+ # Handle --list flag
5134
+ if list_tools:
5135
+ typer.echo("Available tools:\n")
5136
+ typer.echo("AMD tools:")
5137
+ for name, spec in sorted(TOOL_REGISTRY.items()):
5138
+ if spec.platform == "amd":
5139
+ auto = "auto-install" if spec.install_cmd else "manual"
5140
+ typer.echo(f" {name:20} ({auto}) - {spec.description}")
5141
+
5142
+ typer.echo("\nNVIDIA tools:")
5143
+ for name, spec in sorted(TOOL_REGISTRY.items()):
5144
+ if spec.platform == "nvidia":
5145
+ auto = "auto-install" if spec.install_cmd else "manual"
5146
+ typer.echo(f" {name:20} ({auto}) - {spec.description}")
5147
+
5148
+ typer.echo("\nCross-platform:")
5149
+ for name, spec in sorted(TOOL_REGISTRY.items()):
5150
+ if spec.platform == "any":
5151
+ auto = "auto-install" if spec.install_cmd else "manual"
5152
+ typer.echo(f" {name:20} ({auto}) - {spec.description}")
5153
+ return
5154
+
5155
+ # Require target and tool if not listing
5156
+ if not target:
5157
+ typer.echo("Error: Missing argument 'TARGET'", err=True)
5158
+ typer.echo("Usage: wafer targets ensure TARGET TOOL", err=True)
5159
+ typer.echo(" or: wafer targets ensure --list", err=True)
5160
+ raise typer.Exit(1)
5161
+
5162
+ if not tool:
5163
+ typer.echo("Error: Missing argument 'TOOL'", err=True)
5164
+ typer.echo("Usage: wafer targets ensure TARGET TOOL", err=True)
5165
+ typer.echo(" or: wafer targets ensure --list", err=True)
5166
+ raise typer.Exit(1)
5167
+
5168
+ # Check tool exists
5169
+ if tool not in TOOL_REGISTRY:
5170
+ typer.echo(f"Error: Unknown tool '{tool}'", err=True)
5171
+ typer.echo(f"Available tools: {', '.join(sorted(TOOL_REGISTRY.keys()))}", err=True)
5172
+ typer.echo("Run 'wafer targets ensure --list' for details.", err=True)
5173
+ raise typer.Exit(1)
5174
+
5175
+ spec = TOOL_REGISTRY[tool]
5176
+
5177
+ # Determine verbosity
3514
5178
  prefs = get_preferences()
3515
5179
  if quiet:
3516
5180
  show_status = False
@@ -3519,33 +5183,72 @@ def workspaces_sync(
3519
5183
  else:
3520
5184
  show_status = prefs.mode == "explicit"
3521
5185
 
3522
- # Validate path
3523
- if not path.exists():
3524
- typer.echo(f"Error: Path not found: {path}", err=True)
3525
- raise typer.Exit(1)
3526
-
3527
- # Resolve workspace
5186
+ # Load target
3528
5187
  try:
3529
- resolved_workspace = resolve_workspace(workspace)
3530
- except RuntimeError as e:
5188
+ target_config = load_target(target)
5189
+ except FileNotFoundError as e:
3531
5190
  typer.echo(f"Error: {e}", err=True)
5191
+ typer.echo("List available targets with: wafer config targets list", err=True)
5192
+ raise typer.Exit(1) from None
5193
+ except ValueError as e:
5194
+ typer.echo(f"Error loading target config: {e}", err=True)
3532
5195
  raise typer.Exit(1) from None
3533
5196
 
3534
- if show_status:
3535
- typer.echo(f"[wafer] Syncing {path} to workspace {resolved_workspace}...", err=True)
5197
+ # Platform validation
5198
+ platform = get_target_platform(target_config)
5199
+ if spec.platform != "any" and spec.platform != platform:
5200
+ typer.echo(
5201
+ f"Error: {tool} is an {spec.platform.upper()} tool but target '{target}' "
5202
+ f"is {platform.upper()}",
5203
+ err=True,
5204
+ )
5205
+ raise typer.Exit(1)
3536
5206
 
3537
- def on_progress(msg: str) -> None:
3538
- if show_status:
3539
- typer.echo(f"[wafer] {msg}", err=True)
5207
+ if show_status:
5208
+ typer.echo(f"[wafer] Target: {target} ({platform.upper()})", err=True)
5209
+ typer.echo(f"[wafer] Checking for {tool}...", err=True)
3540
5210
 
5211
+ # Get SSH info (may provision)
3541
5212
  try:
3542
- file_count, warning = sync_files(
3543
- resolved_workspace, path.resolve(), on_progress=on_progress
3544
- )
3545
- except RuntimeError as e:
5213
+ ssh_info = trio.run(get_target_ssh_info, target_config)
5214
+ except TargetExecError as e:
3546
5215
  typer.echo(f"Error: {e}", err=True)
3547
5216
  raise typer.Exit(1) from None
3548
5217
 
5218
+ if show_status:
5219
+ typer.echo(f"[wafer] Connected: {ssh_info.user}@{ssh_info.host}:{ssh_info.port}", err=True)
5220
+
5221
+ # Check-only mode
5222
+ if check_only:
5223
+ from .targets_ops import TargetExecError, exec_on_target_sync
5224
+
5225
+ try:
5226
+ exit_code = exec_on_target_sync(ssh_info, spec.check_cmd, timeout_seconds=30)
5227
+ except TargetExecError as e:
5228
+ typer.echo(f"Error: {e}", err=True)
5229
+ raise typer.Exit(1) from None
5230
+ if exit_code == 0:
5231
+ typer.echo(f"{tool} is installed")
5232
+ else:
5233
+ typer.echo(f"{tool} is NOT installed", err=True)
5234
+ raise typer.Exit(1)
5235
+ return
5236
+
5237
+ # Ensure tool is installed
5238
+ result = ensure_tool(ssh_info, tool, force=force, timeout=timeout)
5239
+
5240
+ if result.error:
5241
+ typer.echo(f"Error: {result.error}", err=True)
5242
+ raise typer.Exit(1)
5243
+
5244
+ if result.already_installed:
5245
+ typer.echo(f"{tool} is already installed")
5246
+ elif result.installed:
5247
+ if result.verified:
5248
+ typer.echo(f"{tool} installed successfully")
5249
+ else:
5250
+ typer.echo(f"{tool} installed (verification skipped)")
5251
+
3549
5252
 
3550
5253
  # =============================================================================
3551
5254
  # Perfetto trace analysis commands
@@ -3830,13 +5533,39 @@ def ncu_analyze(
3830
5533
 
3831
5534
 
3832
5535
  # =============================================================================
3833
- # NSYS Analyze command
5536
+ # NSYS commands
3834
5537
  # =============================================================================
3835
5538
 
3836
5539
 
5540
+ @nsys_app.command("check")
5541
+ def nsys_check() -> None:
5542
+ """Check if NSYS (Nsight Systems) is installed and show version.
5543
+
5544
+ NSYS is required for local analysis. If not installed, shows install instructions.
5545
+
5546
+ Examples:
5547
+ wafer nvidia nsys check
5548
+ """
5549
+ from .nsys_analyze import check_nsys_installation
5550
+
5551
+ result = check_nsys_installation()
5552
+
5553
+ if result.installed:
5554
+ typer.echo(f"✓ NSYS installed: {result.path}")
5555
+ if result.version:
5556
+ typer.echo(f" Version: {result.version}")
5557
+ else:
5558
+ typer.echo("✗ NSYS not installed")
5559
+ if result.install_command:
5560
+ typer.echo(f" Install with: {result.install_command}")
5561
+
5562
+
3837
5563
  @nsys_app.command("analyze")
3838
5564
  def nsys_analyze(
3839
5565
  filepath: Path = typer.Argument(..., help="Path to .nsys-rep profile file"),
5566
+ output_dir: Path | None = typer.Option(
5567
+ None, "--output-dir", "-o", help="Output directory for analysis files"
5568
+ ),
3840
5569
  json_output: bool = typer.Option(
3841
5570
  False, "--json", help="Output raw JSON instead of formatted text"
3842
5571
  ),
@@ -3845,6 +5574,12 @@ def nsys_analyze(
3845
5574
  "--remote/--local",
3846
5575
  help="Force remote (via API) or local analysis. Default: auto-detect (remote if nsys not installed locally)",
3847
5576
  ),
5577
+ target: str | None = typer.Option(
5578
+ None,
5579
+ "--target",
5580
+ "-t",
5581
+ help="Remote target: 'workspace:id' for workspace execution, or target name from ~/.wafer/targets/",
5582
+ ),
3848
5583
  ) -> None:
3849
5584
  """Analyze an NVIDIA Nsight Systems profile (.nsys-rep file).
3850
5585
 
@@ -3853,10 +5588,20 @@ def nsys_analyze(
3853
5588
  By default, uses local nsys if available, otherwise runs analysis
3854
5589
  remotely via wafer-api (requires authentication: wafer login).
3855
5590
 
5591
+ Supports multiple execution modes:
5592
+ - Local: Uses local nsys CLI (no GPU required for analysis)
5593
+ - Remote API: Uploads file and runs analysis on Modal
5594
+ - Workspace: Runs analysis on a Wafer workspace via SSH
5595
+ - Target: Runs analysis on a configured target machine via SSH
5596
+
3856
5597
  Examples:
3857
5598
  wafer nvidia nsys analyze profile.nsys-rep
3858
5599
  wafer nvidia nsys analyze profile.nsys-rep --json
5600
+ wafer nvidia nsys analyze profile.nsys-rep --local
3859
5601
  wafer nvidia nsys analyze profile.nsys-rep --remote
5602
+ wafer nvidia nsys analyze profile.nsys-rep --target workspace:abc123
5603
+ wafer nvidia nsys analyze profile.nsys-rep --target vultr-b200
5604
+ wafer nvidia nsys analyze profile.nsys-rep -o ./results/
3860
5605
  """
3861
5606
  from .nsys_analyze import analyze_nsys_profile
3862
5607
 
@@ -3868,11 +5613,20 @@ def nsys_analyze(
3868
5613
  typer.echo(f"Error: Expected .nsys-rep file, got: {filepath.suffix}", err=True)
3869
5614
  raise typer.Exit(1)
3870
5615
 
5616
+ # Warn if both remote flag and target are specified
5617
+ if target and remote is not None:
5618
+ typer.echo(
5619
+ "Warning: --target overrides --remote/--local flag",
5620
+ err=True,
5621
+ )
5622
+
3871
5623
  try:
3872
5624
  result = analyze_nsys_profile(
3873
5625
  filepath,
3874
5626
  json_output=json_output,
3875
5627
  remote=remote,
5628
+ target=target,
5629
+ output_dir=output_dir,
3876
5630
  )
3877
5631
  typer.echo(result)
3878
5632
  except FileNotFoundError as e:
@@ -3883,6 +5637,150 @@ def nsys_analyze(
3883
5637
  raise typer.Exit(1) from None
3884
5638
 
3885
5639
 
5640
+ @nsys_app.command("profile", context_settings={"allow_interspersed_args": False})
5641
+ def nsys_profile(
5642
+ command: list[str] = typer.Argument(..., help="Command to profile"),
5643
+ output: str = typer.Option(
5644
+ "profile",
5645
+ "--output",
5646
+ "-o",
5647
+ help="Output filename (without .nsys-rep extension)",
5648
+ ),
5649
+ trace: str | None = typer.Option(
5650
+ None,
5651
+ "--trace",
5652
+ "-t",
5653
+ help="Trace APIs to capture (comma-separated: cuda,nvtx,osrt,cudnn,cublas). Default: cuda",
5654
+ ),
5655
+ duration: int | None = typer.Option(
5656
+ None,
5657
+ "--duration",
5658
+ "-d",
5659
+ help="Maximum profiling duration in seconds",
5660
+ ),
5661
+ target: str | None = typer.Option(
5662
+ None,
5663
+ "--target",
5664
+ help="Remote target: 'workspace:id' for workspace execution, or target name from ~/.wafer/targets/",
5665
+ ),
5666
+ analyze: bool = typer.Option(
5667
+ False,
5668
+ "--analyze",
5669
+ "-a",
5670
+ help="Automatically analyze the profile after completion",
5671
+ ),
5672
+ json_output: bool = typer.Option(
5673
+ False,
5674
+ "--json",
5675
+ help="Output analysis as JSON (only with --analyze)",
5676
+ ),
5677
+ verbose: bool = typer.Option(
5678
+ False,
5679
+ "--verbose",
5680
+ "-v",
5681
+ help="Show verbose progress messages",
5682
+ ),
5683
+ extra_args: str | None = typer.Option(
5684
+ None,
5685
+ "--extra",
5686
+ help="Extra arguments to pass to nsys profile",
5687
+ ),
5688
+ ) -> None:
5689
+ """Profile a command with NVIDIA Nsight Systems.
5690
+
5691
+ Runs nsys profile on the specified command and generates a .nsys-rep file.
5692
+ Profiling requires an NVIDIA GPU. Use --target to run on a remote GPU server
5693
+ or workspace.
5694
+
5695
+ Examples:
5696
+ wafer nvidia nsys profile -- python train.py
5697
+ wafer nvidia nsys profile -o gemm_profile -- ./gemm_kernel
5698
+ wafer nvidia nsys profile --trace cuda,nvtx -- python model.py
5699
+ wafer nvidia nsys profile --duration 60 -- ./long_running_app
5700
+ wafer nvidia nsys profile --target workspace:abc123 -- python test.py
5701
+ wafer nvidia nsys profile --target vultr-b200 -- ./benchmark
5702
+ wafer nvidia nsys profile --analyze -- python train.py
5703
+ wafer nvidia nsys profile --analyze --json -- ./kernel > results.json
5704
+ """
5705
+ # Parse command
5706
+ import shlex
5707
+
5708
+ from .nsys_analyze import _parse_target
5709
+ from .nsys_profile import (
5710
+ NSYSProfileOptions,
5711
+ profile_and_analyze,
5712
+ profile_local,
5713
+ profile_remote_ssh,
5714
+ profile_workspace,
5715
+ )
5716
+
5717
+ if isinstance(command, list):
5718
+ # Remove leading "--" if present
5719
+ if command and command[0] == "--":
5720
+ command = command[1:]
5721
+ if len(command) == 1:
5722
+ command_str = command[0]
5723
+ else:
5724
+ command_str = shlex.join(command)
5725
+ else:
5726
+ command_str = command
5727
+
5728
+ if not command_str:
5729
+ typer.echo("Error: No command specified", err=True)
5730
+ raise typer.Exit(1)
5731
+
5732
+ # Parse trace options
5733
+ trace_list = trace.split(",") if trace else None
5734
+
5735
+ # Build options
5736
+ options = NSYSProfileOptions(
5737
+ command=command_str,
5738
+ output=output,
5739
+ trace=trace_list,
5740
+ duration=duration,
5741
+ extra_args=extra_args,
5742
+ )
5743
+
5744
+ if verbose:
5745
+ typer.echo(f"[nsys] Command: {command_str}", err=True)
5746
+ if target:
5747
+ typer.echo(f"[nsys] Target: {target}", err=True)
5748
+
5749
+ # Execute
5750
+ if analyze:
5751
+ profile_result, analysis_result = profile_and_analyze(
5752
+ options,
5753
+ target=target,
5754
+ json_output=json_output,
5755
+ verbose=verbose,
5756
+ )
5757
+ else:
5758
+ if target:
5759
+ target_type, target_id = _parse_target(target)
5760
+ if target_type == "workspace":
5761
+ profile_result = profile_workspace(target_id, options, verbose=verbose)
5762
+ else:
5763
+ profile_result = profile_remote_ssh(target_id, options, verbose=verbose)
5764
+ else:
5765
+ profile_result = profile_local(options, verbose=verbose)
5766
+ analysis_result = None
5767
+
5768
+ # Report results
5769
+ if not profile_result.success:
5770
+ typer.echo(f"Error: {profile_result.error}", err=True)
5771
+ if profile_result.stderr:
5772
+ typer.echo(f"stderr: {profile_result.stderr}", err=True)
5773
+ raise typer.Exit(1)
5774
+
5775
+ if verbose or not analyze:
5776
+ typer.echo(f"Profile created: {profile_result.output_path}")
5777
+
5778
+ if analysis_result:
5779
+ if not analysis_result.success:
5780
+ typer.echo(f"Analysis error: {analysis_result.error}", err=True)
5781
+ raise typer.Exit(1)
5782
+
5783
+
3886
5784
  # =============================================================================
3887
5785
  # ROCprof-Compute commands
3888
5786
  # =============================================================================
@@ -4441,8 +6339,8 @@ def _setup_wafer_core_env() -> None:
4441
6339
  - WAFER_API_URL: If already set, uses that instead of config
4442
6340
  - WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
4443
6341
  """
4444
- from .global_config import get_api_url
4445
6342
  from .auth import get_valid_token
6343
+ from .global_config import get_api_url
4446
6344
 
4447
6345
  # Set API URL (get_api_url already respects WAFER_API_URL env var)
4448
6346
  os.environ["WAFER_API_URL"] = get_api_url()
@@ -4746,8 +6644,8 @@ def capture_command( # noqa: PLR0915
4746
6644
  import os
4747
6645
  import tomllib
4748
6646
 
4749
- from .global_config import get_api_url
4750
6647
  from .auth import get_valid_token
6648
+ from .global_config import get_api_url
4751
6649
 
4752
6650
  # Set environment variables for wafer-core BEFORE importing it
4753
6651
  # wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
@@ -4951,8 +6849,8 @@ def capture_list_command(
4951
6849
  """
4952
6850
  import os
4953
6851
 
4954
- from .global_config import get_api_url
4955
6852
  from .auth import get_valid_token
6853
+ from .global_config import get_api_url
4956
6854
 
4957
6855
  # Set environment variables for wafer-core BEFORE importing it
4958
6856
  os.environ["WAFER_API_URL"] = get_api_url()
@@ -5015,13 +6913,14 @@ def capture_list_command(
5015
6913
 
5016
6914
  @corpus_app.command("download")
5017
6915
  def corpus_download(
5018
- name: str = typer.Argument(..., help="Corpus name (cuda, cutlass, hip)"),
6916
+ name: str = typer.Argument(..., help="Corpus name (cuda, cutlass, hip, amd)"),
5019
6917
  force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
5020
6918
  ) -> None:
5021
6919
  """Download a documentation corpus for agent filesystem access.
5022
6920
 
5023
6921
  Examples:
5024
6922
  wafer corpus download cuda
6923
+ wafer corpus download amd
5025
6924
  wafer corpus download cutlass --force
5026
6925
  """
5027
6926
  from .corpus import CORPORA, download_corpus
@@ -5236,71 +7135,107 @@ def tracelens_collective(
5236
7135
 
5237
7136
 
5238
7137
  # =============================================================================
5239
- # ISA Analysis Commands
7138
+ # Unified ISA Analysis Commands (wafer amd isa ...)
5240
7139
  # =============================================================================
5241
7140
 
5242
7141
 
5243
7142
  @isa_app.command("analyze")
5244
7143
  def isa_analyze(
5245
- file: Path = typer.Argument(..., help="Path to .co file to analyze"),
5246
- json_output: bool = typer.Option(False, "--json", help="Output as JSON"),
7144
+ path: Path = typer.Argument(..., help="Path to file or directory to analyze"),
7145
+ json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
7146
+ csv_output: bool = typer.Option(False, "--csv", help="Output as CSV"),
7147
+ recursive: bool = typer.Option(
7148
+ True, "--recursive/--no-recursive", "-r", help="Scan directories recursively"
7149
+ ),
7150
+ filter_expr: str | None = typer.Option(
7151
+ None, "--filter", "-f", help="Filter results (e.g., 'spills > 0')"
7152
+ ),
7153
+ output_file: Path | None = typer.Option(None, "--output", "-o", help="Write output to file"),
7154
+ kernel_index: int = typer.Option(0, "--kernel", "-k", help="Kernel index if multiple in file"),
5247
7155
  ) -> None:
5248
- """Analyze AMD GPU code object (.co file).
7156
+ """Analyze AMD GPU ISA files (.co, .s, .ll, .ttgir).
5249
7157
 
5250
- Extracts and analyzes ISA, showing register usage, instruction mix,
5251
- spills, and other performance-relevant metrics.
7158
+ Performs static analysis to extract performance metrics like register
7159
+ pressure, spills, MFMA density, and occupancy limits.
5252
7160
 
5253
- The .co file is uploaded to the Wafer API server which has ROCm tools
5254
- installed for analysis.
7161
+ Supports:
7162
+ - AMD GPU code objects (.co) - Requires API authentication
7163
+ - AMDGCN ISA assembly (.s, .gcn, .asm) - Local parsing
7164
+ - LLVM-IR files (.ll) - Local parsing
7165
+ - TTGIR files (.ttgir, .ttir, .mlir) - Local parsing
5255
7166
 
5256
7167
  Examples:
5257
- wafer isa analyze kernel.co
5258
- wafer isa analyze kernel.co --json
7168
+ wafer amd isa analyze kernel.co # Code object (needs login)
7169
+ wafer amd isa analyze kernel.s # ISA assembly
7170
+ wafer amd isa analyze kernel.s --json # Output as JSON
7171
+ wafer amd isa analyze ~/.triton/cache/ --filter 'spills > 0'
7172
+ wafer amd isa analyze . -r --csv -o metrics.csv
5259
7173
  """
5260
- from dataclasses import asdict
5261
-
5262
- from wafer_core.tools.isa_analysis_tools import analyze_isa, format_isa_summary
5263
-
5264
7174
  from .auth import get_auth_headers
5265
7175
  from .global_config import get_api_url
7176
+ from .kernel_scope import analyze_command
5266
7177
 
5267
- # Validate file exists
5268
- if not file.exists():
5269
- typer.echo(f"Error: File not found: {file}", err=True)
5270
- raise typer.Exit(1)
5271
-
5272
- if not file.suffix == ".co":
5273
- typer.echo(f"Error: Expected .co file, got: {file.suffix}", err=True)
5274
- raise typer.Exit(1)
5275
-
5276
- # Get API URL and auth
7178
+ # Get API credentials for .co files
5277
7179
  api_url = get_api_url()
5278
7180
  auth_headers = get_auth_headers()
5279
7181
 
5280
- if not auth_headers:
5281
- typer.echo("Error: Not logged in. Run 'wafer login' first.", err=True)
5282
- raise typer.Exit(1)
5283
-
5284
7182
  try:
5285
- result = analyze_isa(
5286
- co_file_path=file,
7183
+ output = analyze_command(
7184
+ path=str(path),
7185
+ json_output=json_output,
7186
+ csv_output=csv_output,
7187
+ recursive=recursive,
7188
+ filter_expr=filter_expr,
7189
+ output_file=str(output_file) if output_file else None,
7190
+ kernel_index=kernel_index,
5287
7191
  api_url=api_url,
5288
7192
  auth_headers=auth_headers,
5289
7193
  )
5290
-
5291
- if json_output:
5292
- typer.echo(json.dumps(asdict(result)))
5293
- else:
5294
- typer.echo(format_isa_summary(result))
7194
+ typer.echo(output)
5295
7195
 
5296
7196
  except FileNotFoundError as e:
5297
7197
  typer.echo(f"Error: {e}", err=True)
5298
7198
  raise typer.Exit(1) from None
7199
+ except RuntimeError as e:
7200
+ typer.echo(f"Error: {e}", err=True)
7201
+ raise typer.Exit(1) from None
5299
7202
  except Exception as e:
5300
7203
  typer.echo(f"Error: {e}", err=True)
5301
7204
  raise typer.Exit(1) from None
5302
7205
 
5303
7206
 
7207
+ @isa_app.command("metrics")
7208
+ def isa_metrics() -> None:
7209
+ """List available metrics for ISA analysis.
7210
+
7211
+ Shows all metrics that can be extracted from AMD GPU ISA files,
7212
+ along with their derivation.
7213
+
7214
+ Examples:
7215
+ wafer amd isa metrics
7216
+ """
7217
+ from .kernel_scope import metrics_command
7218
+
7219
+ output = metrics_command()
7220
+ typer.echo(output)
7221
+
7222
+
7223
+ @isa_app.command("targets")
7224
+ def isa_targets() -> None:
7225
+ """List supported GPU targets and their specifications.
7226
+
7227
+ Shows hardware specs (VGPRs, SGPRs, LDS, etc.) for each supported
7228
+ AMD GPU architecture.
7229
+
7230
+ Examples:
7231
+ wafer amd isa targets
7232
+ """
7233
+ from .kernel_scope import targets_command
7234
+
7235
+ output = targets_command()
7236
+ typer.echo(output)
7237
+
7238
+
5304
7239
  def main() -> None:
5305
7240
  """Entry point for wafer CLI."""
5306
7241
  app()