wafer-cli 0.2.14__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -1,6 +1,8 @@
1
- # ruff: noqa: PLR0913
1
+ # ruff: noqa: PLR0913, E402
2
2
  # PLR0913 (too many arguments) is suppressed because Typer CLI commands
3
3
  # naturally have many parameters - each --flag becomes a function argument.
4
+ # E402 (module level import not at top) is suppressed because we intentionally
5
+ # load .env files before importing other modules that may read env vars.
4
6
  """Wafer CLI - GPU development toolkit for LLM coding agents.
5
7
 
6
8
  Core commands:
@@ -27,6 +29,12 @@ from pathlib import Path
27
29
 
28
30
  import trio
29
31
  import typer
32
+ from dotenv import load_dotenv
33
+
34
+ # Auto-load .env from current directory and ~/.wafer/.env
35
+ # This runs at import time so env vars are available before any config is accessed
36
+ load_dotenv() # cwd/.env
37
+ load_dotenv(Path.home() / ".wafer" / ".env") # ~/.wafer/.env
30
38
 
31
39
  from .config import WaferConfig, WaferEnvironment
32
40
  from .inference import infer_upload_files, resolve_environment
@@ -42,6 +50,7 @@ from .problems import (
42
50
  app = typer.Typer(
43
51
  help="GPU development toolkit for LLM coding agents",
44
52
  no_args_is_help=True,
53
+ pretty_exceptions_show_locals=False, # Don't dump local vars (makes tracebacks huge)
45
54
  )
46
55
 
47
56
  # =============================================================================
@@ -58,11 +67,11 @@ def _show_version() -> None:
58
67
  """Show CLI version and environment, then exit."""
59
68
  from .analytics import _get_cli_version
60
69
  from .global_config import load_global_config
61
-
70
+
62
71
  version = _get_cli_version()
63
72
  config = load_global_config()
64
73
  environment = config.environment
65
-
74
+
66
75
  typer.echo(f"wafer-cli {version} ({environment})")
67
76
  raise typer.Exit()
68
77
 
@@ -110,7 +119,7 @@ def main_callback(
110
119
  if version:
111
120
  _show_version()
112
121
  return
113
-
122
+
114
123
  global _command_start_time, _command_outcome
115
124
  _command_start_time = time.time()
116
125
  _command_outcome = "success" # Default to success, mark failure on exceptions
@@ -121,6 +130,7 @@ def main_callback(
121
130
  analytics.init_analytics()
122
131
 
123
132
  # Install exception hook to catch SystemExit and mark failures
133
+ # Also prints error message FIRST so it's visible even when traceback is truncated
124
134
  original_excepthook = sys.excepthook
125
135
 
126
136
  def custom_excepthook(
@@ -136,7 +146,11 @@ def main_callback(
136
146
  _command_outcome = "failure"
137
147
  else:
138
148
  _command_outcome = "failure"
139
- # Call original excepthook
149
+ # Print error summary FIRST (before traceback) so it's visible even if truncated
150
+ print(
151
+ f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
152
+ )
153
+ # Call original excepthook (prints the full traceback)
140
154
  original_excepthook(exc_type, exc_value, exc_traceback)
141
155
 
142
156
  sys.excepthook = custom_excepthook
@@ -180,11 +194,16 @@ def complete_target_name(incomplete: str) -> list[str]:
180
194
 
181
195
  # =============================================================================
182
196
  # Core subcommand groups (visible in --help)
197
+ #
198
+ # TODO: Further consolidate top-level commands to reduce --help surface area.
199
+ # Candidates:
200
+ # - compare → wafer nvidia compare or keep top-level (cross-platform)
201
+ # - guide/skill/demo → wafer onboard {guide,skill,demo}
183
202
  # =============================================================================
184
203
 
185
204
  # Config management (includes targets as nested subcommand)
186
205
  config_app = typer.Typer(help="Manage CLI configuration and local GPU targets")
187
- app.add_typer(config_app, name="config")
206
+ app.add_typer(config_app, name="config", rich_help_panel="Configuration")
188
207
 
189
208
  # Target management - nested under config
190
209
  targets_app = typer.Typer(
@@ -204,7 +223,7 @@ config_app.add_typer(targets_app, name="targets")
204
223
  workspaces_app = typer.Typer(
205
224
  help="""Manage cloud GPU workspaces for remote development.
206
225
 
207
- Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
226
+ Workspaces are on-demand cloud GPU environments. Requires authentication (wafer auth login).
208
227
 
209
228
  Available GPUs:
210
229
  MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
@@ -217,21 +236,21 @@ Commands:
217
236
  wafer workspaces sync dev ./project # Sync files
218
237
  wafer workspaces delete dev # Clean up"""
219
238
  )
220
- app.add_typer(workspaces_app, name="workspaces")
239
+ app.add_typer(workspaces_app, name="workspaces", rich_help_panel="Infrastructure")
221
240
 
222
- # SSH Key management (BYOK - Bring Your Own Key)
241
+ # SSH Key management (BYOK - Bring Your Own Key) - nested under config
223
242
  ssh_keys_app = typer.Typer(
224
243
  help="""Manage SSH public keys for workspace access.
225
244
 
226
245
  Register your SSH public keys here. These keys are installed in all workspaces
227
246
  you provision, enabling SSH access from any machine with your private key.
228
247
 
229
- wafer ssh-keys list # List registered keys
230
- wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
231
- wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
232
- wafer ssh-keys remove <key-id> # Remove a key"""
248
+ wafer config ssh-keys list # List registered keys
249
+ wafer config ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
250
+ wafer config ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
251
+ wafer config ssh-keys remove <key-id> # Remove a key"""
233
252
  )
234
- app.add_typer(ssh_keys_app, name="ssh-keys")
253
+ config_app.add_typer(ssh_keys_app, name="ssh-keys")
235
254
 
236
255
  # Target operations (exec/ssh/sync on configured targets)
