wafer-cli 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/evaluate.py CHANGED
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
14
14
  from wafer_core.utils.kernel_utils.targets.config import (
15
15
  BaremetalTarget,
16
16
  DigitalOceanTarget,
17
+ LocalTarget,
17
18
  ModalTarget,
18
19
  RunPodTarget,
19
20
  VMTarget,
@@ -396,33 +397,6 @@ async def run_evaluate_docker(
396
397
  print(f"Connecting to {target.ssh_target}...")
397
398
 
398
399
  async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
399
- # Upload wafer-core to remote
400
- try:
401
- wafer_root = _get_wafer_root()
402
- wafer_core_path = wafer_root / "packages" / "wafer-core"
403
- print(f"Uploading wafer-core from {wafer_core_path}...")
404
-
405
- # Create workspace and upload
406
- workspace_name = wafer_core_path.name
407
- remote_workspace = f"{REMOTE_WORKSPACE_BASE}/{workspace_name}"
408
- await client.exec(f"mkdir -p {remote_workspace}")
409
- wafer_core_workspace = await client.expand_path(remote_workspace)
410
-
411
- upload_result = await client.upload_files(
412
- str(wafer_core_path), wafer_core_workspace, recursive=True
413
- )
414
- print(f"Uploaded {upload_result.files_copied} files")
415
- except Exception as e:
416
- return EvaluateResult(
417
- success=False,
418
- all_correct=False,
419
- correctness_score=0.0,
420
- geomean_speedup=0.0,
421
- passed_tests=0,
422
- total_tests=0,
423
- error_message=f"Failed to upload wafer-core: {e}",
424
- )
425
-
426
400
  print(f"Using Docker image: {target.docker_image}")
427
401
  print(f"Using GPU {gpu_id}...")
428
402
 
@@ -431,10 +405,13 @@ async def run_evaluate_docker(
431
405
  ref_code = args.reference.read_text()
432
406
  test_cases_data = json.loads(args.test_cases.read_text())
433
407
 
434
- # Create a unique run directory
408
+ # Create workspace for evaluation files
435
409
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
436
410
  run_dir = f"wafer_eval_{timestamp}"
437
- run_path = f"{wafer_core_workspace}/{run_dir}"
411
+ eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
412
+ await client.exec(f"mkdir -p {eval_workspace}")
413
+ eval_workspace_expanded = await client.expand_path(eval_workspace)
414
+ run_path = f"{eval_workspace_expanded}/{run_dir}"
438
415
 
439
416
  print("Uploading evaluation files...")
440
417
 
@@ -521,17 +498,16 @@ async def run_evaluate_docker(
521
498
  container_impl_path = f"{container_run_path}/implementation.py"
522
499
  container_ref_path = f"{container_run_path}/reference.py"
523
500
  container_test_cases_path = f"{container_run_path}/test_cases.json"
524
- container_evaluate_script = (
525
- f"{CONTAINER_WORKSPACE}/wafer_core/utils/kernel_utils/evaluate.py"
526
- )
527
501
 
528
- # Build pip install command for torch and other deps (no wafer-core install needed)
502
+ # Build pip install command for torch and other deps, plus wafer-core
529
503
  pip_install_cmd = _build_docker_pip_install_cmd(target)
504
+ install_cmd = (
505
+ f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
506
+ )
530
507
 
531
- # Build evaluate command - use PYTHONPATH instead of installing wafer-core
508
+ # Build evaluate command using installed wafer-core module
532
509
  python_cmd_parts = [
533
- f"PYTHONPATH={CONTAINER_WORKSPACE}:$PYTHONPATH",
534
- f"python3 {container_evaluate_script}",
510
+ "python3 -m wafer_core.utils.kernel_utils.evaluate",
535
511
  f"--implementation {container_impl_path}",
536
512
  f"--reference {container_ref_path}",
537
513
  f"--test-cases {container_test_cases_path}",
@@ -547,8 +523,8 @@ async def run_evaluate_docker(
547
523
 
548
524
  eval_cmd = " ".join(python_cmd_parts)
549
525
 
550
- # Full command: install torch deps, then run evaluate with PYTHONPATH
551
- full_cmd = f"{pip_install_cmd} && cd {container_run_path} && {eval_cmd}"
526
+ # Full command: install deps + wafer-core, then run evaluate
527
+ full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
552
528
 
553
529
  # Build Docker run command
554
530
  # Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
@@ -558,7 +534,7 @@ async def run_evaluate_docker(
558
534
  working_dir=container_run_path,
559
535
  env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
560
536
  gpus="all",
561
- volumes={wafer_core_workspace: CONTAINER_WORKSPACE},
537
+ volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
562
538
  cap_add=["SYS_ADMIN"] if args.profile else None,
563
539
  )
564
540
 
@@ -567,7 +543,7 @@ async def run_evaluate_docker(
567
543
  # Run Docker command and stream output
568
544
  log_lines = []
569
545
  async for line in client.exec_stream(docker_cmd):
570
- print(line)
546
+ print(line, flush=True)
571
547
  log_lines.append(line)
572
548
 
573
549
  # Read results
@@ -665,6 +641,181 @@ async def run_evaluate_docker(
665
641
  )
666
642
 
667
643
 
644
+ async def run_evaluate_local(
645
+ args: EvaluateArgs,
646
+ target: LocalTarget,
647
+ ) -> EvaluateResult:
648
+ """Run evaluation locally on the current machine.
649
+
650
+ For LocalTarget - no SSH needed, runs directly.
651
+
652
+ Args:
653
+ args: Evaluate arguments
654
+ target: Local target config
655
+
656
+ Returns:
657
+ Evaluation result
658
+ """
659
+ import os
660
+ import subprocess
661
+ import tempfile
662
+ from datetime import datetime
663
+
664
+ # Select GPU
665
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
666
+
667
+ print(f"Running local evaluation on GPU {gpu_id}...")
668
+
669
+ # Create temp directory for eval files
670
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
671
+ with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
672
+ run_path = Path(run_path)
673
+
674
+ # Write implementation
675
+ impl_path = run_path / "implementation.py"
676
+ impl_path.write_text(args.implementation.read_text())
677
+
678
+ # Write reference
679
+ ref_path = run_path / "reference.py"
680
+ ref_path.write_text(args.reference.read_text())
681
+
682
+ # Write custom inputs if provided
683
+ inputs_path = None
684
+ if args.inputs:
685
+ inputs_path = run_path / "custom_inputs.py"
686
+ inputs_path.write_text(args.inputs.read_text())
687
+
688
+ # Write eval script
689
+ eval_script_path = run_path / "kernelbench_eval.py"
690
+ eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
691
+
692
+ # Write defense module if defensive mode is enabled
693
+ defense_module_path = None
694
+ if args.defensive:
695
+ defense_src = (
696
+ Path(__file__).parent.parent.parent.parent
697
+ / "packages"
698
+ / "wafer-core"
699
+ / "wafer_core"
700
+ / "utils"
701
+ / "kernel_utils"
702
+ / "defense.py"
703
+ )
704
+ if defense_src.exists():
705
+ defense_module_path = run_path / "defense.py"
706
+ defense_module_path.write_text(defense_src.read_text())
707
+ else:
708
+ print(f"Warning: defense.py not found at {defense_src}")
709
+
710
+ # Output file
711
+ output_path = run_path / "results.json"
712
+
713
+ # Build eval command
714
+ cmd_parts = [
715
+ "python3",
716
+ str(eval_script_path),
717
+ "--impl",
718
+ str(impl_path),
719
+ "--reference",
720
+ str(ref_path),
721
+ "--output",
722
+ str(output_path),
723
+ "--seed",
724
+ str(args.seed),
725
+ ]
726
+
727
+ if args.benchmark:
728
+ cmd_parts.append("--benchmark")
729
+ if args.profile:
730
+ cmd_parts.append("--profile")
731
+ if inputs_path:
732
+ cmd_parts.extend(["--inputs", str(inputs_path)])
733
+ if args.defensive and defense_module_path:
734
+ cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
735
+
736
+ # Set environment for GPU selection
737
+ env = os.environ.copy()
738
+ if target.vendor == "nvidia":
739
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
740
+ else: # AMD
741
+ env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
742
+ env["ROCM_PATH"] = "/opt/rocm"
743
+
744
+ print(f"Running: {' '.join(cmd_parts[:4])} ...")
745
+
746
+ # Run evaluation
747
+ try:
748
+ result = subprocess.run(
749
+ cmd_parts,
750
+ cwd=str(run_path),
751
+ env=env,
752
+ capture_output=True,
753
+ text=True,
754
+ timeout=args.timeout or 600,
755
+ )
756
+ except subprocess.TimeoutExpired:
757
+ return EvaluateResult(
758
+ success=False,
759
+ all_correct=False,
760
+ correctness_score=0.0,
761
+ geomean_speedup=0.0,
762
+ passed_tests=0,
763
+ total_tests=0,
764
+ error_message="Evaluation timed out",
765
+ )
766
+
767
+ if result.returncode != 0:
768
+ error_msg = result.stderr or result.stdout or "Unknown error"
769
+ # Truncate long errors
770
+ if len(error_msg) > 1000:
771
+ error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
772
+ return EvaluateResult(
773
+ success=False,
774
+ all_correct=False,
775
+ correctness_score=0.0,
776
+ geomean_speedup=0.0,
777
+ passed_tests=0,
778
+ total_tests=0,
779
+ error_message=f"Evaluation failed:\n{error_msg}",
780
+ )
781
+
782
+ # Parse results
783
+ if not output_path.exists():
784
+ return EvaluateResult(
785
+ success=False,
786
+ all_correct=False,
787
+ correctness_score=0.0,
788
+ geomean_speedup=0.0,
789
+ passed_tests=0,
790
+ total_tests=0,
791
+ error_message="No results.json produced",
792
+ )
793
+
794
+ try:
795
+ results = json.loads(output_path.read_text())
796
+ except json.JSONDecodeError as e:
797
+ return EvaluateResult(
798
+ success=False,
799
+ all_correct=False,
800
+ correctness_score=0.0,
801
+ geomean_speedup=0.0,
802
+ passed_tests=0,
803
+ total_tests=0,
804
+ error_message=f"Failed to parse results: {e}",
805
+ )
806
+
807
+ # Extract results
808
+ return EvaluateResult(
809
+ success=True,
810
+ all_correct=results.get("all_correct", False),
811
+ correctness_score=results.get("correctness_score", 0.0),
812
+ geomean_speedup=results.get("geomean_speedup", 0.0),
813
+ passed_tests=results.get("passed_tests", 0),
814
+ total_tests=results.get("total_tests", 0),
815
+ benchmark_results=results.get("benchmark", {}),
816
+ )
817
+
818
+
668
819
  async def run_evaluate_ssh(
669
820
  args: EvaluateArgs,
670
821
  target: BaremetalTarget | VMTarget,
@@ -982,6 +1133,7 @@ def _build_modal_sandbox_script(
982
1133
  test_cases_b64: str,
983
1134
  run_benchmarks: bool,
984
1135
  run_defensive: bool,
1136
+ defense_code_b64: str | None = None,
985
1137
  ) -> str:
986
1138
  """Build Python script to create sandbox and run evaluation.
987
1139
 
@@ -1062,6 +1214,20 @@ print('Files written')
1062
1214
  print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
1063
1215
  return
1064
1216
 
1217
+ # Write defense module if defensive mode is enabled
1218
+ # NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
1219
+ if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
1220
+ proc = sandbox.exec("python", "-c", f"""
1221
+ import base64
1222
+ with open('/workspace/defense.py', 'w') as f:
1223
+ f.write(base64.b64decode('{defense_code_b64}').decode())
1224
+ print('Defense module written')
1225
+ """)
1226
+ proc.wait()
1227
+ if proc.returncode != 0:
1228
+ print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
1229
+ return
1230
+
1065
1231
  # Build inline evaluation script
1066
1232
  eval_script = """
1067
1233
  import json
@@ -1089,6 +1255,26 @@ generate_input = load_fn('reference.py', 'generate_input')
1089
1255
 
1090
1256
  import torch
1091
1257
 
1258
+ # Load defense module if available and defensive mode is enabled
1259
+ run_defensive = {run_defensive}
1260
+ defense = None
1261
+ if run_defensive:
1262
+ try:
1263
+ defense = load_fn('defense.py', 'run_all_defenses')
1264
+ time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
1265
+ print('[Defense] Defense module loaded')
1266
+
1267
+ # Wrap kernels for defense API compatibility
1268
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1269
+ # These wrappers repack the unpacked args back into a tuple
1270
+ def _wrap_for_defense(kernel):
1271
+ return lambda *args: kernel(args)
1272
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1273
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1274
+ except Exception as e:
1275
+ print(f'[Defense] Warning: Could not load defense module: {{e}}')
1276
+ defense = None
1277
+
1092
1278
  results = []
1093
1279
  all_correct = True
1094
1280
  total_time_ms = 0.0
@@ -1116,36 +1302,63 @@ for tc in test_cases:
1116
1302
  impl_time_ms = 0.0
1117
1303
  ref_time_ms = 0.0
1118
1304
  if {run_benchmarks}:
1119
- # Warmup
1120
- for _ in range(3):
1121
- custom_kernel(inputs)
1122
- torch.cuda.synchronize()
1123
-
1124
- # Measure with defensive timing if requested
1125
- # Defensive: sync before recording end event to catch stream injection
1126
- start = torch.cuda.Event(enable_timing=True)
1127
- end = torch.cuda.Event(enable_timing=True)
1128
- start.record()
1129
- for _ in range(10):
1130
- custom_kernel(inputs)
1131
- if {run_defensive}:
1132
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1133
- end.record()
1134
- torch.cuda.synchronize()
1135
- impl_time_ms = start.elapsed_time(end) / 10
1136
-
1137
- # Reference timing (same defensive approach)
1138
- for _ in range(3):
1139
- ref_kernel(inputs)
1140
- torch.cuda.synchronize()
1141
- start.record()
1142
- for _ in range(10):
1143
- ref_kernel(inputs)
1144
- if {run_defensive}:
1145
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1146
- end.record()
1147
- torch.cuda.synchronize()
1148
- ref_time_ms = start.elapsed_time(end) / 10
1305
+ if run_defensive and defense is not None:
1306
+ # Use full defense suite with wrapped kernels
1307
+ # inputs_list unpacks the tuple so defense can infer dtype/device from tensors
1308
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1309
+
1310
+ # Run defense checks
1311
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1312
+ if not all_passed:
1313
+ failed = [name for name, passed, _ in defense_results if not passed]
1314
+ raise ValueError(f"Defense checks failed: {{failed}}")
1315
+
1316
+ # Time with defensive timing (using wrapped kernels)
1317
+ impl_times, _ = time_with_defenses(
1318
+ custom_kernel_for_defense,
1319
+ inputs_list,
1320
+ num_warmup=3,
1321
+ num_trials=10,
1322
+ verbose=False,
1323
+ run_defenses=False,
1324
+ )
1325
+ impl_time_ms = sum(impl_times) / len(impl_times)
1326
+
1327
+ ref_times, _ = time_with_defenses(
1328
+ ref_kernel_for_defense,
1329
+ inputs_list,
1330
+ num_warmup=3,
1331
+ num_trials=10,
1332
+ verbose=False,
1333
+ run_defenses=False,
1334
+ )
1335
+ ref_time_ms = sum(ref_times) / len(ref_times)
1336
+ else:
1337
+ # Standard timing without full defenses
1338
+ # Warmup
1339
+ for _ in range(3):
1340
+ custom_kernel(inputs)
1341
+ torch.cuda.synchronize()
1342
+
1343
+ start = torch.cuda.Event(enable_timing=True)
1344
+ end = torch.cuda.Event(enable_timing=True)
1345
+ start.record()
1346
+ for _ in range(10):
1347
+ custom_kernel(inputs)
1348
+ end.record()
1349
+ torch.cuda.synchronize()
1350
+ impl_time_ms = start.elapsed_time(end) / 10
1351
+
1352
+ # Reference timing
1353
+ for _ in range(3):
1354
+ ref_kernel(inputs)
1355
+ torch.cuda.synchronize()
1356
+ start.record()
1357
+ for _ in range(10):
1358
+ ref_kernel(inputs)
1359
+ end.record()
1360
+ torch.cuda.synchronize()
1361
+ ref_time_ms = start.elapsed_time(end) / 10
1149
1362
 
1150
1363
  total_time_ms += impl_time_ms
1151
1364
  ref_total_time_ms += ref_time_ms
@@ -1197,7 +1410,7 @@ print(json.dumps({{
1197
1410
  # Find the last JSON line in output
1198
1411
  for line in reversed(stdout.strip().split("\\n")):
1199
1412
  if line.startswith("{{"):
1200
- print(line)
1413
+ print(line, flush=True)
1201
1414
  return
1202
1415
 
1203
1416
  print(json.dumps({{"error": f"No result JSON in output: {{stdout[:500]}}"}}))
@@ -1238,6 +1451,23 @@ async def run_evaluate_modal(
1238
1451
  ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
1239
1452
  test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
1240
1453
 
1454
+ # Encode defense module if defensive mode is enabled
1455
+ defense_code_b64 = None
1456
+ if args.defensive:
1457
+ defense_path = (
1458
+ Path(__file__).parent.parent.parent.parent
1459
+ / "packages"
1460
+ / "wafer-core"
1461
+ / "wafer_core"
1462
+ / "utils"
1463
+ / "kernel_utils"
1464
+ / "defense.py"
1465
+ )
1466
+ if defense_path.exists():
1467
+ defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
1468
+ else:
1469
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1470
+
1241
1471
  # Build the script that creates sandbox and runs eval
1242
1472
  script = _build_modal_sandbox_script(
1243
1473
  target=target,
@@ -1246,6 +1476,7 @@ async def run_evaluate_modal(
1246
1476
  test_cases_b64=test_cases_b64,
1247
1477
  run_benchmarks=args.benchmark,
1248
1478
  run_defensive=args.defensive,
1479
+ defense_code_b64=defense_code_b64,
1249
1480
  )
1250
1481
 
1251
1482
  def _run_subprocess() -> tuple[str, str, int]:
@@ -1343,6 +1574,7 @@ def _build_workspace_eval_script(
1343
1574
  test_cases_json: str,
1344
1575
  run_benchmarks: bool,
1345
1576
  run_defensive: bool = False,
1577
+ defense_code: str | None = None,
1346
1578
  ) -> str:
1347
1579
  """Build inline evaluation script for workspace exec.
1348
1580
 
@@ -1353,6 +1585,7 @@ def _build_workspace_eval_script(
1353
1585
  impl_b64 = base64.b64encode(impl_code.encode()).decode()
1354
1586
  ref_b64 = base64.b64encode(ref_code.encode()).decode()
1355
1587
  tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
1588
+ defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
1356
1589
 
1357
1590
  return f'''
1358
1591
  import base64
@@ -1372,6 +1605,15 @@ with open("/tmp/kernel.py", "w") as f:
1372
1605
  with open("/tmp/reference.py", "w") as f:
1373
1606
  f.write(ref_code)
1374
1607
 
1608
+ # Write defense module if available
1609
+ run_defensive = {run_defensive}
1610
+ defense_b64 = "{defense_b64}"
1611
+ # NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
1612
+ if run_defensive and defense_b64 and defense_b64 != "None":
1613
+ defense_code = base64.b64decode(defense_b64).decode()
1614
+ with open("/tmp/defense.py", "w") as f:
1615
+ f.write(defense_code)
1616
+
1375
1617
  # Load kernels
1376
1618
  def load_fn(path, name):
1377
1619
  spec = importlib.util.spec_from_file_location("mod", path)
@@ -1385,6 +1627,24 @@ generate_input = load_fn("/tmp/reference.py", "generate_input")
1385
1627
 
1386
1628
  import torch
1387
1629
 
1630
+ # Load defense module if available
1631
+ defense = None
1632
+ if run_defensive and defense_b64 and defense_b64 != "None":
1633
+ try:
1634
+ defense = load_fn("/tmp/defense.py", "run_all_defenses")
1635
+ time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
1636
+ print("[Defense] Defense module loaded")
1637
+
1638
+ # Wrap kernels for defense API compatibility
1639
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1640
+ def _wrap_for_defense(kernel):
1641
+ return lambda *args: kernel(args)
1642
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1643
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1644
+ except Exception as e:
1645
+ print(f"[Defense] Warning: Could not load defense module: {{e}}")
1646
+ defense = None
1647
+
1388
1648
  results = []
1389
1649
  all_correct = True
1390
1650
  total_time_ms = 0.0
@@ -1412,36 +1672,60 @@ for tc in test_cases:
1412
1672
  impl_time_ms = 0.0
1413
1673
  ref_time_ms = 0.0
1414
1674
  if {run_benchmarks}:
1415
- # Warmup
1416
- for _ in range(3):
1417
- custom_kernel(inputs)
1418
- torch.cuda.synchronize()
1419
-
1420
- # Measure with defensive timing if requested
1421
- # Defensive: sync before recording end event to catch stream injection
1422
- start = torch.cuda.Event(enable_timing=True)
1423
- end = torch.cuda.Event(enable_timing=True)
1424
- start.record()
1425
- for _ in range(10):
1426
- custom_kernel(inputs)
1427
- if {run_defensive}:
1428
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1429
- end.record()
1430
- torch.cuda.synchronize()
1431
- impl_time_ms = start.elapsed_time(end) / 10
1432
-
1433
- # Reference timing (same defensive approach)
1434
- for _ in range(3):
1435
- ref_kernel(inputs)
1436
- torch.cuda.synchronize()
1437
- start.record()
1438
- for _ in range(10):
1439
- ref_kernel(inputs)
1440
- if {run_defensive}:
1441
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1442
- end.record()
1443
- torch.cuda.synchronize()
1444
- ref_time_ms = start.elapsed_time(end) / 10
1675
+ if run_defensive and defense is not None:
1676
+ # Use full defense suite with wrapped kernels
1677
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1678
+
1679
+ # Run defense checks
1680
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1681
+ if not all_passed:
1682
+ failed = [name for name, passed, _ in defense_results if not passed]
1683
+ raise ValueError(f"Defense checks failed: {{failed}}")
1684
+
1685
+ # Time with defensive timing (using wrapped kernels)
1686
+ impl_times, _ = time_with_defenses(
1687
+ custom_kernel_for_defense,
1688
+ inputs_list,
1689
+ num_warmup=3,
1690
+ num_trials=10,
1691
+ verbose=False,
1692
+ run_defenses=False,
1693
+ )
1694
+ impl_time_ms = sum(impl_times) / len(impl_times)
1695
+
1696
+ ref_times, _ = time_with_defenses(
1697
+ ref_kernel_for_defense,
1698
+ inputs_list,
1699
+ num_warmup=3,
1700
+ num_trials=10,
1701
+ verbose=False,
1702
+ run_defenses=False,
1703
+ )
1704
+ ref_time_ms = sum(ref_times) / len(ref_times)
1705
+ else:
1706
+ # Standard timing
1707
+ for _ in range(3):
1708
+ custom_kernel(inputs)
1709
+ torch.cuda.synchronize()
1710
+
1711
+ start = torch.cuda.Event(enable_timing=True)
1712
+ end = torch.cuda.Event(enable_timing=True)
1713
+ start.record()
1714
+ for _ in range(10):
1715
+ custom_kernel(inputs)
1716
+ end.record()
1717
+ torch.cuda.synchronize()
1718
+ impl_time_ms = start.elapsed_time(end) / 10
1719
+
1720
+ for _ in range(3):
1721
+ ref_kernel(inputs)
1722
+ torch.cuda.synchronize()
1723
+ start.record()
1724
+ for _ in range(10):
1725
+ ref_kernel(inputs)
1726
+ end.record()
1727
+ torch.cuda.synchronize()
1728
+ ref_time_ms = start.elapsed_time(end) / 10
1445
1729
 
1446
1730
  total_time_ms += impl_time_ms
1447
1731
  ref_total_time_ms += ref_time_ms
@@ -1503,6 +1787,23 @@ async def run_evaluate_workspace(
1503
1787
  ref_code = args.reference.read_text()
1504
1788
  test_cases_json = args.test_cases.read_text()
1505
1789
 
1790
+ # Read defense module if defensive mode is enabled
1791
+ defense_code = None
1792
+ if args.defensive:
1793
+ defense_path = (
1794
+ Path(__file__).parent.parent.parent.parent
1795
+ / "packages"
1796
+ / "wafer-core"
1797
+ / "wafer_core"
1798
+ / "utils"
1799
+ / "kernel_utils"
1800
+ / "defense.py"
1801
+ )
1802
+ if defense_path.exists():
1803
+ defense_code = defense_path.read_text()
1804
+ else:
1805
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1806
+
1506
1807
  # Build inline eval script
1507
1808
  eval_script = _build_workspace_eval_script(
1508
1809
  impl_code=impl_code,
@@ -1510,6 +1811,7 @@ async def run_evaluate_workspace(
1510
1811
  test_cases_json=test_cases_json,
1511
1812
  run_benchmarks=args.benchmark,
1512
1813
  run_defensive=args.defensive,
1814
+ defense_code=defense_code,
1513
1815
  )
1514
1816
 
1515
1817
  # Execute via workspace exec
@@ -1855,15 +2157,12 @@ async def run_evaluate_runpod(
1855
2157
  # Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
1856
2158
  venv_bin = env_state.venv_bin
1857
2159
  env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
1858
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
1859
- evaluate_script = (
1860
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
1861
- )
1862
2160
 
1863
2161
  # Run from run_path so reference_kernel.py is importable
2162
+ # Use installed wafer-core module
1864
2163
  eval_cmd = (
1865
2164
  f"cd {run_path} && "
1866
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2165
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
1867
2166
  f"--implementation {impl_path} "
1868
2167
  f"--reference {ref_path} "
1869
2168
  f"--test-cases {test_cases_path} "
@@ -2219,15 +2518,12 @@ async def run_evaluate_digitalocean(
2219
2518
  env_vars = (
2220
2519
  f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
2221
2520
  )
2222
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
2223
- evaluate_script = (
2224
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
2225
- )
2226
2521
 
2227
2522
  # Run from run_path so reference_kernel.py is importable
2523
+ # Use installed wafer-core module
2228
2524
  eval_cmd = (
2229
2525
  f"cd {run_path} && "
2230
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2526
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
2231
2527
  f"--implementation {impl_path} "
2232
2528
  f"--reference {ref_path} "
2233
2529
  f"--test-cases {test_cases_path} "
@@ -2407,7 +2703,9 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2407
2703
  print(f"Using target: {target_name}")
2408
2704
 
2409
2705
  # Dispatch to appropriate executor
2410
- if isinstance(target, BaremetalTarget | VMTarget):
2706
+ if isinstance(target, LocalTarget):
2707
+ return await run_evaluate_local(args, target)
2708
+ elif isinstance(target, BaremetalTarget | VMTarget):
2411
2709
  return await run_evaluate_ssh(args, target)
2412
2710
  elif isinstance(target, ModalTarget):
2413
2711
  return await run_evaluate_modal(args, target)
@@ -2436,6 +2734,7 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2436
2734
  # Inline evaluation script for KernelBench format
2437
2735
  # This runs inside the Docker container on the remote GPU
2438
2736
  KERNELBENCH_EVAL_SCRIPT = """
2737
+ import gc
2439
2738
  import json
2440
2739
  import os
2441
2740
  import sys
@@ -2444,6 +2743,57 @@ import torch
2444
2743
  import torch.nn as nn
2445
2744
  from pathlib import Path
2446
2745
 
2746
+ # Clear any stale GPU memory from previous runs at startup
2747
+ # NOTE: empty_cache only frees memory from THIS process's PyTorch allocator.
2748
+ # It won't free memory from dead/zombie processes - rocm-smi --showpids can show
2749
+ # PIDs that no longer exist but still hold GPU memory. Those require a GPU reset
2750
+ # (rocm-smi --gpureset) to fully clear. TODO: detect and warn about orphaned memory.
2751
+ if torch.cuda.is_available():
2752
+ gc.collect()
2753
+ torch.cuda.empty_cache()
2754
+ torch.cuda.reset_peak_memory_stats()
2755
+
2756
+
2757
+ def _calculate_timing_stats(times: list[float]) -> dict:
2758
+ '''Calculate median and IQR from timing samples.
2759
+
2760
+ Returns dict with median, iqr_low (25th percentile), iqr_high (75th percentile),
2761
+ mean, min, max, and std.
2762
+ '''
2763
+ import statistics
2764
+
2765
+ if not times:
2766
+ return {"median": 0, "iqr_low": 0, "iqr_high": 0, "mean": 0, "min": 0, "max": 0, "std": 0}
2767
+
2768
+ sorted_times = sorted(times)
2769
+ n = len(sorted_times)
2770
+
2771
+ # Median
2772
+ median = statistics.median(sorted_times)
2773
+
2774
+ # Quartiles (25th and 75th percentile)
2775
+ # For small samples, use simple interpolation
2776
+ q1_idx = (n - 1) * 0.25
2777
+ q3_idx = (n - 1) * 0.75
2778
+
2779
+ q1_low = int(q1_idx)
2780
+ q1_frac = q1_idx - q1_low
2781
+ iqr_low = sorted_times[q1_low] * (1 - q1_frac) + sorted_times[min(q1_low + 1, n - 1)] * q1_frac
2782
+
2783
+ q3_low = int(q3_idx)
2784
+ q3_frac = q3_idx - q3_low
2785
+ iqr_high = sorted_times[q3_low] * (1 - q3_frac) + sorted_times[min(q3_low + 1, n - 1)] * q3_frac
2786
+
2787
+ return {
2788
+ "median": median,
2789
+ "iqr_low": iqr_low,
2790
+ "iqr_high": iqr_high,
2791
+ "mean": statistics.mean(sorted_times),
2792
+ "min": min(sorted_times),
2793
+ "max": max(sorted_times),
2794
+ "std": statistics.stdev(sorted_times) if n > 1 else 0,
2795
+ }
2796
+
2447
2797
 
2448
2798
  def run_profiling(model, inputs, name, output_dir):
2449
2799
  '''Run torch.profiler and return summary stats.'''
@@ -2674,12 +3024,26 @@ def main():
2674
3024
  parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
2675
3025
  parser.add_argument("--benchmark", action="store_true")
2676
3026
  parser.add_argument("--profile", action="store_true")
3027
+ parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
3028
+ parser.add_argument("--defense-module", help="Path to defense.py module")
2677
3029
  parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
2678
3030
  parser.add_argument("--num-correct-trials", type=int, default=3)
2679
3031
  parser.add_argument("--num-perf-trials", type=int, default=10)
2680
3032
  parser.add_argument("--output", required=True)
2681
3033
  args = parser.parse_args()
2682
3034
 
3035
+ # Load defense module if defensive mode is enabled
3036
+ defense_module = None
3037
+ if args.defensive and args.defense_module:
3038
+ try:
3039
+ import importlib.util
3040
+ defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
3041
+ defense_module = importlib.util.module_from_spec(defense_spec)
3042
+ defense_spec.loader.exec_module(defense_module)
3043
+ print("[KernelBench] Defense module loaded")
3044
+ except Exception as e:
3045
+ print(f"[KernelBench] Warning: Could not load defense module: {e}")
3046
+
2683
3047
  # Create output directory for profiles
2684
3048
  output_dir = Path(args.output).parent
2685
3049
  profile_dir = output_dir / "profiles"
@@ -2813,47 +3177,102 @@ def main():
2813
3177
  inputs = get_inputs()
2814
3178
  inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
2815
3179
 
2816
- # Warmup
2817
- for _ in range(5):
2818
- with torch.no_grad():
2819
- _ = new_model(*inputs)
2820
- torch.cuda.synchronize()
3180
+ if args.defensive and defense_module is not None:
3181
+ # Use full defense suite
3182
+ print("[KernelBench] Running defense checks on implementation...")
3183
+ run_all_defenses = defense_module.run_all_defenses
3184
+ time_with_defenses = defense_module.time_execution_with_defenses
2821
3185
 
2822
- # Benchmark new model
2823
- start = torch.cuda.Event(enable_timing=True)
2824
- end = torch.cuda.Event(enable_timing=True)
2825
-
2826
- times = []
2827
- for _ in range(args.num_perf_trials):
2828
- start.record()
2829
- with torch.no_grad():
2830
- _ = new_model(*inputs)
2831
- end.record()
2832
- torch.cuda.synchronize()
2833
- times.append(start.elapsed_time(end))
2834
-
2835
- new_time = sum(times) / len(times)
2836
- results["runtime_ms"] = new_time
2837
-
2838
- # Benchmark reference model
2839
- for _ in range(5):
2840
- with torch.no_grad():
2841
- _ = ref_model(*inputs)
2842
- torch.cuda.synchronize()
2843
-
2844
- times = []
2845
- for _ in range(args.num_perf_trials):
2846
- start.record()
2847
- with torch.no_grad():
2848
- _ = ref_model(*inputs)
2849
- end.record()
3186
+ # Run defense checks on implementation
3187
+ all_passed, defense_results, _ = run_all_defenses(
3188
+ lambda *x: new_model(*x),
3189
+ *inputs,
3190
+ )
3191
+ results["defense_results"] = {
3192
+ name: {"passed": passed, "message": msg}
3193
+ for name, passed, msg in defense_results
3194
+ }
3195
+ if not all_passed:
3196
+ failed = [name for name, passed, _ in defense_results if not passed]
3197
+ results["error"] = f"Defense checks failed: {failed}"
3198
+ print(f"[KernelBench] Defense checks FAILED: {failed}")
3199
+ for name, passed, msg in defense_results:
3200
+ status = "PASS" if passed else "FAIL"
3201
+ print(f" [{status}] {name}: {msg}")
3202
+ else:
3203
+ print("[KernelBench] All defense checks passed")
3204
+
3205
+ # Time with defensive timing
3206
+ impl_times, _ = time_with_defenses(
3207
+ lambda: new_model(*inputs),
3208
+ [],
3209
+ num_warmup=5,
3210
+ num_trials=args.num_perf_trials,
3211
+ verbose=False,
3212
+ run_defenses=False, # Already ran above
3213
+ )
3214
+ # Calculate stats for new model
3215
+ new_stats = _calculate_timing_stats(impl_times)
3216
+ results["runtime_ms"] = new_stats["median"]
3217
+ results["runtime_stats"] = new_stats
3218
+
3219
+ # Reference timing
3220
+ ref_times, _ = time_with_defenses(
3221
+ lambda: ref_model(*inputs),
3222
+ [],
3223
+ num_warmup=5,
3224
+ num_trials=args.num_perf_trials,
3225
+ verbose=False,
3226
+ run_defenses=False,
3227
+ )
3228
+ ref_stats = _calculate_timing_stats(ref_times)
3229
+ results["reference_runtime_ms"] = ref_stats["median"]
3230
+ results["reference_runtime_stats"] = ref_stats
3231
+ results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
3232
+ print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
3233
+ else:
3234
+ # Standard timing without full defenses
3235
+ # Warmup BOTH models before benchmarking either
3236
+ # This ensures consistent GPU state and avoids MIOpen cache effects
3237
+ # that cause variance when warming up models sequentially
3238
+ for _ in range(5):
3239
+ with torch.no_grad():
3240
+ _ = new_model(*inputs)
3241
+ _ = ref_model(*inputs)
2850
3242
  torch.cuda.synchronize()
2851
- times.append(start.elapsed_time(end))
2852
3243
 
2853
- ref_time = sum(times) / len(times)
2854
- results["reference_runtime_ms"] = ref_time
2855
- results["speedup"] = ref_time / new_time if new_time > 0 else 0
2856
- print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
3244
+ # Benchmark new model
3245
+ start = torch.cuda.Event(enable_timing=True)
3246
+ end = torch.cuda.Event(enable_timing=True)
3247
+
3248
+ new_times = []
3249
+ for _ in range(args.num_perf_trials):
3250
+ start.record()
3251
+ with torch.no_grad():
3252
+ _ = new_model(*inputs)
3253
+ end.record()
3254
+ torch.cuda.synchronize()
3255
+ new_times.append(start.elapsed_time(end))
3256
+
3257
+ new_stats = _calculate_timing_stats(new_times)
3258
+ results["runtime_ms"] = new_stats["median"]
3259
+ results["runtime_stats"] = new_stats
3260
+
3261
+ # Benchmark reference model
3262
+ ref_times = []
3263
+ for _ in range(args.num_perf_trials):
3264
+ start.record()
3265
+ with torch.no_grad():
3266
+ _ = ref_model(*inputs)
3267
+ end.record()
3268
+ torch.cuda.synchronize()
3269
+ ref_times.append(start.elapsed_time(end))
3270
+
3271
+ ref_stats = _calculate_timing_stats(ref_times)
3272
+ results["reference_runtime_ms"] = ref_stats["median"]
3273
+ results["reference_runtime_stats"] = ref_stats
3274
+ results["speedup"] = ref_stats["median"] / new_stats["median"] if new_stats["median"] > 0 else 0
3275
+ print(f"[KernelBench] New: {new_stats['median']:.3f}ms (IQR: {new_stats['iqr_low']:.3f}-{new_stats['iqr_high']:.3f}), Ref: {ref_stats['median']:.3f}ms (IQR: {ref_stats['iqr_low']:.3f}-{ref_stats['iqr_high']:.3f}), Speedup: {results['speedup']:.2f}x")
2857
3276
 
2858
3277
  # Run profiling if requested and correctness passed
2859
3278
  if args.profile and all_correct:
@@ -2898,6 +3317,16 @@ def main():
2898
3317
  json.dump(results, f, indent=2)
2899
3318
  print(f"[KernelBench] Results written to {args.output}")
2900
3319
 
3320
+ # Cleanup GPU memory
3321
+ try:
3322
+ del ref_model, new_model
3323
+ except NameError:
3324
+ pass
3325
+ import gc
3326
+ gc.collect()
3327
+ if torch.cuda.is_available():
3328
+ torch.cuda.empty_cache()
3329
+
2901
3330
  if __name__ == "__main__":
2902
3331
  main()
2903
3332
  """
@@ -3059,6 +3488,30 @@ async def run_evaluate_kernelbench_docker(
3059
3488
  error_message=f"Failed to write eval script: {write_result.stderr}",
3060
3489
  )
3061
3490
 
3491
+ # Write defense module if defensive mode is enabled
3492
+ defense_module_path = None
3493
+ if args.defensive:
3494
+ defense_path = (
3495
+ Path(__file__).parent.parent.parent.parent
3496
+ / "packages"
3497
+ / "wafer-core"
3498
+ / "wafer_core"
3499
+ / "utils"
3500
+ / "kernel_utils"
3501
+ / "defense.py"
3502
+ )
3503
+ if defense_path.exists():
3504
+ defense_code = defense_path.read_text()
3505
+ defense_module_path = f"{run_path}/defense.py"
3506
+ write_result = await client.exec(
3507
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3508
+ )
3509
+ if write_result.exit_code != 0:
3510
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3511
+ defense_module_path = None
3512
+ else:
3513
+ print(f"Warning: defense.py not found at {defense_path}")
3514
+
3062
3515
  print("Running KernelBench evaluation in Docker container...")
3063
3516
 
3064
3517
  # Paths inside container
@@ -3068,6 +3521,7 @@ async def run_evaluate_kernelbench_docker(
3068
3521
  container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
3069
3522
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3070
3523
  container_output = f"{container_run_path}/results.json"
3524
+ container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
3071
3525
 
3072
3526
  # Build eval command
3073
3527
  python_cmd_parts = [
@@ -3083,6 +3537,9 @@ async def run_evaluate_kernelbench_docker(
3083
3537
  python_cmd_parts.append("--profile")
3084
3538
  if container_inputs_path:
3085
3539
  python_cmd_parts.append(f"--inputs {container_inputs_path}")
3540
+ if args.defensive and container_defense_path:
3541
+ python_cmd_parts.append("--defensive")
3542
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3086
3543
  python_cmd_parts.append(f"--seed {args.seed}")
3087
3544
 
3088
3545
  eval_cmd = " ".join(python_cmd_parts)
@@ -3106,7 +3563,7 @@ async def run_evaluate_kernelbench_docker(
3106
3563
  # Run and stream output
3107
3564
  log_lines = []
3108
3565
  async for line in client.exec_stream(docker_cmd):
3109
- print(line)
3566
+ print(line, flush=True)
3110
3567
  log_lines.append(line)
3111
3568
 
3112
3569
  # Read results
@@ -3298,15 +3755,44 @@ async def run_evaluate_kernelbench_digitalocean(
3298
3755
  error_message=f"Failed to write eval script: {write_result.stderr}",
3299
3756
  )
3300
3757
 
3758
+ # Write defense module if defensive mode is enabled
3759
+ defense_module_path = None
3760
+ if args.defensive:
3761
+ defense_path = (
3762
+ Path(__file__).parent.parent.parent.parent
3763
+ / "packages"
3764
+ / "wafer-core"
3765
+ / "wafer_core"
3766
+ / "utils"
3767
+ / "kernel_utils"
3768
+ / "defense.py"
3769
+ )
3770
+ if defense_path.exists():
3771
+ defense_code = defense_path.read_text()
3772
+ defense_module_path = f"{run_path}/defense.py"
3773
+ write_result = await client.exec(
3774
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3775
+ )
3776
+ if write_result.exit_code != 0:
3777
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3778
+ defense_module_path = None
3779
+ else:
3780
+ print(f"Warning: defense.py not found at {defense_path}")
3781
+
3301
3782
  print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
3302
3783
 
3303
3784
  # Paths inside container
3304
3785
  container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
3305
3786
  container_impl_path = f"{container_run_path}/implementation.py"
3306
3787
  container_ref_path = f"{container_run_path}/reference.py"
3307
- container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
3788
+ container_inputs_path = (
3789
+ f"{container_run_path}/custom_inputs.py" if args.inputs else None
3790
+ )
3308
3791
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3309
3792
  container_output = f"{container_run_path}/results.json"
3793
+ container_defense_path = (
3794
+ f"{container_run_path}/defense.py" if defense_module_path else None
3795
+ )
3310
3796
 
3311
3797
  # Build eval command
3312
3798
  python_cmd_parts = [
@@ -3322,6 +3808,9 @@ async def run_evaluate_kernelbench_digitalocean(
3322
3808
  python_cmd_parts.append("--profile")
3323
3809
  if container_inputs_path:
3324
3810
  python_cmd_parts.append(f"--inputs {container_inputs_path}")
3811
+ if args.defensive and container_defense_path:
3812
+ python_cmd_parts.append("--defensive")
3813
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3325
3814
  python_cmd_parts.append(f"--seed {args.seed}")
3326
3815
 
3327
3816
  eval_cmd = " ".join(python_cmd_parts)
@@ -3346,7 +3835,7 @@ async def run_evaluate_kernelbench_digitalocean(
3346
3835
  # Run and stream output
3347
3836
  log_lines = []
3348
3837
  async for line in client.exec_stream(docker_cmd):
3349
- print(line)
3838
+ print(line, flush=True)
3350
3839
  log_lines.append(line)
3351
3840
 
3352
3841
  # Read results
@@ -3407,71 +3896,544 @@ async def run_evaluate_kernelbench_digitalocean(
3407
3896
  )
3408
3897
 
3409
3898
 
3410
- async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
3411
- """Run KernelBench format evaluation on configured target.
3412
-
3413
- Args:
3414
- args: KernelBench evaluate arguments
3899
+ async def run_evaluate_kernelbench_runpod(
3900
+ args: KernelBenchEvaluateArgs,
3901
+ target: RunPodTarget,
3902
+ ) -> EvaluateResult:
3903
+ """Run KernelBench format evaluation directly on RunPod AMD GPU.
3415
3904
 
3416
- Returns:
3417
- Evaluation result
3905
+ Runs evaluation script directly on host (no Docker) since RunPod pods
3906
+ already have PyTorch/ROCm installed.
3418
3907
  """
3419
- from .targets import get_default_target, load_target
3908
+ from datetime import datetime
3420
3909
 
3421
- # Validate input files
3422
- err = _validate_kernelbench_files(args)
3423
- if err:
3424
- return EvaluateResult(
3425
- success=False,
3426
- all_correct=False,
3427
- correctness_score=0.0,
3428
- geomean_speedup=0.0,
3429
- passed_tests=0,
3430
- total_tests=0,
3431
- error_message=err,
3432
- )
3910
+ from wafer_core.async_ssh import AsyncSSHClient
3911
+ from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
3433
3912
 
3434
- # Load target
3435
- target_name = args.target_name
3436
- if not target_name:
3437
- target_name = get_default_target()
3438
- if not target_name:
3439
- return EvaluateResult(
3440
- success=False,
3441
- all_correct=False,
3442
- correctness_score=0.0,
3443
- geomean_speedup=0.0,
3444
- passed_tests=0,
3445
- total_tests=0,
3446
- error_message=(
3447
- "No target specified and no default set.\n"
3448
- "Set up a target first:\n"
3449
- " wafer config targets init ssh --name my-gpu --host user@host:22\n"
3450
- " wafer config targets init runpod --gpu MI300X\n"
3451
- "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
3452
- ),
3453
- )
3913
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
3454
3914
 
3455
- try:
3456
- target = load_target(target_name)
3457
- except FileNotFoundError:
3458
- return EvaluateResult(
3459
- success=False,
3460
- all_correct=False,
3461
- correctness_score=0.0,
3462
- geomean_speedup=0.0,
3463
- passed_tests=0,
3464
- total_tests=0,
3465
- error_message=f"Target not found: {target_name}. Run: wafer config targets list",
3466
- )
3915
+ # Select GPU
3916
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
3467
3917
 
3468
- print(f"Using target: {target_name}")
3918
+ print(f"Provisioning RunPod ({target.gpu_type_id})...")
3469
3919
 
3470
- # Dispatch to appropriate executor
3471
- if isinstance(target, DigitalOceanTarget):
3920
+ try:
3921
+ async with runpod_ssh_context(target) as ssh_info:
3922
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
3923
+ print(f"Connected to RunPod: {ssh_target}")
3924
+
3925
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
3926
+ # Create workspace
3927
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3928
+ run_dir = f"kernelbench_eval_{timestamp}"
3929
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
3930
+
3931
+ await client.exec(f"mkdir -p {run_path}")
3932
+ print(f"Created run directory: {run_path}")
3933
+
3934
+ # Read and upload files
3935
+ impl_code = args.implementation.read_text()
3936
+ ref_code = args.reference.read_text()
3937
+
3938
+ # Write implementation
3939
+ impl_path = f"{run_path}/implementation.py"
3940
+ write_result = await client.exec(
3941
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
3942
+ )
3943
+ if write_result.exit_code != 0:
3944
+ return EvaluateResult(
3945
+ success=False,
3946
+ all_correct=False,
3947
+ correctness_score=0.0,
3948
+ geomean_speedup=0.0,
3949
+ passed_tests=0,
3950
+ total_tests=0,
3951
+ error_message=f"Failed to write implementation: {write_result.stderr}",
3952
+ )
3953
+
3954
+ # Write reference
3955
+ ref_path = f"{run_path}/reference.py"
3956
+ write_result = await client.exec(
3957
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
3958
+ )
3959
+ if write_result.exit_code != 0:
3960
+ return EvaluateResult(
3961
+ success=False,
3962
+ all_correct=False,
3963
+ correctness_score=0.0,
3964
+ geomean_speedup=0.0,
3965
+ passed_tests=0,
3966
+ total_tests=0,
3967
+ error_message=f"Failed to write reference: {write_result.stderr}",
3968
+ )
3969
+
3970
+ # Write custom inputs if provided
3971
+ inputs_path = None
3972
+ if args.inputs:
3973
+ inputs_code = args.inputs.read_text()
3974
+ inputs_path = f"{run_path}/custom_inputs.py"
3975
+ write_result = await client.exec(
3976
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3977
+ )
3978
+ if write_result.exit_code != 0:
3979
+ return EvaluateResult(
3980
+ success=False,
3981
+ all_correct=False,
3982
+ correctness_score=0.0,
3983
+ geomean_speedup=0.0,
3984
+ passed_tests=0,
3985
+ total_tests=0,
3986
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3987
+ )
3988
+
3989
+ # Write eval script
3990
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
3991
+ write_result = await client.exec(
3992
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
3993
+ )
3994
+ if write_result.exit_code != 0:
3995
+ return EvaluateResult(
3996
+ success=False,
3997
+ all_correct=False,
3998
+ correctness_score=0.0,
3999
+ geomean_speedup=0.0,
4000
+ passed_tests=0,
4001
+ total_tests=0,
4002
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4003
+ )
4004
+
4005
+ # Write defense module if defensive mode is enabled
4006
+ defense_module_path = None
4007
+ if args.defensive:
4008
+ defense_path = (
4009
+ Path(__file__).parent.parent.parent.parent
4010
+ / "packages"
4011
+ / "wafer-core"
4012
+ / "wafer_core"
4013
+ / "utils"
4014
+ / "kernel_utils"
4015
+ / "defense.py"
4016
+ )
4017
+ if defense_path.exists():
4018
+ defense_code = defense_path.read_text()
4019
+ defense_module_path = f"{run_path}/defense.py"
4020
+ write_result = await client.exec(
4021
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4022
+ )
4023
+ if write_result.exit_code != 0:
4024
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4025
+ defense_module_path = None
4026
+ else:
4027
+ print(f"Warning: defense.py not found at {defense_path}")
4028
+
4029
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4030
+
4031
+ # Find Python with PyTorch - check common locations on RunPod
4032
+ python_exe = "python3"
4033
+ for candidate in [
4034
+ "/opt/conda/envs/py_3.10/bin/python3",
4035
+ "/opt/conda/bin/python3",
4036
+ ]:
4037
+ check = await client.exec(
4038
+ f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
4039
+ )
4040
+ if "OK" in check.stdout:
4041
+ python_exe = candidate
4042
+ print(f"Using Python: {python_exe}")
4043
+ break
4044
+
4045
+ # Build eval command - run directly on host
4046
+ output_path = f"{run_path}/results.json"
4047
+ python_cmd_parts = [
4048
+ f"{python_exe} {eval_script_path}",
4049
+ f"--impl {impl_path}",
4050
+ f"--reference {ref_path}",
4051
+ f"--output {output_path}",
4052
+ ]
4053
+
4054
+ if args.benchmark:
4055
+ python_cmd_parts.append("--benchmark")
4056
+ if args.profile:
4057
+ python_cmd_parts.append("--profile")
4058
+ if inputs_path:
4059
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4060
+ if args.defensive and defense_module_path:
4061
+ python_cmd_parts.append("--defensive")
4062
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4063
+ python_cmd_parts.append(f"--seed {args.seed}")
4064
+
4065
+ eval_cmd = " ".join(python_cmd_parts)
4066
+
4067
+ # Set environment for AMD GPU and run
4068
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4069
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4070
+
4071
+ # Run and stream output
4072
+ log_lines = []
4073
+ async for line in client.exec_stream(full_cmd):
4074
+ print(line, flush=True)
4075
+ log_lines.append(line)
4076
+
4077
+ # Read results
4078
+ cat_result = await client.exec(f"cat {output_path}")
4079
+
4080
+ if cat_result.exit_code != 0:
4081
+ log_tail = "\n".join(log_lines[-50:])
4082
+ return EvaluateResult(
4083
+ success=False,
4084
+ all_correct=False,
4085
+ correctness_score=0.0,
4086
+ geomean_speedup=0.0,
4087
+ passed_tests=0,
4088
+ total_tests=0,
4089
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4090
+ )
4091
+
4092
+ # Parse results
4093
+ try:
4094
+ results_data = json.loads(cat_result.stdout)
4095
+ except json.JSONDecodeError as e:
4096
+ return EvaluateResult(
4097
+ success=False,
4098
+ all_correct=False,
4099
+ correctness_score=0.0,
4100
+ geomean_speedup=0.0,
4101
+ passed_tests=0,
4102
+ total_tests=0,
4103
+ error_message=f"Failed to parse results: {e}",
4104
+ )
4105
+
4106
+ # Convert to EvaluateResult
4107
+ correct = results_data.get("correct", False)
4108
+ speedup = results_data.get("speedup", 0.0) or 0.0
4109
+ error = results_data.get("error")
4110
+
4111
+ if error:
4112
+ return EvaluateResult(
4113
+ success=False,
4114
+ all_correct=False,
4115
+ correctness_score=0.0,
4116
+ geomean_speedup=0.0,
4117
+ passed_tests=0,
4118
+ total_tests=1,
4119
+ error_message=error,
4120
+ )
4121
+
4122
+ return EvaluateResult(
4123
+ success=True,
4124
+ all_correct=correct,
4125
+ correctness_score=1.0 if correct else 0.0,
4126
+ geomean_speedup=speedup,
4127
+ passed_tests=1 if correct else 0,
4128
+ total_tests=1,
4129
+ )
4130
+
4131
+ except RunPodError as e:
4132
+ return EvaluateResult(
4133
+ success=False,
4134
+ all_correct=False,
4135
+ correctness_score=0.0,
4136
+ geomean_speedup=0.0,
4137
+ passed_tests=0,
4138
+ total_tests=0,
4139
+ error_message=f"RunPod error: {e}",
4140
+ )
4141
+
4142
+
4143
+ async def run_evaluate_kernelbench_baremetal_amd(
4144
+ args: KernelBenchEvaluateArgs,
4145
+ target: BaremetalTarget,
4146
+ ) -> EvaluateResult:
4147
+ """Run KernelBench format evaluation directly on AMD baremetal target.
4148
+
4149
+ Runs evaluation script directly on host (no Docker) for AMD GPUs
4150
+ that have PyTorch/ROCm installed.
4151
+ """
4152
+ from datetime import datetime
4153
+
4154
+ from wafer_core.async_ssh import AsyncSSHClient
4155
+
4156
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
4157
+
4158
+ # Select GPU
4159
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
4160
+
4161
+ print(f"Connecting to {target.ssh_target}...")
4162
+
4163
+ async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
4164
+ # Create workspace
4165
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4166
+ run_dir = f"kernelbench_eval_{timestamp}"
4167
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
4168
+
4169
+ await client.exec(f"mkdir -p {run_path}")
4170
+ print(f"Created run directory: {run_path}")
4171
+
4172
+ # Read and upload files
4173
+ impl_code = args.implementation.read_text()
4174
+ ref_code = args.reference.read_text()
4175
+
4176
+ # Write implementation
4177
+ impl_path = f"{run_path}/implementation.py"
4178
+ write_result = await client.exec(
4179
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4180
+ )
4181
+ if write_result.exit_code != 0:
4182
+ return EvaluateResult(
4183
+ success=False,
4184
+ all_correct=False,
4185
+ correctness_score=0.0,
4186
+ geomean_speedup=0.0,
4187
+ passed_tests=0,
4188
+ total_tests=0,
4189
+ error_message=f"Failed to write implementation: {write_result.stderr}",
4190
+ )
4191
+
4192
+ # Write reference
4193
+ ref_path = f"{run_path}/reference.py"
4194
+ write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
4195
+ if write_result.exit_code != 0:
4196
+ return EvaluateResult(
4197
+ success=False,
4198
+ all_correct=False,
4199
+ correctness_score=0.0,
4200
+ geomean_speedup=0.0,
4201
+ passed_tests=0,
4202
+ total_tests=0,
4203
+ error_message=f"Failed to write reference: {write_result.stderr}",
4204
+ )
4205
+
4206
+ # Write custom inputs if provided
4207
+ inputs_path = None
4208
+ if args.inputs:
4209
+ inputs_code = args.inputs.read_text()
4210
+ inputs_path = f"{run_path}/custom_inputs.py"
4211
+ write_result = await client.exec(
4212
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4213
+ )
4214
+ if write_result.exit_code != 0:
4215
+ return EvaluateResult(
4216
+ success=False,
4217
+ all_correct=False,
4218
+ correctness_score=0.0,
4219
+ geomean_speedup=0.0,
4220
+ passed_tests=0,
4221
+ total_tests=0,
4222
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
4223
+ )
4224
+
4225
+ # Write eval script
4226
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
4227
+ write_result = await client.exec(
4228
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4229
+ )
4230
+ if write_result.exit_code != 0:
4231
+ return EvaluateResult(
4232
+ success=False,
4233
+ all_correct=False,
4234
+ correctness_score=0.0,
4235
+ geomean_speedup=0.0,
4236
+ passed_tests=0,
4237
+ total_tests=0,
4238
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4239
+ )
4240
+
4241
+ # Write defense module if defensive mode is enabled
4242
+ defense_module_path = None
4243
+ if args.defensive:
4244
+ defense_path = (
4245
+ Path(__file__).parent.parent.parent.parent
4246
+ / "packages"
4247
+ / "wafer-core"
4248
+ / "wafer_core"
4249
+ / "utils"
4250
+ / "kernel_utils"
4251
+ / "defense.py"
4252
+ )
4253
+ if defense_path.exists():
4254
+ defense_code = defense_path.read_text()
4255
+ defense_module_path = f"{run_path}/defense.py"
4256
+ write_result = await client.exec(
4257
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4258
+ )
4259
+ if write_result.exit_code != 0:
4260
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4261
+ defense_module_path = None
4262
+ else:
4263
+ print(f"Warning: defense.py not found at {defense_path}")
4264
+
4265
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4266
+
4267
+ # Find Python with PyTorch - check common locations
4268
+ python_exe = "python3"
4269
+ for candidate in [
4270
+ "/opt/conda/envs/py_3.10/bin/python3",
4271
+ "/opt/conda/bin/python3",
4272
+ ]:
4273
+ check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
4274
+ if "OK" in check.stdout:
4275
+ python_exe = candidate
4276
+ print(f"Using Python: {python_exe}")
4277
+ break
4278
+
4279
+ # Build eval command - run directly on host
4280
+ output_path = f"{run_path}/results.json"
4281
+ python_cmd_parts = [
4282
+ f"{python_exe} {eval_script_path}",
4283
+ f"--impl {impl_path}",
4284
+ f"--reference {ref_path}",
4285
+ f"--output {output_path}",
4286
+ ]
4287
+
4288
+ if args.benchmark:
4289
+ python_cmd_parts.append("--benchmark")
4290
+ if args.profile:
4291
+ python_cmd_parts.append("--profile")
4292
+ if inputs_path:
4293
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4294
+ if args.defensive and defense_module_path:
4295
+ python_cmd_parts.append("--defensive")
4296
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4297
+ python_cmd_parts.append(f"--seed {args.seed}")
4298
+
4299
+ eval_cmd = " ".join(python_cmd_parts)
4300
+
4301
+ # Set environment for AMD GPU and run
4302
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4303
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4304
+
4305
+ # Run and stream output
4306
+ log_lines = []
4307
+ async for line in client.exec_stream(full_cmd):
4308
+ print(line, flush=True)
4309
+ log_lines.append(line)
4310
+
4311
+ # Read results
4312
+ cat_result = await client.exec(f"cat {output_path}")
4313
+
4314
+ if cat_result.exit_code != 0:
4315
+ log_tail = "\n".join(log_lines[-50:])
4316
+ return EvaluateResult(
4317
+ success=False,
4318
+ all_correct=False,
4319
+ correctness_score=0.0,
4320
+ geomean_speedup=0.0,
4321
+ passed_tests=0,
4322
+ total_tests=0,
4323
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4324
+ )
4325
+
4326
+ # Parse results
4327
+ try:
4328
+ results_data = json.loads(cat_result.stdout)
4329
+ except json.JSONDecodeError as e:
4330
+ return EvaluateResult(
4331
+ success=False,
4332
+ all_correct=False,
4333
+ correctness_score=0.0,
4334
+ geomean_speedup=0.0,
4335
+ passed_tests=0,
4336
+ total_tests=0,
4337
+ error_message=f"Failed to parse results: {e}",
4338
+ )
4339
+
4340
+ # Convert to EvaluateResult
4341
+ correct = results_data.get("correct", False)
4342
+ speedup = results_data.get("speedup", 0.0) or 0.0
4343
+ error = results_data.get("error")
4344
+
4345
+ if error:
4346
+ return EvaluateResult(
4347
+ success=False,
4348
+ all_correct=False,
4349
+ correctness_score=0.0,
4350
+ geomean_speedup=0.0,
4351
+ passed_tests=0,
4352
+ total_tests=1,
4353
+ error_message=error,
4354
+ )
4355
+
4356
+ return EvaluateResult(
4357
+ success=True,
4358
+ all_correct=correct,
4359
+ correctness_score=1.0 if correct else 0.0,
4360
+ geomean_speedup=speedup,
4361
+ passed_tests=1 if correct else 0,
4362
+ total_tests=1,
4363
+ )
4364
+
4365
+
4366
+ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
4367
+ """Run KernelBench format evaluation on configured target.
4368
+
4369
+ Args:
4370
+ args: KernelBench evaluate arguments
4371
+
4372
+ Returns:
4373
+ Evaluation result
4374
+ """
4375
+ from .targets import get_default_target, load_target
4376
+
4377
+ # Validate input files
4378
+ err = _validate_kernelbench_files(args)
4379
+ if err:
4380
+ return EvaluateResult(
4381
+ success=False,
4382
+ all_correct=False,
4383
+ correctness_score=0.0,
4384
+ geomean_speedup=0.0,
4385
+ passed_tests=0,
4386
+ total_tests=0,
4387
+ error_message=err,
4388
+ )
4389
+
4390
+ # Load target
4391
+ target_name = args.target_name
4392
+ if not target_name:
4393
+ target_name = get_default_target()
4394
+ if not target_name:
4395
+ return EvaluateResult(
4396
+ success=False,
4397
+ all_correct=False,
4398
+ correctness_score=0.0,
4399
+ geomean_speedup=0.0,
4400
+ passed_tests=0,
4401
+ total_tests=0,
4402
+ error_message=(
4403
+ "No target specified and no default set.\n"
4404
+ "Set up a target first:\n"
4405
+ " wafer config targets init ssh --name my-gpu --host user@host:22\n"
4406
+ " wafer config targets init runpod --gpu MI300X\n"
4407
+ "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
4408
+ ),
4409
+ )
4410
+
4411
+ try:
4412
+ target = load_target(target_name)
4413
+ except FileNotFoundError:
4414
+ return EvaluateResult(
4415
+ success=False,
4416
+ all_correct=False,
4417
+ correctness_score=0.0,
4418
+ geomean_speedup=0.0,
4419
+ passed_tests=0,
4420
+ total_tests=0,
4421
+ error_message=f"Target not found: {target_name}. Run: wafer config targets list",
4422
+ )
4423
+
4424
+ print(f"Using target: {target_name}")
4425
+
4426
+ # Dispatch to appropriate executor
4427
+ if isinstance(target, DigitalOceanTarget):
3472
4428
  # DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
3473
4429
  return await run_evaluate_kernelbench_digitalocean(args, target)
4430
+ elif isinstance(target, RunPodTarget):
4431
+ # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4432
+ return await run_evaluate_kernelbench_runpod(args, target)
3474
4433
  elif isinstance(target, BaremetalTarget | VMTarget):
4434
+ # Check if this is an AMD target (gfx* compute capability) - run directly
4435
+ if target.compute_capability and target.compute_capability.startswith("gfx"):
4436
+ return await run_evaluate_kernelbench_baremetal_amd(args, target)
3475
4437
  # NVIDIA targets - require docker_image to be set
3476
4438
  if not target.docker_image:
3477
4439
  return EvaluateResult(
@@ -3497,6 +4459,6 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
3497
4459
  total_tests=0,
3498
4460
  error_message=(
3499
4461
  f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
3500
- "Use a DigitalOcean, Baremetal, or VM target."
4462
+ "Use a DigitalOcean, RunPod, Baremetal, or VM target."
3501
4463
  ),
3502
4464
  )