wafer-cli 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -194,11 +194,16 @@ def complete_target_name(incomplete: str) -> list[str]:
194
194
 
195
195
  # =============================================================================
196
196
  # Core subcommand groups (visible in --help)
197
+ #
198
+ # TODO: Further consolidate top-level commands to reduce --help surface area.
199
+ # Candidates:
200
+ # - compare → wafer nvidia compare or keep top-level (cross-platform)
201
+ # - guide/skill/demo → wafer onboard {guide,skill,demo}
197
202
  # =============================================================================
198
203
 
199
204
  # Config management (includes targets as nested subcommand)
200
205
  config_app = typer.Typer(help="Manage CLI configuration and local GPU targets")
201
- app.add_typer(config_app, name="config")
206
+ app.add_typer(config_app, name="config", rich_help_panel="Configuration")
202
207
 
203
208
  # Target management - nested under config
204
209
  targets_app = typer.Typer(
@@ -218,7 +223,7 @@ config_app.add_typer(targets_app, name="targets")
218
223
  workspaces_app = typer.Typer(
219
224
  help="""Manage cloud GPU workspaces for remote development.
220
225
 
221
- Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
226
+ Workspaces are on-demand cloud GPU environments. Requires authentication (wafer auth login).
222
227
 
223
228
  Available GPUs:
224
229
  MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
@@ -231,21 +236,21 @@ Commands:
231
236
  wafer workspaces sync dev ./project # Sync files
232
237
  wafer workspaces delete dev # Clean up"""
233
238
  )
234
- app.add_typer(workspaces_app, name="workspaces")
239
+ app.add_typer(workspaces_app, name="workspaces", rich_help_panel="Infrastructure")
235
240
 
236
- # SSH Key management (BYOK - Bring Your Own Key)
241
+ # SSH Key management (BYOK - Bring Your Own Key) - nested under config
237
242
  ssh_keys_app = typer.Typer(
238
243
  help="""Manage SSH public keys for workspace access.
239
244
 
240
245
  Register your SSH public keys here. These keys are installed in all workspaces
241
246
  you provision, enabling SSH access from any machine with your private key.
242
247
 
243
- wafer ssh-keys list # List registered keys
244
- wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
245
- wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
246
- wafer ssh-keys remove <key-id> # Remove a key"""
248
+ wafer config ssh-keys list # List registered keys
249
+ wafer config ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
250
+ wafer config ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
251
+ wafer config ssh-keys remove <key-id> # Remove a key"""
247
252
  )
248
- app.add_typer(ssh_keys_app, name="ssh-keys")
253
+ config_app.add_typer(ssh_keys_app, name="ssh-keys")
249
254
 
250
255
  # Target operations (exec/ssh/sync on configured targets)
251
256
  targets_ops_app = typer.Typer(
@@ -261,22 +266,22 @@ Useful for exploratory work, debugging, or custom scripts.
261
266
  Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
262
267
  Configure targets with: wafer config targets init ..."""
263
268
  )
264
- app.add_typer(targets_ops_app, name="targets")
269
+ app.add_typer(targets_ops_app, name="targets", rich_help_panel="Infrastructure")
265
270
 
266
- # Billing management
271
+ # Billing management - nested under config
267
272
  billing_app = typer.Typer(help="Manage billing, credits, and subscription")
268
- app.add_typer(billing_app, name="billing")
273
+ config_app.add_typer(billing_app, name="billing")
269
274
 
270
275
  # Corpus management
271
276
  corpus_app = typer.Typer(help="Download and manage GPU documentation")
272
- app.add_typer(corpus_app, name="corpus")
277
+ app.add_typer(corpus_app, name="corpus", rich_help_panel="Kernel Development")
273
278
 
274
279
  # Evaluate (supports multiple kernel formats)
275
280
  evaluate_app = typer.Typer(
276
281
  help="Test kernel correctness and performance",
277
282
  invoke_without_command=True,
278
283
  )
279
- app.add_typer(evaluate_app, name="evaluate")
284
+ app.add_typer(evaluate_app, name="evaluate", rich_help_panel="Kernel Development")
280
285
 
281
286
  # Nested subcommand for kernelbench format
282
287
  kernelbench_app = typer.Typer(
@@ -305,7 +310,7 @@ app.add_typer(dev_app, name="dev")
305
310
  # =============================================================================
306
311
 
307
312
  nvidia_app = typer.Typer(help="NVIDIA GPU profiling and analysis tools")
308
- app.add_typer(nvidia_app, name="nvidia")
313
+ app.add_typer(nvidia_app, name="nvidia", rich_help_panel="Profiling")
309
314
 
310
315
  # NCU analysis - under nvidia
311
316
  ncu_app = typer.Typer(help="Nsight Compute profile analysis")
@@ -328,18 +333,25 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
328
333
  # =============================================================================
329
334
 
330
335
  amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
331
- app.add_typer(amd_app, name="amd")
336
+ app.add_typer(amd_app, name="amd", rich_help_panel="Profiling")
332
337
 
333
338
  # Unified ISA Analyzer - supports both .co files and Triton artifacts
334
339
  isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
335
340
  amd_app.add_typer(isa_app, name="isa")
336
341
 
342
+ # =============================================================================
343
+ # Trace comparison (wafer compare)
344
+ # =============================================================================
345
+
346
+ compare_app = typer.Typer(help="Compare GPU traces across platforms (AMD vs NVIDIA)")
347
+ app.add_typer(compare_app, name="compare", rich_help_panel="Profiling")
348
+
337
349
  # =============================================================================
338
350
  # Roofline analysis (wafer roofline)
339
351
  # =============================================================================
340
352
 
341
353
 
342
- @app.command("roofline")
354
+ @app.command("roofline", rich_help_panel="Kernel Development")
343
355
  def roofline_cmd(
344
356
  gpu: str | None = typer.Option(
345
357
  None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
@@ -430,7 +442,7 @@ def roofline_cmd(
430
442
  # =============================================================================
431
443
 
432
444
  skill_app = typer.Typer(help="Manage AI coding assistant skills (Claude Code, Codex)")
433
- app.add_typer(skill_app, name="skill")
445
+ app.add_typer(skill_app, name="skill", rich_help_panel="Onboarding")
434
446
 
435
447
 
436
448
  @skill_app.command("install")
@@ -594,14 +606,17 @@ def skill_status() -> None:
594
606
 
595
607
 
596
608
  # =============================================================================
597
- # Provider auth management (wafer auth ...)
609
+ # Authentication (wafer auth ...)
598
610
  # =============================================================================
599
611
 
600
- provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
601
- app.add_typer(provider_auth_app, name="auth")
612
+ auth_app = typer.Typer(help="Authenticate with Wafer and cloud GPU providers")
613
+ app.add_typer(auth_app, name="auth", rich_help_panel="Configuration")
602
614
 
615
+ providers_app = typer.Typer(help="Manage API keys for cloud GPU providers (RunPod, DigitalOcean, etc.)")
616
+ auth_app.add_typer(providers_app, name="providers")
603
617
 
604
- @provider_auth_app.command("login")
618
+
619
+ @providers_app.command("login")
605
620
  def provider_auth_login(
606
621
  provider: str = typer.Argument(
607
622
  ...,
@@ -620,10 +635,10 @@ def provider_auth_login(
620
635
  (e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
621
636
 
622
637
  Examples:
623
- wafer auth login anthropic --api-key sk-ant-xxx
624
- wafer auth login runpod --api-key rp_xxx
625
- wafer auth login openai --api-key sk-xxx
626
- echo $API_KEY | wafer auth login anthropic
638
+ wafer auth providers login anthropic --api-key sk-ant-xxx
639
+ wafer auth providers login runpod --api-key rp_xxx
640
+ wafer auth providers login openai --api-key sk-xxx
641
+ echo $API_KEY | wafer auth providers login anthropic
627
642
  """
628
643
  import sys
629
644
 
@@ -653,7 +668,7 @@ def provider_auth_login(
653
668
  typer.echo("Stored in: ~/.wafer/auth.json")
654
669
 
655
670
 
656
- @provider_auth_app.command("logout")
671
+ @providers_app.command("logout")
657
672
  def provider_auth_logout(
658
673
  provider: str = typer.Argument(
659
674
  ...,
@@ -663,8 +678,8 @@ def provider_auth_logout(
663
678
  """Remove stored API key for a cloud GPU provider.
664
679
 
665
680
  Examples:
666
- wafer auth logout runpod
667
- wafer auth logout digitalocean
681
+ wafer auth providers logout runpod
682
+ wafer auth providers logout digitalocean
668
683
  """
669
684
  from wafer_core.auth import PROVIDERS, remove_api_key
670
685
 
@@ -680,7 +695,7 @@ def provider_auth_logout(
680
695
  typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
681
696
 
682
697
 
683
- @provider_auth_app.command("status")
698
+ @providers_app.command("status")
684
699
  def provider_auth_status() -> None:
685
700
  """Show authentication status for all cloud GPU providers.
686
701
 
@@ -688,7 +703,7 @@ def provider_auth_status() -> None:
688
703
  the keys are coming from (environment variable or auth.json).
689
704
 
690
705
  Example:
691
- wafer auth status
706
+ wafer auth providers status
692
707
  """
693
708
  from wafer_core.auth import get_all_auth_status
694
709
 
@@ -703,7 +718,7 @@ def provider_auth_status() -> None:
703
718
  typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
704
719
  else:
705
720
  typer.echo(f" {status.display_name}: ✗ Not configured")
706
- typer.echo(f" Run: wafer auth login {status.provider}")
721
+ typer.echo(f" Run: wafer auth providers login {status.provider}")
707
722
  typer.echo(f" Or set: {status.key_url}")
708
723
 
709
724
  typer.echo("")
@@ -1248,7 +1263,7 @@ def config_show_legacy() -> None:
1248
1263
  config_show_new()
1249
1264
 
1250
1265
 
1251
- @app.command()
1266
+ @app.command(rich_help_panel="Kernel Development")
1252
1267
  def agent( # noqa: PLR0913
1253
1268
  prompt: str | None = typer.Argument(
1254
1269
  None,
@@ -1318,7 +1333,7 @@ def agent( # noqa: PLR0913
1318
1333
  None,
1319
1334
  "--model",
1320
1335
  "-m",
1321
- help="Model override (default: claude-sonnet-4-5)",
1336
+ help="Model override (default: claude-opus-4-5)",
1322
1337
  ),
1323
1338
  json_output: bool = typer.Option(
1324
1339
  False,
@@ -1347,6 +1362,11 @@ def agent( # noqa: PLR0913
1347
1362
  "--no-sandbox",
1348
1363
  help="Disable OS-level sandboxing (YOU accept liability for any damage caused by the agent)",
1349
1364
  ),
1365
+ no_proxy: bool = typer.Option(
1366
+ False,
1367
+ "--no-proxy",
1368
+ help="Skip wafer proxy, use ANTHROPIC_API_KEY directly",
1369
+ ),
1350
1370
  ) -> None:
1351
1371
  """AI assistant for GPU kernel development.
1352
1372
 
@@ -1453,6 +1473,7 @@ def agent( # noqa: PLR0913
1453
1473
  template_args=parsed_template_args,
1454
1474
  corpus_path=corpus_path,
1455
1475
  no_sandbox=no_sandbox,
1476
+ no_proxy=no_proxy,
1456
1477
  )
1457
1478
 
1458
1479
 
@@ -1527,7 +1548,11 @@ def evaluate( # noqa: PLR0913
1527
1548
  None, "--reference", help="Path to reference kernel file"
1528
1549
  ),
1529
1550
  test_cases: Path | None = typer.Option(
1530
- None, "--test-cases", help="Path to test cases JSON file"
1551
+ None,
1552
+ "--test-cases",
1553
+ help="Path to test cases JSON file. "
1554
+ 'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
1555
+ "Run 'wafer evaluate make-template' to generate an example.",
1531
1556
  ),
1532
1557
  target: str | None = typer.Option(
1533
1558
  None,
@@ -1557,20 +1582,20 @@ def evaluate( # noqa: PLR0913
1557
1582
 
1558
1583
  Examples:
1559
1584
  # Basic correctness check
1560
- wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json
1585
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
1561
1586
 
1562
1587
  # With benchmarking on a specific target
1563
- wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
1588
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
1564
1589
  --target vultr-b200 --benchmark
1565
1590
 
1566
1591
  # Full evaluation with defensive timing (detects cheating)
1567
- wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
1592
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
1568
1593
  --benchmark --defensive
1569
1594
 
1570
1595
  Subcommands:
1571
1596
  gpumode Use GPUMode format (functional) - RECOMMENDED
1572
1597
  kernelbench Use KernelBench format (ModelNew class)
1573
- make-template Generate template files for this format (deprecated)
1598
+ make-template Generate template files for this format
1574
1599
  """
1575
1600
  # If a subcommand is being invoked, skip the main evaluation logic
1576
1601
  if ctx.invoked_subcommand is not None:
@@ -1724,7 +1749,7 @@ def evaluate_make_template(
1724
1749
  typer.echo(f" 2. Edit {output_dir / 'reference.py'} with the ground truth + input generator")
1725
1750
  typer.echo(f" 3. Edit {output_dir / 'test_cases.json'} with your test parameters")
1726
1751
  typer.echo(" 4. Run:")
1727
- typer.echo(f" wafer evaluate --impl {output_dir / 'kernel.py'} \\")
1752
+ typer.echo(f" wafer evaluate gpumode --impl {output_dir / 'kernel.py'} \\")
1728
1753
  typer.echo(f" --reference {output_dir / 'reference.py'} \\")
1729
1754
  typer.echo(f" --test-cases {output_dir / 'test_cases.json'} --benchmark")
1730
1755
 
@@ -2275,7 +2300,11 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2275
2300
  None, "--reference", help="Path to reference kernel file"
2276
2301
  ),
2277
2302
  test_cases: Path | None = typer.Option(
2278
- None, "--test-cases", help="Path to test cases JSON file"
2303
+ None,
2304
+ "--test-cases",
2305
+ help="Path to test cases JSON file. "
2306
+ 'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
2307
+ "Run 'wafer evaluate make-template' to generate an example.",
2279
2308
  ),
2280
2309
  target: str | None = typer.Option(
2281
2310
  None,
@@ -2343,6 +2372,13 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2343
2372
  err=True,
2344
2373
  )
2345
2374
  typer.echo("", err=True)
2375
+ if "--test-cases" in missing_args:
2376
+ typer.echo(
2377
+ "Tip: Run 'wafer evaluate make-template' to generate template files "
2378
+ "including test_cases.json.",
2379
+ err=True,
2380
+ )
2381
+ typer.echo("", err=True)
2346
2382
  typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
2347
2383
  typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
2348
2384
  raise typer.Exit(1)
@@ -2749,7 +2785,7 @@ def remote_run( # noqa: PLR0913
2749
2785
  # =============================================================================
2750
2786
 
2751
2787
 
2752
- @app.command("login")
2788
+ @auth_app.command("login")
2753
2789
  def login(
2754
2790
  token: str | None = typer.Option(
2755
2791
  None, "--token", "-t", help="Access token (skip browser OAuth)"
@@ -2774,7 +2810,7 @@ def login(
2774
2810
  Uses the API environment from config (see 'wafer config show').
2775
2811
 
2776
2812
  SSH Users (Easiest):
2777
- - Just run: wafer login
2813
+ - Just run: wafer auth login
2778
2814
  - Visit the URL and enter the code shown
2779
2815
  - No port forwarding needed!
2780
2816
 
@@ -2784,17 +2820,17 @@ def login(
2784
2820
 
2785
2821
  Manual token option:
2786
2822
  - Visit auth.wafer.ai, authenticate, copy token from URL
2787
- - Run: wafer login --token <paste-token>
2823
+ - Run: wafer auth login --token <paste-token>
2788
2824
 
2789
2825
  Examples:
2790
- wafer login # device code on SSH, browser on local
2791
- wafer login --no-device-code # force browser (needs port forwarding on SSH)
2792
- wafer login --port 9000 # custom port for browser flow
2793
- wafer login --token xyz # manual token (no browser)
2826
+ wafer auth login # device code on SSH, browser on local
2827
+ wafer auth login --no-device-code # force browser (needs port forwarding on SSH)
2828
+ wafer auth login --port 9000 # custom port for browser flow
2829
+ wafer auth login --token xyz # manual token (no browser)
2794
2830
 
2795
2831
  # Change environment:
2796
2832
  wafer config set api.environment staging
2797
- wafer login
2833
+ wafer auth login
2798
2834
  """
2799
2835
  import httpx
2800
2836
 
@@ -2878,7 +2914,7 @@ def login(
2878
2914
  typer.echo("Token saved to ~/.wafer/credentials.json")
2879
2915
 
2880
2916
 
2881
- @app.command("logout")
2917
+ @auth_app.command("logout")
2882
2918
  def logout() -> None:
2883
2919
  """Remove stored credentials."""
2884
2920
  from . import analytics
@@ -2895,7 +2931,7 @@ def logout() -> None:
2895
2931
  typer.echo("Not logged in (no credentials found).")
2896
2932
 
2897
2933
 
2898
- @app.command("whoami")
2934
+ @auth_app.command("whoami")
2899
2935
  def whoami(
2900
2936
  verify: bool = typer.Option(False, "--verify", "-v", help="Verify token with API"),
2901
2937
  refresh: bool = typer.Option(False, "--refresh", "-r", help="Refresh token if expired"),
@@ -2909,7 +2945,7 @@ def whoami(
2909
2945
 
2910
2946
  creds = load_credentials()
2911
2947
  if creds is None:
2912
- typer.echo("Not logged in. Run: wafer login")
2948
+ typer.echo("Not logged in. Run: wafer auth login")
2913
2949
  raise typer.Exit(1)
2914
2950
 
2915
2951
  if verify or refresh:
@@ -2917,7 +2953,7 @@ def whoami(
2917
2953
  # Try to get valid token with auto-refresh
2918
2954
  token = get_valid_token()
2919
2955
  if token is None:
2920
- typer.echo("Token expired and refresh failed. Run: wafer login", err=True)
2956
+ typer.echo("Token expired and refresh failed. Run: wafer auth login", err=True)
2921
2957
  raise typer.Exit(1)
2922
2958
  if token != creds.access_token:
2923
2959
  typer.echo("Token refreshed successfully")
@@ -2930,10 +2966,10 @@ def whoami(
2930
2966
  except Exception as e:
2931
2967
  if creds.refresh_token and not refresh:
2932
2968
  typer.echo(f"Token expired: {e}", err=True)
2933
- typer.echo("Try: wafer whoami --refresh", err=True)
2969
+ typer.echo("Try: wafer auth whoami --refresh", err=True)
2934
2970
  else:
2935
2971
  typer.echo(f"Token invalid or expired: {e}", err=True)
2936
- typer.echo("Run: wafer login", err=True)
2972
+ typer.echo("Run: wafer auth login", err=True)
2937
2973
  raise typer.Exit(1) from None
2938
2974
  elif creds.email:
2939
2975
  typer.echo(creds.email)
@@ -2941,7 +2977,7 @@ def whoami(
2941
2977
  typer.echo("Logged in (email not available)")
2942
2978
 
2943
2979
 
2944
- @app.command("guide")
2980
+ @app.command("guide", rich_help_panel="Onboarding")
2945
2981
  def guide() -> None:
2946
2982
  """Show the Wafer CLI usage guide.
2947
2983
 
@@ -2972,7 +3008,7 @@ demo_app = typer.Typer(
2972
3008
  wafer demo trace Analyze a sample performance trace
2973
3009
  wafer demo eval Run kernel evaluation on cloud GPU (requires login)"""
2974
3010
  )
2975
- app.add_typer(demo_app, name="demo")
3011
+ app.add_typer(demo_app, name="demo", rich_help_panel="Onboarding")
2976
3012
 
2977
3013
  DEMO_TRACES_URL = "https://github.com/wafer-ai/wafer/raw/main/apps/wafer-cli/wafer/demo_data"
2978
3014
  DEMO_DIR = Path.home() / ".cache" / "wafer" / "demo"
@@ -3192,7 +3228,7 @@ def demo_eval(
3192
3228
  """Demo: Evaluate a kernel on a cloud GPU.
3193
3229
 
3194
3230
  Creates a workspace, runs a sample Triton kernel evaluation, and cleans up.
3195
- Requires authentication (wafer login).
3231
+ Requires authentication (wafer auth login).
3196
3232
 
3197
3233
  Example:
3198
3234
  wafer demo eval
@@ -3207,7 +3243,7 @@ def demo_eval(
3207
3243
  # Check auth first
3208
3244
  creds = load_credentials()
3209
3245
  if not creds:
3210
- typer.echo("Error: Not authenticated. Run: wafer login")
3246
+ typer.echo("Error: Not authenticated. Run: wafer auth login")
3211
3247
  raise typer.Exit(1)
3212
3248
 
3213
3249
  if not yes:
@@ -3856,12 +3892,16 @@ def targets_add(
3856
3892
 
3857
3893
  @targets_app.command("list")
3858
3894
  def targets_list() -> None:
3859
- """List all configured targets.
3895
+ """List all configured targets with live provider status.
3860
3896
 
3861
3897
  Example:
3862
3898
  wafer config targets list
3863
3899
  """
3864
- from .targets import get_default_target, list_targets
3900
+ import socket
3901
+
3902
+ import trio
3903
+
3904
+ from .targets import get_default_target, list_targets, load_target, remove_target
3865
3905
 
3866
3906
  targets = list_targets()
3867
3907
  default = get_default_target()
@@ -3871,10 +3911,146 @@ def targets_list() -> None:
3871
3911
  typer.echo("Add one with: wafer config targets add <path/to/target.toml>")
3872
3912
  return
3873
3913
 
3914
+ def _parse_ssh_target(ssh_target: str) -> tuple[str, int]:
3915
+ """Extract (host, port) from user@host:port string."""
3916
+ parts = ssh_target.rsplit(":", 1)
3917
+ host_part = parts[0]
3918
+ port = int(parts[1]) if len(parts) > 1 else 22
3919
+ if "@" in host_part:
3920
+ host = host_part.split("@", 1)[1]
3921
+ else:
3922
+ host = host_part
3923
+ return (host, port)
3924
+
3925
+ async def _get_live_provider_endpoints() -> set[tuple[str, int]]:
3926
+ """Query RunPod + DO APIs. Returns set of live (ip, port) endpoints."""
3927
+ from wafer_core.targets.digitalocean import list_running_droplets
3928
+ from wafer_core.targets.runpod import sync_pods_from_api
3929
+
3930
+ live_endpoints: set[tuple[str, int]] = set()
3931
+
3932
+ async def _fetch_runpod() -> None:
3933
+ try:
3934
+ pods = await sync_pods_from_api()
3935
+ for p in pods:
3936
+ live_endpoints.add((p.public_ip, p.ssh_port))
3937
+ except Exception:
3938
+ pass
3939
+
3940
+ async def _fetch_do() -> None:
3941
+ try:
3942
+ droplets = await list_running_droplets()
3943
+ for d in droplets:
3944
+ live_endpoints.add((d.public_ip, d.ssh_port))
3945
+ except Exception:
3946
+ pass
3947
+
3948
+ async with trio.open_nursery() as nursery:
3949
+ nursery.start_soon(_fetch_runpod)
3950
+ nursery.start_soon(_fetch_do)
3951
+
3952
+ return live_endpoints
3953
+
3954
+ async def _get_target_status(
3955
+ name: str,
3956
+ live_endpoints: set[tuple[str, int]],
3957
+ ) -> tuple[str, str, str]:
3958
+ """Returns (name, status, ssh_info)."""
3959
+ from wafer_core.targets.digitalocean import (
3960
+ _remove_droplet_from_state,
3961
+ check_droplet_running,
3962
+ get_droplet_state,
3963
+ )
3964
+ from wafer_core.targets.runpod import (
3965
+ _remove_pod_from_state,
3966
+ check_pod_running,
3967
+ get_pod_state,
3968
+ )
3969
+ from wafer_core.utils.kernel_utils.targets.config import (
3970
+ BaremetalTarget,
3971
+ DigitalOceanTarget,
3972
+ ModalTarget,
3973
+ RunPodTarget,
3974
+ )
3975
+
3976
+ try:
3977
+ target = load_target(name)
3978
+ except (FileNotFoundError, ValueError, AssertionError, TypeError):
3979
+ return (name, "error", "")
3980
+
3981
+ if isinstance(target, RunPodTarget):
3982
+ pod = get_pod_state(name)
3983
+ if not pod:
3984
+ return (name, "no instance", "")
3985
+ if await check_pod_running(pod.pod_id):
3986
+ return (name, "running", f"{pod.ssh_username}@{pod.public_ip}:{pod.ssh_port}")
3987
+ _remove_pod_from_state(name)
3988
+ return (name, "stopped", "")
3989
+
3990
+ if isinstance(target, DigitalOceanTarget):
3991
+ droplet = get_droplet_state(name)
3992
+ if not droplet:
3993
+ return (name, "no instance", "")
3994
+ if await check_droplet_running(droplet.droplet_id):
3995
+ return (
3996
+ name,
3997
+ "running",
3998
+ f"{droplet.ssh_username}@{droplet.public_ip}:{droplet.ssh_port}",
3999
+ )
4000
+ _remove_droplet_from_state(name)
4001
+ return (name, "stopped", "")
4002
+
4003
+ if isinstance(target, BaremetalTarget):
4004
+ ssh_target = target.ssh_target
4005
+ host, port = _parse_ssh_target(ssh_target)
4006
+
4007
+ def _tcp_check() -> bool:
4008
+ try:
4009
+ sock = socket.create_connection((host, port), timeout=2)
4010
+ sock.close()
4011
+ return True
4012
+ except OSError:
4013
+ return False
4014
+
4015
+ reachable = await trio.to_thread.run_sync(_tcp_check)
4016
+ if reachable:
4017
+ return (name, "reachable", ssh_target)
4018
+
4019
+ # Unreachable + has a provider = backed by an ephemeral instance.
4020
+ # If not in the live provider listing, the instance is gone — remove config.
4021
+ if target.provider and (host, port) not in live_endpoints:
4022
+ remove_target(name)
4023
+ return (name, "removed (dead pod)", ssh_target)
4024
+
4025
+ return (name, "unreachable", ssh_target)
4026
+
4027
+ if isinstance(target, ModalTarget):
4028
+ return (name, "serverless", "")
4029
+
4030
+ # Unknown target type
4031
+ return (name, "unknown", "")
4032
+
4033
+ async def _gather_statuses() -> list[tuple[str, str, str]]:
4034
+ live_endpoints = await _get_live_provider_endpoints()
4035
+ results: list[tuple[str, str, str]] = [("", "", "")] * len(targets)
4036
+
4037
+ async def _check(i: int, name: str) -> None:
4038
+ results[i] = await _get_target_status(name, live_endpoints)
4039
+
4040
+ async with trio.open_nursery() as nursery:
4041
+ for i, name in enumerate(targets):
4042
+ nursery.start_soon(_check, i, name)
4043
+
4044
+ return results
4045
+
4046
+ statuses = trio.run(_gather_statuses)
4047
+
3874
4048
  typer.echo("Configured targets:")
3875
- for name in targets:
4049
+ for name, status, ssh_info in statuses:
3876
4050
  marker = " (default)" if name == default else ""
3877
- typer.echo(f" {name}{marker}")
4051
+ label = f" {name}{marker}"
4052
+ detail = f" {ssh_info}" if ssh_info else ""
4053
+ typer.echo(f"{label:<40}{status}{detail}")
3878
4054
 
3879
4055
 
3880
4056
  @targets_app.command("show")
@@ -4089,10 +4265,19 @@ def targets_cleanup(
4089
4265
  # Known libraries that can be installed on targets
4090
4266
  # TODO: Consider adding HipKittens to the default RunPod/DO Docker images
4091
4267
  # so this install step isn't needed. For now, this command handles it.
4268
+ # Architecture → branch mapping for libraries that ship per-arch branches.
4269
+ # "default" is used when the detected arch has no explicit entry.
4270
+ _ARCH_BRANCHES: dict[str, dict[str, str]] = {
4271
+ "hipkittens": {
4272
+ "gfx942": "cdna3", # MI300X, MI325X
4273
+ "default": "main", # MI350X, MI355X, and future CDNA4+
4274
+ },
4275
+ }
4276
+
4092
4277
  INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
4093
4278
  "hipkittens": {
4094
- "description": "HipKittens - AMD port of ThunderKittens for MI300X",
4095
- "git_url": "https://github.com/HazyResearch/hipkittens.git",
4279
+ "description": "HipKittens - AMD port of ThunderKittens",
4280
+ "git_url": "https://github.com/HazyResearch/HipKittens.git",
4096
4281
  "install_path": "/opt/hipkittens",
4097
4282
  "requires_amd": True,
4098
4283
  },
@@ -4105,6 +4290,38 @@ INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
4105
4290
  }
4106
4291
 
4107
4292
 
4293
+ def _resolve_gfx_arch(target: object, ssh_cmd: list[str]) -> str | None:
4294
+ """Return the gfx architecture string for *target*.
4295
+
4296
+ 1. If the target config already carries a compute_capability, map it.
4297
+ 2. Otherwise SSH in and probe with ``rocminfo``.
4298
+ Returns None only if detection fails entirely.
4299
+ """
4300
+ import subprocess
4301
+
4302
+ from .evaluate import AMD_CC_TO_ARCH
4303
+
4304
+ cc = getattr(target, "compute_capability", None)
4305
+ if cc and cc in AMD_CC_TO_ARCH:
4306
+ return AMD_CC_TO_ARCH[cc]
4307
+
4308
+ typer.echo(" Detecting GPU architecture via rocminfo...")
4309
+ probe_script = "rocminfo 2>/dev/null | grep -oP 'gfx\\d+' | head -1"
4310
+ result = subprocess.run(
4311
+ ssh_cmd + [probe_script],
4312
+ capture_output=True,
4313
+ text=True,
4314
+ timeout=30,
4315
+ )
4316
+ arch = result.stdout.strip()
4317
+ if result.returncode == 0 and arch.startswith("gfx"):
4318
+ typer.echo(f" Detected: {arch}")
4319
+ return arch
4320
+
4321
+ typer.echo(" Warning: could not detect GPU architecture", err=True)
4322
+ return None
4323
+
4324
+
4108
4325
  @targets_app.command("install")
4109
4326
  def targets_install(
4110
4327
  name: str = typer.Argument(..., help="Target name"),
@@ -4115,6 +4332,9 @@ def targets_install(
4115
4332
  Installs header-only libraries like HipKittens on remote targets.
4116
4333
  Safe to run multiple times - will skip if already installed.
4117
4334
 
4335
+ For libraries with per-architecture branches (e.g. HipKittens), the
4336
+ correct branch is selected automatically based on the target's GPU.
4337
+
4118
4338
  Available libraries:
4119
4339
  hipkittens - HipKittens (AMD ThunderKittens port)
4120
4340
  repair-headers - Fix ROCm thrust headers (after hipify corruption)
@@ -4188,14 +4408,22 @@ def targets_install(
4188
4408
  install_path = lib_info["install_path"]
4189
4409
  git_url = lib_info["git_url"]
4190
4410
 
4191
- # Idempotent install script
4411
+ # Resolve the branch for arch-aware libraries
4412
+ branch = "main"
4413
+ arch_map = _ARCH_BRANCHES.get(library)
4414
+ if arch_map:
4415
+ gfx = await trio.to_thread.run_sync(lambda: _resolve_gfx_arch(target, ssh_cmd))
4416
+ branch = arch_map.get(gfx, arch_map["default"]) if gfx else arch_map["default"]
4417
+ typer.echo(f" Branch: {branch} (arch={gfx or 'unknown'})")
4418
+
4419
+ # Idempotent: if already cloned, ensure correct branch & pull
4192
4420
  install_script = f"""
4193
4421
  if [ -d "{install_path}" ]; then
4194
4422
  echo "ALREADY_INSTALLED: {install_path} exists"
4195
- cd {install_path} && git pull --quiet 2>/dev/null || true
4423
+ cd {install_path} && git fetch --quiet origin && git checkout {branch} --quiet && git pull --quiet origin {branch}
4196
4424
  else
4197
4425
  echo "INSTALLING: cloning to {install_path}"
4198
- git clone --quiet {git_url} {install_path}
4426
+ git clone --quiet --branch {branch} {git_url} {install_path}
4199
4427
  fi
4200
4428
  echo "DONE"
4201
4429
  """
@@ -4373,8 +4601,8 @@ def billing_usage(
4373
4601
  """Show current billing usage and subscription info.
4374
4602
 
4375
4603
  Example:
4376
- wafer billing
4377
- wafer billing --json
4604
+ wafer config billing
4605
+ wafer config billing --json
4378
4606
  """
4379
4607
  # Only show usage if no subcommand was invoked
4380
4608
  if ctx.invoked_subcommand is not None:
@@ -4402,9 +4630,9 @@ def billing_topup(
4402
4630
  Opens a Stripe checkout page to add credits. Default amount is $25.
4403
4631
 
4404
4632
  Example:
4405
- wafer billing topup # Add $25
4406
- wafer billing topup 100 # Add $100
4407
- wafer billing topup --no-browser # Print URL instead
4633
+ wafer config billing topup # Add $25
4634
+ wafer config billing topup 100 # Add $100
4635
+ wafer config billing topup --no-browser # Print URL instead
4408
4636
  """
4409
4637
  import webbrowser
4410
4638
 
@@ -4450,8 +4678,8 @@ def billing_portal(
4450
4678
  Manage your subscription, update payment method, or view invoices.
4451
4679
 
4452
4680
  Example:
4453
- wafer billing portal
4454
- wafer billing portal --no-browser
4681
+ wafer config billing portal
4682
+ wafer config billing portal --no-browser
4455
4683
  """
4456
4684
  import webbrowser
4457
4685
 
@@ -4488,8 +4716,8 @@ def ssh_keys_list(
4488
4716
  """List all registered SSH public keys.
4489
4717
 
4490
4718
  Example:
4491
- wafer ssh-keys list
4492
- wafer ssh-keys list --json
4719
+ wafer config ssh-keys list
4720
+ wafer config ssh-keys list --json
4493
4721
  """
4494
4722
  from .ssh_keys import list_ssh_keys
4495
4723
 
@@ -4515,9 +4743,9 @@ def ssh_keys_add(
4515
4743
  id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
4516
4744
 
4517
4745
  Example:
4518
- wafer ssh-keys add # Auto-detect
4519
- wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
4520
- wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
4746
+ wafer config ssh-keys add # Auto-detect
4747
+ wafer config ssh-keys add ~/.ssh/id_rsa.pub # Specific file
4748
+ wafer config ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
4521
4749
  """
4522
4750
  from .ssh_keys import add_ssh_key
4523
4751
 
@@ -4536,10 +4764,10 @@ def ssh_keys_remove(
4536
4764
  ) -> None:
4537
4765
  """Remove an SSH public key.
4538
4766
 
4539
- Get the key ID from 'wafer ssh-keys list'.
4767
+ Get the key ID from 'wafer config ssh-keys list'.
4540
4768
 
4541
4769
  Example:
4542
- wafer ssh-keys remove abc123-def456-...
4770
+ wafer config ssh-keys remove abc123-def456-...
4543
4771
  """
4544
4772
  from .ssh_keys import remove_ssh_key
4545
4773
 
@@ -4978,7 +5206,9 @@ def workspaces_sync(
4978
5206
  @workspaces_app.command("pull")
4979
5207
  def workspaces_pull(
4980
5208
  workspace: str = typer.Argument(..., help="Workspace name or ID"),
4981
- remote_path: str = typer.Argument(..., help="Remote path in workspace (relative to /workspace or absolute)"),
5209
+ remote_path: str = typer.Argument(
5210
+ ..., help="Remote path in workspace (relative to /workspace or absolute)"
5211
+ ),
4982
5212
  local_path: Path = typer.Argument(
4983
5213
  Path("."), help="Local destination path (default: current directory)"
4984
5214
  ),
@@ -5782,7 +6012,7 @@ def ncu_analyze(
5782
6012
  compute/memory throughput, and optimization recommendations.
5783
6013
 
5784
6014
  By default, uses local NCU if available, otherwise runs analysis
5785
- remotely via wafer-api (requires authentication: wafer login).
6015
+ remotely via wafer-api (requires authentication: wafer auth login).
5786
6016
 
5787
6017
  Use --target for direct SSH mode (like wafer remote-run --direct).
5788
6018
  Use --include-source to fetch SASS assembly with register/instruction data.
@@ -5877,7 +6107,7 @@ def nsys_analyze(
5877
6107
  Returns timeline events, kernel information, memory usage, and diagnostics.
5878
6108
 
5879
6109
  By default, uses local nsys if available, otherwise runs analysis
5880
- remotely via wafer-api (requires authentication: wafer login).
6110
+ remotely via wafer-api (requires authentication: wafer auth login).
5881
6111
 
5882
6112
  Supports multiple execution modes:
5883
6113
  - Local: Uses local nsys CLI (no GPU required for analysis)
@@ -6862,7 +7092,7 @@ def autotuner_results(
6862
7092
  raise typer.Exit(1) from None
6863
7093
 
6864
7094
 
6865
- @app.command("capture")
7095
+ @app.command("capture", rich_help_panel="Kernel Development")
6866
7096
  def capture_command( # noqa: PLR0915
6867
7097
  label: str = typer.Argument(
6868
7098
  ..., help="Label for this capture (e.g., 'baseline', 'optimized-v2')"
@@ -7527,6 +7757,144 @@ def isa_targets() -> None:
7527
7757
  typer.echo(output)
7528
7758
 
7529
7759
 
7760
+ # =============================================================================
7761
+ # Trace comparison commands
7762
+ # =============================================================================
7763
+
7764
+
7765
+ @compare_app.command("analyze")
7766
+ def compare_analyze(
7767
+ trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
7768
+ trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
7769
+ format: str = typer.Option(
7770
+ "text",
7771
+ "--format",
7772
+ "-f",
7773
+ help="Output format: text, text-layers, csv, csv-layers, json",
7774
+ ),
7775
+ output: Path | None = typer.Option(
7776
+ None, "--output", "-o", help="Output file (default: stdout)"
7777
+ ),
7778
+ phase: str = typer.Option(
7779
+ "all",
7780
+ "--phase",
7781
+ help="Filter by phase: all, prefill, decode",
7782
+ ),
7783
+ layers: bool = typer.Option(False, "--layers", help="Show layer-wise performance breakdown"),
7784
+ all: bool = typer.Option(
7785
+ False, "--all", help="Show all items (no truncation for layers, operations, kernels)"
7786
+ ),
7787
+ stack_traces: bool = typer.Option(
7788
+ False, "--stack-traces", help="Show Python stack traces for operations"
7789
+ ),
7790
+ json: bool = typer.Option(
7791
+ False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
7792
+ ),
7793
+ ) -> None:
7794
+ """Compare GPU traces from AMD and NVIDIA platforms.
7795
+
7796
+ Analyzes performance differences between traces, identifying which operations
7797
+ are faster/slower on each platform and providing kernel-level details.
7798
+
7799
+ Examples:
7800
+ # Basic comparison (stdout)
7801
+ wafer compare analyze amd_trace.json nvidia_trace.json
7802
+
7803
+ # Show layer-wise breakdown
7804
+ wafer compare analyze amd_trace.json nvidia_trace.json --layers
7805
+ wafer compare analyze amd_trace.json nvidia_trace.json --format text-layers
7806
+
7807
+ # Show all layers without truncation
7808
+ wafer compare analyze amd_trace.json nvidia_trace.json --layers --all
7809
+
7810
+ # Show Python stack traces
7811
+ wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces
7812
+
7813
+ # Show all stack traces without truncation
7814
+ wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces --all
7815
+
7816
+ # Save to file
7817
+ wafer compare analyze amd_trace.json nvidia_trace.json -o report.txt
7818
+
7819
+ # CSV output (operations) to file
7820
+ wafer compare analyze amd_trace.json nvidia_trace.json --format csv -o operations.csv
7821
+
7822
+ # CSV output (layers) to file
7823
+ wafer compare analyze amd_trace.json nvidia_trace.json --format csv-layers -o layers.csv
7824
+
7825
+ # JSON output to file
7826
+ wafer compare analyze amd_trace.json nvidia_trace.json --format json -o report.json
7827
+
7828
+ # Analyze only prefill phase
7829
+ wafer compare analyze amd_trace.json nvidia_trace.json --phase prefill
7830
+ """
7831
+ from .trace_compare import compare_traces
7832
+
7833
+ compare_traces(
7834
+ trace1=trace1,
7835
+ trace2=trace2,
7836
+ output=output,
7837
+ output_format=format,
7838
+ phase=phase,
7839
+ show_layers=layers,
7840
+ show_all=all,
7841
+ show_stack_traces=stack_traces,
7842
+ )
7843
+ _mark_command_success()
7844
+
7845
+
7846
+ @compare_app.command("fusion")
7847
+ def compare_fusion_cmd(
7848
+ trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
7849
+ trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
7850
+ format: str = typer.Option(
7851
+ "text",
7852
+ "--format",
7853
+ "-f",
7854
+ help="Output format: text, csv, json",
7855
+ ),
7856
+ output: Path | None = typer.Option(
7857
+ None, "--output", "-o", help="Output file (default: stdout)"
7858
+ ),
7859
+ min_group_size: int = typer.Option(
7860
+ 50,
7861
+ "--min-group-size",
7862
+ help="Minimum correlation group size to analyze",
7863
+ ),
7864
+ json: bool = typer.Option(
7865
+ False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
7866
+ ),
7867
+ ) -> None:
7868
+ """Analyze kernel fusion differences between AMD and NVIDIA traces.
7869
+
7870
+ Detects which operations are fused differently on each platform by analyzing
7871
+ how many kernel launches each platform uses for the same logical operations.
7872
+
7873
+ Examples:
7874
+ # Basic fusion analysis (stdout)
7875
+ wafer compare fusion amd_trace.json nvidia_trace.json
7876
+
7877
+ # Save to file
7878
+ wafer compare fusion amd_trace.json nvidia_trace.json -o fusion_report.txt
7879
+
7880
+ # JSON output to file
7881
+ wafer compare fusion amd_trace.json nvidia_trace.json --format json -o fusion.json
7882
+
7883
+ # CSV output to file
7884
+ wafer compare fusion amd_trace.json nvidia_trace.json --format csv -o fusion.csv
7885
+ """
7886
+ from .trace_compare import compare_fusion
7887
+
7888
+ compare_fusion(
7889
+ trace1=trace1,
7890
+ trace2=trace2,
7891
+ output=output,
7892
+ format_type=format,
7893
+ min_group_size=min_group_size,
7894
+ )
7895
+ _mark_command_success()
7896
+
7897
+
7530
7898
  def main() -> None:
7531
7899
  """Entry point for wafer CLI."""
7532
7900
  app()