wafer-cli 0.2.31__py3-none-any.whl → 0.2.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +157 -2
- wafer/billing.py +6 -6
- wafer/cli.py +432 -346
- wafer/corpus.py +6 -72
- wafer/evaluate.py +143 -81
- wafer/global_config.py +0 -13
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +6 -22
- wafer/ssh_keys.py +6 -6
- wafer/targets_ops.py +2 -29
- wafer/templates/aiter_optimize.py +59 -0
- wafer/templates/optimize_kernel.py +2 -4
- wafer/templates/optimize_kernelbench.py +62 -17
- wafer/templates/optimize_vllm.py +156 -0
- wafer/trace_compare.py +48 -139
- wafer/wevin_cli.py +1 -12
- wafer/workspaces.py +8 -8
- wafer_cli-0.2.33.dist-info/METADATA +260 -0
- {wafer_cli-0.2.31.dist-info → wafer_cli-0.2.33.dist-info}/RECORD +25 -23
- wafer_cli-0.2.31.dist-info/METADATA +0 -107
- {wafer_cli-0.2.31.dist-info → wafer_cli-0.2.33.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.31.dist-info → wafer_cli-0.2.33.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.31.dist-info → wafer_cli-0.2.33.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -8,7 +8,6 @@
|
|
|
8
8
|
Core commands:
|
|
9
9
|
agent AI assistant for GPU kernel development
|
|
10
10
|
evaluate Test kernel correctness and performance
|
|
11
|
-
baseline Discover what kernel PyTorch uses for an op
|
|
12
11
|
corpus Download GPU documentation for local access
|
|
13
12
|
workspaces Manage cloud GPU environments
|
|
14
13
|
|
|
@@ -195,16 +194,11 @@ def complete_target_name(incomplete: str) -> list[str]:
|
|
|
195
194
|
|
|
196
195
|
# =============================================================================
|
|
197
196
|
# Core subcommand groups (visible in --help)
|
|
198
|
-
#
|
|
199
|
-
# TODO: Further consolidate top-level commands to reduce --help surface area.
|
|
200
|
-
# Candidates:
|
|
201
|
-
# - compare → wafer nvidia compare or keep top-level (cross-platform)
|
|
202
|
-
# - guide/skill/demo → wafer onboard {guide,skill,demo}
|
|
203
197
|
# =============================================================================
|
|
204
198
|
|
|
205
199
|
# Config management (includes targets as nested subcommand)
|
|
206
200
|
config_app = typer.Typer(help="Manage CLI configuration and local GPU targets")
|
|
207
|
-
app.add_typer(config_app, name="config"
|
|
201
|
+
app.add_typer(config_app, name="config")
|
|
208
202
|
|
|
209
203
|
# Target management - nested under config
|
|
210
204
|
targets_app = typer.Typer(
|
|
@@ -224,7 +218,7 @@ config_app.add_typer(targets_app, name="targets")
|
|
|
224
218
|
workspaces_app = typer.Typer(
|
|
225
219
|
help="""Manage cloud GPU workspaces for remote development.
|
|
226
220
|
|
|
227
|
-
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer
|
|
221
|
+
Workspaces are on-demand cloud GPU environments. Requires authentication (wafer login).
|
|
228
222
|
|
|
229
223
|
Available GPUs:
|
|
230
224
|
MI300X AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
@@ -237,21 +231,21 @@ Commands:
|
|
|
237
231
|
wafer workspaces sync dev ./project # Sync files
|
|
238
232
|
wafer workspaces delete dev # Clean up"""
|
|
239
233
|
)
|
|
240
|
-
app.add_typer(workspaces_app, name="workspaces"
|
|
234
|
+
app.add_typer(workspaces_app, name="workspaces")
|
|
241
235
|
|
|
242
|
-
# SSH Key management (BYOK - Bring Your Own Key)
|
|
236
|
+
# SSH Key management (BYOK - Bring Your Own Key)
|
|
243
237
|
ssh_keys_app = typer.Typer(
|
|
244
238
|
help="""Manage SSH public keys for workspace access.
|
|
245
239
|
|
|
246
240
|
Register your SSH public keys here. These keys are installed in all workspaces
|
|
247
241
|
you provision, enabling SSH access from any machine with your private key.
|
|
248
242
|
|
|
249
|
-
wafer
|
|
250
|
-
wafer
|
|
251
|
-
wafer
|
|
252
|
-
wafer
|
|
243
|
+
wafer ssh-keys list # List registered keys
|
|
244
|
+
wafer ssh-keys add # Add key (auto-detects ~/.ssh/id_ed25519.pub)
|
|
245
|
+
wafer ssh-keys add ~/.ssh/id_rsa.pub --name laptop # Add specific key
|
|
246
|
+
wafer ssh-keys remove <key-id> # Remove a key"""
|
|
253
247
|
)
|
|
254
|
-
|
|
248
|
+
app.add_typer(ssh_keys_app, name="ssh-keys")
|
|
255
249
|
|
|
256
250
|
# Target operations (exec/ssh/sync on configured targets)
|
|
257
251
|
targets_ops_app = typer.Typer(
|
|
@@ -267,48 +261,22 @@ Useful for exploratory work, debugging, or custom scripts.
|
|
|
267
261
|
Supports: RunPod, DigitalOcean (auto-provisions), SSH targets (baremetal/vm).
|
|
268
262
|
Configure targets with: wafer config targets init ..."""
|
|
269
263
|
)
|
|
270
|
-
app.add_typer(targets_ops_app, name="targets"
|
|
264
|
+
app.add_typer(targets_ops_app, name="targets")
|
|
271
265
|
|
|
272
|
-
#
|
|
273
|
-
from wafer.specs_cli import specs_app
|
|
274
|
-
|
|
275
|
-
app.add_typer(specs_app, name="specs", rich_help_panel="Configuration")
|
|
276
|
-
|
|
277
|
-
# Live resource management (new: API-backed commands on `wafer targets`)
|
|
278
|
-
# These become: wafer targets list, wafer targets terminate, etc.
|
|
279
|
-
from wafer.targets_cli import (
|
|
280
|
-
targets_list as _targets_list_cmd,
|
|
281
|
-
)
|
|
282
|
-
from wafer.targets_cli import (
|
|
283
|
-
targets_pools as _targets_pools_cmd,
|
|
284
|
-
)
|
|
285
|
-
from wafer.targets_cli import (
|
|
286
|
-
targets_probe as _targets_probe_cmd,
|
|
287
|
-
)
|
|
288
|
-
from wafer.targets_cli import (
|
|
289
|
-
targets_provision as _targets_provision_cmd,
|
|
290
|
-
)
|
|
291
|
-
from wafer.targets_cli import (
|
|
292
|
-
targets_reconcile as _targets_reconcile_cmd,
|
|
293
|
-
)
|
|
294
|
-
from wafer.targets_cli import (
|
|
295
|
-
targets_terminate as _targets_terminate_cmd,
|
|
296
|
-
)
|
|
297
|
-
|
|
298
|
-
# Billing management - nested under config
|
|
266
|
+
# Billing management
|
|
299
267
|
billing_app = typer.Typer(help="Manage billing, credits, and subscription")
|
|
300
|
-
|
|
268
|
+
app.add_typer(billing_app, name="billing")
|
|
301
269
|
|
|
302
270
|
# Corpus management
|
|
303
271
|
corpus_app = typer.Typer(help="Download and manage GPU documentation")
|
|
304
|
-
app.add_typer(corpus_app, name="corpus"
|
|
272
|
+
app.add_typer(corpus_app, name="corpus")
|
|
305
273
|
|
|
306
274
|
# Evaluate (supports multiple kernel formats)
|
|
307
275
|
evaluate_app = typer.Typer(
|
|
308
276
|
help="Test kernel correctness and performance",
|
|
309
277
|
invoke_without_command=True,
|
|
310
278
|
)
|
|
311
|
-
app.add_typer(evaluate_app, name="evaluate"
|
|
279
|
+
app.add_typer(evaluate_app, name="evaluate")
|
|
312
280
|
|
|
313
281
|
# Nested subcommand for kernelbench format
|
|
314
282
|
kernelbench_app = typer.Typer(
|
|
@@ -324,11 +292,6 @@ gpumode_app = typer.Typer(
|
|
|
324
292
|
)
|
|
325
293
|
evaluate_app.add_typer(gpumode_app, name="gpumode")
|
|
326
294
|
|
|
327
|
-
# Baseline discovery (what kernel does PyTorch use?)
|
|
328
|
-
from wafer.baseline import baseline_app
|
|
329
|
-
|
|
330
|
-
app.add_typer(baseline_app, name="baseline", rich_help_panel="Kernel Development")
|
|
331
|
-
|
|
332
295
|
# =============================================================================
|
|
333
296
|
# Dev commands (internal, used by web app proxy)
|
|
334
297
|
# =============================================================================
|
|
@@ -342,7 +305,7 @@ app.add_typer(dev_app, name="dev")
|
|
|
342
305
|
# =============================================================================
|
|
343
306
|
|
|
344
307
|
nvidia_app = typer.Typer(help="NVIDIA GPU profiling and analysis tools")
|
|
345
|
-
app.add_typer(nvidia_app, name="nvidia"
|
|
308
|
+
app.add_typer(nvidia_app, name="nvidia")
|
|
346
309
|
|
|
347
310
|
# NCU analysis - under nvidia
|
|
348
311
|
ncu_app = typer.Typer(help="Nsight Compute profile analysis")
|
|
@@ -365,7 +328,7 @@ nvidia_app.add_typer(tracelens_app, name="tracelens")
|
|
|
365
328
|
# =============================================================================
|
|
366
329
|
|
|
367
330
|
amd_app = typer.Typer(help="AMD GPU profiling and analysis tools")
|
|
368
|
-
app.add_typer(amd_app, name="amd"
|
|
331
|
+
app.add_typer(amd_app, name="amd")
|
|
369
332
|
|
|
370
333
|
# Unified ISA Analyzer - supports both .co files and Triton artifacts
|
|
371
334
|
isa_app = typer.Typer(help="ISA analysis for AMD GPU kernels (.co, .s, .ll, .ttgir files)")
|
|
@@ -376,14 +339,14 @@ amd_app.add_typer(isa_app, name="isa")
|
|
|
376
339
|
# =============================================================================
|
|
377
340
|
|
|
378
341
|
compare_app = typer.Typer(help="Compare GPU traces across platforms (AMD vs NVIDIA)")
|
|
379
|
-
app.add_typer(compare_app, name="compare"
|
|
342
|
+
app.add_typer(compare_app, name="compare")
|
|
380
343
|
|
|
381
344
|
# =============================================================================
|
|
382
345
|
# Roofline analysis (wafer roofline)
|
|
383
346
|
# =============================================================================
|
|
384
347
|
|
|
385
348
|
|
|
386
|
-
@app.command("roofline"
|
|
349
|
+
@app.command("roofline")
|
|
387
350
|
def roofline_cmd(
|
|
388
351
|
gpu: str | None = typer.Option(
|
|
389
352
|
None, "--gpu", "-g", help="GPU name (e.g., H100, B200, MI300X, A100)"
|
|
@@ -474,7 +437,7 @@ def roofline_cmd(
|
|
|
474
437
|
# =============================================================================
|
|
475
438
|
|
|
476
439
|
skill_app = typer.Typer(help="Manage AI coding assistant skills (Claude Code, Codex)")
|
|
477
|
-
app.add_typer(skill_app, name="skill"
|
|
440
|
+
app.add_typer(skill_app, name="skill")
|
|
478
441
|
|
|
479
442
|
|
|
480
443
|
@skill_app.command("install")
|
|
@@ -638,19 +601,14 @@ def skill_status() -> None:
|
|
|
638
601
|
|
|
639
602
|
|
|
640
603
|
# =============================================================================
|
|
641
|
-
#
|
|
604
|
+
# Provider auth management (wafer auth ...)
|
|
642
605
|
# =============================================================================
|
|
643
606
|
|
|
644
|
-
|
|
645
|
-
app.add_typer(
|
|
646
|
-
|
|
647
|
-
providers_app = typer.Typer(
|
|
648
|
-
help="Manage API keys for cloud GPU providers (RunPod, DigitalOcean, etc.)"
|
|
649
|
-
)
|
|
650
|
-
auth_app.add_typer(providers_app, name="providers")
|
|
607
|
+
provider_auth_app = typer.Typer(help="Manage API keys for cloud GPU providers")
|
|
608
|
+
app.add_typer(provider_auth_app, name="auth")
|
|
651
609
|
|
|
652
610
|
|
|
653
|
-
@
|
|
611
|
+
@provider_auth_app.command("login")
|
|
654
612
|
def provider_auth_login(
|
|
655
613
|
provider: str = typer.Argument(
|
|
656
614
|
...,
|
|
@@ -669,10 +627,10 @@ def provider_auth_login(
|
|
|
669
627
|
(e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
|
|
670
628
|
|
|
671
629
|
Examples:
|
|
672
|
-
wafer auth
|
|
673
|
-
wafer auth
|
|
674
|
-
wafer auth
|
|
675
|
-
echo $API_KEY | wafer auth
|
|
630
|
+
wafer auth login anthropic --api-key sk-ant-xxx
|
|
631
|
+
wafer auth login runpod --api-key rp_xxx
|
|
632
|
+
wafer auth login openai --api-key sk-xxx
|
|
633
|
+
echo $API_KEY | wafer auth login anthropic
|
|
676
634
|
"""
|
|
677
635
|
import sys
|
|
678
636
|
|
|
@@ -702,7 +660,7 @@ def provider_auth_login(
|
|
|
702
660
|
typer.echo("Stored in: ~/.wafer/auth.json")
|
|
703
661
|
|
|
704
662
|
|
|
705
|
-
@
|
|
663
|
+
@provider_auth_app.command("logout")
|
|
706
664
|
def provider_auth_logout(
|
|
707
665
|
provider: str = typer.Argument(
|
|
708
666
|
...,
|
|
@@ -712,8 +670,8 @@ def provider_auth_logout(
|
|
|
712
670
|
"""Remove stored API key for a cloud GPU provider.
|
|
713
671
|
|
|
714
672
|
Examples:
|
|
715
|
-
wafer auth
|
|
716
|
-
wafer auth
|
|
673
|
+
wafer auth logout runpod
|
|
674
|
+
wafer auth logout digitalocean
|
|
717
675
|
"""
|
|
718
676
|
from wafer_core.auth import PROVIDERS, remove_api_key
|
|
719
677
|
|
|
@@ -729,7 +687,7 @@ def provider_auth_logout(
|
|
|
729
687
|
typer.echo(f"No stored API key found for {PROVIDERS[provider]['display_name']}")
|
|
730
688
|
|
|
731
689
|
|
|
732
|
-
@
|
|
690
|
+
@provider_auth_app.command("status")
|
|
733
691
|
def provider_auth_status() -> None:
|
|
734
692
|
"""Show authentication status for all cloud GPU providers.
|
|
735
693
|
|
|
@@ -737,7 +695,7 @@ def provider_auth_status() -> None:
|
|
|
737
695
|
the keys are coming from (environment variable or auth.json).
|
|
738
696
|
|
|
739
697
|
Example:
|
|
740
|
-
wafer auth
|
|
698
|
+
wafer auth status
|
|
741
699
|
"""
|
|
742
700
|
from wafer_core.auth import get_all_auth_status
|
|
743
701
|
|
|
@@ -752,7 +710,7 @@ def provider_auth_status() -> None:
|
|
|
752
710
|
typer.echo(f" {status.display_name}: ✓ {status.key_preview} {source_str}")
|
|
753
711
|
else:
|
|
754
712
|
typer.echo(f" {status.display_name}: ✗ Not configured")
|
|
755
|
-
typer.echo(f" Run: wafer auth
|
|
713
|
+
typer.echo(f" Run: wafer auth login {status.provider}")
|
|
756
714
|
typer.echo(f" Or set: {status.key_url}")
|
|
757
715
|
|
|
758
716
|
typer.echo("")
|
|
@@ -1297,7 +1255,7 @@ def config_show_legacy() -> None:
|
|
|
1297
1255
|
config_show_new()
|
|
1298
1256
|
|
|
1299
1257
|
|
|
1300
|
-
@app.command(
|
|
1258
|
+
@app.command()
|
|
1301
1259
|
def agent( # noqa: PLR0913
|
|
1302
1260
|
prompt: str | None = typer.Argument(
|
|
1303
1261
|
None,
|
|
@@ -1582,11 +1540,7 @@ def evaluate( # noqa: PLR0913
|
|
|
1582
1540
|
None, "--reference", help="Path to reference kernel file"
|
|
1583
1541
|
),
|
|
1584
1542
|
test_cases: Path | None = typer.Option(
|
|
1585
|
-
None,
|
|
1586
|
-
"--test-cases",
|
|
1587
|
-
help="Path to test cases JSON file. "
|
|
1588
|
-
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
1589
|
-
"Run 'wafer evaluate make-template' to generate an example.",
|
|
1543
|
+
None, "--test-cases", help="Path to test cases JSON file"
|
|
1590
1544
|
),
|
|
1591
1545
|
target: str | None = typer.Option(
|
|
1592
1546
|
None,
|
|
@@ -1598,9 +1552,7 @@ def evaluate( # noqa: PLR0913
|
|
|
1598
1552
|
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
1599
1553
|
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
1600
1554
|
defensive: bool = typer.Option(
|
|
1601
|
-
|
|
1602
|
-
"--defense/--no-defense",
|
|
1603
|
-
help="Run reward hack defense checks after benchmarking. Enabled by default.",
|
|
1555
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
1604
1556
|
),
|
|
1605
1557
|
sync_artifacts: bool = typer.Option(
|
|
1606
1558
|
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
@@ -1614,24 +1566,24 @@ def evaluate( # noqa: PLR0913
|
|
|
1614
1566
|
The evaluation checks:
|
|
1615
1567
|
1. Correctness: Does the kernel produce the same output as the reference?
|
|
1616
1568
|
2. Performance (--benchmark): How fast is it compared to the reference?
|
|
1617
|
-
3. Defense: Detects
|
|
1569
|
+
3. Defense (--defensive): Detects evaluation hacking (stream injection, etc.)
|
|
1618
1570
|
|
|
1619
1571
|
Examples:
|
|
1620
1572
|
# Basic correctness check
|
|
1621
|
-
wafer evaluate
|
|
1573
|
+
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json
|
|
1622
1574
|
|
|
1623
|
-
# With benchmarking
|
|
1624
|
-
wafer evaluate
|
|
1575
|
+
# With benchmarking on a specific target
|
|
1576
|
+
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1625
1577
|
--target vultr-b200 --benchmark
|
|
1626
1578
|
|
|
1627
|
-
#
|
|
1628
|
-
wafer evaluate
|
|
1629
|
-
--benchmark --
|
|
1579
|
+
# Full evaluation with defensive timing (detects cheating)
|
|
1580
|
+
wafer evaluate --impl kernel.py --reference ref.py --test-cases tests.json \\
|
|
1581
|
+
--benchmark --defensive
|
|
1630
1582
|
|
|
1631
1583
|
Subcommands:
|
|
1632
1584
|
gpumode Use GPUMode format (functional) - RECOMMENDED
|
|
1633
1585
|
kernelbench Use KernelBench format (ModelNew class)
|
|
1634
|
-
make-template Generate template files for this format
|
|
1586
|
+
make-template Generate template files for this format (deprecated)
|
|
1635
1587
|
"""
|
|
1636
1588
|
# If a subcommand is being invoked, skip the main evaluation logic
|
|
1637
1589
|
if ctx.invoked_subcommand is not None:
|
|
@@ -1785,7 +1737,7 @@ def evaluate_make_template(
|
|
|
1785
1737
|
typer.echo(f" 2. Edit {output_dir / 'reference.py'} with the ground truth + input generator")
|
|
1786
1738
|
typer.echo(f" 3. Edit {output_dir / 'test_cases.json'} with your test parameters")
|
|
1787
1739
|
typer.echo(" 4. Run:")
|
|
1788
|
-
typer.echo(f" wafer evaluate
|
|
1740
|
+
typer.echo(f" wafer evaluate --impl {output_dir / 'kernel.py'} \\")
|
|
1789
1741
|
typer.echo(f" --reference {output_dir / 'reference.py'} \\")
|
|
1790
1742
|
typer.echo(f" --test-cases {output_dir / 'test_cases.json'} --benchmark")
|
|
1791
1743
|
|
|
@@ -1849,95 +1801,6 @@ def kernelbench_list_problems() -> None:
|
|
|
1849
1801
|
raise typer.Exit(1) from None
|
|
1850
1802
|
|
|
1851
1803
|
|
|
1852
|
-
def _resolve_pool_query(pool: str, collector) -> tuple[str, object]:
|
|
1853
|
-
"""Resolve a PoolQuery pool to a target spec name + lock context.
|
|
1854
|
-
|
|
1855
|
-
Queries live providers, matches by pool query, locks one target,
|
|
1856
|
-
returns (spec_name, lock_context) for the evaluator.
|
|
1857
|
-
"""
|
|
1858
|
-
import trio
|
|
1859
|
-
from wafer_core.targets.pool import resolve_pool
|
|
1860
|
-
|
|
1861
|
-
from .target_lock import acquire_from_pool
|
|
1862
|
-
|
|
1863
|
-
matched_targets = trio.run(resolve_pool, pool)
|
|
1864
|
-
|
|
1865
|
-
if not matched_targets:
|
|
1866
|
-
collector.set_error("pool", "NoMatchingTargets", pool=pool)
|
|
1867
|
-
collector.finalize()
|
|
1868
|
-
raise typer.Exit(1)
|
|
1869
|
-
|
|
1870
|
-
# Filter to targets with a spec (evaluator needs spec fields)
|
|
1871
|
-
spec_targets = [t for t in matched_targets if t.spec_name]
|
|
1872
|
-
if not spec_targets:
|
|
1873
|
-
collector.set_error(
|
|
1874
|
-
"pool",
|
|
1875
|
-
"NoSpecTargets",
|
|
1876
|
-
pool=pool,
|
|
1877
|
-
message="Matched targets have no spec binding — evaluator needs spec fields",
|
|
1878
|
-
)
|
|
1879
|
-
collector.finalize()
|
|
1880
|
-
raise typer.Exit(1)
|
|
1881
|
-
|
|
1882
|
-
# Lock one by resource_id
|
|
1883
|
-
resource_ids = [t.resource_id for t in spec_targets]
|
|
1884
|
-
collector.emit("pool_acquire", pool=pool, count=len(resource_ids))
|
|
1885
|
-
|
|
1886
|
-
lock_ctx = acquire_from_pool(resource_ids)
|
|
1887
|
-
acquired_id = lock_ctx.__enter__()
|
|
1888
|
-
|
|
1889
|
-
if acquired_id is None:
|
|
1890
|
-
lock_ctx.__exit__(None, None, None)
|
|
1891
|
-
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=resource_ids)
|
|
1892
|
-
collector.finalize()
|
|
1893
|
-
raise typer.Exit(1)
|
|
1894
|
-
|
|
1895
|
-
# Map resource_id back to spec_name
|
|
1896
|
-
acquired_target = next(t for t in spec_targets if t.resource_id == acquired_id)
|
|
1897
|
-
spec_name = acquired_target.spec_name
|
|
1898
|
-
|
|
1899
|
-
collector.emit("pool_acquired", target=spec_name, resource_id=acquired_id)
|
|
1900
|
-
return spec_name, lock_ctx
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
def _resolve_pool_legacy(pool: str, collector) -> tuple[str, object]:
|
|
1904
|
-
"""Resolve an old-style pool (static target name list) to a target name + lock context.
|
|
1905
|
-
|
|
1906
|
-
Old format: [pools.name] targets = ["t1", "t2"]
|
|
1907
|
-
"""
|
|
1908
|
-
from .target_lock import acquire_from_pool
|
|
1909
|
-
from .targets import filter_pool_by_auth, get_pool
|
|
1910
|
-
|
|
1911
|
-
try:
|
|
1912
|
-
pool_targets = get_pool(pool)
|
|
1913
|
-
except FileNotFoundError as e:
|
|
1914
|
-
collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
|
|
1915
|
-
collector.finalize()
|
|
1916
|
-
raise typer.Exit(1) from None
|
|
1917
|
-
|
|
1918
|
-
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1919
|
-
if skipped:
|
|
1920
|
-
collector.emit("pool_auth_skip", targets=skipped)
|
|
1921
|
-
|
|
1922
|
-
if not usable_targets:
|
|
1923
|
-
collector.set_error("pool", "NoUsableTargets", pool=pool)
|
|
1924
|
-
collector.finalize()
|
|
1925
|
-
raise typer.Exit(1) from None
|
|
1926
|
-
|
|
1927
|
-
collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
|
|
1928
|
-
lock_ctx = acquire_from_pool(usable_targets)
|
|
1929
|
-
acquired_target = lock_ctx.__enter__()
|
|
1930
|
-
|
|
1931
|
-
if acquired_target is None:
|
|
1932
|
-
lock_ctx.__exit__(None, None, None)
|
|
1933
|
-
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
|
|
1934
|
-
collector.finalize()
|
|
1935
|
-
raise typer.Exit(1)
|
|
1936
|
-
|
|
1937
|
-
collector.emit("pool_acquired", target=acquired_target)
|
|
1938
|
-
return acquired_target, lock_ctx
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
1804
|
@kernelbench_app.callback(invoke_without_command=True)
|
|
1942
1805
|
def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
1943
1806
|
ctx: typer.Context,
|
|
@@ -1973,9 +1836,7 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
1973
1836
|
),
|
|
1974
1837
|
seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
|
|
1975
1838
|
defensive: bool = typer.Option(
|
|
1976
|
-
|
|
1977
|
-
"--defense/--no-defense",
|
|
1978
|
-
help="Run reward hack defense checks after benchmarking. Enabled by default.",
|
|
1839
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
1979
1840
|
),
|
|
1980
1841
|
backend: str | None = typer.Option(
|
|
1981
1842
|
None,
|
|
@@ -2015,20 +1876,16 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2015
1876
|
The evaluation checks:
|
|
2016
1877
|
1. Correctness: Does ModelNew.forward() produce same output as Model.forward()?
|
|
2017
1878
|
2. Performance (--benchmark): How fast is it compared to the reference?
|
|
2018
|
-
3. Defense: Detects
|
|
1879
|
+
3. Defense (--defensive): Detects evaluation hacking
|
|
2019
1880
|
|
|
2020
1881
|
Examples:
|
|
2021
1882
|
# Basic correctness check
|
|
2022
1883
|
wafer evaluate kernelbench --impl my_kernel.py --reference problem.py
|
|
2023
1884
|
|
|
2024
|
-
# With benchmarking
|
|
1885
|
+
# With benchmarking
|
|
2025
1886
|
wafer evaluate kernelbench --impl my_kernel.py --reference problem.py \\
|
|
2026
1887
|
--target vultr-b200 --benchmark
|
|
2027
1888
|
|
|
2028
|
-
# Benchmarking without defense checks
|
|
2029
|
-
wafer evaluate kernelbench --impl my_kernel.py --reference problem.py \\
|
|
2030
|
-
--target vultr-b200 --benchmark --no-defense
|
|
2031
|
-
|
|
2032
1889
|
Subcommands:
|
|
2033
1890
|
make-template Extract a KernelBench problem as template
|
|
2034
1891
|
"""
|
|
@@ -2074,12 +1931,39 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2074
1931
|
pool_lock_context = None
|
|
2075
1932
|
|
|
2076
1933
|
if pool:
|
|
2077
|
-
from
|
|
1934
|
+
from .target_lock import acquire_from_pool
|
|
1935
|
+
from .targets import filter_pool_by_auth, get_pool
|
|
2078
1936
|
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
1937
|
+
try:
|
|
1938
|
+
pool_targets = get_pool(pool)
|
|
1939
|
+
except FileNotFoundError as e:
|
|
1940
|
+
collector.set_error("pool", "PoolNotFound", pool=pool, message=str(e))
|
|
1941
|
+
collector.finalize()
|
|
1942
|
+
raise typer.Exit(1) from None
|
|
1943
|
+
|
|
1944
|
+
# Filter to only targets with valid auth
|
|
1945
|
+
usable_targets, skipped = filter_pool_by_auth(pool_targets)
|
|
1946
|
+
if skipped:
|
|
1947
|
+
collector.emit("pool_auth_skip", targets=skipped)
|
|
1948
|
+
|
|
1949
|
+
if not usable_targets:
|
|
1950
|
+
collector.set_error("pool", "NoUsableTargets", pool=pool)
|
|
1951
|
+
collector.finalize()
|
|
1952
|
+
raise typer.Exit(1) from None
|
|
1953
|
+
|
|
1954
|
+
collector.emit("pool_acquire", pool=pool, count=len(usable_targets))
|
|
1955
|
+
pool_lock_context = acquire_from_pool(usable_targets)
|
|
1956
|
+
acquired_target = pool_lock_context.__enter__()
|
|
1957
|
+
|
|
1958
|
+
if acquired_target is None:
|
|
1959
|
+
# Exit context manager before raising to avoid resource leak
|
|
1960
|
+
pool_lock_context.__exit__(None, None, None)
|
|
1961
|
+
collector.set_error("pool", "AllTargetsBusy", pool=pool, targets=usable_targets)
|
|
1962
|
+
collector.finalize()
|
|
1963
|
+
raise typer.Exit(1)
|
|
1964
|
+
|
|
1965
|
+
collector.emit("pool_acquired", target=acquired_target)
|
|
1966
|
+
resolved_target = acquired_target
|
|
2083
1967
|
|
|
2084
1968
|
collector.target = resolved_target
|
|
2085
1969
|
|
|
@@ -2088,15 +1972,12 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2088
1972
|
if stages == "all":
|
|
2089
1973
|
resolved_stages = "compile,correctness,benchmark,defense"
|
|
2090
1974
|
|
|
2091
|
-
# Handle --benchmark and --
|
|
1975
|
+
# Handle backward compat: --benchmark and --defensive flags add to stages
|
|
2092
1976
|
stage_set = set(resolved_stages.split(","))
|
|
2093
1977
|
if benchmark and "benchmark" not in stage_set:
|
|
2094
1978
|
stage_set.add("benchmark")
|
|
2095
|
-
|
|
2096
|
-
if defensive and "benchmark" in stage_set and "defense" not in stage_set:
|
|
1979
|
+
if defensive and "defense" not in stage_set:
|
|
2097
1980
|
stage_set.add("defense")
|
|
2098
|
-
if not defensive:
|
|
2099
|
-
stage_set.discard("defense")
|
|
2100
1981
|
resolved_stages = ",".join(
|
|
2101
1982
|
sorted(
|
|
2102
1983
|
stage_set,
|
|
@@ -2407,11 +2288,7 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2407
2288
|
None, "--reference", help="Path to reference kernel file"
|
|
2408
2289
|
),
|
|
2409
2290
|
test_cases: Path | None = typer.Option(
|
|
2410
|
-
None,
|
|
2411
|
-
"--test-cases",
|
|
2412
|
-
help="Path to test cases JSON file. "
|
|
2413
|
-
'Format: [{"name": "small", "n": 1024, "seed": 42}, ...]. '
|
|
2414
|
-
"Run 'wafer evaluate make-template' to generate an example.",
|
|
2291
|
+
None, "--test-cases", help="Path to test cases JSON file"
|
|
2415
2292
|
),
|
|
2416
2293
|
target: str | None = typer.Option(
|
|
2417
2294
|
None,
|
|
@@ -2430,9 +2307,7 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2430
2307
|
benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
|
|
2431
2308
|
profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
|
|
2432
2309
|
defensive: bool = typer.Option(
|
|
2433
|
-
|
|
2434
|
-
"--defense/--no-defense",
|
|
2435
|
-
help="Run reward hack defense checks after benchmarking. Enabled by default.",
|
|
2310
|
+
False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
|
|
2436
2311
|
),
|
|
2437
2312
|
sync_artifacts: bool = typer.Option(
|
|
2438
2313
|
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
@@ -2481,13 +2356,6 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2481
2356
|
err=True,
|
|
2482
2357
|
)
|
|
2483
2358
|
typer.echo("", err=True)
|
|
2484
|
-
if "--test-cases" in missing_args:
|
|
2485
|
-
typer.echo(
|
|
2486
|
-
"Tip: Run 'wafer evaluate make-template' to generate template files "
|
|
2487
|
-
"including test_cases.json.",
|
|
2488
|
-
err=True,
|
|
2489
|
-
)
|
|
2490
|
-
typer.echo("", err=True)
|
|
2491
2359
|
typer.echo("Run 'wafer evaluate gpumode --help' for full options.", err=True)
|
|
2492
2360
|
typer.echo("Run 'wafer evaluate gpumode download' to download problem sets.", err=True)
|
|
2493
2361
|
raise typer.Exit(1)
|
|
@@ -2588,12 +2456,313 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
|
|
|
2588
2456
|
else:
|
|
2589
2457
|
typer.echo(f"Error: {result.error_message}", err=True)
|
|
2590
2458
|
raise typer.Exit(1)
|
|
2459
|
+
|
|
2460
|
+
|
|
2461
|
+
# =============================================================================
|
|
2462
|
+
# Push and Remote-Run commands
|
|
2463
|
+
# =============================================================================
|
|
2464
|
+
|
|
2465
|
+
|
|
2466
|
+
@app.command("push", hidden=True)
|
|
2467
|
+
def push(
|
|
2468
|
+
local_path: Path = typer.Argument(..., help="Local directory to upload"),
|
|
2469
|
+
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace name override"),
|
|
2470
|
+
direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
|
|
2471
|
+
target_name: str | None = typer.Option(
|
|
2472
|
+
None,
|
|
2473
|
+
"--target",
|
|
2474
|
+
"-t",
|
|
2475
|
+
help="Target for --direct mode. See 'wafer config targets list'.",
|
|
2476
|
+
autocompletion=complete_target_name,
|
|
2477
|
+
),
|
|
2478
|
+
) -> None:
|
|
2479
|
+
"""Push directory to remote GPU.
|
|
2480
|
+
|
|
2481
|
+
By default, uses wafer-api. Use --direct for direct SSH mode.
|
|
2482
|
+
|
|
2483
|
+
Examples:
|
|
2484
|
+
wafer push ./my_project
|
|
2485
|
+
wafer push . --workspace my-kernel
|
|
2486
|
+
wafer push ./my_project --direct --target vultr-b200
|
|
2487
|
+
"""
|
|
2488
|
+
# Validate path
|
|
2489
|
+
if not local_path.exists():
|
|
2490
|
+
typer.echo(f"Error: Path not found: {local_path}", err=True)
|
|
2491
|
+
raise typer.Exit(1)
|
|
2492
|
+
|
|
2493
|
+
if not local_path.is_dir():
|
|
2494
|
+
typer.echo(f"Error: Not a directory: {local_path}", err=True)
|
|
2495
|
+
raise typer.Exit(1)
|
|
2496
|
+
|
|
2497
|
+
# Resolve to absolute path
|
|
2498
|
+
local_path = local_path.resolve()
|
|
2499
|
+
|
|
2500
|
+
if direct:
|
|
2501
|
+
# Direct SSH mode (requires target)
|
|
2502
|
+
if not target_name:
|
|
2503
|
+
typer.echo("Error: --target required for --direct mode", err=True)
|
|
2504
|
+
raise typer.Exit(1)
|
|
2505
|
+
|
|
2506
|
+
from wafer_core.utils.kernel_utils.targets.config import ModalTarget
|
|
2507
|
+
|
|
2508
|
+
from .gpu_run import push_directory as push_direct
|
|
2509
|
+
from .targets import load_target
|
|
2510
|
+
|
|
2511
|
+
try:
|
|
2512
|
+
target = load_target(target_name)
|
|
2513
|
+
except FileNotFoundError:
|
|
2514
|
+
typer.echo(f"Error: Target not found: {target_name}", err=True)
|
|
2515
|
+
typer.echo("List targets with: wafer config targets list", err=True)
|
|
2516
|
+
raise typer.Exit(1) from None
|
|
2517
|
+
|
|
2518
|
+
if isinstance(target, ModalTarget):
|
|
2519
|
+
typer.echo(
|
|
2520
|
+
f"Error: Target '{target_name}' is a Modal target. Direct push requires SSH.",
|
|
2521
|
+
err=True,
|
|
2522
|
+
)
|
|
2523
|
+
raise typer.Exit(1) from None
|
|
2524
|
+
|
|
2525
|
+
typer.echo(f"Connecting to {target.ssh_target}...")
|
|
2526
|
+
try:
|
|
2527
|
+
result = push_direct(local_path, target)
|
|
2528
|
+
except Exception as e:
|
|
2529
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2530
|
+
raise typer.Exit(1) from None
|
|
2531
|
+
|
|
2532
|
+
typer.echo(f"Uploading {len(result.files_uploaded)} files to {result.workspace_path}")
|
|
2533
|
+
for f in result.files_uploaded:
|
|
2534
|
+
typer.echo(f" ✓ {f}")
|
|
2535
|
+
typer.echo(f"Pushed to: {result.workspace_path}")
|
|
2536
|
+
else:
|
|
2537
|
+
# API mode (default)
|
|
2538
|
+
from .api_client import push_directory as push_api
|
|
2539
|
+
|
|
2540
|
+
workspace_name = workspace or local_path.name
|
|
2541
|
+
typer.echo(f"Pushing {local_path.name} to wafer-api...")
|
|
2542
|
+
|
|
2543
|
+
try:
|
|
2544
|
+
result = push_api(local_path, workspace_name)
|
|
2545
|
+
except Exception as e:
|
|
2546
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2547
|
+
raise typer.Exit(1) from None
|
|
2548
|
+
|
|
2549
|
+
typer.echo(f"Uploaded {len(result.files_uploaded)} files")
|
|
2550
|
+
for f in result.files_uploaded:
|
|
2551
|
+
typer.echo(f" ✓ {f}")
|
|
2552
|
+
typer.echo(f"Workspace ID: {result.workspace_id}")
|
|
2553
|
+
|
|
2554
|
+
|
|
2555
|
+
def _run_direct_mode(
|
|
2556
|
+
cmd_str: str,
|
|
2557
|
+
target_name: str,
|
|
2558
|
+
upload_dir: Path | None,
|
|
2559
|
+
workspace_id: str | None,
|
|
2560
|
+
gpu_id: int | None,
|
|
2561
|
+
) -> int:
|
|
2562
|
+
"""Run command via direct SSH mode. Returns exit code."""
|
|
2563
|
+
from wafer_core.utils.kernel_utils.targets.config import ModalTarget
|
|
2564
|
+
|
|
2565
|
+
from .gpu_run import push_directory as push_direct
|
|
2566
|
+
from .gpu_run import run_command as run_direct
|
|
2567
|
+
from .targets import load_target
|
|
2568
|
+
|
|
2569
|
+
try:
|
|
2570
|
+
target = load_target(target_name)
|
|
2571
|
+
except FileNotFoundError:
|
|
2572
|
+
typer.echo(f"Error: Target not found: {target_name}", err=True)
|
|
2573
|
+
typer.echo("List targets with: wafer config targets list", err=True)
|
|
2574
|
+
raise typer.Exit(1) from None
|
|
2575
|
+
|
|
2576
|
+
if isinstance(target, ModalTarget):
|
|
2577
|
+
typer.echo(
|
|
2578
|
+
f"Error: Target '{target_name}' is a Modal target. Direct mode requires SSH.", err=True
|
|
2579
|
+
)
|
|
2580
|
+
raise typer.Exit(1) from None
|
|
2581
|
+
|
|
2582
|
+
if not target.docker_image:
|
|
2583
|
+
typer.echo(f"Error: Target '{target_name}' has no docker_image configured", err=True)
|
|
2584
|
+
raise typer.Exit(1)
|
|
2585
|
+
|
|
2586
|
+
# If upload_dir provided, push first
|
|
2587
|
+
workspace_name = workspace_id
|
|
2588
|
+
if upload_dir:
|
|
2589
|
+
typer.echo(f"Uploading {upload_dir.name}...")
|
|
2590
|
+
try:
|
|
2591
|
+
push_result = push_direct(upload_dir, target)
|
|
2592
|
+
workspace_name = push_result.workspace_name
|
|
2593
|
+
typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
|
|
2594
|
+
except Exception as e:
|
|
2595
|
+
typer.echo(f"Error uploading: {e}", err=True)
|
|
2596
|
+
raise typer.Exit(1) from None
|
|
2597
|
+
elif not workspace_name:
|
|
2598
|
+
workspace_name = "tmp"
|
|
2599
|
+
|
|
2600
|
+
effective_gpu = gpu_id if gpu_id is not None else target.gpu_ids[0]
|
|
2601
|
+
typer.echo(f"Target: {target_name} (docker: {target.docker_image})")
|
|
2602
|
+
typer.echo(f"Workspace: {workspace_name}")
|
|
2603
|
+
typer.echo(f"GPU: {effective_gpu}")
|
|
2604
|
+
typer.echo(f"Command: {cmd_str}")
|
|
2605
|
+
typer.echo("-" * 60)
|
|
2606
|
+
|
|
2607
|
+
try:
|
|
2608
|
+
return run_direct(cmd_str, workspace_name, target, gpu_id)
|
|
2609
|
+
except KeyboardInterrupt:
|
|
2610
|
+
typer.echo("\nInterrupted by user", err=True)
|
|
2611
|
+
raise typer.Exit(130) from None
|
|
2612
|
+
except Exception as e:
|
|
2613
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2614
|
+
raise typer.Exit(1) from None
|
|
2615
|
+
|
|
2616
|
+
|
|
2617
|
+
def _run_api_mode( # noqa: PLR0913
|
|
2618
|
+
cmd_str: str,
|
|
2619
|
+
upload_dir: Path | None,
|
|
2620
|
+
workspace_id: str | None,
|
|
2621
|
+
gpu_id: int | None,
|
|
2622
|
+
gpu_count: int,
|
|
2623
|
+
docker_image: str | None,
|
|
2624
|
+
docker_entrypoint: str | None,
|
|
2625
|
+
pull_image: bool,
|
|
2626
|
+
require_hwc: bool,
|
|
2627
|
+
) -> int:
|
|
2628
|
+
"""Run command via wafer-api. Returns exit code."""
|
|
2629
|
+
from .api_client import run_command_stream
|
|
2630
|
+
|
|
2631
|
+
if upload_dir:
|
|
2632
|
+
typer.echo(f"Uploading: {upload_dir}")
|
|
2633
|
+
elif workspace_id:
|
|
2634
|
+
typer.echo(f"Workspace: {workspace_id}")
|
|
2635
|
+
if gpu_id is not None:
|
|
2636
|
+
typer.echo(f"GPU: {gpu_id}")
|
|
2637
|
+
if gpu_count > 1:
|
|
2638
|
+
typer.echo(f"GPU count: {gpu_count}")
|
|
2639
|
+
if docker_image:
|
|
2640
|
+
typer.echo(f"Image: {docker_image}")
|
|
2641
|
+
if docker_entrypoint:
|
|
2642
|
+
typer.echo(f"Entrypoint: {docker_entrypoint}")
|
|
2643
|
+
if pull_image:
|
|
2644
|
+
typer.echo("Pull image: yes")
|
|
2645
|
+
typer.echo(f"Command: {cmd_str}")
|
|
2646
|
+
if require_hwc:
|
|
2647
|
+
typer.echo("Hardware counters: required (baremetal)")
|
|
2648
|
+
typer.echo("-" * 60)
|
|
2649
|
+
|
|
2650
|
+
try:
|
|
2651
|
+
return run_command_stream(
|
|
2652
|
+
command=cmd_str,
|
|
2653
|
+
upload_dir=upload_dir,
|
|
2654
|
+
workspace_id=workspace_id,
|
|
2655
|
+
gpu_id=gpu_id,
|
|
2656
|
+
gpu_count=gpu_count,
|
|
2657
|
+
docker_image=docker_image,
|
|
2658
|
+
docker_entrypoint=docker_entrypoint,
|
|
2659
|
+
pull_image=pull_image,
|
|
2660
|
+
require_hardware_counters=require_hwc,
|
|
2661
|
+
)
|
|
2662
|
+
except KeyboardInterrupt:
|
|
2663
|
+
typer.echo("\nInterrupted by user", err=True)
|
|
2664
|
+
raise typer.Exit(130) from None
|
|
2665
|
+
except Exception as e:
|
|
2666
|
+
typer.echo(f"Error: {e}", err=True)
|
|
2667
|
+
raise typer.Exit(1) from None
|
|
2668
|
+
|
|
2669
|
+
|
|
2670
|
+
@app.command("remote-run", hidden=True)
|
|
2671
|
+
def remote_run( # noqa: PLR0913
|
|
2672
|
+
command: list[str] = typer.Argument(..., help="Command to run"),
|
|
2673
|
+
upload_dir: Path | None = typer.Option(
|
|
2674
|
+
None, "--upload-dir", "-u", help="Directory to upload (stateless mode)"
|
|
2675
|
+
),
|
|
2676
|
+
workspace_id: str | None = typer.Option(
|
|
2677
|
+
None, "--workspace-id", "-w", help="Workspace ID (from wafer push)"
|
|
2678
|
+
),
|
|
2679
|
+
gpu_id: int | None = typer.Option(None, "--gpu", "-g", help="GPU ID"),
|
|
2680
|
+
gpu_count: int = typer.Option(1, "--gpu-count", "-n", help="Number of GPUs (1-8)"),
|
|
2681
|
+
docker_image: str | None = typer.Option(None, "--image", "-i", help="Docker image override"),
|
|
2682
|
+
docker_entrypoint: str | None = typer.Option(
|
|
2683
|
+
None, "--docker-entrypoint", help="Override Docker entrypoint (e.g., 'bash')"
|
|
2684
|
+
),
|
|
2685
|
+
pull_image: bool = typer.Option(
|
|
2686
|
+
False, "--pull-image", help="Pull image if not available on target"
|
|
2687
|
+
),
|
|
2688
|
+
require_hwc: bool = typer.Option(
|
|
2689
|
+
False, "--require-hwc", help="Require hardware counters (baremetal)"
|
|
2690
|
+
),
|
|
2691
|
+
direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
|
|
2692
|
+
target_name: str | None = typer.Option(
|
|
2693
|
+
None,
|
|
2694
|
+
"--target",
|
|
2695
|
+
"-t",
|
|
2696
|
+
help="Target for --direct mode. See 'wafer config targets list'.",
|
|
2697
|
+
autocompletion=complete_target_name,
|
|
2698
|
+
),
|
|
2699
|
+
) -> None:
|
|
2700
|
+
"""Run command on remote GPU in Docker.
|
|
2701
|
+
|
|
2702
|
+
Two modes:
|
|
2703
|
+
- High-level (stateless): --upload-dir uploads files and runs command
|
|
2704
|
+
- Low-level: --workspace-id uses existing workspace from 'wafer push'
|
|
2705
|
+
|
|
2706
|
+
By default, uses wafer-api. Use --direct for direct SSH mode.
|
|
2707
|
+
|
|
2708
|
+
Examples:
|
|
2709
|
+
# Stateless: upload and run
|
|
2710
|
+
wafer remote-run --upload-dir ./my_project -- python train.py
|
|
2711
|
+
|
|
2712
|
+
# Run without files
|
|
2713
|
+
wafer remote-run -- nvidia-smi
|
|
2714
|
+
|
|
2715
|
+
# Low-level: use existing workspace
|
|
2716
|
+
wafer remote-run --workspace-id ws_abc123 -- python train.py
|
|
2717
|
+
|
|
2718
|
+
# Direct SSH mode
|
|
2719
|
+
wafer remote-run --upload-dir ./my_project --direct --target vultr-b200 -- python train.py
|
|
2720
|
+
"""
|
|
2721
|
+
cmd_str = " ".join(command)
|
|
2722
|
+
if not cmd_str.strip():
|
|
2723
|
+
typer.echo("Error: Empty command", err=True)
|
|
2724
|
+
raise typer.Exit(1)
|
|
2725
|
+
|
|
2726
|
+
if upload_dir and workspace_id:
|
|
2727
|
+
typer.echo("Error: --upload-dir and --workspace-id are mutually exclusive", err=True)
|
|
2728
|
+
raise typer.Exit(1)
|
|
2729
|
+
|
|
2730
|
+
if upload_dir:
|
|
2731
|
+
if not upload_dir.exists():
|
|
2732
|
+
typer.echo(f"Error: Directory not found: {upload_dir}", err=True)
|
|
2733
|
+
raise typer.Exit(1)
|
|
2734
|
+
if not upload_dir.is_dir():
|
|
2735
|
+
typer.echo(f"Error: Not a directory: {upload_dir}", err=True)
|
|
2736
|
+
raise typer.Exit(1)
|
|
2737
|
+
upload_dir = upload_dir.resolve()
|
|
2738
|
+
|
|
2739
|
+
if direct:
|
|
2740
|
+
if not target_name:
|
|
2741
|
+
typer.echo("Error: --target required for --direct mode", err=True)
|
|
2742
|
+
raise typer.Exit(1)
|
|
2743
|
+
exit_code = _run_direct_mode(cmd_str, target_name, upload_dir, workspace_id, gpu_id)
|
|
2744
|
+
else:
|
|
2745
|
+
exit_code = _run_api_mode(
|
|
2746
|
+
cmd_str,
|
|
2747
|
+
upload_dir,
|
|
2748
|
+
workspace_id,
|
|
2749
|
+
gpu_id,
|
|
2750
|
+
gpu_count,
|
|
2751
|
+
docker_image,
|
|
2752
|
+
docker_entrypoint,
|
|
2753
|
+
pull_image,
|
|
2754
|
+
require_hwc,
|
|
2755
|
+
)
|
|
2756
|
+
|
|
2757
|
+
raise typer.Exit(exit_code)
|
|
2758
|
+
|
|
2759
|
+
|
|
2591
2760
|
# =============================================================================
|
|
2592
2761
|
# Authentication commands
|
|
2593
2762
|
# =============================================================================
|
|
2594
2763
|
|
|
2595
2764
|
|
|
2596
|
-
@
|
|
2765
|
+
@app.command("login")
|
|
2597
2766
|
def login(
|
|
2598
2767
|
token: str | None = typer.Option(
|
|
2599
2768
|
None, "--token", "-t", help="Access token (skip browser OAuth)"
|
|
@@ -2618,7 +2787,7 @@ def login(
|
|
|
2618
2787
|
Uses the API environment from config (see 'wafer config show').
|
|
2619
2788
|
|
|
2620
2789
|
SSH Users (Easiest):
|
|
2621
|
-
- Just run: wafer
|
|
2790
|
+
- Just run: wafer login
|
|
2622
2791
|
- Visit the URL and enter the code shown
|
|
2623
2792
|
- No port forwarding needed!
|
|
2624
2793
|
|
|
@@ -2628,17 +2797,17 @@ def login(
|
|
|
2628
2797
|
|
|
2629
2798
|
Manual token option:
|
|
2630
2799
|
- Visit auth.wafer.ai, authenticate, copy token from URL
|
|
2631
|
-
- Run: wafer
|
|
2800
|
+
- Run: wafer login --token <paste-token>
|
|
2632
2801
|
|
|
2633
2802
|
Examples:
|
|
2634
|
-
wafer
|
|
2635
|
-
wafer
|
|
2636
|
-
wafer
|
|
2637
|
-
wafer
|
|
2803
|
+
wafer login # device code on SSH, browser on local
|
|
2804
|
+
wafer login --no-device-code # force browser (needs port forwarding on SSH)
|
|
2805
|
+
wafer login --port 9000 # custom port for browser flow
|
|
2806
|
+
wafer login --token xyz # manual token (no browser)
|
|
2638
2807
|
|
|
2639
2808
|
# Change environment:
|
|
2640
2809
|
wafer config set api.environment staging
|
|
2641
|
-
wafer
|
|
2810
|
+
wafer login
|
|
2642
2811
|
"""
|
|
2643
2812
|
import httpx
|
|
2644
2813
|
|
|
@@ -2722,7 +2891,7 @@ def login(
|
|
|
2722
2891
|
typer.echo("Token saved to ~/.wafer/credentials.json")
|
|
2723
2892
|
|
|
2724
2893
|
|
|
2725
|
-
@
|
|
2894
|
+
@app.command("logout")
|
|
2726
2895
|
def logout() -> None:
|
|
2727
2896
|
"""Remove stored credentials."""
|
|
2728
2897
|
from . import analytics
|
|
@@ -2739,7 +2908,7 @@ def logout() -> None:
|
|
|
2739
2908
|
typer.echo("Not logged in (no credentials found).")
|
|
2740
2909
|
|
|
2741
2910
|
|
|
2742
|
-
@
|
|
2911
|
+
@app.command("whoami")
|
|
2743
2912
|
def whoami(
|
|
2744
2913
|
verify: bool = typer.Option(False, "--verify", "-v", help="Verify token with API"),
|
|
2745
2914
|
refresh: bool = typer.Option(False, "--refresh", "-r", help="Refresh token if expired"),
|
|
@@ -2753,7 +2922,7 @@ def whoami(
|
|
|
2753
2922
|
|
|
2754
2923
|
creds = load_credentials()
|
|
2755
2924
|
if creds is None:
|
|
2756
|
-
typer.echo("Not logged in. Run: wafer
|
|
2925
|
+
typer.echo("Not logged in. Run: wafer login")
|
|
2757
2926
|
raise typer.Exit(1)
|
|
2758
2927
|
|
|
2759
2928
|
if verify or refresh:
|
|
@@ -2761,7 +2930,7 @@ def whoami(
|
|
|
2761
2930
|
# Try to get valid token with auto-refresh
|
|
2762
2931
|
token = get_valid_token()
|
|
2763
2932
|
if token is None:
|
|
2764
|
-
typer.echo("Token expired and refresh failed. Run: wafer
|
|
2933
|
+
typer.echo("Token expired and refresh failed. Run: wafer login", err=True)
|
|
2765
2934
|
raise typer.Exit(1)
|
|
2766
2935
|
if token != creds.access_token:
|
|
2767
2936
|
typer.echo("Token refreshed successfully")
|
|
@@ -2774,10 +2943,10 @@ def whoami(
|
|
|
2774
2943
|
except Exception as e:
|
|
2775
2944
|
if creds.refresh_token and not refresh:
|
|
2776
2945
|
typer.echo(f"Token expired: {e}", err=True)
|
|
2777
|
-
typer.echo("Try: wafer
|
|
2946
|
+
typer.echo("Try: wafer whoami --refresh", err=True)
|
|
2778
2947
|
else:
|
|
2779
2948
|
typer.echo(f"Token invalid or expired: {e}", err=True)
|
|
2780
|
-
typer.echo("Run: wafer
|
|
2949
|
+
typer.echo("Run: wafer login", err=True)
|
|
2781
2950
|
raise typer.Exit(1) from None
|
|
2782
2951
|
elif creds.email:
|
|
2783
2952
|
typer.echo(creds.email)
|
|
@@ -2785,7 +2954,7 @@ def whoami(
|
|
|
2785
2954
|
typer.echo("Logged in (email not available)")
|
|
2786
2955
|
|
|
2787
2956
|
|
|
2788
|
-
@app.command("guide"
|
|
2957
|
+
@app.command("guide")
|
|
2789
2958
|
def guide() -> None:
|
|
2790
2959
|
"""Show the Wafer CLI usage guide.
|
|
2791
2960
|
|
|
@@ -2816,7 +2985,7 @@ demo_app = typer.Typer(
|
|
|
2816
2985
|
wafer demo trace Analyze a sample performance trace
|
|
2817
2986
|
wafer demo eval Run kernel evaluation on cloud GPU (requires login)"""
|
|
2818
2987
|
)
|
|
2819
|
-
app.add_typer(demo_app, name="demo"
|
|
2988
|
+
app.add_typer(demo_app, name="demo")
|
|
2820
2989
|
|
|
2821
2990
|
DEMO_TRACES_URL = "https://github.com/wafer-ai/wafer/raw/main/apps/wafer-cli/wafer/demo_data"
|
|
2822
2991
|
DEMO_DIR = Path.home() / ".cache" / "wafer" / "demo"
|
|
@@ -3036,7 +3205,7 @@ def demo_eval(
|
|
|
3036
3205
|
"""Demo: Evaluate a kernel on a cloud GPU.
|
|
3037
3206
|
|
|
3038
3207
|
Creates a workspace, runs a sample Triton kernel evaluation, and cleans up.
|
|
3039
|
-
Requires authentication (wafer
|
|
3208
|
+
Requires authentication (wafer login).
|
|
3040
3209
|
|
|
3041
3210
|
Example:
|
|
3042
3211
|
wafer demo eval
|
|
@@ -3051,7 +3220,7 @@ def demo_eval(
|
|
|
3051
3220
|
# Check auth first
|
|
3052
3221
|
creds = load_credentials()
|
|
3053
3222
|
if not creds:
|
|
3054
|
-
typer.echo("Error: Not authenticated. Run: wafer
|
|
3223
|
+
typer.echo("Error: Not authenticated. Run: wafer login")
|
|
3055
3224
|
raise typer.Exit(1)
|
|
3056
3225
|
|
|
3057
3226
|
if not yes:
|
|
@@ -4409,8 +4578,8 @@ def billing_usage(
|
|
|
4409
4578
|
"""Show current billing usage and subscription info.
|
|
4410
4579
|
|
|
4411
4580
|
Example:
|
|
4412
|
-
wafer
|
|
4413
|
-
wafer
|
|
4581
|
+
wafer billing
|
|
4582
|
+
wafer billing --json
|
|
4414
4583
|
"""
|
|
4415
4584
|
# Only show usage if no subcommand was invoked
|
|
4416
4585
|
if ctx.invoked_subcommand is not None:
|
|
@@ -4438,9 +4607,9 @@ def billing_topup(
|
|
|
4438
4607
|
Opens a Stripe checkout page to add credits. Default amount is $25.
|
|
4439
4608
|
|
|
4440
4609
|
Example:
|
|
4441
|
-
wafer
|
|
4442
|
-
wafer
|
|
4443
|
-
wafer
|
|
4610
|
+
wafer billing topup # Add $25
|
|
4611
|
+
wafer billing topup 100 # Add $100
|
|
4612
|
+
wafer billing topup --no-browser # Print URL instead
|
|
4444
4613
|
"""
|
|
4445
4614
|
import webbrowser
|
|
4446
4615
|
|
|
@@ -4486,8 +4655,8 @@ def billing_portal(
|
|
|
4486
4655
|
Manage your subscription, update payment method, or view invoices.
|
|
4487
4656
|
|
|
4488
4657
|
Example:
|
|
4489
|
-
wafer
|
|
4490
|
-
wafer
|
|
4658
|
+
wafer billing portal
|
|
4659
|
+
wafer billing portal --no-browser
|
|
4491
4660
|
"""
|
|
4492
4661
|
import webbrowser
|
|
4493
4662
|
|
|
@@ -4524,8 +4693,8 @@ def ssh_keys_list(
|
|
|
4524
4693
|
"""List all registered SSH public keys.
|
|
4525
4694
|
|
|
4526
4695
|
Example:
|
|
4527
|
-
wafer
|
|
4528
|
-
wafer
|
|
4696
|
+
wafer ssh-keys list
|
|
4697
|
+
wafer ssh-keys list --json
|
|
4529
4698
|
"""
|
|
4530
4699
|
from .ssh_keys import list_ssh_keys
|
|
4531
4700
|
|
|
@@ -4551,9 +4720,9 @@ def ssh_keys_add(
|
|
|
4551
4720
|
id_ed25519.pub, id_rsa.pub, id_ecdsa.pub.
|
|
4552
4721
|
|
|
4553
4722
|
Example:
|
|
4554
|
-
wafer
|
|
4555
|
-
wafer
|
|
4556
|
-
wafer
|
|
4723
|
+
wafer ssh-keys add # Auto-detect
|
|
4724
|
+
wafer ssh-keys add ~/.ssh/id_rsa.pub # Specific file
|
|
4725
|
+
wafer ssh-keys add ~/.ssh/id_ed25519.pub --name laptop
|
|
4557
4726
|
"""
|
|
4558
4727
|
from .ssh_keys import add_ssh_key
|
|
4559
4728
|
|
|
@@ -4572,10 +4741,10 @@ def ssh_keys_remove(
|
|
|
4572
4741
|
) -> None:
|
|
4573
4742
|
"""Remove an SSH public key.
|
|
4574
4743
|
|
|
4575
|
-
Get the key ID from 'wafer
|
|
4744
|
+
Get the key ID from 'wafer ssh-keys list'.
|
|
4576
4745
|
|
|
4577
4746
|
Example:
|
|
4578
|
-
wafer
|
|
4747
|
+
wafer ssh-keys remove abc123-def456-...
|
|
4579
4748
|
"""
|
|
4580
4749
|
from .ssh_keys import remove_ssh_key
|
|
4581
4750
|
|
|
@@ -5062,18 +5231,6 @@ def workspaces_pull(
|
|
|
5062
5231
|
raise typer.Exit(1) from None
|
|
5063
5232
|
|
|
5064
5233
|
|
|
5065
|
-
# =============================================================================
|
|
5066
|
-
# Live resource commands (list/terminate/reconcile/provision)
|
|
5067
|
-
# =============================================================================
|
|
5068
|
-
|
|
5069
|
-
targets_ops_app.command("list")(_targets_list_cmd)
|
|
5070
|
-
targets_ops_app.command("terminate")(_targets_terminate_cmd)
|
|
5071
|
-
targets_ops_app.command("reconcile")(_targets_reconcile_cmd)
|
|
5072
|
-
targets_ops_app.command("provision")(_targets_provision_cmd)
|
|
5073
|
-
targets_ops_app.command("pools")(_targets_pools_cmd)
|
|
5074
|
-
targets_ops_app.command("probe")(_targets_probe_cmd)
|
|
5075
|
-
|
|
5076
|
-
|
|
5077
5234
|
# =============================================================================
|
|
5078
5235
|
# Target operations commands (exec/ssh/sync)
|
|
5079
5236
|
# =============================================================================
|
|
@@ -5832,9 +5989,9 @@ def ncu_analyze(
|
|
|
5832
5989
|
compute/memory throughput, and optimization recommendations.
|
|
5833
5990
|
|
|
5834
5991
|
By default, uses local NCU if available, otherwise runs analysis
|
|
5835
|
-
remotely via wafer-api (requires authentication: wafer
|
|
5992
|
+
remotely via wafer-api (requires authentication: wafer login).
|
|
5836
5993
|
|
|
5837
|
-
Use --target for direct SSH mode.
|
|
5994
|
+
Use --target for direct SSH mode (like wafer remote-run --direct).
|
|
5838
5995
|
Use --include-source to fetch SASS assembly with register/instruction data.
|
|
5839
5996
|
|
|
5840
5997
|
Examples:
|
|
@@ -5927,7 +6084,7 @@ def nsys_analyze(
|
|
|
5927
6084
|
Returns timeline events, kernel information, memory usage, and diagnostics.
|
|
5928
6085
|
|
|
5929
6086
|
By default, uses local nsys if available, otherwise runs analysis
|
|
5930
|
-
remotely via wafer-api (requires authentication: wafer
|
|
6087
|
+
remotely via wafer-api (requires authentication: wafer login).
|
|
5931
6088
|
|
|
5932
6089
|
Supports multiple execution modes:
|
|
5933
6090
|
- Local: Uses local nsys CLI (no GPU required for analysis)
|
|
@@ -6912,7 +7069,7 @@ def autotuner_results(
|
|
|
6912
7069
|
raise typer.Exit(1) from None
|
|
6913
7070
|
|
|
6914
7071
|
|
|
6915
|
-
@app.command("capture"
|
|
7072
|
+
@app.command("capture")
|
|
6916
7073
|
def capture_command( # noqa: PLR0915
|
|
6917
7074
|
label: str = typer.Argument(
|
|
6918
7075
|
..., help="Label for this capture (e.g., 'baseline', 'optimized-v2')"
|
|
@@ -7592,29 +7749,18 @@ def compare_analyze(
|
|
|
7592
7749
|
"-f",
|
|
7593
7750
|
help="Output format: text, text-layers, csv, csv-layers, json",
|
|
7594
7751
|
),
|
|
7595
|
-
output: Path | None = typer.Option(
|
|
7596
|
-
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7597
|
-
),
|
|
7752
|
+
output: Path | None = typer.Option(None, "--output", "-o", help="Output file (default: stdout)"),
|
|
7598
7753
|
phase: str = typer.Option(
|
|
7599
7754
|
"all",
|
|
7600
7755
|
"--phase",
|
|
7601
7756
|
help="Filter by phase: all, prefill, decode",
|
|
7602
7757
|
),
|
|
7603
7758
|
layers: bool = typer.Option(False, "--layers", help="Show layer-wise performance breakdown"),
|
|
7604
|
-
all: bool = typer.Option(
|
|
7605
|
-
|
|
7606
|
-
),
|
|
7607
|
-
stack_traces: bool = typer.Option(
|
|
7608
|
-
False, "--stack-traces", help="Show Python stack traces for operations"
|
|
7609
|
-
),
|
|
7610
|
-
recommendations: bool = typer.Option(
|
|
7611
|
-
False, "--recommendations", help="Generate prioritized recommendations for kernel team"
|
|
7612
|
-
),
|
|
7613
|
-
json: bool = typer.Option(
|
|
7614
|
-
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7615
|
-
),
|
|
7759
|
+
all: bool = typer.Option(False, "--all", help="Show all items (no truncation for layers, operations, kernels)"),
|
|
7760
|
+
stack_traces: bool = typer.Option(False, "--stack-traces", help="Show Python stack traces for operations"),
|
|
7761
|
+
json: bool = typer.Option(False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"),
|
|
7616
7762
|
) -> None:
|
|
7617
|
-
"""Compare GPU traces from
|
|
7763
|
+
"""Compare GPU traces from two platforms platforms.
|
|
7618
7764
|
|
|
7619
7765
|
Analyzes performance differences between traces, identifying which operations
|
|
7620
7766
|
are faster/slower on each platform and providing kernel-level details.
|
|
@@ -7662,7 +7808,6 @@ def compare_analyze(
|
|
|
7662
7808
|
show_layers=layers,
|
|
7663
7809
|
show_all=all,
|
|
7664
7810
|
show_stack_traces=stack_traces,
|
|
7665
|
-
recommendations=recommendations,
|
|
7666
7811
|
)
|
|
7667
7812
|
_mark_command_success()
|
|
7668
7813
|
|
|
@@ -7677,17 +7822,13 @@ def compare_fusion_cmd(
|
|
|
7677
7822
|
"-f",
|
|
7678
7823
|
help="Output format: text, csv, json",
|
|
7679
7824
|
),
|
|
7680
|
-
output: Path | None = typer.Option(
|
|
7681
|
-
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7682
|
-
),
|
|
7825
|
+
output: Path | None = typer.Option(None, "--output", "-o", help="Output file (default: stdout)"),
|
|
7683
7826
|
min_group_size: int = typer.Option(
|
|
7684
7827
|
50,
|
|
7685
7828
|
"--min-group-size",
|
|
7686
7829
|
help="Minimum correlation group size to analyze",
|
|
7687
7830
|
),
|
|
7688
|
-
json: bool = typer.Option(
|
|
7689
|
-
False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"
|
|
7690
|
-
),
|
|
7831
|
+
json: bool = typer.Option(False, "--json", hidden=True, help="Ignored (for compatibility with cliExecutor)"),
|
|
7691
7832
|
) -> None:
|
|
7692
7833
|
"""Analyze kernel fusion differences between AMD and NVIDIA traces.
|
|
7693
7834
|
|
|
@@ -7707,69 +7848,14 @@ def compare_fusion_cmd(
|
|
|
7707
7848
|
# CSV output to file
|
|
7708
7849
|
wafer compare fusion amd_trace.json nvidia_trace.json --format csv -o fusion.csv
|
|
7709
7850
|
"""
|
|
7710
|
-
from .trace_compare import
|
|
7711
|
-
|
|
7712
|
-
compare_align(
|
|
7713
|
-
trace1=trace1,
|
|
7714
|
-
trace2=trace2,
|
|
7715
|
-
output=output,
|
|
7716
|
-
output_format=format,
|
|
7717
|
-
phase="all",
|
|
7718
|
-
)
|
|
7719
|
-
_mark_command_success()
|
|
7720
|
-
|
|
7721
|
-
|
|
7722
|
-
@compare_app.command("align")
|
|
7723
|
-
def compare_align_cmd(
|
|
7724
|
-
trace1: Path = typer.Argument(..., help="First trace file (AMD or NVIDIA)", exists=True),
|
|
7725
|
-
trace2: Path = typer.Argument(..., help="Second trace file (AMD or NVIDIA)", exists=True),
|
|
7726
|
-
format: str = typer.Option(
|
|
7727
|
-
"json",
|
|
7728
|
-
"--format",
|
|
7729
|
-
"-f",
|
|
7730
|
-
help="Output format: json",
|
|
7731
|
-
),
|
|
7732
|
-
output: Path | None = typer.Option(
|
|
7733
|
-
None, "--output", "-o", help="Output file (default: stdout)"
|
|
7734
|
-
),
|
|
7735
|
-
phase: str = typer.Option(
|
|
7736
|
-
"all",
|
|
7737
|
-
"--phase",
|
|
7738
|
-
help="Filter by phase: all, prefill, decode",
|
|
7739
|
-
),
|
|
7740
|
-
layer: int | None = typer.Option(
|
|
7741
|
-
None,
|
|
7742
|
-
"--layer",
|
|
7743
|
-
help="Focus on specific layer number",
|
|
7744
|
-
),
|
|
7745
|
-
) -> None:
|
|
7746
|
-
"""Align kernels at layer level for exact kernel-to-kernel comparison.
|
|
7747
|
-
|
|
7748
|
-
Provides kernel-to-kernel mapping across AMD and NVIDIA platforms,
|
|
7749
|
-
showing which kernels correspond to each other at each layer position.
|
|
7750
|
-
|
|
7751
|
-
Examples:
|
|
7752
|
-
# Basic alignment (stdout JSON)
|
|
7753
|
-
wafer compare align amd_trace.json nvidia_trace.json
|
|
7754
|
-
|
|
7755
|
-
# Save to file
|
|
7756
|
-
wafer compare align amd_trace.json nvidia_trace.json -o alignment.json
|
|
7757
|
-
|
|
7758
|
-
# Focus on decode phase only
|
|
7759
|
-
wafer compare align amd_trace.json nvidia_trace.json --phase decode
|
|
7760
|
-
|
|
7761
|
-
# Focus on specific layer
|
|
7762
|
-
wafer compare align amd_trace.json nvidia_trace.json --layer 5
|
|
7763
|
-
"""
|
|
7764
|
-
from .trace_compare import compare_align
|
|
7851
|
+
from .trace_compare import compare_fusion
|
|
7765
7852
|
|
|
7766
|
-
|
|
7853
|
+
compare_fusion(
|
|
7767
7854
|
trace1=trace1,
|
|
7768
7855
|
trace2=trace2,
|
|
7769
7856
|
output=output,
|
|
7770
|
-
|
|
7771
|
-
|
|
7772
|
-
layer=layer,
|
|
7857
|
+
format_type=format,
|
|
7858
|
+
min_group_size=min_group_size,
|
|
7773
7859
|
)
|
|
7774
7860
|
_mark_command_success()
|
|
7775
7861
|
|