wafer-cli 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/cli.py +163 -3
- wafer/corpus.py +241 -9
- wafer/evaluate.py +426 -8
- {wafer_cli-0.2.21.dist-info → wafer_cli-0.2.22.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.21.dist-info → wafer_cli-0.2.22.dist-info}/RECORD +8 -8
- {wafer_cli-0.2.21.dist-info → wafer_cli-0.2.22.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.21.dist-info → wafer_cli-0.2.22.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.21.dist-info → wafer_cli-0.2.22.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -147,7 +147,9 @@ def main_callback(
|
|
|
147
147
|
else:
|
|
148
148
|
_command_outcome = "failure"
|
|
149
149
|
# Print error summary FIRST (before traceback) so it's visible even if truncated
|
|
150
|
-
print(
|
|
150
|
+
print(
|
|
151
|
+
f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
|
|
152
|
+
)
|
|
151
153
|
# Call original excepthook (prints the full traceback)
|
|
152
154
|
original_excepthook(exc_type, exc_value, exc_traceback)
|
|
153
155
|
|
|
@@ -3486,7 +3488,7 @@ def init_runpod(
|
|
|
3486
3488
|
gpu_configs = {
|
|
3487
3489
|
"MI300X": {
|
|
3488
3490
|
"gpu_type_id": "AMD Instinct MI300X OAM",
|
|
3489
|
-
"image": "
|
|
3491
|
+
"image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
|
|
3490
3492
|
"compute_capability": "9.4",
|
|
3491
3493
|
},
|
|
3492
3494
|
"H100": {
|
|
@@ -3582,7 +3584,7 @@ def init_digitalocean(
|
|
|
3582
3584
|
"ssh_key": ssh_key,
|
|
3583
3585
|
"region": region,
|
|
3584
3586
|
"size_slug": "gpu-mi300x1-192gb-devcloud",
|
|
3585
|
-
"image": "
|
|
3587
|
+
"image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
|
|
3586
3588
|
"provision_timeout": 600,
|
|
3587
3589
|
"eval_timeout": 600,
|
|
3588
3590
|
"keep_alive": keep_alive,
|
|
@@ -4084,6 +4086,164 @@ def targets_cleanup(
|
|
|
4084
4086
|
raise typer.Exit(1) from None
|
|
4085
4087
|
|
|
4086
4088
|
|
|
4089
|
+
# Known libraries that can be installed on targets
|
|
4090
|
+
# TODO: Consider adding HipKittens to the default RunPod/DO Docker images
|
|
4091
|
+
# so this install step isn't needed. For now, this command handles it.
|
|
4092
|
+
INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
|
|
4093
|
+
"hipkittens": {
|
|
4094
|
+
"description": "HipKittens - AMD port of ThunderKittens for MI300X",
|
|
4095
|
+
"git_url": "https://github.com/HazyResearch/hipkittens.git",
|
|
4096
|
+
"install_path": "/opt/hipkittens",
|
|
4097
|
+
"requires_amd": True,
|
|
4098
|
+
},
|
|
4099
|
+
# CK is already installed with ROCm 7.0, no action needed
|
|
4100
|
+
"repair-headers": {
|
|
4101
|
+
"description": "Repair ROCm thrust headers (fixes hipify corruption)",
|
|
4102
|
+
"custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
|
|
4103
|
+
"requires_amd": True,
|
|
4104
|
+
},
|
|
4105
|
+
}
|
|
4106
|
+
|
|
4107
|
+
|
|
4108
|
+
@targets_app.command("install")
|
|
4109
|
+
def targets_install(
|
|
4110
|
+
name: str = typer.Argument(..., help="Target name"),
|
|
4111
|
+
library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
|
|
4112
|
+
) -> None:
|
|
4113
|
+
"""Install a library or run maintenance on a target (idempotent).
|
|
4114
|
+
|
|
4115
|
+
Installs header-only libraries like HipKittens on remote targets.
|
|
4116
|
+
Safe to run multiple times - will skip if already installed.
|
|
4117
|
+
|
|
4118
|
+
Available libraries:
|
|
4119
|
+
hipkittens - HipKittens (AMD ThunderKittens port)
|
|
4120
|
+
repair-headers - Fix ROCm thrust headers (after hipify corruption)
|
|
4121
|
+
|
|
4122
|
+
Examples:
|
|
4123
|
+
wafer config targets install runpod-mi300x hipkittens
|
|
4124
|
+
wafer config targets install runpod-mi300x repair-headers
|
|
4125
|
+
wafer config targets install do-mi300x hipkittens
|
|
4126
|
+
"""
|
|
4127
|
+
import subprocess
|
|
4128
|
+
|
|
4129
|
+
from .targets import load_target
|
|
4130
|
+
from .targets_ops import get_target_ssh_info
|
|
4131
|
+
|
|
4132
|
+
if library not in INSTALLABLE_LIBRARIES:
|
|
4133
|
+
available = ", ".join(INSTALLABLE_LIBRARIES.keys())
|
|
4134
|
+
typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
|
|
4135
|
+
raise typer.Exit(1)
|
|
4136
|
+
|
|
4137
|
+
lib_info = INSTALLABLE_LIBRARIES[library]
|
|
4138
|
+
|
|
4139
|
+
try:
|
|
4140
|
+
target = load_target(name)
|
|
4141
|
+
except FileNotFoundError as e:
|
|
4142
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4143
|
+
raise typer.Exit(1) from None
|
|
4144
|
+
|
|
4145
|
+
# Check if target is AMD (for AMD-only libraries)
|
|
4146
|
+
if lib_info.get("requires_amd"):
|
|
4147
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
4148
|
+
DigitalOceanTarget,
|
|
4149
|
+
RunPodTarget,
|
|
4150
|
+
)
|
|
4151
|
+
|
|
4152
|
+
is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
|
|
4153
|
+
if not is_amd and hasattr(target, "compute_capability"):
|
|
4154
|
+
# Check compute capability for MI300X (gfx942 = 9.4)
|
|
4155
|
+
is_amd = target.compute_capability.startswith("9.")
|
|
4156
|
+
if not is_amd:
|
|
4157
|
+
typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
|
|
4158
|
+
raise typer.Exit(1)
|
|
4159
|
+
|
|
4160
|
+
typer.echo(f"Installing {library} on {name}...")
|
|
4161
|
+
typer.echo(f" {lib_info['description']}")
|
|
4162
|
+
|
|
4163
|
+
async def _install() -> bool:
|
|
4164
|
+
# get_target_ssh_info uses pure trio async (no asyncio bridging needed)
|
|
4165
|
+
# and we use subprocess for SSH, not AsyncSSHClient
|
|
4166
|
+
ssh_info = await get_target_ssh_info(target)
|
|
4167
|
+
|
|
4168
|
+
ssh_cmd = [
|
|
4169
|
+
"ssh",
|
|
4170
|
+
"-o",
|
|
4171
|
+
"StrictHostKeyChecking=no",
|
|
4172
|
+
"-o",
|
|
4173
|
+
"UserKnownHostsFile=/dev/null",
|
|
4174
|
+
"-o",
|
|
4175
|
+
"ConnectTimeout=30",
|
|
4176
|
+
"-i",
|
|
4177
|
+
str(ssh_info.key_path),
|
|
4178
|
+
"-p",
|
|
4179
|
+
str(ssh_info.port),
|
|
4180
|
+
f"{ssh_info.user}@{ssh_info.host}",
|
|
4181
|
+
]
|
|
4182
|
+
|
|
4183
|
+
# Handle custom scripts (like repair-headers) vs git installs
|
|
4184
|
+
if "custom_script" in lib_info:
|
|
4185
|
+
install_script = str(lib_info["custom_script"])
|
|
4186
|
+
success_marker = "REPAIRED"
|
|
4187
|
+
else:
|
|
4188
|
+
install_path = lib_info["install_path"]
|
|
4189
|
+
git_url = lib_info["git_url"]
|
|
4190
|
+
|
|
4191
|
+
# Idempotent install script
|
|
4192
|
+
install_script = f"""
|
|
4193
|
+
if [ -d "{install_path}" ]; then
|
|
4194
|
+
echo "ALREADY_INSTALLED: {install_path} exists"
|
|
4195
|
+
cd {install_path} && git pull --quiet 2>/dev/null || true
|
|
4196
|
+
else
|
|
4197
|
+
echo "INSTALLING: cloning to {install_path}"
|
|
4198
|
+
git clone --quiet {git_url} {install_path}
|
|
4199
|
+
fi
|
|
4200
|
+
echo "DONE"
|
|
4201
|
+
"""
|
|
4202
|
+
success_marker = "DONE"
|
|
4203
|
+
|
|
4204
|
+
def run_ssh() -> subprocess.CompletedProcess[str]:
|
|
4205
|
+
return subprocess.run(
|
|
4206
|
+
ssh_cmd + [install_script],
|
|
4207
|
+
capture_output=True,
|
|
4208
|
+
text=True,
|
|
4209
|
+
timeout=120,
|
|
4210
|
+
)
|
|
4211
|
+
|
|
4212
|
+
result = await trio.to_thread.run_sync(run_ssh)
|
|
4213
|
+
|
|
4214
|
+
if result.returncode != 0:
|
|
4215
|
+
typer.echo(f"Error: {result.stderr}", err=True)
|
|
4216
|
+
return False
|
|
4217
|
+
|
|
4218
|
+
output = result.stdout.strip()
|
|
4219
|
+
if "ALREADY_INSTALLED" in output:
|
|
4220
|
+
typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
|
|
4221
|
+
elif "INSTALLING" in output:
|
|
4222
|
+
typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
|
|
4223
|
+
elif "REPAIRED" in output:
|
|
4224
|
+
typer.echo(" ROCm headers repaired")
|
|
4225
|
+
|
|
4226
|
+
return success_marker in output
|
|
4227
|
+
|
|
4228
|
+
try:
|
|
4229
|
+
success = trio.run(_install)
|
|
4230
|
+
except Exception as e:
|
|
4231
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4232
|
+
raise typer.Exit(1) from None
|
|
4233
|
+
|
|
4234
|
+
if success:
|
|
4235
|
+
typer.echo(f"✓ {library} ready on {name}")
|
|
4236
|
+
|
|
4237
|
+
# Print usage hint
|
|
4238
|
+
if library == "hipkittens":
|
|
4239
|
+
typer.echo("")
|
|
4240
|
+
typer.echo("Usage in load_inline:")
|
|
4241
|
+
typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
|
|
4242
|
+
else:
|
|
4243
|
+
typer.echo(f"Failed to install {library}", err=True)
|
|
4244
|
+
raise typer.Exit(1)
|
|
4245
|
+
|
|
4246
|
+
|
|
4087
4247
|
@targets_app.command("pods")
|
|
4088
4248
|
def targets_pods() -> None:
|
|
4089
4249
|
"""List all running RunPod pods.
|
wafer/corpus.py
CHANGED
|
@@ -3,10 +3,12 @@
|
|
|
3
3
|
Download and manage documentation corpora for agent filesystem access.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import re
|
|
6
7
|
import shutil
|
|
7
8
|
import tarfile
|
|
8
9
|
import tempfile
|
|
9
10
|
from dataclasses import dataclass
|
|
11
|
+
from html.parser import HTMLParser
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Literal
|
|
12
14
|
from urllib.parse import urlparse
|
|
@@ -33,7 +35,7 @@ class CorpusConfig:
|
|
|
33
35
|
|
|
34
36
|
name: CorpusName
|
|
35
37
|
description: str
|
|
36
|
-
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
|
|
38
|
+
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
|
|
37
39
|
urls: list[str] | None = None
|
|
38
40
|
repo: str | None = None
|
|
39
41
|
repo_paths: list[str] | None = None
|
|
@@ -67,10 +69,43 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
|
67
69
|
),
|
|
68
70
|
"cutlass": CorpusConfig(
|
|
69
71
|
name="cutlass",
|
|
70
|
-
description="CUTLASS
|
|
71
|
-
source_type="
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
description="CUTLASS C++ documentation, examples, and tutorials",
|
|
73
|
+
source_type="mixed",
|
|
74
|
+
# Official NVIDIA CUTLASS documentation (scraped as markdown)
|
|
75
|
+
urls=[
|
|
76
|
+
"https://docs.nvidia.com/cutlass/latest/overview.html",
|
|
77
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
|
|
78
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
|
|
79
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
|
|
80
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
|
|
81
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
|
|
82
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
|
|
83
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
|
|
84
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
|
|
85
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
|
|
86
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
|
|
87
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
|
|
88
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
|
|
89
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
|
|
90
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
|
|
91
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
|
|
92
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
|
|
93
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
|
|
94
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
|
|
95
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
|
|
96
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
|
|
97
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
|
|
98
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
|
|
99
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
|
|
100
|
+
],
|
|
101
|
+
# NVIDIA/cutlass GitHub examples (excluding python/)
|
|
102
|
+
repos=[
|
|
103
|
+
RepoSource(
|
|
104
|
+
repo="NVIDIA/cutlass",
|
|
105
|
+
paths=["examples"],
|
|
106
|
+
branch="main",
|
|
107
|
+
),
|
|
108
|
+
],
|
|
74
109
|
),
|
|
75
110
|
"hip": CorpusConfig(
|
|
76
111
|
name="hip",
|
|
@@ -169,19 +204,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
|
|
|
169
204
|
return base_dir / "/".join(path_parts)
|
|
170
205
|
|
|
171
206
|
|
|
207
|
+
class _HTMLToMarkdown(HTMLParser):
|
|
208
|
+
"""HTML to Markdown converter for NVIDIA documentation pages.
|
|
209
|
+
|
|
210
|
+
Uses stdlib HTMLParser - requires subclassing due to callback-based API.
|
|
211
|
+
The public interface is the functional `_html_to_markdown()` below.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self) -> None:
|
|
215
|
+
super().__init__()
|
|
216
|
+
self.output: list[str] = []
|
|
217
|
+
self.current_tag: str = ""
|
|
218
|
+
self.in_code_block = False
|
|
219
|
+
self.in_pre = False
|
|
220
|
+
self.list_depth = 0
|
|
221
|
+
self.ordered_list_counters: list[int] = []
|
|
222
|
+
self.skip_content = False
|
|
223
|
+
self.link_href: str | None = None
|
|
224
|
+
|
|
225
|
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
|
226
|
+
self.current_tag = tag
|
|
227
|
+
attrs_dict = dict(attrs)
|
|
228
|
+
|
|
229
|
+
# Skip script, style, nav, footer, header
|
|
230
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
231
|
+
self.skip_content = True
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
if tag == "h1":
|
|
235
|
+
self.output.append("\n# ")
|
|
236
|
+
elif tag == "h2":
|
|
237
|
+
self.output.append("\n## ")
|
|
238
|
+
elif tag == "h3":
|
|
239
|
+
self.output.append("\n### ")
|
|
240
|
+
elif tag == "h4":
|
|
241
|
+
self.output.append("\n#### ")
|
|
242
|
+
elif tag == "h5":
|
|
243
|
+
self.output.append("\n##### ")
|
|
244
|
+
elif tag == "h6":
|
|
245
|
+
self.output.append("\n###### ")
|
|
246
|
+
elif tag == "p":
|
|
247
|
+
self.output.append("\n\n")
|
|
248
|
+
elif tag == "br":
|
|
249
|
+
self.output.append("\n")
|
|
250
|
+
elif tag == "strong" or tag == "b":
|
|
251
|
+
self.output.append("**")
|
|
252
|
+
elif tag == "em" or tag == "i":
|
|
253
|
+
self.output.append("*")
|
|
254
|
+
elif tag == "code" and not self.in_pre:
|
|
255
|
+
self.output.append("`")
|
|
256
|
+
self.in_code_block = True
|
|
257
|
+
elif tag == "pre":
|
|
258
|
+
self.in_pre = True
|
|
259
|
+
# Check for language hint in class
|
|
260
|
+
lang = ""
|
|
261
|
+
if class_attr := attrs_dict.get("class"):
|
|
262
|
+
if "python" in class_attr.lower():
|
|
263
|
+
lang = "python"
|
|
264
|
+
elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
|
|
265
|
+
lang = "cpp"
|
|
266
|
+
elif "cuda" in class_attr.lower():
|
|
267
|
+
lang = "cuda"
|
|
268
|
+
self.output.append(f"\n```{lang}\n")
|
|
269
|
+
elif tag == "ul":
|
|
270
|
+
self.list_depth += 1
|
|
271
|
+
self.output.append("\n")
|
|
272
|
+
elif tag == "ol":
|
|
273
|
+
self.list_depth += 1
|
|
274
|
+
self.ordered_list_counters.append(1)
|
|
275
|
+
self.output.append("\n")
|
|
276
|
+
elif tag == "li":
|
|
277
|
+
indent = " " * (self.list_depth - 1)
|
|
278
|
+
if self.ordered_list_counters:
|
|
279
|
+
num = self.ordered_list_counters[-1]
|
|
280
|
+
self.output.append(f"{indent}{num}. ")
|
|
281
|
+
self.ordered_list_counters[-1] += 1
|
|
282
|
+
else:
|
|
283
|
+
self.output.append(f"{indent}- ")
|
|
284
|
+
elif tag == "a":
|
|
285
|
+
self.link_href = attrs_dict.get("href")
|
|
286
|
+
self.output.append("[")
|
|
287
|
+
elif tag == "img":
|
|
288
|
+
alt = attrs_dict.get("alt", "image")
|
|
289
|
+
src = attrs_dict.get("src", "")
|
|
290
|
+
self.output.append(f"")
|
|
291
|
+
elif tag == "blockquote":
|
|
292
|
+
self.output.append("\n> ")
|
|
293
|
+
elif tag == "hr":
|
|
294
|
+
self.output.append("\n---\n")
|
|
295
|
+
elif tag == "table":
|
|
296
|
+
self.output.append("\n")
|
|
297
|
+
elif tag == "th":
|
|
298
|
+
self.output.append("| ")
|
|
299
|
+
elif tag == "td":
|
|
300
|
+
self.output.append("| ")
|
|
301
|
+
elif tag == "tr":
|
|
302
|
+
pass # Handled in endtag
|
|
303
|
+
|
|
304
|
+
def handle_endtag(self, tag: str) -> None:
|
|
305
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
306
|
+
self.skip_content = False
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
|
|
310
|
+
self.output.append("\n")
|
|
311
|
+
elif tag == "strong" or tag == "b":
|
|
312
|
+
self.output.append("**")
|
|
313
|
+
elif tag == "em" or tag == "i":
|
|
314
|
+
self.output.append("*")
|
|
315
|
+
elif tag == "code" and not self.in_pre:
|
|
316
|
+
self.output.append("`")
|
|
317
|
+
self.in_code_block = False
|
|
318
|
+
elif tag == "pre":
|
|
319
|
+
self.in_pre = False
|
|
320
|
+
self.output.append("\n```\n")
|
|
321
|
+
elif tag == "ul":
|
|
322
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
323
|
+
elif tag == "ol":
|
|
324
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
325
|
+
if self.ordered_list_counters:
|
|
326
|
+
self.ordered_list_counters.pop()
|
|
327
|
+
elif tag == "li":
|
|
328
|
+
self.output.append("\n")
|
|
329
|
+
elif tag == "a":
|
|
330
|
+
if self.link_href:
|
|
331
|
+
self.output.append(f"]({self.link_href})")
|
|
332
|
+
else:
|
|
333
|
+
self.output.append("]")
|
|
334
|
+
self.link_href = None
|
|
335
|
+
elif tag == "p":
|
|
336
|
+
self.output.append("\n")
|
|
337
|
+
elif tag == "blockquote":
|
|
338
|
+
self.output.append("\n")
|
|
339
|
+
elif tag == "tr":
|
|
340
|
+
self.output.append("|\n")
|
|
341
|
+
elif tag == "thead":
|
|
342
|
+
# Add markdown table separator after header row
|
|
343
|
+
self.output.append("|---" * 10 + "|\n")
|
|
344
|
+
|
|
345
|
+
def handle_data(self, data: str) -> None:
|
|
346
|
+
if self.skip_content:
|
|
347
|
+
return
|
|
348
|
+
# Preserve whitespace in code blocks
|
|
349
|
+
if self.in_pre:
|
|
350
|
+
self.output.append(data)
|
|
351
|
+
else:
|
|
352
|
+
# Collapse whitespace outside code
|
|
353
|
+
text = re.sub(r"\s+", " ", data)
|
|
354
|
+
if text.strip():
|
|
355
|
+
self.output.append(text)
|
|
356
|
+
|
|
357
|
+
def get_markdown(self) -> str:
|
|
358
|
+
"""Get the converted markdown, cleaned up."""
|
|
359
|
+
md = "".join(self.output)
|
|
360
|
+
# Clean up excessive newlines
|
|
361
|
+
md = re.sub(r"\n{3,}", "\n\n", md)
|
|
362
|
+
# Clean up empty table separators
|
|
363
|
+
md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
|
|
364
|
+
return md.strip()
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _html_to_markdown(html: str) -> str:
|
|
368
|
+
"""Convert HTML to Markdown."""
|
|
369
|
+
parser = _HTMLToMarkdown()
|
|
370
|
+
parser.feed(html)
|
|
371
|
+
return parser.get_markdown()
|
|
372
|
+
|
|
373
|
+
|
|
172
374
|
def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
173
|
-
"""Download NVIDIA docs
|
|
375
|
+
"""Download NVIDIA docs and convert HTML to Markdown.
|
|
376
|
+
|
|
377
|
+
NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
|
|
378
|
+
"""
|
|
174
379
|
assert config.urls is not None
|
|
175
380
|
downloaded = 0
|
|
176
381
|
with httpx.Client(timeout=30.0, follow_redirects=True) as client:
|
|
177
382
|
for url in config.urls:
|
|
178
|
-
md_url = f"{url}.md"
|
|
179
383
|
filepath = _url_to_filepath(url, dest)
|
|
180
384
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
181
385
|
try:
|
|
182
|
-
|
|
386
|
+
# Fetch HTML page directly
|
|
387
|
+
resp = client.get(url)
|
|
183
388
|
resp.raise_for_status()
|
|
184
|
-
|
|
389
|
+
|
|
390
|
+
# Convert HTML to Markdown
|
|
391
|
+
markdown = _html_to_markdown(resp.text)
|
|
392
|
+
|
|
393
|
+
# Add source URL as header
|
|
394
|
+
content = f"<!-- Source: {url} -->\n\n{markdown}"
|
|
395
|
+
filepath.write_text(content)
|
|
185
396
|
downloaded += 1
|
|
186
397
|
if verbose:
|
|
187
398
|
print(f" ✓ {filepath.relative_to(dest)}")
|
|
@@ -275,6 +486,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
|
|
|
275
486
|
return downloaded
|
|
276
487
|
|
|
277
488
|
|
|
489
|
+
def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
490
|
+
"""Download from mixed sources (NVIDIA docs + GitHub repos)."""
|
|
491
|
+
total = 0
|
|
492
|
+
|
|
493
|
+
# Download NVIDIA markdown docs (urls)
|
|
494
|
+
if config.urls:
|
|
495
|
+
if verbose:
|
|
496
|
+
print(" [NVIDIA docs]")
|
|
497
|
+
total += _download_nvidia_md(config, dest, verbose)
|
|
498
|
+
|
|
499
|
+
# Download GitHub repos
|
|
500
|
+
if config.repos:
|
|
501
|
+
if verbose:
|
|
502
|
+
print(" [GitHub repos]")
|
|
503
|
+
total += _download_github_multi_repo(config, dest, verbose)
|
|
504
|
+
|
|
505
|
+
return total
|
|
506
|
+
|
|
507
|
+
|
|
278
508
|
def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
|
|
279
509
|
"""Download a corpus to local cache.
|
|
280
510
|
|
|
@@ -311,6 +541,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
|
|
|
311
541
|
count = _download_github_repo(config, dest, verbose)
|
|
312
542
|
elif config.source_type == "github_multi_repo":
|
|
313
543
|
count = _download_github_multi_repo(config, dest, verbose)
|
|
544
|
+
elif config.source_type == "mixed":
|
|
545
|
+
count = _download_mixed(config, dest, verbose)
|
|
314
546
|
else:
|
|
315
547
|
raise ValueError(f"Unknown source type: {config.source_type}")
|
|
316
548
|
if verbose:
|
wafer/evaluate.py
CHANGED
|
@@ -1168,11 +1168,16 @@ def _build_modal_sandbox_script(
|
|
|
1168
1168
|
"""
|
|
1169
1169
|
gpu_type = target.gpu_type
|
|
1170
1170
|
|
|
1171
|
-
# Determine PyTorch index based on GPU type
|
|
1171
|
+
# Determine PyTorch index and CUDA arch based on GPU type
|
|
1172
1172
|
if gpu_type in ("B200", "GB200"):
|
|
1173
|
-
torch_index = "https://download.pytorch.org/whl/
|
|
1173
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
1174
|
+
cuda_arch_list = "10.0" # Blackwell (sm_100)
|
|
1175
|
+
elif gpu_type == "H100":
|
|
1176
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
1177
|
+
cuda_arch_list = "9.0" # Hopper (sm_90)
|
|
1174
1178
|
else:
|
|
1175
1179
|
torch_index = "https://download.pytorch.org/whl/cu124"
|
|
1180
|
+
cuda_arch_list = "8.0" # Default to Ampere (sm_80)
|
|
1176
1181
|
|
|
1177
1182
|
return f'''
|
|
1178
1183
|
import asyncio
|
|
@@ -1190,7 +1195,7 @@ async def run_eval():
|
|
|
1190
1195
|
"nvidia/cuda:12.9.0-devel-ubuntu22.04",
|
|
1191
1196
|
add_python="3.12",
|
|
1192
1197
|
)
|
|
1193
|
-
.apt_install("git", "build-essential", "cmake")
|
|
1198
|
+
.apt_install("git", "build-essential", "cmake", "ripgrep")
|
|
1194
1199
|
.pip_install(
|
|
1195
1200
|
"torch",
|
|
1196
1201
|
index_url="{torch_index}",
|
|
@@ -1203,6 +1208,12 @@ async def run_eval():
|
|
|
1203
1208
|
)
|
|
1204
1209
|
.env({{
|
|
1205
1210
|
"CUDA_HOME": "/usr/local/cuda",
|
|
1211
|
+
# C++ compiler needs explicit include path for cuda_runtime.h
|
|
1212
|
+
"CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
|
|
1213
|
+
# Linker needs lib path
|
|
1214
|
+
"LIBRARY_PATH": "/usr/local/cuda/lib64",
|
|
1215
|
+
# Force PyTorch to compile for correct GPU architecture
|
|
1216
|
+
"TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
|
|
1206
1217
|
}})
|
|
1207
1218
|
)
|
|
1208
1219
|
|
|
@@ -2790,6 +2801,15 @@ if torch.cuda.is_available():
|
|
|
2790
2801
|
gc.collect()
|
|
2791
2802
|
torch.cuda.empty_cache()
|
|
2792
2803
|
torch.cuda.reset_peak_memory_stats()
|
|
2804
|
+
|
|
2805
|
+
# Enable TF32 for fair benchmarking against reference kernels.
|
|
2806
|
+
# PyTorch 1.12+ disables TF32 for matmul by default, which handicaps
|
|
2807
|
+
# reference kernels using cuBLAS. We enable it so reference kernels
|
|
2808
|
+
# run at their best performance (using tensor cores when applicable).
|
|
2809
|
+
# This ensures speedup comparisons are against optimized baselines.
|
|
2810
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
2811
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
2812
|
+
print("[KernelBench] TF32 enabled for fair benchmarking")
|
|
2793
2813
|
|
|
2794
2814
|
|
|
2795
2815
|
def _calculate_timing_stats(times: list[float]) -> dict:
|
|
@@ -3453,6 +3473,368 @@ def _validate_kernelbench_files(args: KernelBenchEvaluateArgs) -> str | None:
|
|
|
3453
3473
|
return None
|
|
3454
3474
|
|
|
3455
3475
|
|
|
3476
|
+
def _build_modal_kernelbench_script(
|
|
3477
|
+
target: ModalTarget,
|
|
3478
|
+
impl_code_b64: str,
|
|
3479
|
+
ref_code_b64: str,
|
|
3480
|
+
eval_script_b64: str,
|
|
3481
|
+
run_benchmarks: bool,
|
|
3482
|
+
run_defensive: bool,
|
|
3483
|
+
defense_code_b64: str | None,
|
|
3484
|
+
seed: int,
|
|
3485
|
+
inputs_code_b64: str | None = None,
|
|
3486
|
+
) -> str:
|
|
3487
|
+
"""Build Python script to create Modal sandbox and run KernelBench evaluation.
|
|
3488
|
+
|
|
3489
|
+
This runs in a subprocess to isolate Modal's asyncio from trio.
|
|
3490
|
+
"""
|
|
3491
|
+
gpu_type = target.gpu_type
|
|
3492
|
+
|
|
3493
|
+
# Determine PyTorch index and CUDA arch based on GPU type
|
|
3494
|
+
if gpu_type in ("B200", "GB200"):
|
|
3495
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
3496
|
+
cuda_arch_list = "10.0" # Blackwell (sm_100)
|
|
3497
|
+
elif gpu_type == "H100":
|
|
3498
|
+
# H100 uses CUDA 13.0 (matches modal_app.py)
|
|
3499
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
3500
|
+
cuda_arch_list = "9.0" # Hopper (sm_90)
|
|
3501
|
+
else:
|
|
3502
|
+
torch_index = "https://download.pytorch.org/whl/cu124"
|
|
3503
|
+
cuda_arch_list = "8.0" # Default to Ampere (sm_80)
|
|
3504
|
+
|
|
3505
|
+
# Install CUTLASS headers (for cute/tensor.hpp and cutlass/util/*.h) from GitHub
|
|
3506
|
+
# The nvidia-cutlass-dsl pip package doesn't include the C++ headers needed for nvcc
|
|
3507
|
+
# IMPORTANT: symlink to /usr/local/cuda/include because nvcc searches there by default
|
|
3508
|
+
cutlass_install = '''
|
|
3509
|
+
.run_commands([
|
|
3510
|
+
# Clone CUTLASS headers from GitHub (shallow clone, full include tree)
|
|
3511
|
+
# Use simple shallow clone - sparse-checkout can be buggy in some environments
|
|
3512
|
+
"git clone --depth 1 https://github.com/NVIDIA/cutlass.git /opt/cutlass",
|
|
3513
|
+
# Verify the util headers exist (for debugging)
|
|
3514
|
+
"ls -la /opt/cutlass/include/cutlass/util/ | head -5",
|
|
3515
|
+
# Symlink headers to CUDA include path (nvcc searches here by default)
|
|
3516
|
+
"ln -sf /opt/cutlass/include/cute /usr/local/cuda/include/cute",
|
|
3517
|
+
"ln -sf /opt/cutlass/include/cutlass /usr/local/cuda/include/cutlass",
|
|
3518
|
+
])
|
|
3519
|
+
.pip_install(
|
|
3520
|
+
"nvidia-cutlass-dsl",
|
|
3521
|
+
index_url="https://pypi.nvidia.com",
|
|
3522
|
+
extra_index_url="https://pypi.org/simple",
|
|
3523
|
+
)
|
|
3524
|
+
'''
|
|
3525
|
+
|
|
3526
|
+
inputs_write = ""
|
|
3527
|
+
if inputs_code_b64:
|
|
3528
|
+
inputs_write = f'''
|
|
3529
|
+
# Write custom inputs
|
|
3530
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
3531
|
+
import base64
|
|
3532
|
+
with open('/workspace/custom_inputs.py', 'w') as f:
|
|
3533
|
+
f.write(base64.b64decode('{inputs_code_b64}').decode())
|
|
3534
|
+
print('Custom inputs written')
|
|
3535
|
+
""")
|
|
3536
|
+
proc.wait()
|
|
3537
|
+
'''
|
|
3538
|
+
|
|
3539
|
+
defense_write = ""
|
|
3540
|
+
if run_defensive and defense_code_b64:
|
|
3541
|
+
defense_write = f'''
|
|
3542
|
+
# Write defense module
|
|
3543
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
3544
|
+
import base64
|
|
3545
|
+
with open('/workspace/defense.py', 'w') as f:
|
|
3546
|
+
f.write(base64.b64decode('{defense_code_b64}').decode())
|
|
3547
|
+
print('Defense module written')
|
|
3548
|
+
""")
|
|
3549
|
+
proc.wait()
|
|
3550
|
+
'''
|
|
3551
|
+
|
|
3552
|
+
# Build eval command
|
|
3553
|
+
eval_cmd_parts = [
|
|
3554
|
+
"python /workspace/kernelbench_eval.py",
|
|
3555
|
+
"--impl /workspace/implementation.py",
|
|
3556
|
+
"--reference /workspace/reference.py",
|
|
3557
|
+
"--output /workspace/results.json",
|
|
3558
|
+
f"--seed {seed}",
|
|
3559
|
+
]
|
|
3560
|
+
if run_benchmarks:
|
|
3561
|
+
eval_cmd_parts.append("--benchmark")
|
|
3562
|
+
if run_defensive and defense_code_b64:
|
|
3563
|
+
eval_cmd_parts.append("--defensive")
|
|
3564
|
+
eval_cmd_parts.append("--defense-module /workspace/defense.py")
|
|
3565
|
+
if inputs_code_b64:
|
|
3566
|
+
eval_cmd_parts.append("--inputs /workspace/custom_inputs.py")
|
|
3567
|
+
|
|
3568
|
+
eval_cmd = " ".join(eval_cmd_parts)
|
|
3569
|
+
|
|
3570
|
+
return f'''
|
|
3571
|
+
import asyncio
|
|
3572
|
+
import base64
|
|
3573
|
+
import json
|
|
3574
|
+
import sys
|
|
3575
|
+
import modal
|
|
3576
|
+
|
|
3577
|
+
async def run_eval():
|
|
3578
|
+
app = modal.App.lookup("wafer-evaluate", create_if_missing=True)
|
|
3579
|
+
|
|
3580
|
+
# Build image with PyTorch, CUTLASS DSL and dependencies
|
|
3581
|
+
image = (
|
|
3582
|
+
modal.Image.from_registry(
|
|
3583
|
+
"nvidia/cuda:12.9.0-devel-ubuntu22.04",
|
|
3584
|
+
add_python="3.12",
|
|
3585
|
+
)
|
|
3586
|
+
.apt_install("git", "build-essential", "cmake", "ninja-build", "ripgrep")
|
|
3587
|
+
.pip_install(
|
|
3588
|
+
"torch",
|
|
3589
|
+
index_url="{torch_index}",
|
|
3590
|
+
extra_index_url="https://pypi.org/simple",
|
|
3591
|
+
)
|
|
3592
|
+
.pip_install(
|
|
3593
|
+
"numpy",
|
|
3594
|
+
"triton",
|
|
3595
|
+
"ninja",
|
|
3596
|
+
)
|
|
3597
|
+
{cutlass_install}
|
|
3598
|
+
.env({{
|
|
3599
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
3600
|
+
# C++ compiler needs explicit include path for cuda_runtime.h
|
|
3601
|
+
"CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
|
|
3602
|
+
# Linker needs lib path
|
|
3603
|
+
"LIBRARY_PATH": "/usr/local/cuda/lib64",
|
|
3604
|
+
# Force PyTorch to compile for correct GPU architecture
|
|
3605
|
+
"TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
|
|
3606
|
+
}})
|
|
3607
|
+
)
|
|
3608
|
+
|
|
3609
|
+
# Create sandbox
|
|
3610
|
+
sandbox = modal.Sandbox.create(
|
|
3611
|
+
app=app,
|
|
3612
|
+
image=image,
|
|
3613
|
+
gpu="{gpu_type}",
|
|
3614
|
+
timeout={target.timeout_seconds},
|
|
3615
|
+
)
|
|
3616
|
+
|
|
3617
|
+
try:
|
|
3618
|
+
# Create workspace directory
|
|
3619
|
+
sandbox.exec("mkdir", "-p", "/workspace").wait()
|
|
3620
|
+
|
|
3621
|
+
# Write files to sandbox
|
|
3622
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
3623
|
+
import base64
|
|
3624
|
+
with open('/workspace/implementation.py', 'w') as f:
|
|
3625
|
+
f.write(base64.b64decode('{impl_code_b64}').decode())
|
|
3626
|
+
with open('/workspace/reference.py', 'w') as f:
|
|
3627
|
+
f.write(base64.b64decode('{ref_code_b64}').decode())
|
|
3628
|
+
with open('/workspace/kernelbench_eval.py', 'w') as f:
|
|
3629
|
+
f.write(base64.b64decode('{eval_script_b64}').decode())
|
|
3630
|
+
print('Files written')
|
|
3631
|
+
""")
|
|
3632
|
+
proc.wait()
|
|
3633
|
+
if proc.returncode != 0:
|
|
3634
|
+
print(json.dumps({{"success": False, "error": f"Failed to write files: {{proc.stderr.read()}}"}}))
|
|
3635
|
+
return
|
|
3636
|
+
{inputs_write}
|
|
3637
|
+
{defense_write}
|
|
3638
|
+
# Run evaluation
|
|
3639
|
+
print(f"Running KernelBench evaluation on {{'{gpu_type}'}}...")
|
|
3640
|
+
proc = sandbox.exec("bash", "-c", "{eval_cmd}")
|
|
3641
|
+
|
|
3642
|
+
# Stream output
|
|
3643
|
+
for line in proc.stdout:
|
|
3644
|
+
print(line, end="")
|
|
3645
|
+
for line in proc.stderr:
|
|
3646
|
+
print(line, end="", file=sys.stderr)
|
|
3647
|
+
|
|
3648
|
+
proc.wait()
|
|
3649
|
+
|
|
3650
|
+
if proc.returncode != 0:
|
|
3651
|
+
print(json.dumps({{"success": False, "error": f"Evaluation failed with exit code {{proc.returncode}}"}}))
|
|
3652
|
+
return
|
|
3653
|
+
|
|
3654
|
+
# Read results
|
|
3655
|
+
result_proc = sandbox.exec("cat", "/workspace/results.json")
|
|
3656
|
+
result_data = result_proc.stdout.read()
|
|
3657
|
+
result_proc.wait()
|
|
3658
|
+
|
|
3659
|
+
if result_data:
|
|
3660
|
+
results = json.loads(result_data)
|
|
3661
|
+
print("EVAL_RESULT_JSON:" + json.dumps(results))
|
|
3662
|
+
else:
|
|
3663
|
+
print(json.dumps({{"success": False, "error": "No results.json found"}}))
|
|
3664
|
+
|
|
3665
|
+
finally:
|
|
3666
|
+
sandbox.terminate()
|
|
3667
|
+
|
|
3668
|
+
asyncio.run(run_eval())
|
|
3669
|
+
'''
|
|
3670
|
+
|
|
3671
|
+
|
|
3672
|
+
async def run_evaluate_kernelbench_modal(
|
|
3673
|
+
args: KernelBenchEvaluateArgs,
|
|
3674
|
+
target: ModalTarget,
|
|
3675
|
+
) -> EvaluateResult:
|
|
3676
|
+
"""Run KernelBench format evaluation on Modal sandbox.
|
|
3677
|
+
|
|
3678
|
+
Creates a Modal sandbox, uploads files, runs KernelBench eval, and parses results.
|
|
3679
|
+
Uses subprocess to isolate Modal's asyncio from trio.
|
|
3680
|
+
"""
|
|
3681
|
+
import base64
|
|
3682
|
+
import subprocess
|
|
3683
|
+
import sys
|
|
3684
|
+
|
|
3685
|
+
import trio
|
|
3686
|
+
|
|
3687
|
+
print(f"Creating Modal sandbox ({target.gpu_type}) for KernelBench evaluation...")
|
|
3688
|
+
|
|
3689
|
+
# Encode files as base64
|
|
3690
|
+
impl_code_b64 = base64.b64encode(args.implementation.read_bytes()).decode()
|
|
3691
|
+
ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
|
|
3692
|
+
eval_script_b64 = base64.b64encode(KERNELBENCH_EVAL_SCRIPT.encode()).decode()
|
|
3693
|
+
|
|
3694
|
+
# Encode custom inputs if provided
|
|
3695
|
+
inputs_code_b64 = None
|
|
3696
|
+
if args.inputs:
|
|
3697
|
+
inputs_code_b64 = base64.b64encode(args.inputs.read_bytes()).decode()
|
|
3698
|
+
|
|
3699
|
+
# Encode defense module if defensive mode is enabled
|
|
3700
|
+
defense_code_b64 = None
|
|
3701
|
+
if args.defensive:
|
|
3702
|
+
defense_path = (
|
|
3703
|
+
Path(__file__).parent.parent.parent.parent
|
|
3704
|
+
/ "packages"
|
|
3705
|
+
/ "wafer-core"
|
|
3706
|
+
/ "wafer_core"
|
|
3707
|
+
/ "utils"
|
|
3708
|
+
/ "kernel_utils"
|
|
3709
|
+
/ "defense.py"
|
|
3710
|
+
)
|
|
3711
|
+
if defense_path.exists():
|
|
3712
|
+
defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
|
|
3713
|
+
else:
|
|
3714
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
3715
|
+
|
|
3716
|
+
# Build the script
|
|
3717
|
+
script = _build_modal_kernelbench_script(
|
|
3718
|
+
target=target,
|
|
3719
|
+
impl_code_b64=impl_code_b64,
|
|
3720
|
+
ref_code_b64=ref_code_b64,
|
|
3721
|
+
eval_script_b64=eval_script_b64,
|
|
3722
|
+
run_benchmarks=args.benchmark,
|
|
3723
|
+
run_defensive=args.defensive,
|
|
3724
|
+
defense_code_b64=defense_code_b64,
|
|
3725
|
+
seed=args.seed,
|
|
3726
|
+
inputs_code_b64=inputs_code_b64,
|
|
3727
|
+
)
|
|
3728
|
+
|
|
3729
|
+
def _run_subprocess() -> tuple[str, str, int]:
|
|
3730
|
+
result = subprocess.run(
|
|
3731
|
+
[sys.executable, "-c", script],
|
|
3732
|
+
capture_output=True,
|
|
3733
|
+
text=True,
|
|
3734
|
+
timeout=target.timeout_seconds + 120, # Extra buffer for sandbox creation + image build
|
|
3735
|
+
)
|
|
3736
|
+
return result.stdout, result.stderr, result.returncode
|
|
3737
|
+
|
|
3738
|
+
try:
|
|
3739
|
+
stdout, stderr, returncode = await trio.to_thread.run_sync(_run_subprocess)
|
|
3740
|
+
except subprocess.TimeoutExpired:
|
|
3741
|
+
return EvaluateResult(
|
|
3742
|
+
success=False,
|
|
3743
|
+
all_correct=False,
|
|
3744
|
+
correctness_score=0.0,
|
|
3745
|
+
geomean_speedup=0.0,
|
|
3746
|
+
passed_tests=0,
|
|
3747
|
+
total_tests=0,
|
|
3748
|
+
error_message=f"Modal KernelBench evaluation timed out after {target.timeout_seconds}s",
|
|
3749
|
+
)
|
|
3750
|
+
except Exception as e:
|
|
3751
|
+
return EvaluateResult(
|
|
3752
|
+
success=False,
|
|
3753
|
+
all_correct=False,
|
|
3754
|
+
correctness_score=0.0,
|
|
3755
|
+
geomean_speedup=0.0,
|
|
3756
|
+
passed_tests=0,
|
|
3757
|
+
total_tests=0,
|
|
3758
|
+
error_message=f"Failed to run Modal sandbox: {e}",
|
|
3759
|
+
)
|
|
3760
|
+
|
|
3761
|
+
# Print output for debugging
|
|
3762
|
+
if stdout:
|
|
3763
|
+
for line in stdout.split("\n"):
|
|
3764
|
+
if not line.startswith("EVAL_RESULT_JSON:"):
|
|
3765
|
+
print(line)
|
|
3766
|
+
if stderr:
|
|
3767
|
+
print(stderr, file=sys.stderr)
|
|
3768
|
+
|
|
3769
|
+
if returncode != 0:
|
|
3770
|
+
return EvaluateResult(
|
|
3771
|
+
success=False,
|
|
3772
|
+
all_correct=False,
|
|
3773
|
+
correctness_score=0.0,
|
|
3774
|
+
geomean_speedup=0.0,
|
|
3775
|
+
passed_tests=0,
|
|
3776
|
+
total_tests=0,
|
|
3777
|
+
error_message=f"Modal sandbox failed (exit {returncode}): {stderr or stdout}",
|
|
3778
|
+
)
|
|
3779
|
+
|
|
3780
|
+
# Parse results from stdout
|
|
3781
|
+
result_json = None
|
|
3782
|
+
for line in stdout.split("\n"):
|
|
3783
|
+
if line.startswith("EVAL_RESULT_JSON:"):
|
|
3784
|
+
result_json = line[len("EVAL_RESULT_JSON:"):]
|
|
3785
|
+
break
|
|
3786
|
+
|
|
3787
|
+
if not result_json:
|
|
3788
|
+
return EvaluateResult(
|
|
3789
|
+
success=False,
|
|
3790
|
+
all_correct=False,
|
|
3791
|
+
correctness_score=0.0,
|
|
3792
|
+
geomean_speedup=0.0,
|
|
3793
|
+
passed_tests=0,
|
|
3794
|
+
total_tests=0,
|
|
3795
|
+
error_message="No results found in Modal output",
|
|
3796
|
+
)
|
|
3797
|
+
|
|
3798
|
+
try:
|
|
3799
|
+
results = json.loads(result_json)
|
|
3800
|
+
except json.JSONDecodeError as e:
|
|
3801
|
+
return EvaluateResult(
|
|
3802
|
+
success=False,
|
|
3803
|
+
all_correct=False,
|
|
3804
|
+
correctness_score=0.0,
|
|
3805
|
+
geomean_speedup=0.0,
|
|
3806
|
+
passed_tests=0,
|
|
3807
|
+
total_tests=0,
|
|
3808
|
+
error_message=f"Failed to parse results JSON: {e}",
|
|
3809
|
+
)
|
|
3810
|
+
|
|
3811
|
+
# Check for error in results
|
|
3812
|
+
if "error" in results and results.get("success") is False:
|
|
3813
|
+
return EvaluateResult(
|
|
3814
|
+
success=False,
|
|
3815
|
+
all_correct=False,
|
|
3816
|
+
correctness_score=0.0,
|
|
3817
|
+
geomean_speedup=0.0,
|
|
3818
|
+
passed_tests=0,
|
|
3819
|
+
total_tests=0,
|
|
3820
|
+
error_message=results.get("error", "Unknown error"),
|
|
3821
|
+
)
|
|
3822
|
+
|
|
3823
|
+
# Extract metrics from results
|
|
3824
|
+
return EvaluateResult(
|
|
3825
|
+
success=True,
|
|
3826
|
+
all_correct=results.get("all_correct", False),
|
|
3827
|
+
correctness_score=float(results.get("correctness_score", 0.0)),
|
|
3828
|
+
geomean_speedup=float(results.get("geomean_speedup", 0.0)),
|
|
3829
|
+
passed_tests=int(results.get("passed_tests", 0)),
|
|
3830
|
+
total_tests=int(results.get("total_tests", 0)),
|
|
3831
|
+
error_message=results.get("error"),
|
|
3832
|
+
test_results=results.get("test_results", []),
|
|
3833
|
+
compilation_time_s=results.get("compilation_time_s"),
|
|
3834
|
+
profiling_stats=results.get("profiling_stats"),
|
|
3835
|
+
)
|
|
3836
|
+
|
|
3837
|
+
|
|
3456
3838
|
async def run_evaluate_kernelbench_docker(
|
|
3457
3839
|
args: KernelBenchEvaluateArgs,
|
|
3458
3840
|
target: BaremetalTarget | VMTarget,
|
|
@@ -4246,6 +4628,20 @@ async def run_evaluate_kernelbench_runpod(
|
|
|
4246
4628
|
)
|
|
4247
4629
|
|
|
4248
4630
|
|
|
4631
|
+
async def run_evaluate_kernelbench_baremetal_direct(
|
|
4632
|
+
args: KernelBenchEvaluateArgs,
|
|
4633
|
+
target: BaremetalTarget,
|
|
4634
|
+
) -> EvaluateResult:
|
|
4635
|
+
"""Run KernelBench format evaluation directly on NVIDIA target (no Docker).
|
|
4636
|
+
|
|
4637
|
+
For targets that already have PyTorch/CUDA installed (e.g., workspace containers).
|
|
4638
|
+
Uses CUDA_VISIBLE_DEVICES for GPU selection.
|
|
4639
|
+
"""
|
|
4640
|
+
# Reuse the AMD function but with CUDA env vars
|
|
4641
|
+
# The logic is identical, just the GPU env var is different
|
|
4642
|
+
return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="CUDA_VISIBLE_DEVICES")
|
|
4643
|
+
|
|
4644
|
+
|
|
4249
4645
|
async def run_evaluate_kernelbench_baremetal_amd(
|
|
4250
4646
|
args: KernelBenchEvaluateArgs,
|
|
4251
4647
|
target: BaremetalTarget,
|
|
@@ -4255,6 +4651,18 @@ async def run_evaluate_kernelbench_baremetal_amd(
|
|
|
4255
4651
|
Runs evaluation script directly on host (no Docker) for AMD GPUs
|
|
4256
4652
|
that have PyTorch/ROCm installed.
|
|
4257
4653
|
"""
|
|
4654
|
+
return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="HIP_VISIBLE_DEVICES")
|
|
4655
|
+
|
|
4656
|
+
|
|
4657
|
+
async def _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
4658
|
+
args: KernelBenchEvaluateArgs,
|
|
4659
|
+
target: BaremetalTarget,
|
|
4660
|
+
gpu_env_var: str = "HIP_VISIBLE_DEVICES",
|
|
4661
|
+
) -> EvaluateResult:
|
|
4662
|
+
"""Internal implementation for direct baremetal evaluation.
|
|
4663
|
+
|
|
4664
|
+
Runs evaluation script directly on host (no Docker).
|
|
4665
|
+
"""
|
|
4258
4666
|
from datetime import datetime
|
|
4259
4667
|
|
|
4260
4668
|
from wafer_core.async_ssh import AsyncSSHClient
|
|
@@ -4405,11 +4813,15 @@ async def run_evaluate_kernelbench_baremetal_amd(
|
|
|
4405
4813
|
|
|
4406
4814
|
eval_cmd = " ".join(python_cmd_parts)
|
|
4407
4815
|
|
|
4408
|
-
# Set environment for
|
|
4409
|
-
|
|
4410
|
-
|
|
4411
|
-
|
|
4412
|
-
|
|
4816
|
+
# Set environment for GPU and run
|
|
4817
|
+
if gpu_env_var == "HIP_VISIBLE_DEVICES":
|
|
4818
|
+
# AMD: PYTORCH_ROCM_ARCH for faster compile
|
|
4819
|
+
rocm_arch = _get_rocm_arch(target.compute_capability)
|
|
4820
|
+
arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
|
|
4821
|
+
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
|
|
4822
|
+
else:
|
|
4823
|
+
# NVIDIA: just set CUDA_VISIBLE_DEVICES
|
|
4824
|
+
env_vars = f"CUDA_VISIBLE_DEVICES={gpu_id} PYTHONUNBUFFERED=1"
|
|
4413
4825
|
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4414
4826
|
|
|
4415
4827
|
# Handle prepare-only mode
|
|
@@ -4560,10 +4972,16 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
|
|
|
4560
4972
|
elif isinstance(target, RunPodTarget):
|
|
4561
4973
|
# RunPod AMD MI300X - uses ROCm Docker with device passthrough
|
|
4562
4974
|
return await run_evaluate_kernelbench_runpod(args, target)
|
|
4975
|
+
elif isinstance(target, ModalTarget):
|
|
4976
|
+
# Modal serverless - runs in Modal sandbox
|
|
4977
|
+
return await run_evaluate_kernelbench_modal(args, target)
|
|
4563
4978
|
elif isinstance(target, BaremetalTarget | VMTarget):
|
|
4564
4979
|
# Check if this is an AMD target (gfx* compute capability) - run directly
|
|
4565
4980
|
if target.compute_capability and target.compute_capability.startswith("gfx"):
|
|
4566
4981
|
return await run_evaluate_kernelbench_baremetal_amd(args, target)
|
|
4982
|
+
# Check for direct execution flag (workspace containers that already have everything)
|
|
4983
|
+
if getattr(target, "direct", False):
|
|
4984
|
+
return await run_evaluate_kernelbench_baremetal_direct(args, target)
|
|
4567
4985
|
# NVIDIA targets - require docker_image to be set
|
|
4568
4986
|
if not target.docker_image:
|
|
4569
4987
|
return EvaluateResult(
|
|
@@ -5,10 +5,10 @@ wafer/api_client.py,sha256=i_Az2b2llC3DSW8yOL-BKqa7LSKuxOr8hSN40s-oQXY,6313
|
|
|
5
5
|
wafer/auth.py,sha256=dwss_se5P-FFc9IN38q4kh_dBrA6k-CguDBkivgcdj0,14003
|
|
6
6
|
wafer/autotuner.py,sha256=41WYP41pTDvMijv2h42vm89bcHtDMJXObDlWmn6xpFU,44416
|
|
7
7
|
wafer/billing.py,sha256=jbLB2lI4_9f2KD8uEFDi_ixLlowe5hasC0TIZJyIXRg,7163
|
|
8
|
-
wafer/cli.py,sha256=
|
|
8
|
+
wafer/cli.py,sha256=j4ODOVT_r-kyc21YOI8Yl8bkiZMGuqDpXRs7CvpNaek,261443
|
|
9
9
|
wafer/config.py,sha256=h5Eo9_yfWqWGoPNdVQikI9GoZVUeysunSYiixf1mKcw,3411
|
|
10
|
-
wafer/corpus.py,sha256=
|
|
11
|
-
wafer/evaluate.py,sha256=
|
|
10
|
+
wafer/corpus.py,sha256=oQegXA43MuyRvYxOsWhmqeP5vMb5IKFHOvM-1RcahPA,22301
|
|
11
|
+
wafer/evaluate.py,sha256=SxxhiPkO6aDdfktRzJXpbWMVmIGn_gw-o5C6Zwj2zRc,190930
|
|
12
12
|
wafer/global_config.py,sha256=fhaR_RU3ufMksDmOohH1OLeQ0JT0SDW1hEip_zaP75k,11345
|
|
13
13
|
wafer/gpu_run.py,sha256=TwqXy72T7f2I7e6n5WWod3xgxCPnDhU0BgLsB4CUoQY,9716
|
|
14
14
|
wafer/inference.py,sha256=tZCO5i05FKY27ewis3CSBHFBeFbXY3xwj0DSjdoMY9s,4314
|
|
@@ -34,8 +34,8 @@ wafer/templates/ask_docs.py,sha256=Lxs-faz9v5m4Qa4NjF2X_lE8KwM9ES9MNJkxo7ep56o,2
|
|
|
34
34
|
wafer/templates/optimize_kernel.py,sha256=OvZgN5tm_OymO3lK8Dr0VO48e-5PfNVIIoACrPxpmqk,2446
|
|
35
35
|
wafer/templates/optimize_kernelbench.py,sha256=aoOA13zWEl89r6QW03xF9NKxQ7j4mWe9rwua6-mlr4Y,4780
|
|
36
36
|
wafer/templates/trace_analyze.py,sha256=XE1VqzVkIUsZbXF8EzQdDYgg-AZEYAOFpr6B_vnRELc,2880
|
|
37
|
-
wafer_cli-0.2.
|
|
38
|
-
wafer_cli-0.2.
|
|
39
|
-
wafer_cli-0.2.
|
|
40
|
-
wafer_cli-0.2.
|
|
41
|
-
wafer_cli-0.2.
|
|
37
|
+
wafer_cli-0.2.22.dist-info/METADATA,sha256=vjYzyQtphWxQ0JID0k5tFWoLwVjlR6X0B4UAuMhLhQc,560
|
|
38
|
+
wafer_cli-0.2.22.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
39
|
+
wafer_cli-0.2.22.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
|
|
40
|
+
wafer_cli-0.2.22.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
|
|
41
|
+
wafer_cli-0.2.22.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|