wafer-cli 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/cli.py CHANGED
@@ -1,6 +1,8 @@
1
- # ruff: noqa: PLR0913
1
+ # ruff: noqa: PLR0913, E402
2
2
  # PLR0913 (too many arguments) is suppressed because Typer CLI commands
3
3
  # naturally have many parameters - each --flag becomes a function argument.
4
+ # E402 (module level import not at top) is suppressed because we intentionally
5
+ # load .env files before importing other modules that may read env vars.
4
6
  """Wafer CLI - GPU development toolkit for LLM coding agents.
5
7
 
6
8
  Core commands:
@@ -27,6 +29,12 @@ from pathlib import Path
27
29
 
28
30
  import trio
29
31
  import typer
32
+ from dotenv import load_dotenv
33
+
34
+ # Auto-load .env from current directory and ~/.wafer/.env
35
+ # This runs at import time so env vars are available before any config is accessed
36
+ load_dotenv() # cwd/.env
37
+ load_dotenv(Path.home() / ".wafer" / ".env") # ~/.wafer/.env
30
38
 
31
39
  from .config import WaferConfig, WaferEnvironment
32
40
  from .inference import infer_upload_files, resolve_environment
@@ -42,6 +50,7 @@ from .problems import (
42
50
  app = typer.Typer(
43
51
  help="GPU development toolkit for LLM coding agents",
44
52
  no_args_is_help=True,
53
+ pretty_exceptions_show_locals=False, # Don't dump local vars (makes tracebacks huge)
45
54
  )
46
55
 
47
56
  # =============================================================================
@@ -58,11 +67,11 @@ def _show_version() -> None:
58
67
  """Show CLI version and environment, then exit."""
59
68
  from .analytics import _get_cli_version
60
69
  from .global_config import load_global_config
61
-
70
+
62
71
  version = _get_cli_version()
63
72
  config = load_global_config()
64
73
  environment = config.environment
65
-
74
+
66
75
  typer.echo(f"wafer-cli {version} ({environment})")
67
76
  raise typer.Exit()
68
77
 
@@ -110,7 +119,7 @@ def main_callback(
110
119
  if version:
111
120
  _show_version()
112
121
  return
113
-
122
+
114
123
  global _command_start_time, _command_outcome
115
124
  _command_start_time = time.time()
116
125
  _command_outcome = "success" # Default to success, mark failure on exceptions
@@ -121,6 +130,7 @@ def main_callback(
121
130
  analytics.init_analytics()
122
131
 
123
132
  # Install exception hook to catch SystemExit and mark failures
133
+ # Also prints error message FIRST so it's visible even when traceback is truncated
124
134
  original_excepthook = sys.excepthook
125
135
 
126
136
  def custom_excepthook(
@@ -136,7 +146,11 @@ def main_callback(
136
146
  _command_outcome = "failure"
137
147
  else:
138
148
  _command_outcome = "failure"
139
- # Call original excepthook
149
+ # Print error summary FIRST (before traceback) so it's visible even if truncated
150
+ print(
151
+ f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
152
+ )
153
+ # Call original excepthook (prints the full traceback)
140
154
  original_excepthook(exc_type, exc_value, exc_traceback)
141
155
 
142
156
  sys.excepthook = custom_excepthook
@@ -591,7 +605,7 @@ app.add_typer(provider_auth_app, name="auth")
591
605
  def provider_auth_login(
592
606
  provider: str = typer.Argument(
593
607
  ...,
594
- help="Provider name: runpod, digitalocean, or modal",
608
+ help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
595
609
  ),
596
610
  api_key: str | None = typer.Option(
597
611
  None,
@@ -600,15 +614,16 @@ def provider_auth_login(
600
614
  help="API key (if not provided, reads from stdin)",
601
615
  ),
602
616
  ) -> None:
603
- """Save API key for a cloud GPU provider.
617
+ """Save API key for a provider.
604
618
 
605
619
  Stores the key in ~/.wafer/auth.json. Environment variables
606
- (e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
620
+ (e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
607
621
 
608
622
  Examples:
623
+ wafer auth login anthropic --api-key sk-ant-xxx
609
624
  wafer auth login runpod --api-key rp_xxx
610
- wafer auth login digitalocean --api-key dop_v1_xxx
611
- echo $API_KEY | wafer auth login runpod
625
+ wafer auth login openai --api-key sk-xxx
626
+ echo $API_KEY | wafer auth login anthropic
612
627
  """
613
628
  import sys
614
629
 
