wafer-cli 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/PKG-INFO +1 -1
  2. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/pyproject.toml +1 -1
  3. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/cli.py +403 -106
  4. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/evaluate.py +871 -98
  5. wafer_cli-0.2.6/wafer/target_lock.py +198 -0
  6. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/targets.py +158 -0
  7. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer_cli.egg-info/PKG-INFO +1 -1
  8. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer_cli.egg-info/SOURCES.txt +1 -0
  9. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/README.md +0 -0
  10. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/setup.cfg +0 -0
  11. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_analytics.py +0 -0
  12. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_billing.py +0 -0
  13. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_cli_coverage.py +0 -0
  14. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_cli_parity_integration.py +0 -0
  15. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_config_integration.py +0 -0
  16. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_file_operations_integration.py +0 -0
  17. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_isa_cli.py +0 -0
  18. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_rocprof_compute_integration.py +0 -0
  19. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_ssh_integration.py +0 -0
  20. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_wevin_cli.py +0 -0
  21. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/tests/test_workflow_integration.py +0 -0
  22. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/GUIDE.md +0 -0
  23. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/__init__.py +0 -0
  24. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/analytics.py +0 -0
  25. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/api_client.py +0 -0
  26. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/auth.py +0 -0
  27. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/autotuner.py +0 -0
  28. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/billing.py +0 -0
  29. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/config.py +0 -0
  30. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/corpus.py +0 -0
  31. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/global_config.py +0 -0
  32. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/gpu_run.py +0 -0
  33. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/inference.py +0 -0
  34. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/ncu_analyze.py +0 -0
  35. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/nsys_analyze.py +0 -0
  36. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/problems.py +0 -0
  37. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/rocprof_compute.py +0 -0
  38. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/rocprof_sdk.py +0 -0
  39. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/rocprof_systems.py +0 -0
  40. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/skills/wafer-guide/SKILL.md +0 -0
  41. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/templates/__init__.py +0 -0
  42. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/templates/ask_docs.py +0 -0
  43. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/templates/optimize_kernel.py +0 -0
  44. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/templates/trace_analyze.py +0 -0
  45. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/tracelens.py +0 -0
  46. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/wevin_cli.py +0 -0
  47. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer/workspaces.py +0 -0
  48. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer_cli.egg-info/dependency_links.txt +0 -0
  49. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer_cli.egg-info/entry_points.txt +0 -0
  50. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer_cli.egg-info/requires.txt +0 -0
  51. {wafer_cli-0.2.4 → wafer_cli-0.2.6}/wafer_cli.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: typer>=0.12.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "wafer-cli"
3
- version = "0.2.4"
3
+ version = "0.2.6"
4
4
  description = "CLI tool for running commands on remote GPUs and GPU kernel optimization agent"
5
5
  requires-python = ">=3.11"
