wafer-cli 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/billing.py +6 -6
- wafer/cli.py +502 -85
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +7 -1
- wafer/evaluate.py +13 -15
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/ssh_keys.py +6 -6
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +1 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +183 -0
- wafer/wevin_cli.py +80 -9
- wafer/workspaces.py +104 -8
- wafer_cli-0.2.25.dist-info/METADATA +107 -0
- wafer_cli-0.2.25.dist-info/RECORD +45 -0
- wafer_cli-0.2.23.dist-info/METADATA +0 -16
- wafer_cli-0.2.23.dist-info/RECORD +0 -41
- {wafer_cli-0.2.23.dist-info → wafer_cli-0.2.25.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.23.dist-info → wafer_cli-0.2.25.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.23.dist-info → wafer_cli-0.2.25.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -194,11 +194,16 @@ def complete_target_name(incomplete: str) -> list[str]:
|
|
|
194
194
|
|
|
195
195
|
# =============================================================================
|
|
196
196
|
# Core subcommand groups (visible in --help)
|
|
197
|
+
#
|
|
198
|
+
# TODO: Further consolidate top-level commands to reduce --help surface area.
|
|
199
|
+
# Candidates:
|
|
200
|
+
# - compare → wafer nvidia compare or keep top-level (cross-platform)
|
|
201
|
+
# - guide/skill/demo → wafer onboard {guide,skill,demo}
|
|
197
202
|
# =============================================================================
|
|
198
203
|
|
|
199
204
|
# Config management (includes targets as nested subcommand)
|
|
200
205
|
config_app = typer.Typer(help="Manage CLI configuration and local GPU targets")
|
|
201
|
-
app.add_typer(config_app, name="config")
|
|
206
|
+
app.add_typer(config_app, name="config", rich_help_panel="Configuration")
|
|
202
207
|
|
|
203
208
|
# Target management - nested under config
|
|
204
209
|
targets_app = typer.Typer(
|
|
@@ -218,7 +223,7 @@ config_app.add_typer(targets_app, name="targets")
|
|
|
218
223
|
workspaces_app = typer.Typer(
|
|
219
224
|
help="""Manage cloud GPU workspaces for remote development.
|
|
220
225
|
|
|
221
|
-
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
|
|
226
|
+
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer auth login).
|
|
222
227
|
|
|
223
228
|
Available GPUs:
|
|
224
229
|
MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
@@ -231,21 +236,21 @@ Commands:
|
|
|
231
236
|
wafer workspaces sync dev ./project # Sync files
|
|
232
237
|
wafer workspaces delete dev # Clean up"""
|
|
233
238
|
)
|
|
234
|
-
app.add_typer(workspaces_app, name="workspaces")
|
|
239
|
+
app.add_typer(workspaces_app, name="workspaces", rich_help_panel="Infrastructure")
|
|
235
240
|
|
|
236
|
-
# SSH Key management (BYOK - Bring Your Own Key)
|
|
241
|
+
# SSH Key management (BYOK - Bring Your Own Key) - nested under config
|
|
237
242
|
ssh_keys_app = typer.Typer(
|
|
238
243
|
help="""Manage SSH public keys for workspace access.
|
|
239
244
|
|
|
240
245
|
Register your SSH public keys here. These keys are installed in all workspaces
|
|
241
246
|
you provision, enabling SSH access from any machine with your private key.
|
|
242
247
|
|
|
243
|
-
wafer ssh-keys list # List registered keys
|
|
244
|
-
wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
|
|
245
|
-
wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
|
|
246
|
-
wafer ssh-keys remove <key-id> # Remove a key"""
|
|
248
|
+
wafer config ssh-keys list # List registered keys
|
|
249
|
+
wafer config ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
|
|
250
|
+
wafer config ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
|
|
251
|
+
wafer config ssh-keys remove <key-id> # Remove a key"""
|
|
247
252
|
)
|
|
248
|
-
|
|
253
|
+
config_app.add_typer(ssh_keys_app, name="ssh-keys")
|
|
249
254
|
|
|
250
255
|
# Target operations (exec/ssh/sync on configured targets)
|
|
251
256
|
targets_ops_app = typer.Typer(
|
|
@@ -261,22 +266,22 @@ Useful for exploratory work, debugging, or custom scripts.
|
|
|
261
266
|
Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
|
|
262
267
|
Configure targets with: wafer config targets init ..."""
|
|
263
268
|
)
|
|
264
|
-
app.add_typer(targets_ops_app, name="targets")
|
|
269
|
+
app.add_typer(targets_ops_app, name="targets", rich_help_panel="Infrastructure")
|
|
265
270
|
|
|
266
|
-
# Billing management
|
|
271
|
+
# Billing management - nested under config
|
|
267
272
|
billing_app = typer.Typer(help="Manage billing, credits, and subscription")
|
|
268
|
-
|
|
273
|
+
config_app.add_typer(billing_app, name="billing")
|
|
269
274
|
|
|
270
275
|
# Corpus management
|
|
271
276
|
corpus_app = typer.Typer(help="Download and manage GPU documentation")
|
|
272
|
-
app.add_typer(corpus_app, name="corpus")
|
|
277
|
+
app.add_typer(corpus_app, name="corpus", rich_help_panel="Kernel Development")
|
|
273
278
|
|
|
274
279
|
# Evaluate (supports multiple kernel formats)
|
|
275
280
|
evaluate_app = typer.Typer(
|
|
276
281
|
help="Test kernel correctness and performance",
|
|
277
282
|
invoke_without_command=True,
|
|
278
283
|
)
|
|
279
|
-
app.add_typer(evaluate_app, name="evaluate")
|
|
284
|
+
app.add_typer(evaluate_app, name="evaluate", rich_help_panel="Kernel Development")
|
|
280
285
|
|
|
281
286
|
# Nested subcommand for kernelbench format
|
|
282
287
|
kernelbench_app = typer.Typer(
|
|
@@ -305,7 +310,7 @@ app.add_typer(dev_app, name="dev")
|
|
|
305
310
|
# =============================================================================
|
|
306
311
|
|
|
307
312
|
nvidia_app = typer.Typer(help="NVIDIA GPU profiling and analysis tools")
|
|
308
|
-
app.add_typer(nvidia_app, name="nvidia")
|
|
313
|
+
app.add_typer(nvidia_app, name="nvidia", rich_help_panel="Profiling")
|
|
309
314
|
|
|
310
315
|
# NCU analysis - under nvidia
|
|
311
316
|
ncu_app = typer.Typer(help="Nsight Compute profile analysis")
|
|
@@ -328,18 +333,25 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
|
|
|
328
333
|
# =============================================================================
|
|
329
334
|
|
|
330
335
|
amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
|
|
331
|
-
app.add_typer(amd_app, name="amd")
|
|
336
|
+
app.add_typer(amd_app, name="amd", rich_help_panel="Profiling")
|
|
332
337
|
|
|
333
338
|
# Unified ISA Analyzer - supports both .co files and Triton artifacts
|
|
334
339
|
isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
|
|
335
340
|
amd_app.add_typer(isa_app, name="isa")
|
|
336
341
|
|
|
342
|
+
# =============================================================================
|
|
343
|
+
# Trace comparison (wafer compare)
|
|
344
|
+
# =============================================================================
|
|
345
|
+
|
|
346
|
+
compare_app = typer.Typer(help="Compare GPU traces across platforms (AMD vs NVIDIA)")
|
|
347
|
+
app.add_typer(compare_app, name="compare", rich_help_panel="Profiling")
|
|
348
|
+
|
|
337
349
|
# =============================================================================
|
|
338
350
|
# Roofline analysis (wafer roofline)
|
|
339
351
|
# =============================================================================
|
|
340
352
|
|
|
341
353
|
|
|
342
|
-
@app.command("roofline")
|
|
354
|
+
@app.command("roofline", rich_help_panel="Kernel Development")
|
|
343
355
|
def roofline_cmd(
|
|
344
356
|
gpu: str | None = typer.Option(
|
|
345
357
|
None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
|
|
@@ -430,7 +442,7 @@ def roofline_cmd(
|
|
|
430
442
|
# =============================================================================
|
|
431
443
|
|
|
432
444
|
skill_app = typer.Typer(help="Manage AI coding assistant skills (Claude Code, Codex)")
|
|
433
|
-
app.add_typer(skill_app, name="skill")
|
|
445
|
+
app.add_typer(skill_app, name="skill", rich_help_panel="Onboarding")
|
|
434
446
|
|
|
435
447
|
|
|
436
448
|
@skill_app.command("install")
|
|
@@ -594,14 +606,17 @@ def skill_status() -> None:
|
|
|
594
606
|
|
|
595
607
|
|
|
596
608
|
# =============================================================================
|
|
597
|
-
#
|
|
609
|
+
# Authentication (wafer auth ...)
|
|
598
610
|
# =============================================================================
|
|
599
611
|
|
|
600
|
-
|
|
601
|
-
app.add_typer(
|
|
612
|
+
auth_app = typer.Typer(help="Authenticate with Wafer and cloud GPU providers")
|
|
613
|
+
app.add_typer(auth_app, name="auth", rich_help_panel="Configuration")
|
|
614
|
+
|
|
615
|
+
providers_app = typer.Typer(help="Manage API keys for cloud GPU providers (RunPod, DigitalOcean, etc.)")
|
|
616
|
+
auth_app.add_typer(providers_app, name="providers")
|
|
602
617
|
|
|
603
618
|
|
|
604
|
-
@
|
|
619
|
+
@providers_app.command("login")
|
|
605
620
|
def provider_auth_login(
|
|
606
621
|
provider: str = typer.Argument(
|
|
607
622
|
...,
|
|
@@ -620,10 +635,10 @@ def provider_auth_login(
|
|
|
620
635
|
(e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
|
|
621
636
|
|
|
622
637
|
Examples:
|
|
623
|
-
wafer auth login anthropic --api-key sk-ant-xxx
|
|
624
|
-
wafer auth login runpod --api-key rp_xxx
|
|
625
|
-
wafer auth login openai --api-key sk-xxx
|
|
626
|
-
echo $API_KEY | wafer auth login anthropic
|
|
638
|
+
wafer auth providers login anthropic --api-key sk-ant-xxx
|
|
639
|
+
wafer auth providers login runpod --api-key rp_xxx
|
|
640
|
+
wafer auth providers login openai --api-key sk-xxx
|
|
641
|
+
echo $API_KEY | wafer auth providers login anthropic
|
|
627
642
|
"""
|
|
628
643
|
import sys
|
|
629
644
|
|
|
@@ -653,7 +668,7 @@ def provider_auth_login(
|
|
|
653
668
|
typer.echo("Stored in: ~/.wafer/auth.json")
|
|
654
669
|
|
|
655
670
|
|
|
656
|
-
@
|
|
671
|
+
@providers_app.command("logout")
|
|
657
672
|
def provider_auth_logout(
|
|
658
673
|
provider: str = typer.Argument(
|
|
659
674
|
...,
|
|
@@ -663,8 +678,8 @@ def provider_auth_logout(
|
|
|
663
678
|
"""Remove stored API key for a cloud GPU provider.
|
|
664
679
|
|
|
665
680
|
Examples:
|
|
666
|
-
wafer auth logout runpod
|
|
667
|
-
wafer auth logout digitalocean
|
|
681
|
+
wafer auth providers logout runpod
|
|
682
|
+
wafer auth providers logout digitalocean
|
|
668
683
|
"""
|
|
669
684
|
from wafer_core.auth import PROVIDERS, remove_api_key
|
|
670
685
|
|
|
@@ -680,7 +695,7 @@ def provider_auth_logout(
|
|
|
680
695
|
typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
|
|
681
696
|
|
|
682
697
|
|
|
683
|
-
@
|
|
698
|
+
@providers_app.command("status")
|
|
684
699
|
def provider_auth_status() -> None:
|
|
685
700
|
"""Show authentication status for all cloud GPU providers.
|
|
686
701
|
|
|
@@ -688,7 +703,7 @@ def provider_auth_status() -> None:
|
|
|
688
703
|
the keys are coming from (environment variable or auth.json).
|
|
689
704
|
|
|
690
705
|
Example:
|
|
691
|
-
wafer auth status
|
|
706
|
+
wafer auth providers status
|
|
692
707
|
"""
|
|
693
708
|
from wafer_core.auth import get_all_auth_status
|
|
694
709
|
|
|
@@ -703,7 +718,7 @@ def provider_auth_status() -> None:
|
|
|
703
718
|
typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
|
|
704
719
|
else:
|
|
705
720
|
typer.echo(f" {status.display_name}: ✗ Not configured")
|
|
706
|
-
typer.echo(f" Run: wafer auth login {status.provider}")
|
|
721
|
+
typer.echo(f" Run: wafer auth providers login {status.provider}")
|
|
707
722
|
typer.echo(f" Or set: {status.key_url}")
|
|
708
723
|
|
|
709
724
|
typer.echo("")
|
|
@@ -1248,7 +1263,7 @@ def config_show_legacy() -> None:
|
|
|
1248
1263
|
config_show_new()
|
|
1249
1264
|
|
|
1250
1265
|
|
|
1251
|
-
@app.command()
|
|
1266
|
+
@app.command(rich_help_panel="Kernel Development")
|
|
1252
1267
|
def agent( # noqa: PLR0913
|
|
1253
1268
|
prompt: str | None = typer.Argument(
|
|
1254
1269
|
None,
|
|
@@ -1318,7 +1333,7 @@ def agent( # noqa: PLR0913
|
|
|
1318
1333
|
None,
|
|
1319
1334
|
"--model",
|
|
1320
1335
|
"-m",
|
|
1321
|
-
help="Model override (default: claude-
|
|
1336
|
+
help="Model override (default: claude-opus-4-5)",
|
|
1322
1337
|
),
|
|
1323
1338
|
json_output: bool = typer.Option(
|
|
1324
1339
|
False,
|
|
@@ -1347,6 +1362,11 @@ def agent( # noqa: PLR0913
|
|
|
1347
1362
|
"--no-sandbox",
|
|
1348
1363
|
help="Disable OS-level sandboxing (YOU accept liability for any damage caused by the agent)",
|
|
1349
1364
|
),
|
|
1365
|
+
no_proxy: bool = typer.Option(
|
|
1366
|
+
False,
|
|
1367
|
+
"--no-proxy",
|
|
1368
|
+
help="Skip wafer proxy, use ANTHROPIC_API_KEY directly",
|
|
1369
|
+
),
|
|
1350
1370
|
) -> None:
|
|
1351
1371
|
"""AI assistant for GPU kernel development.
|
|
1352
1372
|
|
|
@@ -1453,6 +1473,7 @@ def agent( # noqa: PLR0913
|
|
|
1453
1473
|
template_args=parsed_template_args,
|
|
1454
1474
|
corpus_path=corpus_path,
|
|
1455
1475
|
no_sandbox=no_sandbox,
|
|
1476
|
+
no_proxy=no_proxy,
|
|
1456
1477
|
)
|
|
1457
1478
|
|
|
1458
1479
|
|
|
@@ -1527,7 +1548,11 @@ def evaluate( # noqa: PLR0913
|
|
|
1527
1548
|
None, "--reference", help="Path to reference kernel file"
|
|
1528
1549
|
),
|
|
1529
1550
|
test_cases: Path | None = typer.Option(
|
|
1530
|
-
None,
|
|
1551
|
+
None,
|
|
1552
|
+
"--test-cases",
|
|
1553
|
+
help="Path to test cases JSON file. "
|
|
1554
|
+
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
1555
|
+
"Run 'wafer evaluate make-template' to generate an example.",
|
|
1531
1556
|
),
|
|
1532
1557
|
target: str | None = typer.Option(
|
|
1533
1558
|
None,
|
|
@@ -1557,20 +1582,20 @@ def evaluate( # noqa: PLR0913
|
|
|
1557
1582
|
|
|
1558
1583
|
Examples:
|
|
1559
1584
|
# Basic correctness check
|
|
1560
|
-
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json
|
|
1585
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
|
|
1561
1586
|
|
|
1562
1587
|
# With benchmarking on a specific target
|
|
1563
|
-
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1588
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1564
1589
|
--target vultr-b200 --benchmark
|
|
1565
1590
|
|
|
1566
1591
|
# Full evaluation with defensive timing (detects cheating)
|
|
1567
|
-
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1592
|
+
wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1568
1593
|
--benchmark --defensive
|
|
1569
1594
|
|
|
1570
1595
|
Subcommands:
|
|
1571
1596
|
gpumode Use GPUMode format (functional) - RECOMMENDED
|
|
1572
1597
|
kernelbench Use KernelBench format (ModelNew class)
|
|
1573
|
-
make-template Generate template files for this format
|
|
1598
|
+
make-template Generate template files for this format
|
|
1574
1599
|
"""
|
|
1575
1600
|
# If a subcommand is being invoked, skip the main evaluation logic
|
|
1576
1601
|
if ctx.invoked_subcommand is not None:
|
|
@@ -1724,7 +1749,7 @@ def evaluate_make_template(
|
|
|
1724
1749
|
typer.echo(f" 2. Edit {output_dir / 'reference.py'} with the ground truth + input generator")
|
|
1725
1750
|
typer.echo(f" 3. Edit {output_dir / 'test_cases.json'} with your test parameters")
|
|
1726
1751
|
typer.echo(" 4. Run:")
|
|
1727
|
-
typer.echo(f" wafer evaluate --impl {output_dir / 'kernel.py'} \\")
|
|
1752
|
+
typer.echo(f" wafer evaluate gpumode --impl {output_dir / 'kernel.py'} \\")
|
|
1728
1753
|
typer.echo(f" --reference {output_dir / 'reference.py'} \\")
|
|
1729
1754
|
typer.echo(f" --test-cases {output_dir / 'test_cases.json'} --benchmark")
|
|
1730
1755
|
|
|
@@ -2275,7 +2300,11 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2275
2300
|
None, "--reference", help="Path to reference kernel file"
|
|
2276
2301
|
),
|
|
2277
2302
|
test_cases: Path | None = typer.Option(
|
|
2278
|
-
None,
|
|
2303
|
+
None,
|
|
2304
|
+
"--test-cases",
|
|
2305
|
+
help="Path to test cases JSON file. "
|
|
2306
|
+
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
2307
|
+
"Run 'wafer evaluate make-template' to generate an example.",
|
|
2279
2308
|
),
|
|
2280
2309
|
target: str | None = typer.Option(
|
|
2281
2310
|
None,
|
|
@@ -2343,6 +2372,13 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2343
2372
|
err=True,
|
|
2344
2373
|
)
|
|
2345
2374
|
typer.echo("", err=True)
|
|
2375
|
+
if "--test-cases" in missing_args:
|
|
2376
|
+
typer.echo(
|
|
2377
|
+
"Tip: Run 'wafer evaluate make-template' to generate template files "
|
|
2378
|
+
"including test_cases.json.",
|
|
2379
|
+
err=True,
|
|
2380
|
+
)
|
|
2381
|
+
typer.echo("", err=True)
|
|
2346
2382
|
typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
|
|
2347
2383
|
typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
|
|
2348
2384
|
raise typer.Exit(1)
|
|
@@ -2749,7 +2785,7 @@ def remote_run( # noqa: PLR0913
|
|
|
2749
2785
|
# =============================================================================
|
|
2750
2786
|
|
|
2751
2787
|
|
|
2752
|
-
@
|
|
2788
|
+
@auth_app.command("login")
|
|
2753
2789
|
def login(
|
|
2754
2790
|
token: str | None = typer.Option(
|
|
2755
2791
|
None, "--token", "-t", help="Access token (skip browser OAuth)"
|
|
@@ -2774,7 +2810,7 @@ def login(
|
|
|
2774
2810
|
Uses the API environment from config (see 'wafer config show').
|
|
2775
2811
|
|
|
2776
2812
|
SSH Users (Easiest):
|
|
2777
|
-
- Just run: wafer login
|
|
2813
|
+
- Just run: wafer auth login
|
|
2778
2814
|
- Visit the URL and enter the code shown
|
|
2779
2815
|
- No port forwarding needed!
|
|
2780
2816
|
|
|
@@ -2784,17 +2820,17 @@ def login(
|
|
|
2784
2820
|
|
|
2785
2821
|
Manual token option:
|
|
2786
2822
|
- Visit auth.wafer.ai, authenticate, copy token from URL
|
|
2787
|
-
- Run: wafer login --token <paste-token>
|
|
2823
|
+
- Run: wafer auth login --token <paste-token>
|
|
2788
2824
|
|
|
2789
2825
|
Examples:
|
|
2790
|
-
wafer login # device code on SSH, browser on local
|
|
2791
|
-
wafer login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2792
|
-
wafer login --port 9000 # custom port for browser flow
|
|
2793
|
-
wafer login --token xyz # manual token (no browser)
|
|
2826
|
+
wafer auth login # device code on SSH, browser on local
|
|
2827
|
+
wafer auth login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2828
|
+
wafer auth login --port 9000 # custom port for browser flow
|
|
2829
|
+
wafer auth login --token xyz # manual token (no browser)
|
|
2794
2830
|
|
|
2795
2831
|
# Change environment:
|
|
2796
2832
|
wafer config set api.environment staging
|
|
2797
|
-
wafer login
|
|
2833
|
+
wafer auth login
|
|
2798
2834
|
"""
|
|
2799
2835
|
import httpx
|
|
2800
2836
|
|
|
@@ -2878,7 +2914,7 @@ def login(
|
|
|
2878
2914
|
typer.echo("Token saved to ~/.wafer/credentials.json")
|
|
2879
2915
|
|
|
2880
2916
|
|
|
2881
|
-
@
|
|
2917
|
+
@auth_app.command("logout")
|
|
2882
2918
|
def logout() -> None:
|
|
2883
2919
|
"""Remove stored credentials."""
|
|
2884
2920
|
from . import analytics
|
|
@@ -2895,7 +2931,7 @@ def logout() -> None:
|
|
|
2895
2931
|
typer.echo("Not logged in (no credentials found).")
|
|
2896
2932
|
|
|
2897
2933
|
|
|
2898
|
-
@
|
|
2934
|
+
@auth_app.command("whoami")
|
|
2899
2935
|
def whoami(
|
|
2900
2936
|
verify: bool = typer.Option(False, "--verify", "-v", help="Verify token with API"),
|
|
2901
2937
|
refresh: bool = typer.Option(False, "--refresh", "-r", help="Refresh token if expired"),
|
|
@@ -2909,7 +2945,7 @@ def whoami(
|
|
|
2909
2945
|
|
|
2910
2946
|
creds = load_credentials()
|
|
2911
2947
|
if creds is None:
|
|
2912
|
-
typer.echo("Not logged in. Run: wafer login")
|
|
2948
|
+
typer.echo("Not logged in. Run: wafer auth login")
|
|
2913
2949
|
raise typer.Exit(1)
|
|
2914
2950
|
|
|
2915
2951
|
if verify or refresh:
|
|
@@ -2917,7 +2953,7 @@ def whoami(
|
|
|
2917
2953
|
# Try to get valid token with auto-refresh
|
|
2918
2954
|
token = get_valid_token()
|
|
2919
2955
|
if token is None:
|
|
2920
|
-
typer.echo("Token expired and refresh failed. Run: wafer login", err=True)
|
|
2956
|
+
typer.echo("Token expired and refresh failed. Run: wafer auth login", err=True)
|
|
2921
2957
|
raise typer.Exit(1)
|
|
2922
2958
|
if token != creds.access_token:
|
|
2923
2959
|
typer.echo("Token refreshed successfully")
|
|
@@ -2930,10 +2966,10 @@ def whoami(
|
|
|
2930
2966
|
except Exception as e:
|
|
2931
2967
|
if creds.refresh_token and not refresh:
|
|
2932
2968
|
typer.echo(f"Token expired: {e}", err=True)
|
|
2933
|
-
typer.echo("Try: wafer whoami --refresh", err=True)
|
|
2969
|
+
typer.echo("Try: wafer auth whoami --refresh", err=True)
|
|
2934
2970
|
else:
|
|
2935
2971
|
typer.echo(f"Token invalid or expired: {e}", err=True)
|
|
2936
|
-
typer.echo("Run: wafer login", err=True)
|
|
2972
|
+
typer.echo("Run: wafer auth login", err=True)
|
|
2937
2973
|
raise typer.Exit(1) from None
|
|
2938
2974
|
elif creds.email:
|
|
2939
2975
|
typer.echo(creds.email)
|
|
@@ -2941,7 +2977,7 @@ def whoami(
|
|
|
2941
2977
|
typer.echo("Logged in (email not available)")
|
|
2942
2978
|
|
|
2943
2979
|
|
|
2944
|
-
@app.command("guide")
|
|
2980
|
+
@app.command("guide", rich_help_panel="Onboarding")
|
|
2945
2981
|
def guide() -> None:
|
|
2946
2982
|
"""Show the Wafer CLI usage guide.
|
|
2947
2983
|
|
|
@@ -2972,7 +3008,7 @@ demo_app = typer.Typer(
|
|
|
2972
3008
|
wafer demo trace Analyze a sample performance trace
|
|
2973
3009
|
wafer demo eval Run kernel evaluation on cloud GPU (requires login)"""
|
|
2974
3010
|
)
|
|
2975
|
-
app.add_typer(demo_app, name="demo")
|
|
3011
|
+
app.add_typer(demo_app, name="demo", rich_help_panel="Onboarding")
|
|
2976
3012
|
|
|
2977
3013
|
DEMO_TRACES_URL = "https://github.com/wafer-ai/wafer/raw/main/apps/wafer-cli/wafer/demo_data"
|
|
2978
3014
|
DEMO_DIR = Path.home() / ".cache" / "wafer" / "demo"
|
|
@@ -3192,7 +3228,7 @@ def demo_eval(
|
|
|
3192
3228
|
"""Demo: Evaluate a kernel on a cloud GPU.
|
|
3193
3229
|
|
|
3194
3230
|
Creates a workspace, runs a sample Triton kernel evaluation, and cleans up.
|
|
3195
|
-
Requires authentication (wafer login).
|
|
3231
|
+
Requires authentication (wafer auth login).
|
|
3196
3232
|
|
|
3197
3233
|
Example:
|
|
3198
3234
|
wafer demo eval
|
|
@@ -3207,7 +3243,7 @@ def demo_eval(
|
|
|
3207
3243
|
# Check auth first
|
|
3208
3244
|
creds = load_credentials()
|
|
3209
3245
|
if not creds:
|
|
3210
|
-
typer.echo("Error: Not authenticated. Run: wafer login")
|
|
3246
|
+
typer.echo("Error: Not authenticated. Run: wafer auth login")
|
|
3211
3247
|
raise typer.Exit(1)
|
|
3212
3248
|
|
|
3213
3249
|
if not yes:
|
|
@@ -3856,12 +3892,16 @@ def targets_add(
|
|
|
3856
3892
|
|
|
3857
3893
|
@targets_app.command("list")
|
|
3858
3894
|
def targets_list() -> None:
|
|
3859
|
-
"""List all configured targets.
|
|
3895
|
+
"""List all configured targets with live provider status.
|
|
3860
3896
|
|
|
3861
3897
|
Example:
|
|
3862
3898
|
wafer config targets list
|
|
3863
3899
|
"""
|
|
3864
|
-
|
|
3900
|
+
import socket
|
|
3901
|
+
|
|
3902
|
+
import trio
|
|
3903
|
+
|
|
3904
|
+
from .targets import get_default_target, list_targets, load_target, remove_target
|
|
3865
3905
|
|
|
3866
3906
|
targets = list_targets()
|
|
3867
3907
|
default = get_default_target()
|
|
@@ -3871,10 +3911,146 @@ def targets_list() -> None:
|
|
|
3871
3911
|
typer.echo("Add one with: wafer config targets add <path/to/target.toml>")
|
|
3872
3912
|
return
|
|
3873
3913
|
|
|
3914
|
+
def _parse_ssh_target(ssh_target: str) -> tuple[str, int]:
|
|
3915
|
+
"""Extract (host, port) from user@host:port string."""
|
|
3916
|
+
parts = ssh_target.rsplit(":", 1)
|
|
3917
|
+
host_part = parts[0]
|
|
3918
|
+
port = int(parts[1]) if len(parts) > 1 else 22
|
|
3919
|
+
if "@" in host_part:
|
|
3920
|
+
host = host_part.split("@", 1)[1]
|
|
3921
|
+
else:
|
|
3922
|
+
host = host_part
|
|
3923
|
+
return (host, port)
|
|
3924
|
+
|
|
3925
|
+
async def _get_live_provider_endpoints() -> set[tuple[str, int]]:
|
|
3926
|
+
"""Query RunPod + DO APIs. Returns set of live (ip, port) endpoints."""
|
|
3927
|
+
from wafer_core.targets.digitalocean import list_running_droplets
|
|
3928
|
+
from wafer_core.targets.runpod import sync_pods_from_api
|
|
3929
|
+
|
|
3930
|
+
live_endpoints: set[tuple[str, int]] = set()
|
|
3931
|
+
|
|
3932
|
+
async def _fetch_runpod() -> None:
|
|
3933
|
+
try:
|
|
3934
|
+
pods = await sync_pods_from_api()
|
|
3935
|
+
for p in pods:
|
|
3936
|
+
live_endpoints.add((p.public_ip, p.ssh_port))
|
|
3937
|
+
except Exception:
|
|
3938
|
+
pass
|
|
3939
|
+
|
|
3940
|
+
async def _fetch_do() -> None:
|
|
3941
|
+
try:
|
|
3942
|
+
droplets = await list_running_droplets()
|
|
3943
|
+
for d in droplets:
|
|
3944
|
+
live_endpoints.add((d.public_ip, d.ssh_port))
|
|
3945
|
+
except Exception:
|
|
3946
|
+
pass
|
|
3947
|
+
|
|
3948
|
+
async with trio.open_nursery() as nursery:
|
|
3949
|
+
nursery.start_soon(_fetch_runpod)
|
|
3950
|
+
nursery.start_soon(_fetch_do)
|
|
3951
|
+
|
|
3952
|
+
return live_endpoints
|
|
3953
|
+
|
|
3954
|
+
async def _get_target_status(
|
|
3955
|
+
name: str,
|
|
3956
|
+
live_endpoints: set[tuple[str, int]],
|
|
3957
|
+
) -> tuple[str, str, str]:
|
|
3958
|
+
"""Returns (name, status, ssh_info)."""
|
|
3959
|
+
from wafer_core.targets.digitalocean import (
|
|
3960
|
+
_remove_droplet_from_state,
|
|
3961
|
+
check_droplet_running,
|
|
3962
|
+
get_droplet_state,
|
|
3963
|
+
)
|
|
3964
|
+
from wafer_core.targets.runpod import (
|
|
3965
|
+
_remove_pod_from_state,
|
|
3966
|
+
check_pod_running,
|
|
3967
|
+
get_pod_state,
|
|
3968
|
+
)
|
|
3969
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
3970
|
+
BaremetalTarget,
|
|
3971
|
+
DigitalOceanTarget,
|
|
3972
|
+
ModalTarget,
|
|
3973
|
+
RunPodTarget,
|
|
3974
|
+
)
|
|
3975
|
+
|
|
3976
|
+
try:
|
|
3977
|
+
target = load_target(name)
|
|
3978
|
+
except (FileNotFoundError, ValueError, AssertionError, TypeError):
|
|
3979
|
+
return (name, "error", "")
|
|
3980
|
+
|
|
3981
|
+
if isinstance(target, RunPodTarget):
|
|
3982
|
+
pod = get_pod_state(name)
|
|
3983
|
+
if not pod:
|
|
3984
|
+
return (name, "no instance", "")
|
|
3985
|
+
if await check_pod_running(pod.pod_id):
|
|
3986
|
+
return (name, "running", f"{pod.ssh_username}@{pod.public_ip}:{pod.ssh_port}")
|
|
3987
|
+
_remove_pod_from_state(name)
|
|
3988
|
+
return (name, "stopped", "")
|
|
3989
|
+
|
|
3990
|
+
if isinstance(target, DigitalOceanTarget):
|
|
3991
|
+
droplet = get_droplet_state(name)
|
|
3992
|
+
if not droplet:
|
|
3993
|
+
return (name, "no instance", "")
|
|
3994
|
+
if await check_droplet_running(droplet.droplet_id):
|
|
3995
|
+
return (
|
|
3996
|
+
name,
|
|
3997
|
+
"running",
|
|
3998
|
+
f"{droplet.ssh_username}@{droplet.public_ip}:{droplet.ssh_port}",
|
|
3999
|
+
)
|
|
4000
|
+
_remove_droplet_from_state(name)
|
|
4001
|
+
return (name, "stopped", "")
|
|
4002
|
+
|
|
4003
|
+
if isinstance(target, BaremetalTarget):
|
|
4004
|
+
ssh_target = target.ssh_target
|
|
4005
|
+
host, port = _parse_ssh_target(ssh_target)
|
|
4006
|
+
|
|
4007
|
+
def _tcp_check() -> bool:
|
|
4008
|
+
try:
|
|
4009
|
+
sock = socket.create_connection((host, port), timeout=2)
|
|
4010
|
+
sock.close()
|
|
4011
|
+
return True
|
|
4012
|
+
except OSError:
|
|
4013
|
+
return False
|
|
4014
|
+
|
|
4015
|
+
reachable = await trio.to_thread.run_sync(_tcp_check)
|
|
4016
|
+
if reachable:
|
|
4017
|
+
return (name, "reachable", ssh_target)
|
|
4018
|
+
|
|
4019
|
+
# Unreachable + has a provider = backed by an ephemeral instance.
|
|
4020
|
+
# If not in the live provider listing, the instance is gone — remove config.
|
|
4021
|
+
if target.provider and (host, port) not in live_endpoints:
|
|
4022
|
+
remove_target(name)
|
|
4023
|
+
return (name, "removed (dead pod)", ssh_target)
|
|
4024
|
+
|
|
4025
|
+
return (name, "unreachable", ssh_target)
|
|
4026
|
+
|
|
4027
|
+
if isinstance(target, ModalTarget):
|
|
4028
|
+
return (name, "serverless", "")
|
|
4029
|
+
|
|
4030
|
+
# Unknown target type
|
|
4031
|
+
return (name, "unknown", "")
|
|
4032
|
+
|
|
4033
|
+
async def _gather_statuses() -> list[tuple[str, str, str]]:
|
|
4034
|
+
live_endpoints = await _get_live_provider_endpoints()
|
|
4035
|
+
results: list[tuple[str, str, str]] = [("", "", "")] * len(targets)
|
|
4036
|
+
|
|
4037
|
+
async def _check(i: int, name: str) -> None:
|
|
4038
|
+
results[i] = await _get_target_status(name, live_endpoints)
|
|
4039
|
+
|
|
4040
|
+
async with trio.open_nursery() as nursery:
|
|
4041
|
+
for i, name in enumerate(targets):
|
|
4042
|
+
nursery.start_soon(_check, i, name)
|
|
4043
|
+
|
|
4044
|
+
return results
|
|
4045
|
+
|
|
4046
|
+
statuses = trio.run(_gather_statuses)
|
|
4047
|
+
|
|
3874
4048
|
typer.echo("Configured targets:")
|
|
3875
|
-
for name in
|
|
4049
|
+
for name, status, ssh_info in statuses:
|
|
3876
4050
|
marker = " (default)" if name == default else ""
|
|
3877
|
-
|
|
4051
|
+
label = f" {name}{marker}"
|
|
4052
|
+
detail = f" {ssh_info}" if ssh_info else ""
|
|
4053
|
+
typer.echo(f"{label:<40}{status}{detail}")
|
|
3878
4054
|
|
|
3879
4055
|
|
|
3880
4056
|
@targets_app.command("show")
|
|
@@ -4089,10 +4265,19 @@ def targets_cleanup(
|
|
|
4089
4265
|
# Known libraries that can be installed on targets
|
|
4090
4266
|
# TODO: Consider adding HipKittens to the default RunPod/DO Docker images
|
|
4091
4267
|
# so this install step isn't needed. For now, this command handles it.
|
|
4268
|
+
# Architecture → branch mapping for libraries that ship per-arch branches.
|
|
4269
|
+
# "default" is used when the detected arch has no explicit entry.
|
|
4270
|
+
_ARCH_BRANCHES: dict[str, dict[str, str]] = {
|
|
4271
|
+
"hipkittens": {
|
|
4272
|
+
"gfx942": "cdna3", # MI300X, MI325X
|
|
4273
|
+
"default": "main", # MI350X, MI355X, and future CDNA4+
|
|
4274
|
+
},
|
|
4275
|
+
}
|
|
4276
|
+
|
|
4092
4277
|
INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
|
|
4093
4278
|
"hipkittens": {
|
|
4094
|
-
"description": "HipKittens - AMD port of ThunderKittens
|
|
4095
|
-
"git_url": "https://github.com/HazyResearch/
|
|
4279
|
+
"description": "HipKittens - AMD port of ThunderKittens",
|
|
4280
|
+
"git_url": "https://github.com/HazyResearch/HipKittens.git",
|
|
4096
4281
|
"install_path": "/opt/hipkittens",
|
|
4097
4282
|
"requires_amd": True,
|
|
4098
4283
|
},
|
|
@@ -4105,6 +4290,38 @@ INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
|
|
|
4105
4290
|
}
|
|
4106
4291
|
|
|
4107
4292
|
|
|
4293
|
+
def _resolve_gfx_arch(target: object, ssh_cmd: list[str]) -> str | None:
|
|
4294
|
+
"""Return the gfx architecture string for *target*.
|
|
4295
|
+
|
|
4296
|
+
1. If the target config already carries a compute_capability, map it.
|
|
4297
|
+
2. Otherwise SSH in and probe with ``rocminfo``.
|
|
4298
|
+
Returns None only if detection fails entirely.
|
|
4299
|
+
"""
|
|
4300
|
+
import subprocess
|
|
4301
|
+
|
|
4302
|
+
from .evaluate import AMD_CC_TO_ARCH
|
|
4303
|
+
|
|
4304
|
+
cc = getattr(target, "compute_capability", None)
|
|
4305
|
+
if cc and cc in AMD_CC_TO_ARCH:
|
|
4306
|
+
return AMD_CC_TO_ARCH[cc]
|
|
4307
|
+
|
|
4308
|
+
typer.echo(" Detecting GPU architecture via rocminfo...")
|
|
4309
|
+
probe_script = "rocminfo 2>/dev/null | grep -oP 'gfx\\d+' | head -1"
|
|
4310
|
+
result = subprocess.run(
|
|
4311
|
+
ssh_cmd + [probe_script],
|
|
4312
|
+
capture_output=True,
|
|
4313
|
+
text=True,
|
|
4314
|
+
timeout=30,
|
|
4315
|
+
)
|
|
4316
|
+
arch = result.stdout.strip()
|
|
4317
|
+
if result.returncode == 0 and arch.startswith("gfx"):
|
|
4318
|
+
typer.echo(f" Detected: {arch}")
|
|
4319
|
+
return arch
|
|
4320
|
+
|
|
4321
|
+
typer.echo(" Warning: could not detect GPU architecture", err=True)
|
|
4322
|
+
return None
|
|
4323
|
+
|
|
4324
|
+
|
|
4108
4325
|
@targets_app.command("install")
|
|
4109
4326
|
def targets_install(
|
|
4110
4327
|
name: str = typer.Argument(..., help="Target name"),
|
|
@@ -4115,6 +4332,9 @@ def targets_install(
|
|
|
4115
4332
|
Installs header-only libraries like HipKittens on remote targets.
|
|
4116
4333
|
Safe to run multiple times - will skip if already installed.
|
|
4117
4334
|
|
|
4335
|
+
For libraries with per-architecture branches (e.g. HipKittens), the
|
|
4336
|
+
correct branch is selected automatically based on the target's GPU.
|
|
4337
|
+
|
|
4118
4338
|
Available libraries:
|
|
4119
4339
|
hipkittens - HipKittens (AMD ThunderKittens port)
|
|
4120
4340
|
repair-headers - Fix ROCm thrust headers (after hipify corruption)
|
|
@@ -4188,14 +4408,22 @@ def targets_install(
|
|
|
4188
4408
|
install_path = lib_info["install_path"]
|
|
4189
4409
|
git_url = lib_info["git_url"]
|
|
4190
4410
|
|
|
4191
|
-
#
|
|
4411
|
+
# Resolve the branch for arch-aware libraries
|
|
4412
|
+
branch = "main"
|
|
4413
|
+
arch_map = _ARCH_BRANCHES.get(library)
|
|
4414
|
+
if arch_map:
|
|
4415
|
+
gfx = await trio.to_thread.run_sync(lambda: _resolve_gfx_arch(target, ssh_cmd))
|
|
4416
|
+
branch = arch_map.get(gfx, arch_map["default"]) if gfx else arch_map["default"]
|
|
4417
|
+
typer.echo(f" Branch: {branch} (arch={gfx or 'unknown'})")
|
|
4418
|
+
|
|
4419
|
+
# Idempotent: if already cloned, ensure correct branch & pull
|
|
4192
4420
|
install_script = f"""
|
|
4193
4421
|
if [ -d "{install_path}" ]; then
|
|
4194
4422
|
echo "ALREADY_INSTALLED: {install_path} exists"
|
|
4195
|
-
cd {install_path} && git pull --quiet
|
|
4423
|
+
cd {install_path} && git fetch --quiet origin && git checkout {branch} --quiet && git pull --quiet origin {branch}
|
|
4196
4424
|
else
|
|
4197
4425
|
echo "INSTALLING: cloning to {install_path}"
|
|
4198
|
-
git clone --quiet {git_url} {install_path}
|
|
4426
|
+
git clone --quiet --branch {branch} {git_url} {install_path}
|
|
4199
4427
|
fi
|
|
4200
4428
|
echo "DONE"
|
|
4201
4429
|
"""
|
|
@@ -4373,8 +4601,8 @@ def billing_usage(
|
|
|
4373
4601
|
"""Show current billing usage and subscription info.
|
|
4374
4602
|
|
|
4375
4603
|
Example:
|
|
4376
|
-
wafer billing
|
|
4377
|
-
wafer billing --json
|
|
4604
|
+
wafer config billing
|
|
4605
|
+
wafer config billing --json
|
|
4378
4606
|
"""
|
|
4379
4607
|
# Only show usage if no subcommand was invoked
|
|
4380
4608
|
if ctx.invoked_subcommand is not None:
|
|
@@ -4402,9 +4630,9 @@ def billing_topup(
|
|
|
4402
4630
|
Opens a Stripe checkout page to add credits. Default amount is $25.
|
|
4403
4631
|
|
|
4404
4632
|
Example:
|
|
4405
|
-
wafer billing topup # Add $25
|
|
4406
|
-
wafer billing topup 100 # Add $100
|
|
4407
|
-
wafer billing topup --no-browser # Print URL instead
|
|
4633
|
+
wafer config billing topup # Add $25
|
|
4634
|
+
wafer config billing topup 100 # Add $100
|
|
4635
|
+
wafer config billing topup --no-browser # Print URL instead
|
|
4408
4636
|
"""
|
|
4409
4637
|
import webbrowser
|
|
4410
4638
|
|
|
@@ -4450,8 +4678,8 @@ def billing_portal(
|
|
|
4450
4678
|
Manage your subscription, update payment method, or view invoices.
|
|
4451
4679
|
|
|
4452
4680
|
Example:
|
|
4453
|
-
wafer billing portal
|
|
4454
|
-
wafer billing portal --no-browser
|
|
4681
|
+
wafer config billing portal
|
|
4682
|
+
wafer config billing portal --no-browser
|
|
4455
4683
|
"""
|
|
4456
4684
|
import webbrowser
|
|
4457
4685
|
|
|
@@ -4488,8 +4716,8 @@ def ssh_keys_list(
|
|
|
4488
4716
|
"""List all registered SSH public keys.
|
|
4489
4717
|
|
|
4490
4718
|
Example:
|
|
4491
|
-
wafer ssh-keys list
|
|
4492
|
-
wafer ssh-keys list --json
|
|
4719
|
+
wafer config ssh-keys list
|
|
4720
|
+
wafer config ssh-keys list --json
|
|
4493
4721
|
"""
|
|
4494
4722
|
from .ssh_keys import list_ssh_keys
|
|
4495
4723
|
|
|
@@ -4515,9 +4743,9 @@ def ssh_keys_add(
|
|
|
4515
4743
|
id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
|
|
4516
4744
|
|
|
4517
4745
|
Example:
|
|
4518
|
-
wafer ssh-keys add # Auto-detect
|
|
4519
|
-
wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
|
|
4520
|
-
wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
|
|
4746
|
+
wafer config ssh-keys add # Auto-detect
|
|
4747
|
+
wafer config ssh-keys add ~/.ssh/id_rsa.pub # Specific file
|
|
4748
|
+
wafer config ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
|
|
4521
4749
|
"""
|
|
4522
4750
|
from .ssh_keys import add_ssh_key
|
|
4523
4751
|
|
|
@@ -4536,10 +4764,10 @@ def ssh_keys_remove(
|
|
|
4536
4764
|
) -> None:
|
|
4537
4765
|
"""Remove an SSH public key.
|
|
4538
4766
|
|
|
4539
|
-
Get the key ID from 'wafer ssh-keys list'.
|
|
4767
|
+
Get the key ID from 'wafer config ssh-keys list'.
|
|
4540
4768
|
|
|
4541
4769
|
Example:
|
|
4542
|
-
wafer ssh-keys remove abc123-def456-...
|
|
4770
|
+
wafer config ssh-keys remove abc123-def456-...
|
|
4543
4771
|
"""
|
|
4544
4772
|
from .ssh_keys import remove_ssh_key
|
|
4545
4773
|
|
|
@@ -4975,6 +5203,57 @@ def workspaces_sync(
|
|
|
4975
5203
|
raise typer.Exit(1) from None
|
|
4976
5204
|
|
|
4977
5205
|
|
|
5206
|
+
@workspaces_app.command("pull")
|
|
5207
|
+
def workspaces_pull(
|
|
5208
|
+
workspace: str = typer.Argument(..., help="Workspace name or ID"),
|
|
5209
|
+
remote_path: str = typer.Argument(
|
|
5210
|
+
..., help="Remote path in workspace (relative to /workspace or absolute)"
|
|
5211
|
+
),
|
|
5212
|
+
local_path: Path = typer.Argument(
|
|
5213
|
+
Path("."), help="Local destination path (default: current directory)"
|
|
5214
|
+
),
|
|
5215
|
+
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show [wafer] status messages"),
|
|
5216
|
+
quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress [wafer] status messages"),
|
|
5217
|
+
) -> None:
|
|
5218
|
+
"""Pull files from workspace to local machine.
|
|
5219
|
+
|
|
5220
|
+
Uses rsync over SSH to download files from the workspace's /workspace directory.
|
|
5221
|
+
|
|
5222
|
+
Examples:
|
|
5223
|
+
wafer workspaces pull dev kernel.py ./ # Pull single file
|
|
5224
|
+
wafer workspaces pull dev kernel.py ./my_kernel.py # Pull and rename
|
|
5225
|
+
wafer workspaces pull dev /workspace/results ./ # Pull directory
|
|
5226
|
+
"""
|
|
5227
|
+
from .global_config import get_preferences
|
|
5228
|
+
from .workspaces import pull_files
|
|
5229
|
+
|
|
5230
|
+
# Determine verbosity based on mode
|
|
5231
|
+
prefs = get_preferences()
|
|
5232
|
+
if quiet:
|
|
5233
|
+
show_status = False
|
|
5234
|
+
elif verbose:
|
|
5235
|
+
show_status = True
|
|
5236
|
+
else:
|
|
5237
|
+
show_status = prefs.mode == "explicit"
|
|
5238
|
+
|
|
5239
|
+
if show_status:
|
|
5240
|
+
typer.echo(f"[wafer] Pulling {remote_path} from workspace {workspace}...", err=True)
|
|
5241
|
+
|
|
5242
|
+
def on_progress(msg: str) -> None:
|
|
5243
|
+
if show_status:
|
|
5244
|
+
typer.echo(f"[wafer] {msg}", err=True)
|
|
5245
|
+
|
|
5246
|
+
try:
|
|
5247
|
+
file_count = pull_files(
|
|
5248
|
+
workspace, remote_path, local_path.resolve(), on_progress=on_progress
|
|
5249
|
+
)
|
|
5250
|
+
if show_status:
|
|
5251
|
+
typer.echo(f"[wafer] Pulled {file_count} files to {local_path}", err=True)
|
|
5252
|
+
except RuntimeError as e:
|
|
5253
|
+
typer.echo(f"Error: {e}", err=True)
|
|
5254
|
+
raise typer.Exit(1) from None
|
|
5255
|
+
|
|
5256
|
+
|
|
4978
5257
|
# =============================================================================
|
|
4979
5258
|
# Target operations commands (exec/ssh/sync)
|
|
4980
5259
|
# =============================================================================
|
|
@@ -5733,7 +6012,7 @@ def ncu_analyze(
|
|
|
5733
6012
|
compute/memory throughput, and optimization recommendations.
|
|
5734
6013
|
|
|
5735
6014
|
By default, uses local NCU if available, otherwise runs analysis
|
|
5736
|
-
remotely via wafer-api (requires authentication: wafer login).
|
|
6015
|
+
remotely via wafer-api (requires authentication: wafer auth login).
|
|
5737
6016
|
|
|
5738
6017
|
Use --target for direct SSH mode (like wafer remote-run --direct).
|
|
5739
6018
|
Use --include-source to fetch SASS assembly with register/instruction data.
|
|
@@ -5828,7 +6107,7 @@ def nsys_analyze(
|
|
|
5828
6107
|
Returns timeline events, kernel information, memory usage, and diagnostics.
|
|
5829
6108
|
|
|
5830
6109
|
By default, uses local nsys if available, otherwise runs analysis
|
|
5831
|
-
remotely via wafer-api (requires authentication: wafer login).
|
|
6110
|
+
remotely via wafer-api (requires authentication: wafer auth login).
|
|
5832
6111
|
|
|
5833
6112
|
Supports multiple execution modes:
|
|
5834
6113
|
- Local: Uses local nsys CLI (no GPU required for analysis)
|
|
@@ -6813,7 +7092,7 @@ def autotuner_results(
|
|
|
6813
7092
|
raise typer.Exit(1) from None
|
|
6814
7093
|
|
|
6815
7094
|
|
|
6816
|
-
@app.command("capture")
|
|
7095
|
+
@app.command("capture", rich_help_panel="Kernel Development")
|
|
6817
7096
|
def capture_command( # noqa: PLR0915
|
|
6818
7097
|
label: str = typer.Argument(
|
|
6819
7098
|
..., help="Label for this capture (e.g., 'baseline', 'optimized-v2')"
|
|
@@ -7478,6 +7757,144 @@ def isa_targets() -> None:
|
|
|
7478
7757
|
typer.echo(output)
|
|
7479
7758
|
|
|
7480
7759
|
|
|
7760
|
+
# =============================================================================
|
|
7761
|
+
# Trace comparison commands
|
|
7762
|
+
# =============================================================================
|
|
7763
|
+
|
|
7764
|
+
|
|
7765
|
+
@compare_app.command("analyze")
|
|
7766
|
+
def compare_analyze(
|
|
7767
|
+
trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
|
|
7768
|
+
trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
|
|
7769
|
+
format: str = typer.Option(
|
|
7770
|
+
"text",
|
|
7771
|
+
"--format",
|
|
7772
|
+
"-f",
|
|
7773
|
+
help="Output format: text, text-layers, csv, csv-layers, json",
|
|
7774
|
+
),
|
|
7775
|
+
output: Path | None = typer.Option(
|
|
7776
|
+
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7777
|
+
),
|
|
7778
|
+
phase: str = typer.Option(
|
|
7779
|
+
"all",
|
|
7780
|
+
"--phase",
|
|
7781
|
+
help="Filter by phase: all, prefill, decode",
|
|
7782
|
+
),
|
|
7783
|
+
layers: bool = typer.Option(False, "--layers", help="Show layer-wise performance breakdown"),
|
|
7784
|
+
all: bool = typer.Option(
|
|
7785
|
+
False, "--all", help="Show all items (no truncation for layers, operations, kernels)"
|
|
7786
|
+
),
|
|
7787
|
+
stack_traces: bool = typer.Option(
|
|
7788
|
+
False, "--stack-traces", help="Show Python stack traces for operations"
|
|
7789
|
+
),
|
|
7790
|
+
json: bool = typer.Option(
|
|
7791
|
+
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7792
|
+
),
|
|
7793
|
+
) -> None:
|
|
7794
|
+
"""Compare GPU traces from AMD and NVIDIA platforms.
|
|
7795
|
+
|
|
7796
|
+
Analyzes performance differences between traces, identifying which operations
|
|
7797
|
+
are faster/slower on each platform and providing kernel-level details.
|
|
7798
|
+
|
|
7799
|
+
Examples:
|
|
7800
|
+
# Basic comparison (stdout)
|
|
7801
|
+
wafer compare analyze amd_trace.json nvidia_trace.json
|
|
7802
|
+
|
|
7803
|
+
# Show layer-wise breakdown
|
|
7804
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --layers
|
|
7805
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format text-layers
|
|
7806
|
+
|
|
7807
|
+
# Show all layers without truncation
|
|
7808
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --layers --all
|
|
7809
|
+
|
|
7810
|
+
# Show Python stack traces
|
|
7811
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces
|
|
7812
|
+
|
|
7813
|
+
# Show all stack traces without truncation
|
|
7814
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --stack-traces --all
|
|
7815
|
+
|
|
7816
|
+
# Save to file
|
|
7817
|
+
wafer compare analyze amd_trace.json nvidia_trace.json -o report.txt
|
|
7818
|
+
|
|
7819
|
+
# CSV output (operations) to file
|
|
7820
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format csv -o operations.csv
|
|
7821
|
+
|
|
7822
|
+
# CSV output (layers) to file
|
|
7823
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format csv-layers -o layers.csv
|
|
7824
|
+
|
|
7825
|
+
# JSON output to file
|
|
7826
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --format json -o report.json
|
|
7827
|
+
|
|
7828
|
+
# Analyze only prefill phase
|
|
7829
|
+
wafer compare analyze amd_trace.json nvidia_trace.json --phase prefill
|
|
7830
|
+
"""
|
|
7831
|
+
from .trace_compare import compare_traces
|
|
7832
|
+
|
|
7833
|
+
compare_traces(
|
|
7834
|
+
trace1=trace1,
|
|
7835
|
+
trace2=trace2,
|
|
7836
|
+
output=output,
|
|
7837
|
+
output_format=format,
|
|
7838
|
+
phase=phase,
|
|
7839
|
+
show_layers=layers,
|
|
7840
|
+
show_all=all,
|
|
7841
|
+
show_stack_traces=stack_traces,
|
|
7842
|
+
)
|
|
7843
|
+
_mark_command_success()
|
|
7844
|
+
|
|
7845
|
+
|
|
7846
|
+
@compare_app.command("fusion")
|
|
7847
|
+
def compare_fusion_cmd(
|
|
7848
|
+
trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
|
|
7849
|
+
trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
|
|
7850
|
+
format: str = typer.Option(
|
|
7851
|
+
"text",
|
|
7852
|
+
"--format",
|
|
7853
|
+
"-f",
|
|
7854
|
+
help="Output format: text, csv, json",
|
|
7855
|
+
),
|
|
7856
|
+
output: Path | None = typer.Option(
|
|
7857
|
+
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7858
|
+
),
|
|
7859
|
+
min_group_size: int = typer.Option(
|
|
7860
|
+
50,
|
|
7861
|
+
"--min-group-size",
|
|
7862
|
+
help="Minimum correlation group size to analyze",
|
|
7863
|
+
),
|
|
7864
|
+
json: bool = typer.Option(
|
|
7865
|
+
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7866
|
+
),
|
|
7867
|
+
) -> None:
|
|
7868
|
+
"""Analyze kernel fusion differences between AMD and NVIDIA traces.
|
|
7869
|
+
|
|
7870
|
+
Detects which operations are fused differently on each platform by analyzing
|
|
7871
|
+
how many kernel launches each platform uses for the same logical operations.
|
|
7872
|
+
|
|
7873
|
+
Examples:
|
|
7874
|
+
# Basic fusion analysis (stdout)
|
|
7875
|
+
wafer compare fusion amd_trace.json nvidia_trace.json
|
|
7876
|
+
|
|
7877
|
+
# Save to file
|
|
7878
|
+
wafer compare fusion amd_trace.json nvidia_trace.json -o fusion_report.txt
|
|
7879
|
+
|
|
7880
|
+
# JSON output to file
|
|
7881
|
+
wafer compare fusion amd_trace.json nvidia_trace.json --format json -o fusion.json
|
|
7882
|
+
|
|
7883
|
+
# CSV output to file
|
|
7884
|
+
wafer compare fusion amd_trace.json nvidia_trace.json --format csv -o fusion.csv
|
|
7885
|
+
"""
|
|
7886
|
+
from .trace_compare import compare_fusion
|
|
7887
|
+
|
|
7888
|
+
compare_fusion(
|
|
7889
|
+
trace1=trace1,
|
|
7890
|
+
trace2=trace2,
|
|
7891
|
+
output=output,
|
|
7892
|
+
format_type=format,
|
|
7893
|
+
min_group_size=min_group_size,
|
|
7894
|
+
)
|
|
7895
|
+
_mark_command_success()
|
|
7896
|
+
|
|
7897
|
+
|
|
7481
7898
|
def main() -> None:
|
|
7482
7899
|
"""Entry point for wafer CLI."""
|
|
7483
7900
|
app()
|