wafer-cli 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/cli.py +205 -22
- wafer/corpus.py +241 -9
- wafer/evaluate.py +426 -8
- wafer/templates/optimize_kernel.py +2 -0
- wafer/wevin_cli.py +39 -16
- {wafer_cli-0.2.20.dist-info → wafer_cli-0.2.22.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.20.dist-info → wafer_cli-0.2.22.dist-info}/RECORD +10 -10
- {wafer_cli-0.2.20.dist-info → wafer_cli-0.2.22.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.20.dist-info → wafer_cli-0.2.22.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.20.dist-info → wafer_cli-0.2.22.dist-info}/top_level.txt +0 -0
wafer/cli.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
# ruff: noqa: PLR0913
|
|
1
|
+
# ruff: noqa: PLR0913, E402
|
|
2
2
|
# PLR0913 (too many arguments) is suppressed because Typer CLI commands
|
|
3
3
|
# naturally have many parameters - each --flag becomes a function argument.
|
|
4
|
+
# E402 (module level import not at top) is suppressed because we intentionally
|
|
5
|
+
# load .env files before importing other modules that may read env vars.
|
|
4
6
|
"""Wafer CLI - GPU development toolkit for LLM coding agents.
|
|
5
7
|
|
|
6
8
|
Core commands:
|
|
@@ -27,6 +29,12 @@ from pathlib import Path
|
|
|
27
29
|
|
|
28
30
|
import trio
|
|
29
31
|
import typer
|
|
32
|
+
from dotenv import load_dotenv
|
|
33
|
+
|
|
34
|
+
# Auto-load .env from current directory and ~/.wafer/.env
|
|
35
|
+
# This runs at import time so env vars are available before any config is accessed
|
|
36
|
+
load_dotenv() # cwd/.env
|
|
37
|
+
load_dotenv(Path.home() / ".wafer" / ".env") # ~/.wafer/.env
|
|
30
38
|
|
|
31
39
|
from .config import WaferConfig, WaferEnvironment
|
|
32
40
|
from .inference import infer_upload_files, resolve_environment
|
|
@@ -42,6 +50,7 @@ from .problems import (
|
|
|
42
50
|
app = typer.Typer(
|
|
43
51
|
help="GPU development toolkit for LLM coding agents",
|
|
44
52
|
no_args_is_help=True,
|
|
53
|
+
pretty_exceptions_show_locals=False, # Don't dump local vars (makes tracebacks huge)
|
|
45
54
|
)
|
|
46
55
|
|
|
47
56
|
# =============================================================================
|
|
@@ -58,11 +67,11 @@ def _show_version() -> None:
|
|
|
58
67
|
"""Show CLI version and environment, then exit."""
|
|
59
68
|
from .analytics import _get_cli_version
|
|
60
69
|
from .global_config import load_global_config
|
|
61
|
-
|
|
70
|
+
|
|
62
71
|
version = _get_cli_version()
|
|
63
72
|
config = load_global_config()
|
|
64
73
|
environment = config.environment
|
|
65
|
-
|
|
74
|
+
|
|
66
75
|
typer.echo(f"wafer-cli {version} ({environment})")
|
|
67
76
|
raise typer.Exit()
|
|
68
77
|
|
|
@@ -110,7 +119,7 @@ def main_callback(
|
|
|
110
119
|
if version:
|
|
111
120
|
_show_version()
|
|
112
121
|
return
|
|
113
|
-
|
|
122
|
+
|
|
114
123
|
global _command_start_time, _command_outcome
|
|
115
124
|
_command_start_time = time.time()
|
|
116
125
|
_command_outcome = "success" # Default to success, mark failure on exceptions
|
|
@@ -121,6 +130,7 @@ def main_callback(
|
|
|
121
130
|
analytics.init_analytics()
|
|
122
131
|
|
|
123
132
|
# Install exception hook to catch SystemExit and mark failures
|
|
133
|
+
# Also prints error message FIRST so it's visible even when traceback is truncated
|
|
124
134
|
original_excepthook = sys.excepthook
|
|
125
135
|
|
|
126
136
|
def custom_excepthook(
|
|
@@ -136,7 +146,11 @@ def main_callback(
|
|
|
136
146
|
_command_outcome = "failure"
|
|
137
147
|
else:
|
|
138
148
|
_command_outcome = "failure"
|
|
139
|
-
|
|
149
|
+
# Print error summary FIRST (before traceback) so it's visible even if truncated
|
|
150
|
+
print(
|
|
151
|
+
f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
|
|
152
|
+
)
|
|
153
|
+
# Call original excepthook (prints the full traceback)
|
|
140
154
|
original_excepthook(exc_type, exc_value, exc_traceback)
|
|
141
155
|
|
|
142
156
|
sys.excepthook = custom_excepthook
|
|
@@ -591,7 +605,7 @@ app.add_typer(provider_auth_app, name="auth")
|
|
|
591
605
|
def provider_auth_login(
|
|
592
606
|
provider: str = typer.Argument(
|
|
593
607
|
...,
|
|
594
|
-
help="Provider name: runpod, digitalocean, or
|
|
608
|
+
help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
|
|
595
609
|
),
|
|
596
610
|
api_key: str | None = typer.Option(
|
|
597
611
|
None,
|
|
@@ -600,15 +614,16 @@ def provider_auth_login(
|
|
|
600
614
|
help="API key (if not provided, reads from stdin)",
|
|
601
615
|
),
|
|
602
616
|
) -> None:
|
|
603
|
-
"""Save API key for a
|
|
617
|
+
"""Save API key for a provider.
|
|
604
618
|
|
|
605
619
|
Stores the key in ~/.wafer/auth.json. Environment variables
|
|
606
|
-
(e.g.,
|
|
620
|
+
(e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
|
|
607
621
|
|
|
608
622
|
Examples:
|
|
623
|
+
wafer auth login anthropic --api-key sk-ant-xxx
|
|
609
624
|
wafer auth login runpod --api-key rp_xxx
|
|
610
|
-
wafer auth login
|
|
611
|
-
echo $API_KEY | wafer auth login
|
|
625
|
+
wafer auth login openai --api-key sk-xxx
|
|
626
|
+
echo $API_KEY | wafer auth login anthropic
|
|
612
627
|
"""
|
|
613
628
|
import sys
|
|
614
629
|
|
|
@@ -642,7 +657,7 @@ def provider_auth_login(
|
|
|
642
657
|
def provider_auth_logout(
|
|
643
658
|
provider: str = typer.Argument(
|
|
644
659
|
...,
|
|
645
|
-
help="Provider name: runpod, digitalocean, or
|
|
660
|
+
help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
|
|
646
661
|
),
|
|
647
662
|
) -> None:
|
|
648
663
|
"""Remove stored API key for a cloud GPU provider.
|
|
@@ -3473,7 +3488,7 @@ def init_runpod(
|
|
|
3473
3488
|
gpu_configs = {
|
|
3474
3489
|
"MI300X": {
|
|
3475
3490
|
"gpu_type_id": "AMD Instinct MI300X OAM",
|
|
3476
|
-
"image": "
|
|
3491
|
+
"image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
|
|
3477
3492
|
"compute_capability": "9.4",
|
|
3478
3493
|
},
|
|
3479
3494
|
"H100": {
|
|
@@ -3569,7 +3584,7 @@ def init_digitalocean(
|
|
|
3569
3584
|
"ssh_key": ssh_key,
|
|
3570
3585
|
"region": region,
|
|
3571
3586
|
"size_slug": "gpu-mi300x1-192gb-devcloud",
|
|
3572
|
-
"image": "
|
|
3587
|
+
"image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
|
|
3573
3588
|
"provision_timeout": 600,
|
|
3574
3589
|
"eval_timeout": 600,
|
|
3575
3590
|
"keep_alive": keep_alive,
|
|
@@ -4071,6 +4086,164 @@ def targets_cleanup(
|
|
|
4071
4086
|
raise typer.Exit(1) from None
|
|
4072
4087
|
|
|
4073
4088
|
|
|
4089
|
+
# Known libraries that can be installed on targets
|
|
4090
|
+
# TODO: Consider adding HipKittens to the default RunPod/DO Docker images
|
|
4091
|
+
# so this install step isn't needed. For now, this command handles it.
|
|
4092
|
+
INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
|
|
4093
|
+
"hipkittens": {
|
|
4094
|
+
"description": "HipKittens - AMD port of ThunderKittens for MI300X",
|
|
4095
|
+
"git_url": "https://github.com/HazyResearch/hipkittens.git",
|
|
4096
|
+
"install_path": "/opt/hipkittens",
|
|
4097
|
+
"requires_amd": True,
|
|
4098
|
+
},
|
|
4099
|
+
# CK is already installed with ROCm 7.0, no action needed
|
|
4100
|
+
"repair-headers": {
|
|
4101
|
+
"description": "Repair ROCm thrust headers (fixes hipify corruption)",
|
|
4102
|
+
"custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
|
|
4103
|
+
"requires_amd": True,
|
|
4104
|
+
},
|
|
4105
|
+
}
|
|
4106
|
+
|
|
4107
|
+
|
|
4108
|
+
@targets_app.command("install")
|
|
4109
|
+
def targets_install(
|
|
4110
|
+
name: str = typer.Argument(..., help="Target name"),
|
|
4111
|
+
library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
|
|
4112
|
+
) -> None:
|
|
4113
|
+
"""Install a library or run maintenance on a target (idempotent).
|
|
4114
|
+
|
|
4115
|
+
Installs header-only libraries like HipKittens on remote targets.
|
|
4116
|
+
Safe to run multiple times - will skip if already installed.
|
|
4117
|
+
|
|
4118
|
+
Available libraries:
|
|
4119
|
+
hipkittens - HipKittens (AMD ThunderKittens port)
|
|
4120
|
+
repair-headers - Fix ROCm thrust headers (after hipify corruption)
|
|
4121
|
+
|
|
4122
|
+
Examples:
|
|
4123
|
+
wafer config targets install runpod-mi300x hipkittens
|
|
4124
|
+
wafer config targets install runpod-mi300x repair-headers
|
|
4125
|
+
wafer config targets install do-mi300x hipkittens
|
|
4126
|
+
"""
|
|
4127
|
+
import subprocess
|
|
4128
|
+
|
|
4129
|
+
from .targets import load_target
|
|
4130
|
+
from .targets_ops import get_target_ssh_info
|
|
4131
|
+
|
|
4132
|
+
if library not in INSTALLABLE_LIBRARIES:
|
|
4133
|
+
available = ", ".join(INSTALLABLE_LIBRARIES.keys())
|
|
4134
|
+
typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
|
|
4135
|
+
raise typer.Exit(1)
|
|
4136
|
+
|
|
4137
|
+
lib_info = INSTALLABLE_LIBRARIES[library]
|
|
4138
|
+
|
|
4139
|
+
try:
|
|
4140
|
+
target = load_target(name)
|
|
4141
|
+
except FileNotFoundError as e:
|
|
4142
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4143
|
+
raise typer.Exit(1) from None
|
|
4144
|
+
|
|
4145
|
+
# Check if target is AMD (for AMD-only libraries)
|
|
4146
|
+
if lib_info.get("requires_amd"):
|
|
4147
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
4148
|
+
DigitalOceanTarget,
|
|
4149
|
+
RunPodTarget,
|
|
4150
|
+
)
|
|
4151
|
+
|
|
4152
|
+
is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
|
|
4153
|
+
if not is_amd and hasattr(target, "compute_capability"):
|
|
4154
|
+
# Check compute capability for MI300X (gfx942 = 9.4)
|
|
4155
|
+
is_amd = target.compute_capability.startswith("9.")
|
|
4156
|
+
if not is_amd:
|
|
4157
|
+
typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
|
|
4158
|
+
raise typer.Exit(1)
|
|
4159
|
+
|
|
4160
|
+
typer.echo(f"Installing {library} on {name}...")
|
|
4161
|
+
typer.echo(f" {lib_info['description']}")
|
|
4162
|
+
|
|
4163
|
+
async def _install() -> bool:
|
|
4164
|
+
# get_target_ssh_info uses pure trio async (no asyncio bridging needed)
|
|
4165
|
+
# and we use subprocess for SSH, not AsyncSSHClient
|
|
4166
|
+
ssh_info = await get_target_ssh_info(target)
|
|
4167
|
+
|
|
4168
|
+
ssh_cmd = [
|
|
4169
|
+
"ssh",
|
|
4170
|
+
"-o",
|
|
4171
|
+
"StrictHostKeyChecking=no",
|
|
4172
|
+
"-o",
|
|
4173
|
+
"UserKnownHostsFile=/dev/null",
|
|
4174
|
+
"-o",
|
|
4175
|
+
"ConnectTimeout=30",
|
|
4176
|
+
"-i",
|
|
4177
|
+
str(ssh_info.key_path),
|
|
4178
|
+
"-p",
|
|
4179
|
+
str(ssh_info.port),
|
|
4180
|
+
f"{ssh_info.user}@{ssh_info.host}",
|
|
4181
|
+
]
|
|
4182
|
+
|
|
4183
|
+
# Handle custom scripts (like repair-headers) vs git installs
|
|
4184
|
+
if "custom_script" in lib_info:
|
|
4185
|
+
install_script = str(lib_info["custom_script"])
|
|
4186
|
+
success_marker = "REPAIRED"
|
|
4187
|
+
else:
|
|
4188
|
+
install_path = lib_info["install_path"]
|
|
4189
|
+
git_url = lib_info["git_url"]
|
|
4190
|
+
|
|
4191
|
+
# Idempotent install script
|
|
4192
|
+
install_script = f"""
|
|
4193
|
+
if [ -d "{install_path}" ]; then
|
|
4194
|
+
echo "ALREADY_INSTALLED: {install_path} exists"
|
|
4195
|
+
cd {install_path} && git pull --quiet 2>/dev/null || true
|
|
4196
|
+
else
|
|
4197
|
+
echo "INSTALLING: cloning to {install_path}"
|
|
4198
|
+
git clone --quiet {git_url} {install_path}
|
|
4199
|
+
fi
|
|
4200
|
+
echo "DONE"
|
|
4201
|
+
"""
|
|
4202
|
+
success_marker = "DONE"
|
|
4203
|
+
|
|
4204
|
+
def run_ssh() -> subprocess.CompletedProcess[str]:
|
|
4205
|
+
return subprocess.run(
|
|
4206
|
+
ssh_cmd + [install_script],
|
|
4207
|
+
capture_output=True,
|
|
4208
|
+
text=True,
|
|
4209
|
+
timeout=120,
|
|
4210
|
+
)
|
|
4211
|
+
|
|
4212
|
+
result = await trio.to_thread.run_sync(run_ssh)
|
|
4213
|
+
|
|
4214
|
+
if result.returncode != 0:
|
|
4215
|
+
typer.echo(f"Error: {result.stderr}", err=True)
|
|
4216
|
+
return False
|
|
4217
|
+
|
|
4218
|
+
output = result.stdout.strip()
|
|
4219
|
+
if "ALREADY_INSTALLED" in output:
|
|
4220
|
+
typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
|
|
4221
|
+
elif "INSTALLING" in output:
|
|
4222
|
+
typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
|
|
4223
|
+
elif "REPAIRED" in output:
|
|
4224
|
+
typer.echo(" ROCm headers repaired")
|
|
4225
|
+
|
|
4226
|
+
return success_marker in output
|
|
4227
|
+
|
|
4228
|
+
try:
|
|
4229
|
+
success = trio.run(_install)
|
|
4230
|
+
except Exception as e:
|
|
4231
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4232
|
+
raise typer.Exit(1) from None
|
|
4233
|
+
|
|
4234
|
+
if success:
|
|
4235
|
+
typer.echo(f"✓ {library} ready on {name}")
|
|
4236
|
+
|
|
4237
|
+
# Print usage hint
|
|
4238
|
+
if library == "hipkittens":
|
|
4239
|
+
typer.echo("")
|
|
4240
|
+
typer.echo("Usage in load_inline:")
|
|
4241
|
+
typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
|
|
4242
|
+
else:
|
|
4243
|
+
typer.echo(f"Failed to install {library}", err=True)
|
|
4244
|
+
raise typer.Exit(1)
|
|
4245
|
+
|
|
4246
|
+
|
|
4074
4247
|
@targets_app.command("pods")
|
|
4075
4248
|
def targets_pods() -> None:
|
|
4076
4249
|
"""List all running RunPod pods.
|
|
@@ -4406,9 +4579,13 @@ def workspaces_list(
|
|
|
4406
4579
|
@workspaces_app.command("create")
|
|
4407
4580
|
def workspaces_create(
|
|
4408
4581
|
name: str = typer.Argument(..., help="Workspace name"),
|
|
4409
|
-
gpu_type: str = typer.Option(
|
|
4582
|
+
gpu_type: str = typer.Option(
|
|
4583
|
+
"B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"
|
|
4584
|
+
),
|
|
4410
4585
|
image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
|
|
4411
|
-
wait: bool = typer.Option(
|
|
4586
|
+
wait: bool = typer.Option(
|
|
4587
|
+
False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"
|
|
4588
|
+
),
|
|
4412
4589
|
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
4413
4590
|
) -> None:
|
|
4414
4591
|
"""Create a new workspace.
|
|
@@ -4717,19 +4894,25 @@ def workspaces_ssh(
|
|
|
4717
4894
|
ssh_host = ws.get("ssh_host")
|
|
4718
4895
|
ssh_port = ws.get("ssh_port")
|
|
4719
4896
|
ssh_user = ws.get("ssh_user")
|
|
4720
|
-
|
|
4897
|
+
|
|
4721
4898
|
if not ssh_host or not ssh_port or not ssh_user:
|
|
4722
4899
|
typer.echo("Error: Workspace not ready. Wait a few seconds and retry.", err=True)
|
|
4723
4900
|
raise typer.Exit(1)
|
|
4724
4901
|
|
|
4725
4902
|
# Connect via SSH
|
|
4726
|
-
os.execvp(
|
|
4903
|
+
os.execvp(
|
|
4727
4904
|
"ssh",
|
|
4728
|
-
|
|
4729
|
-
|
|
4730
|
-
|
|
4731
|
-
|
|
4732
|
-
|
|
4905
|
+
[
|
|
4906
|
+
"ssh",
|
|
4907
|
+
"-p",
|
|
4908
|
+
str(ssh_port),
|
|
4909
|
+
"-o",
|
|
4910
|
+
"StrictHostKeyChecking=no",
|
|
4911
|
+
"-o",
|
|
4912
|
+
"UserKnownHostsFile=/dev/null",
|
|
4913
|
+
f"{ssh_user}@{ssh_host}",
|
|
4914
|
+
],
|
|
4915
|
+
)
|
|
4733
4916
|
|
|
4734
4917
|
|
|
4735
4918
|
@workspaces_app.command("sync")
|
wafer/corpus.py
CHANGED
|
@@ -3,10 +3,12 @@
|
|
|
3
3
|
Download and manage documentation corpora for agent filesystem access.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import re
|
|
6
7
|
import shutil
|
|
7
8
|
import tarfile
|
|
8
9
|
import tempfile
|
|
9
10
|
from dataclasses import dataclass
|
|
11
|
+
from html.parser import HTMLParser
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Literal
|
|
12
14
|
from urllib.parse import urlparse
|
|
@@ -33,7 +35,7 @@ class CorpusConfig:
|
|
|
33
35
|
|
|
34
36
|
name: CorpusName
|
|
35
37
|
description: str
|
|
36
|
-
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
|
|
38
|
+
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
|
|
37
39
|
urls: list[str] | None = None
|
|
38
40
|
repo: str | None = None
|
|
39
41
|
repo_paths: list[str] | None = None
|
|
@@ -67,10 +69,43 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
|
67
69
|
),
|
|
68
70
|
"cutlass": CorpusConfig(
|
|
69
71
|
name="cutlass",
|
|
70
|
-
description="CUTLASS
|
|
71
|
-
source_type="
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
description="CUTLASS C++ documentation, examples, and tutorials",
|
|
73
|
+
source_type="mixed",
|
|
74
|
+
# Official NVIDIA CUTLASS documentation (scraped as markdown)
|
|
75
|
+
urls=[
|
|
76
|
+
"https://docs.nvidia.com/cutlass/latest/overview.html",
|
|
77
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
|
|
78
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
|
|
79
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
|
|
80
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
|
|
81
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
|
|
82
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
|
|
83
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
|
|
84
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
|
|
85
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
|
|
86
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
|
|
87
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
|
|
88
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
|
|
89
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
|
|
90
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
|
|
91
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
|
|
92
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
|
|
93
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
|
|
94
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
|
|
95
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
|
|
96
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
|
|
97
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
|
|
98
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
|
|
99
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
|
|
100
|
+
],
|
|
101
|
+
# NVIDIA/cutlass GitHub examples (excluding python/)
|
|
102
|
+
repos=[
|
|
103
|
+
RepoSource(
|
|
104
|
+
repo="NVIDIA/cutlass",
|
|
105
|
+
paths=["examples"],
|
|
106
|
+
branch="main",
|
|
107
|
+
),
|
|
108
|
+
],
|
|
74
109
|
),
|
|
75
110
|
"hip": CorpusConfig(
|
|
76
111
|
name="hip",
|
|
@@ -169,19 +204,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
|
|
|
169
204
|
return base_dir / "/".join(path_parts)
|
|
170
205
|
|
|
171
206
|
|
|
207
|
+
class _HTMLToMarkdown(HTMLParser):
|
|
208
|
+
"""HTML to Markdown converter for NVIDIA documentation pages.
|
|
209
|
+
|
|
210
|
+
Uses stdlib HTMLParser - requires subclassing due to callback-based API.
|
|
211
|
+
The public interface is the functional `_html_to_markdown()` below.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self) -> None:
|
|
215
|
+
super().__init__()
|
|
216
|
+
self.output: list[str] = []
|
|
217
|
+
self.current_tag: str = ""
|
|
218
|
+
self.in_code_block = False
|
|
219
|
+
self.in_pre = False
|
|
220
|
+
self.list_depth = 0
|
|
221
|
+
self.ordered_list_counters: list[int] = []
|
|
222
|
+
self.skip_content = False
|
|
223
|
+
self.link_href: str | None = None
|
|
224
|
+
|
|
225
|
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
|
226
|
+
self.current_tag = tag
|
|
227
|
+
attrs_dict = dict(attrs)
|
|
228
|
+
|
|
229
|
+
# Skip script, style, nav, footer, header
|
|
230
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
231
|
+
self.skip_content = True
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
if tag == "h1":
|
|
235
|
+
self.output.append("\n# ")
|
|
236
|
+
elif tag == "h2":
|
|
237
|
+
self.output.append("\n## ")
|
|
238
|
+
elif tag == "h3":
|
|
239
|
+
self.output.append("\n### ")
|
|
240
|
+
elif tag == "h4":
|
|
241
|
+
self.output.append("\n#### ")
|
|
242
|
+
elif tag == "h5":
|
|
243
|
+
self.output.append("\n##### ")
|
|
244
|
+
elif tag == "h6":
|
|
245
|
+
self.output.append("\n###### ")
|
|
246
|
+
elif tag == "p":
|
|
247
|
+
self.output.append("\n\n")
|
|
248
|
+
elif tag == "br":
|
|
249
|
+
self.output.append("\n")
|
|
250
|
+
elif tag == "strong" or tag == "b":
|
|
251
|
+
self.output.append("**")
|
|
252
|
+
elif tag == "em" or tag == "i":
|
|
253
|
+
self.output.append("*")
|
|
254
|
+
elif tag == "code" and not self.in_pre:
|
|
255
|
+
self.output.append("`")
|
|
256
|
+
self.in_code_block = True
|
|
257
|
+
elif tag == "pre":
|
|
258
|
+
self.in_pre = True
|
|
259
|
+
# Check for language hint in class
|
|
260
|
+
lang = ""
|
|
261
|
+
if class_attr := attrs_dict.get("class"):
|
|
262
|
+
if "python" in class_attr.lower():
|
|
263
|
+
lang = "python"
|
|
264
|
+
elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
|
|
265
|
+
lang = "cpp"
|
|
266
|
+
elif "cuda" in class_attr.lower():
|
|
267
|
+
lang = "cuda"
|
|
268
|
+
self.output.append(f"\n```{lang}\n")
|
|
269
|
+
elif tag == "ul":
|
|
270
|
+
self.list_depth += 1
|
|
271
|
+
self.output.append("\n")
|
|
272
|
+
elif tag == "ol":
|
|
273
|
+
self.list_depth += 1
|
|
274
|
+
self.ordered_list_counters.append(1)
|
|
275
|
+
self.output.append("\n")
|
|
276
|
+
elif tag == "li":
|
|
277
|
+
indent = " " * (self.list_depth - 1)
|
|
278
|
+
if self.ordered_list_counters:
|
|
279
|
+
num = self.ordered_list_counters[-1]
|
|
280
|
+
self.output.append(f"{indent}{num}. ")
|
|
281
|
+
self.ordered_list_counters[-1] += 1
|
|
282
|
+
else:
|
|
283
|
+
self.output.append(f"{indent}- ")
|
|
284
|
+
elif tag == "a":
|
|
285
|
+
self.link_href = attrs_dict.get("href")
|
|
286
|
+
self.output.append("[")
|
|
287
|
+
elif tag == "img":
|
|
288
|
+
alt = attrs_dict.get("alt", "image")
|
|
289
|
+
src = attrs_dict.get("src", "")
|
|
290
|
+
self.output.append(f"")
|
|
291
|
+
elif tag == "blockquote":
|
|
292
|
+
self.output.append("\n> ")
|
|
293
|
+
elif tag == "hr":
|
|
294
|
+
self.output.append("\n---\n")
|
|
295
|
+
elif tag == "table":
|
|
296
|
+
self.output.append("\n")
|
|
297
|
+
elif tag == "th":
|
|
298
|
+
self.output.append("| ")
|
|
299
|
+
elif tag == "td":
|
|
300
|
+
self.output.append("| ")
|
|
301
|
+
elif tag == "tr":
|
|
302
|
+
pass # Handled in endtag
|
|
303
|
+
|
|
304
|
+
def handle_endtag(self, tag: str) -> None:
|
|
305
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
306
|
+
self.skip_content = False
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
|
|
310
|
+
self.output.append("\n")
|
|
311
|
+
elif tag == "strong" or tag == "b":
|
|
312
|
+
self.output.append("**")
|
|
313
|
+
elif tag == "em" or tag == "i":
|
|
314
|
+
self.output.append("*")
|
|
315
|
+
elif tag == "code" and not self.in_pre:
|
|
316
|
+
self.output.append("`")
|
|
317
|
+
self.in_code_block = False
|
|
318
|
+
elif tag == "pre":
|
|
319
|
+
self.in_pre = False
|
|
320
|
+
self.output.append("\n```\n")
|
|
321
|
+
elif tag == "ul":
|
|
322
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
323
|
+
elif tag == "ol":
|
|
324
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
325
|
+
if self.ordered_list_counters:
|
|
326
|
+
self.ordered_list_counters.pop()
|
|
327
|
+
elif tag == "li":
|
|
328
|
+
self.output.append("\n")
|
|
329
|
+
elif tag == "a":
|
|
330
|
+
if self.link_href:
|
|
331
|
+
self.output.append(f"]({self.link_href})")
|
|
332
|
+
else:
|
|
333
|
+
self.output.append("]")
|
|
334
|
+
self.link_href = None
|
|
335
|
+
elif tag == "p":
|
|
336
|
+
self.output.append("\n")
|
|
337
|
+
elif tag == "blockquote":
|
|
338
|
+
self.output.append("\n")
|
|
339
|
+
elif tag == "tr":
|
|
340
|
+
self.output.append("|\n")
|
|
341
|
+
elif tag == "thead":
|
|
342
|
+
# Add markdown table separator after header row
|
|
343
|
+
self.output.append("|---" * 10 + "|\n")
|
|
344
|
+
|
|
345
|
+
def handle_data(self, data: str) -> None:
|
|
346
|
+
if self.skip_content:
|
|
347
|
+
return
|
|
348
|
+
# Preserve whitespace in code blocks
|
|
349
|
+
if self.in_pre:
|
|
350
|
+
self.output.append(data)
|
|
351
|
+
else:
|
|
352
|
+
# Collapse whitespace outside code
|
|
353
|
+
text = re.sub(r"\s+", " ", data)
|
|
354
|
+
if text.strip():
|
|
355
|
+
self.output.append(text)
|
|
356
|
+
|
|
357
|
+
def get_markdown(self) -> str:
|
|
358
|
+
"""Get the converted markdown, cleaned up."""
|
|
359
|
+
md = "".join(self.output)
|
|
360
|
+
# Clean up excessive newlines
|
|
361
|
+
md = re.sub(r"\n{3,}", "\n\n", md)
|
|
362
|
+
# Clean up empty table separators
|
|
363
|
+
md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
|
|
364
|
+
return md.strip()
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _html_to_markdown(html: str) -> str:
|
|
368
|
+
"""Convert HTML to Markdown."""
|
|
369
|
+
parser = _HTMLToMarkdown()
|
|
370
|
+
parser.feed(html)
|
|
371
|
+
return parser.get_markdown()
|
|
372
|
+
|
|
373
|
+
|
|
172
374
|
def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
173
|
-
"""Download NVIDIA docs
|
|
375
|
+
"""Download NVIDIA docs and convert HTML to Markdown.
|
|
376
|
+
|
|
377
|
+
NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
|
|
378
|
+
"""
|
|
174
379
|
assert config.urls is not None
|
|
175
380
|
downloaded = 0
|
|
176
381
|
with httpx.Client(timeout=30.0, follow_redirects=True) as client:
|
|
177
382
|
for url in config.urls:
|
|
178
|
-
md_url = f"{url}.md"
|
|
179
383
|
filepath = _url_to_filepath(url, dest)
|
|
180
384
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
181
385
|
try:
|
|
182
|
-
|
|
386
|
+
# Fetch HTML page directly
|
|
387
|
+
resp = client.get(url)
|
|
183
388
|
resp.raise_for_status()
|
|
184
|
-
|
|
389
|
+
|
|
390
|
+
# Convert HTML to Markdown
|
|
391
|
+
markdown = _html_to_markdown(resp.text)
|
|
392
|
+
|
|
393
|
+
# Add source URL as header
|
|
394
|
+
content = f"<!-- Source: {url} -->\n\n{markdown}"
|
|
395
|
+
filepath.write_text(content)
|
|
185
396
|
downloaded += 1
|
|
186
397
|
if verbose:
|
|
187
398
|
print(f" ✓ {filepath.relative_to(dest)}")
|
|
@@ -275,6 +486,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
|
|
|
275
486
|
return downloaded
|
|
276
487
|
|
|
277
488
|
|
|
489
|
+
def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
490
|
+
"""Download from mixed sources (NVIDIA docs + GitHub repos)."""
|
|
491
|
+
total = 0
|
|
492
|
+
|
|
493
|
+
# Download NVIDIA markdown docs (urls)
|
|
494
|
+
if config.urls:
|
|
495
|
+
if verbose:
|
|
496
|
+
print(" [NVIDIA docs]")
|
|
497
|
+
total += _download_nvidia_md(config, dest, verbose)
|
|
498
|
+
|
|
499
|
+
# Download GitHub repos
|
|
500
|
+
if config.repos:
|
|
501
|
+
if verbose:
|
|
502
|
+
print(" [GitHub repos]")
|
|
503
|
+
total += _download_github_multi_repo(config, dest, verbose)
|
|
504
|
+
|
|
505
|
+
return total
|
|
506
|
+
|
|
507
|
+
|
|
278
508
|
def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
|
|
279
509
|
"""Download a corpus to local cache.
|
|
280
510
|
|
|
@@ -311,6 +541,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
|
|
|
311
541
|
count = _download_github_repo(config, dest, verbose)
|
|
312
542
|
elif config.source_type == "github_multi_repo":
|
|
313
543
|
count = _download_github_multi_repo(config, dest, verbose)
|
|
544
|
+
elif config.source_type == "mixed":
|
|
545
|
+
count = _download_mixed(config, dest, verbose)
|
|
314
546
|
else:
|
|
315
547
|
raise ValueError(f"Unknown source type: {config.source_type}")
|
|
316
548
|
if verbose:
|
wafer/evaluate.py
CHANGED
|
@@ -1168,11 +1168,16 @@ def _build_modal_sandbox_script(
|
|
|
1168
1168
|
"""
|
|
1169
1169
|
gpu_type = target.gpu_type
|
|
1170
1170
|
|
|
1171
|
-
# Determine PyTorch index based on GPU type
|
|
1171
|
+
# Determine PyTorch index and CUDA arch based on GPU type
|
|
1172
1172
|
if gpu_type in ("B200", "GB200"):
|
|
1173
|
-
torch_index = "https://download.pytorch.org/whl/
|
|
1173
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
1174
|
+
cuda_arch_list = "10.0" # Blackwell (sm_100)
|
|
1175
|
+
elif gpu_type == "H100":
|
|
1176
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
1177
|
+
cuda_arch_list = "9.0" # Hopper (sm_90)
|
|
1174
1178
|
else:
|
|
1175
1179
|
torch_index = "https://download.pytorch.org/whl/cu124"
|
|
1180
|
+
cuda_arch_list = "8.0" # Default to Ampere (sm_80)
|
|
1176
1181
|
|
|
1177
1182
|
return f'''
|
|
1178
1183
|
import asyncio
|
|
@@ -1190,7 +1195,7 @@ async def run_eval():
|
|
|
1190
1195
|
"nvidia/cuda:12.9.0-devel-ubuntu22.04",
|
|
1191
1196
|
add_python="3.12",
|
|
1192
1197
|
)
|
|
1193
|
-
.apt_install("git", "build-essential", "cmake")
|
|
1198
|
+
.apt_install("git", "build-essential", "cmake", "ripgrep")
|
|
1194
1199
|
.pip_install(
|
|
1195
1200
|
"torch",
|
|
1196
1201
|
index_url="{torch_index}",
|
|
@@ -1203,6 +1208,12 @@ async def run_eval():
|
|
|
1203
1208
|
)
|
|
1204
1209
|
.env({{
|
|
1205
1210
|
"CUDA_HOME": "/usr/local/cuda",
|
|
1211
|
+
# C++ compiler needs explicit include path for cuda_runtime.h
|
|
1212
|
+
"CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
|
|
1213
|
+
# Linker needs lib path
|
|
1214
|
+
"LIBRARY_PATH": "/usr/local/cuda/lib64",
|
|
1215
|
+
# Force PyTorch to compile for correct GPU architecture
|
|
1216
|
+
"TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
|
|
1206
1217
|
}})
|
|
1207
1218
|
)
|
|
1208
1219
|
|
|
@@ -2790,6 +2801,15 @@ if torch.cuda.is_available():
|
|
|
2790
2801
|
gc.collect()
|
|
2791
2802
|
torch.cuda.empty_cache()
|
|
2792
2803
|
torch.cuda.reset_peak_memory_stats()
|
|
2804
|
+
|
|
2805
|
+
# Enable TF32 for fair benchmarking against reference kernels.
|
|
2806
|
+
# PyTorch 1.12+ disables TF32 for matmul by default, which handicaps
|
|
2807
|
+
# reference kernels using cuBLAS. We enable it so reference kernels
|
|
2808
|
+
# run at their best performance (using tensor cores when applicable).
|
|
2809
|
+
# This ensures speedup comparisons are against optimized baselines.
|
|
2810
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
2811
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
2812
|
+
print("[KernelBench] TF32 enabled for fair benchmarking")
|
|
2793
2813
|
|
|
2794
2814
|
|
|
2795
2815
|
def _calculate_timing_stats(times: list[float]) -> dict:
|
|
@@ -3453,6 +3473,368 @@ def _validate_kernelbench_files(args: KernelBenchEvaluateArgs) -> str | None:
|
|
|
3453
3473
|
return None
|
|
3454
3474
|
|
|
3455
3475
|
|
|
3476
|
+
def _build_modal_kernelbench_script(
|
|
3477
|
+
target: ModalTarget,
|
|
3478
|
+
impl_code_b64: str,
|
|
3479
|
+
ref_code_b64: str,
|
|
3480
|
+
eval_script_b64: str,
|
|
3481
|
+
run_benchmarks: bool,
|
|
3482
|
+
run_defensive: bool,
|
|
3483
|
+
defense_code_b64: str | None,
|
|
3484
|
+
seed: int,
|
|
3485
|
+
inputs_code_b64: str | None = None,
|
|
3486
|
+
) -> str:
|
|
3487
|
+
"""Build Python script to create Modal sandbox and run KernelBench evaluation.
|
|
3488
|
+
|
|
3489
|
+
This runs in a subprocess to isolate Modal's asyncio from trio.
|
|
3490
|
+
"""
|
|
3491
|
+
gpu_type = target.gpu_type
|
|
3492
|
+
|
|
3493
|
+
# Determine PyTorch index and CUDA arch based on GPU type
|
|
3494
|
+
if gpu_type in ("B200", "GB200"):
|
|
3495
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
3496
|
+
cuda_arch_list = "10.0" # Blackwell (sm_100)
|
|
3497
|
+
elif gpu_type == "H100":
|
|
3498
|
+
# H100 uses CUDA 13.0 (matches modal_app.py)
|
|
3499
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
3500
|
+
cuda_arch_list = "9.0" # Hopper (sm_90)
|
|
3501
|
+
else:
|
|
3502
|
+
torch_index = "https://download.pytorch.org/whl/cu124"
|
|
3503
|
+
cuda_arch_list = "8.0" # Default to Ampere (sm_80)
|
|
3504
|
+
|
|
3505
|
+
# Install CUTLASS headers (for cute/tensor.hpp and cutlass/util/*.h) from GitHub
|
|
3506
|
+
# The nvidia-cutlass-dsl pip package doesn't include the C++ headers needed for nvcc
|
|
3507
|
+
# IMPORTANT: symlink to /usr/local/cuda/include because nvcc searches there by default
|
|
3508
|
+
cutlass_install = '''
|
|
3509
|
+
.run_commands([
|
|
3510
|
+
# Clone CUTLASS headers from GitHub (shallow clone, full include tree)
|
|
3511
|
+
# Use simple shallow clone - sparse-checkout can be buggy in some environments
|
|
3512
|
+
"git clone --depth 1 https://github.com/NVIDIA/cutlass.git /opt/cutlass",
|
|
3513
|
+
# Verify the util headers exist (for debugging)
|
|
3514
|
+
"ls -la /opt/cutlass/include/cutlass/util/ | head -5",
|
|
3515
|
+
# Symlink headers to CUDA include path (nvcc searches here by default)
|
|
3516
|
+
"ln -sf /opt/cutlass/include/cute /usr/local/cuda/include/cute",
|
|
3517
|
+
"ln -sf /opt/cutlass/include/cutlass /usr/local/cuda/include/cutlass",
|
|
3518
|
+
])
|
|
3519
|
+
.pip_install(
|
|
3520
|
+
"nvidia-cutlass-dsl",
|
|
3521
|
+
index_url="https://pypi.nvidia.com",
|
|
3522
|
+
extra_index_url="https://pypi.org/simple",
|
|
3523
|
+
)
|
|
3524
|
+
'''
|
|
3525
|
+
|
|
3526
|
+
inputs_write = ""
|
|
3527
|
+
if inputs_code_b64:
|
|
3528
|
+
inputs_write = f'''
|
|
3529
|
+
# Write custom inputs
|
|
3530
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
3531
|
+
import base64
|
|
3532
|
+
with open('/workspace/custom_inputs.py', 'w') as f:
|
|
3533
|
+
f.write(base64.b64decode('{inputs_code_b64}').decode())
|
|
3534
|
+
print('Custom inputs written')
|
|
3535
|
+
""")
|
|
3536
|
+
proc.wait()
|
|
3537
|
+
'''
|
|
3538
|
+
|
|
3539
|
+
defense_write = ""
|
|
3540
|
+
if run_defensive and defense_code_b64:
|
|
3541
|
+
defense_write = f'''
|
|
3542
|
+
# Write defense module
|
|
3543
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
3544
|
+
import base64
|
|
3545
|
+
with open('/workspace/defense.py', 'w') as f:
|
|
3546
|
+
f.write(base64.b64decode('{defense_code_b64}').decode())
|
|
3547
|
+
print('Defense module written')
|
|
3548
|
+
""")
|
|
3549
|
+
proc.wait()
|
|
3550
|
+
'''
|
|
3551
|
+
|
|
3552
|
+
# Build eval command
|
|
3553
|
+
eval_cmd_parts = [
|
|
3554
|
+
"python /workspace/kernelbench_eval.py",
|
|
3555
|
+
"--impl /workspace/implementation.py",
|
|
3556
|
+
"--reference /workspace/reference.py",
|
|
3557
|
+
"--output /workspace/results.json",
|
|
3558
|
+
f"--seed {seed}",
|
|
3559
|
+
]
|
|
3560
|
+
if run_benchmarks:
|
|
3561
|
+
eval_cmd_parts.append("--benchmark")
|
|
3562
|
+
if run_defensive and defense_code_b64:
|
|
3563
|
+
eval_cmd_parts.append("--defensive")
|
|
3564
|
+
eval_cmd_parts.append("--defense-module /workspace/defense.py")
|
|
3565
|
+
if inputs_code_b64:
|
|
3566
|
+
eval_cmd_parts.append("--inputs /workspace/custom_inputs.py")
|
|
3567
|
+
|
|
3568
|
+
eval_cmd = " ".join(eval_cmd_parts)
|
|
3569
|
+
|
|
3570
|
+
return f'''
|
|
3571
|
+
import asyncio
|
|
3572
|
+
import base64
|
|
3573
|
+
import json
|
|
3574
|
+
import sys
|
|
3575
|
+
import modal
|
|
3576
|
+
|
|
3577
|
+
async def run_eval():
|
|
3578
|
+
app = modal.App.lookup("wafer-evaluate", create_if_missing=True)
|
|
3579
|
+
|
|
3580
|
+
# Build image with PyTorch, CUTLASS DSL and dependencies
|
|
3581
|
+
image = (
|
|
3582
|
+
modal.Image.from_registry(
|
|
3583
|
+
"nvidia/cuda:12.9.0-devel-ubuntu22.04",
|
|
3584
|
+
add_python="3.12",
|
|
3585
|
+
)
|
|
3586
|
+
.apt_install("git", "build-essential", "cmake", "ninja-build", "ripgrep")
|
|
3587
|
+
.pip_install(
|
|
3588
|
+
"torch",
|
|
3589
|
+
index_url="{torch_index}",
|
|
3590
|
+
extra_index_url="https://pypi.org/simple",
|
|
3591
|
+
)
|
|
3592
|
+
.pip_install(
|
|
3593
|
+
"numpy",
|
|
3594
|
+
"triton",
|
|
3595
|
+
"ninja",
|
|
3596
|
+
)
|
|
3597
|
+
{cutlass_install}
|
|
3598
|
+
.env({{
|
|
3599
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
3600
|
+
# C++ compiler needs explicit include path for cuda_runtime.h
|
|
3601
|
+
"CPLUS_INCLUDE_PATH": "/usr/local/cuda/include",
|
|
3602
|
+
# Linker needs lib path
|
|
3603
|
+
"LIBRARY_PATH": "/usr/local/cuda/lib64",
|
|
3604
|
+
# Force PyTorch to compile for correct GPU architecture
|
|
3605
|
+
"TORCH_CUDA_ARCH_LIST": "{cuda_arch_list}",
|
|
3606
|
+
}})
|
|
3607
|
+
)
|
|
3608
|
+
|
|
3609
|
+
# Create sandbox
|
|
3610
|
+
sandbox = modal.Sandbox.create(
|
|
3611
|
+
app=app,
|
|
3612
|
+
image=image,
|
|
3613
|
+
gpu="{gpu_type}",
|
|
3614
|
+
timeout={target.timeout_seconds},
|
|
3615
|
+
)
|
|
3616
|
+
|
|
3617
|
+
try:
|
|
3618
|
+
# Create workspace directory
|
|
3619
|
+
sandbox.exec("mkdir", "-p", "/workspace").wait()
|
|
3620
|
+
|
|
3621
|
+
# Write files to sandbox
|
|
3622
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
3623
|
+
import base64
|
|
3624
|
+
with open('/workspace/implementation.py', 'w') as f:
|
|
3625
|
+
f.write(base64.b64decode('{impl_code_b64}').decode())
|
|
3626
|
+
with open('/workspace/reference.py', 'w') as f:
|
|
3627
|
+
f.write(base64.b64decode('{ref_code_b64}').decode())
|
|
3628
|
+
with open('/workspace/kernelbench_eval.py', 'w') as f:
|
|
3629
|
+
f.write(base64.b64decode('{eval_script_b64}').decode())
|
|
3630
|
+
print('Files written')
|
|
3631
|
+
""")
|
|
3632
|
+
proc.wait()
|
|
3633
|
+
if proc.returncode != 0:
|
|
3634
|
+
print(json.dumps({{"success": False, "error": f"Failed to write files: {{proc.stderr.read()}}"}}))
|
|
3635
|
+
return
|
|
3636
|
+
{inputs_write}
|
|
3637
|
+
{defense_write}
|
|
3638
|
+
# Run evaluation
|
|
3639
|
+
print(f"Running KernelBench evaluation on {{'{gpu_type}'}}...")
|
|
3640
|
+
proc = sandbox.exec("bash", "-c", "{eval_cmd}")
|
|
3641
|
+
|
|
3642
|
+
# Stream output
|
|
3643
|
+
for line in proc.stdout:
|
|
3644
|
+
print(line, end="")
|
|
3645
|
+
for line in proc.stderr:
|
|
3646
|
+
print(line, end="", file=sys.stderr)
|
|
3647
|
+
|
|
3648
|
+
proc.wait()
|
|
3649
|
+
|
|
3650
|
+
if proc.returncode != 0:
|
|
3651
|
+
print(json.dumps({{"success": False, "error": f"Evaluation failed with exit code {{proc.returncode}}"}}))
|
|
3652
|
+
return
|
|
3653
|
+
|
|
3654
|
+
# Read results
|
|
3655
|
+
result_proc = sandbox.exec("cat", "/workspace/results.json")
|
|
3656
|
+
result_data = result_proc.stdout.read()
|
|
3657
|
+
result_proc.wait()
|
|
3658
|
+
|
|
3659
|
+
if result_data:
|
|
3660
|
+
results = json.loads(result_data)
|
|
3661
|
+
print("EVAL_RESULT_JSON:" + json.dumps(results))
|
|
3662
|
+
else:
|
|
3663
|
+
print(json.dumps({{"success": False, "error": "No results.json found"}}))
|
|
3664
|
+
|
|
3665
|
+
finally:
|
|
3666
|
+
sandbox.terminate()
|
|
3667
|
+
|
|
3668
|
+
asyncio.run(run_eval())
|
|
3669
|
+
'''
|
|
3670
|
+
|
|
3671
|
+
|
|
3672
|
+
async def run_evaluate_kernelbench_modal(
|
|
3673
|
+
args: KernelBenchEvaluateArgs,
|
|
3674
|
+
target: ModalTarget,
|
|
3675
|
+
) -> EvaluateResult:
|
|
3676
|
+
"""Run KernelBench format evaluation on Modal sandbox.
|
|
3677
|
+
|
|
3678
|
+
Creates a Modal sandbox, uploads files, runs KernelBench eval, and parses results.
|
|
3679
|
+
Uses subprocess to isolate Modal's asyncio from trio.
|
|
3680
|
+
"""
|
|
3681
|
+
import base64
|
|
3682
|
+
import subprocess
|
|
3683
|
+
import sys
|
|
3684
|
+
|
|
3685
|
+
import trio
|
|
3686
|
+
|
|
3687
|
+
print(f"Creating Modal sandbox ({target.gpu_type}) for KernelBench evaluation...")
|
|
3688
|
+
|
|
3689
|
+
# Encode files as base64
|
|
3690
|
+
impl_code_b64 = base64.b64encode(args.implementation.read_bytes()).decode()
|
|
3691
|
+
ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
|
|
3692
|
+
eval_script_b64 = base64.b64encode(KERNELBENCH_EVAL_SCRIPT.encode()).decode()
|
|
3693
|
+
|
|
3694
|
+
# Encode custom inputs if provided
|
|
3695
|
+
inputs_code_b64 = None
|
|
3696
|
+
if args.inputs:
|
|
3697
|
+
inputs_code_b64 = base64.b64encode(args.inputs.read_bytes()).decode()
|
|
3698
|
+
|
|
3699
|
+
# Encode defense module if defensive mode is enabled
|
|
3700
|
+
defense_code_b64 = None
|
|
3701
|
+
if args.defensive:
|
|
3702
|
+
defense_path = (
|
|
3703
|
+
Path(__file__).parent.parent.parent.parent
|
|
3704
|
+
/ "packages"
|
|
3705
|
+
/ "wafer-core"
|
|
3706
|
+
/ "wafer_core"
|
|
3707
|
+
/ "utils"
|
|
3708
|
+
/ "kernel_utils"
|
|
3709
|
+
/ "defense.py"
|
|
3710
|
+
)
|
|
3711
|
+
if defense_path.exists():
|
|
3712
|
+
defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
|
|
3713
|
+
else:
|
|
3714
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
3715
|
+
|
|
3716
|
+
# Build the script
|
|
3717
|
+
script = _build_modal_kernelbench_script(
|
|
3718
|
+
target=target,
|
|
3719
|
+
impl_code_b64=impl_code_b64,
|
|
3720
|
+
ref_code_b64=ref_code_b64,
|
|
3721
|
+
eval_script_b64=eval_script_b64,
|
|
3722
|
+
run_benchmarks=args.benchmark,
|
|
3723
|
+
run_defensive=args.defensive,
|
|
3724
|
+
defense_code_b64=defense_code_b64,
|
|
3725
|
+
seed=args.seed,
|
|
3726
|
+
inputs_code_b64=inputs_code_b64,
|
|
3727
|
+
)
|
|
3728
|
+
|
|
3729
|
+
def _run_subprocess() -> tuple[str, str, int]:
|
|
3730
|
+
result = subprocess.run(
|
|
3731
|
+
[sys.executable, "-c", script],
|
|
3732
|
+
capture_output=True,
|
|
3733
|
+
text=True,
|
|
3734
|
+
timeout=target.timeout_seconds + 120, # Extra buffer for sandbox creation + image build
|
|
3735
|
+
)
|
|
3736
|
+
return result.stdout, result.stderr, result.returncode
|
|
3737
|
+
|
|
3738
|
+
try:
|
|
3739
|
+
stdout, stderr, returncode = await trio.to_thread.run_sync(_run_subprocess)
|
|
3740
|
+
except subprocess.TimeoutExpired:
|
|
3741
|
+
return EvaluateResult(
|
|
3742
|
+
success=False,
|
|
3743
|
+
all_correct=False,
|
|
3744
|
+
correctness_score=0.0,
|
|
3745
|
+
geomean_speedup=0.0,
|
|
3746
|
+
passed_tests=0,
|
|
3747
|
+
total_tests=0,
|
|
3748
|
+
error_message=f"Modal KernelBench evaluation timed out after {target.timeout_seconds}s",
|
|
3749
|
+
)
|
|
3750
|
+
except Exception as e:
|
|
3751
|
+
return EvaluateResult(
|
|
3752
|
+
success=False,
|
|
3753
|
+
all_correct=False,
|
|
3754
|
+
correctness_score=0.0,
|
|
3755
|
+
geomean_speedup=0.0,
|
|
3756
|
+
passed_tests=0,
|
|
3757
|
+
total_tests=0,
|
|
3758
|
+
error_message=f"Failed to run Modal sandbox: {e}",
|
|
3759
|
+
)
|
|
3760
|
+
|
|
3761
|
+
# Print output for debugging
|
|
3762
|
+
if stdout:
|
|
3763
|
+
for line in stdout.split("\n"):
|
|
3764
|
+
if not line.startswith("EVAL_RESULT_JSON:"):
|
|
3765
|
+
print(line)
|
|
3766
|
+
if stderr:
|
|
3767
|
+
print(stderr, file=sys.stderr)
|
|
3768
|
+
|
|
3769
|
+
if returncode != 0:
|
|
3770
|
+
return EvaluateResult(
|
|
3771
|
+
success=False,
|
|
3772
|
+
all_correct=False,
|
|
3773
|
+
correctness_score=0.0,
|
|
3774
|
+
geomean_speedup=0.0,
|
|
3775
|
+
passed_tests=0,
|
|
3776
|
+
total_tests=0,
|
|
3777
|
+
error_message=f"Modal sandbox failed (exit {returncode}): {stderr or stdout}",
|
|
3778
|
+
)
|
|
3779
|
+
|
|
3780
|
+
# Parse results from stdout
|
|
3781
|
+
result_json = None
|
|
3782
|
+
for line in stdout.split("\n"):
|
|
3783
|
+
if line.startswith("EVAL_RESULT_JSON:"):
|
|
3784
|
+
result_json = line[len("EVAL_RESULT_JSON:"):]
|
|
3785
|
+
break
|
|
3786
|
+
|
|
3787
|
+
if not result_json:
|
|
3788
|
+
return EvaluateResult(
|
|
3789
|
+
success=False,
|
|
3790
|
+
all_correct=False,
|
|
3791
|
+
correctness_score=0.0,
|
|
3792
|
+
geomean_speedup=0.0,
|
|
3793
|
+
passed_tests=0,
|
|
3794
|
+
total_tests=0,
|
|
3795
|
+
error_message="No results found in Modal output",
|
|
3796
|
+
)
|
|
3797
|
+
|
|
3798
|
+
try:
|
|
3799
|
+
results = json.loads(result_json)
|
|
3800
|
+
except json.JSONDecodeError as e:
|
|
3801
|
+
return EvaluateResult(
|
|
3802
|
+
success=False,
|
|
3803
|
+
all_correct=False,
|
|
3804
|
+
correctness_score=0.0,
|
|
3805
|
+
geomean_speedup=0.0,
|
|
3806
|
+
passed_tests=0,
|
|
3807
|
+
total_tests=0,
|
|
3808
|
+
error_message=f"Failed to parse results JSON: {e}",
|
|
3809
|
+
)
|
|
3810
|
+
|
|
3811
|
+
# Check for error in results
|
|
3812
|
+
if "error" in results and results.get("success") is False:
|
|
3813
|
+
return EvaluateResult(
|
|
3814
|
+
success=False,
|
|
3815
|
+
all_correct=False,
|
|
3816
|
+
correctness_score=0.0,
|
|
3817
|
+
geomean_speedup=0.0,
|
|
3818
|
+
passed_tests=0,
|
|
3819
|
+
total_tests=0,
|
|
3820
|
+
error_message=results.get("error", "Unknown error"),
|
|
3821
|
+
)
|
|
3822
|
+
|
|
3823
|
+
# Extract metrics from results
|
|
3824
|
+
return EvaluateResult(
|
|
3825
|
+
success=True,
|
|
3826
|
+
all_correct=results.get("all_correct", False),
|
|
3827
|
+
correctness_score=float(results.get("correctness_score", 0.0)),
|
|
3828
|
+
geomean_speedup=float(results.get("geomean_speedup", 0.0)),
|
|
3829
|
+
passed_tests=int(results.get("passed_tests", 0)),
|
|
3830
|
+
total_tests=int(results.get("total_tests", 0)),
|
|
3831
|
+
error_message=results.get("error"),
|
|
3832
|
+
test_results=results.get("test_results", []),
|
|
3833
|
+
compilation_time_s=results.get("compilation_time_s"),
|
|
3834
|
+
profiling_stats=results.get("profiling_stats"),
|
|
3835
|
+
)
|
|
3836
|
+
|
|
3837
|
+
|
|
3456
3838
|
async def run_evaluate_kernelbench_docker(
|
|
3457
3839
|
args: KernelBenchEvaluateArgs,
|
|
3458
3840
|
target: BaremetalTarget | VMTarget,
|
|
@@ -4246,6 +4628,20 @@ async def run_evaluate_kernelbench_runpod(
|
|
|
4246
4628
|
)
|
|
4247
4629
|
|
|
4248
4630
|
|
|
4631
|
+
async def run_evaluate_kernelbench_baremetal_direct(
|
|
4632
|
+
args: KernelBenchEvaluateArgs,
|
|
4633
|
+
target: BaremetalTarget,
|
|
4634
|
+
) -> EvaluateResult:
|
|
4635
|
+
"""Run KernelBench format evaluation directly on NVIDIA target (no Docker).
|
|
4636
|
+
|
|
4637
|
+
For targets that already have PyTorch/CUDA installed (e.g., workspace containers).
|
|
4638
|
+
Uses CUDA_VISIBLE_DEVICES for GPU selection.
|
|
4639
|
+
"""
|
|
4640
|
+
# Reuse the AMD function but with CUDA env vars
|
|
4641
|
+
# The logic is identical, just the GPU env var is different
|
|
4642
|
+
return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="CUDA_VISIBLE_DEVICES")
|
|
4643
|
+
|
|
4644
|
+
|
|
4249
4645
|
async def run_evaluate_kernelbench_baremetal_amd(
|
|
4250
4646
|
args: KernelBenchEvaluateArgs,
|
|
4251
4647
|
target: BaremetalTarget,
|
|
@@ -4255,6 +4651,18 @@ async def run_evaluate_kernelbench_baremetal_amd(
|
|
|
4255
4651
|
Runs evaluation script directly on host (no Docker) for AMD GPUs
|
|
4256
4652
|
that have PyTorch/ROCm installed.
|
|
4257
4653
|
"""
|
|
4654
|
+
return await _run_evaluate_kernelbench_baremetal_direct_impl(args, target, gpu_env_var="HIP_VISIBLE_DEVICES")
|
|
4655
|
+
|
|
4656
|
+
|
|
4657
|
+
async def _run_evaluate_kernelbench_baremetal_direct_impl(
|
|
4658
|
+
args: KernelBenchEvaluateArgs,
|
|
4659
|
+
target: BaremetalTarget,
|
|
4660
|
+
gpu_env_var: str = "HIP_VISIBLE_DEVICES",
|
|
4661
|
+
) -> EvaluateResult:
|
|
4662
|
+
"""Internal implementation for direct baremetal evaluation.
|
|
4663
|
+
|
|
4664
|
+
Runs evaluation script directly on host (no Docker).
|
|
4665
|
+
"""
|
|
4258
4666
|
from datetime import datetime
|
|
4259
4667
|
|
|
4260
4668
|
from wafer_core.async_ssh import AsyncSSHClient
|
|
@@ -4405,11 +4813,15 @@ async def run_evaluate_kernelbench_baremetal_amd(
|
|
|
4405
4813
|
|
|
4406
4814
|
eval_cmd = " ".join(python_cmd_parts)
|
|
4407
4815
|
|
|
4408
|
-
# Set environment for
|
|
4409
|
-
|
|
4410
|
-
|
|
4411
|
-
|
|
4412
|
-
|
|
4816
|
+
# Set environment for GPU and run
|
|
4817
|
+
if gpu_env_var == "HIP_VISIBLE_DEVICES":
|
|
4818
|
+
# AMD: PYTORCH_ROCM_ARCH for faster compile
|
|
4819
|
+
rocm_arch = _get_rocm_arch(target.compute_capability)
|
|
4820
|
+
arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
|
|
4821
|
+
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
|
|
4822
|
+
else:
|
|
4823
|
+
# NVIDIA: just set CUDA_VISIBLE_DEVICES
|
|
4824
|
+
env_vars = f"CUDA_VISIBLE_DEVICES={gpu_id} PYTHONUNBUFFERED=1"
|
|
4413
4825
|
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4414
4826
|
|
|
4415
4827
|
# Handle prepare-only mode
|
|
@@ -4560,10 +4972,16 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
|
|
|
4560
4972
|
elif isinstance(target, RunPodTarget):
|
|
4561
4973
|
# RunPod AMD MI300X - uses ROCm Docker with device passthrough
|
|
4562
4974
|
return await run_evaluate_kernelbench_runpod(args, target)
|
|
4975
|
+
elif isinstance(target, ModalTarget):
|
|
4976
|
+
# Modal serverless - runs in Modal sandbox
|
|
4977
|
+
return await run_evaluate_kernelbench_modal(args, target)
|
|
4563
4978
|
elif isinstance(target, BaremetalTarget | VMTarget):
|
|
4564
4979
|
# Check if this is an AMD target (gfx* compute capability) - run directly
|
|
4565
4980
|
if target.compute_capability and target.compute_capability.startswith("gfx"):
|
|
4566
4981
|
return await run_evaluate_kernelbench_baremetal_amd(args, target)
|
|
4982
|
+
# Check for direct execution flag (workspace containers that already have everything)
|
|
4983
|
+
if getattr(target, "direct", False):
|
|
4984
|
+
return await run_evaluate_kernelbench_baremetal_direct(args, target)
|
|
4567
4985
|
# NVIDIA targets - require docker_image to be set
|
|
4568
4986
|
if not target.docker_image:
|
|
4569
4987
|
return EvaluateResult(
|
wafer/wevin_cli.py
CHANGED
|
@@ -274,7 +274,12 @@ def _build_environment(
|
|
|
274
274
|
from wafer_core.sandbox import SandboxMode
|
|
275
275
|
|
|
276
276
|
working_dir = Path(corpus_path) if corpus_path else Path.cwd()
|
|
277
|
-
resolved_tools = tools_override or tpl.tools
|
|
277
|
+
resolved_tools = list(tools_override or tpl.tools)
|
|
278
|
+
|
|
279
|
+
# Add skill tool if skills are enabled
|
|
280
|
+
if tpl.include_skills and "skill" not in resolved_tools:
|
|
281
|
+
resolved_tools.append("skill")
|
|
282
|
+
|
|
278
283
|
sandbox_mode = SandboxMode.DISABLED if no_sandbox else SandboxMode.ENABLED
|
|
279
284
|
env: Environment = CodingEnvironment(
|
|
280
285
|
working_dir=working_dir,
|
|
@@ -378,6 +383,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
378
383
|
|
|
379
384
|
# Handle --get-session: load session by ID and print
|
|
380
385
|
if get_session:
|
|
386
|
+
|
|
381
387
|
async def _get_session() -> None:
|
|
382
388
|
try:
|
|
383
389
|
session, err = await session_store.get(get_session)
|
|
@@ -398,16 +404,18 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
398
404
|
error_msg = f"Failed to serialize messages: {e}"
|
|
399
405
|
print(json.dumps({"error": error_msg}))
|
|
400
406
|
sys.exit(1)
|
|
401
|
-
|
|
402
|
-
print(
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
407
|
+
|
|
408
|
+
print(
|
|
409
|
+
json.dumps({
|
|
410
|
+
"session_id": session.session_id,
|
|
411
|
+
"status": session.status.value,
|
|
412
|
+
"model": session.endpoint.model if session.endpoint else None,
|
|
413
|
+
"created_at": session.created_at,
|
|
414
|
+
"updated_at": session.updated_at,
|
|
415
|
+
"messages": messages_data,
|
|
416
|
+
"tags": session.tags,
|
|
417
|
+
})
|
|
418
|
+
)
|
|
411
419
|
else:
|
|
412
420
|
print(f"Session: {session.session_id}")
|
|
413
421
|
print(f"Status: {session.status.value}")
|
|
@@ -495,7 +503,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
495
503
|
print(f"Error loading template: {err}", file=sys.stderr)
|
|
496
504
|
sys.exit(1)
|
|
497
505
|
tpl = loaded_template
|
|
498
|
-
|
|
506
|
+
base_system_prompt = tpl.interpolate_prompt(template_args or {})
|
|
499
507
|
# Show template info when starting without a prompt
|
|
500
508
|
if not prompt and tpl.description:
|
|
501
509
|
print(f"Template: {tpl.name}", file=sys.stderr)
|
|
@@ -503,7 +511,20 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
503
511
|
print(file=sys.stderr)
|
|
504
512
|
else:
|
|
505
513
|
tpl = _get_default_template()
|
|
506
|
-
|
|
514
|
+
base_system_prompt = tpl.system_prompt
|
|
515
|
+
|
|
516
|
+
# Append skill metadata if skills are enabled
|
|
517
|
+
if tpl.include_skills:
|
|
518
|
+
from wafer_core.rollouts.skills import discover_skills, format_skill_metadata_for_prompt
|
|
519
|
+
|
|
520
|
+
skill_metadata = discover_skills()
|
|
521
|
+
if skill_metadata:
|
|
522
|
+
skill_section = format_skill_metadata_for_prompt(skill_metadata)
|
|
523
|
+
system_prompt = base_system_prompt + "\n\n" + skill_section
|
|
524
|
+
else:
|
|
525
|
+
system_prompt = base_system_prompt
|
|
526
|
+
else:
|
|
527
|
+
system_prompt = base_system_prompt
|
|
507
528
|
|
|
508
529
|
# CLI args override template values
|
|
509
530
|
resolved_single_turn = single_turn if single_turn is not None else tpl.single_turn
|
|
@@ -550,7 +571,7 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
550
571
|
else:
|
|
551
572
|
if json_output:
|
|
552
573
|
# Emit session_start if we have a session_id (from --resume)
|
|
553
|
-
model_name = endpoint.model if hasattr(endpoint,
|
|
574
|
+
model_name = endpoint.model if hasattr(endpoint, "model") else None
|
|
554
575
|
frontend = StreamingChunkFrontend(session_id=session_id, model=model_name)
|
|
555
576
|
else:
|
|
556
577
|
frontend = NoneFrontend(show_tool_calls=True, show_thinking=False)
|
|
@@ -565,9 +586,11 @@ def main( # noqa: PLR0913, PLR0915
|
|
|
565
586
|
# Emit session_start for new sessions (if session_id was None and we got one)
|
|
566
587
|
# Check first state to emit as early as possible
|
|
567
588
|
if json_output and isinstance(frontend, StreamingChunkFrontend):
|
|
568
|
-
first_session_id =
|
|
589
|
+
first_session_id = (
|
|
590
|
+
states[0].session_id if states and states[0].session_id else None
|
|
591
|
+
)
|
|
569
592
|
if first_session_id and not session_id: # New session created
|
|
570
|
-
model_name = endpoint.model if hasattr(endpoint,
|
|
593
|
+
model_name = endpoint.model if hasattr(endpoint, "model") else None
|
|
571
594
|
frontend.emit_session_start(first_session_id, model_name)
|
|
572
595
|
# Print resume command with full wafer agent prefix
|
|
573
596
|
if states and states[-1].session_id:
|
|
@@ -5,10 +5,10 @@ wafer/api_client.py,sha256=i_Az2b2llC3DSW8yOL-BKqa7LSKuxOr8hSN40s-oQXY,6313
|
|
|
5
5
|
wafer/auth.py,sha256=dwss_se5P-FFc9IN38q4kh_dBrA6k-CguDBkivgcdj0,14003
|
|
6
6
|
wafer/autotuner.py,sha256=41WYP41pTDvMijv2h42vm89bcHtDMJXObDlWmn6xpFU,44416
|
|
7
7
|
wafer/billing.py,sha256=jbLB2lI4_9f2KD8uEFDi_ixLlowe5hasC0TIZJyIXRg,7163
|
|
8
|
-
wafer/cli.py,sha256=
|
|
8
|
+
wafer/cli.py,sha256=j4ODOVT_r-kyc21YOI8Yl8bkiZMGuqDpXRs7CvpNaek,261443
|
|
9
9
|
wafer/config.py,sha256=h5Eo9_yfWqWGoPNdVQikI9GoZVUeysunSYiixf1mKcw,3411
|
|
10
|
-
wafer/corpus.py,sha256=
|
|
11
|
-
wafer/evaluate.py,sha256=
|
|
10
|
+
wafer/corpus.py,sha256=oQegXA43MuyRvYxOsWhmqeP5vMb5IKFHOvM-1RcahPA,22301
|
|
11
|
+
wafer/evaluate.py,sha256=SxxhiPkO6aDdfktRzJXpbWMVmIGn_gw-o5C6Zwj2zRc,190930
|
|
12
12
|
wafer/global_config.py,sha256=fhaR_RU3ufMksDmOohH1OLeQ0JT0SDW1hEip_zaP75k,11345
|
|
13
13
|
wafer/gpu_run.py,sha256=TwqXy72T7f2I7e6n5WWod3xgxCPnDhU0BgLsB4CUoQY,9716
|
|
14
14
|
wafer/inference.py,sha256=tZCO5i05FKY27ewis3CSBHFBeFbXY3xwj0DSjdoMY9s,4314
|
|
@@ -26,16 +26,16 @@ wafer/target_lock.py,sha256=SDKhNzv2N7gsphGflcNni9FE5YYuAMuEthngAJEo4Gs,7809
|
|
|
26
26
|
wafer/targets.py,sha256=9r-iRWoKSH5cQl1LcamaX-T7cNVOg99ngIm_hlRk-qU,26922
|
|
27
27
|
wafer/targets_ops.py,sha256=jN1oIBx0mutxRNE9xpIc7SaBxPkVmOyus2eqn0kEKNI,21475
|
|
28
28
|
wafer/tracelens.py,sha256=g9ZIeFyNojZn4uTd3skPqIrRiL7aMJOz_-GOd3aiyy4,7998
|
|
29
|
-
wafer/wevin_cli.py,sha256=
|
|
29
|
+
wafer/wevin_cli.py,sha256=Nuk7zTCiJrnpmYtdg5Hu0NbzONCqs54xtON6K7AVB9U,23189
|
|
30
30
|
wafer/workspaces.py,sha256=iUdioK7kA3z_gOTMNVDn9Q87c6qpkdXF4bOhJWkUPg8,32375
|
|
31
31
|
wafer/skills/wafer-guide/SKILL.md,sha256=KWetJw2TVTbz11_nzqazqOJWWRlbHRFShs4sOoreiWo,3255
|
|
32
32
|
wafer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
33
|
wafer/templates/ask_docs.py,sha256=Lxs-faz9v5m4Qa4NjF2X_lE8KwM9ES9MNJkxo7ep56o,2256
|
|
34
|
-
wafer/templates/optimize_kernel.py,sha256=
|
|
34
|
+
wafer/templates/optimize_kernel.py,sha256=OvZgN5tm_OymO3lK8Dr0VO48e-5PfNVIIoACrPxpmqk,2446
|
|
35
35
|
wafer/templates/optimize_kernelbench.py,sha256=aoOA13zWEl89r6QW03xF9NKxQ7j4mWe9rwua6-mlr4Y,4780
|
|
36
36
|
wafer/templates/trace_analyze.py,sha256=XE1VqzVkIUsZbXF8EzQdDYgg-AZEYAOFpr6B_vnRELc,2880
|
|
37
|
-
wafer_cli-0.2.
|
|
38
|
-
wafer_cli-0.2.
|
|
39
|
-
wafer_cli-0.2.
|
|
40
|
-
wafer_cli-0.2.
|
|
41
|
-
wafer_cli-0.2.
|
|
37
|
+
wafer_cli-0.2.22.dist-info/METADATA,sha256=vjYzyQtphWxQ0JID0k5tFWoLwVjlR6X0B4UAuMhLhQc,560
|
|
38
|
+
wafer_cli-0.2.22.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
39
|
+
wafer_cli-0.2.22.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
|
|
40
|
+
wafer_cli-0.2.22.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
|
|
41
|
+
wafer_cli-0.2.22.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|