wafer-cli 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/billing.py +6 -6
- wafer/cli.py +454 -86
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +7 -1
- wafer/evaluate.py +13 -6
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/ssh_keys.py +6 -6
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +1 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +183 -0
- wafer/wevin_cli.py +68 -9
- wafer/workspaces.py +8 -8
- wafer_cli-0.2.25.dist-info/METADATA +107 -0
- wafer_cli-0.2.25.dist-info/RECORD +45 -0
- wafer_cli-0.2.24.dist-info/METADATA +0 -16
- wafer_cli-0.2.24.dist-info/RECORD +0 -41
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.25.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.25.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.24.dist-info → wafer_cli-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Generate agent system prompt instructions from the wafer CLI's own --help text.
|
|
2
|
+
|
|
3
|
+
Walks the typer/click command tree and extracts help text for commands
|
|
4
|
+
matching the bash_allowlist. This ensures agent instructions stay in sync
|
|
5
|
+
with the CLI — the --help text is the single source of truth for both
|
|
6
|
+
human users and AI agents.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
10
|
+
|
|
11
|
+
instructions = build_cli_instructions([
|
|
12
|
+
"wafer evaluate",
|
|
13
|
+
"wafer nvidia ncu",
|
|
14
|
+
"wafer rocprof profile",
|
|
15
|
+
"python", # non-wafer commands are skipped
|
|
16
|
+
])
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import click
|
|
22
|
+
import typer.main
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _resolve_command(root: click.BaseCommand, parts: list[str]) -> click.BaseCommand | None:
|
|
26
|
+
"""Walk the click command tree to find a (sub)command by name parts.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
root: The root click command (from typer.main.get_command)
|
|
30
|
+
parts: Command path segments, e.g. ["evaluate", "kernelbench"]
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The click command at that path, or None if not found.
|
|
34
|
+
"""
|
|
35
|
+
cmd = root
|
|
36
|
+
for part in parts:
|
|
37
|
+
if not isinstance(cmd, click.MultiCommand):
|
|
38
|
+
return None
|
|
39
|
+
ctx = click.Context(cmd, info_name=part)
|
|
40
|
+
child = cmd.get_command(ctx, part)
|
|
41
|
+
if child is None:
|
|
42
|
+
return None
|
|
43
|
+
cmd = child
|
|
44
|
+
return cmd
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _format_command_help(cmd_path: str, cmd: click.BaseCommand) -> str:
|
|
48
|
+
"""Format a single command's help text for inclusion in a system prompt.
|
|
49
|
+
|
|
50
|
+
Extracts the description and option help text (skipping --help itself).
|
|
51
|
+
"""
|
|
52
|
+
lines = [f"### `{cmd_path}`"]
|
|
53
|
+
|
|
54
|
+
if cmd.help:
|
|
55
|
+
lines.append(cmd.help.strip())
|
|
56
|
+
|
|
57
|
+
# Extract option help
|
|
58
|
+
option_lines = []
|
|
59
|
+
for param in getattr(cmd, "params", []):
|
|
60
|
+
if not isinstance(param, click.Option):
|
|
61
|
+
continue
|
|
62
|
+
# Skip --help
|
|
63
|
+
if param.name == "help":
|
|
64
|
+
continue
|
|
65
|
+
name = "/".join(param.opts)
|
|
66
|
+
type_name = param.type.name.upper() if hasattr(param.type, "name") else ""
|
|
67
|
+
help_text = param.help or ""
|
|
68
|
+
is_flag = type_name in ("BOOL", "BOOLEAN") or param.is_flag
|
|
69
|
+
if type_name and not is_flag:
|
|
70
|
+
option_lines.append(f" {name} {type_name} {help_text}")
|
|
71
|
+
else:
|
|
72
|
+
option_lines.append(f" {name} {help_text}")
|
|
73
|
+
|
|
74
|
+
if option_lines:
|
|
75
|
+
lines.append("")
|
|
76
|
+
lines.append("Options:")
|
|
77
|
+
lines.extend(option_lines)
|
|
78
|
+
|
|
79
|
+
# List subcommands if this is a group
|
|
80
|
+
if isinstance(cmd, click.MultiCommand):
|
|
81
|
+
ctx = click.Context(cmd, info_name=cmd_path.split()[-1])
|
|
82
|
+
subcmd_names = cmd.list_commands(ctx)
|
|
83
|
+
if subcmd_names:
|
|
84
|
+
subcmd_lines = []
|
|
85
|
+
for name in subcmd_names:
|
|
86
|
+
subcmd = cmd.get_command(ctx, name)
|
|
87
|
+
if subcmd:
|
|
88
|
+
desc = (subcmd.help or subcmd.short_help or "").strip().split("\n")[0]
|
|
89
|
+
subcmd_lines.append(f" {cmd_path} {name} {desc}")
|
|
90
|
+
if subcmd_lines:
|
|
91
|
+
lines.append("")
|
|
92
|
+
lines.append("Subcommands:")
|
|
93
|
+
lines.extend(subcmd_lines)
|
|
94
|
+
|
|
95
|
+
return "\n".join(lines)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def build_cli_instructions(bash_allowlist: list[str]) -> str:
|
|
99
|
+
"""Generate CLI instruction text from --help for allowed wafer commands.
|
|
100
|
+
|
|
101
|
+
Walks the typer/click command tree and extracts help text for each
|
|
102
|
+
wafer command in the bash_allowlist. Non-wafer commands (python, ls, etc.)
|
|
103
|
+
are skipped.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
bash_allowlist: List of allowed bash command prefixes.
|
|
107
|
+
Example: ["wafer evaluate", "wafer nvidia ncu", "python"]
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Markdown-formatted CLI instructions, or empty string if no wafer
|
|
111
|
+
commands are in the allowlist.
|
|
112
|
+
"""
|
|
113
|
+
if not bash_allowlist:
|
|
114
|
+
return ""
|
|
115
|
+
|
|
116
|
+
# Filter to wafer commands only
|
|
117
|
+
wafer_commands = [cmd for cmd in bash_allowlist if cmd.startswith("wafer ")]
|
|
118
|
+
if not wafer_commands:
|
|
119
|
+
return ""
|
|
120
|
+
|
|
121
|
+
# Lazy import to avoid circular deps at module level
|
|
122
|
+
from wafer.cli import app
|
|
123
|
+
|
|
124
|
+
root = typer.main.get_command(app)
|
|
125
|
+
|
|
126
|
+
sections = []
|
|
127
|
+
for cmd_str in wafer_commands:
|
|
128
|
+
# "wafer evaluate kernelbench" -> ["evaluate", "kernelbench"]
|
|
129
|
+
parts = cmd_str.split()[1:] # drop "wafer" prefix
|
|
130
|
+
cmd = _resolve_command(root, parts)
|
|
131
|
+
if cmd is None:
|
|
132
|
+
# Command not found in tree — skip silently
|
|
133
|
+
continue
|
|
134
|
+
sections.append(_format_command_help(cmd_str, cmd))
|
|
135
|
+
|
|
136
|
+
if not sections:
|
|
137
|
+
return ""
|
|
138
|
+
|
|
139
|
+
header = (
|
|
140
|
+
"## Wafer CLI Commands\n\n"
|
|
141
|
+
"You do not have a local GPU. Use the wafer CLI to run on remote GPU hardware.\n"
|
|
142
|
+
)
|
|
143
|
+
return header + "\n\n".join(sections)
|
wafer/corpus.py
CHANGED
|
@@ -160,11 +160,17 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
|
160
160
|
paths=["docs"],
|
|
161
161
|
branch="develop_deprecated",
|
|
162
162
|
),
|
|
163
|
-
# HipKittens - high-performance AMD kernels
|
|
163
|
+
# HipKittens - high-performance AMD kernels (main branch: MI350X/CDNA4+)
|
|
164
164
|
RepoSource(
|
|
165
165
|
repo="HazyResearch/HipKittens",
|
|
166
166
|
paths=["docs", "kernels", "include"],
|
|
167
167
|
),
|
|
168
|
+
# HipKittens cdna3 branch - MI300X/MI325X (gfx942)
|
|
169
|
+
RepoSource(
|
|
170
|
+
repo="HazyResearch/HipKittens",
|
|
171
|
+
paths=["kernels", "include", "tests"],
|
|
172
|
+
branch="cdna3",
|
|
173
|
+
),
|
|
168
174
|
# vLLM AMD kernels
|
|
169
175
|
RepoSource(
|
|
170
176
|
repo="vllm-project/vllm",
|
wafer/evaluate.py
CHANGED
|
@@ -3496,7 +3496,7 @@ def _build_modal_kernelbench_script(
|
|
|
3496
3496
|
# Install CUTLASS headers (for cute/tensor.hpp and cutlass/util/*.h) from GitHub
|
|
3497
3497
|
# The nvidia-cutlass-dsl pip package doesn't include the C++ headers needed for nvcc
|
|
3498
3498
|
# IMPORTANT: symlink to /usr/local/cuda/include because nvcc searches there by default
|
|
3499
|
-
cutlass_install =
|
|
3499
|
+
cutlass_install = """
|
|
3500
3500
|
.run_commands([
|
|
3501
3501
|
# Clone CUTLASS headers from GitHub (shallow clone, full include tree)
|
|
3502
3502
|
# Use simple shallow clone - sparse-checkout can be buggy in some environments
|
|
@@ -3512,7 +3512,7 @@ def _build_modal_kernelbench_script(
|
|
|
3512
3512
|
index_url="https://pypi.nvidia.com",
|
|
3513
3513
|
extra_index_url="https://pypi.org/simple",
|
|
3514
3514
|
)
|
|
3515
|
-
|
|
3515
|
+
"""
|
|
3516
3516
|
|
|
3517
3517
|
inputs_write = ""
|
|
3518
3518
|
if inputs_code_b64:
|
|
@@ -3772,7 +3772,7 @@ async def run_evaluate_kernelbench_modal(
|
|
|
3772
3772
|
result_json = None
|
|
3773
3773
|
for line in stdout.split("\n"):
|
|
3774
3774
|
if line.startswith("EVAL_RESULT_JSON:"):
|
|
3775
|
-
result_json = line[len("EVAL_RESULT_JSON:"):]
|
|
3775
|
+
result_json = line[len("EVAL_RESULT_JSON:") :]
|
|
3776
3776
|
break
|
|
3777
3777
|
|
|
3778
3778
|
if not result_json:
|
|
@@ -4486,6 +4486,7 @@ async def run_evaluate_kernelbench_runpod(
|
|
|
4486
4486
|
# Find Python with PyTorch - check common locations on RunPod
|
|
4487
4487
|
python_exe = "python3"
|
|
4488
4488
|
for candidate in [
|
|
4489
|
+
"/opt/venv/bin/python3",
|
|
4489
4490
|
"/opt/conda/envs/py_3.10/bin/python3",
|
|
4490
4491
|
"/opt/conda/bin/python3",
|
|
4491
4492
|
]:
|
|
@@ -4630,7 +4631,9 @@ async def run_evaluate_kernelbench_baremetal_direct(
|
|
|
4630
4631
|
"""
|
|
4631
4632
|
# Reuse the AMD function but with CUDA env vars
|
|
4632
4633
|
# The logic is identical, just the GPU env var is different
|
|
4633
|
-
return await _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
4634
|
+
return await _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
4635
|
+
args, target, gpu_env_var="CUDA_VISIBLE_DEVICES"
|
|
4636
|
+
)
|
|
4634
4637
|
|
|
4635
4638
|
|
|
4636
4639
|
async def run_evaluate_kernelbench_baremetal_amd(
|
|
@@ -4642,7 +4645,9 @@ async def run_evaluate_kernelbench_baremetal_amd(
|
|
|
4642
4645
|
Runs evaluation script directly on host (no Docker) for AMD GPUs
|
|
4643
4646
|
that have PyTorch/ROCm installed.
|
|
4644
4647
|
"""
|
|
4645
|
-
return await _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
4648
|
+
return await _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
4649
|
+
args, target, gpu_env_var="HIP_VISIBLE_DEVICES"
|
|
4650
|
+
)
|
|
4646
4651
|
|
|
4647
4652
|
|
|
4648
4653
|
async def _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
@@ -4809,7 +4814,9 @@ async def _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
|
4809
4814
|
# AMD: PYTORCH_ROCM_ARCH for faster compile
|
|
4810
4815
|
rocm_arch = _get_rocm_arch(target.compute_capability)
|
|
4811
4816
|
arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
|
|
4812
|
-
env_vars =
|
|
4817
|
+
env_vars = (
|
|
4818
|
+
f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
|
|
4819
|
+
)
|
|
4813
4820
|
else:
|
|
4814
4821
|
# NVIDIA: just set CUDA_VISIBLE_DEVICES
|
|
4815
4822
|
env_vars = f"CUDA_VISIBLE_DEVICES={gpu_id} PYTHONUNBUFFERED=1"
|
wafer/kernel_scope.py
CHANGED
|
@@ -95,7 +95,7 @@ def analyze_command(
|
|
|
95
95
|
if not api_url or not auth_headers:
|
|
96
96
|
raise RuntimeError(
|
|
97
97
|
"API authentication required for .co file analysis. "
|
|
98
|
-
"Run 'wafer login' first."
|
|
98
|
+
"Run 'wafer auth login' first."
|
|
99
99
|
)
|
|
100
100
|
result = analyze_code_object(target_path, api_url, auth_headers)
|
|
101
101
|
# ISA files - use kernel_index parameter
|
wafer/ncu_analyze.py
CHANGED
|
@@ -520,7 +520,7 @@ def _analyze_remote_api(
|
|
|
520
520
|
|
|
521
521
|
except httpx.HTTPStatusError as e:
|
|
522
522
|
if e.response.status_code == 401:
|
|
523
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
523
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
524
524
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
525
525
|
except httpx.RequestError as e:
|
|
526
526
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
wafer/nsys_analyze.py
CHANGED
|
@@ -844,7 +844,7 @@ def _analyze_remote_api(
|
|
|
844
844
|
|
|
845
845
|
except httpx.HTTPStatusError as e:
|
|
846
846
|
if e.response.status_code == 401:
|
|
847
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
847
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
848
848
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
849
849
|
except httpx.RequestError as e:
|
|
850
850
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
@@ -16,7 +16,7 @@ Before using Wafer CLI commands, install the tool:
|
|
|
16
16
|
uv tool install wafer-cli
|
|
17
17
|
|
|
18
18
|
# Authenticate (one-time setup)
|
|
19
|
-
wafer login
|
|
19
|
+
wafer auth login
|
|
20
20
|
|
|
21
21
|
```
|
|
22
22
|
|
|
@@ -71,15 +71,31 @@ Test correctness and measure speedup against a reference:
|
|
|
71
71
|
wafer evaluate make-template ./my-kernel
|
|
72
72
|
# Creates: kernel.py, reference.py, test_cases.json
|
|
73
73
|
|
|
74
|
-
#
|
|
75
|
-
|
|
74
|
+
# test_cases.json format:
|
|
75
|
+
# [{"name": "small", "n": 1024, "seed": 42}, {"name": "large", "n": 1048576, "seed": 42}]
|
|
76
|
+
# Each dict is passed as **kwargs to generate_input() in reference.py
|
|
77
|
+
|
|
78
|
+
# Run correctness check (GPUMode functional format)
|
|
79
|
+
wafer evaluate gpumode \
|
|
76
80
|
--impl ./my-kernel/kernel.py \
|
|
77
81
|
--reference ./my-kernel/reference.py \
|
|
78
82
|
--test-cases ./my-kernel/test_cases.json \
|
|
79
83
|
--target <target-name>
|
|
80
84
|
|
|
81
|
-
#
|
|
82
|
-
wafer evaluate
|
|
85
|
+
# Run correctness + benchmark (measures speedup vs reference)
|
|
86
|
+
wafer evaluate gpumode \
|
|
87
|
+
--impl ./my-kernel/kernel.py \
|
|
88
|
+
--reference ./my-kernel/reference.py \
|
|
89
|
+
--test-cases ./my-kernel/test_cases.json \
|
|
90
|
+
--target <target-name> --benchmark
|
|
91
|
+
|
|
92
|
+
# Run with defensive timing (detects evaluation hacking)
|
|
93
|
+
wafer evaluate gpumode ... --benchmark --defensive
|
|
94
|
+
|
|
95
|
+
# KernelBench format (ModelNew class)
|
|
96
|
+
wafer evaluate kernelbench \
|
|
97
|
+
--impl my_kernel.py --reference problem.py \
|
|
98
|
+
--target <target-name> --stages all
|
|
83
99
|
```
|
|
84
100
|
|
|
85
101
|
### 4. AI-Assisted Optimization
|
|
@@ -126,4 +142,4 @@ wafer config targets init runpod # RunPod cloud GPUs
|
|
|
126
142
|
wafer config targets init digitalocean # DigitalOcean AMD GPUs
|
|
127
143
|
```
|
|
128
144
|
|
|
129
|
-
Then use: `wafer evaluate --target <name> ...`
|
|
145
|
+
Then use: `wafer evaluate gpumode --target <name> ...`
|
wafer/ssh_keys.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""SSH Keys CLI - Manage SSH public keys for workspace access.
|
|
2
2
|
|
|
3
|
-
This module provides the implementation for the `wafer ssh-keys` subcommand.
|
|
3
|
+
This module provides the implementation for the `wafer config ssh-keys` subcommand.
|
|
4
4
|
Users register their SSH public keys here, which are then installed in all
|
|
5
5
|
workspaces they attach to (BYOK - Bring Your Own Key model).
|
|
6
6
|
"""
|
|
@@ -94,7 +94,7 @@ def list_ssh_keys(json_output: bool = False) -> str:
|
|
|
94
94
|
keys = response.json()
|
|
95
95
|
except httpx.HTTPStatusError as e:
|
|
96
96
|
if e.response.status_code == 401:
|
|
97
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
97
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
98
98
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
99
99
|
except httpx.RequestError as e:
|
|
100
100
|
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
@@ -107,7 +107,7 @@ def list_ssh_keys(json_output: bool = False) -> str:
|
|
|
107
107
|
"No SSH keys registered.\n"
|
|
108
108
|
"\n"
|
|
109
109
|
"Add your SSH key:\n"
|
|
110
|
-
" wafer ssh-keys add\n"
|
|
110
|
+
" wafer config ssh-keys add\n"
|
|
111
111
|
"\n"
|
|
112
112
|
"This will auto-detect your key from ~/.ssh/"
|
|
113
113
|
)
|
|
@@ -149,7 +149,7 @@ def add_ssh_key(
|
|
|
149
149
|
" ssh-keygen -t ed25519\n"
|
|
150
150
|
"\n"
|
|
151
151
|
"Or specify a path:\n"
|
|
152
|
-
" wafer ssh-keys add /path/to/key.pub"
|
|
152
|
+
" wafer config ssh-keys add /path/to/key.pub"
|
|
153
153
|
)
|
|
154
154
|
pubkey_path = detected[0]
|
|
155
155
|
|
|
@@ -202,7 +202,7 @@ def add_ssh_key(
|
|
|
202
202
|
key_data = response.json()
|
|
203
203
|
except httpx.HTTPStatusError as e:
|
|
204
204
|
if e.response.status_code == 401:
|
|
205
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
205
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
206
206
|
if e.response.status_code == 400:
|
|
207
207
|
# Parse error detail
|
|
208
208
|
try:
|
|
@@ -248,7 +248,7 @@ def remove_ssh_key(key_id: str, json_output: bool = False) -> str:
|
|
|
248
248
|
response.raise_for_status()
|
|
249
249
|
except httpx.HTTPStatusError as e:
|
|
250
250
|
if e.response.status_code == 401:
|
|
251
|
-
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
251
|
+
raise RuntimeError("Not authenticated. Run: wafer auth login") from e
|
|
252
252
|
if e.response.status_code == 404:
|
|
253
253
|
raise RuntimeError(f"SSH key not found: {key_id}") from e
|
|
254
254
|
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
wafer/templates/ask_docs.py
CHANGED
|
@@ -51,7 +51,7 @@ Output your answer directly. Be concise but thorough. Include code examples when
|
|
|
51
51
|
"python -c",
|
|
52
52
|
],
|
|
53
53
|
# Model config
|
|
54
|
-
model="anthropic/claude-
|
|
54
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
55
55
|
max_tokens=8192,
|
|
56
56
|
# Thinking config - disabled for simple doc queries
|
|
57
57
|
thinking=False,
|
|
@@ -56,7 +56,7 @@ IMPORTANT: Always verify correctness with wafer evaluate before claiming success
|
|
|
56
56
|
"python -c",
|
|
57
57
|
],
|
|
58
58
|
# Model config - use thinking for complex optimization reasoning
|
|
59
|
-
model="anthropic/claude-
|
|
59
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
60
60
|
max_tokens=16384,
|
|
61
61
|
# Thinking config - enabled for complex kernel optimization
|
|
62
62
|
thinking=True,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Template for KernelBench optimization
|
|
1
|
+
"""Template for KernelBench optimization.
|
|
2
2
|
|
|
3
3
|
Usage:
|
|
4
4
|
# Run on a specific problem
|
|
@@ -26,12 +26,18 @@ try:
|
|
|
26
26
|
except ImportError:
|
|
27
27
|
from rollouts.templates import TemplateConfig
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
from wafer.agent_defaults import ENABLED_TOOLS, KERNELBENCH_BASH_ALLOWLIST
|
|
30
|
+
|
|
31
|
+
# Task-specific instructions only — must stay in sync with the eval's SYSTEM_PROMPT
|
|
32
|
+
# in research/evals/optimize_kernelbench_eval/.../base_config.py.
|
|
33
|
+
# Run test_eval_cli_parity.py to verify.
|
|
34
|
+
# Wafer CLI command docs are auto-generated from --help text and composed
|
|
35
|
+
# at runtime by wevin_cli.py (see wafer.cli_instructions.build_cli_instructions).
|
|
36
|
+
# TODO: Consider having both eval and template import SYSTEM_PROMPT from a shared
|
|
37
|
+
# module so there's only one copy to maintain.
|
|
30
38
|
SYSTEM_PROMPT = """\
|
|
31
39
|
You are a GPU kernel optimization expert. Your task is to write optimized GPU kernels that are correct and faster than the PyTorch baseline.
|
|
32
40
|
|
|
33
|
-
IMPORTANT: You do NOT have a local GPU. You MUST use `wafer evaluate kernelbench` to test kernels on remote GPU hardware.
|
|
34
|
-
|
|
35
41
|
## Kernel Format (KernelBench)
|
|
36
42
|
|
|
37
43
|
The reference file contains a PyTorch `Model` class. You must write a `ModelNew` class that:
|
|
@@ -43,49 +49,14 @@ The reference file also provides:
|
|
|
43
49
|
- `get_inputs()` - generates test inputs for forward()
|
|
44
50
|
- `get_init_inputs()` - generates constructor arguments
|
|
45
51
|
|
|
46
|
-
## Available Tools
|
|
47
|
-
|
|
48
|
-
- read(file_path): Read source files
|
|
49
|
-
- write(file_path, content): Write your optimized kernel
|
|
50
|
-
- glob(pattern): Find files by pattern
|
|
51
|
-
- grep(pattern): Search code
|
|
52
|
-
- bash(command): Run shell commands including wafer CLI
|
|
53
|
-
|
|
54
52
|
## Workflow
|
|
55
53
|
|
|
56
54
|
1. Read the reference problem file to understand what `Model` does
|
|
57
55
|
2. Analyze the computation and identify optimization opportunities
|
|
58
56
|
3. Write an optimized `ModelNew` class with custom $backend_upper kernels using `__global__` kernel definitions and `torch.utils.cpp_extension.load_inline`
|
|
59
|
-
4. Test with: `wafer evaluate kernelbench $target_flag --backend $backend --impl
|
|
57
|
+
4. Test with: `wafer evaluate kernelbench $target_flag --backend $backend --impl optimized.py --reference <problem.py> --benchmark`
|
|
60
58
|
5. Iterate based on feedback until correct and fast
|
|
61
59
|
|
|
62
|
-
## Example Command
|
|
63
|
-
|
|
64
|
-
```bash
|
|
65
|
-
wafer evaluate kernelbench \\
|
|
66
|
-
$target_flag \\
|
|
67
|
-
--backend $backend \\
|
|
68
|
-
--impl optimized_kernel.py \\
|
|
69
|
-
--reference $reference \\
|
|
70
|
-
--benchmark
|
|
71
|
-
```
|
|
72
|
-
|
|
73
|
-
## Profiling Tools (USE THESE!)
|
|
74
|
-
|
|
75
|
-
When your kernel is slower than expected, use profiling to understand WHY:
|
|
76
|
-
|
|
77
|
-
- `wafer rocprof profile --impl <file> --reference <ref>` - AMD GPU profiling
|
|
78
|
-
- `wafer nvidia ncu --impl <file> --reference <ref>` - NVIDIA NCU profiling
|
|
79
|
-
|
|
80
|
-
## CRITICAL: Reactive Debugging
|
|
81
|
-
|
|
82
|
-
After EVERY `wafer evaluate` call:
|
|
83
|
-
1. Check the speedup result
|
|
84
|
-
2. If speedup < 1.0x (slowdown), STOP and analyze:
|
|
85
|
-
- Run profiling to identify the bottleneck
|
|
86
|
-
- Ask: "Why is this slow?" before trying another approach
|
|
87
|
-
3. Don't just try random optimizations - understand the root cause
|
|
88
|
-
|
|
89
60
|
Your kernel MUST:
|
|
90
61
|
- Pass correctness tests (outputs match reference within tolerance)
|
|
91
62
|
- Achieve speedup > 1.0x over PyTorch baseline
|
|
@@ -96,32 +67,16 @@ You MUST run `wafer evaluate kernelbench` to verify your kernel. Your score depe
|
|
|
96
67
|
template = TemplateConfig(
|
|
97
68
|
# Identity
|
|
98
69
|
name="optimize-kernelbench",
|
|
99
|
-
description="Optimize KernelBench problems
|
|
100
|
-
# System prompt
|
|
70
|
+
description="Optimize KernelBench problems",
|
|
71
|
+
# System prompt (task-specific; CLI docs appended at runtime)
|
|
101
72
|
system_prompt=SYSTEM_PROMPT,
|
|
102
73
|
# Tools
|
|
103
|
-
tools=
|
|
104
|
-
bash_allowlist=
|
|
105
|
-
|
|
106
|
-
"wafer nvidia ncu",
|
|
107
|
-
"wafer nvidia nsys",
|
|
108
|
-
"wafer rocprof",
|
|
109
|
-
"wafer compiler-analyze",
|
|
110
|
-
"python",
|
|
111
|
-
"python3",
|
|
112
|
-
"timeout",
|
|
113
|
-
"ls",
|
|
114
|
-
"cat",
|
|
115
|
-
"head",
|
|
116
|
-
"tail",
|
|
117
|
-
"wc",
|
|
118
|
-
"pwd",
|
|
119
|
-
"which",
|
|
120
|
-
],
|
|
121
|
-
# Model config - match eval settings
|
|
74
|
+
tools=ENABLED_TOOLS,
|
|
75
|
+
bash_allowlist=KERNELBENCH_BASH_ALLOWLIST,
|
|
76
|
+
# Model config
|
|
122
77
|
model="anthropic/claude-opus-4-5-20251101",
|
|
123
78
|
max_tokens=8192,
|
|
124
|
-
# No thinking by default
|
|
79
|
+
# No thinking by default, can override with --thinking
|
|
125
80
|
thinking=False,
|
|
126
81
|
# Multi-turn for iterative optimization
|
|
127
82
|
single_turn=False,
|
wafer/templates/trace_analyze.py
CHANGED
|
@@ -60,7 +60,7 @@ Use `--json` flags when available for structured output that's easier to parse.
|
|
|
60
60
|
"python -c",
|
|
61
61
|
],
|
|
62
62
|
# Model config
|
|
63
|
-
model="anthropic/claude-
|
|
63
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
64
64
|
max_tokens=8192,
|
|
65
65
|
# Thinking config - disabled for trace analysis (mostly parsing)
|
|
66
66
|
thinking=False,
|