wafer-cli 0.2.21__tar.gz → 0.2.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/PKG-INFO +1 -1
  2. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/pyproject.toml +1 -1
  3. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/cli.py +163 -3
  4. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/corpus.py +241 -9
  5. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/evaluate.py +426 -8
  6. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer_cli.egg-info/PKG-INFO +1 -1
  7. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/README.md +0 -0
  8. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/setup.cfg +0 -0
  9. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_analytics.py +0 -0
  10. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_auth.py +0 -0
  11. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_billing.py +0 -0
  12. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_cli_coverage.py +0 -0
  13. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_cli_parity_integration.py +0 -0
  14. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_config_integration.py +0 -0
  15. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_file_operations_integration.py +0 -0
  16. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_kernel_scope_cli.py +0 -0
  17. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_nsys_analyze.py +0 -0
  18. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_nsys_profile.py +0 -0
  19. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_output.py +0 -0
  20. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_rocprof_compute_integration.py +0 -0
  21. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_skill_commands.py +0 -0
  22. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_ssh_integration.py +0 -0
  23. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_targets_ops.py +0 -0
  24. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_wevin_cli.py +0 -0
  25. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/tests/test_workflow_integration.py +0 -0
  26. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/GUIDE.md +0 -0
  27. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/__init__.py +0 -0
  28. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/analytics.py +0 -0
  29. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/api_client.py +0 -0
  30. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/auth.py +0 -0
  31. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/autotuner.py +0 -0
  32. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/billing.py +0 -0
  33. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/config.py +0 -0
  34. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/global_config.py +0 -0
  35. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/gpu_run.py +0 -0
  36. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/inference.py +0 -0
  37. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/kernel_scope.py +0 -0
  38. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/ncu_analyze.py +0 -0
  39. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/nsys_analyze.py +0 -0
  40. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/nsys_profile.py +0 -0
  41. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/output.py +0 -0
  42. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/problems.py +0 -0
  43. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/rocprof_compute.py +0 -0
  44. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/rocprof_sdk.py +0 -0
  45. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/rocprof_systems.py +0 -0
  46. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/skills/wafer-guide/SKILL.md +0 -0
  47. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/ssh_keys.py +0 -0
  48. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/target_lock.py +0 -0
  49. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/targets.py +0 -0
  50. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/targets_ops.py +0 -0
  51. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/templates/__init__.py +0 -0
  52. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/templates/ask_docs.py +0 -0
  53. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/templates/optimize_kernel.py +0 -0
  54. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/templates/optimize_kernelbench.py +0 -0
  55. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/templates/trace_analyze.py +0 -0
  56. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/tracelens.py +0 -0
  57. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/wevin_cli.py +0 -0
  58. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer/workspaces.py +0 -0
  59. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer_cli.egg-info/SOURCES.txt +0 -0
  60. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer_cli.egg-info/dependency_links.txt +0 -0
  61. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer_cli.egg-info/entry_points.txt +0 -0
  62. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer_cli.egg-info/requires.txt +0 -0
  63. {wafer_cli-0.2.21 → wafer_cli-0.2.23}/wafer_cli.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.21
3
+ Version: 0.2.23
4
4
  Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: typer>=0.12.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "wafer-cli"
3
- version = "0.2.21"
3
+ version = "0.2.23"
4
4
  description = "CLI tool for running commands on remote GPUs and GPU kernel optimization agent"
5
5
  requires-python = ">=3.11"
