wafer-cli 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/evaluate.py CHANGED
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
14
14
  from wafer_core.utils.kernel_utils.targets.config import (
15
15
  BaremetalTarget,
16
16
  DigitalOceanTarget,
17
+ LocalTarget,
17
18
  ModalTarget,
18
19
  RunPodTarget,
19
20
  VMTarget,
@@ -158,6 +159,8 @@ class KernelBenchEvaluateArgs:
158
159
  target_name: str
159
160
  benchmark: bool = False
160
161
  profile: bool = False
162
+ inputs: Path | None = None # Custom inputs file to override get_inputs()
163
+ seed: int = 42 # Random seed for reproducibility
161
164
  defensive: bool = False
162
165
  sync_artifacts: bool = True
163
166
  gpu_id: int | None = None
@@ -394,33 +397,6 @@ async def run_evaluate_docker(
394
397
  print(f"Connecting to {target.ssh_target}...")
395
398
 
396
399
  async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
397
- # Upload wafer-core to remote
398
- try:
399
- wafer_root = _get_wafer_root()
400
- wafer_core_path = wafer_root / "packages" / "wafer-core"
401
- print(f"Uploading wafer-core from {wafer_core_path}...")
402
-
403
- # Create workspace and upload
404
- workspace_name = wafer_core_path.name
405
- remote_workspace = f"{REMOTE_WORKSPACE_BASE}/{workspace_name}"
406
- await client.exec(f"mkdir -p {remote_workspace}")
407
- wafer_core_workspace = await client.expand_path(remote_workspace)
408
-
409
- upload_result = await client.upload_files(
410
- str(wafer_core_path), wafer_core_workspace, recursive=True
411
- )
412
- print(f"Uploaded {upload_result.files_copied} files")
413
- except Exception as e:
414
- return EvaluateResult(
415
- success=False,
416
- all_correct=False,
417
- correctness_score=0.0,
418
- geomean_speedup=0.0,
419
- passed_tests=0,
420
- total_tests=0,
421
- error_message=f"Failed to upload wafer-core: {e}",
422
- )
423
-
424
400
  print(f"Using Docker image: {target.docker_image}")
425
401
  print(f"Using GPU {gpu_id}...")
426
402
 
@@ -429,10 +405,13 @@ async def run_evaluate_docker(
429
405
  ref_code = args.reference.read_text()
430
406
  test_cases_data = json.loads(args.test_cases.read_text())
431
407
 
432
- # Create a unique run directory
408
+ # Create workspace for evaluation files
433
409
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
434
410
  run_dir = f"wafer_eval_{timestamp}"
435
- run_path = f"{wafer_core_workspace}/{run_dir}"
411
+ eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
412
+ await client.exec(f"mkdir -p {eval_workspace}")
413
+ eval_workspace_expanded = await client.expand_path(eval_workspace)
414
+ run_path = f"{eval_workspace_expanded}/{run_dir}"
436
415
 
437
416
  print("Uploading evaluation files...")
438
417
 
@@ -519,17 +498,16 @@ async def run_evaluate_docker(
519
498
  container_impl_path = f"{container_run_path}/implementation.py"
520
499
  container_ref_path = f"{container_run_path}/reference.py"
521
500
  container_test_cases_path = f"{container_run_path}/test_cases.json"
522
- container_evaluate_script = (
523
- f"{CONTAINER_WORKSPACE}/wafer_core/utils/kernel_utils/evaluate.py"
524
- )
525
501
 
526
- # Build pip install command for torch and other deps (no wafer-core install needed)
502
+ # Build pip install command for torch and other deps, plus wafer-core
527
503
  pip_install_cmd = _build_docker_pip_install_cmd(target)
504
+ install_cmd = (
505
+ f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
506
+ )
528
507
 
529
- # Build evaluate command - use PYTHONPATH instead of installing wafer-core
508
+ # Build evaluate command using installed wafer-core module
530
509
  python_cmd_parts = [
531
- f"PYTHONPATH={CONTAINER_WORKSPACE}:$PYTHONPATH",
532
- f"python3 {container_evaluate_script}",
510
+ "python3 -m wafer_core.utils.kernel_utils.evaluate",
533
511
  f"--implementation {container_impl_path}",
534
512
  f"--reference {container_ref_path}",
535
513
  f"--test-cases {container_test_cases_path}",
@@ -545,8 +523,8 @@ async def run_evaluate_docker(
545
523
 
546
524
  eval_cmd = " ".join(python_cmd_parts)
547
525
 
548
- # Full command: install torch deps, then run evaluate with PYTHONPATH
549
- full_cmd = f"{pip_install_cmd} && cd {container_run_path} && {eval_cmd}"
526
+ # Full command: install deps + wafer-core, then run evaluate
527
+ full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
550
528
 
551
529
  # Build Docker run command
552
530
  # Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
@@ -556,7 +534,7 @@ async def run_evaluate_docker(
556
534
  working_dir=container_run_path,
557
535
  env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
558
536
  gpus="all",
559
- volumes={wafer_core_workspace: CONTAINER_WORKSPACE},
537
+ volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
560
538
  cap_add=["SYS_ADMIN"] if args.profile else None,
561
539
  )
562
540
 
@@ -663,6 +641,181 @@ async def run_evaluate_docker(
663
641
  )
664
642
 
665
643
 
644
+ async def run_evaluate_local(
645
+ args: EvaluateArgs,
646
+ target: LocalTarget,
647
+ ) -> EvaluateResult:
648
+ """Run evaluation locally on the current machine.
649
+
650
+ For LocalTarget - no SSH needed, runs directly.
651
+
652
+ Args:
653
+ args: Evaluate arguments
654
+ target: Local target config
655
+
656
+ Returns:
657
+ Evaluation result
658
+ """
659
+ import os
660
+ import subprocess
661
+ import tempfile
662
+ from datetime import datetime
663
+
664
+ # Select GPU
665
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
666
+
667
+ print(f"Running local evaluation on GPU {gpu_id}...")
668
+
669
+ # Create temp directory for eval files
670
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
671
+ with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
672
+ run_path = Path(run_path)
673
+
674
+ # Write implementation
675
+ impl_path = run_path / "implementation.py"
676
+ impl_path.write_text(args.implementation.read_text())
677
+
678
+ # Write reference
679
+ ref_path = run_path / "reference.py"
680
+ ref_path.write_text(args.reference.read_text())
681
+
682
+ # Write custom inputs if provided
683
+ inputs_path = None
684
+ if args.inputs:
685
+ inputs_path = run_path / "custom_inputs.py"
686
+ inputs_path.write_text(args.inputs.read_text())
687
+
688
+ # Write eval script
689
+ eval_script_path = run_path / "kernelbench_eval.py"
690
+ eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
691
+
692
+ # Write defense module if defensive mode is enabled
693
+ defense_module_path = None
694
+ if args.defensive:
695
+ defense_src = (
696
+ Path(__file__).parent.parent.parent.parent
697
+ / "packages"
698
+ / "wafer-core"
699
+ / "wafer_core"
700
+ / "utils"
701
+ / "kernel_utils"
702
+ / "defense.py"
703
+ )
704
+ if defense_src.exists():
705
+ defense_module_path = run_path / "defense.py"
706
+ defense_module_path.write_text(defense_src.read_text())
707
+ else:
708
+ print(f"Warning: defense.py not found at {defense_src}")
709
+
710
+ # Output file
711
+ output_path = run_path / "results.json"
712
+
713
+ # Build eval command
714
+ cmd_parts = [
715
+ "python3",
716
+ str(eval_script_path),
717
+ "--impl",
718
+ str(impl_path),
719
+ "--reference",
720
+ str(ref_path),
721
+ "--output",
722
+ str(output_path),
723
+ "--seed",
724
+ str(args.seed),
725
+ ]
726
+
727
+ if args.benchmark:
728
+ cmd_parts.append("--benchmark")
729
+ if args.profile:
730
+ cmd_parts.append("--profile")
731
+ if inputs_path:
732
+ cmd_parts.extend(["--inputs", str(inputs_path)])
733
+ if args.defensive and defense_module_path:
734
+ cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
735
+
736
+ # Set environment for GPU selection
737
+ env = os.environ.copy()
738
+ if target.vendor == "nvidia":
739
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
740
+ else: # AMD
741
+ env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
742
+ env["ROCM_PATH"] = "/opt/rocm"
743
+
744
+ print(f"Running: {' '.join(cmd_parts[:4])} ...")
745
+
746
+ # Run evaluation
747
+ try:
748
+ result = subprocess.run(
749
+ cmd_parts,
750
+ cwd=str(run_path),
751
+ env=env,
752
+ capture_output=True,
753
+ text=True,
754
+ timeout=args.timeout or 600,
755
+ )
756
+ except subprocess.TimeoutExpired:
757
+ return EvaluateResult(
758
+ success=False,
759
+ all_correct=False,
760
+ correctness_score=0.0,
761
+ geomean_speedup=0.0,
762
+ passed_tests=0,
763
+ total_tests=0,
764
+ error_message="Evaluation timed out",
765
+ )
766
+
767
+ if result.returncode != 0:
768
+ error_msg = result.stderr or result.stdout or "Unknown error"
769
+ # Truncate long errors
770
+ if len(error_msg) > 1000:
771
+ error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
772
+ return EvaluateResult(
773
+ success=False,
774
+ all_correct=False,
775
+ correctness_score=0.0,
776
+ geomean_speedup=0.0,
777
+ passed_tests=0,
778
+ total_tests=0,
779
+ error_message=f"Evaluation failed:\n{error_msg}",
780
+ )
781
+
782
+ # Parse results
783
+ if not output_path.exists():
784
+ return EvaluateResult(
785
+ success=False,
786
+ all_correct=False,
787
+ correctness_score=0.0,
788
+ geomean_speedup=0.0,
789
+ passed_tests=0,
790
+ total_tests=0,
791
+ error_message="No results.json produced",
792
+ )
793
+
794
+ try:
795
+ results = json.loads(output_path.read_text())
796
+ except json.JSONDecodeError as e:
797
+ return EvaluateResult(
798
+ success=False,
799
+ all_correct=False,
800
+ correctness_score=0.0,
801
+ geomean_speedup=0.0,
802
+ passed_tests=0,
803
+ total_tests=0,
804
+ error_message=f"Failed to parse results: {e}",
805
+ )
806
+
807
+ # Extract results
808
+ return EvaluateResult(
809
+ success=True,
810
+ all_correct=results.get("all_correct", False),
811
+ correctness_score=results.get("correctness_score", 0.0),
812
+ geomean_speedup=results.get("geomean_speedup", 0.0),
813
+ passed_tests=results.get("passed_tests", 0),
814
+ total_tests=results.get("total_tests", 0),
815
+ benchmark_results=results.get("benchmark", {}),
816
+ )
817
+
818
+
666
819
  async def run_evaluate_ssh(
667
820
  args: EvaluateArgs,
668
821
  target: BaremetalTarget | VMTarget,
@@ -980,6 +1133,7 @@ def _build_modal_sandbox_script(
980
1133
  test_cases_b64: str,
981
1134
  run_benchmarks: bool,
982
1135
  run_defensive: bool,
1136
+ defense_code_b64: str | None = None,
983
1137
  ) -> str:
984
1138
  """Build Python script to create sandbox and run evaluation.
985
1139
 
@@ -1060,6 +1214,20 @@ print('Files written')
1060
1214
  print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
1061
1215
  return
1062
1216
 
1217
+ # Write defense module if defensive mode is enabled
1218
+ # NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
1219
+ if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
1220
+ proc = sandbox.exec("python", "-c", f"""
1221
+ import base64
1222
+ with open('/workspace/defense.py', 'w') as f:
1223
+ f.write(base64.b64decode('{defense_code_b64}').decode())
1224
+ print('Defense module written')
1225
+ """)
1226
+ proc.wait()
1227
+ if proc.returncode != 0:
1228
+ print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
1229
+ return
1230
+
1063
1231
  # Build inline evaluation script
1064
1232
  eval_script = """
1065
1233
  import json
@@ -1087,6 +1255,26 @@ generate_input = load_fn('reference.py', 'generate_input')
1087
1255
 
1088
1256
  import torch
1089
1257
 
1258
+ # Load defense module if available and defensive mode is enabled
1259
+ run_defensive = {run_defensive}
1260
+ defense = None
1261
+ if run_defensive:
1262
+ try:
1263
+ defense = load_fn('defense.py', 'run_all_defenses')
1264
+ time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
1265
+ print('[Defense] Defense module loaded')
1266
+
1267
+ # Wrap kernels for defense API compatibility
1268
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1269
+ # These wrappers repack the unpacked args back into a tuple
1270
+ def _wrap_for_defense(kernel):
1271
+ return lambda *args: kernel(args)
1272
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1273
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1274
+ except Exception as e:
1275
+ print(f'[Defense] Warning: Could not load defense module: {{e}}')
1276
+ defense = None
1277
+
1090
1278
  results = []
1091
1279
  all_correct = True
1092
1280
  total_time_ms = 0.0
@@ -1114,36 +1302,63 @@ for tc in test_cases:
1114
1302
  impl_time_ms = 0.0
1115
1303
  ref_time_ms = 0.0
1116
1304
  if {run_benchmarks}:
1117
- # Warmup
1118
- for _ in range(3):
1119
- custom_kernel(inputs)
1120
- torch.cuda.synchronize()
1121
-
1122
- # Measure with defensive timing if requested
1123
- # Defensive: sync before recording end event to catch stream injection
1124
- start = torch.cuda.Event(enable_timing=True)
1125
- end = torch.cuda.Event(enable_timing=True)
1126
- start.record()
1127
- for _ in range(10):
1128
- custom_kernel(inputs)
1129
- if {run_defensive}:
1130
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1131
- end.record()
1132
- torch.cuda.synchronize()
1133
- impl_time_ms = start.elapsed_time(end) / 10
1134
-
1135
- # Reference timing (same defensive approach)
1136
- for _ in range(3):
1137
- ref_kernel(inputs)
1138
- torch.cuda.synchronize()
1139
- start.record()
1140
- for _ in range(10):
1141
- ref_kernel(inputs)
1142
- if {run_defensive}:
1143
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1144
- end.record()
1145
- torch.cuda.synchronize()
1146
- ref_time_ms = start.elapsed_time(end) / 10
1305
+ if run_defensive and defense is not None:
1306
+ # Use full defense suite with wrapped kernels
1307
+ # inputs_list unpacks the tuple so defense can infer dtype/device from tensors
1308
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1309
+
1310
+ # Run defense checks
1311
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1312
+ if not all_passed:
1313
+ failed = [name for name, passed, _ in defense_results if not passed]
1314
+ raise ValueError(f"Defense checks failed: {{failed}}")
1315
+
1316
+ # Time with defensive timing (using wrapped kernels)
1317
+ impl_times, _ = time_with_defenses(
1318
+ custom_kernel_for_defense,
1319
+ inputs_list,
1320
+ num_warmup=3,
1321
+ num_trials=10,
1322
+ verbose=False,
1323
+ run_defenses=False,
1324
+ )
1325
+ impl_time_ms = sum(impl_times) / len(impl_times)
1326
+
1327
+ ref_times, _ = time_with_defenses(
1328
+ ref_kernel_for_defense,
1329
+ inputs_list,
1330
+ num_warmup=3,
1331
+ num_trials=10,
1332
+ verbose=False,
1333
+ run_defenses=False,
1334
+ )
1335
+ ref_time_ms = sum(ref_times) / len(ref_times)
1336
+ else:
1337
+ # Standard timing without full defenses
1338
+ # Warmup
1339
+ for _ in range(3):
1340
+ custom_kernel(inputs)
1341
+ torch.cuda.synchronize()
1342
+
1343
+ start = torch.cuda.Event(enable_timing=True)
1344
+ end = torch.cuda.Event(enable_timing=True)
1345
+ start.record()
1346
+ for _ in range(10):
1347
+ custom_kernel(inputs)
1348
+ end.record()
1349
+ torch.cuda.synchronize()
1350
+ impl_time_ms = start.elapsed_time(end) / 10
1351
+
1352
+ # Reference timing
1353
+ for _ in range(3):
1354
+ ref_kernel(inputs)
1355
+ torch.cuda.synchronize()
1356
+ start.record()
1357
+ for _ in range(10):
1358
+ ref_kernel(inputs)
1359
+ end.record()
1360
+ torch.cuda.synchronize()
1361
+ ref_time_ms = start.elapsed_time(end) / 10
1147
1362
 
1148
1363
  total_time_ms += impl_time_ms
1149
1364
  ref_total_time_ms += ref_time_ms
@@ -1236,6 +1451,23 @@ async def run_evaluate_modal(
1236
1451
  ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
1237
1452
  test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
1238
1453
 
1454
+ # Encode defense module if defensive mode is enabled
1455
+ defense_code_b64 = None
1456
+ if args.defensive:
1457
+ defense_path = (
1458
+ Path(__file__).parent.parent.parent.parent
1459
+ / "packages"
1460
+ / "wafer-core"
1461
+ / "wafer_core"
1462
+ / "utils"
1463
+ / "kernel_utils"
1464
+ / "defense.py"
1465
+ )
1466
+ if defense_path.exists():
1467
+ defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
1468
+ else:
1469
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1470
+
1239
1471
  # Build the script that creates sandbox and runs eval
1240
1472
  script = _build_modal_sandbox_script(
1241
1473
  target=target,
@@ -1244,6 +1476,7 @@ async def run_evaluate_modal(
1244
1476
  test_cases_b64=test_cases_b64,
1245
1477
  run_benchmarks=args.benchmark,
1246
1478
  run_defensive=args.defensive,
1479
+ defense_code_b64=defense_code_b64,
1247
1480
  )
1248
1481
 
1249
1482
  def _run_subprocess() -> tuple[str, str, int]:
@@ -1341,6 +1574,7 @@ def _build_workspace_eval_script(
1341
1574
  test_cases_json: str,
1342
1575
  run_benchmarks: bool,
1343
1576
  run_defensive: bool = False,
1577
+ defense_code: str | None = None,
1344
1578
  ) -> str:
1345
1579
  """Build inline evaluation script for workspace exec.
1346
1580
 
@@ -1351,6 +1585,7 @@ def _build_workspace_eval_script(
1351
1585
  impl_b64 = base64.b64encode(impl_code.encode()).decode()
1352
1586
  ref_b64 = base64.b64encode(ref_code.encode()).decode()
1353
1587
  tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
1588
+ defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
1354
1589
 
1355
1590
  return f'''
1356
1591
  import base64
@@ -1370,6 +1605,15 @@ with open("/tmp/kernel.py", "w") as f:
1370
1605
  with open("/tmp/reference.py", "w") as f:
1371
1606
  f.write(ref_code)
1372
1607
 
1608
+ # Write defense module if available
1609
+ run_defensive = {run_defensive}
1610
+ defense_b64 = "{defense_b64}"
1611
+ # NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
1612
+ if run_defensive and defense_b64 and defense_b64 != "None":
1613
+ defense_code = base64.b64decode(defense_b64).decode()
1614
+ with open("/tmp/defense.py", "w") as f:
1615
+ f.write(defense_code)
1616
+
1373
1617
  # Load kernels
1374
1618
  def load_fn(path, name):
1375
1619
  spec = importlib.util.spec_from_file_location("mod", path)
@@ -1383,6 +1627,24 @@ generate_input = load_fn("/tmp/reference.py", "generate_input")
1383
1627
 
1384
1628
  import torch
1385
1629
 
1630
+ # Load defense module if available
1631
+ defense = None
1632
+ if run_defensive and defense_b64 and defense_b64 != "None":
1633
+ try:
1634
+ defense = load_fn("/tmp/defense.py", "run_all_defenses")
1635
+ time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
1636
+ print("[Defense] Defense module loaded")
1637
+
1638
+ # Wrap kernels for defense API compatibility
1639
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1640
+ def _wrap_for_defense(kernel):
1641
+ return lambda *args: kernel(args)
1642
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1643
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1644
+ except Exception as e:
1645
+ print(f"[Defense] Warning: Could not load defense module: {{e}}")
1646
+ defense = None
1647
+
1386
1648
  results = []
1387
1649
  all_correct = True
1388
1650
  total_time_ms = 0.0
@@ -1410,36 +1672,60 @@ for tc in test_cases:
1410
1672
  impl_time_ms = 0.0
1411
1673
  ref_time_ms = 0.0
1412
1674
  if {run_benchmarks}:
1413
- # Warmup
1414
- for _ in range(3):
1415
- custom_kernel(inputs)
1416
- torch.cuda.synchronize()
1417
-
1418
- # Measure with defensive timing if requested
1419
- # Defensive: sync before recording end event to catch stream injection
1420
- start = torch.cuda.Event(enable_timing=True)
1421
- end = torch.cuda.Event(enable_timing=True)
1422
- start.record()
1423
- for _ in range(10):
1424
- custom_kernel(inputs)
1425
- if {run_defensive}:
1426
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1427
- end.record()
1428
- torch.cuda.synchronize()
1429
- impl_time_ms = start.elapsed_time(end) / 10
1430
-
1431
- # Reference timing (same defensive approach)
1432
- for _ in range(3):
1433
- ref_kernel(inputs)
1434
- torch.cuda.synchronize()
1435
- start.record()
1436
- for _ in range(10):
1437
- ref_kernel(inputs)
1438
- if {run_defensive}:
1439
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1440
- end.record()
1441
- torch.cuda.synchronize()
1442
- ref_time_ms = start.elapsed_time(end) / 10
1675
+ if run_defensive and defense is not None:
1676
+ # Use full defense suite with wrapped kernels
1677
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1678
+
1679
+ # Run defense checks
1680
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1681
+ if not all_passed:
1682
+ failed = [name for name, passed, _ in defense_results if not passed]
1683
+ raise ValueError(f"Defense checks failed: {{failed}}")
1684
+
1685
+ # Time with defensive timing (using wrapped kernels)
1686
+ impl_times, _ = time_with_defenses(
1687
+ custom_kernel_for_defense,
1688
+ inputs_list,
1689
+ num_warmup=3,
1690
+ num_trials=10,
1691
+ verbose=False,
1692
+ run_defenses=False,
1693
+ )
1694
+ impl_time_ms = sum(impl_times) / len(impl_times)
1695
+
1696
+ ref_times, _ = time_with_defenses(
1697
+ ref_kernel_for_defense,
1698
+ inputs_list,
1699
+ num_warmup=3,
1700
+ num_trials=10,
1701
+ verbose=False,
1702
+ run_defenses=False,
1703
+ )
1704
+ ref_time_ms = sum(ref_times) / len(ref_times)
1705
+ else:
1706
+ # Standard timing
1707
+ for _ in range(3):
1708
+ custom_kernel(inputs)
1709
+ torch.cuda.synchronize()
1710
+
1711
+ start = torch.cuda.Event(enable_timing=True)
1712
+ end = torch.cuda.Event(enable_timing=True)
1713
+ start.record()
1714
+ for _ in range(10):
1715
+ custom_kernel(inputs)
1716
+ end.record()
1717
+ torch.cuda.synchronize()
1718
+ impl_time_ms = start.elapsed_time(end) / 10
1719
+
1720
+ for _ in range(3):
1721
+ ref_kernel(inputs)
1722
+ torch.cuda.synchronize()
1723
+ start.record()
1724
+ for _ in range(10):
1725
+ ref_kernel(inputs)
1726
+ end.record()
1727
+ torch.cuda.synchronize()
1728
+ ref_time_ms = start.elapsed_time(end) / 10
1443
1729
 
1444
1730
  total_time_ms += impl_time_ms
1445
1731
  ref_total_time_ms += ref_time_ms
@@ -1501,6 +1787,23 @@ async def run_evaluate_workspace(
1501
1787
  ref_code = args.reference.read_text()
1502
1788
  test_cases_json = args.test_cases.read_text()
1503
1789
 
1790
+ # Read defense module if defensive mode is enabled
1791
+ defense_code = None
1792
+ if args.defensive:
1793
+ defense_path = (
1794
+ Path(__file__).parent.parent.parent.parent
1795
+ / "packages"
1796
+ / "wafer-core"
1797
+ / "wafer_core"
1798
+ / "utils"
1799
+ / "kernel_utils"
1800
+ / "defense.py"
1801
+ )
1802
+ if defense_path.exists():
1803
+ defense_code = defense_path.read_text()
1804
+ else:
1805
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1806
+
1504
1807
  # Build inline eval script
1505
1808
  eval_script = _build_workspace_eval_script(
1506
1809
  impl_code=impl_code,
@@ -1508,6 +1811,7 @@ async def run_evaluate_workspace(
1508
1811
  test_cases_json=test_cases_json,
1509
1812
  run_benchmarks=args.benchmark,
1510
1813
  run_defensive=args.defensive,
1814
+ defense_code=defense_code,
1511
1815
  )
1512
1816
 
1513
1817
  # Execute via workspace exec
@@ -1853,15 +2157,12 @@ async def run_evaluate_runpod(
1853
2157
  # Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
1854
2158
  venv_bin = env_state.venv_bin
1855
2159
  env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
1856
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
1857
- evaluate_script = (
1858
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
1859
- )
1860
2160
 
1861
2161
  # Run from run_path so reference_kernel.py is importable
2162
+ # Use installed wafer-core module
1862
2163
  eval_cmd = (
1863
2164
  f"cd {run_path} && "
1864
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2165
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
1865
2166
  f"--implementation {impl_path} "
1866
2167
  f"--reference {ref_path} "
1867
2168
  f"--test-cases {test_cases_path} "
@@ -2217,15 +2518,12 @@ async def run_evaluate_digitalocean(
2217
2518
  env_vars = (
2218
2519
  f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
2219
2520
  )
2220
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
2221
- evaluate_script = (
2222
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
2223
- )
2224
2521
 
2225
2522
  # Run from run_path so reference_kernel.py is importable
2523
+ # Use installed wafer-core module
2226
2524
  eval_cmd = (
2227
2525
  f"cd {run_path} && "
2228
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2526
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
2229
2527
  f"--implementation {impl_path} "
2230
2528
  f"--reference {ref_path} "
2231
2529
  f"--test-cases {test_cases_path} "
@@ -2405,7 +2703,9 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2405
2703
  print(f"Using target: {target_name}")
2406
2704
 
2407
2705
  # Dispatch to appropriate executor
2408
- if isinstance(target, BaremetalTarget | VMTarget):
2706
+ if isinstance(target, LocalTarget):
2707
+ return await run_evaluate_local(args, target)
2708
+ elif isinstance(target, BaremetalTarget | VMTarget):
2409
2709
  return await run_evaluate_ssh(args, target)
2410
2710
  elif isinstance(target, ModalTarget):
2411
2711
  return await run_evaluate_modal(args, target)
@@ -2435,10 +2735,233 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2435
2735
  # This runs inside the Docker container on the remote GPU
2436
2736
  KERNELBENCH_EVAL_SCRIPT = """
2437
2737
  import json
2738
+ import os
2438
2739
  import sys
2439
2740
  import time
2440
2741
  import torch
2441
2742
  import torch.nn as nn
2743
+ from pathlib import Path
2744
+
2745
+
2746
+ def run_profiling(model, inputs, name, output_dir):
2747
+ '''Run torch.profiler and return summary stats.'''
2748
+ from torch.profiler import profile, ProfilerActivity
2749
+
2750
+ # Determine activities based on backend
2751
+ activities = [ProfilerActivity.CPU]
2752
+ if torch.cuda.is_available():
2753
+ activities.append(ProfilerActivity.CUDA)
2754
+
2755
+ # Warmup
2756
+ for _ in range(3):
2757
+ with torch.no_grad():
2758
+ _ = model(*inputs)
2759
+ torch.cuda.synchronize()
2760
+
2761
+ # Profile
2762
+ with profile(
2763
+ activities=activities,
2764
+ record_shapes=True,
2765
+ with_stack=False,
2766
+ profile_memory=True,
2767
+ ) as prof:
2768
+ with torch.no_grad():
2769
+ _ = model(*inputs)
2770
+ torch.cuda.synchronize()
2771
+
2772
+ # Get key averages
2773
+ key_averages = prof.key_averages()
2774
+
2775
+ # Find the main kernel (longest GPU time)
2776
+ # Use cuda_time_total for compatibility with both CUDA and ROCm
2777
+ def get_gpu_time(e):
2778
+ # Try different attributes for GPU time
2779
+ if hasattr(e, 'cuda_time_total'):
2780
+ return e.cuda_time_total
2781
+ if hasattr(e, 'device_time_total'):
2782
+ return e.device_time_total
2783
+ if hasattr(e, 'self_cuda_time_total'):
2784
+ return e.self_cuda_time_total
2785
+ return 0
2786
+
2787
+ gpu_events = [e for e in key_averages if get_gpu_time(e) > 0]
2788
+ gpu_events.sort(key=lambda e: get_gpu_time(e), reverse=True)
2789
+
2790
+ stats = {
2791
+ "name": name,
2792
+ "total_gpu_time_ms": sum(get_gpu_time(e) for e in gpu_events) / 1000,
2793
+ "total_cpu_time_ms": sum(e.cpu_time_total for e in key_averages) / 1000,
2794
+ "num_gpu_kernels": len(gpu_events),
2795
+ "top_kernels": [],
2796
+ }
2797
+
2798
+ # Top 5 kernels by GPU time
2799
+ for e in gpu_events[:5]:
2800
+ stats["top_kernels"].append({
2801
+ "name": e.key,
2802
+ "gpu_time_ms": get_gpu_time(e) / 1000,
2803
+ "cpu_time_ms": e.cpu_time_total / 1000,
2804
+ "calls": e.count,
2805
+ })
2806
+
2807
+ # Save trace for visualization
2808
+ trace_path = Path(output_dir) / f"{name}_trace.json"
2809
+ prof.export_chrome_trace(str(trace_path))
2810
+ stats["trace_file"] = str(trace_path)
2811
+
2812
+ return stats
2813
+
2814
+
2815
+ def validate_custom_inputs(original_inputs, custom_inputs):
2816
+ '''Validate that custom inputs match the expected signature.
2817
+
2818
+ Returns (is_valid, error_message).
2819
+ '''
2820
+ if len(original_inputs) != len(custom_inputs):
2821
+ return False, f"get_inputs() must return {len(original_inputs)} tensors, got {len(custom_inputs)}"
2822
+
2823
+ for i, (orig, cust) in enumerate(zip(original_inputs, custom_inputs)):
2824
+ if not isinstance(cust, torch.Tensor):
2825
+ if not isinstance(orig, torch.Tensor):
2826
+ continue # Both non-tensor, ok
2827
+ return False, f"Input {i}: expected Tensor, got {type(cust).__name__}"
2828
+
2829
+ if not isinstance(orig, torch.Tensor):
2830
+ return False, f"Input {i}: expected {type(orig).__name__}, got Tensor"
2831
+
2832
+ if orig.dtype != cust.dtype:
2833
+ return False, f"Input {i}: dtype mismatch - expected {orig.dtype}, got {cust.dtype}"
2834
+
2835
+ if orig.dim() != cust.dim():
2836
+ return False, f"Input {i}: dimension mismatch - expected {orig.dim()}D, got {cust.dim()}D"
2837
+
2838
+ return True, None
2839
+
2840
+
2841
+ def analyze_diff(ref_output, new_output, rtol=1e-3, atol=1e-3, max_samples=5):
2842
+ '''Analyze differences between reference and implementation outputs.
2843
+
2844
+ Returns a dict with detailed diff information.
2845
+ '''
2846
+ diff = (ref_output - new_output).abs()
2847
+ threshold = atol + rtol * ref_output.abs()
2848
+ wrong_mask = diff > threshold
2849
+
2850
+ total_elements = ref_output.numel()
2851
+ wrong_count = wrong_mask.sum().item()
2852
+
2853
+ # Basic stats
2854
+ max_diff = diff.max().item()
2855
+ max_diff_idx = tuple(torch.unravel_index(diff.argmax(), diff.shape))
2856
+ max_diff_idx = tuple(int(i) for i in max_diff_idx) # Convert to Python ints
2857
+
2858
+ # Relative error (avoid div by zero)
2859
+ ref_abs = ref_output.abs()
2860
+ nonzero_mask = ref_abs > 1e-8
2861
+ if nonzero_mask.any():
2862
+ rel_error = diff[nonzero_mask] / ref_abs[nonzero_mask]
2863
+ max_rel_error = rel_error.max().item()
2864
+ mean_rel_error = rel_error.mean().item()
2865
+ else:
2866
+ max_rel_error = float('inf') if max_diff > 0 else 0.0
2867
+ mean_rel_error = max_rel_error
2868
+
2869
+ # Error histogram (buckets: <1e-6, 1e-6 to 1e-4, 1e-4 to 1e-2, 1e-2 to 1, >1)
2870
+ histogram = {
2871
+ '<1e-6': int((diff < 1e-6).sum().item()),
2872
+ '1e-6 to 1e-4': int(((diff >= 1e-6) & (diff < 1e-4)).sum().item()),
2873
+ '1e-4 to 1e-2': int(((diff >= 1e-4) & (diff < 1e-2)).sum().item()),
2874
+ '1e-2 to 1': int(((diff >= 1e-2) & (diff < 1)).sum().item()),
2875
+ '>1': int((diff >= 1).sum().item()),
2876
+ }
2877
+
2878
+ result = {
2879
+ 'max_diff': max_diff,
2880
+ 'max_diff_idx': max_diff_idx,
2881
+ 'mean_diff': diff.mean().item(),
2882
+ 'max_rel_error': max_rel_error,
2883
+ 'mean_rel_error': mean_rel_error,
2884
+ 'total_elements': total_elements,
2885
+ 'wrong_count': int(wrong_count),
2886
+ 'wrong_pct': 100.0 * wrong_count / total_elements,
2887
+ 'histogram': histogram,
2888
+ 'samples': [],
2889
+ }
2890
+
2891
+ # Get indices of wrong elements
2892
+ if wrong_count > 0:
2893
+ wrong_indices = torch.nonzero(wrong_mask, as_tuple=False)
2894
+
2895
+ # Take first N samples
2896
+ num_samples = min(max_samples, len(wrong_indices))
2897
+ for i in range(num_samples):
2898
+ idx = tuple(wrong_indices[i].tolist())
2899
+ ref_val = ref_output[idx].item()
2900
+ new_val = new_output[idx].item()
2901
+ diff_val = diff[idx].item()
2902
+ result['samples'].append({
2903
+ 'index': idx,
2904
+ 'ref': ref_val,
2905
+ 'impl': new_val,
2906
+ 'diff': diff_val,
2907
+ })
2908
+
2909
+ # Try to detect pattern
2910
+ if wrong_count >= total_elements * 0.99:
2911
+ result['pattern'] = 'all_wrong'
2912
+ elif wrong_count < total_elements * 0.01:
2913
+ # Check if failures are at boundaries
2914
+ shape = ref_output.shape
2915
+ boundary_count = 0
2916
+ for idx in wrong_indices[:min(100, len(wrong_indices))]:
2917
+ idx_list = idx.tolist()
2918
+ is_boundary = any(i == 0 or i == s - 1 for i, s in zip(idx_list, shape))
2919
+ if is_boundary:
2920
+ boundary_count += 1
2921
+ if boundary_count > len(wrong_indices[:100]) * 0.8:
2922
+ result['pattern'] = 'boundary_issue'
2923
+ else:
2924
+ result['pattern'] = 'scattered'
2925
+ else:
2926
+ result['pattern'] = 'partial'
2927
+
2928
+ return result
2929
+
2930
+
2931
+ def print_diff_analysis(analysis):
2932
+ '''Print a human-readable diff analysis.'''
2933
+ print(f"[KernelBench] Diff analysis:")
2934
+
2935
+ # Max diff with location
2936
+ idx_str = ','.join(str(i) for i in analysis['max_diff_idx'])
2937
+ print(f" Max diff: {analysis['max_diff']:.6f} at index [{idx_str}]")
2938
+ print(f" Mean diff: {analysis['mean_diff']:.6f}")
2939
+
2940
+ # Relative errors
2941
+ print(f" Max relative error: {analysis['max_rel_error']:.2%}, Mean: {analysis['mean_rel_error']:.2%}")
2942
+
2943
+ # Wrong count
2944
+ print(f" Wrong elements: {analysis['wrong_count']:,} / {analysis['total_elements']:,} ({analysis['wrong_pct']:.2f}%)")
2945
+
2946
+ # Histogram
2947
+ hist = analysis['histogram']
2948
+ print(f" Error distribution: <1e-6: {hist['<1e-6']:,} | 1e-6~1e-4: {hist['1e-6 to 1e-4']:,} | 1e-4~1e-2: {hist['1e-4 to 1e-2']:,} | 1e-2~1: {hist['1e-2 to 1']:,} | >1: {hist['>1']:,}")
2949
+
2950
+ if 'pattern' in analysis:
2951
+ pattern_desc = {
2952
+ 'all_wrong': 'ALL elements wrong - likely algorithmic error or wrong weights',
2953
+ 'boundary_issue': 'Mostly BOUNDARY elements wrong - check edge handling',
2954
+ 'scattered': 'SCATTERED failures - numerical precision issue?',
2955
+ 'partial': 'PARTIAL failures - check specific conditions',
2956
+ }
2957
+ print(f" Pattern: {pattern_desc.get(analysis['pattern'], analysis['pattern'])}")
2958
+
2959
+ if analysis['samples']:
2960
+ print(f" Sample failures:")
2961
+ for s in analysis['samples']:
2962
+ idx_str = ','.join(str(i) for i in s['index'])
2963
+ print(f" [{idx_str}]: ref={s['ref']:.6f} impl={s['impl']:.6f} (diff={s['diff']:.6f})")
2964
+
2442
2965
 
2443
2966
  def main():
2444
2967
  # Parse args
@@ -2446,12 +2969,35 @@ def main():
2446
2969
  parser = argparse.ArgumentParser()
2447
2970
  parser.add_argument("--impl", required=True)
2448
2971
  parser.add_argument("--reference", required=True)
2972
+ parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
2449
2973
  parser.add_argument("--benchmark", action="store_true")
2974
+ parser.add_argument("--profile", action="store_true")
2975
+ parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
2976
+ parser.add_argument("--defense-module", help="Path to defense.py module")
2977
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
2450
2978
  parser.add_argument("--num-correct-trials", type=int, default=3)
2451
2979
  parser.add_argument("--num-perf-trials", type=int, default=10)
2452
2980
  parser.add_argument("--output", required=True)
2453
2981
  args = parser.parse_args()
2454
2982
 
2983
+ # Load defense module if defensive mode is enabled
2984
+ defense_module = None
2985
+ if args.defensive and args.defense_module:
2986
+ try:
2987
+ import importlib.util
2988
+ defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
2989
+ defense_module = importlib.util.module_from_spec(defense_spec)
2990
+ defense_spec.loader.exec_module(defense_module)
2991
+ print("[KernelBench] Defense module loaded")
2992
+ except Exception as e:
2993
+ print(f"[KernelBench] Warning: Could not load defense module: {e}")
2994
+
2995
+ # Create output directory for profiles
2996
+ output_dir = Path(args.output).parent
2997
+ profile_dir = output_dir / "profiles"
2998
+ if args.profile:
2999
+ profile_dir.mkdir(exist_ok=True)
3000
+
2455
3001
  results = {
2456
3002
  "compiled": False,
2457
3003
  "correct": False,
@@ -2472,6 +3018,33 @@ def main():
2472
3018
  get_inputs = ref_module.get_inputs
2473
3019
  get_init_inputs = ref_module.get_init_inputs
2474
3020
 
3021
+ # Load custom inputs if provided
3022
+ if args.inputs:
3023
+ inputs_spec = importlib.util.spec_from_file_location("custom_inputs", args.inputs)
3024
+ inputs_module = importlib.util.module_from_spec(inputs_spec)
3025
+ inputs_spec.loader.exec_module(inputs_module)
3026
+
3027
+ # Validate custom inputs match expected signature
3028
+ original_inputs = get_inputs()
3029
+ custom_get_inputs = inputs_module.get_inputs
3030
+ custom_inputs = custom_get_inputs()
3031
+
3032
+ is_valid, error_msg = validate_custom_inputs(original_inputs, custom_inputs)
3033
+ if not is_valid:
3034
+ print(f"[KernelBench] Custom inputs validation failed: {error_msg}")
3035
+ results["error"] = f"Custom inputs validation failed: {error_msg}"
3036
+ raise ValueError(error_msg)
3037
+
3038
+ # Override get_inputs (and optionally get_init_inputs)
3039
+ get_inputs = custom_get_inputs
3040
+ if hasattr(inputs_module, 'get_init_inputs'):
3041
+ get_init_inputs = inputs_module.get_init_inputs
3042
+
3043
+ # Show what changed
3044
+ orig_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in original_inputs]
3045
+ cust_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in custom_inputs]
3046
+ print(f"[KernelBench] Using custom inputs: {orig_shapes} -> {cust_shapes}")
3047
+
2475
3048
  # Load implementation module
2476
3049
  impl_spec = importlib.util.spec_from_file_location("implementation", args.impl)
2477
3050
  impl_module = importlib.util.module_from_spec(impl_spec)
@@ -2481,12 +3054,19 @@ def main():
2481
3054
  results["compiled"] = True
2482
3055
  print("[KernelBench] Modules loaded successfully")
2483
3056
 
2484
- # Instantiate models
3057
+ # Instantiate models with synchronized seeds for reproducible weights
3058
+ # (matches upstream KernelBench behavior in src/eval.py)
3059
+ seed = args.seed
2485
3060
  init_inputs = get_init_inputs()
2486
3061
  with torch.no_grad():
3062
+ torch.manual_seed(seed)
3063
+ torch.cuda.manual_seed(seed)
2487
3064
  ref_model = Model(*init_inputs).cuda().eval()
3065
+
3066
+ torch.manual_seed(seed)
3067
+ torch.cuda.manual_seed(seed)
2488
3068
  new_model = ModelNew(*init_inputs).cuda().eval()
2489
- print("[KernelBench] Models instantiated")
3069
+ print(f"[KernelBench] Models instantiated (seed={seed})")
2490
3070
 
2491
3071
  # Run correctness trials
2492
3072
  all_correct = True
@@ -2502,8 +3082,18 @@ def main():
2502
3082
  if isinstance(ref_output, torch.Tensor):
2503
3083
  if not torch.allclose(ref_output, new_output, rtol=1e-3, atol=1e-3):
2504
3084
  all_correct = False
2505
- max_diff = (ref_output - new_output).abs().max().item()
2506
- results["error"] = f"Correctness failed on trial {trial+1}: max diff = {max_diff}"
3085
+ analysis = analyze_diff(ref_output, new_output)
3086
+ results["error"] = f"Correctness failed on trial {trial+1}: max diff = {analysis['max_diff']}"
3087
+ results["diff_analysis"] = analysis
3088
+ print_diff_analysis(analysis)
3089
+
3090
+ # Save tensors for debugging
3091
+ debug_dir = output_dir / "debug"
3092
+ debug_dir.mkdir(exist_ok=True)
3093
+ torch.save(ref_output.cpu(), debug_dir / "ref_output.pt")
3094
+ torch.save(new_output.cpu(), debug_dir / "impl_output.pt")
3095
+ torch.save(inputs[0].cpu() if inputs else None, debug_dir / "input.pt")
3096
+ print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
2507
3097
  break
2508
3098
  else:
2509
3099
  # Handle tuple/list outputs
@@ -2511,8 +3101,17 @@ def main():
2511
3101
  if isinstance(r, torch.Tensor):
2512
3102
  if not torch.allclose(r, n, rtol=1e-3, atol=1e-3):
2513
3103
  all_correct = False
2514
- max_diff = (r - n).abs().max().item()
2515
- results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {max_diff}"
3104
+ analysis = analyze_diff(r, n)
3105
+ results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {analysis['max_diff']}"
3106
+ results["diff_analysis"] = analysis
3107
+ print_diff_analysis(analysis)
3108
+
3109
+ # Save tensors for debugging
3110
+ debug_dir = output_dir / "debug"
3111
+ debug_dir.mkdir(exist_ok=True)
3112
+ torch.save(r.cpu(), debug_dir / f"ref_output_{i}.pt")
3113
+ torch.save(n.cpu(), debug_dir / f"impl_output_{i}.pt")
3114
+ print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
2516
3115
  break
2517
3116
  if not all_correct:
2518
3117
  break
@@ -2526,47 +3125,132 @@ def main():
2526
3125
  inputs = get_inputs()
2527
3126
  inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
2528
3127
 
2529
- # Warmup
2530
- for _ in range(5):
2531
- with torch.no_grad():
2532
- _ = new_model(*inputs)
2533
- torch.cuda.synchronize()
2534
-
2535
- # Benchmark new model
2536
- start = torch.cuda.Event(enable_timing=True)
2537
- end = torch.cuda.Event(enable_timing=True)
3128
+ if args.defensive and defense_module is not None:
3129
+ # Use full defense suite
3130
+ print("[KernelBench] Running defense checks on implementation...")
3131
+ run_all_defenses = defense_module.run_all_defenses
3132
+ time_with_defenses = defense_module.time_execution_with_defenses
2538
3133
 
2539
- times = []
2540
- for _ in range(args.num_perf_trials):
2541
- start.record()
2542
- with torch.no_grad():
2543
- _ = new_model(*inputs)
2544
- end.record()
3134
+ # Run defense checks on implementation
3135
+ all_passed, defense_results, _ = run_all_defenses(
3136
+ lambda *x: new_model(*x),
3137
+ *inputs,
3138
+ )
3139
+ results["defense_results"] = {
3140
+ name: {"passed": passed, "message": msg}
3141
+ for name, passed, msg in defense_results
3142
+ }
3143
+ if not all_passed:
3144
+ failed = [name for name, passed, _ in defense_results if not passed]
3145
+ results["error"] = f"Defense checks failed: {failed}"
3146
+ print(f"[KernelBench] Defense checks FAILED: {failed}")
3147
+ for name, passed, msg in defense_results:
3148
+ status = "PASS" if passed else "FAIL"
3149
+ print(f" [{status}] {name}: {msg}")
3150
+ else:
3151
+ print("[KernelBench] All defense checks passed")
3152
+
3153
+ # Time with defensive timing
3154
+ impl_times, _ = time_with_defenses(
3155
+ lambda: new_model(*inputs),
3156
+ [],
3157
+ num_warmup=5,
3158
+ num_trials=args.num_perf_trials,
3159
+ verbose=False,
3160
+ run_defenses=False, # Already ran above
3161
+ )
3162
+ new_time = sum(impl_times) / len(impl_times)
3163
+ results["runtime_ms"] = new_time
3164
+
3165
+ # Reference timing
3166
+ ref_times, _ = time_with_defenses(
3167
+ lambda: ref_model(*inputs),
3168
+ [],
3169
+ num_warmup=5,
3170
+ num_trials=args.num_perf_trials,
3171
+ verbose=False,
3172
+ run_defenses=False,
3173
+ )
3174
+ ref_time = sum(ref_times) / len(ref_times)
3175
+ results["reference_runtime_ms"] = ref_time
3176
+ results["speedup"] = ref_time / new_time if new_time > 0 else 0
3177
+ print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
3178
+ else:
3179
+ # Standard timing without full defenses
3180
+ # Warmup
3181
+ for _ in range(5):
3182
+ with torch.no_grad():
3183
+ _ = new_model(*inputs)
2545
3184
  torch.cuda.synchronize()
2546
- times.append(start.elapsed_time(end))
2547
-
2548
- new_time = sum(times) / len(times)
2549
- results["runtime_ms"] = new_time
2550
3185
 
2551
- # Benchmark reference model
2552
- for _ in range(5):
2553
- with torch.no_grad():
2554
- _ = ref_model(*inputs)
2555
- torch.cuda.synchronize()
2556
-
2557
- times = []
2558
- for _ in range(args.num_perf_trials):
2559
- start.record()
2560
- with torch.no_grad():
2561
- _ = ref_model(*inputs)
2562
- end.record()
3186
+ # Benchmark new model
3187
+ start = torch.cuda.Event(enable_timing=True)
3188
+ end = torch.cuda.Event(enable_timing=True)
3189
+
3190
+ times = []
3191
+ for _ in range(args.num_perf_trials):
3192
+ start.record()
3193
+ with torch.no_grad():
3194
+ _ = new_model(*inputs)
3195
+ end.record()
3196
+ torch.cuda.synchronize()
3197
+ times.append(start.elapsed_time(end))
3198
+
3199
+ new_time = sum(times) / len(times)
3200
+ results["runtime_ms"] = new_time
3201
+
3202
+ # Benchmark reference model
3203
+ for _ in range(5):
3204
+ with torch.no_grad():
3205
+ _ = ref_model(*inputs)
2563
3206
  torch.cuda.synchronize()
2564
- times.append(start.elapsed_time(end))
2565
3207
 
2566
- ref_time = sum(times) / len(times)
2567
- results["reference_runtime_ms"] = ref_time
2568
- results["speedup"] = ref_time / new_time if new_time > 0 else 0
2569
- print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
3208
+ times = []
3209
+ for _ in range(args.num_perf_trials):
3210
+ start.record()
3211
+ with torch.no_grad():
3212
+ _ = ref_model(*inputs)
3213
+ end.record()
3214
+ torch.cuda.synchronize()
3215
+ times.append(start.elapsed_time(end))
3216
+
3217
+ ref_time = sum(times) / len(times)
3218
+ results["reference_runtime_ms"] = ref_time
3219
+ results["speedup"] = ref_time / new_time if new_time > 0 else 0
3220
+ print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
3221
+
3222
+ # Run profiling if requested and correctness passed
3223
+ if args.profile and all_correct:
3224
+ print("[KernelBench] Running profiler...")
3225
+ inputs = get_inputs()
3226
+ inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
3227
+
3228
+ try:
3229
+ # Profile implementation
3230
+ impl_stats = run_profiling(new_model, inputs, "implementation", str(profile_dir))
3231
+ results["profile_impl"] = impl_stats
3232
+ print(f"[KernelBench] Implementation profile:")
3233
+ print(f" Total GPU time: {impl_stats['total_gpu_time_ms']:.3f}ms")
3234
+ print(f" Kernels launched: {impl_stats['num_gpu_kernels']}")
3235
+ if impl_stats['top_kernels']:
3236
+ print(f" Top kernel: {impl_stats['top_kernels'][0]['name'][:60]}...")
3237
+ print(f" {impl_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
3238
+
3239
+ # Profile reference
3240
+ ref_stats = run_profiling(ref_model, inputs, "reference", str(profile_dir))
3241
+ results["profile_ref"] = ref_stats
3242
+ print(f"[KernelBench] Reference profile:")
3243
+ print(f" Total GPU time: {ref_stats['total_gpu_time_ms']:.3f}ms")
3244
+ print(f" Kernels launched: {ref_stats['num_gpu_kernels']}")
3245
+ if ref_stats['top_kernels']:
3246
+ print(f" Top kernel: {ref_stats['top_kernels'][0]['name'][:60]}...")
3247
+ print(f" {ref_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
3248
+
3249
+ print(f"[KernelBench] Profile traces saved to: {profile_dir}/")
3250
+
3251
+ except Exception as prof_err:
3252
+ print(f"[KernelBench] Profiling failed: {prof_err}")
3253
+ results["profile_error"] = str(prof_err)
2570
3254
 
2571
3255
  except Exception as e:
2572
3256
  import traceback
@@ -2705,6 +3389,24 @@ async def run_evaluate_kernelbench_docker(
2705
3389
  error_message=f"Failed to write reference: {write_result.stderr}",
2706
3390
  )
2707
3391
 
3392
+ # Write custom inputs if provided
3393
+ if args.inputs:
3394
+ inputs_code = args.inputs.read_text()
3395
+ inputs_file_path = f"{run_path}/custom_inputs.py"
3396
+ write_result = await client.exec(
3397
+ f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3398
+ )
3399
+ if write_result.exit_code != 0:
3400
+ return EvaluateResult(
3401
+ success=False,
3402
+ all_correct=False,
3403
+ correctness_score=0.0,
3404
+ geomean_speedup=0.0,
3405
+ passed_tests=0,
3406
+ total_tests=0,
3407
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3408
+ )
3409
+
2708
3410
  # Write eval script
2709
3411
  eval_script_path = f"{run_path}/kernelbench_eval.py"
2710
3412
  write_result = await client.exec(
@@ -2721,14 +3423,40 @@ async def run_evaluate_kernelbench_docker(
2721
3423
  error_message=f"Failed to write eval script: {write_result.stderr}",
2722
3424
  )
2723
3425
 
3426
+ # Write defense module if defensive mode is enabled
3427
+ defense_module_path = None
3428
+ if args.defensive:
3429
+ defense_path = (
3430
+ Path(__file__).parent.parent.parent.parent
3431
+ / "packages"
3432
+ / "wafer-core"
3433
+ / "wafer_core"
3434
+ / "utils"
3435
+ / "kernel_utils"
3436
+ / "defense.py"
3437
+ )
3438
+ if defense_path.exists():
3439
+ defense_code = defense_path.read_text()
3440
+ defense_module_path = f"{run_path}/defense.py"
3441
+ write_result = await client.exec(
3442
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3443
+ )
3444
+ if write_result.exit_code != 0:
3445
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3446
+ defense_module_path = None
3447
+ else:
3448
+ print(f"Warning: defense.py not found at {defense_path}")
3449
+
2724
3450
  print("Running KernelBench evaluation in Docker container...")
2725
3451
 
2726
3452
  # Paths inside container
2727
3453
  container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
2728
3454
  container_impl_path = f"{container_run_path}/implementation.py"
2729
3455
  container_ref_path = f"{container_run_path}/reference.py"
3456
+ container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
2730
3457
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
2731
3458
  container_output = f"{container_run_path}/results.json"
3459
+ container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
2732
3460
 
2733
3461
  # Build eval command
2734
3462
  python_cmd_parts = [
@@ -2740,6 +3468,14 @@ async def run_evaluate_kernelbench_docker(
2740
3468
 
2741
3469
  if args.benchmark:
2742
3470
  python_cmd_parts.append("--benchmark")
3471
+ if args.profile:
3472
+ python_cmd_parts.append("--profile")
3473
+ if container_inputs_path:
3474
+ python_cmd_parts.append(f"--inputs {container_inputs_path}")
3475
+ if args.defensive and container_defense_path:
3476
+ python_cmd_parts.append("--defensive")
3477
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3478
+ python_cmd_parts.append(f"--seed {args.seed}")
2743
3479
 
2744
3480
  eval_cmd = " ".join(python_cmd_parts)
2745
3481
 
@@ -2920,6 +3656,24 @@ async def run_evaluate_kernelbench_digitalocean(
2920
3656
  error_message=f"Failed to write reference: {write_result.stderr}",
2921
3657
  )
2922
3658
 
3659
+ # Write custom inputs if provided
3660
+ if args.inputs:
3661
+ inputs_code = args.inputs.read_text()
3662
+ inputs_file_path = f"{run_path}/custom_inputs.py"
3663
+ write_result = await client.exec(
3664
+ f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3665
+ )
3666
+ if write_result.exit_code != 0:
3667
+ return EvaluateResult(
3668
+ success=False,
3669
+ all_correct=False,
3670
+ correctness_score=0.0,
3671
+ geomean_speedup=0.0,
3672
+ passed_tests=0,
3673
+ total_tests=0,
3674
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3675
+ )
3676
+
2923
3677
  # Write eval script
2924
3678
  eval_script_path = f"{run_path}/kernelbench_eval.py"
2925
3679
  write_result = await client.exec(
@@ -2936,14 +3690,44 @@ async def run_evaluate_kernelbench_digitalocean(
2936
3690
  error_message=f"Failed to write eval script: {write_result.stderr}",
2937
3691
  )
2938
3692
 
3693
+ # Write defense module if defensive mode is enabled
3694
+ defense_module_path = None
3695
+ if args.defensive:
3696
+ defense_path = (
3697
+ Path(__file__).parent.parent.parent.parent
3698
+ / "packages"
3699
+ / "wafer-core"
3700
+ / "wafer_core"
3701
+ / "utils"
3702
+ / "kernel_utils"
3703
+ / "defense.py"
3704
+ )
3705
+ if defense_path.exists():
3706
+ defense_code = defense_path.read_text()
3707
+ defense_module_path = f"{run_path}/defense.py"
3708
+ write_result = await client.exec(
3709
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3710
+ )
3711
+ if write_result.exit_code != 0:
3712
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3713
+ defense_module_path = None
3714
+ else:
3715
+ print(f"Warning: defense.py not found at {defense_path}")
3716
+
2939
3717
  print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
2940
3718
 
2941
3719
  # Paths inside container
2942
3720
  container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
2943
3721
  container_impl_path = f"{container_run_path}/implementation.py"
2944
3722
  container_ref_path = f"{container_run_path}/reference.py"
3723
+ container_inputs_path = (
3724
+ f"{container_run_path}/custom_inputs.py" if args.inputs else None
3725
+ )
2945
3726
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
2946
3727
  container_output = f"{container_run_path}/results.json"
3728
+ container_defense_path = (
3729
+ f"{container_run_path}/defense.py" if defense_module_path else None
3730
+ )
2947
3731
 
2948
3732
  # Build eval command
2949
3733
  python_cmd_parts = [
@@ -2955,6 +3739,14 @@ async def run_evaluate_kernelbench_digitalocean(
2955
3739
 
2956
3740
  if args.benchmark:
2957
3741
  python_cmd_parts.append("--benchmark")
3742
+ if args.profile:
3743
+ python_cmd_parts.append("--profile")
3744
+ if container_inputs_path:
3745
+ python_cmd_parts.append(f"--inputs {container_inputs_path}")
3746
+ if args.defensive and container_defense_path:
3747
+ python_cmd_parts.append("--defensive")
3748
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3749
+ python_cmd_parts.append(f"--seed {args.seed}")
2958
3750
 
2959
3751
  eval_cmd = " ".join(python_cmd_parts)
2960
3752
 
@@ -3039,11 +3831,478 @@ async def run_evaluate_kernelbench_digitalocean(
3039
3831
  )
3040
3832
 
3041
3833
 
3042
- async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
3043
- """Run KernelBench format evaluation on configured target.
3044
-
3045
- Args:
3046
- args: KernelBench evaluate arguments
3834
+ async def run_evaluate_kernelbench_runpod(
3835
+ args: KernelBenchEvaluateArgs,
3836
+ target: RunPodTarget,
3837
+ ) -> EvaluateResult:
3838
+ """Run KernelBench format evaluation directly on RunPod AMD GPU.
3839
+
3840
+ Runs evaluation script directly on host (no Docker) since RunPod pods
3841
+ already have PyTorch/ROCm installed.
3842
+ """
3843
+ from datetime import datetime
3844
+
3845
+ from wafer_core.async_ssh import AsyncSSHClient
3846
+ from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
3847
+
3848
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
3849
+
3850
+ # Select GPU
3851
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
3852
+
3853
+ print(f"Provisioning RunPod ({target.gpu_type_id})...")
3854
+
3855
+ try:
3856
+ async with runpod_ssh_context(target) as ssh_info:
3857
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
3858
+ print(f"Connected to RunPod: {ssh_target}")
3859
+
3860
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
3861
+ # Create workspace
3862
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3863
+ run_dir = f"kernelbench_eval_{timestamp}"
3864
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
3865
+
3866
+ await client.exec(f"mkdir -p {run_path}")
3867
+ print(f"Created run directory: {run_path}")
3868
+
3869
+ # Read and upload files
3870
+ impl_code = args.implementation.read_text()
3871
+ ref_code = args.reference.read_text()
3872
+
3873
+ # Write implementation
3874
+ impl_path = f"{run_path}/implementation.py"
3875
+ write_result = await client.exec(
3876
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
3877
+ )
3878
+ if write_result.exit_code != 0:
3879
+ return EvaluateResult(
3880
+ success=False,
3881
+ all_correct=False,
3882
+ correctness_score=0.0,
3883
+ geomean_speedup=0.0,
3884
+ passed_tests=0,
3885
+ total_tests=0,
3886
+ error_message=f"Failed to write implementation: {write_result.stderr}",
3887
+ )
3888
+
3889
+ # Write reference
3890
+ ref_path = f"{run_path}/reference.py"
3891
+ write_result = await client.exec(
3892
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
3893
+ )
3894
+ if write_result.exit_code != 0:
3895
+ return EvaluateResult(
3896
+ success=False,
3897
+ all_correct=False,
3898
+ correctness_score=0.0,
3899
+ geomean_speedup=0.0,
3900
+ passed_tests=0,
3901
+ total_tests=0,
3902
+ error_message=f"Failed to write reference: {write_result.stderr}",
3903
+ )
3904
+
3905
+ # Write custom inputs if provided
3906
+ inputs_path = None
3907
+ if args.inputs:
3908
+ inputs_code = args.inputs.read_text()
3909
+ inputs_path = f"{run_path}/custom_inputs.py"
3910
+ write_result = await client.exec(
3911
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3912
+ )
3913
+ if write_result.exit_code != 0:
3914
+ return EvaluateResult(
3915
+ success=False,
3916
+ all_correct=False,
3917
+ correctness_score=0.0,
3918
+ geomean_speedup=0.0,
3919
+ passed_tests=0,
3920
+ total_tests=0,
3921
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3922
+ )
3923
+
3924
+ # Write eval script
3925
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
3926
+ write_result = await client.exec(
3927
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
3928
+ )
3929
+ if write_result.exit_code != 0:
3930
+ return EvaluateResult(
3931
+ success=False,
3932
+ all_correct=False,
3933
+ correctness_score=0.0,
3934
+ geomean_speedup=0.0,
3935
+ passed_tests=0,
3936
+ total_tests=0,
3937
+ error_message=f"Failed to write eval script: {write_result.stderr}",
3938
+ )
3939
+
3940
+ # Write defense module if defensive mode is enabled
3941
+ defense_module_path = None
3942
+ if args.defensive:
3943
+ defense_path = (
3944
+ Path(__file__).parent.parent.parent.parent
3945
+ / "packages"
3946
+ / "wafer-core"
3947
+ / "wafer_core"
3948
+ / "utils"
3949
+ / "kernel_utils"
3950
+ / "defense.py"
3951
+ )
3952
+ if defense_path.exists():
3953
+ defense_code = defense_path.read_text()
3954
+ defense_module_path = f"{run_path}/defense.py"
3955
+ write_result = await client.exec(
3956
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3957
+ )
3958
+ if write_result.exit_code != 0:
3959
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3960
+ defense_module_path = None
3961
+ else:
3962
+ print(f"Warning: defense.py not found at {defense_path}")
3963
+
3964
+ print("Running KernelBench evaluation (AMD/ROCm)...")
3965
+
3966
+ # Find Python with PyTorch - check common locations on RunPod
3967
+ python_exe = "python3"
3968
+ for candidate in [
3969
+ "/opt/conda/envs/py_3.10/bin/python3",
3970
+ "/opt/conda/bin/python3",
3971
+ ]:
3972
+ check = await client.exec(
3973
+ f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
3974
+ )
3975
+ if "OK" in check.stdout:
3976
+ python_exe = candidate
3977
+ print(f"Using Python: {python_exe}")
3978
+ break
3979
+
3980
+ # Build eval command - run directly on host
3981
+ output_path = f"{run_path}/results.json"
3982
+ python_cmd_parts = [
3983
+ f"{python_exe} {eval_script_path}",
3984
+ f"--impl {impl_path}",
3985
+ f"--reference {ref_path}",
3986
+ f"--output {output_path}",
3987
+ ]
3988
+
3989
+ if args.benchmark:
3990
+ python_cmd_parts.append("--benchmark")
3991
+ if args.profile:
3992
+ python_cmd_parts.append("--profile")
3993
+ if inputs_path:
3994
+ python_cmd_parts.append(f"--inputs {inputs_path}")
3995
+ if args.defensive and defense_module_path:
3996
+ python_cmd_parts.append("--defensive")
3997
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
3998
+ python_cmd_parts.append(f"--seed {args.seed}")
3999
+
4000
+ eval_cmd = " ".join(python_cmd_parts)
4001
+
4002
+ # Set environment for AMD GPU and run
4003
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4004
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4005
+
4006
+ # Run and stream output
4007
+ log_lines = []
4008
+ async for line in client.exec_stream(full_cmd):
4009
+ print(line)
4010
+ log_lines.append(line)
4011
+
4012
+ # Read results
4013
+ cat_result = await client.exec(f"cat {output_path}")
4014
+
4015
+ if cat_result.exit_code != 0:
4016
+ log_tail = "\n".join(log_lines[-50:])
4017
+ return EvaluateResult(
4018
+ success=False,
4019
+ all_correct=False,
4020
+ correctness_score=0.0,
4021
+ geomean_speedup=0.0,
4022
+ passed_tests=0,
4023
+ total_tests=0,
4024
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4025
+ )
4026
+
4027
+ # Parse results
4028
+ try:
4029
+ results_data = json.loads(cat_result.stdout)
4030
+ except json.JSONDecodeError as e:
4031
+ return EvaluateResult(
4032
+ success=False,
4033
+ all_correct=False,
4034
+ correctness_score=0.0,
4035
+ geomean_speedup=0.0,
4036
+ passed_tests=0,
4037
+ total_tests=0,
4038
+ error_message=f"Failed to parse results: {e}",
4039
+ )
4040
+
4041
+ # Convert to EvaluateResult
4042
+ correct = results_data.get("correct", False)
4043
+ speedup = results_data.get("speedup", 0.0) or 0.0
4044
+ error = results_data.get("error")
4045
+
4046
+ if error:
4047
+ return EvaluateResult(
4048
+ success=False,
4049
+ all_correct=False,
4050
+ correctness_score=0.0,
4051
+ geomean_speedup=0.0,
4052
+ passed_tests=0,
4053
+ total_tests=1,
4054
+ error_message=error,
4055
+ )
4056
+
4057
+ return EvaluateResult(
4058
+ success=True,
4059
+ all_correct=correct,
4060
+ correctness_score=1.0 if correct else 0.0,
4061
+ geomean_speedup=speedup,
4062
+ passed_tests=1 if correct else 0,
4063
+ total_tests=1,
4064
+ )
4065
+
4066
+ except RunPodError as e:
4067
+ return EvaluateResult(
4068
+ success=False,
4069
+ all_correct=False,
4070
+ correctness_score=0.0,
4071
+ geomean_speedup=0.0,
4072
+ passed_tests=0,
4073
+ total_tests=0,
4074
+ error_message=f"RunPod error: {e}",
4075
+ )
4076
+
4077
+
4078
+ async def run_evaluate_kernelbench_baremetal_amd(
4079
+ args: KernelBenchEvaluateArgs,
4080
+ target: BaremetalTarget,
4081
+ ) -> EvaluateResult:
4082
+ """Run KernelBench format evaluation directly on AMD baremetal target.
4083
+
4084
+ Runs evaluation script directly on host (no Docker) for AMD GPUs
4085
+ that have PyTorch/ROCm installed.
4086
+ """
4087
+ from datetime import datetime
4088
+
4089
+ from wafer_core.async_ssh import AsyncSSHClient
4090
+
4091
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
4092
+
4093
+ # Select GPU
4094
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
4095
+
4096
+ print(f"Connecting to {target.ssh_target}...")
4097
+
4098
+ async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
4099
+ # Create workspace
4100
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4101
+ run_dir = f"kernelbench_eval_{timestamp}"
4102
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
4103
+
4104
+ await client.exec(f"mkdir -p {run_path}")
4105
+ print(f"Created run directory: {run_path}")
4106
+
4107
+ # Read and upload files
4108
+ impl_code = args.implementation.read_text()
4109
+ ref_code = args.reference.read_text()
4110
+
4111
+ # Write implementation
4112
+ impl_path = f"{run_path}/implementation.py"
4113
+ write_result = await client.exec(
4114
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4115
+ )
4116
+ if write_result.exit_code != 0:
4117
+ return EvaluateResult(
4118
+ success=False,
4119
+ all_correct=False,
4120
+ correctness_score=0.0,
4121
+ geomean_speedup=0.0,
4122
+ passed_tests=0,
4123
+ total_tests=0,
4124
+ error_message=f"Failed to write implementation: {write_result.stderr}",
4125
+ )
4126
+
4127
+ # Write reference
4128
+ ref_path = f"{run_path}/reference.py"
4129
+ write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
4130
+ if write_result.exit_code != 0:
4131
+ return EvaluateResult(
4132
+ success=False,
4133
+ all_correct=False,
4134
+ correctness_score=0.0,
4135
+ geomean_speedup=0.0,
4136
+ passed_tests=0,
4137
+ total_tests=0,
4138
+ error_message=f"Failed to write reference: {write_result.stderr}",
4139
+ )
4140
+
4141
+ # Write custom inputs if provided
4142
+ inputs_path = None
4143
+ if args.inputs:
4144
+ inputs_code = args.inputs.read_text()
4145
+ inputs_path = f"{run_path}/custom_inputs.py"
4146
+ write_result = await client.exec(
4147
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4148
+ )
4149
+ if write_result.exit_code != 0:
4150
+ return EvaluateResult(
4151
+ success=False,
4152
+ all_correct=False,
4153
+ correctness_score=0.0,
4154
+ geomean_speedup=0.0,
4155
+ passed_tests=0,
4156
+ total_tests=0,
4157
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
4158
+ )
4159
+
4160
+ # Write eval script
4161
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
4162
+ write_result = await client.exec(
4163
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4164
+ )
4165
+ if write_result.exit_code != 0:
4166
+ return EvaluateResult(
4167
+ success=False,
4168
+ all_correct=False,
4169
+ correctness_score=0.0,
4170
+ geomean_speedup=0.0,
4171
+ passed_tests=0,
4172
+ total_tests=0,
4173
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4174
+ )
4175
+
4176
+ # Write defense module if defensive mode is enabled
4177
+ defense_module_path = None
4178
+ if args.defensive:
4179
+ defense_path = (
4180
+ Path(__file__).parent.parent.parent.parent
4181
+ / "packages"
4182
+ / "wafer-core"
4183
+ / "wafer_core"
4184
+ / "utils"
4185
+ / "kernel_utils"
4186
+ / "defense.py"
4187
+ )
4188
+ if defense_path.exists():
4189
+ defense_code = defense_path.read_text()
4190
+ defense_module_path = f"{run_path}/defense.py"
4191
+ write_result = await client.exec(
4192
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4193
+ )
4194
+ if write_result.exit_code != 0:
4195
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4196
+ defense_module_path = None
4197
+ else:
4198
+ print(f"Warning: defense.py not found at {defense_path}")
4199
+
4200
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4201
+
4202
+ # Find Python with PyTorch - check common locations
4203
+ python_exe = "python3"
4204
+ for candidate in [
4205
+ "/opt/conda/envs/py_3.10/bin/python3",
4206
+ "/opt/conda/bin/python3",
4207
+ ]:
4208
+ check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
4209
+ if "OK" in check.stdout:
4210
+ python_exe = candidate
4211
+ print(f"Using Python: {python_exe}")
4212
+ break
4213
+
4214
+ # Build eval command - run directly on host
4215
+ output_path = f"{run_path}/results.json"
4216
+ python_cmd_parts = [
4217
+ f"{python_exe} {eval_script_path}",
4218
+ f"--impl {impl_path}",
4219
+ f"--reference {ref_path}",
4220
+ f"--output {output_path}",
4221
+ ]
4222
+
4223
+ if args.benchmark:
4224
+ python_cmd_parts.append("--benchmark")
4225
+ if args.profile:
4226
+ python_cmd_parts.append("--profile")
4227
+ if inputs_path:
4228
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4229
+ if args.defensive and defense_module_path:
4230
+ python_cmd_parts.append("--defensive")
4231
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4232
+ python_cmd_parts.append(f"--seed {args.seed}")
4233
+
4234
+ eval_cmd = " ".join(python_cmd_parts)
4235
+
4236
+ # Set environment for AMD GPU and run
4237
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4238
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4239
+
4240
+ # Run and stream output
4241
+ log_lines = []
4242
+ async for line in client.exec_stream(full_cmd):
4243
+ print(line)
4244
+ log_lines.append(line)
4245
+
4246
+ # Read results
4247
+ cat_result = await client.exec(f"cat {output_path}")
4248
+
4249
+ if cat_result.exit_code != 0:
4250
+ log_tail = "\n".join(log_lines[-50:])
4251
+ return EvaluateResult(
4252
+ success=False,
4253
+ all_correct=False,
4254
+ correctness_score=0.0,
4255
+ geomean_speedup=0.0,
4256
+ passed_tests=0,
4257
+ total_tests=0,
4258
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4259
+ )
4260
+
4261
+ # Parse results
4262
+ try:
4263
+ results_data = json.loads(cat_result.stdout)
4264
+ except json.JSONDecodeError as e:
4265
+ return EvaluateResult(
4266
+ success=False,
4267
+ all_correct=False,
4268
+ correctness_score=0.0,
4269
+ geomean_speedup=0.0,
4270
+ passed_tests=0,
4271
+ total_tests=0,
4272
+ error_message=f"Failed to parse results: {e}",
4273
+ )
4274
+
4275
+ # Convert to EvaluateResult
4276
+ correct = results_data.get("correct", False)
4277
+ speedup = results_data.get("speedup", 0.0) or 0.0
4278
+ error = results_data.get("error")
4279
+
4280
+ if error:
4281
+ return EvaluateResult(
4282
+ success=False,
4283
+ all_correct=False,
4284
+ correctness_score=0.0,
4285
+ geomean_speedup=0.0,
4286
+ passed_tests=0,
4287
+ total_tests=1,
4288
+ error_message=error,
4289
+ )
4290
+
4291
+ return EvaluateResult(
4292
+ success=True,
4293
+ all_correct=correct,
4294
+ correctness_score=1.0 if correct else 0.0,
4295
+ geomean_speedup=speedup,
4296
+ passed_tests=1 if correct else 0,
4297
+ total_tests=1,
4298
+ )
4299
+
4300
+
4301
+ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
4302
+ """Run KernelBench format evaluation on configured target.
4303
+
4304
+ Args:
4305
+ args: KernelBench evaluate arguments
3047
4306
 
3048
4307
  Returns:
3049
4308
  Evaluation result
@@ -3103,7 +4362,13 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
3103
4362
  if isinstance(target, DigitalOceanTarget):
3104
4363
  # DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
3105
4364
  return await run_evaluate_kernelbench_digitalocean(args, target)
4365
+ elif isinstance(target, RunPodTarget):
4366
+ # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4367
+ return await run_evaluate_kernelbench_runpod(args, target)
3106
4368
  elif isinstance(target, BaremetalTarget | VMTarget):
4369
+ # Check if this is an AMD target (gfx* compute capability) - run directly
4370
+ if target.compute_capability and target.compute_capability.startswith("gfx"):
4371
+ return await run_evaluate_kernelbench_baremetal_amd(args, target)
3107
4372
  # NVIDIA targets - require docker_image to be set
3108
4373
  if not target.docker_image:
3109
4374
  return EvaluateResult(
@@ -3129,6 +4394,6 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
3129
4394
  total_tests=0,
3130
4395
  error_message=(
3131
4396
  f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
3132
- "Use a DigitalOcean, Baremetal, or VM target."
4397
+ "Use a DigitalOcean, RunPod, Baremetal, or VM target."
3133
4398
  ),
3134
4399
  )