wafer-cli 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/analytics.py +0 -1
- wafer/auth.py +1 -1
- wafer/autotuner.py +21 -17
- wafer/cli.py +41 -3
- wafer/evaluate.py +113 -53
- wafer/kernel_scope.py +7 -9
- wafer/nsys_profile.py +2 -3
- wafer/output.py +10 -3
- wafer/rocprof_compute.py +50 -42
- wafer/rocprof_sdk.py +1 -1
- wafer/targets_ops.py +0 -1
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/wevin_cli.py +1 -1
- wafer/workspaces.py +0 -2
- {wafer_cli-0.2.13.dist-info → wafer_cli-0.2.15.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.13.dist-info → wafer_cli-0.2.15.dist-info}/RECORD +19 -18
- {wafer_cli-0.2.13.dist-info → wafer_cli-0.2.15.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.13.dist-info → wafer_cli-0.2.15.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.13.dist-info → wafer_cli-0.2.15.dist-info}/top_level.txt +0 -0
wafer/analytics.py
CHANGED
wafer/auth.py
CHANGED
|
@@ -419,7 +419,7 @@ def device_code_login(timeout: int = 600) -> tuple[str, str | None]:
|
|
|
419
419
|
print(f" {CROSS}\n")
|
|
420
420
|
raise RuntimeError(f"CLI auth flow failed: {response.status_code} {response.text}")
|
|
421
421
|
|
|
422
|
-
except httpx.RequestError
|
|
422
|
+
except httpx.RequestError:
|
|
423
423
|
# Network error, retry
|
|
424
424
|
print("!", end="", flush=True)
|
|
425
425
|
last_poll = time.time()
|
wafer/autotuner.py
CHANGED
|
@@ -5,6 +5,7 @@ This module provides the implementation for the `wafer autotuner` commands.
|
|
|
5
5
|
|
|
6
6
|
import asyncio
|
|
7
7
|
import json
|
|
8
|
+
from datetime import UTC
|
|
8
9
|
from pathlib import Path
|
|
9
10
|
from typing import Any
|
|
10
11
|
|
|
@@ -32,13 +33,14 @@ def run_sweep_command(
|
|
|
32
33
|
raise FileNotFoundError(f"Config file not found: {config_file}")
|
|
33
34
|
|
|
34
35
|
# Import autotuner core
|
|
35
|
-
from datetime import datetime
|
|
36
|
+
from datetime import datetime
|
|
36
37
|
from uuid import uuid4
|
|
38
|
+
|
|
37
39
|
import trio
|
|
38
40
|
from wafer_core.tools.autotuner import AutotunerConfig, run_sweep
|
|
39
41
|
from wafer_core.tools.autotuner.dtypes import Sweep, Trial
|
|
40
42
|
from wafer_core.tools.autotuner.search import generate_grid_trials
|
|
41
|
-
from wafer_core.tools.autotuner.storage import
|
|
43
|
+
from wafer_core.tools.autotuner.storage import add_trial, create_sweep, get_sweep, get_trials
|
|
42
44
|
|
|
43
45
|
# Load or reconstruct config
|
|
44
46
|
if resume_sweep_id:
|
|
@@ -189,8 +191,8 @@ def run_sweep_command(
|
|
|
189
191
|
status="running",
|
|
190
192
|
total_trials=total_trials,
|
|
191
193
|
completed_trials=0,
|
|
192
|
-
created_at=datetime.now(
|
|
193
|
-
updated_at=datetime.now(
|
|
194
|
+
created_at=datetime.now(UTC),
|
|
195
|
+
updated_at=datetime.now(UTC),
|
|
194
196
|
)
|
|
195
197
|
|
|
196
198
|
# Create sweep and get the actual ID from the API
|
|
@@ -245,7 +247,7 @@ def run_sweep_command(
|
|
|
245
247
|
# Helper to update sweep status
|
|
246
248
|
async def update_sweep_status(status: str) -> None:
|
|
247
249
|
import httpx
|
|
248
|
-
from wafer_core.tools.autotuner.storage import
|
|
250
|
+
from wafer_core.tools.autotuner.storage import _get_auth_headers, get_api_url
|
|
249
251
|
|
|
250
252
|
api_url = get_api_url()
|
|
251
253
|
headers = _get_auth_headers()
|
|
@@ -260,7 +262,7 @@ def run_sweep_command(
|
|
|
260
262
|
# Note: working_dir already set based on is_resume flag
|
|
261
263
|
|
|
262
264
|
try:
|
|
263
|
-
|
|
265
|
+
await run_sweep(
|
|
264
266
|
config=config,
|
|
265
267
|
sweep_id=actual_sweep_id,
|
|
266
268
|
working_dir=working_dir,
|
|
@@ -273,7 +275,7 @@ def run_sweep_command(
|
|
|
273
275
|
|
|
274
276
|
# Print final summary
|
|
275
277
|
print()
|
|
276
|
-
print(
|
|
278
|
+
print("✅ Sweep completed!")
|
|
277
279
|
print(f" Total: {total_trials} trials")
|
|
278
280
|
print(f" Success: {success_count}")
|
|
279
281
|
print(f" Failed: {failed_count}")
|
|
@@ -297,7 +299,7 @@ def run_sweep_command(
|
|
|
297
299
|
except KeyboardInterrupt:
|
|
298
300
|
# User pressed Ctrl+C
|
|
299
301
|
print()
|
|
300
|
-
print(
|
|
302
|
+
print("❌ Sweep interrupted by user (Ctrl+C)")
|
|
301
303
|
print(f" Completed: {completed_count}/{total_trials} trials")
|
|
302
304
|
await update_sweep_status("failed")
|
|
303
305
|
raise
|
|
@@ -350,8 +352,8 @@ def results_command(
|
|
|
350
352
|
Formatted string with results
|
|
351
353
|
"""
|
|
352
354
|
from wafer_core.tools.autotuner import compute_pareto_frontier
|
|
353
|
-
from wafer_core.tools.autotuner.storage import get_sweep, get_trials
|
|
354
355
|
from wafer_core.tools.autotuner.aggregation import aggregate_trials_by_config
|
|
356
|
+
from wafer_core.tools.autotuner.storage import get_sweep, get_trials
|
|
355
357
|
|
|
356
358
|
try:
|
|
357
359
|
# Get sweep and trials
|
|
@@ -501,7 +503,10 @@ def results_command(
|
|
|
501
503
|
# Use aggregated config scoring
|
|
502
504
|
if len(objectives_data) > 1:
|
|
503
505
|
# Multi-objective: compute Pareto
|
|
504
|
-
from wafer_core.tools.autotuner.scoring import
|
|
506
|
+
from wafer_core.tools.autotuner.scoring import (
|
|
507
|
+
compute_pareto_frontier_configs,
|
|
508
|
+
rank_pareto_configs,
|
|
509
|
+
)
|
|
505
510
|
objectives = [
|
|
506
511
|
Objective(
|
|
507
512
|
metric=obj["metric"],
|
|
@@ -513,7 +518,7 @@ def results_command(
|
|
|
513
518
|
pareto_configs = compute_pareto_frontier_configs(aggregated_configs, objectives)
|
|
514
519
|
ranked_configs = rank_pareto_configs(pareto_configs, objectives)
|
|
515
520
|
|
|
516
|
-
lines.append(
|
|
521
|
+
lines.append("Pareto Frontier (using config objectives):")
|
|
517
522
|
lines.append(f"Found {len(ranked_configs)} non-dominated configurations.")
|
|
518
523
|
lines.append("")
|
|
519
524
|
|
|
@@ -563,7 +568,7 @@ def results_command(
|
|
|
563
568
|
]
|
|
564
569
|
pareto_trials = compute_pareto_frontier(completed_trials, objectives)
|
|
565
570
|
|
|
566
|
-
lines.append(
|
|
571
|
+
lines.append("Pareto Frontier (using config objectives):")
|
|
567
572
|
lines.append(f"Found {len(pareto_trials)} non-dominated configurations.")
|
|
568
573
|
lines.append("")
|
|
569
574
|
|
|
@@ -699,8 +704,8 @@ def best_command(
|
|
|
699
704
|
Returns:
|
|
700
705
|
Formatted string with best config
|
|
701
706
|
"""
|
|
702
|
-
from wafer_core.tools.autotuner.storage import get_sweep, get_trials
|
|
703
707
|
from wafer_core.tools.autotuner.aggregation import aggregate_trials_by_config
|
|
708
|
+
from wafer_core.tools.autotuner.storage import get_sweep, get_trials
|
|
704
709
|
|
|
705
710
|
try:
|
|
706
711
|
# Get sweep and trials
|
|
@@ -991,7 +996,7 @@ def delete_command(sweep_id: str) -> str:
|
|
|
991
996
|
Success message
|
|
992
997
|
"""
|
|
993
998
|
import httpx
|
|
994
|
-
from wafer_core.tools.autotuner.storage import
|
|
999
|
+
from wafer_core.tools.autotuner.storage import _get_auth_headers, get_api_url
|
|
995
1000
|
|
|
996
1001
|
try:
|
|
997
1002
|
api_url = get_api_url()
|
|
@@ -1008,8 +1013,7 @@ def delete_command(sweep_id: str) -> str:
|
|
|
1008
1013
|
except httpx.HTTPStatusError as e:
|
|
1009
1014
|
if e.response.status_code == 404:
|
|
1010
1015
|
raise ValueError(f"Sweep {sweep_id} not found")
|
|
1011
|
-
|
|
1012
|
-
raise ValueError(f"Failed to delete sweep: {e}")
|
|
1016
|
+
raise ValueError(f"Failed to delete sweep: {e}")
|
|
1013
1017
|
except Exception as e:
|
|
1014
1018
|
raise ValueError(f"Failed to delete sweep: {e}") from e
|
|
1015
1019
|
|
|
@@ -1024,7 +1028,7 @@ def delete_all_command(status_filter: str | None = None) -> str:
|
|
|
1024
1028
|
Summary of deletions
|
|
1025
1029
|
"""
|
|
1026
1030
|
import httpx
|
|
1027
|
-
from wafer_core.tools.autotuner.storage import
|
|
1031
|
+
from wafer_core.tools.autotuner.storage import _get_auth_headers, get_api_url, list_sweeps
|
|
1028
1032
|
|
|
1029
1033
|
try:
|
|
1030
1034
|
# Get all sweeps
|
wafer/cli.py
CHANGED
|
@@ -1805,6 +1805,18 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
1805
1805
|
True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
|
|
1806
1806
|
),
|
|
1807
1807
|
gpu_id: int | None = typer.Option(None, "--gpu-id", help="Override GPU ID"),
|
|
1808
|
+
stages: str = typer.Option(
|
|
1809
|
+
"compile,correctness",
|
|
1810
|
+
"--stages",
|
|
1811
|
+
help="Comma-separated stages to run: compile, correctness, benchmark, defense. "
|
|
1812
|
+
"Use 'all' for compile,correctness,benchmark,defense. Default: compile,correctness",
|
|
1813
|
+
),
|
|
1814
|
+
prepare_only: bool = typer.Option(
|
|
1815
|
+
False,
|
|
1816
|
+
"--prepare-only",
|
|
1817
|
+
help="Sync files and generate eval script but don't run. "
|
|
1818
|
+
"Prints the command to run manually (useful for wrapping with rocprof, etc.)",
|
|
1819
|
+
),
|
|
1808
1820
|
json_output: bool = typer.Option(
|
|
1809
1821
|
False, "--json", help="Output as single JSON object (machine-readable)"
|
|
1810
1822
|
),
|
|
@@ -1912,18 +1924,42 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
1912
1924
|
|
|
1913
1925
|
collector.target = resolved_target
|
|
1914
1926
|
|
|
1927
|
+
# Expand 'all' stages shorthand
|
|
1928
|
+
resolved_stages = stages
|
|
1929
|
+
if stages == "all":
|
|
1930
|
+
resolved_stages = "compile,correctness,benchmark,defense"
|
|
1931
|
+
|
|
1932
|
+
# Handle backward compat: --benchmark and --defensive flags add to stages
|
|
1933
|
+
stage_set = set(resolved_stages.split(","))
|
|
1934
|
+
if benchmark and "benchmark" not in stage_set:
|
|
1935
|
+
stage_set.add("benchmark")
|
|
1936
|
+
if defensive and "defense" not in stage_set:
|
|
1937
|
+
stage_set.add("defense")
|
|
1938
|
+
resolved_stages = ",".join(
|
|
1939
|
+
sorted(
|
|
1940
|
+
stage_set,
|
|
1941
|
+
key=lambda s: (
|
|
1942
|
+
["compile", "correctness", "benchmark", "defense"].index(s)
|
|
1943
|
+
if s in ["compile", "correctness", "benchmark", "defense"]
|
|
1944
|
+
else 99
|
|
1945
|
+
),
|
|
1946
|
+
)
|
|
1947
|
+
)
|
|
1948
|
+
|
|
1915
1949
|
args = KernelBenchEvaluateArgs(
|
|
1916
1950
|
implementation=implementation,
|
|
1917
1951
|
reference=reference,
|
|
1918
1952
|
target_name=resolved_target,
|
|
1919
|
-
benchmark=benchmark,
|
|
1953
|
+
benchmark=benchmark or "benchmark" in stage_set,
|
|
1920
1954
|
profile=profile,
|
|
1921
1955
|
inputs=inputs,
|
|
1922
1956
|
seed=seed,
|
|
1923
|
-
defensive=defensive,
|
|
1957
|
+
defensive=defensive or "defense" in stage_set,
|
|
1924
1958
|
backend=backend,
|
|
1925
1959
|
sync_artifacts=sync_artifacts,
|
|
1926
1960
|
gpu_id=gpu_id,
|
|
1961
|
+
stages=resolved_stages,
|
|
1962
|
+
prepare_only=prepare_only,
|
|
1927
1963
|
)
|
|
1928
1964
|
|
|
1929
1965
|
collector.emit("started", target=resolved_target)
|
|
@@ -1955,7 +1991,9 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
|
|
|
1955
1991
|
collector.output_text_result(result)
|
|
1956
1992
|
collector.finalize()
|
|
1957
1993
|
|
|
1958
|
-
|
|
1994
|
+
# For prepare-only mode, success means we prepared successfully (don't check correctness)
|
|
1995
|
+
# For compile-only (all_correct is None), also treat as success
|
|
1996
|
+
if not prepare_only and result.all_correct is not None and not result.all_correct:
|
|
1959
1997
|
raise typer.Exit(1)
|
|
1960
1998
|
else:
|
|
1961
1999
|
collector.output_text_error(result.error_message or "Unknown error")
|
wafer/evaluate.py
CHANGED
|
@@ -21,7 +21,6 @@ from wafer_core.utils.kernel_utils.targets.config import (
|
|
|
21
21
|
WorkspaceTarget,
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
|
|
25
24
|
# Map AMD compute capability to ROCm architecture
|
|
26
25
|
# Used to set PYTORCH_ROCM_ARCH for faster compilation (compile only for target arch)
|
|
27
26
|
AMD_CC_TO_ARCH = {
|
|
@@ -189,6 +188,8 @@ class KernelBenchEvaluateArgs:
|
|
|
189
188
|
backend: str | None = None # Kernel backend for static validation
|
|
190
189
|
sync_artifacts: bool = True
|
|
191
190
|
gpu_id: int | None = None
|
|
191
|
+
stages: str = "compile,correctness" # Stages to run: compile, correctness, benchmark, defense
|
|
192
|
+
prepare_only: bool = False # Sync files and generate script but don't run
|
|
192
193
|
|
|
193
194
|
|
|
194
195
|
@dataclass(frozen=True)
|
|
@@ -196,7 +197,7 @@ class EvaluateResult:
|
|
|
196
197
|
"""Result from remote evaluation."""
|
|
197
198
|
|
|
198
199
|
success: bool
|
|
199
|
-
all_correct: bool
|
|
200
|
+
all_correct: bool | None # None when correctness wasn't checked (compile-only, prepare-only)
|
|
200
201
|
correctness_score: float
|
|
201
202
|
geomean_speedup: float
|
|
202
203
|
passed_tests: int
|
|
@@ -3066,8 +3067,18 @@ def main():
|
|
|
3066
3067
|
parser.add_argument("--num-correct-trials", type=int, default=3)
|
|
3067
3068
|
parser.add_argument("--num-perf-trials", type=int, default=10)
|
|
3068
3069
|
parser.add_argument("--output", required=True)
|
|
3070
|
+
parser.add_argument("--stages", default="compile,correctness",
|
|
3071
|
+
help="Comma-separated stages: compile, correctness, benchmark, defense")
|
|
3069
3072
|
args = parser.parse_args()
|
|
3070
3073
|
|
|
3074
|
+
# Parse stages
|
|
3075
|
+
stages = set(args.stages.split(","))
|
|
3076
|
+
run_compile = "compile" in stages
|
|
3077
|
+
run_correctness = "correctness" in stages
|
|
3078
|
+
run_benchmark = "benchmark" in stages or args.benchmark
|
|
3079
|
+
run_defense = "defense" in stages or args.defensive
|
|
3080
|
+
print(f"[KernelBench] Stages: {args.stages}")
|
|
3081
|
+
|
|
3071
3082
|
# Load defense module if defensive mode is enabled
|
|
3072
3083
|
defense_module = None
|
|
3073
3084
|
if args.defensive and args.defense_module:
|
|
@@ -3156,64 +3167,69 @@ def main():
|
|
|
3156
3167
|
new_model = ModelNew(*init_inputs).cuda().eval()
|
|
3157
3168
|
print(f"[KernelBench] Models instantiated (seed={seed})")
|
|
3158
3169
|
|
|
3159
|
-
# Run correctness trials
|
|
3170
|
+
# Run correctness trials (if stage enabled)
|
|
3160
3171
|
all_correct = True
|
|
3161
|
-
|
|
3162
|
-
|
|
3163
|
-
|
|
3164
|
-
|
|
3165
|
-
|
|
3166
|
-
|
|
3167
|
-
|
|
3168
|
-
|
|
3169
|
-
|
|
3170
|
-
|
|
3171
|
-
|
|
3172
|
-
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
|
|
3176
|
-
|
|
3177
|
-
|
|
3178
|
-
|
|
3179
|
-
|
|
3180
|
-
|
|
3181
|
-
|
|
3182
|
-
|
|
3183
|
-
|
|
3184
|
-
|
|
3185
|
-
|
|
3186
|
-
|
|
3187
|
-
|
|
3188
|
-
|
|
3189
|
-
|
|
3190
|
-
|
|
3191
|
-
|
|
3192
|
-
|
|
3193
|
-
|
|
3194
|
-
|
|
3195
|
-
|
|
3196
|
-
|
|
3197
|
-
|
|
3198
|
-
|
|
3199
|
-
|
|
3200
|
-
|
|
3201
|
-
|
|
3202
|
-
|
|
3203
|
-
|
|
3204
|
-
|
|
3205
|
-
|
|
3172
|
+
if not run_correctness:
|
|
3173
|
+
print("[KernelBench] Skipping correctness (not in stages)")
|
|
3174
|
+
results["correct"] = None # Unknown - not checked
|
|
3175
|
+
else:
|
|
3176
|
+
for trial in range(args.num_correct_trials):
|
|
3177
|
+
inputs = get_inputs()
|
|
3178
|
+
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
3179
|
+
|
|
3180
|
+
with torch.no_grad():
|
|
3181
|
+
ref_output = ref_model(*inputs)
|
|
3182
|
+
new_output = new_model(*inputs)
|
|
3183
|
+
|
|
3184
|
+
# Compare outputs
|
|
3185
|
+
if isinstance(ref_output, torch.Tensor):
|
|
3186
|
+
if not torch.allclose(ref_output, new_output, rtol=1e-3, atol=1e-3):
|
|
3187
|
+
all_correct = False
|
|
3188
|
+
analysis = analyze_diff(ref_output, new_output)
|
|
3189
|
+
results["error"] = f"Correctness failed on trial {trial+1}: max diff = {analysis['max_diff']}"
|
|
3190
|
+
results["diff_analysis"] = analysis
|
|
3191
|
+
print_diff_analysis(analysis)
|
|
3192
|
+
|
|
3193
|
+
# Save tensors for debugging
|
|
3194
|
+
debug_dir = output_dir / "debug"
|
|
3195
|
+
debug_dir.mkdir(exist_ok=True)
|
|
3196
|
+
torch.save(ref_output.cpu(), debug_dir / "ref_output.pt")
|
|
3197
|
+
torch.save(new_output.cpu(), debug_dir / "impl_output.pt")
|
|
3198
|
+
torch.save(inputs[0].cpu() if inputs else None, debug_dir / "input.pt")
|
|
3199
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
3200
|
+
break
|
|
3201
|
+
else:
|
|
3202
|
+
# Handle tuple/list outputs
|
|
3203
|
+
for i, (r, n) in enumerate(zip(ref_output, new_output)):
|
|
3204
|
+
if isinstance(r, torch.Tensor):
|
|
3205
|
+
if not torch.allclose(r, n, rtol=1e-3, atol=1e-3):
|
|
3206
|
+
all_correct = False
|
|
3207
|
+
analysis = analyze_diff(r, n)
|
|
3208
|
+
results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {analysis['max_diff']}"
|
|
3209
|
+
results["diff_analysis"] = analysis
|
|
3210
|
+
print_diff_analysis(analysis)
|
|
3211
|
+
|
|
3212
|
+
# Save tensors for debugging
|
|
3213
|
+
debug_dir = output_dir / "debug"
|
|
3214
|
+
debug_dir.mkdir(exist_ok=True)
|
|
3215
|
+
torch.save(r.cpu(), debug_dir / f"ref_output_{i}.pt")
|
|
3216
|
+
torch.save(n.cpu(), debug_dir / f"impl_output_{i}.pt")
|
|
3217
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
3218
|
+
break
|
|
3219
|
+
if not all_correct:
|
|
3220
|
+
break
|
|
3206
3221
|
|
|
3207
|
-
|
|
3208
|
-
|
|
3222
|
+
results["correct"] = all_correct
|
|
3223
|
+
print(f"[KernelBench] Correctness: {all_correct}")
|
|
3209
3224
|
|
|
3210
|
-
# Run benchmark if
|
|
3211
|
-
|
|
3225
|
+
# Run benchmark if stage enabled (and correctness passed or skipped)
|
|
3226
|
+
should_benchmark = run_benchmark and (all_correct or not run_correctness)
|
|
3227
|
+
if should_benchmark:
|
|
3212
3228
|
print("[KernelBench] Running benchmarks...")
|
|
3213
3229
|
inputs = get_inputs()
|
|
3214
3230
|
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
3215
3231
|
|
|
3216
|
-
if
|
|
3232
|
+
if run_defense and defense_module is not None:
|
|
3217
3233
|
# Use full defense suite
|
|
3218
3234
|
print("[KernelBench] Running defense checks on implementation...")
|
|
3219
3235
|
run_all_defenses = defense_module.run_all_defenses
|
|
@@ -3598,6 +3614,7 @@ async def run_evaluate_kernelbench_docker(
|
|
|
3598
3614
|
python_cmd_parts.append("--defensive")
|
|
3599
3615
|
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3600
3616
|
python_cmd_parts.append(f"--seed {args.seed}")
|
|
3617
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
3601
3618
|
|
|
3602
3619
|
eval_cmd = " ".join(python_cmd_parts)
|
|
3603
3620
|
|
|
@@ -3869,6 +3886,7 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
3869
3886
|
python_cmd_parts.append("--defensive")
|
|
3870
3887
|
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3871
3888
|
python_cmd_parts.append(f"--seed {args.seed}")
|
|
3889
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
3872
3890
|
|
|
3873
3891
|
eval_cmd = " ".join(python_cmd_parts)
|
|
3874
3892
|
|
|
@@ -4124,6 +4142,7 @@ async def run_evaluate_kernelbench_runpod(
|
|
|
4124
4142
|
python_cmd_parts.append("--defensive")
|
|
4125
4143
|
python_cmd_parts.append(f"--defense-module {defense_module_path}")
|
|
4126
4144
|
python_cmd_parts.append(f"--seed {args.seed}")
|
|
4145
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
4127
4146
|
|
|
4128
4147
|
eval_cmd = " ".join(python_cmd_parts)
|
|
4129
4148
|
|
|
@@ -4134,6 +4153,26 @@ async def run_evaluate_kernelbench_runpod(
|
|
|
4134
4153
|
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
|
|
4135
4154
|
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4136
4155
|
|
|
4156
|
+
# Handle prepare-only mode
|
|
4157
|
+
if args.prepare_only:
|
|
4158
|
+
print(f"\n[wafer] Prepared evaluation at: {run_path}")
|
|
4159
|
+
print(f"[wafer] Target: {target.name} ({client.host}:{client.port})")
|
|
4160
|
+
print("[wafer] To run manually:")
|
|
4161
|
+
print(f" ssh -p {client.port} root@{client.host} '{full_cmd}'")
|
|
4162
|
+
print("\n[wafer] Or wrap with rocprof:")
|
|
4163
|
+
print(
|
|
4164
|
+
f" ssh -p {client.port} root@{client.host} 'cd {run_path} && {env_vars} rocprof -i counters.txt {eval_cmd}'"
|
|
4165
|
+
)
|
|
4166
|
+
return EvaluateResult(
|
|
4167
|
+
success=True,
|
|
4168
|
+
all_correct=None, # Not checked in prepare-only mode
|
|
4169
|
+
correctness_score=0.0,
|
|
4170
|
+
geomean_speedup=0.0,
|
|
4171
|
+
passed_tests=0,
|
|
4172
|
+
total_tests=0,
|
|
4173
|
+
error_message=None,
|
|
4174
|
+
)
|
|
4175
|
+
|
|
4137
4176
|
# Run and stream output
|
|
4138
4177
|
log_lines = []
|
|
4139
4178
|
async for line in client.exec_stream(full_cmd):
|
|
@@ -4361,6 +4400,7 @@ async def run_evaluate_kernelbench_baremetal_amd(
|
|
|
4361
4400
|
python_cmd_parts.append("--defensive")
|
|
4362
4401
|
python_cmd_parts.append(f"--defense-module {defense_module_path}")
|
|
4363
4402
|
python_cmd_parts.append(f"--seed {args.seed}")
|
|
4403
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
4364
4404
|
|
|
4365
4405
|
eval_cmd = " ".join(python_cmd_parts)
|
|
4366
4406
|
|
|
@@ -4371,6 +4411,26 @@ async def run_evaluate_kernelbench_baremetal_amd(
|
|
|
4371
4411
|
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
|
|
4372
4412
|
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4373
4413
|
|
|
4414
|
+
# Handle prepare-only mode
|
|
4415
|
+
if args.prepare_only:
|
|
4416
|
+
print(f"\n[wafer] Prepared evaluation at: {run_path}")
|
|
4417
|
+
print(f"[wafer] Target: {target.name} ({client.host}:{client.port})")
|
|
4418
|
+
print("[wafer] To run manually:")
|
|
4419
|
+
print(f" ssh -p {client.port} root@{client.host} '{full_cmd}'")
|
|
4420
|
+
print("\n[wafer] Or wrap with rocprof:")
|
|
4421
|
+
print(
|
|
4422
|
+
f" ssh -p {client.port} root@{client.host} 'cd {run_path} && {env_vars} rocprof -i counters.txt {eval_cmd}'"
|
|
4423
|
+
)
|
|
4424
|
+
return EvaluateResult(
|
|
4425
|
+
success=True,
|
|
4426
|
+
all_correct=None, # Not checked in prepare-only mode
|
|
4427
|
+
correctness_score=0.0,
|
|
4428
|
+
geomean_speedup=0.0,
|
|
4429
|
+
passed_tests=0,
|
|
4430
|
+
total_tests=0,
|
|
4431
|
+
error_message=None,
|
|
4432
|
+
)
|
|
4433
|
+
|
|
4374
4434
|
# Run and stream output
|
|
4375
4435
|
log_lines = []
|
|
4376
4436
|
async for line in client.exec_stream(full_cmd):
|
wafer/kernel_scope.py
CHANGED
|
@@ -10,10 +10,8 @@ It supports analysis of:
|
|
|
10
10
|
Design: Wafer-436 - AMD Kernel Scope / ISA Analyzer
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
-
import json
|
|
14
13
|
import sys
|
|
15
14
|
from pathlib import Path
|
|
16
|
-
from typing import Optional
|
|
17
15
|
|
|
18
16
|
|
|
19
17
|
def print_usage() -> None:
|
|
@@ -54,11 +52,11 @@ def analyze_command(
|
|
|
54
52
|
json_output: bool = False,
|
|
55
53
|
csv_output: bool = False,
|
|
56
54
|
recursive: bool = True,
|
|
57
|
-
filter_expr:
|
|
58
|
-
output_file:
|
|
55
|
+
filter_expr: str | None = None,
|
|
56
|
+
output_file: str | None = None,
|
|
59
57
|
kernel_index: int = 0,
|
|
60
|
-
api_url:
|
|
61
|
-
auth_headers:
|
|
58
|
+
api_url: str | None = None,
|
|
59
|
+
auth_headers: dict[str, str] | None = None,
|
|
62
60
|
) -> str:
|
|
63
61
|
"""Analyze ISA/LLVM-IR/TTGIR/.co file or directory.
|
|
64
62
|
|
|
@@ -77,10 +75,10 @@ def analyze_command(
|
|
|
77
75
|
Analysis output string
|
|
78
76
|
"""
|
|
79
77
|
from wafer_core.lib.kernel_scope import (
|
|
80
|
-
analyze_isa_file,
|
|
81
78
|
analyze_code_object,
|
|
82
79
|
analyze_directory,
|
|
83
80
|
analyze_file,
|
|
81
|
+
analyze_isa_file,
|
|
84
82
|
)
|
|
85
83
|
|
|
86
84
|
target_path = Path(path).expanduser()
|
|
@@ -249,7 +247,7 @@ def _result_to_text(result) -> str:
|
|
|
249
247
|
lines.extend([
|
|
250
248
|
f"Kernel: {a.kernel_name}",
|
|
251
249
|
f"Architecture: {a.architecture}",
|
|
252
|
-
|
|
250
|
+
"Source: Code Object (.co)",
|
|
253
251
|
"",
|
|
254
252
|
"=== Registers ===",
|
|
255
253
|
f" VGPRs: {a.vgpr_count}",
|
|
@@ -289,7 +287,7 @@ def _result_to_text(result) -> str:
|
|
|
289
287
|
lines.extend([
|
|
290
288
|
f"Kernel: {a.kernel_name}",
|
|
291
289
|
f"Architecture: {a.architecture}",
|
|
292
|
-
|
|
290
|
+
"Source: ISA Assembly (.s)",
|
|
293
291
|
"",
|
|
294
292
|
"=== Registers ===",
|
|
295
293
|
f" VGPRs: {a.vgpr_count}",
|
wafer/nsys_profile.py
CHANGED
|
@@ -18,7 +18,6 @@ from .nsys_analyze import (
|
|
|
18
18
|
NSYSAnalysisResult,
|
|
19
19
|
_find_nsys,
|
|
20
20
|
_get_install_command,
|
|
21
|
-
_get_platform,
|
|
22
21
|
_parse_target,
|
|
23
22
|
is_macos,
|
|
24
23
|
)
|
|
@@ -316,11 +315,11 @@ def profile_remote_ssh(
|
|
|
316
315
|
Returns:
|
|
317
316
|
NSYSProfileResult with success status and output path
|
|
318
317
|
"""
|
|
318
|
+
import trio
|
|
319
|
+
|
|
319
320
|
from .targets import load_target
|
|
320
321
|
from .targets_ops import TargetExecError, exec_on_target_sync, get_target_ssh_info
|
|
321
322
|
|
|
322
|
-
import trio
|
|
323
|
-
|
|
324
323
|
# Load target
|
|
325
324
|
try:
|
|
326
325
|
target_config = load_target(target)
|
wafer/output.py
CHANGED
|
@@ -127,10 +127,17 @@ class OutputCollector:
|
|
|
127
127
|
|
|
128
128
|
typer.echo("")
|
|
129
129
|
typer.echo("=" * 60)
|
|
130
|
-
|
|
130
|
+
# Handle None (correctness not run), True (pass), False (fail)
|
|
131
|
+
if result.all_correct is None:
|
|
132
|
+
status = "OK" # Correctness wasn't checked (e.g., compile-only or prepare-only)
|
|
133
|
+
elif result.all_correct:
|
|
134
|
+
status = "PASS"
|
|
135
|
+
else:
|
|
136
|
+
status = "FAIL"
|
|
131
137
|
typer.echo(f"Result: {status}")
|
|
132
|
-
|
|
133
|
-
|
|
138
|
+
if result.total_tests > 0:
|
|
139
|
+
score_pct = f"{result.correctness_score:.1%}"
|
|
140
|
+
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
134
141
|
if result.geomean_speedup > 0:
|
|
135
142
|
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
136
143
|
typer.echo("=" * 60)
|
wafer/rocprof_compute.py
CHANGED
|
@@ -15,8 +15,8 @@ Architecture follows similar patterns from the codebase.
|
|
|
15
15
|
import json
|
|
16
16
|
import subprocess
|
|
17
17
|
import sys
|
|
18
|
-
from pathlib import Path
|
|
19
18
|
from dataclasses import asdict
|
|
19
|
+
from pathlib import Path
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def print_usage() -> None:
|
|
@@ -67,9 +67,12 @@ def check_command(json_output: bool = False) -> str:
|
|
|
67
67
|
Returns:
|
|
68
68
|
Status message or JSON string
|
|
69
69
|
"""
|
|
70
|
-
from wafer_core.lib.rocprofiler.compute import check_installation as core_check # pragma: no cover
|
|
71
70
|
from dataclasses import asdict
|
|
72
71
|
|
|
72
|
+
from wafer_core.lib.rocprofiler.compute import (
|
|
73
|
+
check_installation as core_check, # pragma: no cover
|
|
74
|
+
)
|
|
75
|
+
|
|
73
76
|
result = core_check()
|
|
74
77
|
|
|
75
78
|
if json_output:
|
|
@@ -77,27 +80,27 @@ def check_command(json_output: bool = False) -> str:
|
|
|
77
80
|
return json.dumps(result_dict, indent=2)
|
|
78
81
|
else:
|
|
79
82
|
if result.installed:
|
|
80
|
-
print(
|
|
83
|
+
print("✓ rocprof-compute is installed", file=sys.stderr)
|
|
81
84
|
if result.path:
|
|
82
85
|
print(f" Path: {result.path}", file=sys.stderr)
|
|
83
86
|
if result.version:
|
|
84
87
|
print(f" Version: {result.version}", file=sys.stderr)
|
|
85
88
|
return "rocprof-compute is installed"
|
|
86
89
|
else:
|
|
87
|
-
print(
|
|
88
|
-
print(
|
|
89
|
-
print(
|
|
90
|
-
print(
|
|
91
|
-
print(
|
|
92
|
-
print(
|
|
93
|
-
print(
|
|
94
|
-
print(
|
|
95
|
-
print(
|
|
96
|
-
print(
|
|
97
|
-
print(
|
|
98
|
-
print(
|
|
99
|
-
print(
|
|
100
|
-
print(
|
|
90
|
+
print("✗ rocprof-compute is not installed", file=sys.stderr)
|
|
91
|
+
print("", file=sys.stderr)
|
|
92
|
+
print("rocprof-compute is required to use this feature.", file=sys.stderr)
|
|
93
|
+
print("", file=sys.stderr)
|
|
94
|
+
print("Installation options:", file=sys.stderr)
|
|
95
|
+
print(" 1. Install ROCm toolkit (includes rocprof-compute):", file=sys.stderr)
|
|
96
|
+
print(" sudo apt-get install rocm-dev", file=sys.stderr)
|
|
97
|
+
print("", file=sys.stderr)
|
|
98
|
+
print(" 2. Install rocprofiler-compute package:", file=sys.stderr)
|
|
99
|
+
print(" sudo apt-get install rocprofiler-compute", file=sys.stderr)
|
|
100
|
+
print("", file=sys.stderr)
|
|
101
|
+
print(" 3. Add ROCm to PATH if already installed:", file=sys.stderr)
|
|
102
|
+
print(" export PATH=/opt/rocm/bin:$PATH", file=sys.stderr)
|
|
103
|
+
print("", file=sys.stderr)
|
|
101
104
|
if result.install_command:
|
|
102
105
|
print(f"Suggested command: {result.install_command}", file=sys.stderr)
|
|
103
106
|
return "rocprof-compute is not installed"
|
|
@@ -105,9 +108,12 @@ def check_command(json_output: bool = False) -> str:
|
|
|
105
108
|
|
|
106
109
|
def check_installation() -> dict:
|
|
107
110
|
"""Legacy function for backward compatibility."""
|
|
108
|
-
from wafer_core.lib.rocprofiler.compute import check_installation as core_check # pragma: no cover
|
|
109
111
|
from dataclasses import asdict
|
|
110
112
|
|
|
113
|
+
from wafer_core.lib.rocprofiler.compute import (
|
|
114
|
+
check_installation as core_check, # pragma: no cover
|
|
115
|
+
)
|
|
116
|
+
|
|
111
117
|
result = core_check()
|
|
112
118
|
if hasattr(result, "__dataclass_fields__"):
|
|
113
119
|
return asdict(result)
|
|
@@ -160,13 +166,13 @@ def gui_command(
|
|
|
160
166
|
if json_output:
|
|
161
167
|
return json.dumps(result_dict, indent=2)
|
|
162
168
|
else:
|
|
163
|
-
print(
|
|
169
|
+
print("Launching bundled rocprof-compute GUI viewer...", file=sys.stderr)
|
|
164
170
|
print(f"Folder: {launch_result.folder}", file=sys.stderr)
|
|
165
171
|
print(f"Port: {launch_result.port}", file=sys.stderr)
|
|
166
172
|
print(f"URL: {launch_result.url}", file=sys.stderr)
|
|
167
|
-
print(
|
|
173
|
+
print("", file=sys.stderr)
|
|
168
174
|
print(f"Open {launch_result.url} in your browser", file=sys.stderr)
|
|
169
|
-
print(
|
|
175
|
+
print("Press Ctrl+C to stop the server", file=sys.stderr)
|
|
170
176
|
|
|
171
177
|
# The launch_gui_server with background=False is blocking, so we never reach here
|
|
172
178
|
# unless there's an error
|
|
@@ -186,7 +192,7 @@ def gui_command(
|
|
|
186
192
|
if json_output:
|
|
187
193
|
return json.dumps(result_dict, indent=2)
|
|
188
194
|
else:
|
|
189
|
-
print(
|
|
195
|
+
print("Launching external rocprof-compute GUI...", file=sys.stderr)
|
|
190
196
|
print(f"Folder: {launch_result.folder}", file=sys.stderr)
|
|
191
197
|
print(f"Port: {launch_result.port}", file=sys.stderr)
|
|
192
198
|
print(f"URL: {launch_result.url}", file=sys.stderr)
|
|
@@ -247,9 +253,10 @@ def profile_command(
|
|
|
247
253
|
Raises:
|
|
248
254
|
RuntimeError: If profiling fails
|
|
249
255
|
"""
|
|
250
|
-
from wafer_core.lib.rocprofiler.compute import run_profile # pragma: no cover
|
|
251
256
|
import shlex
|
|
252
257
|
|
|
258
|
+
from wafer_core.lib.rocprofiler.compute import run_profile # pragma: no cover
|
|
259
|
+
|
|
253
260
|
# Parse command string
|
|
254
261
|
cmd_list = shlex.split(command)
|
|
255
262
|
|
|
@@ -276,30 +283,30 @@ def profile_command(
|
|
|
276
283
|
return json.dumps(result_dict, indent=2)
|
|
277
284
|
else:
|
|
278
285
|
if result.success:
|
|
279
|
-
print(
|
|
286
|
+
print("✓ Profiling completed", file=sys.stderr)
|
|
280
287
|
if result.workload_path:
|
|
281
288
|
print(f" Workload: {result.workload_path}", file=sys.stderr)
|
|
282
289
|
if result.output_files:
|
|
283
290
|
print(f" Generated {len(result.output_files)} files", file=sys.stderr)
|
|
284
291
|
return f"Results in: {result.workload_path}"
|
|
285
292
|
else:
|
|
286
|
-
print(
|
|
287
|
-
print(
|
|
293
|
+
print("✗ Profiling failed", file=sys.stderr)
|
|
294
|
+
print("", file=sys.stderr)
|
|
288
295
|
|
|
289
296
|
# Show stderr output (contains actual error details)
|
|
290
297
|
# Note: rocprof-compute may write errors to stdout instead of stderr
|
|
291
298
|
error_output = result.stderr or result.stdout
|
|
292
299
|
if error_output and error_output.strip():
|
|
293
|
-
print(
|
|
294
|
-
print(
|
|
300
|
+
print("rocprof-compute output:", file=sys.stderr)
|
|
301
|
+
print("─" * 60, file=sys.stderr)
|
|
295
302
|
print(error_output.strip(), file=sys.stderr)
|
|
296
|
-
print(
|
|
297
|
-
print(
|
|
303
|
+
print("─" * 60, file=sys.stderr)
|
|
304
|
+
print("", file=sys.stderr)
|
|
298
305
|
|
|
299
306
|
# Show command that was run
|
|
300
307
|
if result.command:
|
|
301
308
|
print(f"Command: {' '.join(result.command)}", file=sys.stderr)
|
|
302
|
-
print(
|
|
309
|
+
print("", file=sys.stderr)
|
|
303
310
|
|
|
304
311
|
# Show high-level error
|
|
305
312
|
if result.error:
|
|
@@ -359,7 +366,7 @@ def analyze_command(
|
|
|
359
366
|
Raises:
|
|
360
367
|
RuntimeError: If analysis fails
|
|
361
368
|
"""
|
|
362
|
-
from wafer_core.lib.rocprofiler.compute import
|
|
369
|
+
from wafer_core.lib.rocprofiler.compute import parse_workload, run_analysis # pragma: no cover
|
|
363
370
|
|
|
364
371
|
# If GUI mode, delegate to GUI launch
|
|
365
372
|
if gui:
|
|
@@ -396,23 +403,23 @@ def analyze_command(
|
|
|
396
403
|
# Just return success message
|
|
397
404
|
return "Analysis completed"
|
|
398
405
|
else:
|
|
399
|
-
print(
|
|
400
|
-
print(
|
|
406
|
+
print("✗ Analysis failed", file=sys.stderr)
|
|
407
|
+
print("", file=sys.stderr)
|
|
401
408
|
|
|
402
409
|
# Show stderr output (contains actual error details)
|
|
403
410
|
# Note: rocprof-compute may write errors to stdout instead of stderr
|
|
404
411
|
error_output = result.stderr or result.stdout
|
|
405
412
|
if error_output and error_output.strip():
|
|
406
|
-
print(
|
|
407
|
-
print(
|
|
413
|
+
print("rocprof-compute output:", file=sys.stderr)
|
|
414
|
+
print("─" * 60, file=sys.stderr)
|
|
408
415
|
print(error_output.strip(), file=sys.stderr)
|
|
409
|
-
print(
|
|
410
|
-
print(
|
|
416
|
+
print("─" * 60, file=sys.stderr)
|
|
417
|
+
print("", file=sys.stderr)
|
|
411
418
|
|
|
412
419
|
# Show command that was run
|
|
413
420
|
if result.command:
|
|
414
421
|
print(f"Command: {' '.join(result.command)}", file=sys.stderr)
|
|
415
|
-
print(
|
|
422
|
+
print("", file=sys.stderr)
|
|
416
423
|
|
|
417
424
|
# Show high-level error
|
|
418
425
|
if result.error:
|
|
@@ -443,10 +450,11 @@ def list_metrics_command(arch: str) -> str:
|
|
|
443
450
|
Returns:
|
|
444
451
|
Metrics list output
|
|
445
452
|
"""
|
|
446
|
-
from wafer_core.lib.rocprofiler.compute import find_rocprof_compute # pragma: no cover
|
|
447
|
-
import subprocess
|
|
448
453
|
import os
|
|
449
454
|
import shutil
|
|
455
|
+
import subprocess
|
|
456
|
+
|
|
457
|
+
from wafer_core.lib.rocprofiler.compute import find_rocprof_compute # pragma: no cover
|
|
450
458
|
|
|
451
459
|
rocprof_path = find_rocprof_compute()
|
|
452
460
|
if not rocprof_path:
|
|
@@ -475,7 +483,7 @@ def list_metrics_command(arch: str) -> str:
|
|
|
475
483
|
else:
|
|
476
484
|
print(f"✗ Failed to list metrics for {arch}", file=sys.stderr)
|
|
477
485
|
if result.stderr:
|
|
478
|
-
print(
|
|
486
|
+
print("Error output:", file=sys.stderr)
|
|
479
487
|
print(result.stderr, file=sys.stderr)
|
|
480
488
|
if result.stdout:
|
|
481
489
|
print(result.stdout, file=sys.stderr)
|
wafer/rocprof_sdk.py
CHANGED
|
@@ -109,7 +109,7 @@ def list_counters_command() -> str:
|
|
|
109
109
|
print(output)
|
|
110
110
|
return output
|
|
111
111
|
else:
|
|
112
|
-
print(
|
|
112
|
+
print("✗ Failed to list counters", file=sys.stderr)
|
|
113
113
|
print(f" {error}", file=sys.stderr)
|
|
114
114
|
raise RuntimeError(error)
|
|
115
115
|
|
wafer/targets_ops.py
CHANGED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Template for KernelBench optimization - matches eval system prompt.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
# Run on a specific problem
|
|
5
|
+
wafer agent -t optimize-kernelbench \
|
|
6
|
+
--args reference=/path/to/problem.py \
|
|
7
|
+
--args pool=kernelbench-pool \
|
|
8
|
+
--args backend=hip \
|
|
9
|
+
--json \
|
|
10
|
+
"Optimize the Softmax kernel"
|
|
11
|
+
|
|
12
|
+
# Watch in real-time with JSON streaming
|
|
13
|
+
wafer agent -t optimize-kernelbench \
|
|
14
|
+
--args reference=./23_Softmax.py \
|
|
15
|
+
--json
|
|
16
|
+
|
|
17
|
+
Variables:
|
|
18
|
+
- reference: Path to the KernelBench problem file (required)
|
|
19
|
+
- pool: Target pool name (default: kernelbench-pool)
|
|
20
|
+
- target: Single target name (alternative to pool)
|
|
21
|
+
- backend: Backend type - hip or cuda (default: hip)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from wafer_core.rollouts.templates import TemplateConfig
|
|
26
|
+
except ImportError:
|
|
27
|
+
from rollouts.templates import TemplateConfig
|
|
28
|
+
|
|
29
|
+
# System prompt matches optimize_kernelbench_eval/base_config.py SYSTEM_PROMPT
|
|
30
|
+
SYSTEM_PROMPT = """\
|
|
31
|
+
You are a GPU kernel optimization expert. Your task is to write optimized GPU kernels that are correct and faster than the PyTorch baseline.
|
|
32
|
+
|
|
33
|
+
IMPORTANT: You do NOT have a local GPU. You MUST use `wafer evaluate kernelbench` to test kernels on remote GPU hardware.
|
|
34
|
+
|
|
35
|
+
## Kernel Format (KernelBench)
|
|
36
|
+
|
|
37
|
+
The reference file contains a PyTorch `Model` class. You must write a `ModelNew` class that:
|
|
38
|
+
1. Has the same `__init__` signature as `Model`
|
|
39
|
+
2. Has a `forward()` method with the same input/output signature
|
|
40
|
+
3. Uses custom $backend_upper kernels for the computation (NOT PyTorch ops like F.scaled_dot_product_attention or torch.matmul)
|
|
41
|
+
|
|
42
|
+
The reference file also provides:
|
|
43
|
+
- `get_inputs()` - generates test inputs for forward()
|
|
44
|
+
- `get_init_inputs()` - generates constructor arguments
|
|
45
|
+
|
|
46
|
+
## Available Tools
|
|
47
|
+
|
|
48
|
+
- read(file_path): Read source files
|
|
49
|
+
- write(file_path, content): Write your optimized kernel
|
|
50
|
+
- glob(pattern): Find files by pattern
|
|
51
|
+
- grep(pattern): Search code
|
|
52
|
+
- bash(command): Run shell commands including wafer CLI
|
|
53
|
+
|
|
54
|
+
## Workflow
|
|
55
|
+
|
|
56
|
+
1. Read the reference problem file to understand what `Model` does
|
|
57
|
+
2. Analyze the computation and identify optimization opportunities
|
|
58
|
+
3. Write an optimized `ModelNew` class with custom $backend_upper kernels using `__global__` kernel definitions and `torch.utils.cpp_extension.load_inline`
|
|
59
|
+
4. Test with: `wafer evaluate kernelbench $target_flag --backend $backend --impl <your_file.py> --reference <problem.py> --benchmark`
|
|
60
|
+
5. Iterate based on feedback until correct and fast
|
|
61
|
+
|
|
62
|
+
## Example Command
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
wafer evaluate kernelbench \\
|
|
66
|
+
$target_flag \\
|
|
67
|
+
--backend $backend \\
|
|
68
|
+
--impl optimized_kernel.py \\
|
|
69
|
+
--reference $reference \\
|
|
70
|
+
--benchmark
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## Profiling Tools (USE THESE!)
|
|
74
|
+
|
|
75
|
+
When your kernel is slower than expected, use profiling to understand WHY:
|
|
76
|
+
|
|
77
|
+
- `wafer rocprof profile --impl <file> --reference <ref>` - AMD GPU profiling
|
|
78
|
+
- `wafer nvidia ncu --impl <file> --reference <ref>` - NVIDIA NCU profiling
|
|
79
|
+
|
|
80
|
+
## CRITICAL: Reactive Debugging
|
|
81
|
+
|
|
82
|
+
After EVERY `wafer evaluate` call:
|
|
83
|
+
1. Check the speedup result
|
|
84
|
+
2. If speedup < 1.0x (slowdown), STOP and analyze:
|
|
85
|
+
- Run profiling to identify the bottleneck
|
|
86
|
+
- Ask: "Why is this slow?" before trying another approach
|
|
87
|
+
3. Don't just try random optimizations - understand the root cause
|
|
88
|
+
|
|
89
|
+
Your kernel MUST:
|
|
90
|
+
- Pass correctness tests (outputs match reference within tolerance)
|
|
91
|
+
- Achieve speedup > 1.0x over PyTorch baseline
|
|
92
|
+
- Use actual $backend_upper kernels (with `__global__` definitions), NOT PyTorch ops
|
|
93
|
+
|
|
94
|
+
You MUST run `wafer evaluate kernelbench` to verify your kernel. Your score depends on actual measured results."""
|
|
95
|
+
|
|
96
|
+
template = TemplateConfig(
|
|
97
|
+
# Identity
|
|
98
|
+
name="optimize-kernelbench",
|
|
99
|
+
description="Optimize KernelBench problems (matches eval system prompt)",
|
|
100
|
+
# System prompt
|
|
101
|
+
system_prompt=SYSTEM_PROMPT,
|
|
102
|
+
# Tools
|
|
103
|
+
tools=["read", "write", "edit", "glob", "grep", "bash"],
|
|
104
|
+
bash_allowlist=[
|
|
105
|
+
"wafer evaluate",
|
|
106
|
+
"wafer nvidia ncu",
|
|
107
|
+
"wafer nvidia nsys",
|
|
108
|
+
"wafer rocprof",
|
|
109
|
+
"wafer compiler-analyze",
|
|
110
|
+
"python",
|
|
111
|
+
"python3",
|
|
112
|
+
"timeout",
|
|
113
|
+
"ls",
|
|
114
|
+
"cat",
|
|
115
|
+
"head",
|
|
116
|
+
"tail",
|
|
117
|
+
"wc",
|
|
118
|
+
"pwd",
|
|
119
|
+
"which",
|
|
120
|
+
],
|
|
121
|
+
# Model config - match eval settings
|
|
122
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
123
|
+
max_tokens=8192,
|
|
124
|
+
# No thinking by default (match eval), can override with --thinking
|
|
125
|
+
thinking=False,
|
|
126
|
+
# Multi-turn for iterative optimization
|
|
127
|
+
single_turn=False,
|
|
128
|
+
# Template variables
|
|
129
|
+
defaults={
|
|
130
|
+
"reference": "./problem.py",
|
|
131
|
+
"pool": "kernelbench-pool",
|
|
132
|
+
"target": "", # If set, overrides pool
|
|
133
|
+
"backend": "hip",
|
|
134
|
+
"backend_upper": "HIP", # Auto-computed from backend
|
|
135
|
+
"target_flag": "--pool kernelbench-pool", # Auto-computed
|
|
136
|
+
},
|
|
137
|
+
)
|
wafer/wevin_cli.py
CHANGED
|
@@ -364,9 +364,9 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
364
364
|
json_output: bool = False,
|
|
365
365
|
) -> None:
|
|
366
366
|
"""Run wevin agent in-process via rollouts."""
|
|
367
|
-
import trio
|
|
368
367
|
from dataclasses import asdict
|
|
369
368
|
|
|
369
|
+
import trio
|
|
370
370
|
from wafer_core.rollouts import FileSessionStore
|
|
371
371
|
|
|
372
372
|
session_store = FileSessionStore()
|
wafer/workspaces.py
CHANGED
|
@@ -1,40 +1,41 @@
|
|
|
1
1
|
wafer/GUIDE.md,sha256=G6P4aFZslEXiHmVjtTB3_OIpGK5d1tSiqxtawASVUZg,3588
|
|
2
2
|
wafer/__init__.py,sha256=kBM_ONCpU6UUMBOH8Tmg4A88sNFnbaD59o61cJs-uYM,90
|
|
3
|
-
wafer/analytics.py,sha256=
|
|
3
|
+
wafer/analytics.py,sha256=qLY6Z16usVHFD8TCv7XBuz7l47vXVdXk-qhOzA-hW_8,8179
|
|
4
4
|
wafer/api_client.py,sha256=i_Az2b2llC3DSW8yOL-BKqa7LSKuxOr8hSN40s-oQXY,6313
|
|
5
|
-
wafer/auth.py,sha256=
|
|
6
|
-
wafer/autotuner.py,sha256=
|
|
5
|
+
wafer/auth.py,sha256=nneKUjGwb5ggJEHRdF_GlFkT1ZozHP4kGyuXjhZjtgM,13677
|
|
6
|
+
wafer/autotuner.py,sha256=41WYP41pTDvMijv2h42vm89bcHtDMJXObDlWmn6xpFU,44416
|
|
7
7
|
wafer/billing.py,sha256=jbLB2lI4_9f2KD8uEFDi_ixLlowe5hasC0TIZJyIXRg,7163
|
|
8
|
-
wafer/cli.py,sha256=
|
|
8
|
+
wafer/cli.py,sha256=lBBTQCcmKREqZDOQh27qSq8i6NedjHW5oh1JiuT9aho,254241
|
|
9
9
|
wafer/config.py,sha256=h5Eo9_yfWqWGoPNdVQikI9GoZVUeysunSYiixf1mKcw,3411
|
|
10
10
|
wafer/corpus.py,sha256=x5aFhCsTSAtgzFG9AMFpqq92Ej63mXofL-vvvpjj1sM,12913
|
|
11
|
-
wafer/evaluate.py,sha256=
|
|
11
|
+
wafer/evaluate.py,sha256=bLTfL7jAGQlfqLL39hSGSB7bnBp5THTCY7nl6giVMkQ,176005
|
|
12
12
|
wafer/global_config.py,sha256=fhaR_RU3ufMksDmOohH1OLeQ0JT0SDW1hEip_zaP75k,11345
|
|
13
13
|
wafer/gpu_run.py,sha256=TwqXy72T7f2I7e6n5WWod3xgxCPnDhU0BgLsB4CUoQY,9716
|
|
14
14
|
wafer/inference.py,sha256=tZCO5i05FKY27ewis3CSBHFBeFbXY3xwj0DSjdoMY9s,4314
|
|
15
|
-
wafer/kernel_scope.py,sha256=
|
|
15
|
+
wafer/kernel_scope.py,sha256=YtnxknAChkJoeU_vIdxiqWsAITGBeabp9OGIK-X32i0,20796
|
|
16
16
|
wafer/ncu_analyze.py,sha256=rAWzKQRZEY6E_CL3gAWUaW3uZ4kvQVZskVCPDpsFJuE,24633
|
|
17
17
|
wafer/nsys_analyze.py,sha256=AhNcjPaapB0QCbqiHRXvyy-ccjevvVwEyxes84D28JU,36124
|
|
18
|
-
wafer/nsys_profile.py,sha256=
|
|
19
|
-
wafer/output.py,sha256=
|
|
18
|
+
wafer/nsys_profile.py,sha256=QFBl8pkr8r4uRNdNUO9gY-obj9slqpOgVYFZ_sXu6Nw,15478
|
|
19
|
+
wafer/output.py,sha256=8jw5ifvIMK8ldyBMGW4NhrKvJPl66TV2Y2fJ5Tlhh1I,8293
|
|
20
20
|
wafer/problems.py,sha256=ce2sy10A1nnNUG3VGsseTS8jL7LZsku4dE8zVf9JHQ4,11296
|
|
21
|
-
wafer/rocprof_compute.py,sha256=
|
|
22
|
-
wafer/rocprof_sdk.py,sha256=
|
|
21
|
+
wafer/rocprof_compute.py,sha256=n_yOGZaFbOXna_ghhmYWXeyUoSabgH4KkjlYq38DlHo,19888
|
|
22
|
+
wafer/rocprof_sdk.py,sha256=0Q7Ye6dUfa1anFZbqKc21rItgqva8V8VIZoSB7wqbmA,10085
|
|
23
23
|
wafer/rocprof_systems.py,sha256=4IWbMcbYk1x_8iS7P3FC_u5sgH6EXADCtR2lV9id80M,18629
|
|
24
24
|
wafer/ssh_keys.py,sha256=9kSdhV_dg9T6pQu2JmNQptarkkwGtN9rLyRkI1bW4i4,8094
|
|
25
25
|
wafer/target_lock.py,sha256=SDKhNzv2N7gsphGflcNni9FE5YYuAMuEthngAJEo4Gs,7809
|
|
26
26
|
wafer/targets.py,sha256=9r-iRWoKSH5cQl1LcamaX-T7cNVOg99ngIm_hlRk-qU,26922
|
|
27
|
-
wafer/targets_ops.py,sha256=
|
|
27
|
+
wafer/targets_ops.py,sha256=jN1oIBx0mutxRNE9xpIc7SaBxPkVmOyus2eqn0kEKNI,21475
|
|
28
28
|
wafer/tracelens.py,sha256=g9ZIeFyNojZn4uTd3skPqIrRiL7aMJOz_-GOd3aiyy4,7998
|
|
29
|
-
wafer/wevin_cli.py,sha256=
|
|
30
|
-
wafer/workspaces.py,sha256=
|
|
29
|
+
wafer/wevin_cli.py,sha256=vF3GNH-qWXO4hAlXaDg98VZpS4uFexVUp94BHsJjjMU,22179
|
|
30
|
+
wafer/workspaces.py,sha256=XZvN-13oq40fkpoJTB2UWTG9KkD-eO47ptXK0FY6360,30083
|
|
31
31
|
wafer/skills/wafer-guide/SKILL.md,sha256=KWetJw2TVTbz11_nzqazqOJWWRlbHRFShs4sOoreiWo,3255
|
|
32
32
|
wafer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
33
|
wafer/templates/ask_docs.py,sha256=Lxs-faz9v5m4Qa4NjF2X_lE8KwM9ES9MNJkxo7ep56o,2256
|
|
34
34
|
wafer/templates/optimize_kernel.py,sha256=u6AL7Q3uttqlnBLzcoFdsiPq5lV2TV3bgqwCYYlK9gk,2357
|
|
35
|
+
wafer/templates/optimize_kernelbench.py,sha256=aoOA13zWEl89r6QW03xF9NKxQ7j4mWe9rwua6-mlr4Y,4780
|
|
35
36
|
wafer/templates/trace_analyze.py,sha256=XE1VqzVkIUsZbXF8EzQdDYgg-AZEYAOFpr6B_vnRELc,2880
|
|
36
|
-
wafer_cli-0.2.
|
|
37
|
-
wafer_cli-0.2.
|
|
38
|
-
wafer_cli-0.2.
|
|
39
|
-
wafer_cli-0.2.
|
|
40
|
-
wafer_cli-0.2.
|
|
37
|
+
wafer_cli-0.2.15.dist-info/METADATA,sha256=z1TLYbZzeOJpMMaG3TqJd_M5WbeLRVEtSmoTO0qhPc4,560
|
|
38
|
+
wafer_cli-0.2.15.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
39
|
+
wafer_cli-0.2.15.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
|
|
40
|
+
wafer_cli-0.2.15.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
|
|
41
|
+
wafer_cli-0.2.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|