6
6
  dependencies = [
@@ -99,7 +99,11 @@ def main_callback(ctx: typer.Context) -> None:
99
99
  # Install exception hook to catch SystemExit and mark failures
100
100
  original_excepthook = sys.excepthook
101
101
 
102
- def custom_excepthook(exc_type, exc_value, exc_traceback):
102
+ def custom_excepthook(
103
+ exc_type: type[BaseException],
104
+ exc_value: BaseException,
105
+ exc_traceback: object,
106
+ ) -> None:
103
107
  global _command_outcome
104
108
  # Mark as failure if SystemExit with non-zero code, or any other exception
105
109
  if exc_type is SystemExit:
@@ -467,7 +471,7 @@ def provider_auth_login(
467
471
  # Save the key
468
472
  save_api_key(provider, api_key)
469
473
  typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
470
- typer.echo(f"Stored in: ~/.wafer/auth.json")
474
+ typer.echo("Stored in: ~/.wafer/auth.json")
471
475
 
472
476
 
473
477
  @provider_auth_app.command("logout")
@@ -517,9 +521,7 @@ def provider_auth_status() -> None:
517
521
  for status in statuses:
518
522
  if status.is_authenticated:
519
523
  source_str = f"({status.source})" if status.source else ""
520
- typer.echo(
521
- f" {status.display_name}: ✓ {status.key_preview} {source_str}"
522
- )
524
+ typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
523
525
  else:
524
526
  typer.echo(f" {status.display_name}: ✗ Not configured")
525
527
  typer.echo(f" Run: wafer auth login {status.provider}")
@@ -1430,90 +1432,19 @@ def evaluate( # noqa: PLR0913
1430
1432
  if ctx.invoked_subcommand is not None:
1431
1433
  return
1432
1434
 
1433
- # Deprecation warning for bare evaluate
1434
- typer.echo(
1435
- "⚠️ Deprecation warning: 'wafer evaluate' will be removed in a future version.",
1436
- err=True,
1437
- )
1438
- typer.echo(
1439
- " Use 'wafer evaluate gpumode' instead for the functional format.",
1440
- err=True,
1441
- )
1435
+ # Bare 'wafer evaluate' is no longer supported - must use subcommand
1436
+ typer.echo("Error: 'wafer evaluate' requires a subcommand.", err=True)
1442
1437
  typer.echo("", err=True)
1443
-
1444
- # Validate required args when running evaluation (not subcommands)
1445
- missing_args = []
1446
- if implementation is None:
1447
- missing_args.append("--impl/-i")
1448
- if reference is None:
1449
- missing_args.append("--reference")
1450
- if test_cases is None:
1451
- missing_args.append("--test-cases")
1452
-
1453
- if missing_args:
1454
- typer.echo("Error: Missing required arguments", err=True)
1455
- typer.echo(f" Required: {', '.join(missing_args)}", err=True)
1456
- typer.echo("", err=True)
1457
- typer.echo(
1458
- "Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
1459
- err=True,
1460
- )
1461
- typer.echo("", err=True)
1462
- typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
1463
- typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
1464
- raise typer.Exit(1)
1465
-
1466
- from .evaluate import EvaluateArgs, run_evaluate
1467
-
1468
- args = EvaluateArgs(
1469
- implementation=implementation,
1470
- reference=reference,
1471
- test_cases=test_cases,
1472
- target_name=target or "",
1473
- benchmark=benchmark,
1474
- profile=profile,
1475
- defensive=defensive,
1476
- sync_artifacts=sync_artifacts,
1477
- gpu_id=gpu_id,
1478
- )
1479
-
1480
- try:
1481
- # Use trio_asyncio to run async code that uses both trio and asyncio
1482
- # (AsyncSSHClient uses asyncssh which is asyncio-based, bridged via trio_asyncio)
1483
- import trio_asyncio
1484
-
1485
- result = trio_asyncio.run(run_evaluate, args)
1486
- except KeyboardInterrupt:
1487
- typer.echo("\nInterrupted by user", err=True)
1488
- raise typer.Exit(130) from None
1489
- except Exception as e:
1490
- # Unwrap ExceptionGroup (from Trio nurseries) to show actual error
1491
- if hasattr(e, "exceptions") and e.exceptions:
1492
- for exc in e.exceptions:
1493
- typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
1494
- else:
1495
- typer.echo(f"Error: {e}", err=True)
1496
- raise typer.Exit(1) from None
1497
-
1498
- # Print results
1499
- if result.success:
1500
- typer.echo("")
1501
- typer.echo("=" * 60)
1502
- status = "PASS" if result.all_correct else "FAIL"
1503
- typer.echo(f"Result: {status}")
1504
- score_pct = f"{result.correctness_score:.1%}"
1505
- typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
1506
- if result.geomean_speedup > 0:
1507
- typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
1508
- if result.artifact_path:
1509
- typer.echo(f"Artifacts: {result.artifact_path}")
1510
- typer.echo("=" * 60)
1511
-
1512
- if not result.all_correct:
1513
- raise typer.Exit(1)
1514
- else:
1515
- typer.echo(f"Error: {result.error_message}", err=True)
1516
- raise typer.Exit(1)
1438
+ typer.echo("Available subcommands:", err=True)
1439
+ typer.echo(" gpumode Evaluate GPUMode format (custom_kernel/ref_kernel functions)", err=True)
1440
+ typer.echo(" kernelbench Evaluate KernelBench format (ModelNew class)", err=True)
1441
+ typer.echo("", err=True)
1442
+ typer.echo("Examples:", err=True)
1443
+ typer.echo(" wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json", err=True)
1444
+ typer.echo(" wafer evaluate kernelbench --impl impl.py --reference ref.py --benchmark", err=True)
1445
+ typer.echo("", err=True)
1446
+ typer.echo("Run 'wafer evaluate gpumode --help' or 'wafer evaluate kernelbench --help' for options.", err=True)
1447
+ raise typer.Exit(1)
1517
1448
 
1518
1449
 
1519
1450
  TEMPLATE_KERNEL = '''\
@@ -1724,6 +1655,13 @@ def kernelbench_evaluate( # noqa: PLR0913
1724
1655
  help="GPU target name. See 'wafer config targets list' for available targets.",
1725
1656
  autocompletion=complete_target_name,
1726
1657
  ),
1658
+ pool: str | None = typer.Option(
1659
+ None,
1660
+ "--pool",
1661
+ "-p",
1662
+ help="Target pool name. Acquires first available target from the pool. "
1663
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
1664
+ ),
1727
1665
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
1728
1666
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
1729
1667
  inputs: Path | None = typer.Option(
@@ -1786,12 +1724,43 @@ def kernelbench_evaluate( # noqa: PLR0913
1786
1724
  )
1787
1725
  raise typer.Exit(1)
1788
1726
 
1727
+ # Validate --target and --pool are mutually exclusive
1728
+ if target and pool:
1729
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
1730
+ raise typer.Exit(1)
1731
+
1789
1732
  from .evaluate import KernelBenchEvaluateArgs, run_evaluate_kernelbench
1790
1733
 
1734
+ # If pool specified, acquire a target from the pool
1735
+ resolved_target = target or ""
1736
+ pool_lock_context = None
1737
+
1738
+ if pool:
1739
+ from .target_lock import acquire_from_pool
1740
+ from .targets import get_pool
1741
+
1742
+ try:
1743
+ pool_targets = get_pool(pool)
1744
+ except FileNotFoundError as e:
1745
+ typer.echo(f"Error: {e}", err=True)
1746
+ raise typer.Exit(1) from None
1747
+
1748
+ typer.echo(f"Acquiring target from pool '{pool}' ({len(pool_targets)} targets)...")
1749
+ pool_lock_context = acquire_from_pool(pool_targets)
1750
+ acquired_target = pool_lock_context.__enter__()
1751
+
1752
+ if acquired_target is None:
1753
+ typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
1754
+ typer.echo(f" Targets: {', '.join(pool_targets)}", err=True)
1755
+ raise typer.Exit(1)
1756
+
1757
+ typer.echo(f"Acquired target: {acquired_target}")
1758
+ resolved_target = acquired_target
1759
+
1791
1760
  args = KernelBenchEvaluateArgs(
1792
1761
  implementation=implementation,
1793
1762
  reference=reference,
1794
- target_name=target or "",
1763
+ target_name=resolved_target,
1795
1764
  benchmark=benchmark,
1796
1765
  profile=profile,
1797
1766
  inputs=inputs,
@@ -1811,6 +1780,10 @@ def kernelbench_evaluate( # noqa: PLR0913
1811
1780
  except Exception as e:
1812
1781
  typer.echo(f"Error: {e}", err=True)
1813
1782
  raise typer.Exit(1) from None
1783
+ finally:
1784
+ # Release pool lock if we acquired one
1785
+ if pool_lock_context is not None:
1786
+ pool_lock_context.__exit__(None, None, None)
1814
1787
 
1815
1788
  # Print results
1816
1789
  if result.success:
@@ -2066,7 +2039,7 @@ def gpumode_make_template(
2066
2039
 
2067
2040
 
2068
2041
  @gpumode_app.callback(invoke_without_command=True)
2069
- def gpumode_evaluate( # noqa: PLR0913
2042
+ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2070
2043
  ctx: typer.Context,
2071
2044
  implementation: Path | None = typer.Option(
2072
2045
  None, "--impl", "-i", help="Path to implementation kernel file"
@@ -2084,6 +2057,13 @@ def gpumode_evaluate( # noqa: PLR0913
2084
2057
  help="GPU target name. See 'wafer config targets list' for available targets.",
2085
2058
  autocompletion=complete_target_name,
2086
2059
  ),
2060
+ pool: str | None = typer.Option(
2061
+ None,
2062
+ "--pool",
2063
+ "-p",
2064
+ help="Target pool name. Acquires first available target from the pool. "
2065
+ "Define pools in ~/.wafer/config.toml under [pools.<name>].",
2066
+ ),
2087
2067
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
2088
2068
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
2089
2069
  defensive: bool = typer.Option(
@@ -2140,14 +2120,44 @@ def gpumode_evaluate( # noqa: PLR0913
2140
2120
  typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
2141
2121
  raise typer.Exit(1)
2142
2122
 
2143
- # Reuse the existing evaluate logic (same format)
2123
+ # Validate --target and --pool are mutually exclusive
2124
+ if target and pool:
2125
+ typer.echo("Error: Cannot specify both --target and --pool", err=True)
2126
+ raise typer.Exit(1)
2127
+
2144
2128
  from .evaluate import EvaluateArgs, run_evaluate
2145
2129
 
2130
+ # If pool specified, acquire a target from the pool
2131
+ resolved_target = target or ""
2132
+ pool_lock_context = None
2133
+
2134
+ if pool:
2135
+ from .target_lock import acquire_from_pool
2136
+ from .targets import get_pool
2137
+
2138
+ try:
2139
+ pool_targets = get_pool(pool)
2140
+ except FileNotFoundError as e:
2141
+ typer.echo(f"Error: {e}", err=True)
2142
+ raise typer.Exit(1) from None
2143
+
2144
+ typer.echo(f"Acquiring target from pool '{pool}' ({len(pool_targets)} targets)...")
2145
+ pool_lock_context = acquire_from_pool(pool_targets)
2146
+ acquired_target = pool_lock_context.__enter__()
2147
+
2148
+ if acquired_target is None:
2149
+ typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
2150
+ typer.echo(f" Targets: {', '.join(pool_targets)}", err=True)
2151
+ raise typer.Exit(1)
2152
+
2153
+ typer.echo(f"Acquired target: {acquired_target}")
2154
+ resolved_target = acquired_target
2155
+
2146
2156
  args = EvaluateArgs(
2147
2157
  implementation=implementation,
2148
2158
  reference=reference,
2149
2159
  test_cases=test_cases,
2150
- target_name=target or "",
2160
+ target_name=resolved_target,
2151
2161
  benchmark=benchmark,
2152
2162
  profile=profile,
2153
2163
  defensive=defensive,
@@ -2169,6 +2179,10 @@ def gpumode_evaluate( # noqa: PLR0913
2169
2179
  else:
2170
2180
  typer.echo(f"Error: {e}", err=True)
2171
2181
  raise typer.Exit(1) from None
2182
+ finally:
2183
+ # Release pool lock if we acquired one
2184
+ if pool_lock_context is not None:
2185
+ pool_lock_context.__exit__(None, None, None)
2172
2186
 
2173
2187
  # Print results
2174
2188
  if result.success:
@@ -3078,6 +3092,7 @@ init_app = typer.Typer(
3078
3092
 
3079
3093
  Choose based on your GPU access:
3080
3094
 
3095
+ local GPU on current machine (no SSH)
3081
3096
  ssh Your own hardware via SSH
3082
3097
  runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
3083
3098
  digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
@@ -3085,6 +3100,92 @@ Choose based on your GPU access:
3085
3100
  targets_app.add_typer(init_app, name="init")
3086
3101
 
3087
3102
 
3103
+ @init_app.command("local")
3104
+ def init_local(
3105
+ name: str = typer.Option("local", "--name", "-n", help="Target name"),
3106
+ gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
3107
+ ) -> None:
3108
+ """Initialize a local target for GPU on current machine.
3109
+
3110
+ Detects your local GPU and configures a target for direct execution
3111
+ (no SSH). Use this when running wafer on the same machine as the GPU.
3112
+
3113
+ Examples:
3114
+ wafer config targets init local
3115
+ wafer config targets init local --name my-5090 --gpu-ids 0,1
3116
+ """
3117
+ from .targets import save_target
3118
+
3119
+ # Parse GPU IDs
3120
+ try:
3121
+ parsed_gpu_ids = [int(g.strip()) for g in gpu_ids.split(",")]
3122
+ except ValueError:
3123
+ typer.echo(f"Error: Invalid GPU IDs '{gpu_ids}'. Use comma-separated integers.", err=True)
3124
+ raise typer.Exit(1) from None
3125
+
3126
+ typer.echo("Detecting local GPU...")
3127
+
3128
+ try:
3129
+ from wafer_core.gpu_detect import (
3130
+ detect_local_gpu,
3131
+ get_compute_capability,
3132
+ get_torch_requirements,
3133
+ )
3134
+
3135
+ detected_gpu = detect_local_gpu()
3136
+
3137
+ if detected_gpu:
3138
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3139
+ if detected_gpu.vendor == "nvidia":
3140
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3141
+ else:
3142
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3143
+ typer.echo(f" GPU count: {detected_gpu.gpu_count}")
3144
+
3145
+ # Get torch requirements and compute capability
3146
+ torch_reqs = get_torch_requirements(detected_gpu)
3147
+ compute_capability = get_compute_capability(detected_gpu)
3148
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3149
+
3150
+ typer.echo(f" PyTorch: {torch_reqs.packages[0]}")
3151
+ else:
3152
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)", err=True)
3153
+ raise typer.Exit(1)
3154
+
3155
+ except ImportError as e:
3156
+ typer.echo(f"Error: Missing dependency: {e}", err=True)
3157
+ raise typer.Exit(1) from None
3158
+
3159
+ # Build target data
3160
+ target_data = {
3161
+ "name": name,
3162
+ "type": "local",
3163
+ "gpu_ids": parsed_gpu_ids,
3164
+ "gpu_type": gpu_type,
3165
+ "compute_capability": compute_capability,
3166
+ "torch_package": torch_reqs.packages[0],
3167
+ "torch_index_url": torch_reqs.index_url,
3168
+ "vendor": detected_gpu.vendor,
3169
+ "driver_version": detected_gpu.driver_version,
3170
+ }
3171
+
3172
+ try:
3173
+ target = save_target(target_data)
3174
+ typer.echo(f"✓ Created target: {target.name}")
3175
+ typer.echo(" Type: Local (no SSH)")
3176
+ typer.echo(f" GPU IDs: {parsed_gpu_ids}")
3177
+ typer.echo(f" GPU Type: {gpu_type}")
3178
+ typer.echo(f" Compute: {compute_capability}")
3179
+ typer.echo(f" Torch: {torch_reqs.packages[0]}")
3180
+ typer.echo("")
3181
+ typer.echo(
3182
+ f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
3183
+ )
3184
+ except (ValueError, AssertionError) as e:
3185
+ typer.echo(f"Error: {e}", err=True)
3186
+ raise typer.Exit(1) from None
3187
+
3188
+
3088
3189
  @init_app.command("runpod")
3089
3190
  def init_runpod(
3090
3191
  name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
@@ -3248,23 +3349,29 @@ def init_ssh(
3248
3349
  host: str = typer.Option(..., "--host", "-H", help="SSH host (user@hostname:port)"),
3249
3350
  ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
3250
3351
  gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
3251
- gpu_type: str = typer.Option(
3252
- "H100", "--gpu-type", help="GPU type (H100, A100, B200, MI300X, etc.)"
3352
+ gpu_type: str | None = typer.Option(
3353
+ None, "--gpu-type", help="GPU type (auto-detected if not specified)"
3253
3354
  ),
3254
3355
  docker_image: str | None = typer.Option(
3255
3356
  None, "--docker-image", "-d", help="Docker image (optional)"
3256
3357
  ),
3257
3358
  ncu: bool = typer.Option(False, "--ncu/--no-ncu", help="NCU profiling available"),
3359
+ no_detect: bool = typer.Option(False, "--no-detect", help="Skip GPU auto-detection"),
3258
3360
  ) -> None:
3259
3361
  """Initialize an SSH target for your own GPU hardware.
3260
3362
 
3261
3363
  Creates a target config for direct SSH access to a GPU machine.
3262
- Use for baremetal servers, VMs, or any machine you have SSH access to.
3364
+ Automatically detects GPU type and selects compatible PyTorch version.
3263
3365
 
3264
3366
  Examples:
3367
+ # Auto-detect GPU (recommended)
3265
3368
  wafer config targets init ssh --name my-gpu --host user@192.168.1.100:22
3369
+
3370
+ # Multiple GPUs with NCU profiling
3266
3371
  wafer config targets init ssh --name lab-h100 --host ubuntu@gpu.lab.com:22 --gpu-ids 0,1 --ncu
3267
- wafer config targets init ssh --name docker-gpu --host user@host:22 --docker-image nvcr.io/nvidia/pytorch:24.01-py3
3372
+
3373
+ # Skip detection, specify manually
3374
+ wafer config targets init ssh --name my-gpu --host user@host:22 --gpu-type H100 --no-detect
3268
3375
  """
3269
3376
  from .targets import save_target
3270
3377
 
@@ -3281,17 +3388,87 @@ def init_ssh(
3281
3388
  typer.echo("Example: user@192.168.1.100:22", err=True)
3282
3389
  raise typer.Exit(1)
3283
3390
 
3391
+ # Auto-detect GPU if not specified
3392
+ detected_gpu = None
3393
+ torch_package = None
3394
+ torch_index_url = None
3395
+
3396
+ if not no_detect:
3397
+ typer.echo(f"Connecting to {host}...")
3398
+ try:
3399
+ import trio
3400
+ import trio_asyncio
3401
+
3402
+ from wafer_core.async_ssh import AsyncSSHClient
3403
+ from wafer_core.gpu_detect import (
3404
+ detect_remote_gpu,
3405
+ get_compute_capability,
3406
+ get_torch_requirements,
3407
+ )
3408
+
3409
+ expanded_key = str(Path(ssh_key).expanduser())
3410
+
3411
+ async def _detect() -> None:
3412
+ nonlocal detected_gpu, torch_package, torch_index_url
3413
+ # Need trio_asyncio.open_loop() for asyncssh bridge
3414
+ async with trio_asyncio.open_loop():
3415
+ async with AsyncSSHClient(host, expanded_key) as client:
3416
+ detected_gpu = await detect_remote_gpu(client)
3417
+
3418
+ trio.run(_detect)
3419
+
3420
+ if detected_gpu:
3421
+ typer.echo(f" Found: {detected_gpu.gpu_name}")
3422
+ if detected_gpu.vendor == "nvidia":
3423
+ typer.echo(f" CUDA: {detected_gpu.driver_version}")
3424
+ else:
3425
+ typer.echo(f" ROCm: {detected_gpu.driver_version}")
3426
+
3427
+ # Get torch requirements
3428
+ torch_reqs = get_torch_requirements(detected_gpu)
3429
+ torch_package = torch_reqs.packages[0] # Just torch, not all packages
3430
+ torch_index_url = torch_reqs.index_url
3431
+ typer.echo(f" PyTorch: {torch_package}")
3432
+
3433
+ # Use detected GPU type if not specified
3434
+ if not gpu_type:
3435
+ # Extract GPU name (e.g., "H100" from "NVIDIA H100 80GB HBM3")
3436
+ gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
3437
+ else:
3438
+ typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)")
3439
+ if not gpu_type:
3440
+ gpu_type = "H100" # Default fallback
3441
+ typer.echo(f" Using default: {gpu_type}")
3442
+
3443
+ except Exception as e:
3444
+ typer.echo(f" Detection failed: {e}", err=True)
3445
+ if not gpu_type:
3446
+ gpu_type = "H100"
3447
+ typer.echo(f" Using default: {gpu_type}")
3448
+
3449
+ # Fallback if no detection
3450
+ if not gpu_type:
3451
+ gpu_type = "H100"
3452
+
3284
3453
  # Compute capability mappings
3285
- compute_caps = {
3286
- "B200": "10.0",
3287
- "H100": "9.0",
3288
- "A100": "8.0",
3289
- "A10": "8.6",
3290
- "V100": "7.0",
3291
- "MI300X": "9.4",
3292
- "MI250X": "9.0",
3293
- }
3294
- compute_capability = compute_caps.get(gpu_type, "8.0")
3454
+ if detected_gpu:
3455
+ from wafer_core.gpu_detect import get_compute_capability
3456
+
3457
+ compute_capability = get_compute_capability(detected_gpu)
3458
+ else:
3459
+ compute_caps = {
3460
+ "B200": "10.0",
3461
+ "H100": "9.0",
3462
+ "A100": "8.0",
3463
+ "A10": "8.6",
3464
+ "V100": "7.0",
3465
+ "MI300X": "9.4",
3466
+ "MI250X": "9.0",
3467
+ "RTX 5090": "10.0",
3468
+ "RTX 4090": "8.9",
3469
+ "RTX 3090": "8.6",
3470
+ }
3471
+ compute_capability = compute_caps.get(gpu_type, "8.0")
3295
3472
 
3296
3473
  # Build target data
3297
3474
  target_data = {
@@ -3308,6 +3485,12 @@ def init_ssh(
3308
3485
  if docker_image:
3309
3486
  target_data["docker_image"] = docker_image
3310
3487
 
3488
+ # Add torch requirements if detected
3489
+ if torch_package:
3490
+ target_data["torch_package"] = torch_package
3491
+ if torch_index_url:
3492
+ target_data["torch_index_url"] = torch_index_url
3493
+
3311
3494
  try:
3312
3495
  target = save_target(target_data)
3313
3496
  typer.echo(f"✓ Created target: {target.name}")
@@ -3315,9 +3498,12 @@ def init_ssh(
3315
3498
  typer.echo(f" Host: {host}")
3316
3499
  typer.echo(f" GPU IDs: {parsed_gpu_ids}")
3317
3500
  typer.echo(f" GPU Type: {gpu_type}")
3501
+ typer.echo(f" Compute: {compute_capability}")
3318
3502
  typer.echo(f" NCU: {'Yes' if ncu else 'No'}")
3319
3503
  if docker_image:
3320
3504
  typer.echo(f" Docker: {docker_image}")
3505
+ if torch_package:
3506
+ typer.echo(f" Torch: {torch_package}")
3321
3507
  typer.echo("")
3322
3508
  typer.echo(
3323
3509
  f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
@@ -3327,6 +3513,31 @@ def init_ssh(
3327
3513
  raise typer.Exit(1) from None
3328
3514
 
3329
3515
 
3516
+ def _extract_gpu_type(gpu_name: str) -> str:
3517
+ """Extract GPU type from full GPU name.
3518
+
3519
+ Examples:
3520
+ "NVIDIA H100 80GB HBM3" -> "H100"
3521
+ "NVIDIA GeForce RTX 4090" -> "RTX 4090"
3522
+ "AMD Instinct MI300X OAM" -> "MI300X"
3523
+ """
3524
+ gpu_name_upper = gpu_name.upper()
3525
+
3526
+ # Check for known GPU types
3527
+ known_types = [
3528
+ "B200", "B100", "H200", "H100", "A100", "A10", "V100",
3529
+ "RTX 5090", "RTX 5080", "RTX 4090", "RTX 4080", "RTX 3090", "RTX 3080",
3530
+ "MI300X", "MI250X", "MI100",
3531
+ ]
3532
+
3533
+ for gpu_type in known_types:
3534
+ if gpu_type in gpu_name_upper:
3535
+ return gpu_type
3536
+
3537
+ # Fallback: return cleaned name
3538
+ return gpu_name.replace("NVIDIA ", "").replace("AMD ", "").strip()
3539
+
3540
+
3330
3541
  @targets_app.command("add")
3331
3542
  def targets_add(
3332
3543
  file_path: Path = typer.Argument(..., help="Path to target TOML file"),
@@ -3539,6 +3750,92 @@ def targets_pods() -> None:
3539
3750
  typer.echo()
3540
3751
 
3541
3752
 
3753
+ # ── Pool commands ───────────────────────────────────────────────────────────
3754
+
3755
+
3756
+ @targets_app.command("pool-list")
3757
+ def targets_pool_list() -> None:
3758
+ """List all configured target pools.
3759
+
3760
+ Example:
3761
+ wafer config targets pool-list
3762
+ """
3763
+ from .targets import get_pool, list_pools
3764
+
3765
+ pools = list_pools()
3766
+
3767
+ if not pools:
3768
+ typer.echo("No pools configured")
3769
+ typer.echo("")
3770
+ typer.echo("Define pools in ~/.wafer/config.toml:")
3771
+ typer.echo(" [pools.my-pool]")
3772
+ typer.echo(' targets = ["target-1", "target-2"]')
3773
+ return
3774
+
3775
+ typer.echo("Configured pools:\n")
3776
+ for pool_name in pools:
3777
+ try:
3778
+ targets = get_pool(pool_name)
3779
+ typer.echo(f" {pool_name}: {', '.join(targets)}")
3780
+ except Exception as e:
3781
+ typer.echo(f" {pool_name}: (error: {e})")
3782
+
3783
+
3784
+ @targets_app.command("pool-create")
3785
+ def targets_pool_create(
3786
+ name: str = typer.Argument(..., help="Pool name"),
3787
+ targets: list[str] = typer.Argument(..., help="Target names to include in pool"),
3788
+ ) -> None:
3789
+ """Create or update a target pool.
3790
+
3791
+ Example:
3792
+ wafer config targets pool-create mi300x-pool mi300x-1 mi300x-2 mi300x-3
3793
+ """
3794
+ from .targets import save_pool
3795
+
3796
+ try:
3797
+ save_pool(name, targets)
3798
+ typer.echo(f"Pool '{name}' created with {len(targets)} targets")
3799
+ except FileNotFoundError as e:
3800
+ typer.echo(f"Error: {e}", err=True)
3801
+ raise typer.Exit(1) from None
3802
+
3803
+
3804
+ @targets_app.command("pool-status")
3805
+ def targets_pool_status(
3806
+ name: str = typer.Argument(..., help="Pool name"),
3807
+ ) -> None:
3808
+ """Show status of targets in a pool (locked/available).
3809
+
3810
+ Example:
3811
+ wafer config targets pool-status mi300x-pool
3812
+ """
3813
+ from .target_lock import get_lock_holder, is_target_locked
3814
+ from .targets import get_pool
3815
+
3816
+ try:
3817
+ targets = get_pool(name)
3818
+ except FileNotFoundError as e:
3819
+ typer.echo(f"Error: {e}", err=True)
3820
+ raise typer.Exit(1) from None
3821
+
3822
+ typer.echo(f"Pool '{name}' ({len(targets)} targets):\n")
3823
+
3824
+ available = 0
3825
+ for target_name in targets:
3826
+ locked = is_target_locked(target_name)
3827
+ if locked:
3828
+ pid = get_lock_holder(target_name)
3829
+ pid_str = f" (pid {pid})" if pid else ""
3830
+ typer.echo(f" [busy] {target_name}{pid_str}")
3831
+ else:
3832
+ typer.echo(f" [free] {target_name}")
3833
+ available += 1
3834
+
3835
+ typer.echo("")
3836
+ typer.echo(f"Available: {available}/{len(targets)}")
3837
+
3838
+
3542
3839
  # =============================================================================
3543
3840
  # Billing commands
3544
3841
  # =============================================================================