wafer-cli 0.2.32__py3-none-any.whl → 0.2.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +157 -2
- wafer/billing.py +6 -6
- wafer/cli.py +432 -348
- wafer/corpus.py +6 -72
- wafer/evaluate.py +143 -81
- wafer/global_config.py +0 -13
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +6 -22
- wafer/ssh_keys.py +6 -6
- wafer/targets_ops.py +2 -29
- wafer/templates/aiter_optimize.py +59 -0
- wafer/templates/optimize_kernel.py +2 -4
- wafer/templates/optimize_kernelbench.py +62 -17
- wafer/templates/optimize_vllm.py +156 -0
- wafer/trace_compare.py +48 -139
- wafer/wevin_cli.py +1 -12
- wafer/workspaces.py +8 -8
- wafer_cli-0.2.33.dist-info/METADATA +260 -0
- {wafer_cli-0.2.32.dist-info → wafer_cli-0.2.33.dist-info}/RECORD +25 -23
- wafer_cli-0.2.32.dist-info/METADATA +0 -107
- {wafer_cli-0.2.32.dist-info → wafer_cli-0.2.33.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.32.dist-info → wafer_cli-0.2.33.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.32.dist-info → wafer_cli-0.2.33.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -8,7 +8,6 @@
|
|
|
8
8
|
Core commands:
|
|
9
9
|
agent AI assistant for GPU kernel development
|
|
10
10
|
evaluate Test kernel correctness and performance
|
|
11
|
-
baseline Discover what kernel PyTorch uses for an op
|
|
12
11
|
corpus Download GPU documentation for local access
|
|
13
12
|
workspaces Manage cloud GPU environments
|
|
14
13
|
|
|
@@ -195,16 +194,11 @@ def complete_target_name(incomplete: str) -> list[str]:
|
|
|
195
194
|
|
|
196
195
|
# =============================================================================
|
|
197
196
|
# Core subcommand groups (visible in --help)
|
|
198
|
-
#
|
|
199
|
-
# TODO: Further consolidate top-level commands to reduce --help surface area.
|
|
200
|
-
# Candidates:
|
|
201
|
-
# - compare → wafer nvidia compare or keep top-level (cross-platform)
|
|
202
|
-
# - guide/skill/demo → wafer onboard {guide,skill,demo}
|
|
203
197
|
# =============================================================================
|
|
204
198
|
|
|
205
199
|
# Config management (includes targets as nested subcommand)
|
|
206
200
|
config_app = typer.Typer(help="Manage CLI configuration and local GPU targets")
|
|
207
|
-
app.add_typer(config_app, name="config"
|
|
201
|
+
app.add_typer(config_app, name="config")
|
|
208
202
|
|
|
209
203
|
# Target management - nested under config
|
|
210
204
|
targets_app = typer.Typer(
|
|
@@ -224,7 +218,7 @@ config_app.add_typer(targets_app, name="targets")
|
|
|
224
218
|
workspaces_app = typer.Typer(
|
|
225
219
|
help="""Manage cloud GPU workspaces for remote development.
|
|
226
220
|
|
|
227
|
-
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer
|
|
221
|
+
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
|
|
228
222
|
|
|
229
223
|
Available GPUs:
|
|
230
224
|
MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
@@ -237,21 +231,21 @@ Commands:
|
|
|
237
231
|
wafer workspaces sync dev ./project # Sync files
|
|
238
232
|
wafer workspaces delete dev # Clean up"""
|
|
239
233
|
)
|
|
240
|
-
app.add_typer(workspaces_app, name="workspaces"
|
|
234
|
+
app.add_typer(workspaces_app, name="workspaces")
|
|
241
235
|
|
|
242
|
-
# SSH Key management (BYOK - Bring Your Own Key)
|
|
236
|
+
# SSH Key management (BYOK - Bring Your Own Key)
|
|
243
237
|
ssh_keys_app = typer.Typer(
|
|
244
238
|
help="""Manage SSH public keys for workspace access.
|
|
245
239
|
|
|
246
240
|
Register your SSH public keys here. These keys are installed in all workspaces
|
|
247
241
|
you provision, enabling SSH access from any machine with your private key.
|
|
248
242
|
|
|
249
|
-
wafer
|
|
250
|
-
wafer
|
|
251
|
-
wafer
|
|
252
|
-
wafer
|
|
243
|
+
wafer ssh-keys list # List registered keys
|
|
244
|
+
wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
|
|
245
|
+
wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
|
|
246
|
+
wafer ssh-keys remove <key-id> # Remove a key"""
|
|
253
247
|
)
|
|
254
|
-
|
|
248
|
+
app.add_typer(ssh_keys_app, name="ssh-keys")
|
|
255
249
|
|
|
256
250
|
# Target operations (exec/ssh/sync on configured targets)
|
|
257
251
|
targets_ops_app = typer.Typer(
|
|
@@ -267,48 +261,22 @@ Useful for exploratory work, debugging, or custom scripts.
|
|
|
267
261
|
Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
|
|
268
262
|
Configure targets with: wafer config targets init ..."""
|
|
269
263
|
)
|
|
270
|
-
app.add_typer(targets_ops_app, name="targets"
|
|
264
|
+
app.add_typer(targets_ops_app, name="targets")
|
|
271
265
|
|
|
272
|
-
#
|
|
273
|
-
from wafer.specs_cli import specs_app
|
|
274
|
-
|
|
275
|
-
app.add_typer(specs_app, name="specs", rich_help_panel="Configuration")
|
|
276
|
-
|
|
277
|
-
# Live resource management (new: API-backed commands on `wafer targets`)
|
|
278
|
-
# These become: wafer targets list, wafer targets terminate, etc.
|
|
279
|
-
from wafer.targets_cli import (
|
|
280
|
-
targets_list as _targets_list_cmd,
|
|
281
|
-
)
|
|
282
|
-
from wafer.targets_cli import (
|
|
283
|
-
targets_pools as _targets_pools_cmd,
|
|
284
|
-
)
|
|
285
|
-
from wafer.targets_cli import (
|
|
286
|
-
targets_probe as _targets_probe_cmd,
|
|
287
|
-
)
|
|
288
|
-
from wafer.targets_cli import (
|
|
289
|
-
targets_provision as _targets_provision_cmd,
|
|
290
|
-
)
|
|
291
|
-
from wafer.targets_cli import (
|
|
292
|
-
targets_reconcile as _targets_reconcile_cmd,
|
|
293
|
-
)
|
|
294
|
-
from wafer.targets_cli import (
|
|
295
|
-
targets_terminate as _targets_terminate_cmd,
|
|
296
|
-
)
|
|
297
|
-
|
|
298
|
-
# Billing management - nested under config
|
|
266
|
+
# Billing management
|
|
299
267
|
billing_app = typer.Typer(help="Manage billing, credits, and subscription")
|
|
300
|
-
|
|
268
|
+
app.add_typer(billing_app, name="billing")
|
|
301
269
|
|
|
302
270
|
# Corpus management
|
|
303
271
|
corpus_app = typer.Typer(help="Download and manage GPU documentation")
|
|
304
|
-
app.add_typer(corpus_app, name="corpus"
|
|
272
|
+
app.add_typer(corpus_app, name="corpus")
|
|
305
273
|
|
|
306
274
|
# Evaluate (supports multiple kernel formats)
|
|
307
275
|
evaluate_app = typer.Typer(
|
|
308
276
|
help="Test kernel correctness and performance",
|
|
309
277
|
invoke_without_command=True,
|
|
310
278
|
)
|
|
311
|
-
app.add_typer(evaluate_app, name="evaluate"
|
|
279
|
+
app.add_typer(evaluate_app, name="evaluate")
|
|
312
280
|
|
|
313
281
|
# Nested subcommand for kernelbench format
|
|
314
282
|
kernelbench_app = typer.Typer(
|
|
@@ -324,11 +292,6 @@ gpumode_app = typer.Typer(
|
|
|
324
292
|
)
|
|
325
293
|
evaluate_app.add_typer(gpumode_app, name="gpumode")
|
|
326
294
|
|
|
327
|
-
# Baseline discovery (what kernel does PyTorch use?)
|
|
328
|
-
from wafer.baseline import baseline_app
|
|
329
|
-
|
|
330
|
-
app.add_typer(baseline_app, name="baseline", rich_help_panel="Kernel Development")
|
|
331
|
-
|
|
332
295
|
# =============================================================================
|
|
333
296
|
# Dev commands (internal, used by web app proxy)
|
|
334
297
|
# =============================================================================
|
|
@@ -342,7 +305,7 @@ app.add_typer(dev_app, name="dev")
|
|
|
342
305
|
# =============================================================================
|
|
343
306
|
|
|
344
307
|
nvidia_app = typer.Typer(help="NVIDIA GPU profiling and analysis tools")
|
|
345
|
-
app.add_typer(nvidia_app, name="nvidia"
|
|
308
|
+
app.add_typer(nvidia_app, name="nvidia")
|
|
346
309
|
|
|
347
310
|
# NCU analysis - under nvidia
|
|
348
311
|
ncu_app = typer.Typer(help="Nsight Compute profile analysis")
|
|
@@ -365,7 +328,7 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
|
|
|
365
328
|
# =============================================================================
|
|
366
329
|
|
|
367
330
|
amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
|
|
368
|
-
app.add_typer(amd_app, name="amd"
|
|
331
|
+
app.add_typer(amd_app, name="amd")
|
|
369
332
|
|
|
370
333
|
# Unified ISA Analyzer - supports both .co files and Triton artifacts
|
|
371
334
|
isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
|
|
@@ -376,14 +339,14 @@ amd_app.add_typer(isa_app, name="isa")
|
|
|
376
339
|
# =============================================================================
|
|
377
340
|
|
|
378
341
|
compare_app = typer.Typer(help="Compare GPU traces across platforms (AMD vs NVIDIA)")
|
|
379
|
-
app.add_typer(compare_app, name="compare"
|
|
342
|
+
app.add_typer(compare_app, name="compare")
|
|
380
343
|
|
|
381
344
|
# =============================================================================
|
|
382
345
|
# Roofline analysis (wafer roofline)
|
|
383
346
|
# =============================================================================
|
|
384
347
|
|
|
385
348
|
|
|
386
|
-
@app.command("roofline"
|
|
349
|
+
@app.command("roofline")
|
|
387
350
|
def roofline_cmd(
|
|
388
351
|
gpu: str | None = typer.Option(
|
|
389
352
|
None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
|
|
@@ -474,7 +437,7 @@ def roofline_cmd(
|
|
|
474
437
|
# =============================================================================
|
|
475
438
|
|
|
476
439
|
skill_app = typer.Typer(help="Manage AI coding assistant skills (Claude Code, Codex)")
|
|
477
|
-
app.add_typer(skill_app, name="skill"
|
|
440
|
+
app.add_typer(skill_app, name="skill")
|
|
478
441
|
|
|
479
442
|
|
|
480
443
|
@skill_app.command("install")
|
|
@@ -638,19 +601,14 @@ def skill_status() -> None:
|
|
|
638
601
|
|
|
639
602
|
|
|
640
603
|
# =============================================================================
|
|
641
|
-
#
|
|
604
|
+
# Provider auth management (wafer auth ...)
|
|
642
605
|
# =============================================================================
|
|
643
606
|
|
|
644
|
-
|
|
645
|
-
app.add_typer(
|
|
646
|
-
|
|
647
|
-
providers_app = typer.Typer(
|
|
648
|
-
help="Manage API keys for cloud GPU providers (RunPod, DigitalOcean, etc.)"
|
|
649
|
-
)
|
|
650
|
-
auth_app.add_typer(providers_app, name="providers")
|
|
607
|
+
provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
|
|
608
|
+
app.add_typer(provider_auth_app, name="auth")
|
|
651
609
|
|
|
652
610
|
|
|
653
|
-
@
|
|
611
|
+
@provider_auth_app.command("login")
|
|
654
612
|
def provider_auth_login(
|
|
655
613
|
provider: str = typer.Argument(
|
|
656
614
|
...,
|
|
@@ -669,10 +627,10 @@ def provider_auth_login(
|
|
|
669
627
|
(e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
|
|
670
628
|
|
|
671
629
|
Examples:
|
|
672
|
-
wafer auth
|
|
673
|
-
wafer auth
|
|
674
|
-
wafer auth
|
|
675
|
-
echo $API_KEY | wafer auth
|
|
630
|
+
wafer auth login anthropic --api-key sk-ant-xxx
|
|
631
|
+
wafer auth login runpod --api-key rp_xxx
|
|
632
|
+
wafer auth login openai --api-key sk-xxx
|
|
633
|
+
echo $API_KEY | wafer auth login anthropic
|
|
676
634
|
"""
|
|
677
635
|
import sys
|
|
678
636
|
|
|
@@ -702,7 +660,7 @@ def provider_auth_login(
|
|
|
702
660
|
typer.echo("Stored in: ~/.wafer/auth.json")
|
|
703
661
|
|
|
704
662
|
|
|
705
|
-
@
|
|
663
|
+
@provider_auth_app.command("logout")
|
|
706
664
|
def provider_auth_logout(
|
|
707
665
|
provider: str = typer.Argument(
|
|
708
666
|
...,
|
|
@@ -712,8 +670,8 @@ def provider_auth_logout(
|
|
|
712
670
|
"""Remove stored API key for a cloud GPU provider.
|
|
713
671
|
|
|
714
672
|
Examples:
|
|
715
|
-
wafer auth
|
|
716
|
-
wafer auth
|
|
673
|
+
wafer auth logout runpod
|
|
674
|
+
wafer auth logout digitalocean
|
|
717
675
|
"""
|
|
718
676
|
from wafer_core.auth import PROVIDERS, remove_api_key
|
|
719
677
|
|
|
@@ -729,7 +687,7 @@ def provider_auth_logout(
|
|
|
729
687
|
typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
|
|
730
688
|
|
|
731
689
|
|
|
732
|
-
@
|
|
690
|
+
@provider_auth_app.command("status")
|
|
733
691
|
def provider_auth_status() -> None:
|
|
734
692
|
"""Show authentication status for all cloud GPU providers.
|
|
735
693
|
|
|
@@ -737,7 +695,7 @@ def provider_auth_status() -> None:
|
|
|
737
695
|
the keys are coming from (environment variable or auth.json).
|
|
738
696
|
|
|
739
697
|
Example:
|
|
740
|
-
wafer auth
|
|
698
|
+
wafer auth status
|
|
741
699
|
"""
|
|
742
700
|
from wafer_core.auth import get_all_auth_status
|
|
743
701
|
|
|
@@ -752,7 +710,7 @@ def provider_auth_status() -> None:
|
|
|
752
710
|
typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
|
|
753
711
|
else:
|
|
754
712
|
typer.echo(f" {status.display_name}: ✗ Not configured")
|
|
755
|
-
typer.echo(f" Run: wafer auth
|
|
713
|
+
typer.echo(f" Run: wafer auth login {status.provider}")
|
|
756
714
|
typer.echo(f" Or set: {status.key_url}")
|
|
757
715
|
|
|
758
716
|
typer.echo("")
|
|
@@ -1297,7 +1255,7 @@ def config_show_legacy() -> None:
|
|
|
1297
1255
|
config_show_new()
|
|
1298
1256
|
|
|
1299
1257
|
|
|
1300
|
-
@app.command(
|
|
1258
|
+
@app.command()
|
|
1301
1259
|
def agent( # noqa: PLR0913
|
|
1302
1260
|
prompt: str | None = typer.Argument(
|
|
1303
1261
|
None,
|
|
@@ -1539,7 +1497,6 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1539
1497
|
template_args: list[str] | None = typer.Option(None, "--args"),
|
|
1540
1498
|
corpus: str | None = typer.Option(None, "--corpus"),
|
|
1541
1499
|
no_sandbox: bool = typer.Option(False, "--no-sandbox"),
|
|
1542
|
-
no_proxy: bool = typer.Option(False, "--no-proxy", help="Skip wafer proxy, use ANTHROPIC_API_KEY directly"),
|
|
1543
1500
|
) -> None:
|
|
1544
1501
|
agent(
|
|
1545
1502
|
prompt=prompt,
|
|
@@ -1560,7 +1517,6 @@ def _make_agent_alias(name: str, doc: str) -> None:
|
|
|
1560
1517
|
template_args=template_args,
|
|
1561
1518
|
corpus=corpus,
|
|
1562
1519
|
no_sandbox=no_sandbox,
|
|
1563
|
-
no_proxy=no_proxy, # Must explicitly pass to avoid Typer default object being truthy
|
|
1564
1520
|
)
|
|
1565
1521
|
|
|
1566
1522
|
alias_cmd.__doc__ = doc
|
|
@@ -1584,11 +1540,7 @@ def evaluate( # noqa: PLR0913
|
|
|
1584
1540
|
None, "--reference", help="Path to reference kernel file"
|
|
1585
1541
|
),
|
|
1586
1542
|
test_cases: Path | None = typer.Option(
|
|
1587
|
-
None,
|
|
1588
|
-
"--test-cases",
|
|
1589
|
-
help="Path to test cases JSON file. "
|
|
1590
|
-
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
1591
|
-
"Run 'wafer evaluate make-template' to generate an example.",
|
|
1543
|
+
None, "--test-cases", help="Path to test cases JSON file"
|
|
1592
1544
|
),
|
|
1593
1545
|
target: str | None = typer.Option(
|
|
1594
1546
|
None,
|
|
@@ -1600,9 +1552,7 @@ def evaluate( # noqa: PLR0913
|
|
|
1600
1552
|
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
1601
1553
|
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
1602
1554
|
defensive: bool = typer.Option(
|
|
1603
|
-
|
|
1604
|
-
"--defense/--no-defense",
|
|
1605
|
-
help="Run reward hack defense checks after benchmarking. Enabled by default.",
|
|
1555
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
1606
1556
|
),
|
|
1607
1557
|
sync_artifacts: bool = typer.Option(
|
|
1608
1558
|
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
@@ -1616,24 +1566,24 @@ def evaluate( # noqa: PLR0913
|
|
|
1616
1566
|
The evaluation checks:
|
|
1617
1567
|
1. Correctness: Does the kernel produce the same output as the reference?
|
|
1618
1568
|
2. Performance (--benchmark): How fast is it compared to the reference?
|
|
1619
|
-
3. Defense: Detects
|
|
1569
|
+
3. Defense (--defensive): Detects evaluation hacking (stream injection, etc.)
|
|
1620
1570
|
|
|
1621
1571
|
Examples:
|
|
1622
1572
|
# Basic correctness check
|
|
1623
|
-
wafer evaluate
|
|
1573
|
+
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json
|
|
1624
1574
|
|
|
1625
|
-
# With benchmarking
|
|
1626
|
-
wafer evaluate
|
|
1575
|
+
# With benchmarking on a specific target
|
|
1576
|
+
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1627
1577
|
--target vultr-b200 --benchmark
|
|
1628
1578
|
|
|
1629
|
-
#
|
|
1630
|
-
wafer evaluate
|
|
1631
|
-
--benchmark --
|
|
1579
|
+
# Full evaluation with defensive timing (detects cheating)
|
|
1580
|
+
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1581
|
+
--benchmark --defensive
|
|
1632
1582
|
|
|
1633
1583
|
Subcommands:
|
|
1634
1584
|
gpumode Use GPUMode format (functional) - RECOMMENDED
|
|
1635
1585
|
kernelbench Use KernelBench format (ModelNew class)
|
|
1636
|
-
make-template Generate template files for this format
|
|
1586
|
+
make-template Generate template files for this format (deprecated)
|
|
1637
1587
|
"""
|
|
1638
1588
|
# If a subcommand is being invoked, skip the main evaluation logic
|
|
1639
1589
|
if ctx.invoked_subcommand is not None:
|
|
@@ -1787,7 +1737,7 @@ def evaluate_make_template(
|
|
|
1787
1737
|
typer.echo(f" 2. Edit {output_dir / 'reference.py'} with the ground truth + input generator")
|
|
1788
1738
|
typer.echo(f" 3. Edit {output_dir / 'test_cases.json'} with your test parameters")
|
|
1789
1739
|
typer.echo(" 4. Run:")
|
|
1790
|
-
typer.echo(f" wafer evaluate
|
|
1740
|
+
typer.echo(f" wafer evaluate --impl {output_dir / 'kernel.py'} \\")
|
|
1791
1741
|
typer.echo(f" --reference {output_dir / 'reference.py'} \\")
|
|
1792
1742
|
typer.echo(f" --test-cases {output_dir / 'test_cases.json'} --benchmark")
|
|
1793
1743
|
|
|
@@ -1851,95 +1801,6 @@ def kernelbench_list_problems() -> None:
|
|
|
1851
1801
|
raise typer.Exit(1) from None
|
|
1852
1802
|
|
|
1853
1803
|
|
|
1854
|
-
def _resolve_pool_query(pool: str, collector) -> tuple[str, object]:
|
|
1855
|
-
"""Resolve a PoolQuery pool to a target spec name + lock context.
|
|
1856
|
-
|
|
1857
|
-
Queries live providers, matches by pool query, locks one target,
|
|
1858
|
-
returns (spec_name, lock_context) for the evaluator.
|
|
1859
|
-
"""
|
|
1860
|
-
import trio
|
|
1861
|
-
from wafer_core.targets.pool import resolve_pool
|
|
1862
|
-
|
|
1863
|
-
from .target_lock import acquire_from_pool
|
|
1864
|
-
|
|
1865
|
-
matched_targets = trio.run(resolve_pool, pool)
|
|
1866
|
-
|
|
1867
|
-
if not matched_targets:
|
|
1868
|
-
collector.set_error("pool", "NoMatchingTargets", pool=pool)
|
|
1869
|
-
collector.finalize()
|
|
1870
|
-
raise typer.Exit(1)
|
|
1871
|
-
|
|
1872
|
-
# Filter to targets with a spec (evaluator needs spec fields)
|
|
1873
|
-
spec_targets = [t for t in matched_targets if t.spec_name]
|
|
1874
|
-
if not spec_targets:
|
|
1875
|
-
collector.set_error(
|
|
1876
|
-
"pool",
|
|
1877
|
-
"NoSpecTargets",
|
|
1878
|
-
pool=pool,
|
|
1879
|
-
message="Matched targets have no spec binding — evaluator needs spec fields",
|
|
1880
|
-
)
|
|
1881
|
-
collector.finalize()
|
|
1882
|
-
raise typer.Exit(1)
|
|
1883
|
-
|
|
1884
|
-
# Lock one by resource_id
|
|
1885
|
-
resource_ids = [t.resource_id for t in spec_targets]
|
|
1886
|
-
collector.emit("pool_acquire", pool=pool, count=len(resource_ids))
|
|
1887
|
-
|
|
1888
|
-
lock_ctx = acquire_from_pool(resource_ids)
|
|
1889
|
-
acquired_id = lock_ctx.__enter__()
|
|
1890
|
-
|
|
1891
|
-
if acquired_id is None:
|
|
1892
|
-
lock_ctx.__exit__(None, None, None)
|
|
1893
|
-
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=resource_ids)
|
|
1894
|
-
collector.finalize()
|
|
1895
|
-
raise typer.Exit(1)
|
|
1896
|
-
|
|
1897
|
-
# Map resource_id back to spec_name
|
|
1898
|
-
acquired_target = next(t for t in spec_targets if t.resource_id == acquired_id)
|
|
1899
|
-
spec_name = acquired_target.spec_name
|
|
1900
|
-
|
|
1901
|
-
collector.emit("pool_acquired", target=spec_name, resource_id=acquired_id)
|
|
1902
|
-
return spec_name, lock_ctx
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
def _resolve_pool_legacy(pool: str, collector) -> tuple[str, object]:
|
|
1906
|
-
"""Resolve an old-style pool (static target name list) to a target name + lock context.
|
|
1907
|
-
|
|
1908
|
-
Old format: [pools.name] targets = ["t1", "t2"]
|
|
1909
|
-
"""
|
|
1910
|
-
from .target_lock import acquire_from_pool
|
|
1911
|
-
from .targets import filter_pool_by_auth, get_pool
|
|
1912
|
-
|
|
1913
|
-
try:
|
|
1914
|
-
pool_targets = get_pool(pool)
|
|
1915
|
-
except FileNotFoundError as e:
|
|
1916
|
-
collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
|
|
1917
|
-
collector.finalize()
|
|
1918
|
-
raise typer.Exit(1) from None
|
|
1919
|
-
|
|
1920
|
-
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1921
|
-
if skipped:
|
|
1922
|
-
collector.emit("pool_auth_skip", targets=skipped)
|
|
1923
|
-
|
|
1924
|
-
if not usable_targets:
|
|
1925
|
-
collector.set_error("pool", "NoUsableTargets", pool=pool)
|
|
1926
|
-
collector.finalize()
|
|
1927
|
-
raise typer.Exit(1) from None
|
|
1928
|
-
|
|
1929
|
-
collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
|
|
1930
|
-
lock_ctx = acquire_from_pool(usable_targets)
|
|
1931
|
-
acquired_target = lock_ctx.__enter__()
|
|
1932
|
-
|
|
1933
|
-
if acquired_target is None:
|
|
1934
|
-
lock_ctx.__exit__(None, None, None)
|
|
1935
|
-
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
|
|
1936
|
-
collector.finalize()
|
|
1937
|
-
raise typer.Exit(1)
|
|
1938
|
-
|
|
1939
|
-
collector.emit("pool_acquired", target=acquired_target)
|
|
1940
|
-
return acquired_target, lock_ctx
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
1804
|
@kernelbench_app.callback(invoke_without_command=True)
|
|
1944
1805
|
def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
1945
1806
|
ctx: typer.Context,
|
|
@@ -1975,9 +1836,7 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
1975
1836
|
),
|
|
1976
1837
|
seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
|
|
1977
1838
|
defensive: bool = typer.Option(
|
|
1978
|
-
|
|
1979
|
-
"--defense/--no-defense",
|
|
1980
|
-
help="Run reward hack defense checks after benchmarking. Enabled by default.",
|
|
1839
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
1981
1840
|
),
|
|
1982
1841
|
backend: str | None = typer.Option(
|
|
1983
1842
|
None,
|
|
@@ -2017,20 +1876,16 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2017
1876
|
The evaluation checks:
|
|
2018
1877
|
1. Correctness: Does ModelNew.forward() produce same output as Model.forward()?
|
|
2019
1878
|
2. Performance (--benchmark): How fast is it compared to the reference?
|
|
2020
|
-
3. Defense: Detects
|
|
1879
|
+
3. Defense (--defensive): Detects evaluation hacking
|
|
2021
1880
|
|
|
2022
1881
|
Examples:
|
|
2023
1882
|
# Basic correctness check
|
|
2024
1883
|
wafer evaluate kernelbench --impl my_kernel.py --reference problem.py
|
|
2025
1884
|
|
|
2026
|
-
# With benchmarking
|
|
1885
|
+
# With benchmarking
|
|
2027
1886
|
wafer evaluate kernelbench --impl my_kernel.py --reference problem.py \\
|
|
2028
1887
|
--target vultr-b200 --benchmark
|
|
2029
1888
|
|
|
2030
|
-
# Benchmarking without defense checks
|
|
2031
|
-
wafer evaluate kernelbench --impl my_kernel.py --reference problem.py \\
|
|
2032
|
-
--target vultr-b200 --benchmark --no-defense
|
|
2033
|
-
|
|
2034
1889
|
Subcommands:
|
|
2035
1890
|
make-template Extract a KernelBench problem as template
|
|
2036
1891
|
"""
|
|
@@ -2076,12 +1931,39 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2076
1931
|
pool_lock_context = None
|
|
2077
1932
|
|
|
2078
1933
|
if pool:
|
|
2079
|
-
from
|
|
1934
|
+
from .target_lock import acquire_from_pool
|
|
1935
|
+
from .targets import filter_pool_by_auth, get_pool
|
|
2080
1936
|
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
1937
|
+
try:
|
|
1938
|
+
pool_targets = get_pool(pool)
|
|
1939
|
+
except FileNotFoundError as e:
|
|
1940
|
+
collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
|
|
1941
|
+
collector.finalize()
|
|
1942
|
+
raise typer.Exit(1) from None
|
|
1943
|
+
|
|
1944
|
+
# Filter to only targets with valid auth
|
|
1945
|
+
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1946
|
+
if skipped:
|
|
1947
|
+
collector.emit("pool_auth_skip", targets=skipped)
|
|
1948
|
+
|
|
1949
|
+
if not usable_targets:
|
|
1950
|
+
collector.set_error("pool", "NoUsableTargets", pool=pool)
|
|
1951
|
+
collector.finalize()
|
|
1952
|
+
raise typer.Exit(1) from None
|
|
1953
|
+
|
|
1954
|
+
collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
|
|
1955
|
+
pool_lock_context = acquire_from_pool(usable_targets)
|
|
1956
|
+
acquired_target = pool_lock_context.__enter__()
|
|
1957
|
+
|
|
1958
|
+
if acquired_target is None:
|
|
1959
|
+
# Exit context manager before raising to avoid resource leak
|
|
1960
|
+
pool_lock_context.__exit__(None, None, None)
|
|
1961
|
+
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
|
|
1962
|
+
collector.finalize()
|
|
1963
|
+
raise typer.Exit(1)
|
|
1964
|
+
|
|
1965
|
+
collector.emit("pool_acquired", target=acquired_target)
|
|
1966
|
+
resolved_target = acquired_target
|
|
2085
1967
|
|
|
2086
1968
|
collector.target = resolved_target
|
|
2087
1969
|
|
|
@@ -2090,15 +1972,12 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2090
1972
|
if stages == "all":
|
|
2091
1973
|
resolved_stages = "compile,correctness,benchmark,defense"
|
|
2092
1974
|
|
|
2093
|
-
# Handle --benchmark and --
|
|
1975
|
+
# Handle backward compat: --benchmark and --defensive flags add to stages
|
|
2094
1976
|
stage_set = set(resolved_stages.split(","))
|
|
2095
1977
|
if benchmark and "benchmark" not in stage_set:
|
|
2096
1978
|
stage_set.add("benchmark")
|
|
2097
|
-
|
|
2098
|
-
if defensive and "benchmark" in stage_set and "defense" not in stage_set:
|
|
1979
|
+
if defensive and "defense" not in stage_set:
|
|
2099
1980
|
stage_set.add("defense")
|
|
2100
|
-
if not defensive:
|
|
2101
|
-
stage_set.discard("defense")
|
|
2102
1981
|
resolved_stages = ",".join(
|
|
2103
1982
|
sorted(
|
|
2104
1983
|
stage_set,
|
|
@@ -2409,11 +2288,7 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2409
2288
|
None, "--reference", help="Path to reference kernel file"
|
|
2410
2289
|
),
|
|
2411
2290
|
test_cases: Path | None = typer.Option(
|
|
2412
|
-
None,
|
|
2413
|
-
"--test-cases",
|
|
2414
|
-
help="Path to test cases JSON file. "
|
|
2415
|
-
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
2416
|
-
"Run 'wafer evaluate make-template' to generate an example.",
|
|
2291
|
+
None, "--test-cases", help="Path to test cases JSON file"
|
|
2417
2292
|
),
|
|
2418
2293
|
target: str | None = typer.Option(
|
|
2419
2294
|
None,
|
|
@@ -2432,9 +2307,7 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2432
2307
|
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
2433
2308
|
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
2434
2309
|
defensive: bool = typer.Option(
|
|
2435
|
-
|
|
2436
|
-
"--defense/--no-defense",
|
|
2437
|
-
help="Run reward hack defense checks after benchmarking. Enabled by default.",
|
|
2310
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
2438
2311
|
),
|
|
2439
2312
|
sync_artifacts: bool = typer.Option(
|
|
2440
2313
|
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
@@ -2483,13 +2356,6 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2483
2356
|
err=True,
|
|
2484
2357
|
)
|
|
2485
2358
|
typer.echo("", err=True)
|
|
2486
|
-
if "--test-cases" in missing_args:
|
|
2487
|
-
typer.echo(
|
|
2488
|
-
"Tip: Run 'wafer evaluate make-template' to generate template files "
|
|
2489
|
-
"including test_cases.json.",
|
|
2490
|
-
err=True,
|
|
2491
|
-
)
|
|
2492
|
-
typer.echo("", err=True)
|
|
2493
2359
|
typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
|
|
2494
2360
|
typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
|
|
2495
2361
|
raise typer.Exit(1)
|
|
@@ -2590,12 +2456,313 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2590
2456
|
else:
|
|
2591
2457
|
typer.echo(f"Error: {result.error_message}", err=True)
|
|
2592
2458
|
raise typer.Exit(1)
|
|
2459
|
+
|
|
2460
|
+
|
|
2461
|
+
# =============================================================================
|
|
2462
|
+
# Push and Remote-Run commands
|
|
2463
|
+
# =============================================================================
|
|
2464
|
+
|
|
2465
|
+
|
|
2466
|
+
@app.command("push", hidden=True)
|
|
2467
|
+
def push(
|
|
2468
|
+
local_path: Path = typer.Argument(..., help="Local directory to upload"),
|
|
2469
|
+
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace name override"),
|
|
2470
|
+
direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
|
|
2471
|
+
target_name: str | None = typer.Option(
|
|
2472
|
+
None,
|
|
2473
|
+
"--target",
|
|
2474
|
+
"-t",
|
|
2475
|
+
help="Target for --direct mode. See 'wafer config targets list'.",
|
|
2476
|
+
autocompletion=complete_target_name,
|
|
2477
|
+
),
|
|
2478
|
+
) -> None:
|
|
2479
|
+
"""Push directory to remote GPU.
|
|
2480
|
+
|
|
2481
|
+
By default, uses wafer-api. Use --direct for direct SSH mode.
|
|
2482
|
+
|
|
2483
|
+
Examples:
|
|
2484
|
+
wafer push ./my_project
|
|
2485
|
+
wafer push . --workspace my-kernel
|
|
2486
|
+
wafer push ./my_project --direct --target vultr-b200
|
|
2487
|
+
"""
|
|
2488
|
+
# Validate path
|
|
2489
|
+
if not local_path.exists():
|
|
2490
|
+
typer.echo(f"Error: Path not found: {local_path}", err=True)
|
|
2491
|
+
raise typer.Exit(1)
|
|
2492
|
+
|
|
2493
|
+
if not local_path.is_dir():
|
|
2494
|
+
typer.echo(f"Error: Not a directory: {local_path}", err=True)
|
|
2495
|
+
raise typer.Exit(1)
|
|
2496
|
+
|
|
2497
|
+
# Resolve to absolute path
|
|
2498
|
+
local_path = local_path.resolve()
|
|
2499
|
+
|
|
2500
|
+
if direct:
|
|
2501
|
+
# Direct SSH mode (requires target)
|
|
2502
|
+
if not target_name:
|
|
2503
|
+
typer.echo("Error: --target required for --direct mode", err=True)
|
|
2504
|
+
raise typer.Exit(1)
|
|
2505
|
+
|
|
2506
|
+
from wafer_core.utils.kernel_utils.targets.config import ModalTarget
|
|
2507
|
+
|
|
2508
|
+
from .gpu_run import push_directory as push_direct
|
|
2509
|
+
from .targets import load_target
|
|
2510
|
+
|
|
2511
|
+
try:
|
|
2512
|
+
target = load_target(target_name)
|
|
2513
|
+
except FileNotFoundError:
|
|
2514
|
+
typer.echo(f"Error: Target not found: {target_name}", err=True)
|
|
2515
|
+
typer.echo("List targets with: wafer config targets list", err=True)
|
|
2516
|
+
raise typer.Exit(1) from None
|
|
2517
|
+
|
|
2518
|
+
if isinstance(target, ModalTarget):
|
|
2519
|
+
typer.echo(
|
|
2520
|
+
f"Error: Target '{target_name}' is a Modal target. Direct push requires SSH.",
|
|
2521
|
+
err=True,
|
|
2522
|
+
)
|
|
2523
|
+
raise typer.Exit(1) from None
|
|
2524
|
+
|
|
2525
|
+
typer.echo(f"Connecting to {target.ssh_target}...")
|
|
2526
|
+
try:
|
|
2527
|
+
result = push_direct(local_path, target)
|
|
2528
|
+
except Exception as e:
|
|
2529
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2530
|
+
raise typer.Exit(1) from None
|
|
2531
|
+
|
|
2532
|
+
typer.echo(f"Uploading {len(result.files_uploaded)} files to {result.workspace_path}")
|
|
2533
|
+
for f in result.files_uploaded:
|
|
2534
|
+
typer.echo(f" ✓ {f}")
|
|
2535
|
+
typer.echo(f"Pushed to: {result.workspace_path}")
|
|
2536
|
+
else:
|
|
2537
|
+
# API mode (default)
|
|
2538
|
+
from .api_client import push_directory as push_api
|
|
2539
|
+
|
|
2540
|
+
workspace_name = workspace or local_path.name
|
|
2541
|
+
typer.echo(f"Pushing {local_path.name} to wafer-api...")
|
|
2542
|
+
|
|
2543
|
+
try:
|
|
2544
|
+
result = push_api(local_path, workspace_name)
|
|
2545
|
+
except Exception as e:
|
|
2546
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2547
|
+
raise typer.Exit(1) from None
|
|
2548
|
+
|
|
2549
|
+
typer.echo(f"Uploaded {len(result.files_uploaded)} files")
|
|
2550
|
+
for f in result.files_uploaded:
|
|
2551
|
+
typer.echo(f" ✓ {f}")
|
|
2552
|
+
typer.echo(f"Workspace ID: {result.workspace_id}")
|
|
2553
|
+
|
|
2554
|
+
|
|
2555
|
+
def _run_direct_mode(
|
|
2556
|
+
cmd_str: str,
|
|
2557
|
+
target_name: str,
|
|
2558
|
+
upload_dir: Path | None,
|
|
2559
|
+
workspace_id: str | None,
|
|
2560
|
+
gpu_id: int | None,
|
|
2561
|
+
) -> int:
|
|
2562
|
+
"""Run command via direct SSH mode. Returns exit code."""
|
|
2563
|
+
from wafer_core.utils.kernel_utils.targets.config import ModalTarget
|
|
2564
|
+
|
|
2565
|
+
from .gpu_run import push_directory as push_direct
|
|
2566
|
+
from .gpu_run import run_command as run_direct
|
|
2567
|
+
from .targets import load_target
|
|
2568
|
+
|
|
2569
|
+
try:
|
|
2570
|
+
target = load_target(target_name)
|
|
2571
|
+
except FileNotFoundError:
|
|
2572
|
+
typer.echo(f"Error: Target not found: {target_name}", err=True)
|
|
2573
|
+
typer.echo("List targets with: wafer config targets list", err=True)
|
|
2574
|
+
raise typer.Exit(1) from None
|
|
2575
|
+
|
|
2576
|
+
if isinstance(target, ModalTarget):
|
|
2577
|
+
typer.echo(
|
|
2578
|
+
f"Error: Target '{target_name}' is a Modal target. Direct mode requires SSH.", err=True
|
|
2579
|
+
)
|
|
2580
|
+
raise typer.Exit(1) from None
|
|
2581
|
+
|
|
2582
|
+
if not target.docker_image:
|
|
2583
|
+
typer.echo(f"Error: Target '{target_name}' has no docker_image configured", err=True)
|
|
2584
|
+
raise typer.Exit(1)
|
|
2585
|
+
|
|
2586
|
+
# If upload_dir provided, push first
|
|
2587
|
+
workspace_name = workspace_id
|
|
2588
|
+
if upload_dir:
|
|
2589
|
+
typer.echo(f"Uploading {upload_dir.name}...")
|
|
2590
|
+
try:
|
|
2591
|
+
push_result = push_direct(upload_dir, target)
|
|
2592
|
+
workspace_name = push_result.workspace_name
|
|
2593
|
+
typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
|
|
2594
|
+
except Exception as e:
|
|
2595
|
+
typer.echo(f"Error uploading: {e}", err=True)
|
|
2596
|
+
raise typer.Exit(1) from None
|
|
2597
|
+
elif not workspace_name:
|
|
2598
|
+
workspace_name = "tmp"
|
|
2599
|
+
|
|
2600
|
+
effective_gpu = gpu_id if gpu_id is not None else target.gpu_ids[0]
|
|
2601
|
+
typer.echo(f"Target: {target_name} (docker: {target.docker_image})")
|
|
2602
|
+
typer.echo(f"Workspace: {workspace_name}")
|
|
2603
|
+
typer.echo(f"GPU: {effective_gpu}")
|
|
2604
|
+
typer.echo(f"Command: {cmd_str}")
|
|
2605
|
+
typer.echo("-" * 60)
|
|
2606
|
+
|
|
2607
|
+
try:
|
|
2608
|
+
return run_direct(cmd_str, workspace_name, target, gpu_id)
|
|
2609
|
+
except KeyboardInterrupt:
|
|
2610
|
+
typer.echo("\nInterrupted by user", err=True)
|
|
2611
|
+
raise typer.Exit(130) from None
|
|
2612
|
+
except Exception as e:
|
|
2613
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2614
|
+
raise typer.Exit(1) from None
|
|
2615
|
+
|
|
2616
|
+
|
|
2617
|
+
def _run_api_mode( # noqa: PLR0913
|
|
2618
|
+
cmd_str: str,
|
|
2619
|
+
upload_dir: Path | None,
|
|
2620
|
+
workspace_id: str | None,
|
|
2621
|
+
gpu_id: int | None,
|
|
2622
|
+
gpu_count: int,
|
|
2623
|
+
docker_image: str | None,
|
|
2624
|
+
docker_entrypoint: str | None,
|
|
2625
|
+
pull_image: bool,
|
|
2626
|
+
require_hwc: bool,
|
|
2627
|
+
) -> int:
|
|
2628
|
+
"""Run command via wafer-api. Returns exit code."""
|
|
2629
|
+
from .api_client import run_command_stream
|
|
2630
|
+
|
|
2631
|
+
if upload_dir:
|
|
2632
|
+
typer.echo(f"Uploading: {upload_dir}")
|
|
2633
|
+
elif workspace_id:
|
|
2634
|
+
typer.echo(f"Workspace: {workspace_id}")
|
|
2635
|
+
if gpu_id is not None:
|
|
2636
|
+
typer.echo(f"GPU: {gpu_id}")
|
|
2637
|
+
if gpu_count > 1:
|
|
2638
|
+
typer.echo(f"GPU count: {gpu_count}")
|
|
2639
|
+
if docker_image:
|
|
2640
|
+
typer.echo(f"Image: {docker_image}")
|
|
2641
|
+
if docker_entrypoint:
|
|
2642
|
+
typer.echo(f"Entrypoint: {docker_entrypoint}")
|
|
2643
|
+
if pull_image:
|
|
2644
|
+
typer.echo("Pull image: yes")
|
|
2645
|
+
typer.echo(f"Command: {cmd_str}")
|
|
2646
|
+
if require_hwc:
|
|
2647
|
+
typer.echo("Hardware counters: required (baremetal)")
|
|
2648
|
+
typer.echo("-" * 60)
|
|
2649
|
+
|
|
2650
|
+
try:
|
|
2651
|
+
return run_command_stream(
|
|
2652
|
+
command=cmd_str,
|
|
2653
|
+
upload_dir=upload_dir,
|
|
2654
|
+
workspace_id=workspace_id,
|
|
2655
|
+
gpu_id=gpu_id,
|
|
2656
|
+
gpu_count=gpu_count,
|
|
2657
|
+
docker_image=docker_image,
|
|
2658
|
+
docker_entrypoint=docker_entrypoint,
|
|
2659
|
+
pull_image=pull_image,
|
|
2660
|
+
require_hardware_counters=require_hwc,
|
|
2661
|
+
)
|
|
2662
|
+
except KeyboardInterrupt:
|
|
2663
|
+
typer.echo("\nInterrupted by user", err=True)
|
|
2664
|
+
raise typer.Exit(130) from None
|
|
2665
|
+
except Exception as e:
|
|
2666
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2667
|
+
raise typer.Exit(1) from None
|
|
2668
|
+
|
|
2669
|
+
|
|
2670
|
+
@app.command("remote-run", hidden=True)
|
|
2671
|
+
def remote_run( # noqa: PLR0913
|
|
2672
|
+
command: list[str] = typer.Argument(..., help="Command to run"),
|
|
2673
|
+
upload_dir: Path | None = typer.Option(
|
|
2674
|
+
None, "--upload-dir", "-u", help="Directory to upload (stateless mode)"
|
|
2675
|
+
),
|
|
2676
|
+
workspace_id: str | None = typer.Option(
|
|
2677
|
+
None, "--workspace-id", "-w", help="Workspace ID (from wafer push)"
|
|
2678
|
+
),
|
|
2679
|
+
gpu_id: int | None = typer.Option(None, "--gpu", "-g", help="GPU ID"),
|
|
2680
|
+
gpu_count: int = typer.Option(1, "--gpu-count", "-n", help="Number of GPUs (1-8)"),
|
|
2681
|
+
docker_image: str | None = typer.Option(None, "--image", "-i", help="Docker image override"),
|
|
2682
|
+
docker_entrypoint: str | None = typer.Option(
|
|
2683
|
+
None, "--docker-entrypoint", help="Override Docker entrypoint (e.g., 'bash')"
|
|
2684
|
+
),
|
|
2685
|
+
pull_image: bool = typer.Option(
|
|
2686
|
+
False, "--pull-image", help="Pull image if not available on target"
|
|
2687
|
+
),
|
|
2688
|
+
require_hwc: bool = typer.Option(
|
|
2689
|
+
False, "--require-hwc", help="Require hardware counters (baremetal)"
|
|
2690
|
+
),
|
|
2691
|
+
direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
|
|
2692
|
+
target_name: str | None = typer.Option(
|
|
2693
|
+
None,
|
|
2694
|
+
"--target",
|
|
2695
|
+
"-t",
|
|
2696
|
+
help="Target for --direct mode. See 'wafer config targets list'.",
|
|
2697
|
+
autocompletion=complete_target_name,
|
|
2698
|
+
),
|
|
2699
|
+
) -> None:
|
|
2700
|
+
"""Run command on remote GPU in Docker.
|
|
2701
|
+
|
|
2702
|
+
Two modes:
|
|
2703
|
+
- High-level (stateless): --upload-dir uploads files and runs command
|
|
2704
|
+
- Low-level: --workspace-id uses existing workspace from 'wafer push'
|
|
2705
|
+
|
|
2706
|
+
By default, uses wafer-api. Use --direct for direct SSH mode.
|
|
2707
|
+
|
|
2708
|
+
Examples:
|
|
2709
|
+
# Stateless: upload and run
|
|
2710
|
+
wafer remote-run --upload-dir ./my_project -- python train.py
|
|
2711
|
+
|
|
2712
|
+
# Run without files
|
|
2713
|
+
wafer remote-run -- nvidia-smi
|
|
2714
|
+
|
|
2715
|
+
# Low-level: use existing workspace
|
|
2716
|
+
wafer remote-run --workspace-id ws_abc123 -- python train.py
|
|
2717
|
+
|
|
2718
|
+
# Direct SSH mode
|
|
2719
|
+
wafer remote-run --upload-dir ./my_project --direct --target vultr-b200 -- python train.py
|
|
2720
|
+
"""
|
|
2721
|
+
cmd_str = " ".join(command)
|
|
2722
|
+
if not cmd_str.strip():
|
|
2723
|
+
typer.echo("Error: Empty command", err=True)
|
|
2724
|
+
raise typer.Exit(1)
|
|
2725
|
+
|
|
2726
|
+
if upload_dir and workspace_id:
|
|
2727
|
+
typer.echo("Error: --upload-dir and --workspace-id are mutually exclusive", err=True)
|
|
2728
|
+
raise typer.Exit(1)
|
|
2729
|
+
|
|
2730
|
+
if upload_dir:
|
|
2731
|
+
if not upload_dir.exists():
|
|
2732
|
+
typer.echo(f"Error: Directory not found: {upload_dir}", err=True)
|
|
2733
|
+
raise typer.Exit(1)
|
|
2734
|
+
if not upload_dir.is_dir():
|
|
2735
|
+
typer.echo(f"Error: Not a directory: {upload_dir}", err=True)
|
|
2736
|
+
raise typer.Exit(1)
|
|
2737
|
+
upload_dir = upload_dir.resolve()
|
|
2738
|
+
|
|
2739
|
+
if direct:
|
|
2740
|
+
if not target_name:
|
|
2741
|
+
typer.echo("Error: --target required for --direct mode", err=True)
|
|
2742
|
+
raise typer.Exit(1)
|
|
2743
|
+
exit_code = _run_direct_mode(cmd_str, target_name, upload_dir, workspace_id, gpu_id)
|
|
2744
|
+
else:
|
|
2745
|
+
exit_code = _run_api_mode(
|
|
2746
|
+
cmd_str,
|
|
2747
|
+
upload_dir,
|
|
2748
|
+
workspace_id,
|
|
2749
|
+
gpu_id,
|
|
2750
|
+
gpu_count,
|
|
2751
|
+
docker_image,
|
|
2752
|
+
docker_entrypoint,
|
|
2753
|
+
pull_image,
|
|
2754
|
+
require_hwc,
|
|
2755
|
+
)
|
|
2756
|
+
|
|
2757
|
+
raise typer.Exit(exit_code)
|
|
2758
|
+
|
|
2759
|
+
|
|
2593
2760
|
# =============================================================================
|
|
2594
2761
|
# Authentication commands
|
|
2595
2762
|
# =============================================================================
|
|
2596
2763
|
|
|
2597
2764
|
|
|
2598
|
-
@
|
|
2765
|
+
@app.command("login")
|
|
2599
2766
|
def login(
|
|
2600
2767
|
token: str | None = typer.Option(
|
|
2601
2768
|
None, "--token", "-t", help="Access token (skip browser OAuth)"
|
|
@@ -2620,7 +2787,7 @@ def login(
|
|
|
2620
2787
|
Uses the API environment from config (see 'wafer config show').
|
|
2621
2788
|
|
|
2622
2789
|
SSH Users (Easiest):
|
|
2623
|
-
- Just run: wafer
|
|
2790
|
+
- Just run: wafer login
|
|
2624
2791
|
- Visit the URL and enter the code shown
|
|
2625
2792
|
- No port forwarding needed!
|
|
2626
2793
|
|
|
@@ -2630,17 +2797,17 @@ def login(
|
|
|
2630
2797
|
|
|
2631
2798
|
Manual token option:
|
|
2632
2799
|
- Visit auth.wafer.ai, authenticate, copy token from URL
|
|
2633
|
-
- Run: wafer
|
|
2800
|
+
- Run: wafer login --token <paste-token>
|
|
2634
2801
|
|
|
2635
2802
|
Examples:
|
|
2636
|
-
wafer
|
|
2637
|
-
wafer
|
|
2638
|
-
wafer
|
|
2639
|
-
wafer
|
|
2803
|
+
wafer login # device code on SSH, browser on local
|
|
2804
|
+
wafer login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2805
|
+
wafer login --port 9000 # custom port for browser flow
|
|
2806
|
+
wafer login --token xyz # manual token (no browser)
|
|
2640
2807
|
|
|
2641
2808
|
# Change environment:
|
|
2642
2809
|
wafer config set api.environment staging
|
|
2643
|
-
wafer
|
|
2810
|
+
wafer login
|
|
2644
2811
|
"""
|
|
2645
2812
|
import httpx
|
|
2646
2813
|
|
|
@@ -2724,7 +2891,7 @@ def login(
|
|
|
2724
2891
|
typer.echo("Token saved to ~/.wafer/credentials.json")
|
|
2725
2892
|
|
|
2726
2893
|
|
|
2727
|
-
@
|
|
2894
|
+
@app.command("logout")
|
|
2728
2895
|
def logout() -> None:
|
|
2729
2896
|
"""Remove stored credentials."""
|
|
2730
2897
|
from . import analytics
|
|
@@ -2741,7 +2908,7 @@ def logout() -> None:
|
|
|
2741
2908
|
typer.echo("Not logged in (no credentials found).")
|
|
2742
2909
|
|
|
2743
2910
|
|
|
2744
|
-
@
|
|
2911
|
+
@app.command("whoami")
|
|
2745
2912
|
def whoami(
|
|
2746
2913
|
verify: bool = typer.Option(False, "--verify", "-v", help="Verify token with API"),
|
|
2747
2914
|
refresh: bool = typer.Option(False, "--refresh", "-r", help="Refresh token if expired"),
|
|
@@ -2755,7 +2922,7 @@ def whoami(
|
|
|
2755
2922
|
|
|
2756
2923
|
creds = load_credentials()
|
|
2757
2924
|
if creds is None:
|
|
2758
|
-
typer.echo("Not logged in. Run: wafer
|
|
2925
|
+
typer.echo("Not logged in. Run: wafer login")
|
|
2759
2926
|
raise typer.Exit(1)
|
|
2760
2927
|
|
|
2761
2928
|
if verify or refresh:
|
|
@@ -2763,7 +2930,7 @@ def whoami(
|
|
|
2763
2930
|
# Try to get valid token with auto-refresh
|
|
2764
2931
|
token = get_valid_token()
|
|
2765
2932
|
if token is None:
|
|
2766
|
-
typer.echo("Token expired and refresh failed. Run: wafer
|
|
2933
|
+
typer.echo("Token expired and refresh failed. Run: wafer login", err=True)
|
|
2767
2934
|
raise typer.Exit(1)
|
|
2768
2935
|
if token != creds.access_token:
|
|
2769
2936
|
typer.echo("Token refreshed successfully")
|
|
@@ -2776,10 +2943,10 @@ def whoami(
|
|
|
2776
2943
|
except Exception as e:
|
|
2777
2944
|
if creds.refresh_token and not refresh:
|
|
2778
2945
|
typer.echo(f"Token expired: {e}", err=True)
|
|
2779
|
-
typer.echo("Try: wafer
|
|
2946
|
+
typer.echo("Try: wafer whoami --refresh", err=True)
|
|
2780
2947
|
else:
|
|
2781
2948
|
typer.echo(f"Token invalid or expired: {e}", err=True)
|
|
2782
|
-
typer.echo("Run: wafer
|
|
2949
|
+
typer.echo("Run: wafer login", err=True)
|
|
2783
2950
|
raise typer.Exit(1) from None
|
|
2784
2951
|
elif creds.email:
|
|
2785
2952
|
typer.echo(creds.email)
|
|
@@ -2787,7 +2954,7 @@ def whoami(
|
|
|
2787
2954
|
typer.echo("Logged in (email not available)")
|
|
2788
2955
|
|
|
2789
2956
|
|
|
2790
|
-
@app.command("guide"
|
|
2957
|
+
@app.command("guide")
|
|
2791
2958
|
def guide() -> None:
|
|
2792
2959
|
"""Show the Wafer CLI usage guide.
|
|
2793
2960
|
|
|
@@ -2818,7 +2985,7 @@ demo_app = typer.Typer(
|
|
|
2818
2985
|
wafer demo trace Analyze a sample performance trace
|
|
2819
2986
|
wafer demo eval Run kernel evaluation on cloud GPU (requires login)"""
|
|
2820
2987
|
)
|
|
2821
|
-
app.add_typer(demo_app, name="demo"
|
|
2988
|
+
app.add_typer(demo_app, name="demo")
|
|
2822
2989
|
|
|
2823
2990
|
DEMO_TRACES_URL = "https://github.com/wafer-ai/wafer/raw/main/apps/wafer-cli/wafer/demo_data"
|
|
2824
2991
|
DEMO_DIR = Path.home() / ".cache" / "wafer" / "demo"
|
|
@@ -3038,7 +3205,7 @@ def demo_eval(
|
|
|
3038
3205
|
"""Demo: Evaluate a kernel on a cloud GPU.
|
|
3039
3206
|
|
|
3040
3207
|
Creates a workspace, runs a sample Triton kernel evaluation, and cleans up.
|
|
3041
|
-
Requires authentication (wafer
|
|
3208
|
+
Requires authentication (wafer login).
|
|
3042
3209
|
|
|
3043
3210
|
Example:
|
|
3044
3211
|
wafer demo eval
|
|
@@ -3053,7 +3220,7 @@ def demo_eval(
|
|
|
3053
3220
|
# Check auth first
|
|
3054
3221
|
creds = load_credentials()
|
|
3055
3222
|
if not creds:
|
|
3056
|
-
typer.echo("Error: Not authenticated. Run: wafer
|
|
3223
|
+
typer.echo("Error: Not authenticated. Run: wafer login")
|
|
3057
3224
|
raise typer.Exit(1)
|
|
3058
3225
|
|
|
3059
3226
|
if not yes:
|
|
@@ -4411,8 +4578,8 @@ def billing_usage(
|
|
|
4411
4578
|
"""Show current billing usage and subscription info.
|
|
4412
4579
|
|
|
4413
4580
|
Example:
|
|
4414
|
-
wafer
|
|
4415
|
-
wafer
|
|
4581
|
+
wafer billing
|
|
4582
|
+
wafer billing --json
|
|
4416
4583
|
"""
|
|
4417
4584
|
# Only show usage if no subcommand was invoked
|
|
4418
4585
|
if ctx.invoked_subcommand is not None:
|
|
@@ -4440,9 +4607,9 @@ def billing_topup(
|
|
|
4440
4607
|
Opens a Stripe checkout page to add credits. Default amount is $25.
|
|
4441
4608
|
|
|
4442
4609
|
Example:
|
|
4443
|
-
wafer
|
|
4444
|
-
wafer
|
|
4445
|
-
wafer
|
|
4610
|
+
wafer billing topup # Add $25
|
|
4611
|
+
wafer billing topup 100 # Add $100
|
|
4612
|
+
wafer billing topup --no-browser # Print URL instead
|
|
4446
4613
|
"""
|
|
4447
4614
|
import webbrowser
|
|
4448
4615
|
|
|
@@ -4488,8 +4655,8 @@ def billing_portal(
|
|
|
4488
4655
|
Manage your subscription, update payment method, or view invoices.
|
|
4489
4656
|
|
|
4490
4657
|
Example:
|
|
4491
|
-
wafer
|
|
4492
|
-
wafer
|
|
4658
|
+
wafer billing portal
|
|
4659
|
+
wafer billing portal --no-browser
|
|
4493
4660
|
"""
|
|
4494
4661
|
import webbrowser
|
|
4495
4662
|
|
|
@@ -4526,8 +4693,8 @@ def ssh_keys_list(
|
|
|
4526
4693
|
"""List all registered SSH public keys.
|
|
4527
4694
|
|
|
4528
4695
|
Example:
|
|
4529
|
-
wafer
|
|
4530
|
-
wafer
|
|
4696
|
+
wafer ssh-keys list
|
|
4697
|
+
wafer ssh-keys list --json
|
|
4531
4698
|
"""
|
|
4532
4699
|
from .ssh_keys import list_ssh_keys
|
|
4533
4700
|
|
|
@@ -4553,9 +4720,9 @@ def ssh_keys_add(
|
|
|
4553
4720
|
id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
|
|
4554
4721
|
|
|
4555
4722
|
Example:
|
|
4556
|
-
wafer
|
|
4557
|
-
wafer
|
|
4558
|
-
wafer
|
|
4723
|
+
wafer ssh-keys add # Auto-detect
|
|
4724
|
+
wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
|
|
4725
|
+
wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
|
|
4559
4726
|
"""
|
|
4560
4727
|
from .ssh_keys import add_ssh_key
|
|
4561
4728
|
|
|
@@ -4574,10 +4741,10 @@ def ssh_keys_remove(
|
|
|
4574
4741
|
) -> None:
|
|
4575
4742
|
"""Remove an SSH public key.
|
|
4576
4743
|
|
|
4577
|
-
Get the key ID from 'wafer
|
|
4744
|
+
Get the key ID from 'wafer ssh-keys list'.
|
|
4578
4745
|
|
|
4579
4746
|
Example:
|
|
4580
|
-
wafer
|
|
4747
|
+
wafer ssh-keys remove abc123-def456-...
|
|
4581
4748
|
"""
|
|
4582
4749
|
from .ssh_keys import remove_ssh_key
|
|
4583
4750
|
|
|
@@ -5064,18 +5231,6 @@ def workspaces_pull(
|
|
|
5064
5231
|
raise typer.Exit(1) from None
|
|
5065
5232
|
|
|
5066
5233
|
|
|
5067
|
-
# =============================================================================
|
|
5068
|
-
# Live resource commands (list/terminate/reconcile/provision)
|
|
5069
|
-
# =============================================================================
|
|
5070
|
-
|
|
5071
|
-
targets_ops_app.command("list")(_targets_list_cmd)
|
|
5072
|
-
targets_ops_app.command("terminate")(_targets_terminate_cmd)
|
|
5073
|
-
targets_ops_app.command("reconcile")(_targets_reconcile_cmd)
|
|
5074
|
-
targets_ops_app.command("provision")(_targets_provision_cmd)
|
|
5075
|
-
targets_ops_app.command("pools")(_targets_pools_cmd)
|
|
5076
|
-
targets_ops_app.command("probe")(_targets_probe_cmd)
|
|
5077
|
-
|
|
5078
|
-
|
|
5079
5234
|
# =============================================================================
|
|
5080
5235
|
# Target operations commands (exec/ssh/sync)
|
|
5081
5236
|
# =============================================================================
|
|
@@ -5834,9 +5989,9 @@ def ncu_analyze(
|
|
|
5834
5989
|
compute/memory throughput, and optimization recommendations.
|
|
5835
5990
|
|
|
5836
5991
|
By default, uses local NCU if available, otherwise runs analysis
|
|
5837
|
-
remotely via wafer-api (requires authentication: wafer
|
|
5992
|
+
remotely via wafer-api (requires authentication: wafer login).
|
|
5838
5993
|
|
|
5839
|
-
Use --target for direct SSH mode.
|
|
5994
|
+
Use --target for direct SSH mode (like wafer remote-run --direct).
|
|
5840
5995
|
Use --include-source to fetch SASS assembly with register/instruction data.
|
|
5841
5996
|
|
|
5842
5997
|
Examples:
|
|
@@ -5929,7 +6084,7 @@ def nsys_analyze(
|
|
|
5929
6084
|
Returns timeline events, kernel information, memory usage, and diagnostics.
|
|
5930
6085
|
|
|
5931
6086
|
By default, uses local nsys if available, otherwise runs analysis
|
|
5932
|
-
remotely via wafer-api (requires authentication: wafer
|
|
6087
|
+
remotely via wafer-api (requires authentication: wafer login).
|
|
5933
6088
|
|
|
5934
6089
|
Supports multiple execution modes:
|
|
5935
6090
|
- Local: Uses local nsys CLI (no GPU required for analysis)
|
|
@@ -6914,7 +7069,7 @@ def autotuner_results(
|
|
|
6914
7069
|
raise typer.Exit(1) from None
|
|
6915
7070
|
|
|
6916
7071
|
|
|
6917
|
-
@app.command("capture"
|
|
7072
|
+
@app.command("capture")
|
|
6918
7073
|
def capture_command( # noqa: PLR0915
|
|
6919
7074
|
label: str = typer.Argument(
|
|
6920
7075
|
..., help="Label for this capture (e.g., 'baseline', 'optimized-v2')"
|
|
@@ -7594,29 +7749,18 @@ def compare_analyze(
|
|
|
7594
7749
|
"-f",
|
|
7595
7750
|
help="Output format: text, text-layers, csv, csv-layers, json",
|
|
7596
7751
|
),
|
|
7597
|
-
output: Path | None = typer.Option(
|
|
7598
|
-
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7599
|
-
),
|
|
7752
|
+
output: Path | None = typer.Option(None, "--output", "-o", help="Output file (default: stdout)"),
|
|
7600
7753
|
phase: str = typer.Option(
|
|
7601
7754
|
"all",
|
|
7602
7755
|
"--phase",
|
|
7603
7756
|
help="Filter by phase: all, prefill, decode",
|
|
7604
7757
|
),
|
|
7605
7758
|
layers: bool = typer.Option(False, "--layers", help="Show layer-wise performance breakdown"),
|
|
7606
|
-
all: bool = typer.Option(
|
|
7607
|
-
|
|
7608
|
-
),
|
|
7609
|
-
stack_traces: bool = typer.Option(
|
|
7610
|
-
False, "--stack-traces", help="Show Python stack traces for operations"
|
|
7611
|
-
),
|
|
7612
|
-
recommendations: bool = typer.Option(
|
|
7613
|
-
False, "--recommendations", help="Generate prioritized recommendations for kernel team"
|
|
7614
|
-
),
|
|
7615
|
-
json: bool = typer.Option(
|
|
7616
|
-
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7617
|
-
),
|
|
7759
|
+
all: bool = typer.Option(False, "--all", help="Show all items (no truncation for layers, operations, kernels)"),
|
|
7760
|
+
stack_traces: bool = typer.Option(False, "--stack-traces", help="Show Python stack traces for operations"),
|
|
7761
|
+
json: bool = typer.Option(False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"),
|
|
7618
7762
|
) -> None:
|
|
7619
|
-
"""Compare GPU traces from
|
|
7763
|
+
"""Compare GPU traces from two platforms platforms.
|
|
7620
7764
|
|
|
7621
7765
|
Analyzes performance differences between traces, identifying which operations
|
|
7622
7766
|
are faster/slower on each platform and providing kernel-level details.
|
|
@@ -7664,7 +7808,6 @@ def compare_analyze(
|
|
|
7664
7808
|
show_layers=layers,
|
|
7665
7809
|
show_all=all,
|
|
7666
7810
|
show_stack_traces=stack_traces,
|
|
7667
|
-
recommendations=recommendations,
|
|
7668
7811
|
)
|
|
7669
7812
|
_mark_command_success()
|
|
7670
7813
|
|
|
@@ -7679,17 +7822,13 @@ def compare_fusion_cmd(
|
|
|
7679
7822
|
"-f",
|
|
7680
7823
|
help="Output format: text, csv, json",
|
|
7681
7824
|
),
|
|
7682
|
-
output: Path | None = typer.Option(
|
|
7683
|
-
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7684
|
-
),
|
|
7825
|
+
output: Path | None = typer.Option(None, "--output", "-o", help="Output file (default: stdout)"),
|
|
7685
7826
|
min_group_size: int = typer.Option(
|
|
7686
7827
|
50,
|
|
7687
7828
|
"--min-group-size",
|
|
7688
7829
|
help="Minimum correlation group size to analyze",
|
|
7689
7830
|
),
|
|
7690
|
-
json: bool = typer.Option(
|
|
7691
|
-
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7692
|
-
),
|
|
7831
|
+
json: bool = typer.Option(False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"),
|
|
7693
7832
|
) -> None:
|
|
7694
7833
|
"""Analyze kernel fusion differences between AMD and NVIDIA traces.
|
|
7695
7834
|
|
|
@@ -7709,69 +7848,14 @@ def compare_fusion_cmd(
|
|
|
7709
7848
|
# CSV output to file
|
|
7710
7849
|
wafer compare fusion amd_trace.json nvidia_trace.json --format csv -o fusion.csv
|
|
7711
7850
|
"""
|
|
7712
|
-
from .trace_compare import
|
|
7713
|
-
|
|
7714
|
-
compare_align(
|
|
7715
|
-
trace1=trace1,
|
|
7716
|
-
trace2=trace2,
|
|
7717
|
-
output=output,
|
|
7718
|
-
output_format=format,
|
|
7719
|
-
phase="all",
|
|
7720
|
-
)
|
|
7721
|
-
_mark_command_success()
|
|
7722
|
-
|
|
7723
|
-
|
|
7724
|
-
@compare_app.command("align")
|
|
7725
|
-
def compare_align_cmd(
|
|
7726
|
-
trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
|
|
7727
|
-
trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
|
|
7728
|
-
format: str = typer.Option(
|
|
7729
|
-
"json",
|
|
7730
|
-
"--format",
|
|
7731
|
-
"-f",
|
|
7732
|
-
help="Output format: json",
|
|
7733
|
-
),
|
|
7734
|
-
output: Path | None = typer.Option(
|
|
7735
|
-
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7736
|
-
),
|
|
7737
|
-
phase: str = typer.Option(
|
|
7738
|
-
"all",
|
|
7739
|
-
"--phase",
|
|
7740
|
-
help="Filter by phase: all, prefill, decode",
|
|
7741
|
-
),
|
|
7742
|
-
layer: int | None = typer.Option(
|
|
7743
|
-
None,
|
|
7744
|
-
"--layer",
|
|
7745
|
-
help="Focus on specific layer number",
|
|
7746
|
-
),
|
|
7747
|
-
) -> None:
|
|
7748
|
-
"""Align kernels at layer level for exact kernel-to-kernel comparison.
|
|
7749
|
-
|
|
7750
|
-
Provides kernel-to-kernel mapping across AMD and NVIDIA platforms,
|
|
7751
|
-
showing which kernels correspond to each other at each layer position.
|
|
7752
|
-
|
|
7753
|
-
Examples:
|
|
7754
|
-
# Basic alignment (stdout JSON)
|
|
7755
|
-
wafer compare align amd_trace.json nvidia_trace.json
|
|
7756
|
-
|
|
7757
|
-
# Save to file
|
|
7758
|
-
wafer compare align amd_trace.json nvidia_trace.json -o alignment.json
|
|
7759
|
-
|
|
7760
|
-
# Focus on decode phase only
|
|
7761
|
-
wafer compare align amd_trace.json nvidia_trace.json --phase decode
|
|
7762
|
-
|
|
7763
|
-
# Focus on specific layer
|
|
7764
|
-
wafer compare align amd_trace.json nvidia_trace.json --layer 5
|
|
7765
|
-
"""
|
|
7766
|
-
from .trace_compare import compare_align
|
|
7851
|
+
from .trace_compare import compare_fusion
|
|
7767
7852
|
|
|
7768
|
-
|
|
7853
|
+
compare_fusion(
|
|
7769
7854
|
trace1=trace1,
|
|
7770
7855
|
trace2=trace2,
|
|
7771
7856
|
output=output,
|
|
7772
|
-
|
|
7773
|
-
|
|
7774
|
-
layer=layer,
|
|
7857
|
+
format_type=format,
|
|
7858
|
+
min_group_size=min_group_size,
|
|
7775
7859
|
)
|
|
7776
7860
|
_mark_command_success()
|
|
7777
7861
|
|