wafer-cli 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/auth.py +85 -0
- wafer/cli.py +1196 -160
- wafer/evaluate.py +1171 -209
- wafer/gpu_run.py +5 -1
- wafer/kernel_scope.py +453 -0
- wafer/problems.py +357 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +490 -0
- wafer/wevin_cli.py +2 -0
- wafer/workspaces.py +53 -1
- {wafer_cli-0.2.7.dist-info → wafer_cli-0.2.9.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.7.dist-info → wafer_cli-0.2.9.dist-info}/RECORD +15 -12
- {wafer_cli-0.2.7.dist-info → wafer_cli-0.2.9.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.7.dist-info → wafer_cli-0.2.9.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.7.dist-info → wafer_cli-0.2.9.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -30,6 +30,14 @@ import typer
|
|
|
30
30
|
|
|
31
31
|
from .config import WaferConfig, WaferEnvironment
|
|
32
32
|
from .inference import infer_upload_files, resolve_environment
|
|
33
|
+
from .problems import (
|
|
34
|
+
download_problems,
|
|
35
|
+
get_problem_path,
|
|
36
|
+
get_problems_path,
|
|
37
|
+
)
|
|
38
|
+
from .problems import (
|
|
39
|
+
list_problems as list_problems_fn,
|
|
40
|
+
)
|
|
33
41
|
|
|
34
42
|
app = typer.Typer(
|
|
35
43
|
help="GPU development toolkit for LLM coding agents",
|
|
@@ -91,11 +99,15 @@ def main_callback(ctx: typer.Context) -> None:
|
|
|
91
99
|
# Install exception hook to catch SystemExit and mark failures
|
|
92
100
|
original_excepthook = sys.excepthook
|
|
93
101
|
|
|
94
|
-
def custom_excepthook(
|
|
102
|
+
def custom_excepthook(
|
|
103
|
+
exc_type: type[BaseException],
|
|
104
|
+
exc_value: BaseException,
|
|
105
|
+
exc_traceback: object,
|
|
106
|
+
) -> None:
|
|
95
107
|
global _command_outcome
|
|
96
108
|
# Mark as failure if SystemExit with non-zero code, or any other exception
|
|
97
109
|
if exc_type is SystemExit:
|
|
98
|
-
exit_code = exc_value.code if hasattr(exc_value,
|
|
110
|
+
exit_code = exc_value.code if hasattr(exc_value, "code") else 1
|
|
99
111
|
if exit_code != 0 and exit_code is not None:
|
|
100
112
|
_command_outcome = "failure"
|
|
101
113
|
else:
|
|
@@ -200,6 +212,13 @@ kernelbench_app = typer.Typer(
|
|
|
200
212
|
)
|
|
201
213
|
evaluate_app.add_typer(kernelbench_app, name="kernelbench")
|
|
202
214
|
|
|
215
|
+
# Nested subcommand for gpumode format
|
|
216
|
+
gpumode_app = typer.Typer(
|
|
217
|
+
help="Evaluate kernels in GPUMode format (custom_kernel/ref_kernel functions)",
|
|
218
|
+
invoke_without_command=True,
|
|
219
|
+
)
|
|
220
|
+
evaluate_app.add_typer(gpumode_app, name="gpumode")
|
|
221
|
+
|
|
203
222
|
# =============================================================================
|
|
204
223
|
# Dev commands (internal, used by web app proxy)
|
|
205
224
|
# =============================================================================
|
|
@@ -242,6 +261,10 @@ app.add_typer(amd_app, name="amd")
|
|
|
242
261
|
isa_app = typer.Typer(help="ISA analysis for AMD GPU code objects (.co files)")
|
|
243
262
|
amd_app.add_typer(isa_app, name="isa")
|
|
244
263
|
|
|
264
|
+
# Kernel Scope - static ISA analysis for Triton kernels
|
|
265
|
+
kernel_scope_app = typer.Typer(help="Static ISA analysis for Triton compilation artifacts")
|
|
266
|
+
amd_app.add_typer(kernel_scope_app, name="kernel-scope")
|
|
267
|
+
|
|
245
268
|
# =============================================================================
|
|
246
269
|
# Skill management (wafer skill ...)
|
|
247
270
|
# =============================================================================
|
|
@@ -396,6 +419,122 @@ def skill_status() -> None:
|
|
|
396
419
|
typer.echo(f"{tool_name}: Not installed")
|
|
397
420
|
|
|
398
421
|
|
|
422
|
+
# =============================================================================
|
|
423
|
+
# Provider auth management (wafer auth ...)
|
|
424
|
+
# =============================================================================
|
|
425
|
+
|
|
426
|
+
provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
|
|
427
|
+
app.add_typer(provider_auth_app, name="auth")
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
@provider_auth_app.command("login")
|
|
431
|
+
def provider_auth_login(
|
|
432
|
+
provider: str = typer.Argument(
|
|
433
|
+
...,
|
|
434
|
+
help="Provider name: runpod, digitalocean, or modal",
|
|
435
|
+
),
|
|
436
|
+
api_key: str | None = typer.Option(
|
|
437
|
+
None,
|
|
438
|
+
"--api-key",
|
|
439
|
+
"-k",
|
|
440
|
+
help="API key (if not provided, reads from stdin)",
|
|
441
|
+
),
|
|
442
|
+
) -> None:
|
|
443
|
+
"""Save API key for a cloud GPU provider.
|
|
444
|
+
|
|
445
|
+
Stores the key in ~/.wafer/auth.json. Environment variables
|
|
446
|
+
(e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
|
|
447
|
+
|
|
448
|
+
Examples:
|
|
449
|
+
wafer auth login runpod --api-key rp_xxx
|
|
450
|
+
wafer auth login digitalocean --api-key dop_v1_xxx
|
|
451
|
+
echo $API_KEY | wafer auth login runpod
|
|
452
|
+
"""
|
|
453
|
+
import sys
|
|
454
|
+
|
|
455
|
+
from wafer_core.auth import PROVIDERS, save_api_key
|
|
456
|
+
|
|
457
|
+
# Validate provider
|
|
458
|
+
if provider not in PROVIDERS:
|
|
459
|
+
typer.echo(f"Error: Unknown provider '{provider}'", err=True)
|
|
460
|
+
typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
|
|
461
|
+
raise typer.Exit(1)
|
|
462
|
+
|
|
463
|
+
# Get API key from option or stdin
|
|
464
|
+
if api_key is None:
|
|
465
|
+
if sys.stdin.isatty():
|
|
466
|
+
typer.echo(f"Enter API key for {PROVIDERS[provider]['display_name']}:")
|
|
467
|
+
api_key = typer.prompt("API key", hide_input=True)
|
|
468
|
+
else:
|
|
469
|
+
api_key = sys.stdin.read().strip()
|
|
470
|
+
|
|
471
|
+
if not api_key:
|
|
472
|
+
typer.echo("Error: No API key provided", err=True)
|
|
473
|
+
raise typer.Exit(1)
|
|
474
|
+
|
|
475
|
+
# Save the key
|
|
476
|
+
save_api_key(provider, api_key)
|
|
477
|
+
typer.echo(f"API key saved for {PROVIDERS[provider]['display_name']}")
|
|
478
|
+
typer.echo("Stored in: ~/.wafer/auth.json")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
@provider_auth_app.command("logout")
|
|
482
|
+
def provider_auth_logout(
|
|
483
|
+
provider: str = typer.Argument(
|
|
484
|
+
...,
|
|
485
|
+
help="Provider name: runpod, digitalocean, or modal",
|
|
486
|
+
),
|
|
487
|
+
) -> None:
|
|
488
|
+
"""Remove stored API key for a cloud GPU provider.
|
|
489
|
+
|
|
490
|
+
Examples:
|
|
491
|
+
wafer auth logout runpod
|
|
492
|
+
wafer auth logout digitalocean
|
|
493
|
+
"""
|
|
494
|
+
from wafer_core.auth import PROVIDERS, remove_api_key
|
|
495
|
+
|
|
496
|
+
# Validate provider
|
|
497
|
+
if provider not in PROVIDERS:
|
|
498
|
+
typer.echo(f"Error: Unknown provider '{provider}'", err=True)
|
|
499
|
+
typer.echo(f"Valid providers: {', '.join(PROVIDERS.keys())}", err=True)
|
|
500
|
+
raise typer.Exit(1)
|
|
501
|
+
|
|
502
|
+
if remove_api_key(provider):
|
|
503
|
+
typer.echo(f"API key removed for {PROVIDERS[provider]['display_name']}")
|
|
504
|
+
else:
|
|
505
|
+
typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@provider_auth_app.command("status")
|
|
509
|
+
def provider_auth_status() -> None:
|
|
510
|
+
"""Show authentication status for all cloud GPU providers.
|
|
511
|
+
|
|
512
|
+
Displays which providers have API keys configured and where
|
|
513
|
+
the keys are coming from (environment variable or auth.json).
|
|
514
|
+
|
|
515
|
+
Example:
|
|
516
|
+
wafer auth status
|
|
517
|
+
"""
|
|
518
|
+
from wafer_core.auth import get_all_auth_status
|
|
519
|
+
|
|
520
|
+
statuses = get_all_auth_status()
|
|
521
|
+
|
|
522
|
+
typer.echo("Cloud GPU Provider Authentication Status")
|
|
523
|
+
typer.echo("=" * 45)
|
|
524
|
+
|
|
525
|
+
for status in statuses:
|
|
526
|
+
if status.is_authenticated:
|
|
527
|
+
source_str = f"({status.source})" if status.source else ""
|
|
528
|
+
typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
|
|
529
|
+
else:
|
|
530
|
+
typer.echo(f" {status.display_name}: ✗ Not configured")
|
|
531
|
+
typer.echo(f" Run: wafer auth login {status.provider}")
|
|
532
|
+
typer.echo(f" Or set: {status.key_url}")
|
|
533
|
+
|
|
534
|
+
typer.echo("")
|
|
535
|
+
typer.echo("Note: Environment variables take precedence over stored keys.")
|
|
536
|
+
|
|
537
|
+
|
|
399
538
|
@app.command(hidden=True)
|
|
400
539
|
def run(
|
|
401
540
|
command: str = typer.Argument(..., help="Command to run in Docker container"),
|
|
@@ -1289,86 +1428,37 @@ def evaluate( # noqa: PLR0913
|
|
|
1289
1428
|
--benchmark --defensive
|
|
1290
1429
|
|
|
1291
1430
|
Subcommands:
|
|
1292
|
-
|
|
1431
|
+
gpumode Use GPUMode format (functional) - RECOMMENDED
|
|
1293
1432
|
kernelbench Use KernelBench format (ModelNew class)
|
|
1433
|
+
make-template Generate template files for this format (deprecated)
|
|
1294
1434
|
"""
|
|
1295
1435
|
# If a subcommand is being invoked, skip the main evaluation logic
|
|
1296
1436
|
if ctx.invoked_subcommand is not None:
|
|
1297
1437
|
return
|
|
1298
1438
|
|
|
1299
|
-
#
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
if test_cases is None:
|
|
1306
|
-
missing_args.append("--test-cases")
|
|
1307
|
-
|
|
1308
|
-
if missing_args:
|
|
1309
|
-
typer.echo("Error: Missing required arguments", err=True)
|
|
1310
|
-
typer.echo(f" Required: {', '.join(missing_args)}", err=True)
|
|
1311
|
-
typer.echo("", err=True)
|
|
1312
|
-
typer.echo(
|
|
1313
|
-
"Usage: wafer evaluate --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
|
|
1314
|
-
err=True,
|
|
1315
|
-
)
|
|
1316
|
-
typer.echo("", err=True)
|
|
1317
|
-
typer.echo("Run 'wafer evaluate --help' for full options.", err=True)
|
|
1318
|
-
typer.echo("Run 'wafer evaluate make-template DIR' to generate starter files.", err=True)
|
|
1319
|
-
raise typer.Exit(1)
|
|
1320
|
-
|
|
1321
|
-
from .evaluate import EvaluateArgs, run_evaluate
|
|
1322
|
-
|
|
1323
|
-
args = EvaluateArgs(
|
|
1324
|
-
implementation=implementation,
|
|
1325
|
-
reference=reference,
|
|
1326
|
-
test_cases=test_cases,
|
|
1327
|
-
target_name=target or "",
|
|
1328
|
-
benchmark=benchmark,
|
|
1329
|
-
profile=profile,
|
|
1330
|
-
defensive=defensive,
|
|
1331
|
-
sync_artifacts=sync_artifacts,
|
|
1332
|
-
gpu_id=gpu_id,
|
|
1439
|
+
# Bare 'wafer evaluate' is no longer supported - must use subcommand
|
|
1440
|
+
typer.echo("Error: 'wafer evaluate' requires a subcommand.", err=True)
|
|
1441
|
+
typer.echo("", err=True)
|
|
1442
|
+
typer.echo("Available subcommands:", err=True)
|
|
1443
|
+
typer.echo(
|
|
1444
|
+
" gpumode Evaluate GPUMode format (custom_kernel/ref_kernel functions)", err=True
|
|
1333
1445
|
)
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
typer.echo(f"Error: {e}", err=True)
|
|
1351
|
-
raise typer.Exit(1) from None
|
|
1352
|
-
|
|
1353
|
-
# Print results
|
|
1354
|
-
if result.success:
|
|
1355
|
-
typer.echo("")
|
|
1356
|
-
typer.echo("=" * 60)
|
|
1357
|
-
status = "PASS" if result.all_correct else "FAIL"
|
|
1358
|
-
typer.echo(f"Result: {status}")
|
|
1359
|
-
score_pct = f"{result.correctness_score:.1%}"
|
|
1360
|
-
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
1361
|
-
if result.geomean_speedup > 0:
|
|
1362
|
-
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
1363
|
-
if result.artifact_path:
|
|
1364
|
-
typer.echo(f"Artifacts: {result.artifact_path}")
|
|
1365
|
-
typer.echo("=" * 60)
|
|
1366
|
-
|
|
1367
|
-
if not result.all_correct:
|
|
1368
|
-
raise typer.Exit(1)
|
|
1369
|
-
else:
|
|
1370
|
-
typer.echo(f"Error: {result.error_message}", err=True)
|
|
1371
|
-
raise typer.Exit(1)
|
|
1446
|
+
typer.echo(" kernelbench Evaluate KernelBench format (ModelNew class)", err=True)
|
|
1447
|
+
typer.echo("", err=True)
|
|
1448
|
+
typer.echo("Examples:", err=True)
|
|
1449
|
+
typer.echo(
|
|
1450
|
+
" wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json",
|
|
1451
|
+
err=True,
|
|
1452
|
+
)
|
|
1453
|
+
typer.echo(
|
|
1454
|
+
" wafer evaluate kernelbench --impl impl.py --reference ref.py --benchmark", err=True
|
|
1455
|
+
)
|
|
1456
|
+
typer.echo("", err=True)
|
|
1457
|
+
typer.echo(
|
|
1458
|
+
"Run 'wafer evaluate gpumode --help' or 'wafer evaluate kernelbench --help' for options.",
|
|
1459
|
+
err=True,
|
|
1460
|
+
)
|
|
1461
|
+
raise typer.Exit(1)
|
|
1372
1462
|
|
|
1373
1463
|
|
|
1374
1464
|
TEMPLATE_KERNEL = '''\
|
|
@@ -1503,8 +1593,59 @@ def evaluate_make_template(
|
|
|
1503
1593
|
# KernelBench format evaluation
|
|
1504
1594
|
# =============================================================================
|
|
1505
1595
|
|
|
1506
|
-
|
|
1507
|
-
|
|
1596
|
+
|
|
1597
|
+
def _get_kernelbench_root() -> Path | None:
|
|
1598
|
+
"""Get KernelBench problems root, preferring downloaded location."""
|
|
1599
|
+
# First check downloaded location
|
|
1600
|
+
downloaded = get_problems_path("kernelbench")
|
|
1601
|
+
if downloaded is not None:
|
|
1602
|
+
kb_root = downloaded / "KernelBench"
|
|
1603
|
+
if kb_root.exists():
|
|
1604
|
+
return kb_root
|
|
1605
|
+
return downloaded
|
|
1606
|
+
|
|
1607
|
+
# Fall back to legacy location (for development)
|
|
1608
|
+
legacy = Path(__file__).parent.parent.parent.parent / "research" / "KernelBench" / "KernelBench"
|
|
1609
|
+
if legacy.exists():
|
|
1610
|
+
return legacy
|
|
1611
|
+
|
|
1612
|
+
return None
|
|
1613
|
+
|
|
1614
|
+
|
|
1615
|
+
@kernelbench_app.command("download")
|
|
1616
|
+
def kernelbench_download(
|
|
1617
|
+
force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
|
|
1618
|
+
) -> None:
|
|
1619
|
+
"""Download KernelBench problems from GitHub.
|
|
1620
|
+
|
|
1621
|
+
Downloads the problem set to ~/.cache/wafer/problems/kernelbench/
|
|
1622
|
+
|
|
1623
|
+
Examples:
|
|
1624
|
+
wafer evaluate kernelbench download
|
|
1625
|
+
wafer evaluate kernelbench download --force # Re-download
|
|
1626
|
+
"""
|
|
1627
|
+
try:
|
|
1628
|
+
path = download_problems("kernelbench", force=force, verbose=True)
|
|
1629
|
+
typer.echo("")
|
|
1630
|
+
typer.echo(f"Problems available at: {path}")
|
|
1631
|
+
typer.echo("Run 'wafer evaluate kernelbench list-problems' to see available problems.")
|
|
1632
|
+
except Exception as e:
|
|
1633
|
+
typer.echo(f"Error downloading problems: {e}", err=True)
|
|
1634
|
+
raise typer.Exit(1) from None
|
|
1635
|
+
|
|
1636
|
+
|
|
1637
|
+
@kernelbench_app.command("list-problems")
|
|
1638
|
+
def kernelbench_list_problems() -> None:
|
|
1639
|
+
"""List available KernelBench problems.
|
|
1640
|
+
|
|
1641
|
+
Examples:
|
|
1642
|
+
wafer evaluate kernelbench list-problems
|
|
1643
|
+
"""
|
|
1644
|
+
try:
|
|
1645
|
+
list_problems_fn("kernelbench", verbose=True)
|
|
1646
|
+
except ValueError as e:
|
|
1647
|
+
typer.echo(str(e), err=True)
|
|
1648
|
+
raise typer.Exit(1) from None
|
|
1508
1649
|
|
|
1509
1650
|
|
|
1510
1651
|
@kernelbench_app.callback(invoke_without_command=True)
|
|
@@ -1528,9 +1669,18 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1528
1669
|
help="GPU target name. See 'wafer config targets list' for available targets.",
|
|
1529
1670
|
autocompletion=complete_target_name,
|
|
1530
1671
|
),
|
|
1672
|
+
pool: str | None = typer.Option(
|
|
1673
|
+
None,
|
|
1674
|
+
"--pool",
|
|
1675
|
+
"-p",
|
|
1676
|
+
help="Target pool name. Acquires first available target from the pool. "
|
|
1677
|
+
"Define pools in ~/.wafer/config.toml under [pools.<name>].",
|
|
1678
|
+
),
|
|
1531
1679
|
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
1532
1680
|
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
1533
|
-
inputs: Path | None = typer.Option(
|
|
1681
|
+
inputs: Path | None = typer.Option(
|
|
1682
|
+
None, "--inputs", help="Custom inputs file to override get_inputs()"
|
|
1683
|
+
),
|
|
1534
1684
|
seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
|
|
1535
1685
|
defensive: bool = typer.Option(
|
|
1536
1686
|
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
@@ -1588,12 +1738,54 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1588
1738
|
)
|
|
1589
1739
|
raise typer.Exit(1)
|
|
1590
1740
|
|
|
1741
|
+
# Validate --target and --pool are mutually exclusive
|
|
1742
|
+
if target and pool:
|
|
1743
|
+
typer.echo("Error: Cannot specify both --target and --pool", err=True)
|
|
1744
|
+
raise typer.Exit(1)
|
|
1745
|
+
|
|
1591
1746
|
from .evaluate import KernelBenchEvaluateArgs, run_evaluate_kernelbench
|
|
1592
1747
|
|
|
1748
|
+
# If pool specified, acquire a target from the pool
|
|
1749
|
+
resolved_target = target or ""
|
|
1750
|
+
pool_lock_context = None
|
|
1751
|
+
|
|
1752
|
+
if pool:
|
|
1753
|
+
from .target_lock import acquire_from_pool
|
|
1754
|
+
from .targets import filter_pool_by_auth, get_pool
|
|
1755
|
+
|
|
1756
|
+
try:
|
|
1757
|
+
pool_targets = get_pool(pool)
|
|
1758
|
+
except FileNotFoundError as e:
|
|
1759
|
+
typer.echo(f"Error: {e}", err=True)
|
|
1760
|
+
raise typer.Exit(1) from None
|
|
1761
|
+
|
|
1762
|
+
# Filter to only targets with valid auth
|
|
1763
|
+
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1764
|
+
if skipped:
|
|
1765
|
+
typer.echo(f"Skipping targets without auth: {', '.join(skipped)}", err=True)
|
|
1766
|
+
|
|
1767
|
+
if not usable_targets:
|
|
1768
|
+
typer.echo(f"Error: No usable targets in pool '{pool}'", err=True)
|
|
1769
|
+
typer.echo(" All targets require authentication that is not configured.", err=True)
|
|
1770
|
+
typer.echo(" Run 'wafer auth status' to see which providers need setup.", err=True)
|
|
1771
|
+
raise typer.Exit(1) from None
|
|
1772
|
+
|
|
1773
|
+
typer.echo(f"Acquiring target from pool '{pool}' ({len(usable_targets)} targets)...")
|
|
1774
|
+
pool_lock_context = acquire_from_pool(usable_targets)
|
|
1775
|
+
acquired_target = pool_lock_context.__enter__()
|
|
1776
|
+
|
|
1777
|
+
if acquired_target is None:
|
|
1778
|
+
typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
|
|
1779
|
+
typer.echo(f" Targets: {', '.join(usable_targets)}", err=True)
|
|
1780
|
+
raise typer.Exit(1)
|
|
1781
|
+
|
|
1782
|
+
typer.echo(f"Acquired target: {acquired_target}")
|
|
1783
|
+
resolved_target = acquired_target
|
|
1784
|
+
|
|
1593
1785
|
args = KernelBenchEvaluateArgs(
|
|
1594
1786
|
implementation=implementation,
|
|
1595
1787
|
reference=reference,
|
|
1596
|
-
target_name=
|
|
1788
|
+
target_name=resolved_target,
|
|
1597
1789
|
benchmark=benchmark,
|
|
1598
1790
|
profile=profile,
|
|
1599
1791
|
inputs=inputs,
|
|
@@ -1613,6 +1805,10 @@ def kernelbench_evaluate( # noqa: PLR0913
|
|
|
1613
1805
|
except Exception as e:
|
|
1614
1806
|
typer.echo(f"Error: {e}", err=True)
|
|
1615
1807
|
raise typer.Exit(1) from None
|
|
1808
|
+
finally:
|
|
1809
|
+
# Release pool lock if we acquired one
|
|
1810
|
+
if pool_lock_context is not None:
|
|
1811
|
+
pool_lock_context.__exit__(None, None, None)
|
|
1616
1812
|
|
|
1617
1813
|
# Print results
|
|
1618
1814
|
if result.success:
|
|
@@ -1659,6 +1855,13 @@ def kernelbench_make_template(
|
|
|
1659
1855
|
# Overwrite existing
|
|
1660
1856
|
wafer evaluate kernelbench make-template level1/1 --force
|
|
1661
1857
|
"""
|
|
1858
|
+
# Get problems root (downloaded or legacy)
|
|
1859
|
+
kb_root = _get_kernelbench_root()
|
|
1860
|
+
if kb_root is None:
|
|
1861
|
+
typer.echo("Error: KernelBench problems not found.", err=True)
|
|
1862
|
+
typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
|
|
1863
|
+
raise typer.Exit(1)
|
|
1864
|
+
|
|
1662
1865
|
# Parse problem ID
|
|
1663
1866
|
parts = problem.split("/")
|
|
1664
1867
|
if len(parts) != 2:
|
|
@@ -1670,10 +1873,10 @@ def kernelbench_make_template(
|
|
|
1670
1873
|
level_str = f"level{level_str}"
|
|
1671
1874
|
|
|
1672
1875
|
# Find the problem file
|
|
1673
|
-
problem_dir =
|
|
1876
|
+
problem_dir = kb_root / level_str
|
|
1674
1877
|
if not problem_dir.exists():
|
|
1675
1878
|
typer.echo(f"Error: KernelBench level directory not found: {problem_dir}", err=True)
|
|
1676
|
-
typer.echo(
|
|
1879
|
+
typer.echo("Run 'wafer evaluate kernelbench download' to download problems.", err=True)
|
|
1677
1880
|
raise typer.Exit(1)
|
|
1678
1881
|
|
|
1679
1882
|
# Find matching problem file
|
|
@@ -1708,37 +1911,335 @@ def kernelbench_make_template(
|
|
|
1708
1911
|
|
|
1709
1912
|
output = output.resolve()
|
|
1710
1913
|
|
|
1711
|
-
# Check if exists
|
|
1712
|
-
if output.exists() and not force:
|
|
1713
|
-
typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
|
|
1914
|
+
# Check if exists
|
|
1915
|
+
if output.exists() and not force:
|
|
1916
|
+
typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
|
|
1917
|
+
raise typer.Exit(1)
|
|
1918
|
+
|
|
1919
|
+
# Copy the file
|
|
1920
|
+
content = problem_file.read_text()
|
|
1921
|
+
output.parent.mkdir(parents=True, exist_ok=True)
|
|
1922
|
+
output.write_text(content)
|
|
1923
|
+
|
|
1924
|
+
typer.echo(f"Created {output}")
|
|
1925
|
+
typer.echo("")
|
|
1926
|
+
typer.echo("Next steps:")
|
|
1927
|
+
typer.echo(f" 1. Read {output} to understand the Model interface")
|
|
1928
|
+
typer.echo(" 2. Create an implementation file with your ModelNew class:")
|
|
1929
|
+
typer.echo("")
|
|
1930
|
+
typer.echo(" import torch.nn as nn")
|
|
1931
|
+
typer.echo("")
|
|
1932
|
+
typer.echo(" class ModelNew(nn.Module):")
|
|
1933
|
+
typer.echo(" def __init__(self, ...):")
|
|
1934
|
+
typer.echo(" # Same signature as Model.__init__")
|
|
1935
|
+
typer.echo(" ...")
|
|
1936
|
+
typer.echo("")
|
|
1937
|
+
typer.echo(" def forward(self, ...):")
|
|
1938
|
+
typer.echo(" # Same signature as Model.forward")
|
|
1939
|
+
typer.echo(" # Your optimized implementation here")
|
|
1940
|
+
typer.echo(" ...")
|
|
1941
|
+
typer.echo("")
|
|
1942
|
+
typer.echo(" 3. Run evaluation:")
|
|
1943
|
+
typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
|
|
1944
|
+
|
|
1945
|
+
|
|
1946
|
+
# =============================================================================
|
|
1947
|
+
# GPUMode format evaluation
|
|
1948
|
+
# =============================================================================
|
|
1949
|
+
|
|
1950
|
+
|
|
1951
|
+
@gpumode_app.command("download")
|
|
1952
|
+
def gpumode_download(
|
|
1953
|
+
force: bool = typer.Option(False, "--force", "-f", help="Re-download even if exists"),
|
|
1954
|
+
) -> None:
|
|
1955
|
+
"""Download GPUMode reference kernels from GitHub.
|
|
1956
|
+
|
|
1957
|
+
Downloads the problem set to ~/.cache/wafer/problems/gpumode/
|
|
1958
|
+
|
|
1959
|
+
Examples:
|
|
1960
|
+
wafer evaluate gpumode download
|
|
1961
|
+
wafer evaluate gpumode download --force # Re-download
|
|
1962
|
+
"""
|
|
1963
|
+
try:
|
|
1964
|
+
path = download_problems("gpumode", force=force, verbose=True)
|
|
1965
|
+
typer.echo("")
|
|
1966
|
+
typer.echo(f"Problems available at: {path}")
|
|
1967
|
+
typer.echo("Run 'wafer evaluate gpumode list-problems' to see available problems.")
|
|
1968
|
+
except Exception as e:
|
|
1969
|
+
typer.echo(f"Error downloading problems: {e}", err=True)
|
|
1970
|
+
raise typer.Exit(1) from None
|
|
1971
|
+
|
|
1972
|
+
|
|
1973
|
+
@gpumode_app.command("list-problems")
|
|
1974
|
+
def gpumode_list_problems() -> None:
|
|
1975
|
+
"""List available GPUMode problems.
|
|
1976
|
+
|
|
1977
|
+
Examples:
|
|
1978
|
+
wafer evaluate gpumode list-problems
|
|
1979
|
+
"""
|
|
1980
|
+
try:
|
|
1981
|
+
list_problems_fn("gpumode", verbose=True)
|
|
1982
|
+
except ValueError as e:
|
|
1983
|
+
typer.echo(str(e), err=True)
|
|
1984
|
+
raise typer.Exit(1) from None
|
|
1985
|
+
|
|
1986
|
+
|
|
1987
|
+
@gpumode_app.command("make-template")
|
|
1988
|
+
def gpumode_make_template(
|
|
1989
|
+
problem: str = typer.Option(
|
|
1990
|
+
...,
|
|
1991
|
+
"--problem",
|
|
1992
|
+
"-p",
|
|
1993
|
+
help="Problem ID (e.g., 'pmpp/vectoradd_py' or 'amd/fp8-mm')",
|
|
1994
|
+
),
|
|
1995
|
+
output: Path = typer.Option(
|
|
1996
|
+
None, "--output", "-o", help="Output directory (default: ./<problem_name>/)"
|
|
1997
|
+
),
|
|
1998
|
+
force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
|
|
1999
|
+
) -> None:
|
|
2000
|
+
"""Extract a GPUMode problem as template files.
|
|
2001
|
+
|
|
2002
|
+
Creates a directory with reference.py, task.yml, and other problem files.
|
|
2003
|
+
You then create kernel.py with your custom_kernel implementation.
|
|
2004
|
+
|
|
2005
|
+
Examples:
|
|
2006
|
+
# Extract pmpp vectoradd problem
|
|
2007
|
+
wafer evaluate gpumode make-template --problem pmpp/vectoradd_py
|
|
2008
|
+
|
|
2009
|
+
# Extract to specific directory
|
|
2010
|
+
wafer evaluate gpumode make-template --problem pmpp/vectoradd_py --output ./my-kernel/
|
|
2011
|
+
"""
|
|
2012
|
+
import shutil
|
|
2013
|
+
|
|
2014
|
+
# Get problem path
|
|
2015
|
+
problem_path = get_problem_path("gpumode", problem)
|
|
2016
|
+
if problem_path is None:
|
|
2017
|
+
# Check if problems are downloaded
|
|
2018
|
+
if get_problems_path("gpumode") is None:
|
|
2019
|
+
typer.echo("Error: GPUMode problems not downloaded.", err=True)
|
|
2020
|
+
typer.echo("Run 'wafer evaluate gpumode download' first.", err=True)
|
|
2021
|
+
else:
|
|
2022
|
+
typer.echo(f"Error: Problem '{problem}' not found.", err=True)
|
|
2023
|
+
typer.echo(
|
|
2024
|
+
"Run 'wafer evaluate gpumode list-problems' to see available problems.", err=True
|
|
2025
|
+
)
|
|
2026
|
+
raise typer.Exit(1)
|
|
2027
|
+
|
|
2028
|
+
# Determine output path
|
|
2029
|
+
if output is None:
|
|
2030
|
+
output = Path.cwd() / problem.replace("/", "_")
|
|
2031
|
+
|
|
2032
|
+
output = output.resolve()
|
|
2033
|
+
|
|
2034
|
+
# Check if exists
|
|
2035
|
+
if output.exists() and not force:
|
|
2036
|
+
typer.echo(f"Error: {output} already exists. Use --force to overwrite.", err=True)
|
|
2037
|
+
raise typer.Exit(1)
|
|
2038
|
+
|
|
2039
|
+
# Copy the problem directory
|
|
2040
|
+
if output.exists():
|
|
2041
|
+
shutil.rmtree(output)
|
|
2042
|
+
shutil.copytree(problem_path, output)
|
|
2043
|
+
|
|
2044
|
+
typer.echo(f"Created {output}/")
|
|
2045
|
+
typer.echo("")
|
|
2046
|
+
typer.echo("Contents:")
|
|
2047
|
+
for f in sorted(output.iterdir()):
|
|
2048
|
+
if not f.name.startswith("."):
|
|
2049
|
+
typer.echo(f" {f.name}")
|
|
2050
|
+
typer.echo("")
|
|
2051
|
+
typer.echo("Next steps:")
|
|
2052
|
+
typer.echo(" 1. Read reference.py to understand the kernel interface")
|
|
2053
|
+
typer.echo(" 2. Create kernel.py with your custom_kernel implementation:")
|
|
2054
|
+
typer.echo("")
|
|
2055
|
+
typer.echo(" def custom_kernel(data):")
|
|
2056
|
+
typer.echo(" # Your optimized implementation")
|
|
2057
|
+
typer.echo(" ...")
|
|
2058
|
+
typer.echo("")
|
|
2059
|
+
typer.echo(" 3. Run evaluation:")
|
|
2060
|
+
typer.echo(
|
|
2061
|
+
f" wafer evaluate gpumode --impl {output}/kernel.py --reference {output}/reference.py \\"
|
|
2062
|
+
)
|
|
2063
|
+
typer.echo(f" --test-cases {output}/test_cases.json --target <target>")
|
|
2064
|
+
|
|
2065
|
+
|
|
2066
|
+
@gpumode_app.callback(invoke_without_command=True)
|
|
2067
|
+
def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
2068
|
+
ctx: typer.Context,
|
|
2069
|
+
implementation: Path | None = typer.Option(
|
|
2070
|
+
None, "--impl", "-i", help="Path to implementation kernel file"
|
|
2071
|
+
),
|
|
2072
|
+
reference: Path | None = typer.Option(
|
|
2073
|
+
None, "--reference", help="Path to reference kernel file"
|
|
2074
|
+
),
|
|
2075
|
+
test_cases: Path | None = typer.Option(
|
|
2076
|
+
None, "--test-cases", help="Path to test cases JSON file"
|
|
2077
|
+
),
|
|
2078
|
+
target: str | None = typer.Option(
|
|
2079
|
+
None,
|
|
2080
|
+
"--target",
|
|
2081
|
+
"-t",
|
|
2082
|
+
help="GPU target name. See 'wafer config targets list' for available targets.",
|
|
2083
|
+
autocompletion=complete_target_name,
|
|
2084
|
+
),
|
|
2085
|
+
pool: str | None = typer.Option(
|
|
2086
|
+
None,
|
|
2087
|
+
"--pool",
|
|
2088
|
+
"-p",
|
|
2089
|
+
help="Target pool name. Acquires first available target from the pool. "
|
|
2090
|
+
"Define pools in ~/.wafer/config.toml under [pools.<name>].",
|
|
2091
|
+
),
|
|
2092
|
+
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
2093
|
+
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
2094
|
+
defensive: bool = typer.Option(
|
|
2095
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
2096
|
+
),
|
|
2097
|
+
sync_artifacts: bool = typer.Option(
|
|
2098
|
+
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
2099
|
+
),
|
|
2100
|
+
gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
|
|
2101
|
+
) -> None:
|
|
2102
|
+
"""Run kernel evaluation in GPUMode format (functional).
|
|
2103
|
+
|
|
2104
|
+
This format expects:
|
|
2105
|
+
- Implementation: Python file with `custom_kernel(inputs)` function
|
|
2106
|
+
- Reference: Python file with `ref_kernel(inputs)` and `generate_input(**kwargs)` functions
|
|
2107
|
+
- Test cases: JSON file with test parameters
|
|
2108
|
+
|
|
2109
|
+
Examples:
|
|
2110
|
+
# Basic correctness check
|
|
2111
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
|
|
2112
|
+
|
|
2113
|
+
# With benchmarking
|
|
2114
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
2115
|
+
--target vultr-b200 --benchmark
|
|
2116
|
+
|
|
2117
|
+
Subcommands:
|
|
2118
|
+
download Download GPUMode problems from GitHub
|
|
2119
|
+
list-problems List available problems
|
|
2120
|
+
make-template Extract a problem as template files
|
|
2121
|
+
"""
|
|
2122
|
+
# If a subcommand is being invoked, skip the main evaluation logic
|
|
2123
|
+
if ctx.invoked_subcommand is not None:
|
|
2124
|
+
return
|
|
2125
|
+
|
|
2126
|
+
# Validate required args when running evaluation (not subcommands)
|
|
2127
|
+
missing_args = []
|
|
2128
|
+
if implementation is None:
|
|
2129
|
+
missing_args.append("--impl/-i")
|
|
2130
|
+
if reference is None:
|
|
2131
|
+
missing_args.append("--reference")
|
|
2132
|
+
if test_cases is None:
|
|
2133
|
+
missing_args.append("--test-cases")
|
|
2134
|
+
|
|
2135
|
+
if missing_args:
|
|
2136
|
+
typer.echo("Error: Missing required arguments", err=True)
|
|
2137
|
+
typer.echo(f" Required: {', '.join(missing_args)}", err=True)
|
|
2138
|
+
typer.echo("", err=True)
|
|
2139
|
+
typer.echo(
|
|
2140
|
+
"Usage: wafer evaluate gpumode --impl KERNEL.py --reference REF.py --test-cases TESTS.json",
|
|
2141
|
+
err=True,
|
|
2142
|
+
)
|
|
2143
|
+
typer.echo("", err=True)
|
|
2144
|
+
typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
|
|
2145
|
+
typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
|
|
2146
|
+
raise typer.Exit(1)
|
|
2147
|
+
|
|
2148
|
+
# Validate --target and --pool are mutually exclusive
|
|
2149
|
+
if target and pool:
|
|
2150
|
+
typer.echo("Error: Cannot specify both --target and --pool", err=True)
|
|
2151
|
+
raise typer.Exit(1)
|
|
2152
|
+
|
|
2153
|
+
from .evaluate import EvaluateArgs, run_evaluate
|
|
2154
|
+
|
|
2155
|
+
# If pool specified, acquire a target from the pool
|
|
2156
|
+
resolved_target = target or ""
|
|
2157
|
+
pool_lock_context = None
|
|
2158
|
+
|
|
2159
|
+
if pool:
|
|
2160
|
+
from .target_lock import acquire_from_pool
|
|
2161
|
+
from .targets import filter_pool_by_auth, get_pool
|
|
2162
|
+
|
|
2163
|
+
try:
|
|
2164
|
+
pool_targets = get_pool(pool)
|
|
2165
|
+
except FileNotFoundError as e:
|
|
2166
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2167
|
+
raise typer.Exit(1) from None
|
|
2168
|
+
|
|
2169
|
+
# Filter to only targets with valid auth
|
|
2170
|
+
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
2171
|
+
if skipped:
|
|
2172
|
+
typer.echo(f"Skipping targets without auth: {', '.join(skipped)}", err=True)
|
|
2173
|
+
|
|
2174
|
+
if not usable_targets:
|
|
2175
|
+
typer.echo(f"Error: No usable targets in pool '{pool}'", err=True)
|
|
2176
|
+
typer.echo(" All targets require authentication that is not configured.", err=True)
|
|
2177
|
+
typer.echo(" Run 'wafer auth status' to see which providers need setup.", err=True)
|
|
2178
|
+
raise typer.Exit(1) from None
|
|
2179
|
+
|
|
2180
|
+
typer.echo(f"Acquiring target from pool '{pool}' ({len(usable_targets)} targets)...")
|
|
2181
|
+
pool_lock_context = acquire_from_pool(usable_targets)
|
|
2182
|
+
acquired_target = pool_lock_context.__enter__()
|
|
2183
|
+
|
|
2184
|
+
if acquired_target is None:
|
|
2185
|
+
typer.echo(f"Error: All targets in pool '{pool}' are busy", err=True)
|
|
2186
|
+
typer.echo(f" Targets: {', '.join(usable_targets)}", err=True)
|
|
2187
|
+
raise typer.Exit(1)
|
|
2188
|
+
|
|
2189
|
+
typer.echo(f"Acquired target: {acquired_target}")
|
|
2190
|
+
resolved_target = acquired_target
|
|
2191
|
+
|
|
2192
|
+
args = EvaluateArgs(
|
|
2193
|
+
implementation=implementation,
|
|
2194
|
+
reference=reference,
|
|
2195
|
+
test_cases=test_cases,
|
|
2196
|
+
target_name=resolved_target,
|
|
2197
|
+
benchmark=benchmark,
|
|
2198
|
+
profile=profile,
|
|
2199
|
+
defensive=defensive,
|
|
2200
|
+
sync_artifacts=sync_artifacts,
|
|
2201
|
+
gpu_id=gpu_id,
|
|
2202
|
+
)
|
|
2203
|
+
|
|
2204
|
+
try:
|
|
2205
|
+
import trio_asyncio
|
|
2206
|
+
|
|
2207
|
+
result = trio_asyncio.run(run_evaluate, args)
|
|
2208
|
+
except KeyboardInterrupt:
|
|
2209
|
+
typer.echo("\nInterrupted by user", err=True)
|
|
2210
|
+
raise typer.Exit(130) from None
|
|
2211
|
+
except Exception as e:
|
|
2212
|
+
if hasattr(e, "exceptions") and e.exceptions:
|
|
2213
|
+
for exc in e.exceptions:
|
|
2214
|
+
typer.echo(f"Error: {type(exc).__name__}: {exc}", err=True)
|
|
2215
|
+
else:
|
|
2216
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2217
|
+
raise typer.Exit(1) from None
|
|
2218
|
+
finally:
|
|
2219
|
+
# Release pool lock if we acquired one
|
|
2220
|
+
if pool_lock_context is not None:
|
|
2221
|
+
pool_lock_context.__exit__(None, None, None)
|
|
2222
|
+
|
|
2223
|
+
# Print results
|
|
2224
|
+
if result.success:
|
|
2225
|
+
typer.echo("")
|
|
2226
|
+
typer.echo("=" * 60)
|
|
2227
|
+
status = "PASS" if result.all_correct else "FAIL"
|
|
2228
|
+
typer.echo(f"Result: {status}")
|
|
2229
|
+
score_pct = f"{result.correctness_score:.1%}"
|
|
2230
|
+
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
2231
|
+
if result.geomean_speedup > 0:
|
|
2232
|
+
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
2233
|
+
if result.artifact_path:
|
|
2234
|
+
typer.echo(f"Artifacts: {result.artifact_path}")
|
|
2235
|
+
typer.echo("=" * 60)
|
|
2236
|
+
|
|
2237
|
+
if not result.all_correct:
|
|
2238
|
+
raise typer.Exit(1)
|
|
2239
|
+
else:
|
|
2240
|
+
typer.echo(f"Error: {result.error_message}", err=True)
|
|
1714
2241
|
raise typer.Exit(1)
|
|
1715
2242
|
|
|
1716
|
-
# Copy the file
|
|
1717
|
-
content = problem_file.read_text()
|
|
1718
|
-
output.parent.mkdir(parents=True, exist_ok=True)
|
|
1719
|
-
output.write_text(content)
|
|
1720
|
-
|
|
1721
|
-
typer.echo(f"Created {output}")
|
|
1722
|
-
typer.echo("")
|
|
1723
|
-
typer.echo("Next steps:")
|
|
1724
|
-
typer.echo(f" 1. Read {output} to understand the Model interface")
|
|
1725
|
-
typer.echo(" 2. Create an implementation file with your ModelNew class:")
|
|
1726
|
-
typer.echo("")
|
|
1727
|
-
typer.echo(" import torch.nn as nn")
|
|
1728
|
-
typer.echo("")
|
|
1729
|
-
typer.echo(" class ModelNew(nn.Module):")
|
|
1730
|
-
typer.echo(" def __init__(self, ...):")
|
|
1731
|
-
typer.echo(" # Same signature as Model.__init__")
|
|
1732
|
-
typer.echo(" ...")
|
|
1733
|
-
typer.echo("")
|
|
1734
|
-
typer.echo(" def forward(self, ...):")
|
|
1735
|
-
typer.echo(" # Same signature as Model.forward")
|
|
1736
|
-
typer.echo(" # Your optimized implementation here")
|
|
1737
|
-
typer.echo(" ...")
|
|
1738
|
-
typer.echo("")
|
|
1739
|
-
typer.echo(" 3. Run evaluation:")
|
|
1740
|
-
typer.echo(f" wafer evaluate kernelbench --impl my_kernel.py --reference {output}")
|
|
1741
|
-
|
|
1742
2243
|
|
|
1743
2244
|
# =============================================================================
|
|
1744
2245
|
# Push and Remote-Run commands
|
|
@@ -1871,7 +2372,7 @@ def _run_direct_mode(
|
|
|
1871
2372
|
typer.echo(f"Uploading {upload_dir.name}...")
|
|
1872
2373
|
try:
|
|
1873
2374
|
push_result = push_direct(upload_dir, target)
|
|
1874
|
-
workspace_name = push_result.
|
|
2375
|
+
workspace_name = push_result.workspace_name
|
|
1875
2376
|
typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
|
|
1876
2377
|
except Exception as e:
|
|
1877
2378
|
typer.echo(f"Error uploading: {e}", err=True)
|
|
@@ -2044,27 +2545,41 @@ def login(
|
|
|
2044
2545
|
None, "--token", "-t", help="Access token (skip browser OAuth)"
|
|
2045
2546
|
),
|
|
2046
2547
|
port: int | None = typer.Option(
|
|
2047
|
-
None,
|
|
2548
|
+
None,
|
|
2549
|
+
"--port",
|
|
2550
|
+
"-p",
|
|
2551
|
+
help="Port for OAuth callback server (local only, ignored for SSH)",
|
|
2552
|
+
),
|
|
2553
|
+
no_device_code: bool = typer.Option(
|
|
2554
|
+
False,
|
|
2555
|
+
"--no-device-code",
|
|
2556
|
+
help="Force browser OAuth even on SSH (requires port forwarding)",
|
|
2048
2557
|
),
|
|
2049
2558
|
) -> None:
|
|
2050
2559
|
"""Authenticate CLI with wafer-api via GitHub OAuth.
|
|
2051
2560
|
|
|
2052
|
-
Opens browser for GitHub authentication.
|
|
2561
|
+
Local: Opens browser for GitHub authentication.
|
|
2562
|
+
SSH: Uses device code flow (no port forwarding needed).
|
|
2563
|
+
|
|
2053
2564
|
Uses the API environment from config (see 'wafer config show').
|
|
2054
2565
|
|
|
2055
|
-
SSH Users:
|
|
2056
|
-
-
|
|
2057
|
-
-
|
|
2058
|
-
-
|
|
2059
|
-
|
|
2566
|
+
SSH Users (Easiest):
|
|
2567
|
+
- Just run: wafer login
|
|
2568
|
+
- Visit the URL and enter the code shown
|
|
2569
|
+
- No port forwarding needed!
|
|
2570
|
+
|
|
2571
|
+
SSH with browser (Advanced):
|
|
2572
|
+
- Use --no-device-code to force browser flow
|
|
2573
|
+
- Requires: ssh -L 8765:localhost:8765 user@host
|
|
2060
2574
|
|
|
2061
2575
|
Manual token option:
|
|
2062
2576
|
- Visit auth.wafer.ai, authenticate, copy token from URL
|
|
2063
2577
|
- Run: wafer login --token <paste-token>
|
|
2064
2578
|
|
|
2065
2579
|
Examples:
|
|
2066
|
-
wafer login #
|
|
2067
|
-
wafer login --port
|
|
2580
|
+
wafer login # device code on SSH, browser on local
|
|
2581
|
+
wafer login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2582
|
+
wafer login --port 9000 # custom port for browser flow
|
|
2068
2583
|
wafer login --token xyz # manual token (no browser)
|
|
2069
2584
|
|
|
2070
2585
|
# Change environment:
|
|
@@ -2073,7 +2588,7 @@ def login(
|
|
|
2073
2588
|
"""
|
|
2074
2589
|
import httpx
|
|
2075
2590
|
|
|
2076
|
-
from .auth import browser_login, save_credentials, verify_token
|
|
2591
|
+
from .auth import browser_login, device_code_login, save_credentials, verify_token
|
|
2077
2592
|
from .global_config import get_api_url, get_supabase_url, load_global_config
|
|
2078
2593
|
|
|
2079
2594
|
# Show which environment we're logging into
|
|
@@ -2083,21 +2598,31 @@ def login(
|
|
|
2083
2598
|
typer.echo(f"Auth: {get_supabase_url()}")
|
|
2084
2599
|
typer.echo("")
|
|
2085
2600
|
|
|
2086
|
-
# Auto-detect SSH
|
|
2087
|
-
|
|
2088
|
-
is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
|
|
2089
|
-
if is_ssh:
|
|
2090
|
-
port = 8765
|
|
2091
|
-
typer.echo("🔒 SSH session detected - using port 8765 for OAuth callback")
|
|
2092
|
-
typer.echo(" Make sure you have port forwarding set up:")
|
|
2093
|
-
typer.echo(" ssh -L 8765:localhost:8765 user@host")
|
|
2094
|
-
typer.echo("")
|
|
2601
|
+
# Auto-detect SSH
|
|
2602
|
+
is_ssh = bool(os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"))
|
|
2095
2603
|
|
|
2096
|
-
#
|
|
2604
|
+
# Choose auth method
|
|
2097
2605
|
refresh_token = None
|
|
2098
2606
|
if token is None:
|
|
2099
2607
|
try:
|
|
2100
|
-
|
|
2608
|
+
if is_ssh and not no_device_code:
|
|
2609
|
+
# Use device code flow for SSH (no port forwarding needed)
|
|
2610
|
+
typer.echo("🔒 SSH session detected - using device code authentication")
|
|
2611
|
+
typer.echo(" (No port forwarding required!)")
|
|
2612
|
+
typer.echo("")
|
|
2613
|
+
token, refresh_token = device_code_login()
|
|
2614
|
+
else:
|
|
2615
|
+
# Use browser OAuth for local or if explicitly requested
|
|
2616
|
+
if is_ssh:
|
|
2617
|
+
typer.echo("🔒 SSH session detected - using browser authentication")
|
|
2618
|
+
typer.echo(" Make sure you have port forwarding set up:")
|
|
2619
|
+
if port is None:
|
|
2620
|
+
port = 8765
|
|
2621
|
+
typer.echo(f" ssh -L {port}:localhost:{port} user@host")
|
|
2622
|
+
else:
|
|
2623
|
+
typer.echo(f" ssh -L {port}:localhost:{port} user@host")
|
|
2624
|
+
typer.echo("")
|
|
2625
|
+
token, refresh_token = browser_login(port=port)
|
|
2101
2626
|
except TimeoutError as e:
|
|
2102
2627
|
typer.echo(f"Error: {e}", err=True)
|
|
2103
2628
|
raise typer.Exit(1) from None
|
|
@@ -2146,9 +2671,8 @@ def login(
|
|
|
2146
2671
|
@app.command("logout")
|
|
2147
2672
|
def logout() -> None:
|
|
2148
2673
|
"""Remove stored credentials."""
|
|
2149
|
-
from .auth import clear_credentials
|
|
2150
|
-
|
|
2151
2674
|
from . import analytics
|
|
2675
|
+
from .auth import clear_credentials
|
|
2152
2676
|
|
|
2153
2677
|
# Track logout event first (while credentials still exist for user identification)
|
|
2154
2678
|
# Note: track_logout() handles the case where user is not logged in
|
|
@@ -2625,6 +3149,7 @@ init_app = typer.Typer(
|
|
|
2625
3149
|
|
|
2626
3150
|
Choose based on your GPU access:
|
|
2627
3151
|
|
|
3152
|
+
local GPU on current machine (no SSH)
|
|
2628
3153
|
ssh Your own hardware via SSH
|
|
2629
3154
|
runpod RunPod cloud GPUs (needs WAFER_RUNPOD_API_KEY)
|
|
2630
3155
|
digitalocean DigitalOcean AMD MI300X (needs WAFER_AMD_DIGITALOCEAN_API_KEY)"""
|
|
@@ -2632,6 +3157,92 @@ Choose based on your GPU access:
|
|
|
2632
3157
|
targets_app.add_typer(init_app, name="init")
|
|
2633
3158
|
|
|
2634
3159
|
|
|
3160
|
+
@init_app.command("local")
|
|
3161
|
+
def init_local(
|
|
3162
|
+
name: str = typer.Option("local", "--name", "-n", help="Target name"),
|
|
3163
|
+
gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
|
|
3164
|
+
) -> None:
|
|
3165
|
+
"""Initialize a local target for GPU on current machine.
|
|
3166
|
+
|
|
3167
|
+
Detects your local GPU and configures a target for direct execution
|
|
3168
|
+
(no SSH). Use this when running wafer on the same machine as the GPU.
|
|
3169
|
+
|
|
3170
|
+
Examples:
|
|
3171
|
+
wafer config targets init local
|
|
3172
|
+
wafer config targets init local --name my-5090 --gpu-ids 0,1
|
|
3173
|
+
"""
|
|
3174
|
+
from .targets import save_target
|
|
3175
|
+
|
|
3176
|
+
# Parse GPU IDs
|
|
3177
|
+
try:
|
|
3178
|
+
parsed_gpu_ids = [int(g.strip()) for g in gpu_ids.split(",")]
|
|
3179
|
+
except ValueError:
|
|
3180
|
+
typer.echo(f"Error: Invalid GPU IDs '{gpu_ids}'. Use comma-separated integers.", err=True)
|
|
3181
|
+
raise typer.Exit(1) from None
|
|
3182
|
+
|
|
3183
|
+
typer.echo("Detecting local GPU...")
|
|
3184
|
+
|
|
3185
|
+
try:
|
|
3186
|
+
from wafer_core.gpu_detect import (
|
|
3187
|
+
detect_local_gpu,
|
|
3188
|
+
get_compute_capability,
|
|
3189
|
+
get_torch_requirements,
|
|
3190
|
+
)
|
|
3191
|
+
|
|
3192
|
+
detected_gpu = detect_local_gpu()
|
|
3193
|
+
|
|
3194
|
+
if detected_gpu:
|
|
3195
|
+
typer.echo(f" Found: {detected_gpu.gpu_name}")
|
|
3196
|
+
if detected_gpu.vendor == "nvidia":
|
|
3197
|
+
typer.echo(f" CUDA: {detected_gpu.driver_version}")
|
|
3198
|
+
else:
|
|
3199
|
+
typer.echo(f" ROCm: {detected_gpu.driver_version}")
|
|
3200
|
+
typer.echo(f" GPU count: {detected_gpu.gpu_count}")
|
|
3201
|
+
|
|
3202
|
+
# Get torch requirements and compute capability
|
|
3203
|
+
torch_reqs = get_torch_requirements(detected_gpu)
|
|
3204
|
+
compute_capability = get_compute_capability(detected_gpu)
|
|
3205
|
+
gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
|
|
3206
|
+
|
|
3207
|
+
typer.echo(f" PyTorch: {torch_reqs.packages[0]}")
|
|
3208
|
+
else:
|
|
3209
|
+
typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)", err=True)
|
|
3210
|
+
raise typer.Exit(1)
|
|
3211
|
+
|
|
3212
|
+
except ImportError as e:
|
|
3213
|
+
typer.echo(f"Error: Missing dependency: {e}", err=True)
|
|
3214
|
+
raise typer.Exit(1) from None
|
|
3215
|
+
|
|
3216
|
+
# Build target data
|
|
3217
|
+
target_data = {
|
|
3218
|
+
"name": name,
|
|
3219
|
+
"type": "local",
|
|
3220
|
+
"gpu_ids": parsed_gpu_ids,
|
|
3221
|
+
"gpu_type": gpu_type,
|
|
3222
|
+
"compute_capability": compute_capability,
|
|
3223
|
+
"torch_package": torch_reqs.packages[0],
|
|
3224
|
+
"torch_index_url": torch_reqs.index_url,
|
|
3225
|
+
"vendor": detected_gpu.vendor,
|
|
3226
|
+
"driver_version": detected_gpu.driver_version,
|
|
3227
|
+
}
|
|
3228
|
+
|
|
3229
|
+
try:
|
|
3230
|
+
target = save_target(target_data)
|
|
3231
|
+
typer.echo(f"✓ Created target: {target.name}")
|
|
3232
|
+
typer.echo(" Type: Local (no SSH)")
|
|
3233
|
+
typer.echo(f" GPU IDs: {parsed_gpu_ids}")
|
|
3234
|
+
typer.echo(f" GPU Type: {gpu_type}")
|
|
3235
|
+
typer.echo(f" Compute: {compute_capability}")
|
|
3236
|
+
typer.echo(f" Torch: {torch_reqs.packages[0]}")
|
|
3237
|
+
typer.echo("")
|
|
3238
|
+
typer.echo(
|
|
3239
|
+
f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
|
|
3240
|
+
)
|
|
3241
|
+
except (ValueError, AssertionError) as e:
|
|
3242
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3243
|
+
raise typer.Exit(1) from None
|
|
3244
|
+
|
|
3245
|
+
|
|
2635
3246
|
@init_app.command("runpod")
|
|
2636
3247
|
def init_runpod(
|
|
2637
3248
|
name: str = typer.Option("runpod-mi300x", "--name", "-n", help="Target name"),
|
|
@@ -2795,23 +3406,29 @@ def init_ssh(
|
|
|
2795
3406
|
host: str = typer.Option(..., "--host", "-H", help="SSH host (user@hostname:port)"),
|
|
2796
3407
|
ssh_key: str = typer.Option("~/.ssh/id_ed25519", "--ssh-key", "-k", help="Path to SSH key"),
|
|
2797
3408
|
gpu_ids: str = typer.Option("0", "--gpu-ids", "-g", help="Comma-separated GPU IDs"),
|
|
2798
|
-
gpu_type: str = typer.Option(
|
|
2799
|
-
|
|
3409
|
+
gpu_type: str | None = typer.Option(
|
|
3410
|
+
None, "--gpu-type", help="GPU type (auto-detected if not specified)"
|
|
2800
3411
|
),
|
|
2801
3412
|
docker_image: str | None = typer.Option(
|
|
2802
3413
|
None, "--docker-image", "-d", help="Docker image (optional)"
|
|
2803
3414
|
),
|
|
2804
3415
|
ncu: bool = typer.Option(False, "--ncu/--no-ncu", help="NCU profiling available"),
|
|
3416
|
+
no_detect: bool = typer.Option(False, "--no-detect", help="Skip GPU auto-detection"),
|
|
2805
3417
|
) -> None:
|
|
2806
3418
|
"""Initialize an SSH target for your own GPU hardware.
|
|
2807
3419
|
|
|
2808
3420
|
Creates a target config for direct SSH access to a GPU machine.
|
|
2809
|
-
|
|
3421
|
+
Automatically detects GPU type and selects compatible PyTorch version.
|
|
2810
3422
|
|
|
2811
3423
|
Examples:
|
|
3424
|
+
# Auto-detect GPU (recommended)
|
|
2812
3425
|
wafer config targets init ssh --name my-gpu --host user@192.168.1.100:22
|
|
3426
|
+
|
|
3427
|
+
# Multiple GPUs with NCU profiling
|
|
2813
3428
|
wafer config targets init ssh --name lab-h100 --host ubuntu@gpu.lab.com:22 --gpu-ids 0,1 --ncu
|
|
2814
|
-
|
|
3429
|
+
|
|
3430
|
+
# Skip detection, specify manually
|
|
3431
|
+
wafer config targets init ssh --name my-gpu --host user@host:22 --gpu-type H100 --no-detect
|
|
2815
3432
|
"""
|
|
2816
3433
|
from .targets import save_target
|
|
2817
3434
|
|
|
@@ -2828,17 +3445,86 @@ def init_ssh(
|
|
|
2828
3445
|
typer.echo("Example: user@192.168.1.100:22", err=True)
|
|
2829
3446
|
raise typer.Exit(1)
|
|
2830
3447
|
|
|
3448
|
+
# Auto-detect GPU if not specified
|
|
3449
|
+
detected_gpu = None
|
|
3450
|
+
torch_package = None
|
|
3451
|
+
torch_index_url = None
|
|
3452
|
+
|
|
3453
|
+
if not no_detect:
|
|
3454
|
+
typer.echo(f"Connecting to {host}...")
|
|
3455
|
+
try:
|
|
3456
|
+
import trio
|
|
3457
|
+
import trio_asyncio
|
|
3458
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
3459
|
+
from wafer_core.gpu_detect import (
|
|
3460
|
+
detect_remote_gpu,
|
|
3461
|
+
get_compute_capability,
|
|
3462
|
+
get_torch_requirements,
|
|
3463
|
+
)
|
|
3464
|
+
|
|
3465
|
+
expanded_key = str(Path(ssh_key).expanduser())
|
|
3466
|
+
|
|
3467
|
+
async def _detect() -> None:
|
|
3468
|
+
nonlocal detected_gpu, torch_package, torch_index_url
|
|
3469
|
+
# Need trio_asyncio.open_loop() for asyncssh bridge
|
|
3470
|
+
async with trio_asyncio.open_loop():
|
|
3471
|
+
async with AsyncSSHClient(host, expanded_key) as client:
|
|
3472
|
+
detected_gpu = await detect_remote_gpu(client)
|
|
3473
|
+
|
|
3474
|
+
trio.run(_detect)
|
|
3475
|
+
|
|
3476
|
+
if detected_gpu:
|
|
3477
|
+
typer.echo(f" Found: {detected_gpu.gpu_name}")
|
|
3478
|
+
if detected_gpu.vendor == "nvidia":
|
|
3479
|
+
typer.echo(f" CUDA: {detected_gpu.driver_version}")
|
|
3480
|
+
else:
|
|
3481
|
+
typer.echo(f" ROCm: {detected_gpu.driver_version}")
|
|
3482
|
+
|
|
3483
|
+
# Get torch requirements
|
|
3484
|
+
torch_reqs = get_torch_requirements(detected_gpu)
|
|
3485
|
+
torch_package = torch_reqs.packages[0] # Just torch, not all packages
|
|
3486
|
+
torch_index_url = torch_reqs.index_url
|
|
3487
|
+
typer.echo(f" PyTorch: {torch_package}")
|
|
3488
|
+
|
|
3489
|
+
# Use detected GPU type if not specified
|
|
3490
|
+
if not gpu_type:
|
|
3491
|
+
# Extract GPU name (e.g., "H100" from "NVIDIA H100 80GB HBM3")
|
|
3492
|
+
gpu_type = _extract_gpu_type(detected_gpu.gpu_name)
|
|
3493
|
+
else:
|
|
3494
|
+
typer.echo(" No GPU detected (nvidia-smi/rocm-smi not found)")
|
|
3495
|
+
if not gpu_type:
|
|
3496
|
+
gpu_type = "H100" # Default fallback
|
|
3497
|
+
typer.echo(f" Using default: {gpu_type}")
|
|
3498
|
+
|
|
3499
|
+
except Exception as e:
|
|
3500
|
+
typer.echo(f" Detection failed: {e}", err=True)
|
|
3501
|
+
if not gpu_type:
|
|
3502
|
+
gpu_type = "H100"
|
|
3503
|
+
typer.echo(f" Using default: {gpu_type}")
|
|
3504
|
+
|
|
3505
|
+
# Fallback if no detection
|
|
3506
|
+
if not gpu_type:
|
|
3507
|
+
gpu_type = "H100"
|
|
3508
|
+
|
|
2831
3509
|
# Compute capability mappings
|
|
2832
|
-
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
2838
|
-
|
|
2839
|
-
|
|
2840
|
-
|
|
2841
|
-
|
|
3510
|
+
if detected_gpu:
|
|
3511
|
+
from wafer_core.gpu_detect import get_compute_capability
|
|
3512
|
+
|
|
3513
|
+
compute_capability = get_compute_capability(detected_gpu)
|
|
3514
|
+
else:
|
|
3515
|
+
compute_caps = {
|
|
3516
|
+
"B200": "10.0",
|
|
3517
|
+
"H100": "9.0",
|
|
3518
|
+
"A100": "8.0",
|
|
3519
|
+
"A10": "8.6",
|
|
3520
|
+
"V100": "7.0",
|
|
3521
|
+
"MI300X": "9.4",
|
|
3522
|
+
"MI250X": "9.0",
|
|
3523
|
+
"RTX 5090": "10.0",
|
|
3524
|
+
"RTX 4090": "8.9",
|
|
3525
|
+
"RTX 3090": "8.6",
|
|
3526
|
+
}
|
|
3527
|
+
compute_capability = compute_caps.get(gpu_type, "8.0")
|
|
2842
3528
|
|
|
2843
3529
|
# Build target data
|
|
2844
3530
|
target_data = {
|
|
@@ -2855,6 +3541,12 @@ def init_ssh(
|
|
|
2855
3541
|
if docker_image:
|
|
2856
3542
|
target_data["docker_image"] = docker_image
|
|
2857
3543
|
|
|
3544
|
+
# Add torch requirements if detected
|
|
3545
|
+
if torch_package:
|
|
3546
|
+
target_data["torch_package"] = torch_package
|
|
3547
|
+
if torch_index_url:
|
|
3548
|
+
target_data["torch_index_url"] = torch_index_url
|
|
3549
|
+
|
|
2858
3550
|
try:
|
|
2859
3551
|
target = save_target(target_data)
|
|
2860
3552
|
typer.echo(f"✓ Created target: {target.name}")
|
|
@@ -2862,9 +3554,12 @@ def init_ssh(
|
|
|
2862
3554
|
typer.echo(f" Host: {host}")
|
|
2863
3555
|
typer.echo(f" GPU IDs: {parsed_gpu_ids}")
|
|
2864
3556
|
typer.echo(f" GPU Type: {gpu_type}")
|
|
3557
|
+
typer.echo(f" Compute: {compute_capability}")
|
|
2865
3558
|
typer.echo(f" NCU: {'Yes' if ncu else 'No'}")
|
|
2866
3559
|
if docker_image:
|
|
2867
3560
|
typer.echo(f" Docker: {docker_image}")
|
|
3561
|
+
if torch_package:
|
|
3562
|
+
typer.echo(f" Torch: {torch_package}")
|
|
2868
3563
|
typer.echo("")
|
|
2869
3564
|
typer.echo(
|
|
2870
3565
|
f"Usage: wafer evaluate --target {name} --impl kernel.py --reference ref.py --test-cases tests.json"
|
|
@@ -2874,6 +3569,44 @@ def init_ssh(
|
|
|
2874
3569
|
raise typer.Exit(1) from None
|
|
2875
3570
|
|
|
2876
3571
|
|
|
3572
|
+
def _extract_gpu_type(gpu_name: str) -> str:
|
|
3573
|
+
"""Extract GPU type from full GPU name.
|
|
3574
|
+
|
|
3575
|
+
Examples:
|
|
3576
|
+
"NVIDIA H100 80GB HBM3" -> "H100"
|
|
3577
|
+
"NVIDIA GeForce RTX 4090" -> "RTX 4090"
|
|
3578
|
+
"AMD Instinct MI300X OAM" -> "MI300X"
|
|
3579
|
+
"""
|
|
3580
|
+
gpu_name_upper = gpu_name.upper()
|
|
3581
|
+
|
|
3582
|
+
# Check for known GPU types
|
|
3583
|
+
known_types = [
|
|
3584
|
+
"B200",
|
|
3585
|
+
"B100",
|
|
3586
|
+
"H200",
|
|
3587
|
+
"H100",
|
|
3588
|
+
"A100",
|
|
3589
|
+
"A10",
|
|
3590
|
+
"V100",
|
|
3591
|
+
"RTX 5090",
|
|
3592
|
+
"RTX 5080",
|
|
3593
|
+
"RTX 4090",
|
|
3594
|
+
"RTX 4080",
|
|
3595
|
+
"RTX 3090",
|
|
3596
|
+
"RTX 3080",
|
|
3597
|
+
"MI300X",
|
|
3598
|
+
"MI250X",
|
|
3599
|
+
"MI100",
|
|
3600
|
+
]
|
|
3601
|
+
|
|
3602
|
+
for gpu_type in known_types:
|
|
3603
|
+
if gpu_type in gpu_name_upper:
|
|
3604
|
+
return gpu_type
|
|
3605
|
+
|
|
3606
|
+
# Fallback: return cleaned name
|
|
3607
|
+
return gpu_name.replace("NVIDIA ", "").replace("AMD ", "").strip()
|
|
3608
|
+
|
|
3609
|
+
|
|
2877
3610
|
@targets_app.command("add")
|
|
2878
3611
|
def targets_add(
|
|
2879
3612
|
file_path: Path = typer.Argument(..., help="Path to target TOML file"),
|
|
@@ -2956,6 +3689,93 @@ def targets_show(
|
|
|
2956
3689
|
raise typer.Exit(1) from None
|
|
2957
3690
|
|
|
2958
3691
|
|
|
3692
|
+
@targets_app.command("probe")
|
|
3693
|
+
def targets_probe(
|
|
3694
|
+
name: str = typer.Argument(..., help="Target name"),
|
|
3695
|
+
) -> None:
|
|
3696
|
+
"""Probe a target to discover available compilation backends.
|
|
3697
|
+
|
|
3698
|
+
Connects to the target and checks what's available:
|
|
3699
|
+
- Triton
|
|
3700
|
+
- torch.compile/inductor
|
|
3701
|
+
- HIP/hipcc or CUDA/nvcc
|
|
3702
|
+
- ROCm or CUDA version
|
|
3703
|
+
- Python packages (torch, triton, etc.)
|
|
3704
|
+
|
|
3705
|
+
Example:
|
|
3706
|
+
wafer config targets probe runpod-mi300x
|
|
3707
|
+
"""
|
|
3708
|
+
import trio
|
|
3709
|
+
|
|
3710
|
+
from .targets import ProbeError, load_target, probe_target_capabilities
|
|
3711
|
+
|
|
3712
|
+
try:
|
|
3713
|
+
target = load_target(name)
|
|
3714
|
+
except FileNotFoundError as e:
|
|
3715
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3716
|
+
raise typer.Exit(1) from None
|
|
3717
|
+
|
|
3718
|
+
typer.echo(f"Probing target: {name}...")
|
|
3719
|
+
|
|
3720
|
+
try:
|
|
3721
|
+
capabilities = trio.run(probe_target_capabilities, target)
|
|
3722
|
+
except ProbeError as e:
|
|
3723
|
+
# ProbeError already has actionable context
|
|
3724
|
+
typer.echo(f"\nError: {e}", err=True)
|
|
3725
|
+
raise typer.Exit(1) from None
|
|
3726
|
+
except Exception as e:
|
|
3727
|
+
# Unexpected errors - include type for debugging
|
|
3728
|
+
typer.echo(f"\nUnexpected error probing target: {type(e).__name__}: {e}", err=True)
|
|
3729
|
+
raise typer.Exit(1) from None
|
|
3730
|
+
|
|
3731
|
+
# Display results
|
|
3732
|
+
typer.echo(f"\nTarget: {name}")
|
|
3733
|
+
|
|
3734
|
+
if capabilities.get("gpu_name"):
|
|
3735
|
+
typer.echo(f" GPU: {capabilities['gpu_name']}")
|
|
3736
|
+
if capabilities.get("compute_capability"):
|
|
3737
|
+
typer.echo(f" Compute: {capabilities['compute_capability']}")
|
|
3738
|
+
|
|
3739
|
+
typer.echo("\n Compilation Backends:")
|
|
3740
|
+
backends = capabilities.get("backends", {})
|
|
3741
|
+
|
|
3742
|
+
# Triton
|
|
3743
|
+
triton_ver = backends.get("triton")
|
|
3744
|
+
if triton_ver:
|
|
3745
|
+
typer.echo(f" ✓ Triton: {triton_ver}")
|
|
3746
|
+
else:
|
|
3747
|
+
typer.echo(" ✗ Triton: not installed")
|
|
3748
|
+
|
|
3749
|
+
# torch.compile
|
|
3750
|
+
if triton_ver and backends.get("torch"):
|
|
3751
|
+
typer.echo(" ✓ torch.compile/inductor: available")
|
|
3752
|
+
else:
|
|
3753
|
+
typer.echo(" ✗ torch.compile/inductor: requires Triton")
|
|
3754
|
+
|
|
3755
|
+
# HIP/CUDA compiler
|
|
3756
|
+
if backends.get("hipcc"):
|
|
3757
|
+
typer.echo(f" ✓ HIP/hipcc: {backends['hipcc']}")
|
|
3758
|
+
elif backends.get("nvcc"):
|
|
3759
|
+
typer.echo(f" ✓ CUDA/nvcc: {backends['nvcc']}")
|
|
3760
|
+
else:
|
|
3761
|
+
typer.echo(" ✗ No GPU compiler found")
|
|
3762
|
+
|
|
3763
|
+
# ROCm/CUDA version
|
|
3764
|
+
if capabilities.get("rocm_version"):
|
|
3765
|
+
typer.echo(f" ROCm: {capabilities['rocm_version']}")
|
|
3766
|
+
if capabilities.get("cuda_version"):
|
|
3767
|
+
typer.echo(f" CUDA: {capabilities['cuda_version']}")
|
|
3768
|
+
|
|
3769
|
+
typer.echo("\n Python Environment:")
|
|
3770
|
+
typer.echo(f" Python: {capabilities.get('python_version', 'unknown')}")
|
|
3771
|
+
|
|
3772
|
+
packages = capabilities.get("packages", {})
|
|
3773
|
+
if packages.get("torch"):
|
|
3774
|
+
typer.echo(f" PyTorch: {packages['torch']}")
|
|
3775
|
+
if triton_ver:
|
|
3776
|
+
typer.echo(f" Triton: {triton_ver}")
|
|
3777
|
+
|
|
3778
|
+
|
|
2959
3779
|
@targets_app.command("remove")
|
|
2960
3780
|
def targets_remove(
|
|
2961
3781
|
name: str = typer.Argument(..., help="Target name"),
|
|
@@ -3086,6 +3906,92 @@ def targets_pods() -> None:
|
|
|
3086
3906
|
typer.echo()
|
|
3087
3907
|
|
|
3088
3908
|
|
|
3909
|
+
# ── Pool commands ───────────────────────────────────────────────────────────
|
|
3910
|
+
|
|
3911
|
+
|
|
3912
|
+
@targets_app.command("pool-list")
|
|
3913
|
+
def targets_pool_list() -> None:
|
|
3914
|
+
"""List all configured target pools.
|
|
3915
|
+
|
|
3916
|
+
Example:
|
|
3917
|
+
wafer config targets pool-list
|
|
3918
|
+
"""
|
|
3919
|
+
from .targets import get_pool, list_pools
|
|
3920
|
+
|
|
3921
|
+
pools = list_pools()
|
|
3922
|
+
|
|
3923
|
+
if not pools:
|
|
3924
|
+
typer.echo("No pools configured")
|
|
3925
|
+
typer.echo("")
|
|
3926
|
+
typer.echo("Define pools in ~/.wafer/config.toml:")
|
|
3927
|
+
typer.echo(" [pools.my-pool]")
|
|
3928
|
+
typer.echo(' targets = ["target-1", "target-2"]')
|
|
3929
|
+
return
|
|
3930
|
+
|
|
3931
|
+
typer.echo("Configured pools:\n")
|
|
3932
|
+
for pool_name in pools:
|
|
3933
|
+
try:
|
|
3934
|
+
targets = get_pool(pool_name)
|
|
3935
|
+
typer.echo(f" {pool_name}: {', '.join(targets)}")
|
|
3936
|
+
except Exception as e:
|
|
3937
|
+
typer.echo(f" {pool_name}: (error: {e})")
|
|
3938
|
+
|
|
3939
|
+
|
|
3940
|
+
@targets_app.command("pool-create")
|
|
3941
|
+
def targets_pool_create(
|
|
3942
|
+
name: str = typer.Argument(..., help="Pool name"),
|
|
3943
|
+
targets: list[str] = typer.Argument(..., help="Target names to include in pool"),
|
|
3944
|
+
) -> None:
|
|
3945
|
+
"""Create or update a target pool.
|
|
3946
|
+
|
|
3947
|
+
Example:
|
|
3948
|
+
wafer config targets pool-create mi300x-pool mi300x-1 mi300x-2 mi300x-3
|
|
3949
|
+
"""
|
|
3950
|
+
from .targets import save_pool
|
|
3951
|
+
|
|
3952
|
+
try:
|
|
3953
|
+
save_pool(name, targets)
|
|
3954
|
+
typer.echo(f"Pool '{name}' created with {len(targets)} targets")
|
|
3955
|
+
except FileNotFoundError as e:
|
|
3956
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3957
|
+
raise typer.Exit(1) from None
|
|
3958
|
+
|
|
3959
|
+
|
|
3960
|
+
@targets_app.command("pool-status")
|
|
3961
|
+
def targets_pool_status(
|
|
3962
|
+
name: str = typer.Argument(..., help="Pool name"),
|
|
3963
|
+
) -> None:
|
|
3964
|
+
"""Show status of targets in a pool (locked/available).
|
|
3965
|
+
|
|
3966
|
+
Example:
|
|
3967
|
+
wafer config targets pool-status mi300x-pool
|
|
3968
|
+
"""
|
|
3969
|
+
from .target_lock import get_lock_holder, is_target_locked
|
|
3970
|
+
from .targets import get_pool
|
|
3971
|
+
|
|
3972
|
+
try:
|
|
3973
|
+
targets = get_pool(name)
|
|
3974
|
+
except FileNotFoundError as e:
|
|
3975
|
+
typer.echo(f"Error: {e}", err=True)
|
|
3976
|
+
raise typer.Exit(1) from None
|
|
3977
|
+
|
|
3978
|
+
typer.echo(f"Pool '{name}' ({len(targets)} targets):\n")
|
|
3979
|
+
|
|
3980
|
+
available = 0
|
|
3981
|
+
for target_name in targets:
|
|
3982
|
+
locked = is_target_locked(target_name)
|
|
3983
|
+
if locked:
|
|
3984
|
+
pid = get_lock_holder(target_name)
|
|
3985
|
+
pid_str = f" (pid {pid})" if pid else ""
|
|
3986
|
+
typer.echo(f" [busy] {target_name}{pid_str}")
|
|
3987
|
+
else:
|
|
3988
|
+
typer.echo(f" [free] {target_name}")
|
|
3989
|
+
available += 1
|
|
3990
|
+
|
|
3991
|
+
typer.echo("")
|
|
3992
|
+
typer.echo(f"Available: {available}/{len(targets)}")
|
|
3993
|
+
|
|
3994
|
+
|
|
3089
3995
|
# =============================================================================
|
|
3090
3996
|
# Billing commands
|
|
3091
3997
|
# =============================================================================
|
|
@@ -3119,7 +4025,9 @@ def billing_usage(
|
|
|
3119
4025
|
@billing_app.command("topup")
|
|
3120
4026
|
def billing_topup(
|
|
3121
4027
|
amount: int = typer.Argument(25, help="Amount in dollars ($10-$500)"),
|
|
3122
|
-
no_browser: bool = typer.Option(
|
|
4028
|
+
no_browser: bool = typer.Option(
|
|
4029
|
+
False, "--no-browser", help="Print URL instead of opening browser"
|
|
4030
|
+
),
|
|
3123
4031
|
) -> None:
|
|
3124
4032
|
"""Add credits to your account.
|
|
3125
4033
|
|
|
@@ -3165,7 +4073,9 @@ def billing_topup(
|
|
|
3165
4073
|
|
|
3166
4074
|
@billing_app.command("portal")
|
|
3167
4075
|
def billing_portal(
|
|
3168
|
-
no_browser: bool = typer.Option(
|
|
4076
|
+
no_browser: bool = typer.Option(
|
|
4077
|
+
False, "--no-browser", help="Print URL instead of opening browser"
|
|
4078
|
+
),
|
|
3169
4079
|
) -> None:
|
|
3170
4080
|
"""Open Stripe billing portal.
|
|
3171
4081
|
|
|
@@ -3319,7 +4229,7 @@ def workspaces_exec(
|
|
|
3319
4229
|
workspace: str | None = typer.Argument(
|
|
3320
4230
|
None, help="Workspace name or ID (optional if default set)"
|
|
3321
4231
|
),
|
|
3322
|
-
command: list[str] = typer.Argument(..., help="Command to execute
|
|
4232
|
+
command: list[str] = typer.Argument(..., help="Command to execute"),
|
|
3323
4233
|
timeout: int | None = typer.Option(
|
|
3324
4234
|
None,
|
|
3325
4235
|
"--timeout",
|
|
@@ -3332,13 +4242,23 @@ def workspaces_exec(
|
|
|
3332
4242
|
"-s",
|
|
3333
4243
|
help="Sync local directory to workspace before executing",
|
|
3334
4244
|
),
|
|
4245
|
+
gpu: bool = typer.Option(False, "--gpu", help="Force GPU routing (default behavior)"),
|
|
4246
|
+
cpu: bool = typer.Option(False, "--cpu", help="Run in workspace container (no GPU)"),
|
|
4247
|
+
baremetal: bool = typer.Option(
|
|
4248
|
+
False, "--baremetal", help="Force baremetal target (for hardware counters like ncu/nsys)"
|
|
4249
|
+
),
|
|
3335
4250
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
3336
4251
|
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
3337
4252
|
) -> None:
|
|
3338
|
-
"""Execute a command in workspace
|
|
4253
|
+
"""Execute a command in workspace.
|
|
4254
|
+
|
|
4255
|
+
By default, auto-detects whether to route to GPU based on the command.
|
|
4256
|
+
Use --gpu, --cpu, or --baremetal to override.
|
|
3339
4257
|
|
|
3340
|
-
|
|
3341
|
-
|
|
4258
|
+
Routing options:
|
|
4259
|
+
--gpu Force GPU container (Modal or baremetal with GPU)
|
|
4260
|
+
--cpu Run in workspace container directly (no GPU)
|
|
4261
|
+
--baremetal Force baremetal target (for ncu, nsys, hardware counters)
|
|
3342
4262
|
|
|
3343
4263
|
If workspace is not specified, uses the default workspace from config,
|
|
3344
4264
|
or the only workspace if you have exactly one.
|
|
@@ -3353,6 +4273,21 @@ def workspaces_exec(
|
|
|
3353
4273
|
from .global_config import get_defaults, get_preferences
|
|
3354
4274
|
from .workspaces import exec_command, resolve_workspace, sync_files
|
|
3355
4275
|
|
|
4276
|
+
# Validate mutually exclusive routing flags
|
|
4277
|
+
routing_flags = sum([gpu, cpu, baremetal])
|
|
4278
|
+
if routing_flags > 1:
|
|
4279
|
+
typer.echo("Error: --gpu, --cpu, and --baremetal are mutually exclusive", err=True)
|
|
4280
|
+
raise typer.Exit(1)
|
|
4281
|
+
|
|
4282
|
+
# Determine routing (None = auto-detect)
|
|
4283
|
+
routing: str | None = None
|
|
4284
|
+
if gpu:
|
|
4285
|
+
routing = "gpu"
|
|
4286
|
+
elif cpu:
|
|
4287
|
+
routing = "cpu"
|
|
4288
|
+
elif baremetal:
|
|
4289
|
+
routing = "baremetal"
|
|
4290
|
+
|
|
3356
4291
|
# Resolve workspace (specified, config default, or single workspace)
|
|
3357
4292
|
try:
|
|
3358
4293
|
resolved_workspace = resolve_workspace(workspace)
|
|
@@ -3377,7 +4312,8 @@ def workspaces_exec(
|
|
|
3377
4312
|
show_status = prefs.mode == "explicit"
|
|
3378
4313
|
|
|
3379
4314
|
if show_status:
|
|
3380
|
-
|
|
4315
|
+
routing_label = routing or "auto"
|
|
4316
|
+
typer.echo(f"[wafer] Workspace: {resolved_workspace} (routing: {routing_label})", err=True)
|
|
3381
4317
|
|
|
3382
4318
|
# Sync files if requested
|
|
3383
4319
|
if sync is not None:
|
|
@@ -3413,8 +4349,15 @@ def workspaces_exec(
|
|
|
3413
4349
|
# Remove leading "--" if present (typer passes it through with allow_interspersed_args=False)
|
|
3414
4350
|
if command and command[0] == "--":
|
|
3415
4351
|
command = command[1:]
|
|
3416
|
-
#
|
|
3417
|
-
|
|
4352
|
+
# Handle two cases:
|
|
4353
|
+
# 1. Single element: user quoted the whole command (e.g., "echo hello world")
|
|
4354
|
+
# -> use directly, don't re-quote
|
|
4355
|
+
# 2. Multiple elements: user passed separate args (e.g., -- python -c "print(1)")
|
|
4356
|
+
# -> use shlex.join to properly quote args with spaces
|
|
4357
|
+
if len(command) == 1:
|
|
4358
|
+
command_str = command[0]
|
|
4359
|
+
else:
|
|
4360
|
+
command_str = shlex.join(command)
|
|
3418
4361
|
else:
|
|
3419
4362
|
command_str = command
|
|
3420
4363
|
|
|
@@ -3423,6 +4366,7 @@ def workspaces_exec(
|
|
|
3423
4366
|
workspace_id=resolved_workspace,
|
|
3424
4367
|
command=command_str,
|
|
3425
4368
|
timeout_seconds=effective_timeout,
|
|
4369
|
+
routing=routing,
|
|
3426
4370
|
)
|
|
3427
4371
|
except RuntimeError as e:
|
|
3428
4372
|
typer.echo(f"Error: {e}", err=True)
|
|
@@ -4441,8 +5385,8 @@ def _setup_wafer_core_env() -> None:
|
|
|
4441
5385
|
- WAFER_API_URL: If already set, uses that instead of config
|
|
4442
5386
|
- WAFER_AUTH_TOKEN: If already set, uses that instead of cached token
|
|
4443
5387
|
"""
|
|
4444
|
-
from .global_config import get_api_url
|
|
4445
5388
|
from .auth import get_valid_token
|
|
5389
|
+
from .global_config import get_api_url
|
|
4446
5390
|
|
|
4447
5391
|
# Set API URL (get_api_url already respects WAFER_API_URL env var)
|
|
4448
5392
|
os.environ["WAFER_API_URL"] = get_api_url()
|
|
@@ -4746,8 +5690,8 @@ def capture_command( # noqa: PLR0915
|
|
|
4746
5690
|
import os
|
|
4747
5691
|
import tomllib
|
|
4748
5692
|
|
|
4749
|
-
from .global_config import get_api_url
|
|
4750
5693
|
from .auth import get_valid_token
|
|
5694
|
+
from .global_config import get_api_url
|
|
4751
5695
|
|
|
4752
5696
|
# Set environment variables for wafer-core BEFORE importing it
|
|
4753
5697
|
# wafer-core backend.py reads WAFER_API_URL and WAFER_AUTH_TOKEN from env
|
|
@@ -4951,8 +5895,8 @@ def capture_list_command(
|
|
|
4951
5895
|
"""
|
|
4952
5896
|
import os
|
|
4953
5897
|
|
|
4954
|
-
from .global_config import get_api_url
|
|
4955
5898
|
from .auth import get_valid_token
|
|
5899
|
+
from .global_config import get_api_url
|
|
4956
5900
|
|
|
4957
5901
|
# Set environment variables for wafer-core BEFORE importing it
|
|
4958
5902
|
os.environ["WAFER_API_URL"] = get_api_url()
|
|
@@ -5301,6 +6245,98 @@ def isa_analyze(
|
|
|
5301
6245
|
raise typer.Exit(1) from None
|
|
5302
6246
|
|
|
5303
6247
|
|
|
6248
|
+
# =============================================================================
|
|
6249
|
+
# Kernel Scope Commands (wafer amd kernel-scope ...)
|
|
6250
|
+
# =============================================================================
|
|
6251
|
+
|
|
6252
|
+
|
|
6253
|
+
@kernel_scope_app.command("analyze")
|
|
6254
|
+
def kernel_scope_analyze(
|
|
6255
|
+
path: Path = typer.Argument(..., help="Path to file or directory to analyze"),
|
|
6256
|
+
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
6257
|
+
csv_output: bool = typer.Option(False, "--csv", help="Output as CSV"),
|
|
6258
|
+
recursive: bool = typer.Option(
|
|
6259
|
+
True, "--recursive/--no-recursive", "-r", help="Scan directories recursively"
|
|
6260
|
+
),
|
|
6261
|
+
filter_expr: str | None = typer.Option(
|
|
6262
|
+
None, "--filter", "-f", help="Filter results (e.g., 'spills > 0')"
|
|
6263
|
+
),
|
|
6264
|
+
output_file: Path | None = typer.Option(None, "--output", "-o", help="Write output to file"),
|
|
6265
|
+
kernel_index: int = typer.Option(0, "--kernel", "-k", help="Kernel index if multiple in file"),
|
|
6266
|
+
) -> None:
|
|
6267
|
+
"""Analyze Triton compilation artifacts (ISA, LLVM-IR, TTGIR).
|
|
6268
|
+
|
|
6269
|
+
Performs static analysis to extract performance metrics like register
|
|
6270
|
+
pressure, spills, MFMA density, and occupancy limits.
|
|
6271
|
+
|
|
6272
|
+
Supports:
|
|
6273
|
+
- AMDGCN ISA files (.s, .gcn, .asm)
|
|
6274
|
+
- LLVM-IR files (.ll)
|
|
6275
|
+
- TTGIR files (.ttgir, .ttir, .mlir)
|
|
6276
|
+
|
|
6277
|
+
Examples:
|
|
6278
|
+
wafer amd kernel-scope analyze kernel.s
|
|
6279
|
+
wafer amd kernel-scope analyze kernel.s --json
|
|
6280
|
+
wafer amd kernel-scope analyze ~/.triton/cache/ --filter 'spills > 0'
|
|
6281
|
+
wafer amd kernel-scope analyze . -r --csv -o metrics.csv
|
|
6282
|
+
"""
|
|
6283
|
+
from .kernel_scope import analyze_command
|
|
6284
|
+
|
|
6285
|
+
try:
|
|
6286
|
+
output = analyze_command(
|
|
6287
|
+
path=str(path),
|
|
6288
|
+
json_output=json_output,
|
|
6289
|
+
csv_output=csv_output,
|
|
6290
|
+
recursive=recursive,
|
|
6291
|
+
filter_expr=filter_expr,
|
|
6292
|
+
output_file=str(output_file) if output_file else None,
|
|
6293
|
+
kernel_index=kernel_index,
|
|
6294
|
+
)
|
|
6295
|
+
typer.echo(output)
|
|
6296
|
+
|
|
6297
|
+
except FileNotFoundError as e:
|
|
6298
|
+
typer.echo(f"Error: {e}", err=True)
|
|
6299
|
+
raise typer.Exit(1) from None
|
|
6300
|
+
except RuntimeError as e:
|
|
6301
|
+
typer.echo(f"Error: {e}", err=True)
|
|
6302
|
+
raise typer.Exit(1) from None
|
|
6303
|
+
except Exception as e:
|
|
6304
|
+
typer.echo(f"Error: {e}", err=True)
|
|
6305
|
+
raise typer.Exit(1) from None
|
|
6306
|
+
|
|
6307
|
+
|
|
6308
|
+
@kernel_scope_app.command("metrics")
|
|
6309
|
+
def kernel_scope_metrics() -> None:
|
|
6310
|
+
"""List available metrics for kernel scope analysis.
|
|
6311
|
+
|
|
6312
|
+
Shows all metrics that can be extracted from Triton compilation
|
|
6313
|
+
artifacts, along with their derivation.
|
|
6314
|
+
|
|
6315
|
+
Examples:
|
|
6316
|
+
wafer amd kernel-scope metrics
|
|
6317
|
+
"""
|
|
6318
|
+
from .kernel_scope import metrics_command
|
|
6319
|
+
|
|
6320
|
+
output = metrics_command()
|
|
6321
|
+
typer.echo(output)
|
|
6322
|
+
|
|
6323
|
+
|
|
6324
|
+
@kernel_scope_app.command("targets")
|
|
6325
|
+
def kernel_scope_targets() -> None:
|
|
6326
|
+
"""List supported GPU targets and their specifications.
|
|
6327
|
+
|
|
6328
|
+
Shows hardware specs (VGPRs, SGPRs, LDS, etc.) for each supported
|
|
6329
|
+
AMD GPU architecture.
|
|
6330
|
+
|
|
6331
|
+
Examples:
|
|
6332
|
+
wafer amd kernel-scope targets
|
|
6333
|
+
"""
|
|
6334
|
+
from .kernel_scope import targets_command
|
|
6335
|
+
|
|
6336
|
+
output = targets_command()
|
|
6337
|
+
typer.echo(output)
|
|
6338
|
+
|
|
6339
|
+
|
|
5304
6340
|
def main() -> None:
|
|
5305
6341
|
"""Entry point for wafer CLI."""
|
|
5306
6342
|
app()
|