wafer-cli 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +118 -0
- wafer/__init__.py +3 -0
- wafer/analytics.py +306 -0
- wafer/api_client.py +195 -0
- wafer/auth.py +432 -0
- wafer/autotuner.py +1080 -0
- wafer/billing.py +233 -0
- wafer/cli.py +7289 -0
- wafer/config.py +105 -0
- wafer/corpus.py +366 -0
- wafer/evaluate.py +4593 -0
- wafer/global_config.py +350 -0
- wafer/gpu_run.py +307 -0
- wafer/inference.py +148 -0
- wafer/kernel_scope.py +552 -0
- wafer/ncu_analyze.py +651 -0
- wafer/nsys_analyze.py +1042 -0
- wafer/nsys_profile.py +510 -0
- wafer/output.py +248 -0
- wafer/problems.py +357 -0
- wafer/rocprof_compute.py +490 -0
- wafer/rocprof_sdk.py +274 -0
- wafer/rocprof_systems.py +520 -0
- wafer/skills/wafer-guide/SKILL.md +129 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +842 -0
- wafer/targets_ops.py +717 -0
- wafer/templates/__init__.py +0 -0
- wafer/templates/ask_docs.py +61 -0
- wafer/templates/optimize_kernel.py +71 -0
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/templates/trace_analyze.py +74 -0
- wafer/tracelens.py +218 -0
- wafer/wevin_cli.py +577 -0
- wafer/workspaces.py +852 -0
- wafer_cli-0.2.14.dist-info/METADATA +16 -0
- wafer_cli-0.2.14.dist-info/RECORD +41 -0
- wafer_cli-0.2.14.dist-info/WHEEL +5 -0
- wafer_cli-0.2.14.dist-info/entry_points.txt +2 -0
- wafer_cli-0.2.14.dist-info/top_level.txt +1 -0
wafer/evaluate.py
ADDED
|
@@ -0,0 +1,4593 @@
|
|
|
1
|
+
"""Remote kernel evaluation for Wafer CLI.
|
|
2
|
+
|
|
3
|
+
Runs evaluate.py on a remote GPU target with the same interface as local execution.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import shlex
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
15
|
+
BaremetalTarget,
|
|
16
|
+
DigitalOceanTarget,
|
|
17
|
+
LocalTarget,
|
|
18
|
+
ModalTarget,
|
|
19
|
+
RunPodTarget,
|
|
20
|
+
VMTarget,
|
|
21
|
+
WorkspaceTarget,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Map AMD compute capability to ROCm architecture
|
|
25
|
+
# Used to set PYTORCH_ROCM_ARCH for faster compilation (compile only for target arch)
|
|
26
|
+
AMD_CC_TO_ARCH = {
|
|
27
|
+
"9.4": "gfx942", # MI300X
|
|
28
|
+
"9.0a": "gfx90a", # MI200 series
|
|
29
|
+
"9.08": "gfx908", # MI100
|
|
30
|
+
"9.06": "gfx906", # MI50/60
|
|
31
|
+
"10.30": "gfx1030", # RDNA2
|
|
32
|
+
"11.0": "gfx1100", # RDNA3
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _get_rocm_arch(compute_capability: str) -> str | None:
|
|
37
|
+
"""Get ROCm architecture string from compute capability.
|
|
38
|
+
|
|
39
|
+
Returns gfx* string for PYTORCH_ROCM_ARCH, or None if not found.
|
|
40
|
+
"""
|
|
41
|
+
# Already a gfx string
|
|
42
|
+
if compute_capability.startswith("gfx"):
|
|
43
|
+
return compute_capability
|
|
44
|
+
# Map from numeric CC
|
|
45
|
+
return AMD_CC_TO_ARCH.get(compute_capability)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _build_docker_run_command(
|
|
49
|
+
image: str,
|
|
50
|
+
command: str,
|
|
51
|
+
*,
|
|
52
|
+
working_dir: str | None = None,
|
|
53
|
+
env: dict[str, str] | None = None,
|
|
54
|
+
gpus: str = "all",
|
|
55
|
+
volumes: dict[str, str] | None = None,
|
|
56
|
+
cap_add: list[str] | None = None,
|
|
57
|
+
) -> str:
|
|
58
|
+
"""Build a docker run command string for NVIDIA GPUs.
|
|
59
|
+
|
|
60
|
+
Pure function: string in, string out. No side effects.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
image: Docker image name (e.g., "nvcr.io/nvidia/cutlass:4.3-devel")
|
|
64
|
+
command: Command to run inside container
|
|
65
|
+
working_dir: Container working directory (optional)
|
|
66
|
+
env: Environment variables as dict (optional)
|
|
67
|
+
gpus: GPU access string ("all", "device=0", "device=0,1", etc.)
|
|
68
|
+
volumes: Host:container volume mappings (optional)
|
|
69
|
+
cap_add: Linux capabilities to add (e.g., ["SYS_ADMIN"] for NCU profiling)
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Complete docker run command string
|
|
73
|
+
"""
|
|
74
|
+
parts = ["docker", "run", "--rm"]
|
|
75
|
+
|
|
76
|
+
# Add capabilities (needed for NCU profiling)
|
|
77
|
+
if cap_add:
|
|
78
|
+
for cap in cap_add:
|
|
79
|
+
parts.extend(["--cap-add", cap])
|
|
80
|
+
|
|
81
|
+
# GPU access - use single quotes for the device spec to avoid shell escaping issues
|
|
82
|
+
if gpus:
|
|
83
|
+
parts.extend(["--gpus", f"'{gpus}'"])
|
|
84
|
+
|
|
85
|
+
# Volume mounts
|
|
86
|
+
if volumes:
|
|
87
|
+
for host_path, container_path in volumes.items():
|
|
88
|
+
parts.extend(["-v", f"{host_path}:{container_path}"])
|
|
89
|
+
|
|
90
|
+
# Working directory
|
|
91
|
+
if working_dir:
|
|
92
|
+
parts.extend(["-w", working_dir])
|
|
93
|
+
|
|
94
|
+
# Environment variables
|
|
95
|
+
if env:
|
|
96
|
+
for key, value in env.items():
|
|
97
|
+
parts.extend(["-e", f"{key}={shlex.quote(value)}"])
|
|
98
|
+
|
|
99
|
+
# Image and command
|
|
100
|
+
parts.append(image)
|
|
101
|
+
parts.append(f"bash -c {shlex.quote(command)}")
|
|
102
|
+
|
|
103
|
+
return " ".join(parts)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _build_docker_run_command_amd(
|
|
107
|
+
image: str,
|
|
108
|
+
command: str,
|
|
109
|
+
*,
|
|
110
|
+
working_dir: str | None = None,
|
|
111
|
+
env: dict[str, str] | None = None,
|
|
112
|
+
volumes: dict[str, str] | None = None,
|
|
113
|
+
) -> str:
|
|
114
|
+
"""Build a docker run command string for AMD GPUs (ROCm).
|
|
115
|
+
|
|
116
|
+
Uses device passthrough instead of NVIDIA's --gpus flag.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
image: Docker image name (e.g., "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0")
|
|
120
|
+
command: Command to run inside container
|
|
121
|
+
working_dir: Container working directory (optional)
|
|
122
|
+
env: Environment variables as dict (optional)
|
|
123
|
+
volumes: Host:container volume mappings (optional)
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Complete docker run command string
|
|
127
|
+
"""
|
|
128
|
+
parts = ["docker", "run", "--rm"]
|
|
129
|
+
|
|
130
|
+
# AMD GPU access via device passthrough
|
|
131
|
+
parts.extend(["--device=/dev/kfd", "--device=/dev/dri", "--group-add", "video"])
|
|
132
|
+
|
|
133
|
+
# Volume mounts
|
|
134
|
+
if volumes:
|
|
135
|
+
for host_path, container_path in volumes.items():
|
|
136
|
+
parts.extend(["-v", f"{host_path}:{container_path}"])
|
|
137
|
+
|
|
138
|
+
# Working directory
|
|
139
|
+
if working_dir:
|
|
140
|
+
parts.extend(["-w", working_dir])
|
|
141
|
+
|
|
142
|
+
# Environment variables
|
|
143
|
+
if env:
|
|
144
|
+
for key, value in env.items():
|
|
145
|
+
parts.extend(["-e", f"{key}={shlex.quote(value)}"])
|
|
146
|
+
|
|
147
|
+
# Image and command
|
|
148
|
+
parts.append(image)
|
|
149
|
+
parts.append(f"bash -c {shlex.quote(command)}")
|
|
150
|
+
|
|
151
|
+
return " ".join(parts)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass(frozen=True)
|
|
155
|
+
class EvaluateArgs:
|
|
156
|
+
"""Arguments for evaluate command.
|
|
157
|
+
|
|
158
|
+
Mirrors evaluate.py's CLI args.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
implementation: Path
|
|
162
|
+
reference: Path
|
|
163
|
+
test_cases: Path
|
|
164
|
+
target_name: str
|
|
165
|
+
benchmark: bool = False
|
|
166
|
+
profile: bool = False
|
|
167
|
+
defensive: bool = False
|
|
168
|
+
sync_artifacts: bool = True
|
|
169
|
+
gpu_id: int | None = None
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@dataclass(frozen=True)
|
|
173
|
+
class KernelBenchEvaluateArgs:
|
|
174
|
+
"""Arguments for KernelBench format evaluate command.
|
|
175
|
+
|
|
176
|
+
KernelBench format uses Model/ModelNew classes instead of functions.
|
|
177
|
+
No test_cases file - reference defines get_inputs()/get_init_inputs().
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
implementation: Path # Must define ModelNew class
|
|
181
|
+
reference: Path # Must define Model, get_inputs, get_init_inputs
|
|
182
|
+
target_name: str
|
|
183
|
+
benchmark: bool = False
|
|
184
|
+
profile: bool = False
|
|
185
|
+
inputs: Path | None = None # Custom inputs file to override get_inputs()
|
|
186
|
+
seed: int = 42 # Random seed for reproducibility
|
|
187
|
+
defensive: bool = False
|
|
188
|
+
backend: str | None = None # Kernel backend for static validation
|
|
189
|
+
sync_artifacts: bool = True
|
|
190
|
+
gpu_id: int | None = None
|
|
191
|
+
stages: str = "compile,correctness" # Stages to run: compile, correctness, benchmark, defense
|
|
192
|
+
prepare_only: bool = False # Sync files and generate script but don't run
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclass(frozen=True)
|
|
196
|
+
class EvaluateResult:
|
|
197
|
+
"""Result from remote evaluation."""
|
|
198
|
+
|
|
199
|
+
success: bool
|
|
200
|
+
all_correct: bool | None # None when correctness wasn't checked (compile-only, prepare-only)
|
|
201
|
+
correctness_score: float
|
|
202
|
+
geomean_speedup: float
|
|
203
|
+
passed_tests: int
|
|
204
|
+
total_tests: int
|
|
205
|
+
error_message: str | None = None
|
|
206
|
+
artifact_path: Path | None = None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _check_python_file_has(path: Path, *names: str) -> list[str]:
|
|
210
|
+
"""Check if a Python file exports the given names.
|
|
211
|
+
|
|
212
|
+
Uses AST parsing to find:
|
|
213
|
+
- Function definitions: def name(...)
|
|
214
|
+
- Class definitions: class name(...)
|
|
215
|
+
- Assignments: name = ...
|
|
216
|
+
- Imports: from module import name / from module import x as name
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
List of names that are missing
|
|
220
|
+
"""
|
|
221
|
+
import ast
|
|
222
|
+
|
|
223
|
+
content = path.read_text()
|
|
224
|
+
try:
|
|
225
|
+
tree = ast.parse(content)
|
|
226
|
+
except SyntaxError:
|
|
227
|
+
# If we can't parse, let the runtime fail with a better error
|
|
228
|
+
return []
|
|
229
|
+
|
|
230
|
+
defined_names: set[str] = set()
|
|
231
|
+
for node in ast.walk(tree):
|
|
232
|
+
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
|
|
233
|
+
defined_names.add(node.name)
|
|
234
|
+
elif isinstance(node, ast.ClassDef):
|
|
235
|
+
defined_names.add(node.name)
|
|
236
|
+
elif isinstance(node, ast.Assign):
|
|
237
|
+
for target in node.targets:
|
|
238
|
+
if isinstance(target, ast.Name):
|
|
239
|
+
defined_names.add(target.id)
|
|
240
|
+
elif isinstance(node, ast.ImportFrom):
|
|
241
|
+
for alias in node.names:
|
|
242
|
+
# Use asname if present, otherwise use the original name
|
|
243
|
+
defined_names.add(alias.asname or alias.name)
|
|
244
|
+
|
|
245
|
+
return [name for name in names if name not in defined_names]
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _validate_files(args: EvaluateArgs) -> str | None:
|
|
249
|
+
"""Validate that all input files exist, have correct format, and expected signatures.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Error message if validation fails, None if all valid
|
|
253
|
+
"""
|
|
254
|
+
if not args.implementation.exists():
|
|
255
|
+
return f"Implementation file not found: {args.implementation}"
|
|
256
|
+
if not args.reference.exists():
|
|
257
|
+
return f"Reference file not found: {args.reference}"
|
|
258
|
+
if not args.test_cases.exists():
|
|
259
|
+
return f"Test cases file not found: {args.test_cases}"
|
|
260
|
+
|
|
261
|
+
# Validate test_cases is valid JSON
|
|
262
|
+
try:
|
|
263
|
+
json.loads(args.test_cases.read_text())
|
|
264
|
+
except json.JSONDecodeError:
|
|
265
|
+
if args.test_cases.suffix == ".py":
|
|
266
|
+
return (
|
|
267
|
+
f"--test-cases must be a JSON file, not a Python file: {args.test_cases}\n"
|
|
268
|
+
"Hint: For KernelBench problems, use 'wafer evaluate kernelbench' instead:\n"
|
|
269
|
+
f" wafer evaluate kernelbench --impl <impl.py> --reference {args.test_cases}"
|
|
270
|
+
)
|
|
271
|
+
return f"--test-cases must be valid JSON: {args.test_cases}"
|
|
272
|
+
|
|
273
|
+
# Validate implementation has custom_kernel
|
|
274
|
+
impl_missing = _check_python_file_has(args.implementation, "custom_kernel")
|
|
275
|
+
if impl_missing:
|
|
276
|
+
# Check if it looks like KernelBench format (has ModelNew)
|
|
277
|
+
has_model_new = not _check_python_file_has(args.implementation, "ModelNew")
|
|
278
|
+
if has_model_new:
|
|
279
|
+
return (
|
|
280
|
+
f"Implementation file missing 'custom_kernel' function: {args.implementation}\n"
|
|
281
|
+
"Hint: This looks like KernelBench format. Use 'wafer evaluate kernelbench' instead:\n"
|
|
282
|
+
f" wafer evaluate kernelbench --impl {args.implementation} --reference <reference.py>"
|
|
283
|
+
)
|
|
284
|
+
return (
|
|
285
|
+
f"Implementation file missing 'custom_kernel' function: {args.implementation}\n"
|
|
286
|
+
" Required: 'def custom_kernel(inputs)' function"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Validate reference has ref_kernel and generate_input
|
|
290
|
+
ref_missing = _check_python_file_has(args.reference, "ref_kernel", "generate_input")
|
|
291
|
+
if ref_missing:
|
|
292
|
+
# Check if it looks like KernelBench format (has Model and get_inputs)
|
|
293
|
+
has_kernelbench = not _check_python_file_has(args.reference, "Model", "get_inputs")
|
|
294
|
+
if has_kernelbench:
|
|
295
|
+
return (
|
|
296
|
+
f"Reference file missing required functions: {', '.join(ref_missing)}\n"
|
|
297
|
+
"Hint: This looks like KernelBench format. Use 'wafer evaluate kernelbench' instead:\n"
|
|
298
|
+
f" wafer evaluate kernelbench --impl <impl.py> --reference {args.reference}"
|
|
299
|
+
)
|
|
300
|
+
return (
|
|
301
|
+
f"Reference file missing required functions: {', '.join(ref_missing)}\n"
|
|
302
|
+
f" File: {args.reference}\n"
|
|
303
|
+
" Required: 'ref_kernel' and 'generate_input' functions"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
return None
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _select_gpu_id(
|
|
310
|
+
target: BaremetalTarget | VMTarget | ModalTarget, gpu_id_override: int | None
|
|
311
|
+
) -> int:
|
|
312
|
+
"""Select GPU ID to use.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
target: Target config
|
|
316
|
+
gpu_id_override: Optional explicit GPU ID
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
GPU ID to use
|
|
320
|
+
"""
|
|
321
|
+
if gpu_id_override is not None:
|
|
322
|
+
return gpu_id_override
|
|
323
|
+
|
|
324
|
+
# Use first GPU from target's list
|
|
325
|
+
if isinstance(target, BaremetalTarget | VMTarget):
|
|
326
|
+
return target.gpu_ids[0]
|
|
327
|
+
|
|
328
|
+
# Modal doesn't have explicit GPU IDs
|
|
329
|
+
return 0
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _build_docker_pip_install_cmd(target: BaremetalTarget | VMTarget) -> str:
|
|
333
|
+
"""Build pip install command for Docker container.
|
|
334
|
+
|
|
335
|
+
Installs uv first, then uses uv to install packages (Modal-like approach).
|
|
336
|
+
Uses --system flag to install to container's system Python (not any venv).
|
|
337
|
+
|
|
338
|
+
Handles base CUDA images that may not have pip pre-installed.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
target: Target config with pip_packages, torch_package, torch_index_url
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Shell command string to install dependencies
|
|
345
|
+
"""
|
|
346
|
+
commands = []
|
|
347
|
+
|
|
348
|
+
# Some base images (like nvidia/cuda) don't have pip or git, install them first
|
|
349
|
+
# Use apt for Debian/Ubuntu-based images, with noninteractive to avoid prompts
|
|
350
|
+
commands.append(
|
|
351
|
+
"(which pip > /dev/null 2>&1 && which git > /dev/null 2>&1) || "
|
|
352
|
+
"(apt-get update && "
|
|
353
|
+
"DEBIAN_FRONTEND=noninteractive apt-get install -y python3 python3-pip git > /dev/null)"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Install uv (fast, reliable) - use pip3 for compatibility
|
|
357
|
+
commands.append("pip3 install uv")
|
|
358
|
+
|
|
359
|
+
# Install torch with custom index if specified (like Modal's two-phase install)
|
|
360
|
+
# Use --system --break-system-packages to install to container's Python
|
|
361
|
+
# (needed for Python 3.12+ with PEP 668 externally managed environments)
|
|
362
|
+
if target.torch_package:
|
|
363
|
+
if target.torch_index_url:
|
|
364
|
+
commands.append(
|
|
365
|
+
f"uv pip install --system --break-system-packages --index-url {target.torch_index_url} "
|
|
366
|
+
f"--extra-index-url https://pypi.org/simple {target.torch_package}"
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
commands.append(
|
|
370
|
+
f"uv pip install --system --break-system-packages {target.torch_package}"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Install other packages
|
|
374
|
+
if target.pip_packages:
|
|
375
|
+
packages_str = " ".join(target.pip_packages)
|
|
376
|
+
commands.append(f"uv pip install --system --break-system-packages {packages_str}")
|
|
377
|
+
|
|
378
|
+
return " && ".join(commands)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _get_wafer_root() -> Path:
|
|
382
|
+
"""Get wafer monorepo root directory.
|
|
383
|
+
|
|
384
|
+
Walks up from this file to find the wafer repo root (contains apps/, packages/).
|
|
385
|
+
"""
|
|
386
|
+
current = Path(__file__).resolve()
|
|
387
|
+
for parent in [current] + list(current.parents):
|
|
388
|
+
if (parent / "apps").is_dir() and (parent / "packages").is_dir():
|
|
389
|
+
return parent
|
|
390
|
+
raise RuntimeError(f"Could not find wafer root from {__file__}")
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
async def run_evaluate_docker(
|
|
394
|
+
args: EvaluateArgs,
|
|
395
|
+
target: BaremetalTarget | VMTarget,
|
|
396
|
+
) -> EvaluateResult:
|
|
397
|
+
"""Run evaluation in Docker container on SSH-based target.
|
|
398
|
+
|
|
399
|
+
Uses async SSH client for true non-blocking I/O.
|
|
400
|
+
Uploads wafer-core and runs evaluate.py directly with PYTHONPATH.
|
|
401
|
+
No package installation needed - avoids rollouts dependency.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
args: Evaluate arguments
|
|
405
|
+
target: SSH target config with docker_image set
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Evaluation result
|
|
409
|
+
"""
|
|
410
|
+
from datetime import datetime
|
|
411
|
+
|
|
412
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
413
|
+
|
|
414
|
+
CONTAINER_WORKSPACE = "/workspace"
|
|
415
|
+
REMOTE_WORKSPACE_BASE = "~/.wafer/workspaces"
|
|
416
|
+
|
|
417
|
+
if not target.docker_image:
|
|
418
|
+
raise ValueError("docker_image must be set for Docker execution")
|
|
419
|
+
|
|
420
|
+
# Select GPU
|
|
421
|
+
gpu_id = _select_gpu_id(target, args.gpu_id)
|
|
422
|
+
|
|
423
|
+
print(f"Connecting to {target.ssh_target}...")
|
|
424
|
+
|
|
425
|
+
async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
|
|
426
|
+
print(f"Using Docker image: {target.docker_image}")
|
|
427
|
+
print(f"Using GPU {gpu_id}...")
|
|
428
|
+
|
|
429
|
+
# Read local files
|
|
430
|
+
impl_code = args.implementation.read_text()
|
|
431
|
+
ref_code = args.reference.read_text()
|
|
432
|
+
test_cases_data = json.loads(args.test_cases.read_text())
|
|
433
|
+
|
|
434
|
+
# Create workspace for evaluation files
|
|
435
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
436
|
+
run_dir = f"wafer_eval_{timestamp}"
|
|
437
|
+
eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
|
|
438
|
+
await client.exec(f"mkdir -p {eval_workspace}")
|
|
439
|
+
eval_workspace_expanded = await client.expand_path(eval_workspace)
|
|
440
|
+
run_path = f"{eval_workspace_expanded}/{run_dir}"
|
|
441
|
+
|
|
442
|
+
print("Uploading evaluation files...")
|
|
443
|
+
|
|
444
|
+
# Create run directory
|
|
445
|
+
mkdir_result = await client.exec(f"mkdir -p {run_path}")
|
|
446
|
+
if mkdir_result.exit_code != 0:
|
|
447
|
+
return EvaluateResult(
|
|
448
|
+
success=False,
|
|
449
|
+
all_correct=False,
|
|
450
|
+
correctness_score=0.0,
|
|
451
|
+
geomean_speedup=0.0,
|
|
452
|
+
passed_tests=0,
|
|
453
|
+
total_tests=0,
|
|
454
|
+
error_message=f"Failed to create run directory: {mkdir_result.stderr}",
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# Write implementation
|
|
458
|
+
impl_path = f"{run_path}/implementation.py"
|
|
459
|
+
write_result = await client.exec(
|
|
460
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
461
|
+
)
|
|
462
|
+
if write_result.exit_code != 0:
|
|
463
|
+
return EvaluateResult(
|
|
464
|
+
success=False,
|
|
465
|
+
all_correct=False,
|
|
466
|
+
correctness_score=0.0,
|
|
467
|
+
geomean_speedup=0.0,
|
|
468
|
+
passed_tests=0,
|
|
469
|
+
total_tests=0,
|
|
470
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Write reference
|
|
474
|
+
ref_path = f"{run_path}/reference.py"
|
|
475
|
+
write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
|
|
476
|
+
if write_result.exit_code != 0:
|
|
477
|
+
return EvaluateResult(
|
|
478
|
+
success=False,
|
|
479
|
+
all_correct=False,
|
|
480
|
+
correctness_score=0.0,
|
|
481
|
+
geomean_speedup=0.0,
|
|
482
|
+
passed_tests=0,
|
|
483
|
+
total_tests=0,
|
|
484
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Also write as reference_kernel.py (evaluate.py imports generate_input from this)
|
|
488
|
+
ref_kernel_path = f"{run_path}/reference_kernel.py"
|
|
489
|
+
write_result = await client.exec(
|
|
490
|
+
f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
491
|
+
)
|
|
492
|
+
if write_result.exit_code != 0:
|
|
493
|
+
return EvaluateResult(
|
|
494
|
+
success=False,
|
|
495
|
+
all_correct=False,
|
|
496
|
+
correctness_score=0.0,
|
|
497
|
+
geomean_speedup=0.0,
|
|
498
|
+
passed_tests=0,
|
|
499
|
+
total_tests=0,
|
|
500
|
+
error_message=f"Failed to write reference_kernel: {write_result.stderr}",
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# Write test cases
|
|
504
|
+
test_cases_path = f"{run_path}/test_cases.json"
|
|
505
|
+
test_cases_json = json.dumps(test_cases_data, indent=2)
|
|
506
|
+
write_result = await client.exec(
|
|
507
|
+
f"cat > '{test_cases_path}' << 'TESTS_EOF'\n{test_cases_json}\nTESTS_EOF"
|
|
508
|
+
)
|
|
509
|
+
if write_result.exit_code != 0:
|
|
510
|
+
return EvaluateResult(
|
|
511
|
+
success=False,
|
|
512
|
+
all_correct=False,
|
|
513
|
+
correctness_score=0.0,
|
|
514
|
+
geomean_speedup=0.0,
|
|
515
|
+
passed_tests=0,
|
|
516
|
+
total_tests=0,
|
|
517
|
+
error_message=f"Failed to write test cases: {write_result.stderr}",
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
print("Running evaluation in Docker container...")
|
|
521
|
+
|
|
522
|
+
# Paths inside container (workspace mounted at /workspace)
|
|
523
|
+
container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
|
|
524
|
+
container_impl_path = f"{container_run_path}/implementation.py"
|
|
525
|
+
container_ref_path = f"{container_run_path}/reference.py"
|
|
526
|
+
container_test_cases_path = f"{container_run_path}/test_cases.json"
|
|
527
|
+
|
|
528
|
+
# Build pip install command for torch and other deps, plus wafer-core
|
|
529
|
+
pip_install_cmd = _build_docker_pip_install_cmd(target)
|
|
530
|
+
install_cmd = (
|
|
531
|
+
f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
# Build evaluate command using installed wafer-core module
|
|
535
|
+
python_cmd_parts = [
|
|
536
|
+
"python3 -m wafer_core.utils.kernel_utils.evaluate",
|
|
537
|
+
f"--implementation {container_impl_path}",
|
|
538
|
+
f"--reference {container_ref_path}",
|
|
539
|
+
f"--test-cases {container_test_cases_path}",
|
|
540
|
+
f"--run-dir {container_run_path}",
|
|
541
|
+
]
|
|
542
|
+
|
|
543
|
+
if args.benchmark:
|
|
544
|
+
python_cmd_parts.append("--benchmark")
|
|
545
|
+
if args.profile:
|
|
546
|
+
python_cmd_parts.append("--profile")
|
|
547
|
+
if args.defensive:
|
|
548
|
+
python_cmd_parts.append("--defensive")
|
|
549
|
+
|
|
550
|
+
eval_cmd = " ".join(python_cmd_parts)
|
|
551
|
+
|
|
552
|
+
# Full command: install deps + wafer-core, then run evaluate
|
|
553
|
+
full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
|
|
554
|
+
|
|
555
|
+
# Build Docker run command
|
|
556
|
+
# Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
|
|
557
|
+
docker_cmd = _build_docker_run_command(
|
|
558
|
+
image=target.docker_image,
|
|
559
|
+
command=full_cmd,
|
|
560
|
+
working_dir=container_run_path,
|
|
561
|
+
env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
|
|
562
|
+
gpus="all",
|
|
563
|
+
volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
|
|
564
|
+
cap_add=["SYS_ADMIN"] if args.profile else None,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
print(f"Docker command: {docker_cmd[:100]}...")
|
|
568
|
+
|
|
569
|
+
# Run Docker command and stream output
|
|
570
|
+
log_lines = []
|
|
571
|
+
async for line in client.exec_stream(docker_cmd):
|
|
572
|
+
print(line, flush=True)
|
|
573
|
+
log_lines.append(line)
|
|
574
|
+
|
|
575
|
+
# Read results
|
|
576
|
+
results_path = f"{run_path}/results.json"
|
|
577
|
+
cat_result = await client.exec(f"cat {results_path}")
|
|
578
|
+
|
|
579
|
+
if cat_result.exit_code != 0:
|
|
580
|
+
log_tail = "\n".join(log_lines[-50:])
|
|
581
|
+
return EvaluateResult(
|
|
582
|
+
success=False,
|
|
583
|
+
all_correct=False,
|
|
584
|
+
correctness_score=0.0,
|
|
585
|
+
geomean_speedup=0.0,
|
|
586
|
+
passed_tests=0,
|
|
587
|
+
total_tests=0,
|
|
588
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# Parse results
|
|
592
|
+
try:
|
|
593
|
+
results_data = json.loads(cat_result.stdout)
|
|
594
|
+
except json.JSONDecodeError as e:
|
|
595
|
+
return EvaluateResult(
|
|
596
|
+
success=False,
|
|
597
|
+
all_correct=False,
|
|
598
|
+
correctness_score=0.0,
|
|
599
|
+
geomean_speedup=0.0,
|
|
600
|
+
passed_tests=0,
|
|
601
|
+
total_tests=0,
|
|
602
|
+
error_message=f"Failed to parse results: {e}",
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Extract backend results
|
|
606
|
+
backends = results_data.get("backends", [])
|
|
607
|
+
if not backends:
|
|
608
|
+
return EvaluateResult(
|
|
609
|
+
success=False,
|
|
610
|
+
all_correct=False,
|
|
611
|
+
correctness_score=0.0,
|
|
612
|
+
geomean_speedup=0.0,
|
|
613
|
+
passed_tests=0,
|
|
614
|
+
total_tests=0,
|
|
615
|
+
error_message="No backend results found",
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
backend = backends[0]
|
|
619
|
+
correctness_tests = backend.get("correctness_tests", [])
|
|
620
|
+
passed = sum(1 for t in correctness_tests if t.get("is_correct", False))
|
|
621
|
+
total = len(correctness_tests)
|
|
622
|
+
|
|
623
|
+
# Sync artifacts if requested
|
|
624
|
+
artifact_path = None
|
|
625
|
+
if args.sync_artifacts:
|
|
626
|
+
local_artifact_dir = Path.cwd() / "wafer_artifacts" / run_dir
|
|
627
|
+
local_artifact_dir.mkdir(parents=True, exist_ok=True)
|
|
628
|
+
|
|
629
|
+
try:
|
|
630
|
+
# Download results.json
|
|
631
|
+
download_result = await client.download_files(
|
|
632
|
+
remote_path=f"{run_path}/results.json",
|
|
633
|
+
local_path=str(local_artifact_dir / "results.json"),
|
|
634
|
+
)
|
|
635
|
+
if download_result.success:
|
|
636
|
+
artifact_path = local_artifact_dir
|
|
637
|
+
print(f"Artifacts saved to: {artifact_path}")
|
|
638
|
+
else:
|
|
639
|
+
print(f"Warning: Failed to sync results.json: {download_result.error_message}")
|
|
640
|
+
|
|
641
|
+
# Download NCU profiles if they exist (from --profile flag)
|
|
642
|
+
# NCU profiles are stored in artifact/ncu/ subdirectory
|
|
643
|
+
ncu_check = await client.exec(f"test -d {run_path}/artifact/ncu")
|
|
644
|
+
if ncu_check.exit_code == 0:
|
|
645
|
+
local_ncu_dir = local_artifact_dir / "ncu"
|
|
646
|
+
local_ncu_dir.mkdir(parents=True, exist_ok=True)
|
|
647
|
+
ncu_result = await client.download_files(
|
|
648
|
+
remote_path=f"{run_path}/artifact/ncu",
|
|
649
|
+
local_path=str(local_ncu_dir),
|
|
650
|
+
recursive=True,
|
|
651
|
+
)
|
|
652
|
+
if ncu_result.success:
|
|
653
|
+
print(f"NCU profiles synced: {ncu_result.files_copied} files")
|
|
654
|
+
else:
|
|
655
|
+
print(f"Warning: Failed to sync NCU profiles: {ncu_result.error_message}")
|
|
656
|
+
except Exception as e:
|
|
657
|
+
print(f"Warning: Failed to sync artifacts: {e}")
|
|
658
|
+
|
|
659
|
+
return EvaluateResult(
|
|
660
|
+
success=True,
|
|
661
|
+
all_correct=backend.get("all_correct", False),
|
|
662
|
+
correctness_score=backend.get("correctness_score", 0.0),
|
|
663
|
+
geomean_speedup=backend.get("geomean_speedup", 0.0),
|
|
664
|
+
passed_tests=passed,
|
|
665
|
+
total_tests=total,
|
|
666
|
+
artifact_path=artifact_path,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
async def run_evaluate_local(
|
|
671
|
+
args: EvaluateArgs,
|
|
672
|
+
target: LocalTarget,
|
|
673
|
+
) -> EvaluateResult:
|
|
674
|
+
"""Run evaluation locally on the current machine.
|
|
675
|
+
|
|
676
|
+
For LocalTarget - no SSH needed, runs directly.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
args: Evaluate arguments
|
|
680
|
+
target: Local target config
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
Evaluation result
|
|
684
|
+
"""
|
|
685
|
+
import os
|
|
686
|
+
import subprocess
|
|
687
|
+
import tempfile
|
|
688
|
+
from datetime import datetime
|
|
689
|
+
|
|
690
|
+
# Select GPU
|
|
691
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
692
|
+
|
|
693
|
+
print(f"Running local evaluation on GPU {gpu_id}...")
|
|
694
|
+
|
|
695
|
+
# Create temp directory for eval files
|
|
696
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
697
|
+
with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
|
|
698
|
+
run_path = Path(run_path)
|
|
699
|
+
|
|
700
|
+
# Write implementation
|
|
701
|
+
impl_path = run_path / "implementation.py"
|
|
702
|
+
impl_path.write_text(args.implementation.read_text())
|
|
703
|
+
|
|
704
|
+
# Write reference
|
|
705
|
+
ref_path = run_path / "reference.py"
|
|
706
|
+
ref_path.write_text(args.reference.read_text())
|
|
707
|
+
|
|
708
|
+
# Write custom inputs if provided
|
|
709
|
+
inputs_path = None
|
|
710
|
+
if args.inputs:
|
|
711
|
+
inputs_path = run_path / "custom_inputs.py"
|
|
712
|
+
inputs_path.write_text(args.inputs.read_text())
|
|
713
|
+
|
|
714
|
+
# Write eval script
|
|
715
|
+
eval_script_path = run_path / "kernelbench_eval.py"
|
|
716
|
+
eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
|
|
717
|
+
|
|
718
|
+
# Write defense module if defensive mode is enabled
|
|
719
|
+
defense_module_path = None
|
|
720
|
+
if args.defensive:
|
|
721
|
+
defense_src = (
|
|
722
|
+
Path(__file__).parent.parent.parent.parent
|
|
723
|
+
/ "packages"
|
|
724
|
+
/ "wafer-core"
|
|
725
|
+
/ "wafer_core"
|
|
726
|
+
/ "utils"
|
|
727
|
+
/ "kernel_utils"
|
|
728
|
+
/ "defense.py"
|
|
729
|
+
)
|
|
730
|
+
if defense_src.exists():
|
|
731
|
+
defense_module_path = run_path / "defense.py"
|
|
732
|
+
defense_module_path.write_text(defense_src.read_text())
|
|
733
|
+
else:
|
|
734
|
+
print(f"Warning: defense.py not found at {defense_src}")
|
|
735
|
+
|
|
736
|
+
# Output file
|
|
737
|
+
output_path = run_path / "results.json"
|
|
738
|
+
|
|
739
|
+
# Build eval command
|
|
740
|
+
cmd_parts = [
|
|
741
|
+
"python3",
|
|
742
|
+
str(eval_script_path),
|
|
743
|
+
"--impl",
|
|
744
|
+
str(impl_path),
|
|
745
|
+
"--reference",
|
|
746
|
+
str(ref_path),
|
|
747
|
+
"--output",
|
|
748
|
+
str(output_path),
|
|
749
|
+
"--seed",
|
|
750
|
+
str(args.seed),
|
|
751
|
+
]
|
|
752
|
+
|
|
753
|
+
if args.benchmark:
|
|
754
|
+
cmd_parts.append("--benchmark")
|
|
755
|
+
if args.profile:
|
|
756
|
+
cmd_parts.append("--profile")
|
|
757
|
+
if inputs_path:
|
|
758
|
+
cmd_parts.extend(["--inputs", str(inputs_path)])
|
|
759
|
+
if args.defensive and defense_module_path:
|
|
760
|
+
cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
|
|
761
|
+
|
|
762
|
+
# Set environment for GPU selection
|
|
763
|
+
env = os.environ.copy()
|
|
764
|
+
if target.vendor == "nvidia":
|
|
765
|
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
766
|
+
else: # AMD
|
|
767
|
+
env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
|
|
768
|
+
env["ROCM_PATH"] = "/opt/rocm"
|
|
769
|
+
|
|
770
|
+
print(f"Running: {' '.join(cmd_parts[:4])} ...")
|
|
771
|
+
|
|
772
|
+
# Run evaluation
|
|
773
|
+
try:
|
|
774
|
+
result = subprocess.run(
|
|
775
|
+
cmd_parts,
|
|
776
|
+
cwd=str(run_path),
|
|
777
|
+
env=env,
|
|
778
|
+
capture_output=True,
|
|
779
|
+
text=True,
|
|
780
|
+
timeout=args.timeout or 600,
|
|
781
|
+
)
|
|
782
|
+
except subprocess.TimeoutExpired:
|
|
783
|
+
return EvaluateResult(
|
|
784
|
+
success=False,
|
|
785
|
+
all_correct=False,
|
|
786
|
+
correctness_score=0.0,
|
|
787
|
+
geomean_speedup=0.0,
|
|
788
|
+
passed_tests=0,
|
|
789
|
+
total_tests=0,
|
|
790
|
+
error_message="Evaluation timed out",
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
if result.returncode != 0:
|
|
794
|
+
error_msg = result.stderr or result.stdout or "Unknown error"
|
|
795
|
+
# Truncate long errors
|
|
796
|
+
if len(error_msg) > 1000:
|
|
797
|
+
error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
|
|
798
|
+
return EvaluateResult(
|
|
799
|
+
success=False,
|
|
800
|
+
all_correct=False,
|
|
801
|
+
correctness_score=0.0,
|
|
802
|
+
geomean_speedup=0.0,
|
|
803
|
+
passed_tests=0,
|
|
804
|
+
total_tests=0,
|
|
805
|
+
error_message=f"Evaluation failed:\n{error_msg}",
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# Parse results
|
|
809
|
+
if not output_path.exists():
|
|
810
|
+
return EvaluateResult(
|
|
811
|
+
success=False,
|
|
812
|
+
all_correct=False,
|
|
813
|
+
correctness_score=0.0,
|
|
814
|
+
geomean_speedup=0.0,
|
|
815
|
+
passed_tests=0,
|
|
816
|
+
total_tests=0,
|
|
817
|
+
error_message="No results.json produced",
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
try:
|
|
821
|
+
results = json.loads(output_path.read_text())
|
|
822
|
+
except json.JSONDecodeError as e:
|
|
823
|
+
return EvaluateResult(
|
|
824
|
+
success=False,
|
|
825
|
+
all_correct=False,
|
|
826
|
+
correctness_score=0.0,
|
|
827
|
+
geomean_speedup=0.0,
|
|
828
|
+
passed_tests=0,
|
|
829
|
+
total_tests=0,
|
|
830
|
+
error_message=f"Failed to parse results: {e}",
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
# Extract results
|
|
834
|
+
return EvaluateResult(
|
|
835
|
+
success=True,
|
|
836
|
+
all_correct=results.get("all_correct", False),
|
|
837
|
+
correctness_score=results.get("correctness_score", 0.0),
|
|
838
|
+
geomean_speedup=results.get("geomean_speedup", 0.0),
|
|
839
|
+
passed_tests=results.get("passed_tests", 0),
|
|
840
|
+
total_tests=results.get("total_tests", 0),
|
|
841
|
+
benchmark_results=results.get("benchmark", {}),
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
async def run_evaluate_ssh(
|
|
846
|
+
args: EvaluateArgs,
|
|
847
|
+
target: BaremetalTarget | VMTarget,
|
|
848
|
+
) -> EvaluateResult:
|
|
849
|
+
"""Run evaluation on SSH-based target (Baremetal or VM).
|
|
850
|
+
|
|
851
|
+
Routes to Docker or venv execution based on target.docker_image.
|
|
852
|
+
|
|
853
|
+
If docker_image is set:
|
|
854
|
+
- Uses Docker container with GPU passthrough
|
|
855
|
+
- Installs deps via uv inside container (Modal-like)
|
|
856
|
+
|
|
857
|
+
If docker_image is not set:
|
|
858
|
+
- Uses the existing venv-based deployment infrastructure
|
|
859
|
+
|
|
860
|
+
Args:
|
|
861
|
+
args: Evaluate arguments
|
|
862
|
+
target: SSH target config
|
|
863
|
+
|
|
864
|
+
Returns:
|
|
865
|
+
Evaluation result
|
|
866
|
+
"""
|
|
867
|
+
# Route to Docker execution if docker_image is set
|
|
868
|
+
if target.docker_image:
|
|
869
|
+
return await run_evaluate_docker(args, target)
|
|
870
|
+
|
|
871
|
+
# Otherwise, use venv-based execution (existing path)
|
|
872
|
+
from datetime import datetime
|
|
873
|
+
|
|
874
|
+
from wafer_core.remote_jobs import (
|
|
875
|
+
LogStreamConfig,
|
|
876
|
+
start_tmux_session,
|
|
877
|
+
stream_log_until_complete,
|
|
878
|
+
)
|
|
879
|
+
from wafer_core.utils.kernel_utils.deployment import (
|
|
880
|
+
DeploymentConfig,
|
|
881
|
+
setup_deployment,
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
# Select GPU
|
|
885
|
+
gpu_id = _select_gpu_id(target, args.gpu_id)
|
|
886
|
+
|
|
887
|
+
# Create deployment config
|
|
888
|
+
config = DeploymentConfig(
|
|
889
|
+
ssh_target=target.ssh_target,
|
|
890
|
+
ssh_key=target.ssh_key,
|
|
891
|
+
gpu_id=gpu_id,
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
print(f"Connecting to {target.ssh_target}...")
|
|
895
|
+
|
|
896
|
+
# Setup deployment (expensive - deploys monorepo + creates venv)
|
|
897
|
+
state, err = await setup_deployment(config)
|
|
898
|
+
if err:
|
|
899
|
+
return EvaluateResult(
|
|
900
|
+
success=False,
|
|
901
|
+
all_correct=False,
|
|
902
|
+
correctness_score=0.0,
|
|
903
|
+
geomean_speedup=0.0,
|
|
904
|
+
passed_tests=0,
|
|
905
|
+
total_tests=0,
|
|
906
|
+
error_message=f"Deployment setup failed: {err}",
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
assert state is not None
|
|
910
|
+
|
|
911
|
+
print(f"Using GPU {gpu_id}...")
|
|
912
|
+
|
|
913
|
+
# Read local files
|
|
914
|
+
impl_code = args.implementation.read_text()
|
|
915
|
+
ref_code = args.reference.read_text()
|
|
916
|
+
test_cases_data = json.loads(args.test_cases.read_text())
|
|
917
|
+
|
|
918
|
+
# Create a unique run directory within the deployed workspace
|
|
919
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
920
|
+
run_dir = f"wafer_eval_{timestamp}"
|
|
921
|
+
|
|
922
|
+
# workspace_path is the project path (e.g., .../research/async-wevin/benchmarks/gpumode)
|
|
923
|
+
workspace = state.workspace_path
|
|
924
|
+
run_path = f"{workspace}/{run_dir}"
|
|
925
|
+
|
|
926
|
+
# Get SSH client from deployment state
|
|
927
|
+
client = state.ssh_client
|
|
928
|
+
|
|
929
|
+
print("Uploading files...")
|
|
930
|
+
|
|
931
|
+
# Create run directory
|
|
932
|
+
mkdir_result = client.exec(f"mkdir -p {run_path}")
|
|
933
|
+
if mkdir_result.exit_code != 0:
|
|
934
|
+
return EvaluateResult(
|
|
935
|
+
success=False,
|
|
936
|
+
all_correct=False,
|
|
937
|
+
correctness_score=0.0,
|
|
938
|
+
geomean_speedup=0.0,
|
|
939
|
+
passed_tests=0,
|
|
940
|
+
total_tests=0,
|
|
941
|
+
error_message=f"Failed to create run directory: {mkdir_result.stderr}",
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
# Write implementation (must define custom_kernel function)
|
|
945
|
+
impl_path = f"{run_path}/implementation.py"
|
|
946
|
+
write_result = client.exec(f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF")
|
|
947
|
+
if write_result.exit_code != 0:
|
|
948
|
+
return EvaluateResult(
|
|
949
|
+
success=False,
|
|
950
|
+
all_correct=False,
|
|
951
|
+
correctness_score=0.0,
|
|
952
|
+
geomean_speedup=0.0,
|
|
953
|
+
passed_tests=0,
|
|
954
|
+
total_tests=0,
|
|
955
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
# Write reference (must define ref_kernel function)
|
|
959
|
+
ref_path = f"{run_path}/reference.py"
|
|
960
|
+
write_result = client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
|
|
961
|
+
if write_result.exit_code != 0:
|
|
962
|
+
return EvaluateResult(
|
|
963
|
+
success=False,
|
|
964
|
+
all_correct=False,
|
|
965
|
+
correctness_score=0.0,
|
|
966
|
+
geomean_speedup=0.0,
|
|
967
|
+
passed_tests=0,
|
|
968
|
+
total_tests=0,
|
|
969
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
# Also write as reference_kernel.py (evaluate.py imports generate_input from this)
|
|
973
|
+
ref_kernel_path = f"{run_path}/reference_kernel.py"
|
|
974
|
+
write_result = client.exec(f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
|
|
975
|
+
if write_result.exit_code != 0:
|
|
976
|
+
return EvaluateResult(
|
|
977
|
+
success=False,
|
|
978
|
+
all_correct=False,
|
|
979
|
+
correctness_score=0.0,
|
|
980
|
+
geomean_speedup=0.0,
|
|
981
|
+
passed_tests=0,
|
|
982
|
+
total_tests=0,
|
|
983
|
+
error_message=f"Failed to write reference_kernel: {write_result.stderr}",
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
# Write test cases
|
|
987
|
+
test_cases_path = f"{run_path}/test_cases.json"
|
|
988
|
+
test_cases_json = json.dumps(test_cases_data, indent=2)
|
|
989
|
+
write_result = client.exec(
|
|
990
|
+
f"cat > '{test_cases_path}' << 'TESTS_EOF'\n{test_cases_json}\nTESTS_EOF"
|
|
991
|
+
)
|
|
992
|
+
if write_result.exit_code != 0:
|
|
993
|
+
return EvaluateResult(
|
|
994
|
+
success=False,
|
|
995
|
+
all_correct=False,
|
|
996
|
+
correctness_score=0.0,
|
|
997
|
+
geomean_speedup=0.0,
|
|
998
|
+
passed_tests=0,
|
|
999
|
+
total_tests=0,
|
|
1000
|
+
error_message=f"Failed to write test cases: {write_result.stderr}",
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
print("Running evaluation...")
|
|
1004
|
+
|
|
1005
|
+
# Build evaluate command
|
|
1006
|
+
# The deployment deploys to research/async-wevin/benchmarks/gpumode
|
|
1007
|
+
# evaluate.py is at research/async-wevin/wafer_utils/kernel_utils/evaluate.py
|
|
1008
|
+
# So we need to go up 2 levels from workspace to find async-wevin root
|
|
1009
|
+
# workspace = .../research/async-wevin/benchmarks/gpumode
|
|
1010
|
+
# async_wevin_root = .../research/async-wevin
|
|
1011
|
+
async_wevin_root = "/".join(workspace.rstrip("/").split("/")[:-2])
|
|
1012
|
+
evaluate_script = f"{async_wevin_root}/wafer_utils/kernel_utils/evaluate.py"
|
|
1013
|
+
|
|
1014
|
+
env_state = state.env_state
|
|
1015
|
+
|
|
1016
|
+
eval_cmd_parts = [
|
|
1017
|
+
f"cd {run_path} &&",
|
|
1018
|
+
f"PATH={env_state.venv_bin}:$PATH",
|
|
1019
|
+
f"{env_state.venv_python} {evaluate_script}",
|
|
1020
|
+
f"--implementation {impl_path}",
|
|
1021
|
+
f"--reference {ref_path}",
|
|
1022
|
+
f"--test-cases {test_cases_path}",
|
|
1023
|
+
f"--run-dir {run_path}",
|
|
1024
|
+
]
|
|
1025
|
+
|
|
1026
|
+
if args.benchmark:
|
|
1027
|
+
eval_cmd_parts.append("--benchmark")
|
|
1028
|
+
if args.profile:
|
|
1029
|
+
eval_cmd_parts.append("--profile")
|
|
1030
|
+
if args.defensive:
|
|
1031
|
+
eval_cmd_parts.append("--defensive")
|
|
1032
|
+
|
|
1033
|
+
eval_cmd = " ".join(eval_cmd_parts)
|
|
1034
|
+
|
|
1035
|
+
# Run via tmux for streaming output
|
|
1036
|
+
session_name = f"wafer_eval_{datetime.now().strftime('%H%M%S')}"
|
|
1037
|
+
log_file = f"{run_path}/evaluate.log"
|
|
1038
|
+
|
|
1039
|
+
_, err = start_tmux_session(
|
|
1040
|
+
client=client,
|
|
1041
|
+
session_name=session_name,
|
|
1042
|
+
command=eval_cmd,
|
|
1043
|
+
workspace=run_path,
|
|
1044
|
+
log_file=log_file,
|
|
1045
|
+
env_vars={
|
|
1046
|
+
"CUDA_VISIBLE_DEVICES": str(gpu_id),
|
|
1047
|
+
"PYTHONUNBUFFERED": "1",
|
|
1048
|
+
},
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
if err:
|
|
1052
|
+
return EvaluateResult(
|
|
1053
|
+
success=False,
|
|
1054
|
+
all_correct=False,
|
|
1055
|
+
correctness_score=0.0,
|
|
1056
|
+
geomean_speedup=0.0,
|
|
1057
|
+
passed_tests=0,
|
|
1058
|
+
total_tests=0,
|
|
1059
|
+
error_message=f"Failed to start evaluation: {err}",
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
# Stream logs until completion
|
|
1063
|
+
stream_config = LogStreamConfig(
|
|
1064
|
+
session_name=session_name,
|
|
1065
|
+
log_file=log_file,
|
|
1066
|
+
timeout_sec=600, # 10 minutes max
|
|
1067
|
+
poll_interval_sec=2.0,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
_ = stream_log_until_complete(client=client, config=stream_config)
|
|
1071
|
+
|
|
1072
|
+
# Read results
|
|
1073
|
+
results_path = f"{run_path}/results.json"
|
|
1074
|
+
cat_result = client.exec(f"cat {results_path}")
|
|
1075
|
+
|
|
1076
|
+
if cat_result.exit_code != 0:
|
|
1077
|
+
# Try to get error from log
|
|
1078
|
+
log_result = client.exec(f"tail -50 {log_file}")
|
|
1079
|
+
log_tail = log_result.stdout if log_result.exit_code == 0 else ""
|
|
1080
|
+
return EvaluateResult(
|
|
1081
|
+
success=False,
|
|
1082
|
+
all_correct=False,
|
|
1083
|
+
correctness_score=0.0,
|
|
1084
|
+
geomean_speedup=0.0,
|
|
1085
|
+
passed_tests=0,
|
|
1086
|
+
total_tests=0,
|
|
1087
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
# Parse results
|
|
1091
|
+
try:
|
|
1092
|
+
results_data = json.loads(cat_result.stdout)
|
|
1093
|
+
except json.JSONDecodeError as e:
|
|
1094
|
+
return EvaluateResult(
|
|
1095
|
+
success=False,
|
|
1096
|
+
all_correct=False,
|
|
1097
|
+
correctness_score=0.0,
|
|
1098
|
+
geomean_speedup=0.0,
|
|
1099
|
+
passed_tests=0,
|
|
1100
|
+
total_tests=0,
|
|
1101
|
+
error_message=f"Failed to parse results: {e}",
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
# Extract backend results
|
|
1105
|
+
# Results format: {"backends": [{"backend_name": ..., "correctness_score": ..., ...}]}
|
|
1106
|
+
backends = results_data.get("backends", [])
|
|
1107
|
+
if not backends:
|
|
1108
|
+
return EvaluateResult(
|
|
1109
|
+
success=False,
|
|
1110
|
+
all_correct=False,
|
|
1111
|
+
correctness_score=0.0,
|
|
1112
|
+
geomean_speedup=0.0,
|
|
1113
|
+
passed_tests=0,
|
|
1114
|
+
total_tests=0,
|
|
1115
|
+
error_message="No backend results found",
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
backend = backends[0]
|
|
1119
|
+
correctness_tests = backend.get("correctness_tests", [])
|
|
1120
|
+
passed = sum(1 for t in correctness_tests if t.get("is_correct", False))
|
|
1121
|
+
total = len(correctness_tests)
|
|
1122
|
+
|
|
1123
|
+
# Sync artifacts if requested
|
|
1124
|
+
artifact_path = None
|
|
1125
|
+
if args.sync_artifacts:
|
|
1126
|
+
local_artifact_dir = Path.cwd() / "wafer_artifacts" / run_dir
|
|
1127
|
+
local_artifact_dir.mkdir(parents=True, exist_ok=True)
|
|
1128
|
+
|
|
1129
|
+
# Download results and logs
|
|
1130
|
+
try:
|
|
1131
|
+
client.download_files(
|
|
1132
|
+
remote_path=f"{run_path}/results.json",
|
|
1133
|
+
local_path=str(local_artifact_dir / "results.json"),
|
|
1134
|
+
)
|
|
1135
|
+
client.download_files(
|
|
1136
|
+
remote_path=log_file,
|
|
1137
|
+
local_path=str(local_artifact_dir / "evaluate.log"),
|
|
1138
|
+
)
|
|
1139
|
+
artifact_path = local_artifact_dir
|
|
1140
|
+
print(f"Artifacts saved to: {artifact_path}")
|
|
1141
|
+
except Exception as e:
|
|
1142
|
+
print(f"Warning: Failed to sync artifacts: {e}")
|
|
1143
|
+
|
|
1144
|
+
return EvaluateResult(
|
|
1145
|
+
success=True,
|
|
1146
|
+
all_correct=backend.get("all_correct", False),
|
|
1147
|
+
correctness_score=backend.get("correctness_score", 0.0),
|
|
1148
|
+
geomean_speedup=backend.get("geomean_speedup", 0.0),
|
|
1149
|
+
passed_tests=passed,
|
|
1150
|
+
total_tests=total,
|
|
1151
|
+
artifact_path=artifact_path,
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
|
|
1155
|
+
def _build_modal_sandbox_script(
|
|
1156
|
+
target: ModalTarget,
|
|
1157
|
+
impl_code_b64: str,
|
|
1158
|
+
ref_code_b64: str,
|
|
1159
|
+
test_cases_b64: str,
|
|
1160
|
+
run_benchmarks: bool,
|
|
1161
|
+
run_defensive: bool,
|
|
1162
|
+
defense_code_b64: str | None = None,
|
|
1163
|
+
) -> str:
|
|
1164
|
+
"""Build Python script to create sandbox and run evaluation.
|
|
1165
|
+
|
|
1166
|
+
This runs in a subprocess to isolate Modal's asyncio from trio.
|
|
1167
|
+
"""
|
|
1168
|
+
gpu_type = target.gpu_type
|
|
1169
|
+
|
|
1170
|
+
# Determine PyTorch index based on GPU type
|
|
1171
|
+
if gpu_type in ("B200", "GB200"):
|
|
1172
|
+
torch_index = "https://download.pytorch.org/whl/nightly/cu128"
|
|
1173
|
+
else:
|
|
1174
|
+
torch_index = "https://download.pytorch.org/whl/cu124"
|
|
1175
|
+
|
|
1176
|
+
return f'''
|
|
1177
|
+
import asyncio
|
|
1178
|
+
import base64
|
|
1179
|
+
import json
|
|
1180
|
+
import sys
|
|
1181
|
+
import modal
|
|
1182
|
+
|
|
1183
|
+
async def run_eval():
|
|
1184
|
+
app = modal.App.lookup("wafer-evaluate", create_if_missing=True)
|
|
1185
|
+
|
|
1186
|
+
# Build image with PyTorch and dependencies
|
|
1187
|
+
image = (
|
|
1188
|
+
modal.Image.from_registry(
|
|
1189
|
+
"nvidia/cuda:12.9.0-devel-ubuntu22.04",
|
|
1190
|
+
add_python="3.12",
|
|
1191
|
+
)
|
|
1192
|
+
.apt_install("git", "build-essential", "cmake")
|
|
1193
|
+
.pip_install(
|
|
1194
|
+
"torch",
|
|
1195
|
+
index_url="{torch_index}",
|
|
1196
|
+
extra_index_url="https://pypi.org/simple",
|
|
1197
|
+
)
|
|
1198
|
+
.pip_install(
|
|
1199
|
+
"numpy",
|
|
1200
|
+
"triton",
|
|
1201
|
+
"ninja",
|
|
1202
|
+
)
|
|
1203
|
+
.env({{
|
|
1204
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
1205
|
+
}})
|
|
1206
|
+
)
|
|
1207
|
+
|
|
1208
|
+
# Create sandbox
|
|
1209
|
+
sandbox = modal.Sandbox.create(
|
|
1210
|
+
app=app,
|
|
1211
|
+
image=image,
|
|
1212
|
+
gpu="{gpu_type}",
|
|
1213
|
+
timeout={target.timeout_seconds},
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
try:
|
|
1217
|
+
# Decode files
|
|
1218
|
+
impl_code = base64.b64decode("{impl_code_b64}").decode()
|
|
1219
|
+
ref_code = base64.b64decode("{ref_code_b64}").decode()
|
|
1220
|
+
test_cases = base64.b64decode("{test_cases_b64}").decode()
|
|
1221
|
+
|
|
1222
|
+
# Write files to sandbox
|
|
1223
|
+
sandbox.exec("mkdir", "-p", "/workspace").wait()
|
|
1224
|
+
|
|
1225
|
+
# Write implementation
|
|
1226
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
1227
|
+
import base64
|
|
1228
|
+
with open('/workspace/kernel.py', 'w') as f:
|
|
1229
|
+
f.write(base64.b64decode('{impl_code_b64}').decode())
|
|
1230
|
+
with open('/workspace/reference.py', 'w') as f:
|
|
1231
|
+
f.write(base64.b64decode('{ref_code_b64}').decode())
|
|
1232
|
+
with open('/workspace/reference_kernel.py', 'w') as f:
|
|
1233
|
+
f.write(base64.b64decode('{ref_code_b64}').decode())
|
|
1234
|
+
with open('/workspace/test_cases.json', 'w') as f:
|
|
1235
|
+
f.write(base64.b64decode('{test_cases_b64}').decode())
|
|
1236
|
+
print('Files written')
|
|
1237
|
+
""")
|
|
1238
|
+
proc.wait()
|
|
1239
|
+
if proc.returncode != 0:
|
|
1240
|
+
print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
|
|
1241
|
+
return
|
|
1242
|
+
|
|
1243
|
+
# Write defense module if defensive mode is enabled
|
|
1244
|
+
# NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
|
|
1245
|
+
if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
|
|
1246
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
1247
|
+
import base64
|
|
1248
|
+
with open('/workspace/defense.py', 'w') as f:
|
|
1249
|
+
f.write(base64.b64decode('{defense_code_b64}').decode())
|
|
1250
|
+
print('Defense module written')
|
|
1251
|
+
""")
|
|
1252
|
+
proc.wait()
|
|
1253
|
+
if proc.returncode != 0:
|
|
1254
|
+
print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
|
|
1255
|
+
return
|
|
1256
|
+
|
|
1257
|
+
# Build inline evaluation script
|
|
1258
|
+
eval_script = """
|
|
1259
|
+
import json
|
|
1260
|
+
import sys
|
|
1261
|
+
import os
|
|
1262
|
+
import importlib.util
|
|
1263
|
+
|
|
1264
|
+
os.chdir('/workspace')
|
|
1265
|
+
sys.path.insert(0, '/workspace')
|
|
1266
|
+
|
|
1267
|
+
# Load test cases
|
|
1268
|
+
with open('test_cases.json') as f:
|
|
1269
|
+
test_cases = json.load(f)
|
|
1270
|
+
|
|
1271
|
+
# Load kernels
|
|
1272
|
+
def load_fn(path, name):
|
|
1273
|
+
spec = importlib.util.spec_from_file_location("mod", path)
|
|
1274
|
+
mod = importlib.util.module_from_spec(spec)
|
|
1275
|
+
spec.loader.exec_module(mod)
|
|
1276
|
+
return getattr(mod, name)
|
|
1277
|
+
|
|
1278
|
+
custom_kernel = load_fn('kernel.py', 'custom_kernel')
|
|
1279
|
+
ref_kernel = load_fn('reference.py', 'ref_kernel')
|
|
1280
|
+
generate_input = load_fn('reference.py', 'generate_input')
|
|
1281
|
+
|
|
1282
|
+
import torch
|
|
1283
|
+
|
|
1284
|
+
# Load defense module if available and defensive mode is enabled
|
|
1285
|
+
run_defensive = {run_defensive}
|
|
1286
|
+
defense = None
|
|
1287
|
+
if run_defensive:
|
|
1288
|
+
try:
|
|
1289
|
+
defense = load_fn('defense.py', 'run_all_defenses')
|
|
1290
|
+
time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
|
|
1291
|
+
print('[Defense] Defense module loaded')
|
|
1292
|
+
|
|
1293
|
+
# Wrap kernels for defense API compatibility
|
|
1294
|
+
# Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
|
|
1295
|
+
# These wrappers repack the unpacked args back into a tuple
|
|
1296
|
+
def _wrap_for_defense(kernel):
|
|
1297
|
+
return lambda *args: kernel(args)
|
|
1298
|
+
custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
|
|
1299
|
+
ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
|
|
1300
|
+
except Exception as e:
|
|
1301
|
+
print(f'[Defense] Warning: Could not load defense module: {{e}}')
|
|
1302
|
+
defense = None
|
|
1303
|
+
|
|
1304
|
+
results = []
|
|
1305
|
+
all_correct = True
|
|
1306
|
+
total_time_ms = 0.0
|
|
1307
|
+
ref_total_time_ms = 0.0
|
|
1308
|
+
|
|
1309
|
+
for tc in test_cases:
|
|
1310
|
+
name = tc.pop('name', 'test')
|
|
1311
|
+
try:
|
|
1312
|
+
inputs = generate_input(**tc)
|
|
1313
|
+
|
|
1314
|
+
# Correctness check - pass inputs as single arg (wafer-core convention)
|
|
1315
|
+
with torch.no_grad():
|
|
1316
|
+
ref_out = ref_kernel(inputs)
|
|
1317
|
+
impl_out = custom_kernel(inputs)
|
|
1318
|
+
|
|
1319
|
+
if isinstance(ref_out, torch.Tensor):
|
|
1320
|
+
correct = torch.allclose(ref_out, impl_out, rtol=1e-3, atol=1e-3)
|
|
1321
|
+
else:
|
|
1322
|
+
correct = ref_out == impl_out
|
|
1323
|
+
|
|
1324
|
+
if not correct:
|
|
1325
|
+
all_correct = False
|
|
1326
|
+
|
|
1327
|
+
# Benchmark if requested
|
|
1328
|
+
impl_time_ms = 0.0
|
|
1329
|
+
ref_time_ms = 0.0
|
|
1330
|
+
if {run_benchmarks}:
|
|
1331
|
+
if run_defensive and defense is not None:
|
|
1332
|
+
# Use full defense suite with wrapped kernels
|
|
1333
|
+
# inputs_list unpacks the tuple so defense can infer dtype/device from tensors
|
|
1334
|
+
inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
|
|
1335
|
+
|
|
1336
|
+
# Run defense checks
|
|
1337
|
+
all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
|
|
1338
|
+
if not all_passed:
|
|
1339
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
1340
|
+
raise ValueError(f"Defense checks failed: {{failed}}")
|
|
1341
|
+
|
|
1342
|
+
# Time with defensive timing (using wrapped kernels)
|
|
1343
|
+
impl_times, _ = time_with_defenses(
|
|
1344
|
+
custom_kernel_for_defense,
|
|
1345
|
+
inputs_list,
|
|
1346
|
+
num_warmup=3,
|
|
1347
|
+
num_trials=10,
|
|
1348
|
+
verbose=False,
|
|
1349
|
+
run_defenses=False,
|
|
1350
|
+
)
|
|
1351
|
+
impl_time_ms = sum(impl_times) / len(impl_times)
|
|
1352
|
+
|
|
1353
|
+
ref_times, _ = time_with_defenses(
|
|
1354
|
+
ref_kernel_for_defense,
|
|
1355
|
+
inputs_list,
|
|
1356
|
+
num_warmup=3,
|
|
1357
|
+
num_trials=10,
|
|
1358
|
+
verbose=False,
|
|
1359
|
+
run_defenses=False,
|
|
1360
|
+
)
|
|
1361
|
+
ref_time_ms = sum(ref_times) / len(ref_times)
|
|
1362
|
+
else:
|
|
1363
|
+
# Standard timing without full defenses
|
|
1364
|
+
# Warmup
|
|
1365
|
+
for _ in range(3):
|
|
1366
|
+
custom_kernel(inputs)
|
|
1367
|
+
torch.cuda.synchronize()
|
|
1368
|
+
|
|
1369
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
1370
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
1371
|
+
start.record()
|
|
1372
|
+
for _ in range(10):
|
|
1373
|
+
custom_kernel(inputs)
|
|
1374
|
+
end.record()
|
|
1375
|
+
torch.cuda.synchronize()
|
|
1376
|
+
impl_time_ms = start.elapsed_time(end) / 10
|
|
1377
|
+
|
|
1378
|
+
# Reference timing
|
|
1379
|
+
for _ in range(3):
|
|
1380
|
+
ref_kernel(inputs)
|
|
1381
|
+
torch.cuda.synchronize()
|
|
1382
|
+
start.record()
|
|
1383
|
+
for _ in range(10):
|
|
1384
|
+
ref_kernel(inputs)
|
|
1385
|
+
end.record()
|
|
1386
|
+
torch.cuda.synchronize()
|
|
1387
|
+
ref_time_ms = start.elapsed_time(end) / 10
|
|
1388
|
+
|
|
1389
|
+
total_time_ms += impl_time_ms
|
|
1390
|
+
ref_total_time_ms += ref_time_ms
|
|
1391
|
+
|
|
1392
|
+
results.append({{
|
|
1393
|
+
'name': name,
|
|
1394
|
+
'correct': correct,
|
|
1395
|
+
'impl_time_ms': impl_time_ms,
|
|
1396
|
+
'ref_time_ms': ref_time_ms,
|
|
1397
|
+
}})
|
|
1398
|
+
|
|
1399
|
+
except Exception as e:
|
|
1400
|
+
results.append({{'name': name, 'correct': False, 'error': str(e)}})
|
|
1401
|
+
all_correct = False
|
|
1402
|
+
|
|
1403
|
+
# Calculate speedup
|
|
1404
|
+
speedup = 0.0
|
|
1405
|
+
if total_time_ms > 0 and ref_total_time_ms > 0:
|
|
1406
|
+
speedup = ref_total_time_ms / total_time_ms
|
|
1407
|
+
|
|
1408
|
+
passed = sum(1 for r in results if r.get('correct', False))
|
|
1409
|
+
total = len(results)
|
|
1410
|
+
|
|
1411
|
+
print(json.dumps({{
|
|
1412
|
+
'success': True,
|
|
1413
|
+
'all_correct': all_correct,
|
|
1414
|
+
'passed': passed,
|
|
1415
|
+
'total': total,
|
|
1416
|
+
'speedup': speedup,
|
|
1417
|
+
'results': results,
|
|
1418
|
+
}}))
|
|
1419
|
+
"""
|
|
1420
|
+
|
|
1421
|
+
# Run evaluation
|
|
1422
|
+
proc = sandbox.exec(
|
|
1423
|
+
"python", "-c", eval_script,
|
|
1424
|
+
timeout={target.timeout_seconds},
|
|
1425
|
+
)
|
|
1426
|
+
proc.wait()
|
|
1427
|
+
|
|
1428
|
+
stdout = proc.stdout.read()
|
|
1429
|
+
stderr = proc.stderr.read()
|
|
1430
|
+
|
|
1431
|
+
if proc.returncode != 0:
|
|
1432
|
+
print(json.dumps({{"error": f"Eval failed: {{stderr or stdout}}"}}))
|
|
1433
|
+
return
|
|
1434
|
+
|
|
1435
|
+
# Forward the result JSON
|
|
1436
|
+
# Find the last JSON line in output
|
|
1437
|
+
for line in reversed(stdout.strip().split("\\n")):
|
|
1438
|
+
if line.startswith("{{"):
|
|
1439
|
+
print(line, flush=True)
|
|
1440
|
+
return
|
|
1441
|
+
|
|
1442
|
+
print(json.dumps({{"error": f"No result JSON in output: {{stdout[:500]}}"}}))
|
|
1443
|
+
|
|
1444
|
+
finally:
|
|
1445
|
+
sandbox.terminate()
|
|
1446
|
+
|
|
1447
|
+
asyncio.run(run_eval())
|
|
1448
|
+
'''
|
|
1449
|
+
|
|
1450
|
+
|
|
1451
|
+
async def run_evaluate_modal(
|
|
1452
|
+
args: EvaluateArgs,
|
|
1453
|
+
target: ModalTarget,
|
|
1454
|
+
) -> EvaluateResult:
|
|
1455
|
+
"""Run evaluation on Modal sandbox.
|
|
1456
|
+
|
|
1457
|
+
Creates a Modal sandbox, uploads files, runs evaluate, and parses results.
|
|
1458
|
+
Uses subprocess to isolate Modal's asyncio from trio.
|
|
1459
|
+
|
|
1460
|
+
Args:
|
|
1461
|
+
args: Evaluate arguments
|
|
1462
|
+
target: Modal target config
|
|
1463
|
+
|
|
1464
|
+
Returns:
|
|
1465
|
+
Evaluation result
|
|
1466
|
+
"""
|
|
1467
|
+
import base64
|
|
1468
|
+
import subprocess
|
|
1469
|
+
import sys
|
|
1470
|
+
|
|
1471
|
+
import trio
|
|
1472
|
+
|
|
1473
|
+
print(f"Creating Modal sandbox ({target.gpu_type})...")
|
|
1474
|
+
|
|
1475
|
+
# Encode files as base64
|
|
1476
|
+
impl_code_b64 = base64.b64encode(args.implementation.read_bytes()).decode()
|
|
1477
|
+
ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
|
|
1478
|
+
test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
|
|
1479
|
+
|
|
1480
|
+
# Encode defense module if defensive mode is enabled
|
|
1481
|
+
defense_code_b64 = None
|
|
1482
|
+
if args.defensive:
|
|
1483
|
+
defense_path = (
|
|
1484
|
+
Path(__file__).parent.parent.parent.parent
|
|
1485
|
+
/ "packages"
|
|
1486
|
+
/ "wafer-core"
|
|
1487
|
+
/ "wafer_core"
|
|
1488
|
+
/ "utils"
|
|
1489
|
+
/ "kernel_utils"
|
|
1490
|
+
/ "defense.py"
|
|
1491
|
+
)
|
|
1492
|
+
if defense_path.exists():
|
|
1493
|
+
defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
|
|
1494
|
+
else:
|
|
1495
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
1496
|
+
|
|
1497
|
+
# Build the script that creates sandbox and runs eval
|
|
1498
|
+
script = _build_modal_sandbox_script(
|
|
1499
|
+
target=target,
|
|
1500
|
+
impl_code_b64=impl_code_b64,
|
|
1501
|
+
ref_code_b64=ref_code_b64,
|
|
1502
|
+
test_cases_b64=test_cases_b64,
|
|
1503
|
+
run_benchmarks=args.benchmark,
|
|
1504
|
+
run_defensive=args.defensive,
|
|
1505
|
+
defense_code_b64=defense_code_b64,
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
def _run_subprocess() -> tuple[str, str, int]:
|
|
1509
|
+
result = subprocess.run(
|
|
1510
|
+
[sys.executable, "-c", script],
|
|
1511
|
+
capture_output=True,
|
|
1512
|
+
text=True,
|
|
1513
|
+
timeout=target.timeout_seconds + 60, # Extra buffer for sandbox creation
|
|
1514
|
+
)
|
|
1515
|
+
return result.stdout, result.stderr, result.returncode
|
|
1516
|
+
|
|
1517
|
+
try:
|
|
1518
|
+
stdout, stderr, returncode = await trio.to_thread.run_sync(_run_subprocess)
|
|
1519
|
+
except subprocess.TimeoutExpired:
|
|
1520
|
+
return EvaluateResult(
|
|
1521
|
+
success=False,
|
|
1522
|
+
all_correct=False,
|
|
1523
|
+
correctness_score=0.0,
|
|
1524
|
+
geomean_speedup=0.0,
|
|
1525
|
+
passed_tests=0,
|
|
1526
|
+
total_tests=0,
|
|
1527
|
+
error_message=f"Modal evaluation timed out after {target.timeout_seconds}s",
|
|
1528
|
+
)
|
|
1529
|
+
except Exception as e:
|
|
1530
|
+
return EvaluateResult(
|
|
1531
|
+
success=False,
|
|
1532
|
+
all_correct=False,
|
|
1533
|
+
correctness_score=0.0,
|
|
1534
|
+
geomean_speedup=0.0,
|
|
1535
|
+
passed_tests=0,
|
|
1536
|
+
total_tests=0,
|
|
1537
|
+
error_message=f"Failed to run Modal sandbox: {e}",
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
if returncode != 0:
|
|
1541
|
+
return EvaluateResult(
|
|
1542
|
+
success=False,
|
|
1543
|
+
all_correct=False,
|
|
1544
|
+
correctness_score=0.0,
|
|
1545
|
+
geomean_speedup=0.0,
|
|
1546
|
+
passed_tests=0,
|
|
1547
|
+
total_tests=0,
|
|
1548
|
+
error_message=f"Modal sandbox failed (exit {returncode}): {stderr or stdout}",
|
|
1549
|
+
)
|
|
1550
|
+
|
|
1551
|
+
# Parse result JSON from stdout
|
|
1552
|
+
result_json = None
|
|
1553
|
+
for line in reversed(stdout.strip().split("\n")):
|
|
1554
|
+
if line.startswith("{"):
|
|
1555
|
+
try:
|
|
1556
|
+
result_json = json.loads(line)
|
|
1557
|
+
break
|
|
1558
|
+
except json.JSONDecodeError:
|
|
1559
|
+
continue
|
|
1560
|
+
|
|
1561
|
+
if result_json is None:
|
|
1562
|
+
return EvaluateResult(
|
|
1563
|
+
success=False,
|
|
1564
|
+
all_correct=False,
|
|
1565
|
+
correctness_score=0.0,
|
|
1566
|
+
geomean_speedup=0.0,
|
|
1567
|
+
passed_tests=0,
|
|
1568
|
+
total_tests=0,
|
|
1569
|
+
error_message=f"No valid JSON result in output: {stdout[:500]}",
|
|
1570
|
+
)
|
|
1571
|
+
|
|
1572
|
+
if "error" in result_json:
|
|
1573
|
+
return EvaluateResult(
|
|
1574
|
+
success=False,
|
|
1575
|
+
all_correct=False,
|
|
1576
|
+
correctness_score=0.0,
|
|
1577
|
+
geomean_speedup=0.0,
|
|
1578
|
+
passed_tests=0,
|
|
1579
|
+
total_tests=0,
|
|
1580
|
+
error_message=result_json["error"],
|
|
1581
|
+
)
|
|
1582
|
+
|
|
1583
|
+
passed = result_json.get("passed", 0)
|
|
1584
|
+
total = result_json.get("total", 0)
|
|
1585
|
+
correctness = passed / total if total > 0 else 0.0
|
|
1586
|
+
|
|
1587
|
+
return EvaluateResult(
|
|
1588
|
+
success=True,
|
|
1589
|
+
all_correct=result_json.get("all_correct", False),
|
|
1590
|
+
correctness_score=correctness,
|
|
1591
|
+
geomean_speedup=result_json.get("speedup", 0.0),
|
|
1592
|
+
passed_tests=passed,
|
|
1593
|
+
total_tests=total,
|
|
1594
|
+
)
|
|
1595
|
+
|
|
1596
|
+
|
|
1597
|
+
def _build_workspace_eval_script(
|
|
1598
|
+
impl_code: str,
|
|
1599
|
+
ref_code: str,
|
|
1600
|
+
test_cases_json: str,
|
|
1601
|
+
run_benchmarks: bool,
|
|
1602
|
+
run_defensive: bool = False,
|
|
1603
|
+
defense_code: str | None = None,
|
|
1604
|
+
) -> str:
|
|
1605
|
+
"""Build inline evaluation script for workspace exec.
|
|
1606
|
+
|
|
1607
|
+
Similar to Modal inline eval, but runs via workspace exec.
|
|
1608
|
+
"""
|
|
1609
|
+
import base64
|
|
1610
|
+
|
|
1611
|
+
impl_b64 = base64.b64encode(impl_code.encode()).decode()
|
|
1612
|
+
ref_b64 = base64.b64encode(ref_code.encode()).decode()
|
|
1613
|
+
tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
|
|
1614
|
+
defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
|
|
1615
|
+
|
|
1616
|
+
return f'''
|
|
1617
|
+
import base64
|
|
1618
|
+
import json
|
|
1619
|
+
import sys
|
|
1620
|
+
import os
|
|
1621
|
+
import importlib.util
|
|
1622
|
+
|
|
1623
|
+
# Decode files
|
|
1624
|
+
impl_code = base64.b64decode("{impl_b64}").decode()
|
|
1625
|
+
ref_code = base64.b64decode("{ref_b64}").decode()
|
|
1626
|
+
test_cases = json.loads(base64.b64decode("{tests_b64}").decode())
|
|
1627
|
+
|
|
1628
|
+
# Write to temp files
|
|
1629
|
+
with open("/tmp/kernel.py", "w") as f:
|
|
1630
|
+
f.write(impl_code)
|
|
1631
|
+
with open("/tmp/reference.py", "w") as f:
|
|
1632
|
+
f.write(ref_code)
|
|
1633
|
+
|
|
1634
|
+
# Write defense module if available
|
|
1635
|
+
run_defensive = {run_defensive}
|
|
1636
|
+
defense_b64 = "{defense_b64}"
|
|
1637
|
+
# NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
|
|
1638
|
+
if run_defensive and defense_b64 and defense_b64 != "None":
|
|
1639
|
+
defense_code = base64.b64decode(defense_b64).decode()
|
|
1640
|
+
with open("/tmp/defense.py", "w") as f:
|
|
1641
|
+
f.write(defense_code)
|
|
1642
|
+
|
|
1643
|
+
# Load kernels
|
|
1644
|
+
def load_fn(path, name):
|
|
1645
|
+
spec = importlib.util.spec_from_file_location("mod", path)
|
|
1646
|
+
mod = importlib.util.module_from_spec(spec)
|
|
1647
|
+
spec.loader.exec_module(mod)
|
|
1648
|
+
return getattr(mod, name)
|
|
1649
|
+
|
|
1650
|
+
custom_kernel = load_fn("/tmp/kernel.py", "custom_kernel")
|
|
1651
|
+
ref_kernel = load_fn("/tmp/reference.py", "ref_kernel")
|
|
1652
|
+
generate_input = load_fn("/tmp/reference.py", "generate_input")
|
|
1653
|
+
|
|
1654
|
+
import torch
|
|
1655
|
+
|
|
1656
|
+
# Load defense module if available
|
|
1657
|
+
defense = None
|
|
1658
|
+
if run_defensive and defense_b64 and defense_b64 != "None":
|
|
1659
|
+
try:
|
|
1660
|
+
defense = load_fn("/tmp/defense.py", "run_all_defenses")
|
|
1661
|
+
time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
|
|
1662
|
+
print("[Defense] Defense module loaded")
|
|
1663
|
+
|
|
1664
|
+
# Wrap kernels for defense API compatibility
|
|
1665
|
+
# Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
|
|
1666
|
+
def _wrap_for_defense(kernel):
|
|
1667
|
+
return lambda *args: kernel(args)
|
|
1668
|
+
custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
|
|
1669
|
+
ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
|
|
1670
|
+
except Exception as e:
|
|
1671
|
+
print(f"[Defense] Warning: Could not load defense module: {{e}}")
|
|
1672
|
+
defense = None
|
|
1673
|
+
|
|
1674
|
+
results = []
|
|
1675
|
+
all_correct = True
|
|
1676
|
+
total_time_ms = 0.0
|
|
1677
|
+
ref_total_time_ms = 0.0
|
|
1678
|
+
|
|
1679
|
+
for tc in test_cases:
|
|
1680
|
+
name = tc.pop("name", "test")
|
|
1681
|
+
try:
|
|
1682
|
+
inputs = generate_input(**tc)
|
|
1683
|
+
|
|
1684
|
+
# Correctness check - pass inputs as single arg (wafer-core convention)
|
|
1685
|
+
with torch.no_grad():
|
|
1686
|
+
ref_out = ref_kernel(inputs)
|
|
1687
|
+
impl_out = custom_kernel(inputs)
|
|
1688
|
+
|
|
1689
|
+
if isinstance(ref_out, torch.Tensor):
|
|
1690
|
+
correct = torch.allclose(ref_out, impl_out, rtol=1e-3, atol=1e-3)
|
|
1691
|
+
else:
|
|
1692
|
+
correct = ref_out == impl_out
|
|
1693
|
+
|
|
1694
|
+
if not correct:
|
|
1695
|
+
all_correct = False
|
|
1696
|
+
|
|
1697
|
+
# Benchmark if requested
|
|
1698
|
+
impl_time_ms = 0.0
|
|
1699
|
+
ref_time_ms = 0.0
|
|
1700
|
+
if {run_benchmarks}:
|
|
1701
|
+
if run_defensive and defense is not None:
|
|
1702
|
+
# Use full defense suite with wrapped kernels
|
|
1703
|
+
inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
|
|
1704
|
+
|
|
1705
|
+
# Run defense checks
|
|
1706
|
+
all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
|
|
1707
|
+
if not all_passed:
|
|
1708
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
1709
|
+
raise ValueError(f"Defense checks failed: {{failed}}")
|
|
1710
|
+
|
|
1711
|
+
# Time with defensive timing (using wrapped kernels)
|
|
1712
|
+
impl_times, _ = time_with_defenses(
|
|
1713
|
+
custom_kernel_for_defense,
|
|
1714
|
+
inputs_list,
|
|
1715
|
+
num_warmup=3,
|
|
1716
|
+
num_trials=10,
|
|
1717
|
+
verbose=False,
|
|
1718
|
+
run_defenses=False,
|
|
1719
|
+
)
|
|
1720
|
+
impl_time_ms = sum(impl_times) / len(impl_times)
|
|
1721
|
+
|
|
1722
|
+
ref_times, _ = time_with_defenses(
|
|
1723
|
+
ref_kernel_for_defense,
|
|
1724
|
+
inputs_list,
|
|
1725
|
+
num_warmup=3,
|
|
1726
|
+
num_trials=10,
|
|
1727
|
+
verbose=False,
|
|
1728
|
+
run_defenses=False,
|
|
1729
|
+
)
|
|
1730
|
+
ref_time_ms = sum(ref_times) / len(ref_times)
|
|
1731
|
+
else:
|
|
1732
|
+
# Standard timing
|
|
1733
|
+
for _ in range(3):
|
|
1734
|
+
custom_kernel(inputs)
|
|
1735
|
+
torch.cuda.synchronize()
|
|
1736
|
+
|
|
1737
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
1738
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
1739
|
+
start.record()
|
|
1740
|
+
for _ in range(10):
|
|
1741
|
+
custom_kernel(inputs)
|
|
1742
|
+
end.record()
|
|
1743
|
+
torch.cuda.synchronize()
|
|
1744
|
+
impl_time_ms = start.elapsed_time(end) / 10
|
|
1745
|
+
|
|
1746
|
+
for _ in range(3):
|
|
1747
|
+
ref_kernel(inputs)
|
|
1748
|
+
torch.cuda.synchronize()
|
|
1749
|
+
start.record()
|
|
1750
|
+
for _ in range(10):
|
|
1751
|
+
ref_kernel(inputs)
|
|
1752
|
+
end.record()
|
|
1753
|
+
torch.cuda.synchronize()
|
|
1754
|
+
ref_time_ms = start.elapsed_time(end) / 10
|
|
1755
|
+
|
|
1756
|
+
total_time_ms += impl_time_ms
|
|
1757
|
+
ref_total_time_ms += ref_time_ms
|
|
1758
|
+
|
|
1759
|
+
results.append({{
|
|
1760
|
+
"name": name,
|
|
1761
|
+
"correct": correct,
|
|
1762
|
+
"impl_time_ms": impl_time_ms,
|
|
1763
|
+
"ref_time_ms": ref_time_ms,
|
|
1764
|
+
}})
|
|
1765
|
+
|
|
1766
|
+
except Exception as e:
|
|
1767
|
+
results.append({{"name": name, "correct": False, "error": str(e)}})
|
|
1768
|
+
all_correct = False
|
|
1769
|
+
|
|
1770
|
+
# Calculate speedup
|
|
1771
|
+
speedup = 0.0
|
|
1772
|
+
if total_time_ms > 0 and ref_total_time_ms > 0:
|
|
1773
|
+
speedup = ref_total_time_ms / total_time_ms
|
|
1774
|
+
|
|
1775
|
+
passed = sum(1 for r in results if r.get("correct", False))
|
|
1776
|
+
total = len(results)
|
|
1777
|
+
|
|
1778
|
+
print(json.dumps({{
|
|
1779
|
+
"success": True,
|
|
1780
|
+
"all_correct": all_correct,
|
|
1781
|
+
"passed": passed,
|
|
1782
|
+
"total": total,
|
|
1783
|
+
"speedup": speedup,
|
|
1784
|
+
"results": results,
|
|
1785
|
+
}}))
|
|
1786
|
+
'''
|
|
1787
|
+
|
|
1788
|
+
|
|
1789
|
+
async def run_evaluate_workspace(
|
|
1790
|
+
args: EvaluateArgs,
|
|
1791
|
+
target: WorkspaceTarget,
|
|
1792
|
+
) -> EvaluateResult:
|
|
1793
|
+
"""Run evaluation on wafer-api managed workspace.
|
|
1794
|
+
|
|
1795
|
+
Uses inline evaluation (no file sync needed) via workspace exec.
|
|
1796
|
+
The eval script is passed as a Python command with base64-encoded files.
|
|
1797
|
+
|
|
1798
|
+
Args:
|
|
1799
|
+
args: Evaluate arguments
|
|
1800
|
+
target: Workspace target config
|
|
1801
|
+
|
|
1802
|
+
Returns:
|
|
1803
|
+
Evaluation result
|
|
1804
|
+
"""
|
|
1805
|
+
import trio
|
|
1806
|
+
|
|
1807
|
+
from .workspaces import exec_command
|
|
1808
|
+
|
|
1809
|
+
print(f"Using workspace: {target.workspace_id}")
|
|
1810
|
+
|
|
1811
|
+
# Read files
|
|
1812
|
+
impl_code = args.implementation.read_text()
|
|
1813
|
+
ref_code = args.reference.read_text()
|
|
1814
|
+
test_cases_json = args.test_cases.read_text()
|
|
1815
|
+
|
|
1816
|
+
# Read defense module if defensive mode is enabled
|
|
1817
|
+
defense_code = None
|
|
1818
|
+
if args.defensive:
|
|
1819
|
+
defense_path = (
|
|
1820
|
+
Path(__file__).parent.parent.parent.parent
|
|
1821
|
+
/ "packages"
|
|
1822
|
+
/ "wafer-core"
|
|
1823
|
+
/ "wafer_core"
|
|
1824
|
+
/ "utils"
|
|
1825
|
+
/ "kernel_utils"
|
|
1826
|
+
/ "defense.py"
|
|
1827
|
+
)
|
|
1828
|
+
if defense_path.exists():
|
|
1829
|
+
defense_code = defense_path.read_text()
|
|
1830
|
+
else:
|
|
1831
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
1832
|
+
|
|
1833
|
+
# Build inline eval script
|
|
1834
|
+
eval_script = _build_workspace_eval_script(
|
|
1835
|
+
impl_code=impl_code,
|
|
1836
|
+
ref_code=ref_code,
|
|
1837
|
+
test_cases_json=test_cases_json,
|
|
1838
|
+
run_benchmarks=args.benchmark,
|
|
1839
|
+
run_defensive=args.defensive,
|
|
1840
|
+
defense_code=defense_code,
|
|
1841
|
+
)
|
|
1842
|
+
|
|
1843
|
+
# Execute via workspace exec
|
|
1844
|
+
# Use python -c with the script
|
|
1845
|
+
eval_cmd = f"python -c {shlex.quote(eval_script)}"
|
|
1846
|
+
|
|
1847
|
+
print("Running evaluation...")
|
|
1848
|
+
|
|
1849
|
+
# Capture stdout by redirecting exec output
|
|
1850
|
+
# exec_command prints to stdout, we need to capture it
|
|
1851
|
+
import io
|
|
1852
|
+
import sys
|
|
1853
|
+
|
|
1854
|
+
captured_output = io.StringIO()
|
|
1855
|
+
original_stdout = sys.stdout
|
|
1856
|
+
|
|
1857
|
+
def _exec() -> int:
|
|
1858
|
+
# Temporarily redirect stdout to capture output
|
|
1859
|
+
sys.stdout = captured_output
|
|
1860
|
+
try:
|
|
1861
|
+
return exec_command(
|
|
1862
|
+
workspace_id=target.workspace_id,
|
|
1863
|
+
command=eval_cmd,
|
|
1864
|
+
timeout_seconds=target.timeout_seconds,
|
|
1865
|
+
)
|
|
1866
|
+
finally:
|
|
1867
|
+
sys.stdout = original_stdout
|
|
1868
|
+
|
|
1869
|
+
try:
|
|
1870
|
+
exit_code = await trio.to_thread.run_sync(_exec)
|
|
1871
|
+
except Exception as e:
|
|
1872
|
+
sys.stdout = original_stdout
|
|
1873
|
+
return EvaluateResult(
|
|
1874
|
+
success=False,
|
|
1875
|
+
all_correct=False,
|
|
1876
|
+
correctness_score=0.0,
|
|
1877
|
+
geomean_speedup=0.0,
|
|
1878
|
+
passed_tests=0,
|
|
1879
|
+
total_tests=0,
|
|
1880
|
+
error_message=f"Execution failed: {e}",
|
|
1881
|
+
)
|
|
1882
|
+
|
|
1883
|
+
# Parse output
|
|
1884
|
+
output = captured_output.getvalue()
|
|
1885
|
+
print(output) # Show output to user
|
|
1886
|
+
|
|
1887
|
+
# Find JSON result in output
|
|
1888
|
+
result_json = None
|
|
1889
|
+
for line in reversed(output.strip().split("\n")):
|
|
1890
|
+
if line.startswith("{"):
|
|
1891
|
+
try:
|
|
1892
|
+
result_json = json.loads(line)
|
|
1893
|
+
break
|
|
1894
|
+
except json.JSONDecodeError:
|
|
1895
|
+
continue
|
|
1896
|
+
|
|
1897
|
+
if result_json is None:
|
|
1898
|
+
if exit_code == 0:
|
|
1899
|
+
return EvaluateResult(
|
|
1900
|
+
success=True,
|
|
1901
|
+
all_correct=True,
|
|
1902
|
+
correctness_score=1.0,
|
|
1903
|
+
geomean_speedup=0.0,
|
|
1904
|
+
passed_tests=0,
|
|
1905
|
+
total_tests=0,
|
|
1906
|
+
)
|
|
1907
|
+
else:
|
|
1908
|
+
return EvaluateResult(
|
|
1909
|
+
success=False,
|
|
1910
|
+
all_correct=False,
|
|
1911
|
+
correctness_score=0.0,
|
|
1912
|
+
geomean_speedup=0.0,
|
|
1913
|
+
passed_tests=0,
|
|
1914
|
+
total_tests=0,
|
|
1915
|
+
error_message=f"Evaluation failed with exit code {exit_code}",
|
|
1916
|
+
)
|
|
1917
|
+
|
|
1918
|
+
if "error" in result_json:
|
|
1919
|
+
return EvaluateResult(
|
|
1920
|
+
success=False,
|
|
1921
|
+
all_correct=False,
|
|
1922
|
+
correctness_score=0.0,
|
|
1923
|
+
geomean_speedup=0.0,
|
|
1924
|
+
passed_tests=0,
|
|
1925
|
+
total_tests=0,
|
|
1926
|
+
error_message=result_json["error"],
|
|
1927
|
+
)
|
|
1928
|
+
|
|
1929
|
+
passed = result_json.get("passed", 0)
|
|
1930
|
+
total = result_json.get("total", 0)
|
|
1931
|
+
correctness = passed / total if total > 0 else 0.0
|
|
1932
|
+
|
|
1933
|
+
return EvaluateResult(
|
|
1934
|
+
success=True,
|
|
1935
|
+
all_correct=result_json.get("all_correct", False),
|
|
1936
|
+
correctness_score=correctness,
|
|
1937
|
+
geomean_speedup=result_json.get("speedup", 0.0),
|
|
1938
|
+
passed_tests=passed,
|
|
1939
|
+
total_tests=total,
|
|
1940
|
+
)
|
|
1941
|
+
|
|
1942
|
+
|
|
1943
|
+
async def run_evaluate_runpod(
|
|
1944
|
+
args: EvaluateArgs,
|
|
1945
|
+
target: RunPodTarget,
|
|
1946
|
+
) -> EvaluateResult:
|
|
1947
|
+
"""Run evaluation on RunPod target.
|
|
1948
|
+
|
|
1949
|
+
Provisions a RunPod pod (or reuses existing), runs evaluation via SSH,
|
|
1950
|
+
then cleans up based on keep_alive setting.
|
|
1951
|
+
|
|
1952
|
+
Sets up a Python venv with ROCm torch using uv, then runs evaluation.
|
|
1953
|
+
|
|
1954
|
+
Args:
|
|
1955
|
+
args: Evaluate arguments
|
|
1956
|
+
target: RunPod target config
|
|
1957
|
+
|
|
1958
|
+
Returns:
|
|
1959
|
+
Evaluation result
|
|
1960
|
+
"""
|
|
1961
|
+
from datetime import datetime
|
|
1962
|
+
|
|
1963
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
1964
|
+
from wafer_core.remote_env import async_setup_python_env
|
|
1965
|
+
from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
|
|
1966
|
+
|
|
1967
|
+
REMOTE_WORKSPACE = "/tmp/wafer_eval"
|
|
1968
|
+
ROCM_TORCH_INDEX_URL = "https://download.pytorch.org/whl/rocm6.2"
|
|
1969
|
+
ROCM_TORCH_VERSION_SUFFIX = "+rocm6.2"
|
|
1970
|
+
|
|
1971
|
+
print(f"Provisioning RunPod ({target.gpu_type_id})...")
|
|
1972
|
+
|
|
1973
|
+
try:
|
|
1974
|
+
async with runpod_ssh_context(target) as ssh_info:
|
|
1975
|
+
ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
|
|
1976
|
+
print(f"Connected to RunPod: {ssh_target}")
|
|
1977
|
+
|
|
1978
|
+
async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
|
|
1979
|
+
# Ensure rsync is installed (needed for file uploads)
|
|
1980
|
+
print("Checking rsync...")
|
|
1981
|
+
result = await client.exec("which rsync || echo 'NOT_FOUND'")
|
|
1982
|
+
if "NOT_FOUND" in result.stdout:
|
|
1983
|
+
print("Installing rsync...")
|
|
1984
|
+
await client.exec("apt-get update && apt-get install -y rsync")
|
|
1985
|
+
|
|
1986
|
+
# Setup Python environment with ROCm torch
|
|
1987
|
+
# Match wafer-core dependencies needed for evaluate.py
|
|
1988
|
+
print("Setting up Python environment with ROCm torch...")
|
|
1989
|
+
requirements = [
|
|
1990
|
+
f"torch==2.5.1{ROCM_TORCH_VERSION_SUFFIX}",
|
|
1991
|
+
"numpy",
|
|
1992
|
+
"ninja",
|
|
1993
|
+
"setuptools",
|
|
1994
|
+
# wafer_core dependencies
|
|
1995
|
+
"trio",
|
|
1996
|
+
"httpx",
|
|
1997
|
+
"pydantic",
|
|
1998
|
+
"anyio",
|
|
1999
|
+
"pyyaml",
|
|
2000
|
+
]
|
|
2001
|
+
|
|
2002
|
+
try:
|
|
2003
|
+
env_state = await async_setup_python_env(
|
|
2004
|
+
client=client,
|
|
2005
|
+
workspace=REMOTE_WORKSPACE,
|
|
2006
|
+
requirements=requirements,
|
|
2007
|
+
python_version=">=3.10",
|
|
2008
|
+
venv_path=".venv",
|
|
2009
|
+
index_url=ROCM_TORCH_INDEX_URL,
|
|
2010
|
+
)
|
|
2011
|
+
python_exe = env_state.venv_python
|
|
2012
|
+
print(f"Using Python: {python_exe}")
|
|
2013
|
+
except Exception as e:
|
|
2014
|
+
return EvaluateResult(
|
|
2015
|
+
success=False,
|
|
2016
|
+
all_correct=False,
|
|
2017
|
+
correctness_score=0.0,
|
|
2018
|
+
geomean_speedup=0.0,
|
|
2019
|
+
passed_tests=0,
|
|
2020
|
+
total_tests=0,
|
|
2021
|
+
error_message=f"Failed to setup Python environment: {e}",
|
|
2022
|
+
)
|
|
2023
|
+
|
|
2024
|
+
# Upload wafer-core to remote
|
|
2025
|
+
try:
|
|
2026
|
+
wafer_root = _get_wafer_root()
|
|
2027
|
+
wafer_core_path = wafer_root / "packages" / "wafer-core"
|
|
2028
|
+
print(f"Uploading wafer-core from {wafer_core_path}...")
|
|
2029
|
+
|
|
2030
|
+
wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
|
|
2031
|
+
await client.exec(f"mkdir -p {wafer_core_remote}")
|
|
2032
|
+
wafer_core_workspace = await client.expand_path(wafer_core_remote)
|
|
2033
|
+
|
|
2034
|
+
upload_result = await client.upload_files(
|
|
2035
|
+
str(wafer_core_path), wafer_core_workspace, recursive=True
|
|
2036
|
+
)
|
|
2037
|
+
|
|
2038
|
+
# Wide event logging for upload result
|
|
2039
|
+
upload_event = {
|
|
2040
|
+
"event": "wafer_core_upload",
|
|
2041
|
+
"target": target.name,
|
|
2042
|
+
"target_type": "runpod",
|
|
2043
|
+
"ssh_host": f"{client.user}@{client.host}:{client.port}",
|
|
2044
|
+
"local_path": str(wafer_core_path),
|
|
2045
|
+
"remote_path": wafer_core_workspace,
|
|
2046
|
+
"success": upload_result.success,
|
|
2047
|
+
"files_copied": upload_result.files_copied,
|
|
2048
|
+
"duration_seconds": upload_result.duration_seconds,
|
|
2049
|
+
"error_message": upload_result.error_message,
|
|
2050
|
+
}
|
|
2051
|
+
if upload_result.debug_info:
|
|
2052
|
+
upload_event["debug_info"] = upload_result.debug_info
|
|
2053
|
+
logger.info(json.dumps(upload_event))
|
|
2054
|
+
|
|
2055
|
+
# Fail fast if upload failed
|
|
2056
|
+
if not upload_result.success:
|
|
2057
|
+
print(f"ERROR: Upload failed: {upload_result.error_message}")
|
|
2058
|
+
if upload_result.debug_info:
|
|
2059
|
+
print(f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}")
|
|
2060
|
+
return EvaluateResult(
|
|
2061
|
+
success=False,
|
|
2062
|
+
all_correct=False,
|
|
2063
|
+
correctness_score=0.0,
|
|
2064
|
+
geomean_speedup=0.0,
|
|
2065
|
+
passed_tests=0,
|
|
2066
|
+
total_tests=0,
|
|
2067
|
+
error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
|
|
2068
|
+
)
|
|
2069
|
+
|
|
2070
|
+
print(f"Uploaded {upload_result.files_copied} files")
|
|
2071
|
+
except Exception as e:
|
|
2072
|
+
return EvaluateResult(
|
|
2073
|
+
success=False,
|
|
2074
|
+
all_correct=False,
|
|
2075
|
+
correctness_score=0.0,
|
|
2076
|
+
geomean_speedup=0.0,
|
|
2077
|
+
passed_tests=0,
|
|
2078
|
+
total_tests=0,
|
|
2079
|
+
error_message=f"Failed to upload wafer-core: {e}",
|
|
2080
|
+
)
|
|
2081
|
+
|
|
2082
|
+
# Select GPU (RunPod pods typically have GPU 0)
|
|
2083
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
2084
|
+
print(f"Using GPU {gpu_id}...")
|
|
2085
|
+
|
|
2086
|
+
# Read local files
|
|
2087
|
+
impl_code = args.implementation.read_text()
|
|
2088
|
+
ref_code = args.reference.read_text()
|
|
2089
|
+
test_cases_data = json.loads(args.test_cases.read_text())
|
|
2090
|
+
|
|
2091
|
+
# Create a unique run directory (uuid for concurrent eval isolation)
|
|
2092
|
+
import uuid
|
|
2093
|
+
|
|
2094
|
+
unique_id = uuid.uuid4().hex[:8]
|
|
2095
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
2096
|
+
run_dir = f"wafer_eval_{timestamp}_{unique_id}"
|
|
2097
|
+
run_path = f"{REMOTE_WORKSPACE}/{run_dir}"
|
|
2098
|
+
|
|
2099
|
+
print("Uploading evaluation files...")
|
|
2100
|
+
|
|
2101
|
+
# Create run directory
|
|
2102
|
+
mkdir_result = await client.exec(f"mkdir -p {run_path}")
|
|
2103
|
+
if mkdir_result.exit_code != 0:
|
|
2104
|
+
return EvaluateResult(
|
|
2105
|
+
success=False,
|
|
2106
|
+
all_correct=False,
|
|
2107
|
+
correctness_score=0.0,
|
|
2108
|
+
geomean_speedup=0.0,
|
|
2109
|
+
passed_tests=0,
|
|
2110
|
+
total_tests=0,
|
|
2111
|
+
error_message=f"Failed to create run directory: {mkdir_result.stderr}",
|
|
2112
|
+
)
|
|
2113
|
+
|
|
2114
|
+
# Write implementation
|
|
2115
|
+
impl_path = f"{run_path}/implementation.py"
|
|
2116
|
+
write_result = await client.exec(
|
|
2117
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
2118
|
+
)
|
|
2119
|
+
if write_result.exit_code != 0:
|
|
2120
|
+
return EvaluateResult(
|
|
2121
|
+
success=False,
|
|
2122
|
+
all_correct=False,
|
|
2123
|
+
correctness_score=0.0,
|
|
2124
|
+
geomean_speedup=0.0,
|
|
2125
|
+
passed_tests=0,
|
|
2126
|
+
total_tests=0,
|
|
2127
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
2128
|
+
)
|
|
2129
|
+
|
|
2130
|
+
# Write reference
|
|
2131
|
+
ref_path = f"{run_path}/reference.py"
|
|
2132
|
+
write_result = await client.exec(
|
|
2133
|
+
f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
2134
|
+
)
|
|
2135
|
+
if write_result.exit_code != 0:
|
|
2136
|
+
return EvaluateResult(
|
|
2137
|
+
success=False,
|
|
2138
|
+
all_correct=False,
|
|
2139
|
+
correctness_score=0.0,
|
|
2140
|
+
geomean_speedup=0.0,
|
|
2141
|
+
passed_tests=0,
|
|
2142
|
+
total_tests=0,
|
|
2143
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
2144
|
+
)
|
|
2145
|
+
|
|
2146
|
+
# Also write as reference_kernel.py (evaluate.py imports generate_input from this)
|
|
2147
|
+
ref_kernel_path = f"{run_path}/reference_kernel.py"
|
|
2148
|
+
write_result = await client.exec(
|
|
2149
|
+
f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
2150
|
+
)
|
|
2151
|
+
if write_result.exit_code != 0:
|
|
2152
|
+
return EvaluateResult(
|
|
2153
|
+
success=False,
|
|
2154
|
+
all_correct=False,
|
|
2155
|
+
correctness_score=0.0,
|
|
2156
|
+
geomean_speedup=0.0,
|
|
2157
|
+
passed_tests=0,
|
|
2158
|
+
total_tests=0,
|
|
2159
|
+
error_message=f"Failed to write reference_kernel: {write_result.stderr}",
|
|
2160
|
+
)
|
|
2161
|
+
|
|
2162
|
+
# Write test cases as JSON
|
|
2163
|
+
test_cases_path = f"{run_path}/test_cases.json"
|
|
2164
|
+
test_cases_json = json.dumps(test_cases_data)
|
|
2165
|
+
write_result = await client.exec(
|
|
2166
|
+
f"cat > '{test_cases_path}' << 'TEST_EOF'\n{test_cases_json}\nTEST_EOF"
|
|
2167
|
+
)
|
|
2168
|
+
if write_result.exit_code != 0:
|
|
2169
|
+
return EvaluateResult(
|
|
2170
|
+
success=False,
|
|
2171
|
+
all_correct=False,
|
|
2172
|
+
correctness_score=0.0,
|
|
2173
|
+
geomean_speedup=0.0,
|
|
2174
|
+
passed_tests=0,
|
|
2175
|
+
total_tests=0,
|
|
2176
|
+
error_message=f"Failed to write test cases: {write_result.stderr}",
|
|
2177
|
+
)
|
|
2178
|
+
|
|
2179
|
+
print("Running evaluation...")
|
|
2180
|
+
|
|
2181
|
+
# Build evaluation command
|
|
2182
|
+
# RunPod ROCm images use HIP_VISIBLE_DEVICES for AMD GPUs
|
|
2183
|
+
# Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
|
|
2184
|
+
venv_bin = env_state.venv_bin
|
|
2185
|
+
env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
|
|
2186
|
+
|
|
2187
|
+
# Run from run_path so reference_kernel.py is importable
|
|
2188
|
+
# Use installed wafer-core module
|
|
2189
|
+
eval_cmd = (
|
|
2190
|
+
f"cd {run_path} && "
|
|
2191
|
+
f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
|
|
2192
|
+
f"--implementation {impl_path} "
|
|
2193
|
+
f"--reference {ref_path} "
|
|
2194
|
+
f"--test-cases {test_cases_path} "
|
|
2195
|
+
f"--run-dir {run_path}"
|
|
2196
|
+
)
|
|
2197
|
+
|
|
2198
|
+
if args.benchmark:
|
|
2199
|
+
eval_cmd += " --benchmark"
|
|
2200
|
+
if args.defensive:
|
|
2201
|
+
eval_cmd += " --defensive"
|
|
2202
|
+
|
|
2203
|
+
# Run with timeout
|
|
2204
|
+
import trio
|
|
2205
|
+
|
|
2206
|
+
with trio.move_on_after(target.eval_timeout) as cancel_scope:
|
|
2207
|
+
result = await client.exec(eval_cmd)
|
|
2208
|
+
|
|
2209
|
+
if cancel_scope.cancelled_caught:
|
|
2210
|
+
return EvaluateResult(
|
|
2211
|
+
success=False,
|
|
2212
|
+
all_correct=False,
|
|
2213
|
+
correctness_score=0.0,
|
|
2214
|
+
geomean_speedup=0.0,
|
|
2215
|
+
passed_tests=0,
|
|
2216
|
+
total_tests=0,
|
|
2217
|
+
error_message=f"Evaluation timed out after {target.eval_timeout}s",
|
|
2218
|
+
)
|
|
2219
|
+
|
|
2220
|
+
# Parse output
|
|
2221
|
+
stdout = result.stdout
|
|
2222
|
+
stderr = result.stderr
|
|
2223
|
+
|
|
2224
|
+
if result.exit_code != 0:
|
|
2225
|
+
return EvaluateResult(
|
|
2226
|
+
success=False,
|
|
2227
|
+
all_correct=False,
|
|
2228
|
+
correctness_score=0.0,
|
|
2229
|
+
geomean_speedup=0.0,
|
|
2230
|
+
passed_tests=0,
|
|
2231
|
+
total_tests=0,
|
|
2232
|
+
error_message=f"Evaluation failed:\nstdout: {stdout}\nstderr: {stderr}",
|
|
2233
|
+
)
|
|
2234
|
+
|
|
2235
|
+
# Find JSON result in output
|
|
2236
|
+
result_json = None
|
|
2237
|
+
for line in reversed(stdout.strip().split("\n")):
|
|
2238
|
+
if line.startswith("{"):
|
|
2239
|
+
try:
|
|
2240
|
+
result_json = json.loads(line)
|
|
2241
|
+
break
|
|
2242
|
+
except json.JSONDecodeError:
|
|
2243
|
+
continue
|
|
2244
|
+
|
|
2245
|
+
if result_json is None:
|
|
2246
|
+
return EvaluateResult(
|
|
2247
|
+
success=False,
|
|
2248
|
+
all_correct=False,
|
|
2249
|
+
correctness_score=0.0,
|
|
2250
|
+
geomean_speedup=0.0,
|
|
2251
|
+
passed_tests=0,
|
|
2252
|
+
total_tests=0,
|
|
2253
|
+
error_message=f"No JSON result in output:\n{stdout}",
|
|
2254
|
+
)
|
|
2255
|
+
|
|
2256
|
+
if "error" in result_json:
|
|
2257
|
+
return EvaluateResult(
|
|
2258
|
+
success=False,
|
|
2259
|
+
all_correct=False,
|
|
2260
|
+
correctness_score=0.0,
|
|
2261
|
+
geomean_speedup=0.0,
|
|
2262
|
+
passed_tests=0,
|
|
2263
|
+
total_tests=0,
|
|
2264
|
+
error_message=result_json["error"],
|
|
2265
|
+
)
|
|
2266
|
+
|
|
2267
|
+
passed = result_json.get("passed", 0)
|
|
2268
|
+
total = result_json.get("total", 0)
|
|
2269
|
+
correctness = passed / total if total > 0 else 0.0
|
|
2270
|
+
|
|
2271
|
+
return EvaluateResult(
|
|
2272
|
+
success=True,
|
|
2273
|
+
all_correct=result_json.get("all_correct", False),
|
|
2274
|
+
correctness_score=correctness,
|
|
2275
|
+
geomean_speedup=result_json.get("speedup", 0.0),
|
|
2276
|
+
passed_tests=passed,
|
|
2277
|
+
total_tests=total,
|
|
2278
|
+
)
|
|
2279
|
+
|
|
2280
|
+
except RunPodError as e:
|
|
2281
|
+
return EvaluateResult(
|
|
2282
|
+
success=False,
|
|
2283
|
+
all_correct=False,
|
|
2284
|
+
correctness_score=0.0,
|
|
2285
|
+
geomean_speedup=0.0,
|
|
2286
|
+
passed_tests=0,
|
|
2287
|
+
total_tests=0,
|
|
2288
|
+
error_message=f"RunPod error: {e}",
|
|
2289
|
+
)
|
|
2290
|
+
|
|
2291
|
+
|
|
2292
|
+
async def run_evaluate_digitalocean(
|
|
2293
|
+
args: EvaluateArgs,
|
|
2294
|
+
target: DigitalOceanTarget,
|
|
2295
|
+
) -> EvaluateResult:
|
|
2296
|
+
"""Run evaluation on DigitalOcean target.
|
|
2297
|
+
|
|
2298
|
+
Provisions a DigitalOcean droplet (or reuses existing), bootstraps Python
|
|
2299
|
+
environment with uv, runs evaluation via SSH, then cleans up based on
|
|
2300
|
+
keep_alive setting.
|
|
2301
|
+
|
|
2302
|
+
Args:
|
|
2303
|
+
args: Evaluate arguments
|
|
2304
|
+
target: DigitalOcean target config
|
|
2305
|
+
|
|
2306
|
+
Returns:
|
|
2307
|
+
Evaluation result
|
|
2308
|
+
"""
|
|
2309
|
+
from datetime import datetime
|
|
2310
|
+
|
|
2311
|
+
import trio_asyncio
|
|
2312
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
2313
|
+
from wafer_core.remote_env import async_setup_python_env
|
|
2314
|
+
from wafer_core.targets.digitalocean import DigitalOceanError, digitalocean_ssh_context
|
|
2315
|
+
|
|
2316
|
+
REMOTE_WORKSPACE = "/tmp/wafer_eval"
|
|
2317
|
+
ROCM_TORCH_INDEX_URL = "https://download.pytorch.org/whl/rocm6.2"
|
|
2318
|
+
ROCM_TORCH_VERSION_SUFFIX = "+rocm6.2"
|
|
2319
|
+
|
|
2320
|
+
print(f"Provisioning DigitalOcean droplet ({target.size_slug})...")
|
|
2321
|
+
|
|
2322
|
+
try:
|
|
2323
|
+
async with digitalocean_ssh_context(target) as ssh_info:
|
|
2324
|
+
ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
|
|
2325
|
+
print(f"Connected to DigitalOcean: {ssh_target}")
|
|
2326
|
+
|
|
2327
|
+
# Need trio_asyncio for AsyncSSHClient
|
|
2328
|
+
async with trio_asyncio.open_loop():
|
|
2329
|
+
async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
|
|
2330
|
+
# Ensure rsync and ninja are installed
|
|
2331
|
+
# ninja is needed for torch.utils.cpp_extension (HIP kernel compilation)
|
|
2332
|
+
print("Checking system dependencies...")
|
|
2333
|
+
result = await client.exec("which rsync && which ninja || echo 'MISSING'")
|
|
2334
|
+
if "MISSING" in result.stdout:
|
|
2335
|
+
print("Installing rsync and ninja...")
|
|
2336
|
+
await client.exec("apt-get update && apt-get install -y rsync ninja-build")
|
|
2337
|
+
|
|
2338
|
+
# Setup Python environment with ROCm torch
|
|
2339
|
+
# Match wafer-core dependencies needed for evaluate.py
|
|
2340
|
+
print("Setting up Python environment with ROCm torch...")
|
|
2341
|
+
requirements = [
|
|
2342
|
+
f"torch==2.5.1{ROCM_TORCH_VERSION_SUFFIX}",
|
|
2343
|
+
"numpy",
|
|
2344
|
+
"ninja",
|
|
2345
|
+
"setuptools",
|
|
2346
|
+
# wafer_core dependencies
|
|
2347
|
+
"trio",
|
|
2348
|
+
"httpx",
|
|
2349
|
+
"pydantic",
|
|
2350
|
+
"anyio",
|
|
2351
|
+
"pyyaml",
|
|
2352
|
+
]
|
|
2353
|
+
|
|
2354
|
+
try:
|
|
2355
|
+
env_state = await async_setup_python_env(
|
|
2356
|
+
client=client,
|
|
2357
|
+
workspace=REMOTE_WORKSPACE,
|
|
2358
|
+
requirements=requirements,
|
|
2359
|
+
python_version="3.10",
|
|
2360
|
+
venv_path=".venv",
|
|
2361
|
+
index_url=ROCM_TORCH_INDEX_URL,
|
|
2362
|
+
)
|
|
2363
|
+
python_exe = env_state.venv_python
|
|
2364
|
+
print(f"Using Python: {python_exe}")
|
|
2365
|
+
except Exception as e:
|
|
2366
|
+
return EvaluateResult(
|
|
2367
|
+
success=False,
|
|
2368
|
+
all_correct=False,
|
|
2369
|
+
correctness_score=0.0,
|
|
2370
|
+
geomean_speedup=0.0,
|
|
2371
|
+
passed_tests=0,
|
|
2372
|
+
total_tests=0,
|
|
2373
|
+
error_message=f"Failed to setup Python environment: {e}",
|
|
2374
|
+
)
|
|
2375
|
+
|
|
2376
|
+
# Upload wafer-core to remote
|
|
2377
|
+
try:
|
|
2378
|
+
wafer_root = _get_wafer_root()
|
|
2379
|
+
wafer_core_path = wafer_root / "packages" / "wafer-core"
|
|
2380
|
+
print(f"Uploading wafer-core from {wafer_core_path}...")
|
|
2381
|
+
|
|
2382
|
+
wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
|
|
2383
|
+
await client.exec(f"mkdir -p {wafer_core_remote}")
|
|
2384
|
+
wafer_core_workspace = await client.expand_path(wafer_core_remote)
|
|
2385
|
+
|
|
2386
|
+
# Use SFTP instead of rsync to avoid SSH subprocess timeout issues
|
|
2387
|
+
# (DigitalOcean may rate-limit new SSH connections)
|
|
2388
|
+
upload_result = await client.upload_files(
|
|
2389
|
+
str(wafer_core_path),
|
|
2390
|
+
wafer_core_workspace,
|
|
2391
|
+
recursive=True,
|
|
2392
|
+
use_sftp=True,
|
|
2393
|
+
)
|
|
2394
|
+
|
|
2395
|
+
# Wide event logging for upload result
|
|
2396
|
+
upload_event = {
|
|
2397
|
+
"event": "wafer_core_upload",
|
|
2398
|
+
"target": target.name,
|
|
2399
|
+
"target_type": "digitalocean",
|
|
2400
|
+
"ssh_host": f"{client.user}@{client.host}:{client.port}",
|
|
2401
|
+
"local_path": str(wafer_core_path),
|
|
2402
|
+
"remote_path": wafer_core_workspace,
|
|
2403
|
+
"success": upload_result.success,
|
|
2404
|
+
"files_copied": upload_result.files_copied,
|
|
2405
|
+
"duration_seconds": upload_result.duration_seconds,
|
|
2406
|
+
"error_message": upload_result.error_message,
|
|
2407
|
+
}
|
|
2408
|
+
if upload_result.debug_info:
|
|
2409
|
+
upload_event["debug_info"] = upload_result.debug_info
|
|
2410
|
+
logger.info(json.dumps(upload_event))
|
|
2411
|
+
|
|
2412
|
+
# Fail fast if upload failed
|
|
2413
|
+
if not upload_result.success:
|
|
2414
|
+
print(f"ERROR: Upload failed: {upload_result.error_message}")
|
|
2415
|
+
if upload_result.debug_info:
|
|
2416
|
+
print(
|
|
2417
|
+
f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}"
|
|
2418
|
+
)
|
|
2419
|
+
return EvaluateResult(
|
|
2420
|
+
success=False,
|
|
2421
|
+
all_correct=False,
|
|
2422
|
+
correctness_score=0.0,
|
|
2423
|
+
geomean_speedup=0.0,
|
|
2424
|
+
passed_tests=0,
|
|
2425
|
+
total_tests=0,
|
|
2426
|
+
error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
|
|
2427
|
+
)
|
|
2428
|
+
|
|
2429
|
+
print(f"Uploaded {upload_result.files_copied} files")
|
|
2430
|
+
except Exception as e:
|
|
2431
|
+
return EvaluateResult(
|
|
2432
|
+
success=False,
|
|
2433
|
+
all_correct=False,
|
|
2434
|
+
correctness_score=0.0,
|
|
2435
|
+
geomean_speedup=0.0,
|
|
2436
|
+
passed_tests=0,
|
|
2437
|
+
total_tests=0,
|
|
2438
|
+
error_message=f"Failed to upload wafer-core: {e}",
|
|
2439
|
+
)
|
|
2440
|
+
|
|
2441
|
+
# Select GPU (DigitalOcean droplets typically have GPU 0)
|
|
2442
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
2443
|
+
print(f"Using GPU {gpu_id}...")
|
|
2444
|
+
|
|
2445
|
+
# Read local files
|
|
2446
|
+
impl_code = args.implementation.read_text()
|
|
2447
|
+
ref_code = args.reference.read_text()
|
|
2448
|
+
test_cases_data = json.loads(args.test_cases.read_text())
|
|
2449
|
+
|
|
2450
|
+
# Create a unique run directory (uuid for concurrent eval isolation)
|
|
2451
|
+
import uuid
|
|
2452
|
+
|
|
2453
|
+
unique_id = uuid.uuid4().hex[:8]
|
|
2454
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
2455
|
+
run_dir = f"wafer_eval_{timestamp}_{unique_id}"
|
|
2456
|
+
run_path = f"{REMOTE_WORKSPACE}/{run_dir}"
|
|
2457
|
+
|
|
2458
|
+
print("Uploading evaluation files...")
|
|
2459
|
+
|
|
2460
|
+
# Create run directory
|
|
2461
|
+
mkdir_result = await client.exec(f"mkdir -p {run_path}")
|
|
2462
|
+
if mkdir_result.exit_code != 0:
|
|
2463
|
+
return EvaluateResult(
|
|
2464
|
+
success=False,
|
|
2465
|
+
all_correct=False,
|
|
2466
|
+
correctness_score=0.0,
|
|
2467
|
+
geomean_speedup=0.0,
|
|
2468
|
+
passed_tests=0,
|
|
2469
|
+
total_tests=0,
|
|
2470
|
+
error_message=f"Failed to create run directory: {mkdir_result.stderr}",
|
|
2471
|
+
)
|
|
2472
|
+
|
|
2473
|
+
# Write implementation
|
|
2474
|
+
impl_path = f"{run_path}/implementation.py"
|
|
2475
|
+
write_result = await client.exec(
|
|
2476
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
2477
|
+
)
|
|
2478
|
+
if write_result.exit_code != 0:
|
|
2479
|
+
return EvaluateResult(
|
|
2480
|
+
success=False,
|
|
2481
|
+
all_correct=False,
|
|
2482
|
+
correctness_score=0.0,
|
|
2483
|
+
geomean_speedup=0.0,
|
|
2484
|
+
passed_tests=0,
|
|
2485
|
+
total_tests=0,
|
|
2486
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
2487
|
+
)
|
|
2488
|
+
|
|
2489
|
+
# Write reference
|
|
2490
|
+
ref_path = f"{run_path}/reference.py"
|
|
2491
|
+
write_result = await client.exec(
|
|
2492
|
+
f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
2493
|
+
)
|
|
2494
|
+
if write_result.exit_code != 0:
|
|
2495
|
+
return EvaluateResult(
|
|
2496
|
+
success=False,
|
|
2497
|
+
all_correct=False,
|
|
2498
|
+
correctness_score=0.0,
|
|
2499
|
+
geomean_speedup=0.0,
|
|
2500
|
+
passed_tests=0,
|
|
2501
|
+
total_tests=0,
|
|
2502
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
2503
|
+
)
|
|
2504
|
+
|
|
2505
|
+
# Also write as reference_kernel.py (evaluate.py imports generate_input from this)
|
|
2506
|
+
ref_kernel_path = f"{run_path}/reference_kernel.py"
|
|
2507
|
+
write_result = await client.exec(
|
|
2508
|
+
f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
2509
|
+
)
|
|
2510
|
+
if write_result.exit_code != 0:
|
|
2511
|
+
return EvaluateResult(
|
|
2512
|
+
success=False,
|
|
2513
|
+
all_correct=False,
|
|
2514
|
+
correctness_score=0.0,
|
|
2515
|
+
geomean_speedup=0.0,
|
|
2516
|
+
passed_tests=0,
|
|
2517
|
+
total_tests=0,
|
|
2518
|
+
error_message=f"Failed to write reference_kernel: {write_result.stderr}",
|
|
2519
|
+
)
|
|
2520
|
+
|
|
2521
|
+
# Write test cases as JSON
|
|
2522
|
+
test_cases_path = f"{run_path}/test_cases.json"
|
|
2523
|
+
test_cases_json = json.dumps(test_cases_data)
|
|
2524
|
+
write_result = await client.exec(
|
|
2525
|
+
f"cat > '{test_cases_path}' << 'TEST_EOF'\n{test_cases_json}\nTEST_EOF"
|
|
2526
|
+
)
|
|
2527
|
+
if write_result.exit_code != 0:
|
|
2528
|
+
return EvaluateResult(
|
|
2529
|
+
success=False,
|
|
2530
|
+
all_correct=False,
|
|
2531
|
+
correctness_score=0.0,
|
|
2532
|
+
geomean_speedup=0.0,
|
|
2533
|
+
passed_tests=0,
|
|
2534
|
+
total_tests=0,
|
|
2535
|
+
error_message=f"Failed to write test cases: {write_result.stderr}",
|
|
2536
|
+
)
|
|
2537
|
+
|
|
2538
|
+
print("Running evaluation...")
|
|
2539
|
+
|
|
2540
|
+
# Build evaluation command
|
|
2541
|
+
# DigitalOcean AMD uses HIP_VISIBLE_DEVICES for AMD GPUs
|
|
2542
|
+
# Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
|
|
2543
|
+
venv_bin = env_state.venv_bin
|
|
2544
|
+
env_vars = (
|
|
2545
|
+
f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
|
|
2546
|
+
)
|
|
2547
|
+
|
|
2548
|
+
# Run from run_path so reference_kernel.py is importable
|
|
2549
|
+
# Use installed wafer-core module
|
|
2550
|
+
eval_cmd = (
|
|
2551
|
+
f"cd {run_path} && "
|
|
2552
|
+
f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
|
|
2553
|
+
f"--implementation {impl_path} "
|
|
2554
|
+
f"--reference {ref_path} "
|
|
2555
|
+
f"--test-cases {test_cases_path} "
|
|
2556
|
+
f"--run-dir {run_path}"
|
|
2557
|
+
)
|
|
2558
|
+
|
|
2559
|
+
if args.benchmark:
|
|
2560
|
+
eval_cmd += " --benchmark"
|
|
2561
|
+
if args.defensive:
|
|
2562
|
+
eval_cmd += " --defensive"
|
|
2563
|
+
|
|
2564
|
+
# Run with timeout
|
|
2565
|
+
import trio
|
|
2566
|
+
|
|
2567
|
+
with trio.move_on_after(target.eval_timeout) as cancel_scope:
|
|
2568
|
+
result = await client.exec(eval_cmd)
|
|
2569
|
+
|
|
2570
|
+
if cancel_scope.cancelled_caught:
|
|
2571
|
+
return EvaluateResult(
|
|
2572
|
+
success=False,
|
|
2573
|
+
all_correct=False,
|
|
2574
|
+
correctness_score=0.0,
|
|
2575
|
+
geomean_speedup=0.0,
|
|
2576
|
+
passed_tests=0,
|
|
2577
|
+
total_tests=0,
|
|
2578
|
+
error_message=f"Evaluation timed out after {target.eval_timeout}s",
|
|
2579
|
+
)
|
|
2580
|
+
|
|
2581
|
+
# Show output to user
|
|
2582
|
+
stdout = result.stdout
|
|
2583
|
+
stderr = result.stderr
|
|
2584
|
+
if stdout:
|
|
2585
|
+
print(stdout)
|
|
2586
|
+
|
|
2587
|
+
if result.exit_code != 0:
|
|
2588
|
+
# Include both stdout and stderr for debugging
|
|
2589
|
+
error_parts = [f"Evaluation failed (exit code {result.exit_code}):"]
|
|
2590
|
+
if stdout:
|
|
2591
|
+
error_parts.append(f"stdout: {stdout}")
|
|
2592
|
+
if stderr:
|
|
2593
|
+
error_parts.append(f"stderr: {stderr}")
|
|
2594
|
+
return EvaluateResult(
|
|
2595
|
+
success=False,
|
|
2596
|
+
all_correct=False,
|
|
2597
|
+
correctness_score=0.0,
|
|
2598
|
+
geomean_speedup=0.0,
|
|
2599
|
+
passed_tests=0,
|
|
2600
|
+
total_tests=0,
|
|
2601
|
+
error_message="\n".join(error_parts),
|
|
2602
|
+
)
|
|
2603
|
+
|
|
2604
|
+
# Read results from results.json (like SSH path)
|
|
2605
|
+
results_path = f"{run_path}/results.json"
|
|
2606
|
+
cat_result = await client.exec(f"cat {results_path}")
|
|
2607
|
+
|
|
2608
|
+
if cat_result.exit_code != 0:
|
|
2609
|
+
return EvaluateResult(
|
|
2610
|
+
success=False,
|
|
2611
|
+
all_correct=False,
|
|
2612
|
+
correctness_score=0.0,
|
|
2613
|
+
geomean_speedup=0.0,
|
|
2614
|
+
passed_tests=0,
|
|
2615
|
+
total_tests=0,
|
|
2616
|
+
error_message=f"Failed to read results: {cat_result.stderr}",
|
|
2617
|
+
)
|
|
2618
|
+
|
|
2619
|
+
try:
|
|
2620
|
+
results_data = json.loads(cat_result.stdout)
|
|
2621
|
+
except json.JSONDecodeError as e:
|
|
2622
|
+
return EvaluateResult(
|
|
2623
|
+
success=False,
|
|
2624
|
+
all_correct=False,
|
|
2625
|
+
correctness_score=0.0,
|
|
2626
|
+
geomean_speedup=0.0,
|
|
2627
|
+
passed_tests=0,
|
|
2628
|
+
total_tests=0,
|
|
2629
|
+
error_message=f"Invalid JSON in results: {e}",
|
|
2630
|
+
)
|
|
2631
|
+
|
|
2632
|
+
# Extract backend results (same format as SSH path)
|
|
2633
|
+
backends = results_data.get("backends", [])
|
|
2634
|
+
if not backends:
|
|
2635
|
+
return EvaluateResult(
|
|
2636
|
+
success=False,
|
|
2637
|
+
all_correct=False,
|
|
2638
|
+
correctness_score=0.0,
|
|
2639
|
+
geomean_speedup=0.0,
|
|
2640
|
+
passed_tests=0,
|
|
2641
|
+
total_tests=0,
|
|
2642
|
+
error_message="No backend results found",
|
|
2643
|
+
)
|
|
2644
|
+
|
|
2645
|
+
backend = backends[0]
|
|
2646
|
+
correctness_tests = backend.get("correctness_tests", [])
|
|
2647
|
+
passed = sum(1 for t in correctness_tests if t.get("is_correct", False))
|
|
2648
|
+
total = len(correctness_tests)
|
|
2649
|
+
|
|
2650
|
+
return EvaluateResult(
|
|
2651
|
+
success=True,
|
|
2652
|
+
all_correct=backend.get("all_correct", False),
|
|
2653
|
+
correctness_score=backend.get("correctness_score", 0.0),
|
|
2654
|
+
geomean_speedup=backend.get("geomean_speedup", 0.0),
|
|
2655
|
+
passed_tests=passed,
|
|
2656
|
+
total_tests=total,
|
|
2657
|
+
)
|
|
2658
|
+
|
|
2659
|
+
except DigitalOceanError as e:
|
|
2660
|
+
return EvaluateResult(
|
|
2661
|
+
success=False,
|
|
2662
|
+
all_correct=False,
|
|
2663
|
+
correctness_score=0.0,
|
|
2664
|
+
geomean_speedup=0.0,
|
|
2665
|
+
passed_tests=0,
|
|
2666
|
+
total_tests=0,
|
|
2667
|
+
error_message=f"DigitalOcean error: {e}",
|
|
2668
|
+
)
|
|
2669
|
+
|
|
2670
|
+
|
|
2671
|
+
async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
|
|
2672
|
+
"""Run evaluation on configured target.
|
|
2673
|
+
|
|
2674
|
+
Args:
|
|
2675
|
+
args: Evaluate arguments
|
|
2676
|
+
|
|
2677
|
+
Returns:
|
|
2678
|
+
Evaluation result
|
|
2679
|
+
"""
|
|
2680
|
+
from .targets import get_default_target, load_target
|
|
2681
|
+
|
|
2682
|
+
# Validate input files
|
|
2683
|
+
err = _validate_files(args)
|
|
2684
|
+
if err:
|
|
2685
|
+
return EvaluateResult(
|
|
2686
|
+
success=False,
|
|
2687
|
+
all_correct=False,
|
|
2688
|
+
correctness_score=0.0,
|
|
2689
|
+
geomean_speedup=0.0,
|
|
2690
|
+
passed_tests=0,
|
|
2691
|
+
total_tests=0,
|
|
2692
|
+
error_message=err,
|
|
2693
|
+
)
|
|
2694
|
+
|
|
2695
|
+
# Load target
|
|
2696
|
+
target_name = args.target_name
|
|
2697
|
+
if not target_name:
|
|
2698
|
+
target_name = get_default_target()
|
|
2699
|
+
if not target_name:
|
|
2700
|
+
return EvaluateResult(
|
|
2701
|
+
success=False,
|
|
2702
|
+
all_correct=False,
|
|
2703
|
+
correctness_score=0.0,
|
|
2704
|
+
geomean_speedup=0.0,
|
|
2705
|
+
passed_tests=0,
|
|
2706
|
+
total_tests=0,
|
|
2707
|
+
error_message=(
|
|
2708
|
+
"No target specified and no default set.\n"
|
|
2709
|
+
"Set up a target first:\n"
|
|
2710
|
+
" wafer config targets init ssh --name my-gpu --host user@host:22\n"
|
|
2711
|
+
" wafer config targets init runpod --gpu MI300X\n"
|
|
2712
|
+
"Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
|
|
2713
|
+
),
|
|
2714
|
+
)
|
|
2715
|
+
|
|
2716
|
+
try:
|
|
2717
|
+
target = load_target(target_name)
|
|
2718
|
+
except FileNotFoundError:
|
|
2719
|
+
return EvaluateResult(
|
|
2720
|
+
success=False,
|
|
2721
|
+
all_correct=False,
|
|
2722
|
+
correctness_score=0.0,
|
|
2723
|
+
geomean_speedup=0.0,
|
|
2724
|
+
passed_tests=0,
|
|
2725
|
+
total_tests=0,
|
|
2726
|
+
error_message=f"Target not found: {target_name}. Run: wafer config targets list",
|
|
2727
|
+
)
|
|
2728
|
+
|
|
2729
|
+
print(f"Using target: {target_name}")
|
|
2730
|
+
|
|
2731
|
+
# Dispatch to appropriate executor
|
|
2732
|
+
if isinstance(target, LocalTarget):
|
|
2733
|
+
return await run_evaluate_local(args, target)
|
|
2734
|
+
elif isinstance(target, BaremetalTarget | VMTarget):
|
|
2735
|
+
return await run_evaluate_ssh(args, target)
|
|
2736
|
+
elif isinstance(target, ModalTarget):
|
|
2737
|
+
return await run_evaluate_modal(args, target)
|
|
2738
|
+
elif isinstance(target, WorkspaceTarget):
|
|
2739
|
+
return await run_evaluate_workspace(args, target)
|
|
2740
|
+
elif isinstance(target, RunPodTarget):
|
|
2741
|
+
return await run_evaluate_runpod(args, target)
|
|
2742
|
+
elif isinstance(target, DigitalOceanTarget):
|
|
2743
|
+
return await run_evaluate_digitalocean(args, target)
|
|
2744
|
+
else:
|
|
2745
|
+
return EvaluateResult(
|
|
2746
|
+
success=False,
|
|
2747
|
+
all_correct=False,
|
|
2748
|
+
correctness_score=0.0,
|
|
2749
|
+
geomean_speedup=0.0,
|
|
2750
|
+
passed_tests=0,
|
|
2751
|
+
total_tests=0,
|
|
2752
|
+
error_message=f"Unknown target type: {type(target)}",
|
|
2753
|
+
)
|
|
2754
|
+
|
|
2755
|
+
|
|
2756
|
+
# =============================================================================
|
|
2757
|
+
# KernelBench Format Evaluation
|
|
2758
|
+
# =============================================================================
|
|
2759
|
+
|
|
2760
|
+
# Inline evaluation script for KernelBench format
|
|
2761
|
+
# This runs inside the Docker container on the remote GPU
|
|
2762
|
+
KERNELBENCH_EVAL_SCRIPT = """
|
|
2763
|
+
import gc
|
|
2764
|
+
import json
|
|
2765
|
+
import os
|
|
2766
|
+
import sys
|
|
2767
|
+
import time
|
|
2768
|
+
import torch
|
|
2769
|
+
import torch.nn as nn
|
|
2770
|
+
from pathlib import Path
|
|
2771
|
+
|
|
2772
|
+
# Use a unique per-run PyTorch extension cache directory to ensure fresh compilation.
|
|
2773
|
+
# This prevents stale cached extensions from being loaded when the pod is reused.
|
|
2774
|
+
# Without this, if a kernel is modified but uses the same extension name,
|
|
2775
|
+
# PyTorch would load the old cached .so instead of recompiling.
|
|
2776
|
+
# We use a UUID-based directory instead of clearing the cache to avoid race conditions
|
|
2777
|
+
# with other processes that might be using the cache.
|
|
2778
|
+
import uuid
|
|
2779
|
+
unique_cache_dir = f"/tmp/torch_extensions_{uuid.uuid4().hex[:8]}"
|
|
2780
|
+
os.environ["TORCH_EXTENSIONS_DIR"] = unique_cache_dir
|
|
2781
|
+
print(f"[KernelBench] Using unique extension cache: {unique_cache_dir}")
|
|
2782
|
+
|
|
2783
|
+
# Clear any stale GPU memory from previous runs at startup
|
|
2784
|
+
# NOTE: empty_cache only frees memory from THIS process's PyTorch allocator.
|
|
2785
|
+
# It won't free memory from dead/zombie processes - rocm-smi --showpids can show
|
|
2786
|
+
# PIDs that no longer exist but still hold GPU memory. Those require a GPU reset
|
|
2787
|
+
# (rocm-smi --gpureset) to fully clear. TODO: detect and warn about orphaned memory.
|
|
2788
|
+
if torch.cuda.is_available():
|
|
2789
|
+
gc.collect()
|
|
2790
|
+
torch.cuda.empty_cache()
|
|
2791
|
+
torch.cuda.reset_peak_memory_stats()
|
|
2792
|
+
|
|
2793
|
+
|
|
2794
|
+
def _calculate_timing_stats(times: list[float]) -> dict:
|
|
2795
|
+
'''Calculate median and IQR from timing samples.
|
|
2796
|
+
|
|
2797
|
+
Returns dict with median, iqr_low (25th percentile), iqr_high (75th percentile),
|
|
2798
|
+
mean, min, max, and std.
|
|
2799
|
+
'''
|
|
2800
|
+
import statistics
|
|
2801
|
+
|
|
2802
|
+
if not times:
|
|
2803
|
+
return {"median": 0, "iqr_low": 0, "iqr_high": 0, "mean": 0, "min": 0, "max": 0, "std": 0}
|
|
2804
|
+
|
|
2805
|
+
sorted_times = sorted(times)
|
|
2806
|
+
n = len(sorted_times)
|
|
2807
|
+
|
|
2808
|
+
# Median
|
|
2809
|
+
median = statistics.median(sorted_times)
|
|
2810
|
+
|
|
2811
|
+
# Quartiles (25th and 75th percentile)
|
|
2812
|
+
# For small samples, use simple interpolation
|
|
2813
|
+
q1_idx = (n - 1) * 0.25
|
|
2814
|
+
q3_idx = (n - 1) * 0.75
|
|
2815
|
+
|
|
2816
|
+
q1_low = int(q1_idx)
|
|
2817
|
+
q1_frac = q1_idx - q1_low
|
|
2818
|
+
iqr_low = sorted_times[q1_low] * (1 - q1_frac) + sorted_times[min(q1_low + 1, n - 1)] * q1_frac
|
|
2819
|
+
|
|
2820
|
+
q3_low = int(q3_idx)
|
|
2821
|
+
q3_frac = q3_idx - q3_low
|
|
2822
|
+
iqr_high = sorted_times[q3_low] * (1 - q3_frac) + sorted_times[min(q3_low + 1, n - 1)] * q3_frac
|
|
2823
|
+
|
|
2824
|
+
return {
|
|
2825
|
+
"median": median,
|
|
2826
|
+
"iqr_low": iqr_low,
|
|
2827
|
+
"iqr_high": iqr_high,
|
|
2828
|
+
"mean": statistics.mean(sorted_times),
|
|
2829
|
+
"min": min(sorted_times),
|
|
2830
|
+
"max": max(sorted_times),
|
|
2831
|
+
"std": statistics.stdev(sorted_times) if n > 1 else 0,
|
|
2832
|
+
}
|
|
2833
|
+
|
|
2834
|
+
|
|
2835
|
+
def run_profiling(model, inputs, name, output_dir):
|
|
2836
|
+
'''Run torch.profiler and return summary stats.'''
|
|
2837
|
+
from torch.profiler import profile, ProfilerActivity
|
|
2838
|
+
|
|
2839
|
+
# Determine activities based on backend
|
|
2840
|
+
activities = [ProfilerActivity.CPU]
|
|
2841
|
+
if torch.cuda.is_available():
|
|
2842
|
+
activities.append(ProfilerActivity.CUDA)
|
|
2843
|
+
|
|
2844
|
+
# Warmup
|
|
2845
|
+
for _ in range(3):
|
|
2846
|
+
with torch.no_grad():
|
|
2847
|
+
_ = model(*inputs)
|
|
2848
|
+
torch.cuda.synchronize()
|
|
2849
|
+
|
|
2850
|
+
# Profile
|
|
2851
|
+
with profile(
|
|
2852
|
+
activities=activities,
|
|
2853
|
+
record_shapes=True,
|
|
2854
|
+
with_stack=False,
|
|
2855
|
+
profile_memory=True,
|
|
2856
|
+
) as prof:
|
|
2857
|
+
with torch.no_grad():
|
|
2858
|
+
_ = model(*inputs)
|
|
2859
|
+
torch.cuda.synchronize()
|
|
2860
|
+
|
|
2861
|
+
# Get key averages
|
|
2862
|
+
key_averages = prof.key_averages()
|
|
2863
|
+
|
|
2864
|
+
# Find the main kernel (longest GPU time)
|
|
2865
|
+
# Use cuda_time_total for compatibility with both CUDA and ROCm
|
|
2866
|
+
def get_gpu_time(e):
|
|
2867
|
+
# Try different attributes for GPU time
|
|
2868
|
+
if hasattr(e, 'cuda_time_total'):
|
|
2869
|
+
return e.cuda_time_total
|
|
2870
|
+
if hasattr(e, 'device_time_total'):
|
|
2871
|
+
return e.device_time_total
|
|
2872
|
+
if hasattr(e, 'self_cuda_time_total'):
|
|
2873
|
+
return e.self_cuda_time_total
|
|
2874
|
+
return 0
|
|
2875
|
+
|
|
2876
|
+
gpu_events = [e for e in key_averages if get_gpu_time(e) > 0]
|
|
2877
|
+
gpu_events.sort(key=lambda e: get_gpu_time(e), reverse=True)
|
|
2878
|
+
|
|
2879
|
+
stats = {
|
|
2880
|
+
"name": name,
|
|
2881
|
+
"total_gpu_time_ms": sum(get_gpu_time(e) for e in gpu_events) / 1000,
|
|
2882
|
+
"total_cpu_time_ms": sum(e.cpu_time_total for e in key_averages) / 1000,
|
|
2883
|
+
"num_gpu_kernels": len(gpu_events),
|
|
2884
|
+
"top_kernels": [],
|
|
2885
|
+
}
|
|
2886
|
+
|
|
2887
|
+
# Top 5 kernels by GPU time
|
|
2888
|
+
for e in gpu_events[:5]:
|
|
2889
|
+
stats["top_kernels"].append({
|
|
2890
|
+
"name": e.key,
|
|
2891
|
+
"gpu_time_ms": get_gpu_time(e) / 1000,
|
|
2892
|
+
"cpu_time_ms": e.cpu_time_total / 1000,
|
|
2893
|
+
"calls": e.count,
|
|
2894
|
+
})
|
|
2895
|
+
|
|
2896
|
+
# Save trace for visualization
|
|
2897
|
+
trace_path = Path(output_dir) / f"{name}_trace.json"
|
|
2898
|
+
prof.export_chrome_trace(str(trace_path))
|
|
2899
|
+
stats["trace_file"] = str(trace_path)
|
|
2900
|
+
|
|
2901
|
+
return stats
|
|
2902
|
+
|
|
2903
|
+
|
|
2904
|
+
def validate_custom_inputs(original_inputs, custom_inputs):
|
|
2905
|
+
'''Validate that custom inputs match the expected signature.
|
|
2906
|
+
|
|
2907
|
+
Returns (is_valid, error_message).
|
|
2908
|
+
'''
|
|
2909
|
+
if len(original_inputs) != len(custom_inputs):
|
|
2910
|
+
return False, f"get_inputs() must return {len(original_inputs)} tensors, got {len(custom_inputs)}"
|
|
2911
|
+
|
|
2912
|
+
for i, (orig, cust) in enumerate(zip(original_inputs, custom_inputs)):
|
|
2913
|
+
if not isinstance(cust, torch.Tensor):
|
|
2914
|
+
if not isinstance(orig, torch.Tensor):
|
|
2915
|
+
continue # Both non-tensor, ok
|
|
2916
|
+
return False, f"Input {i}: expected Tensor, got {type(cust).__name__}"
|
|
2917
|
+
|
|
2918
|
+
if not isinstance(orig, torch.Tensor):
|
|
2919
|
+
return False, f"Input {i}: expected {type(orig).__name__}, got Tensor"
|
|
2920
|
+
|
|
2921
|
+
if orig.dtype != cust.dtype:
|
|
2922
|
+
return False, f"Input {i}: dtype mismatch - expected {orig.dtype}, got {cust.dtype}"
|
|
2923
|
+
|
|
2924
|
+
if orig.dim() != cust.dim():
|
|
2925
|
+
return False, f"Input {i}: dimension mismatch - expected {orig.dim()}D, got {cust.dim()}D"
|
|
2926
|
+
|
|
2927
|
+
return True, None
|
|
2928
|
+
|
|
2929
|
+
|
|
2930
|
+
def analyze_diff(ref_output, new_output, rtol=1e-3, atol=1e-3, max_samples=5):
|
|
2931
|
+
'''Analyze differences between reference and implementation outputs.
|
|
2932
|
+
|
|
2933
|
+
Returns a dict with detailed diff information.
|
|
2934
|
+
'''
|
|
2935
|
+
diff = (ref_output - new_output).abs()
|
|
2936
|
+
threshold = atol + rtol * ref_output.abs()
|
|
2937
|
+
wrong_mask = diff > threshold
|
|
2938
|
+
|
|
2939
|
+
total_elements = ref_output.numel()
|
|
2940
|
+
wrong_count = wrong_mask.sum().item()
|
|
2941
|
+
|
|
2942
|
+
# Basic stats
|
|
2943
|
+
max_diff = diff.max().item()
|
|
2944
|
+
max_diff_idx = tuple(torch.unravel_index(diff.argmax(), diff.shape))
|
|
2945
|
+
max_diff_idx = tuple(int(i) for i in max_diff_idx) # Convert to Python ints
|
|
2946
|
+
|
|
2947
|
+
# Relative error (avoid div by zero)
|
|
2948
|
+
ref_abs = ref_output.abs()
|
|
2949
|
+
nonzero_mask = ref_abs > 1e-8
|
|
2950
|
+
if nonzero_mask.any():
|
|
2951
|
+
rel_error = diff[nonzero_mask] / ref_abs[nonzero_mask]
|
|
2952
|
+
max_rel_error = rel_error.max().item()
|
|
2953
|
+
mean_rel_error = rel_error.mean().item()
|
|
2954
|
+
else:
|
|
2955
|
+
max_rel_error = float('inf') if max_diff > 0 else 0.0
|
|
2956
|
+
mean_rel_error = max_rel_error
|
|
2957
|
+
|
|
2958
|
+
# Error histogram (buckets: <1e-6, 1e-6 to 1e-4, 1e-4 to 1e-2, 1e-2 to 1, >1)
|
|
2959
|
+
histogram = {
|
|
2960
|
+
'<1e-6': int((diff < 1e-6).sum().item()),
|
|
2961
|
+
'1e-6 to 1e-4': int(((diff >= 1e-6) & (diff < 1e-4)).sum().item()),
|
|
2962
|
+
'1e-4 to 1e-2': int(((diff >= 1e-4) & (diff < 1e-2)).sum().item()),
|
|
2963
|
+
'1e-2 to 1': int(((diff >= 1e-2) & (diff < 1)).sum().item()),
|
|
2964
|
+
'>1': int((diff >= 1).sum().item()),
|
|
2965
|
+
}
|
|
2966
|
+
|
|
2967
|
+
result = {
|
|
2968
|
+
'max_diff': max_diff,
|
|
2969
|
+
'max_diff_idx': max_diff_idx,
|
|
2970
|
+
'mean_diff': diff.mean().item(),
|
|
2971
|
+
'max_rel_error': max_rel_error,
|
|
2972
|
+
'mean_rel_error': mean_rel_error,
|
|
2973
|
+
'total_elements': total_elements,
|
|
2974
|
+
'wrong_count': int(wrong_count),
|
|
2975
|
+
'wrong_pct': 100.0 * wrong_count / total_elements,
|
|
2976
|
+
'histogram': histogram,
|
|
2977
|
+
'samples': [],
|
|
2978
|
+
}
|
|
2979
|
+
|
|
2980
|
+
# Get indices of wrong elements
|
|
2981
|
+
if wrong_count > 0:
|
|
2982
|
+
wrong_indices = torch.nonzero(wrong_mask, as_tuple=False)
|
|
2983
|
+
|
|
2984
|
+
# Take first N samples
|
|
2985
|
+
num_samples = min(max_samples, len(wrong_indices))
|
|
2986
|
+
for i in range(num_samples):
|
|
2987
|
+
idx = tuple(wrong_indices[i].tolist())
|
|
2988
|
+
ref_val = ref_output[idx].item()
|
|
2989
|
+
new_val = new_output[idx].item()
|
|
2990
|
+
diff_val = diff[idx].item()
|
|
2991
|
+
result['samples'].append({
|
|
2992
|
+
'index': idx,
|
|
2993
|
+
'ref': ref_val,
|
|
2994
|
+
'impl': new_val,
|
|
2995
|
+
'diff': diff_val,
|
|
2996
|
+
})
|
|
2997
|
+
|
|
2998
|
+
# Try to detect pattern
|
|
2999
|
+
if wrong_count >= total_elements * 0.99:
|
|
3000
|
+
result['pattern'] = 'all_wrong'
|
|
3001
|
+
elif wrong_count < total_elements * 0.01:
|
|
3002
|
+
# Check if failures are at boundaries
|
|
3003
|
+
shape = ref_output.shape
|
|
3004
|
+
boundary_count = 0
|
|
3005
|
+
for idx in wrong_indices[:min(100, len(wrong_indices))]:
|
|
3006
|
+
idx_list = idx.tolist()
|
|
3007
|
+
is_boundary = any(i == 0 or i == s - 1 for i, s in zip(idx_list, shape))
|
|
3008
|
+
if is_boundary:
|
|
3009
|
+
boundary_count += 1
|
|
3010
|
+
if boundary_count > len(wrong_indices[:100]) * 0.8:
|
|
3011
|
+
result['pattern'] = 'boundary_issue'
|
|
3012
|
+
else:
|
|
3013
|
+
result['pattern'] = 'scattered'
|
|
3014
|
+
else:
|
|
3015
|
+
result['pattern'] = 'partial'
|
|
3016
|
+
|
|
3017
|
+
return result
|
|
3018
|
+
|
|
3019
|
+
|
|
3020
|
+
def print_diff_analysis(analysis):
|
|
3021
|
+
'''Print a human-readable diff analysis.'''
|
|
3022
|
+
print(f"[KernelBench] Diff analysis:")
|
|
3023
|
+
|
|
3024
|
+
# Max diff with location
|
|
3025
|
+
idx_str = ','.join(str(i) for i in analysis['max_diff_idx'])
|
|
3026
|
+
print(f" Max diff: {analysis['max_diff']:.6f} at index [{idx_str}]")
|
|
3027
|
+
print(f" Mean diff: {analysis['mean_diff']:.6f}")
|
|
3028
|
+
|
|
3029
|
+
# Relative errors
|
|
3030
|
+
print(f" Max relative error: {analysis['max_rel_error']:.2%}, Mean: {analysis['mean_rel_error']:.2%}")
|
|
3031
|
+
|
|
3032
|
+
# Wrong count
|
|
3033
|
+
print(f" Wrong elements: {analysis['wrong_count']:,} / {analysis['total_elements']:,} ({analysis['wrong_pct']:.2f}%)")
|
|
3034
|
+
|
|
3035
|
+
# Histogram
|
|
3036
|
+
hist = analysis['histogram']
|
|
3037
|
+
print(f" Error distribution: <1e-6: {hist['<1e-6']:,} | 1e-6~1e-4: {hist['1e-6 to 1e-4']:,} | 1e-4~1e-2: {hist['1e-4 to 1e-2']:,} | 1e-2~1: {hist['1e-2 to 1']:,} | >1: {hist['>1']:,}")
|
|
3038
|
+
|
|
3039
|
+
if 'pattern' in analysis:
|
|
3040
|
+
pattern_desc = {
|
|
3041
|
+
'all_wrong': 'ALL elements wrong - likely algorithmic error or wrong weights',
|
|
3042
|
+
'boundary_issue': 'Mostly BOUNDARY elements wrong - check edge handling',
|
|
3043
|
+
'scattered': 'SCATTERED failures - numerical precision issue?',
|
|
3044
|
+
'partial': 'PARTIAL failures - check specific conditions',
|
|
3045
|
+
}
|
|
3046
|
+
print(f" Pattern: {pattern_desc.get(analysis['pattern'], analysis['pattern'])}")
|
|
3047
|
+
|
|
3048
|
+
if analysis['samples']:
|
|
3049
|
+
print(f" Sample failures:")
|
|
3050
|
+
for s in analysis['samples']:
|
|
3051
|
+
idx_str = ','.join(str(i) for i in s['index'])
|
|
3052
|
+
print(f" [{idx_str}]: ref={s['ref']:.6f} impl={s['impl']:.6f} (diff={s['diff']:.6f})")
|
|
3053
|
+
|
|
3054
|
+
|
|
3055
|
+
def main():
|
|
3056
|
+
# Parse args
|
|
3057
|
+
import argparse
|
|
3058
|
+
parser = argparse.ArgumentParser()
|
|
3059
|
+
parser.add_argument("--impl", required=True)
|
|
3060
|
+
parser.add_argument("--reference", required=True)
|
|
3061
|
+
parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
|
|
3062
|
+
parser.add_argument("--benchmark", action="store_true")
|
|
3063
|
+
parser.add_argument("--profile", action="store_true")
|
|
3064
|
+
parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
|
|
3065
|
+
parser.add_argument("--defense-module", help="Path to defense.py module")
|
|
3066
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
|
3067
|
+
parser.add_argument("--num-correct-trials", type=int, default=3)
|
|
3068
|
+
parser.add_argument("--num-perf-trials", type=int, default=10)
|
|
3069
|
+
parser.add_argument("--output", required=True)
|
|
3070
|
+
parser.add_argument("--stages", default="compile,correctness",
|
|
3071
|
+
help="Comma-separated stages: compile, correctness, benchmark, defense")
|
|
3072
|
+
args = parser.parse_args()
|
|
3073
|
+
|
|
3074
|
+
# Parse stages
|
|
3075
|
+
stages = set(args.stages.split(","))
|
|
3076
|
+
run_compile = "compile" in stages
|
|
3077
|
+
run_correctness = "correctness" in stages
|
|
3078
|
+
run_benchmark = "benchmark" in stages or args.benchmark
|
|
3079
|
+
run_defense = "defense" in stages or args.defensive
|
|
3080
|
+
print(f"[KernelBench] Stages: {args.stages}")
|
|
3081
|
+
|
|
3082
|
+
# Load defense module if defensive mode is enabled
|
|
3083
|
+
defense_module = None
|
|
3084
|
+
if args.defensive and args.defense_module:
|
|
3085
|
+
try:
|
|
3086
|
+
import importlib.util
|
|
3087
|
+
defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
|
|
3088
|
+
defense_module = importlib.util.module_from_spec(defense_spec)
|
|
3089
|
+
defense_spec.loader.exec_module(defense_module)
|
|
3090
|
+
print("[KernelBench] Defense module loaded")
|
|
3091
|
+
except Exception as e:
|
|
3092
|
+
print(f"[KernelBench] Warning: Could not load defense module: {e}")
|
|
3093
|
+
|
|
3094
|
+
# Create output directory for profiles
|
|
3095
|
+
output_dir = Path(args.output).parent
|
|
3096
|
+
profile_dir = output_dir / "profiles"
|
|
3097
|
+
if args.profile:
|
|
3098
|
+
profile_dir.mkdir(exist_ok=True)
|
|
3099
|
+
|
|
3100
|
+
results = {
|
|
3101
|
+
"compiled": False,
|
|
3102
|
+
"correct": False,
|
|
3103
|
+
"speedup": None,
|
|
3104
|
+
"runtime_ms": None,
|
|
3105
|
+
"reference_runtime_ms": None,
|
|
3106
|
+
"error": None,
|
|
3107
|
+
}
|
|
3108
|
+
|
|
3109
|
+
try:
|
|
3110
|
+
# Load reference module
|
|
3111
|
+
import importlib.util
|
|
3112
|
+
ref_spec = importlib.util.spec_from_file_location("reference", args.reference)
|
|
3113
|
+
ref_module = importlib.util.module_from_spec(ref_spec)
|
|
3114
|
+
ref_spec.loader.exec_module(ref_module)
|
|
3115
|
+
|
|
3116
|
+
Model = ref_module.Model
|
|
3117
|
+
get_inputs = ref_module.get_inputs
|
|
3118
|
+
get_init_inputs = ref_module.get_init_inputs
|
|
3119
|
+
|
|
3120
|
+
# Load custom inputs if provided
|
|
3121
|
+
if args.inputs:
|
|
3122
|
+
inputs_spec = importlib.util.spec_from_file_location("custom_inputs", args.inputs)
|
|
3123
|
+
inputs_module = importlib.util.module_from_spec(inputs_spec)
|
|
3124
|
+
inputs_spec.loader.exec_module(inputs_module)
|
|
3125
|
+
|
|
3126
|
+
# Validate custom inputs match expected signature
|
|
3127
|
+
original_inputs = get_inputs()
|
|
3128
|
+
custom_get_inputs = inputs_module.get_inputs
|
|
3129
|
+
custom_inputs = custom_get_inputs()
|
|
3130
|
+
|
|
3131
|
+
is_valid, error_msg = validate_custom_inputs(original_inputs, custom_inputs)
|
|
3132
|
+
if not is_valid:
|
|
3133
|
+
print(f"[KernelBench] Custom inputs validation failed: {error_msg}")
|
|
3134
|
+
results["error"] = f"Custom inputs validation failed: {error_msg}"
|
|
3135
|
+
raise ValueError(error_msg)
|
|
3136
|
+
|
|
3137
|
+
# Override get_inputs (and optionally get_init_inputs)
|
|
3138
|
+
get_inputs = custom_get_inputs
|
|
3139
|
+
if hasattr(inputs_module, 'get_init_inputs'):
|
|
3140
|
+
get_init_inputs = inputs_module.get_init_inputs
|
|
3141
|
+
|
|
3142
|
+
# Show what changed
|
|
3143
|
+
orig_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in original_inputs]
|
|
3144
|
+
cust_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in custom_inputs]
|
|
3145
|
+
print(f"[KernelBench] Using custom inputs: {orig_shapes} -> {cust_shapes}")
|
|
3146
|
+
|
|
3147
|
+
# Load implementation module
|
|
3148
|
+
impl_spec = importlib.util.spec_from_file_location("implementation", args.impl)
|
|
3149
|
+
impl_module = importlib.util.module_from_spec(impl_spec)
|
|
3150
|
+
impl_spec.loader.exec_module(impl_module)
|
|
3151
|
+
|
|
3152
|
+
ModelNew = impl_module.ModelNew
|
|
3153
|
+
results["compiled"] = True
|
|
3154
|
+
print("[KernelBench] Modules loaded successfully")
|
|
3155
|
+
|
|
3156
|
+
# Instantiate models with synchronized seeds for reproducible weights
|
|
3157
|
+
# (matches upstream KernelBench behavior in src/eval.py)
|
|
3158
|
+
seed = args.seed
|
|
3159
|
+
init_inputs = get_init_inputs()
|
|
3160
|
+
with torch.no_grad():
|
|
3161
|
+
torch.manual_seed(seed)
|
|
3162
|
+
torch.cuda.manual_seed(seed)
|
|
3163
|
+
ref_model = Model(*init_inputs).cuda().eval()
|
|
3164
|
+
|
|
3165
|
+
torch.manual_seed(seed)
|
|
3166
|
+
torch.cuda.manual_seed(seed)
|
|
3167
|
+
new_model = ModelNew(*init_inputs).cuda().eval()
|
|
3168
|
+
print(f"[KernelBench] Models instantiated (seed={seed})")
|
|
3169
|
+
|
|
3170
|
+
# Run correctness trials (if stage enabled)
|
|
3171
|
+
all_correct = True
|
|
3172
|
+
if not run_correctness:
|
|
3173
|
+
print("[KernelBench] Skipping correctness (not in stages)")
|
|
3174
|
+
results["correct"] = None # Unknown - not checked
|
|
3175
|
+
else:
|
|
3176
|
+
for trial in range(args.num_correct_trials):
|
|
3177
|
+
inputs = get_inputs()
|
|
3178
|
+
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
3179
|
+
|
|
3180
|
+
with torch.no_grad():
|
|
3181
|
+
ref_output = ref_model(*inputs)
|
|
3182
|
+
new_output = new_model(*inputs)
|
|
3183
|
+
|
|
3184
|
+
# Compare outputs
|
|
3185
|
+
if isinstance(ref_output, torch.Tensor):
|
|
3186
|
+
if not torch.allclose(ref_output, new_output, rtol=1e-3, atol=1e-3):
|
|
3187
|
+
all_correct = False
|
|
3188
|
+
analysis = analyze_diff(ref_output, new_output)
|
|
3189
|
+
results["error"] = f"Correctness failed on trial {trial+1}: max diff = {analysis['max_diff']}"
|
|
3190
|
+
results["diff_analysis"] = analysis
|
|
3191
|
+
print_diff_analysis(analysis)
|
|
3192
|
+
|
|
3193
|
+
# Save tensors for debugging
|
|
3194
|
+
debug_dir = output_dir / "debug"
|
|
3195
|
+
debug_dir.mkdir(exist_ok=True)
|
|
3196
|
+
torch.save(ref_output.cpu(), debug_dir / "ref_output.pt")
|
|
3197
|
+
torch.save(new_output.cpu(), debug_dir / "impl_output.pt")
|
|
3198
|
+
torch.save(inputs[0].cpu() if inputs else None, debug_dir / "input.pt")
|
|
3199
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
3200
|
+
break
|
|
3201
|
+
else:
|
|
3202
|
+
# Handle tuple/list outputs
|
|
3203
|
+
for i, (r, n) in enumerate(zip(ref_output, new_output)):
|
|
3204
|
+
if isinstance(r, torch.Tensor):
|
|
3205
|
+
if not torch.allclose(r, n, rtol=1e-3, atol=1e-3):
|
|
3206
|
+
all_correct = False
|
|
3207
|
+
analysis = analyze_diff(r, n)
|
|
3208
|
+
results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {analysis['max_diff']}"
|
|
3209
|
+
results["diff_analysis"] = analysis
|
|
3210
|
+
print_diff_analysis(analysis)
|
|
3211
|
+
|
|
3212
|
+
# Save tensors for debugging
|
|
3213
|
+
debug_dir = output_dir / "debug"
|
|
3214
|
+
debug_dir.mkdir(exist_ok=True)
|
|
3215
|
+
torch.save(r.cpu(), debug_dir / f"ref_output_{i}.pt")
|
|
3216
|
+
torch.save(n.cpu(), debug_dir / f"impl_output_{i}.pt")
|
|
3217
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
3218
|
+
break
|
|
3219
|
+
if not all_correct:
|
|
3220
|
+
break
|
|
3221
|
+
|
|
3222
|
+
results["correct"] = all_correct
|
|
3223
|
+
print(f"[KernelBench] Correctness: {all_correct}")
|
|
3224
|
+
|
|
3225
|
+
# Run benchmark if stage enabled (and correctness passed or skipped)
|
|
3226
|
+
should_benchmark = run_benchmark and (all_correct or not run_correctness)
|
|
3227
|
+
if should_benchmark:
|
|
3228
|
+
print("[KernelBench] Running benchmarks...")
|
|
3229
|
+
inputs = get_inputs()
|
|
3230
|
+
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
3231
|
+
|
|
3232
|
+
if run_defense and defense_module is not None:
|
|
3233
|
+
# Use full defense suite
|
|
3234
|
+
print("[KernelBench] Running defense checks on implementation...")
|
|
3235
|
+
run_all_defenses = defense_module.run_all_defenses
|
|
3236
|
+
time_with_defenses = defense_module.time_execution_with_defenses
|
|
3237
|
+
|
|
3238
|
+
# Run defense checks on implementation
|
|
3239
|
+
all_passed, defense_results, _ = run_all_defenses(
|
|
3240
|
+
lambda *x: new_model(*x),
|
|
3241
|
+
*inputs,
|
|
3242
|
+
)
|
|
3243
|
+
results["defense_results"] = {
|
|
3244
|
+
name: {"passed": passed, "message": msg}
|
|
3245
|
+
for name, passed, msg in defense_results
|
|
3246
|
+
}
|
|
3247
|
+
if not all_passed:
|
|
3248
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
3249
|
+
results["error"] = f"Defense checks failed: {failed}"
|
|
3250
|
+
print(f"[KernelBench] Defense checks FAILED: {failed}")
|
|
3251
|
+
for name, passed, msg in defense_results:
|
|
3252
|
+
status = "PASS" if passed else "FAIL"
|
|
3253
|
+
print(f" [{status}] {name}: {msg}")
|
|
3254
|
+
else:
|
|
3255
|
+
print("[KernelBench] All defense checks passed")
|
|
3256
|
+
|
|
3257
|
+
# Time with defensive timing
|
|
3258
|
+
impl_times, _ = time_with_defenses(
|
|
3259
|
+
lambda: new_model(*inputs),
|
|
3260
|
+
[],
|
|
3261
|
+
num_warmup=5,
|
|
3262
|
+
num_trials=args.num_perf_trials,
|
|
3263
|
+
verbose=False,
|
|
3264
|
+
run_defenses=False, # Already ran above
|
|
3265
|
+
)
|
|
3266
|
+
# Calculate stats for new model
|
|
3267
|
+
new_stats = _calculate_timing_stats(impl_times)
|
|
3268
|
+
results["runtime_ms"] = new_stats["median"]
|
|
3269
|
+
results["runtime_stats"] = new_stats
|
|
3270
|
+
|
|
3271
|
+
# Reference timing
|
|
3272
|
+
ref_times, _ = time_with_defenses(
|
|
3273
|
+
lambda: ref_model(*inputs),
|
|
3274
|
+
[],
|
|
3275
|
+
num_warmup=5,
|
|
3276
|
+
num_trials=args.num_perf_trials,
|
|
3277
|
+
verbose=False,
|
|
3278
|
+
run_defenses=False,
|
|
3279
|
+
)
|
|
3280
|
+
ref_stats = _calculate_timing_stats(ref_times)
|
|
3281
|
+
results["reference_runtime_ms"] = ref_stats["median"]
|
|
3282
|
+
results["reference_runtime_stats"] = ref_stats
|
|
3283
|
+
results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
|
|
3284
|
+
print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
|
|
3285
|
+
else:
|
|
3286
|
+
# Standard timing without full defenses
|
|
3287
|
+
# Warmup BOTH models before benchmarking either
|
|
3288
|
+
# This ensures consistent GPU state and avoids MIOpen cache effects
|
|
3289
|
+
# that cause variance when warming up models sequentially
|
|
3290
|
+
for _ in range(5):
|
|
3291
|
+
with torch.no_grad():
|
|
3292
|
+
_ = new_model(*inputs)
|
|
3293
|
+
_ = ref_model(*inputs)
|
|
3294
|
+
torch.cuda.synchronize()
|
|
3295
|
+
|
|
3296
|
+
# Benchmark new model
|
|
3297
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
3298
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
3299
|
+
|
|
3300
|
+
new_times = []
|
|
3301
|
+
for _ in range(args.num_perf_trials):
|
|
3302
|
+
start.record()
|
|
3303
|
+
with torch.no_grad():
|
|
3304
|
+
_ = new_model(*inputs)
|
|
3305
|
+
end.record()
|
|
3306
|
+
torch.cuda.synchronize()
|
|
3307
|
+
new_times.append(start.elapsed_time(end))
|
|
3308
|
+
|
|
3309
|
+
new_stats = _calculate_timing_stats(new_times)
|
|
3310
|
+
results["runtime_ms"] = new_stats["median"]
|
|
3311
|
+
results["runtime_stats"] = new_stats
|
|
3312
|
+
|
|
3313
|
+
# Benchmark reference model
|
|
3314
|
+
ref_times = []
|
|
3315
|
+
for _ in range(args.num_perf_trials):
|
|
3316
|
+
start.record()
|
|
3317
|
+
with torch.no_grad():
|
|
3318
|
+
_ = ref_model(*inputs)
|
|
3319
|
+
end.record()
|
|
3320
|
+
torch.cuda.synchronize()
|
|
3321
|
+
ref_times.append(start.elapsed_time(end))
|
|
3322
|
+
|
|
3323
|
+
ref_stats = _calculate_timing_stats(ref_times)
|
|
3324
|
+
results["reference_runtime_ms"] = ref_stats["median"]
|
|
3325
|
+
results["reference_runtime_stats"] = ref_stats
|
|
3326
|
+
results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
|
|
3327
|
+
print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
|
|
3328
|
+
|
|
3329
|
+
# Run profiling if requested and correctness passed
|
|
3330
|
+
if args.profile and all_correct:
|
|
3331
|
+
print("[KernelBench] Running profiler...")
|
|
3332
|
+
inputs = get_inputs()
|
|
3333
|
+
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
3334
|
+
|
|
3335
|
+
try:
|
|
3336
|
+
# Profile implementation
|
|
3337
|
+
impl_stats = run_profiling(new_model, inputs, "implementation", str(profile_dir))
|
|
3338
|
+
results["profile_impl"] = impl_stats
|
|
3339
|
+
print(f"[KernelBench] Implementation profile:")
|
|
3340
|
+
print(f" Total GPU time: {impl_stats['total_gpu_time_ms']:.3f}ms")
|
|
3341
|
+
print(f" Kernels launched: {impl_stats['num_gpu_kernels']}")
|
|
3342
|
+
if impl_stats['top_kernels']:
|
|
3343
|
+
print(f" Top kernel: {impl_stats['top_kernels'][0]['name'][:60]}...")
|
|
3344
|
+
print(f" {impl_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
|
|
3345
|
+
|
|
3346
|
+
# Profile reference
|
|
3347
|
+
ref_stats = run_profiling(ref_model, inputs, "reference", str(profile_dir))
|
|
3348
|
+
results["profile_ref"] = ref_stats
|
|
3349
|
+
print(f"[KernelBench] Reference profile:")
|
|
3350
|
+
print(f" Total GPU time: {ref_stats['total_gpu_time_ms']:.3f}ms")
|
|
3351
|
+
print(f" Kernels launched: {ref_stats['num_gpu_kernels']}")
|
|
3352
|
+
if ref_stats['top_kernels']:
|
|
3353
|
+
print(f" Top kernel: {ref_stats['top_kernels'][0]['name'][:60]}...")
|
|
3354
|
+
print(f" {ref_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
|
|
3355
|
+
|
|
3356
|
+
print(f"[KernelBench] Profile traces saved to: {profile_dir}/")
|
|
3357
|
+
|
|
3358
|
+
except Exception as prof_err:
|
|
3359
|
+
print(f"[KernelBench] Profiling failed: {prof_err}")
|
|
3360
|
+
results["profile_error"] = str(prof_err)
|
|
3361
|
+
|
|
3362
|
+
except Exception as e:
|
|
3363
|
+
import traceback
|
|
3364
|
+
results["error"] = f"{type(e).__name__}: {e}\\n{traceback.format_exc()}"
|
|
3365
|
+
print(f"[KernelBench] Error: {results['error']}")
|
|
3366
|
+
|
|
3367
|
+
# Write results
|
|
3368
|
+
with open(args.output, "w") as f:
|
|
3369
|
+
json.dump(results, f, indent=2)
|
|
3370
|
+
print(f"[KernelBench] Results written to {args.output}")
|
|
3371
|
+
|
|
3372
|
+
# Cleanup GPU memory
|
|
3373
|
+
try:
|
|
3374
|
+
del ref_model, new_model
|
|
3375
|
+
except NameError:
|
|
3376
|
+
pass
|
|
3377
|
+
import gc
|
|
3378
|
+
gc.collect()
|
|
3379
|
+
if torch.cuda.is_available():
|
|
3380
|
+
torch.cuda.empty_cache()
|
|
3381
|
+
|
|
3382
|
+
if __name__ == "__main__":
|
|
3383
|
+
main()
|
|
3384
|
+
"""
|
|
3385
|
+
|
|
3386
|
+
|
|
3387
|
+
def _validate_kernelbench_files(args: KernelBenchEvaluateArgs) -> str | None:
|
|
3388
|
+
"""Validate that KernelBench input files exist and have expected signatures.
|
|
3389
|
+
|
|
3390
|
+
Returns:
|
|
3391
|
+
Error message if validation fails, None if all valid
|
|
3392
|
+
"""
|
|
3393
|
+
if not args.implementation.exists():
|
|
3394
|
+
return f"Implementation file not found: {args.implementation}"
|
|
3395
|
+
if not args.reference.exists():
|
|
3396
|
+
return f"Reference file not found: {args.reference}"
|
|
3397
|
+
|
|
3398
|
+
# Validate implementation has ModelNew class
|
|
3399
|
+
impl_missing = _check_python_file_has(args.implementation, "ModelNew")
|
|
3400
|
+
if impl_missing:
|
|
3401
|
+
# Check if it looks like functional format (has custom_kernel)
|
|
3402
|
+
has_custom_kernel = not _check_python_file_has(args.implementation, "custom_kernel")
|
|
3403
|
+
if has_custom_kernel:
|
|
3404
|
+
return (
|
|
3405
|
+
f"Implementation file missing 'ModelNew' class: {args.implementation}\n"
|
|
3406
|
+
"Hint: This looks like functional format. Use 'wafer evaluate' instead:\n"
|
|
3407
|
+
f" wafer evaluate --impl {args.implementation} --reference <ref.py> --test-cases <tests.json>"
|
|
3408
|
+
)
|
|
3409
|
+
return (
|
|
3410
|
+
f"Implementation file missing 'ModelNew' class: {args.implementation}\n"
|
|
3411
|
+
" KernelBench format requires a 'class ModelNew(nn.Module)' definition"
|
|
3412
|
+
)
|
|
3413
|
+
|
|
3414
|
+
# Validate reference has Model, get_inputs, get_init_inputs
|
|
3415
|
+
ref_missing = _check_python_file_has(args.reference, "Model", "get_inputs", "get_init_inputs")
|
|
3416
|
+
if ref_missing:
|
|
3417
|
+
# Check if it looks like functional format (has ref_kernel and generate_input)
|
|
3418
|
+
has_functional = not _check_python_file_has(args.reference, "ref_kernel", "generate_input")
|
|
3419
|
+
if has_functional:
|
|
3420
|
+
return (
|
|
3421
|
+
f"Reference file missing required definitions: {', '.join(ref_missing)}\n"
|
|
3422
|
+
"Hint: This looks like functional format. Use 'wafer evaluate' instead:\n"
|
|
3423
|
+
f" wafer evaluate --impl <impl.py> --reference {args.reference} --test-cases <tests.json>"
|
|
3424
|
+
)
|
|
3425
|
+
return (
|
|
3426
|
+
f"Reference file missing required definitions: {', '.join(ref_missing)}\n"
|
|
3427
|
+
f" File: {args.reference}\n"
|
|
3428
|
+
" KernelBench format requires: 'class Model', 'get_inputs()', 'get_init_inputs()'"
|
|
3429
|
+
)
|
|
3430
|
+
|
|
3431
|
+
# Static kernel validation if backend specified
|
|
3432
|
+
if args.backend:
|
|
3433
|
+
from wafer_core.utils.kernel_utils.static_checker import validate_kernel_static
|
|
3434
|
+
|
|
3435
|
+
code = args.implementation.read_text()
|
|
3436
|
+
valid, errors, warnings = validate_kernel_static(code, backend=args.backend)
|
|
3437
|
+
|
|
3438
|
+
# Print warnings (don't fail)
|
|
3439
|
+
for warning in warnings:
|
|
3440
|
+
logger.warning(f"Static check warning: {warning}")
|
|
3441
|
+
|
|
3442
|
+
# Fail on errors
|
|
3443
|
+
if not valid:
|
|
3444
|
+
error_list = "\n - ".join(errors)
|
|
3445
|
+
return (
|
|
3446
|
+
f"Static kernel validation failed for backend '{args.backend}':\n"
|
|
3447
|
+
f" - {error_list}\n\n"
|
|
3448
|
+
f"The implementation must use {args.backend.upper()} kernel primitives.\n"
|
|
3449
|
+
"See KernelBench documentation for valid kernel patterns."
|
|
3450
|
+
)
|
|
3451
|
+
|
|
3452
|
+
return None
|
|
3453
|
+
|
|
3454
|
+
|
|
3455
|
+
async def run_evaluate_kernelbench_docker(
|
|
3456
|
+
args: KernelBenchEvaluateArgs,
|
|
3457
|
+
target: BaremetalTarget | VMTarget,
|
|
3458
|
+
) -> EvaluateResult:
|
|
3459
|
+
"""Run KernelBench format evaluation in Docker container on SSH-based target.
|
|
3460
|
+
|
|
3461
|
+
Similar to run_evaluate_docker but uses KernelBench eval script instead.
|
|
3462
|
+
"""
|
|
3463
|
+
from datetime import datetime
|
|
3464
|
+
|
|
3465
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
3466
|
+
|
|
3467
|
+
CONTAINER_WORKSPACE = "/workspace"
|
|
3468
|
+
REMOTE_WORKSPACE_BASE = "~/.wafer/workspaces"
|
|
3469
|
+
|
|
3470
|
+
if not target.docker_image:
|
|
3471
|
+
return EvaluateResult(
|
|
3472
|
+
success=False,
|
|
3473
|
+
all_correct=False,
|
|
3474
|
+
correctness_score=0.0,
|
|
3475
|
+
geomean_speedup=0.0,
|
|
3476
|
+
passed_tests=0,
|
|
3477
|
+
total_tests=0,
|
|
3478
|
+
error_message="docker_image must be set for Docker execution",
|
|
3479
|
+
)
|
|
3480
|
+
|
|
3481
|
+
# Select GPU
|
|
3482
|
+
gpu_id = _select_gpu_id(target, args.gpu_id)
|
|
3483
|
+
|
|
3484
|
+
print(f"Connecting to {target.ssh_target}...")
|
|
3485
|
+
|
|
3486
|
+
async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
|
|
3487
|
+
# Create workspace
|
|
3488
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
3489
|
+
run_dir = f"kernelbench_eval_{timestamp}"
|
|
3490
|
+
workspace_path = await client.expand_path(f"{REMOTE_WORKSPACE_BASE}/kernelbench")
|
|
3491
|
+
run_path = f"{workspace_path}/{run_dir}"
|
|
3492
|
+
|
|
3493
|
+
await client.exec(f"mkdir -p {run_path}")
|
|
3494
|
+
print(f"Created run directory: {run_path}")
|
|
3495
|
+
|
|
3496
|
+
# Read and upload files
|
|
3497
|
+
impl_code = args.implementation.read_text()
|
|
3498
|
+
ref_code = args.reference.read_text()
|
|
3499
|
+
|
|
3500
|
+
# Write implementation
|
|
3501
|
+
impl_path = f"{run_path}/implementation.py"
|
|
3502
|
+
write_result = await client.exec(
|
|
3503
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
3504
|
+
)
|
|
3505
|
+
if write_result.exit_code != 0:
|
|
3506
|
+
return EvaluateResult(
|
|
3507
|
+
success=False,
|
|
3508
|
+
all_correct=False,
|
|
3509
|
+
correctness_score=0.0,
|
|
3510
|
+
geomean_speedup=0.0,
|
|
3511
|
+
passed_tests=0,
|
|
3512
|
+
total_tests=0,
|
|
3513
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
3514
|
+
)
|
|
3515
|
+
|
|
3516
|
+
# Write reference
|
|
3517
|
+
ref_path = f"{run_path}/reference.py"
|
|
3518
|
+
write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
|
|
3519
|
+
if write_result.exit_code != 0:
|
|
3520
|
+
return EvaluateResult(
|
|
3521
|
+
success=False,
|
|
3522
|
+
all_correct=False,
|
|
3523
|
+
correctness_score=0.0,
|
|
3524
|
+
geomean_speedup=0.0,
|
|
3525
|
+
passed_tests=0,
|
|
3526
|
+
total_tests=0,
|
|
3527
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
3528
|
+
)
|
|
3529
|
+
|
|
3530
|
+
# Write custom inputs if provided
|
|
3531
|
+
if args.inputs:
|
|
3532
|
+
inputs_code = args.inputs.read_text()
|
|
3533
|
+
inputs_file_path = f"{run_path}/custom_inputs.py"
|
|
3534
|
+
write_result = await client.exec(
|
|
3535
|
+
f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
3536
|
+
)
|
|
3537
|
+
if write_result.exit_code != 0:
|
|
3538
|
+
return EvaluateResult(
|
|
3539
|
+
success=False,
|
|
3540
|
+
all_correct=False,
|
|
3541
|
+
correctness_score=0.0,
|
|
3542
|
+
geomean_speedup=0.0,
|
|
3543
|
+
passed_tests=0,
|
|
3544
|
+
total_tests=0,
|
|
3545
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
3546
|
+
)
|
|
3547
|
+
|
|
3548
|
+
# Write eval script
|
|
3549
|
+
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
3550
|
+
write_result = await client.exec(
|
|
3551
|
+
f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
|
|
3552
|
+
)
|
|
3553
|
+
if write_result.exit_code != 0:
|
|
3554
|
+
return EvaluateResult(
|
|
3555
|
+
success=False,
|
|
3556
|
+
all_correct=False,
|
|
3557
|
+
correctness_score=0.0,
|
|
3558
|
+
geomean_speedup=0.0,
|
|
3559
|
+
passed_tests=0,
|
|
3560
|
+
total_tests=0,
|
|
3561
|
+
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
3562
|
+
)
|
|
3563
|
+
|
|
3564
|
+
# Write defense module if defensive mode is enabled
|
|
3565
|
+
defense_module_path = None
|
|
3566
|
+
if args.defensive:
|
|
3567
|
+
defense_path = (
|
|
3568
|
+
Path(__file__).parent.parent.parent.parent
|
|
3569
|
+
/ "packages"
|
|
3570
|
+
/ "wafer-core"
|
|
3571
|
+
/ "wafer_core"
|
|
3572
|
+
/ "utils"
|
|
3573
|
+
/ "kernel_utils"
|
|
3574
|
+
/ "defense.py"
|
|
3575
|
+
)
|
|
3576
|
+
if defense_path.exists():
|
|
3577
|
+
defense_code = defense_path.read_text()
|
|
3578
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
3579
|
+
write_result = await client.exec(
|
|
3580
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
3581
|
+
)
|
|
3582
|
+
if write_result.exit_code != 0:
|
|
3583
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
3584
|
+
defense_module_path = None
|
|
3585
|
+
else:
|
|
3586
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
3587
|
+
|
|
3588
|
+
print("Running KernelBench evaluation in Docker container...")
|
|
3589
|
+
|
|
3590
|
+
# Paths inside container
|
|
3591
|
+
container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
|
|
3592
|
+
container_impl_path = f"{container_run_path}/implementation.py"
|
|
3593
|
+
container_ref_path = f"{container_run_path}/reference.py"
|
|
3594
|
+
container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
|
|
3595
|
+
container_eval_script = f"{container_run_path}/kernelbench_eval.py"
|
|
3596
|
+
container_output = f"{container_run_path}/results.json"
|
|
3597
|
+
container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
|
|
3598
|
+
|
|
3599
|
+
# Build eval command
|
|
3600
|
+
python_cmd_parts = [
|
|
3601
|
+
f"python3 {container_eval_script}",
|
|
3602
|
+
f"--impl {container_impl_path}",
|
|
3603
|
+
f"--reference {container_ref_path}",
|
|
3604
|
+
f"--output {container_output}",
|
|
3605
|
+
]
|
|
3606
|
+
|
|
3607
|
+
if args.benchmark:
|
|
3608
|
+
python_cmd_parts.append("--benchmark")
|
|
3609
|
+
if args.profile:
|
|
3610
|
+
python_cmd_parts.append("--profile")
|
|
3611
|
+
if container_inputs_path:
|
|
3612
|
+
python_cmd_parts.append(f"--inputs {container_inputs_path}")
|
|
3613
|
+
if args.defensive and container_defense_path:
|
|
3614
|
+
python_cmd_parts.append("--defensive")
|
|
3615
|
+
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3616
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
3617
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
3618
|
+
|
|
3619
|
+
eval_cmd = " ".join(python_cmd_parts)
|
|
3620
|
+
|
|
3621
|
+
# Build pip install for torch dependencies if needed
|
|
3622
|
+
pip_install_cmd = _build_docker_pip_install_cmd(target)
|
|
3623
|
+
full_cmd = f"{pip_install_cmd} && cd {container_run_path} && {eval_cmd}"
|
|
3624
|
+
|
|
3625
|
+
# Build Docker command
|
|
3626
|
+
docker_cmd = _build_docker_run_command(
|
|
3627
|
+
image=target.docker_image,
|
|
3628
|
+
command=full_cmd,
|
|
3629
|
+
working_dir=container_run_path,
|
|
3630
|
+
env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
|
|
3631
|
+
gpus="all",
|
|
3632
|
+
volumes={workspace_path: CONTAINER_WORKSPACE},
|
|
3633
|
+
)
|
|
3634
|
+
|
|
3635
|
+
print(f"Docker command: {docker_cmd[:100]}...")
|
|
3636
|
+
|
|
3637
|
+
# Run and stream output
|
|
3638
|
+
log_lines = []
|
|
3639
|
+
async for line in client.exec_stream(docker_cmd):
|
|
3640
|
+
print(line, flush=True)
|
|
3641
|
+
log_lines.append(line)
|
|
3642
|
+
|
|
3643
|
+
# Read results
|
|
3644
|
+
results_path = f"{run_path}/results.json"
|
|
3645
|
+
cat_result = await client.exec(f"cat {results_path}")
|
|
3646
|
+
|
|
3647
|
+
if cat_result.exit_code != 0:
|
|
3648
|
+
log_tail = "\n".join(log_lines[-50:])
|
|
3649
|
+
return EvaluateResult(
|
|
3650
|
+
success=False,
|
|
3651
|
+
all_correct=False,
|
|
3652
|
+
correctness_score=0.0,
|
|
3653
|
+
geomean_speedup=0.0,
|
|
3654
|
+
passed_tests=0,
|
|
3655
|
+
total_tests=0,
|
|
3656
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
3657
|
+
)
|
|
3658
|
+
|
|
3659
|
+
# Parse results
|
|
3660
|
+
try:
|
|
3661
|
+
results_data = json.loads(cat_result.stdout)
|
|
3662
|
+
except json.JSONDecodeError as e:
|
|
3663
|
+
return EvaluateResult(
|
|
3664
|
+
success=False,
|
|
3665
|
+
all_correct=False,
|
|
3666
|
+
correctness_score=0.0,
|
|
3667
|
+
geomean_speedup=0.0,
|
|
3668
|
+
passed_tests=0,
|
|
3669
|
+
total_tests=0,
|
|
3670
|
+
error_message=f"Failed to parse results: {e}",
|
|
3671
|
+
)
|
|
3672
|
+
|
|
3673
|
+
# Convert to EvaluateResult
|
|
3674
|
+
# TODO: use compiled field - currently ignored, should affect success/error
|
|
3675
|
+
# compiled = results_data.get("compiled", False)
|
|
3676
|
+
correct = results_data.get("correct", False)
|
|
3677
|
+
speedup = results_data.get("speedup", 0.0) or 0.0
|
|
3678
|
+
error = results_data.get("error")
|
|
3679
|
+
|
|
3680
|
+
if error:
|
|
3681
|
+
return EvaluateResult(
|
|
3682
|
+
success=False,
|
|
3683
|
+
all_correct=False,
|
|
3684
|
+
correctness_score=0.0,
|
|
3685
|
+
geomean_speedup=0.0,
|
|
3686
|
+
passed_tests=0,
|
|
3687
|
+
total_tests=1,
|
|
3688
|
+
error_message=error,
|
|
3689
|
+
)
|
|
3690
|
+
|
|
3691
|
+
return EvaluateResult(
|
|
3692
|
+
success=True,
|
|
3693
|
+
all_correct=correct,
|
|
3694
|
+
correctness_score=1.0 if correct else 0.0,
|
|
3695
|
+
geomean_speedup=speedup,
|
|
3696
|
+
passed_tests=1 if correct else 0,
|
|
3697
|
+
total_tests=1,
|
|
3698
|
+
)
|
|
3699
|
+
|
|
3700
|
+
|
|
3701
|
+
# Default ROCm PyTorch image for DigitalOcean AMD MI300X
|
|
3702
|
+
DEFAULT_ROCM_DOCKER_IMAGE = "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0"
|
|
3703
|
+
|
|
3704
|
+
|
|
3705
|
+
async def run_evaluate_kernelbench_digitalocean(
|
|
3706
|
+
args: KernelBenchEvaluateArgs,
|
|
3707
|
+
target: DigitalOceanTarget,
|
|
3708
|
+
) -> EvaluateResult:
|
|
3709
|
+
"""Run KernelBench format evaluation in Docker container on DigitalOcean AMD GPU.
|
|
3710
|
+
|
|
3711
|
+
Uses ROCm Docker image with device passthrough for AMD GPUs.
|
|
3712
|
+
"""
|
|
3713
|
+
from datetime import datetime
|
|
3714
|
+
|
|
3715
|
+
import trio_asyncio
|
|
3716
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
3717
|
+
from wafer_core.targets.digitalocean import digitalocean_ssh_context
|
|
3718
|
+
|
|
3719
|
+
CONTAINER_WORKSPACE = "/workspace"
|
|
3720
|
+
REMOTE_WORKSPACE_BASE = "~/.wafer/workspaces"
|
|
3721
|
+
|
|
3722
|
+
docker_image = DEFAULT_ROCM_DOCKER_IMAGE
|
|
3723
|
+
|
|
3724
|
+
# Select GPU
|
|
3725
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
3726
|
+
|
|
3727
|
+
print("Provisioning/connecting to DigitalOcean droplet...")
|
|
3728
|
+
|
|
3729
|
+
async with digitalocean_ssh_context(target) as ssh_info:
|
|
3730
|
+
ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
|
|
3731
|
+
print(f"Connected to {ssh_target}")
|
|
3732
|
+
|
|
3733
|
+
async with trio_asyncio.open_loop():
|
|
3734
|
+
async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
|
|
3735
|
+
# Ensure Docker is installed
|
|
3736
|
+
docker_check = await client.exec("which docker")
|
|
3737
|
+
if docker_check.exit_code != 0:
|
|
3738
|
+
print("Docker not found, installing...")
|
|
3739
|
+
install_result = await client.exec(
|
|
3740
|
+
"apt-get update -qq && apt-get install -y -qq docker.io"
|
|
3741
|
+
)
|
|
3742
|
+
if install_result.exit_code != 0:
|
|
3743
|
+
return EvaluateResult(
|
|
3744
|
+
success=False,
|
|
3745
|
+
all_correct=False,
|
|
3746
|
+
correctness_score=0.0,
|
|
3747
|
+
geomean_speedup=0.0,
|
|
3748
|
+
passed_tests=0,
|
|
3749
|
+
total_tests=0,
|
|
3750
|
+
error_message=f"Failed to install Docker: {install_result.stderr}",
|
|
3751
|
+
)
|
|
3752
|
+
|
|
3753
|
+
# Create workspace
|
|
3754
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
3755
|
+
run_dir = f"kernelbench_eval_{timestamp}"
|
|
3756
|
+
workspace_path = await client.expand_path(f"{REMOTE_WORKSPACE_BASE}/kernelbench")
|
|
3757
|
+
run_path = f"{workspace_path}/{run_dir}"
|
|
3758
|
+
|
|
3759
|
+
await client.exec(f"mkdir -p {run_path}")
|
|
3760
|
+
print(f"Created run directory: {run_path}")
|
|
3761
|
+
|
|
3762
|
+
# Read and upload files
|
|
3763
|
+
impl_code = args.implementation.read_text()
|
|
3764
|
+
ref_code = args.reference.read_text()
|
|
3765
|
+
|
|
3766
|
+
# Write implementation
|
|
3767
|
+
impl_path = f"{run_path}/implementation.py"
|
|
3768
|
+
write_result = await client.exec(
|
|
3769
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
3770
|
+
)
|
|
3771
|
+
if write_result.exit_code != 0:
|
|
3772
|
+
return EvaluateResult(
|
|
3773
|
+
success=False,
|
|
3774
|
+
all_correct=False,
|
|
3775
|
+
correctness_score=0.0,
|
|
3776
|
+
geomean_speedup=0.0,
|
|
3777
|
+
passed_tests=0,
|
|
3778
|
+
total_tests=0,
|
|
3779
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
3780
|
+
)
|
|
3781
|
+
|
|
3782
|
+
# Write reference
|
|
3783
|
+
ref_path = f"{run_path}/reference.py"
|
|
3784
|
+
write_result = await client.exec(
|
|
3785
|
+
f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
3786
|
+
)
|
|
3787
|
+
if write_result.exit_code != 0:
|
|
3788
|
+
return EvaluateResult(
|
|
3789
|
+
success=False,
|
|
3790
|
+
all_correct=False,
|
|
3791
|
+
correctness_score=0.0,
|
|
3792
|
+
geomean_speedup=0.0,
|
|
3793
|
+
passed_tests=0,
|
|
3794
|
+
total_tests=0,
|
|
3795
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
3796
|
+
)
|
|
3797
|
+
|
|
3798
|
+
# Write custom inputs if provided
|
|
3799
|
+
if args.inputs:
|
|
3800
|
+
inputs_code = args.inputs.read_text()
|
|
3801
|
+
inputs_file_path = f"{run_path}/custom_inputs.py"
|
|
3802
|
+
write_result = await client.exec(
|
|
3803
|
+
f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
3804
|
+
)
|
|
3805
|
+
if write_result.exit_code != 0:
|
|
3806
|
+
return EvaluateResult(
|
|
3807
|
+
success=False,
|
|
3808
|
+
all_correct=False,
|
|
3809
|
+
correctness_score=0.0,
|
|
3810
|
+
geomean_speedup=0.0,
|
|
3811
|
+
passed_tests=0,
|
|
3812
|
+
total_tests=0,
|
|
3813
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
3814
|
+
)
|
|
3815
|
+
|
|
3816
|
+
# Write eval script
|
|
3817
|
+
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
3818
|
+
write_result = await client.exec(
|
|
3819
|
+
f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
|
|
3820
|
+
)
|
|
3821
|
+
if write_result.exit_code != 0:
|
|
3822
|
+
return EvaluateResult(
|
|
3823
|
+
success=False,
|
|
3824
|
+
all_correct=False,
|
|
3825
|
+
correctness_score=0.0,
|
|
3826
|
+
geomean_speedup=0.0,
|
|
3827
|
+
passed_tests=0,
|
|
3828
|
+
total_tests=0,
|
|
3829
|
+
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
3830
|
+
)
|
|
3831
|
+
|
|
3832
|
+
# Write defense module if defensive mode is enabled
|
|
3833
|
+
defense_module_path = None
|
|
3834
|
+
if args.defensive:
|
|
3835
|
+
defense_path = (
|
|
3836
|
+
Path(__file__).parent.parent.parent.parent
|
|
3837
|
+
/ "packages"
|
|
3838
|
+
/ "wafer-core"
|
|
3839
|
+
/ "wafer_core"
|
|
3840
|
+
/ "utils"
|
|
3841
|
+
/ "kernel_utils"
|
|
3842
|
+
/ "defense.py"
|
|
3843
|
+
)
|
|
3844
|
+
if defense_path.exists():
|
|
3845
|
+
defense_code = defense_path.read_text()
|
|
3846
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
3847
|
+
write_result = await client.exec(
|
|
3848
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
3849
|
+
)
|
|
3850
|
+
if write_result.exit_code != 0:
|
|
3851
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
3852
|
+
defense_module_path = None
|
|
3853
|
+
else:
|
|
3854
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
3855
|
+
|
|
3856
|
+
print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
|
|
3857
|
+
|
|
3858
|
+
# Paths inside container
|
|
3859
|
+
container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
|
|
3860
|
+
container_impl_path = f"{container_run_path}/implementation.py"
|
|
3861
|
+
container_ref_path = f"{container_run_path}/reference.py"
|
|
3862
|
+
container_inputs_path = (
|
|
3863
|
+
f"{container_run_path}/custom_inputs.py" if args.inputs else None
|
|
3864
|
+
)
|
|
3865
|
+
container_eval_script = f"{container_run_path}/kernelbench_eval.py"
|
|
3866
|
+
container_output = f"{container_run_path}/results.json"
|
|
3867
|
+
container_defense_path = (
|
|
3868
|
+
f"{container_run_path}/defense.py" if defense_module_path else None
|
|
3869
|
+
)
|
|
3870
|
+
|
|
3871
|
+
# Build eval command
|
|
3872
|
+
python_cmd_parts = [
|
|
3873
|
+
f"python3 {container_eval_script}",
|
|
3874
|
+
f"--impl {container_impl_path}",
|
|
3875
|
+
f"--reference {container_ref_path}",
|
|
3876
|
+
f"--output {container_output}",
|
|
3877
|
+
]
|
|
3878
|
+
|
|
3879
|
+
if args.benchmark:
|
|
3880
|
+
python_cmd_parts.append("--benchmark")
|
|
3881
|
+
if args.profile:
|
|
3882
|
+
python_cmd_parts.append("--profile")
|
|
3883
|
+
if container_inputs_path:
|
|
3884
|
+
python_cmd_parts.append(f"--inputs {container_inputs_path}")
|
|
3885
|
+
if args.defensive and container_defense_path:
|
|
3886
|
+
python_cmd_parts.append("--defensive")
|
|
3887
|
+
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3888
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
3889
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
3890
|
+
|
|
3891
|
+
eval_cmd = " ".join(python_cmd_parts)
|
|
3892
|
+
|
|
3893
|
+
# For AMD, we don't need pip install - the ROCm image has everything
|
|
3894
|
+
full_cmd = f"cd {container_run_path} && {eval_cmd}"
|
|
3895
|
+
|
|
3896
|
+
# Build Docker command for AMD
|
|
3897
|
+
# PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
|
|
3898
|
+
rocm_arch = _get_rocm_arch(target.compute_capability)
|
|
3899
|
+
env_dict = {
|
|
3900
|
+
"HIP_VISIBLE_DEVICES": str(gpu_id),
|
|
3901
|
+
"PYTHONUNBUFFERED": "1",
|
|
3902
|
+
}
|
|
3903
|
+
if rocm_arch:
|
|
3904
|
+
env_dict["PYTORCH_ROCM_ARCH"] = rocm_arch
|
|
3905
|
+
|
|
3906
|
+
docker_cmd = _build_docker_run_command_amd(
|
|
3907
|
+
image=docker_image,
|
|
3908
|
+
command=full_cmd,
|
|
3909
|
+
working_dir=container_run_path,
|
|
3910
|
+
env=env_dict,
|
|
3911
|
+
volumes={workspace_path: CONTAINER_WORKSPACE},
|
|
3912
|
+
)
|
|
3913
|
+
|
|
3914
|
+
print(f"Docker command: {docker_cmd[:100]}...")
|
|
3915
|
+
|
|
3916
|
+
# Run and stream output
|
|
3917
|
+
log_lines = []
|
|
3918
|
+
async for line in client.exec_stream(docker_cmd):
|
|
3919
|
+
print(line, flush=True)
|
|
3920
|
+
log_lines.append(line)
|
|
3921
|
+
|
|
3922
|
+
# Read results
|
|
3923
|
+
results_path = f"{run_path}/results.json"
|
|
3924
|
+
cat_result = await client.exec(f"cat {results_path}")
|
|
3925
|
+
|
|
3926
|
+
if cat_result.exit_code != 0:
|
|
3927
|
+
log_tail = "\n".join(log_lines[-50:])
|
|
3928
|
+
return EvaluateResult(
|
|
3929
|
+
success=False,
|
|
3930
|
+
all_correct=False,
|
|
3931
|
+
correctness_score=0.0,
|
|
3932
|
+
geomean_speedup=0.0,
|
|
3933
|
+
passed_tests=0,
|
|
3934
|
+
total_tests=0,
|
|
3935
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
3936
|
+
)
|
|
3937
|
+
|
|
3938
|
+
# Parse results
|
|
3939
|
+
try:
|
|
3940
|
+
results_data = json.loads(cat_result.stdout)
|
|
3941
|
+
except json.JSONDecodeError as e:
|
|
3942
|
+
return EvaluateResult(
|
|
3943
|
+
success=False,
|
|
3944
|
+
all_correct=False,
|
|
3945
|
+
correctness_score=0.0,
|
|
3946
|
+
geomean_speedup=0.0,
|
|
3947
|
+
passed_tests=0,
|
|
3948
|
+
total_tests=0,
|
|
3949
|
+
error_message=f"Failed to parse results: {e}",
|
|
3950
|
+
)
|
|
3951
|
+
|
|
3952
|
+
# Convert to EvaluateResult
|
|
3953
|
+
# TODO: use compiled field - currently ignored, should affect success/error
|
|
3954
|
+
# compiled = results_data.get("compiled", False)
|
|
3955
|
+
correct = results_data.get("correct", False)
|
|
3956
|
+
speedup = results_data.get("speedup", 0.0) or 0.0
|
|
3957
|
+
error = results_data.get("error")
|
|
3958
|
+
|
|
3959
|
+
if error:
|
|
3960
|
+
return EvaluateResult(
|
|
3961
|
+
success=False,
|
|
3962
|
+
all_correct=False,
|
|
3963
|
+
correctness_score=0.0,
|
|
3964
|
+
geomean_speedup=0.0,
|
|
3965
|
+
passed_tests=0,
|
|
3966
|
+
total_tests=1,
|
|
3967
|
+
error_message=error,
|
|
3968
|
+
)
|
|
3969
|
+
|
|
3970
|
+
return EvaluateResult(
|
|
3971
|
+
success=True,
|
|
3972
|
+
all_correct=correct,
|
|
3973
|
+
correctness_score=1.0 if correct else 0.0,
|
|
3974
|
+
geomean_speedup=speedup,
|
|
3975
|
+
passed_tests=1 if correct else 0,
|
|
3976
|
+
total_tests=1,
|
|
3977
|
+
)
|
|
3978
|
+
|
|
3979
|
+
|
|
3980
|
+
async def run_evaluate_kernelbench_runpod(
|
|
3981
|
+
args: KernelBenchEvaluateArgs,
|
|
3982
|
+
target: RunPodTarget,
|
|
3983
|
+
) -> EvaluateResult:
|
|
3984
|
+
"""Run KernelBench format evaluation directly on RunPod AMD GPU.
|
|
3985
|
+
|
|
3986
|
+
Runs evaluation script directly on host (no Docker) since RunPod pods
|
|
3987
|
+
already have PyTorch/ROCm installed.
|
|
3988
|
+
"""
|
|
3989
|
+
from datetime import datetime
|
|
3990
|
+
|
|
3991
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
3992
|
+
from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
|
|
3993
|
+
|
|
3994
|
+
REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
|
|
3995
|
+
|
|
3996
|
+
# Select GPU
|
|
3997
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
3998
|
+
|
|
3999
|
+
print(f"Provisioning RunPod ({target.gpu_type_id})...")
|
|
4000
|
+
|
|
4001
|
+
try:
|
|
4002
|
+
async with runpod_ssh_context(target) as ssh_info:
|
|
4003
|
+
ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
|
|
4004
|
+
print(f"Connected to RunPod: {ssh_target}")
|
|
4005
|
+
|
|
4006
|
+
async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
|
|
4007
|
+
# Create workspace
|
|
4008
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
4009
|
+
run_dir = f"kernelbench_eval_{timestamp}"
|
|
4010
|
+
run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
|
|
4011
|
+
|
|
4012
|
+
await client.exec(f"mkdir -p {run_path}")
|
|
4013
|
+
print(f"Created run directory: {run_path}")
|
|
4014
|
+
|
|
4015
|
+
# Read and upload files
|
|
4016
|
+
impl_code = args.implementation.read_text()
|
|
4017
|
+
ref_code = args.reference.read_text()
|
|
4018
|
+
|
|
4019
|
+
# Write implementation
|
|
4020
|
+
impl_path = f"{run_path}/implementation.py"
|
|
4021
|
+
write_result = await client.exec(
|
|
4022
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
4023
|
+
)
|
|
4024
|
+
if write_result.exit_code != 0:
|
|
4025
|
+
return EvaluateResult(
|
|
4026
|
+
success=False,
|
|
4027
|
+
all_correct=False,
|
|
4028
|
+
correctness_score=0.0,
|
|
4029
|
+
geomean_speedup=0.0,
|
|
4030
|
+
passed_tests=0,
|
|
4031
|
+
total_tests=0,
|
|
4032
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
4033
|
+
)
|
|
4034
|
+
|
|
4035
|
+
# Write reference
|
|
4036
|
+
ref_path = f"{run_path}/reference.py"
|
|
4037
|
+
write_result = await client.exec(
|
|
4038
|
+
f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
4039
|
+
)
|
|
4040
|
+
if write_result.exit_code != 0:
|
|
4041
|
+
return EvaluateResult(
|
|
4042
|
+
success=False,
|
|
4043
|
+
all_correct=False,
|
|
4044
|
+
correctness_score=0.0,
|
|
4045
|
+
geomean_speedup=0.0,
|
|
4046
|
+
passed_tests=0,
|
|
4047
|
+
total_tests=0,
|
|
4048
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
4049
|
+
)
|
|
4050
|
+
|
|
4051
|
+
# Write custom inputs if provided
|
|
4052
|
+
inputs_path = None
|
|
4053
|
+
if args.inputs:
|
|
4054
|
+
inputs_code = args.inputs.read_text()
|
|
4055
|
+
inputs_path = f"{run_path}/custom_inputs.py"
|
|
4056
|
+
write_result = await client.exec(
|
|
4057
|
+
f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
4058
|
+
)
|
|
4059
|
+
if write_result.exit_code != 0:
|
|
4060
|
+
return EvaluateResult(
|
|
4061
|
+
success=False,
|
|
4062
|
+
all_correct=False,
|
|
4063
|
+
correctness_score=0.0,
|
|
4064
|
+
geomean_speedup=0.0,
|
|
4065
|
+
passed_tests=0,
|
|
4066
|
+
total_tests=0,
|
|
4067
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
4068
|
+
)
|
|
4069
|
+
|
|
4070
|
+
# Write eval script
|
|
4071
|
+
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
4072
|
+
write_result = await client.exec(
|
|
4073
|
+
f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
|
|
4074
|
+
)
|
|
4075
|
+
if write_result.exit_code != 0:
|
|
4076
|
+
return EvaluateResult(
|
|
4077
|
+
success=False,
|
|
4078
|
+
all_correct=False,
|
|
4079
|
+
correctness_score=0.0,
|
|
4080
|
+
geomean_speedup=0.0,
|
|
4081
|
+
passed_tests=0,
|
|
4082
|
+
total_tests=0,
|
|
4083
|
+
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
4084
|
+
)
|
|
4085
|
+
|
|
4086
|
+
# Write defense module if defensive mode is enabled
|
|
4087
|
+
defense_module_path = None
|
|
4088
|
+
if args.defensive:
|
|
4089
|
+
defense_path = (
|
|
4090
|
+
Path(__file__).parent.parent.parent.parent
|
|
4091
|
+
/ "packages"
|
|
4092
|
+
/ "wafer-core"
|
|
4093
|
+
/ "wafer_core"
|
|
4094
|
+
/ "utils"
|
|
4095
|
+
/ "kernel_utils"
|
|
4096
|
+
/ "defense.py"
|
|
4097
|
+
)
|
|
4098
|
+
if defense_path.exists():
|
|
4099
|
+
defense_code = defense_path.read_text()
|
|
4100
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
4101
|
+
write_result = await client.exec(
|
|
4102
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
4103
|
+
)
|
|
4104
|
+
if write_result.exit_code != 0:
|
|
4105
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
4106
|
+
defense_module_path = None
|
|
4107
|
+
else:
|
|
4108
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
4109
|
+
|
|
4110
|
+
print("Running KernelBench evaluation (AMD/ROCm)...")
|
|
4111
|
+
|
|
4112
|
+
# Find Python with PyTorch - check common locations on RunPod
|
|
4113
|
+
python_exe = "python3"
|
|
4114
|
+
for candidate in [
|
|
4115
|
+
"/opt/conda/envs/py_3.10/bin/python3",
|
|
4116
|
+
"/opt/conda/bin/python3",
|
|
4117
|
+
]:
|
|
4118
|
+
check = await client.exec(
|
|
4119
|
+
f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
|
|
4120
|
+
)
|
|
4121
|
+
if "OK" in check.stdout:
|
|
4122
|
+
python_exe = candidate
|
|
4123
|
+
print(f"Using Python: {python_exe}")
|
|
4124
|
+
break
|
|
4125
|
+
|
|
4126
|
+
# Build eval command - run directly on host
|
|
4127
|
+
output_path = f"{run_path}/results.json"
|
|
4128
|
+
python_cmd_parts = [
|
|
4129
|
+
f"{python_exe} {eval_script_path}",
|
|
4130
|
+
f"--impl {impl_path}",
|
|
4131
|
+
f"--reference {ref_path}",
|
|
4132
|
+
f"--output {output_path}",
|
|
4133
|
+
]
|
|
4134
|
+
|
|
4135
|
+
if args.benchmark:
|
|
4136
|
+
python_cmd_parts.append("--benchmark")
|
|
4137
|
+
if args.profile:
|
|
4138
|
+
python_cmd_parts.append("--profile")
|
|
4139
|
+
if inputs_path:
|
|
4140
|
+
python_cmd_parts.append(f"--inputs {inputs_path}")
|
|
4141
|
+
if args.defensive and defense_module_path:
|
|
4142
|
+
python_cmd_parts.append("--defensive")
|
|
4143
|
+
python_cmd_parts.append(f"--defense-module {defense_module_path}")
|
|
4144
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
4145
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
4146
|
+
|
|
4147
|
+
eval_cmd = " ".join(python_cmd_parts)
|
|
4148
|
+
|
|
4149
|
+
# Set environment for AMD GPU and run
|
|
4150
|
+
# PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
|
|
4151
|
+
rocm_arch = _get_rocm_arch(target.compute_capability)
|
|
4152
|
+
arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
|
|
4153
|
+
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
|
|
4154
|
+
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4155
|
+
|
|
4156
|
+
# Handle prepare-only mode
|
|
4157
|
+
if args.prepare_only:
|
|
4158
|
+
print(f"\n[wafer] Prepared evaluation at: {run_path}")
|
|
4159
|
+
print(f"[wafer] Target: {target.name} ({client.host}:{client.port})")
|
|
4160
|
+
print("[wafer] To run manually:")
|
|
4161
|
+
print(f" ssh -p {client.port} root@{client.host} '{full_cmd}'")
|
|
4162
|
+
print("\n[wafer] Or wrap with rocprof:")
|
|
4163
|
+
print(
|
|
4164
|
+
f" ssh -p {client.port} root@{client.host} 'cd {run_path} && {env_vars} rocprof -i counters.txt {eval_cmd}'"
|
|
4165
|
+
)
|
|
4166
|
+
return EvaluateResult(
|
|
4167
|
+
success=True,
|
|
4168
|
+
all_correct=None, # Not checked in prepare-only mode
|
|
4169
|
+
correctness_score=0.0,
|
|
4170
|
+
geomean_speedup=0.0,
|
|
4171
|
+
passed_tests=0,
|
|
4172
|
+
total_tests=0,
|
|
4173
|
+
error_message=None,
|
|
4174
|
+
)
|
|
4175
|
+
|
|
4176
|
+
# Run and stream output
|
|
4177
|
+
log_lines = []
|
|
4178
|
+
async for line in client.exec_stream(full_cmd):
|
|
4179
|
+
print(line, flush=True)
|
|
4180
|
+
log_lines.append(line)
|
|
4181
|
+
|
|
4182
|
+
# Read results
|
|
4183
|
+
cat_result = await client.exec(f"cat {output_path}")
|
|
4184
|
+
|
|
4185
|
+
if cat_result.exit_code != 0:
|
|
4186
|
+
log_tail = "\n".join(log_lines[-50:])
|
|
4187
|
+
return EvaluateResult(
|
|
4188
|
+
success=False,
|
|
4189
|
+
all_correct=False,
|
|
4190
|
+
correctness_score=0.0,
|
|
4191
|
+
geomean_speedup=0.0,
|
|
4192
|
+
passed_tests=0,
|
|
4193
|
+
total_tests=0,
|
|
4194
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
4195
|
+
)
|
|
4196
|
+
|
|
4197
|
+
# Parse results
|
|
4198
|
+
try:
|
|
4199
|
+
results_data = json.loads(cat_result.stdout)
|
|
4200
|
+
except json.JSONDecodeError as e:
|
|
4201
|
+
return EvaluateResult(
|
|
4202
|
+
success=False,
|
|
4203
|
+
all_correct=False,
|
|
4204
|
+
correctness_score=0.0,
|
|
4205
|
+
geomean_speedup=0.0,
|
|
4206
|
+
passed_tests=0,
|
|
4207
|
+
total_tests=0,
|
|
4208
|
+
error_message=f"Failed to parse results: {e}",
|
|
4209
|
+
)
|
|
4210
|
+
|
|
4211
|
+
# Convert to EvaluateResult
|
|
4212
|
+
correct = results_data.get("correct", False)
|
|
4213
|
+
speedup = results_data.get("speedup", 0.0) or 0.0
|
|
4214
|
+
error = results_data.get("error")
|
|
4215
|
+
|
|
4216
|
+
if error:
|
|
4217
|
+
return EvaluateResult(
|
|
4218
|
+
success=False,
|
|
4219
|
+
all_correct=False,
|
|
4220
|
+
correctness_score=0.0,
|
|
4221
|
+
geomean_speedup=0.0,
|
|
4222
|
+
passed_tests=0,
|
|
4223
|
+
total_tests=1,
|
|
4224
|
+
error_message=error,
|
|
4225
|
+
)
|
|
4226
|
+
|
|
4227
|
+
return EvaluateResult(
|
|
4228
|
+
success=True,
|
|
4229
|
+
all_correct=correct,
|
|
4230
|
+
correctness_score=1.0 if correct else 0.0,
|
|
4231
|
+
geomean_speedup=speedup,
|
|
4232
|
+
passed_tests=1 if correct else 0,
|
|
4233
|
+
total_tests=1,
|
|
4234
|
+
)
|
|
4235
|
+
|
|
4236
|
+
except RunPodError as e:
|
|
4237
|
+
return EvaluateResult(
|
|
4238
|
+
success=False,
|
|
4239
|
+
all_correct=False,
|
|
4240
|
+
correctness_score=0.0,
|
|
4241
|
+
geomean_speedup=0.0,
|
|
4242
|
+
passed_tests=0,
|
|
4243
|
+
total_tests=0,
|
|
4244
|
+
error_message=f"RunPod error: {e}",
|
|
4245
|
+
)
|
|
4246
|
+
|
|
4247
|
+
|
|
4248
|
+
async def run_evaluate_kernelbench_baremetal_amd(
|
|
4249
|
+
args: KernelBenchEvaluateArgs,
|
|
4250
|
+
target: BaremetalTarget,
|
|
4251
|
+
) -> EvaluateResult:
|
|
4252
|
+
"""Run KernelBench format evaluation directly on AMD baremetal target.
|
|
4253
|
+
|
|
4254
|
+
Runs evaluation script directly on host (no Docker) for AMD GPUs
|
|
4255
|
+
that have PyTorch/ROCm installed.
|
|
4256
|
+
"""
|
|
4257
|
+
from datetime import datetime
|
|
4258
|
+
|
|
4259
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
4260
|
+
|
|
4261
|
+
REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
|
|
4262
|
+
|
|
4263
|
+
# Select GPU
|
|
4264
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
4265
|
+
|
|
4266
|
+
print(f"Connecting to {target.ssh_target}...")
|
|
4267
|
+
|
|
4268
|
+
async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
|
|
4269
|
+
# Create workspace
|
|
4270
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
4271
|
+
run_dir = f"kernelbench_eval_{timestamp}"
|
|
4272
|
+
run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
|
|
4273
|
+
|
|
4274
|
+
await client.exec(f"mkdir -p {run_path}")
|
|
4275
|
+
print(f"Created run directory: {run_path}")
|
|
4276
|
+
|
|
4277
|
+
# Read and upload files
|
|
4278
|
+
impl_code = args.implementation.read_text()
|
|
4279
|
+
ref_code = args.reference.read_text()
|
|
4280
|
+
|
|
4281
|
+
# Write implementation
|
|
4282
|
+
impl_path = f"{run_path}/implementation.py"
|
|
4283
|
+
write_result = await client.exec(
|
|
4284
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
4285
|
+
)
|
|
4286
|
+
if write_result.exit_code != 0:
|
|
4287
|
+
return EvaluateResult(
|
|
4288
|
+
success=False,
|
|
4289
|
+
all_correct=False,
|
|
4290
|
+
correctness_score=0.0,
|
|
4291
|
+
geomean_speedup=0.0,
|
|
4292
|
+
passed_tests=0,
|
|
4293
|
+
total_tests=0,
|
|
4294
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
4295
|
+
)
|
|
4296
|
+
|
|
4297
|
+
# Write reference
|
|
4298
|
+
ref_path = f"{run_path}/reference.py"
|
|
4299
|
+
write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
|
|
4300
|
+
if write_result.exit_code != 0:
|
|
4301
|
+
return EvaluateResult(
|
|
4302
|
+
success=False,
|
|
4303
|
+
all_correct=False,
|
|
4304
|
+
correctness_score=0.0,
|
|
4305
|
+
geomean_speedup=0.0,
|
|
4306
|
+
passed_tests=0,
|
|
4307
|
+
total_tests=0,
|
|
4308
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
4309
|
+
)
|
|
4310
|
+
|
|
4311
|
+
# Write custom inputs if provided
|
|
4312
|
+
inputs_path = None
|
|
4313
|
+
if args.inputs:
|
|
4314
|
+
inputs_code = args.inputs.read_text()
|
|
4315
|
+
inputs_path = f"{run_path}/custom_inputs.py"
|
|
4316
|
+
write_result = await client.exec(
|
|
4317
|
+
f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
4318
|
+
)
|
|
4319
|
+
if write_result.exit_code != 0:
|
|
4320
|
+
return EvaluateResult(
|
|
4321
|
+
success=False,
|
|
4322
|
+
all_correct=False,
|
|
4323
|
+
correctness_score=0.0,
|
|
4324
|
+
geomean_speedup=0.0,
|
|
4325
|
+
passed_tests=0,
|
|
4326
|
+
total_tests=0,
|
|
4327
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
4328
|
+
)
|
|
4329
|
+
|
|
4330
|
+
# Write eval script
|
|
4331
|
+
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
4332
|
+
write_result = await client.exec(
|
|
4333
|
+
f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
|
|
4334
|
+
)
|
|
4335
|
+
if write_result.exit_code != 0:
|
|
4336
|
+
return EvaluateResult(
|
|
4337
|
+
success=False,
|
|
4338
|
+
all_correct=False,
|
|
4339
|
+
correctness_score=0.0,
|
|
4340
|
+
geomean_speedup=0.0,
|
|
4341
|
+
passed_tests=0,
|
|
4342
|
+
total_tests=0,
|
|
4343
|
+
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
4344
|
+
)
|
|
4345
|
+
|
|
4346
|
+
# Write defense module if defensive mode is enabled
|
|
4347
|
+
defense_module_path = None
|
|
4348
|
+
if args.defensive:
|
|
4349
|
+
defense_path = (
|
|
4350
|
+
Path(__file__).parent.parent.parent.parent
|
|
4351
|
+
/ "packages"
|
|
4352
|
+
/ "wafer-core"
|
|
4353
|
+
/ "wafer_core"
|
|
4354
|
+
/ "utils"
|
|
4355
|
+
/ "kernel_utils"
|
|
4356
|
+
/ "defense.py"
|
|
4357
|
+
)
|
|
4358
|
+
if defense_path.exists():
|
|
4359
|
+
defense_code = defense_path.read_text()
|
|
4360
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
4361
|
+
write_result = await client.exec(
|
|
4362
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
4363
|
+
)
|
|
4364
|
+
if write_result.exit_code != 0:
|
|
4365
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
4366
|
+
defense_module_path = None
|
|
4367
|
+
else:
|
|
4368
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
4369
|
+
|
|
4370
|
+
print("Running KernelBench evaluation (AMD/ROCm)...")
|
|
4371
|
+
|
|
4372
|
+
# Find Python with PyTorch - check common locations
|
|
4373
|
+
python_exe = "python3"
|
|
4374
|
+
for candidate in [
|
|
4375
|
+
"/opt/conda/envs/py_3.10/bin/python3",
|
|
4376
|
+
"/opt/conda/bin/python3",
|
|
4377
|
+
]:
|
|
4378
|
+
check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
|
|
4379
|
+
if "OK" in check.stdout:
|
|
4380
|
+
python_exe = candidate
|
|
4381
|
+
print(f"Using Python: {python_exe}")
|
|
4382
|
+
break
|
|
4383
|
+
|
|
4384
|
+
# Build eval command - run directly on host
|
|
4385
|
+
output_path = f"{run_path}/results.json"
|
|
4386
|
+
python_cmd_parts = [
|
|
4387
|
+
f"{python_exe} {eval_script_path}",
|
|
4388
|
+
f"--impl {impl_path}",
|
|
4389
|
+
f"--reference {ref_path}",
|
|
4390
|
+
f"--output {output_path}",
|
|
4391
|
+
]
|
|
4392
|
+
|
|
4393
|
+
if args.benchmark:
|
|
4394
|
+
python_cmd_parts.append("--benchmark")
|
|
4395
|
+
if args.profile:
|
|
4396
|
+
python_cmd_parts.append("--profile")
|
|
4397
|
+
if inputs_path:
|
|
4398
|
+
python_cmd_parts.append(f"--inputs {inputs_path}")
|
|
4399
|
+
if args.defensive and defense_module_path:
|
|
4400
|
+
python_cmd_parts.append("--defensive")
|
|
4401
|
+
python_cmd_parts.append(f"--defense-module {defense_module_path}")
|
|
4402
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
4403
|
+
python_cmd_parts.append(f"--stages {args.stages}")
|
|
4404
|
+
|
|
4405
|
+
eval_cmd = " ".join(python_cmd_parts)
|
|
4406
|
+
|
|
4407
|
+
# Set environment for AMD GPU and run
|
|
4408
|
+
# PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
|
|
4409
|
+
rocm_arch = _get_rocm_arch(target.compute_capability)
|
|
4410
|
+
arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
|
|
4411
|
+
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
|
|
4412
|
+
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4413
|
+
|
|
4414
|
+
# Handle prepare-only mode
|
|
4415
|
+
if args.prepare_only:
|
|
4416
|
+
print(f"\n[wafer] Prepared evaluation at: {run_path}")
|
|
4417
|
+
print(f"[wafer] Target: {target.name} ({client.host}:{client.port})")
|
|
4418
|
+
print("[wafer] To run manually:")
|
|
4419
|
+
print(f" ssh -p {client.port} root@{client.host} '{full_cmd}'")
|
|
4420
|
+
print("\n[wafer] Or wrap with rocprof:")
|
|
4421
|
+
print(
|
|
4422
|
+
f" ssh -p {client.port} root@{client.host} 'cd {run_path} && {env_vars} rocprof -i counters.txt {eval_cmd}'"
|
|
4423
|
+
)
|
|
4424
|
+
return EvaluateResult(
|
|
4425
|
+
success=True,
|
|
4426
|
+
all_correct=None, # Not checked in prepare-only mode
|
|
4427
|
+
correctness_score=0.0,
|
|
4428
|
+
geomean_speedup=0.0,
|
|
4429
|
+
passed_tests=0,
|
|
4430
|
+
total_tests=0,
|
|
4431
|
+
error_message=None,
|
|
4432
|
+
)
|
|
4433
|
+
|
|
4434
|
+
# Run and stream output
|
|
4435
|
+
log_lines = []
|
|
4436
|
+
async for line in client.exec_stream(full_cmd):
|
|
4437
|
+
print(line, flush=True)
|
|
4438
|
+
log_lines.append(line)
|
|
4439
|
+
|
|
4440
|
+
# Read results
|
|
4441
|
+
cat_result = await client.exec(f"cat {output_path}")
|
|
4442
|
+
|
|
4443
|
+
if cat_result.exit_code != 0:
|
|
4444
|
+
log_tail = "\n".join(log_lines[-50:])
|
|
4445
|
+
return EvaluateResult(
|
|
4446
|
+
success=False,
|
|
4447
|
+
all_correct=False,
|
|
4448
|
+
correctness_score=0.0,
|
|
4449
|
+
geomean_speedup=0.0,
|
|
4450
|
+
passed_tests=0,
|
|
4451
|
+
total_tests=0,
|
|
4452
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
4453
|
+
)
|
|
4454
|
+
|
|
4455
|
+
# Parse results
|
|
4456
|
+
try:
|
|
4457
|
+
results_data = json.loads(cat_result.stdout)
|
|
4458
|
+
except json.JSONDecodeError as e:
|
|
4459
|
+
return EvaluateResult(
|
|
4460
|
+
success=False,
|
|
4461
|
+
all_correct=False,
|
|
4462
|
+
correctness_score=0.0,
|
|
4463
|
+
geomean_speedup=0.0,
|
|
4464
|
+
passed_tests=0,
|
|
4465
|
+
total_tests=0,
|
|
4466
|
+
error_message=f"Failed to parse results: {e}",
|
|
4467
|
+
)
|
|
4468
|
+
|
|
4469
|
+
# Convert to EvaluateResult
|
|
4470
|
+
correct = results_data.get("correct", False)
|
|
4471
|
+
speedup = results_data.get("speedup", 0.0) or 0.0
|
|
4472
|
+
error = results_data.get("error")
|
|
4473
|
+
|
|
4474
|
+
if error:
|
|
4475
|
+
return EvaluateResult(
|
|
4476
|
+
success=False,
|
|
4477
|
+
all_correct=False,
|
|
4478
|
+
correctness_score=0.0,
|
|
4479
|
+
geomean_speedup=0.0,
|
|
4480
|
+
passed_tests=0,
|
|
4481
|
+
total_tests=1,
|
|
4482
|
+
error_message=error,
|
|
4483
|
+
)
|
|
4484
|
+
|
|
4485
|
+
return EvaluateResult(
|
|
4486
|
+
success=True,
|
|
4487
|
+
all_correct=correct,
|
|
4488
|
+
correctness_score=1.0 if correct else 0.0,
|
|
4489
|
+
geomean_speedup=speedup,
|
|
4490
|
+
passed_tests=1 if correct else 0,
|
|
4491
|
+
total_tests=1,
|
|
4492
|
+
)
|
|
4493
|
+
|
|
4494
|
+
|
|
4495
|
+
async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
|
|
4496
|
+
"""Run KernelBench format evaluation on configured target.
|
|
4497
|
+
|
|
4498
|
+
Args:
|
|
4499
|
+
args: KernelBench evaluate arguments
|
|
4500
|
+
|
|
4501
|
+
Returns:
|
|
4502
|
+
Evaluation result
|
|
4503
|
+
"""
|
|
4504
|
+
from .targets import get_default_target, load_target
|
|
4505
|
+
|
|
4506
|
+
# Validate input files
|
|
4507
|
+
err = _validate_kernelbench_files(args)
|
|
4508
|
+
if err:
|
|
4509
|
+
return EvaluateResult(
|
|
4510
|
+
success=False,
|
|
4511
|
+
all_correct=False,
|
|
4512
|
+
correctness_score=0.0,
|
|
4513
|
+
geomean_speedup=0.0,
|
|
4514
|
+
passed_tests=0,
|
|
4515
|
+
total_tests=0,
|
|
4516
|
+
error_message=err,
|
|
4517
|
+
)
|
|
4518
|
+
|
|
4519
|
+
# Load target
|
|
4520
|
+
target_name = args.target_name
|
|
4521
|
+
if not target_name:
|
|
4522
|
+
target_name = get_default_target()
|
|
4523
|
+
if not target_name:
|
|
4524
|
+
return EvaluateResult(
|
|
4525
|
+
success=False,
|
|
4526
|
+
all_correct=False,
|
|
4527
|
+
correctness_score=0.0,
|
|
4528
|
+
geomean_speedup=0.0,
|
|
4529
|
+
passed_tests=0,
|
|
4530
|
+
total_tests=0,
|
|
4531
|
+
error_message=(
|
|
4532
|
+
"No target specified and no default set.\n"
|
|
4533
|
+
"Set up a target first:\n"
|
|
4534
|
+
" wafer config targets init ssh --name my-gpu --host user@host:22\n"
|
|
4535
|
+
" wafer config targets init runpod --gpu MI300X\n"
|
|
4536
|
+
"Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
|
|
4537
|
+
),
|
|
4538
|
+
)
|
|
4539
|
+
|
|
4540
|
+
try:
|
|
4541
|
+
target = load_target(target_name)
|
|
4542
|
+
except FileNotFoundError:
|
|
4543
|
+
return EvaluateResult(
|
|
4544
|
+
success=False,
|
|
4545
|
+
all_correct=False,
|
|
4546
|
+
correctness_score=0.0,
|
|
4547
|
+
geomean_speedup=0.0,
|
|
4548
|
+
passed_tests=0,
|
|
4549
|
+
total_tests=0,
|
|
4550
|
+
error_message=f"Target not found: {target_name}. Run: wafer config targets list",
|
|
4551
|
+
)
|
|
4552
|
+
|
|
4553
|
+
print(f"Using target: {target_name}")
|
|
4554
|
+
|
|
4555
|
+
# Dispatch to appropriate executor
|
|
4556
|
+
if isinstance(target, DigitalOceanTarget):
|
|
4557
|
+
# DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
|
|
4558
|
+
return await run_evaluate_kernelbench_digitalocean(args, target)
|
|
4559
|
+
elif isinstance(target, RunPodTarget):
|
|
4560
|
+
# RunPod AMD MI300X - uses ROCm Docker with device passthrough
|
|
4561
|
+
return await run_evaluate_kernelbench_runpod(args, target)
|
|
4562
|
+
elif isinstance(target, BaremetalTarget | VMTarget):
|
|
4563
|
+
# Check if this is an AMD target (gfx* compute capability) - run directly
|
|
4564
|
+
if target.compute_capability and target.compute_capability.startswith("gfx"):
|
|
4565
|
+
return await run_evaluate_kernelbench_baremetal_amd(args, target)
|
|
4566
|
+
# NVIDIA targets - require docker_image to be set
|
|
4567
|
+
if not target.docker_image:
|
|
4568
|
+
return EvaluateResult(
|
|
4569
|
+
success=False,
|
|
4570
|
+
all_correct=False,
|
|
4571
|
+
correctness_score=0.0,
|
|
4572
|
+
geomean_speedup=0.0,
|
|
4573
|
+
passed_tests=0,
|
|
4574
|
+
total_tests=0,
|
|
4575
|
+
error_message=(
|
|
4576
|
+
f"Target '{target_name}' does not have docker_image set. "
|
|
4577
|
+
"KernelBench format requires Docker execution."
|
|
4578
|
+
),
|
|
4579
|
+
)
|
|
4580
|
+
return await run_evaluate_kernelbench_docker(args, target)
|
|
4581
|
+
else:
|
|
4582
|
+
return EvaluateResult(
|
|
4583
|
+
success=False,
|
|
4584
|
+
all_correct=False,
|
|
4585
|
+
correctness_score=0.0,
|
|
4586
|
+
geomean_speedup=0.0,
|
|
4587
|
+
passed_tests=0,
|
|
4588
|
+
total_tests=0,
|
|
4589
|
+
error_message=(
|
|
4590
|
+
f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
|
|
4591
|
+
"Use a DigitalOcean, RunPod, Baremetal, or VM target."
|
|
4592
|
+
),
|
|
4593
|
+
)
|