wafer-cli 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -30,6 +30,14 @@ import typer
30
30
 
31
31
  from .config import WaferConfig, WaferEnvironment
32
32
  from .inference import infer_upload_files, resolve_environment
33
+ from .problems import (
34
+ download_problems,
35
+ get_problem_path,
36
+ get_problems_path,
37
+ )
38
+ from .problems import (
39
+ list_problems as list_problems_fn,
40
+ )
33
41
 
34
42
  app = typer.Typer(
35
43
  help="GPU development toolkit for LLM coding agents",
@@ -91,11 +99,15 @@ def main_callback(ctx: typer.Context) -> None:
91
99
  # Install exception hook to catch SystemExit and mark failures
92
100
  original_excepthook = sys.excepthook
93
101
 
94
- def custom_excepthook(exc_type, exc_value, exc_traceback):
102
+ def custom_excepthook(
103
+ exc_type: type[BaseException],
104
+ exc_value: BaseException,
105
+ exc_traceback: object,
106
+ ) -> None:
95
107
  global _command_outcome
96
108
  # Mark as failure if SystemExit with non-zero code, or any other exception
97
109
  if exc_type is SystemExit:
98
- exit_code = exc_value.code if hasattr(exc_value, 'code') else 1
110
+ exit_code = exc_value.code if hasattr(exc_value, "code") else 1
99
111
  if exit_code != 0 and exit_code is not None:
100
112
  _command_outcome = "failure"
101
113
  else:
@@ -200,6 +212,13 @@ kernelbench_app = typer.Typer(
200
212
  )
201
213
  evaluate_app.add_typer(kernelbench_app, name="kernelbench")
202
214
 
215
+ # Nested subcommand for gpumode format
216
+ gpumode_app = typer.Typer(
217
+ help="Evaluate kernels in GPUMode format (custom_kernel/ref_kernel functions)",
218
+ invoke_without_command=True,
219
+ )
220
+ evaluate_app.add_typer(gpumode_app, name="gpumode")
221
+
203
222
  # =============================================================================
204
223
  # Dev commands (internal, used by web app proxy)
205
224
  # =============================================================================
@@ -242,6 +261,10 @@ app.add_typer(amd_app, name="amd")
242
261
  isa_app = typer.Typer(help="ISA analysis for AMD GPU code objects (.co files)")
243
262
  amd_app.add_typer(isa_app, name="isa")
244
263
 
264
+ # Kernel Scope - static ISA analysis for Triton kernels
265
+ kernel_scope_app = typer.Typer(help="Static ISA analysis for Triton compilation artifacts")
266
+ amd_app.add_typer(kernel_scope_app, name="kernel-scope")
267
+
245
268
  # =============================================================================
246
269
  # Skill management (wafer skill ...)
247
270
  # =============================================================================
@@ -396,6 +419,122 @@ def skill_status() -> None:
396
419
  typer.echo(f"{tool_name}: Not installed")
397
420
 
398
421
 
422
+ # =============================================================================
423
+ # Provider auth management (wafer auth ...)
424
+ # =============================================================================
425
+
426
+ provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
427
+ app.add_typer(provider_auth_app, name="auth")
428
+
429
+
430
+ @provider_auth_app.command("login")
431
+ def provider_auth_login(
432
+ provider: str = typer.Argument(
433
+ ...,
434
+ help="Provider name: runpod, digitalocean, or modal",
435
+ ),
436
+ api_key: str | None = typer.Option(
437
+ None,
438
+ "--api-key",
439
+ "-k",
440
+ help="API key (if not provided, reads from stdin)",
441
+ ),
442
+ ) -> None:
443
+ """Save API key for a cloud GPU provider.
444
+
445
+ Stores the key in ~/.wafer/auth.json. Environment variables
446
+ (e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
447
+
448
+ Examples:
449
+ wafer auth login runpod --api-key rp_xxx
450
+ wafer auth login digitalocean --api-key dop_v1_xxx
451
+ echo $API_KEY | wafer auth login runpod
452
+ """
453
+ import sys
454
+
455
+ from wafer_core.auth import PROVIDERS, save_api_key
456
+
457
+ # Validate provider
458
+ if provider not in PROVIDERS:
459
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
460
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
461
+ raise typer.Exit(1)
462
+
463
+ # Get API key from option or stdin
464
+ if api_key is None:
465
+ if sys.stdin.isatty():
466
+ typer.echo(f"Enter API key for {PROVIDERS[provider]['display_name']}:")
467
+ api_key = typer.prompt("API key", hide_input=True)
468
+ else:
469
+ api_key = sys.stdin.read().strip()
470
+
471
+ if not api_key:
472
+ typer.echo("Error: No API key provided", err=True)
473
+ raise typer.Exit(1)
474
+
475
+ # Save the key
476
+ save_api_key(provider, api_key)
477
+ typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
478
+ typer.echo("Stored in: ~/.wafer/auth.json")
479
+
480
+
481
+ @provider_auth_app.command("logout")
482
+ def provider_auth_logout(
483
+ provider: str = typer.Argument(
484
+ ...,
485
+ help="Provider name: runpod, digitalocean, or modal",
486
+ ),
487
+ ) -> None:
488
+ """Remove stored API key for a cloud GPU provider.
489
+
490
+ Examples:
491
+ wafer auth logout runpod
492
+ wafer auth logout digitalocean
493
+ """
494
+ from wafer_core.auth import PROVIDERS, remove_api_key
495
+
496
+ # Validate provider
497
+ if provider not in PROVIDERS:
498
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
499
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
500
+ raise typer.Exit(1)
501
+
502
+ if remove_api_key(provider):
503
+ typer.echo(f"API key removed for {PROVIDERS[provider]['display_name']}")
504
+ else:
505
+ typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
506
+
507
+
508
+ @provider_auth_app.command("status")
509
+ def provider_auth_status() -> None:
510
+ """Show authentication status for all cloud GPU providers.
511
+
512
+ Displays which providers have API keys configured and where
513
+ the keys are coming from (environment variable or auth.json).
514
+
515
+ Example:
516
+ wafer auth status
517
+ """
518
+ from wafer_core.auth import get_all_auth_status
519
+
520
+ statuses = get_all_auth_status()
521
+
522
+ typer.echo("Cloud GPU Provider Authentication Status")
523
+ typer.echo("=" * 45)
524
+
525
+ for status in statuses:
526
+ if status.is_authenticated:
527
+ source_str = f"({status.source})" if status.source else ""
528
+ typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
529
+ else:
530
+ typer.echo(f" {status.display_name}: ✗ Not configured")
531
+ typer.echo(f" Run: wafer auth login {status.provider}")
532
+ typer.echo(f" Or set: {status.key_url}")
533
+
534
+ typer.echo("")
535
+ typer.echo("Note: Environment variables take precedence over stored keys.")
536
+
537
+
399
538
  @app.command(hidden=True)
400
539
  def run(
401
540
  command: str = typer.Argument(..., help="Command to run in Docker container"),
@@ -1289,86 +1428,37 @@ def evaluate( # noqa: PLR0913
1289
1428
  --benchmark --defensive
1290
1429
 
1291
1430
  Subcommands:
1292
- make-template Generate template files for this format
1431
+ gpumode Use GPUMode format (functional) - RECOMMENDED
1293
1432
  kernelbench Use KernelBench format (ModelNew class)
1433
+ make-template Generate template files for this format (deprecated)
1294
1434
  """
1295
1435
  # If a subcommand is being invoked, skip the main evaluation logic
1296
1436
  if ctx.invoked_subcommand is not None:
1297
1437
  return
1298
1438
 
1299
- # Validate required args when running evaluation (not subcommands)
1300
- missing_args = []
1301
- if implementation is None:
1302
- missing_args.append("--impl/-i")
1303
- if reference is None:
1304
- missing_args.append("--reference")
1305
- if test_cases is None:
1306
- missing_args.append("--test-cases")
1307
-
1308
- if missing_args:
1309
- typer.echo("Error: Missing required arguments", err=True)
1310
- typer.echo(f" Required: {', '.join(missing_args)}", err=True)
1311
- typer.echo("", err=True)
1312
- typer.echo(
1313
- "Usage: wafer evaluate --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
1314
- err=True,
1315
- )
1316
- typer.echo("", err=True)
1317
- typer.echo("Run 'wafer evaluate --help' for full options.", err=True)
1318
- typer.echo("Run 'wafer evaluate make-template DIR' to generate starter files.", err=True)
1319
- raise typer.Exit(1)
1320
-
1321
- from .evaluate import EvaluateArgs, run_evaluate
1322
-
1323
- args = EvaluateArgs(
1324
- implementation=implementation,
1325
- reference=reference,
1326
- test_cases=test_cases,
1327
- target_name=target or "",
1328
- benchmark=benchmark,
1329
- profile=profile,
1330
- defensive=defensive,
1331
- sync_artifacts=sync_artifacts,
1332
- gpu_id=gpu_id,
1439
+ # Bare 'wafer evaluate' is no longer supported - must use subcommand
1440
+ typer.echo("Error: 'wafer evaluate' requires a subcommand.", err=True)
1441
+ typer.echo("", err=True)
1442
+ typer.echo("Available subcommands:", err=True)
1443
+ typer.echo(
1444
+ " gpumode Evaluate GPUMode format (custom_kernel/ref_kernel functions)", err=True
1333
1445
  )
1334
-
1335
- try:
1336
- # Use trio_asyncio to run async code that uses both trio and asyncio
1337
- # (AsyncSSHClient uses asyncssh which is asyncio-based, bridged via trio_asyncio)
1338
- import trio_asyncio
1339
-
1340
- result = trio_asyncio.run(run_evaluate, args)
1341
- except KeyboardInterrupt:
1342
- typer.echo("\nInterrupted by user", err=True)
1343
- raise typer.Exit(130) from None
1344
- except Exception as e:
1345
- # Unwrap ExceptionGroup (from Trio nurseries) to show actual error
1346
- if hasattr(e, "exceptions") and e.exceptions:
1347
- for exc in e.exceptions:
1348
- typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
1349
- else:
1350
- typer.echo(f"Error: {e}", err=True)
1351
- raise typer.Exit(1) from None
1352
-
1353
- # Print results
1354
- if result.success:
1355
- typer.echo("")
1356
- typer.echo("=" * 60)
1357
- status = "PASS" if result.all_correct else "FAIL"
1358
- typer.echo(f"Result: {status}")
1359
- score_pct = f"{result.correctness_score:.1%}"
1360
- typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
1361
- if result.geomean_speedup > 0:
1362
- typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
1363
- if result.artifact_path:
1364
- typer.echo(f"Artifacts: {result.artifact_path}")
1365
- typer.echo("=" * 60)
1366
-
1367
- if not result.all_correct:
1368
- raise typer.Exit(1)
1369
- else:
1370
- typer.echo(f"Error: {result.error_message}", err=True)
1371
- raise typer.Exit(1)
1446
+ typer.echo(" kernelbench Evaluate KernelBench format (ModelNew class)", err=True)
1447
+ typer.echo("", err=True)
1448
+ typer.echo("Examples:", err=True)
1449
+ typer.echo(
1450
+ " wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json",
1451
+ err=True,
1452
+ )
1453
+ typer.echo(
1454
+ " wafer evaluate kernelbench --impl impl.py --reference ref.py --benchmark", err=True
1455
+ )
1456
+ typer.echo("", err=True)
1457
+ typer.echo(
1458
+ "Run 'wafer evaluate gpumode --help' or 'wafer evaluate kernelbench --help' for options.",
1459
+ err=True,
1460
+ )
1461
+ raise typer.Exit(1)
1372
1462
 
1373
1463
 
1374
1464
  TEMPLATE_KERNEL = '''\
@@ -1503,8 +1593,59 @@ def evaluate_make_template(
1503
1593
  # KernelBench format evaluation
1504
1594
  # =============================================================================
1505
1595
 
1506
- # Path to KernelBench problems (relative to wafer root)
1507
- KERNELBENCH_ROOT = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench"
1596
+
1597
+ def _get_kernelbench_root() -> Path | None:
1598
+ """Get KernelBench problems root, preferring downloaded location."""
1599
+ # First check downloaded location
1600
+ downloaded = get_problems_path("kernelbench")
1601
+ if downloaded is not None:
1602
+ kb_root = downloaded / "KernelBench"
1603
+ if kb_root.exists():
1604
+ return kb_root
1605
+ return downloaded
1606
+
1607
+ # Fall back to legacy location (for development)
1608
+ legacy = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench" / "KernelBench"
1609
+ if legacy.exists():
1610
+ return legacy
1611
+
1612
+ return None
1613
+
1614
+
1615
+ @kernelbench_app.command("download")
1616
+ def kernelbench_download(
1617
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
1618
+ ) -> None:
1619
+ """Download KernelBench problems from GitHub.
1620
+
1621
+ Downloads the problem set to ~/.cache/wafer/problems/kernelbench/
1622
+
1623
+ Examples:
1624
+ wafer evaluate kernelbench download
1625
+ wafer evaluate kernelbench download --force # Re-download
1626
+ """
1627
+ try:
1628
+ path = download_problems("kernelbench", force=force, verbose=True)
1629
+ typer.echo("")
1630
+ typer.echo(f"Problems available at: {path}")
1631
+ typer.echo("Run 'wafer evaluate kernelbench list-problems' to see available problems.")
1632
+ except Exception as e:
1633
+ typer.echo(f"Error downloading problems: {e}", err=True)
1634
+ raise typer.Exit(1) from None
1635
+
1636
+
1637
+ @kernelbench_app.command("list-problems")
1638
+ def kernelbench_list_problems() -> None:
1639
+ """List available KernelBench problems.
1640
+
1641
+ Examples:
1642
+ wafer evaluate kernelbench list-problems
1643
+ """
1644
+ try:
1645
+ list_problems_fn("kernelbench", verbose=True)
1646
+ except ValueError as e:
1647
+ typer.echo(str(e), err=True)
1648
+ raise typer.Exit(1) from None
1508
1649
 
1509
1650
 
1510
1651
  @kernelbench_app.callback(invoke_without_command=True)
@@ -1528,9 +1669,18 @@ def kernelbench_evaluate( # noqa: PLR0913
1528
1669
  help="GPU target name. See 'wafer config targets list' for available targets.",
1529
1670
  autocompletion=complete_target_name,
1530
1671
  ),
1672
+ pool: str | None = typer.Option(
1673
+ None,
1674
+ "--pool",
1675
+ "-p",
1676
+ help="Target pool name. Acquires first available target from the pool. "
1677
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
1678
+ ),
1531
1679
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
1532
1680
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
1533
- inputs: Path | None = typer.Option(None, "--inputs", help="Custom inputs file to override get_inputs()"),
1681
+ inputs: Path | None = typer.Option(
1682
+ None, "--inputs", help="Custom inputs file to override get_inputs()"
1683
+ ),
1534
1684
  seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
1535
1685
  defensive: bool = typer.Option(
1536
1686
  False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
@@ -1588,12 +1738,54 @@ def kernelbench_evaluate( # noqa: PLR0913
1588
1738
  )
1589
1739
  raise typer.Exit(1)
1590
1740
 
1741
+ # Validate --target and --pool are mutually exclusive
1742
+ if target and pool:
1743
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
1744
+ raise typer.Exit(1)
1745
+
1591
1746
  from .evaluate import KernelBenchEvaluateArgs, run_evaluate_kernelbench
1592
1747
 
1748
+ # If pool specified, acquire a target from the pool
1749
+ resolved_target = target or ""
1750
+ pool_lock_context = None
1751
+
1752
+ if pool:
1753
+ from .target_lock import acquire_from_pool
1754
+ from .targets import filter_pool_by_auth, get_pool
1755
+
1756
+ try:
1757
+ pool_targets = get_pool(pool)
1758
+ except FileNotFoundError as e:
1759
+ typer.echo(f"Error: {e}", err=True)
1760
+ raise typer.Exit(1) from None
1761
+
1762
+ # Filter to only targets with valid auth
1763
+ usable_targets, skipped = filter_pool_by_auth(pool_targets)
1764
+ if skipped:
1765
+ typer.echo(f"Skipping targets without auth: {', '.join(skipped)}", err=True)
1766
+
1767
+ if not usable_targets:
1768
+ typer.echo(f"Error: No usable targets in pool '{pool}'", err=True)
1769
+ typer.echo(" All targets require authentication that is not configured.", err=True)
1770
+ typer.echo(" Run 'wafer auth status' to see which providers need setup.", err=True)
1771
+ raise typer.Exit(1) from None
1772
+
1773
+ typer.echo(f"Acquiring target from pool '{pool}' ({len(usable_targets)} targets)...")
1774
+ pool_lock_context = acquire_from_pool(usable_targets)
1775
+ acquired_target = pool_lock_context.__enter__()
1776
+
1777
+ if acquired_target is None:
1778
+ typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
1779
+ typer.echo(f" Targets: {', '.join(usable_targets)}", err=True)
1780
+ raise typer.Exit(1)
1781
+
1782
+ typer.echo(f"Acquired target: {acquired_target}")
1783
+ resolved_target = acquired_target
1784
+
1593
1785
  args = KernelBenchEvaluateArgs(
1594
1786
  implementation=implementation,
1595
1787
  reference=reference,
1596
- target_name=target or "",
1788
+ target_name=resolved_target,
1597
1789
  benchmark=benchmark,
1598
1790
  profile=profile,
1599
1791
  inputs=inputs,
@@ -1613,6 +1805,10 @@ def kernelbench_evaluate( # noqa: PLR0913
1613
1805
  except Exception as e:
1614
1806
  typer.echo(f"Error: {e}", err=True)
1615
1807
  raise typer.Exit(1) from None
1808
+ finally:
1809
+ # Release pool lock if we acquired one
1810
+ if pool_lock_context is not None:
1811
+ pool_lock_context.__exit__(None, None, None)
1616
1812
 
1617
1813
  # Print results
1618
1814
  if result.success:
@@ -1659,6 +1855,13 @@ def kernelbench_make_template(
1659
1855
  # Overwrite existing
1660
1856
  wafer evaluate kernelbench make-template level1/1 --force
1661
1857
  """
1858
+ # Get problems root (downloaded or legacy)
1859
+ kb_root = _get_kernelbench_root()
1860
+ if kb_root is None:
1861
+ typer.echo("Error: KernelBench problems not found.", err=True)
1862
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1863
+ raise typer.Exit(1)
1864
+
1662
1865
  # Parse problem ID
1663
1866
  parts = problem.split("/")
1664
1867
  if len(parts) != 2:
@@ -1670,10 +1873,10 @@ def kernelbench_make_template(
1670
1873
  level_str = f"level{level_str}"
1671
1874
 
1672
1875
  # Find the problem file
1673
- problem_dir = KERNELBENCH_ROOT / "KernelBench" / level_str
1876
+ problem_dir = kb_root / level_str
1674
1877
  if not problem_dir.exists():
1675
1878
  typer.echo(f"Error: KernelBench level directory not found: {problem_dir}", err=True)
1676
- typer.echo(f"Make sure KernelBench is at: {KERNELBENCH_ROOT}", err=True)
1879
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1677
1880
  raise typer.Exit(1)
1678
1881
 
1679
1882
  # Find matching problem file
@@ -1708,37 +1911,335 @@ def kernelbench_make_template(
1708
1911
 
1709
1912
  output = output.resolve()
1710
1913
 
1711
- # Check if exists
1712
- if output.exists() and not force:
1713
- typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
1914
+ # Check if exists
1915
+ if output.exists() and not force:
1916
+ typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
1917
+ raise typer.Exit(1)
1918
+
1919
+ # Copy the file
1920
+ content = problem_file.read_text()
1921
+ output.parent.mkdir(parents=True, exist_ok=True)
1922
+ output.write_text(content)
1923
+
1924
+ typer.echo(f"Created {output}")
1925
+ typer.echo("")
1926
+ typer.echo("Next steps:")
1927
+ typer.echo(f" 1. Read {output} to understand the Model interface")
1928
+ typer.echo(" 2. Create an implementation file with your ModelNew class:")
1929
+ typer.echo("")
1930
+ typer.echo(" import torch.nn as nn")
1931
+ typer.echo("")
1932
+ typer.echo(" class ModelNew(nn.Module):")
1933
+ typer.echo(" def __init__(self, ...):")
1934
+ typer.echo(" # Same signature as Model.__init__")
1935
+ typer.echo(" ...")
1936
+ typer.echo("")
1937
+ typer.echo(" def forward(self, ...):")
1938
+ typer.echo(" # Same signature as Model.forward")
1939
+ typer.echo(" # Your optimized implementation here")
1940
+ typer.echo(" ...")
1941
+ typer.echo("")
1942
+ typer.echo(" 3. Run evaluation:")
1943
+ typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
1944
+
1945
+
1946
+ # =============================================================================
1947
+ # GPUMode format evaluation
1948
+ # =============================================================================
1949
+
1950
+
1951
+ @gpumode_app.command("download")
1952
+ def gpumode_download(
1953
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
1954
+ ) -> None:
1955
+ """Download GPUMode reference kernels from GitHub.
1956
+
1957
+ Downloads the problem set to ~/.cache/wafer/problems/gpumode/
1958
+
1959
+ Examples:
1960
+ wafer evaluate gpumode download
1961
+ wafer evaluate gpumode download --force # Re-download
1962
+ """
1963
+ try:
1964
+ path = download_problems("gpumode", force=force, verbose=True)
1965
+ typer.echo("")
1966
+ typer.echo(f"Problems available at: {path}")
1967
+ typer.echo("Run 'wafer evaluate gpumode list-problems' to see available problems.")
1968
+ except Exception as e:
1969
+ typer.echo(f"Error downloading problems: {e}", err=True)
1970
+ raise typer.Exit(1) from None
1971
+
1972
+
1973
+ @gpumode_app.command("list-problems")
1974
+ def gpumode_list_problems() -> None:
1975
+ """List available GPUMode problems.
1976
+
1977
+ Examples:
1978
+ wafer evaluate gpumode list-problems
1979
+ """
1980
+ try:
1981
+ list_problems_fn("gpumode", verbose=True)
1982
+ except ValueError as e:
1983
+ typer.echo(str(e), err=True)
1984
+ raise typer.Exit(1) from None
1985
+
1986
+
1987
+ @gpumode_app.command("make-template")
1988
+ def gpumode_make_template(
1989
+ problem: str = typer.Option(
1990
+ ...,
1991
+ "--problem",
1992
+ "-p",
1993
+ help="Problem ID (e.g., 'pmpp/vectoradd_py' or 'amd/fp8-mm')",
1994
+ ),
1995
+ output: Path = typer.Option(
1996
+ None, "--output", "-o", help="Output directory (default: ./<problem_name>/)"
1997
+ ),
1998
+ force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
1999
+ ) -> None:
2000
+ """Extract a GPUMode problem as template files.
2001
+
2002
+ Creates a directory with reference.py, task.yml, and other problem files.
2003
+ You then create kernel.py with your custom_kernel implementation.
2004
+
2005
+ Examples:
2006
+ # Extract pmpp vectoradd problem
2007
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py
2008
+
2009
+ # Extract to specific directory
2010
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py --output ./my-kernel/
2011
+ """
2012
+ import shutil
2013
+
2014
+ # Get problem path
2015
+ problem_path = get_problem_path("gpumode", problem)
2016
+ if problem_path is None:
2017
+ # Check if problems are downloaded
2018
+ if get_problems_path("gpumode") is None:
2019
+ typer.echo("Error: GPUMode problems not downloaded.", err=True)
2020
+ typer.echo("Run 'wafer evaluate gpumode download' first.", err=True)
2021
+ else:
2022
+ typer.echo(f"Error: Problem '{problem}' not found.", err=True)
2023
+ typer.echo(
2024
+ "Run 'wafer evaluate gpumode list-problems' to see available problems.", err=True
2025
+ )
2026
+ raise typer.Exit(1)
2027
+
2028
+ # Determine output path
2029
+ if output is None:
2030
+ output = Path.cwd() / problem.replace("/", "_")
2031
+
2032
+ output = output.resolve()
2033
+
2034
+ # Check if exists
2035
+ if output.exists() and not force:
2036
+ typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
2037
+ raise typer.Exit(1)
2038
+
2039
+ # Copy the problem directory
2040
+ if output.exists():
2041
+ shutil.rmtree(output)
2042
+ shutil.copytree(problem_path, output)
2043
+
2044
+ typer.echo(f"Created {output}/")
2045
+ typer.echo("")
2046
+ typer.echo("Contents:")
2047
+ for f in sorted(output.iterdir()):
2048
+ if not f.name.startswith("."):
2049
+ typer.echo(f" {f.name}")
2050
+ typer.echo("")
2051
+ typer.echo("Next steps:")
2052
+ typer.echo(" 1. Read reference.py to understand the kernel interface")
2053
+ typer.echo(" 2. Create kernel.py with your custom_kernel implementation:")
2054
+ typer.echo("")
2055
+ typer.echo(" def custom_kernel(data):")
2056
+ typer.echo(" # Your optimized implementation")
2057
+ typer.echo(" ...")
2058
+ typer.echo("")
2059
+ typer.echo(" 3. Run evaluation:")
2060
+ typer.echo(
2061
+ f" wafer evaluate gpumode --impl {output}/kernel.py --reference {output}/reference.py \\"
2062
+ )
2063
+ typer.echo(f" --test-cases {output}/test_cases.json --target <target>")
2064
+
2065
+
2066
+ @gpumode_app.callback(invoke_without_command=True)
2067
+ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2068
+ ctx: typer.Context,
2069
+ implementation: Path | None = typer.Option(
2070
+ None, "--impl", "-i", help="Path to implementation kernel file"
2071
+ ),
2072
+ reference: Path | None = typer.Option(
2073
+ None, "--reference", help="Path to reference kernel file"
2074
+ ),
2075
+ test_cases: Path | None = typer.Option(
2076
+ None, "--test-cases", help="Path to test cases JSON file"
2077
+ ),
2078
+ target: str | None = typer.Option(
2079
+ None,
2080
+ "--target",
2081
+ "-t",
2082
+ help="GPU target name. See 'wafer config targets list' for available targets.",
2083
+ autocompletion=complete_target_name,
2084
+ ),
2085
+ pool: str | None = typer.Option(
2086
+ None,
2087
+ "--pool",
2088
+ "-p",
2089
+ help="Target pool name. Acquires first available target from the pool. "
2090
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
2091
+ ),
2092
+ benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
2093
+ profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
2094
+ defensive: bool = typer.Option(
2095
+ False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
2096
+ ),
2097
+ sync_artifacts: bool = typer.Option(
2098
+ True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
2099
+ ),
2100
+ gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
2101
+ ) -> None:
2102
+ """Run kernel evaluation in GPUMode format (functional).
2103
+
2104
+ This format expects:
2105
+ - Implementation: Python file with `custom_kernel(inputs)` function
2106
+ - Reference: Python file with `ref_kernel(inputs)` and `generate_input(**kwargs)` functions
2107
+ - Test cases: JSON file with test parameters
2108
+
2109
+ Examples:
2110
+ # Basic correctness check
2111
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
2112
+
2113
+ # With benchmarking
2114
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
2115
+ --target vultr-b200 --benchmark
2116
+
2117
+ Subcommands:
2118
+ download Download GPUMode problems from GitHub
2119
+ list-problems List available problems
2120
+ make-template Extract a problem as template files
2121
+ """
2122
+ # If a subcommand is being invoked, skip the main evaluation logic
2123
+ if ctx.invoked_subcommand is not None:
2124
+ return
2125
+
2126
+ # Validate required args when running evaluation (not subcommands)
2127
+ missing_args = []
2128
+ if implementation is None:
2129
+ missing_args.append("--impl/-i")
2130
+ if reference is None:
2131
+ missing_args.append("--reference")
2132
+ if test_cases is None:
2133
+ missing_args.append("--test-cases")
2134
+
2135
+ if missing_args:
2136
+ typer.echo("Error: Missing required arguments", err=True)
2137
+ typer.echo(f" Required: {', '.join(missing_args)}", err=True)
2138
+ typer.echo("", err=True)
2139
+ typer.echo(
2140
+ "Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
2141
+ err=True,
2142
+ )
2143
+ typer.echo("", err=True)
2144
+ typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
2145
+ typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
2146
+ raise typer.Exit(1)
2147
+
2148
+ # Validate --target and --pool are mutually exclusive
2149
+ if target and pool:
2150
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
2151
+ raise typer.Exit(1)
2152
+
2153
+ from .evaluate import EvaluateArgs, run_evaluate
2154
+
2155
+ # If pool specified, acquire a target from the pool
2156
+ resolved_target = target or ""
2157
+ pool_lock_context = None
2158
+
2159
+ if pool:
2160
+ from .target_lock import acquire_from_pool
2161
+ from .targets import filter_pool_by_auth, get_pool
2162
+
2163
+ try:
2164
+ pool_targets = get_pool(pool)
2165
+ except FileNotFoundError as e:
2166
+ typer.echo(f"Error: {e}", err=True)
2167
+ raise typer.Exit(1) from None
2168
+
2169
+ # Filter to only targets with valid auth
2170
+ usable_targets, skipped = filter_pool_by_auth(pool_targets)
2171
+ if skipped:
2172
+ typer.echo(f"Skipping targets without auth: {', '.join(skipped)}", err=True)
2173
+
2174
+ if not usable_targets:
2175
+ typer.echo(f"Error: No usable targets in pool '{pool}'", err=True)
2176
+ typer.echo(" All targets require authentication that is not configured.", err=True)
2177
+ typer.echo(" Run 'wafer auth status' to see which providers need setup.", err=True)
2178
+ raise typer.Exit(1) from None
2179
+
2180
+ typer.echo(f"Acquiring target from pool '{pool}' ({len(usable_targets)} targets)...")
2181
+ pool_lock_context = acquire_from_pool(usable_targets)
2182
+ acquired_target = pool_lock_context.__enter__()
2183
+
2184
+ if acquired_target is None:
2185
+ typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
2186
+ typer.echo(f" Targets: {', '.join(usable_targets)}", err=True)
2187
+ raise typer.Exit(1)
2188
+
2189
+ typer.echo(f"Acquired target: {acquired_target}")
2190
+ resolved_target = acquired_target
2191
+
2192
+ args = EvaluateArgs(
2193
+ implementation=implementation,
2194
+ reference=reference,
2195
+ test_cases=test_cases,
2196
+ target_name=resolved_target,
2197
+ benchmark=benchmark,
2198
+ profile=profile,
2199
+ defensive=defensive,
2200
+ sync_artifacts=sync_artifacts,
2201
+ gpu_id=gpu_id,
2202
+ )
2203
+
2204
+ try:
2205
+ import trio_asyncio
2206
+
2207
+ result = trio_asyncio.run(run_evaluate, args)
2208
+ except KeyboardInterrupt:
2209
+ typer.echo("\nInterrupted by user", err=True)
2210
+ raise typer.Exit(130) from None
2211
+ except Exception as e:
2212
+ if hasattr(e, "exceptions") and e.exceptions:
2213
+ for exc in e.exceptions:
2214
+ typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
2215
+ else:
2216
+ typer.echo(f"Error: {e}", err=True)
2217
+ raise typer.Exit(1) from None
2218
+ finally:
2219
+ # Release pool lock if we acquired one
2220
+ if pool_lock_context is not None:
2221
+ pool_lock_context.__exit__(None, None, None)
2222
+
2223
+ # Print results
2224
+ if result.success:
2225
+ typer.echo("")
2226
+ typer.echo("=" * 60)
2227
+ status = "PASS" if result.all_correct else "FAIL"
2228
+ typer.echo(f"Result: {status}")
2229
+ score_pct = f"{result.correctness_score:.1%}"
2230
+ typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
2231
+ if result.geomean_speedup > 0:
2232
+ typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
2233
+ if result.artifact_path:
2234
+ typer.echo(f"Artifacts: {result.artifact_path}")
2235
+ typer.echo("=" * 60)
2236
+
2237
+ if not result.all_correct:
2238
+ raise typer.Exit(1)
2239
+ else:
2240
+ typer.echo(f"Error: {result.error_message}", err=True)
1714
2241
  raise typer.Exit(1)
1715
2242
 
1716
- # Copy the file
1717
- content = problem_file.read_text()
1718
- output.parent.mkdir(parents=True, exist_ok=True)
1719
- output.write_text(content)
1720
-
1721
- typer.echo(f"Created {output}")
1722
- typer.echo("")
1723
- typer.echo("Next steps:")
1724
- typer.echo(f" 1. Read {output} to understand the Model interface")
1725
- typer.echo(" 2. Create an implementation file with your ModelNew class:")
1726
- typer.echo("")
1727
- typer.echo(" import torch.nn as nn")
1728
- typer.echo("")
1729
- typer.echo(" class ModelNew(nn.Module):")
1730
- typer.echo(" def __init__(self, ...):")
1731
- typer.echo(" # Same signature as Model.__init__")
1732
- typer.echo(" ...")
1733
- typer.echo("")
1734
- typer.echo(" def forward(self, ...):")
1735
- typer.echo(" # Same signature as Model.forward")
1736
- typer.echo(" # Your optimized implementation here")
1737
- typer.echo(" ...")
1738
- typer.echo("")
1739
- typer.echo(" 3. Run evaluation:")
1740
- typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
1741
-
1742
2243
 
1743
2244
  # =============================================================================
1744
2245
  # Push and Remote-Run commands
@@ -1871,7 +2372,7 @@ def _run_direct_mode(
1871
2372
  typer.echo(f"Uploading {upload_dir.name}...")
1872
2373
  try:
1873
2374
  push_result = push_direct(upload_dir, target)
1874
- workspace_name = push_result.workspace_path
2375
+ workspace_name = push_result.workspace_name
1875
2376
  typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
1876
2377
  except Exception as e:
1877
2378
  typer.echo(f"Error uploading: {e}", err=True)
@@ -2044,27 +2545,41 @@ def login(
2044
2545
  None, "--token", "-t", help="Access token (skip browser OAuth)"
2045
2546
  ),
2046
2547
  port: int | None = typer.Option(
2047
- None, "--port", "-p", help="Port for OAuth callback server (default: 8765 for SSH, random for local)"
2548
+ None,
2549
+ "--port",
2550
+ "-p",
2551
+ help="Port for OAuth callback server (local only, ignored for SSH)",
2552
+ ),
2553
+ no_device_code: bool = typer.Option(
2554
+ False,
2555
+ "--no-device-code",
2556
+ help="Force browser OAuth even on SSH (requires port forwarding)",
2048
2557
  ),
2049
2558
  ) -> None:
2050
2559
  """Authenticate CLI with wafer-api via GitHub OAuth.
2051
2560
 
2052
- Opens browser for GitHub authentication. Use --token to skip browser.
2561
+ Local: Opens browser for GitHub authentication.
2562
+ SSH: Uses device code flow (no port forwarding needed).
2563
+
2053
2564
  Uses the API environment from config (see 'wafer config show').
2054
2565
 
2055
- SSH Users:
2056
- - Automatically uses port 8765 (just set up port forwarding once)
2057
- - On local machine: ssh -L 8765:localhost:8765 user@host
2058
- - On remote machine: wafer login
2059
- - Browser opens locally, redirect works through tunnel
2566
+ SSH Users (Easiest):
2567
+ - Just run: wafer login
2568
+ - Visit the URL and enter the code shown
2569
+ - No port forwarding needed!
2570
+
2571
+ SSH with browser (Advanced):
2572
+ - Use --no-device-code to force browser flow
2573
+ - Requires: ssh -L 8765:localhost:8765 user@host
2060
2574
 
2061
2575
  Manual token option:
2062
2576
  - Visit auth.wafer.ai, authenticate, copy token from URL
2063
2577
  - Run: wafer login --token <paste-token>
2064
2578
 
2065
2579
  Examples:
2066
- wafer login # auto-detects SSH, uses appropriate port
2067
- wafer login --port 9000 # override port
2580
+ wafer login # device code on SSH, browser on local
2581
+ wafer login --no-device-code # force browser (needs port forwarding on SSH)
2582
+ wafer login --port 9000 # custom port for browser flow
2068
2583
  wafer login --token xyz # manual token (no browser)
2069
2584
 
2070
2585
  # Change environment:
@@ -2073,7 +2588,7 @@ def login(
2073
2588
  """
2074
2589
  import httpx
2075
2590
 
2076
- from .auth import browser_login, save_credentials, verify_token
2591
+ from .auth import browser_login, device_code_login, save_credentials, verify_token
2077
2592
  from .global_config import get_api_url, get_supabase_url, load_global_config
2078
2593
 
2079
2594
  # Show which environment we're logging into
@@ -2083,21 +2598,31 @@ def login(
2083
2598
  typer.echo(f"Auth: {get_supabase_url()}")
2084
2599
  typer.echo("")
2085
2600
 
2086
- # Auto-detect SSH and use fixed port
2087
- if port is None:
2088
- is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
2089
- if is_ssh:
2090
- port = 8765
2091
- typer.echo("🔒 SSH session detected - using port 8765 for OAuth callback")
2092
- typer.echo(" Make sure you have port forwarding set up:")
2093
- typer.echo(" ssh -L 8765:localhost:8765 user@host")
2094
- typer.echo("")
2601
+ # Auto-detect SSH
2602
+ is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
2095
2603
 
2096
- # Browser OAuth if no token provided
2604
+ # Choose auth method
2097
2605
  refresh_token = None
2098
2606
  if token is None:
2099
2607
  try:
2100
- token, refresh_token = browser_login(port=port)
2608
+ if is_ssh and not no_device_code:
2609
+ # Use device code flow for SSH (no port forwarding needed)
2610
+ typer.echo("🔒 SSH session detected - using device code authentication")
2611
+ typer.echo(" (No port forwarding required!)")
2612
+ typer.echo("")
2613
+ token, refresh_token = device_code_login()
2614
+ else:
2615
+ # Use browser OAuth for local or if explicitly requested
2616
+ if is_ssh:
2617
+ typer.echo("🔒 SSH session detected - using browser authentication")
2618
+ typer.echo(" Make sure you have port forwarding set up:")
2619
+ if port is None:
2620
+ port = 8765
2621
+ typer.echo(f" ssh -L {port}:localhost:{port} user@host")
2622
+ else:
2623
+ typer.echo(f" ssh -L {port}:localhost:{port} user@host")
2624
+ typer.echo("")
2625
+ token, refresh_token = browser_login(port=port)
2101
2626
  except TimeoutError as e:
2102
2627
  typer.echo(f"Error: {e}", err=True)
2103
2628
  raise typer.Exit(1) from None
@@ -2146,9 +2671,8 @@ def login(
2146
2671
  @app.command("logout")
2147
2672
  def logout() -> None:
2148
2673
  """Remove stored credentials."""
2149
- from .auth import clear_credentials
2150
-
2151
2674
  from . import analytics
2675
+ from .auth import clear_credentials
2152
2676
 
2153
2677
  # Track logout event first (while credentials still exist for user identification)
2154
2678
  # Note: track_logout() handles the case where user is not logged in
@@ -2625,6 +3149,7 @@ init_app = typer.Typer(
2625
3149
 
2626
3150
  Choose based on your GPU access:
2627
3151
 
3152
+ local GPU on current machine (no SSH)
2628
3153
  ssh Your own hardware via SSH
2629
3154
  runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
2630
3155
  digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
@@ -2632,6 +3157,92 @@ Choose based on your GPU access:
2632
3157
  targets_app.add_typer(init_app, name="init")
2633
3158
 
2634
3159
 
3160
+ @init_app.command("local")
3161
+ def init_local(
3162
+ name: str = typer.Option("local", "--name", "-n", help="Target name"),
3163
+ gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
3164
+ ) -> None:
3165
+ """Initialize a local target for GPU on current machine.
3166
+
3167
+ Detects your local GPU and configures a target for direct execution
3168
+ (no SSH). Use this when running wafer on the same machine as the GPU.
3169
+
3170
+ Examples:
3171
+ wafer config targets init local
3172
+ wafer config targets init local --name my-5090 --gpu-ids 0,1
3173
+ """
3174
+ from .targets import save_target
3175
+
3176
+ # Parse GPU IDs
3177
+ try:
3178
+ parsed_gpu_ids = [int(g.strip()) for g in gpu_ids.split(",")]
3179
+ except ValueError:
3180
+ typer.echo(f"Error: Invalid GPU IDs '{gpu_ids}'. Use comma-separated integers.", err=True)
3181
+ raise typer.Exit(1) from None
3182
+
3183
+ typer.echo("Detecting local GPU...")
3184
+
3185
+ try:
3186
+ from wafer_core.gpu_detect import (
3187
+ detect_local_gpu,
3188
+ get_compute_capability,
3189
+ get_torch_requirements,
3190
+ )
3191
+
3192
+ detected_gpu = detect_local_gpu()
3193
+
3194
+ if detected_gpu:
3195
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3196
+ if detected_gpu.vendor == "nvidia":
3197
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3198
+ else:
3199
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3200
+ typer.echo(f" GPU count: {detected_gpu.gpu_count}")
3201
+
3202
+ # Get torch requirements and compute capability
3203
+ torch_reqs = get_torch_requirements(detected_gpu)
3204
+ compute_capability = get_compute_capability(detected_gpu)
3205
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3206
+
3207
+ typer.echo(f" PyTorch: {torch_reqs.packages[0]}")
3208
+ else:
3209
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)", err=True)
3210
+ raise typer.Exit(1)
3211
+
3212
+ except ImportError as e:
3213
+ typer.echo(f"Error: Missing dependency: {e}", err=True)
3214
+ raise typer.Exit(1) from None
3215
+
3216
+ # Build target data
3217
+ target_data = {
3218
+ "name": name,
3219
+ "type": "local",
3220
+ "gpu_ids": parsed_gpu_ids,
3221
+ "gpu_type": gpu_type,
3222
+ "compute_capability": compute_capability,
3223
+ "torch_package": torch_reqs.packages[0],
3224
+ "torch_index_url": torch_reqs.index_url,
3225
+ "vendor": detected_gpu.vendor,
3226
+ "driver_version": detected_gpu.driver_version,
3227
+ }
3228
+
3229
+ try:
3230
+ target = save_target(target_data)
3231
+ typer.echo(f"✓ Created target: {target.name}")
3232
+ typer.echo(" Type: Local (no SSH)")
3233
+ typer.echo(f" GPU IDs: {parsed_gpu_ids}")
3234
+ typer.echo(f" GPU Type: {gpu_type}")
3235
+ typer.echo(f" Compute: {compute_capability}")
3236
+ typer.echo(f" Torch: {torch_reqs.packages[0]}")
3237
+ typer.echo("")
3238
+ typer.echo(
3239
+ f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
3240
+ )
3241
+ except (ValueError, AssertionError) as e:
3242
+ typer.echo(f"Error: {e}", err=True)
3243
+ raise typer.Exit(1) from None
3244
+
3245
+
2635
3246
  @init_app.command("runpod")
2636
3247
  def init_runpod(
2637
3248
  name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
@@ -2795,23 +3406,29 @@ def init_ssh(
2795
3406
  host: str = typer.Option(..., "--host", "-H", help="SSH host (user@hostname:port)"),
2796
3407
  ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
2797
3408
  gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
2798
- gpu_type: str = typer.Option(
2799
- "H100", "--gpu-type", help="GPU type (H100, A100, B200, MI300X, etc.)"
3409
+ gpu_type: str | None = typer.Option(
3410
+ None, "--gpu-type", help="GPU type (auto-detected if not specified)"
2800
3411
  ),
2801
3412
  docker_image: str | None = typer.Option(
2802
3413
  None, "--docker-image", "-d", help="Docker image (optional)"
2803
3414
  ),
2804
3415
  ncu: bool = typer.Option(False, "--ncu/--no-ncu", help="NCU profiling available"),
3416
+ no_detect: bool = typer.Option(False, "--no-detect", help="Skip GPU auto-detection"),
2805
3417
  ) -> None:
2806
3418
  """Initialize an SSH target for your own GPU hardware.
2807
3419
 
2808
3420
  Creates a target config for direct SSH access to a GPU machine.
2809
- Use for baremetal servers, VMs, or any machine you have SSH access to.
3421
+ Automatically detects GPU type and selects compatible PyTorch version.
2810
3422
 
2811
3423
  Examples:
3424
+ # Auto-detect GPU (recommended)
2812
3425
  wafer config targets init ssh --name my-gpu --host user@192.168.1.100:22
3426
+
3427
+ # Multiple GPUs with NCU profiling
2813
3428
  wafer config targets init ssh --name lab-h100 --host ubuntu@gpu.lab.com:22 --gpu-ids 0,1 --ncu
2814
- wafer config targets init ssh --name docker-gpu --host user@host:22 --docker-image nvcr.io/nvidia/pytorch:24.01-py3
3429
+
3430
+ # Skip detection, specify manually
3431
+ wafer config targets init ssh --name my-gpu --host user@host:22 --gpu-type H100 --no-detect
2815
3432
  """
2816
3433
  from .targets import save_target
2817
3434
 
@@ -2828,17 +3445,86 @@ def init_ssh(
2828
3445
  typer.echo("Example: user@192.168.1.100:22", err=True)
2829
3446
  raise typer.Exit(1)
2830
3447
 
3448
+ # Auto-detect GPU if not specified
3449
+ detected_gpu = None
3450
+ torch_package = None
3451
+ torch_index_url = None
3452
+
3453
+ if not no_detect:
3454
+ typer.echo(f"Connecting to {host}...")
3455
+ try:
3456
+ import trio
3457
+ import trio_asyncio
3458
+ from wafer_core.async_ssh import AsyncSSHClient
3459
+ from wafer_core.gpu_detect import (
3460
+ detect_remote_gpu,
3461
+ get_compute_capability,
3462
+ get_torch_requirements,
3463
+ )
3464
+
3465
+ expanded_key = str(Path(ssh_key).expanduser())
3466
+
3467
+ async def _detect() -> None:
3468
+ nonlocal detected_gpu, torch_package, torch_index_url
3469
+ # Need trio_asyncio.open_loop() for asyncssh bridge
3470
+ async with trio_asyncio.open_loop():
3471
+ async with AsyncSSHClient(host, expanded_key) as client:
3472
+ detected_gpu = await detect_remote_gpu(client)
3473
+
3474
+ trio.run(_detect)
3475
+
3476
+ if detected_gpu:
3477
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3478
+ if detected_gpu.vendor == "nvidia":
3479
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3480
+ else:
3481
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3482
+
3483
+ # Get torch requirements
3484
+ torch_reqs = get_torch_requirements(detected_gpu)
3485
+ torch_package = torch_reqs.packages[0] # Just torch, not all packages
3486
+ torch_index_url = torch_reqs.index_url
3487
+ typer.echo(f" PyTorch: {torch_package}")
3488
+
3489
+ # Use detected GPU type if not specified
3490
+ if not gpu_type:
3491
+ # Extract GPU name (e.g., "H100" from "NVIDIA H100 80GB HBM3")
3492
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3493
+ else:
3494
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)")
3495
+ if not gpu_type:
3496
+ gpu_type = "H100" # Default fallback
3497
+ typer.echo(f" Using default: {gpu_type}")
3498
+
3499
+ except Exception as e:
3500
+ typer.echo(f" Detection failed: {e}", err=True)
3501
+ if not gpu_type:
3502
+ gpu_type = "H100"
3503
+ typer.echo(f" Using default: {gpu_type}")
3504
+
3505
+ # Fallback if no detection
3506
+ if not gpu_type:
3507
+ gpu_type = "H100"
3508
+
2831
3509
  # Compute capability mappings
2832
- compute_caps = {
2833
- "B200": "10.0",
2834
- "H100": "9.0",
2835
- "A100": "8.0",
2836
- "A10": "8.6",
2837
- "V100": "7.0",
2838
- "MI300X": "9.4",
2839
- "MI250X": "9.0",
2840
- }
2841
- compute_capability = compute_caps.get(gpu_type, "8.0")
3510
+ if detected_gpu:
3511
+ from wafer_core.gpu_detect import get_compute_capability
3512
+
3513
+ compute_capability = get_compute_capability(detected_gpu)
3514
+ else:
3515
+ compute_caps = {
3516
+ "B200": "10.0",
3517
+ "H100": "9.0",
3518
+ "A100": "8.0",
3519
+ "A10": "8.6",
3520
+ "V100": "7.0",
3521
+ "MI300X": "9.4",
3522
+ "MI250X": "9.0",
3523
+ "RTX 5090": "10.0",
3524
+ "RTX 4090": "8.9",
3525
+ "RTX 3090": "8.6",
3526
+ }
3527
+ compute_capability = compute_caps.get(gpu_type, "8.0")
2842
3528
 
2843
3529
  # Build target data
2844
3530
  target_data = {
@@ -2855,6 +3541,12 @@ def init_ssh(
2855
3541
  if docker_image:
2856
3542
  target_data["docker_image"] = docker_image
2857
3543
 
3544
+ # Add torch requirements if detected
3545
+ if torch_package:
3546
+ target_data["torch_package"] = torch_package
3547
+ if torch_index_url:
3548
+ target_data["torch_index_url"] = torch_index_url
3549
+
2858
3550
  try:
2859
3551
  target = save_target(target_data)
2860
3552
  typer.echo(f"✓ Created target: {target.name}")
@@ -2862,9 +3554,12 @@ def init_ssh(
2862
3554
  typer.echo(f" Host: {host}")
2863
3555
  typer.echo(f" GPU IDs: {parsed_gpu_ids}")
2864
3556
  typer.echo(f" GPU Type: {gpu_type}")
3557
+ typer.echo(f" Compute: {compute_capability}")
2865
3558
  typer.echo(f" NCU: {'Yes' if ncu else 'No'}")
2866
3559
  if docker_image:
2867
3560
  typer.echo(f" Docker: {docker_image}")
3561
+ if torch_package:
3562
+ typer.echo(f" Torch: {torch_package}")
2868
3563
  typer.echo("")
2869
3564
  typer.echo(
2870
3565
  f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
@@ -2874,6 +3569,44 @@ def init_ssh(
2874
3569
  raise typer.Exit(1) from None
2875
3570
 
2876
3571
 
3572
+ def _extract_gpu_type(gpu_name: str) -> str:
3573
+ """Extract GPU type from full GPU name.
3574
+
3575
+ Examples:
3576
+ "NVIDIA H100 80GB HBM3" -> "H100"
3577
+ "NVIDIA GeForce RTX 4090" -> "RTX 4090"
3578
+ "AMD Instinct MI300X OAM" -> "MI300X"
3579
+ """
3580
+ gpu_name_upper = gpu_name.upper()
3581
+
3582
+ # Check for known GPU types
3583
+ known_types = [
3584
+ "B200",
3585
+ "B100",
3586
+ "H200",
3587
+ "H100",
3588
+ "A100",
3589
+ "A10",
3590
+ "V100",
3591
+ "RTX 5090",
3592
+ "RTX 5080",
3593
+ "RTX 4090",
3594
+ "RTX 4080",
3595
+ "RTX 3090",
3596
+ "RTX 3080",
3597
+ "MI300X",
3598
+ "MI250X",
3599
+ "MI100",
3600
+ ]
3601
+
3602
+ for gpu_type in known_types:
3603
+ if gpu_type in gpu_name_upper:
3604
+ return gpu_type
3605
+
3606
+ # Fallback: return cleaned name
3607
+ return gpu_name.replace("NVIDIA ", "").replace("AMD ", "").strip()
3608
+
3609
+
2877
3610
  @targets_app.command("add")
2878
3611
  def targets_add(
2879
3612
  file_path: Path = typer.Argument(..., help="Path to target TOML file"),
@@ -2956,6 +3689,93 @@ def targets_show(
2956
3689
  raise typer.Exit(1) from None
2957
3690
 
2958
3691
 
3692
+ @targets_app.command("probe")
3693
+ def targets_probe(
3694
+ name: str = typer.Argument(..., help="Target name"),
3695
+ ) -> None:
3696
+ """Probe a target to discover available compilation backends.
3697
+
3698
+ Connects to the target and checks what's available:
3699
+ - Triton
3700
+ - torch.compile/inductor
3701
+ - HIP/hipcc or CUDA/nvcc
3702
+ - ROCm or CUDA version
3703
+ - Python packages (torch, triton, etc.)
3704
+
3705
+ Example:
3706
+ wafer config targets probe runpod-mi300x
3707
+ """
3708
+ import trio
3709
+
3710
+ from .targets import ProbeError, load_target, probe_target_capabilities
3711
+
3712
+ try:
3713
+ target = load_target(name)
3714
+ except FileNotFoundError as e:
3715
+ typer.echo(f"Error: {e}", err=True)
3716
+ raise typer.Exit(1) from None
3717
+
3718
+ typer.echo(f"Probing target: {name}...")
3719
+
3720
+ try:
3721
+ capabilities = trio.run(probe_target_capabilities, target)
3722
+ except ProbeError as e:
3723
+ # ProbeError already has actionable context
3724
+ typer.echo(f"\nError: {e}", err=True)
3725
+ raise typer.Exit(1) from None
3726
+ except Exception as e:
3727
+ # Unexpected errors - include type for debugging
3728
+ typer.echo(f"\nUnexpected error probing target: {type(e).__name__}: {e}", err=True)
3729
+ raise typer.Exit(1) from None
3730
+
3731
+ # Display results
3732
+ typer.echo(f"\nTarget: {name}")
3733
+
3734
+ if capabilities.get("gpu_name"):
3735
+ typer.echo(f" GPU: {capabilities['gpu_name']}")
3736
+ if capabilities.get("compute_capability"):
3737
+ typer.echo(f" Compute: {capabilities['compute_capability']}")
3738
+
3739
+ typer.echo("\n Compilation Backends:")
3740
+ backends = capabilities.get("backends", {})
3741
+
3742
+ # Triton
3743
+ triton_ver = backends.get("triton")
3744
+ if triton_ver:
3745
+ typer.echo(f" ✓ Triton: {triton_ver}")
3746
+ else:
3747
+ typer.echo(" ✗ Triton: not installed")
3748
+
3749
+ # torch.compile
3750
+ if triton_ver and backends.get("torch"):
3751
+ typer.echo(" ✓ torch.compile/inductor: available")
3752
+ else:
3753
+ typer.echo(" ✗ torch.compile/inductor: requires Triton")
3754
+
3755
+ # HIP/CUDA compiler
3756
+ if backends.get("hipcc"):
3757
+ typer.echo(f" ✓ HIP/hipcc: {backends['hipcc']}")
3758
+ elif backends.get("nvcc"):
3759
+ typer.echo(f" ✓ CUDA/nvcc: {backends['nvcc']}")
3760
+ else:
3761
+ typer.echo(" ✗ No GPU compiler found")
3762
+
3763
+ # ROCm/CUDA version
3764
+ if capabilities.get("rocm_version"):
3765
+ typer.echo(f" ROCm: {capabilities['rocm_version']}")
3766
+ if capabilities.get("cuda_version"):
3767
+ typer.echo(f" CUDA: {capabilities['cuda_version']}")
3768
+
3769
+ typer.echo("\n Python Environment:")
3770
+ typer.echo(f" Python: {capabilities.get('python_version', 'unknown')}")
3771
+
3772
+ packages = capabilities.get("packages", {})
3773
+ if packages.get("torch"):
3774
+ typer.echo(f" PyTorch: {packages['torch']}")
3775
+ if triton_ver:
3776
+ typer.echo(f" Triton: {triton_ver}")
3777
+
3778
+
2959
3779
  @targets_app.command("remove")
2960
3780
  def targets_remove(
2961
3781
  name: str = typer.Argument(..., help="Target name"),
@@ -3086,6 +3906,92 @@ def targets_pods() -> None:
3086
3906
  typer.echo()
3087
3907
 
3088
3908
 
3909
+ # ── Pool commands ───────────────────────────────────────────────────────────
3910
+
3911
+
3912
+ @targets_app.command("pool-list")
3913
+ def targets_pool_list() -> None:
3914
+ """List all configured target pools.
3915
+
3916
+ Example:
3917
+ wafer config targets pool-list
3918
+ """
3919
+ from .targets import get_pool, list_pools
3920
+
3921
+ pools = list_pools()
3922
+
3923
+ if not pools:
3924
+ typer.echo("No pools configured")
3925
+ typer.echo("")
3926
+ typer.echo("Define pools in ~/.wafer/config.toml:")
3927
+ typer.echo(" [pools.my-pool]")
3928
+ typer.echo(' targets = ["target-1", "target-2"]')
3929
+ return
3930
+
3931
+ typer.echo("Configured pools:\n")
3932
+ for pool_name in pools:
3933
+ try:
3934
+ targets = get_pool(pool_name)
3935
+ typer.echo(f" {pool_name}: {', '.join(targets)}")
3936
+ except Exception as e:
3937
+ typer.echo(f" {pool_name}: (error: {e})")
3938
+
3939
+
3940
+ @targets_app.command("pool-create")
3941
+ def targets_pool_create(
3942
+ name: str = typer.Argument(..., help="Pool name"),
3943
+ targets: list[str] = typer.Argument(..., help="Target names to include in pool"),
3944
+ ) -> None:
3945
+ """Create or update a target pool.
3946
+
3947
+ Example:
3948
+ wafer config targets pool-create mi300x-pool mi300x-1 mi300x-2 mi300x-3
3949
+ """
3950
+ from .targets import save_pool
3951
+
3952
+ try:
3953
+ save_pool(name, targets)
3954
+ typer.echo(f"Pool '{name}' created with {len(targets)} targets")
3955
+ except FileNotFoundError as e:
3956
+ typer.echo(f"Error: {e}", err=True)
3957
+ raise typer.Exit(1) from None
3958
+
3959
+
3960
+ @targets_app.command("pool-status")
3961
+ def targets_pool_status(
3962
+ name: str = typer.Argument(..., help="Pool name"),
3963
+ ) -> None:
3964
+ """Show status of targets in a pool (locked/available).
3965
+
3966
+ Example:
3967
+ wafer config targets pool-status mi300x-pool
3968
+ """
3969
+ from .target_lock import get_lock_holder, is_target_locked
3970
+ from .targets import get_pool
3971
+
3972
+ try:
3973
+ targets = get_pool(name)
3974
+ except FileNotFoundError as e:
3975
+ typer.echo(f"Error: {e}", err=True)
3976
+ raise typer.Exit(1) from None
3977
+
3978
+ typer.echo(f"Pool '{name}' ({len(targets)} targets):\n")
3979
+
3980
+ available = 0
3981
+ for target_name in targets:
3982
+ locked = is_target_locked(target_name)
3983
+ if locked:
3984
+ pid = get_lock_holder(target_name)
3985
+ pid_str = f" (pid {pid})" if pid else ""
3986
+ typer.echo(f" [busy] {target_name}{pid_str}")
3987
+ else:
3988
+ typer.echo(f" [free] {target_name}")
3989
+ available += 1
3990
+
3991
+ typer.echo("")
3992
+ typer.echo(f"Available: {available}/{len(targets)}")
3993
+
3994
+
3089
3995
  # =============================================================================
3090
3996
  # Billing commands
3091
3997
  # =============================================================================
@@ -3119,7 +4025,9 @@ def billing_usage(
3119
4025
  @billing_app.command("topup")
3120
4026
  def billing_topup(
3121
4027
  amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
3122
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
4028
+ no_browser: bool = typer.Option(
4029
+ False, "--no-browser", help="Print URL instead of opening browser"
4030
+ ),
3123
4031
  ) -> None:
3124
4032
  """Add credits to your account.
3125
4033
 
@@ -3165,7 +4073,9 @@ def billing_topup(
3165
4073
 
3166
4074
  @billing_app.command("portal")
3167
4075
  def billing_portal(
3168
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
4076
+ no_browser: bool = typer.Option(
4077
+ False, "--no-browser", help="Print URL instead of opening browser"
4078
+ ),
3169
4079
  ) -> None:
3170
4080
  """Open Stripe billing portal.
3171
4081
 
@@ -3319,7 +4229,7 @@ def workspaces_exec(
3319
4229
  workspace: str | None = typer.Argument(
3320
4230
  None, help="Workspace name or ID (optional if default set)"
3321
4231
  ),
3322
- command: list[str] = typer.Argument(..., help="Command to execute on GPU"),
4232
+ command: list[str] = typer.Argument(..., help="Command to execute"),
3323
4233
  timeout: int | None = typer.Option(
3324
4234
  None,
3325
4235
  "--timeout",
@@ -3332,13 +4242,23 @@ def workspaces_exec(
3332
4242
  "-s",
3333
4243
  help="Sync local directory to workspace before executing",
3334
4244
  ),
4245
+ gpu: bool = typer.Option(False, "--gpu", help="Force GPU routing (default behavior)"),
4246
+ cpu: bool = typer.Option(False, "--cpu", help="Run in workspace container (no GPU)"),
4247
+ baremetal: bool = typer.Option(
4248
+ False, "--baremetal", help="Force baremetal target (for hardware counters like ncu/nsys)"
4249
+ ),
3335
4250
  verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
3336
4251
  quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
3337
4252
  ) -> None:
3338
- """Execute a command in workspace with GPU routing.
4253
+ """Execute a command in workspace.
4254
+
4255
+ By default, auto-detects whether to route to GPU based on the command.
4256
+ Use --gpu, --cpu, or --baremetal to override.
3339
4257
 
3340
- Runs the command on the workspace's configured GPU target (Modal, baremetal, etc.)
3341
- and streams output back. No SSH or zsh plugin required.
4258
+ Routing options:
4259
+ --gpu Force GPU container (Modal or baremetal with GPU)
4260
+ --cpu Run in workspace container directly (no GPU)
4261
+ --baremetal Force baremetal target (for ncu, nsys, hardware counters)
3342
4262
 
3343
4263
  If workspace is not specified, uses the default workspace from config,
3344
4264
  or the only workspace if you have exactly one.
@@ -3353,6 +4273,21 @@ def workspaces_exec(
3353
4273
  from .global_config import get_defaults, get_preferences
3354
4274
  from .workspaces import exec_command, resolve_workspace, sync_files
3355
4275
 
4276
+ # Validate mutually exclusive routing flags
4277
+ routing_flags = sum([gpu, cpu, baremetal])
4278
+ if routing_flags > 1:
4279
+ typer.echo("Error: --gpu, --cpu, and --baremetal are mutually exclusive", err=True)
4280
+ raise typer.Exit(1)
4281
+
4282
+ # Determine routing (None = auto-detect)
4283
+ routing: str | None = None
4284
+ if gpu:
4285
+ routing = "gpu"
4286
+ elif cpu:
4287
+ routing = "cpu"
4288
+ elif baremetal:
4289
+ routing = "baremetal"
4290
+
3356
4291
  # Resolve workspace (specified, config default, or single workspace)
3357
4292
  try:
3358
4293
  resolved_workspace = resolve_workspace(workspace)
@@ -3377,7 +4312,8 @@ def workspaces_exec(
3377
4312
  show_status = prefs.mode == "explicit"
3378
4313
 
3379
4314
  if show_status:
3380
- typer.echo(f"[wafer] Workspace: {resolved_workspace}", err=True)
4315
+ routing_label = routing or "auto"
4316
+ typer.echo(f"[wafer] Workspace: {resolved_workspace} (routing: {routing_label})", err=True)
3381
4317
 
3382
4318
  # Sync files if requested
3383
4319
  if sync is not None:
@@ -3413,8 +4349,15 @@ def workspaces_exec(
3413
4349
  # Remove leading "--" if present (typer passes it through with allow_interspersed_args=False)
3414
4350
  if command and command[0] == "--":
3415
4351
  command = command[1:]
3416
- # Use shlex.join to properly quote args containing spaces/special chars
3417
- command_str = shlex.join(command)
4352
+ # Handle two cases:
4353
+ # 1. Single element: user quoted the whole command (e.g., "echo hello world")
4354
+ # -> use directly, don't re-quote
4355
+ # 2. Multiple elements: user passed separate args (e.g., -- python -c "print(1)")
4356
+ # -> use shlex.join to properly quote args with spaces
4357
+ if len(command) == 1:
4358
+ command_str = command[0]
4359
+ else:
4360
+ command_str = shlex.join(command)
3418
4361
  else:
3419
4362
  command_str = command
3420
4363
 
@@ -3423,6 +4366,7 @@ def workspaces_exec(
3423
4366
  workspace_id=resolved_workspace,
3424
4367
  command=command_str,
3425
4368
  timeout_seconds=effective_timeout,
4369
+ routing=routing,
3426
4370
  )
3427
4371
  except RuntimeError as e:
3428
4372
  typer.echo(f"Error: {e}", err=True)
@@ -4441,8 +5385,8 @@ def _setup_wafer_core_env() -> None:
4441
5385
  - WAFER_API_URL: If already set, uses that instead of config
4442
5386
  - WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
4443
5387
  """
4444
- from .global_config import get_api_url
4445
5388
  from .auth import get_valid_token
5389
+ from .global_config import get_api_url
4446
5390
 
4447
5391
  # Set API URL (get_api_url already respects WAFER_API_URL env var)
4448
5392
  os.environ["WAFER_API_URL"] = get_api_url()
@@ -4746,8 +5690,8 @@ def capture_command( # noqa: PLR0915
4746
5690
  import os
4747
5691
  import tomllib
4748
5692
 
4749
- from .global_config import get_api_url
4750
5693
  from .auth import get_valid_token
5694
+ from .global_config import get_api_url
4751
5695
 
4752
5696
  # Set environment variables for wafer-core BEFORE importing it
4753
5697
  # wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
@@ -4951,8 +5895,8 @@ def capture_list_command(
4951
5895
  """
4952
5896
  import os
4953
5897
 
4954
- from .global_config import get_api_url
4955
5898
  from .auth import get_valid_token
5899
+ from .global_config import get_api_url
4956
5900
 
4957
5901
  # Set environment variables for wafer-core BEFORE importing it
4958
5902
  os.environ["WAFER_API_URL"] = get_api_url()
@@ -5301,6 +6245,98 @@ def isa_analyze(
5301
6245
  raise typer.Exit(1) from None
5302
6246
 
5303
6247
 
6248
+ # =============================================================================
6249
+ # Kernel Scope Commands (wafer amd kernel-scope ...)
6250
+ # =============================================================================
6251
+
6252
+
6253
+ @kernel_scope_app.command("analyze")
6254
+ def kernel_scope_analyze(
6255
+ path: Path = typer.Argument(..., help="Path to file or directory to analyze"),
6256
+ json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
6257
+ csv_output: bool = typer.Option(False, "--csv", help="Output as CSV"),
6258
+ recursive: bool = typer.Option(
6259
+ True, "--recursive/--no-recursive", "-r", help="Scan directories recursively"
6260
+ ),
6261
+ filter_expr: str | None = typer.Option(
6262
+ None, "--filter", "-f", help="Filter results (e.g., 'spills > 0')"
6263
+ ),
6264
+ output_file: Path | None = typer.Option(None, "--output", "-o", help="Write output to file"),
6265
+ kernel_index: int = typer.Option(0, "--kernel", "-k", help="Kernel index if multiple in file"),
6266
+ ) -> None:
6267
+ """Analyze Triton compilation artifacts (ISA, LLVM-IR, TTGIR).
6268
+
6269
+ Performs static analysis to extract performance metrics like register
6270
+ pressure, spills, MFMA density, and occupancy limits.
6271
+
6272
+ Supports:
6273
+ - AMDGCN ISA files (.s, .gcn, .asm)
6274
+ - LLVM-IR files (.ll)
6275
+ - TTGIR files (.ttgir, .ttir, .mlir)
6276
+
6277
+ Examples:
6278
+ wafer amd kernel-scope analyze kernel.s
6279
+ wafer amd kernel-scope analyze kernel.s --json
6280
+ wafer amd kernel-scope analyze ~/.triton/cache/ --filter 'spills > 0'
6281
+ wafer amd kernel-scope analyze . -r --csv -o metrics.csv
6282
+ """
6283
+ from .kernel_scope import analyze_command
6284
+
6285
+ try:
6286
+ output = analyze_command(
6287
+ path=str(path),
6288
+ json_output=json_output,
6289
+ csv_output=csv_output,
6290
+ recursive=recursive,
6291
+ filter_expr=filter_expr,
6292
+ output_file=str(output_file) if output_file else None,
6293
+ kernel_index=kernel_index,
6294
+ )
6295
+ typer.echo(output)
6296
+
6297
+ except FileNotFoundError as e:
6298
+ typer.echo(f"Error: {e}", err=True)
6299
+ raise typer.Exit(1) from None
6300
+ except RuntimeError as e:
6301
+ typer.echo(f"Error: {e}", err=True)
6302
+ raise typer.Exit(1) from None
6303
+ except Exception as e:
6304
+ typer.echo(f"Error: {e}", err=True)
6305
+ raise typer.Exit(1) from None
6306
+
6307
+
6308
+ @kernel_scope_app.command("metrics")
6309
+ def kernel_scope_metrics() -> None:
6310
+ """List available metrics for kernel scope analysis.
6311
+
6312
+ Shows all metrics that can be extracted from Triton compilation
6313
+ artifacts, along with their derivation.
6314
+
6315
+ Examples:
6316
+ wafer amd kernel-scope metrics
6317
+ """
6318
+ from .kernel_scope import metrics_command
6319
+
6320
+ output = metrics_command()
6321
+ typer.echo(output)
6322
+
6323
+
6324
+ @kernel_scope_app.command("targets")
6325
+ def kernel_scope_targets() -> None:
6326
+ """List supported GPU targets and their specifications.
6327
+
6328
+ Shows hardware specs (VGPRs, SGPRs, LDS, etc.) for each supported
6329
+ AMD GPU architecture.
6330
+
6331
+ Examples:
6332
+ wafer amd kernel-scope targets
6333
+ """
6334
+ from .kernel_scope import targets_command
6335
+
6336
+ output = targets_command()
6337
+ typer.echo(output)
6338
+
6339
+
5304
6340
  def main() -> None:
5305
6341
  """Entry point for wafer CLI."""
5306
6342
  app()