wafer-cli 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -30,6 +30,14 @@ import typer
30
30
 
31
31
  from .config import WaferConfig, WaferEnvironment
32
32
  from .inference import infer_upload_files, resolve_environment
33
+ from .problems import (
34
+ download_problems,
35
+ get_problem_path,
36
+ get_problems_path,
37
+ )
38
+ from .problems import (
39
+ list_problems as list_problems_fn,
40
+ )
33
41
 
34
42
  app = typer.Typer(
35
43
  help="GPU development toolkit for LLM coding agents",
@@ -91,11 +99,15 @@ def main_callback(ctx: typer.Context) -> None:
91
99
  # Install exception hook to catch SystemExit and mark failures
92
100
  original_excepthook = sys.excepthook
93
101
 
94
- def custom_excepthook(exc_type, exc_value, exc_traceback):
102
+ def custom_excepthook(
103
+ exc_type: type[BaseException],
104
+ exc_value: BaseException,
105
+ exc_traceback: object,
106
+ ) -> None:
95
107
  global _command_outcome
96
108
  # Mark as failure if SystemExit with non-zero code, or any other exception
97
109
  if exc_type is SystemExit:
98
- exit_code = exc_value.code if hasattr(exc_value, 'code') else 1
110
+ exit_code = exc_value.code if hasattr(exc_value, "code") else 1
99
111
  if exit_code != 0 and exit_code is not None:
100
112
  _command_outcome = "failure"
101
113
  else:
@@ -200,6 +212,13 @@ kernelbench_app = typer.Typer(
200
212
  )
201
213
  evaluate_app.add_typer(kernelbench_app, name="kernelbench")
202
214
 
215
+ # Nested subcommand for gpumode format
216
+ gpumode_app = typer.Typer(
217
+ help="Evaluate kernels in GPUMode format (custom_kernel/ref_kernel functions)",
218
+ invoke_without_command=True,
219
+ )
220
+ evaluate_app.add_typer(gpumode_app, name="gpumode")
221
+
203
222
  # =============================================================================
204
223
  # Dev commands (internal, used by web app proxy)
205
224
  # =============================================================================
@@ -396,6 +415,122 @@ def skill_status() -> None:
396
415
  typer.echo(f"{tool_name}: Not installed")
397
416
 
398
417
 
418
+ # =============================================================================
419
+ # Provider auth management (wafer auth ...)
420
+ # =============================================================================
421
+
422
+ provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
423
+ app.add_typer(provider_auth_app, name="auth")
424
+
425
+
426
+ @provider_auth_app.command("login")
427
+ def provider_auth_login(
428
+ provider: str = typer.Argument(
429
+ ...,
430
+ help="Provider name: runpod, digitalocean, or modal",
431
+ ),
432
+ api_key: str | None = typer.Option(
433
+ None,
434
+ "--api-key",
435
+ "-k",
436
+ help="API key (if not provided, reads from stdin)",
437
+ ),
438
+ ) -> None:
439
+ """Save API key for a cloud GPU provider.
440
+
441
+ Stores the key in ~/.wafer/auth.json. Environment variables
442
+ (e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
443
+
444
+ Examples:
445
+ wafer auth login runpod --api-key rp_xxx
446
+ wafer auth login digitalocean --api-key dop_v1_xxx
447
+ echo $API_KEY | wafer auth login runpod
448
+ """
449
+ import sys
450
+
451
+ from wafer_core.auth import PROVIDERS, save_api_key
452
+
453
+ # Validate provider
454
+ if provider not in PROVIDERS:
455
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
456
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
457
+ raise typer.Exit(1)
458
+
459
+ # Get API key from option or stdin
460
+ if api_key is None:
461
+ if sys.stdin.isatty():
462
+ typer.echo(f"Enter API key for {PROVIDERS[provider]['display_name']}:")
463
+ api_key = typer.prompt("API key", hide_input=True)
464
+ else:
465
+ api_key = sys.stdin.read().strip()
466
+
467
+ if not api_key:
468
+ typer.echo("Error: No API key provided", err=True)
469
+ raise typer.Exit(1)
470
+
471
+ # Save the key
472
+ save_api_key(provider, api_key)
473
+ typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
474
+ typer.echo("Stored in: ~/.wafer/auth.json")
475
+
476
+
477
+ @provider_auth_app.command("logout")
478
+ def provider_auth_logout(
479
+ provider: str = typer.Argument(
480
+ ...,
481
+ help="Provider name: runpod, digitalocean, or modal",
482
+ ),
483
+ ) -> None:
484
+ """Remove stored API key for a cloud GPU provider.
485
+
486
+ Examples:
487
+ wafer auth logout runpod
488
+ wafer auth logout digitalocean
489
+ """
490
+ from wafer_core.auth import PROVIDERS, remove_api_key
491
+
492
+ # Validate provider
493
+ if provider not in PROVIDERS:
494
+ typer.echo(f"Error: Unknown provider '{provider}'", err=True)
495
+ typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
496
+ raise typer.Exit(1)
497
+
498
+ if remove_api_key(provider):
499
+ typer.echo(f"API key removed for {PROVIDERS[provider]['display_name']}")
500
+ else:
501
+ typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
502
+
503
+
504
+ @provider_auth_app.command("status")
505
+ def provider_auth_status() -> None:
506
+ """Show authentication status for all cloud GPU providers.
507
+
508
+ Displays which providers have API keys configured and where
509
+ the keys are coming from (environment variable or auth.json).
510
+
511
+ Example:
512
+ wafer auth status
513
+ """
514
+ from wafer_core.auth import get_all_auth_status
515
+
516
+ statuses = get_all_auth_status()
517
+
518
+ typer.echo("Cloud GPU Provider Authentication Status")
519
+ typer.echo("=" * 45)
520
+
521
+ for status in statuses:
522
+ if status.is_authenticated:
523
+ source_str = f"({status.source})" if status.source else ""
524
+ typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
525
+ else:
526
+ typer.echo(f" {status.display_name}: ✗ Not configured")
527
+ typer.echo(f" Run: wafer auth login {status.provider}")
528
+ typer.echo(f" Or set: {status.key_url}")
529
+
530
+ typer.echo("")
531
+ typer.echo("Note: Environment variables take precedence over stored keys.")
532
+
533
+
399
534
  @app.command(hidden=True)
400
535
  def run(
401
536
  command: str = typer.Argument(..., help="Command to run in Docker container"),
@@ -1289,86 +1424,27 @@ def evaluate( # noqa: PLR0913
1289
1424
  --benchmark --defensive
1290
1425
 
1291
1426
  Subcommands:
1292
- make-template Generate template files for this format
1427
+ gpumode Use GPUMode format (functional) - RECOMMENDED
1293
1428
  kernelbench Use KernelBench format (ModelNew class)
1429
+ make-template Generate template files for this format (deprecated)
1294
1430
  """
1295
1431
  # If a subcommand is being invoked, skip the main evaluation logic
1296
1432
  if ctx.invoked_subcommand is not None:
1297
1433
  return
1298
1434
 
1299
- # Validate required args when running evaluation (not subcommands)
1300
- missing_args = []
1301
- if implementation is None:
1302
- missing_args.append("--impl/-i")
1303
- if reference is None:
1304
- missing_args.append("--reference")
1305
- if test_cases is None:
1306
- missing_args.append("--test-cases")
1307
-
1308
- if missing_args:
1309
- typer.echo("Error: Missing required arguments", err=True)
1310
- typer.echo(f" Required: {', '.join(missing_args)}", err=True)
1311
- typer.echo("", err=True)
1312
- typer.echo(
1313
- "Usage: wafer evaluate --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
1314
- err=True,
1315
- )
1316
- typer.echo("", err=True)
1317
- typer.echo("Run 'wafer evaluate --help' for full options.", err=True)
1318
- typer.echo("Run 'wafer evaluate make-template DIR' to generate starter files.", err=True)
1319
- raise typer.Exit(1)
1320
-
1321
- from .evaluate import EvaluateArgs, run_evaluate
1322
-
1323
- args = EvaluateArgs(
1324
- implementation=implementation,
1325
- reference=reference,
1326
- test_cases=test_cases,
1327
- target_name=target or "",
1328
- benchmark=benchmark,
1329
- profile=profile,
1330
- defensive=defensive,
1331
- sync_artifacts=sync_artifacts,
1332
- gpu_id=gpu_id,
1333
- )
1334
-
1335
- try:
1336
- # Use trio_asyncio to run async code that uses both trio and asyncio
1337
- # (AsyncSSHClient uses asyncssh which is asyncio-based, bridged via trio_asyncio)
1338
- import trio_asyncio
1339
-
1340
- result = trio_asyncio.run(run_evaluate, args)
1341
- except KeyboardInterrupt:
1342
- typer.echo("\nInterrupted by user", err=True)
1343
- raise typer.Exit(130) from None
1344
- except Exception as e:
1345
- # Unwrap ExceptionGroup (from Trio nurseries) to show actual error
1346
- if hasattr(e, "exceptions") and e.exceptions:
1347
- for exc in e.exceptions:
1348
- typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
1349
- else:
1350
- typer.echo(f"Error: {e}", err=True)
1351
- raise typer.Exit(1) from None
1352
-
1353
- # Print results
1354
- if result.success:
1355
- typer.echo("")
1356
- typer.echo("=" * 60)
1357
- status = "PASS" if result.all_correct else "FAIL"
1358
- typer.echo(f"Result: {status}")
1359
- score_pct = f"{result.correctness_score:.1%}"
1360
- typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
1361
- if result.geomean_speedup > 0:
1362
- typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
1363
- if result.artifact_path:
1364
- typer.echo(f"Artifacts: {result.artifact_path}")
1365
- typer.echo("=" * 60)
1366
-
1367
- if not result.all_correct:
1368
- raise typer.Exit(1)
1369
- else:
1370
- typer.echo(f"Error: {result.error_message}", err=True)
1371
- raise typer.Exit(1)
1435
+ # Bare 'wafer evaluate' is no longer supported - must use subcommand
1436
+ typer.echo("Error: 'wafer evaluate' requires a subcommand.", err=True)
1437
+ typer.echo("", err=True)
1438
+ typer.echo("Available subcommands:", err=True)
1439
+ typer.echo(" gpumode Evaluate GPUMode format (custom_kernel/ref_kernel functions)", err=True)
1440
+ typer.echo(" kernelbench Evaluate KernelBench format (ModelNew class)", err=True)
1441
+ typer.echo("", err=True)
1442
+ typer.echo("Examples:", err=True)
1443
+ typer.echo(" wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json", err=True)
1444
+ typer.echo(" wafer evaluate kernelbench --impl impl.py --reference ref.py --benchmark", err=True)
1445
+ typer.echo("", err=True)
1446
+ typer.echo("Run 'wafer evaluate gpumode --help' or 'wafer evaluate kernelbench --help' for options.", err=True)
1447
+ raise typer.Exit(1)
1372
1448
 
1373
1449
 
1374
1450
  TEMPLATE_KERNEL = '''\
@@ -1503,8 +1579,59 @@ def evaluate_make_template(
1503
1579
  # KernelBench format evaluation
1504
1580
  # =============================================================================
1505
1581
 
1506
- # Path to KernelBench problems (relative to wafer root)
1507
- KERNELBENCH_ROOT = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench"
1582
+
1583
+ def _get_kernelbench_root() -> Path | None:
1584
+ """Get KernelBench problems root, preferring downloaded location."""
1585
+ # First check downloaded location
1586
+ downloaded = get_problems_path("kernelbench")
1587
+ if downloaded is not None:
1588
+ kb_root = downloaded / "KernelBench"
1589
+ if kb_root.exists():
1590
+ return kb_root
1591
+ return downloaded
1592
+
1593
+ # Fall back to legacy location (for development)
1594
+ legacy = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench" / "KernelBench"
1595
+ if legacy.exists():
1596
+ return legacy
1597
+
1598
+ return None
1599
+
1600
+
1601
+ @kernelbench_app.command("download")
1602
+ def kernelbench_download(
1603
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
1604
+ ) -> None:
1605
+ """Download KernelBench problems from GitHub.
1606
+
1607
+ Downloads the problem set to ~/.cache/wafer/problems/kernelbench/
1608
+
1609
+ Examples:
1610
+ wafer evaluate kernelbench download
1611
+ wafer evaluate kernelbench download --force # Re-download
1612
+ """
1613
+ try:
1614
+ path = download_problems("kernelbench", force=force, verbose=True)
1615
+ typer.echo("")
1616
+ typer.echo(f"Problems available at: {path}")
1617
+ typer.echo("Run 'wafer evaluate kernelbench list-problems' to see available problems.")
1618
+ except Exception as e:
1619
+ typer.echo(f"Error downloading problems: {e}", err=True)
1620
+ raise typer.Exit(1) from None
1621
+
1622
+
1623
+ @kernelbench_app.command("list-problems")
1624
+ def kernelbench_list_problems() -> None:
1625
+ """List available KernelBench problems.
1626
+
1627
+ Examples:
1628
+ wafer evaluate kernelbench list-problems
1629
+ """
1630
+ try:
1631
+ list_problems_fn("kernelbench", verbose=True)
1632
+ except ValueError as e:
1633
+ typer.echo(str(e), err=True)
1634
+ raise typer.Exit(1) from None
1508
1635
 
1509
1636
 
1510
1637
  @kernelbench_app.callback(invoke_without_command=True)
@@ -1528,8 +1655,19 @@ def kernelbench_evaluate( # noqa: PLR0913
1528
1655
  help="GPU target name. See 'wafer config targets list' for available targets.",
1529
1656
  autocompletion=complete_target_name,
1530
1657
  ),
1658
+ pool: str | None = typer.Option(
1659
+ None,
1660
+ "--pool",
1661
+ "-p",
1662
+ help="Target pool name. Acquires first available target from the pool. "
1663
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
1664
+ ),
1531
1665
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
1532
1666
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
1667
+ inputs: Path | None = typer.Option(
1668
+ None, "--inputs", help="Custom inputs file to override get_inputs()"
1669
+ ),
1670
+ seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
1533
1671
  defensive: bool = typer.Option(
1534
1672
  False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
1535
1673
  ),
@@ -1586,14 +1724,47 @@ def kernelbench_evaluate( # noqa: PLR0913
1586
1724
  )
1587
1725
  raise typer.Exit(1)
1588
1726
 
1727
+ # Validate --target and --pool are mutually exclusive
1728
+ if target and pool:
1729
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
1730
+ raise typer.Exit(1)
1731
+
1589
1732
  from .evaluate import KernelBenchEvaluateArgs, run_evaluate_kernelbench
1590
1733
 
1734
+ # If pool specified, acquire a target from the pool
1735
+ resolved_target = target or ""
1736
+ pool_lock_context = None
1737
+
1738
+ if pool:
1739
+ from .target_lock import acquire_from_pool
1740
+ from .targets import get_pool
1741
+
1742
+ try:
1743
+ pool_targets = get_pool(pool)
1744
+ except FileNotFoundError as e:
1745
+ typer.echo(f"Error: {e}", err=True)
1746
+ raise typer.Exit(1) from None
1747
+
1748
+ typer.echo(f"Acquiring target from pool '{pool}' ({len(pool_targets)} targets)...")
1749
+ pool_lock_context = acquire_from_pool(pool_targets)
1750
+ acquired_target = pool_lock_context.__enter__()
1751
+
1752
+ if acquired_target is None:
1753
+ typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
1754
+ typer.echo(f" Targets: {', '.join(pool_targets)}", err=True)
1755
+ raise typer.Exit(1)
1756
+
1757
+ typer.echo(f"Acquired target: {acquired_target}")
1758
+ resolved_target = acquired_target
1759
+
1591
1760
  args = KernelBenchEvaluateArgs(
1592
1761
  implementation=implementation,
1593
1762
  reference=reference,
1594
- target_name=target or "",
1763
+ target_name=resolved_target,
1595
1764
  benchmark=benchmark,
1596
1765
  profile=profile,
1766
+ inputs=inputs,
1767
+ seed=seed,
1597
1768
  defensive=defensive,
1598
1769
  sync_artifacts=sync_artifacts,
1599
1770
  gpu_id=gpu_id,
@@ -1609,6 +1780,10 @@ def kernelbench_evaluate( # noqa: PLR0913
1609
1780
  except Exception as e:
1610
1781
  typer.echo(f"Error: {e}", err=True)
1611
1782
  raise typer.Exit(1) from None
1783
+ finally:
1784
+ # Release pool lock if we acquired one
1785
+ if pool_lock_context is not None:
1786
+ pool_lock_context.__exit__(None, None, None)
1612
1787
 
1613
1788
  # Print results
1614
1789
  if result.success:
@@ -1655,6 +1830,13 @@ def kernelbench_make_template(
1655
1830
  # Overwrite existing
1656
1831
  wafer evaluate kernelbench make-template level1/1 --force
1657
1832
  """
1833
+ # Get problems root (downloaded or legacy)
1834
+ kb_root = _get_kernelbench_root()
1835
+ if kb_root is None:
1836
+ typer.echo("Error: KernelBench problems not found.", err=True)
1837
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1838
+ raise typer.Exit(1)
1839
+
1658
1840
  # Parse problem ID
1659
1841
  parts = problem.split("/")
1660
1842
  if len(parts) != 2:
@@ -1666,10 +1848,10 @@ def kernelbench_make_template(
1666
1848
  level_str = f"level{level_str}"
1667
1849
 
1668
1850
  # Find the problem file
1669
- problem_dir = KERNELBENCH_ROOT / "KernelBench" / level_str
1851
+ problem_dir = kb_root / level_str
1670
1852
  if not problem_dir.exists():
1671
1853
  typer.echo(f"Error: KernelBench level directory not found: {problem_dir}", err=True)
1672
- typer.echo(f"Make sure KernelBench is at: {KERNELBENCH_ROOT}", err=True)
1854
+ typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
1673
1855
  raise typer.Exit(1)
1674
1856
 
1675
1857
  # Find matching problem file
@@ -1736,6 +1918,293 @@ def kernelbench_make_template(
1736
1918
  typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
1737
1919
 
1738
1920
 
1921
+ # =============================================================================
1922
+ # GPUMode format evaluation
1923
+ # =============================================================================
1924
+
1925
+
1926
+ @gpumode_app.command("download")
1927
+ def gpumode_download(
1928
+ force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
1929
+ ) -> None:
1930
+ """Download GPUMode reference kernels from GitHub.
1931
+
1932
+ Downloads the problem set to ~/.cache/wafer/problems/gpumode/
1933
+
1934
+ Examples:
1935
+ wafer evaluate gpumode download
1936
+ wafer evaluate gpumode download --force # Re-download
1937
+ """
1938
+ try:
1939
+ path = download_problems("gpumode", force=force, verbose=True)
1940
+ typer.echo("")
1941
+ typer.echo(f"Problems available at: {path}")
1942
+ typer.echo("Run 'wafer evaluate gpumode list-problems' to see available problems.")
1943
+ except Exception as e:
1944
+ typer.echo(f"Error downloading problems: {e}", err=True)
1945
+ raise typer.Exit(1) from None
1946
+
1947
+
1948
+ @gpumode_app.command("list-problems")
1949
+ def gpumode_list_problems() -> None:
1950
+ """List available GPUMode problems.
1951
+
1952
+ Examples:
1953
+ wafer evaluate gpumode list-problems
1954
+ """
1955
+ try:
1956
+ list_problems_fn("gpumode", verbose=True)
1957
+ except ValueError as e:
1958
+ typer.echo(str(e), err=True)
1959
+ raise typer.Exit(1) from None
1960
+
1961
+
1962
+ @gpumode_app.command("make-template")
1963
+ def gpumode_make_template(
1964
+ problem: str = typer.Option(
1965
+ ...,
1966
+ "--problem",
1967
+ "-p",
1968
+ help="Problem ID (e.g., 'pmpp/vectoradd_py' or 'amd/fp8-mm')",
1969
+ ),
1970
+ output: Path = typer.Option(
1971
+ None, "--output", "-o", help="Output directory (default: ./<problem_name>/)"
1972
+ ),
1973
+ force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
1974
+ ) -> None:
1975
+ """Extract a GPUMode problem as template files.
1976
+
1977
+ Creates a directory with reference.py, task.yml, and other problem files.
1978
+ You then create kernel.py with your custom_kernel implementation.
1979
+
1980
+ Examples:
1981
+ # Extract pmpp vectoradd problem
1982
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py
1983
+
1984
+ # Extract to specific directory
1985
+ wafer evaluate gpumode make-template --problem pmpp/vectoradd_py --output ./my-kernel/
1986
+ """
1987
+ import shutil
1988
+
1989
+ # Get problem path
1990
+ problem_path = get_problem_path("gpumode", problem)
1991
+ if problem_path is None:
1992
+ # Check if problems are downloaded
1993
+ if get_problems_path("gpumode") is None:
1994
+ typer.echo("Error: GPUMode problems not downloaded.", err=True)
1995
+ typer.echo("Run 'wafer evaluate gpumode download' first.", err=True)
1996
+ else:
1997
+ typer.echo(f"Error: Problem '{problem}' not found.", err=True)
1998
+ typer.echo(
1999
+ "Run 'wafer evaluate gpumode list-problems' to see available problems.", err=True
2000
+ )
2001
+ raise typer.Exit(1)
2002
+
2003
+ # Determine output path
2004
+ if output is None:
2005
+ output = Path.cwd() / problem.replace("/", "_")
2006
+
2007
+ output = output.resolve()
2008
+
2009
+ # Check if exists
2010
+ if output.exists() and not force:
2011
+ typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
2012
+ raise typer.Exit(1)
2013
+
2014
+ # Copy the problem directory
2015
+ if output.exists():
2016
+ shutil.rmtree(output)
2017
+ shutil.copytree(problem_path, output)
2018
+
2019
+ typer.echo(f"Created {output}/")
2020
+ typer.echo("")
2021
+ typer.echo("Contents:")
2022
+ for f in sorted(output.iterdir()):
2023
+ if not f.name.startswith("."):
2024
+ typer.echo(f" {f.name}")
2025
+ typer.echo("")
2026
+ typer.echo("Next steps:")
2027
+ typer.echo(" 1. Read reference.py to understand the kernel interface")
2028
+ typer.echo(" 2. Create kernel.py with your custom_kernel implementation:")
2029
+ typer.echo("")
2030
+ typer.echo(" def custom_kernel(data):")
2031
+ typer.echo(" # Your optimized implementation")
2032
+ typer.echo(" ...")
2033
+ typer.echo("")
2034
+ typer.echo(" 3. Run evaluation:")
2035
+ typer.echo(
2036
+ f" wafer evaluate gpumode --impl {output}/kernel.py --reference {output}/reference.py \\"
2037
+ )
2038
+ typer.echo(f" --test-cases {output}/test_cases.json --target <target>")
2039
+
2040
+
2041
+ @gpumode_app.callback(invoke_without_command=True)
2042
+ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2043
+ ctx: typer.Context,
2044
+ implementation: Path | None = typer.Option(
2045
+ None, "--impl", "-i", help="Path to implementation kernel file"
2046
+ ),
2047
+ reference: Path | None = typer.Option(
2048
+ None, "--reference", help="Path to reference kernel file"
2049
+ ),
2050
+ test_cases: Path | None = typer.Option(
2051
+ None, "--test-cases", help="Path to test cases JSON file"
2052
+ ),
2053
+ target: str | None = typer.Option(
2054
+ None,
2055
+ "--target",
2056
+ "-t",
2057
+ help="GPU target name. See 'wafer config targets list' for available targets.",
2058
+ autocompletion=complete_target_name,
2059
+ ),
2060
+ pool: str | None = typer.Option(
2061
+ None,
2062
+ "--pool",
2063
+ "-p",
2064
+ help="Target pool name. Acquires first available target from the pool. "
2065
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
2066
+ ),
2067
+ benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
2068
+ profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
2069
+ defensive: bool = typer.Option(
2070
+ False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
2071
+ ),
2072
+ sync_artifacts: bool = typer.Option(
2073
+ True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
2074
+ ),
2075
+ gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
2076
+ ) -> None:
2077
+ """Run kernel evaluation in GPUMode format (functional).
2078
+
2079
+ This format expects:
2080
+ - Implementation: Python file with `custom_kernel(inputs)` function
2081
+ - Reference: Python file with `ref_kernel(inputs)` and `generate_input(**kwargs)` functions
2082
+ - Test cases: JSON file with test parameters
2083
+
2084
+ Examples:
2085
+ # Basic correctness check
2086
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
2087
+
2088
+ # With benchmarking
2089
+ wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
2090
+ --target vultr-b200 --benchmark
2091
+
2092
+ Subcommands:
2093
+ download Download GPUMode problems from GitHub
2094
+ list-problems List available problems
2095
+ make-template Extract a problem as template files
2096
+ """
2097
+ # If a subcommand is being invoked, skip the main evaluation logic
2098
+ if ctx.invoked_subcommand is not None:
2099
+ return
2100
+
2101
+ # Validate required args when running evaluation (not subcommands)
2102
+ missing_args = []
2103
+ if implementation is None:
2104
+ missing_args.append("--impl/-i")
2105
+ if reference is None:
2106
+ missing_args.append("--reference")
2107
+ if test_cases is None:
2108
+ missing_args.append("--test-cases")
2109
+
2110
+ if missing_args:
2111
+ typer.echo("Error: Missing required arguments", err=True)
2112
+ typer.echo(f" Required: {', '.join(missing_args)}", err=True)
2113
+ typer.echo("", err=True)
2114
+ typer.echo(
2115
+ "Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
2116
+ err=True,
2117
+ )
2118
+ typer.echo("", err=True)
2119
+ typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
2120
+ typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
2121
+ raise typer.Exit(1)
2122
+
2123
+ # Validate --target and --pool are mutually exclusive
2124
+ if target and pool:
2125
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
2126
+ raise typer.Exit(1)
2127
+
2128
+ from .evaluate import EvaluateArgs, run_evaluate
2129
+
2130
+ # If pool specified, acquire a target from the pool
2131
+ resolved_target = target or ""
2132
+ pool_lock_context = None
2133
+
2134
+ if pool:
2135
+ from .target_lock import acquire_from_pool
2136
+ from .targets import get_pool
2137
+
2138
+ try:
2139
+ pool_targets = get_pool(pool)
2140
+ except FileNotFoundError as e:
2141
+ typer.echo(f"Error: {e}", err=True)
2142
+ raise typer.Exit(1) from None
2143
+
2144
+ typer.echo(f"Acquiring target from pool '{pool}' ({len(pool_targets)} targets)...")
2145
+ pool_lock_context = acquire_from_pool(pool_targets)
2146
+ acquired_target = pool_lock_context.__enter__()
2147
+
2148
+ if acquired_target is None:
2149
+ typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
2150
+ typer.echo(f" Targets: {', '.join(pool_targets)}", err=True)
2151
+ raise typer.Exit(1)
2152
+
2153
+ typer.echo(f"Acquired target: {acquired_target}")
2154
+ resolved_target = acquired_target
2155
+
2156
+ args = EvaluateArgs(
2157
+ implementation=implementation,
2158
+ reference=reference,
2159
+ test_cases=test_cases,
2160
+ target_name=resolved_target,
2161
+ benchmark=benchmark,
2162
+ profile=profile,
2163
+ defensive=defensive,
2164
+ sync_artifacts=sync_artifacts,
2165
+ gpu_id=gpu_id,
2166
+ )
2167
+
2168
+ try:
2169
+ import trio_asyncio
2170
+
2171
+ result = trio_asyncio.run(run_evaluate, args)
2172
+ except KeyboardInterrupt:
2173
+ typer.echo("\nInterrupted by user", err=True)
2174
+ raise typer.Exit(130) from None
2175
+ except Exception as e:
2176
+ if hasattr(e, "exceptions") and e.exceptions:
2177
+ for exc in e.exceptions:
2178
+ typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
2179
+ else:
2180
+ typer.echo(f"Error: {e}", err=True)
2181
+ raise typer.Exit(1) from None
2182
+ finally:
2183
+ # Release pool lock if we acquired one
2184
+ if pool_lock_context is not None:
2185
+ pool_lock_context.__exit__(None, None, None)
2186
+
2187
+ # Print results
2188
+ if result.success:
2189
+ typer.echo("")
2190
+ typer.echo("=" * 60)
2191
+ status = "PASS" if result.all_correct else "FAIL"
2192
+ typer.echo(f"Result: {status}")
2193
+ score_pct = f"{result.correctness_score:.1%}"
2194
+ typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
2195
+ if result.geomean_speedup > 0:
2196
+ typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
2197
+ if result.artifact_path:
2198
+ typer.echo(f"Artifacts: {result.artifact_path}")
2199
+ typer.echo("=" * 60)
2200
+
2201
+ if not result.all_correct:
2202
+ raise typer.Exit(1)
2203
+ else:
2204
+ typer.echo(f"Error: {result.error_message}", err=True)
2205
+ raise typer.Exit(1)
2206
+
2207
+
1739
2208
  # =============================================================================
1740
2209
  # Push and Remote-Run commands
1741
2210
  # =============================================================================
@@ -1867,7 +2336,7 @@ def _run_direct_mode(
1867
2336
  typer.echo(f"Uploading {upload_dir.name}...")
1868
2337
  try:
1869
2338
  push_result = push_direct(upload_dir, target)
1870
- workspace_name = push_result.workspace_path
2339
+ workspace_name = push_result.workspace_name
1871
2340
  typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
1872
2341
  except Exception as e:
1873
2342
  typer.echo(f"Error uploading: {e}", err=True)
@@ -2040,7 +2509,10 @@ def login(
2040
2509
  None, "--token", "-t", help="Access token (skip browser OAuth)"
2041
2510
  ),
2042
2511
  port: int | None = typer.Option(
2043
- None, "--port", "-p", help="Port for OAuth callback server (default: 8765 for SSH, random for local)"
2512
+ None,
2513
+ "--port",
2514
+ "-p",
2515
+ help="Port for OAuth callback server (default: 8765 for SSH, random for local)",
2044
2516
  ),
2045
2517
  ) -> None:
2046
2518
  """Authenticate CLI with wafer-api via GitHub OAuth.
@@ -2142,9 +2614,8 @@ def login(
2142
2614
  @app.command("logout")
2143
2615
  def logout() -> None:
2144
2616
  """Remove stored credentials."""
2145
- from .auth import clear_credentials
2146
-
2147
2617
  from . import analytics
2618
+ from .auth import clear_credentials
2148
2619
 
2149
2620
  # Track logout event first (while credentials still exist for user identification)
2150
2621
  # Note: track_logout() handles the case where user is not logged in
@@ -2621,6 +3092,7 @@ init_app = typer.Typer(
2621
3092
 
2622
3093
  Choose based on your GPU access:
2623
3094
 
3095
+ local GPU on current machine (no SSH)
2624
3096
  ssh Your own hardware via SSH
2625
3097
  runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
2626
3098
  digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
@@ -2628,6 +3100,92 @@ Choose based on your GPU access:
2628
3100
  targets_app.add_typer(init_app, name="init")
2629
3101
 
2630
3102
 
3103
+ @init_app.command("local")
3104
+ def init_local(
3105
+ name: str = typer.Option("local", "--name", "-n", help="Target name"),
3106
+ gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
3107
+ ) -> None:
3108
+ """Initialize a local target for GPU on current machine.
3109
+
3110
+ Detects your local GPU and configures a target for direct execution
3111
+ (no SSH). Use this when running wafer on the same machine as the GPU.
3112
+
3113
+ Examples:
3114
+ wafer config targets init local
3115
+ wafer config targets init local --name my-5090 --gpu-ids 0,1
3116
+ """
3117
+ from .targets import save_target
3118
+
3119
+ # Parse GPU IDs
3120
+ try:
3121
+ parsed_gpu_ids = [int(g.strip()) for g in gpu_ids.split(",")]
3122
+ except ValueError:
3123
+ typer.echo(f"Error: Invalid GPU IDs '{gpu_ids}'. Use comma-separated integers.", err=True)
3124
+ raise typer.Exit(1) from None
3125
+
3126
+ typer.echo("Detecting local GPU...")
3127
+
3128
+ try:
3129
+ from wafer_core.gpu_detect import (
3130
+ detect_local_gpu,
3131
+ get_compute_capability,
3132
+ get_torch_requirements,
3133
+ )
3134
+
3135
+ detected_gpu = detect_local_gpu()
3136
+
3137
+ if detected_gpu:
3138
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3139
+ if detected_gpu.vendor == "nvidia":
3140
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3141
+ else:
3142
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3143
+ typer.echo(f" GPU count: {detected_gpu.gpu_count}")
3144
+
3145
+ # Get torch requirements and compute capability
3146
+ torch_reqs = get_torch_requirements(detected_gpu)
3147
+ compute_capability = get_compute_capability(detected_gpu)
3148
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3149
+
3150
+ typer.echo(f" PyTorch: {torch_reqs.packages[0]}")
3151
+ else:
3152
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)", err=True)
3153
+ raise typer.Exit(1)
3154
+
3155
+ except ImportError as e:
3156
+ typer.echo(f"Error: Missing dependency: {e}", err=True)
3157
+ raise typer.Exit(1) from None
3158
+
3159
+ # Build target data
3160
+ target_data = {
3161
+ "name": name,
3162
+ "type": "local",
3163
+ "gpu_ids": parsed_gpu_ids,
3164
+ "gpu_type": gpu_type,
3165
+ "compute_capability": compute_capability,
3166
+ "torch_package": torch_reqs.packages[0],
3167
+ "torch_index_url": torch_reqs.index_url,
3168
+ "vendor": detected_gpu.vendor,
3169
+ "driver_version": detected_gpu.driver_version,
3170
+ }
3171
+
3172
+ try:
3173
+ target = save_target(target_data)
3174
+ typer.echo(f"✓ Created target: {target.name}")
3175
+ typer.echo(" Type: Local (no SSH)")
3176
+ typer.echo(f" GPU IDs: {parsed_gpu_ids}")
3177
+ typer.echo(f" GPU Type: {gpu_type}")
3178
+ typer.echo(f" Compute: {compute_capability}")
3179
+ typer.echo(f" Torch: {torch_reqs.packages[0]}")
3180
+ typer.echo("")
3181
+ typer.echo(
3182
+ f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
3183
+ )
3184
+ except (ValueError, AssertionError) as e:
3185
+ typer.echo(f"Error: {e}", err=True)
3186
+ raise typer.Exit(1) from None
3187
+
3188
+
2631
3189
  @init_app.command("runpod")
2632
3190
  def init_runpod(
2633
3191
  name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
@@ -2791,23 +3349,29 @@ def init_ssh(
2791
3349
  host: str = typer.Option(..., "--host", "-H", help="SSH host (user@hostname:port)"),
2792
3350
  ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
2793
3351
  gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
2794
- gpu_type: str = typer.Option(
2795
- "H100", "--gpu-type", help="GPU type (H100, A100, B200, MI300X, etc.)"
3352
+ gpu_type: str | None = typer.Option(
3353
+ None, "--gpu-type", help="GPU type (auto-detected if not specified)"
2796
3354
  ),
2797
3355
  docker_image: str | None = typer.Option(
2798
3356
  None, "--docker-image", "-d", help="Docker image (optional)"
2799
3357
  ),
2800
3358
  ncu: bool = typer.Option(False, "--ncu/--no-ncu", help="NCU profiling available"),
3359
+ no_detect: bool = typer.Option(False, "--no-detect", help="Skip GPU auto-detection"),
2801
3360
  ) -> None:
2802
3361
  """Initialize an SSH target for your own GPU hardware.
2803
3362
 
2804
3363
  Creates a target config for direct SSH access to a GPU machine.
2805
- Use for baremetal servers, VMs, or any machine you have SSH access to.
3364
+ Automatically detects GPU type and selects compatible PyTorch version.
2806
3365
 
2807
3366
  Examples:
3367
+ # Auto-detect GPU (recommended)
2808
3368
  wafer config targets init ssh --name my-gpu --host user@192.168.1.100:22
3369
+
3370
+ # Multiple GPUs with NCU profiling
2809
3371
  wafer config targets init ssh --name lab-h100 --host ubuntu@gpu.lab.com:22 --gpu-ids 0,1 --ncu
2810
- wafer config targets init ssh --name docker-gpu --host user@host:22 --docker-image nvcr.io/nvidia/pytorch:24.01-py3
3372
+
3373
+ # Skip detection, specify manually
3374
+ wafer config targets init ssh --name my-gpu --host user@host:22 --gpu-type H100 --no-detect
2811
3375
  """
2812
3376
  from .targets import save_target
2813
3377
 
@@ -2824,17 +3388,87 @@ def init_ssh(
2824
3388
  typer.echo("Example: user@192.168.1.100:22", err=True)
2825
3389
  raise typer.Exit(1)
2826
3390
 
3391
+ # Auto-detect GPU if not specified
3392
+ detected_gpu = None
3393
+ torch_package = None
3394
+ torch_index_url = None
3395
+
3396
+ if not no_detect:
3397
+ typer.echo(f"Connecting to {host}...")
3398
+ try:
3399
+ import trio
3400
+ import trio_asyncio
3401
+
3402
+ from wafer_core.async_ssh import AsyncSSHClient
3403
+ from wafer_core.gpu_detect import (
3404
+ detect_remote_gpu,
3405
+ get_compute_capability,
3406
+ get_torch_requirements,
3407
+ )
3408
+
3409
+ expanded_key = str(Path(ssh_key).expanduser())
3410
+
3411
+ async def _detect() -> None:
3412
+ nonlocal detected_gpu, torch_package, torch_index_url
3413
+ # Need trio_asyncio.open_loop() for asyncssh bridge
3414
+ async with trio_asyncio.open_loop():
3415
+ async with AsyncSSHClient(host, expanded_key) as client:
3416
+ detected_gpu = await detect_remote_gpu(client)
3417
+
3418
+ trio.run(_detect)
3419
+
3420
+ if detected_gpu:
3421
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3422
+ if detected_gpu.vendor == "nvidia":
3423
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3424
+ else:
3425
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3426
+
3427
+ # Get torch requirements
3428
+ torch_reqs = get_torch_requirements(detected_gpu)
3429
+ torch_package = torch_reqs.packages[0] # Just torch, not all packages
3430
+ torch_index_url = torch_reqs.index_url
3431
+ typer.echo(f" PyTorch: {torch_package}")
3432
+
3433
+ # Use detected GPU type if not specified
3434
+ if not gpu_type:
3435
+ # Extract GPU name (e.g., "H100" from "NVIDIA H100 80GB HBM3")
3436
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3437
+ else:
3438
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)")
3439
+ if not gpu_type:
3440
+ gpu_type = "H100" # Default fallback
3441
+ typer.echo(f" Using default: {gpu_type}")
3442
+
3443
+ except Exception as e:
3444
+ typer.echo(f" Detection failed: {e}", err=True)
3445
+ if not gpu_type:
3446
+ gpu_type = "H100"
3447
+ typer.echo(f" Using default: {gpu_type}")
3448
+
3449
+ # Fallback if no detection
3450
+ if not gpu_type:
3451
+ gpu_type = "H100"
3452
+
2827
3453
  # Compute capability mappings
2828
- compute_caps = {
2829
- "B200": "10.0",
2830
- "H100": "9.0",
2831
- "A100": "8.0",
2832
- "A10": "8.6",
2833
- "V100": "7.0",
2834
- "MI300X": "9.4",
2835
- "MI250X": "9.0",
2836
- }
2837
- compute_capability = compute_caps.get(gpu_type, "8.0")
3454
+ if detected_gpu:
3455
+ from wafer_core.gpu_detect import get_compute_capability
3456
+
3457
+ compute_capability = get_compute_capability(detected_gpu)
3458
+ else:
3459
+ compute_caps = {
3460
+ "B200": "10.0",
3461
+ "H100": "9.0",
3462
+ "A100": "8.0",
3463
+ "A10": "8.6",
3464
+ "V100": "7.0",
3465
+ "MI300X": "9.4",
3466
+ "MI250X": "9.0",
3467
+ "RTX 5090": "10.0",
3468
+ "RTX 4090": "8.9",
3469
+ "RTX 3090": "8.6",
3470
+ }
3471
+ compute_capability = compute_caps.get(gpu_type, "8.0")
2838
3472
 
2839
3473
  # Build target data
2840
3474
  target_data = {
@@ -2851,6 +3485,12 @@ def init_ssh(
2851
3485
  if docker_image:
2852
3486
  target_data["docker_image"] = docker_image
2853
3487
 
3488
+ # Add torch requirements if detected
3489
+ if torch_package:
3490
+ target_data["torch_package"] = torch_package
3491
+ if torch_index_url:
3492
+ target_data["torch_index_url"] = torch_index_url
3493
+
2854
3494
  try:
2855
3495
  target = save_target(target_data)
2856
3496
  typer.echo(f"✓ Created target: {target.name}")
@@ -2858,9 +3498,12 @@ def init_ssh(
2858
3498
  typer.echo(f" Host: {host}")
2859
3499
  typer.echo(f" GPU IDs: {parsed_gpu_ids}")
2860
3500
  typer.echo(f" GPU Type: {gpu_type}")
3501
+ typer.echo(f" Compute: {compute_capability}")
2861
3502
  typer.echo(f" NCU: {'Yes' if ncu else 'No'}")
2862
3503
  if docker_image:
2863
3504
  typer.echo(f" Docker: {docker_image}")
3505
+ if torch_package:
3506
+ typer.echo(f" Torch: {torch_package}")
2864
3507
  typer.echo("")
2865
3508
  typer.echo(
2866
3509
  f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
@@ -2870,6 +3513,31 @@ def init_ssh(
2870
3513
  raise typer.Exit(1) from None
2871
3514
 
2872
3515
 
3516
+ def _extract_gpu_type(gpu_name: str) -> str:
3517
+ """Extract GPU type from full GPU name.
3518
+
3519
+ Examples:
3520
+ "NVIDIA H100 80GB HBM3" -> "H100"
3521
+ "NVIDIA GeForce RTX 4090" -> "RTX 4090"
3522
+ "AMD Instinct MI300X OAM" -> "MI300X"
3523
+ """
3524
+ gpu_name_upper = gpu_name.upper()
3525
+
3526
+ # Check for known GPU types
3527
+ known_types = [
3528
+ "B200", "B100", "H200", "H100", "A100", "A10", "V100",
3529
+ "RTX 5090", "RTX 5080", "RTX 4090", "RTX 4080", "RTX 3090", "RTX 3080",
3530
+ "MI300X", "MI250X", "MI100",
3531
+ ]
3532
+
3533
+ for gpu_type in known_types:
3534
+ if gpu_type in gpu_name_upper:
3535
+ return gpu_type
3536
+
3537
+ # Fallback: return cleaned name
3538
+ return gpu_name.replace("NVIDIA ", "").replace("AMD ", "").strip()
3539
+
3540
+
2873
3541
  @targets_app.command("add")
2874
3542
  def targets_add(
2875
3543
  file_path: Path = typer.Argument(..., help="Path to target TOML file"),
@@ -3082,6 +3750,92 @@ def targets_pods() -> None:
3082
3750
  typer.echo()
3083
3751
 
3084
3752
 
3753
+ # ── Pool commands ───────────────────────────────────────────────────────────
3754
+
3755
+
3756
+ @targets_app.command("pool-list")
3757
+ def targets_pool_list() -> None:
3758
+ """List all configured target pools.
3759
+
3760
+ Example:
3761
+ wafer config targets pool-list
3762
+ """
3763
+ from .targets import get_pool, list_pools
3764
+
3765
+ pools = list_pools()
3766
+
3767
+ if not pools:
3768
+ typer.echo("No pools configured")
3769
+ typer.echo("")
3770
+ typer.echo("Define pools in ~/.wafer/config.toml:")
3771
+ typer.echo(" [pools.my-pool]")
3772
+ typer.echo(' targets = ["target-1", "target-2"]')
3773
+ return
3774
+
3775
+ typer.echo("Configured pools:\n")
3776
+ for pool_name in pools:
3777
+ try:
3778
+ targets = get_pool(pool_name)
3779
+ typer.echo(f" {pool_name}: {', '.join(targets)}")
3780
+ except Exception as e:
3781
+ typer.echo(f" {pool_name}: (error: {e})")
3782
+
3783
+
3784
+ @targets_app.command("pool-create")
3785
+ def targets_pool_create(
3786
+ name: str = typer.Argument(..., help="Pool name"),
3787
+ targets: list[str] = typer.Argument(..., help="Target names to include in pool"),
3788
+ ) -> None:
3789
+ """Create or update a target pool.
3790
+
3791
+ Example:
3792
+ wafer config targets pool-create mi300x-pool mi300x-1 mi300x-2 mi300x-3
3793
+ """
3794
+ from .targets import save_pool
3795
+
3796
+ try:
3797
+ save_pool(name, targets)
3798
+ typer.echo(f"Pool '{name}' created with {len(targets)} targets")
3799
+ except FileNotFoundError as e:
3800
+ typer.echo(f"Error: {e}", err=True)
3801
+ raise typer.Exit(1) from None
3802
+
3803
+
3804
+ @targets_app.command("pool-status")
3805
+ def targets_pool_status(
3806
+ name: str = typer.Argument(..., help="Pool name"),
3807
+ ) -> None:
3808
+ """Show status of targets in a pool (locked/available).
3809
+
3810
+ Example:
3811
+ wafer config targets pool-status mi300x-pool
3812
+ """
3813
+ from .target_lock import get_lock_holder, is_target_locked
3814
+ from .targets import get_pool
3815
+
3816
+ try:
3817
+ targets = get_pool(name)
3818
+ except FileNotFoundError as e:
3819
+ typer.echo(f"Error: {e}", err=True)
3820
+ raise typer.Exit(1) from None
3821
+
3822
+ typer.echo(f"Pool '{name}' ({len(targets)} targets):\n")
3823
+
3824
+ available = 0
3825
+ for target_name in targets:
3826
+ locked = is_target_locked(target_name)
3827
+ if locked:
3828
+ pid = get_lock_holder(target_name)
3829
+ pid_str = f" (pid {pid})" if pid else ""
3830
+ typer.echo(f" [busy] {target_name}{pid_str}")
3831
+ else:
3832
+ typer.echo(f" [free] {target_name}")
3833
+ available += 1
3834
+
3835
+ typer.echo("")
3836
+ typer.echo(f"Available: {available}/{len(targets)}")
3837
+
3838
+
3085
3839
  # =============================================================================
3086
3840
  # Billing commands
3087
3841
  # =============================================================================
@@ -3115,7 +3869,9 @@ def billing_usage(
3115
3869
  @billing_app.command("topup")
3116
3870
  def billing_topup(
3117
3871
  amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
3118
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
3872
+ no_browser: bool = typer.Option(
3873
+ False, "--no-browser", help="Print URL instead of opening browser"
3874
+ ),
3119
3875
  ) -> None:
3120
3876
  """Add credits to your account.
3121
3877
 
@@ -3161,7 +3917,9 @@ def billing_topup(
3161
3917
 
3162
3918
  @billing_app.command("portal")
3163
3919
  def billing_portal(
3164
- no_browser: bool = typer.Option(False, "--no-browser", help="Print URL instead of opening browser"),
3920
+ no_browser: bool = typer.Option(
3921
+ False, "--no-browser", help="Print URL instead of opening browser"
3922
+ ),
3165
3923
  ) -> None:
3166
3924
  """Open Stripe billing portal.
3167
3925
 
@@ -4437,8 +5195,8 @@ def _setup_wafer_core_env() -> None:
4437
5195
  - WAFER_API_URL: If already set, uses that instead of config
4438
5196
  - WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
4439
5197
  """
4440
- from .global_config import get_api_url
4441
5198
  from .auth import get_valid_token
5199
+ from .global_config import get_api_url
4442
5200
 
4443
5201
  # Set API URL (get_api_url already respects WAFER_API_URL env var)
4444
5202
  os.environ["WAFER_API_URL"] = get_api_url()
@@ -4742,8 +5500,8 @@ def capture_command( # noqa: PLR0915
4742
5500
  import os
4743
5501
  import tomllib
4744
5502
 
4745
- from .global_config import get_api_url
4746
5503
  from .auth import get_valid_token
5504
+ from .global_config import get_api_url
4747
5505
 
4748
5506
  # Set environment variables for wafer-core BEFORE importing it
4749
5507
  # wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
@@ -4947,8 +5705,8 @@ def capture_list_command(
4947
5705
  """
4948
5706
  import os
4949
5707
 
4950
- from .global_config import get_api_url
4951
5708
  from .auth import get_valid_token
5709
+ from .global_config import get_api_url
4952
5710
 
4953
5711
  # Set environment variables for wafer-core BEFORE importing it
4954
5712
  os.environ["WAFER_API_URL"] = get_api_url()