237
256
  targets_ops_app = typer.Typer(
@@ -247,22 +266,48 @@ Useful for exploratory work, debugging, or custom scripts.
247
266
  Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
248
267
  Configure targets with: wafer config targets init ..."""
249
268
  )
250
- app.add_typer(targets_ops_app, name="targets")
269
+ app.add_typer(targets_ops_app, name="targets", rich_help_panel="Infrastructure")
270
+
271
+ # Specs management (new: local TOML configs)
272
+ from wafer.specs_cli import specs_app
273
+
274
+ app.add_typer(specs_app, name="specs", rich_help_panel="Configuration")
275
+
276
+ # Live resource management (new: API-backed commands on `wafer targets`)
277
+ # These become: wafer targets list, wafer targets terminate, etc.
278
+ from wafer.targets_cli import (
279
+ targets_list as _targets_list_cmd,
280
+ )
281
+ from wafer.targets_cli import (
282
+ targets_provision as _targets_provision_cmd,
283
+ )
284
+ from wafer.targets_cli import (
285
+ targets_reconcile as _targets_reconcile_cmd,
286
+ )
287
+ from wafer.targets_cli import (
288
+ targets_terminate as _targets_terminate_cmd,
289
+ )
290
+ from wafer.targets_cli import (
291
+ targets_pools as _targets_pools_cmd,
292
+ )
293
+ from wafer.targets_cli import (
294
+ targets_probe as _targets_probe_cmd,
295
+ )
251
296
 
252
- # Billing management
297
+ # Billing management - nested under config
253
298
  billing_app = typer.Typer(help="Manage billing, credits, and subscription")
254
- app.add_typer(billing_app, name="billing")
299
+ config_app.add_typer(billing_app, name="billing")
255
300
 
256
301
  # Corpus management
257
302
  corpus_app = typer.Typer(help="Download and manage GPU documentation")
258
- app.add_typer(corpus_app, name="corpus")
303
+ app.add_typer(corpus_app, name="corpus", rich_help_panel="Kernel Development")
259
304
 
260
305
  # Evaluate (supports multiple kernel formats)
261
306
  evaluate_app = typer.Typer(
262
307
  help="Test kernel correctness and performance",
263
308
  invoke_without_command=True,
264
309
  )
265
- app.add_typer(evaluate_app, name="evaluate")
310
+ app.add_typer(evaluate_app, name="evaluate", rich_help_panel="Kernel Development")
266
311
 
267
312
  # Nested subcommand for kernelbench format
268
313
  kernelbench_app = typer.Typer(
@@ -291,7 +336,7 @@ app.add_typer(dev_app, name="dev")
291
336
  # =============================================================================
292
337
 
293
338
  nvidia_app = typer.Typer(help="NVIDIA GPU profiling and analysis tools")
294
- app.add_typer(nvidia_app, name="nvidia")
339
+ app.add_typer(nvidia_app, name="nvidia", rich_help_panel="Profiling")
295
340
 
296
341
  # NCU analysis - under nvidia
297
342
  ncu_app = typer.Typer(help="Nsight Compute profile analysis")
@@ -314,18 +359,25 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
314
359
  # =============================================================================
315
360
 
316
361
  amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
317
- app.add_typer(amd_app, name="amd")
362
+ app.add_typer(amd_app, name="amd", rich_help_panel="Profiling")
318
363
 
319
364
  # Unified ISA Analyzer - supports both .co files and Triton artifacts
320
365
  isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
321
366
  amd_app.add_typer(isa_app, name="isa")
322
367
 
368
+ # =============================================================================
369
+ # Trace comparison (wafer compare)
370
+ # =============================================================================
371
+
372
+ compare_app = typer.Typer(help="Compare GPU traces across platforms (AMD vs NVIDIA)")
373
+ app.add_typer(compare_app, name="compare", rich_help_panel="Profiling")
374
+
323
375
  # =============================================================================
324
376
  # Roofline analysis (wafer roofline)
325
377
  # =============================================================================
326
378
 
327
379
 
328
- @app.command("roofline")
380
+ @app.command("roofline", rich_help_panel="Kernel Development")
329
381
  def roofline_cmd(
330
382
  gpu: str | None = typer.Option(
331
383
  None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
@@ -416,7 +468,7 @@ def roofline_cmd(
416
468
  # =============================================================================
417
469
 
418
470
  skill_app = typer.Typer(help="Manage AI coding assistant skills (Claude Code, Codex)")
419
- app.add_typer(skill_app, name="skill")
471
+ app.add_typer(skill_app, name="skill", rich_help_panel="Onboarding")
420
472
 
421
473
 
422
474
  @skill_app.command("install")
@@ -580,18 +632,23 @@ def skill_status() -> None:
580
632
 
581
633
 
582
634
  # =============================================================================
583
- # Provider auth management (wafer auth ...)
635
+ # Authentication (wafer auth ...)
584
636
  # =============================================================================
585
637
 
586
- provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
587
- app.add_typer(provider_auth_app, name="auth")
638
+ auth_app = typer.Typer(help="Authenticate with Wafer and cloud GPU providers")
639
+ app.add_typer(auth_app, name="auth", rich_help_panel="Configuration")
588
640
 
641
+ providers_app = typer.Typer(
642
+ help="Manage API keys for cloud GPU providers (RunPod, DigitalOcean, etc.)"
643
+ )
644
+ auth_app.add_typer(providers_app, name="providers")
589
645
 
590
- @provider_auth_app.command("login")
646
+
647
+ @providers_app.command("login")
591
648
  def provider_auth_login(
592
649
  provider: str = typer.Argument(
593
650
  ...,
594
- help="Provider name: runpod, digitalocean, or modal",
651
+ help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
595
652
  ),
596
653
  api_key: str | None = typer.Option(
597
654
  None,
@@ -600,15 +657,16 @@ def provider_auth_login(
600
657
  help="API key (if not provided, reads from stdin)",
601
658
  ),
602
659
  ) -> None:
603
- """Save API key for a cloud GPU provider.
660
+ """Save API key for a provider.
604
661
 
605
662
  Stores the key in ~/.wafer/auth.json. Environment variables
606
- (e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
663
+ (e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
607
664
 
608
665
  Examples:
609
- wafer auth login runpod --api-key rp_xxx
610
- wafer auth login digitalocean --api-key dop_v1_xxx
611
- echo $API_KEY | wafer auth login runpod
666
+ wafer auth providers login anthropic --api-key sk-ant-xxx
667
+ wafer auth providers login runpod --api-key rp_xxx
668
+ wafer auth providers login openai --api-key sk-xxx
669
+ echo $API_KEY | wafer auth providers login anthropic
612
670
  """
613
671
  import sys
614
672
 
@@ -638,18 +696,18 @@ def provider_auth_login(
638
696
  typer.echo("Stored in: ~/.wafer/auth.json")
639
697
 
640
698
 
641
- @provider_auth_app.command("logout")
699
+ @providers_app.command("logout")
642
700
  def provider_auth_logout(
643
701
  provider: str = typer.Argument(
644
702
  ...,
645
- help="Provider name: runpod, digitalocean, or modal",
703
+ help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
646
704
  ),
647
705
  ) -> None:
648
706
  """Remove stored API key for a cloud GPU provider.
649
707
 
650
708
  Examples:
651
- wafer auth logout runpod
652
- wafer auth logout digitalocean
709
+ wafer auth providers logout runpod
710
+ wafer auth providers logout digitalocean
653
711
  """
654
712
  from wafer_core.auth import PROVIDERS, remove_api_key
655
713
 
@@ -665,7 +723,7 @@ def provider_auth_logout(
665
723
  typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
666
724
 
667
725
 
668
- @provider_auth_app.command("status")
726
+ @providers_app.command("status")
669
727
  def provider_auth_status() -> None:
670
728
  """Show authentication status for all cloud GPU providers.
671
729
 
@@ -673,7 +731,7 @@ def provider_auth_status() -> None:
673
731
  the keys are coming from (environment variable or auth.json).
674
732
 
675
733
  Example:
676
- wafer auth status
734
+ wafer auth providers status
677
735
  """
678
736
  from wafer_core.auth import get_all_auth_status
679
737
 
@@ -688,7 +746,7 @@ def provider_auth_status() -> None:
688
746
  typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
689
747
  else:
690
748
  typer.echo(f" {status.display_name}: ✗ Not configured")
691
- typer.echo(f" Run: wafer auth login {status.provider}")
749
+ typer.echo(f" Run: wafer auth providers login {status.provider}")
692
750
  typer.echo(f" Or set: {status.key_url}")
693
751
 
694
752
  typer.echo("")
@@ -1233,7 +1291,7 @@ def config_show_legacy() -> None:
1233
1291
  config_show_new()
1234
1292
 
1235
1293
 
1236
- @app.command()
1294
+ @app.command(rich_help_panel="Kernel Development")
1237
1295
  def agent( # noqa: PLR0913
1238
1296
  prompt: str | None = typer.Argument(
1239
1297
  None,
@@ -1303,7 +1361,7 @@ def agent( # noqa: PLR0913
1303
1361
  None,
1304
1362
  "--model",
1305
1363
  "-m",
1306
- help="Model override (default: claude-sonnet-4-5)",
1364
+ help="Model override (default: claude-opus-4-5)",
1307
1365
  ),
1308
1366
  json_output: bool = typer.Option(
1309
1367
  False,
@@ -1327,6 +1385,16 @@ def agent( # noqa: PLR0913
1327
1385
  "-c",
1328
1386
  help="Documentation corpus to use (cuda, cutlass, hip, amd). Must be downloaded first.",
1329
1387
  ),
1388
+ no_sandbox: bool = typer.Option(
1389
+ False,
1390
+ "--no-sandbox",
1391
+ help="Disable OS-level sandboxing (YOU accept liability for any damage caused by the agent)",
1392
+ ),
1393
+ no_proxy: bool = typer.Option(
1394
+ False,
1395
+ "--no-proxy",
1396
+ help="Skip wafer proxy, use ANTHROPIC_API_KEY directly",
1397
+ ),
1330
1398
  ) -> None:
1331
1399
  """AI assistant for GPU kernel development.
1332
1400
 
@@ -1408,6 +1476,13 @@ def agent( # noqa: PLR0913
1408
1476
  raise typer.Exit(1) from None
1409
1477
  corpus_path = str(path)
1410
1478
 
1479
+ # Warn user about sandbox disabled
1480
+ if no_sandbox:
1481
+ print(
1482
+ "Warning: Sandbox disabled. You accept liability for any damage caused by the agent.",
1483
+ file=sys.stderr,
1484
+ )
1485
+
1411
1486
  wevin_main(
1412
1487
  prompt=actual_prompt,
1413
1488
  interactive=use_tui,
@@ -1425,6 +1500,8 @@ def agent( # noqa: PLR0913
1425
1500
  template=template,
1426
1501
  template_args=parsed_template_args,
1427
1502
  corpus_path=corpus_path,
1503
+ no_sandbox=no_sandbox,
1504
+ no_proxy=no_proxy,
1428
1505
  )
1429
1506
 
1430
1507
 
@@ -1455,6 +1532,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
1455
1532
  template: str | None = typer.Option(None, "--template", "-t"),
1456
1533
  template_args: list[str] | None = typer.Option(None, "--args"),
1457
1534
  corpus: str | None = typer.Option(None, "--corpus"),
1535
+ no_sandbox: bool = typer.Option(False, "--no-sandbox"),
1458
1536
  ) -> None:
1459
1537
  agent(
1460
1538
  prompt=prompt,
@@ -1474,6 +1552,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
1474
1552
  template=template,
1475
1553
  template_args=template_args,
1476
1554
  corpus=corpus,
1555
+ no_sandbox=no_sandbox,
1477
1556
  )
1478
1557
 
1479
1558
  alias_cmd.__doc__ = doc
@@ -1497,7 +1576,11 @@ def evaluate( # noqa: PLR0913
1497
1576
  None, "--reference", help="Path to reference kernel file"
1498
1577
  ),
1499
1578
  test_cases: Path | None = typer.Option(
1500
- None, "--test-cases", help="Path to test cases JSON file"
1579
+ None,
1580
+ "--test-cases",
1581
+ help="Path to test cases JSON file. "
1582
+ 'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
1583
+ "Run 'wafer evaluate make-template' to generate an example.",
1501
1584
  ),
1502
1585
  target: str | None = typer.Option(
1503
1586
  None,
@@ -1527,20 +1610,20 @@ def evaluate( # noqa: PLR0913
1527
1610
 
1528
1611
  Examples:
1529
1612
  # Basic correctness check
1530
- wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json
1613
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
1531
1614
 
1532
1615
  # With benchmarking on a specific target
1533
- wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
1616
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
1534
1617
  --target vultr-b200 --benchmark
1535
1618
 
1536
1619
  # Full evaluation with defensive timing (detects cheating)
1537
- wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
1620
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
1538
1621
  --benchmark --defensive
1539
1622
 
1540
1623
  Subcommands:
1541
1624
  gpumode Use GPUMode format (functional) - RECOMMENDED
1542
1625
  kernelbench Use KernelBench format (ModelNew class)
1543
- make-template Generate template files for this format (deprecated)
1626
+ make-template Generate template files for this format
1544
1627
  """
1545
1628
  # If a subcommand is being invoked, skip the main evaluation logic
1546
1629
  if ctx.invoked_subcommand is not None:
@@ -1694,7 +1777,7 @@ def evaluate_make_template(
1694
1777
  typer.echo(f" 2. Edit {output_dir / 'reference.py'} with the ground truth + input generator")
1695
1778
  typer.echo(f" 3. Edit {output_dir / 'test_cases.json'} with your test parameters")
1696
1779
  typer.echo(" 4. Run:")
1697
- typer.echo(f" wafer evaluate --impl {output_dir / 'kernel.py'} \\")
1780
+ typer.echo(f" wafer evaluate gpumode --impl {output_dir / 'kernel.py'} \\")
1698
1781
  typer.echo(f" --reference {output_dir / 'reference.py'} \\")
1699
1782
  typer.echo(f" --test-cases {output_dir / 'test_cases.json'} --benchmark")
1700
1783
 
@@ -1758,6 +1841,93 @@ def kernelbench_list_problems() -> None:
1758
1841
  raise typer.Exit(1) from None
1759
1842
 
1760
1843
 
1844
+ def _resolve_pool_query(pool: str, collector) -> tuple[str, object]:
1845
+ """Resolve a PoolQuery pool to a target spec name + lock context.
1846
+
1847
+ Queries live providers, matches by pool query, locks one target,
1848
+ returns (spec_name, lock_context) for the evaluator.
1849
+ """
1850
+ import trio
1851
+ from wafer_core.targets.pool import resolve_pool
1852
+
1853
+ from .target_lock import acquire_from_pool
1854
+
1855
+ matched_targets = trio.run(resolve_pool, pool)
1856
+
1857
+ if not matched_targets:
1858
+ collector.set_error("pool", "NoMatchingTargets", pool=pool)
1859
+ collector.finalize()
1860
+ raise typer.Exit(1)
1861
+
1862
+ # Filter to targets with a spec (evaluator needs spec fields)
1863
+ spec_targets = [t for t in matched_targets if t.spec_name]
1864
+ if not spec_targets:
1865
+ collector.set_error(
1866
+ "pool", "NoSpecTargets", pool=pool,
1867
+ message="Matched targets have no spec binding — evaluator needs spec fields",
1868
+ )
1869
+ collector.finalize()
1870
+ raise typer.Exit(1)
1871
+
1872
+ # Lock one by resource_id
1873
+ resource_ids = [t.resource_id for t in spec_targets]
1874
+ collector.emit("pool_acquire", pool=pool, count=len(resource_ids))
1875
+
1876
+ lock_ctx = acquire_from_pool(resource_ids)
1877
+ acquired_id = lock_ctx.__enter__()
1878
+
1879
+ if acquired_id is None:
1880
+ lock_ctx.__exit__(None, None, None)
1881
+ collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=resource_ids)
1882
+ collector.finalize()
1883
+ raise typer.Exit(1)
1884
+
1885
+ # Map resource_id back to spec_name
1886
+ acquired_target = next(t for t in spec_targets if t.resource_id == acquired_id)
1887
+ spec_name = acquired_target.spec_name
1888
+
1889
+ collector.emit("pool_acquired", target=spec_name, resource_id=acquired_id)
1890
+ return spec_name, lock_ctx
1891
+
1892
+
1893
+ def _resolve_pool_legacy(pool: str, collector) -> tuple[str, object]:
1894
+ """Resolve an old-style pool (static target name list) to a target name + lock context.
1895
+
1896
+ Old format: [pools.name] targets = ["t1", "t2"]
1897
+ """
1898
+ from .target_lock import acquire_from_pool
1899
+ from .targets import filter_pool_by_auth, get_pool
1900
+
1901
+ try:
1902
+ pool_targets = get_pool(pool)
1903
+ except FileNotFoundError as e:
1904
+ collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
1905
+ collector.finalize()
1906
+ raise typer.Exit(1) from None
1907
+
1908
+ usable_targets, skipped = filter_pool_by_auth(pool_targets)
1909
+ if skipped:
1910
+ collector.emit("pool_auth_skip", targets=skipped)
1911
+
1912
+ if not usable_targets:
1913
+ collector.set_error("pool", "NoUsableTargets", pool=pool)
1914
+ collector.finalize()
1915
+ raise typer.Exit(1) from None
1916
+
1917
+ collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
1918
+ lock_ctx = acquire_from_pool(usable_targets)
1919
+ acquired_target = lock_ctx.__enter__()
1920
+
1921
+ if acquired_target is None:
1922
+ lock_ctx.__exit__(None, None, None)
1923
+ collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
1924
+ collector.finalize()
1925
+ raise typer.Exit(1)
1926
+
1927
+ collector.emit("pool_acquired", target=acquired_target)
1928
+ return acquired_target, lock_ctx
1929
+
1930
+
1761
1931
  @kernelbench_app.callback(invoke_without_command=True)
1762
1932
  def kernelbench_evaluate( # noqa: PLR0913, PLR0915
1763
1933
  ctx: typer.Context,
@@ -1888,39 +2058,12 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
1888
2058
  pool_lock_context = None
1889
2059
 
1890
2060
  if pool:
1891
- from .target_lock import acquire_from_pool
1892
- from .targets import filter_pool_by_auth, get_pool
1893
-
1894
- try:
1895
- pool_targets = get_pool(pool)
1896
- except FileNotFoundError as e:
1897
- collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
1898
- collector.finalize()
1899
- raise typer.Exit(1) from None
1900
-
1901
- # Filter to only targets with valid auth
1902
- usable_targets, skipped = filter_pool_by_auth(pool_targets)
1903
- if skipped:
1904
- collector.emit("pool_auth_skip", targets=skipped)
1905
-
1906
- if not usable_targets:
1907
- collector.set_error("pool", "NoUsableTargets", pool=pool)
1908
- collector.finalize()
1909
- raise typer.Exit(1) from None
2061
+ from wafer_core.targets.pool import is_query_pool
1910
2062
 
1911
- collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
1912
- pool_lock_context = acquire_from_pool(usable_targets)
1913
- acquired_target = pool_lock_context.__enter__()
1914
-
1915
- if acquired_target is None:
1916
- # Exit context manager before raising to avoid resource leak
1917
- pool_lock_context.__exit__(None, None, None)
1918
- collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
1919
- collector.finalize()
1920
- raise typer.Exit(1)
1921
-
1922
- collector.emit("pool_acquired", target=acquired_target)
1923
- resolved_target = acquired_target
2063
+ if is_query_pool(pool):
2064
+ resolved_target, pool_lock_context = _resolve_pool_query(pool, collector)
2065
+ else:
2066
+ resolved_target, pool_lock_context = _resolve_pool_legacy(pool, collector)
1924
2067
 
1925
2068
  collector.target = resolved_target
1926
2069
 
@@ -2245,7 +2388,11 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2245
2388
  None, "--reference", help="Path to reference kernel file"
2246
2389
  ),
2247
2390
  test_cases: Path | None = typer.Option(
2248
- None, "--test-cases", help="Path to test cases JSON file"
2391
+ None,
2392
+ "--test-cases",
2393
+ help="Path to test cases JSON file. "
2394
+ 'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
2395
+ "Run 'wafer evaluate make-template' to generate an example.",
2249
2396
  ),
2250
2397
  target: str | None = typer.Option(
2251
2398
  None,
@@ -2313,6 +2460,13 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2313
2460
  err=True,
2314
2461
  )
2315
2462
  typer.echo("", err=True)
2463
+ if "--test-cases" in missing_args:
2464
+ typer.echo(
2465
+ "Tip: Run 'wafer evaluate make-template' to generate template files "
2466
+ "including test_cases.json.",
2467
+ err=True,
2468
+ )
2469
+ typer.echo("", err=True)
2316
2470
  typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
2317
2471
  typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
2318
2472
  raise typer.Exit(1)
@@ -2719,7 +2873,7 @@ def remote_run( # noqa: PLR0913
2719
2873
  # =============================================================================
2720
2874
 
2721
2875
 
2722
- @app.command("login")
2876
+ @auth_app.command("login")
2723
2877
  def login(
2724
2878
  token: str | None = typer.Option(
2725
2879
  None, "--token", "-t", help="Access token (skip browser OAuth)"
@@ -2744,7 +2898,7 @@ def login(
2744
2898
  Uses the API environment from config (see 'wafer config show').
2745
2899
 
2746
2900
  SSH Users (Easiest):
2747
- - Just run: wafer login
2901
+ - Just run: wafer auth login
2748
2902
  - Visit the URL and enter the code shown
2749
2903
  - No port forwarding needed!
2750
2904
 
@@ -2754,17 +2908,17 @@ def login(
2754
2908
 
2755
2909
  Manual token option:
2756
2910
  - Visit auth.wafer.ai, authenticate, copy token from URL
2757
- - Run: wafer login --token <paste-token>
2911
+ - Run: wafer auth login --token <paste-token>
2758
2912
 
2759
2913
  Examples:
2760
- wafer login # device code on SSH, browser on local
2761
- wafer login --no-device-code # force browser (needs port forwarding on SSH)
2762
- wafer login --port 9000 # custom port for browser flow
2763
- wafer login --token xyz # manual token (no browser)
2914
+ wafer auth login # device code on SSH, browser on local
2915
+ wafer auth login --no-device-code # force browser (needs port forwarding on SSH)
2916
+ wafer auth login --port 9000 # custom port for browser flow
2917
+ wafer auth login --token xyz # manual token (no browser)
2764
2918
 
2765
2919
  # Change environment:
2766
2920
  wafer config set api.environment staging
2767
- wafer login
2921
+ wafer auth login
2768
2922
  """
2769
2923
  import httpx
2770
2924
 
@@ -2848,7 +3002,7 @@ def login(
2848
3002
  typer.echo("Token saved to ~/.wafer/credentials.json")
2849
3003
 
2850
3004
 
2851
- @app.command("logout")
3005
+ @auth_app.command("logout")
2852
3006
  def logout() -> None:
2853
3007
  """Remove stored credentials."""
2854
3008
  from . import analytics
@@ -2865,7 +3019,7 @@ def logout() -> None:
2865
3019
  typer.echo("Not logged in (no credentials found).")
2866
3020
 
2867
3021
 
2868
- @app.command("whoami")
3022
+ @auth_app.command("whoami")
2869
3023
  def whoami(
2870
3024
  verify: bool = typer.Option(False, "--verify", "-v", help="Verify token with API"),
2871
3025
  refresh: bool = typer.Option(False, "--refresh", "-r", help="Refresh token if expired"),
@@ -2879,7 +3033,7 @@ def whoami(
2879
3033
 
2880
3034
  creds = load_credentials()
2881
3035
  if creds is None:
2882
- typer.echo("Not logged in. Run: wafer login")
3036
+ typer.echo("Not logged in. Run: wafer auth login")
2883
3037
  raise typer.Exit(1)
2884
3038
 
2885
3039
  if verify or refresh:
@@ -2887,7 +3041,7 @@ def whoami(
2887
3041
  # Try to get valid token with auto-refresh
2888
3042
  token = get_valid_token()
2889
3043
  if token is None:
2890
- typer.echo("Token expired and refresh failed. Run: wafer login", err=True)
3044
+ typer.echo("Token expired and refresh failed. Run: wafer auth login", err=True)
2891
3045
  raise typer.Exit(1)
2892
3046
  if token != creds.access_token:
2893
3047
  typer.echo("Token refreshed successfully")
@@ -2900,10 +3054,10 @@ def whoami(
2900
3054
  except Exception as e:
2901
3055
  if creds.refresh_token and not refresh:
2902
3056
  typer.echo(f"Token expired: {e}", err=True)
2903
- typer.echo("Try: wafer whoami --refresh", err=True)
3057
+ typer.echo("Try: wafer auth whoami --refresh", err=True)
2904
3058
  else:
2905
3059
  typer.echo(f"Token invalid or expired: {e}", err=True)
2906
- typer.echo("Run: wafer login", err=True)
3060
+ typer.echo("Run: wafer auth login", err=True)
2907
3061
  raise typer.Exit(1) from None
2908
3062
  elif creds.email:
2909
3063
  typer.echo(creds.email)
@@ -2911,7 +3065,7 @@ def whoami(
2911
3065
  typer.echo("Logged in (email not available)")
2912
3066
 
2913
3067
 
2914
- @app.command("guide")
3068
+ @app.command("guide", rich_help_panel="Onboarding")
2915
3069
  def guide() -> None:
2916
3070
  """Show the Wafer CLI usage guide.
2917
3071
 
@@ -2942,7 +3096,7 @@ demo_app = typer.Typer(
2942
3096
  wafer demo trace Analyze a sample performance trace
2943
3097
  wafer demo eval Run kernel evaluation on cloud GPU (requires login)"""
2944
3098
  )
2945
- app.add_typer(demo_app, name="demo")
3099
+ app.add_typer(demo_app, name="demo", rich_help_panel="Onboarding")
2946
3100
 
2947
3101
  DEMO_TRACES_URL = "https://github.com/wafer-ai/wafer/raw/main/apps/wafer-cli/wafer/demo_data"
2948
3102
  DEMO_DIR = Path.home() / ".cache" / "wafer" / "demo"
@@ -3162,7 +3316,7 @@ def demo_eval(
3162
3316
  """Demo: Evaluate a kernel on a cloud GPU.
3163
3317
 
3164
3318
  Creates a workspace, runs a sample Triton kernel evaluation, and cleans up.
3165
- Requires authentication (wafer login).
3319
+ Requires authentication (wafer auth login).
3166
3320
 
3167
3321
  Example:
3168
3322
  wafer demo eval
@@ -3177,7 +3331,7 @@ def demo_eval(
3177
3331
  # Check auth first
3178
3332
  creds = load_credentials()
3179
3333
  if not creds:
3180
- typer.echo("Error: Not authenticated. Run: wafer login")
3334
+ typer.echo("Error: Not authenticated. Run: wafer auth login")
3181
3335
  raise typer.Exit(1)
3182
3336
 
3183
3337
  if not yes:
@@ -3458,7 +3612,7 @@ def init_runpod(
3458
3612
  gpu_configs = {
3459
3613
  "MI300X": {
3460
3614
  "gpu_type_id": "AMD Instinct MI300X OAM",
3461
- "image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
3615
+ "image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
3462
3616
  "compute_capability": "9.4",
3463
3617
  },
3464
3618
  "H100": {
@@ -3554,7 +3708,7 @@ def init_digitalocean(
3554
3708
  "ssh_key": ssh_key,
3555
3709
  "region": region,
3556
3710
  "size_slug": "gpu-mi300x1-192gb-devcloud",
3557
- "image": "gpu-amd-base",
3711
+ "image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
3558
3712
  "provision_timeout": 600,
3559
3713
  "eval_timeout": 600,
3560
3714
  "keep_alive": keep_alive,
@@ -3826,12 +3980,16 @@ def targets_add(
3826
3980
 
3827
3981
  @targets_app.command("list")
3828
3982
  def targets_list() -> None:
3829
- """List all configured targets.
3983
+ """List all configured targets with live provider status.
3830
3984
 
3831
3985
  Example:
3832
3986
  wafer config targets list
3833
3987
  """
3834
- from .targets import get_default_target, list_targets
3988
+ import socket
3989
+
3990
+ import trio
3991
+
3992
+ from .targets import get_default_target, list_targets, load_target, remove_target
3835
3993
 
3836
3994
  targets = list_targets()
3837
3995
  default = get_default_target()
@@ -3841,10 +3999,146 @@ def targets_list() -> None:
3841
3999
  typer.echo("Add one with: wafer config targets add <path/to/target.toml>")
3842
4000
  return
3843
4001
 
4002
+ def _parse_ssh_target(ssh_target: str) -> tuple[str, int]:
4003
+ """Extract (host, port) from user@host:port string."""
4004
+ parts = ssh_target.rsplit(":", 1)
4005
+ host_part = parts[0]
4006
+ port = int(parts[1]) if len(parts) > 1 else 22
4007
+ if "@" in host_part:
4008
+ host = host_part.split("@", 1)[1]
4009
+ else:
4010
+ host = host_part
4011
+ return (host, port)
4012
+
4013
+ async def _get_live_provider_endpoints() -> set[tuple[str, int]]:
4014
+ """Query RunPod + DO APIs. Returns set of live (ip, port) endpoints."""
4015
+ from wafer_core.targets.digitalocean import list_running_droplets
4016
+ from wafer_core.targets.runpod import sync_pods_from_api
4017
+
4018
+ live_endpoints: set[tuple[str, int]] = set()
4019
+
4020
+ async def _fetch_runpod() -> None:
4021
+ try:
4022
+ pods = await sync_pods_from_api()
4023
+ for p in pods:
4024
+ live_endpoints.add((p.public_ip, p.ssh_port))
4025
+ except Exception:
4026
+ pass
4027
+
4028
+ async def _fetch_do() -> None:
4029
+ try:
4030
+ droplets = await list_running_droplets()
4031
+ for d in droplets:
4032
+ live_endpoints.add((d.public_ip, d.ssh_port))
4033
+ except Exception:
4034
+ pass
4035
+
4036
+ async with trio.open_nursery() as nursery:
4037
+ nursery.start_soon(_fetch_runpod)
4038
+ nursery.start_soon(_fetch_do)
4039
+
4040
+ return live_endpoints
4041
+
4042
+ async def _get_target_status(
4043
+ name: str,
4044
+ live_endpoints: set[tuple[str, int]],
4045
+ ) -> tuple[str, str, str]:
4046
+ """Returns (name, status, ssh_info)."""
4047
+ from wafer_core.targets.digitalocean import (
4048
+ _remove_droplet_from_state,
4049
+ check_droplet_running,
4050
+ get_droplet_state,
4051
+ )
4052
+ from wafer_core.targets.runpod import (
4053
+ _remove_pod_from_state,
4054
+ check_pod_running,
4055
+ get_pod_state,
4056
+ )
4057
+ from wafer_core.utils.kernel_utils.targets.config import (
4058
+ BaremetalTarget,
4059
+ DigitalOceanTarget,
4060
+ ModalTarget,
4061
+ RunPodTarget,
4062
+ )
4063
+
4064
+ try:
4065
+ target = load_target(name)
4066
+ except (FileNotFoundError, ValueError, AssertionError, TypeError):
4067
+ return (name, "error", "")
4068
+
4069
+ if isinstance(target, RunPodTarget):
4070
+ pod = get_pod_state(name)
4071
+ if not pod:
4072
+ return (name, "no instance", "")
4073
+ if await check_pod_running(pod.pod_id):
4074
+ return (name, "running", f"{pod.ssh_username}@{pod.public_ip}:{pod.ssh_port}")
4075
+ _remove_pod_from_state(name)
4076
+ return (name, "stopped", "")
4077
+
4078
+ if isinstance(target, DigitalOceanTarget):
4079
+ droplet = get_droplet_state(name)
4080
+ if not droplet:
4081
+ return (name, "no instance", "")
4082
+ if await check_droplet_running(droplet.droplet_id):
4083
+ return (
4084
+ name,
4085
+ "running",
4086
+ f"{droplet.ssh_username}@{droplet.public_ip}:{droplet.ssh_port}",
4087
+ )
4088
+ _remove_droplet_from_state(name)
4089
+ return (name, "stopped", "")
4090
+
4091
+ if isinstance(target, BaremetalTarget):
4092
+ ssh_target = target.ssh_target
4093
+ host, port = _parse_ssh_target(ssh_target)
4094
+
4095
+ def _tcp_check() -> bool:
4096
+ try:
4097
+ sock = socket.create_connection((host, port), timeout=2)
4098
+ sock.close()
4099
+ return True
4100
+ except OSError:
4101
+ return False
4102
+
4103
+ reachable = await trio.to_thread.run_sync(_tcp_check)
4104
+ if reachable:
4105
+ return (name, "reachable", ssh_target)
4106
+
4107
+ # Unreachable + has a provider = backed by an ephemeral instance.
4108
+ # If not in the live provider listing, the instance is gone — remove config.
4109
+ if target.provider and (host, port) not in live_endpoints:
4110
+ remove_target(name)
4111
+ return (name, "removed (dead pod)", ssh_target)
4112
+
4113
+ return (name, "unreachable", ssh_target)
4114
+
4115
+ if isinstance(target, ModalTarget):
4116
+ return (name, "serverless", "")
4117
+
4118
+ # Unknown target type
4119
+ return (name, "unknown", "")
4120
+
4121
+ async def _gather_statuses() -> list[tuple[str, str, str]]:
4122
+ live_endpoints = await _get_live_provider_endpoints()
4123
+ results: list[tuple[str, str, str]] = [("", "", "")] * len(targets)
4124
+
4125
+ async def _check(i: int, name: str) -> None:
4126
+ results[i] = await _get_target_status(name, live_endpoints)
4127
+
4128
+ async with trio.open_nursery() as nursery:
4129
+ for i, name in enumerate(targets):
4130
+ nursery.start_soon(_check, i, name)
4131
+
4132
+ return results
4133
+
4134
+ statuses = trio.run(_gather_statuses)
4135
+
3844
4136
  typer.echo("Configured targets:")
3845
- for name in targets:
4137
+ for name, status, ssh_info in statuses:
3846
4138
  marker = " (default)" if name == default else ""
3847
- typer.echo(f" {name}{marker}")
4139
+ label = f" {name}{marker}"
4140
+ detail = f" {ssh_info}" if ssh_info else ""
4141
+ typer.echo(f"{label:<40}{status}{detail}")
3848
4142
 
3849
4143
 
3850
4144
  @targets_app.command("show")
@@ -4056,6 +4350,216 @@ def targets_cleanup(
4056
4350
  raise typer.Exit(1) from None
4057
4351
 
4058
4352
 
4353
+ # Known libraries that can be installed on targets
4354
+ # TODO: Consider adding HipKittens to the default RunPod/DO Docker images
4355
+ # so this install step isn't needed. For now, this command handles it.
4356
+ # Architecture → branch mapping for libraries that ship per-arch branches.
4357
+ # "default" is used when the detected arch has no explicit entry.
4358
+ _ARCH_BRANCHES: dict[str, dict[str, str]] = {
4359
+ "hipkittens": {
4360
+ "gfx942": "cdna3", # MI300X, MI325X
4361
+ "default": "main", # MI350X, MI355X, and future CDNA4+
4362
+ },
4363
+ }
4364
+
4365
+ INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
4366
+ "hipkittens": {
4367
+ "description": "HipKittens - AMD port of ThunderKittens",
4368
+ "git_url": "https://github.com/HazyResearch/HipKittens.git",
4369
+ "install_path": "/opt/hipkittens",
4370
+ "requires_amd": True,
4371
+ },
4372
+ # CK is already installed with ROCm 7.0, no action needed
4373
+ "repair-headers": {
4374
+ "description": "Repair ROCm thrust headers (fixes hipify corruption)",
4375
+ "custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
4376
+ "requires_amd": True,
4377
+ },
4378
+ }
4379
+
4380
+
4381
+ def _resolve_gfx_arch(target: object, ssh_cmd: list[str]) -> str | None:
4382
+ """Return the gfx architecture string for *target*.
4383
+
4384
+ 1. If the target config already carries a compute_capability, map it.
4385
+ 2. Otherwise SSH in and probe with ``rocminfo``.
4386
+ Returns None only if detection fails entirely.
4387
+ """
4388
+ import subprocess
4389
+
4390
+ from .evaluate import AMD_CC_TO_ARCH
4391
+
4392
+ cc = getattr(target, "compute_capability", None)
4393
+ if cc and cc in AMD_CC_TO_ARCH:
4394
+ return AMD_CC_TO_ARCH[cc]
4395
+
4396
+ typer.echo(" Detecting GPU architecture via rocminfo...")
4397
+ probe_script = "rocminfo 2>/dev/null | grep -oP 'gfx\\d+' | head -1"
4398
+ result = subprocess.run(
4399
+ ssh_cmd + [probe_script],
4400
+ capture_output=True,
4401
+ text=True,
4402
+ timeout=30,
4403
+ )
4404
+ arch = result.stdout.strip()
4405
+ if result.returncode == 0 and arch.startswith("gfx"):
4406
+ typer.echo(f" Detected: {arch}")
4407
+ return arch
4408
+
4409
+ typer.echo(" Warning: could not detect GPU architecture", err=True)
4410
+ return None
4411
+
4412
+
4413
+ @targets_app.command("install")
4414
+ def targets_install(
4415
+ name: str = typer.Argument(..., help="Target name"),
4416
+ library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
4417
+ ) -> None:
4418
+ """Install a library or run maintenance on a target (idempotent).
4419
+
4420
+ Installs header-only libraries like HipKittens on remote targets.
4421
+ Safe to run multiple times - will skip if already installed.
4422
+
4423
+ For libraries with per-architecture branches (e.g. HipKittens), the
4424
+ correct branch is selected automatically based on the target's GPU.
4425
+
4426
+ Available libraries:
4427
+ hipkittens - HipKittens (AMD ThunderKittens port)
4428
+ repair-headers - Fix ROCm thrust headers (after hipify corruption)
4429
+
4430
+ Examples:
4431
+ wafer config targets install runpod-mi300x hipkittens
4432
+ wafer config targets install runpod-mi300x repair-headers
4433
+ wafer config targets install do-mi300x hipkittens
4434
+ """
4435
+ import subprocess
4436
+
4437
+ from .targets import load_target
4438
+ from .targets_ops import get_target_ssh_info
4439
+
4440
+ if library not in INSTALLABLE_LIBRARIES:
4441
+ available = ", ".join(INSTALLABLE_LIBRARIES.keys())
4442
+ typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
4443
+ raise typer.Exit(1)
4444
+
4445
+ lib_info = INSTALLABLE_LIBRARIES[library]
4446
+
4447
+ try:
4448
+ target = load_target(name)
4449
+ except FileNotFoundError as e:
4450
+ typer.echo(f"Error: {e}", err=True)
4451
+ raise typer.Exit(1) from None
4452
+
4453
+ # Check if target is AMD (for AMD-only libraries)
4454
+ if lib_info.get("requires_amd"):
4455
+ from wafer_core.utils.kernel_utils.targets.config import (
4456
+ DigitalOceanTarget,
4457
+ RunPodTarget,
4458
+ )
4459
+
4460
+ is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
4461
+ if not is_amd and hasattr(target, "compute_capability"):
4462
+ # Check compute capability for MI300X (gfx942 = 9.4)
4463
+ is_amd = target.compute_capability.startswith("9.")
4464
+ if not is_amd:
4465
+ typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
4466
+ raise typer.Exit(1)
4467
+
4468
+ typer.echo(f"Installing {library} on {name}...")
4469
+ typer.echo(f" {lib_info['description']}")
4470
+
4471
+ async def _install() -> bool:
4472
+ # get_target_ssh_info uses pure trio async (no asyncio bridging needed)
4473
+ # and we use subprocess for SSH, not AsyncSSHClient
4474
+ ssh_info = await get_target_ssh_info(target)
4475
+
4476
+ ssh_cmd = [
4477
+ "ssh",
4478
+ "-o",
4479
+ "StrictHostKeyChecking=no",
4480
+ "-o",
4481
+ "UserKnownHostsFile=/dev/null",
4482
+ "-o",
4483
+ "ConnectTimeout=30",
4484
+ "-i",
4485
+ str(ssh_info.key_path),
4486
+ "-p",
4487
+ str(ssh_info.port),
4488
+ f"{ssh_info.user}@{ssh_info.host}",
4489
+ ]
4490
+
4491
+ # Handle custom scripts (like repair-headers) vs git installs
4492
+ if "custom_script" in lib_info:
4493
+ install_script = str(lib_info["custom_script"])
4494
+ success_marker = "REPAIRED"
4495
+ else:
4496
+ install_path = lib_info["install_path"]
4497
+ git_url = lib_info["git_url"]
4498
+
4499
+ # Resolve the branch for arch-aware libraries
4500
+ branch = "main"
4501
+ arch_map = _ARCH_BRANCHES.get(library)
4502
+ if arch_map:
4503
+ gfx = await trio.to_thread.run_sync(lambda: _resolve_gfx_arch(target, ssh_cmd))
4504
+ branch = arch_map.get(gfx, arch_map["default"]) if gfx else arch_map["default"]
4505
+ typer.echo(f" Branch: {branch} (arch={gfx or 'unknown'})")
4506
+
4507
+ # Idempotent: if already cloned, ensure correct branch & pull
4508
+ install_script = f"""
4509
+ if [ -d "{install_path}" ]; then
4510
+ echo "ALREADY_INSTALLED: {install_path} exists"
4511
+ cd {install_path} && git fetch --quiet origin && git checkout {branch} --quiet && git pull --quiet origin {branch}
4512
+ else
4513
+ echo "INSTALLING: cloning to {install_path}"
4514
+ git clone --quiet --branch {branch} {git_url} {install_path}
4515
+ fi
4516
+ echo "DONE"
4517
+ """
4518
+ success_marker = "DONE"
4519
+
4520
+ def run_ssh() -> subprocess.CompletedProcess[str]:
4521
+ return subprocess.run(
4522
+ ssh_cmd + [install_script],
4523
+ capture_output=True,
4524
+ text=True,
4525
+ timeout=120,
4526
+ )
4527
+
4528
+ result = await trio.to_thread.run_sync(run_ssh)
4529
+
4530
+ if result.returncode != 0:
4531
+ typer.echo(f"Error: {result.stderr}", err=True)
4532
+ return False
4533
+
4534
+ output = result.stdout.strip()
4535
+ if "ALREADY_INSTALLED" in output:
4536
+ typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
4537
+ elif "INSTALLING" in output:
4538
+ typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
4539
+ elif "REPAIRED" in output:
4540
+ typer.echo(" ROCm headers repaired")
4541
+
4542
+ return success_marker in output
4543
+
4544
+ try:
4545
+ success = trio.run(_install)
4546
+ except Exception as e:
4547
+ typer.echo(f"Error: {e}", err=True)
4548
+ raise typer.Exit(1) from None
4549
+
4550
+ if success:
4551
+ typer.echo(f"✓ {library} ready on {name}")
4552
+
4553
+ # Print usage hint
4554
+ if library == "hipkittens":
4555
+ typer.echo("")
4556
+ typer.echo("Usage in load_inline:")
4557
+ typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
4558
+ else:
4559
+ typer.echo(f"Failed to install {library}", err=True)
4560
+ raise typer.Exit(1)
4561
+
4562
+
4059
4563
  @targets_app.command("pods")
4060
4564
  def targets_pods() -> None:
4061
4565
  """List all running RunPod pods.
@@ -4185,8 +4689,8 @@ def billing_usage(
4185
4689
  """Show current billing usage and subscription info.
4186
4690
 
4187
4691
  Example:
4188
- wafer billing
4189
- wafer billing --json
4692
+ wafer config billing
4693
+ wafer config billing --json
4190
4694
  """
4191
4695
  # Only show usage if no subcommand was invoked
4192
4696
  if ctx.invoked_subcommand is not None:
@@ -4214,9 +4718,9 @@ def billing_topup(
4214
4718
  Opens a Stripe checkout page to add credits. Default amount is $25.
4215
4719
 
4216
4720
  Example:
4217
- wafer billing topup # Add $25
4218
- wafer billing topup 100 # Add $100
4219
- wafer billing topup --no-browser # Print URL instead
4721
+ wafer config billing topup # Add $25
4722
+ wafer config billing topup 100 # Add $100
4723
+ wafer config billing topup --no-browser # Print URL instead
4220
4724
  """
4221
4725
  import webbrowser
4222
4726
 
@@ -4262,8 +4766,8 @@ def billing_portal(
4262
4766
  Manage your subscription, update payment method, or view invoices.
4263
4767
 
4264
4768
  Example:
4265
- wafer billing portal
4266
- wafer billing portal --no-browser
4769
+ wafer config billing portal
4770
+ wafer config billing portal --no-browser
4267
4771
  """
4268
4772
  import webbrowser
4269
4773
 
@@ -4300,8 +4804,8 @@ def ssh_keys_list(
4300
4804
  """List all registered SSH public keys.
4301
4805
 
4302
4806
  Example:
4303
- wafer ssh-keys list
4304
- wafer ssh-keys list --json
4807
+ wafer config ssh-keys list
4808
+ wafer config ssh-keys list --json
4305
4809
  """
4306
4810
  from .ssh_keys import list_ssh_keys
4307
4811
 
@@ -4327,9 +4831,9 @@ def ssh_keys_add(
4327
4831
  id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
4328
4832
 
4329
4833
  Example:
4330
- wafer ssh-keys add # Auto-detect
4331
- wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
4332
- wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
4834
+ wafer config ssh-keys add # Auto-detect
4835
+ wafer config ssh-keys add ~/.ssh/id_rsa.pub # Specific file
4836
+ wafer config ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
4333
4837
  """
4334
4838
  from .ssh_keys import add_ssh_key
4335
4839
 
@@ -4348,10 +4852,10 @@ def ssh_keys_remove(
4348
4852
  ) -> None:
4349
4853
  """Remove an SSH public key.
4350
4854
 
4351
- Get the key ID from 'wafer ssh-keys list'.
4855
+ Get the key ID from 'wafer config ssh-keys list'.
4352
4856
 
4353
4857
  Example:
4354
- wafer ssh-keys remove abc123-def456-...
4858
+ wafer config ssh-keys remove abc123-def456-...
4355
4859
  """
4356
4860
  from .ssh_keys import remove_ssh_key
4357
4861
 
@@ -4391,9 +4895,13 @@ def workspaces_list(
4391
4895
  @workspaces_app.command("create")
4392
4896
  def workspaces_create(
4393
4897
  name: str = typer.Argument(..., help="Workspace name"),
4394
- gpu_type: str = typer.Option("B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"),
4898
+ gpu_type: str = typer.Option(
4899
+ "B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"
4900
+ ),
4395
4901
  image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
4396
- wait: bool = typer.Option(False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"),
4902
+ wait: bool = typer.Option(
4903
+ False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"
4904
+ ),
4397
4905
  json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
4398
4906
  ) -> None:
4399
4907
  """Create a new workspace.
@@ -4702,19 +5210,25 @@ def workspaces_ssh(
4702
5210
  ssh_host = ws.get("ssh_host")
4703
5211
  ssh_port = ws.get("ssh_port")
4704
5212
  ssh_user = ws.get("ssh_user")
4705
-
5213
+
4706
5214
  if not ssh_host or not ssh_port or not ssh_user:
4707
5215
  typer.echo("Error: Workspace not ready. Wait a few seconds and retry.", err=True)
4708
5216
  raise typer.Exit(1)
4709
5217
 
4710
5218
  # Connect via SSH
4711
- os.execvp("ssh", [
5219
+ os.execvp(
4712
5220
  "ssh",
4713
- "-p", str(ssh_port),
4714
- "-o", "StrictHostKeyChecking=no",
4715
- "-o", "UserKnownHostsFile=/dev/null",
4716
- f"{ssh_user}@{ssh_host}",
4717
- ])
5221
+ [
5222
+ "ssh",
5223
+ "-p",
5224
+ str(ssh_port),
5225
+ "-o",
5226
+ "StrictHostKeyChecking=no",
5227
+ "-o",
5228
+ "UserKnownHostsFile=/dev/null",
5229
+ f"{ssh_user}@{ssh_host}",
5230
+ ],
5231
+ )
4718
5232
 
4719
5233
 
4720
5234
  @workspaces_app.command("sync")
@@ -4777,6 +5291,69 @@ def workspaces_sync(
4777
5291
  raise typer.Exit(1) from None
4778
5292
 
4779
5293
 
5294
+ @workspaces_app.command("pull")
5295
+ def workspaces_pull(
5296
+ workspace: str = typer.Argument(..., help="Workspace name or ID"),
5297
+ remote_path: str = typer.Argument(
5298
+ ..., help="Remote path in workspace (relative to /workspace or absolute)"
5299
+ ),
5300
+ local_path: Path = typer.Argument(
5301
+ Path("."), help="Local destination path (default: current directory)"
5302
+ ),
5303
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
5304
+ quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
5305
+ ) -> None:
5306
+ """Pull files from workspace to local machine.
5307
+
5308
+ Uses rsync over SSH to download files from the workspace's /workspace directory.
5309
+
5310
+ Examples:
5311
+ wafer workspaces pull dev kernel.py ./ # Pull single file
5312
+ wafer workspaces pull dev kernel.py ./my_kernel.py # Pull and rename
5313
+ wafer workspaces pull dev /workspace/results ./ # Pull directory
5314
+ """
5315
+ from .global_config import get_preferences
5316
+ from .workspaces import pull_files
5317
+
5318
+ # Determine verbosity based on mode
5319
+ prefs = get_preferences()
5320
+ if quiet:
5321
+ show_status = False
5322
+ elif verbose:
5323
+ show_status = True
5324
+ else:
5325
+ show_status = prefs.mode == "explicit"
5326
+
5327
+ if show_status:
5328
+ typer.echo(f"[wafer] Pulling {remote_path} from workspace {workspace}...", err=True)
5329
+
5330
+ def on_progress(msg: str) -> None:
5331
+ if show_status:
5332
+ typer.echo(f"[wafer] {msg}", err=True)
5333
+
5334
+ try:
5335
+ file_count = pull_files(
5336
+ workspace, remote_path, local_path.resolve(), on_progress=on_progress
5337
+ )
5338
+ if show_status:
5339
+ typer.echo(f"[wafer] Pulled {file_count} files to {local_path}", err=True)
5340
+ except RuntimeError as e:
5341
+ typer.echo(f"Error: {e}", err=True)
5342
+ raise typer.Exit(1) from None
5343
+
5344
+
5345
+ # =============================================================================
5346
+ # Live resource commands (list/terminate/reconcile/provision)
5347
+ # =============================================================================
5348
+
5349
+ targets_ops_app.command("list")(_targets_list_cmd)
5350
+ targets_ops_app.command("terminate")(_targets_terminate_cmd)
5351
+ targets_ops_app.command("reconcile")(_targets_reconcile_cmd)
5352
+ targets_ops_app.command("provision")(_targets_provision_cmd)
5353
+ targets_ops_app.command("pools")(_targets_pools_cmd)
5354
+ targets_ops_app.command("probe")(_targets_probe_cmd)
5355
+
5356
+
4780
5357
  # =============================================================================
4781
5358
  # Target operations commands (exec/ssh/sync)
4782
5359
  # =============================================================================
@@ -5535,7 +6112,7 @@ def ncu_analyze(
5535
6112
  compute/memory throughput, and optimization recommendations.
5536
6113
 
5537
6114
  By default, uses local NCU if available, otherwise runs analysis
5538
- remotely via wafer-api (requires authentication: wafer login).
6115
+ remotely via wafer-api (requires authentication: wafer auth login).
5539
6116
 
5540
6117
  Use --target for direct SSH mode (like wafer remote-run --direct).
5541
6118
  Use --include-source to fetch SASS assembly with register/instruction data.
@@ -5630,7 +6207,7 @@ def nsys_analyze(
5630
6207
  Returns timeline events, kernel information, memory usage, and diagnostics.
5631
6208
 
5632
6209
  By default, uses local nsys if available, otherwise runs analysis
5633
- remotely via wafer-api (requires authentication: wafer login).
6210
+ remotely via wafer-api (requires authentication: wafer auth login).
5634
6211
 
5635
6212
  Supports multiple execution modes:
5636
6213
  - Local: Uses local nsys CLI (no GPU required for analysis)
@@ -6615,7 +7192,7 @@ def autotuner_results(
6615
7192
  raise typer.Exit(1) from None
6616
7193
 
6617
7194
 
6618
- @app.command("capture")
7195
+ @app.command("capture", rich_help_panel="Kernel Development")
6619
7196
  def capture_command( # noqa: PLR0915
6620
7197
  label: str = typer.Argument(
6621
7198
  ..., help="Label for this capture (e.g., 'baseline', 'optimized-v2')"
@@ -7280,6 +7857,203 @@ def isa_targets() -> None:
7280
7857
  typer.echo(output)
7281
7858
 
7282
7859
 
7860
+ # =============================================================================
7861
+ # Trace comparison commands
7862
+ # =============================================================================
7863
+
7864
+
7865
+ @compare_app.command("analyze")
7866
+ def compare_analyze(
7867
+ trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
7868
+ trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
7869
+ format: str = typer.Option(
7870
+ "text",
7871
+ "--format",
7872
+ "-f",
7873
+ help="Output format: text, text-layers, csv, csv-layers, json",
7874
+ ),
7875
+ output: Path | None = typer.Option(
7876
+ None, "--output", "-o", help="Output file (default: stdout)"
7877
+ ),
7878
+ phase: str = typer.Option(
7879
+ "all",
7880
+ "--phase",
7881
+ help="Filter by phase: all, prefill, decode",
7882
+ ),
7883
+ layers: bool = typer.Option(False, "--layers", help="Show layer-wise performance breakdown"),
7884
+ all: bool = typer.Option(
7885
+ False, "--all", help="Show all items (no truncation for layers, operations, kernels)"
7886
+ ),
7887
+ stack_traces: bool = typer.Option(
7888
+ False, "--stack-traces", help="Show Python stack traces for operations"
7889
+ ),
7890
+ recommendations: bool = typer.Option(
7891
+ False, "--recommendations", help="Generate prioritized recommendations for kernel team"
7892
+ ),
7893
+ json: bool = typer.Option(
7894
+ False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
7895
+ ),
7896
+ ) -> None:
7897
+ """Compare GPU traces from AMD and NVIDIA platforms.
7898
+
7899
+ Analyzes performance differences between traces, identifying which operations
7900
+ are faster/slower on each platform and providing kernel-level details.
7901
+
7902
+ Examples:
7903
+ # Basic comparison (stdout)
7904
+ wafer compare analyze amd_trace.json nvidia_trace.json
7905
+
7906
+ # Show layer-wise breakdown
7907
+ wafer compare analyze amd_trace.json nvidia_trace.json --layers
7908
+ wafer compare analyze amd_trace.json nvidia_trace.json --format text-layers
7909
+
7910
+ # Show all layers without truncation
7911
+ wafer compare analyze amd_trace.json nvidia_trace.json --layers --all
7912
+
7913
+ # Show Python stack traces
7914
+ wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces
7915
+
7916
+ # Show all stack traces without truncation
7917
+ wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces --all
7918
+
7919
+ # Save to file
7920
+ wafer compare analyze amd_trace.json nvidia_trace.json -o report.txt
7921
+
7922
+ # CSV output (operations) to file
7923
+ wafer compare analyze amd_trace.json nvidia_trace.json --format csv -o operations.csv
7924
+
7925
+ # CSV output (layers) to file
7926
+ wafer compare analyze amd_trace.json nvidia_trace.json --format csv-layers -o layers.csv
7927
+
7928
+ # JSON output to file
7929
+ wafer compare analyze amd_trace.json nvidia_trace.json --format json -o report.json
7930
+
7931
+ # Analyze only prefill phase
7932
+ wafer compare analyze amd_trace.json nvidia_trace.json --phase prefill
7933
+ """
7934
+ from .trace_compare import compare_traces
7935
+
7936
+ compare_traces(
7937
+ trace1=trace1,
7938
+ trace2=trace2,
7939
+ output=output,
7940
+ output_format=format,
7941
+ phase=phase,
7942
+ show_layers=layers,
7943
+ show_all=all,
7944
+ show_stack_traces=stack_traces,
7945
+ recommendations=recommendations,
7946
+ )
7947
+ _mark_command_success()
7948
+
7949
+
7950
+ @compare_app.command("fusion")
7951
+ def compare_fusion_cmd(
7952
+ trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
7953
+ trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
7954
+ format: str = typer.Option(
7955
+ "text",
7956
+ "--format",
7957
+ "-f",
7958
+ help="Output format: text, csv, json",
7959
+ ),
7960
+ output: Path | None = typer.Option(
7961
+ None, "--output", "-o", help="Output file (default: stdout)"
7962
+ ),
7963
+ min_group_size: int = typer.Option(
7964
+ 50,
7965
+ "--min-group-size",
7966
+ help="Minimum correlation group size to analyze",
7967
+ ),
7968
+ json: bool = typer.Option(
7969
+ False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
7970
+ ),
7971
+ ) -> None:
7972
+ """Analyze kernel fusion differences between AMD and NVIDIA traces.
7973
+
7974
+ Detects which operations are fused differently on each platform by analyzing
7975
+ how many kernel launches each platform uses for the same logical operations.
7976
+
7977
+ Examples:
7978
+ # Basic fusion analysis (stdout)
7979
+ wafer compare fusion amd_trace.json nvidia_trace.json
7980
+
7981
+ # Save to file
7982
+ wafer compare fusion amd_trace.json nvidia_trace.json -o fusion_report.txt
7983
+
7984
+ # JSON output to file
7985
+ wafer compare fusion amd_trace.json nvidia_trace.json --format json -o fusion.json
7986
+
7987
+ # CSV output to file
7988
+ wafer compare fusion amd_trace.json nvidia_trace.json --format csv -o fusion.csv
7989
+ """
7990
+ from .trace_compare import compare_align
7991
+
7992
+ compare_align(
7993
+ trace1=trace1,
7994
+ trace2=trace2,
7995
+ output=output,
7996
+ output_format=format,
7997
+ phase="all",
7998
+ )
7999
+ _mark_command_success()
8000
+
8001
+
8002
+ @compare_app.command("align")
8003
+ def compare_align_cmd(
8004
+ trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
8005
+ trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
8006
+ format: str = typer.Option(
8007
+ "json",
8008
+ "--format",
8009
+ "-f",
8010
+ help="Output format: json",
8011
+ ),
8012
+ output: Path | None = typer.Option(
8013
+ None, "--output", "-o", help="Output file (default: stdout)"
8014
+ ),
8015
+ phase: str = typer.Option(
8016
+ "all",
8017
+ "--phase",
8018
+ help="Filter by phase: all, prefill, decode",
8019
+ ),
8020
+ layer: int | None = typer.Option(
8021
+ None,
8022
+ "--layer",
8023
+ help="Focus on specific layer number",
8024
+ ),
8025
+ ) -> None:
8026
+ """Align kernels at layer level for exact kernel-to-kernel comparison.
8027
+
8028
+ Provides kernel-to-kernel mapping across AMD and NVIDIA platforms,
8029
+ showing which kernels correspond to each other at each layer position.
8030
+
8031
+ Examples:
8032
+ # Basic alignment (stdout JSON)
8033
+ wafer compare align amd_trace.json nvidia_trace.json
8034
+
8035
+ # Save to file
8036
+ wafer compare align amd_trace.json nvidia_trace.json -o alignment.json
8037
+
8038
+ # Focus on decode phase only
8039
+ wafer compare align amd_trace.json nvidia_trace.json --phase decode
8040
+
8041
+ # Focus on specific layer
8042
+ wafer compare align amd_trace.json nvidia_trace.json --layer 5
8043
+ """
8044
+ from .trace_compare import compare_align
8045
+
8046
+ compare_align(
8047
+ trace1=trace1,
8048
+ trace2=trace2,
8049
+ output=output,
8050
+ output_format=format,
8051
+ phase=phase,
8052
+ layer=layer,
8053
+ )
8054
+ _mark_command_success()
8055
+
8056
+
7283
8057
  def main() -> None:
7284
8058
  """Entry point for wafer CLI."""
7285
8059
  app()