@@ -642,7 +657,7 @@ def provider_auth_login(
642
657
  def provider_auth_logout(
643
658
  provider: str = typer.Argument(
644
659
  ...,
645
- help="Provider name: runpod, digitalocean, or modal",
660
+ help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
646
661
  ),
647
662
  ) -> None:
648
663
  """Remove stored API key for a cloud GPU provider.
@@ -3473,7 +3488,7 @@ def init_runpod(
3473
3488
  gpu_configs = {
3474
3489
  "MI300X": {
3475
3490
  "gpu_type_id": "AMD Instinct MI300X OAM",
3476
- "image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
3491
+ "image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
3477
3492
  "compute_capability": "9.4",
3478
3493
  },
3479
3494
  "H100": {
@@ -3569,7 +3584,7 @@ def init_digitalocean(
3569
3584
  "ssh_key": ssh_key,
3570
3585
  "region": region,
3571
3586
  "size_slug": "gpu-mi300x1-192gb-devcloud",
3572
- "image": "gpu-amd-base",
3587
+ "image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
3573
3588
  "provision_timeout": 600,
3574
3589
  "eval_timeout": 600,
3575
3590
  "keep_alive": keep_alive,
@@ -4071,6 +4086,164 @@ def targets_cleanup(
4071
4086
  raise typer.Exit(1) from None
4072
4087
 
4073
4088
 
4089
+ # Known libraries that can be installed on targets
4090
+ # TODO: Consider adding HipKittens to the default RunPod/DO Docker images
4091
+ # so this install step isn't needed. For now, this command handles it.
4092
+ INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
4093
+ "hipkittens": {
4094
+ "description": "HipKittens - AMD port of ThunderKittens for MI300X",
4095
+ "git_url": "https://github.com/HazyResearch/hipkittens.git",
4096
+ "install_path": "/opt/hipkittens",
4097
+ "requires_amd": True,
4098
+ },
4099
+ # CK is already installed with ROCm 7.0, no action needed
4100
+ "repair-headers": {
4101
+ "description": "Repair ROCm thrust headers (fixes hipify corruption)",
4102
+ "custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
4103
+ "requires_amd": True,
4104
+ },
4105
+ }
4106
+
4107
+
4108
+ @targets_app.command("install")
4109
+ def targets_install(
4110
+ name: str = typer.Argument(..., help="Target name"),
4111
+ library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
4112
+ ) -> None:
4113
+ """Install a library or run maintenance on a target (idempotent).
4114
+
4115
+ Installs header-only libraries like HipKittens on remote targets.
4116
+ Safe to run multiple times - will skip if already installed.
4117
+
4118
+ Available libraries:
4119
+ hipkittens - HipKittens (AMD ThunderKittens port)
4120
+ repair-headers - Fix ROCm thrust headers (after hipify corruption)
4121
+
4122
+ Examples:
4123
+ wafer config targets install runpod-mi300x hipkittens
4124
+ wafer config targets install runpod-mi300x repair-headers
4125
+ wafer config targets install do-mi300x hipkittens
4126
+ """
4127
+ import subprocess
4128
+
4129
+ from .targets import load_target
4130
+ from .targets_ops import get_target_ssh_info
4131
+
4132
+ if library not in INSTALLABLE_LIBRARIES:
4133
+ available = ", ".join(INSTALLABLE_LIBRARIES.keys())
4134
+ typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
4135
+ raise typer.Exit(1)
4136
+
4137
+ lib_info = INSTALLABLE_LIBRARIES[library]
4138
+
4139
+ try:
4140
+ target = load_target(name)
4141
+ except FileNotFoundError as e:
4142
+ typer.echo(f"Error: {e}", err=True)
4143
+ raise typer.Exit(1) from None
4144
+
4145
+ # Check if target is AMD (for AMD-only libraries)
4146
+ if lib_info.get("requires_amd"):
4147
+ from wafer_core.utils.kernel_utils.targets.config import (
4148
+ DigitalOceanTarget,
4149
+ RunPodTarget,
4150
+ )
4151
+
4152
+ is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
4153
+ if not is_amd and hasattr(target, "compute_capability"):
4154
+ # Check compute capability for MI300X (gfx942 = 9.4)
4155
+ is_amd = target.compute_capability.startswith("9.")
4156
+ if not is_amd:
4157
+ typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
4158
+ raise typer.Exit(1)
4159
+
4160
+ typer.echo(f"Installing {library} on {name}...")
4161
+ typer.echo(f" {lib_info['description']}")
4162
+
4163
+ async def _install() -> bool:
4164
+ # get_target_ssh_info uses pure trio async (no asyncio bridging needed)
4165
+ # and we use subprocess for SSH, not AsyncSSHClient
4166
+ ssh_info = await get_target_ssh_info(target)
4167
+
4168
+ ssh_cmd = [
4169
+ "ssh",
4170
+ "-o",
4171
+ "StrictHostKeyChecking=no",
4172
+ "-o",
4173
+ "UserKnownHostsFile=/dev/null",
4174
+ "-o",
4175
+ "ConnectTimeout=30",
4176
+ "-i",
4177
+ str(ssh_info.key_path),
4178
+ "-p",
4179
+ str(ssh_info.port),
4180
+ f"{ssh_info.user}@{ssh_info.host}",
4181
+ ]
4182
+
4183
+ # Handle custom scripts (like repair-headers) vs git installs
4184
+ if "custom_script" in lib_info:
4185
+ install_script = str(lib_info["custom_script"])
4186
+ success_marker = "REPAIRED"
4187
+ else:
4188
+ install_path = lib_info["install_path"]
4189
+ git_url = lib_info["git_url"]
4190
+
4191
+ # Idempotent install script
4192
+ install_script = f"""
4193
+ if [ -d "{install_path}" ]; then
4194
+ echo "ALREADY_INSTALLED: {install_path} exists"
4195
+ cd {install_path} && git pull --quiet 2>/dev/null || true
4196
+ else
4197
+ echo "INSTALLING: cloning to {install_path}"
4198
+ git clone --quiet {git_url} {install_path}
4199
+ fi
4200
+ echo "DONE"
4201
+ """
4202
+ success_marker = "DONE"
4203
+
4204
+ def run_ssh() -> subprocess.CompletedProcess[str]:
4205
+ return subprocess.run(
4206
+ ssh_cmd + [install_script],
4207
+ capture_output=True,
4208
+ text=True,
4209
+ timeout=120,
4210
+ )
4211
+
4212
+ result = await trio.to_thread.run_sync(run_ssh)
4213
+
4214
+ if result.returncode != 0:
4215
+ typer.echo(f"Error: {result.stderr}", err=True)
4216
+ return False
4217
+
4218
+ output = result.stdout.strip()
4219
+ if "ALREADY_INSTALLED" in output:
4220
+ typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
4221
+ elif "INSTALLING" in output:
4222
+ typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
4223
+ elif "REPAIRED" in output:
4224
+ typer.echo(" ROCm headers repaired")
4225
+
4226
+ return success_marker in output
4227
+
4228
+ try:
4229
+ success = trio.run(_install)
4230
+ except Exception as e:
4231
+ typer.echo(f"Error: {e}", err=True)
4232
+ raise typer.Exit(1) from None
4233
+
4234
+ if success:
4235
+ typer.echo(f"✓ {library} ready on {name}")
4236
+
4237
+ # Print usage hint
4238
+ if library == "hipkittens":
4239
+ typer.echo("")
4240
+ typer.echo("Usage in load_inline:")
4241
+ typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
4242
+ else:
4243
+ typer.echo(f"Failed to install {library}", err=True)
4244
+ raise typer.Exit(1)
4245
+
4246
+
4074
4247
  @targets_app.command("pods")
4075
4248
  def targets_pods() -> None:
4076
4249
  """List all running RunPod pods.
@@ -4406,9 +4579,13 @@ def workspaces_list(
4406
4579
  @workspaces_app.command("create")
4407
4580
  def workspaces_create(
4408
4581
  name: str = typer.Argument(..., help="Workspace name"),
4409
- gpu_type: str = typer.Option("B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"),
4582
+ gpu_type: str = typer.Option(
4583
+ "B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"
4584
+ ),
4410
4585
  image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
4411
- wait: bool = typer.Option(False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"),
4586
+ wait: bool = typer.Option(
4587
+ False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"
4588
+ ),
4412
4589
  json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
4413
4590
  ) -> None:
4414
4591
  """Create a new workspace.
@@ -4717,19 +4894,25 @@ def workspaces_ssh(
4717
4894
  ssh_host = ws.get("ssh_host")
4718
4895
  ssh_port = ws.get("ssh_port")
4719
4896
  ssh_user = ws.get("ssh_user")
4720
-
4897
+
4721
4898
  if not ssh_host or not ssh_port or not ssh_user:
4722
4899
  typer.echo("Error: Workspace not ready. Wait a few seconds and retry.", err=True)
4723
4900
  raise typer.Exit(1)
4724
4901
 
4725
4902
  # Connect via SSH
4726
- os.execvp("ssh", [
4903
+ os.execvp(
4727
4904
  "ssh",
4728
- "-p", str(ssh_port),
4729
- "-o", "StrictHostKeyChecking=no",
4730
- "-o", "UserKnownHostsFile=/dev/null",
4731
- f"{ssh_user}@{ssh_host}",
4732
- ])
4905
+ [
4906
+ "ssh",
4907
+ "-p",
4908
+ str(ssh_port),
4909
+ "-o",
4910
+ "StrictHostKeyChecking=no",
4911
+ "-o",
4912
+ "UserKnownHostsFile=/dev/null",
4913
+ f"{ssh_user}@{ssh_host}",
4914
+ ],
4915
+ )
4733
4916
 
4734
4917
 
4735
4918
  @workspaces_app.command("sync")
wafer/corpus.py CHANGED
@@ -3,10 +3,12 @@
3
3
  Download and manage documentation corpora for agent filesystem access.
4
4
  """
5
5
 
6
+ import re
6
7
  import shutil
7
8
  import tarfile
8
9
  import tempfile
9
10
  from dataclasses import dataclass
11
+ from html.parser import HTMLParser
10
12
  from pathlib import Path
11
13
  from typing import Literal
12
14
  from urllib.parse import urlparse
@@ -33,7 +35,7 @@ class CorpusConfig:
33
35
 
34
36
  name: CorpusName
35
37
  description: str
36
- source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
38
+ source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
37
39
  urls: list[str] | None = None
38
40
  repo: str | None = None
39
41
  repo_paths: list[str] | None = None
@@ -67,10 +69,43 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
67
69
  ),
68
70
  "cutlass": CorpusConfig(
69
71
  name="cutlass",
70
- description="CUTLASS and CuTe DSL documentation",
71
- source_type="github_repo",
72
- repo="NVIDIA/cutlass",
73
- repo_paths=["media/docs", "python/cutlass/docs"],
72
+ description="CUTLASS C++ documentation, examples, and tutorials",
73
+ source_type="mixed",
74
+ # Official NVIDIA CUTLASS documentation (scraped as markdown)
75
+ urls=[
76
+ "https://docs.nvidia.com/cutlass/latest/overview.html",
77
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
78
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
79
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
80
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
81
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
82
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
83
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
84
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
85
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
86
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
87
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
88
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
89
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
90
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
91
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
92
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
93
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
94
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
95
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
96
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
97
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
98
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
99
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
100
+ ],
101
+ # NVIDIA/cutlass GitHub examples (excluding python/)
102
+ repos=[
103
+ RepoSource(
104
+ repo="NVIDIA/cutlass",
105
+ paths=["examples"],
106
+ branch="main",
107
+ ),
108
+ ],
74
109
  ),
75
110
  "hip": CorpusConfig(
76
111
  name="hip",
@@ -169,19 +204,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
169
204
  return base_dir / "/".join(path_parts)
170
205
 
171
206
 
207
+ class _HTMLToMarkdown(HTMLParser):
208
+ """HTML to Markdown converter for NVIDIA documentation pages.
209
+
210
+ Uses stdlib HTMLParser - requires subclassing due to callback-based API.
211
+ The public interface is the functional `_html_to_markdown()` below.
212
+ """
213
+
214
+ def __init__(self) -> None:
215
+ super().__init__()
216
+ self.output: list[str] = []
217
+ self.current_tag: str = ""
218
+ self.in_code_block = False
219
+ self.in_pre = False
220
+ self.list_depth = 0
221
+ self.ordered_list_counters: list[int] = []
222
+ self.skip_content = False
223
+ self.link_href: str | None = None
224
+
225
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
226
+ self.current_tag = tag
227
+ attrs_dict = dict(attrs)
228
+
229
+ # Skip script, style, nav, footer, header
230
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
231
+ self.skip_content = True
232
+ return
233
+
234
+ if tag == "h1":
235
+ self.output.append("\n# ")
236
+ elif tag == "h2":
237
+ self.output.append("\n## ")
238
+ elif tag == "h3":
239
+ self.output.append("\n### ")
240
+ elif tag == "h4":
241
+ self.output.append("\n#### ")
242
+ elif tag == "h5":
243
+ self.output.append("\n##### ")
244
+ elif tag == "h6":
245
+ self.output.append("\n###### ")
246
+ elif tag == "p":
247
+ self.output.append("\n\n")
248
+ elif tag == "br":
249
+ self.output.append("\n")
250
+ elif tag == "strong" or tag == "b":
251
+ self.output.append("**")
252
+ elif tag == "em" or tag == "i":
253
+ self.output.append("*")
254
+ elif tag == "code" and not self.in_pre:
255
+ self.output.append("`")
256
+ self.in_code_block = True
257
+ elif tag == "pre":
258
+ self.in_pre = True
259
+ # Check for language hint in class
260
+ lang = ""
261
+ if class_attr := attrs_dict.get("class"):
262
+ if "python" in class_attr.lower():
263
+ lang = "python"
264
+ elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
265
+ lang = "cpp"
266
+ elif "cuda" in class_attr.lower():
267
+ lang = "cuda"
268
+ self.output.append(f"\n```{lang}\n")
269
+ elif tag == "ul":
270
+ self.list_depth += 1
271
+ self.output.append("\n")
272
+ elif tag == "ol":
273
+ self.list_depth += 1
274
+ self.ordered_list_counters.append(1)
275
+ self.output.append("\n")
276
+ elif tag == "li":
277
+ indent = " " * (self.list_depth - 1)
278
+ if self.ordered_list_counters:
279
+ num = self.ordered_list_counters[-1]
280
+ self.output.append(f"{indent}{num}. ")
281
+ self.ordered_list_counters[-1] += 1
282
+ else:
283
+ self.output.append(f"{indent}- ")
284
+ elif tag == "a":
285
+ self.link_href = attrs_dict.get("href")
286
+ self.output.append("[")
287
+ elif tag == "img":
288
+ alt = attrs_dict.get("alt", "image")
289
+ src = attrs_dict.get("src", "")
290
+ self.output.append(f"![{alt}]({src})")
291
+ elif tag == "blockquote":
292
+ self.output.append("\n> ")
293
+ elif tag == "hr":
294
+ self.output.append("\n---\n")
295
+ elif tag == "table":
296
+ self.output.append("\n")
297
+ elif tag == "th":
298
+ self.output.append("| ")
299
+ elif tag == "td":
300
+ self.output.append("| ")
301
+ elif tag == "tr":
302
+ pass # Handled in endtag
303
+
304
+ def handle_endtag(self, tag: str) -> None:
305
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
306
+ self.skip_content = False
307
+ return
308
+
309
+ if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
310
+ self.output.append("\n")
311
+ elif tag == "strong" or tag == "b":
312
+ self.output.append("**")
313
+ elif tag == "em" or tag == "i":
314
+ self.output.append("*")
315
+ elif tag == "code" and not self.in_pre:
316
+ self.output.append("`")
317
+ self.in_code_block = False
318
+ elif tag == "pre":
319
+ self.in_pre = False
320
+ self.output.append("\n```\n")
321
+ elif tag == "ul":
322
+ self.list_depth = max(0, self.list_depth - 1)
323
+ elif tag == "ol":
324
+ self.list_depth = max(0, self.list_depth - 1)
325
+ if self.ordered_list_counters:
326
+ self.ordered_list_counters.pop()
327
+ elif tag == "li":
328
+ self.output.append("\n")
329
+ elif tag == "a":
330
+ if self.link_href:
331
+ self.output.append(f"]({self.link_href})")
332
+ else:
333
+ self.output.append("]")
334
+ self.link_href = None
335
+ elif tag == "p":
336
+ self.output.append("\n")
337
+ elif tag == "blockquote":
338
+ self.output.append("\n")
339
+ elif tag == "tr":
340
+ self.output.append("|\n")
341
+ elif tag == "thead":
342
+ # Add markdown table separator after header row
343
+ self.output.append("|---" * 10 + "|\n")
344
+
345
+ def handle_data(self, data: str) -> None:
346
+ if self.skip_content:
347
+ return
348
+ # Preserve whitespace in code blocks
349
+ if self.in_pre:
350
+ self.output.append(data)
351
+ else:
352
+ # Collapse whitespace outside code
353
+ text = re.sub(r"\s+", " ", data)
354
+ if text.strip():
355
+ self.output.append(text)
356
+
357
+ def get_markdown(self) -> str:
358
+ """Get the converted markdown, cleaned up."""
359
+ md = "".join(self.output)
360
+ # Clean up excessive newlines
361
+ md = re.sub(r"\n{3,}", "\n\n", md)
362
+ # Clean up empty table separators
363
+ md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
364
+ return md.strip()
365
+
366
+
367
+ def _html_to_markdown(html: str) -> str:
368
+ """Convert HTML to Markdown."""
369
+ parser = _HTMLToMarkdown()
370
+ parser.feed(html)
371
+ return parser.get_markdown()
372
+
373
+
172
374
  def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
173
- """Download NVIDIA docs using .md endpoint."""
375
+ """Download NVIDIA docs and convert HTML to Markdown.
376
+
377
+ NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
378
+ """
174
379
  assert config.urls is not None
175
380
  downloaded = 0
176
381
  with httpx.Client(timeout=30.0, follow_redirects=True) as client:
177
382
  for url in config.urls:
178
- md_url = f"{url}.md"
179
383
  filepath = _url_to_filepath(url, dest)
180
384
  filepath.parent.mkdir(parents=True, exist_ok=True)
181
385
  try:
182
- resp = client.get(md_url)
386
+ # Fetch HTML page directly
387
+ resp = client.get(url)
183
388
  resp.raise_for_status()
184
- filepath.write_text(resp.text)
389
+
390
+ # Convert HTML to Markdown
391
+ markdown = _html_to_markdown(resp.text)
392
+
393
+ # Add source URL as header
394
+ content = f"<!-- Source: {url} -->\n\n{markdown}"
395
+ filepath.write_text(content)
185
396
  downloaded += 1
186
397
  if verbose:
187
398
  print(f" ✓ {filepath.relative_to(dest)}")
@@ -275,6 +486,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
275
486
  return downloaded
276
487
 
277
488
 
489
+ def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
490
+ """Download from mixed sources (NVIDIA docs + GitHub repos)."""
491
+ total = 0
492
+
493
+ # Download NVIDIA markdown docs (urls)
494
+ if config.urls:
495
+ if verbose:
496
+ print(" [NVIDIA docs]")
497
+ total += _download_nvidia_md(config, dest, verbose)
498
+
499
+ # Download GitHub repos
500
+ if config.repos:
501
+ if verbose:
502
+ print(" [GitHub repos]")
503
+ total += _download_github_multi_repo(config, dest, verbose)
504
+
505
+ return total
506
+
507
+
278
508
  def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
279
509
  """Download a corpus to local cache.
280
510
 
@@ -311,6 +541,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
311
541
  count = _download_github_repo(config, dest, verbose)
312
542
  elif config.source_type == "github_multi_repo":
313
543
  count = _download_github_multi_repo(config, dest, verbose)
544
+ elif config.source_type == "mixed":
545
+ count = _download_mixed(config, dest, verbose)
314
546
  else:
315
547
  raise ValueError(f"Unknown source type: {config.source_type}")
316
548
  if verbose:
wafer/evaluate.py CHANGED
@@ -1168,11 +1168,16 @@ def _build_modal_sandbox_script(
1168
1168
  """
1169
1169
  gpu_type = target.gpu_type
1170
1170
 
1171
- # Determine PyTorch index based on GPU type
1171
+ # Determine PyTorch index and CUDA arch based on GPU type
1172
1172
  if gpu_type in ("B200", "GB200"):
1173
- torch_index = "https://download.pytorch.org/whl/nightly/cu128"
1173
+ torch_index = "https://download.pytorch.org/whl/cu130"
1174
+ cuda_arch_list = "10.0" # Blackwell (sm_100)
1175
+ elif gpu_type == "H100":
1176
+ torch_index = "https://download.pytorch.org/whl/cu130"
1177
+ cuda_arch_list = "9.0" # Hopper (sm_90)
1174
1178
  else:
1175
1179
  torch_index = "https://download.pytorch.org/whl/cu124"
1180
+ cuda_arch_list = "8.0" # Default to Ampere (sm_80)
1176
1181
 
1177
1182
  return f'''
1178
1183
  import asyncio
@@ -1190,7 +1195,7 @@ async def run_eval():
1190
1195
  "nvidia/cuda:12.9.0-devel-ubuntu22.04",
1191
1196
  add_python="3.12",
1192
1197
  )
1193
- .apt_install("git", "build-essential", "cmake")
1198
+ .apt_install("git", "build-essential", "cmake", "ripgrep")
1194
1199
  .pip_install(
1195
1200
  "torch",
1196
1201
  index_url="{torch_index}",
@@ -1203,6 +1208,12 @@ async def run_eval():
1203
1208
  )
1204
1209
  .env({{
1205
1210
  "CUDA_HOME": "/usr/local/cuda",
1211
+ # C++ compiler needs explicit include path for cuda_runtime.h
1212
+ "CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
1213
+ # Linker needs lib path
1214
+ "LIBRARY_PATH": "/usr/local/cuda/lib64",
1215
+ # Force PyTorch to compile for correct GPU architecture
1216
+ "TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
1206
1217
  }})
1207
1218
  )
1208
1219
 
@@ -2790,6 +2801,15 @@ if torch.cuda.is_available():
2790
2801
  gc.collect()
2791
2802
  torch.cuda.empty_cache()
2792
2803
  torch.cuda.reset_peak_memory_stats()
2804
+
2805
+ # Enable TF32 for fair benchmarking against reference kernels.
2806
+ # PyTorch 1.12+ disables TF32 for matmul by default, which handicaps
2807
+ # reference kernels using cuBLAS. We enable it so reference kernels
2808
+ # run at their best performance (using tensor cores when applicable).
2809
+ # This ensures speedup comparisons are against optimized baselines.
2810
+ torch.backends.cuda.matmul.allow_tf32 = True
2811
+ torch.backends.cudnn.allow_tf32 = True
2812
+ print("[KernelBench] TF32 enabled for fair benchmarking")
2793
2813
 
2794
2814
 
2795
2815
  def _calculate_timing_stats(times: list[float]) -> dict:
@@ -3453,6 +3473,368 @@ def _validate_kernelbench_files(args: KernelBenchEvaluateArgs) -> str | None:
3453
3473
  return None
3454
3474
 
3455
3475
 
3476
+ def _build_modal_kernelbench_script(
3477
+ target: ModalTarget,
3478
+ impl_code_b64: str,
3479
+ ref_code_b64: str,
3480
+ eval_script_b64: str,
3481
+ run_benchmarks: bool,
3482
+ run_defensive: bool,
3483
+ defense_code_b64: str | None,
3484
+ seed: int,
3485
+ inputs_code_b64: str | None = None,
3486
+ ) -> str:
3487
+ """Build Python script to create Modal sandbox and run KernelBench evaluation.
3488
+
3489
+ This runs in a subprocess to isolate Modal's asyncio from trio.
3490
+ """
3491
+ gpu_type = target.gpu_type
3492
+
3493
+ # Determine PyTorch index and CUDA arch based on GPU type
3494
+ if gpu_type in ("B200", "GB200"):
3495
+ torch_index = "https://download.pytorch.org/whl/cu130"
3496
+ cuda_arch_list = "10.0" # Blackwell (sm_100)
3497
+ elif gpu_type == "H100":
3498
+ # H100 uses CUDA 13.0 (matches modal_app.py)
3499
+ torch_index = "https://download.pytorch.org/whl/cu130"
3500
+ cuda_arch_list = "9.0" # Hopper (sm_90)
3501
+ else:
3502
+ torch_index = "https://download.pytorch.org/whl/cu124"
3503
+ cuda_arch_list = "8.0" # Default to Ampere (sm_80)
3504
+
3505
+ # Install CUTLASS headers (for cute/tensor.hpp and cutlass/util/*.h) from GitHub
3506
+ # The nvidia-cutlass-dsl pip package doesn't include the C++ headers needed for nvcc
3507
+ # IMPORTANT: symlink to /usr/local/cuda/include because nvcc searches there by default
3508
+ cutlass_install = '''
3509
+ .run_commands([
3510
+ # Clone CUTLASS headers from GitHub (shallow clone, full include tree)
3511
+ # Use simple shallow clone - sparse-checkout can be buggy in some environments
3512
+ "git clone --depth 1 https://github.com/NVIDIA/cutlass.git /opt/cutlass",
3513
+ # Verify the util headers exist (for debugging)
3514
+ "ls -la /opt/cutlass/include/cutlass/util/ | head -5",
3515
+ # Symlink headers to CUDA include path (nvcc searches here by default)
3516
+ "ln -sf /opt/cutlass/include/cute /usr/local/cuda/include/cute",
3517
+ "ln -sf /opt/cutlass/include/cutlass /usr/local/cuda/include/cutlass",
3518
+ ])
3519
+ .pip_install(
3520
+ "nvidia-cutlass-dsl",
3521
+ index_url="https://pypi.nvidia.com",
3522
+ extra_index_url="https://pypi.org/simple",
3523
+ )
3524
+ '''
3525
+
3526
+ inputs_write = ""
3527
+ if inputs_code_b64:
3528
+ inputs_write = f'''
3529
+ # Write custom inputs
3530
+ proc = sandbox.exec("python", "-c", f"""
3531
+ import base64
3532
+ with open('/workspace/custom_inputs.py', 'w') as f:
3533
+ f.write(base64.b64decode('{inputs_code_b64}').decode())
3534
+ print('Custom inputs written')
3535
+ """)
3536
+ proc.wait()
3537
+ '''
3538
+
3539
+ defense_write = ""
3540
+ if run_defensive and defense_code_b64:
3541
+ defense_write = f'''
3542
+ # Write defense module
3543
+ proc = sandbox.exec("python", "-c", f"""
3544
+ import base64
3545
+ with open('/workspace/defense.py', 'w') as f:
3546
+ f.write(base64.b64decode('{defense_code_b64}').decode())
3547
+ print('Defense module written')
3548
+ """)
3549
+ proc.wait()
3550
+ '''
3551
+
3552
+ # Build eval command
3553
+ eval_cmd_parts = [
3554
+ "python /workspace/kernelbench_eval.py",
3555
+ "--impl /workspace/implementation.py",
3556
+ "--reference /workspace/reference.py",
3557
+ "--output /workspace/results.json",
3558
+ f"--seed {seed}",
3559
+ ]
3560
+ if run_benchmarks:
3561
+ eval_cmd_parts.append("--benchmark")
3562
+ if run_defensive and defense_code_b64:
3563
+ eval_cmd_parts.append("--defensive")
3564
+ eval_cmd_parts.append("--defense-module /workspace/defense.py")
3565
+ if inputs_code_b64:
3566
+ eval_cmd_parts.append("--inputs /workspace/custom_inputs.py")
3567
+
3568
+ eval_cmd = " ".join(eval_cmd_parts)
3569
+
3570
+ return f'''
3571
+ import asyncio
3572
+ import base64
3573
+ import json
3574
+ import sys
3575
+ import modal
3576
+
3577
+ async def run_eval():
3578
+ app = modal.App.lookup("wafer-evaluate", create_if_missing=True)
3579
+
3580
+ # Build image with PyTorch, CUTLASS DSL and dependencies
3581
+ image = (
3582
+ modal.Image.from_registry(
3583
+ "nvidia/cuda:12.9.0-devel-ubuntu22.04",
3584
+ add_python="3.12",
3585
+ )
3586
+ .apt_install("git", "build-essential", "cmake", "ninja-build", "ripgrep")
3587
+ .pip_install(
3588
+ "torch",
3589
+ index_url="{torch_index}",
3590
+ extra_index_url="https://pypi.org/simple",
3591
+ )
3592
+ .pip_install(
3593
+ "numpy",
3594
+ "triton",
3595
+ "ninja",
3596
+ )
3597
+ {cutlass_install}
3598
+ .env({{
3599
+ "CUDA_HOME": "/usr/local/cuda",
3600
+ # C++ compiler needs explicit include path for cuda_runtime.h
3601
+ "CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
3602
+ # Linker needs lib path
3603
+ "LIBRARY_PATH": "/usr/local/cuda/lib64",
3604
+ # Force PyTorch to compile for correct GPU architecture
3605
+ "TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
3606
+ }})
3607
+ )
3608
+
3609
+ # Create sandbox
3610
+ sandbox = modal.Sandbox.create(
3611
+ app=app,
3612
+ image=image,
3613
+ gpu="{gpu_type}",
3614
+ timeout={target.timeout_seconds},
3615
+ )
3616
+
3617
+ try:
3618
+ # Create workspace directory
3619
+ sandbox.exec("mkdir", "-p", "/workspace").wait()
3620
+
3621
+ # Write files to sandbox
3622
+ proc = sandbox.exec("python", "-c", f"""
3623
+ import base64
3624
+ with open('/workspace/implementation.py', 'w') as f:
3625
+ f.write(base64.b64decode('{impl_code_b64}').decode())
3626
+ with open('/workspace/reference.py', 'w') as f:
3627
+ f.write(base64.b64decode('{ref_code_b64}').decode())
3628
+ with open('/workspace/kernelbench_eval.py', 'w') as f:
3629
+ f.write(base64.b64decode('{eval_script_b64}').decode())
3630
+ print('Files written')
3631
+ """)
3632
+ proc.wait()
3633
+ if proc.returncode != 0:
3634
+ print(json.dumps({{"success": False, "error": f"Failed to write files: {{proc.stderr.read()}}"}}))
3635
+ return
3636
+ {inputs_write}
3637
+ {defense_write}
3638
+ # Run evaluation
3639
+ print(f"Running KernelBench evaluation on {{'{gpu_type}'}}...")
3640
+ proc = sandbox.exec("bash", "-c", "{eval_cmd}")
3641
+
3642
+ # Stream output
3643
+ for line in proc.stdout:
3644
+ print(line, end="")
3645
+ for line in proc.stderr:
3646
+ print(line, end="", file=sys.stderr)
3647
+
3648
+ proc.wait()
3649
+
3650
+ if proc.returncode != 0:
3651
+ print(json.dumps({{"success": False, "error": f"Evaluation failed with exit code {{proc.returncode}}"}}))
3652
+ return
3653
+
3654
+ # Read results
3655
+ result_proc = sandbox.exec("cat", "/workspace/results.json")
3656
+ result_data = result_proc.stdout.read()
3657
+ result_proc.wait()
3658
+
3659
+ if result_data:
3660
+ results = json.loads(result_data)
3661
+ print("EVAL_RESULT_JSON:" + json.dumps(results))
3662
+ else:
3663
+ print(json.dumps({{"success": False, "error": "No results.json found"}}))
3664
+
3665
+ finally:
3666
+ sandbox.terminate()
3667
+
3668
+ asyncio.run(run_eval())
3669
+ '''
3670
+
3671
+
3672
+ async def run_evaluate_kernelbench_modal(
3673
+ args: KernelBenchEvaluateArgs,
3674
+ target: ModalTarget,
3675
+ ) -> EvaluateResult:
3676
+ """Run KernelBench format evaluation on Modal sandbox.
3677
+
3678
+ Creates a Modal sandbox, uploads files, runs KernelBench eval, and parses results.
3679
+ Uses subprocess to isolate Modal's asyncio from trio.
3680
+ """
3681
+ import base64
3682
+ import subprocess
3683
+ import sys
3684
+
3685
+ import trio
3686
+
3687
+ print(f"Creating Modal sandbox ({target.gpu_type}) for KernelBench evaluation...")
3688
+
3689
+ # Encode files as base64
3690
+ impl_code_b64 = base64.b64encode(args.implementation.read_bytes()).decode()
3691
+ ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
3692
+ eval_script_b64 = base64.b64encode(KERNELBENCH_EVAL_SCRIPT.encode()).decode()
3693
+
3694
+ # Encode custom inputs if provided
3695
+ inputs_code_b64 = None
3696
+ if args.inputs:
3697
+ inputs_code_b64 = base64.b64encode(args.inputs.read_bytes()).decode()
3698
+
3699
+ # Encode defense module if defensive mode is enabled
3700
+ defense_code_b64 = None
3701
+ if args.defensive:
3702
+ defense_path = (
3703
+ Path(__file__).parent.parent.parent.parent
3704
+ / "packages"
3705
+ / "wafer-core"
3706
+ / "wafer_core"
3707
+ / "utils"
3708
+ / "kernel_utils"
3709
+ / "defense.py"
3710
+ )
3711
+ if defense_path.exists():
3712
+ defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
3713
+ else:
3714
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
3715
+
3716
+ # Build the script
3717
+ script = _build_modal_kernelbench_script(
3718
+ target=target,
3719
+ impl_code_b64=impl_code_b64,
3720
+ ref_code_b64=ref_code_b64,
3721
+ eval_script_b64=eval_script_b64,
3722
+ run_benchmarks=args.benchmark,
3723
+ run_defensive=args.defensive,
3724
+ defense_code_b64=defense_code_b64,
3725
+ seed=args.seed,
3726
+ inputs_code_b64=inputs_code_b64,
3727
+ )
3728
+
3729
+ def _run_subprocess() -> tuple[str, str, int]:
3730
+ result = subprocess.run(
3731
+ [sys.executable, "-c", script],
3732
+ capture_output=True,
3733
+ text=True,
3734
+ timeout=target.timeout_seconds + 120, # Extra buffer for sandbox creation + image build
3735
+ )
3736
+ return result.stdout, result.stderr, result.returncode
3737
+
3738
+ try:
3739
+ stdout, stderr, returncode = await trio.to_thread.run_sync(_run_subprocess)
3740
+ except subprocess.TimeoutExpired:
3741
+ return EvaluateResult(
3742
+ success=False,
3743
+ all_correct=False,
3744
+ correctness_score=0.0,
3745
+ geomean_speedup=0.0,
3746
+ passed_tests=0,
3747
+ total_tests=0,
3748
+ error_message=f"Modal KernelBench evaluation timed out after {target.timeout_seconds}s",
3749
+ )
3750
+ except Exception as e:
3751
+ return EvaluateResult(
3752
+ success=False,
3753
+ all_correct=False,
3754
+ correctness_score=0.0,
3755
+ geomean_speedup=0.0,
3756
+ passed_tests=0,
3757
+ total_tests=0,
3758
+ error_message=f"Failed to run Modal sandbox: {e}",
3759
+ )
3760
+
3761
+ # Print output for debugging
3762
+ if stdout:
3763
+ for line in stdout.split("\n"):
3764
+ if not line.startswith("EVAL_RESULT_JSON:"):
3765
+ print(line)
3766
+ if stderr:
3767
+ print(stderr, file=sys.stderr)
3768
+
3769
+ if returncode != 0:
3770
+ return EvaluateResult(
3771
+ success=False,
3772
+ all_correct=False,
3773
+ correctness_score=0.0,
3774
+ geomean_speedup=0.0,
3775
+ passed_tests=0,
3776
+ total_tests=0,
3777
+ error_message=f"Modal sandbox failed (exit {returncode}): {stderr or stdout}",
3778
+ )
3779
+
3780
+ # Parse results from stdout
3781
+ result_json = None
3782
+ for line in stdout.split("\n"):
3783
+ if line.startswith("EVAL_RESULT_JSON:"):
3784
+ result_json = line[len("EVAL_RESULT_JSON:"):]
3785
+ break
3786
+
3787
+ if not result_json:
3788
+ return EvaluateResult(
3789
+ success=False,
3790
+ all_correct=False,
3791
+ correctness_score=0.0,
3792
+ geomean_speedup=0.0,
3793
+ passed_tests=0,
3794
+ total_tests=0,
3795
+ error_message="No results found in Modal output",
3796
+ )
3797
+
3798
+ try:
3799
+ results = json.loads(result_json)
3800
+ except json.JSONDecodeError as e:
3801
+ return EvaluateResult(
3802
+ success=False,
3803
+ all_correct=False,
3804
+ correctness_score=0.0,
3805
+ geomean_speedup=0.0,
3806
+ passed_tests=0,
3807
+ total_tests=0,
3808
+ error_message=f"Failed to parse results JSON: {e}",
3809
+ )
3810
+
3811
+ # Check for error in results
3812
+ if "error" in results and results.get("success") is False:
3813
+ return EvaluateResult(
3814
+ success=False,
3815
+ all_correct=False,
3816
+ correctness_score=0.0,
3817
+ geomean_speedup=0.0,
3818
+ passed_tests=0,
3819
+ total_tests=0,
3820
+ error_message=results.get("error", "Unknown error"),
3821
+ )
3822
+
3823
+ # Extract metrics from results
3824
+ return EvaluateResult(
3825
+ success=True,
3826
+ all_correct=results.get("all_correct", False),
3827
+ correctness_score=float(results.get("correctness_score", 0.0)),
3828
+ geomean_speedup=float(results.get("geomean_speedup", 0.0)),
3829
+ passed_tests=int(results.get("passed_tests", 0)),
3830
+ total_tests=int(results.get("total_tests", 0)),
3831
+ error_message=results.get("error"),
3832
+ test_results=results.get("test_results", []),
3833
+ compilation_time_s=results.get("compilation_time_s"),
3834
+ profiling_stats=results.get("profiling_stats"),
3835
+ )
3836
+
3837
+
3456
3838
  async def run_evaluate_kernelbench_docker(
3457
3839
  args: KernelBenchEvaluateArgs,
3458
3840
  target: BaremetalTarget | VMTarget,
@@ -4246,6 +4628,20 @@ async def run_evaluate_kernelbench_runpod(
4246
4628
  )
4247
4629
 
4248
4630
 
4631
+ async def run_evaluate_kernelbench_baremetal_direct(
4632
+ args: KernelBenchEvaluateArgs,
4633
+ target: BaremetalTarget,
4634
+ ) -> EvaluateResult:
4635
+ """Run KernelBench format evaluation directly on NVIDIA target (no Docker).
4636
+
4637
+ For targets that already have PyTorch/CUDA installed (e.g., workspace containers).
4638
+ Uses CUDA_VISIBLE_DEVICES for GPU selection.
4639
+ """
4640
+ # Reuse the AMD function but with CUDA env vars
4641
+ # The logic is identical, just the GPU env var is different
4642
+ return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="CUDA_VISIBLE_DEVICES")
4643
+
4644
+
4249
4645
  async def run_evaluate_kernelbench_baremetal_amd(
4250
4646
  args: KernelBenchEvaluateArgs,
4251
4647
  target: BaremetalTarget,
@@ -4255,6 +4651,18 @@ async def run_evaluate_kernelbench_baremetal_amd(
4255
4651
  Runs evaluation script directly on host (no Docker) for AMD GPUs
4256
4652
  that have PyTorch/ROCm installed.
4257
4653
  """
4654
+ return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="HIP_VISIBLE_DEVICES")
4655
+
4656
+
4657
+ async def _run_evaluate_kernelbench_baremetal_direct_impl(
4658
+ args: KernelBenchEvaluateArgs,
4659
+ target: BaremetalTarget,
4660
+ gpu_env_var: str = "HIP_VISIBLE_DEVICES",
4661
+ ) -> EvaluateResult:
4662
+ """Internal implementation for direct baremetal evaluation.
4663
+
4664
+ Runs evaluation script directly on host (no Docker).
4665
+ """
4258
4666
  from datetime import datetime
4259
4667
 
4260
4668
  from wafer_core.async_ssh import AsyncSSHClient
@@ -4405,11 +4813,15 @@ async def run_evaluate_kernelbench_baremetal_amd(
4405
4813
 
4406
4814
  eval_cmd = " ".join(python_cmd_parts)
4407
4815
 
4408
- # Set environment for AMD GPU and run
4409
- # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
4410
- rocm_arch = _get_rocm_arch(target.compute_capability)
4411
- arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4412
- env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4816
+ # Set environment for GPU and run
4817
+ if gpu_env_var == "HIP_VISIBLE_DEVICES":
4818
+ # AMD: PYTORCH_ROCM_ARCH for faster compile
4819
+ rocm_arch = _get_rocm_arch(target.compute_capability)
4820
+ arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4821
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4822
+ else:
4823
+ # NVIDIA: just set CUDA_VISIBLE_DEVICES
4824
+ env_vars = f"CUDA_VISIBLE_DEVICES={gpu_id} PYTHONUNBUFFERED=1"
4413
4825
  full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4414
4826
 
4415
4827
  # Handle prepare-only mode
@@ -4560,10 +4972,16 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
4560
4972
  elif isinstance(target, RunPodTarget):
4561
4973
  # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4562
4974
  return await run_evaluate_kernelbench_runpod(args, target)
4975
+ elif isinstance(target, ModalTarget):
4976
+ # Modal serverless - runs in Modal sandbox
4977
+ return await run_evaluate_kernelbench_modal(args, target)
4563
4978
  elif isinstance(target, BaremetalTarget | VMTarget):
4564
4979
  # Check if this is an AMD target (gfx* compute capability) - run directly
4565
4980
  if target.compute_capability and target.compute_capability.startswith("gfx"):
4566
4981
  return await run_evaluate_kernelbench_baremetal_amd(args, target)
4982
+ # Check for direct execution flag (workspace containers that already have everything)
4983
+ if getattr(target, "direct", False):
4984
+ return await run_evaluate_kernelbench_baremetal_direct(args, target)
4567
4985
  # NVIDIA targets - require docker_image to be set
4568
4986
  if not target.docker_image:
4569
4987
  return EvaluateResult(
@@ -68,4 +68,6 @@ IMPORTANT: Always verify correctness with wafer evaluate before claiming success
68
68
  "kernel": "./kernel.cu",
69
69
  "target": "H100",
70
70
  },
71
+ # Enable skill discovery (agent can load wafer-guide, etc.)
72
+ include_skills=True,
71
73
  )
wafer/wevin_cli.py CHANGED
@@ -274,7 +274,12 @@ def _build_environment(
274
274
  from wafer_core.sandbox import SandboxMode
275
275
 
276
276
  working_dir = Path(corpus_path) if corpus_path else Path.cwd()
277
- resolved_tools = tools_override or tpl.tools
277
+ resolved_tools = list(tools_override or tpl.tools)
278
+
279
+ # Add skill tool if skills are enabled
280
+ if tpl.include_skills and "skill" not in resolved_tools:
281
+ resolved_tools.append("skill")
282
+
278
283
  sandbox_mode = SandboxMode.DISABLED if no_sandbox else SandboxMode.ENABLED
279
284
  env: Environment = CodingEnvironment(
280
285
  working_dir=working_dir,
@@ -378,6 +383,7 @@ def main( # noqa: PLR0913, PLR0915
378
383
 
379
384
  # Handle --get-session: load session by ID and print
380
385
  if get_session:
386
+
381
387
  async def _get_session() -> None:
382
388
  try:
383
389
  session, err = await session_store.get(get_session)
@@ -398,16 +404,18 @@ def main( # noqa: PLR0913, PLR0915
398
404
  error_msg = f"Failed to serialize messages: {e}"
399
405
  print(json.dumps({"error": error_msg}))
400
406
  sys.exit(1)
401
-
402
- print(json.dumps({
403
- "session_id": session.session_id,
404
- "status": session.status.value,
405
- "model": session.endpoint.model if session.endpoint else None,
406
- "created_at": session.created_at,
407
- "updated_at": session.updated_at,
408
- "messages": messages_data,
409
- "tags": session.tags,
410
- }))
407
+
408
+ print(
409
+ json.dumps({
410
+ "session_id": session.session_id,
411
+ "status": session.status.value,
412
+ "model": session.endpoint.model if session.endpoint else None,
413
+ "created_at": session.created_at,
414
+ "updated_at": session.updated_at,
415
+ "messages": messages_data,
416
+ "tags": session.tags,
417
+ })
418
+ )
411
419
  else:
412
420
  print(f"Session: {session.session_id}")
413
421
  print(f"Status: {session.status.value}")
@@ -495,7 +503,7 @@ def main( # noqa: PLR0913, PLR0915
495
503
  print(f"Error loading template: {err}", file=sys.stderr)
496
504
  sys.exit(1)
497
505
  tpl = loaded_template
498
- system_prompt = tpl.interpolate_prompt(template_args or {})
506
+ base_system_prompt = tpl.interpolate_prompt(template_args or {})
499
507
  # Show template info when starting without a prompt
500
508
  if not prompt and tpl.description:
501
509
  print(f"Template: {tpl.name}", file=sys.stderr)
@@ -503,7 +511,20 @@ def main( # noqa: PLR0913, PLR0915
503
511
  print(file=sys.stderr)
504
512
  else:
505
513
  tpl = _get_default_template()
506
- system_prompt = tpl.system_prompt
514
+ base_system_prompt = tpl.system_prompt
515
+
516
+ # Append skill metadata if skills are enabled
517
+ if tpl.include_skills:
518
+ from wafer_core.rollouts.skills import discover_skills, format_skill_metadata_for_prompt
519
+
520
+ skill_metadata = discover_skills()
521
+ if skill_metadata:
522
+ skill_section = format_skill_metadata_for_prompt(skill_metadata)
523
+ system_prompt = base_system_prompt + "\n\n" + skill_section
524
+ else:
525
+ system_prompt = base_system_prompt
526
+ else:
527
+ system_prompt = base_system_prompt
507
528
 
508
529
  # CLI args override template values
509
530
  resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
@@ -550,7 +571,7 @@ def main( # noqa: PLR0913, PLR0915
550
571
  else:
551
572
  if json_output:
552
573
  # Emit session_start if we have a session_id (from --resume)
553
- model_name = endpoint.model if hasattr(endpoint, 'model') else None
574
+ model_name = endpoint.model if hasattr(endpoint, "model") else None
554
575
  frontend = StreamingChunkFrontend(session_id=session_id, model=model_name)
555
576
  else:
556
577
  frontend = NoneFrontend(show_tool_calls=True, show_thinking=False)
@@ -565,9 +586,11 @@ def main( # noqa: PLR0913, PLR0915
565
586
  # Emit session_start for new sessions (if session_id was None and we got one)
566
587
  # Check first state to emit as early as possible
567
588
  if json_output and isinstance(frontend, StreamingChunkFrontend):
568
- first_session_id = states[0].session_id if states and states[0].session_id else None
589
+ first_session_id = (
590
+ states[0].session_id if states and states[0].session_id else None
591
+ )
569
592
  if first_session_id and not session_id: # New session created
570
- model_name = endpoint.model if hasattr(endpoint, 'model') else None
593
+ model_name = endpoint.model if hasattr(endpoint, "model") else None
571
594
  frontend.emit_session_start(first_session_id, model_name)
572
595
  # Print resume command with full wafer agent prefix
573
596
  if states and states[-1].session_id:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.20
3
+ Version: 0.2.22
4
4
  Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: typer>=0.12.0
@@ -5,10 +5,10 @@ wafer/api_client.py,sha256=i_Az2b2llC3DSW8yOL-BKqa7LSKuxOr8hSN40s-oQXY,6313
5
5
  wafer/auth.py,sha256=dwss_se5P-FFc9IN38q4kh_dBrA6k-CguDBkivgcdj0,14003
6
6
  wafer/autotuner.py,sha256=41WYP41pTDvMijv2h42vm89bcHtDMJXObDlWmn6xpFU,44416
7
7
  wafer/billing.py,sha256=jbLB2lI4_9f2KD8uEFDi_ixLlowe5hasC0TIZJyIXRg,7163
8
- wafer/cli.py,sha256=cNScdwOsyaSHnaRPtzSIcES6IEx4kWpMqMpZMIbrp3g,254768
8
+ wafer/cli.py,sha256=j4ODOVT_r-kyc21YOI8Yl8bkiZMGuqDpXRs7CvpNaek,261443
9
9
  wafer/config.py,sha256=h5Eo9_yfWqWGoPNdVQikI9GoZVUeysunSYiixf1mKcw,3411
10
- wafer/corpus.py,sha256=x5aFhCsTSAtgzFG9AMFpqq92Ej63mXofL-vvvpjj1sM,12913
11
- wafer/evaluate.py,sha256=s1NszUBtxdWRonbi8YR3XWfCiCjNm14g2Pp1lu4kmtY,176125
10
+ wafer/corpus.py,sha256=oQegXA43MuyRvYxOsWhmqeP5vMb5IKFHOvM-1RcahPA,22301
11
+ wafer/evaluate.py,sha256=SxxhiPkO6aDdfktRzJXpbWMVmIGn_gw-o5C6Zwj2zRc,190930
12
12
  wafer/global_config.py,sha256=fhaR_RU3ufMksDmOohH1OLeQ0JT0SDW1hEip_zaP75k,11345
13
13
  wafer/gpu_run.py,sha256=TwqXy72T7f2I7e6n5WWod3xgxCPnDhU0BgLsB4CUoQY,9716
14
14
  wafer/inference.py,sha256=tZCO5i05FKY27ewis3CSBHFBeFbXY3xwj0DSjdoMY9s,4314
@@ -26,16 +26,16 @@ wafer/target_lock.py,sha256=SDKhNzv2N7gsphGflcNni9FE5YYuAMuEthngAJEo4Gs,7809
26
26
  wafer/targets.py,sha256=9r-iRWoKSH5cQl1LcamaX-T7cNVOg99ngIm_hlRk-qU,26922
27
27
  wafer/targets_ops.py,sha256=jN1oIBx0mutxRNE9xpIc7SaBxPkVmOyus2eqn0kEKNI,21475
28
28
  wafer/tracelens.py,sha256=g9ZIeFyNojZn4uTd3skPqIrRiL7aMJOz_-GOd3aiyy4,7998
29
- wafer/wevin_cli.py,sha256=VnGVt__7kpVe2n_UctURSIpael_2TgsAwmqoQjz6CN0,22412
29
+ wafer/wevin_cli.py,sha256=Nuk7zTCiJrnpmYtdg5Hu0NbzONCqs54xtON6K7AVB9U,23189
30
30
  wafer/workspaces.py,sha256=iUdioK7kA3z_gOTMNVDn9Q87c6qpkdXF4bOhJWkUPg8,32375
31
31
  wafer/skills/wafer-guide/SKILL.md,sha256=KWetJw2TVTbz11_nzqazqOJWWRlbHRFShs4sOoreiWo,3255
32
32
  wafer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
33
  wafer/templates/ask_docs.py,sha256=Lxs-faz9v5m4Qa4NjF2X_lE8KwM9ES9MNJkxo7ep56o,2256
34
- wafer/templates/optimize_kernel.py,sha256=u6AL7Q3uttqlnBLzcoFdsiPq5lV2TV3bgqwCYYlK9gk,2357
34
+ wafer/templates/optimize_kernel.py,sha256=OvZgN5tm_OymO3lK8Dr0VO48e-5PfNVIIoACrPxpmqk,2446
35
35
  wafer/templates/optimize_kernelbench.py,sha256=aoOA13zWEl89r6QW03xF9NKxQ7j4mWe9rwua6-mlr4Y,4780
36
36
  wafer/templates/trace_analyze.py,sha256=XE1VqzVkIUsZbXF8EzQdDYgg-AZEYAOFpr6B_vnRELc,2880
37
- wafer_cli-0.2.20.dist-info/METADATA,sha256=rZ94ea_wCkSGAhT0X1wN9DFhCr5ojeXucvROQLX0Ox4,560
38
- wafer_cli-0.2.20.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
39
- wafer_cli-0.2.20.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
40
- wafer_cli-0.2.20.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
41
- wafer_cli-0.2.20.dist-info/RECORD,,
37
+ wafer_cli-0.2.22.dist-info/METADATA,sha256=vjYzyQtphWxQ0JID0k5tFWoLwVjlR6X0B4UAuMhLhQc,560
38
+ wafer_cli-0.2.22.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
39
+ wafer_cli-0.2.22.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
40
+ wafer_cli-0.2.22.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
41
+ wafer_cli-0.2.22.dist-info/RECORD,,