wafer-cli 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/evaluate.py CHANGED
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
14
14
  from wafer_core.utils.kernel_utils.targets.config import (
15
15
  BaremetalTarget,
16
16
  DigitalOceanTarget,
17
+ LocalTarget,
17
18
  ModalTarget,
18
19
  RunPodTarget,
19
20
  VMTarget,
@@ -21,6 +22,30 @@ from wafer_core.utils.kernel_utils.targets.config import (
21
22
  )
22
23
 
23
24
 
25
+ # Map AMD compute capability to ROCm architecture
26
+ # Used to set PYTORCH_ROCM_ARCH for faster compilation (compile only for target arch)
27
+ AMD_CC_TO_ARCH = {
28
+ "9.4": "gfx942", # MI300X
29
+ "9.0a": "gfx90a", # MI200 series
30
+ "9.08": "gfx908", # MI100
31
+ "9.06": "gfx906", # MI50/60
32
+ "10.30": "gfx1030", # RDNA2
33
+ "11.0": "gfx1100", # RDNA3
34
+ }
35
+
36
+
37
+ def _get_rocm_arch(compute_capability: str) -> str | None:
38
+ """Get ROCm architecture string from compute capability.
39
+
40
+ Returns gfx* string for PYTORCH_ROCM_ARCH, or None if not found.
41
+ """
42
+ # Already a gfx string
43
+ if compute_capability.startswith("gfx"):
44
+ return compute_capability
45
+ # Map from numeric CC
46
+ return AMD_CC_TO_ARCH.get(compute_capability)
47
+
48
+
24
49
  def _build_docker_run_command(
25
50
  image: str,
26
51
  command: str,
@@ -161,6 +186,7 @@ class KernelBenchEvaluateArgs:
161
186
  inputs: Path | None = None # Custom inputs file to override get_inputs()
162
187
  seed: int = 42 # Random seed for reproducibility
163
188
  defensive: bool = False
189
+ backend: str | None = None # Kernel backend for static validation
164
190
  sync_artifacts: bool = True
165
191
  gpu_id: int | None = None
166
192
 
@@ -396,33 +422,6 @@ async def run_evaluate_docker(
396
422
  print(f"Connecting to {target.ssh_target}...")
397
423
 
398
424
  async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
399
- # Upload wafer-core to remote
400
- try:
401
- wafer_root = _get_wafer_root()
402
- wafer_core_path = wafer_root / "packages" / "wafer-core"
403
- print(f"Uploading wafer-core from {wafer_core_path}...")
404
-
405
- # Create workspace and upload
406
- workspace_name = wafer_core_path.name
407
- remote_workspace = f"{REMOTE_WORKSPACE_BASE}/{workspace_name}"
408
- await client.exec(f"mkdir -p {remote_workspace}")
409
- wafer_core_workspace = await client.expand_path(remote_workspace)
410
-
411
- upload_result = await client.upload_files(
412
- str(wafer_core_path), wafer_core_workspace, recursive=True
413
- )
414
- print(f"Uploaded {upload_result.files_copied} files")
415
- except Exception as e:
416
- return EvaluateResult(
417
- success=False,
418
- all_correct=False,
419
- correctness_score=0.0,
420
- geomean_speedup=0.0,
421
- passed_tests=0,
422
- total_tests=0,
423
- error_message=f"Failed to upload wafer-core: {e}",
424
- )
425
-
426
425
  print(f"Using Docker image: {target.docker_image}")
427
426
  print(f"Using GPU {gpu_id}...")
428
427
 
@@ -431,10 +430,13 @@ async def run_evaluate_docker(
431
430
  ref_code = args.reference.read_text()
432
431
  test_cases_data = json.loads(args.test_cases.read_text())
433
432
 
434
- # Create a unique run directory
433
+ # Create workspace for evaluation files
435
434
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
436
435
  run_dir = f"wafer_eval_{timestamp}"
437
- run_path = f"{wafer_core_workspace}/{run_dir}"
436
+ eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
437
+ await client.exec(f"mkdir -p {eval_workspace}")
438
+ eval_workspace_expanded = await client.expand_path(eval_workspace)
439
+ run_path = f"{eval_workspace_expanded}/{run_dir}"
438
440
 
439
441
  print("Uploading evaluation files...")
440
442
 
@@ -521,17 +523,16 @@ async def run_evaluate_docker(
521
523
  container_impl_path = f"{container_run_path}/implementation.py"
522
524
  container_ref_path = f"{container_run_path}/reference.py"
523
525
  container_test_cases_path = f"{container_run_path}/test_cases.json"
524
- container_evaluate_script = (
525
- f"{CONTAINER_WORKSPACE}/wafer_core/utils/kernel_utils/evaluate.py"
526
- )
527
526
 
528
- # Build pip install command for torch and other deps (no wafer-core install needed)
527
+ # Build pip install command for torch and other deps, plus wafer-core
529
528
  pip_install_cmd = _build_docker_pip_install_cmd(target)
529
+ install_cmd = (
530
+ f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
531
+ )
530
532
 
531
- # Build evaluate command - use PYTHONPATH instead of installing wafer-core
533
+ # Build evaluate command using installed wafer-core module
532
534
  python_cmd_parts = [
533
- f"PYTHONPATH={CONTAINER_WORKSPACE}:$PYTHONPATH",
534
- f"python3 {container_evaluate_script}",
535
+ "python3 -m wafer_core.utils.kernel_utils.evaluate",
535
536
  f"--implementation {container_impl_path}",
536
537
  f"--reference {container_ref_path}",
537
538
  f"--test-cases {container_test_cases_path}",
@@ -547,8 +548,8 @@ async def run_evaluate_docker(
547
548
 
548
549
  eval_cmd = " ".join(python_cmd_parts)
549
550
 
550
- # Full command: install torch deps, then run evaluate with PYTHONPATH
551
- full_cmd = f"{pip_install_cmd} && cd {container_run_path} && {eval_cmd}"
551
+ # Full command: install deps + wafer-core, then run evaluate
552
+ full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
552
553
 
553
554
  # Build Docker run command
554
555
  # Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
@@ -558,7 +559,7 @@ async def run_evaluate_docker(
558
559
  working_dir=container_run_path,
559
560
  env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
560
561
  gpus="all",
561
- volumes={wafer_core_workspace: CONTAINER_WORKSPACE},
562
+ volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
562
563
  cap_add=["SYS_ADMIN"] if args.profile else None,
563
564
  )
564
565
 
@@ -567,7 +568,7 @@ async def run_evaluate_docker(
567
568
  # Run Docker command and stream output
568
569
  log_lines = []
569
570
  async for line in client.exec_stream(docker_cmd):
570
- print(line)
571
+ print(line, flush=True)
571
572
  log_lines.append(line)
572
573
 
573
574
  # Read results
@@ -665,6 +666,181 @@ async def run_evaluate_docker(
665
666
  )
666
667
 
667
668
 
669
+ async def run_evaluate_local(
670
+ args: EvaluateArgs,
671
+ target: LocalTarget,
672
+ ) -> EvaluateResult:
673
+ """Run evaluation locally on the current machine.
674
+
675
+ For LocalTarget - no SSH needed, runs directly.
676
+
677
+ Args:
678
+ args: Evaluate arguments
679
+ target: Local target config
680
+
681
+ Returns:
682
+ Evaluation result
683
+ """
684
+ import os
685
+ import subprocess
686
+ import tempfile
687
+ from datetime import datetime
688
+
689
+ # Select GPU
690
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
691
+
692
+ print(f"Running local evaluation on GPU {gpu_id}...")
693
+
694
+ # Create temp directory for eval files
695
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
696
+ with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
697
+ run_path = Path(run_path)
698
+
699
+ # Write implementation
700
+ impl_path = run_path / "implementation.py"
701
+ impl_path.write_text(args.implementation.read_text())
702
+
703
+ # Write reference
704
+ ref_path = run_path / "reference.py"
705
+ ref_path.write_text(args.reference.read_text())
706
+
707
+ # Write custom inputs if provided
708
+ inputs_path = None
709
+ if args.inputs:
710
+ inputs_path = run_path / "custom_inputs.py"
711
+ inputs_path.write_text(args.inputs.read_text())
712
+
713
+ # Write eval script
714
+ eval_script_path = run_path / "kernelbench_eval.py"
715
+ eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
716
+
717
+ # Write defense module if defensive mode is enabled
718
+ defense_module_path = None
719
+ if args.defensive:
720
+ defense_src = (
721
+ Path(__file__).parent.parent.parent.parent
722
+ / "packages"
723
+ / "wafer-core"
724
+ / "wafer_core"
725
+ / "utils"
726
+ / "kernel_utils"
727
+ / "defense.py"
728
+ )
729
+ if defense_src.exists():
730
+ defense_module_path = run_path / "defense.py"
731
+ defense_module_path.write_text(defense_src.read_text())
732
+ else:
733
+ print(f"Warning: defense.py not found at {defense_src}")
734
+
735
+ # Output file
736
+ output_path = run_path / "results.json"
737
+
738
+ # Build eval command
739
+ cmd_parts = [
740
+ "python3",
741
+ str(eval_script_path),
742
+ "--impl",
743
+ str(impl_path),
744
+ "--reference",
745
+ str(ref_path),
746
+ "--output",
747
+ str(output_path),
748
+ "--seed",
749
+ str(args.seed),
750
+ ]
751
+
752
+ if args.benchmark:
753
+ cmd_parts.append("--benchmark")
754
+ if args.profile:
755
+ cmd_parts.append("--profile")
756
+ if inputs_path:
757
+ cmd_parts.extend(["--inputs", str(inputs_path)])
758
+ if args.defensive and defense_module_path:
759
+ cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
760
+
761
+ # Set environment for GPU selection
762
+ env = os.environ.copy()
763
+ if target.vendor == "nvidia":
764
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
765
+ else: # AMD
766
+ env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
767
+ env["ROCM_PATH"] = "/opt/rocm"
768
+
769
+ print(f"Running: {' '.join(cmd_parts[:4])} ...")
770
+
771
+ # Run evaluation
772
+ try:
773
+ result = subprocess.run(
774
+ cmd_parts,
775
+ cwd=str(run_path),
776
+ env=env,
777
+ capture_output=True,
778
+ text=True,
779
+ timeout=args.timeout or 600,
780
+ )
781
+ except subprocess.TimeoutExpired:
782
+ return EvaluateResult(
783
+ success=False,
784
+ all_correct=False,
785
+ correctness_score=0.0,
786
+ geomean_speedup=0.0,
787
+ passed_tests=0,
788
+ total_tests=0,
789
+ error_message="Evaluation timed out",
790
+ )
791
+
792
+ if result.returncode != 0:
793
+ error_msg = result.stderr or result.stdout or "Unknown error"
794
+ # Truncate long errors
795
+ if len(error_msg) > 1000:
796
+ error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
797
+ return EvaluateResult(
798
+ success=False,
799
+ all_correct=False,
800
+ correctness_score=0.0,
801
+ geomean_speedup=0.0,
802
+ passed_tests=0,
803
+ total_tests=0,
804
+ error_message=f"Evaluation failed:\n{error_msg}",
805
+ )
806
+
807
+ # Parse results
808
+ if not output_path.exists():
809
+ return EvaluateResult(
810
+ success=False,
811
+ all_correct=False,
812
+ correctness_score=0.0,
813
+ geomean_speedup=0.0,
814
+ passed_tests=0,
815
+ total_tests=0,
816
+ error_message="No results.json produced",
817
+ )
818
+
819
+ try:
820
+ results = json.loads(output_path.read_text())
821
+ except json.JSONDecodeError as e:
822
+ return EvaluateResult(
823
+ success=False,
824
+ all_correct=False,
825
+ correctness_score=0.0,
826
+ geomean_speedup=0.0,
827
+ passed_tests=0,
828
+ total_tests=0,
829
+ error_message=f"Failed to parse results: {e}",
830
+ )
831
+
832
+ # Extract results
833
+ return EvaluateResult(
834
+ success=True,
835
+ all_correct=results.get("all_correct", False),
836
+ correctness_score=results.get("correctness_score", 0.0),
837
+ geomean_speedup=results.get("geomean_speedup", 0.0),
838
+ passed_tests=results.get("passed_tests", 0),
839
+ total_tests=results.get("total_tests", 0),
840
+ benchmark_results=results.get("benchmark", {}),
841
+ )
842
+
843
+
668
844
  async def run_evaluate_ssh(
669
845
  args: EvaluateArgs,
670
846
  target: BaremetalTarget | VMTarget,
@@ -982,6 +1158,7 @@ def _build_modal_sandbox_script(
982
1158
  test_cases_b64: str,
983
1159
  run_benchmarks: bool,
984
1160
  run_defensive: bool,
1161
+ defense_code_b64: str | None = None,
985
1162
  ) -> str:
986
1163
  """Build Python script to create sandbox and run evaluation.
987
1164
 
@@ -1062,6 +1239,20 @@ print('Files written')
1062
1239
  print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
1063
1240
  return
1064
1241
 
1242
+ # Write defense module if defensive mode is enabled
1243
+ # NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
1244
+ if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
1245
+ proc = sandbox.exec("python", "-c", f"""
1246
+ import base64
1247
+ with open('/workspace/defense.py', 'w') as f:
1248
+ f.write(base64.b64decode('{defense_code_b64}').decode())
1249
+ print('Defense module written')
1250
+ """)
1251
+ proc.wait()
1252
+ if proc.returncode != 0:
1253
+ print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
1254
+ return
1255
+
1065
1256
  # Build inline evaluation script
1066
1257
  eval_script = """
1067
1258
  import json
@@ -1089,6 +1280,26 @@ generate_input = load_fn('reference.py', 'generate_input')
1089
1280
 
1090
1281
  import torch
1091
1282
 
1283
+ # Load defense module if available and defensive mode is enabled
1284
+ run_defensive = {run_defensive}
1285
+ defense = None
1286
+ if run_defensive:
1287
+ try:
1288
+ defense = load_fn('defense.py', 'run_all_defenses')
1289
+ time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
1290
+ print('[Defense] Defense module loaded')
1291
+
1292
+ # Wrap kernels for defense API compatibility
1293
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1294
+ # These wrappers repack the unpacked args back into a tuple
1295
+ def _wrap_for_defense(kernel):
1296
+ return lambda *args: kernel(args)
1297
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1298
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1299
+ except Exception as e:
1300
+ print(f'[Defense] Warning: Could not load defense module: {{e}}')
1301
+ defense = None
1302
+
1092
1303
  results = []
1093
1304
  all_correct = True
1094
1305
  total_time_ms = 0.0
@@ -1116,36 +1327,63 @@ for tc in test_cases:
1116
1327
  impl_time_ms = 0.0
1117
1328
  ref_time_ms = 0.0
1118
1329
  if {run_benchmarks}:
1119
- # Warmup
1120
- for _ in range(3):
1121
- custom_kernel(inputs)
1122
- torch.cuda.synchronize()
1123
-
1124
- # Measure with defensive timing if requested
1125
- # Defensive: sync before recording end event to catch stream injection
1126
- start = torch.cuda.Event(enable_timing=True)
1127
- end = torch.cuda.Event(enable_timing=True)
1128
- start.record()
1129
- for _ in range(10):
1130
- custom_kernel(inputs)
1131
- if {run_defensive}:
1132
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1133
- end.record()
1134
- torch.cuda.synchronize()
1135
- impl_time_ms = start.elapsed_time(end) / 10
1136
-
1137
- # Reference timing (same defensive approach)
1138
- for _ in range(3):
1139
- ref_kernel(inputs)
1140
- torch.cuda.synchronize()
1141
- start.record()
1142
- for _ in range(10):
1143
- ref_kernel(inputs)
1144
- if {run_defensive}:
1145
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1146
- end.record()
1147
- torch.cuda.synchronize()
1148
- ref_time_ms = start.elapsed_time(end) / 10
1330
+ if run_defensive and defense is not None:
1331
+ # Use full defense suite with wrapped kernels
1332
+ # inputs_list unpacks the tuple so defense can infer dtype/device from tensors
1333
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1334
+
1335
+ # Run defense checks
1336
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1337
+ if not all_passed:
1338
+ failed = [name for name, passed, _ in defense_results if not passed]
1339
+ raise ValueError(f"Defense checks failed: {{failed}}")
1340
+
1341
+ # Time with defensive timing (using wrapped kernels)
1342
+ impl_times, _ = time_with_defenses(
1343
+ custom_kernel_for_defense,
1344
+ inputs_list,
1345
+ num_warmup=3,
1346
+ num_trials=10,
1347
+ verbose=False,
1348
+ run_defenses=False,
1349
+ )
1350
+ impl_time_ms = sum(impl_times) / len(impl_times)
1351
+
1352
+ ref_times, _ = time_with_defenses(
1353
+ ref_kernel_for_defense,
1354
+ inputs_list,
1355
+ num_warmup=3,
1356
+ num_trials=10,
1357
+ verbose=False,
1358
+ run_defenses=False,
1359
+ )
1360
+ ref_time_ms = sum(ref_times) / len(ref_times)
1361
+ else:
1362
+ # Standard timing without full defenses
1363
+ # Warmup
1364
+ for _ in range(3):
1365
+ custom_kernel(inputs)
1366
+ torch.cuda.synchronize()
1367
+
1368
+ start = torch.cuda.Event(enable_timing=True)
1369
+ end = torch.cuda.Event(enable_timing=True)
1370
+ start.record()
1371
+ for _ in range(10):
1372
+ custom_kernel(inputs)
1373
+ end.record()
1374
+ torch.cuda.synchronize()
1375
+ impl_time_ms = start.elapsed_time(end) / 10
1376
+
1377
+ # Reference timing
1378
+ for _ in range(3):
1379
+ ref_kernel(inputs)
1380
+ torch.cuda.synchronize()
1381
+ start.record()
1382
+ for _ in range(10):
1383
+ ref_kernel(inputs)
1384
+ end.record()
1385
+ torch.cuda.synchronize()
1386
+ ref_time_ms = start.elapsed_time(end) / 10
1149
1387
 
1150
1388
  total_time_ms += impl_time_ms
1151
1389
  ref_total_time_ms += ref_time_ms
@@ -1197,7 +1435,7 @@ print(json.dumps({{
1197
1435
  # Find the last JSON line in output
1198
1436
  for line in reversed(stdout.strip().split("\\n")):
1199
1437
  if line.startswith("{{"):
1200
- print(line)
1438
+ print(line, flush=True)
1201
1439
  return
1202
1440
 
1203
1441
  print(json.dumps({{"error": f"No result JSON in output: {{stdout[:500]}}"}}))
@@ -1238,6 +1476,23 @@ async def run_evaluate_modal(
1238
1476
  ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
1239
1477
  test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
1240
1478
 
1479
+ # Encode defense module if defensive mode is enabled
1480
+ defense_code_b64 = None
1481
+ if args.defensive:
1482
+ defense_path = (
1483
+ Path(__file__).parent.parent.parent.parent
1484
+ / "packages"
1485
+ / "wafer-core"
1486
+ / "wafer_core"
1487
+ / "utils"
1488
+ / "kernel_utils"
1489
+ / "defense.py"
1490
+ )
1491
+ if defense_path.exists():
1492
+ defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
1493
+ else:
1494
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1495
+
1241
1496
  # Build the script that creates sandbox and runs eval
1242
1497
  script = _build_modal_sandbox_script(
1243
1498
  target=target,
@@ -1246,6 +1501,7 @@ async def run_evaluate_modal(
1246
1501
  test_cases_b64=test_cases_b64,
1247
1502
  run_benchmarks=args.benchmark,
1248
1503
  run_defensive=args.defensive,
1504
+ defense_code_b64=defense_code_b64,
1249
1505
  )
1250
1506
 
1251
1507
  def _run_subprocess() -> tuple[str, str, int]:
@@ -1343,6 +1599,7 @@ def _build_workspace_eval_script(
1343
1599
  test_cases_json: str,
1344
1600
  run_benchmarks: bool,
1345
1601
  run_defensive: bool = False,
1602
+ defense_code: str | None = None,
1346
1603
  ) -> str:
1347
1604
  """Build inline evaluation script for workspace exec.
1348
1605
 
@@ -1353,6 +1610,7 @@ def _build_workspace_eval_script(
1353
1610
  impl_b64 = base64.b64encode(impl_code.encode()).decode()
1354
1611
  ref_b64 = base64.b64encode(ref_code.encode()).decode()
1355
1612
  tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
1613
+ defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
1356
1614
 
1357
1615
  return f'''
1358
1616
  import base64
@@ -1372,6 +1630,15 @@ with open("/tmp/kernel.py", "w") as f:
1372
1630
  with open("/tmp/reference.py", "w") as f:
1373
1631
  f.write(ref_code)
1374
1632
 
1633
+ # Write defense module if available
1634
+ run_defensive = {run_defensive}
1635
+ defense_b64 = "{defense_b64}"
1636
+ # NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
1637
+ if run_defensive and defense_b64 and defense_b64 != "None":
1638
+ defense_code = base64.b64decode(defense_b64).decode()
1639
+ with open("/tmp/defense.py", "w") as f:
1640
+ f.write(defense_code)
1641
+
1375
1642
  # Load kernels
1376
1643
  def load_fn(path, name):
1377
1644
  spec = importlib.util.spec_from_file_location("mod", path)
@@ -1385,6 +1652,24 @@ generate_input = load_fn("/tmp/reference.py", "generate_input")
1385
1652
 
1386
1653
  import torch
1387
1654
 
1655
+ # Load defense module if available
1656
+ defense = None
1657
+ if run_defensive and defense_b64 and defense_b64 != "None":
1658
+ try:
1659
+ defense = load_fn("/tmp/defense.py", "run_all_defenses")
1660
+ time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
1661
+ print("[Defense] Defense module loaded")
1662
+
1663
+ # Wrap kernels for defense API compatibility
1664
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1665
+ def _wrap_for_defense(kernel):
1666
+ return lambda *args: kernel(args)
1667
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1668
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1669
+ except Exception as e:
1670
+ print(f"[Defense] Warning: Could not load defense module: {{e}}")
1671
+ defense = None
1672
+
1388
1673
  results = []
1389
1674
  all_correct = True
1390
1675
  total_time_ms = 0.0
@@ -1412,36 +1697,60 @@ for tc in test_cases:
1412
1697
  impl_time_ms = 0.0
1413
1698
  ref_time_ms = 0.0
1414
1699
  if {run_benchmarks}:
1415
- # Warmup
1416
- for _ in range(3):
1417
- custom_kernel(inputs)
1418
- torch.cuda.synchronize()
1419
-
1420
- # Measure with defensive timing if requested
1421
- # Defensive: sync before recording end event to catch stream injection
1422
- start = torch.cuda.Event(enable_timing=True)
1423
- end = torch.cuda.Event(enable_timing=True)
1424
- start.record()
1425
- for _ in range(10):
1426
- custom_kernel(inputs)
1427
- if {run_defensive}:
1428
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1429
- end.record()
1430
- torch.cuda.synchronize()
1431
- impl_time_ms = start.elapsed_time(end) / 10
1432
-
1433
- # Reference timing (same defensive approach)
1434
- for _ in range(3):
1435
- ref_kernel(inputs)
1436
- torch.cuda.synchronize()
1437
- start.record()
1438
- for _ in range(10):
1439
- ref_kernel(inputs)
1440
- if {run_defensive}:
1441
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1442
- end.record()
1443
- torch.cuda.synchronize()
1444
- ref_time_ms = start.elapsed_time(end) / 10
1700
+ if run_defensive and defense is not None:
1701
+ # Use full defense suite with wrapped kernels
1702
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1703
+
1704
+ # Run defense checks
1705
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1706
+ if not all_passed:
1707
+ failed = [name for name, passed, _ in defense_results if not passed]
1708
+ raise ValueError(f"Defense checks failed: {{failed}}")
1709
+
1710
+ # Time with defensive timing (using wrapped kernels)
1711
+ impl_times, _ = time_with_defenses(
1712
+ custom_kernel_for_defense,
1713
+ inputs_list,
1714
+ num_warmup=3,
1715
+ num_trials=10,
1716
+ verbose=False,
1717
+ run_defenses=False,
1718
+ )
1719
+ impl_time_ms = sum(impl_times) / len(impl_times)
1720
+
1721
+ ref_times, _ = time_with_defenses(
1722
+ ref_kernel_for_defense,
1723
+ inputs_list,
1724
+ num_warmup=3,
1725
+ num_trials=10,
1726
+ verbose=False,
1727
+ run_defenses=False,
1728
+ )
1729
+ ref_time_ms = sum(ref_times) / len(ref_times)
1730
+ else:
1731
+ # Standard timing
1732
+ for _ in range(3):
1733
+ custom_kernel(inputs)
1734
+ torch.cuda.synchronize()
1735
+
1736
+ start = torch.cuda.Event(enable_timing=True)
1737
+ end = torch.cuda.Event(enable_timing=True)
1738
+ start.record()
1739
+ for _ in range(10):
1740
+ custom_kernel(inputs)
1741
+ end.record()
1742
+ torch.cuda.synchronize()
1743
+ impl_time_ms = start.elapsed_time(end) / 10
1744
+
1745
+ for _ in range(3):
1746
+ ref_kernel(inputs)
1747
+ torch.cuda.synchronize()
1748
+ start.record()
1749
+ for _ in range(10):
1750
+ ref_kernel(inputs)
1751
+ end.record()
1752
+ torch.cuda.synchronize()
1753
+ ref_time_ms = start.elapsed_time(end) / 10
1445
1754
 
1446
1755
  total_time_ms += impl_time_ms
1447
1756
  ref_total_time_ms += ref_time_ms
@@ -1503,6 +1812,23 @@ async def run_evaluate_workspace(
1503
1812
  ref_code = args.reference.read_text()
1504
1813
  test_cases_json = args.test_cases.read_text()
1505
1814
 
1815
+ # Read defense module if defensive mode is enabled
1816
+ defense_code = None
1817
+ if args.defensive:
1818
+ defense_path = (
1819
+ Path(__file__).parent.parent.parent.parent
1820
+ / "packages"
1821
+ / "wafer-core"
1822
+ / "wafer_core"
1823
+ / "utils"
1824
+ / "kernel_utils"
1825
+ / "defense.py"
1826
+ )
1827
+ if defense_path.exists():
1828
+ defense_code = defense_path.read_text()
1829
+ else:
1830
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1831
+
1506
1832
  # Build inline eval script
1507
1833
  eval_script = _build_workspace_eval_script(
1508
1834
  impl_code=impl_code,
@@ -1510,6 +1836,7 @@ async def run_evaluate_workspace(
1510
1836
  test_cases_json=test_cases_json,
1511
1837
  run_benchmarks=args.benchmark,
1512
1838
  run_defensive=args.defensive,
1839
+ defense_code=defense_code,
1513
1840
  )
1514
1841
 
1515
1842
  # Execute via workspace exec
@@ -1855,15 +2182,12 @@ async def run_evaluate_runpod(
1855
2182
  # Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
1856
2183
  venv_bin = env_state.venv_bin
1857
2184
  env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
1858
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
1859
- evaluate_script = (
1860
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
1861
- )
1862
2185
 
1863
2186
  # Run from run_path so reference_kernel.py is importable
2187
+ # Use installed wafer-core module
1864
2188
  eval_cmd = (
1865
2189
  f"cd {run_path} && "
1866
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2190
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
1867
2191
  f"--implementation {impl_path} "
1868
2192
  f"--reference {ref_path} "
1869
2193
  f"--test-cases {test_cases_path} "
@@ -2219,15 +2543,12 @@ async def run_evaluate_digitalocean(
2219
2543
  env_vars = (
2220
2544
  f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
2221
2545
  )
2222
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
2223
- evaluate_script = (
2224
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
2225
- )
2226
2546
 
2227
2547
  # Run from run_path so reference_kernel.py is importable
2548
+ # Use installed wafer-core module
2228
2549
  eval_cmd = (
2229
2550
  f"cd {run_path} && "
2230
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2551
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
2231
2552
  f"--implementation {impl_path} "
2232
2553
  f"--reference {ref_path} "
2233
2554
  f"--test-cases {test_cases_path} "
@@ -2407,7 +2728,9 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2407
2728
  print(f"Using target: {target_name}")
2408
2729
 
2409
2730
  # Dispatch to appropriate executor
2410
- if isinstance(target, BaremetalTarget | VMTarget):
2731
+ if isinstance(target, LocalTarget):
2732
+ return await run_evaluate_local(args, target)
2733
+ elif isinstance(target, BaremetalTarget | VMTarget):
2411
2734
  return await run_evaluate_ssh(args, target)
2412
2735
  elif isinstance(target, ModalTarget):
2413
2736
  return await run_evaluate_modal(args, target)
@@ -2436,6 +2759,7 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2436
2759
  # Inline evaluation script for KernelBench format
2437
2760
  # This runs inside the Docker container on the remote GPU
2438
2761
  KERNELBENCH_EVAL_SCRIPT = """
2762
+ import gc
2439
2763
  import json
2440
2764
  import os
2441
2765
  import sys
@@ -2444,6 +2768,68 @@ import torch
2444
2768
  import torch.nn as nn
2445
2769
  from pathlib import Path
2446
2770
 
2771
+ # Use a unique per-run PyTorch extension cache directory to ensure fresh compilation.
2772
+ # This prevents stale cached extensions from being loaded when the pod is reused.
2773
+ # Without this, if a kernel is modified but uses the same extension name,
2774
+ # PyTorch would load the old cached .so instead of recompiling.
2775
+ # We use a UUID-based directory instead of clearing the cache to avoid race conditions
2776
+ # with other processes that might be using the cache.
2777
+ import uuid
2778
+ unique_cache_dir = f"/tmp/torch_extensions_{uuid.uuid4().hex[:8]}"
2779
+ os.environ["TORCH_EXTENSIONS_DIR"] = unique_cache_dir
2780
+ print(f"[KernelBench] Using unique extension cache: {unique_cache_dir}")
2781
+
2782
+ # Clear any stale GPU memory from previous runs at startup
2783
+ # NOTE: empty_cache only frees memory from THIS process's PyTorch allocator.
2784
+ # It won't free memory from dead/zombie processes - rocm-smi --showpids can show
2785
+ # PIDs that no longer exist but still hold GPU memory. Those require a GPU reset
2786
+ # (rocm-smi --gpureset) to fully clear. TODO: detect and warn about orphaned memory.
2787
+ if torch.cuda.is_available():
2788
+ gc.collect()
2789
+ torch.cuda.empty_cache()
2790
+ torch.cuda.reset_peak_memory_stats()
2791
+
2792
+
2793
+ def _calculate_timing_stats(times: list[float]) -> dict:
2794
+ '''Calculate median and IQR from timing samples.
2795
+
2796
+ Returns dict with median, iqr_low (25th percentile), iqr_high (75th percentile),
2797
+ mean, min, max, and std.
2798
+ '''
2799
+ import statistics
2800
+
2801
+ if not times:
2802
+ return {"median": 0, "iqr_low": 0, "iqr_high": 0, "mean": 0, "min": 0, "max": 0, "std": 0}
2803
+
2804
+ sorted_times = sorted(times)
2805
+ n = len(sorted_times)
2806
+
2807
+ # Median
2808
+ median = statistics.median(sorted_times)
2809
+
2810
+ # Quartiles (25th and 75th percentile)
2811
+ # For small samples, use simple interpolation
2812
+ q1_idx = (n - 1) * 0.25
2813
+ q3_idx = (n - 1) * 0.75
2814
+
2815
+ q1_low = int(q1_idx)
2816
+ q1_frac = q1_idx - q1_low
2817
+ iqr_low = sorted_times[q1_low] * (1 - q1_frac) + sorted_times[min(q1_low + 1, n - 1)] * q1_frac
2818
+
2819
+ q3_low = int(q3_idx)
2820
+ q3_frac = q3_idx - q3_low
2821
+ iqr_high = sorted_times[q3_low] * (1 - q3_frac) + sorted_times[min(q3_low + 1, n - 1)] * q3_frac
2822
+
2823
+ return {
2824
+ "median": median,
2825
+ "iqr_low": iqr_low,
2826
+ "iqr_high": iqr_high,
2827
+ "mean": statistics.mean(sorted_times),
2828
+ "min": min(sorted_times),
2829
+ "max": max(sorted_times),
2830
+ "std": statistics.stdev(sorted_times) if n > 1 else 0,
2831
+ }
2832
+
2447
2833
 
2448
2834
  def run_profiling(model, inputs, name, output_dir):
2449
2835
  '''Run torch.profiler and return summary stats.'''
@@ -2674,12 +3060,26 @@ def main():
2674
3060
  parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
2675
3061
  parser.add_argument("--benchmark", action="store_true")
2676
3062
  parser.add_argument("--profile", action="store_true")
3063
+ parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
3064
+ parser.add_argument("--defense-module", help="Path to defense.py module")
2677
3065
  parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
2678
3066
  parser.add_argument("--num-correct-trials", type=int, default=3)
2679
3067
  parser.add_argument("--num-perf-trials", type=int, default=10)
2680
3068
  parser.add_argument("--output", required=True)
2681
3069
  args = parser.parse_args()
2682
3070
 
3071
+ # Load defense module if defensive mode is enabled
3072
+ defense_module = None
3073
+ if args.defensive and args.defense_module:
3074
+ try:
3075
+ import importlib.util
3076
+ defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
3077
+ defense_module = importlib.util.module_from_spec(defense_spec)
3078
+ defense_spec.loader.exec_module(defense_module)
3079
+ print("[KernelBench] Defense module loaded")
3080
+ except Exception as e:
3081
+ print(f"[KernelBench] Warning: Could not load defense module: {e}")
3082
+
2683
3083
  # Create output directory for profiles
2684
3084
  output_dir = Path(args.output).parent
2685
3085
  profile_dir = output_dir / "profiles"
@@ -2813,47 +3213,102 @@ def main():
2813
3213
  inputs = get_inputs()
2814
3214
  inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
2815
3215
 
2816
- # Warmup
2817
- for _ in range(5):
2818
- with torch.no_grad():
2819
- _ = new_model(*inputs)
2820
- torch.cuda.synchronize()
2821
-
2822
- # Benchmark new model
2823
- start = torch.cuda.Event(enable_timing=True)
2824
- end = torch.cuda.Event(enable_timing=True)
2825
-
2826
- times = []
2827
- for _ in range(args.num_perf_trials):
2828
- start.record()
2829
- with torch.no_grad():
2830
- _ = new_model(*inputs)
2831
- end.record()
2832
- torch.cuda.synchronize()
2833
- times.append(start.elapsed_time(end))
2834
-
2835
- new_time = sum(times) / len(times)
2836
- results["runtime_ms"] = new_time
2837
-
2838
- # Benchmark reference model
2839
- for _ in range(5):
2840
- with torch.no_grad():
2841
- _ = ref_model(*inputs)
2842
- torch.cuda.synchronize()
3216
+ if args.defensive and defense_module is not None:
3217
+ # Use full defense suite
3218
+ print("[KernelBench] Running defense checks on implementation...")
3219
+ run_all_defenses = defense_module.run_all_defenses
3220
+ time_with_defenses = defense_module.time_execution_with_defenses
2843
3221
 
2844
- times = []
2845
- for _ in range(args.num_perf_trials):
2846
- start.record()
2847
- with torch.no_grad():
2848
- _ = ref_model(*inputs)
2849
- end.record()
3222
+ # Run defense checks on implementation
3223
+ all_passed, defense_results, _ = run_all_defenses(
3224
+ lambda *x: new_model(*x),
3225
+ *inputs,
3226
+ )
3227
+ results["defense_results"] = {
3228
+ name: {"passed": passed, "message": msg}
3229
+ for name, passed, msg in defense_results
3230
+ }
3231
+ if not all_passed:
3232
+ failed = [name for name, passed, _ in defense_results if not passed]
3233
+ results["error"] = f"Defense checks failed: {failed}"
3234
+ print(f"[KernelBench] Defense checks FAILED: {failed}")
3235
+ for name, passed, msg in defense_results:
3236
+ status = "PASS" if passed else "FAIL"
3237
+ print(f" [{status}] {name}: {msg}")
3238
+ else:
3239
+ print("[KernelBench] All defense checks passed")
3240
+
3241
+ # Time with defensive timing
3242
+ impl_times, _ = time_with_defenses(
3243
+ lambda: new_model(*inputs),
3244
+ [],
3245
+ num_warmup=5,
3246
+ num_trials=args.num_perf_trials,
3247
+ verbose=False,
3248
+ run_defenses=False, # Already ran above
3249
+ )
3250
+ # Calculate stats for new model
3251
+ new_stats = _calculate_timing_stats(impl_times)
3252
+ results["runtime_ms"] = new_stats["median"]
3253
+ results["runtime_stats"] = new_stats
3254
+
3255
+ # Reference timing
3256
+ ref_times, _ = time_with_defenses(
3257
+ lambda: ref_model(*inputs),
3258
+ [],
3259
+ num_warmup=5,
3260
+ num_trials=args.num_perf_trials,
3261
+ verbose=False,
3262
+ run_defenses=False,
3263
+ )
3264
+ ref_stats = _calculate_timing_stats(ref_times)
3265
+ results["reference_runtime_ms"] = ref_stats["median"]
3266
+ results["reference_runtime_stats"] = ref_stats
3267
+ results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
3268
+ print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
3269
+ else:
3270
+ # Standard timing without full defenses
3271
+ # Warmup BOTH models before benchmarking either
3272
+ # This ensures consistent GPU state and avoids MIOpen cache effects
3273
+ # that cause variance when warming up models sequentially
3274
+ for _ in range(5):
3275
+ with torch.no_grad():
3276
+ _ = new_model(*inputs)
3277
+ _ = ref_model(*inputs)
2850
3278
  torch.cuda.synchronize()
2851
- times.append(start.elapsed_time(end))
2852
3279
 
2853
- ref_time = sum(times) / len(times)
2854
- results["reference_runtime_ms"] = ref_time
2855
- results["speedup"] = ref_time / new_time if new_time > 0 else 0
2856
- print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
3280
+ # Benchmark new model
3281
+ start = torch.cuda.Event(enable_timing=True)
3282
+ end = torch.cuda.Event(enable_timing=True)
3283
+
3284
+ new_times = []
3285
+ for _ in range(args.num_perf_trials):
3286
+ start.record()
3287
+ with torch.no_grad():
3288
+ _ = new_model(*inputs)
3289
+ end.record()
3290
+ torch.cuda.synchronize()
3291
+ new_times.append(start.elapsed_time(end))
3292
+
3293
+ new_stats = _calculate_timing_stats(new_times)
3294
+ results["runtime_ms"] = new_stats["median"]
3295
+ results["runtime_stats"] = new_stats
3296
+
3297
+ # Benchmark reference model
3298
+ ref_times = []
3299
+ for _ in range(args.num_perf_trials):
3300
+ start.record()
3301
+ with torch.no_grad():
3302
+ _ = ref_model(*inputs)
3303
+ end.record()
3304
+ torch.cuda.synchronize()
3305
+ ref_times.append(start.elapsed_time(end))
3306
+
3307
+ ref_stats = _calculate_timing_stats(ref_times)
3308
+ results["reference_runtime_ms"] = ref_stats["median"]
3309
+ results["reference_runtime_stats"] = ref_stats
3310
+ results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
3311
+ print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
2857
3312
 
2858
3313
  # Run profiling if requested and correctness passed
2859
3314
  if args.profile and all_correct:
@@ -2898,6 +3353,16 @@ def main():
2898
3353
  json.dump(results, f, indent=2)
2899
3354
  print(f"[KernelBench] Results written to {args.output}")
2900
3355
 
3356
+ # Cleanup GPU memory
3357
+ try:
3358
+ del ref_model, new_model
3359
+ except NameError:
3360
+ pass
3361
+ import gc
3362
+ gc.collect()
3363
+ if torch.cuda.is_available():
3364
+ torch.cuda.empty_cache()
3365
+
2901
3366
  if __name__ == "__main__":
2902
3367
  main()
2903
3368
  """
@@ -2947,6 +3412,27 @@ def _validate_kernelbench_files(args: KernelBenchEvaluateArgs) -> str | None:
2947
3412
  " KernelBench format requires: 'class Model', 'get_inputs()', 'get_init_inputs()'"
2948
3413
  )
2949
3414
 
3415
+ # Static kernel validation if backend specified
3416
+ if args.backend:
3417
+ from wafer_core.utils.kernel_utils.static_checker import validate_kernel_static
3418
+
3419
+ code = args.implementation.read_text()
3420
+ valid, errors, warnings = validate_kernel_static(code, backend=args.backend)
3421
+
3422
+ # Print warnings (don't fail)
3423
+ for warning in warnings:
3424
+ logger.warning(f"Static check warning: {warning}")
3425
+
3426
+ # Fail on errors
3427
+ if not valid:
3428
+ error_list = "\n - ".join(errors)
3429
+ return (
3430
+ f"Static kernel validation failed for backend '{args.backend}':\n"
3431
+ f" - {error_list}\n\n"
3432
+ f"The implementation must use {args.backend.upper()} kernel primitives.\n"
3433
+ "See KernelBench documentation for valid kernel patterns."
3434
+ )
3435
+
2950
3436
  return None
2951
3437
 
2952
3438
 
@@ -3059,6 +3545,30 @@ async def run_evaluate_kernelbench_docker(
3059
3545
  error_message=f"Failed to write eval script: {write_result.stderr}",
3060
3546
  )
3061
3547
 
3548
+ # Write defense module if defensive mode is enabled
3549
+ defense_module_path = None
3550
+ if args.defensive:
3551
+ defense_path = (
3552
+ Path(__file__).parent.parent.parent.parent
3553
+ / "packages"
3554
+ / "wafer-core"
3555
+ / "wafer_core"
3556
+ / "utils"
3557
+ / "kernel_utils"
3558
+ / "defense.py"
3559
+ )
3560
+ if defense_path.exists():
3561
+ defense_code = defense_path.read_text()
3562
+ defense_module_path = f"{run_path}/defense.py"
3563
+ write_result = await client.exec(
3564
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3565
+ )
3566
+ if write_result.exit_code != 0:
3567
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3568
+ defense_module_path = None
3569
+ else:
3570
+ print(f"Warning: defense.py not found at {defense_path}")
3571
+
3062
3572
  print("Running KernelBench evaluation in Docker container...")
3063
3573
 
3064
3574
  # Paths inside container
@@ -3068,6 +3578,7 @@ async def run_evaluate_kernelbench_docker(
3068
3578
  container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
3069
3579
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3070
3580
  container_output = f"{container_run_path}/results.json"
3581
+ container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
3071
3582
 
3072
3583
  # Build eval command
3073
3584
  python_cmd_parts = [
@@ -3083,6 +3594,9 @@ async def run_evaluate_kernelbench_docker(
3083
3594
  python_cmd_parts.append("--profile")
3084
3595
  if container_inputs_path:
3085
3596
  python_cmd_parts.append(f"--inputs {container_inputs_path}")
3597
+ if args.defensive and container_defense_path:
3598
+ python_cmd_parts.append("--defensive")
3599
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3086
3600
  python_cmd_parts.append(f"--seed {args.seed}")
3087
3601
 
3088
3602
  eval_cmd = " ".join(python_cmd_parts)
@@ -3106,7 +3620,7 @@ async def run_evaluate_kernelbench_docker(
3106
3620
  # Run and stream output
3107
3621
  log_lines = []
3108
3622
  async for line in client.exec_stream(docker_cmd):
3109
- print(line)
3623
+ print(line, flush=True)
3110
3624
  log_lines.append(line)
3111
3625
 
3112
3626
  # Read results
@@ -3298,15 +3812,44 @@ async def run_evaluate_kernelbench_digitalocean(
3298
3812
  error_message=f"Failed to write eval script: {write_result.stderr}",
3299
3813
  )
3300
3814
 
3815
+ # Write defense module if defensive mode is enabled
3816
+ defense_module_path = None
3817
+ if args.defensive:
3818
+ defense_path = (
3819
+ Path(__file__).parent.parent.parent.parent
3820
+ / "packages"
3821
+ / "wafer-core"
3822
+ / "wafer_core"
3823
+ / "utils"
3824
+ / "kernel_utils"
3825
+ / "defense.py"
3826
+ )
3827
+ if defense_path.exists():
3828
+ defense_code = defense_path.read_text()
3829
+ defense_module_path = f"{run_path}/defense.py"
3830
+ write_result = await client.exec(
3831
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3832
+ )
3833
+ if write_result.exit_code != 0:
3834
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3835
+ defense_module_path = None
3836
+ else:
3837
+ print(f"Warning: defense.py not found at {defense_path}")
3838
+
3301
3839
  print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
3302
3840
 
3303
3841
  # Paths inside container
3304
3842
  container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
3305
3843
  container_impl_path = f"{container_run_path}/implementation.py"
3306
3844
  container_ref_path = f"{container_run_path}/reference.py"
3307
- container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
3845
+ container_inputs_path = (
3846
+ f"{container_run_path}/custom_inputs.py" if args.inputs else None
3847
+ )
3308
3848
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3309
3849
  container_output = f"{container_run_path}/results.json"
3850
+ container_defense_path = (
3851
+ f"{container_run_path}/defense.py" if defense_module_path else None
3852
+ )
3310
3853
 
3311
3854
  # Build eval command
3312
3855
  python_cmd_parts = [
@@ -3322,6 +3865,9 @@ async def run_evaluate_kernelbench_digitalocean(
3322
3865
  python_cmd_parts.append("--profile")
3323
3866
  if container_inputs_path:
3324
3867
  python_cmd_parts.append(f"--inputs {container_inputs_path}")
3868
+ if args.defensive and container_defense_path:
3869
+ python_cmd_parts.append("--defensive")
3870
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3325
3871
  python_cmd_parts.append(f"--seed {args.seed}")
3326
3872
 
3327
3873
  eval_cmd = " ".join(python_cmd_parts)
@@ -3330,14 +3876,20 @@ async def run_evaluate_kernelbench_digitalocean(
3330
3876
  full_cmd = f"cd {container_run_path} && {eval_cmd}"
3331
3877
 
3332
3878
  # Build Docker command for AMD
3879
+ # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
3880
+ rocm_arch = _get_rocm_arch(target.compute_capability)
3881
+ env_dict = {
3882
+ "HIP_VISIBLE_DEVICES": str(gpu_id),
3883
+ "PYTHONUNBUFFERED": "1",
3884
+ }
3885
+ if rocm_arch:
3886
+ env_dict["PYTORCH_ROCM_ARCH"] = rocm_arch
3887
+
3333
3888
  docker_cmd = _build_docker_run_command_amd(
3334
3889
  image=docker_image,
3335
3890
  command=full_cmd,
3336
3891
  working_dir=container_run_path,
3337
- env={
3338
- "HIP_VISIBLE_DEVICES": str(gpu_id),
3339
- "PYTHONUNBUFFERED": "1",
3340
- },
3892
+ env=env_dict,
3341
3893
  volumes={workspace_path: CONTAINER_WORKSPACE},
3342
3894
  )
3343
3895
 
@@ -3346,7 +3898,7 @@ async def run_evaluate_kernelbench_digitalocean(
3346
3898
  # Run and stream output
3347
3899
  log_lines = []
3348
3900
  async for line in client.exec_stream(docker_cmd):
3349
- print(line)
3901
+ print(line, flush=True)
3350
3902
  log_lines.append(line)
3351
3903
 
3352
3904
  # Read results
@@ -3407,55 +3959,528 @@ async def run_evaluate_kernelbench_digitalocean(
3407
3959
  )
3408
3960
 
3409
3961
 
3410
- async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
3411
- """Run KernelBench format evaluation on configured target.
3412
-
3413
- Args:
3414
- args: KernelBench evaluate arguments
3962
+ async def run_evaluate_kernelbench_runpod(
3963
+ args: KernelBenchEvaluateArgs,
3964
+ target: RunPodTarget,
3965
+ ) -> EvaluateResult:
3966
+ """Run KernelBench format evaluation directly on RunPod AMD GPU.
3415
3967
 
3416
- Returns:
3417
- Evaluation result
3968
+ Runs evaluation script directly on host (no Docker) since RunPod pods
3969
+ already have PyTorch/ROCm installed.
3418
3970
  """
3419
- from .targets import get_default_target, load_target
3971
+ from datetime import datetime
3420
3972
 
3421
- # Validate input files
3422
- err = _validate_kernelbench_files(args)
3423
- if err:
3424
- return EvaluateResult(
3425
- success=False,
3426
- all_correct=False,
3427
- correctness_score=0.0,
3428
- geomean_speedup=0.0,
3429
- passed_tests=0,
3430
- total_tests=0,
3431
- error_message=err,
3432
- )
3973
+ from wafer_core.async_ssh import AsyncSSHClient
3974
+ from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
3433
3975
 
3434
- # Load target
3435
- target_name = args.target_name
3436
- if not target_name:
3437
- target_name = get_default_target()
3438
- if not target_name:
3439
- return EvaluateResult(
3440
- success=False,
3441
- all_correct=False,
3442
- correctness_score=0.0,
3443
- geomean_speedup=0.0,
3444
- passed_tests=0,
3445
- total_tests=0,
3446
- error_message=(
3447
- "No target specified and no default set.\n"
3448
- "Set up a target first:\n"
3449
- " wafer config targets init ssh --name my-gpu --host user@host:22\n"
3450
- " wafer config targets init runpod --gpu MI300X\n"
3451
- "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
3452
- ),
3453
- )
3976
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
3977
+
3978
+ # Select GPU
3979
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
3980
+
3981
+ print(f"Provisioning RunPod ({target.gpu_type_id})...")
3454
3982
 
3455
3983
  try:
3456
- target = load_target(target_name)
3457
- except FileNotFoundError:
3458
- return EvaluateResult(
3984
+ async with runpod_ssh_context(target) as ssh_info:
3985
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
3986
+ print(f"Connected to RunPod: {ssh_target}")
3987
+
3988
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
3989
+ # Create workspace
3990
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3991
+ run_dir = f"kernelbench_eval_{timestamp}"
3992
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
3993
+
3994
+ await client.exec(f"mkdir -p {run_path}")
3995
+ print(f"Created run directory: {run_path}")
3996
+
3997
+ # Read and upload files
3998
+ impl_code = args.implementation.read_text()
3999
+ ref_code = args.reference.read_text()
4000
+
4001
+ # Write implementation
4002
+ impl_path = f"{run_path}/implementation.py"
4003
+ write_result = await client.exec(
4004
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4005
+ )
4006
+ if write_result.exit_code != 0:
4007
+ return EvaluateResult(
4008
+ success=False,
4009
+ all_correct=False,
4010
+ correctness_score=0.0,
4011
+ geomean_speedup=0.0,
4012
+ passed_tests=0,
4013
+ total_tests=0,
4014
+ error_message=f"Failed to write implementation: {write_result.stderr}",
4015
+ )
4016
+
4017
+ # Write reference
4018
+ ref_path = f"{run_path}/reference.py"
4019
+ write_result = await client.exec(
4020
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
4021
+ )
4022
+ if write_result.exit_code != 0:
4023
+ return EvaluateResult(
4024
+ success=False,
4025
+ all_correct=False,
4026
+ correctness_score=0.0,
4027
+ geomean_speedup=0.0,
4028
+ passed_tests=0,
4029
+ total_tests=0,
4030
+ error_message=f"Failed to write reference: {write_result.stderr}",
4031
+ )
4032
+
4033
+ # Write custom inputs if provided
4034
+ inputs_path = None
4035
+ if args.inputs:
4036
+ inputs_code = args.inputs.read_text()
4037
+ inputs_path = f"{run_path}/custom_inputs.py"
4038
+ write_result = await client.exec(
4039
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4040
+ )
4041
+ if write_result.exit_code != 0:
4042
+ return EvaluateResult(
4043
+ success=False,
4044
+ all_correct=False,
4045
+ correctness_score=0.0,
4046
+ geomean_speedup=0.0,
4047
+ passed_tests=0,
4048
+ total_tests=0,
4049
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
4050
+ )
4051
+
4052
+ # Write eval script
4053
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
4054
+ write_result = await client.exec(
4055
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4056
+ )
4057
+ if write_result.exit_code != 0:
4058
+ return EvaluateResult(
4059
+ success=False,
4060
+ all_correct=False,
4061
+ correctness_score=0.0,
4062
+ geomean_speedup=0.0,
4063
+ passed_tests=0,
4064
+ total_tests=0,
4065
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4066
+ )
4067
+
4068
+ # Write defense module if defensive mode is enabled
4069
+ defense_module_path = None
4070
+ if args.defensive:
4071
+ defense_path = (
4072
+ Path(__file__).parent.parent.parent.parent
4073
+ / "packages"
4074
+ / "wafer-core"
4075
+ / "wafer_core"
4076
+ / "utils"
4077
+ / "kernel_utils"
4078
+ / "defense.py"
4079
+ )
4080
+ if defense_path.exists():
4081
+ defense_code = defense_path.read_text()
4082
+ defense_module_path = f"{run_path}/defense.py"
4083
+ write_result = await client.exec(
4084
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4085
+ )
4086
+ if write_result.exit_code != 0:
4087
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4088
+ defense_module_path = None
4089
+ else:
4090
+ print(f"Warning: defense.py not found at {defense_path}")
4091
+
4092
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4093
+
4094
+ # Find Python with PyTorch - check common locations on RunPod
4095
+ python_exe = "python3"
4096
+ for candidate in [
4097
+ "/opt/conda/envs/py_3.10/bin/python3",
4098
+ "/opt/conda/bin/python3",
4099
+ ]:
4100
+ check = await client.exec(
4101
+ f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
4102
+ )
4103
+ if "OK" in check.stdout:
4104
+ python_exe = candidate
4105
+ print(f"Using Python: {python_exe}")
4106
+ break
4107
+
4108
+ # Build eval command - run directly on host
4109
+ output_path = f"{run_path}/results.json"
4110
+ python_cmd_parts = [
4111
+ f"{python_exe} {eval_script_path}",
4112
+ f"--impl {impl_path}",
4113
+ f"--reference {ref_path}",
4114
+ f"--output {output_path}",
4115
+ ]
4116
+
4117
+ if args.benchmark:
4118
+ python_cmd_parts.append("--benchmark")
4119
+ if args.profile:
4120
+ python_cmd_parts.append("--profile")
4121
+ if inputs_path:
4122
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4123
+ if args.defensive and defense_module_path:
4124
+ python_cmd_parts.append("--defensive")
4125
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4126
+ python_cmd_parts.append(f"--seed {args.seed}")
4127
+
4128
+ eval_cmd = " ".join(python_cmd_parts)
4129
+
4130
+ # Set environment for AMD GPU and run
4131
+ # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
4132
+ rocm_arch = _get_rocm_arch(target.compute_capability)
4133
+ arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4134
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4135
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4136
+
4137
+ # Run and stream output
4138
+ log_lines = []
4139
+ async for line in client.exec_stream(full_cmd):
4140
+ print(line, flush=True)
4141
+ log_lines.append(line)
4142
+
4143
+ # Read results
4144
+ cat_result = await client.exec(f"cat {output_path}")
4145
+
4146
+ if cat_result.exit_code != 0:
4147
+ log_tail = "\n".join(log_lines[-50:])
4148
+ return EvaluateResult(
4149
+ success=False,
4150
+ all_correct=False,
4151
+ correctness_score=0.0,
4152
+ geomean_speedup=0.0,
4153
+ passed_tests=0,
4154
+ total_tests=0,
4155
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4156
+ )
4157
+
4158
+ # Parse results
4159
+ try:
4160
+ results_data = json.loads(cat_result.stdout)
4161
+ except json.JSONDecodeError as e:
4162
+ return EvaluateResult(
4163
+ success=False,
4164
+ all_correct=False,
4165
+ correctness_score=0.0,
4166
+ geomean_speedup=0.0,
4167
+ passed_tests=0,
4168
+ total_tests=0,
4169
+ error_message=f"Failed to parse results: {e}",
4170
+ )
4171
+
4172
+ # Convert to EvaluateResult
4173
+ correct = results_data.get("correct", False)
4174
+ speedup = results_data.get("speedup", 0.0) or 0.0
4175
+ error = results_data.get("error")
4176
+
4177
+ if error:
4178
+ return EvaluateResult(
4179
+ success=False,
4180
+ all_correct=False,
4181
+ correctness_score=0.0,
4182
+ geomean_speedup=0.0,
4183
+ passed_tests=0,
4184
+ total_tests=1,
4185
+ error_message=error,
4186
+ )
4187
+
4188
+ return EvaluateResult(
4189
+ success=True,
4190
+ all_correct=correct,
4191
+ correctness_score=1.0 if correct else 0.0,
4192
+ geomean_speedup=speedup,
4193
+ passed_tests=1 if correct else 0,
4194
+ total_tests=1,
4195
+ )
4196
+
4197
+ except RunPodError as e:
4198
+ return EvaluateResult(
4199
+ success=False,
4200
+ all_correct=False,
4201
+ correctness_score=0.0,
4202
+ geomean_speedup=0.0,
4203
+ passed_tests=0,
4204
+ total_tests=0,
4205
+ error_message=f"RunPod error: {e}",
4206
+ )
4207
+
4208
+
4209
+ async def run_evaluate_kernelbench_baremetal_amd(
4210
+ args: KernelBenchEvaluateArgs,
4211
+ target: BaremetalTarget,
4212
+ ) -> EvaluateResult:
4213
+ """Run KernelBench format evaluation directly on AMD baremetal target.
4214
+
4215
+ Runs evaluation script directly on host (no Docker) for AMD GPUs
4216
+ that have PyTorch/ROCm installed.
4217
+ """
4218
+ from datetime import datetime
4219
+
4220
+ from wafer_core.async_ssh import AsyncSSHClient
4221
+
4222
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
4223
+
4224
+ # Select GPU
4225
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
4226
+
4227
+ print(f"Connecting to {target.ssh_target}...")
4228
+
4229
+ async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
4230
+ # Create workspace
4231
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4232
+ run_dir = f"kernelbench_eval_{timestamp}"
4233
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
4234
+
4235
+ await client.exec(f"mkdir -p {run_path}")
4236
+ print(f"Created run directory: {run_path}")
4237
+
4238
+ # Read and upload files
4239
+ impl_code = args.implementation.read_text()
4240
+ ref_code = args.reference.read_text()
4241
+
4242
+ # Write implementation
4243
+ impl_path = f"{run_path}/implementation.py"
4244
+ write_result = await client.exec(
4245
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4246
+ )
4247
+ if write_result.exit_code != 0:
4248
+ return EvaluateResult(
4249
+ success=False,
4250
+ all_correct=False,
4251
+ correctness_score=0.0,
4252
+ geomean_speedup=0.0,
4253
+ passed_tests=0,
4254
+ total_tests=0,
4255
+ error_message=f"Failed to write implementation: {write_result.stderr}",
4256
+ )
4257
+
4258
+ # Write reference
4259
+ ref_path = f"{run_path}/reference.py"
4260
+ write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
4261
+ if write_result.exit_code != 0:
4262
+ return EvaluateResult(
4263
+ success=False,
4264
+ all_correct=False,
4265
+ correctness_score=0.0,
4266
+ geomean_speedup=0.0,
4267
+ passed_tests=0,
4268
+ total_tests=0,
4269
+ error_message=f"Failed to write reference: {write_result.stderr}",
4270
+ )
4271
+
4272
+ # Write custom inputs if provided
4273
+ inputs_path = None
4274
+ if args.inputs:
4275
+ inputs_code = args.inputs.read_text()
4276
+ inputs_path = f"{run_path}/custom_inputs.py"
4277
+ write_result = await client.exec(
4278
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4279
+ )
4280
+ if write_result.exit_code != 0:
4281
+ return EvaluateResult(
4282
+ success=False,
4283
+ all_correct=False,
4284
+ correctness_score=0.0,
4285
+ geomean_speedup=0.0,
4286
+ passed_tests=0,
4287
+ total_tests=0,
4288
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
4289
+ )
4290
+
4291
+ # Write eval script
4292
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
4293
+ write_result = await client.exec(
4294
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4295
+ )
4296
+ if write_result.exit_code != 0:
4297
+ return EvaluateResult(
4298
+ success=False,
4299
+ all_correct=False,
4300
+ correctness_score=0.0,
4301
+ geomean_speedup=0.0,
4302
+ passed_tests=0,
4303
+ total_tests=0,
4304
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4305
+ )
4306
+
4307
+ # Write defense module if defensive mode is enabled
4308
+ defense_module_path = None
4309
+ if args.defensive:
4310
+ defense_path = (
4311
+ Path(__file__).parent.parent.parent.parent
4312
+ / "packages"
4313
+ / "wafer-core"
4314
+ / "wafer_core"
4315
+ / "utils"
4316
+ / "kernel_utils"
4317
+ / "defense.py"
4318
+ )
4319
+ if defense_path.exists():
4320
+ defense_code = defense_path.read_text()
4321
+ defense_module_path = f"{run_path}/defense.py"
4322
+ write_result = await client.exec(
4323
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4324
+ )
4325
+ if write_result.exit_code != 0:
4326
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4327
+ defense_module_path = None
4328
+ else:
4329
+ print(f"Warning: defense.py not found at {defense_path}")
4330
+
4331
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4332
+
4333
+ # Find Python with PyTorch - check common locations
4334
+ python_exe = "python3"
4335
+ for candidate in [
4336
+ "/opt/conda/envs/py_3.10/bin/python3",
4337
+ "/opt/conda/bin/python3",
4338
+ ]:
4339
+ check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
4340
+ if "OK" in check.stdout:
4341
+ python_exe = candidate
4342
+ print(f"Using Python: {python_exe}")
4343
+ break
4344
+
4345
+ # Build eval command - run directly on host
4346
+ output_path = f"{run_path}/results.json"
4347
+ python_cmd_parts = [
4348
+ f"{python_exe} {eval_script_path}",
4349
+ f"--impl {impl_path}",
4350
+ f"--reference {ref_path}",
4351
+ f"--output {output_path}",
4352
+ ]
4353
+
4354
+ if args.benchmark:
4355
+ python_cmd_parts.append("--benchmark")
4356
+ if args.profile:
4357
+ python_cmd_parts.append("--profile")
4358
+ if inputs_path:
4359
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4360
+ if args.defensive and defense_module_path:
4361
+ python_cmd_parts.append("--defensive")
4362
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4363
+ python_cmd_parts.append(f"--seed {args.seed}")
4364
+
4365
+ eval_cmd = " ".join(python_cmd_parts)
4366
+
4367
+ # Set environment for AMD GPU and run
4368
+ # PYTORCH_ROCM_ARCH: compile only for target arch (5-7x faster compile)
4369
+ rocm_arch = _get_rocm_arch(target.compute_capability)
4370
+ arch_env = f"PYTORCH_ROCM_ARCH={rocm_arch}" if rocm_arch else ""
4371
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1 {arch_env}"
4372
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4373
+
4374
+ # Run and stream output
4375
+ log_lines = []
4376
+ async for line in client.exec_stream(full_cmd):
4377
+ print(line, flush=True)
4378
+ log_lines.append(line)
4379
+
4380
+ # Read results
4381
+ cat_result = await client.exec(f"cat {output_path}")
4382
+
4383
+ if cat_result.exit_code != 0:
4384
+ log_tail = "\n".join(log_lines[-50:])
4385
+ return EvaluateResult(
4386
+ success=False,
4387
+ all_correct=False,
4388
+ correctness_score=0.0,
4389
+ geomean_speedup=0.0,
4390
+ passed_tests=0,
4391
+ total_tests=0,
4392
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4393
+ )
4394
+
4395
+ # Parse results
4396
+ try:
4397
+ results_data = json.loads(cat_result.stdout)
4398
+ except json.JSONDecodeError as e:
4399
+ return EvaluateResult(
4400
+ success=False,
4401
+ all_correct=False,
4402
+ correctness_score=0.0,
4403
+ geomean_speedup=0.0,
4404
+ passed_tests=0,
4405
+ total_tests=0,
4406
+ error_message=f"Failed to parse results: {e}",
4407
+ )
4408
+
4409
+ # Convert to EvaluateResult
4410
+ correct = results_data.get("correct", False)
4411
+ speedup = results_data.get("speedup", 0.0) or 0.0
4412
+ error = results_data.get("error")
4413
+
4414
+ if error:
4415
+ return EvaluateResult(
4416
+ success=False,
4417
+ all_correct=False,
4418
+ correctness_score=0.0,
4419
+ geomean_speedup=0.0,
4420
+ passed_tests=0,
4421
+ total_tests=1,
4422
+ error_message=error,
4423
+ )
4424
+
4425
+ return EvaluateResult(
4426
+ success=True,
4427
+ all_correct=correct,
4428
+ correctness_score=1.0 if correct else 0.0,
4429
+ geomean_speedup=speedup,
4430
+ passed_tests=1 if correct else 0,
4431
+ total_tests=1,
4432
+ )
4433
+
4434
+
4435
+ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
4436
+ """Run KernelBench format evaluation on configured target.
4437
+
4438
+ Args:
4439
+ args: KernelBench evaluate arguments
4440
+
4441
+ Returns:
4442
+ Evaluation result
4443
+ """
4444
+ from .targets import get_default_target, load_target
4445
+
4446
+ # Validate input files
4447
+ err = _validate_kernelbench_files(args)
4448
+ if err:
4449
+ return EvaluateResult(
4450
+ success=False,
4451
+ all_correct=False,
4452
+ correctness_score=0.0,
4453
+ geomean_speedup=0.0,
4454
+ passed_tests=0,
4455
+ total_tests=0,
4456
+ error_message=err,
4457
+ )
4458
+
4459
+ # Load target
4460
+ target_name = args.target_name
4461
+ if not target_name:
4462
+ target_name = get_default_target()
4463
+ if not target_name:
4464
+ return EvaluateResult(
4465
+ success=False,
4466
+ all_correct=False,
4467
+ correctness_score=0.0,
4468
+ geomean_speedup=0.0,
4469
+ passed_tests=0,
4470
+ total_tests=0,
4471
+ error_message=(
4472
+ "No target specified and no default set.\n"
4473
+ "Set up a target first:\n"
4474
+ " wafer config targets init ssh --name my-gpu --host user@host:22\n"
4475
+ " wafer config targets init runpod --gpu MI300X\n"
4476
+ "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
4477
+ ),
4478
+ )
4479
+
4480
+ try:
4481
+ target = load_target(target_name)
4482
+ except FileNotFoundError:
4483
+ return EvaluateResult(
3459
4484
  success=False,
3460
4485
  all_correct=False,
3461
4486
  correctness_score=0.0,
@@ -3471,7 +4496,13 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
3471
4496
  if isinstance(target, DigitalOceanTarget):
3472
4497
  # DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
3473
4498
  return await run_evaluate_kernelbench_digitalocean(args, target)
4499
+ elif isinstance(target, RunPodTarget):
4500
+ # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4501
+ return await run_evaluate_kernelbench_runpod(args, target)
3474
4502
  elif isinstance(target, BaremetalTarget | VMTarget):
4503
+ # Check if this is an AMD target (gfx* compute capability) - run directly
4504
+ if target.compute_capability and target.compute_capability.startswith("gfx"):
4505
+ return await run_evaluate_kernelbench_baremetal_amd(args, target)
3475
4506
  # NVIDIA targets - require docker_image to be set
3476
4507
  if not target.docker_image:
3477
4508
  return EvaluateResult(
@@ -3497,6 +4528,6 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
3497
4528
  total_tests=0,
3498
4529
  error_message=(
3499
4530
  f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
3500
- "Use a DigitalOcean, Baremetal, or VM target."
4531
+ "Use a DigitalOcean, RunPod, Baremetal, or VM target."
3501
4532
  ),
3502
4533
  )