wafer-cli 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/cli.py +862 -104
- wafer/evaluate.py +1423 -158
- wafer/gpu_run.py +5 -1
- wafer/problems.py +357 -0
- wafer/target_lock.py +198 -0
- wafer/targets.py +158 -0
- wafer/wevin_cli.py +22 -2
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/RECORD +12 -10
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/WHEEL +1 -1
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -30,6 +30,14 @@ import typer
|
|
|
30
30
|
|
|
31
31
|
from .config import WaferConfig, WaferEnvironment
|
|
32
32
|
from .inference import infer_upload_files, resolve_environment
|
|
33
|
+
from .problems import (
|
|
34
|
+
download_problems,
|
|
35
|
+
get_problem_path,
|
|
36
|
+
get_problems_path,
|
|
37
|
+
)
|
|
38
|
+
from .problems import (
|
|
39
|
+
list_problems as list_problems_fn,
|
|
40
|
+
)
|
|
33
41
|
|
|
34
42
|
app = typer.Typer(
|
|
35
43
|
help="GPU development toolkit for LLM coding agents",
|
|
@@ -91,11 +99,15 @@ def main_callback(ctx: typer.Context) -> None:
|
|
|
91
99
|
# Install exception hook to catch SystemExit and mark failures
|
|
92
100
|
original_excepthook = sys.excepthook
|
|
93
101
|
|
|
94
|
-
def custom_excepthook(
|
|
102
|
+
def custom_excepthook(
|
|
103
|
+
exc_type: type[BaseException],
|
|
104
|
+
exc_value: BaseException,
|
|
105
|
+
exc_traceback: object,
|
|
106
|
+
) -> None:
|
|
95
107
|
global _command_outcome
|
|
96
108
|
# Mark as failure if SystemExit with non-zero code, or any other exception
|
|
97
109
|
if exc_type is SystemExit:
|
|
98
|
-
exit_code = exc_value.code if hasattr(exc_value,
|
|
110
|
+
exit_code = exc_value.code if hasattr(exc_value, "code") else 1
|
|
99
111
|
if exit_code != 0 and exit_code is not None:
|
|
100
112
|
_command_outcome = "failure"
|
|
101
113
|
else:
|
|
@@ -200,6 +212,13 @@ kernelbench_app = typer.Typer(
|
|
|
200
212
|
)
|
|
201
213
|
evaluate_app.add_typer(kernelbench_app, name="kernelbench")
|
|
202
214
|
|
|
215
|
+
# Nested subcommand for gpumode format
|
|
216
|
+
gpumode_app = typer.Typer(
|
|
217
|
+
help="Evaluate kernels in GPUMode format (custom_kernel/ref_kernel functions)",
|
|
218
|
+
invoke_without_command=True,
|
|
219
|
+
)
|
|
220
|
+
evaluate_app.add_typer(gpumode_app, name="gpumode")
|
|
221
|
+
|
|
203
222
|
# =============================================================================
|
|
204
223
|
# Dev commands (internal, used by web app proxy)
|
|
205
224
|
# =============================================================================
|
|
@@ -396,6 +415,122 @@ def skill_status() -> None:
|
|
|
396
415
|
typer.echo(f"{tool_name}: Not installed")
|
|
397
416
|
|
|
398
417
|
|
|
418
|
+
# =============================================================================
|
|
419
|
+
# Provider auth management (wafer auth ...)
|
|
420
|
+
# =============================================================================
|
|
421
|
+
|
|
422
|
+
provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
|
|
423
|
+
app.add_typer(provider_auth_app, name="auth")
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
@provider_auth_app.command("login")
|
|
427
|
+
def provider_auth_login(
|
|
428
|
+
provider: str = typer.Argument(
|
|
429
|
+
...,
|
|
430
|
+
help="Provider name: runpod, digitalocean, or modal",
|
|
431
|
+
),
|
|
432
|
+
api_key: str | None = typer.Option(
|
|
433
|
+
None,
|
|
434
|
+
"--api-key",
|
|
435
|
+
"-k",
|
|
436
|
+
help="API key (if not provided, reads from stdin)",
|
|
437
|
+
),
|
|
438
|
+
) -> None:
|
|
439
|
+
"""Save API key for a cloud GPU provider.
|
|
440
|
+
|
|
441
|
+
Stores the key in ~/.wafer/auth.json. Environment variables
|
|
442
|
+
(e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
|
|
443
|
+
|
|
444
|
+
Examples:
|
|
445
|
+
wafer auth login runpod --api-key rp_xxx
|
|
446
|
+
wafer auth login digitalocean --api-key dop_v1_xxx
|
|
447
|
+
echo $API_KEY | wafer auth login runpod
|
|
448
|
+
"""
|
|
449
|
+
import sys
|
|
450
|
+
|
|
451
|
+
from wafer_core.auth import PROVIDERS, save_api_key
|
|
452
|
+
|
|
453
|
+
# Validate provider
|
|
454
|
+
if provider not in PROVIDERS:
|
|
455
|
+
typer.echo(f"Error: Unknown provider '{provider}'", err=True)
|
|
456
|
+
typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
|
|
457
|
+
raise typer.Exit(1)
|
|
458
|
+
|
|
459
|
+
# Get API key from option or stdin
|
|
460
|
+
if api_key is None:
|
|
461
|
+
if sys.stdin.isatty():
|
|
462
|
+
typer.echo(f"Enter API key for {PROVIDERS[provider]['display_name']}:")
|
|
463
|
+
api_key = typer.prompt("API key", hide_input=True)
|
|
464
|
+
else:
|
|
465
|
+
api_key = sys.stdin.read().strip()
|
|
466
|
+
|
|
467
|
+
if not api_key:
|
|
468
|
+
typer.echo("Error: No API key provided", err=True)
|
|
469
|
+
raise typer.Exit(1)
|
|
470
|
+
|
|
471
|
+
# Save the key
|
|
472
|
+
save_api_key(provider, api_key)
|
|
473
|
+
typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
|
|
474
|
+
typer.echo("Stored in: ~/.wafer/auth.json")
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
@provider_auth_app.command("logout")
|
|
478
|
+
def provider_auth_logout(
|
|
479
|
+
provider: str = typer.Argument(
|
|
480
|
+
...,
|
|
481
|
+
help="Provider name: runpod, digitalocean, or modal",
|
|
482
|
+
),
|
|
483
|
+
) -> None:
|
|
484
|
+
"""Remove stored API key for a cloud GPU provider.
|
|
485
|
+
|
|
486
|
+
Examples:
|
|
487
|
+
wafer auth logout runpod
|
|
488
|
+
wafer auth logout digitalocean
|
|
489
|
+
"""
|
|
490
|
+
from wafer_core.auth import PROVIDERS, remove_api_key
|
|
491
|
+
|
|
492
|
+
# Validate provider
|
|
493
|
+
if provider not in PROVIDERS:
|
|
494
|
+
typer.echo(f"Error: Unknown provider '{provider}'", err=True)
|
|
495
|
+
typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
|
|
496
|
+
raise typer.Exit(1)
|
|
497
|
+
|
|
498
|
+
if remove_api_key(provider):
|
|
499
|
+
typer.echo(f"API key removed for {PROVIDERS[provider]['display_name']}")
|
|
500
|
+
else:
|
|
501
|
+
typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
@provider_auth_app.command("status")
|
|
505
|
+
def provider_auth_status() -> None:
|
|
506
|
+
"""Show authentication status for all cloud GPU providers.
|
|
507
|
+
|
|
508
|
+
Displays which providers have API keys configured and where
|
|
509
|
+
the keys are coming from (environment variable or auth.json).
|
|
510
|
+
|
|
511
|
+
Example:
|
|
512
|
+
wafer auth status
|
|
513
|
+
"""
|
|
514
|
+
from wafer_core.auth import get_all_auth_status
|
|
515
|
+
|
|
516
|
+
statuses = get_all_auth_status()
|
|
517
|
+
|
|
518
|
+
typer.echo("Cloud GPU Provider Authentication Status")
|
|
519
|
+
typer.echo("=" * 45)
|
|
520
|
+
|
|
521
|
+
for status in statuses:
|
|
522
|
+
if status.is_authenticated:
|
|
523
|
+
source_str = f"({status.source})" if status.source else ""
|
|
524
|
+
typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
|
|
525
|
+
else:
|
|
526
|
+
typer.echo(f" {status.display_name}: ✗ Not configured")
|
|
527
|
+
typer.echo(f" Run: wafer auth login {status.provider}")
|
|
528
|
+
typer.echo(f" Or set: {status.key_url}")
|
|
529
|
+
|
|
530
|
+
typer.echo("")
|
|
531
|
+
typer.echo("Note: Environment variables take precedence over stored keys.")
|
|
532
|
+
|
|
533
|
+
|
|
399
534
|
@app.command(hidden=True)
|
|
400
535
|
def run(
|
|
401
536
|
command: str = typer.Argument(..., help="Command to run in Docker container"),
|
|
@@ -1289,86 +1424,27 @@ def evaluate( # noqa: PLR0913
|
|
|
1289
1424
|
--benchmark --defensive
|
|
1290
1425
|
|
|
1291
1426
|
Subcommands:
|
|
1292
|
-
|
|
1427
|
+
gpumode Use GPUMode format (functional) - RECOMMENDED
|
|
1293
1428
|
kernelbench Use KernelBench format (ModelNew class)
|
|
1429
|
+
make-template Generate template files for this format (deprecated)
|
|
1294
1430
|
"""
|
|
1295
1431
|
# If a subcommand is being invoked, skip the main evaluation logic
|
|
1296
1432
|
if ctx.invoked_subcommand is not None:
|
|
1297
1433
|
return
|
|
1298
1434
|
|
|
1299
|
-
#
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
typer.echo(
|
|
1313
|
-
"Usage: wafer evaluate --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
|
|
1314
|
-
err=True,
|
|
1315
|
-
)
|
|
1316
|
-
typer.echo("", err=True)
|
|
1317
|
-
typer.echo("Run 'wafer evaluate --help' for full options.", err=True)
|
|
1318
|
-
typer.echo("Run 'wafer evaluate make-template DIR' to generate starter files.", err=True)
|
|
1319
|
-
raise typer.Exit(1)
|
|
1320
|
-
|
|
1321
|
-
from .evaluate import EvaluateArgs, run_evaluate
|
|
1322
|
-
|
|
1323
|
-
args = EvaluateArgs(
|
|
1324
|
-
implementation=implementation,
|
|
1325
|
-
reference=reference,
|
|
1326
|
-
test_cases=test_cases,
|
|
1327
|
-
target_name=target or "",
|
|
1328
|
-
benchmark=benchmark,
|
|
1329
|
-
profile=profile,
|
|
1330
|
-
defensive=defensive,
|
|
1331
|
-
sync_artifacts=sync_artifacts,
|
|
1332
|
-
gpu_id=gpu_id,
|
|
1333
|
-
)
|
|
1334
|
-
|
|
1335
|
-
try:
|
|
1336
|
-
# Use trio_asyncio to run async code that uses both trio and asyncio
|
|
1337
|
-
# (AsyncSSHClient uses asyncssh which is asyncio-based, bridged via trio_asyncio)
|
|
1338
|
-
import trio_asyncio
|
|
1339
|
-
|
|
1340
|
-
result = trio_asyncio.run(run_evaluate, args)
|
|
1341
|
-
except KeyboardInterrupt:
|
|
1342
|
-
typer.echo("\nInterrupted by user", err=True)
|
|
1343
|
-
raise typer.Exit(130) from None
|
|
1344
|
-
except Exception as e:
|
|
1345
|
-
# Unwrap ExceptionGroup (from Trio nurseries) to show actual error
|
|
1346
|
-
if hasattr(e, "exceptions") and e.exceptions:
|
|
1347
|
-
for exc in e.exceptions:
|
|
1348
|
-
typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
|
|
1349
|
-
else:
|
|
1350
|
-
typer.echo(f"Error: {e}", err=True)
|
|
1351
|
-
raise typer.Exit(1) from None
|
|
1352
|
-
|
|
1353
|
-
# Print results
|
|
1354
|
-
if result.success:
|
|
1355
|
-
typer.echo("")
|
|
1356
|
-
typer.echo("=" * 60)
|
|
1357
|
-
status = "PASS" if result.all_correct else "FAIL"
|
|
1358
|
-
typer.echo(f"Result: {status}")
|
|
1359
|
-
score_pct = f"{result.correctness_score:.1%}"
|
|
1360
|
-
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
1361
|
-
if result.geomean_speedup > 0:
|
|
1362
|
-
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
1363
|
-
if result.artifact_path:
|
|
1364
|
-
typer.echo(f"Artifacts: {result.artifact_path}")
|
|
1365
|
-
typer.echo("=" * 60)
|
|
1366
|
-
|
|
1367
|
-
if not result.all_correct:
|
|
1368
|
-
raise typer.Exit(1)
|
|
1369
|
-
else:
|
|
1370
|
-
typer.echo(f"Error: {result.error_message}", err=True)
|
|
1371
|
-
raise typer.Exit(1)
|
|
1435
|
+
# Bare 'wafer evaluate' is no longer supported - must use subcommand
|
|
1436
|
+
typer.echo("Error: 'wafer evaluate' requires a subcommand.", err=True)
|
|
1437
|
+
typer.echo("", err=True)
|
|
1438
|
+
typer.echo("Available subcommands:", err=True)
|
|
1439
|
+
typer.echo(" gpumode Evaluate GPUMode format (custom_kernel/ref_kernel functions)", err=True)
|
|
1440
|
+
typer.echo(" kernelbench Evaluate KernelBench format (ModelNew class)", err=True)
|
|
1441
|
+
typer.echo("", err=True)
|
|
1442
|
+
typer.echo("Examples:", err=True)
|
|
1443
|
+
typer.echo(" wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json", err=True)
|
|
1444
|
+
typer.echo(" wafer evaluate kernelbench --impl impl.py --reference ref.py --benchmark", err=True)
|
|
1445
|
+
typer.echo("", err=True)
|
|
1446
|
+
typer.echo("Run 'wafer evaluate gpumode --help' or 'wafer evaluate kernelbench --help' for options.", err=True)
|
|
1447
|
+
raise typer.Exit(1)
|
|
1372
1448
|
|
|
1373
1449
|
|
|
1374
1450
|
TEMPLATE_KERNEL = '''\
|
|
@@ -1503,8 +1579,59 @@ def evaluate_make_template(
|
|
|
1503
1579
|
# KernelBench format evaluation
|
|
1504
1580
|
# =============================================================================
|
|
1505
1581
|
|
|
1506
|
-
|
|
1507
|
-
|
|
1582
|
+
|
|
1583
|
+
def _get_kernelbench_root() -> Path | None:
|
|
1584
|
+
"""Get KernelBench problems root, preferring downloaded location."""
|
|
1585
|
+
# First check downloaded location
|
|
1586
|
+
downloaded = get_problems_path("kernelbench")
|
|
1587
|
+
if downloaded is not None:
|
|
1588
|
+
kb_root = downloaded / "KernelBench"
|
|
1589
|
+
if kb_root.exists():
|
|
1590
|
+
return kb_root
|
|
1591
|
+
return downloaded
|
|
1592
|
+
|
|
1593
|
+
# Fall back to legacy location (for development)
|
|
1594
|
+
legacy = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench" / "KernelBench"
|
|
1595
|
+
if legacy.exists():
|
|
1596
|
+
return legacy
|
|
1597
|
+
|
|
1598
|
+
return None
|
|
1599
|
+
|
|
1600
|
+
|
|
1601
|
+
@kernelbench_app.command("download")
|
|
1602
|
+
def kernelbench_download(
|
|
1603
|
+
force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
|
|
1604
|
+
) -> None:
|
|
1605
|
+
"""Download KernelBench problems from GitHub.
|
|
1606
|
+
|
|
1607
|
+
Downloads the problem set to ~/.cache/wafer/problems/kernelbench/
|
|
1608
|
+
|
|
1609
|
+
Examples:
|
|
1610
|
+
wafer evaluate kernelbench download
|
|
1611
|
+
wafer evaluate kernelbench download --force # Re-download
|
|
1612
|
+
"""
|
|
1613
|
+
try:
|
|
1614
|
+
path = download_problems("kernelbench", force=force, verbose=True)
|
|
1615
|
+
typer.echo("")
|
|
1616
|
+
typer.echo(f"Problems available at: {path}")
|
|
1617
|
+
typer.echo("Run 'wafer evaluate kernelbench list-problems' to see available problems.")
|
|
1618
|
+
except Exception as e:
|
|
1619
|
+
typer.echo(f"Error downloading problems: {e}", err=True)
|
|
1620
|
+
raise typer.Exit(1) from None
|
|
1621
|
+
|
|
1622
|
+
|
|
1623
|
+
@kernelbench_app.command("list-problems")
|
|
1624
|
+
def kernelbench_list_problems() -> None:
|
|
1625
|
+
"""List available KernelBench problems.
|
|
1626
|
+
|
|
1627
|
+
Examples:
|
|
1628
|
+
wafer evaluate kernelbench list-problems
|
|
1629
|
+
"""
|
|
1630
|
+
try:
|
|
1631
|
+
list_problems_fn("kernelbench", verbose=True)
|
|
1632
|
+
except ValueError as e:
|
|
1633
|
+
typer.echo(str(e), err=True)
|
|
1634
|
+
raise typer.Exit(1) from None
|
|
1508
1635
|
|
|
1509
1636
|
|
|
1510
1637
|
@kernelbench_app.callback(invoke_without_command=True)
|
|
@@ -1528,8 +1655,19 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1528
1655
|
help="GPU target name. See 'wafer config targets list' for available targets.",
|
|
1529
1656
|
autocompletion=complete_target_name,
|
|
1530
1657
|
),
|
|
1658
|
+
pool: str | None = typer.Option(
|
|
1659
|
+
None,
|
|
1660
|
+
"--pool",
|
|
1661
|
+
"-p",
|
|
1662
|
+
help="Target pool name. Acquires first available target from the pool. "
|
|
1663
|
+
"Define pools in ~/.wafer/config.toml under [pools.<name>].",
|
|
1664
|
+
),
|
|
1531
1665
|
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
1532
1666
|
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
1667
|
+
inputs: Path | None = typer.Option(
|
|
1668
|
+
None, "--inputs", help="Custom inputs file to override get_inputs()"
|
|
1669
|
+
),
|
|
1670
|
+
seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
|
|
1533
1671
|
defensive: bool = typer.Option(
|
|
1534
1672
|
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
1535
1673
|
),
|
|
@@ -1586,14 +1724,47 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1586
1724
|
)
|
|
1587
1725
|
raise typer.Exit(1)
|
|
1588
1726
|
|
|
1727
|
+
# Validate --target and --pool are mutually exclusive
|
|
1728
|
+
if target and pool:
|
|
1729
|
+
typer.echo("Error: Cannot specify both --target and --pool", err=True)
|
|
1730
|
+
raise typer.Exit(1)
|
|
1731
|
+
|
|
1589
1732
|
from .evaluate import KernelBenchEvaluateArgs, run_evaluate_kernelbench
|
|
1590
1733
|
|
|
1734
|
+
# If pool specified, acquire a target from the pool
|
|
1735
|
+
resolved_target = target or ""
|
|
1736
|
+
pool_lock_context = None
|
|
1737
|
+
|
|
1738
|
+
if pool:
|
|
1739
|
+
from .target_lock import acquire_from_pool
|
|
1740
|
+
from .targets import get_pool
|
|
1741
|
+
|
|
1742
|
+
try:
|
|
1743
|
+
pool_targets = get_pool(pool)
|
|
1744
|
+
except FileNotFoundError as e:
|
|
1745
|
+
typer.echo(f"Error: {e}", err=True)
|
|
1746
|
+
raise typer.Exit(1) from None
|
|
1747
|
+
|
|
1748
|
+
typer.echo(f"Acquiring target from pool '{pool}' ({len(pool_targets)} targets)...")
|
|
1749
|
+
pool_lock_context = acquire_from_pool(pool_targets)
|
|
1750
|
+
acquired_target = pool_lock_context.__enter__()
|
|
1751
|
+
|
|
1752
|
+
if acquired_target is None:
|
|
1753
|
+
typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
|
|
1754
|
+
typer.echo(f" Targets: {', '.join(pool_targets)}", err=True)
|
|
1755
|
+
raise typer.Exit(1)
|
|
1756
|
+
|
|
1757
|
+
typer.echo(f"Acquired target: {acquired_target}")
|
|
1758
|
+
resolved_target = acquired_target
|
|
1759
|
+
|
|
1591
1760
|
args = KernelBenchEvaluateArgs(
|
|
1592
1761
|
implementation=implementation,
|
|
1593
1762
|
reference=reference,
|
|
1594
|
-
target_name=
|
|
1763
|
+
target_name=resolved_target,
|
|
1595
1764
|
benchmark=benchmark,
|
|
1596
1765
|
profile=profile,
|
|
1766
|
+
inputs=inputs,
|
|
1767
|
+
seed=seed,
|
|
1597
1768
|
defensive=defensive,
|
|
1598
1769
|
sync_artifacts=sync_artifacts,
|
|
1599
1770
|
gpu_id=gpu_id,
|
|
@@ -1609,6 +1780,10 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1609
1780
|
except Exception as e:
|
|
1610
1781
|
typer.echo(f"Error: {e}", err=True)
|
|
1611
1782
|
raise typer.Exit(1) from None
|
|
1783
|
+
finally:
|
|
1784
|
+
# Release pool lock if we acquired one
|
|
1785
|
+
if pool_lock_context is not None:
|
|
1786
|
+
pool_lock_context.__exit__(None, None, None)
|
|
1612
1787
|
|
|
1613
1788
|
# Print results
|
|
1614
1789
|
if result.success:
|
|
@@ -1655,6 +1830,13 @@ def kernelbench_make_template(
|
|
|
1655
1830
|
# Overwrite existing
|
|
1656
1831
|
wafer evaluate kernelbench make-template level1/1 --force
|
|
1657
1832
|
"""
|
|
1833
|
+
# Get problems root (downloaded or legacy)
|
|
1834
|
+
kb_root = _get_kernelbench_root()
|
|
1835
|
+
if kb_root is None:
|
|
1836
|
+
typer.echo("Error: KernelBench problems not found.", err=True)
|
|
1837
|
+
typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
|
|
1838
|
+
raise typer.Exit(1)
|
|
1839
|
+
|
|
1658
1840
|
# Parse problem ID
|
|
1659
1841
|
parts = problem.split("/")
|
|
1660
1842
|
if len(parts) != 2:
|
|
@@ -1666,10 +1848,10 @@ def kernelbench_make_template(
|
|
|
1666
1848
|
level_str = f"level{level_str}"
|
|
1667
1849
|
|
|
1668
1850
|
# Find the problem file
|
|
1669
|
-
problem_dir =
|
|
1851
|
+
problem_dir = kb_root / level_str
|
|
1670
1852
|
if not problem_dir.exists():
|
|
1671
1853
|
typer.echo(f"Error: KernelBench level directory not found: {problem_dir}", err=True)
|
|
1672
|
-
typer.echo(
|
|
1854
|
+
typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
|
|
1673
1855
|
raise typer.Exit(1)
|
|
1674
1856
|
|
|
1675
1857
|
# Find matching problem file
|
|
@@ -1736,6 +1918,293 @@ def kernelbench_make_template(
|
|
|
1736
1918
|
typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
|
|
1737
1919
|
|
|
1738
1920
|
|
|
1921
|
+
# =============================================================================
|
|
1922
|
+
# GPUMode format evaluation
|
|
1923
|
+
# =============================================================================
|
|
1924
|
+
|
|
1925
|
+
|
|
1926
|
+
@gpumode_app.command("download")
|
|
1927
|
+
def gpumode_download(
|
|
1928
|
+
force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
|
|
1929
|
+
) -> None:
|
|
1930
|
+
"""Download GPUMode reference kernels from GitHub.
|
|
1931
|
+
|
|
1932
|
+
Downloads the problem set to ~/.cache/wafer/problems/gpumode/
|
|
1933
|
+
|
|
1934
|
+
Examples:
|
|
1935
|
+
wafer evaluate gpumode download
|
|
1936
|
+
wafer evaluate gpumode download --force # Re-download
|
|
1937
|
+
"""
|
|
1938
|
+
try:
|
|
1939
|
+
path = download_problems("gpumode", force=force, verbose=True)
|
|
1940
|
+
typer.echo("")
|
|
1941
|
+
typer.echo(f"Problems available at: {path}")
|
|
1942
|
+
typer.echo("Run 'wafer evaluate gpumode list-problems' to see available problems.")
|
|
1943
|
+
except Exception as e:
|
|
1944
|
+
typer.echo(f"Error downloading problems: {e}", err=True)
|
|
1945
|
+
raise typer.Exit(1) from None
|
|
1946
|
+
|
|
1947
|
+
|
|
1948
|
+
@gpumode_app.command("list-problems")
|
|
1949
|
+
def gpumode_list_problems() -> None:
|
|
1950
|
+
"""List available GPUMode problems.
|
|
1951
|
+
|
|
1952
|
+
Examples:
|
|
1953
|
+
wafer evaluate gpumode list-problems
|
|
1954
|
+
"""
|
|
1955
|
+
try:
|
|
1956
|
+
list_problems_fn("gpumode", verbose=True)
|
|
1957
|
+
except ValueError as e:
|
|
1958
|
+
typer.echo(str(e), err=True)
|
|
1959
|
+
raise typer.Exit(1) from None
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
@gpumode_app.command("make-template")
|
|
1963
|
+
def gpumode_make_template(
|
|
1964
|
+
problem: str = typer.Option(
|
|
1965
|
+
...,
|
|
1966
|
+
"--problem",
|
|
1967
|
+
"-p",
|
|
1968
|
+
help="Problem ID (e.g., 'pmpp/vectoradd_py' or 'amd/fp8-mm')",
|
|
1969
|
+
),
|
|
1970
|
+
output: Path = typer.Option(
|
|
1971
|
+
None, "--output", "-o", help="Output directory (default: ./<problem_name>/)"
|
|
1972
|
+
),
|
|
1973
|
+
force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
|
|
1974
|
+
) -> None:
|
|
1975
|
+
"""Extract a GPUMode problem as template files.
|
|
1976
|
+
|
|
1977
|
+
Creates a directory with reference.py, task.yml, and other problem files.
|
|
1978
|
+
You then create kernel.py with your custom_kernel implementation.
|
|
1979
|
+
|
|
1980
|
+
Examples:
|
|
1981
|
+
# Extract pmpp vectoradd problem
|
|
1982
|
+
wafer evaluate gpumode make-template --problem pmpp/vectoradd_py
|
|
1983
|
+
|
|
1984
|
+
# Extract to specific directory
|
|
1985
|
+
wafer evaluate gpumode make-template --problem pmpp/vectoradd_py --output ./my-kernel/
|
|
1986
|
+
"""
|
|
1987
|
+
import shutil
|
|
1988
|
+
|
|
1989
|
+
# Get problem path
|
|
1990
|
+
problem_path = get_problem_path("gpumode", problem)
|
|
1991
|
+
if problem_path is None:
|
|
1992
|
+
# Check if problems are downloaded
|
|
1993
|
+
if get_problems_path("gpumode") is None:
|
|
1994
|
+
typer.echo("Error: GPUMode problems not downloaded.", err=True)
|
|
1995
|
+
typer.echo("Run 'wafer evaluate gpumode download' first.", err=True)
|
|
1996
|
+
else:
|
|
1997
|
+
typer.echo(f"Error: Problem '{problem}' not found.", err=True)
|
|
1998
|
+
typer.echo(
|
|
1999
|
+
"Run 'wafer evaluate gpumode list-problems' to see available problems.", err=True
|
|
2000
|
+
)
|
|
2001
|
+
raise typer.Exit(1)
|
|
2002
|
+
|
|
2003
|
+
# Determine output path
|
|
2004
|
+
if output is None:
|
|
2005
|
+
output = Path.cwd() / problem.replace("/", "_")
|
|
2006
|
+
|
|
2007
|
+
output = output.resolve()
|
|
2008
|
+
|
|
2009
|
+
# Check if exists
|
|
2010
|
+
if output.exists() and not force:
|
|
2011
|
+
typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
|
|
2012
|
+
raise typer.Exit(1)
|
|
2013
|
+
|
|
2014
|
+
# Copy the problem directory
|
|
2015
|
+
if output.exists():
|
|
2016
|
+
shutil.rmtree(output)
|
|
2017
|
+
shutil.copytree(problem_path, output)
|
|
2018
|
+
|
|
2019
|
+
typer.echo(f"Created {output}/")
|
|
2020
|
+
typer.echo("")
|
|
2021
|
+
typer.echo("Contents:")
|
|
2022
|
+
for f in sorted(output.iterdir()):
|
|
2023
|
+
if not f.name.startswith("."):
|
|
2024
|
+
typer.echo(f" {f.name}")
|
|
2025
|
+
typer.echo("")
|
|
2026
|
+
typer.echo("Next steps:")
|
|
2027
|
+
typer.echo(" 1. Read reference.py to understand the kernel interface")
|
|
2028
|
+
typer.echo(" 2. Create kernel.py with your custom_kernel implementation:")
|
|
2029
|
+
typer.echo("")
|
|
2030
|
+
typer.echo(" def custom_kernel(data):")
|
|
2031
|
+
typer.echo(" # Your optimized implementation")
|
|
2032
|
+
typer.echo(" ...")
|
|
2033
|
+
typer.echo("")
|
|
2034
|
+
typer.echo(" 3. Run evaluation:")
|
|
2035
|
+
typer.echo(
|
|
2036
|
+
f" wafer evaluate gpumode --impl {output}/kernel.py --reference {output}/reference.py \\"
|
|
2037
|
+
)
|
|
2038
|
+
typer.echo(f" --test-cases {output}/test_cases.json --target <target>")
|
|
2039
|
+
|
|
2040
|
+
|
|
2041
|
+
@gpumode_app.callback(invoke_without_command=True)
|
|
2042
|
+
def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
2043
|
+
ctx: typer.Context,
|
|
2044
|
+
implementation: Path | None = typer.Option(
|
|
2045
|
+
None, "--impl", "-i", help="Path to implementation kernel file"
|
|
2046
|
+
),
|
|
2047
|
+
reference: Path | None = typer.Option(
|
|
2048
|
+
None, "--reference", help="Path to reference kernel file"
|
|
2049
|
+
),
|
|
2050
|
+
test_cases: Path | None = typer.Option(
|
|
2051
|
+
None, "--test-cases", help="Path to test cases JSON file"
|
|
2052
|
+
),
|
|
2053
|
+
target: str | None = typer.Option(
|
|
2054
|
+
None,
|
|
2055
|
+
"--target",
|
|
2056
|
+
"-t",
|
|
2057
|
+
help="GPU target name. See 'wafer config targets list' for available targets.",
|
|
2058
|
+
autocompletion=complete_target_name,
|
|
2059
|
+
),
|
|
2060
|
+
pool: str | None = typer.Option(
|
|
2061
|
+
None,
|
|
2062
|
+
"--pool",
|
|
2063
|
+
"-p",
|
|
2064
|
+
help="Target pool name. Acquires first available target from the pool. "
|
|
2065
|
+
"Define pools in ~/.wafer/config.toml under [pools.<name>].",
|
|
2066
|
+
),
|
|
2067
|
+
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
2068
|
+
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
2069
|
+
defensive: bool = typer.Option(
|
|
2070
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
2071
|
+
),
|
|
2072
|
+
sync_artifacts: bool = typer.Option(
|
|
2073
|
+
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
2074
|
+
),
|
|
2075
|
+
gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
|
|
2076
|
+
) -> None:
|
|
2077
|
+
"""Run kernel evaluation in GPUMode format (functional).
|
|
2078
|
+
|
|
2079
|
+
This format expects:
|
|
2080
|
+
- Implementation: Python file with `custom_kernel(inputs)` function
|
|
2081
|
+
- Reference: Python file with `ref_kernel(inputs)` and `generate_input(**kwargs)` functions
|
|
2082
|
+
- Test cases: JSON file with test parameters
|
|
2083
|
+
|
|
2084
|
+
Examples:
|
|
2085
|
+
# Basic correctness check
|
|
2086
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
|
|
2087
|
+
|
|
2088
|
+
# With benchmarking
|
|
2089
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
2090
|
+
--target vultr-b200 --benchmark
|
|
2091
|
+
|
|
2092
|
+
Subcommands:
|
|
2093
|
+
download Download GPUMode problems from GitHub
|
|
2094
|
+
list-problems List available problems
|
|
2095
|
+
make-template Extract a problem as template files
|
|
2096
|
+
"""
|
|
2097
|
+
# If a subcommand is being invoked, skip the main evaluation logic
|
|
2098
|
+
if ctx.invoked_subcommand is not None:
|
|
2099
|
+
return
|
|
2100
|
+
|
|
2101
|
+
# Validate required args when running evaluation (not subcommands)
|
|
2102
|
+
missing_args = []
|
|
2103
|
+
if implementation is None:
|
|
2104
|
+
missing_args.append("--impl/-i")
|
|
2105
|
+
if reference is None:
|
|
2106
|
+
missing_args.append("--reference")
|
|
2107
|
+
if test_cases is None:
|
|
2108
|
+
missing_args.append("--test-cases")
|
|
2109
|
+
|
|
2110
|
+
if missing_args:
|
|
2111
|
+
typer.echo("Error: Missing required arguments", err=True)
|
|
2112
|
+
typer.echo(f" Required: {', '.join(missing_args)}", err=True)
|
|
2113
|
+
typer.echo("", err=True)
|
|
2114
|
+
typer.echo(
|
|
2115
|
+
"Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
|
|
2116
|
+
err=True,
|
|
2117
|
+
)
|
|
2118
|
+
typer.echo("", err=True)
|
|
2119
|
+
typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
|
|
2120
|
+
typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
|
|
2121
|
+
raise typer.Exit(1)
|
|
2122
|
+
|
|
2123
|
+
# Validate --target and --pool are mutually exclusive
|
|
2124
|
+
if target and pool:
|
|
2125
|
+
typer.echo("Error: Cannot specify both --target and --pool", err=True)
|
|
2126
|
+
raise typer.Exit(1)
|
|
2127
|
+
|
|
2128
|
+
from .evaluate import EvaluateArgs, run_evaluate
|
|
2129
|
+
|
|
2130
|
+
# If pool specified, acquire a target from the pool
|
|
2131
|
+
resolved_target = target or ""
|
|
2132
|
+
pool_lock_context = None
|
|
2133
|
+
|
|
2134
|
+
if pool:
|
|
2135
|
+
from .target_lock import acquire_from_pool
|
|
2136
|
+
from .targets import get_pool
|
|
2137
|
+
|
|
2138
|
+
try:
|
|
2139
|
+
pool_targets = get_pool(pool)
|
|
2140
|
+
except FileNotFoundError as e:
|
|
2141
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2142
|
+
raise typer.Exit(1) from None
|
|
2143
|
+
|
|
2144
|
+
typer.echo(f"Acquiring target from pool '{pool}' ({len(pool_targets)} targets)...")
|
|
2145
|
+
pool_lock_context = acquire_from_pool(pool_targets)
|
|
2146
|
+
acquired_target = pool_lock_context.__enter__()
|
|
2147
|
+
|
|
2148
|
+
if acquired_target is None:
|
|
2149
|
+
typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
|
|
2150
|
+
typer.echo(f" Targets: {', '.join(pool_targets)}", err=True)
|
|
2151
|
+
raise typer.Exit(1)
|
|
2152
|
+
|
|
2153
|
+
typer.echo(f"Acquired target: {acquired_target}")
|
|
2154
|
+
resolved_target = acquired_target
|
|
2155
|
+
|
|
2156
|
+
args = EvaluateArgs(
|
|
2157
|
+
implementation=implementation,
|
|
2158
|
+
reference=reference,
|
|
2159
|
+
test_cases=test_cases,
|
|
2160
|
+
target_name=resolved_target,
|
|
2161
|
+
benchmark=benchmark,
|
|
2162
|
+
profile=profile,
|
|
2163
|
+
defensive=defensive,
|
|
2164
|
+
sync_artifacts=sync_artifacts,
|
|
2165
|
+
gpu_id=gpu_id,
|
|
2166
|
+
)
|
|
2167
|
+
|
|
2168
|
+
try:
|
|
2169
|
+
import trio_asyncio
|
|
2170
|
+
|
|
2171
|
+
result = trio_asyncio.run(run_evaluate, args)
|
|
2172
|
+
except KeyboardInterrupt:
|
|
2173
|
+
typer.echo("\nInterrupted by user", err=True)
|
|
2174
|
+
raise typer.Exit(130) from None
|
|
2175
|
+
except Exception as e:
|
|
2176
|
+
if hasattr(e, "exceptions") and e.exceptions:
|
|
2177
|
+
for exc in e.exceptions:
|
|
2178
|
+
typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
|
|
2179
|
+
else:
|
|
2180
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2181
|
+
raise typer.Exit(1) from None
|
|
2182
|
+
finally:
|
|
2183
|
+
# Release pool lock if we acquired one
|
|
2184
|
+
if pool_lock_context is not None:
|
|
2185
|
+
pool_lock_context.__exit__(None, None, None)
|
|
2186
|
+
|
|
2187
|
+
# Print results
|
|
2188
|
+
if result.success:
|
|
2189
|
+
typer.echo("")
|
|
2190
|
+
typer.echo("=" * 60)
|
|
2191
|
+
status = "PASS" if result.all_correct else "FAIL"
|
|
2192
|
+
typer.echo(f"Result: {status}")
|
|
2193
|
+
score_pct = f"{result.correctness_score:.1%}"
|
|
2194
|
+
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
2195
|
+
if result.geomean_speedup > 0:
|
|
2196
|
+
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
2197
|
+
if result.artifact_path:
|
|
2198
|
+
typer.echo(f"Artifacts: {result.artifact_path}")
|
|
2199
|
+
typer.echo("=" * 60)
|
|
2200
|
+
|
|
2201
|
+
if not result.all_correct:
|
|
2202
|
+
raise typer.Exit(1)
|
|
2203
|
+
else:
|
|
2204
|
+
typer.echo(f"Error: {result.error_message}", err=True)
|
|
2205
|
+
raise typer.Exit(1)
|
|
2206
|
+
|
|
2207
|
+
|
|
1739
2208
|
# =============================================================================
|
|
1740
2209
|
# Push and Remote-Run commands
|
|
1741
2210
|
# =============================================================================
|
|
@@ -1867,7 +2336,7 @@ def _run_direct_mode(
|
|
|
1867
2336
|
typer.echo(f"Uploading {upload_dir.name}...")
|
|
1868
2337
|
try:
|
|
1869
2338
|
push_result = push_direct(upload_dir, target)
|
|
1870
|
-
workspace_name = push_result.
|
|
2339
|
+
workspace_name = push_result.workspace_name
|
|
1871
2340
|
typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
|
|
1872
2341
|
except Exception as e:
|
|
1873
2342
|
typer.echo(f"Error uploading: {e}", err=True)
|
|
@@ -2040,7 +2509,10 @@ def login(
|
|
|
2040
2509
|
None, "--token", "-t", help="Access token (skip browser OAuth)"
|
|
2041
2510
|
),
|
|
2042
2511
|
port: int | None = typer.Option(
|
|
2043
|
-
None,
|
|
2512
|
+
None,
|
|
2513
|
+
"--port",
|
|
2514
|
+
"-p",
|
|
2515
|
+
help="Port for OAuth callback server (default: 8765 for SSH, random for local)",
|
|
2044
2516
|
),
|
|
2045
2517
|
) -> None:
|
|
2046
2518
|
"""Authenticate CLI with wafer-api via GitHub OAuth.
|
|
@@ -2142,9 +2614,8 @@ def login(
|
|
|
2142
2614
|
@app.command("logout")
|
|
2143
2615
|
def logout() -> None:
|
|
2144
2616
|
"""Remove stored credentials."""
|
|
2145
|
-
from .auth import clear_credentials
|
|
2146
|
-
|
|
2147
2617
|
from . import analytics
|
|
2618
|
+
from .auth import clear_credentials
|
|
2148
2619
|
|
|
2149
2620
|
# Track logout event first (while credentials still exist for user identification)
|
|
2150
2621
|
# Note: track_logout() handles the case where user is not logged in
|
|
@@ -2621,6 +3092,7 @@ init_app = typer.Typer(
|
|
|
2621
3092
|
|
|
2622
3093
|
Choose based on your GPU access:
|
|
2623
3094
|
|
|
3095
|
+
local GPU on current machine (no SSH)
|
|
2624
3096
|
ssh Your own hardware via SSH
|
|
2625
3097
|
runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
|
|
2626
3098
|
digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
|
|
@@ -2628,6 +3100,92 @@ Choose based on your GPU access:
|
|
|
2628
3100
|
targets_app.add_typer(init_app, name="init")
|
|
2629
3101
|
|
|
2630
3102
|
|
|
3103
|
+
@init_app.command("local")
|
|
3104
|
+
def init_local(
|
|
3105
|
+
name: str = typer.Option("local", "--name", "-n", help="Target name"),
|
|
3106
|
+
gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
|
|
3107
|
+
) -> None:
|
|
3108
|
+
"""Initialize a local target for GPU on current machine.
|
|
3109
|
+
|
|
3110
|
+
Detects your local GPU and configures a target for direct execution
|
|
3111
|
+
(no SSH). Use this when running wafer on the same machine as the GPU.
|
|
3112
|
+
|
|
3113
|
+
Examples:
|
|
3114
|
+
wafer config targets init local
|
|
3115
|
+
wafer config targets init local --name my-5090 --gpu-ids 0,1
|
|
3116
|
+
"""
|
|
3117
|
+
from .targets import save_target
|
|
3118
|
+
|
|
3119
|
+
# Parse GPU IDs
|
|
3120
|
+
try:
|
|
3121
|
+
parsed_gpu_ids = [int(g.strip()) for g in gpu_ids.split(",")]
|
|
3122
|
+
except ValueError:
|
|
3123
|
+
typer.echo(f"Error: Invalid GPU IDs '{gpu_ids}'. Use comma-separated integers.", err=True)
|
|
3124
|
+
raise typer.Exit(1) from None
|
|
3125
|
+
|
|
3126
|
+
typer.echo("Detecting local GPU...")
|
|
3127
|
+
|
|
3128
|
+
try:
|
|
3129
|
+
from wafer_core.gpu_detect import (
|
|
3130
|
+
detect_local_gpu,
|
|
3131
|
+
get_compute_capability,
|
|
3132
|
+
get_torch_requirements,
|
|
3133
|
+
)
|
|
3134
|
+
|
|
3135
|
+
detected_gpu = detect_local_gpu()
|
|
3136
|
+
|
|
3137
|
+
if detected_gpu:
|
|
3138
|
+
typer.echo(f" Found: {detected_gpu.gpu_name}")
|
|
3139
|
+
if detected_gpu.vendor == "nvidia":
|
|
3140
|
+
typer.echo(f" CUDA: {detected_gpu.driver_version}")
|
|
3141
|
+
else:
|
|
3142
|
+
typer.echo(f" ROCm: {detected_gpu.driver_version}")
|
|
3143
|
+
typer.echo(f" GPU count: {detected_gpu.gpu_count}")
|
|
3144
|
+
|
|
3145
|
+
# Get torch requirements and compute capability
|
|
3146
|
+
torch_reqs = get_torch_requirements(detected_gpu)
|
|
3147
|
+
compute_capability = get_compute_capability(detected_gpu)
|
|
3148
|
+
gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
|
|
3149
|
+
|
|
3150
|
+
typer.echo(f" PyTorch: {torch_reqs.packages[0]}")
|
|
3151
|
+
else:
|
|
3152
|
+
typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)", err=True)
|
|
3153
|
+
raise typer.Exit(1)
|
|
3154
|
+
|
|
3155
|
+
except ImportError as e:
|
|
3156
|
+
typer.echo(f"Error: Missing dependency: {e}", err=True)
|
|
3157
|
+
raise typer.Exit(1) from None
|
|
3158
|
+
|
|
3159
|
+
# Build target data
|
|
3160
|
+
target_data = {
|
|
3161
|
+
"name": name,
|
|
3162
|
+
"type": "local",
|
|
3163
|
+
"gpu_ids": parsed_gpu_ids,
|
|
3164
|
+
"gpu_type": gpu_type,
|
|
3165
|
+
"compute_capability": compute_capability,
|
|
3166
|
+
"torch_package": torch_reqs.packages[0],
|
|
3167
|
+
"torch_index_url": torch_reqs.index_url,
|
|
3168
|
+
"vendor": detected_gpu.vendor,
|
|
3169
|
+
"driver_version": detected_gpu.driver_version,
|
|
3170
|
+
}
|
|
3171
|
+
|
|
3172
|
+
try:
|
|
3173
|
+
target = save_target(target_data)
|
|
3174
|
+
typer.echo(f"✓ Created target: {target.name}")
|
|
3175
|
+
typer.echo(" Type: Local (no SSH)")
|
|
3176
|
+
typer.echo(f" GPU IDs: {parsed_gpu_ids}")
|
|
3177
|
+
typer.echo(f" GPU Type: {gpu_type}")
|
|
3178
|
+
typer.echo(f" Compute: {compute_capability}")
|
|
3179
|
+
typer.echo(f" Torch: {torch_reqs.packages[0]}")
|
|
3180
|
+
typer.echo("")
|
|
3181
|
+
typer.echo(
|
|
3182
|
+
f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
|
|
3183
|
+
)
|
|
3184
|
+
except (ValueError, AssertionError) as e:
|
|
3185
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3186
|
+
raise typer.Exit(1) from None
|
|
3187
|
+
|
|
3188
|
+
|
|
2631
3189
|
@init_app.command("runpod")
|
|
2632
3190
|
def init_runpod(
|
|
2633
3191
|
name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
|
|
@@ -2791,23 +3349,29 @@ def init_ssh(
|
|
|
2791
3349
|
host: str = typer.Option(..., "--host", "-H", help="SSH host (user@hostname:port)"),
|
|
2792
3350
|
ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
|
|
2793
3351
|
gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
|
|
2794
|
-
gpu_type: str = typer.Option(
|
|
2795
|
-
|
|
3352
|
+
gpu_type: str | None = typer.Option(
|
|
3353
|
+
None, "--gpu-type", help="GPU type (auto-detected if not specified)"
|
|
2796
3354
|
),
|
|
2797
3355
|
docker_image: str | None = typer.Option(
|
|
2798
3356
|
None, "--docker-image", "-d", help="Docker image (optional)"
|
|
2799
3357
|
),
|
|
2800
3358
|
ncu: bool = typer.Option(False, "--ncu/--no-ncu", help="NCU profiling available"),
|
|
3359
|
+
no_detect: bool = typer.Option(False, "--no-detect", help="Skip GPU auto-detection"),
|
|
2801
3360
|
) -> None:
|
|
2802
3361
|
"""Initialize an SSH target for your own GPU hardware.
|
|
2803
3362
|
|
|
2804
3363
|
Creates a target config for direct SSH access to a GPU machine.
|
|
2805
|
-
|
|
3364
|
+
Automatically detects GPU type and selects compatible PyTorch version.
|
|
2806
3365
|
|
|
2807
3366
|
Examples:
|
|
3367
|
+
# Auto-detect GPU (recommended)
|
|
2808
3368
|
wafer config targets init ssh --name my-gpu --host user@192.168.1.100:22
|
|
3369
|
+
|
|
3370
|
+
# Multiple GPUs with NCU profiling
|
|
2809
3371
|
wafer config targets init ssh --name lab-h100 --host ubuntu@gpu.lab.com:22 --gpu-ids 0,1 --ncu
|
|
2810
|
-
|
|
3372
|
+
|
|
3373
|
+
# Skip detection, specify manually
|
|
3374
|
+
wafer config targets init ssh --name my-gpu --host user@host:22 --gpu-type H100 --no-detect
|
|
2811
3375
|
"""
|
|
2812
3376
|
from .targets import save_target
|
|
2813
3377
|
|
|
@@ -2824,17 +3388,87 @@ def init_ssh(
|
|
|
2824
3388
|
typer.echo("Example: user@192.168.1.100:22", err=True)
|
|
2825
3389
|
raise typer.Exit(1)
|
|
2826
3390
|
|
|
3391
|
+
# Auto-detect GPU if not specified
|
|
3392
|
+
detected_gpu = None
|
|
3393
|
+
torch_package = None
|
|
3394
|
+
torch_index_url = None
|
|
3395
|
+
|
|
3396
|
+
if not no_detect:
|
|
3397
|
+
typer.echo(f"Connecting to {host}...")
|
|
3398
|
+
try:
|
|
3399
|
+
import trio
|
|
3400
|
+
import trio_asyncio
|
|
3401
|
+
|
|
3402
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
3403
|
+
from wafer_core.gpu_detect import (
|
|
3404
|
+
detect_remote_gpu,
|
|
3405
|
+
get_compute_capability,
|
|
3406
|
+
get_torch_requirements,
|
|
3407
|
+
)
|
|
3408
|
+
|
|
3409
|
+
expanded_key = str(Path(ssh_key).expanduser())
|
|
3410
|
+
|
|
3411
|
+
async def _detect() -> None:
|
|
3412
|
+
nonlocal detected_gpu, torch_package, torch_index_url
|
|
3413
|
+
# Need trio_asyncio.open_loop() for asyncssh bridge
|
|
3414
|
+
async with trio_asyncio.open_loop():
|
|
3415
|
+
async with AsyncSSHClient(host, expanded_key) as client:
|
|
3416
|
+
detected_gpu = await detect_remote_gpu(client)
|
|
3417
|
+
|
|
3418
|
+
trio.run(_detect)
|
|
3419
|
+
|
|
3420
|
+
if detected_gpu:
|
|
3421
|
+
typer.echo(f" Found: {detected_gpu.gpu_name}")
|
|
3422
|
+
if detected_gpu.vendor == "nvidia":
|
|
3423
|
+
typer.echo(f" CUDA: {detected_gpu.driver_version}")
|
|
3424
|
+
else:
|
|
3425
|
+
typer.echo(f" ROCm: {detected_gpu.driver_version}")
|
|
3426
|
+
|
|
3427
|
+
# Get torch requirements
|
|
3428
|
+
torch_reqs = get_torch_requirements(detected_gpu)
|
|
3429
|
+
torch_package = torch_reqs.packages[0] # Just torch, not all packages
|
|
3430
|
+
torch_index_url = torch_reqs.index_url
|
|
3431
|
+
typer.echo(f" PyTorch: {torch_package}")
|
|
3432
|
+
|
|
3433
|
+
# Use detected GPU type if not specified
|
|
3434
|
+
if not gpu_type:
|
|
3435
|
+
# Extract GPU name (e.g., "H100" from "NVIDIA H100 80GB HBM3")
|
|
3436
|
+
gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
|
|
3437
|
+
else:
|
|
3438
|
+
typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)")
|
|
3439
|
+
if not gpu_type:
|
|
3440
|
+
gpu_type = "H100" # Default fallback
|
|
3441
|
+
typer.echo(f" Using default: {gpu_type}")
|
|
3442
|
+
|
|
3443
|
+
except Exception as e:
|
|
3444
|
+
typer.echo(f" Detection failed: {e}", err=True)
|
|
3445
|
+
if not gpu_type:
|
|
3446
|
+
gpu_type = "H100"
|
|
3447
|
+
typer.echo(f" Using default: {gpu_type}")
|
|
3448
|
+
|
|
3449
|
+
# Fallback if no detection
|
|
3450
|
+
if not gpu_type:
|
|
3451
|
+
gpu_type = "H100"
|
|
3452
|
+
|
|
2827
3453
|
# Compute capability mappings
|
|
2828
|
-
|
|
2829
|
-
|
|
2830
|
-
|
|
2831
|
-
|
|
2832
|
-
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
3454
|
+
if detected_gpu:
|
|
3455
|
+
from wafer_core.gpu_detect import get_compute_capability
|
|
3456
|
+
|
|
3457
|
+
compute_capability = get_compute_capability(detected_gpu)
|
|
3458
|
+
else:
|
|
3459
|
+
compute_caps = {
|
|
3460
|
+
"B200": "10.0",
|
|
3461
|
+
"H100": "9.0",
|
|
3462
|
+
"A100": "8.0",
|
|
3463
|
+
"A10": "8.6",
|
|
3464
|
+
"V100": "7.0",
|
|
3465
|
+
"MI300X": "9.4",
|
|
3466
|
+
"MI250X": "9.0",
|
|
3467
|
+
"RTX 5090": "10.0",
|
|
3468
|
+
"RTX 4090": "8.9",
|
|
3469
|
+
"RTX 3090": "8.6",
|
|
3470
|
+
}
|
|
3471
|
+
compute_capability = compute_caps.get(gpu_type, "8.0")
|
|
2838
3472
|
|
|
2839
3473
|
# Build target data
|
|
2840
3474
|
target_data = {
|
|
@@ -2851,6 +3485,12 @@ def init_ssh(
|
|
|
2851
3485
|
if docker_image:
|
|
2852
3486
|
target_data["docker_image"] = docker_image
|
|
2853
3487
|
|
|
3488
|
+
# Add torch requirements if detected
|
|
3489
|
+
if torch_package:
|
|
3490
|
+
target_data["torch_package"] = torch_package
|
|
3491
|
+
if torch_index_url:
|
|
3492
|
+
target_data["torch_index_url"] = torch_index_url
|
|
3493
|
+
|
|
2854
3494
|
try:
|
|
2855
3495
|
target = save_target(target_data)
|
|
2856
3496
|
typer.echo(f"✓ Created target: {target.name}")
|
|
@@ -2858,9 +3498,12 @@ def init_ssh(
|
|
|
2858
3498
|
typer.echo(f" Host: {host}")
|
|
2859
3499
|
typer.echo(f" GPU IDs: {parsed_gpu_ids}")
|
|
2860
3500
|
typer.echo(f" GPU Type: {gpu_type}")
|
|
3501
|
+
typer.echo(f" Compute: {compute_capability}")
|
|
2861
3502
|
typer.echo(f" NCU: {'Yes' if ncu else 'No'}")
|
|
2862
3503
|
if docker_image:
|
|
2863
3504
|
typer.echo(f" Docker: {docker_image}")
|
|
3505
|
+
if torch_package:
|
|
3506
|
+
typer.echo(f" Torch: {torch_package}")
|
|
2864
3507
|
typer.echo("")
|
|
2865
3508
|
typer.echo(
|
|
2866
3509
|
f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
|
|
@@ -2870,6 +3513,31 @@ def init_ssh(
|
|
|
2870
3513
|
raise typer.Exit(1) from None
|
|
2871
3514
|
|
|
2872
3515
|
|
|
3516
|
+
def _extract_gpu_type(gpu_name: str) -> str:
|
|
3517
|
+
"""Extract GPU type from full GPU name.
|
|
3518
|
+
|
|
3519
|
+
Examples:
|
|
3520
|
+
"NVIDIA H100 80GB HBM3" -> "H100"
|
|
3521
|
+
"NVIDIA GeForce RTX 4090" -> "RTX 4090"
|
|
3522
|
+
"AMD Instinct MI300X OAM" -> "MI300X"
|
|
3523
|
+
"""
|
|
3524
|
+
gpu_name_upper = gpu_name.upper()
|
|
3525
|
+
|
|
3526
|
+
# Check for known GPU types
|
|
3527
|
+
known_types = [
|
|
3528
|
+
"B200", "B100", "H200", "H100", "A100", "A10", "V100",
|
|
3529
|
+
"RTX 5090", "RTX 5080", "RTX 4090", "RTX 4080", "RTX 3090", "RTX 3080",
|
|
3530
|
+
"MI300X", "MI250X", "MI100",
|
|
3531
|
+
]
|
|
3532
|
+
|
|
3533
|
+
for gpu_type in known_types:
|
|
3534
|
+
if gpu_type in gpu_name_upper:
|
|
3535
|
+
return gpu_type
|
|
3536
|
+
|
|
3537
|
+
# Fallback: return cleaned name
|
|
3538
|
+
return gpu_name.replace("NVIDIA ", "").replace("AMD ", "").strip()
|
|
3539
|
+
|
|
3540
|
+
|
|
2873
3541
|
@targets_app.command("add")
|
|
2874
3542
|
def targets_add(
|
|
2875
3543
|
file_path: Path = typer.Argument(..., help="Path to target TOML file"),
|
|
@@ -3082,6 +3750,92 @@ def targets_pods() -> None:
|
|
|
3082
3750
|
typer.echo()
|
|
3083
3751
|
|
|
3084
3752
|
|
|
3753
|
+
# ── Pool commands ───────────────────────────────────────────────────────────
|
|
3754
|
+
|
|
3755
|
+
|
|
3756
|
+
@targets_app.command("pool-list")
|
|
3757
|
+
def targets_pool_list() -> None:
|
|
3758
|
+
"""List all configured target pools.
|
|
3759
|
+
|
|
3760
|
+
Example:
|
|
3761
|
+
wafer config targets pool-list
|
|
3762
|
+
"""
|
|
3763
|
+
from .targets import get_pool, list_pools
|
|
3764
|
+
|
|
3765
|
+
pools = list_pools()
|
|
3766
|
+
|
|
3767
|
+
if not pools:
|
|
3768
|
+
typer.echo("No pools configured")
|
|
3769
|
+
typer.echo("")
|
|
3770
|
+
typer.echo("Define pools in ~/.wafer/config.toml:")
|
|
3771
|
+
typer.echo(" [pools.my-pool]")
|
|
3772
|
+
typer.echo(' targets = ["target-1", "target-2"]')
|
|
3773
|
+
return
|
|
3774
|
+
|
|
3775
|
+
typer.echo("Configured pools:\n")
|
|
3776
|
+
for pool_name in pools:
|
|
3777
|
+
try:
|
|
3778
|
+
targets = get_pool(pool_name)
|
|
3779
|
+
typer.echo(f" {pool_name}: {', '.join(targets)}")
|
|
3780
|
+
except Exception as e:
|
|
3781
|
+
typer.echo(f" {pool_name}: (error: {e})")
|
|
3782
|
+
|
|
3783
|
+
|
|
3784
|
+
@targets_app.command("pool-create")
|
|
3785
|
+
def targets_pool_create(
|
|
3786
|
+
name: str = typer.Argument(..., help="Pool name"),
|
|
3787
|
+
targets: list[str] = typer.Argument(..., help="Target names to include in pool"),
|
|
3788
|
+
) -> None:
|
|
3789
|
+
"""Create or update a target pool.
|
|
3790
|
+
|
|
3791
|
+
Example:
|
|
3792
|
+
wafer config targets pool-create mi300x-pool mi300x-1 mi300x-2 mi300x-3
|
|
3793
|
+
"""
|
|
3794
|
+
from .targets import save_pool
|
|
3795
|
+
|
|
3796
|
+
try:
|
|
3797
|
+
save_pool(name, targets)
|
|
3798
|
+
typer.echo(f"Pool '{name}' created with {len(targets)} targets")
|
|
3799
|
+
except FileNotFoundError as e:
|
|
3800
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3801
|
+
raise typer.Exit(1) from None
|
|
3802
|
+
|
|
3803
|
+
|
|
3804
|
+
@targets_app.command("pool-status")
|
|
3805
|
+
def targets_pool_status(
|
|
3806
|
+
name: str = typer.Argument(..., help="Pool name"),
|
|
3807
|
+
) -> None:
|
|
3808
|
+
"""Show status of targets in a pool (locked/available).
|
|
3809
|
+
|
|
3810
|
+
Example:
|
|
3811
|
+
wafer config targets pool-status mi300x-pool
|
|
3812
|
+
"""
|
|
3813
|
+
from .target_lock import get_lock_holder, is_target_locked
|
|
3814
|
+
from .targets import get_pool
|
|
3815
|
+
|
|
3816
|
+
try:
|
|
3817
|
+
targets = get_pool(name)
|
|
3818
|
+
except FileNotFoundError as e:
|
|
3819
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3820
|
+
raise typer.Exit(1) from None
|
|
3821
|
+
|
|
3822
|
+
typer.echo(f"Pool '{name}' ({len(targets)} targets):\n")
|
|
3823
|
+
|
|
3824
|
+
available = 0
|
|
3825
|
+
for target_name in targets:
|
|
3826
|
+
locked = is_target_locked(target_name)
|
|
3827
|
+
if locked:
|
|
3828
|
+
pid = get_lock_holder(target_name)
|
|
3829
|
+
pid_str = f" (pid {pid})" if pid else ""
|
|
3830
|
+
typer.echo(f" [busy] {target_name}{pid_str}")
|
|
3831
|
+
else:
|
|
3832
|
+
typer.echo(f" [free] {target_name}")
|
|
3833
|
+
available += 1
|
|
3834
|
+
|
|
3835
|
+
typer.echo("")
|
|
3836
|
+
typer.echo(f"Available: {available}/{len(targets)}")
|
|
3837
|
+
|
|
3838
|
+
|
|
3085
3839
|
# =============================================================================
|
|
3086
3840
|
# Billing commands
|
|
3087
3841
|
# =============================================================================
|
|
@@ -3115,7 +3869,9 @@ def billing_usage(
|
|
|
3115
3869
|
@billing_app.command("topup")
|
|
3116
3870
|
def billing_topup(
|
|
3117
3871
|
amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
|
|
3118
|
-
no_browser: bool = typer.Option(
|
|
3872
|
+
no_browser: bool = typer.Option(
|
|
3873
|
+
False, "--no-browser", help="Print URL instead of opening browser"
|
|
3874
|
+
),
|
|
3119
3875
|
) -> None:
|
|
3120
3876
|
"""Add credits to your account.
|
|
3121
3877
|
|
|
@@ -3161,7 +3917,9 @@ def billing_topup(
|
|
|
3161
3917
|
|
|
3162
3918
|
@billing_app.command("portal")
|
|
3163
3919
|
def billing_portal(
|
|
3164
|
-
no_browser: bool = typer.Option(
|
|
3920
|
+
no_browser: bool = typer.Option(
|
|
3921
|
+
False, "--no-browser", help="Print URL instead of opening browser"
|
|
3922
|
+
),
|
|
3165
3923
|
) -> None:
|
|
3166
3924
|
"""Open Stripe billing portal.
|
|
3167
3925
|
|
|
@@ -4437,8 +5195,8 @@ def _setup_wafer_core_env() -> None:
|
|
|
4437
5195
|
- WAFER_API_URL: If already set, uses that instead of config
|
|
4438
5196
|
- WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
|
|
4439
5197
|
"""
|
|
4440
|
-
from .global_config import get_api_url
|
|
4441
5198
|
from .auth import get_valid_token
|
|
5199
|
+
from .global_config import get_api_url
|
|
4442
5200
|
|
|
4443
5201
|
# Set API URL (get_api_url already respects WAFER_API_URL env var)
|
|
4444
5202
|
os.environ["WAFER_API_URL"] = get_api_url()
|
|
@@ -4742,8 +5500,8 @@ def capture_command( # noqa: PLR0915
|
|
|
4742
5500
|
import os
|
|
4743
5501
|
import tomllib
|
|
4744
5502
|
|
|
4745
|
-
from .global_config import get_api_url
|
|
4746
5503
|
from .auth import get_valid_token
|
|
5504
|
+
from .global_config import get_api_url
|
|
4747
5505
|
|
|
4748
5506
|
# Set environment variables for wafer-core BEFORE importing it
|
|
4749
5507
|
# wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
|
|
@@ -4947,8 +5705,8 @@ def capture_list_command(
|
|
|
4947
5705
|
"""
|
|
4948
5706
|
import os
|
|
4949
5707
|
|
|
4950
|
-
from .global_config import get_api_url
|
|
4951
5708
|
from .auth import get_valid_token
|
|
5709
|
+
from .global_config import get_api_url
|
|
4952
5710
|
|
|
4953
5711
|
# Set environment variables for wafer-core BEFORE importing it
|
|
4954
5712
|
os.environ["WAFER_API_URL"] = get_api_url()
|