wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/evaluate.py ADDED
@@ -0,0 +1,4593 @@
1
+ """Remote kernel evaluation for Wafer CLI.
2
+
3
+ Runs evaluate.py on a remote GPU target with the same interface as local execution.
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ import shlex
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ from wafer_core.utils.kernel_utils.targets.config import (
15
+ BaremetalTarget,
16
+ DigitalOceanTarget,
17
+ LocalTarget,
18
+ ModalTarget,
19
+ RunPodTarget,
20
+ VMTarget,
21
+ WorkspaceTarget,
22
+ )
23
+
24
+ # Map AMD compute capability to ROCm architecture
25
+ # Used to set PYTORCH_ROCM_ARCH for faster compilation (compile only for target arch)
26
+ AMD_CC_TO_ARCH = {
27
+ "9.4": "gfx942", # MI300X
28
+ "9.0a": "gfx90a", # MI200 series
29
+ "9.08": "gfx908", # MI100
30
+ "9.06": "gfx906", # MI50/60
31
+ "10.30": "gfx1030", # RDNA2
32
+ "11.0": "gfx1100", # RDNA3
33
+ }
34
+
35
+
36
+ def _get_rocm_arch(compute_capability: str) -> str | None:
37
+ """Get ROCm architecture string from compute capability.
38
+
39
+ Returns gfx* string for PYTORCH_ROCM_ARCH, or None if not found.
40
+ """
41
+ # Already a gfx string
42
+ if compute_capability.startswith("gfx"):
43
+ return compute_capability
44
+ # Map from numeric CC
45
+ return AMD_CC_TO_ARCH.get(compute_capability)
46
+
47
+
48
+ def _build_docker_run_command(
49
+ image: str,
50
+ command: str,
51
+ *,
52
+ working_dir: str | None = None,
53
+ env: dict[str, str] | None = None,
54
+ gpus: str = "all",
55
+ volumes: dict[str, str] | None = None,
56
+ cap_add: list[str] | None = None,
57
+ ) -> str:
58
+ """Build a docker run command string for NVIDIA GPUs.
59
+
60
+ Pure function: string in, string out. No side effects.
61
+
62
+ Args:
63
+ image: Docker image name (e.g., "nvcr.io/nvidia/cutlass:4.3-devel")
64
+ command: Command to run inside container
65
+ working_dir: Container working directory (optional)
66
+ env: Environment variables as dict (optional)
67
+ gpus: GPU access string ("all", "device=0", "device=0,1", etc.)
68
+ volumes: Host:container volume mappings (optional)
69
+ cap_add: Linux capabilities to add (e.g., ["SYS_ADMIN"] for NCU profiling)
70
+
71
+ Returns:
72
+ Complete docker run command string
73
+ """
74
+ parts = ["docker", "run", "--rm"]
75
+
76
+ # Add capabilities (needed for NCU profiling)
77
+ if cap_add:
78
+ for cap in cap_add:
79
+ parts.extend(["--cap-add", cap])
80
+
81
+ # GPU access - use single quotes for the device spec to avoid shell escaping issues
82
+ if gpus:
83
+ parts.extend(["--gpus", f"'{gpus}'"])
84
+
85
+ # Volume mounts
86
+ if volumes:
87
+ for host_path, container_path in volumes.items():
88
+ parts.extend(["-v", f"{host_path}:{container_path}"])
89
+
90
+ # Working directory
91
+ if working_dir:
92
+ parts.extend(["-w", working_dir])
93
+
94
+ # Environment variables
95
+ if env:
96
+ for key, value in env.items():
97
+ parts.extend(["-e", f"{key}={shlex.quote(value)}"])
98
+
99
+ # Image and command
100
+ parts.append(image)
101
+ parts.append(f"bash -c {shlex.quote(command)}")
102
+
103
+ return " ".join(parts)
104
+
105
+
106
+ def _build_docker_run_command_amd(
107
+ image: str,
108
+ command: str,
109
+ *,
110
+ working_dir: str | None = None,
111
+ env: dict[str, str] | None = None,
112
+ volumes: dict[str, str] | None = None,
113
+ ) -> str:
114
+ """Build a docker run command string for AMD GPUs (ROCm).
115
+
116
+ Uses device passthrough instead of NVIDIA's --gpus flag.
117
+
118
+ Args:
119
+ image: Docker image name (e.g., "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0")
120
+ command: Command to run inside container
121
+ working_dir: Container working directory (optional)
122
+ env: Environment variables as dict (optional)
123
+ volumes: Host:container volume mappings (optional)
124
+
125
+ Returns:
126
+ Complete docker run command string
127
+ """
128
+ parts = ["docker", "run", "--rm"]
129
+
130
+ # AMD GPU access via device passthrough
131
+ parts.extend(["--device=/dev/kfd", "--device=/dev/dri", "--group-add", "video"])
132
+
133
+ # Volume mounts
134
+ if volumes:
135
+ for host_path, container_path in volumes.items():
136
+ parts.extend(["-v", f"{host_path}:{container_path}"])
137
+
138
+ # Working directory
139
+ if working_dir:
140
+ parts.extend(["-w", working_dir])
141
+
142
+ # Environment variables
143
+ if env:
144
+ for key, value in env.items():
145
+ parts.extend(["-e", f"{key}={shlex.quote(value)}"])
146
+
147
+ # Image and command
148
+ parts.append(image)
149
+ parts.append(f"bash -c {shlex.quote(command)}")
150
+
151
+ return " ".join(parts)
152
+
153
+
154
+ @dataclass(frozen=True)
155
+ class EvaluateArgs:
156
+ """Arguments for evaluate command.
157
+
158
+ Mirrors evaluate.py's CLI args.
159
+ """
160
+
161
+ implementation: Path
162
+ reference: Path
163
+ test_cases: Path
164
+ target_name: str
165
+ benchmark: bool = False
166
+ profile: bool = False
167
+ defensive: bool = False
168
+ sync_artifacts: bool = True
169
+ gpu_id: int | None = None
170
+
171
+
172
+ @dataclass(frozen=True)
173
+ class KernelBenchEvaluateArgs:
174
+ """Arguments for KernelBench format evaluate command.
175
+
176
+ KernelBench format uses Model/ModelNew classes instead of functions.
177
+ No test_cases file - reference defines get_inputs()/get_init_inputs().
178
+ """
179
+
180
+ implementation: Path # Must define ModelNew class
181
+ reference: Path # Must define Model, get_inputs, get_init_inputs
182
+ target_name: str
183
+ benchmark: bool = False
184
+ profile: bool = False
185
+ inputs: Path | None = None # Custom inputs file to override get_inputs()
186
+ seed: int = 42 # Random seed for reproducibility
187
+ defensive: bool = False
188
+ backend: str | None = None # Kernel backend for static validation
189
+ sync_artifacts: bool = True
190
+ gpu_id: int | None = None
191
+ stages: str = "compile,correctness" # Stages to run: compile, correctness, benchmark, defense
192
+ prepare_only: bool = False # Sync files and generate script but don't run
193
+
194
+
195
+ @dataclass(frozen=True)
196
+ class EvaluateResult:
197
+ """Result from remote evaluation."""
198
+
199
+ success: bool
200
+ all_correct: bool | None # None when correctness wasn't checked (compile-only, prepare-only)
201
+ correctness_score: float
202
+ geomean_speedup: float
203
+ passed_tests: int
204
+ total_tests: int
205
+ error_message: str | None = None
206
+ artifact_path: Path | None = None
207
+
208
+
209
+ def _check_python_file_has(path: Path, *names: str) -> list[str]:
210
+ """Check if a Python file exports the given names.
211
+
212
+ Uses AST parsing to find:
213
+ - Function definitions: def name(...)
214
+ - Class definitions: class name(...)
215
+ - Assignments: name = ...
216
+ - Imports: from module import name / from module import x as name
217
+
218
+ Returns:
219
+ List of names that are missing
220
+ """
221
+ import ast
222
+
223
+ content = path.read_text()
224
+ try:
225
+ tree = ast.parse(content)
226
+ except SyntaxError:
227
+ # If we can't parse, let the runtime fail with a better error
228
+ return []
229
+
230
+ defined_names: set[str] = set()
231
+ for node in ast.walk(tree):
232
+ if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
233
+ defined_names.add(node.name)
234
+ elif isinstance(node, ast.ClassDef):
235
+ defined_names.add(node.name)
236
+ elif isinstance(node, ast.Assign):
237
+ for target in node.targets:
238
+ if isinstance(target, ast.Name):
239
+ defined_names.add(target.id)
240
+ elif isinstance(node, ast.ImportFrom):
241
+ for alias in node.names:
242
+ # Use asname if present, otherwise use the original name
243
+ defined_names.add(alias.asname or alias.name)
244
+
245
+ return [name for name in names if name not in defined_names]
246
+
247
+
248
+ def _validate_files(args: EvaluateArgs) -> str | None:
249
+ """Validate that all input files exist, have correct format, and expected signatures.
250
+
251
+ Returns:
252
+ Error message if validation fails, None if all valid
253
+ """
254
+ if not args.implementation.exists():
255
+ return f"Implementation file not found: {args.implementation}"
256
+ if not args.reference.exists():
257
+ return f"Reference file not found: {args.reference}"
258
+ if not args.test_cases.exists():
259
+ return f"Test cases file not found: {args.test_cases}"
260
+
261
+ # Validate test_cases is valid JSON
262
+ try:
263
+ json.loads(args.test_cases.read_text())
264
+ except json.JSONDecodeError:
265
+ if args.test_cases.suffix == ".py":
266
+ return (
267
+ f"--test-cases must be a JSON file, not a Python file: {args.test_cases}\n"
268
+ "Hint: For KernelBench problems, use 'wafer evaluate kernelbench' instead:\n"
269
+ f" wafer evaluate kernelbench --impl <impl.py> --reference {args.test_cases}"
270
+ )
271
+ return f"--test-cases must be valid JSON: {args.test_cases}"
272
+
273
+ # Validate implementation has custom_kernel
274
+ impl_missing = _check_python_file_has(args.implementation, "custom_kernel")
275
+ if impl_missing:
276
+ # Check if it looks like KernelBench format (has ModelNew)
277
+ has_model_new = not _check_python_file_has(args.implementation, "ModelNew")
278
+ if has_model_new:
279
+ return (
280
+ f"Implementation file missing 'custom_kernel' function: {args.implementation}\n"
281
+ "Hint: This looks like KernelBench format. Use 'wafer evaluate kernelbench' instead:\n"
282
+ f" wafer evaluate kernelbench --impl {args.implementation} --reference <reference.py>"
283
+ )
284
+ return (
285
+ f"Implementation file missing 'custom_kernel' function: {args.implementation}\n"
286
+ " Required: 'def custom_kernel(inputs)' function"
287
+ )
288
+
289
+ # Validate reference has ref_kernel and generate_input
290
+ ref_missing = _check_python_file_has(args.reference, "ref_kernel", "generate_input")
291
+ if ref_missing:
292
+ # Check if it looks like KernelBench format (has Model and get_inputs)
293
+ has_kernelbench = not _check_python_file_has(args.reference, "Model", "get_inputs")
294
+ if has_kernelbench:
295
+ return (
296
+ f"Reference file missing required functions: {', '.join(ref_missing)}\n"
297
+ "Hint: This looks like KernelBench format. Use 'wafer evaluate kernelbench' instead:\n"
298
+ f" wafer evaluate kernelbench --impl <impl.py> --reference {args.reference}"
299
+ )
300
+ return (
301
+ f"Reference file missing required functions: {', '.join(ref_missing)}\n"
302
+ f" File: {args.reference}\n"
303
+ " Required: 'ref_kernel' and 'generate_input' functions"
304
+ )
305
+
306
+ return None
307
+
308
+
309
+ def _select_gpu_id(
310
+ target: BaremetalTarget | VMTarget | ModalTarget, gpu_id_override: int | None
311
+ ) -> int:
312
+ """Select GPU ID to use.
313
+
314
+ Args:
315
+ target: Target config
316
+ gpu_id_override: Optional explicit GPU ID
317
+
318
+ Returns:
319
+ GPU ID to use
320
+ """
321
+ if gpu_id_override is not None:
322
+ return gpu_id_override
323
+
324
+ # Use first GPU from target's list
325
+ if isinstance(target, BaremetalTarget | VMTarget):
326
+ return target.gpu_ids[0]
327
+
328
+ # Modal doesn't have explicit GPU IDs
329
+ return 0
330
+
331
+
332
+ def _build_docker_pip_install_cmd(target: BaremetalTarget | VMTarget) -> str:
333
+ """Build pip install command for Docker container.
334
+
335
+ Installs uv first, then uses uv to install packages (Modal-like approach).
336
+ Uses --system flag to install to container's system Python (not any venv).
337
+
338
+ Handles base CUDA images that may not have pip pre-installed.
339
+
340
+ Args:
341
+ target: Target config with pip_packages, torch_package, torch_index_url
342
+
343
+ Returns:
344
+ Shell command string to install dependencies
345
+ """
346
+ commands = []
347
+
348
+ # Some base images (like nvidia/cuda) don't have pip or git, install them first
349
+ # Use apt for Debian/Ubuntu-based images, with noninteractive to avoid prompts
350
+ commands.append(
351
+ "(which pip > /dev/null 2>&1 && which git > /dev/null 2>&1) || "
352
+ "(apt-get update && "
353
+ "DEBIAN_FRONTEND=noninteractive apt-get install -y python3 python3-pip git > /dev/null)"
354
+ )
355
+
356
+ # Install uv (fast, reliable) - use pip3 for compatibility
357
+ commands.append("pip3 install uv")
358
+
359
+ # Install torch with custom index if specified (like Modal's two-phase install)
360
+ # Use --system --break-system-packages to install to container's Python
361
+ # (needed for Python 3.12+ with PEP 668 externally managed environments)
362
+ if target.torch_package:
363
+ if target.torch_index_url:
364
+ commands.append(
365
+ f"uv pip install --system --break-system-packages --index-url {target.torch_index_url} "
366
+ f"--extra-index-url https://pypi.org/simple {target.torch_package}"
367
+ )
368
+ else:
369
+ commands.append(
370
+ f"uv pip install --system --break-system-packages {target.torch_package}"
371
+ )
372
+
373
+ # Install other packages
374
+ if target.pip_packages:
375
+ packages_str = " ".join(target.pip_packages)
376
+ commands.append(f"uv pip install --system --break-system-packages {packages_str}")
377
+
378
+ return " && ".join(commands)
379
+
380
+
381
+ def _get_wafer_root() -> Path:
382
+ """Get wafer monorepo root directory.
383
+
384
+ Walks up from this file to find the wafer repo root (contains apps/, packages/).
385
+ """
386
+ current = Path(__file__).resolve()
387
+ for parent in [current] + list(current.parents):
388
+ if (parent / "apps").is_dir() and (parent / "packages").is_dir():
389
+ return parent
390
+ raise RuntimeError(f"Could not find wafer root from {__file__}")
391
+
392
+
393
+ async def run_evaluate_docker(
394
+ args: EvaluateArgs,
395
+ target: BaremetalTarget | VMTarget,
396
+ ) -> EvaluateResult:
397
+ """Run evaluation in Docker container on SSH-based target.
398
+
399
+ Uses async SSH client for true non-blocking I/O.
400
+ Uploads wafer-core and runs evaluate.py directly with PYTHONPATH.
401
+ No package installation needed - avoids rollouts dependency.
402
+
403
+ Args:
404
+ args: Evaluate arguments
405
+ target: SSH target config with docker_image set
406
+
407
+ Returns:
408
+ Evaluation result
409
+ """
410
+ from datetime import datetime
411
+
412
+ from wafer_core.async_ssh import AsyncSSHClient
413
+
414
+ CONTAINER_WORKSPACE = "/workspace"
415
+ REMOTE_WORKSPACE_BASE = "~/.wafer/workspaces"
416
+
417
+ if not target.docker_image:
418
+ raise ValueError("docker_image must be set for Docker execution")
419
+
420
+ # Select GPU
421
+ gpu_id = _select_gpu_id(target, args.gpu_id)
422
+
423
+ print(f"Connecting to {target.ssh_target}...")
424
+
425
+ async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
426
+ print(f"Using Docker image: {target.docker_image}")
427
+ print(f"Using GPU {gpu_id}...")
428
+
429
+ # Read local files
430
+ impl_code = args.implementation.read_text()
431
+ ref_code = args.reference.read_text()
432
+ test_cases_data = json.loads(args.test_cases.read_text())
433
+
434
+ # Create workspace for evaluation files
435
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
436
+ run_dir = f"wafer_eval_{timestamp}"
437
+ eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
438
+ await client.exec(f"mkdir -p {eval_workspace}")
439
+ eval_workspace_expanded = await client.expand_path(eval_workspace)
440
+ run_path = f"{eval_workspace_expanded}/{run_dir}"
441
+
442
+ print("Uploading evaluation files...")
443
+
444
+ # Create run directory
445
+ mkdir_result = await client.exec(f"mkdir -p {run_path}")
446
+ if mkdir_result.exit_code != 0:
447
+ return EvaluateResult(
448
+ success=False,
449
+ all_correct=False,
450
+ correctness_score=0.0,
451
+ geomean_speedup=0.0,
452
+ passed_tests=0,
453
+ total_tests=0,
454
+ error_message=f"Failed to create run directory: {mkdir_result.stderr}",
455
+ )
456
+
457
+ # Write implementation
458
+ impl_path = f"{run_path}/implementation.py"
459
+ write_result = await client.exec(
460
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
461
+ )
462
+ if write_result.exit_code != 0:
463
+ return EvaluateResult(
464
+ success=False,
465
+ all_correct=False,
466
+ correctness_score=0.0,
467
+ geomean_speedup=0.0,
468
+ passed_tests=0,
469
+ total_tests=0,
470
+ error_message=f"Failed to write implementation: {write_result.stderr}",
471
+ )
472
+
473
+ # Write reference
474
+ ref_path = f"{run_path}/reference.py"
475
+ write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
476
+ if write_result.exit_code != 0:
477
+ return EvaluateResult(
478
+ success=False,
479
+ all_correct=False,
480
+ correctness_score=0.0,
481
+ geomean_speedup=0.0,
482
+ passed_tests=0,
483
+ total_tests=0,
484
+ error_message=f"Failed to write reference: {write_result.stderr}",
485
+ )
486
+
487
+ # Also write as reference_kernel.py (evaluate.py imports generate_input from this)
488
+ ref_kernel_path = f"{run_path}/reference_kernel.py"
489
+ write_result = await client.exec(
490
+ f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
491
+ )
492
+ if write_result.exit_code != 0:
493
+ return EvaluateResult(
494
+ success=False,
495
+ all_correct=False,
496
+ correctness_score=0.0,
497
+ geomean_speedup=0.0,
498
+ passed_tests=0,
499
+ total_tests=0,
500
+ error_message=f"Failed to write reference_kernel: {write_result.stderr}",
501
+ )
502
+
503
+ # Write test cases
504
+ test_cases_path = f"{run_path}/test_cases.json"
505
+ test_cases_json = json.dumps(test_cases_data, indent=2)
506
+ write_result = await client.exec(
507
+ f"cat > '{test_cases_path}' << 'TESTS_EOF'\n{test_cases_json}\nTESTS_EOF"
508
+ )
509
+ if write_result.exit_code != 0:
510
+ return EvaluateResult(
511
+ success=False,
512
+ all_correct=False,
513
+ correctness_score=0.0,
514
+ geomean_speedup=0.0,
515
+ passed_tests=0,
516
+ total_tests=0,
517
+ error_message=f"Failed to write test cases: {write_result.stderr}",
518
+ )
519
+
520
+ print("Running evaluation in Docker container...")
521
+
522
+ # Paths inside container (workspace mounted at /workspace)
523
+ container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
524
+ container_impl_path = f"{container_run_path}/implementation.py"
525
+ container_ref_path = f"{container_run_path}/reference.py"
526
+ container_test_cases_path = f"{container_run_path}/test_cases.json"
527
+
528
+ # Build pip install command for torch and other deps, plus wafer-core
529
+ pip_install_cmd = _build_docker_pip_install_cmd(target)
530
+ install_cmd = (
531
+ f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
532
+ )
533
+
534
+ # Build evaluate command using installed wafer-core module
535
+ python_cmd_parts = [
536
+ "python3 -m wafer_core.utils.kernel_utils.evaluate",
537
+ f"--implementation {container_impl_path}",
538
+ f"--reference {container_ref_path}",
539
+ f"--test-cases {container_test_cases_path}",
540
+ f"--run-dir {container_run_path}",
541
+ ]
542
+
543
+ if args.benchmark:
544
+ python_cmd_parts.append("--benchmark")
545
+ if args.profile:
546
+ python_cmd_parts.append("--profile")
547
+ if args.defensive:
548
+ python_cmd_parts.append("--defensive")
549
+
550
+ eval_cmd = " ".join(python_cmd_parts)
551
+
552
+ # Full command: install deps + wafer-core, then run evaluate
553
+ full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
554
+
555
+ # Build Docker run command
556
+ # Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
557
+ docker_cmd = _build_docker_run_command(
558
+ image=target.docker_image,
559
+ command=full_cmd,
560
+ working_dir=container_run_path,
561
+ env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
562
+ gpus="all",
563
+ volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
564
+ cap_add=["SYS_ADMIN"] if args.profile else None,
565
+ )
566
+
567
+ print(f"Docker command: {docker_cmd[:100]}...")
568
+
569
+ # Run Docker command and stream output
570
+ log_lines = []
571
+ async for line in client.exec_stream(docker_cmd):
572
+ print(line, flush=True)
573
+ log_lines.append(line)
574
+
575
+ # Read results
576
+ results_path = f"{run_path}/results.json"
577
+ cat_result = await client.exec(f"cat {results_path}")
578
+
579
+ if cat_result.exit_code != 0:
580
+ log_tail = "\n".join(log_lines[-50:])
581
+ return EvaluateResult(
582
+ success=False,
583
+ all_correct=False,
584
+ correctness_score=0.0,
585
+ geomean_speedup=0.0,
586
+ passed_tests=0,
587
+ total_tests=0,
588
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
589
+ )
590
+
591
+ # Parse results
592
+ try:
593
+ results_data = json.loads(cat_result.stdout)
594
+ except json.JSONDecodeError as e:
595
+ return EvaluateResult(
596
+ success=False,
597
+ all_correct=False,
598
+ correctness_score=0.0,
599
+ geomean_speedup=0.0,
600
+ passed_tests=0,
601
+ total_tests=0,
602
+ error_message=f"Failed to parse results: {e}",
603
+ )
604
+
605
+ # Extract backend results
606
+ backends = results_data.get("backends", [])
607
+ if not backends:
608
+ return EvaluateResult(
609
+ success=False,
610
+ all_correct=False,
611
+ correctness_score=0.0,
612
+ geomean_speedup=0.0,
613
+ passed_tests=0,
614
+ total_tests=0,
615
+ error_message="No backend results found",
616
+ )
617
+
618
+ backend = backends[0]
619
+ correctness_tests = backend.get("correctness_tests", [])
620
+ passed = sum(1 for t in correctness_tests if t.get("is_correct", False))
621
+ total = len(correctness_tests)
622
+
623
+ # Sync artifacts if requested
624
+ artifact_path = None
625
+ if args.sync_artifacts:
626
+ local_artifact_dir = Path.cwd() / "wafer_artifacts" / run_dir
627
+ local_artifact_dir.mkdir(parents=True, exist_ok=True)
628
+
629
+ try:
630
+ # Download results.json
631
+ download_result = await client.download_files(
632
+ remote_path=f"{run_path}/results.json",
633
+ local_path=str(local_artifact_dir / "results.json"),
634
+ )
635
+ if download_result.success:
636
+ artifact_path = local_artifact_dir
637
+ print(f"Artifacts saved to: {artifact_path}")
638
+ else:
639
+ print(f"Warning: Failed to sync results.json: {download_result.error_message}")
640
+
641
+ # Download NCU profiles if they exist (from --profile flag)
642
+ # NCU profiles are stored in artifact/ncu/ subdirectory
643
+ ncu_check = await client.exec(f"test -d {run_path}/artifact/ncu")
644
+ if ncu_check.exit_code == 0:
645
+ local_ncu_dir = local_artifact_dir / "ncu"
646
+ local_ncu_dir.mkdir(parents=True, exist_ok=True)
647
+ ncu_result = await client.download_files(
648
+ remote_path=f"{run_path}/artifact/ncu",
649
+ local_path=str(local_ncu_dir),
650
+ recursive=True,
651
+ )
652
+ if ncu_result.success:
653
+ print(f"NCU profiles synced: {ncu_result.files_copied} files")
654
+ else:
655
+ print(f"Warning: Failed to sync NCU profiles: {ncu_result.error_message}")
656
+ except Exception as e:
657
+ print(f"Warning: Failed to sync artifacts: {e}")
658
+
659
+ return EvaluateResult(
660
+ success=True,
661
+ all_correct=backend.get("all_correct", False),
662
+ correctness_score=backend.get("correctness_score", 0.0),
663
+ geomean_speedup=backend.get("geomean_speedup", 0.0),
664
+ passed_tests=passed,
665
+ total_tests=total,
666
+ artifact_path=artifact_path,
667
+ )
668
+
669
+
670
+ async def run_evaluate_local(
671
+ args: EvaluateArgs,
672
+ target: LocalTarget,
673
+ ) -> EvaluateResult:
674
+ """Run evaluation locally on the current machine.
675
+
676
+ For LocalTarget - no SSH needed, runs directly.
677
+
678
+ Args:
679
+ args: Evaluate arguments
680
+ target: Local target config
681
+
682
+ Returns:
683
+ Evaluation result
684
+ """
685
+ import os
686
+ import subprocess
687
+ import tempfile
688
+ from datetime import datetime
689
+
690
+ # Select GPU
691
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
692
+
693
+ print(f"Running local evaluation on GPU {gpu_id}...")
694
+
695
+ # Create temp directory for eval files
696
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
697
+ with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
698
+ run_path = Path(run_path)
699
+
700
+ # Write implementation
701
+ impl_path = run_path / "implementation.py"
702
+ impl_path.write_text(args.implementation.read_text())
703
+
704
+ # Write reference
705
+ ref_path = run_path / "reference.py"
706
+ ref_path.write_text(args.reference.read_text())
707
+
708
+ # Write custom inputs if provided
709
+ inputs_path = None
710
+ if args.inputs:
711
+ inputs_path = run_path / "custom_inputs.py"
712
+ inputs_path.write_text(args.inputs.read_text())
713
+
714
+ # Write eval script
715
+ eval_script_path = run_path / "kernelbench_eval.py"
716
+ eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
717
+
718
+ # Write defense module if defensive mode is enabled
719
+ defense_module_path = None
720
+ if args.defensive:
721
+ defense_src = (
722
+ Path(__file__).parent.parent.parent.parent
723
+ / "packages"
724
+ / "wafer-core"
725
+ / "wafer_core"
726
+ / "utils"
727
+ / "kernel_utils"
728
+ / "defense.py"
729
+ )
730
+ if defense_src.exists():
731
+ defense_module_path = run_path / "defense.py"
732
+ defense_module_path.write_text(defense_src.read_text())
733
+ else:
734
+ print(f"Warning: defense.py not found at {defense_src}")
735
+
736
+ # Output file
737
+ output_path = run_path / "results.json"
738
+
739
+ # Build eval command
740
+ cmd_parts = [
741
+ "python3",
742
+ str(eval_script_path),
743
+ "--impl",
744
+ str(impl_path),
745
+ "--reference",
746
+ str(ref_path),
747
+ "--output",
748
+ str(output_path),
749
+ "--seed",
750
+ str(args.seed),
751
+ ]
752
+
753
+ if args.benchmark:
754
+ cmd_parts.append("--benchmark")
755
+ if args.profile:
756
+ cmd_parts.append("--profile")
757
+ if inputs_path:
758
+ cmd_parts.extend(["--inputs", str(inputs_path)])
759
+ if args.defensive and defense_module_path:
760
+ cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
761
+
762
+ # Set environment for GPU selection
763
+ env = os.environ.copy()
764
+ if target.vendor == "nvidia":
765
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
766
+ else: # AMD
767
+ env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
768
+ env["ROCM_PATH"] = "/opt/rocm"
769
+
770
+ print(f"Running: {' '.join(cmd_parts[:4])} ...")
771
+
772
+ # Run evaluation
773
+ try:
774
+ result = subprocess.run(
775
+ cmd_parts,
776
+ cwd=str(run_path),
777
+ env=env,
778
+ capture_output=True,
779
+ text=True,
780
+ timeout=args.timeout or 600,
781
+ )
782
+ except subprocess.TimeoutExpired:
783
+ return EvaluateResult(
784
+ success=False,
785
+ all_correct=False,
786
+ correctness_score=0.0,
787
+ geomean_speedup=0.0,
788
+ passed_tests=0,
789
+ total_tests=0,
790
+ error_message="Evaluation timed out",
791
+ )
792
+
793
+ if result.returncode != 0:
794
+ error_msg = result.stderr or result.stdout or "Unknown error"
795
+ # Truncate long errors
796
+ if len(error_msg) > 1000:
797
+ error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
798
+ return EvaluateResult(
799
+ success=False,
800
+ all_correct=False,
801
+ correctness_score=0.0,
802
+ geomean_speedup=0.0,
803
+ passed_tests=0,
804
+ total_tests=0,
805
+ error_message=f"Evaluation failed:\n{error_msg}",
806
+ )
807
+
808
+ # Parse results
809
+ if not output_path.exists():
810
+ return EvaluateResult(
811
+ success=False,
812
+ all_correct=False,
813
+ correctness_score=0.0,
814
+ geomean_speedup=0.0,
815
+ passed_tests=0,
816
+ total_tests=0,
817
+ error_message="No results.json produced",
818
+ )
819
+
820
+ try:
821
+ results = json.loads(output_path.read_text())
822
+ except json.JSONDecodeError as e:
823
+ return EvaluateResult(
824
+ success=False,
825
+ all_correct=False,
826
+ correctness_score=0.0,
827
+ geomean_speedup=0.0,
828
+ passed_tests=0,
829
+ total_tests=0,
830
+ error_message=f"Failed to parse results: {e}",
831
+ )
832
+
833
+ # Extract results
834
+ return EvaluateResult(
835
+ success=True,
836
+ all_correct=results.get("all_correct", False),
837
+ correctness_score=results.get("correctness_score", 0.0),
838
+ geomean_speedup=results.get("geomean_speedup", 0.0),
839
+ passed_tests=results.get("passed_tests", 0),
840
+ total_tests=results.get("total_tests", 0),
841
+ benchmark_results=results.get("benchmark", {}),
842
+ )
843
+
844
+
845
+ async def run_evaluate_ssh(
846
+ args: EvaluateArgs,
847
+ target: BaremetalTarget | VMTarget,
848
+ ) -> EvaluateResult:
849
+ """Run evaluation on SSH-based target (Baremetal or VM).
850
+
851
+ Routes to Docker or venv execution based on target.docker_image.
852
+
853
+ If docker_image is set:
854
+ - Uses Docker container with GPU passthrough
855
+ - Installs deps via uv inside container (Modal-like)
856
+
857
+ If docker_image is not set:
858
+ - Uses the existing venv-based deployment infrastructure
859
+
860
+ Args:
861
+ args: Evaluate arguments
862
+ target: SSH target config
863
+
864
+ Returns:
865
+ Evaluation result
866
+ """
867
+ # Route to Docker execution if docker_image is set
868
+ if target.docker_image:
869
+ return await run_evaluate_docker(args, target)
870
+
871
+ # Otherwise, use venv-based execution (existing path)
872
+ from datetime import datetime
873
+
874
+ from wafer_core.remote_jobs import (
875
+ LogStreamConfig,
876
+ start_tmux_session,
877
+ stream_log_until_complete,
878
+ )
879
+ from wafer_core.utils.kernel_utils.deployment import (
880
+ DeploymentConfig,
881
+ setup_deployment,
882
+ )
883
+
884
+ # Select GPU
885
+ gpu_id = _select_gpu_id(target, args.gpu_id)
886
+
887
+ # Create deployment config
888
+ config = DeploymentConfig(
889
+ ssh_target=target.ssh_target,
890
+ ssh_key=target.ssh_key,
891
+ gpu_id=gpu_id,
892
+ )
893
+
894
+ print(f"Connecting to {target.ssh_target}...")
895
+
896
+ # Setup deployment (expensive - deploys monorepo + creates venv)
897
+ state, err = await setup_deployment(config)
898
+ if err:
899
+ return EvaluateResult(
900
+ success=False,
901
+ all_correct=False,
902
+ correctness_score=0.0,
903
+ geomean_speedup=0.0,
904
+ passed_tests=0,
905
+ total_tests=0,
906
+ error_message=f"Deployment setup failed: {err}",
907
+ )
908
+
909
+ assert state is not None
910
+
911
+ print(f"Using GPU {gpu_id}...")
912
+
913
+ # Read local files
914
+ impl_code = args.implementation.read_text()
915
+ ref_code = args.reference.read_text()
916
+ test_cases_data = json.loads(args.test_cases.read_text())
917
+
918
+ # Create a unique run directory within the deployed workspace
919
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
920
+ run_dir = f"wafer_eval_{timestamp}"
921
+
922
+ # workspace_path is the project path (e.g., .../research/async-wevin/benchmarks/gpumode)
923
+ workspace = state.workspace_path
924
+ run_path = f"{workspace}/{run_dir}"
925
+
926
+ # Get SSH client from deployment state
927
+ client = state.ssh_client
928
+
929
+ print("Uploading files...")
930
+
931
+ # Create run directory
932
+ mkdir_result = client.exec(f"mkdir -p {run_path}")
933
+ if mkdir_result.exit_code != 0:
934
+ return EvaluateResult(
935
+ success=False,
936
+ all_correct=False,
937
+ correctness_score=0.0,
938
+ geomean_speedup=0.0,
939
+ passed_tests=0,
940
+ total_tests=0,
941
+ error_message=f"Failed to create run directory: {mkdir_result.stderr}",
942
+ )
943
+
944
+ # Write implementation (must define custom_kernel function)
945
+ impl_path = f"{run_path}/implementation.py"
946
+ write_result = client.exec(f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF")
947
+ if write_result.exit_code != 0:
948
+ return EvaluateResult(
949
+ success=False,
950
+ all_correct=False,
951
+ correctness_score=0.0,
952
+ geomean_speedup=0.0,
953
+ passed_tests=0,
954
+ total_tests=0,
955
+ error_message=f"Failed to write implementation: {write_result.stderr}",
956
+ )
957
+
958
+ # Write reference (must define ref_kernel function)
959
+ ref_path = f"{run_path}/reference.py"
960
+ write_result = client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
961
+ if write_result.exit_code != 0:
962
+ return EvaluateResult(
963
+ success=False,
964
+ all_correct=False,
965
+ correctness_score=0.0,
966
+ geomean_speedup=0.0,
967
+ passed_tests=0,
968
+ total_tests=0,
969
+ error_message=f"Failed to write reference: {write_result.stderr}",
970
+ )
971
+
972
+ # Also write as reference_kernel.py (evaluate.py imports generate_input from this)
973
+ ref_kernel_path = f"{run_path}/reference_kernel.py"
974
+ write_result = client.exec(f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
975
+ if write_result.exit_code != 0:
976
+ return EvaluateResult(
977
+ success=False,
978
+ all_correct=False,
979
+ correctness_score=0.0,
980
+ geomean_speedup=0.0,
981
+ passed_tests=0,
982
+ total_tests=0,
983
+ error_message=f"Failed to write reference_kernel: {write_result.stderr}",
984
+ )
985
+
986
+ # Write test cases
987
+ test_cases_path = f"{run_path}/test_cases.json"
988
+ test_cases_json = json.dumps(test_cases_data, indent=2)
989
+ write_result = client.exec(
990
+ f"cat > '{test_cases_path}' << 'TESTS_EOF'\n{test_cases_json}\nTESTS_EOF"
991
+ )
992
+ if write_result.exit_code != 0:
993
+ return EvaluateResult(
994
+ success=False,
995
+ all_correct=False,
996
+ correctness_score=0.0,
997
+ geomean_speedup=0.0,
998
+ passed_tests=0,
999
+ total_tests=0,
1000
+ error_message=f"Failed to write test cases: {write_result.stderr}",
1001
+ )
1002
+
1003
+ print("Running evaluation...")
1004
+
1005
+ # Build evaluate command
1006
+ # The deployment deploys to research/async-wevin/benchmarks/gpumode
1007
+ # evaluate.py is at research/async-wevin/wafer_utils/kernel_utils/evaluate.py
1008
+ # So we need to go up 2 levels from workspace to find async-wevin root
1009
+ # workspace = .../research/async-wevin/benchmarks/gpumode
1010
+ # async_wevin_root = .../research/async-wevin
1011
+ async_wevin_root = "/".join(workspace.rstrip("/").split("/")[:-2])
1012
+ evaluate_script = f"{async_wevin_root}/wafer_utils/kernel_utils/evaluate.py"
1013
+
1014
+ env_state = state.env_state
1015
+
1016
+ eval_cmd_parts = [
1017
+ f"cd {run_path} &&",
1018
+ f"PATH={env_state.venv_bin}:$PATH",
1019
+ f"{env_state.venv_python} {evaluate_script}",
1020
+ f"--implementation {impl_path}",
1021
+ f"--reference {ref_path}",
1022
+ f"--test-cases {test_cases_path}",
1023
+ f"--run-dir {run_path}",
1024
+ ]
1025
+
1026
+ if args.benchmark:
1027
+ eval_cmd_parts.append("--benchmark")
1028
+ if args.profile:
1029
+ eval_cmd_parts.append("--profile")
1030
+ if args.defensive:
1031
+ eval_cmd_parts.append("--defensive")
1032
+
1033
+ eval_cmd = " ".join(eval_cmd_parts)
1034
+
1035
+ # Run via tmux for streaming output
1036
+ session_name = f"wafer_eval_{datetime.now().strftime('%H%M%S')}"
1037
+ log_file = f"{run_path}/evaluate.log"
1038
+
1039
+ _, err = start_tmux_session(
1040
+ client=client,
1041
+ session_name=session_name,
1042
+ command=eval_cmd,
1043
+ workspace=run_path,
1044
+ log_file=log_file,
1045
+ env_vars={
1046
+ "CUDA_VISIBLE_DEVICES": str(gpu_id),
1047
+ "PYTHONUNBUFFERED": "1",
1048
+ },
1049
+ )
1050
+
1051
+ if err:
1052
+ return EvaluateResult(
1053
+ success=False,
1054
+ all_correct=False,
1055
+ correctness_score=0.0,
1056
+ geomean_speedup=0.0,
1057
+ passed_tests=0,
1058
+ total_tests=0,
1059
+ error_message=f"Failed to start evaluation: {err}",
1060
+ )
1061
+
1062
+ # Stream logs until completion
1063
+ stream_config = LogStreamConfig(
1064
+ session_name=session_name,
1065
+ log_file=log_file,
1066
+ timeout_sec=600, # 10 minutes max
1067
+ poll_interval_sec=2.0,
1068
+ )
1069
+
1070
+ _ = stream_log_until_complete(client=client, config=stream_config)
1071
+
1072
+ # Read results
1073
+ results_path = f"{run_path}/results.json"
1074
+ cat_result = client.exec(f"cat {results_path}")
1075
+
1076
+ if cat_result.exit_code != 0:
1077
+ # Try to get error from log
1078
+ log_result = client.exec(f"tail -50 {log_file}")
1079
+ log_tail = log_result.stdout if log_result.exit_code == 0 else ""
1080
+ return EvaluateResult(
1081
+ success=False,
1082
+ all_correct=False,
1083
+ correctness_score=0.0,
1084
+ geomean_speedup=0.0,
1085
+ passed_tests=0,
1086
+ total_tests=0,
1087
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
1088
+ )
1089
+
1090
+ # Parse results
1091
+ try:
1092
+ results_data = json.loads(cat_result.stdout)
1093
+ except json.JSONDecodeError as e:
1094
+ return EvaluateResult(
1095
+ success=False,
1096
+ all_correct=False,
1097
+ correctness_score=0.0,
1098
+ geomean_speedup=0.0,
1099
+ passed_tests=0,
1100
+ total_tests=0,
1101
+ error_message=f"Failed to parse results: {e}",
1102
+ )
1103
+
1104
+ # Extract backend results
1105
+ # Results format: {"backends": [{"backend_name": ..., "correctness_score": ..., ...}]}
1106
+ backends = results_data.get("backends", [])
1107
+ if not backends:
1108
+ return EvaluateResult(
1109
+ success=False,
1110
+ all_correct=False,
1111
+ correctness_score=0.0,
1112
+ geomean_speedup=0.0,
1113
+ passed_tests=0,
1114
+ total_tests=0,
1115
+ error_message="No backend results found",
1116
+ )
1117
+
1118
+ backend = backends[0]
1119
+ correctness_tests = backend.get("correctness_tests", [])
1120
+ passed = sum(1 for t in correctness_tests if t.get("is_correct", False))
1121
+ total = len(correctness_tests)
1122
+
1123
+ # Sync artifacts if requested
1124
+ artifact_path = None
1125
+ if args.sync_artifacts:
1126
+ local_artifact_dir = Path.cwd() / "wafer_artifacts" / run_dir
1127
+ local_artifact_dir.mkdir(parents=True, exist_ok=True)
1128
+
1129
+ # Download results and logs
1130
+ try:
1131
+ client.download_files(
1132
+ remote_path=f"{run_path}/results.json",
1133
+ local_path=str(local_artifact_dir / "results.json"),
1134
+ )
1135
+ client.download_files(
1136
+ remote_path=log_file,
1137
+ local_path=str(local_artifact_dir / "evaluate.log"),
1138
+ )
1139
+ artifact_path = local_artifact_dir
1140
+ print(f"Artifacts saved to: {artifact_path}")
1141
+ except Exception as e:
1142
+ print(f"Warning: Failed to sync artifacts: {e}")
1143
+
1144
+ return EvaluateResult(
1145
+ success=True,
1146
+ all_correct=backend.get("all_correct", False),
1147
+ correctness_score=backend.get("correctness_score", 0.0),
1148
+ geomean_speedup=backend.get("geomean_speedup", 0.0),
1149
+ passed_tests=passed,
1150
+ total_tests=total,
1151
+ artifact_path=artifact_path,
1152
+ )
1153
+
1154
+
1155
+ def _build_modal_sandbox_script(
1156
+ target: ModalTarget,
1157
+ impl_code_b64: str,
1158
+ ref_code_b64: str,
1159
+ test_cases_b64: str,
1160
+ run_benchmarks: bool,
1161
+ run_defensive: bool,
1162
+ defense_code_b64: str | None = None,
1163
+ ) -> str:
1164
+ """Build Python script to create sandbox and run evaluation.
1165
+
1166
+ This runs in a subprocess to isolate Modal's asyncio from trio.
1167
+ """
1168
+ gpu_type = target.gpu_type
1169
+
1170
+ # Determine PyTorch index based on GPU type
1171
+ if gpu_type in ("B200", "GB200"):
1172
+ torch_index = "https://download.pytorch.org/whl/nightly/cu128"
1173
+ else:
1174
+ torch_index = "https://download.pytorch.org/whl/cu124"
1175
+
1176
+ return f'''
1177
+ import asyncio
1178
+ import base64
1179
+ import json
1180
+ import sys
1181
+ import modal
1182
+
1183
+ async def run_eval():
1184
+ app = modal.App.lookup("wafer-evaluate", create_if_missing=True)
1185
+
1186
+ # Build image with PyTorch and dependencies
1187
+ image = (
1188
+ modal.Image.from_registry(
1189
+ "nvidia/cuda:12.9.0-devel-ubuntu22.04",
1190
+ add_python="3.12",
1191
+ )
1192
+ .apt_install("git", "build-essential", "cmake")
1193
+ .pip_install(
1194
+ "torch",
1195
+ index_url="{torch_index}",
1196
+ extra_index_url="https://pypi.org/simple",
1197
+ )
1198
+ .pip_install(
1199
+ "numpy",
1200
+ "triton",
1201
+ "ninja",
1202
+ )
1203
+ .env({{
1204
+ "CUDA_HOME": "/usr/local/cuda",
1205
+ }})
1206
+ )
1207
+
1208
+ # Create sandbox
1209
+ sandbox = modal.Sandbox.create(
1210
+ app=app,
1211
+ image=image,
1212
+ gpu="{gpu_type}",
1213
+ timeout={target.timeout_seconds},
1214
+ )
1215
+
1216
+ try:
1217
+ # Decode files
1218
+ impl_code = base64.b64decode("{impl_code_b64}").decode()
1219
+ ref_code = base64.b64decode("{ref_code_b64}").decode()
1220
+ test_cases = base64.b64decode("{test_cases_b64}").decode()
1221
+
1222
+ # Write files to sandbox
1223
+ sandbox.exec("mkdir", "-p", "/workspace").wait()
1224
+
1225
+ # Write implementation
1226
+ proc = sandbox.exec("python", "-c", f"""
1227
+ import base64
1228
+ with open('/workspace/kernel.py', 'w') as f:
1229
+ f.write(base64.b64decode('{impl_code_b64}').decode())
1230
+ with open('/workspace/reference.py', 'w') as f:
1231
+ f.write(base64.b64decode('{ref_code_b64}').decode())
1232
+ with open('/workspace/reference_kernel.py', 'w') as f:
1233
+ f.write(base64.b64decode('{ref_code_b64}').decode())
1234
+ with open('/workspace/test_cases.json', 'w') as f:
1235
+ f.write(base64.b64decode('{test_cases_b64}').decode())
1236
+ print('Files written')
1237
+ """)
1238
+ proc.wait()
1239
+ if proc.returncode != 0:
1240
+ print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
1241
+ return
1242
+
1243
+ # Write defense module if defensive mode is enabled
1244
+ # NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
1245
+ if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
1246
+ proc = sandbox.exec("python", "-c", f"""
1247
+ import base64
1248
+ with open('/workspace/defense.py', 'w') as f:
1249
+ f.write(base64.b64decode('{defense_code_b64}').decode())
1250
+ print('Defense module written')
1251
+ """)
1252
+ proc.wait()
1253
+ if proc.returncode != 0:
1254
+ print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
1255
+ return
1256
+
1257
+ # Build inline evaluation script
1258
+ eval_script = """
1259
+ import json
1260
+ import sys
1261
+ import os
1262
+ import importlib.util
1263
+
1264
+ os.chdir('/workspace')
1265
+ sys.path.insert(0, '/workspace')
1266
+
1267
+ # Load test cases
1268
+ with open('test_cases.json') as f:
1269
+ test_cases = json.load(f)
1270
+
1271
+ # Load kernels
1272
+ def load_fn(path, name):
1273
+ spec = importlib.util.spec_from_file_location("mod", path)
1274
+ mod = importlib.util.module_from_spec(spec)
1275
+ spec.loader.exec_module(mod)
1276
+ return getattr(mod, name)
1277
+
1278
+ custom_kernel = load_fn('kernel.py', 'custom_kernel')
1279
+ ref_kernel = load_fn('reference.py', 'ref_kernel')
1280
+ generate_input = load_fn('reference.py', 'generate_input')
1281
+
1282
+ import torch
1283
+
1284
+ # Load defense module if available and defensive mode is enabled
1285
+ run_defensive = {run_defensive}
1286
+ defense = None
1287
+ if run_defensive:
1288
+ try:
1289
+ defense = load_fn('defense.py', 'run_all_defenses')
1290
+ time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
1291
+ print('[Defense] Defense module loaded')
1292
+
1293
+ # Wrap kernels for defense API compatibility
1294
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1295
+ # These wrappers repack the unpacked args back into a tuple
1296
+ def _wrap_for_defense(kernel):
1297
+ return lambda *args: kernel(args)
1298
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1299
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1300
+ except Exception as e:
1301
+ print(f'[Defense] Warning: Could not load defense module: {{e}}')
1302
+ defense = None
1303
+
1304
+ results = []
1305
+ all_correct = True
1306
+ total_time_ms = 0.0
1307
+ ref_total_time_ms = 0.0
1308
+
1309
+ for tc in test_cases:
1310
+ name = tc.pop('name', 'test')
1311
+ try:
1312
+ inputs = generate_input(**tc)
1313
+
1314
+ # Correctness check - pass inputs as single arg (wafer-core convention)
1315
+ with torch.no_grad():
1316
+ ref_out = ref_kernel(inputs)
1317
+ impl_out = custom_kernel(inputs)
1318
+
1319
+ if isinstance(ref_out, torch.Tensor):
1320
+ correct = torch.allclose(ref_out, impl_out, rtol=1e-3, atol=1e-3)
1321
+ else:
1322
+ correct = ref_out == impl_out
1323
+
1324
+ if not correct:
1325
+ all_correct = False
1326
+
1327
+ # Benchmark if requested
1328
+ impl_time_ms = 0.0
1329
+ ref_time_ms = 0.0
1330
+ if {run_benchmarks}:
1331
+ if run_defensive and defense is not None:
1332
+ # Use full defense suite with wrapped kernels
1333
+ # inputs_list unpacks the tuple so defense can infer dtype/device from tensors
1334
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1335
+
1336
+ # Run defense checks
1337
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1338
+ if not all_passed:
1339
+ failed = [name for name, passed, _ in defense_results if not passed]
1340
+ raise ValueError(f"Defense checks failed: {{failed}}")
1341
+
1342
+ # Time with defensive timing (using wrapped kernels)
1343
+ impl_times, _ = time_with_defenses(
1344
+ custom_kernel_for_defense,
1345
+ inputs_list,
1346
+ num_warmup=3,
1347
+ num_trials=10,
1348
+ verbose=False,
1349
+ run_defenses=False,
1350
+ )
1351
+ impl_time_ms = sum(impl_times) / len(impl_times)
1352
+
1353
+ ref_times, _ = time_with_defenses(
1354
+ ref_kernel_for_defense,
1355
+ inputs_list,
1356
+ num_warmup=3,
1357
+ num_trials=10,
1358
+ verbose=False,
1359
+ run_defenses=False,
1360
+ )
1361
+ ref_time_ms = sum(ref_times) / len(ref_times)
1362
+ else:
1363
+ # Standard timing without full defenses
1364
+ # Warmup
1365
+ for _ in range(3):
1366
+ custom_kernel(inputs)
1367
+ torch.cuda.synchronize()
1368
+
1369
+ start = torch.cuda.Event(enable_timing=True)
1370
+ end = torch.cuda.Event(enable_timing=True)
1371
+ start.record()
1372
+ for _ in range(10):
1373
+ custom_kernel(inputs)
1374
+ end.record()
1375
+ torch.cuda.synchronize()
1376
+ impl_time_ms = start.elapsed_time(end) / 10
1377
+
1378
+ # Reference timing
1379
+ for _ in range(3):
1380
+ ref_kernel(inputs)
1381
+ torch.cuda.synchronize()
1382
+ start.record()
1383
+ for _ in range(10):
1384
+ ref_kernel(inputs)
1385
+ end.record()
1386
+ torch.cuda.synchronize()
1387
+ ref_time_ms = start.elapsed_time(end) / 10
1388
+
1389
+ total_time_ms += impl_time_ms
1390
+ ref_total_time_ms += ref_time_ms
1391
+
1392
+ results.append({{
1393
+ 'name': name,
1394
+ 'correct': correct,
1395
+ 'impl_time_ms': impl_time_ms,
1396
+ 'ref_time_ms': ref_time_ms,
1397
+ }})
1398
+
1399
+ except Exception as e:
1400
+ results.append({{'name': name, 'correct': False, 'error': str(e)}})
1401
+ all_correct = False
1402
+
1403
+ # Calculate speedup
1404
+ speedup = 0.0
1405
+ if total_time_ms > 0 and ref_total_time_ms > 0:
1406
+ speedup = ref_total_time_ms / total_time_ms
1407
+
1408
+ passed = sum(1 for r in results if r.get('correct', False))
1409
+ total = len(results)
1410
+
1411
+ print(json.dumps({{
1412
+ 'success': True,
1413
+ 'all_correct': all_correct,
1414
+ 'passed': passed,
1415
+ 'total': total,
1416
+ 'speedup': speedup,
1417
+ 'results': results,
1418
+ }}))
1419
+ """
1420
+
1421
+ # Run evaluation
1422
+ proc = sandbox.exec(
1423
+ "python", "-c", eval_script,
1424
+ timeout={target.timeout_seconds},
1425
+ )
1426
+ proc.wait()
1427
+
1428
+ stdout = proc.stdout.read()
1429
+ stderr = proc.stderr.read()
1430
+
1431
+ if proc.returncode != 0:
1432
+ print(json.dumps({{"error": f"Eval failed: {{stderr or stdout}}"}}))
1433
+ return
1434
+
1435
+ # Forward the result JSON
1436
+ # Find the last JSON line in output
1437
+ for line in reversed(stdout.strip().split("\\n")):
1438
+ if line.startswith("{{"):
1439
+ print(line, flush=True)
1440
+ return
1441
+
1442
+ print(json.dumps({{"error": f"No result JSON in output: {{stdout[:500]}}"}}))
1443
+
1444
+ finally:
1445
+ sandbox.terminate()
1446
+
1447
+ asyncio.run(run_eval())
1448
+ '''
1449
+
1450
+
1451
+ async def run_evaluate_modal(
1452
+ args: EvaluateArgs,
1453
+ target: ModalTarget,
1454
+ ) -> EvaluateResult:
1455
+ """Run evaluation on Modal sandbox.
1456
+
1457
+ Creates a Modal sandbox, uploads files, runs evaluate, and parses results.
1458
+ Uses subprocess to isolate Modal's asyncio from trio.
1459
+
1460
+ Args:
1461
+ args: Evaluate arguments
1462
+ target: Modal target config
1463
+
1464
+ Returns:
1465
+ Evaluation result
1466
+ """
1467
+ import base64
1468
+ import subprocess
1469
+ import sys
1470
+
1471
+ import trio
1472
+
1473
+ print(f"Creating Modal sandbox ({target.gpu_type})...")
1474
+
1475
+ # Encode files as base64
1476
+ impl_code_b64 = base64.b64encode(args.implementation.read_bytes()).decode()
1477
+ ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
1478
+ test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
1479
+
1480
+ # Encode defense module if defensive mode is enabled
1481
+ defense_code_b64 = None
1482
+ if args.defensive:
1483
+ defense_path = (
1484
+ Path(__file__).parent.parent.parent.parent
1485
+ / "packages"
1486
+ / "wafer-core"
1487
+ / "wafer_core"
1488
+ / "utils"
1489
+ / "kernel_utils"
1490
+ / "defense.py"
1491
+ )
1492
+ if defense_path.exists():
1493
+ defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
1494
+ else:
1495
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1496
+
1497
+ # Build the script that creates sandbox and runs eval
1498
+ script = _build_modal_sandbox_script(
1499
+ target=target,
1500
+ impl_code_b64=impl_code_b64,
1501
+ ref_code_b64=ref_code_b64,
1502
+ test_cases_b64=test_cases_b64,
1503
+ run_benchmarks=args.benchmark,
1504
+ run_defensive=args.defensive,
1505
+ defense_code_b64=defense_code_b64,
1506
+ )
1507
+
1508
+ def _run_subprocess() -> tuple[str, str, int]:
1509
+ result = subprocess.run(
1510
+ [sys.executable, "-c", script],
1511
+ capture_output=True,
1512
+ text=True,
1513
+ timeout=target.timeout_seconds + 60, # Extra buffer for sandbox creation
1514
+ )
1515
+ return result.stdout, result.stderr, result.returncode
1516
+
1517
+ try:
1518
+ stdout, stderr, returncode = await trio.to_thread.run_sync(_run_subprocess)
1519
+ except subprocess.TimeoutExpired:
1520
+ return EvaluateResult(
1521
+ success=False,
1522
+ all_correct=False,
1523
+ correctness_score=0.0,
1524
+ geomean_speedup=0.0,
1525
+ passed_tests=0,
1526
+ total_tests=0,
1527
+ error_message=f"Modal evaluation timed out after {target.timeout_seconds}s",
1528
+ )
1529
+ except Exception as e:
1530
+ return EvaluateResult(
1531
+ success=False,
1532
+ all_correct=False,
1533
+ correctness_score=0.0,
1534
+ geomean_speedup=0.0,
1535
+ passed_tests=0,
1536
+ total_tests=0,
1537
+ error_message=f"Failed to run Modal sandbox: {e}",
1538
+ )
1539
+
1540
+ if returncode != 0:
1541
+ return EvaluateResult(
1542
+ success=False,
1543
+ all_correct=False,
1544
+ correctness_score=0.0,
1545
+ geomean_speedup=0.0,
1546
+ passed_tests=0,
1547
+ total_tests=0,
1548
+ error_message=f"Modal sandbox failed (exit {returncode}): {stderr or stdout}",
1549
+ )
1550
+
1551
+ # Parse result JSON from stdout
1552
+ result_json = None
1553
+ for line in reversed(stdout.strip().split("\n")):
1554
+ if line.startswith("{"):
1555
+ try:
1556
+ result_json = json.loads(line)
1557
+ break
1558
+ except json.JSONDecodeError:
1559
+ continue
1560
+
1561
+ if result_json is None:
1562
+ return EvaluateResult(
1563
+ success=False,
1564
+ all_correct=False,
1565
+ correctness_score=0.0,
1566
+ geomean_speedup=0.0,
1567
+ passed_tests=0,
1568
+ total_tests=0,
1569
+ error_message=f"No valid JSON result in output: {stdout[:500]}",
1570
+ )
1571
+
1572
+ if "error" in result_json:
1573
+ return EvaluateResult(
1574
+ success=False,
1575
+ all_correct=False,
1576
+ correctness_score=0.0,
1577
+ geomean_speedup=0.0,
1578
+ passed_tests=0,
1579
+ total_tests=0,
1580
+ error_message=result_json["error"],
1581
+ )
1582
+
1583
+ passed = result_json.get("passed", 0)
1584
+ total = result_json.get("total", 0)
1585
+ correctness = passed / total if total > 0 else 0.0
1586
+
1587
+ return EvaluateResult(
1588
+ success=True,
1589
+ all_correct=result_json.get("all_correct", False),
1590
+ correctness_score=correctness,
1591
+ geomean_speedup=result_json.get("speedup", 0.0),
1592
+ passed_tests=passed,
1593
+ total_tests=total,
1594
+ )
1595
+
1596
+
1597
+ def _build_workspace_eval_script(
1598
+ impl_code: str,
1599
+ ref_code: str,
1600
+ test_cases_json: str,
1601
+ run_benchmarks: bool,
1602
+ run_defensive: bool = False,
1603
+ defense_code: str | None = None,
1604
+ ) -> str:
1605
+ """Build inline evaluation script for workspace exec.
1606
+
1607
+ Similar to Modal inline eval, but runs via workspace exec.
1608
+ """
1609
+ import base64
1610
+
1611
+ impl_b64 = base64.b64encode(impl_code.encode()).decode()
1612
+ ref_b64 = base64.b64encode(ref_code.encode()).decode()
1613
+ tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
1614
+ defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
1615
+
1616
+ return f'''
1617
+ import base64
1618
+ import json
1619
+ import sys
1620
+ import os
1621
+ import importlib.util
1622
+
1623
+ # Decode files
1624
+ impl_code = base64.b64decode("{impl_b64}").decode()
1625
+ ref_code = base64.b64decode("{ref_b64}").decode()
1626
+ test_cases = json.loads(base64.b64decode("{tests_b64}").decode())
1627
+
1628
+ # Write to temp files
1629
+ with open("/tmp/kernel.py", "w") as f:
1630
+ f.write(impl_code)
1631
+ with open("/tmp/reference.py", "w") as f:
1632
+ f.write(ref_code)
1633
+
1634
+ # Write defense module if available
1635
+ run_defensive = {run_defensive}
1636
+ defense_b64 = "{defense_b64}"
1637
+ # NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
1638
+ if run_defensive and defense_b64 and defense_b64 != "None":
1639
+ defense_code = base64.b64decode(defense_b64).decode()
1640
+ with open("/tmp/defense.py", "w") as f:
1641
+ f.write(defense_code)
1642
+
1643
+ # Load kernels
1644
+ def load_fn(path, name):
1645
+ spec = importlib.util.spec_from_file_location("mod", path)
1646
+ mod = importlib.util.module_from_spec(spec)
1647
+ spec.loader.exec_module(mod)
1648
+ return getattr(mod, name)
1649
+
1650
+ custom_kernel = load_fn("/tmp/kernel.py", "custom_kernel")
1651
+ ref_kernel = load_fn("/tmp/reference.py", "ref_kernel")
1652
+ generate_input = load_fn("/tmp/reference.py", "generate_input")
1653
+
1654
+ import torch
1655
+
1656
+ # Load defense module if available
1657
+ defense = None
1658
+ if run_defensive and defense_b64 and defense_b64 != "None":
1659
+ try:
1660
+ defense = load_fn("/tmp/defense.py", "run_all_defenses")
1661
+ time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
1662
+ print("[Defense] Defense module loaded")
1663
+
1664
+ # Wrap kernels for defense API compatibility
1665
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1666
+ def _wrap_for_defense(kernel):
1667
+ return lambda *args: kernel(args)
1668
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1669
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1670
+ except Exception as e:
1671
+ print(f"[Defense] Warning: Could not load defense module: {{e}}")
1672
+ defense = None
1673
+
1674
+ results = []
1675
+ all_correct = True
1676
+ total_time_ms = 0.0
1677
+ ref_total_time_ms = 0.0
1678
+
1679
+ for tc in test_cases:
1680
+ name = tc.pop("name", "test")
1681
+ try:
1682
+ inputs = generate_input(**tc)
1683
+
1684
+ # Correctness check - pass inputs as single arg (wafer-core convention)
1685
+ with torch.no_grad():
1686
+ ref_out = ref_kernel(inputs)
1687
+ impl_out = custom_kernel(inputs)
1688
+
1689
+ if isinstance(ref_out, torch.Tensor):
1690
+ correct = torch.allclose(ref_out, impl_out, rtol=1e-3, atol=1e-3)
1691
+ else:
1692
+ correct = ref_out == impl_out
1693
+
1694
+ if not correct:
1695
+ all_correct = False
1696
+
1697
+ # Benchmark if requested
1698
+ impl_time_ms = 0.0
1699
+ ref_time_ms = 0.0
1700
+ if {run_benchmarks}:
1701
+ if run_defensive and defense is not None:
1702
+ # Use full defense suite with wrapped kernels
1703
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1704
+
1705
+ # Run defense checks
1706
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1707
+ if not all_passed:
1708
+ failed = [name for name, passed, _ in defense_results if not passed]
1709
+ raise ValueError(f"Defense checks failed: {{failed}}")
1710
+
1711
+ # Time with defensive timing (using wrapped kernels)
1712
+ impl_times, _ = time_with_defenses(
1713
+ custom_kernel_for_defense,
1714
+ inputs_list,
1715
+ num_warmup=3,
1716
+ num_trials=10,
1717
+ verbose=False,
1718
+ run_defenses=False,
1719
+ )
1720
+ impl_time_ms = sum(impl_times) / len(impl_times)
1721
+
1722
+ ref_times, _ = time_with_defenses(
1723
+ ref_kernel_for_defense,
1724
+ inputs_list,
1725
+ num_warmup=3,
1726
+ num_trials=10,
1727
+ verbose=False,
1728
+ run_defenses=False,
1729
+ )
1730
+ ref_time_ms = sum(ref_times) / len(ref_times)
1731
+ else:
1732
+ # Standard timing
1733
+ for _ in range(3):
1734
+ custom_kernel(inputs)
1735
+ torch.cuda.synchronize()
1736
+
1737
+ start = torch.cuda.Event(enable_timing=True)
1738
+ end = torch.cuda.Event(enable_timing=True)
1739
+ start.record()
1740
+ for _ in range(10):
1741
+ custom_kernel(inputs)
1742
+ end.record()
1743
+ torch.cuda.synchronize()
1744
+ impl_time_ms = start.elapsed_time(end) / 10
1745
+
1746
+ for _ in range(3):
1747
+ ref_kernel(inputs)
1748
+ torch.cuda.synchronize()
1749
+ start.record()
1750
+ for _ in range(10):
1751
+ ref_kernel(inputs)
1752
+ end.record()
1753
+ torch.cuda.synchronize()
1754
+ ref_time_ms = start.elapsed_time(end) / 10
1755
+
1756
+ total_time_ms += impl_time_ms
1757
+ ref_total_time_ms += ref_time_ms
1758
+
1759
+ results.append({{
1760
+ "name": name,
1761
+ "correct": correct,
1762
+ "impl_time_ms": impl_time_ms,
1763
+ "ref_time_ms": ref_time_ms,
1764
+ }})
1765
+
1766
+ except Exception as e:
1767
+ results.append({{"name": name, "correct": False, "error": str(e)}})
1768
+ all_correct = False
1769
+
1770
+ # Calculate speedup
1771
+ speedup = 0.0
1772
+ if total_time_ms > 0 and ref_total_time_ms > 0:
1773
+ speedup = ref_total_time_ms / total_time_ms
1774
+
1775
+ passed = sum(1 for r in results if r.get("correct", False))
1776
+ total = len(results)
1777
+
1778
+ print(json.dumps({{
1779
+ "success": True,
1780
+ "all_correct": all_correct,
1781
+ "passed": passed,
1782
+ "total": total,
1783
+ "speedup": speedup,
1784
+ "results": results,
1785
+ }}))
1786
+ '''
1787
+
1788
+
1789
+ async def run_evaluate_workspace(
1790
+ args: EvaluateArgs,
1791
+ target: WorkspaceTarget,
1792
+ ) -> EvaluateResult:
1793
+ """Run evaluation on wafer-api managed workspace.
1794
+
1795
+ Uses inline evaluation (no file sync needed) via workspace exec.
1796
+ The eval script is passed as a Python command with base64-encoded files.
1797
+
1798
+ Args:
1799
+ args: Evaluate arguments
1800
+ target: Workspace target config
1801
+
1802
+ Returns:
1803
+ Evaluation result
1804
+ """
1805
+ import trio
1806
+
1807
+ from .workspaces import exec_command
1808
+
1809
+ print(f"Using workspace: {target.workspace_id}")
1810
+
1811
+ # Read files
1812
+ impl_code = args.implementation.read_text()
1813
+ ref_code = args.reference.read_text()
1814
+ test_cases_json = args.test_cases.read_text()
1815
+
1816
+ # Read defense module if defensive mode is enabled
1817
+ defense_code = None
1818
+ if args.defensive:
1819
+ defense_path = (
1820
+ Path(__file__).parent.parent.parent.parent
1821
+ / "packages"
1822
+ / "wafer-core"
1823
+ / "wafer_core"
1824
+ / "utils"
1825
+ / "kernel_utils"
1826
+ / "defense.py"
1827
+ )
1828
+ if defense_path.exists():
1829
+ defense_code = defense_path.read_text()
1830
+ else:
1831
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1832
+
1833
+ # Build inline eval script
1834
+ eval_script = _build_workspace_eval_script(
1835
+ impl_code=impl_code,
1836
+ ref_code=ref_code,
1837
+ test_cases_json=test_cases_json,
1838
+ run_benchmarks=args.benchmark,
1839
+ run_defensive=args.defensive,
1840
+ defense_code=defense_code,
1841
+ )
1842
+
1843
+ # Execute via workspace exec
1844
+ # Use python -c with the script
1845
+ eval_cmd = f"python -c {shlex.quote(eval_script)}"
1846
+
1847
+ print("Running evaluation...")
1848
+
1849
+ # Capture stdout by redirecting exec output
1850
+ # exec_command prints to stdout, we need to capture it
1851
+ import io
1852
+ import sys
1853
+
1854
+ captured_output = io.StringIO()
1855
+ original_stdout = sys.stdout
1856
+
1857
+ def _exec() -> int:
1858
+ # Temporarily redirect stdout to capture output
1859
+ sys.stdout = captured_output
1860
+ try:
1861
+ return exec_command(
1862
+ workspace_id=target.workspace_id,
1863
+ command=eval_cmd,
1864
+ timeout_seconds=target.timeout_seconds,
1865
+ )
1866
+ finally:
1867
+ sys.stdout = original_stdout
1868
+
1869
+ try:
1870
+ exit_code = await trio.to_thread.run_sync(_exec)
1871
+ except Exception as e:
1872
+ sys.stdout = original_stdout
1873
+ return EvaluateResult(
1874
+ success=False,
1875
+ all_correct=False,
1876
+ correctness_score=0.0,
1877
+ geomean_speedup=0.0,
1878
+ passed_tests=0,
1879
+ total_tests=0,
1880
+ error_message=f"Execution failed: {e}",
1881
+ )
1882
+
1883
+ # Parse output
1884
+ output = captured_output.getvalue()
1885
+ print(output) # Show output to user
1886
+
1887
+ # Find JSON result in output
1888
+ result_json = None
1889
+ for line in reversed(output.strip().split("\n")):
1890
+ if line.startswith("{"):
1891
+ try:
1892
+ result_json = json.loads(line)
1893
+ break
1894
+ except json.JSONDecodeError:
1895
+ continue
1896
+
1897
+ if result_json is None:
1898
+ if exit_code == 0:
1899
+ return EvaluateResult(
1900
+ success=True,
1901
+ all_correct=True,
1902
+ correctness_score=1.0,
1903
+ geomean_speedup=0.0,
1904
+ passed_tests=0,
1905
+ total_tests=0,
1906
+ )
1907
+ else:
1908
+ return EvaluateResult(
1909
+ success=False,
1910
+ all_correct=False,
1911
+ correctness_score=0.0,
1912
+ geomean_speedup=0.0,
1913
+ passed_tests=0,
1914
+ total_tests=0,
1915
+ error_message=f"Evaluation failed with exit code {exit_code}",
1916
+ )
1917
+
1918
+ if "error" in result_json:
1919
+ return EvaluateResult(
1920
+ success=False,
1921
+ all_correct=False,
1922
+ correctness_score=0.0,
1923
+ geomean_speedup=0.0,
1924
+ passed_tests=0,
1925
+ total_tests=0,
1926
+ error_message=result_json["error"],
1927
+ )
1928
+
1929
+ passed = result_json.get("passed", 0)
1930
+ total = result_json.get("total", 0)
1931
+ correctness = passed / total if total > 0 else 0.0
1932
+
1933
+ return EvaluateResult(
1934
+ success=True,
1935
+ all_correct=result_json.get("all_correct", False),
1936
+ correctness_score=correctness,
1937
+ geomean_speedup=result_json.get("speedup", 0.0),
1938
+ passed_tests=passed,
1939
+ total_tests=total,
1940
+ )
1941
+
1942
+
1943
+ async def run_evaluate_runpod(
1944
+ args: EvaluateArgs,
1945
+ target: RunPodTarget,
1946
+ ) -> EvaluateResult:
1947
+ """Run evaluation on RunPod target.
1948
+
1949
+ Provisions a RunPod pod (or reuses existing), runs evaluation via SSH,
1950
+ then cleans up based on keep_alive setting.
1951
+
1952
+ Sets up a Python venv with ROCm torch using uv, then runs evaluation.
1953
+
1954
+ Args:
1955
+ args: Evaluate arguments
1956
+ target: RunPod target config
1957
+
1958
+ Returns:
1959
+ Evaluation result
1960
+ """
1961
+ from datetime import datetime
1962
+
1963
+ from wafer_core.async_ssh import AsyncSSHClient
1964
+ from wafer_core.remote_env import async_setup_python_env
1965
+ from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
1966
+
1967
+ REMOTE_WORKSPACE = "/tmp/wafer_eval"
1968
+ ROCM_TORCH_INDEX_URL = "https://download.pytorch.org/whl/rocm6.2"
1969
+ ROCM_TORCH_VERSION_SUFFIX = "+rocm6.2"
1970
+
1971
+ print(f"Provisioning RunPod ({target.gpu_type_id})...")
1972
+
1973
+ try:
1974
+ async with runpod_ssh_context(target) as ssh_info:
1975
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
1976
+ print(f"Connected to RunPod: {ssh_target}")
1977
+
1978
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
1979
+ # Ensure rsync is installed (needed for file uploads)
1980
+ print("Checking rsync...")
1981
+ result = await client.exec("which rsync || echo 'NOT_FOUND'")
1982
+ if "NOT_FOUND" in result.stdout:
1983
+ print("Installing rsync...")
1984
+ await client.exec("apt-get update && apt-get install -y rsync")
1985
+
1986
+ # Setup Python environment with ROCm torch
1987
+ # Match wafer-core dependencies needed for evaluate.py
1988
+ print("Setting up Python environment with ROCm torch...")
1989
+ requirements = [
1990
+ f"torch==2.5.1{ROCM_TORCH_VERSION_SUFFIX}",
1991
+ "numpy",
1992
+ "ninja",
1993
+ "setuptools",
1994
+ # wafer_core dependencies
1995
+ "trio",
1996
+ "httpx",
1997
+ "pydantic",
1998
+ "anyio",
1999
+ "pyyaml",
2000
+ ]
2001
+
2002
+ try:
2003
+ env_state = await async_setup_python_env(
2004
+ client=client,
2005
+ workspace=REMOTE_WORKSPACE,
2006
+ requirements=requirements,
2007
+ python_version=">=3.10",
2008
+ venv_path=".venv",
2009
+ index_url=ROCM_TORCH_INDEX_URL,
2010
+ )
2011
+ python_exe = env_state.venv_python
2012
+ print(f"Using Python: {python_exe}")
2013
+ except Exception as e:
2014
+ return EvaluateResult(
2015
+ success=False,
2016
+ all_correct=False,
2017
+ correctness_score=0.0,
2018
+ geomean_speedup=0.0,
2019
+ passed_tests=0,
2020
+ total_tests=0,
2021
+ error_message=f"Failed to setup Python environment: {e}",
2022
+ )
2023
+
2024
+ # Upload wafer-core to remote
2025
+ try:
2026
+ wafer_root = _get_wafer_root()
2027
+ wafer_core_path = wafer_root / "packages" / "wafer-core"
2028
+ print(f"Uploading wafer-core from {wafer_core_path}...")
2029
+
2030
+ wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
2031
+ await client.exec(f"mkdir -p {wafer_core_remote}")
2032
+ wafer_core_workspace = await client.expand_path(wafer_core_remote)
2033
+
2034
+ upload_result = await client.upload_files(
2035
+ str(wafer_core_path), wafer_core_workspace, recursive=True
2036
+ )
2037
+
2038
+ # Wide event logging for upload result
2039
+ upload_event = {
2040
+ "event": "wafer_core_upload",
2041
+ "target": target.name,
2042
+ "target_type": "runpod",
2043
+ "ssh_host": f"{client.user}@{client.host}:{client.port}",
2044
+ "local_path": str(wafer_core_path),
2045
+ "remote_path": wafer_core_workspace,
2046
+ "success": upload_result.success,
2047
+ "files_copied": upload_result.files_copied,
2048
+ "duration_seconds": upload_result.duration_seconds,
2049
+ "error_message": upload_result.error_message,
2050
+ }
2051
+ if upload_result.debug_info:
2052
+ upload_event["debug_info"] = upload_result.debug_info
2053
+ logger.info(json.dumps(upload_event))
2054
+
2055
+ # Fail fast if upload failed
2056
+ if not upload_result.success:
2057
+ print(f"ERROR: Upload failed: {upload_result.error_message}")
2058
+ if upload_result.debug_info:
2059
+ print(f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}")
2060
+ return EvaluateResult(
2061
+ success=False,
2062
+ all_correct=False,
2063
+ correctness_score=0.0,
2064
+ geomean_speedup=0.0,
2065
+ passed_tests=0,
2066
+ total_tests=0,
2067
+ error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
2068
+ )
2069
+
2070
+ print(f"Uploaded {upload_result.files_copied} files")
2071
+ except Exception as e:
2072
+ return EvaluateResult(
2073
+ success=False,
2074
+ all_correct=False,
2075
+ correctness_score=0.0,
2076
+ geomean_speedup=0.0,
2077
+ passed_tests=0,
2078
+ total_tests=0,
2079
+ error_message=f"Failed to upload wafer-core: {e}",
2080
+ )
2081
+
2082
+ # Select GPU (RunPod pods typically have GPU 0)
2083
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
2084
+ print(f"Using GPU {gpu_id}...")
2085
+
2086
+ # Read local files
2087
+ impl_code = args.implementation.read_text()
2088
+ ref_code = args.reference.read_text()
2089
+ test_cases_data = json.loads(args.test_cases.read_text())
2090
+
2091
+ # Create a unique run directory (uuid for concurrent eval isolation)
2092
+ import uuid
2093
+
2094
+ unique_id = uuid.uuid4().hex[:8]
2095
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
2096
+ run_dir = f"wafer_eval_{timestamp}_{unique_id}"
2097
+ run_path = f"{REMOTE_WORKSPACE}/{run_dir}"
2098
+
2099
+ print("Uploading evaluation files...")
2100
+
2101
+ # Create run directory
2102
+ mkdir_result = await client.exec(f"mkdir -p {run_path}")
2103
+ if mkdir_result.exit_code != 0:
2104
+ return EvaluateResult(
2105
+ success=False,
2106
+ all_correct=False,
2107
+ correctness_score=0.0,
2108
+ geomean_speedup=0.0,
2109
+ passed_tests=0,
2110
+ total_tests=0,
2111
+ error_message=f"Failed to create run directory: {mkdir_result.stderr}",
2112
+ )
2113
+
2114
+ # Write implementation
2115
+ impl_path = f"{run_path}/implementation.py"
2116
+ write_result = await client.exec(
2117
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
2118
+ )
2119
+ if write_result.exit_code != 0:
2120
+ return EvaluateResult(
2121
+ success=False,
2122
+ all_correct=False,
2123
+ correctness_score=0.0,
2124
+ geomean_speedup=0.0,
2125
+ passed_tests=0,
2126
+ total_tests=0,
2127
+ error_message=f"Failed to write implementation: {write_result.stderr}",
2128
+ )
2129
+
2130
+ # Write reference
2131
+ ref_path = f"{run_path}/reference.py"
2132
+ write_result = await client.exec(
2133
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
2134
+ )
2135
+ if write_result.exit_code != 0:
2136
+ return EvaluateResult(
2137
+ success=False,
2138
+ all_correct=False,
2139
+ correctness_score=0.0,
2140
+ geomean_speedup=0.0,
2141
+ passed_tests=0,
2142
+ total_tests=0,
2143
+ error_message=f"Failed to write reference: {write_result.stderr}",
2144
+ )
2145
+
2146
+ # Also write as reference_kernel.py (evaluate.py imports generate_input from this)
2147
+ ref_kernel_path = f"{run_path}/reference_kernel.py"
2148
+ write_result = await client.exec(
2149
+ f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
2150
+ )
2151
+ if write_result.exit_code != 0:
2152
+ return EvaluateResult(
2153
+ success=False,
2154
+ all_correct=False,
2155
+ correctness_score=0.0,
2156
+ geomean_speedup=0.0,
2157
+ passed_tests=0,
2158
+ total_tests=0,
2159
+ error_message=f"Failed to write reference_kernel: {write_result.stderr}",
2160
+ )
2161
+
2162
+ # Write test cases as JSON
2163
+ test_cases_path = f"{run_path}/test_cases.json"
2164
+ test_cases_json = json.dumps(test_cases_data)
2165
+ write_result = await client.exec(
2166
+ f"cat > '{test_cases_path}' << 'TEST_EOF'\n{test_cases_json}\nTEST_EOF"
2167
+ )
2168
+ if write_result.exit_code != 0:
2169
+ return EvaluateResult(
2170
+ success=False,
2171
+ all_correct=False,
2172
+ correctness_score=0.0,
2173
+ geomean_speedup=0.0,
2174
+ passed_tests=0,
2175
+ total_tests=0,
2176
+ error_message=f"Failed to write test cases: {write_result.stderr}",
2177
+ )
2178
+
2179
+ print("Running evaluation...")
2180
+
2181
+ # Build evaluation command
2182
+ # RunPod ROCm images use HIP_VISIBLE_DEVICES for AMD GPUs
2183
+ # Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
2184
+ venv_bin = env_state.venv_bin
2185
+ env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
2186
+
2187
+ # Run from run_path so reference_kernel.py is importable
2188
+ # Use installed wafer-core module
2189
+ eval_cmd = (
2190
+ f"cd {run_path} && "
2191
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
2192
+ f"--implementation {impl_path} "
2193
+ f"--reference {ref_path} "
2194
+ f"--test-cases {test_cases_path} "
2195
+ f"--run-dir {run_path}"
2196
+ )
2197
+
2198
+ if args.benchmark:
2199
+ eval_cmd += " --benchmark"
2200
+ if args.defensive:
2201
+ eval_cmd += " --defensive"
2202
+
2203
+ # Run with timeout
2204
+ import trio
2205
+
2206
+ with trio.move_on_after(target.eval_timeout) as cancel_scope:
2207
+ result = await client.exec(eval_cmd)
2208
+
2209
+ if cancel_scope.cancelled_caught:
2210
+ return EvaluateResult(
2211
+ success=False,
2212
+ all_correct=False,
2213
+ correctness_score=0.0,
2214
+ geomean_speedup=0.0,
2215
+ passed_tests=0,
2216
+ total_tests=0,
2217
+ error_message=f"Evaluation timed out after {target.eval_timeout}s",
2218
+ )
2219
+
2220
+ # Parse output
2221
+ stdout = result.stdout
2222
+ stderr = result.stderr
2223
+
2224
+ if result.exit_code != 0:
2225
+ return EvaluateResult(
2226
+ success=False,
2227
+ all_correct=False,
2228
+ correctness_score=0.0,
2229
+ geomean_speedup=0.0,
2230
+ passed_tests=0,
2231
+ total_tests=0,
2232
+ error_message=f"Evaluation failed:\nstdout: {stdout}\nstderr: {stderr}",
2233
+ )
2234
+
2235
+ # Find JSON result in output
2236
+ result_json = None
2237
+ for line in reversed(stdout.strip().split("\n")):
2238
+ if line.startswith("{"):
2239
+ try:
2240
+ result_json = json.loads(line)
2241
+ break
2242
+ except json.JSONDecodeError:
2243
+ continue
2244
+
2245
+ if result_json is None:
2246
+ return EvaluateResult(
2247
+ success=False,
2248
+ all_correct=False,
2249
+ correctness_score=0.0,
2250
+ geomean_speedup=0.0,
2251
+ passed_tests=0,
2252
+ total_tests=0,
2253
+ error_message=f"No JSON result in output:\n{stdout}",
2254
+ )
2255
+
2256
+ if "error" in result_json:
2257
+ return EvaluateResult(
2258
+ success=False,
2259
+ all_correct=False,
2260
+ correctness_score=0.0,
2261
+ geomean_speedup=0.0,
2262
+ passed_tests=0,
2263
+ total_tests=0,
2264
+ error_message=result_json["error"],
2265
+ )
2266
+
2267
+ passed = result_json.get("passed", 0)
2268
+ total = result_json.get("total", 0)
2269
+ correctness = passed / total if total > 0 else 0.0
2270
+
2271
+ return EvaluateResult(
2272
+ success=True,
2273
+ all_correct=result_json.get("all_correct", False),
2274
+ correctness_score=correctness,
2275
+ geomean_speedup=result_json.get("speedup", 0.0),
2276
+ passed_tests=passed,
2277
+ total_tests=total,
2278
+ )
2279
+
2280
+ except RunPodError as e:
2281
+ return EvaluateResult(
2282
+ success=False,
2283
+ all_correct=False,
2284
+ correctness_score=0.0,
2285
+ geomean_speedup=0.0,
2286
+ passed_tests=0,
2287
+ total_tests=0,
2288
+ error_message=f"RunPod error: {e}",
2289
+ )
2290
+
2291
+
2292
+ async def run_evaluate_digitalocean(
2293
+ args: EvaluateArgs,
2294
+ target: DigitalOceanTarget,
2295
+ ) -> EvaluateResult:
2296
+ """Run evaluation on DigitalOcean target.
2297
+
2298
+ Provisions a DigitalOcean droplet (or reuses existing), bootstraps Python
2299
+ environment with uv, runs evaluation via SSH, then cleans up based on
2300
+ keep_alive setting.
2301
+
2302
+ Args:
2303
+ args: Evaluate arguments
2304
+ target: DigitalOcean target config
2305
+
2306
+ Returns:
2307
+ Evaluation result
2308
+ """
2309
+ from datetime import datetime
2310
+
2311
+ import trio_asyncio
2312
+ from wafer_core.async_ssh import AsyncSSHClient
2313
+ from wafer_core.remote_env import async_setup_python_env
2314
+ from wafer_core.targets.digitalocean import DigitalOceanError, digitalocean_ssh_context
2315
+
2316
+ REMOTE_WORKSPACE = "/tmp/wafer_eval"
2317
+ ROCM_TORCH_INDEX_URL = "https://download.pytorch.org/whl/rocm6.2"
2318
+ ROCM_TORCH_VERSION_SUFFIX = "+rocm6.2"
2319
+
2320
+ print(f"Provisioning DigitalOcean droplet ({target.size_slug})...")
2321
+
2322
+ try:
2323
+ async with digitalocean_ssh_context(target) as ssh_info:
2324
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
2325
+ print(f"Connected to DigitalOcean: {ssh_target}")
2326
+
2327
+ # Need trio_asyncio for AsyncSSHClient
2328
+ async with trio_asyncio.open_loop():
2329
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
2330
+ # Ensure rsync and ninja are installed
2331
+ # ninja is needed for torch.utils.cpp_extension (HIP kernel compilation)
2332
+ print("Checking system dependencies...")
2333
+ result = await client.exec("which rsync && which ninja || echo 'MISSING'")
2334
+ if "MISSING" in result.stdout:
2335
+ print("Installing rsync and ninja...")
2336
+ await client.exec("apt-get update && apt-get install -y rsync ninja-build")
2337
+
2338
+ # Setup Python environment with ROCm torch
2339
+ # Match wafer-core dependencies needed for evaluate.py
2340
+ print("Setting up Python environment with ROCm torch...")
2341
+ requirements = [
2342
+ f"torch==2.5.1{ROCM_TORCH_VERSION_SUFFIX}",
2343
+ "numpy",
2344
+ "ninja",
2345
+ "setuptools",
2346
+ # wafer_core dependencies
2347
+ "trio",
2348
+ "httpx",
2349
+ "pydantic",
2350
+ "anyio",
2351
+ "pyyaml",
2352
+ ]
2353
+
2354
+ try:
2355
+ env_state = await async_setup_python_env(
2356
+ client=client,
2357
+ workspace=REMOTE_WORKSPACE,
2358
+ requirements=requirements,
2359
+ python_version="3.10",
2360
+ venv_path=".venv",
2361
+ index_url=ROCM_TORCH_INDEX_URL,
2362
+ )
2363
+ python_exe = env_state.venv_python
2364
+ print(f"Using Python: {python_exe}")
2365
+ except Exception as e:
2366
+ return EvaluateResult(
2367
+ success=False,
2368
+ all_correct=False,
2369
+ correctness_score=0.0,
2370
+ geomean_speedup=0.0,
2371
+ passed_tests=0,
2372
+ total_tests=0,
2373
+ error_message=f"Failed to setup Python environment: {e}",
2374
+ )
2375
+
2376
+ # Upload wafer-core to remote
2377
+ try:
2378
+ wafer_root = _get_wafer_root()
2379
+ wafer_core_path = wafer_root / "packages" / "wafer-core"
2380
+ print(f"Uploading wafer-core from {wafer_core_path}...")
2381
+
2382
+ wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
2383
+ await client.exec(f"mkdir -p {wafer_core_remote}")
2384
+ wafer_core_workspace = await client.expand_path(wafer_core_remote)
2385
+
2386
+ # Use SFTP instead of rsync to avoid SSH subprocess timeout issues
2387
+ # (DigitalOcean may rate-limit new SSH connections)
2388
+ upload_result = await client.upload_files(
2389
+ str(wafer_core_path),
2390
+ wafer_core_workspace,
2391
+ recursive=True,
2392
+ use_sftp=True,
2393
+ )
2394
+
2395
+ # Wide event logging for upload result
2396
+ upload_event = {
2397
+ "event": "wafer_core_upload",
2398
+ "target": target.name,
2399
+ "target_type": "digitalocean",
2400
+ "ssh_host": f"{client.user}@{client.host}:{client.port}",
2401
+ "local_path": str(wafer_core_path),
2402
+ "remote_path": wafer_core_workspace,
2403
+ "success": upload_result.success,
2404
+ "files_copied": upload_result.files_copied,
2405
+ "duration_seconds": upload_result.duration_seconds,
2406
+ "error_message": upload_result.error_message,
2407
+ }
2408
+ if upload_result.debug_info:
2409
+ upload_event["debug_info"] = upload_result.debug_info
2410
+ logger.info(json.dumps(upload_event))
2411
+
2412
+ # Fail fast if upload failed
2413
+ if not upload_result.success:
2414
+ print(f"ERROR: Upload failed: {upload_result.error_message}")
2415
+ if upload_result.debug_info:
2416
+ print(
2417
+ f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}"
2418
+ )
2419
+ return EvaluateResult(
2420
+ success=False,
2421
+ all_correct=False,
2422
+ correctness_score=0.0,
2423
+ geomean_speedup=0.0,
2424
+ passed_tests=0,
2425
+ total_tests=0,
2426
+ error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
2427
+ )
2428
+
2429
+ print(f"Uploaded {upload_result.files_copied} files")
2430
+ except Exception as e:
2431
+ return EvaluateResult(
2432
+ success=False,
2433
+ all_correct=False,
2434
+ correctness_score=0.0,
2435
+ geomean_speedup=0.0,
2436
+ passed_tests=0,
2437
+ total_tests=0,
2438
+ error_message=f"Failed to upload wafer-core: {e}",
2439
+ )
2440
+
2441
+ # Select GPU (DigitalOcean droplets typically have GPU 0)
2442
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
2443
+ print(f"Using GPU {gpu_id}...")
2444
+
2445
+ # Read local files
2446
+ impl_code = args.implementation.read_text()
2447
+ ref_code = args.reference.read_text()
2448
+ test_cases_data = json.loads(args.test_cases.read_text())
2449
+
2450
+ # Create a unique run directory (uuid for concurrent eval isolation)
2451
+ import uuid
2452
+
2453
+ unique_id = uuid.uuid4().hex[:8]
2454
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
2455
+ run_dir = f"wafer_eval_{timestamp}_{unique_id}"
2456
+ run_path = f"{REMOTE_WORKSPACE}/{run_dir}"
2457
+
2458
+ print("Uploading evaluation files...")
2459
+
2460
+ # Create run directory
2461
+ mkdir_result = await client.exec(f"mkdir -p {run_path}")
2462
+ if mkdir_result.exit_code != 0:
2463
+ return EvaluateResult(
2464
+ success=False,
2465
+ all_correct=False,
2466
+ correctness_score=0.0,
2467
+ geomean_speedup=0.0,
2468
+ passed_tests=0,
2469
+ total_tests=0,
2470
+ error_message=f"Failed to create run directory: {mkdir_result.stderr}",
2471
+ )
2472
+
2473
+ # Write implementation
2474
+ impl_path = f"{run_path}/implementation.py"
2475
+ write_result = await client.exec(
2476
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
2477
+ )
2478
+ if write_result.exit_code != 0:
2479
+ return EvaluateResult(
2480
+ success=False,
2481
+ all_correct=False,
2482
+ correctness_score=0.0,
2483
+ geomean_speedup=0.0,
2484
+ passed_tests=0,
2485
+ total_tests=0,
2486
+ error_message=f"Failed to write implementation: {write_result.stderr}",
2487
+ )
2488
+
2489
+ # Write reference
2490
+ ref_path = f"{run_path}/reference.py"
2491
+ write_result = await client.exec(
2492
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
2493
+ )
2494
+ if write_result.exit_code != 0:
2495
+ return EvaluateResult(
2496
+ success=False,
2497
+ all_correct=False,
2498
+ correctness_score=0.0,
2499
+ geomean_speedup=0.0,
2500
+ passed_tests=0,
2501
+ total_tests=0,
2502
+ error_message=f"Failed to write reference: {write_result.stderr}",
2503
+ )
2504
+
2505
+ # Also write as reference_kernel.py (evaluate.py imports generate_input from this)
2506
+ ref_kernel_path = f"{run_path}/reference_kernel.py"
2507
+ write_result = await client.exec(
2508
+ f"cat > '{ref_kernel_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
2509
+ )
2510
+ if write_result.exit_code != 0:
2511
+ return EvaluateResult(
2512
+ success=False,
2513
+ all_correct=False,
2514
+ correctness_score=0.0,
2515
+ geomean_speedup=0.0,
2516
+ passed_tests=0,
2517
+ total_tests=0,
2518
+ error_message=f"Failed to write reference_kernel: {write_result.stderr}",
2519
+ )
2520
+
2521
+ # Write test cases as JSON
2522
+ test_cases_path = f"{run_path}/test_cases.json"
2523
+ test_cases_json = json.dumps(test_cases_data)
2524
+ write_result = await client.exec(
2525
+ f"cat > '{test_cases_path}' << 'TEST_EOF'\n{test_cases_json}\nTEST_EOF"
2526
+ )
2527
+ if write_result.exit_code != 0:
2528
+ return EvaluateResult(
2529
+ success=False,
2530
+ all_correct=False,
2531
+ correctness_score=0.0,
2532
+ geomean_speedup=0.0,
2533
+ passed_tests=0,
2534
+ total_tests=0,
2535
+ error_message=f"Failed to write test cases: {write_result.stderr}",
2536
+ )
2537
+
2538
+ print("Running evaluation...")
2539
+
2540
+ # Build evaluation command
2541
+ # DigitalOcean AMD uses HIP_VISIBLE_DEVICES for AMD GPUs
2542
+ # Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
2543
+ venv_bin = env_state.venv_bin
2544
+ env_vars = (
2545
+ f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
2546
+ )
2547
+
2548
+ # Run from run_path so reference_kernel.py is importable
2549
+ # Use installed wafer-core module
2550
+ eval_cmd = (
2551
+ f"cd {run_path} && "
2552
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
2553
+ f"--implementation {impl_path} "
2554
+ f"--reference {ref_path} "
2555
+ f"--test-cases {test_cases_path} "
2556
+ f"--run-dir {run_path}"
2557
+ )
2558
+
2559
+ if args.benchmark:
2560
+ eval_cmd += " --benchmark"
2561
+ if args.defensive:
2562
+ eval_cmd += " --defensive"
2563
+
2564
+ # Run with timeout
2565
+ import trio
2566
+
2567
+ with trio.move_on_after(target.eval_timeout) as cancel_scope:
2568
+ result = await client.exec(eval_cmd)
2569
+
2570
+ if cancel_scope.cancelled_caught:
2571
+ return EvaluateResult(
2572
+ success=False,
2573
+ all_correct=False,
2574
+ correctness_score=0.0,
2575
+ geomean_speedup=0.0,
2576
+ passed_tests=0,
2577
+ total_tests=0,
2578
+ error_message=f"Evaluation timed out after {target.eval_timeout}s",
2579
+ )
2580
+
2581
+ # Show output to user
2582
+ stdout = result.stdout
2583
+ stderr = result.stderr
2584
+ if stdout:
2585
+ print(stdout)
2586
+
2587
+ if result.exit_code != 0:
2588
+ # Include both stdout and stderr for debugging
2589
+ error_parts = [f"Evaluation failed (exit code {result.exit_code}):"]
2590
+ if stdout:
2591
+ error_parts.append(f"stdout: {stdout}")
2592
+ if stderr:
2593
+ error_parts.append(f"stderr: {stderr}")
2594
+ return EvaluateResult(
2595
+ success=False,
2596
+ all_correct=False,
2597
+ correctness_score=0.0,
2598
+ geomean_speedup=0.0,
2599
+ passed_tests=0,
2600
+ total_tests=0,
2601
+ error_message="\n".join(error_parts),
2602
+ )
2603
+
2604
+ # Read results from results.json (like SSH path)
2605
+ results_path = f"{run_path}/results.json"
2606
+ cat_result = await client.exec(f"cat {results_path}")
2607
+
2608
+ if cat_result.exit_code != 0:
2609
+ return EvaluateResult(
2610
+ success=False,
2611
+ all_correct=False,
2612
+ correctness_score=0.0,
2613
+ geomean_speedup=0.0,
2614
+ passed_tests=0,
2615
+ total_tests=0,
2616
+ error_message=f"Failed to read results: {cat_result.stderr}",
2617
+ )
2618
+
2619
+ try:
2620
+ results_data = json.loads(cat_result.stdout)
2621
+ except json.JSONDecodeError as e:
2622
+ return EvaluateResult(
2623
+ success=False,
2624
+ all_correct=False,
2625
+ correctness_score=0.0,
2626
+ geomean_speedup=0.0,
2627
+ passed_tests=0,
2628
+ total_tests=0,
2629
+ error_message=f"Invalid JSON in results: {e}",
2630
+ )
2631
+
2632
+ # Extract backend results (same format as SSH path)
2633
+ backends = results_data.get("backends", [])
2634
+ if not backends:
2635
+ return EvaluateResult(
2636
+ success=False,
2637
+ all_correct=False,
2638
+ correctness_score=0.0,
2639
+ geomean_speedup=0.0,
2640
+ passed_tests=0,
2641
+ total_tests=0,
2642
+ error_message="No backend results found",
2643
+ )
2644
+
2645
+ backend = backends[0]
2646
+ correctness_tests = backend.get("correctness_tests", [])
2647
+ passed = sum(1 for t in correctness_tests if t.get("is_correct", False))
2648
+ total = len(correctness_tests)
2649
+
2650
+ return EvaluateResult(
2651
+ success=True,
2652
+ all_correct=backend.get("all_correct", False),
2653
+ correctness_score=backend.get("correctness_score", 0.0),
2654
+ geomean_speedup=backend.get("geomean_speedup", 0.0),
2655
+ passed_tests=passed,
2656
+ total_tests=total,
2657
+ )
2658
+
2659
+ except DigitalOceanError as e:
2660
+ return EvaluateResult(
2661
+ success=False,
2662
+ all_correct=False,
2663
+ correctness_score=0.0,
2664
+ geomean_speedup=0.0,
2665
+ passed_tests=0,
2666
+ total_tests=0,
2667
+ error_message=f"DigitalOcean error: {e}",
2668
+ )
2669
+
2670
+
2671
+ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2672
+ """Run evaluation on configured target.
2673
+
2674
+ Args:
2675
+ args: Evaluate arguments
2676
+
2677
+ Returns:
2678
+ Evaluation result
2679
+ """
2680
+ from .targets import get_default_target, load_target
2681
+
2682
+ # Validate input files
2683
+ err = _validate_files(args)
2684
+ if err:
2685
+ return EvaluateResult(
2686
+ success=False,
2687
+ all_correct=False,
2688
+ correctness_score=0.0,
2689
+ geomean_speedup=0.0,
2690
+ passed_tests=0,
2691
+ total_tests=0,
2692
+ error_message=err,
2693
+ )
2694
+
2695
+ # Load target
2696
+ target_name = args.target_name
2697
+ if not target_name:
2698
+ target_name = get_default_target()
2699
+ if not target_name:
2700
+ return EvaluateResult(
2701
+ success=False,
2702
+ all_correct=False,
2703
+ correctness_score=0.0,
2704
+ geomean_speedup=0.0,
2705
+ passed_tests=0,
2706
+ total_tests=0,
2707
+ error_message=(
2708
+ "No target specified and no default set.\n"
2709
+ "Set up a target first:\n"
2710
+ " wafer config targets init ssh --name my-gpu --host user@host:22\n"
2711
+ " wafer config targets init runpod --gpu MI300X\n"
2712
+ "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
2713
+ ),
2714
+ )
2715
+
2716
+ try:
2717
+ target = load_target(target_name)
2718
+ except FileNotFoundError:
2719
+ return EvaluateResult(
2720
+ success=False,
2721
+ all_correct=False,
2722
+ correctness_score=0.0,
2723
+ geomean_speedup=0.0,
2724
+ passed_tests=0,
2725
+ total_tests=0,
2726
+ error_message=f"Target not found: {target_name}. Run: wafer config targets list",
2727
+ )
2728
+
2729
+ print(f"Using target: {target_name}")
2730
+
2731
+ # Dispatch to appropriate executor
2732
+ if isinstance(target, LocalTarget):
2733
+ return await run_evaluate_local(args, target)
2734
+ elif isinstance(target, BaremetalTarget | VMTarget):
2735
+ return await run_evaluate_ssh(args, target)
2736
+ elif isinstance(target, ModalTarget):
2737
+ return await run_evaluate_modal(args, target)
2738
+ elif isinstance(target, WorkspaceTarget):
2739
+ return await run_evaluate_workspace(args, target)
2740
+ elif isinstance(target, RunPodTarget):
2741
+ return await run_evaluate_runpod(args, target)
2742
+ elif isinstance(target, DigitalOceanTarget):
2743
+ return await run_evaluate_digitalocean(args, target)
2744
+ else:
2745
+ return EvaluateResult(
2746
+ success=False,
2747
+ all_correct=False,
2748
+ correctness_score=0.0,
2749
+ geomean_speedup=0.0,
2750
+ passed_tests=0,
2751
+ total_tests=0,
2752
+ error_message=f"Unknown target type: {type(target)}",
2753
+ )
2754
+
2755
+
2756
+ # =============================================================================
2757
+ # KernelBench Format Evaluation
2758
+ # =============================================================================
2759
+
2760
+ # Inline evaluation script for KernelBench format
2761
+ # This runs inside the Docker container on the remote GPU
2762
+ KERNELBENCH_EVAL_SCRIPT = """
2763
+ import gc
2764
+ import json
2765
+ import os
2766
+ import sys
2767
+ import time
2768
+ import torch
2769
+ import torch.nn as nn
2770
+ from pathlib import Path
2771
+
2772
+ # Use a unique per-run PyTorch extension cache directory to ensure fresh compilation.
2773
+ # This prevents stale cached extensions from being loaded when the pod is reused.
2774
+ # Without this, if a kernel is modified but uses the same extension name,
2775
+ # PyTorch would load the old cached .so instead of recompiling.
2776
+ # We use a UUID-based directory instead of clearing the cache to avoid race conditions
2777
+ # with other processes that might be using the cache.
2778
+ import uuid
2779
+ unique_cache_dir = f"/tmp/torch_extensions_{uuid.uuid4().hex[:8]}"
2780
+ os.environ["TORCH_EXTENSIONS_DIR"] = unique_cache_dir
2781
+ print(f"[KernelBench] Using unique extension cache: {unique_cache_dir}")
2782
+
2783
+ # Clear any stale GPU memory from previous runs at startup
2784
+ # NOTE: empty_cache only frees memory from THIS process's PyTorch allocator.
2785
+ # It won't free memory from dead/zombie processes - rocm-smi --showpids can show
2786
+ # PIDs that no longer exist but still hold GPU memory. Those require a GPU reset
2787
+ # (rocm-smi --gpureset) to fully clear. TODO: detect and warn about orphaned memory.
2788
+ if torch.cuda.is_available():
2789
+ gc.collect()
2790
+ torch.cuda.empty_cache()
2791
+ torch.cuda.reset_peak_memory_stats()
2792
+
2793
+
2794
+ def _calculate_timing_stats(times: list[float]) -> dict:
2795
+ '''Calculate median and IQR from timing samples.
2796
+
2797
+ Returns dict with median, iqr_low (25th percentile), iqr_high (75th percentile),
2798
+ mean, min, max, and std.
2799
+ '''
2800
+ import statistics
2801
+
2802
+ if not times:
2803
+ return {"median": 0, "iqr_low": 0, "iqr_high": 0, "mean": 0, "min": 0, "max": 0, "std": 0}
2804
+
2805
+ sorted_times = sorted(times)
2806
+ n = len(sorted_times)
2807
+
2808
+ # Median
2809
+ median = statistics.median(sorted_times)
2810
+
2811
+ # Quartiles (25th and 75th percentile)
2812
+ # For small samples, use simple interpolation
2813
+ q1_idx = (n - 1) * 0.25
2814
+ q3_idx = (n - 1) * 0.75
2815
+
2816
+ q1_low = int(q1_idx)
2817
+ q1_frac = q1_idx - q1_low
2818
+ iqr_low = sorted_times[q1_low] * (1 - q1_frac) + sorted_times[min(q1_low + 1, n - 1)] * q1_frac
2819
+
2820
+ q3_low = int(q3_idx)
2821
+ q3_frac = q3_idx - q3_low
2822
+ iqr_high = sorted_times[q3_low] * (1 - q3_frac) + sorted_times[min(q3_low + 1, n - 1)] * q3_frac
2823
+
2824
+ return {
2825
+ "median": median,
2826
+ "iqr_low": iqr_low,
2827
+ "iqr_high": iqr_high,
2828
+ "mean": statistics.mean(sorted_times),
2829
+ "min": min(sorted_times),
2830
+ "max": max(sorted_times),
2831
+ "std": statistics.stdev(sorted_times) if n > 1 else 0,
2832
+ }
2833
+
2834
+
2835
+ def run_profiling(model, inputs, name, output_dir):
2836
+ '''Run torch.profiler and return summary stats.'''
2837
+ from torch.profiler import profile, ProfilerActivity
2838
+
2839
+ # Determine activities based on backend
2840
+ activities = [ProfilerActivity.CPU]
2841
+ if torch.cuda.is_available():
2842
+ activities.append(ProfilerActivity.CUDA)
2843
+
2844
+ # Warmup
2845
+ for _ in range(3):
2846
+ with torch.no_grad():
2847
+ _ = model(*inputs)
2848
+ torch.cuda.synchronize()
2849
+
2850
+ # Profile
2851
+ with profile(
2852
+ activities=activities,
2853
+ record_shapes=True,
2854
+ with_stack=False,
2855
+ profile_memory=True,
2856
+ ) as prof:
2857
+ with torch.no_grad():
2858
+ _ = model(*inputs)
2859
+ torch.cuda.synchronize()
2860
+
2861
+ # Get key averages
2862
+ key_averages = prof.key_averages()
2863
+
2864
+ # Find the main kernel (longest GPU time)
2865
+ # Use cuda_time_total for compatibility with both CUDA and ROCm
2866
+ def get_gpu_time(e):
2867
+ # Try different attributes for GPU time
2868
+ if hasattr(e, 'cuda_time_total'):
2869
+ return e.cuda_time_total
2870
+ if hasattr(e, 'device_time_total'):
2871
+ return e.device_time_total
2872
+ if hasattr(e, 'self_cuda_time_total'):
2873
+ return e.self_cuda_time_total
2874
+ return 0
2875
+
2876
+ gpu_events = [e for e in key_averages if get_gpu_time(e) > 0]
2877
+ gpu_events.sort(key=lambda e: get_gpu_time(e), reverse=True)
2878
+
2879
+ stats = {
2880
+ "name": name,
2881
+ "total_gpu_time_ms": sum(get_gpu_time(e) for e in gpu_events) / 1000,
2882
+ "total_cpu_time_ms": sum(e.cpu_time_total for e in key_averages) / 1000,
2883
+ "num_gpu_kernels": len(gpu_events),
2884
+ "top_kernels": [],
2885
+ }
2886
+
2887
+ # Top 5 kernels by GPU time
2888
+ for e in gpu_events[:5]:
2889
+ stats["top_kernels"].append({
2890
+ "name": e.key,
2891
+ "gpu_time_ms": get_gpu_time(e) / 1000,
2892
+ "cpu_time_ms": e.cpu_time_total / 1000,
2893
+ "calls": e.count,
2894
+ })
2895
+
2896
+ # Save trace for visualization
2897
+ trace_path = Path(output_dir) / f"{name}_trace.json"
2898
+ prof.export_chrome_trace(str(trace_path))
2899
+ stats["trace_file"] = str(trace_path)
2900
+
2901
+ return stats
2902
+
2903
+
2904
+ def validate_custom_inputs(original_inputs, custom_inputs):
2905
+ '''Validate that custom inputs match the expected signature.
2906
+
2907
+ Returns (is_valid, error_message).
2908
+ '''
2909
+ if len(original_inputs) != len(custom_inputs):
2910
+ return False, f"get_inputs() must return {len(original_inputs)} tensors, got {len(custom_inputs)}"
2911
+
2912
+ for i, (orig, cust) in enumerate(zip(original_inputs, custom_inputs)):
2913
+ if not isinstance(cust, torch.Tensor):
2914
+ if not isinstance(orig, torch.Tensor):
2915
+ continue # Both non-tensor, ok
2916
+ return False, f"Input {i}: expected Tensor, got {type(cust).__name__}"
2917
+
2918
+ if not isinstance(orig, torch.Tensor):
2919
+ return False, f"Input {i}: expected {type(orig).__name__}, got Tensor"
2920
+
2921
+ if orig.dtype != cust.dtype:
2922
+ return False, f"Input {i}: dtype mismatch - expected {orig.dtype}, got {cust.dtype}"
2923
+
2924
+ if orig.dim() != cust.dim():
2925
+ return False, f"Input {i}: dimension mismatch - expected {orig.dim()}D, got {cust.dim()}D"
2926
+
2927
+ return True, None
2928
+
2929
+
2930
+ def analyze_diff(ref_output, new_output, rtol=1e-3, atol=1e-3, max_samples=5):
2931
+ '''Analyze differences between reference and implementation outputs.
2932
+
2933
+ Returns a dict with detailed diff information.
2934
+ '''
2935
+ diff = (ref_output - new_output).abs()
2936
+ threshold = atol + rtol * ref_output.abs()
2937
+ wrong_mask = diff > threshold
2938
+
2939
+ total_elements = ref_output.numel()
2940
+ wrong_count = wrong_mask.sum().item()
2941
+
2942
+ # Basic stats
2943
+ max_diff = diff.max().item()
2944
+ max_diff_idx = tuple(torch.unravel_index(diff.argmax(), diff.shape))
2945
+ max_diff_idx = tuple(int(i) for i in max_diff_idx) # Convert to Python ints
2946
+
2947
+ # Relative error (avoid div by zero)
2948
+ ref_abs = ref_output.abs()
2949
+ nonzero_mask = ref_abs > 1e-8
2950
+ if nonzero_mask.any():
2951
+ rel_error = diff[nonzero_mask] / ref_abs[nonzero_mask]
2952
+ max_rel_error = rel_error.max().item()
2953
+ mean_rel_error = rel_error.mean().item()
2954
+ else:
2955
+ max_rel_error = float('inf') if max_diff > 0 else 0.0
2956
+ mean_rel_error = max_rel_error
2957
+
2958
+ # Error histogram (buckets: <1e-6, 1e-6 to 1e-4, 1e-4 to 1e-2, 1e-2 to 1, >1)
2959
+ histogram = {
2960
+ '<1e-6': int((diff < 1e-6).sum().item()),
2961
+ '1e-6 to 1e-4': int(((diff >= 1e-6) & (diff < 1e-4)).sum().item()),
2962
+ '1e-4 to 1e-2': int(((diff >= 1e-4) & (diff < 1e-2)).sum().item()),
2963
+ '1e-2 to 1': int(((diff >= 1e-2) & (diff < 1)).sum().item()),
2964
+ '>1': int((diff >= 1).sum().item()),
2965
+ }
2966
+
2967
+ result = {
2968
+ 'max_diff': max_diff,
2969
+ 'max_diff_idx': max_diff_idx,
2970
+ 'mean_diff': diff.mean().item(),
2971
+ 'max_rel_error': max_rel_error,
2972
+ 'mean_rel_error': mean_rel_error,
2973
+ 'total_elements': total_elements,
2974
+ 'wrong_count': int(wrong_count),
2975
+ 'wrong_pct': 100.0 * wrong_count / total_elements,
2976
+ 'histogram': histogram,
2977
+ 'samples': [],
2978
+ }
2979
+
2980
+ # Get indices of wrong elements
2981
+ if wrong_count > 0:
2982
+ wrong_indices = torch.nonzero(wrong_mask, as_tuple=False)
2983
+
2984
+ # Take first N samples
2985
+ num_samples = min(max_samples, len(wrong_indices))
2986
+ for i in range(num_samples):
2987
+ idx = tuple(wrong_indices[i].tolist())
2988
+ ref_val = ref_output[idx].item()
2989
+ new_val = new_output[idx].item()
2990
+ diff_val = diff[idx].item()
2991
+ result['samples'].append({
2992
+ 'index': idx,
2993
+ 'ref': ref_val,
2994
+ 'impl': new_val,
2995
+ 'diff': diff_val,
2996
+ })
2997
+
2998
+ # Try to detect pattern
2999
+ if wrong_count >= total_elements * 0.99:
3000
+ result['pattern'] = 'all_wrong'
3001
+ elif wrong_count < total_elements * 0.01:
3002
+ # Check if failures are at boundaries
3003
+ shape = ref_output.shape
3004
+ boundary_count = 0
3005
+ for idx in wrong_indices[:min(100, len(wrong_indices))]:
3006
+ idx_list = idx.tolist()
3007
+ is_boundary = any(i == 0 or i == s - 1 for i, s in zip(idx_list, shape))
3008
+ if is_boundary:
3009
+ boundary_count += 1
3010
+ if boundary_count > len(wrong_indices[:100]) * 0.8:
3011
+ result['pattern'] = 'boundary_issue'
3012
+ else:
3013
+ result['pattern'] = 'scattered'
3014
+ else:
3015
+ result['pattern'] = 'partial'
3016
+
3017
+ return result
3018
+
3019
+
3020
+ def print_diff_analysis(analysis):
3021
+ '''Print a human-readable diff analysis.'''
3022
+ print(f"[KernelBench] Diff analysis:")
3023
+
3024
+ # Max diff with location
3025
+ idx_str = ','.join(str(i) for i in analysis['max_diff_idx'])
3026
+ print(f" Max diff: {analysis['max_diff']:.6f} at index [{idx_str}]")
3027
+ print(f" Mean diff: {analysis['mean_diff']:.6f}")
3028
+
3029
+ # Relative errors
3030
+ print(f" Max relative error: {analysis['max_rel_error']:.2%}, Mean: {analysis['mean_rel_error']:.2%}")
3031
+
3032
+ # Wrong count
3033
+ print(f" Wrong elements: {analysis['wrong_count']:,} / {analysis['total_elements']:,} ({analysis['wrong_pct']:.2f}%)")
3034
+
3035
+ # Histogram
3036
+ hist = analysis['histogram']
3037
+ print(f" Error distribution: <1e-6: {hist['<1e-6']:,} | 1e-6~1e-4: {hist['1e-6 to 1e-4']:,} | 1e-4~1e-2: {hist['1e-4 to 1e-2']:,} | 1e-2~1: {hist['1e-2 to 1']:,} | >1: {hist['>1']:,}")
3038
+
3039
+ if 'pattern' in analysis:
3040
+ pattern_desc = {
3041
+ 'all_wrong': 'ALL elements wrong - likely algorithmic error or wrong weights',
3042
+ 'boundary_issue': 'Mostly BOUNDARY elements wrong - check edge handling',
3043
+ 'scattered': 'SCATTERED failures - numerical precision issue?',
3044
+ 'partial': 'PARTIAL failures - check specific conditions',
3045
+ }
3046
+ print(f" Pattern: {pattern_desc.get(analysis['pattern'], analysis['pattern'])}")
3047
+
3048
+ if analysis['samples']:
3049
+ print(f" Sample failures:")
3050
+ for s in analysis['samples']:
3051
+ idx_str = ','.join(str(i) for i in s['index'])
3052
+ print(f" [{idx_str}]: ref={s['ref']:.6f} impl={s['impl']:.6f} (diff={s['diff']:.6f})")
3053
+
3054
+
3055
+ def main():
3056
+ # Parse args
3057
+ import argparse
3058
+ parser = argparse.ArgumentParser()
3059
+ parser.add_argument("--impl", required=True)
3060
+ parser.add_argument("--reference", required=True)
3061
+ parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
3062
+ parser.add_argument("--benchmark", action="store_true")
3063
+ parser.add_argument("--profile", action="store_true")
3064
+ parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
3065
+ parser.add_argument("--defense-module", help="Path to defense.py module")
3066
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
3067
+ parser.add_argument("--num-correct-trials", type=int, default=3)
3068
+ parser.add_argument("--num-perf-trials", type=int, default=10)
3069
+ parser.add_argument("--output", required=True)
3070
+ parser.add_argument("--stages", default="compile,correctness",
3071
+ help="Comma-separated stages: compile, correctness, benchmark, defense")
3072
+ args = parser.parse_args()
3073
+
3074
+ # Parse stages
3075
+ stages = set(args.stages.split(","))
3076
+ run_compile = "compile" in stages
3077
+ run_correctness = "correctness" in stages
3078
+ run_benchmark = "benchmark" in stages or args.benchmark
3079
+ run_defense = "defense" in stages or args.defensive
3080
+ print(f"[KernelBench] Stages: {args.stages}")
3081
+
3082
+ # Load defense module if defensive mode is enabled
3083
+ defense_module = None
3084
+ if args.defensive and args.defense_module:
3085
+ try:
3086
+ import importlib.util
3087
+ defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
3088
+ defense_module = importlib.util.module_from_spec(defense_spec)
3089
+ defense_spec.loader.exec_module(defense_module)
3090
+ print("[KernelBench] Defense module loaded")
3091
+ except Exception as e:
3092
+ print(f"[KernelBench] Warning: Could not load defense module: {e}")
3093
+
3094
+ # Create output directory for profiles
3095
+ output_dir = Path(args.output).parent
3096
+ profile_dir = output_dir / "profiles"
3097
+ if args.profile:
3098
+ profile_dir.mkdir(exist_ok=True)
3099
+
3100
+ results = {
3101
+ "compiled": False,
3102
+ "correct": False,
3103
+ "speedup": None,
3104
+ "runtime_ms": None,
3105
+ "reference_runtime_ms": None,
3106
+ "error": None,
3107
+ }
3108
+
3109
+ try:
3110
+ # Load reference module
3111
+ import importlib.util
3112
+ ref_spec = importlib.util.spec_from_file_location("reference", args.reference)
3113
+ ref_module = importlib.util.module_from_spec(ref_spec)
3114
+ ref_spec.loader.exec_module(ref_module)
3115
+
3116
+ Model = ref_module.Model
3117
+ get_inputs = ref_module.get_inputs
3118
+ get_init_inputs = ref_module.get_init_inputs
3119
+
3120
+ # Load custom inputs if provided
3121
+ if args.inputs:
3122
+ inputs_spec = importlib.util.spec_from_file_location("custom_inputs", args.inputs)
3123
+ inputs_module = importlib.util.module_from_spec(inputs_spec)
3124
+ inputs_spec.loader.exec_module(inputs_module)
3125
+
3126
+ # Validate custom inputs match expected signature
3127
+ original_inputs = get_inputs()
3128
+ custom_get_inputs = inputs_module.get_inputs
3129
+ custom_inputs = custom_get_inputs()
3130
+
3131
+ is_valid, error_msg = validate_custom_inputs(original_inputs, custom_inputs)
3132
+ if not is_valid:
3133
+ print(f"[KernelBench] Custom inputs validation failed: {error_msg}")
3134
+ results["error"] = f"Custom inputs validation failed: {error_msg}"
3135
+ raise ValueError(error_msg)
3136
+
3137
+ # Override get_inputs (and optionally get_init_inputs)
3138
+ get_inputs = custom_get_inputs
3139
+ if hasattr(inputs_module, 'get_init_inputs'):
3140
+ get_init_inputs = inputs_module.get_init_inputs
3141
+
3142
+ # Show what changed
3143
+ orig_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in original_inputs]
3144
+ cust_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in custom_inputs]
3145
+ print(f"[KernelBench] Using custom inputs: {orig_shapes} -> {cust_shapes}")
3146
+
3147
+ # Load implementation module
3148
+ impl_spec = importlib.util.spec_from_file_location("implementation", args.impl)
3149
+ impl_module = importlib.util.module_from_spec(impl_spec)
3150
+ impl_spec.loader.exec_module(impl_module)
3151
+
3152
+ ModelNew = impl_module.ModelNew
3153
+ results["compiled"] = True
3154
+ print("[KernelBench] Modules loaded successfully")
3155
+
3156
+ # Instantiate models with synchronized seeds for reproducible weights
3157
+ # (matches upstream KernelBench behavior in src/eval.py)
3158
+ seed = args.seed
3159
+ init_inputs = get_init_inputs()
3160
+ with torch.no_grad():
3161
+ torch.manual_seed(seed)
3162
+ torch.cuda.manual_seed(seed)
3163
+ ref_model = Model(*init_inputs).cuda().eval()
3164
+
3165
+ torch.manual_seed(seed)
3166
+ torch.cuda.manual_seed(seed)
3167
+ new_model = ModelNew(*init_inputs).cuda().eval()
3168
+ print(f"[KernelBench] Models instantiated (seed={seed})")
3169
+
3170
+ # Run correctness trials (if stage enabled)
3171
+ all_correct = True
3172
+ if not run_correctness:
3173
+ print("[KernelBench] Skipping correctness (not in stages)")
3174
+ results["correct"] = None # Unknown - not checked
3175
+ else:
3176
+ for trial in range(args.num_correct_trials):
3177
+ inputs = get_inputs()
3178
+ inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
3179
+
3180
+ with torch.no_grad():
3181
+ ref_output = ref_model(*inputs)
3182
+ new_output = new_model(*inputs)
3183
+
3184
+ # Compare outputs
3185
+ if isinstance(ref_output, torch.Tensor):
3186
+ if not torch.allclose(ref_output, new_output, rtol=1e-3, atol=1e-3):
3187
+ all_correct = False
3188
+ analysis = analyze_diff(ref_output, new_output)
3189
+ results["error"] = f"Correctness failed on trial {trial+1}: max diff = {analysis['max_diff']}"
3190
+ results["diff_analysis"] = analysis
3191
+ print_diff_analysis(analysis)
3192
+
3193
+ # Save tensors for debugging
3194
+ debug_dir = output_dir / "debug"
3195
+ debug_dir.mkdir(exist_ok=True)
3196
+ torch.save(ref_output.cpu(), debug_dir / "ref_output.pt")
3197
+ torch.save(new_output.cpu(), debug_dir / "impl_output.pt")
3198
+ torch.save(inputs[0].cpu() if inputs else None, debug_dir / "input.pt")
3199
+ print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
3200
+ break
3201
+ else:
3202
+ # Handle tuple/list outputs
3203
+ for i, (r, n) in enumerate(zip(ref_output, new_output)):
3204
+ if isinstance(r, torch.Tensor):
3205
+ if not torch.allclose(r, n, rtol=1e-3, atol=1e-3):
3206
+ all_correct = False
3207
+ analysis = analyze_diff(r, n)
3208
+ results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {analysis['max_diff']}"
3209
+ results["diff_analysis"] = analysis
3210
+ print_diff_analysis(analysis)
3211
+
3212
+ # Save tensors for debugging
3213
+ debug_dir = output_dir / "debug"
3214
+ debug_dir.mkdir(exist_ok=True)
3215
+ torch.save(r.cpu(), debug_dir / f"ref_output_{i}.pt")
3216
+ torch.save(n.cpu(), debug_dir / f"impl_output_{i}.pt")
3217
+ print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
3218
+ break
3219
+ if not all_correct:
3220
+ break
3221
+
3222
+ results["correct"] = all_correct
3223
+ print(f"[KernelBench] Correctness: {all_correct}")
3224
+
3225
+ # Run benchmark if stage enabled (and correctness passed or skipped)
3226
+ should_benchmark = run_benchmark and (all_correct or not run_correctness)
3227
+ if should_benchmark:
3228
+ print("[KernelBench] Running benchmarks...")
3229
+ inputs = get_inputs()
3230
+ inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
3231
+
3232
+ if run_defense and defense_module is not None:
3233
+ # Use full defense suite
3234
+ print("[KernelBench] Running defense checks on implementation...")
3235
+ run_all_defenses = defense_module.run_all_defenses
3236
+ time_with_defenses = defense_module.time_execution_with_defenses
3237
+
3238
+ # Run defense checks on implementation
3239
+ all_passed, defense_results, _ = run_all_defenses(
3240
+ lambda *x: new_model(*x),
3241
+ *inputs,
3242
+ )
3243
+ results["defense_results"] = {
3244
+ name: {"passed": passed, "message": msg}
3245
+ for name, passed, msg in defense_results
3246
+ }
3247
+ if not all_passed:
3248
+ failed = [name for name, passed, _ in defense_results if not passed]
3249
+ results["error"] = f"Defense checks failed: {failed}"
3250
+ print(f"[KernelBench] Defense checks FAILED: {failed}")
3251
+ for name, passed, msg in defense_results:
3252
+ status = "PASS" if passed else "FAIL"
3253
+ print(f" [{status}] {name}: {msg}")
3254
+ else:
3255
+ print("[KernelBench] All defense checks passed")
3256
+
3257
+ # Time with defensive timing
3258
+ impl_times, _ = time_with_defenses(
3259
+ lambda: new_model(*inputs),
3260
+ [],
3261
+ num_warmup=5,
3262
+ num_trials=args.num_perf_trials,
3263
+ verbose=False,
3264
+ run_defenses=False, # Already ran above
3265
+ )
3266
+ # Calculate stats for new model
3267
+ new_stats = _calculate_timing_stats(impl_times)
3268
+ results["runtime_ms"] = new_stats["median"]
3269
+ results["runtime_stats"] = new_stats
3270
+
3271
+ # Reference timing
3272
+ ref_times, _ = time_with_defenses(
3273
+ lambda: ref_model(*inputs),
3274
+ [],
3275
+ num_warmup=5,
3276
+ num_trials=args.num_perf_trials,
3277
+ verbose=False,
3278
+ run_defenses=False,
3279
+ )
3280
+ ref_stats = _calculate_timing_stats(ref_times)
3281
+ results["reference_runtime_ms"] = ref_stats["median"]
3282
+ results["reference_runtime_stats"] = ref_stats
3283
+ results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
3284
+ print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
3285
+ else:
3286
+ # Standard timing without full defenses
3287
+ # Warmup BOTH models before benchmarking either
3288
+ # This ensures consistent GPU state and avoids MIOpen cache effects
3289
+ # that cause variance when warming up models sequentially
3290
+ for _ in range(5):
3291
+ with torch.no_grad():
3292
+ _ = new_model(*inputs)
3293
+ _ = ref_model(*inputs)
3294
+ torch.cuda.synchronize()
3295
+
3296
+ # Benchmark new model
3297
+ start = torch.cuda.Event(enable_timing=True)
3298
+ end = torch.cuda.Event(enable_timing=True)
3299
+
3300
+ new_times = []
3301
+ for _ in range(args.num_perf_trials):
3302
+ start.record()
3303
+ with torch.no_grad():
3304
+ _ = new_model(*inputs)
3305
+ end.record()
3306
+ torch.cuda.synchronize()
3307
+ new_times.append(start.elapsed_time(end))
3308
+
3309
+ new_stats = _calculate_timing_stats(new_times)
3310
+ results["runtime_ms"] = new_stats["median"]
3311
+ results["runtime_stats"] = new_stats
3312
+
3313
+ # Benchmark reference model
3314
+ ref_times = []
3315
+ for _ in range(args.num_perf_trials):
3316
+ start.record()
3317
+ with torch.no_grad():
3318
+ _ = ref_model(*inputs)
3319
+ end.record()
3320
+ torch.cuda.synchronize()
3321
+ ref_times.append(start.elapsed_time(end))
3322
+
3323
+ ref_stats = _calculate_timing_stats(ref_times)
3324
+ results["reference_runtime_ms"] = ref_stats["median"]
3325
+ results["reference_runtime_stats"] = ref_stats
3326
+ results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
3327
+ print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
3328
+
3329
+ # Run profiling if requested and correctness passed
3330
+ if args.profile and all_correct:
3331
+ print("[KernelBench] Running profiler...")
3332
+ inputs = get_inputs()
3333
+ inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
3334
+
3335
+ try:
3336
+ # Profile implementation
3337
+ impl_stats = run_profiling(new_model, inputs, "implementation", str(profile_dir))
3338
+ results["profile_impl"] = impl_stats
3339
+ print(f"[KernelBench] Implementation profile:")
3340
+ print(f" Total GPU time: {impl_stats['total_gpu_time_ms']:.3f}ms")
3341
+ print(f" Kernels launched: {impl_stats['num_gpu_kernels']}")
3342
+ if impl_stats['top_kernels']:
3343
+ print(f" Top kernel: {impl_stats['top_kernels'][0]['name'][:60]}...")
3344
+ print(f" {impl_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
3345
+
3346
+ # Profile reference
3347
+ ref_stats = run_profiling(ref_model, inputs, "reference", str(profile_dir))
3348
+ results["profile_ref"] = ref_stats
3349
+ print(f"[KernelBench] Reference profile:")
3350
+ print(f" Total GPU time: {ref_stats['total_gpu_time_ms']:.3f}ms")
3351
+ print(f" Kernels launched: {ref_stats['num_gpu_kernels']}")
3352
+ if ref_stats['top_kernels']:
3353
+ print(f" Top kernel: {ref_stats['top_kernels'][0]['name'][:60]}...")
3354
+ print(f" {ref_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
3355
+
3356
+ print(f"[KernelBench] Profile traces saved to: {profile_dir}/")
3357
+
3358
+ except Exception as prof_err:
3359
+ print(f"[KernelBench] Profiling failed: {prof_err}")
3360
+ results["profile_error"] = str(prof_err)
3361
+
3362
+ except Exception as e:
3363
+ import traceback
3364
+ results["error"] = f"{type(e).__name__}: {e}\\n{traceback.format_exc()}"
3365
+ print(f"[KernelBench] Error: {results['error']}")
3366
+
3367
+ # Write results
3368
+ with open(args.output, "w") as f:
3369
+ json.dump(results, f, indent=2)
3370
+ print(f"[KernelBench] Results written to {args.output}")
3371
+
3372
+ # Cleanup GPU memory
3373
+ try:
3374
+ del ref_model, new_model
3375
+ except NameError:
3376
+ pass
3377
+ import gc
3378
+ gc.collect()
3379
+ if torch.cuda.is_available():
3380
+ torch.cuda.empty_cache()
3381
+
3382
+ if __name__ == "__main__":
3383
+ main()
3384
+ """
3385
+
3386
+
3387
+ def _validate_kernelbench_files(args: KernelBenchEvaluateArgs) -> str | None:
3388
+ """Validate that KernelBench input files exist and have expected signatures.
3389
+
3390
+ Returns:
3391
+ Error message if validation fails, None if all valid
3392
+ """
3393
+ if not args.implementation.exists():
3394
+ return f"Implementation file not found: {args.implementation}"
3395
+ if not args.reference.exists():
3396
+ return f"Reference file not found: {args.reference}"
3397
+
3398
+ # Validate implementation has ModelNew class
3399
+ impl_missing = _check_python_file_has(args.implementation, "ModelNew")
3400
+ if impl_missing:
3401
+ # Check if it looks like functional format (has custom_kernel)
3402
+ has_custom_kernel = not _check_python_file_has(args.implementation, "custom_kernel")
3403
+ if has_custom_kernel:
3404
+ return (
3405
+ f"Implementation file missing 'ModelNew' class: {args.implementation}\n"
3406
+ "Hint: This looks like functional format. Use 'wafer evaluate' instead:\n"
3407
+ f" wafer evaluate --impl {args.implementation} --reference <ref.py> --test-cases <tests.json>"
3408
+ )
3409
+ return (
3410
+ f"Implementation file missing 'ModelNew' class: {args.implementation}\n"
3411
+ " KernelBench format requires a 'class ModelNew(nn.Module)' definition"
3412
+ )
3413
+
3414
+ # Validate reference has Model, get_inputs, get_init_inputs
3415
+ ref_missing = _check_python_file_has(args.reference, "Model", "get_inputs", "get_init_inputs")
3416
+ if ref_missing:
3417
+ # Check if it looks like functional format (has ref_kernel and generate_input)
3418
+ has_functional = not _check_python_file_has(args.reference, "ref_kernel", "generate_input")
3419
+ if has_functional:
3420
+ return (
3421
+ f"Reference file missing required definitions: {', '.join(ref_missing)}\n"
3422
+ "Hint: This looks like functional format. Use 'wafer evaluate' instead:\n"
3423
+ f" wafer evaluate --impl <impl.py> --reference {args.reference} --test-cases <tests.json>"
3424
+ )
3425
+ return (
3426
+ f"Reference file missing required definitions: {', '.join(ref_missing)}\n"
3427
+ f" File: {args.reference}\n"
3428
+ " KernelBench format requires: 'class Model', 'get_inputs()', 'get_init_inputs()'"
3429
+ )
3430
+
3431
+ # Static kernel validation if backend specified
3432
+ if args.backend:
3433
+ from wafer_core.utils.kernel_utils.static_checker import validate_kernel_static
3434
+
3435
+ code = args.implementation.read_text()
3436
+ valid, errors, warnings = validate_kernel_static(code, backend=args.backend)
3437
+
3438
+ # Print warnings (don't fail)
3439
+ for warning in warnings:
3440
+ logger.warning(f"Static check warning: {warning}")
3441
+
3442
+ # Fail on errors
3443
+ if not valid:
3444
+ error_list = "\n - ".join(errors)
3445
+ return (
3446
+ f"Static kernel validation failed for backend '{args.backend}':\n"
3447
+ f" - {error_list}\n\n"
3448
+ f"The implementation must use {args.backend.upper()} kernel primitives.\n"
3449
+ "See KernelBench documentation for valid kernel patterns."
3450
+ )
3451
+
3452
+ return None
3453
+
3454
+
3455
+ async def run_evaluate_kernelbench_docker(
3456
+ args: KernelBenchEvaluateArgs,
3457
+ target: BaremetalTarget | VMTarget,
3458
+ ) -> EvaluateResult:
3459
+ """Run KernelBench format evaluation in Docker container on SSH-based target.
3460
+
3461
+ Similar to run_evaluate_docker but uses KernelBench eval script instead.
3462
+ """
3463
+ from datetime import datetime
3464
+
3465
+ from wafer_core.async_ssh import AsyncSSHClient
3466
+
3467
+ CONTAINER_WORKSPACE = "/workspace"
3468
+ REMOTE_WORKSPACE_BASE = "~/.wafer/workspaces"
3469
+
3470
+ if not target.docker_image:
3471
+ return EvaluateResult(
3472
+ success=False,
3473
+ all_correct=False,
3474
+ correctness_score=0.0,
3475
+ geomean_speedup=0.0,
3476
+ passed_tests=0,
3477
+ total_tests=0,
3478
+ error_message="docker_image must be set for Docker execution",
3479
+ )
3480
+
3481
+ # Select GPU
3482
+ gpu_id = _select_gpu_id(target, args.gpu_id)
3483
+
3484
+ print(f"Connecting to {target.ssh_target}...")
3485
+
3486
+ async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
3487
+ # Create workspace
3488
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3489
+ run_dir = f"kernelbench_eval_{timestamp}"
3490
+ workspace_path = await client.expand_path(f"{REMOTE_WORKSPACE_BASE}/kernelbench")
3491
+ run_path = f"{workspace_path}/{run_dir}"
3492
+
3493
+ await client.exec(f"mkdir -p {run_path}")
3494
+ print(f"Created run directory: {run_path}")
3495
+
3496
+ # Read and upload files
3497
+ impl_code = args.implementation.read_text()
3498
+ ref_code = args.reference.read_text()
3499
+
3500
+ # Write implementation
3501
+ impl_path = f"{run_path}/implementation.py"
3502
+ write_result = await client.exec(
3503
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
3504
+ )
3505
+ if write_result.exit_code != 0:
3506
+ return EvaluateResult(
3507
+ success=False,
3508
+ all_correct=False,
3509
+ correctness_score=0.0,
3510
+ geomean_speedup=0.0,
3511
+ passed_tests=0,
3512
+ total_tests=0,
3513
+ error_message=f"Failed to write implementation: {write_result.stderr}",
3514
+ )
3515
+
3516
+ # Write reference
3517
+ ref_path = f"{run_path}/reference.py"
3518
+ write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
3519
+ if write_result.exit_code != 0:
3520
+ return EvaluateResult(
3521
+ success=False,
3522
+ all_correct=False,
3523
+ correctness_score=0.0,
3524
+ geomean_speedup=0.0,
3525
+ passed_tests=0,
3526
+ total_tests=0,
3527
+ error_message=f"Failed to write reference: {write_result.stderr}",
3528
+ )
3529
+
3530
+ # Write custom inputs if provided
3531
+ if args.inputs:
3532
+ inputs_code = args.inputs.read_text()
3533
+ inputs_file_path = f"{run_path}/custom_inputs.py"
3534
+ write_result = await client.exec(
3535
+ f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3536
+ )
3537
+ if write_result.exit_code != 0:
3538
+ return EvaluateResult(
3539
+ success=False,
3540
+ all_correct=False,
3541
+ correctness_score=0.0,
3542
+ geomean_speedup=0.0,
3543
+ passed_tests=0,
3544
+ total_tests=0,
3545
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3546
+ )
3547
+
3548
+ # Write eval script
3549
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
3550
+ write_result = await client.exec(
3551
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
3552
+ )
3553
+ if write_result.exit_code != 0:
3554
+ return EvaluateResult(
3555
+ success=False,
3556
+ all_correct=False,
3557
+ correctness_score=0.0,
3558
+ geomean_speedup=0.0,
3559
+ passed_tests=0,
3560
+ total_tests=0,
3561
+ error_message=f"Failed to write eval script: {write_result.stderr}",
3562
+ )
3563
+
3564
+ # Write defense module if defensive mode is enabled
3565
+ defense_module_path = None
3566
+ if args.defensive:
3567
+ defense_path = (
3568
+ Path(__file__).parent.parent.parent.parent
3569
+ / "packages"
3570
+ / "wafer-core"
3571
+ / "wafer_core"
3572
+ / "utils"
3573
+ / "kernel_utils"
3574
+ / "defense.py"
3575
+ )
3576
+ if defense_path.exists():
3577
+ defense_code = defense_path.read_text()
3578
+ defense_module_path = f"{run_path}/defense.py"
3579
+ write_result = await client.exec(
3580
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3581
+ )
3582
+ if write_result.exit_code != 0:
3583
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3584
+ defense_module_path = None
3585
+ else:
3586
+ print(f"Warning: defense.py not found at {defense_path}")
3587
+
3588
+ print("Running KernelBench evaluation in Docker container...")
3589
+
3590
+ # Paths inside container
3591
+ container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
3592
+ container_impl_path = f"{container_run_path}/implementation.py"
3593
+ container_ref_path = f"{container_run_path}/reference.py"
3594
+ container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
3595
+ container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3596
+ container_output = f"{container_run_path}/results.json"
3597
+ container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
3598
+
3599
+ # Build eval command
3600
+ python_cmd_parts = [
3601
+ f"python3 {container_eval_script}",
3602
+ f"--impl {container_impl_path}",
3603
+ f"--reference {container_ref_path}",
3604
+ f"--output {container_output}",
3605
+ ]
3606
+
3607
+ if args.benchmark:
3608
+ python_cmd_parts.append("--benchmark")
3609
+ if args.profile:
3610
+ python_cmd_parts.append("--profile")
3611
+ if container_inputs_path:
3612
+ python_cmd_parts.append(f"--inputs {container_inputs_path}")
3613
+ if args.defensive and container_defense_path:
3614
+ python_cmd_parts.append("--defensive")
3615
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3616
+ python_cmd_parts.append(f"--seed {args.seed}")
3617
+ python_cmd_parts.append(f"--stages {args.stages}")
3618
+
3619
+ eval_cmd = " ".join(python_cmd_parts)
3620
+
3621
+ # Build pip install for torch dependencies if needed
3622
+ pip_install_cmd = _build_docker_pip_install_cmd(target)
3623
+ full_cmd = f"{pip_install_cmd} && cd {container_run_path} && {eval_cmd}"
3624
+
3625
+ # Build Docker command
3626
+ docker_cmd = _build_docker_run_command(
3627
+ image=target.docker_image,
3628
+ command=full_cmd,
3629
+ working_dir=container_run_path,
3630
+ env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
3631
+ gpus="all",
3632
+ volumes={workspace_path: CONTAINER_WORKSPACE},
3633
+ )
3634
+
3635
+ print(f"Docker command: {docker_cmd[:100]}...")
3636
+
3637
+ # Run and stream output
3638
+ log_lines = []
3639
+ async for line in client.exec_stream(docker_cmd):
3640
+ print(line, flush=True)
3641
+ log_lines.append(line)
3642
+
3643
+ # Read results
3644
+ results_path = f"{run_path}/results.json"
3645
+ cat_result = await client.exec(f"cat {results_path}")
3646
+
3647
+ if cat_result.exit_code != 0:
3648
+ log_tail = "\n".join(log_lines[-50:])
3649
+ return EvaluateResult(
3650
+ success=False,
3651
+ all_correct=False,
3652
+ correctness_score=0.0,
3653
+ geomean_speedup=0.0,
3654
+ passed_tests=0,
3655
+ total_tests=0,
3656
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
3657
+ )
3658
+
3659
+ # Parse results
3660
+ try:
3661
+ results_data = json.loads(cat_result.stdout)
3662
+ except json.JSONDecodeError as e:
3663
+ return EvaluateResult(
3664
+ success=False,
3665
+ all_correct=False,
3666
+ correctness_score=0.0,
3667
+ geomean_speedup=0.0,
3668
+ passed_tests=0,
3669
+ total_tests=0,
3670
+ error_message=f"Failed to parse results: {e}",
3671
+ )
3672
+
3673
+ # Convert to EvaluateResult
3674
+ # TODO: use compiled field - currently ignored, should affect success/error
3675
+ # compiled = results_data.get("compiled", False)
3676
+ correct = results_data.get("correct", False)
3677
+ speedup = results_data.get("speedup", 0.0) or 0.0
3678
+ error = results_data.get("error")
3679
+
3680
+ if error:
3681
+ return EvaluateResult(
3682
+ success=False,
3683
+ all_correct=False,
3684
+ correctness_score=0.0,
3685
+ geomean_speedup=0.0,
3686
+ passed_tests=0,
3687
+ total_tests=1,
3688
+ error_message=error,
3689
+ )
3690
+
3691
+ return EvaluateResult(
3692
+ success=True,
3693
+ all_correct=correct,
3694
+ correctness_score=1.0 if correct else 0.0,
3695
+ geomean_speedup=speedup,
3696
+ passed_tests=1 if correct else 0,
3697
+ total_tests=1,
3698
+ )
3699
+
3700
+
3701
+ # Default ROCm PyTorch image for DigitalOcean AMD MI300X
3702
+ DEFAULT_ROCM_DOCKER_IMAGE = "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0"
3703
+
3704
+
3705
+ async def run_evaluate_kernelbench_digitalocean(
3706
+ args: KernelBenchEvaluateArgs,
3707
+ target: DigitalOceanTarget,
3708
+ ) -> EvaluateResult:
3709
+ """Run KernelBench format evaluation in Docker container on DigitalOcean AMD GPU.
3710
+
3711
+ Uses ROCm Docker image with device passthrough for AMD GPUs.
3712
+ """
3713
+ from datetime import datetime
3714
+
3715
+ import trio_asyncio
3716
+ from wafer_core.async_ssh import AsyncSSHClient
3717
+ from wafer_core.targets.digitalocean import digitalocean_ssh_context
3718
+
3719
+ CONTAINER_WORKSPACE = "/workspace"
3720
+ REMOTE_WORKSPACE_BASE = "~/.wafer/workspaces"
3721
+
3722
+ docker_image = DEFAULT_ROCM_DOCKER_IMAGE
3723
+
3724
+ # Select GPU
3725
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
3726
+
3727
+ print("Provisioning/connecting to DigitalOcean droplet...")
3728
+
3729
+ async with digitalocean_ssh_context(target) as ssh_info:
3730
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
3731
+ print(f"Connected to {ssh_target}")
3732
+
3733
+ async with trio_asyncio.open_loop():
3734
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
3735
+ # Ensure Docker is installed
3736
+ docker_check = await client.exec("which docker")
3737
+ if docker_check.exit_code != 0:
3738
+ print("Docker not found, installing...")
3739
+ install_result = await client.exec(
3740
+ "apt-get update -qq && apt-get install -y -qq docker.io"
3741
+ )
3742
+ if install_result.exit_code != 0:
3743
+ return EvaluateResult(
3744
+ success=False,
3745
+ all_correct=False,
3746
+ correctness_score=0.0,
3747
+ geomean_speedup=0.0,
3748
+ passed_tests=0,
3749
+ total_tests=0,
3750
+ error_message=f"Failed to install Docker: {install_result.stderr}",
3751
+ )
3752
+
3753
+ # Create workspace
3754
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3755
+ run_dir = f"kernelbench_eval_{timestamp}"
3756
+ workspace_path = await client.expand_path(f"{REMOTE_WORKSPACE_BASE}/kernelbench")
3757
+ run_path = f"{workspace_path}/{run_dir}"
3758
+
3759
+ await client.exec(f"mkdir -p {run_path}")
3760
+ print(f"Created run directory: {run_path}")
3761
+
3762
+ # Read and upload files
3763
+ impl_code = args.implementation.read_text()
3764
+ ref_code = args.reference.read_text()
3765
+
3766
+ # Write implementation
3767
+ impl_path = f"{run_path}/implementation.py"
3768
+ write_result = await client.exec(
3769
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
3770
+ )
3771
+ if write_result.exit_code != 0:
3772
+ return EvaluateResult(
3773
+ success=False,
3774
+ all_correct=False,
3775
+ correctness_score=0.0,
3776
+ geomean_speedup=0.0,
3777
+ passed_tests=0,
3778
+ total_tests=0,
3779
+ error_message=f"Failed to write implementation: {write_result.stderr}",
3780
+ )
3781
+
3782
+ # Write reference
3783
+ ref_path = f"{run_path}/reference.py"
3784
+ write_result = await client.exec(
3785
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
3786
+ )
3787
+ if write_result.exit_code != 0:
3788
+ return EvaluateResult(
3789
+ success=False,
3790
+ all_correct=False,
3791
+ correctness_score=0.0,
3792
+ geomean_speedup=0.0,
3793
+ passed_tests=0,
3794
+ total_tests=0,
3795
+ error_message=f"Failed to write reference: {write_result.stderr}",
3796
+ )
3797
+
3798
+ # Write custom inputs if provided
3799
+ if args.inputs:
3800
+ inputs_code = args.inputs.read_text()
3801
+ inputs_file_path = f"{run_path}/custom_inputs.py"
3802
+ write_result = await client.exec(
3803
+ f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3804
+ )
3805
+ if write_result.exit_code != 0:
3806
+ return EvaluateResult(
3807
+ success=False,
3808
+ all_correct=False,
3809
+ correctness_score=0.0,
3810
+ geomean_speedup=0.0,
3811
+ passed_tests=0,
3812
+ total_tests=0,
3813
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3814
+ )
3815
+
3816
+ # Write eval script
3817
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
3818
+ write_result = await client.exec(
3819
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
3820
+ )
3821
+ if write_result.exit_code != 0:
3822
+ return EvaluateResult(
3823
+ success=False,
3824
+ all_correct=False,
3825
+ correctness_score=0.0,
3826
+ geomean_speedup=0.0,
3827
+ passed_tests=0,
3828
+ total_tests=0,
3829
+ error_message=f"Failed to write eval script: {write_result.stderr}",
3830
+ )
3831
+
3832
+ # Write defense module if defensive mode is enabled
3833
+ defense_module_path = None
3834
+ if args.defensive:
3835
+ defense_path = (
3836
+ Path(__file__).parent.parent.parent.parent
3837
+ / "packages"
3838
+ / "wafer-core"
3839
+ / "wafer_core"
3840
+ / "utils"
3841
+ / "kernel_utils"
3842
+ / "defense.py"
3843
+ )
3844
+ if defense_path.exists():
3845
+ defense_code = defense_path.read_text()
3846
+ defense_module_path = f"{run_path}/defense.py"
3847
+ write_result = await client.exec(
3848
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3849
+ )
3850
+ if write_result.exit_code != 0:
3851
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3852
+ defense_module_path = None
3853
+ else:
3854
+ print(f"Warning: defense.py not found at {defense_path}")
3855
+
3856
+ print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
3857
+
3858
+ # Paths inside container
3859
+ container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
3860
+ container_impl_path = f"{container_run_path}/implementation.py"
3861
+ container_ref_path = f"{container_run_path}/reference.py"
3862
+ container_inputs_path = (
3863
+ f"{container_run_path}/custom_inputs.py" if args.inputs else None
3864
+ )
3865
+ container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3866
+ container_output = f"{container_run_path}/results.json"
3867
+ container_defense_path = (
3868
+ f"{container_run_path}/defense.py" if defense_module_path else None
3869
+ )
3870
+
3871
+ # Build eval command
3872
+ python_cmd_parts = [
3873
+ f"python3 {container_eval_script}",
3874
+ f"--impl {container_impl_path}",
3875
+ f"--reference {container_ref_path}",
3876
+ f"--output {container_output}",
3877
+ ]
3878
+
3879
+ if args.benchmark:
3880
+ python_cmd_parts.append("--benchmark")
3881
+ if args.profile:
3882
+ python_cmd_parts.append("--profile")
3883
+ if container_inputs_path:
3884
+ python_cmd_parts.append(f"--inputs {container_inputs_path}")
3885
+ if args.defensive and container_defense_path:
3886
+ python_cmd_parts.append("--defensive")
3887
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3888
+ python_cmd_parts.append(f"--seed {args.seed}")
3889
+ python_cmd_parts.append(f"--stages {args.stages}")
3890
+
3891
+ eval_cmd = " ".join(python_cmd_parts)
3892
+
3893
+ # For AMD, we don't need pip install - the ROCm image has everything
3894
+ full_cmd = f"cd {container_run_path} && {eval_cmd}"
3895
+
3896
+ # Build Docker command for AMD
3897
+ # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
3898
+ rocm_arch = _get_rocm_arch(target.compute_capability)
3899
+ env_dict = {
3900
+ "HIP_VISIBLE_DEVICES": str(gpu_id),
3901
+ "PYTHONUNBUFFERED": "1",
3902
+ }
3903
+ if rocm_arch:
3904
+ env_dict["PYTORCH_ROCM_ARCH"] = rocm_arch
3905
+
3906
+ docker_cmd = _build_docker_run_command_amd(
3907
+ image=docker_image,
3908
+ command=full_cmd,
3909
+ working_dir=container_run_path,
3910
+ env=env_dict,
3911
+ volumes={workspace_path: CONTAINER_WORKSPACE},
3912
+ )
3913
+
3914
+ print(f"Docker command: {docker_cmd[:100]}...")
3915
+
3916
+ # Run and stream output
3917
+ log_lines = []
3918
+ async for line in client.exec_stream(docker_cmd):
3919
+ print(line, flush=True)
3920
+ log_lines.append(line)
3921
+
3922
+ # Read results
3923
+ results_path = f"{run_path}/results.json"
3924
+ cat_result = await client.exec(f"cat {results_path}")
3925
+
3926
+ if cat_result.exit_code != 0:
3927
+ log_tail = "\n".join(log_lines[-50:])
3928
+ return EvaluateResult(
3929
+ success=False,
3930
+ all_correct=False,
3931
+ correctness_score=0.0,
3932
+ geomean_speedup=0.0,
3933
+ passed_tests=0,
3934
+ total_tests=0,
3935
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
3936
+ )
3937
+
3938
+ # Parse results
3939
+ try:
3940
+ results_data = json.loads(cat_result.stdout)
3941
+ except json.JSONDecodeError as e:
3942
+ return EvaluateResult(
3943
+ success=False,
3944
+ all_correct=False,
3945
+ correctness_score=0.0,
3946
+ geomean_speedup=0.0,
3947
+ passed_tests=0,
3948
+ total_tests=0,
3949
+ error_message=f"Failed to parse results: {e}",
3950
+ )
3951
+
3952
+ # Convert to EvaluateResult
3953
+ # TODO: use compiled field - currently ignored, should affect success/error
3954
+ # compiled = results_data.get("compiled", False)
3955
+ correct = results_data.get("correct", False)
3956
+ speedup = results_data.get("speedup", 0.0) or 0.0
3957
+ error = results_data.get("error")
3958
+
3959
+ if error:
3960
+ return EvaluateResult(
3961
+ success=False,
3962
+ all_correct=False,
3963
+ correctness_score=0.0,
3964
+ geomean_speedup=0.0,
3965
+ passed_tests=0,
3966
+ total_tests=1,
3967
+ error_message=error,
3968
+ )
3969
+
3970
+ return EvaluateResult(
3971
+ success=True,
3972
+ all_correct=correct,
3973
+ correctness_score=1.0 if correct else 0.0,
3974
+ geomean_speedup=speedup,
3975
+ passed_tests=1 if correct else 0,
3976
+ total_tests=1,
3977
+ )
3978
+
3979
+
3980
+ async def run_evaluate_kernelbench_runpod(
3981
+ args: KernelBenchEvaluateArgs,
3982
+ target: RunPodTarget,
3983
+ ) -> EvaluateResult:
3984
+ """Run KernelBench format evaluation directly on RunPod AMD GPU.
3985
+
3986
+ Runs evaluation script directly on host (no Docker) since RunPod pods
3987
+ already have PyTorch/ROCm installed.
3988
+ """
3989
+ from datetime import datetime
3990
+
3991
+ from wafer_core.async_ssh import AsyncSSHClient
3992
+ from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
3993
+
3994
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
3995
+
3996
+ # Select GPU
3997
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
3998
+
3999
+ print(f"Provisioning RunPod ({target.gpu_type_id})...")
4000
+
4001
+ try:
4002
+ async with runpod_ssh_context(target) as ssh_info:
4003
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
4004
+ print(f"Connected to RunPod: {ssh_target}")
4005
+
4006
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
4007
+ # Create workspace
4008
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4009
+ run_dir = f"kernelbench_eval_{timestamp}"
4010
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
4011
+
4012
+ await client.exec(f"mkdir -p {run_path}")
4013
+ print(f"Created run directory: {run_path}")
4014
+
4015
+ # Read and upload files
4016
+ impl_code = args.implementation.read_text()
4017
+ ref_code = args.reference.read_text()
4018
+
4019
+ # Write implementation
4020
+ impl_path = f"{run_path}/implementation.py"
4021
+ write_result = await client.exec(
4022
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4023
+ )
4024
+ if write_result.exit_code != 0:
4025
+ return EvaluateResult(
4026
+ success=False,
4027
+ all_correct=False,
4028
+ correctness_score=0.0,
4029
+ geomean_speedup=0.0,
4030
+ passed_tests=0,
4031
+ total_tests=0,
4032
+ error_message=f"Failed to write implementation: {write_result.stderr}",
4033
+ )
4034
+
4035
+ # Write reference
4036
+ ref_path = f"{run_path}/reference.py"
4037
+ write_result = await client.exec(
4038
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
4039
+ )
4040
+ if write_result.exit_code != 0:
4041
+ return EvaluateResult(
4042
+ success=False,
4043
+ all_correct=False,
4044
+ correctness_score=0.0,
4045
+ geomean_speedup=0.0,
4046
+ passed_tests=0,
4047
+ total_tests=0,
4048
+ error_message=f"Failed to write reference: {write_result.stderr}",
4049
+ )
4050
+
4051
+ # Write custom inputs if provided
4052
+ inputs_path = None
4053
+ if args.inputs:
4054
+ inputs_code = args.inputs.read_text()
4055
+ inputs_path = f"{run_path}/custom_inputs.py"
4056
+ write_result = await client.exec(
4057
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4058
+ )
4059
+ if write_result.exit_code != 0:
4060
+ return EvaluateResult(
4061
+ success=False,
4062
+ all_correct=False,
4063
+ correctness_score=0.0,
4064
+ geomean_speedup=0.0,
4065
+ passed_tests=0,
4066
+ total_tests=0,
4067
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
4068
+ )
4069
+
4070
+ # Write eval script
4071
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
4072
+ write_result = await client.exec(
4073
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4074
+ )
4075
+ if write_result.exit_code != 0:
4076
+ return EvaluateResult(
4077
+ success=False,
4078
+ all_correct=False,
4079
+ correctness_score=0.0,
4080
+ geomean_speedup=0.0,
4081
+ passed_tests=0,
4082
+ total_tests=0,
4083
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4084
+ )
4085
+
4086
+ # Write defense module if defensive mode is enabled
4087
+ defense_module_path = None
4088
+ if args.defensive:
4089
+ defense_path = (
4090
+ Path(__file__).parent.parent.parent.parent
4091
+ / "packages"
4092
+ / "wafer-core"
4093
+ / "wafer_core"
4094
+ / "utils"
4095
+ / "kernel_utils"
4096
+ / "defense.py"
4097
+ )
4098
+ if defense_path.exists():
4099
+ defense_code = defense_path.read_text()
4100
+ defense_module_path = f"{run_path}/defense.py"
4101
+ write_result = await client.exec(
4102
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4103
+ )
4104
+ if write_result.exit_code != 0:
4105
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4106
+ defense_module_path = None
4107
+ else:
4108
+ print(f"Warning: defense.py not found at {defense_path}")
4109
+
4110
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4111
+
4112
+ # Find Python with PyTorch - check common locations on RunPod
4113
+ python_exe = "python3"
4114
+ for candidate in [
4115
+ "/opt/conda/envs/py_3.10/bin/python3",
4116
+ "/opt/conda/bin/python3",
4117
+ ]:
4118
+ check = await client.exec(
4119
+ f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
4120
+ )
4121
+ if "OK" in check.stdout:
4122
+ python_exe = candidate
4123
+ print(f"Using Python: {python_exe}")
4124
+ break
4125
+
4126
+ # Build eval command - run directly on host
4127
+ output_path = f"{run_path}/results.json"
4128
+ python_cmd_parts = [
4129
+ f"{python_exe} {eval_script_path}",
4130
+ f"--impl {impl_path}",
4131
+ f"--reference {ref_path}",
4132
+ f"--output {output_path}",
4133
+ ]
4134
+
4135
+ if args.benchmark:
4136
+ python_cmd_parts.append("--benchmark")
4137
+ if args.profile:
4138
+ python_cmd_parts.append("--profile")
4139
+ if inputs_path:
4140
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4141
+ if args.defensive and defense_module_path:
4142
+ python_cmd_parts.append("--defensive")
4143
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4144
+ python_cmd_parts.append(f"--seed {args.seed}")
4145
+ python_cmd_parts.append(f"--stages {args.stages}")
4146
+
4147
+ eval_cmd = " ".join(python_cmd_parts)
4148
+
4149
+ # Set environment for AMD GPU and run
4150
+ # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
4151
+ rocm_arch = _get_rocm_arch(target.compute_capability)
4152
+ arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4153
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4154
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4155
+
4156
+ # Handle prepare-only mode
4157
+ if args.prepare_only:
4158
+ print(f"\n[wafer] Prepared evaluation at: {run_path}")
4159
+ print(f"[wafer] Target: {target.name} ({client.host}:{client.port})")
4160
+ print("[wafer] To run manually:")
4161
+ print(f" ssh -p {client.port} root@{client.host} '{full_cmd}'")
4162
+ print("\n[wafer] Or wrap with rocprof:")
4163
+ print(
4164
+ f" ssh -p {client.port} root@{client.host} 'cd {run_path} && {env_vars} rocprof -i counters.txt {eval_cmd}'"
4165
+ )
4166
+ return EvaluateResult(
4167
+ success=True,
4168
+ all_correct=None, # Not checked in prepare-only mode
4169
+ correctness_score=0.0,
4170
+ geomean_speedup=0.0,
4171
+ passed_tests=0,
4172
+ total_tests=0,
4173
+ error_message=None,
4174
+ )
4175
+
4176
+ # Run and stream output
4177
+ log_lines = []
4178
+ async for line in client.exec_stream(full_cmd):
4179
+ print(line, flush=True)
4180
+ log_lines.append(line)
4181
+
4182
+ # Read results
4183
+ cat_result = await client.exec(f"cat {output_path}")
4184
+
4185
+ if cat_result.exit_code != 0:
4186
+ log_tail = "\n".join(log_lines[-50:])
4187
+ return EvaluateResult(
4188
+ success=False,
4189
+ all_correct=False,
4190
+ correctness_score=0.0,
4191
+ geomean_speedup=0.0,
4192
+ passed_tests=0,
4193
+ total_tests=0,
4194
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4195
+ )
4196
+
4197
+ # Parse results
4198
+ try:
4199
+ results_data = json.loads(cat_result.stdout)
4200
+ except json.JSONDecodeError as e:
4201
+ return EvaluateResult(
4202
+ success=False,
4203
+ all_correct=False,
4204
+ correctness_score=0.0,
4205
+ geomean_speedup=0.0,
4206
+ passed_tests=0,
4207
+ total_tests=0,
4208
+ error_message=f"Failed to parse results: {e}",
4209
+ )
4210
+
4211
+ # Convert to EvaluateResult
4212
+ correct = results_data.get("correct", False)
4213
+ speedup = results_data.get("speedup", 0.0) or 0.0
4214
+ error = results_data.get("error")
4215
+
4216
+ if error:
4217
+ return EvaluateResult(
4218
+ success=False,
4219
+ all_correct=False,
4220
+ correctness_score=0.0,
4221
+ geomean_speedup=0.0,
4222
+ passed_tests=0,
4223
+ total_tests=1,
4224
+ error_message=error,
4225
+ )
4226
+
4227
+ return EvaluateResult(
4228
+ success=True,
4229
+ all_correct=correct,
4230
+ correctness_score=1.0 if correct else 0.0,
4231
+ geomean_speedup=speedup,
4232
+ passed_tests=1 if correct else 0,
4233
+ total_tests=1,
4234
+ )
4235
+
4236
+ except RunPodError as e:
4237
+ return EvaluateResult(
4238
+ success=False,
4239
+ all_correct=False,
4240
+ correctness_score=0.0,
4241
+ geomean_speedup=0.0,
4242
+ passed_tests=0,
4243
+ total_tests=0,
4244
+ error_message=f"RunPod error: {e}",
4245
+ )
4246
+
4247
+
4248
+ async def run_evaluate_kernelbench_baremetal_amd(
4249
+ args: KernelBenchEvaluateArgs,
4250
+ target: BaremetalTarget,
4251
+ ) -> EvaluateResult:
4252
+ """Run KernelBench format evaluation directly on AMD baremetal target.
4253
+
4254
+ Runs evaluation script directly on host (no Docker) for AMD GPUs
4255
+ that have PyTorch/ROCm installed.
4256
+ """
4257
+ from datetime import datetime
4258
+
4259
+ from wafer_core.async_ssh import AsyncSSHClient
4260
+
4261
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
4262
+
4263
+ # Select GPU
4264
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
4265
+
4266
+ print(f"Connecting to {target.ssh_target}...")
4267
+
4268
+ async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
4269
+ # Create workspace
4270
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4271
+ run_dir = f"kernelbench_eval_{timestamp}"
4272
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
4273
+
4274
+ await client.exec(f"mkdir -p {run_path}")
4275
+ print(f"Created run directory: {run_path}")
4276
+
4277
+ # Read and upload files
4278
+ impl_code = args.implementation.read_text()
4279
+ ref_code = args.reference.read_text()
4280
+
4281
+ # Write implementation
4282
+ impl_path = f"{run_path}/implementation.py"
4283
+ write_result = await client.exec(
4284
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4285
+ )
4286
+ if write_result.exit_code != 0:
4287
+ return EvaluateResult(
4288
+ success=False,
4289
+ all_correct=False,
4290
+ correctness_score=0.0,
4291
+ geomean_speedup=0.0,
4292
+ passed_tests=0,
4293
+ total_tests=0,
4294
+ error_message=f"Failed to write implementation: {write_result.stderr}",
4295
+ )
4296
+
4297
+ # Write reference
4298
+ ref_path = f"{run_path}/reference.py"
4299
+ write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
4300
+ if write_result.exit_code != 0:
4301
+ return EvaluateResult(
4302
+ success=False,
4303
+ all_correct=False,
4304
+ correctness_score=0.0,
4305
+ geomean_speedup=0.0,
4306
+ passed_tests=0,
4307
+ total_tests=0,
4308
+ error_message=f"Failed to write reference: {write_result.stderr}",
4309
+ )
4310
+
4311
+ # Write custom inputs if provided
4312
+ inputs_path = None
4313
+ if args.inputs:
4314
+ inputs_code = args.inputs.read_text()
4315
+ inputs_path = f"{run_path}/custom_inputs.py"
4316
+ write_result = await client.exec(
4317
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4318
+ )
4319
+ if write_result.exit_code != 0:
4320
+ return EvaluateResult(
4321
+ success=False,
4322
+ all_correct=False,
4323
+ correctness_score=0.0,
4324
+ geomean_speedup=0.0,
4325
+ passed_tests=0,
4326
+ total_tests=0,
4327
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
4328
+ )
4329
+
4330
+ # Write eval script
4331
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
4332
+ write_result = await client.exec(
4333
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4334
+ )
4335
+ if write_result.exit_code != 0:
4336
+ return EvaluateResult(
4337
+ success=False,
4338
+ all_correct=False,
4339
+ correctness_score=0.0,
4340
+ geomean_speedup=0.0,
4341
+ passed_tests=0,
4342
+ total_tests=0,
4343
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4344
+ )
4345
+
4346
+ # Write defense module if defensive mode is enabled
4347
+ defense_module_path = None
4348
+ if args.defensive:
4349
+ defense_path = (
4350
+ Path(__file__).parent.parent.parent.parent
4351
+ / "packages"
4352
+ / "wafer-core"
4353
+ / "wafer_core"
4354
+ / "utils"
4355
+ / "kernel_utils"
4356
+ / "defense.py"
4357
+ )
4358
+ if defense_path.exists():
4359
+ defense_code = defense_path.read_text()
4360
+ defense_module_path = f"{run_path}/defense.py"
4361
+ write_result = await client.exec(
4362
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4363
+ )
4364
+ if write_result.exit_code != 0:
4365
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4366
+ defense_module_path = None
4367
+ else:
4368
+ print(f"Warning: defense.py not found at {defense_path}")
4369
+
4370
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4371
+
4372
+ # Find Python with PyTorch - check common locations
4373
+ python_exe = "python3"
4374
+ for candidate in [
4375
+ "/opt/conda/envs/py_3.10/bin/python3",
4376
+ "/opt/conda/bin/python3",
4377
+ ]:
4378
+ check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
4379
+ if "OK" in check.stdout:
4380
+ python_exe = candidate
4381
+ print(f"Using Python: {python_exe}")
4382
+ break
4383
+
4384
+ # Build eval command - run directly on host
4385
+ output_path = f"{run_path}/results.json"
4386
+ python_cmd_parts = [
4387
+ f"{python_exe} {eval_script_path}",
4388
+ f"--impl {impl_path}",
4389
+ f"--reference {ref_path}",
4390
+ f"--output {output_path}",
4391
+ ]
4392
+
4393
+ if args.benchmark:
4394
+ python_cmd_parts.append("--benchmark")
4395
+ if args.profile:
4396
+ python_cmd_parts.append("--profile")
4397
+ if inputs_path:
4398
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4399
+ if args.defensive and defense_module_path:
4400
+ python_cmd_parts.append("--defensive")
4401
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4402
+ python_cmd_parts.append(f"--seed {args.seed}")
4403
+ python_cmd_parts.append(f"--stages {args.stages}")
4404
+
4405
+ eval_cmd = " ".join(python_cmd_parts)
4406
+
4407
+ # Set environment for AMD GPU and run
4408
+ # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
4409
+ rocm_arch = _get_rocm_arch(target.compute_capability)
4410
+ arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4411
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4412
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4413
+
4414
+ # Handle prepare-only mode
4415
+ if args.prepare_only:
4416
+ print(f"\n[wafer] Prepared evaluation at: {run_path}")
4417
+ print(f"[wafer] Target: {target.name} ({client.host}:{client.port})")
4418
+ print("[wafer] To run manually:")
4419
+ print(f" ssh -p {client.port} root@{client.host} '{full_cmd}'")
4420
+ print("\n[wafer] Or wrap with rocprof:")
4421
+ print(
4422
+ f" ssh -p {client.port} root@{client.host} 'cd {run_path} && {env_vars} rocprof -i counters.txt {eval_cmd}'"
4423
+ )
4424
+ return EvaluateResult(
4425
+ success=True,
4426
+ all_correct=None, # Not checked in prepare-only mode
4427
+ correctness_score=0.0,
4428
+ geomean_speedup=0.0,
4429
+ passed_tests=0,
4430
+ total_tests=0,
4431
+ error_message=None,
4432
+ )
4433
+
4434
+ # Run and stream output
4435
+ log_lines = []
4436
+ async for line in client.exec_stream(full_cmd):
4437
+ print(line, flush=True)
4438
+ log_lines.append(line)
4439
+
4440
+ # Read results
4441
+ cat_result = await client.exec(f"cat {output_path}")
4442
+
4443
+ if cat_result.exit_code != 0:
4444
+ log_tail = "\n".join(log_lines[-50:])
4445
+ return EvaluateResult(
4446
+ success=False,
4447
+ all_correct=False,
4448
+ correctness_score=0.0,
4449
+ geomean_speedup=0.0,
4450
+ passed_tests=0,
4451
+ total_tests=0,
4452
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4453
+ )
4454
+
4455
+ # Parse results
4456
+ try:
4457
+ results_data = json.loads(cat_result.stdout)
4458
+ except json.JSONDecodeError as e:
4459
+ return EvaluateResult(
4460
+ success=False,
4461
+ all_correct=False,
4462
+ correctness_score=0.0,
4463
+ geomean_speedup=0.0,
4464
+ passed_tests=0,
4465
+ total_tests=0,
4466
+ error_message=f"Failed to parse results: {e}",
4467
+ )
4468
+
4469
+ # Convert to EvaluateResult
4470
+ correct = results_data.get("correct", False)
4471
+ speedup = results_data.get("speedup", 0.0) or 0.0
4472
+ error = results_data.get("error")
4473
+
4474
+ if error:
4475
+ return EvaluateResult(
4476
+ success=False,
4477
+ all_correct=False,
4478
+ correctness_score=0.0,
4479
+ geomean_speedup=0.0,
4480
+ passed_tests=0,
4481
+ total_tests=1,
4482
+ error_message=error,
4483
+ )
4484
+
4485
+ return EvaluateResult(
4486
+ success=True,
4487
+ all_correct=correct,
4488
+ correctness_score=1.0 if correct else 0.0,
4489
+ geomean_speedup=speedup,
4490
+ passed_tests=1 if correct else 0,
4491
+ total_tests=1,
4492
+ )
4493
+
4494
+
4495
+ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
4496
+ """Run KernelBench format evaluation on configured target.
4497
+
4498
+ Args:
4499
+ args: KernelBench evaluate arguments
4500
+
4501
+ Returns:
4502
+ Evaluation result
4503
+ """
4504
+ from .targets import get_default_target, load_target
4505
+
4506
+ # Validate input files
4507
+ err = _validate_kernelbench_files(args)
4508
+ if err:
4509
+ return EvaluateResult(
4510
+ success=False,
4511
+ all_correct=False,
4512
+ correctness_score=0.0,
4513
+ geomean_speedup=0.0,
4514
+ passed_tests=0,
4515
+ total_tests=0,
4516
+ error_message=err,
4517
+ )
4518
+
4519
+ # Load target
4520
+ target_name = args.target_name
4521
+ if not target_name:
4522
+ target_name = get_default_target()
4523
+ if not target_name:
4524
+ return EvaluateResult(
4525
+ success=False,
4526
+ all_correct=False,
4527
+ correctness_score=0.0,
4528
+ geomean_speedup=0.0,
4529
+ passed_tests=0,
4530
+ total_tests=0,
4531
+ error_message=(
4532
+ "No target specified and no default set.\n"
4533
+ "Set up a target first:\n"
4534
+ " wafer config targets init ssh --name my-gpu --host user@host:22\n"
4535
+ " wafer config targets init runpod --gpu MI300X\n"
4536
+ "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
4537
+ ),
4538
+ )
4539
+
4540
+ try:
4541
+ target = load_target(target_name)
4542
+ except FileNotFoundError:
4543
+ return EvaluateResult(
4544
+ success=False,
4545
+ all_correct=False,
4546
+ correctness_score=0.0,
4547
+ geomean_speedup=0.0,
4548
+ passed_tests=0,
4549
+ total_tests=0,
4550
+ error_message=f"Target not found: {target_name}. Run: wafer config targets list",
4551
+ )
4552
+
4553
+ print(f"Using target: {target_name}")
4554
+
4555
+ # Dispatch to appropriate executor
4556
+ if isinstance(target, DigitalOceanTarget):
4557
+ # DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
4558
+ return await run_evaluate_kernelbench_digitalocean(args, target)
4559
+ elif isinstance(target, RunPodTarget):
4560
+ # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4561
+ return await run_evaluate_kernelbench_runpod(args, target)
4562
+ elif isinstance(target, BaremetalTarget | VMTarget):
4563
+ # Check if this is an AMD target (gfx* compute capability) - run directly
4564
+ if target.compute_capability and target.compute_capability.startswith("gfx"):
4565
+ return await run_evaluate_kernelbench_baremetal_amd(args, target)
4566
+ # NVIDIA targets - require docker_image to be set
4567
+ if not target.docker_image:
4568
+ return EvaluateResult(
4569
+ success=False,
4570
+ all_correct=False,
4571
+ correctness_score=0.0,
4572
+ geomean_speedup=0.0,
4573
+ passed_tests=0,
4574
+ total_tests=0,
4575
+ error_message=(
4576
+ f"Target '{target_name}' does not have docker_image set. "
4577
+ "KernelBench format requires Docker execution."
4578
+ ),
4579
+ )
4580
+ return await run_evaluate_kernelbench_docker(args, target)
4581
+ else:
4582
+ return EvaluateResult(
4583
+ success=False,
4584
+ all_correct=False,
4585
+ correctness_score=0.0,
4586
+ geomean_speedup=0.0,
4587
+ passed_tests=0,
4588
+ total_tests=0,
4589
+ error_message=(
4590
+ f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
4591
+ "Use a DigitalOcean, RunPod, Baremetal, or VM target."
4592
+ ),
4593
+ )