6
6
  dependencies = [
@@ -147,7 +147,9 @@ def main_callback(
147
147
  else:
148
148
  _command_outcome = "failure"
149
149
  # Print error summary FIRST (before traceback) so it's visible even if truncated
150
- print(f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr)
150
+ print(
151
+ f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
152
+ )
151
153
  # Call original excepthook (prints the full traceback)
152
154
  original_excepthook(exc_type, exc_value, exc_traceback)
153
155
 
@@ -3486,7 +3488,7 @@ def init_runpod(
3486
3488
  gpu_configs = {
3487
3489
  "MI300X": {
3488
3490
  "gpu_type_id": "AMD Instinct MI300X OAM",
3489
- "image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
3491
+ "image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
3490
3492
  "compute_capability": "9.4",
3491
3493
  },
3492
3494
  "H100": {
@@ -3582,7 +3584,7 @@ def init_digitalocean(
3582
3584
  "ssh_key": ssh_key,
3583
3585
  "region": region,
3584
3586
  "size_slug": "gpu-mi300x1-192gb-devcloud",
3585
- "image": "gpu-amd-base",
3587
+ "image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
3586
3588
  "provision_timeout": 600,
3587
3589
  "eval_timeout": 600,
3588
3590
  "keep_alive": keep_alive,
@@ -4084,6 +4086,164 @@ def targets_cleanup(
4084
4086
  raise typer.Exit(1) from None
4085
4087
 
4086
4088
 
4089
+ # Known libraries that can be installed on targets
4090
+ # TODO: Consider adding HipKittens to the default RunPod/DO Docker images
4091
+ # so this install step isn't needed. For now, this command handles it.
4092
+ INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
4093
+ "hipkittens": {
4094
+ "description": "HipKittens - AMD port of ThunderKittens for MI300X",
4095
+ "git_url": "https://github.com/HazyResearch/hipkittens.git",
4096
+ "install_path": "/opt/hipkittens",
4097
+ "requires_amd": True,
4098
+ },
4099
+ # CK is already installed with ROCm 7.0, no action needed
4100
+ "repair-headers": {
4101
+ "description": "Repair ROCm thrust headers (fixes hipify corruption)",
4102
+ "custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
4103
+ "requires_amd": True,
4104
+ },
4105
+ }
4106
+
4107
+
4108
+ @targets_app.command("install")
4109
+ def targets_install(
4110
+ name: str = typer.Argument(..., help="Target name"),
4111
+ library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
4112
+ ) -> None:
4113
+ """Install a library or run maintenance on a target (idempotent).
4114
+
4115
+ Installs header-only libraries like HipKittens on remote targets.
4116
+ Safe to run multiple times - will skip if already installed.
4117
+
4118
+ Available libraries:
4119
+ hipkittens - HipKittens (AMD ThunderKittens port)
4120
+ repair-headers - Fix ROCm thrust headers (after hipify corruption)
4121
+
4122
+ Examples:
4123
+ wafer config targets install runpod-mi300x hipkittens
4124
+ wafer config targets install runpod-mi300x repair-headers
4125
+ wafer config targets install do-mi300x hipkittens
4126
+ """
4127
+ import subprocess
4128
+
4129
+ from .targets import load_target
4130
+ from .targets_ops import get_target_ssh_info
4131
+
4132
+ if library not in INSTALLABLE_LIBRARIES:
4133
+ available = ", ".join(INSTALLABLE_LIBRARIES.keys())
4134
+ typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
4135
+ raise typer.Exit(1)
4136
+
4137
+ lib_info = INSTALLABLE_LIBRARIES[library]
4138
+
4139
+ try:
4140
+ target = load_target(name)
4141
+ except FileNotFoundError as e:
4142
+ typer.echo(f"Error: {e}", err=True)
4143
+ raise typer.Exit(1) from None
4144
+
4145
+ # Check if target is AMD (for AMD-only libraries)
4146
+ if lib_info.get("requires_amd"):
4147
+ from wafer_core.utils.kernel_utils.targets.config import (
4148
+ DigitalOceanTarget,
4149
+ RunPodTarget,
4150
+ )
4151
+
4152
+ is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
4153
+ if not is_amd and hasattr(target, "compute_capability"):
4154
+ # Check compute capability for MI300X (gfx942 = 9.4)
4155
+ is_amd = target.compute_capability.startswith("9.")
4156
+ if not is_amd:
4157
+ typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
4158
+ raise typer.Exit(1)
4159
+
4160
+ typer.echo(f"Installing {library} on {name}...")
4161
+ typer.echo(f" {lib_info['description']}")
4162
+
4163
+ async def _install() -> bool:
4164
+ # get_target_ssh_info uses pure trio async (no asyncio bridging needed)
4165
+ # and we use subprocess for SSH, not AsyncSSHClient
4166
+ ssh_info = await get_target_ssh_info(target)
4167
+
4168
+ ssh_cmd = [
4169
+ "ssh",
4170
+ "-o",
4171
+ "StrictHostKeyChecking=no",
4172
+ "-o",
4173
+ "UserKnownHostsFile=/dev/null",
4174
+ "-o",
4175
+ "ConnectTimeout=30",
4176
+ "-i",
4177
+ str(ssh_info.key_path),
4178
+ "-p",
4179
+ str(ssh_info.port),
4180
+ f"{ssh_info.user}@{ssh_info.host}",
4181
+ ]
4182
+
4183
+ # Handle custom scripts (like repair-headers) vs git installs
4184
+ if "custom_script" in lib_info:
4185
+ install_script = str(lib_info["custom_script"])
4186
+ success_marker = "REPAIRED"
4187
+ else:
4188
+ install_path = lib_info["install_path"]
4189
+ git_url = lib_info["git_url"]
4190
+
4191
+ # Idempotent install script
4192
+ install_script = f"""
4193
+ if [ -d "{install_path}" ]; then
4194
+ echo "ALREADY_INSTALLED: {install_path} exists"
4195
+ cd {install_path} && git pull --quiet 2>/dev/null || true
4196
+ else
4197
+ echo "INSTALLING: cloning to {install_path}"
4198
+ git clone --quiet {git_url} {install_path}
4199
+ fi
4200
+ echo "DONE"
4201
+ """
4202
+ success_marker = "DONE"
4203
+
4204
+ def run_ssh() -> subprocess.CompletedProcess[str]:
4205
+ return subprocess.run(
4206
+ ssh_cmd + [install_script],
4207
+ capture_output=True,
4208
+ text=True,
4209
+ timeout=120,
4210
+ )
4211
+
4212
+ result = await trio.to_thread.run_sync(run_ssh)
4213
+
4214
+ if result.returncode != 0:
4215
+ typer.echo(f"Error: {result.stderr}", err=True)
4216
+ return False
4217
+
4218
+ output = result.stdout.strip()
4219
+ if "ALREADY_INSTALLED" in output:
4220
+ typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
4221
+ elif "INSTALLING" in output:
4222
+ typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
4223
+ elif "REPAIRED" in output:
4224
+ typer.echo(" ROCm headers repaired")
4225
+
4226
+ return success_marker in output
4227
+
4228
+ try:
4229
+ success = trio.run(_install)
4230
+ except Exception as e:
4231
+ typer.echo(f"Error: {e}", err=True)
4232
+ raise typer.Exit(1) from None
4233
+
4234
+ if success:
4235
+ typer.echo(f"✓ {library} ready on {name}")
4236
+
4237
+ # Print usage hint
4238
+ if library == "hipkittens":
4239
+ typer.echo("")
4240
+ typer.echo("Usage in load_inline:")
4241
+ typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
4242
+ else:
4243
+ typer.echo(f"Failed to install {library}", err=True)
4244
+ raise typer.Exit(1)
4245
+
4246
+
4087
4247
  @targets_app.command("pods")
4088
4248
  def targets_pods() -> None:
4089
4249
  """List all running RunPod pods.
@@ -3,10 +3,12 @@
3
3
  Download and manage documentation corpora for agent filesystem access.
4
4
  """
5
5
 
6
+ import re
6
7
  import shutil
7
8
  import tarfile
8
9
  import tempfile
9
10
  from dataclasses import dataclass
11
+ from html.parser import HTMLParser
10
12
  from pathlib import Path
11
13
  from typing import Literal
12
14
  from urllib.parse import urlparse
@@ -33,7 +35,7 @@ class CorpusConfig:
33
35
 
34
36
  name: CorpusName
35
37
  description: str
36
- source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
38
+ source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
37
39
  urls: list[str] | None = None
38
40
  repo: str | None = None
39
41
  repo_paths: list[str] | None = None
@@ -67,10 +69,43 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
67
69
  ),
68
70
  "cutlass": CorpusConfig(
69
71
  name="cutlass",
70
- description="CUTLASS and CuTe DSL documentation",
71
- source_type="github_repo",
72
- repo="NVIDIA/cutlass",
73
- repo_paths=["media/docs", "python/cutlass/docs"],
72
+ description="CUTLASS C++ documentation, examples, and tutorials",
73
+ source_type="mixed",
74
+ # Official NVIDIA CUTLASS documentation (scraped as markdown)
75
+ urls=[
76
+ "https://docs.nvidia.com/cutlass/latest/overview.html",
77
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
78
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
79
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
80
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
81
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
82
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
83
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
84
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
85
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
86
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
87
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
88
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
89
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
90
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
91
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
92
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
93
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
94
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
95
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
96
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
97
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
98
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
99
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
100
+ ],
101
+ # NVIDIA/cutlass GitHub examples (excluding python/)
102
+ repos=[
103
+ RepoSource(
104
+ repo="NVIDIA/cutlass",
105
+ paths=["examples"],
106
+ branch="main",
107
+ ),
108
+ ],
74
109
  ),
75
110
  "hip": CorpusConfig(
76
111
  name="hip",
@@ -169,19 +204,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
169
204
  return base_dir / "/".join(path_parts)
170
205
 
171
206
 
207
+ class _HTMLToMarkdown(HTMLParser):
208
+ """HTML to Markdown converter for NVIDIA documentation pages.
209
+
210
+ Uses stdlib HTMLParser - requires subclassing due to callback-based API.
211
+ The public interface is the functional `_html_to_markdown()` below.
212
+ """
213
+
214
+ def __init__(self) -> None:
215
+ super().__init__()
216
+ self.output: list[str] = []
217
+ self.current_tag: str = ""
218
+ self.in_code_block = False
219
+ self.in_pre = False
220
+ self.list_depth = 0
221
+ self.ordered_list_counters: list[int] = []
222
+ self.skip_content = False
223
+ self.link_href: str | None = None
224
+
225
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
226
+ self.current_tag = tag
227
+ attrs_dict = dict(attrs)
228
+
229
+ # Skip script, style, nav, footer, header
230
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
231
+ self.skip_content = True
232
+ return
233
+
234
+ if tag == "h1":
235
+ self.output.append("\n# ")
236
+ elif tag == "h2":
237
+ self.output.append("\n## ")
238
+ elif tag == "h3":
239
+ self.output.append("\n### ")
240
+ elif tag == "h4":
241
+ self.output.append("\n#### ")
242
+ elif tag == "h5":
243
+ self.output.append("\n##### ")
244
+ elif tag == "h6":
245
+ self.output.append("\n###### ")
246
+ elif tag == "p":
247
+ self.output.append("\n\n")
248
+ elif tag == "br":
249
+ self.output.append("\n")
250
+ elif tag == "strong" or tag == "b":
251
+ self.output.append("**")
252
+ elif tag == "em" or tag == "i":
253
+ self.output.append("*")
254
+ elif tag == "code" and not self.in_pre:
255
+ self.output.append("`")
256
+ self.in_code_block = True
257
+ elif tag == "pre":
258
+ self.in_pre = True
259
+ # Check for language hint in class
260
+ lang = ""
261
+ if class_attr := attrs_dict.get("class"):
262
+ if "python" in class_attr.lower():
263
+ lang = "python"
264
+ elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
265
+ lang = "cpp"
266
+ elif "cuda" in class_attr.lower():
267
+ lang = "cuda"
268
+ self.output.append(f"\n```{lang}\n")
269
+ elif tag == "ul":
270
+ self.list_depth += 1
271
+ self.output.append("\n")
272
+ elif tag == "ol":
273
+ self.list_depth += 1
274
+ self.ordered_list_counters.append(1)
275
+ self.output.append("\n")
276
+ elif tag == "li":
277
+ indent = " " * (self.list_depth - 1)
278
+ if self.ordered_list_counters:
279
+ num = self.ordered_list_counters[-1]
280
+ self.output.append(f"{indent}{num}. ")
281
+ self.ordered_list_counters[-1] += 1
282
+ else:
283
+ self.output.append(f"{indent}- ")
284
+ elif tag == "a":
285
+ self.link_href = attrs_dict.get("href")
286
+ self.output.append("[")
287
+ elif tag == "img":
288
+ alt = attrs_dict.get("alt", "image")
289
+ src = attrs_dict.get("src", "")
290
+ self.output.append(f"![{alt}]({src})")
291
+ elif tag == "blockquote":
292
+ self.output.append("\n> ")
293
+ elif tag == "hr":
294
+ self.output.append("\n---\n")
295
+ elif tag == "table":
296
+ self.output.append("\n")
297
+ elif tag == "th":
298
+ self.output.append("| ")
299
+ elif tag == "td":
300
+ self.output.append("| ")
301
+ elif tag == "tr":
302
+ pass # Handled in endtag
303
+
304
+ def handle_endtag(self, tag: str) -> None:
305
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
306
+ self.skip_content = False
307
+ return
308
+
309
+ if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
310
+ self.output.append("\n")
311
+ elif tag == "strong" or tag == "b":
312
+ self.output.append("**")
313
+ elif tag == "em" or tag == "i":
314
+ self.output.append("*")
315
+ elif tag == "code" and not self.in_pre:
316
+ self.output.append("`")
317
+ self.in_code_block = False
318
+ elif tag == "pre":
319
+ self.in_pre = False
320
+ self.output.append("\n```\n")
321
+ elif tag == "ul":
322
+ self.list_depth = max(0, self.list_depth - 1)
323
+ elif tag == "ol":
324
+ self.list_depth = max(0, self.list_depth - 1)
325
+ if self.ordered_list_counters:
326
+ self.ordered_list_counters.pop()
327
+ elif tag == "li":
328
+ self.output.append("\n")
329
+ elif tag == "a":
330
+ if self.link_href:
331
+ self.output.append(f"]({self.link_href})")
332
+ else:
333
+ self.output.append("]")
334
+ self.link_href = None
335
+ elif tag == "p":
336
+ self.output.append("\n")
337
+ elif tag == "blockquote":
338
+ self.output.append("\n")
339
+ elif tag == "tr":
340
+ self.output.append("|\n")
341
+ elif tag == "thead":
342
+ # Add markdown table separator after header row
343
+ self.output.append("|---" * 10 + "|\n")
344
+
345
+ def handle_data(self, data: str) -> None:
346
+ if self.skip_content:
347
+ return
348
+ # Preserve whitespace in code blocks
349
+ if self.in_pre:
350
+ self.output.append(data)
351
+ else:
352
+ # Collapse whitespace outside code
353
+ text = re.sub(r"\s+", " ", data)
354
+ if text.strip():
355
+ self.output.append(text)
356
+
357
+ def get_markdown(self) -> str:
358
+ """Get the converted markdown, cleaned up."""
359
+ md = "".join(self.output)
360
+ # Clean up excessive newlines
361
+ md = re.sub(r"\n{3,}", "\n\n", md)
362
+ # Clean up empty table separators
363
+ md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
364
+ return md.strip()
365
+
366
+
367
+ def _html_to_markdown(html: str) -> str:
368
+ """Convert HTML to Markdown."""
369
+ parser = _HTMLToMarkdown()
370
+ parser.feed(html)
371
+ return parser.get_markdown()
372
+
373
+
172
374
  def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
173
- """Download NVIDIA docs using .md endpoint."""
375
+ """Download NVIDIA docs and convert HTML to Markdown.
376
+
377
+ NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
378
+ """
174
379
  assert config.urls is not None
175
380
  downloaded = 0
176
381
  with httpx.Client(timeout=30.0, follow_redirects=True) as client:
177
382
  for url in config.urls:
178
- md_url = f"{url}.md"
179
383
  filepath = _url_to_filepath(url, dest)
180
384
  filepath.parent.mkdir(parents=True, exist_ok=True)
181
385
  try:
182
- resp = client.get(md_url)
386
+ # Fetch HTML page directly
387
+ resp = client.get(url)
183
388
  resp.raise_for_status()
184
- filepath.write_text(resp.text)
389
+
390
+ # Convert HTML to Markdown
391
+ markdown = _html_to_markdown(resp.text)
392
+
393
+ # Add source URL as header
394
+ content = f"<!-- Source: {url} -->\n\n{markdown}"
395
+ filepath.write_text(content)
185
396
  downloaded += 1
186
397
  if verbose:
187
398
  print(f" ✓ {filepath.relative_to(dest)}")
@@ -275,6 +486,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
275
486
  return downloaded
276
487
 
277
488
 
489
+ def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
490
+ """Download from mixed sources (NVIDIA docs + GitHub repos)."""
491
+ total = 0
492
+
493
+ # Download NVIDIA markdown docs (urls)
494
+ if config.urls:
495
+ if verbose:
496
+ print(" [NVIDIA docs]")
497
+ total += _download_nvidia_md(config, dest, verbose)
498
+
499
+ # Download GitHub repos
500
+ if config.repos:
501
+ if verbose:
502
+ print(" [GitHub repos]")
503
+ total += _download_github_multi_repo(config, dest, verbose)
504
+
505
+ return total
506
+
507
+
278
508
  def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
279
509
  """Download a corpus to local cache.
280
510
 
@@ -311,6 +541,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
311
541
  count = _download_github_repo(config, dest, verbose)
312
542
  elif config.source_type == "github_multi_repo":
313
543
  count = _download_github_multi_repo(config, dest, verbose)
544
+ elif config.source_type == "mixed":
545
+ count = _download_mixed(config, dest, verbose)
314
546
  else:
315
547
  raise ValueError(f"Unknown source type: {config.source_type}")
316
548
  if verbose:
@@ -1168,11 +1168,16 @@ def _build_modal_sandbox_script(
1168
1168
  """
1169
1169
  gpu_type = target.gpu_type
1170
1170
 
1171
- # Determine PyTorch index based on GPU type
1171
+ # Determine PyTorch index and CUDA arch based on GPU type
1172
1172
  if gpu_type in ("B200", "GB200"):
1173
- torch_index = "https://download.pytorch.org/whl/nightly/cu128"
1173
+ torch_index = "https://download.pytorch.org/whl/cu130"
1174
+ cuda_arch_list = "10.0" # Blackwell (sm_100)
1175
+ elif gpu_type == "H100":
1176
+ torch_index = "https://download.pytorch.org/whl/cu130"
1177
+ cuda_arch_list = "9.0" # Hopper (sm_90)
1174
1178
  else:
1175
1179
  torch_index = "https://download.pytorch.org/whl/cu124"
1180
+ cuda_arch_list = "8.0" # Default to Ampere (sm_80)
1176
1181
 
1177
1182
  return f'''
1178
1183
  import asyncio
@@ -1190,7 +1195,7 @@ async def run_eval():
1190
1195
  "nvidia/cuda:12.9.0-devel-ubuntu22.04",
1191
1196
  add_python="3.12",
1192
1197
  )
1193
- .apt_install("git", "build-essential", "cmake")
1198
+ .apt_install("git", "build-essential", "cmake", "ripgrep")
1194
1199
  .pip_install(
1195
1200
  "torch",
1196
1201
  index_url="{torch_index}",
@@ -1203,6 +1208,12 @@ async def run_eval():
1203
1208
  )
1204
1209
  .env({{
1205
1210
  "CUDA_HOME": "/usr/local/cuda",
1211
+ # C++ compiler needs explicit include path for cuda_runtime.h
1212
+ "CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
1213
+ # Linker needs lib path
1214
+ "LIBRARY_PATH": "/usr/local/cuda/lib64",
1215
+ # Force PyTorch to compile for correct GPU architecture
1216
+ "TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
1206
1217
  }})
1207
1218
  )
1208
1219
 
@@ -2790,6 +2801,15 @@ if torch.cuda.is_available():
2790
2801
  gc.collect()
2791
2802
  torch.cuda.empty_cache()
2792
2803
  torch.cuda.reset_peak_memory_stats()
2804
+
2805
+ # Enable TF32 for fair benchmarking against reference kernels.
2806
+ # PyTorch 1.12+ disables TF32 for matmul by default, which handicaps
2807
+ # reference kernels using cuBLAS. We enable it so reference kernels
2808
+ # run at their best performance (using tensor cores when applicable).
2809
+ # This ensures speedup comparisons are against optimized baselines.
2810
+ torch.backends.cuda.matmul.allow_tf32 = True
2811
+ torch.backends.cudnn.allow_tf32 = True
2812
+ print("[KernelBench] TF32 enabled for fair benchmarking")
2793
2813
 
2794
2814
 
2795
2815
  def _calculate_timing_stats(times: list[float]) -> dict:
@@ -3453,6 +3473,368 @@ def _validate_kernelbench_files(args: KernelBenchEvaluateArgs) -> str | None:
3453
3473
  return None
3454
3474
 
3455
3475
 
3476
+ def _build_modal_kernelbench_script(
3477
+ target: ModalTarget,
3478
+ impl_code_b64: str,
3479
+ ref_code_b64: str,
3480
+ eval_script_b64: str,
3481
+ run_benchmarks: bool,
3482
+ run_defensive: bool,
3483
+ defense_code_b64: str | None,
3484
+ seed: int,
3485
+ inputs_code_b64: str | None = None,
3486
+ ) -> str:
3487
+ """Build Python script to create Modal sandbox and run KernelBench evaluation.
3488
+
3489
+ This runs in a subprocess to isolate Modal's asyncio from trio.
3490
+ """
3491
+ gpu_type = target.gpu_type
3492
+
3493
+ # Determine PyTorch index and CUDA arch based on GPU type
3494
+ if gpu_type in ("B200", "GB200"):
3495
+ torch_index = "https://download.pytorch.org/whl/cu130"
3496
+ cuda_arch_list = "10.0" # Blackwell (sm_100)
3497
+ elif gpu_type == "H100":
3498
+ # H100 uses CUDA 13.0 (matches modal_app.py)
3499
+ torch_index = "https://download.pytorch.org/whl/cu130"
3500
+ cuda_arch_list = "9.0" # Hopper (sm_90)
3501
+ else:
3502
+ torch_index = "https://download.pytorch.org/whl/cu124"
3503
+ cuda_arch_list = "8.0" # Default to Ampere (sm_80)
3504
+
3505
+ # Install CUTLASS headers (for cute/tensor.hpp and cutlass/util/*.h) from GitHub
3506
+ # The nvidia-cutlass-dsl pip package doesn't include the C++ headers needed for nvcc
3507
+ # IMPORTANT: symlink to /usr/local/cuda/include because nvcc searches there by default
3508
+ cutlass_install = '''
3509
+ .run_commands([
3510
+ # Clone CUTLASS headers from GitHub (shallow clone, full include tree)
3511
+ # Use simple shallow clone - sparse-checkout can be buggy in some environments
3512
+ "git clone --depth 1 https://github.com/NVIDIA/cutlass.git /opt/cutlass",
3513
+ # Verify the util headers exist (for debugging)
3514
+ "ls -la /opt/cutlass/include/cutlass/util/ | head -5",
3515
+ # Symlink headers to CUDA include path (nvcc searches here by default)
3516
+ "ln -sf /opt/cutlass/include/cute /usr/local/cuda/include/cute",
3517
+ "ln -sf /opt/cutlass/include/cutlass /usr/local/cuda/include/cutlass",
3518
+ ])
3519
+ .pip_install(
3520
+ "nvidia-cutlass-dsl",
3521
+ index_url="https://pypi.nvidia.com",
3522
+ extra_index_url="https://pypi.org/simple",
3523
+ )
3524
+ '''
3525
+
3526
+ inputs_write = ""
3527
+ if inputs_code_b64:
3528
+ inputs_write = f'''
3529
+ # Write custom inputs
3530
+ proc = sandbox.exec("python", "-c", f"""
3531
+ import base64
3532
+ with open('/workspace/custom_inputs.py', 'w') as f:
3533
+ f.write(base64.b64decode('{inputs_code_b64}').decode())
3534
+ print('Custom inputs written')
3535
+ """)
3536
+ proc.wait()
3537
+ '''
3538
+
3539
+ defense_write = ""
3540
+ if run_defensive and defense_code_b64:
3541
+ defense_write = f'''
3542
+ # Write defense module
3543
+ proc = sandbox.exec("python", "-c", f"""
3544
+ import base64
3545
+ with open('/workspace/defense.py', 'w') as f:
3546
+ f.write(base64.b64decode('{defense_code_b64}').decode())
3547
+ print('Defense module written')
3548
+ """)
3549
+ proc.wait()
3550
+ '''
3551
+
3552
+ # Build eval command
3553
+ eval_cmd_parts = [
3554
+ "python /workspace/kernelbench_eval.py",
3555
+ "--impl /workspace/implementation.py",
3556
+ "--reference /workspace/reference.py",
3557
+ "--output /workspace/results.json",
3558
+ f"--seed {seed}",
3559
+ ]
3560
+ if run_benchmarks:
3561
+ eval_cmd_parts.append("--benchmark")
3562
+ if run_defensive and defense_code_b64:
3563
+ eval_cmd_parts.append("--defensive")
3564
+ eval_cmd_parts.append("--defense-module /workspace/defense.py")
3565
+ if inputs_code_b64:
3566
+ eval_cmd_parts.append("--inputs /workspace/custom_inputs.py")
3567
+
3568
+ eval_cmd = " ".join(eval_cmd_parts)
3569
+
3570
+ return f'''
3571
+ import asyncio
3572
+ import base64
3573
+ import json
3574
+ import sys
3575
+ import modal
3576
+
3577
+ async def run_eval():
3578
+ app = modal.App.lookup("wafer-evaluate", create_if_missing=True)
3579
+
3580
+ # Build image with PyTorch, CUTLASS DSL and dependencies
3581
+ image = (
3582
+ modal.Image.from_registry(
3583
+ "nvidia/cuda:12.9.0-devel-ubuntu22.04",
3584
+ add_python="3.12",
3585
+ )
3586
+ .apt_install("git", "build-essential", "cmake", "ninja-build", "ripgrep")
3587
+ .pip_install(
3588
+ "torch",
3589
+ index_url="{torch_index}",
3590
+ extra_index_url="https://pypi.org/simple",
3591
+ )
3592
+ .pip_install(
3593
+ "numpy",
3594
+ "triton",
3595
+ "ninja",
3596
+ )
3597
+ {cutlass_install}
3598
+ .env({{
3599
+ "CUDA_HOME": "/usr/local/cuda",
3600
+ # C++ compiler needs explicit include path for cuda_runtime.h
3601
+ "CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
3602
+ # Linker needs lib path
3603
+ "LIBRARY_PATH": "/usr/local/cuda/lib64",
3604
+ # Force PyTorch to compile for correct GPU architecture
3605
+ "TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
3606
+ }})
3607
+ )
3608
+
3609
+ # Create sandbox
3610
+ sandbox = modal.Sandbox.create(
3611
+ app=app,
3612
+ image=image,
3613
+ gpu="{gpu_type}",
3614
+ timeout={target.timeout_seconds},
3615
+ )
3616
+
3617
+ try:
3618
+ # Create workspace directory
3619
+ sandbox.exec("mkdir", "-p", "/workspace").wait()
3620
+
3621
+ # Write files to sandbox
3622
+ proc = sandbox.exec("python", "-c", f"""
3623
+ import base64
3624
+ with open('/workspace/implementation.py', 'w') as f:
3625
+ f.write(base64.b64decode('{impl_code_b64}').decode())
3626
+ with open('/workspace/reference.py', 'w') as f:
3627
+ f.write(base64.b64decode('{ref_code_b64}').decode())
3628
+ with open('/workspace/kernelbench_eval.py', 'w') as f:
3629
+ f.write(base64.b64decode('{eval_script_b64}').decode())
3630
+ print('Files written')
3631
+ """)
3632
+ proc.wait()
3633
+ if proc.returncode != 0:
3634
+ print(json.dumps({{"success": False, "error": f"Failed to write files: {{proc.stderr.read()}}"}}))
3635
+ return
3636
+ {inputs_write}
3637
+ {defense_write}
3638
+ # Run evaluation
3639
+ print(f"Running KernelBench evaluation on {{'{gpu_type}'}}...")
3640
+ proc = sandbox.exec("bash", "-c", "{eval_cmd}")
3641
+
3642
+ # Stream output
3643
+ for line in proc.stdout:
3644
+ print(line, end="")
3645
+ for line in proc.stderr:
3646
+ print(line, end="", file=sys.stderr)
3647
+
3648
+ proc.wait()
3649
+
3650
+ if proc.returncode != 0:
3651
+ print(json.dumps({{"success": False, "error": f"Evaluation failed with exit code {{proc.returncode}}"}}))
3652
+ return
3653
+
3654
+ # Read results
3655
+ result_proc = sandbox.exec("cat", "/workspace/results.json")
3656
+ result_data = result_proc.stdout.read()
3657
+ result_proc.wait()
3658
+
3659
+ if result_data:
3660
+ results = json.loads(result_data)
3661
+ print("EVAL_RESULT_JSON:" + json.dumps(results))
3662
+ else:
3663
+ print(json.dumps({{"success": False, "error": "No results.json found"}}))
3664
+
3665
+ finally:
3666
+ sandbox.terminate()
3667
+
3668
+ asyncio.run(run_eval())
3669
+ '''
3670
+
3671
+
3672
+ async def run_evaluate_kernelbench_modal(
3673
+ args: KernelBenchEvaluateArgs,
3674
+ target: ModalTarget,
3675
+ ) -> EvaluateResult:
3676
+ """Run KernelBench format evaluation on Modal sandbox.
3677
+
3678
+ Creates a Modal sandbox, uploads files, runs KernelBench eval, and parses results.
3679
+ Uses subprocess to isolate Modal's asyncio from trio.
3680
+ """
3681
+ import base64
3682
+ import subprocess
3683
+ import sys
3684
+
3685
+ import trio
3686
+
3687
+ print(f"Creating Modal sandbox ({target.gpu_type}) for KernelBench evaluation...")
3688
+
3689
+ # Encode files as base64
3690
+ impl_code_b64 = base64.b64encode(args.implementation.read_bytes()).decode()
3691
+ ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
3692
+ eval_script_b64 = base64.b64encode(KERNELBENCH_EVAL_SCRIPT.encode()).decode()
3693
+
3694
+ # Encode custom inputs if provided
3695
+ inputs_code_b64 = None
3696
+ if args.inputs:
3697
+ inputs_code_b64 = base64.b64encode(args.inputs.read_bytes()).decode()
3698
+
3699
+ # Encode defense module if defensive mode is enabled
3700
+ defense_code_b64 = None
3701
+ if args.defensive:
3702
+ defense_path = (
3703
+ Path(__file__).parent.parent.parent.parent
3704
+ / "packages"
3705
+ / "wafer-core"
3706
+ / "wafer_core"
3707
+ / "utils"
3708
+ / "kernel_utils"
3709
+ / "defense.py"
3710
+ )
3711
+ if defense_path.exists():
3712
+ defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
3713
+ else:
3714
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
3715
+
3716
+ # Build the script
3717
+ script = _build_modal_kernelbench_script(
3718
+ target=target,
3719
+ impl_code_b64=impl_code_b64,
3720
+ ref_code_b64=ref_code_b64,
3721
+ eval_script_b64=eval_script_b64,
3722
+ run_benchmarks=args.benchmark,
3723
+ run_defensive=args.defensive,
3724
+ defense_code_b64=defense_code_b64,
3725
+ seed=args.seed,
3726
+ inputs_code_b64=inputs_code_b64,
3727
+ )
3728
+
3729
+ def _run_subprocess() -> tuple[str, str, int]:
3730
+ result = subprocess.run(
3731
+ [sys.executable, "-c", script],
3732
+ capture_output=True,
3733
+ text=True,
3734
+ timeout=target.timeout_seconds + 120, # Extra buffer for sandbox creation + image build
3735
+ )
3736
+ return result.stdout, result.stderr, result.returncode
3737
+
3738
+ try:
3739
+ stdout, stderr, returncode = await trio.to_thread.run_sync(_run_subprocess)
3740
+ except subprocess.TimeoutExpired:
3741
+ return EvaluateResult(
3742
+ success=False,
3743
+ all_correct=False,
3744
+ correctness_score=0.0,
3745
+ geomean_speedup=0.0,
3746
+ passed_tests=0,
3747
+ total_tests=0,
3748
+ error_message=f"Modal KernelBench evaluation timed out after {target.timeout_seconds}s",
3749
+ )
3750
+ except Exception as e:
3751
+ return EvaluateResult(
3752
+ success=False,
3753
+ all_correct=False,
3754
+ correctness_score=0.0,
3755
+ geomean_speedup=0.0,
3756
+ passed_tests=0,
3757
+ total_tests=0,
3758
+ error_message=f"Failed to run Modal sandbox: {e}",
3759
+ )
3760
+
3761
+ # Print output for debugging
3762
+ if stdout:
3763
+ for line in stdout.split("\n"):
3764
+ if not line.startswith("EVAL_RESULT_JSON:"):
3765
+ print(line)
3766
+ if stderr:
3767
+ print(stderr, file=sys.stderr)
3768
+
3769
+ if returncode != 0:
3770
+ return EvaluateResult(
3771
+ success=False,
3772
+ all_correct=False,
3773
+ correctness_score=0.0,
3774
+ geomean_speedup=0.0,
3775
+ passed_tests=0,
3776
+ total_tests=0,
3777
+ error_message=f"Modal sandbox failed (exit {returncode}): {stderr or stdout}",
3778
+ )
3779
+
3780
+ # Parse results from stdout
3781
+ result_json = None
3782
+ for line in stdout.split("\n"):
3783
+ if line.startswith("EVAL_RESULT_JSON:"):
3784
+ result_json = line[len("EVAL_RESULT_JSON:"):]
3785
+ break
3786
+
3787
+ if not result_json:
3788
+ return EvaluateResult(
3789
+ success=False,
3790
+ all_correct=False,
3791
+ correctness_score=0.0,
3792
+ geomean_speedup=0.0,
3793
+ passed_tests=0,
3794
+ total_tests=0,
3795
+ error_message="No results found in Modal output",
3796
+ )
3797
+
3798
+ try:
3799
+ results = json.loads(result_json)
3800
+ except json.JSONDecodeError as e:
3801
+ return EvaluateResult(
3802
+ success=False,
3803
+ all_correct=False,
3804
+ correctness_score=0.0,
3805
+ geomean_speedup=0.0,
3806
+ passed_tests=0,
3807
+ total_tests=0,
3808
+ error_message=f"Failed to parse results JSON: {e}",
3809
+ )
3810
+
3811
+ # Check for error in results
3812
+ if "error" in results and results.get("success") is False:
3813
+ return EvaluateResult(
3814
+ success=False,
3815
+ all_correct=False,
3816
+ correctness_score=0.0,
3817
+ geomean_speedup=0.0,
3818
+ passed_tests=0,
3819
+ total_tests=0,
3820
+ error_message=results.get("error", "Unknown error"),
3821
+ )
3822
+
3823
+ # Extract metrics from results
3824
+ return EvaluateResult(
3825
+ success=True,
3826
+ all_correct=results.get("all_correct", False),
3827
+ correctness_score=float(results.get("correctness_score", 0.0)),
3828
+ geomean_speedup=float(results.get("geomean_speedup", 0.0)),
3829
+ passed_tests=int(results.get("passed_tests", 0)),
3830
+ total_tests=int(results.get("total_tests", 0)),
3831
+ error_message=results.get("error"),
3832
+ test_results=results.get("test_results", []),
3833
+ compilation_time_s=results.get("compilation_time_s"),
3834
+ profiling_stats=results.get("profiling_stats"),
3835
+ )
3836
+
3837
+
3456
3838
  async def run_evaluate_kernelbench_docker(
3457
3839
  args: KernelBenchEvaluateArgs,
3458
3840
  target: BaremetalTarget | VMTarget,
@@ -4246,6 +4628,20 @@ async def run_evaluate_kernelbench_runpod(
4246
4628
  )
4247
4629
 
4248
4630
 
4631
+ async def run_evaluate_kernelbench_baremetal_direct(
4632
+ args: KernelBenchEvaluateArgs,
4633
+ target: BaremetalTarget,
4634
+ ) -> EvaluateResult:
4635
+ """Run KernelBench format evaluation directly on NVIDIA target (no Docker).
4636
+
4637
+ For targets that already have PyTorch/CUDA installed (e.g., workspace containers).
4638
+ Uses CUDA_VISIBLE_DEVICES for GPU selection.
4639
+ """
4640
+ # Reuse the AMD function but with CUDA env vars
4641
+ # The logic is identical, just the GPU env var is different
4642
+ return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="CUDA_VISIBLE_DEVICES")
4643
+
4644
+
4249
4645
  async def run_evaluate_kernelbench_baremetal_amd(
4250
4646
  args: KernelBenchEvaluateArgs,
4251
4647
  target: BaremetalTarget,
@@ -4255,6 +4651,18 @@ async def run_evaluate_kernelbench_baremetal_amd(
4255
4651
  Runs evaluation script directly on host (no Docker) for AMD GPUs
4256
4652
  that have PyTorch/ROCm installed.
4257
4653
  """
4654
+ return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="HIP_VISIBLE_DEVICES")
4655
+
4656
+
4657
+ async def _run_evaluate_kernelbench_baremetal_direct_impl(
4658
+ args: KernelBenchEvaluateArgs,
4659
+ target: BaremetalTarget,
4660
+ gpu_env_var: str = "HIP_VISIBLE_DEVICES",
4661
+ ) -> EvaluateResult:
4662
+ """Internal implementation for direct baremetal evaluation.
4663
+
4664
+ Runs evaluation script directly on host (no Docker).
4665
+ """
4258
4666
  from datetime import datetime
4259
4667
 
4260
4668
  from wafer_core.async_ssh import AsyncSSHClient
@@ -4405,11 +4813,15 @@ async def run_evaluate_kernelbench_baremetal_amd(
4405
4813
 
4406
4814
  eval_cmd = " ".join(python_cmd_parts)
4407
4815
 
4408
- # Set environment for AMD GPU and run
4409
- # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
4410
- rocm_arch = _get_rocm_arch(target.compute_capability)
4411
- arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4412
- env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4816
+ # Set environment for GPU and run
4817
+ if gpu_env_var == "HIP_VISIBLE_DEVICES":
4818
+ # AMD: PYTORCH_ROCM_ARCH for faster compile
4819
+ rocm_arch = _get_rocm_arch(target.compute_capability)
4820
+ arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4821
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4822
+ else:
4823
+ # NVIDIA: just set CUDA_VISIBLE_DEVICES
4824
+ env_vars = f"CUDA_VISIBLE_DEVICES={gpu_id} PYTHONUNBUFFERED=1"
4413
4825
  full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4414
4826
 
4415
4827
  # Handle prepare-only mode
@@ -4560,10 +4972,16 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
4560
4972
  elif isinstance(target, RunPodTarget):
4561
4973
  # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4562
4974
  return await run_evaluate_kernelbench_runpod(args, target)
4975
+ elif isinstance(target, ModalTarget):
4976
+ # Modal serverless - runs in Modal sandbox
4977
+ return await run_evaluate_kernelbench_modal(args, target)
4563
4978
  elif isinstance(target, BaremetalTarget | VMTarget):
4564
4979
  # Check if this is an AMD target (gfx* compute capability) - run directly
4565
4980
  if target.compute_capability and target.compute_capability.startswith("gfx"):
4566
4981
  return await run_evaluate_kernelbench_baremetal_amd(args, target)
4982
+ # Check for direct execution flag (workspace containers that already have everything)
4983
+ if getattr(target, "direct", False):
4984
+ return await run_evaluate_kernelbench_baremetal_direct(args, target)
4567
4985
  # NVIDIA targets - require docker_image to be set
4568
4986
  if not target.docker_image:
4569
4987
  return EvaluateResult(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.21
3
+ Version: 0.2.23
4
4
  Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: typer>=0.12.0
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes