wafer-cli 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/evaluate.py CHANGED
@@ -14,7 +14,6 @@ logger = logging.getLogger(__name__)
14
14
  from wafer_core.utils.kernel_utils.targets.config import (
15
15
  BaremetalTarget,
16
16
  DigitalOceanTarget,
17
- LocalTarget,
18
17
  ModalTarget,
19
18
  RunPodTarget,
20
19
  VMTarget,
@@ -397,6 +396,33 @@ async def run_evaluate_docker(
397
396
  print(f"Connecting to {target.ssh_target}...")
398
397
 
399
398
  async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
399
+ # Upload wafer-core to remote
400
+ try:
401
+ wafer_root = _get_wafer_root()
402
+ wafer_core_path = wafer_root / "packages" / "wafer-core"
403
+ print(f"Uploading wafer-core from {wafer_core_path}...")
404
+
405
+ # Create workspace and upload
406
+ workspace_name = wafer_core_path.name
407
+ remote_workspace = f"{REMOTE_WORKSPACE_BASE}/{workspace_name}"
408
+ await client.exec(f"mkdir -p {remote_workspace}")
409
+ wafer_core_workspace = await client.expand_path(remote_workspace)
410
+
411
+ upload_result = await client.upload_files(
412
+ str(wafer_core_path), wafer_core_workspace, recursive=True
413
+ )
414
+ print(f"Uploaded {upload_result.files_copied} files")
415
+ except Exception as e:
416
+ return EvaluateResult(
417
+ success=False,
418
+ all_correct=False,
419
+ correctness_score=0.0,
420
+ geomean_speedup=0.0,
421
+ passed_tests=0,
422
+ total_tests=0,
423
+ error_message=f"Failed to upload wafer-core: {e}",
424
+ )
425
+
400
426
  print(f"Using Docker image: {target.docker_image}")
401
427
  print(f"Using GPU {gpu_id}...")
402
428
 
@@ -405,13 +431,10 @@ async def run_evaluate_docker(
405
431
  ref_code = args.reference.read_text()
406
432
  test_cases_data = json.loads(args.test_cases.read_text())
407
433
 
408
- # Create workspace for evaluation files
434
+ # Create a unique run directory
409
435
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
410
436
  run_dir = f"wafer_eval_{timestamp}"
411
- eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
412
- await client.exec(f"mkdir -p {eval_workspace}")
413
- eval_workspace_expanded = await client.expand_path(eval_workspace)
414
- run_path = f"{eval_workspace_expanded}/{run_dir}"
437
+ run_path = f"{wafer_core_workspace}/{run_dir}"
415
438
 
416
439
  print("Uploading evaluation files...")
417
440
 
@@ -498,16 +521,17 @@ async def run_evaluate_docker(
498
521
  container_impl_path = f"{container_run_path}/implementation.py"
499
522
  container_ref_path = f"{container_run_path}/reference.py"
500
523
  container_test_cases_path = f"{container_run_path}/test_cases.json"
524
+ container_evaluate_script = (
525
+ f"{CONTAINER_WORKSPACE}/wafer_core/utils/kernel_utils/evaluate.py"
526
+ )
501
527
 
502
- # Build pip install command for torch and other deps, plus wafer-core
528
+ # Build pip install command for torch and other deps (no wafer-core install needed)
503
529
  pip_install_cmd = _build_docker_pip_install_cmd(target)
504
- install_cmd = (
505
- f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
506
- )
507
530
 
508
- # Build evaluate command using installed wafer-core module
531
+ # Build evaluate command - use PYTHONPATH instead of installing wafer-core
509
532
  python_cmd_parts = [
510
- "python3 -m wafer_core.utils.kernel_utils.evaluate",
533
+ f"PYTHONPATH={CONTAINER_WORKSPACE}:$PYTHONPATH",
534
+ f"python3 {container_evaluate_script}",
511
535
  f"--implementation {container_impl_path}",
512
536
  f"--reference {container_ref_path}",
513
537
  f"--test-cases {container_test_cases_path}",
@@ -523,8 +547,8 @@ async def run_evaluate_docker(
523
547
 
524
548
  eval_cmd = " ".join(python_cmd_parts)
525
549
 
526
- # Full command: install deps + wafer-core, then run evaluate
527
- full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
550
+ # Full command: install torch deps, then run evaluate with PYTHONPATH
551
+ full_cmd = f"{pip_install_cmd} && cd {container_run_path} && {eval_cmd}"
528
552
 
529
553
  # Build Docker run command
530
554
  # Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
@@ -534,7 +558,7 @@ async def run_evaluate_docker(
534
558
  working_dir=container_run_path,
535
559
  env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
536
560
  gpus="all",
537
- volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
561
+ volumes={wafer_core_workspace: CONTAINER_WORKSPACE},
538
562
  cap_add=["SYS_ADMIN"] if args.profile else None,
539
563
  )
540
564
 
@@ -641,181 +665,6 @@ async def run_evaluate_docker(
641
665
  )
642
666
 
643
667
 
644
- async def run_evaluate_local(
645
- args: EvaluateArgs,
646
- target: LocalTarget,
647
- ) -> EvaluateResult:
648
- """Run evaluation locally on the current machine.
649
-
650
- For LocalTarget - no SSH needed, runs directly.
651
-
652
- Args:
653
- args: Evaluate arguments
654
- target: Local target config
655
-
656
- Returns:
657
- Evaluation result
658
- """
659
- import os
660
- import subprocess
661
- import tempfile
662
- from datetime import datetime
663
-
664
- # Select GPU
665
- gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
666
-
667
- print(f"Running local evaluation on GPU {gpu_id}...")
668
-
669
- # Create temp directory for eval files
670
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
671
- with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
672
- run_path = Path(run_path)
673
-
674
- # Write implementation
675
- impl_path = run_path / "implementation.py"
676
- impl_path.write_text(args.implementation.read_text())
677
-
678
- # Write reference
679
- ref_path = run_path / "reference.py"
680
- ref_path.write_text(args.reference.read_text())
681
-
682
- # Write custom inputs if provided
683
- inputs_path = None
684
- if args.inputs:
685
- inputs_path = run_path / "custom_inputs.py"
686
- inputs_path.write_text(args.inputs.read_text())
687
-
688
- # Write eval script
689
- eval_script_path = run_path / "kernelbench_eval.py"
690
- eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
691
-
692
- # Write defense module if defensive mode is enabled
693
- defense_module_path = None
694
- if args.defensive:
695
- defense_src = (
696
- Path(__file__).parent.parent.parent.parent
697
- / "packages"
698
- / "wafer-core"
699
- / "wafer_core"
700
- / "utils"
701
- / "kernel_utils"
702
- / "defense.py"
703
- )
704
- if defense_src.exists():
705
- defense_module_path = run_path / "defense.py"
706
- defense_module_path.write_text(defense_src.read_text())
707
- else:
708
- print(f"Warning: defense.py not found at {defense_src}")
709
-
710
- # Output file
711
- output_path = run_path / "results.json"
712
-
713
- # Build eval command
714
- cmd_parts = [
715
- "python3",
716
- str(eval_script_path),
717
- "--impl",
718
- str(impl_path),
719
- "--reference",
720
- str(ref_path),
721
- "--output",
722
- str(output_path),
723
- "--seed",
724
- str(args.seed),
725
- ]
726
-
727
- if args.benchmark:
728
- cmd_parts.append("--benchmark")
729
- if args.profile:
730
- cmd_parts.append("--profile")
731
- if inputs_path:
732
- cmd_parts.extend(["--inputs", str(inputs_path)])
733
- if args.defensive and defense_module_path:
734
- cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
735
-
736
- # Set environment for GPU selection
737
- env = os.environ.copy()
738
- if target.vendor == "nvidia":
739
- env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
740
- else: # AMD
741
- env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
742
- env["ROCM_PATH"] = "/opt/rocm"
743
-
744
- print(f"Running: {' '.join(cmd_parts[:4])} ...")
745
-
746
- # Run evaluation
747
- try:
748
- result = subprocess.run(
749
- cmd_parts,
750
- cwd=str(run_path),
751
- env=env,
752
- capture_output=True,
753
- text=True,
754
- timeout=args.timeout or 600,
755
- )
756
- except subprocess.TimeoutExpired:
757
- return EvaluateResult(
758
- success=False,
759
- all_correct=False,
760
- correctness_score=0.0,
761
- geomean_speedup=0.0,
762
- passed_tests=0,
763
- total_tests=0,
764
- error_message="Evaluation timed out",
765
- )
766
-
767
- if result.returncode != 0:
768
- error_msg = result.stderr or result.stdout or "Unknown error"
769
- # Truncate long errors
770
- if len(error_msg) > 1000:
771
- error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
772
- return EvaluateResult(
773
- success=False,
774
- all_correct=False,
775
- correctness_score=0.0,
776
- geomean_speedup=0.0,
777
- passed_tests=0,
778
- total_tests=0,
779
- error_message=f"Evaluation failed:\n{error_msg}",
780
- )
781
-
782
- # Parse results
783
- if not output_path.exists():
784
- return EvaluateResult(
785
- success=False,
786
- all_correct=False,
787
- correctness_score=0.0,
788
- geomean_speedup=0.0,
789
- passed_tests=0,
790
- total_tests=0,
791
- error_message="No results.json produced",
792
- )
793
-
794
- try:
795
- results = json.loads(output_path.read_text())
796
- except json.JSONDecodeError as e:
797
- return EvaluateResult(
798
- success=False,
799
- all_correct=False,
800
- correctness_score=0.0,
801
- geomean_speedup=0.0,
802
- passed_tests=0,
803
- total_tests=0,
804
- error_message=f"Failed to parse results: {e}",
805
- )
806
-
807
- # Extract results
808
- return EvaluateResult(
809
- success=True,
810
- all_correct=results.get("all_correct", False),
811
- correctness_score=results.get("correctness_score", 0.0),
812
- geomean_speedup=results.get("geomean_speedup", 0.0),
813
- passed_tests=results.get("passed_tests", 0),
814
- total_tests=results.get("total_tests", 0),
815
- benchmark_results=results.get("benchmark", {}),
816
- )
817
-
818
-
819
668
  async def run_evaluate_ssh(
820
669
  args: EvaluateArgs,
821
670
  target: BaremetalTarget | VMTarget,
@@ -1133,7 +982,6 @@ def _build_modal_sandbox_script(
1133
982
  test_cases_b64: str,
1134
983
  run_benchmarks: bool,
1135
984
  run_defensive: bool,
1136
- defense_code_b64: str | None = None,
1137
985
  ) -> str:
1138
986
  """Build Python script to create sandbox and run evaluation.
1139
987
 
@@ -1214,20 +1062,6 @@ print('Files written')
1214
1062
  print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
1215
1063
  return
1216
1064
 
1217
- # Write defense module if defensive mode is enabled
1218
- # NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
1219
- if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
1220
- proc = sandbox.exec("python", "-c", f"""
1221
- import base64
1222
- with open('/workspace/defense.py', 'w') as f:
1223
- f.write(base64.b64decode('{defense_code_b64}').decode())
1224
- print('Defense module written')
1225
- """)
1226
- proc.wait()
1227
- if proc.returncode != 0:
1228
- print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
1229
- return
1230
-
1231
1065
  # Build inline evaluation script
1232
1066
  eval_script = """
1233
1067
  import json
@@ -1255,26 +1089,6 @@ generate_input = load_fn('reference.py', 'generate_input')
1255
1089
 
1256
1090
  import torch
1257
1091
 
1258
- # Load defense module if available and defensive mode is enabled
1259
- run_defensive = {run_defensive}
1260
- defense = None
1261
- if run_defensive:
1262
- try:
1263
- defense = load_fn('defense.py', 'run_all_defenses')
1264
- time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
1265
- print('[Defense] Defense module loaded')
1266
-
1267
- # Wrap kernels for defense API compatibility
1268
- # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1269
- # These wrappers repack the unpacked args back into a tuple
1270
- def _wrap_for_defense(kernel):
1271
- return lambda *args: kernel(args)
1272
- custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1273
- ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1274
- except Exception as e:
1275
- print(f'[Defense] Warning: Could not load defense module: {{e}}')
1276
- defense = None
1277
-
1278
1092
  results = []
1279
1093
  all_correct = True
1280
1094
  total_time_ms = 0.0
@@ -1302,63 +1116,36 @@ for tc in test_cases:
1302
1116
  impl_time_ms = 0.0
1303
1117
  ref_time_ms = 0.0
1304
1118
  if {run_benchmarks}:
1305
- if run_defensive and defense is not None:
1306
- # Use full defense suite with wrapped kernels
1307
- # inputs_list unpacks the tuple so defense can infer dtype/device from tensors
1308
- inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1309
-
1310
- # Run defense checks
1311
- all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1312
- if not all_passed:
1313
- failed = [name for name, passed, _ in defense_results if not passed]
1314
- raise ValueError(f"Defense checks failed: {{failed}}")
1315
-
1316
- # Time with defensive timing (using wrapped kernels)
1317
- impl_times, _ = time_with_defenses(
1318
- custom_kernel_for_defense,
1319
- inputs_list,
1320
- num_warmup=3,
1321
- num_trials=10,
1322
- verbose=False,
1323
- run_defenses=False,
1324
- )
1325
- impl_time_ms = sum(impl_times) / len(impl_times)
1326
-
1327
- ref_times, _ = time_with_defenses(
1328
- ref_kernel_for_defense,
1329
- inputs_list,
1330
- num_warmup=3,
1331
- num_trials=10,
1332
- verbose=False,
1333
- run_defenses=False,
1334
- )
1335
- ref_time_ms = sum(ref_times) / len(ref_times)
1336
- else:
1337
- # Standard timing without full defenses
1338
- # Warmup
1339
- for _ in range(3):
1340
- custom_kernel(inputs)
1341
- torch.cuda.synchronize()
1342
-
1343
- start = torch.cuda.Event(enable_timing=True)
1344
- end = torch.cuda.Event(enable_timing=True)
1345
- start.record()
1346
- for _ in range(10):
1347
- custom_kernel(inputs)
1348
- end.record()
1349
- torch.cuda.synchronize()
1350
- impl_time_ms = start.elapsed_time(end) / 10
1351
-
1352
- # Reference timing
1353
- for _ in range(3):
1354
- ref_kernel(inputs)
1355
- torch.cuda.synchronize()
1356
- start.record()
1357
- for _ in range(10):
1358
- ref_kernel(inputs)
1359
- end.record()
1360
- torch.cuda.synchronize()
1361
- ref_time_ms = start.elapsed_time(end) / 10
1119
+ # Warmup
1120
+ for _ in range(3):
1121
+ custom_kernel(inputs)
1122
+ torch.cuda.synchronize()
1123
+
1124
+ # Measure with defensive timing if requested
1125
+ # Defensive: sync before recording end event to catch stream injection
1126
+ start = torch.cuda.Event(enable_timing=True)
1127
+ end = torch.cuda.Event(enable_timing=True)
1128
+ start.record()
1129
+ for _ in range(10):
1130
+ custom_kernel(inputs)
1131
+ if {run_defensive}:
1132
+ torch.cuda.synchronize() # DEFENSE: sync all streams before end
1133
+ end.record()
1134
+ torch.cuda.synchronize()
1135
+ impl_time_ms = start.elapsed_time(end) / 10
1136
+
1137
+ # Reference timing (same defensive approach)
1138
+ for _ in range(3):
1139
+ ref_kernel(inputs)
1140
+ torch.cuda.synchronize()
1141
+ start.record()
1142
+ for _ in range(10):
1143
+ ref_kernel(inputs)
1144
+ if {run_defensive}:
1145
+ torch.cuda.synchronize() # DEFENSE: sync all streams before end
1146
+ end.record()
1147
+ torch.cuda.synchronize()
1148
+ ref_time_ms = start.elapsed_time(end) / 10
1362
1149
 
1363
1150
  total_time_ms += impl_time_ms
1364
1151
  ref_total_time_ms += ref_time_ms
@@ -1451,23 +1238,6 @@ async def run_evaluate_modal(
1451
1238
  ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
1452
1239
  test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
1453
1240
 
1454
- # Encode defense module if defensive mode is enabled
1455
- defense_code_b64 = None
1456
- if args.defensive:
1457
- defense_path = (
1458
- Path(__file__).parent.parent.parent.parent
1459
- / "packages"
1460
- / "wafer-core"
1461
- / "wafer_core"
1462
- / "utils"
1463
- / "kernel_utils"
1464
- / "defense.py"
1465
- )
1466
- if defense_path.exists():
1467
- defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
1468
- else:
1469
- print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1470
-
1471
1241
  # Build the script that creates sandbox and runs eval
1472
1242
  script = _build_modal_sandbox_script(
1473
1243
  target=target,
@@ -1476,7 +1246,6 @@ async def run_evaluate_modal(
1476
1246
  test_cases_b64=test_cases_b64,
1477
1247
  run_benchmarks=args.benchmark,
1478
1248
  run_defensive=args.defensive,
1479
- defense_code_b64=defense_code_b64,
1480
1249
  )
1481
1250
 
1482
1251
  def _run_subprocess() -> tuple[str, str, int]:
@@ -1574,7 +1343,6 @@ def _build_workspace_eval_script(
1574
1343
  test_cases_json: str,
1575
1344
  run_benchmarks: bool,
1576
1345
  run_defensive: bool = False,
1577
- defense_code: str | None = None,
1578
1346
  ) -> str:
1579
1347
  """Build inline evaluation script for workspace exec.
1580
1348
 
@@ -1585,7 +1353,6 @@ def _build_workspace_eval_script(
1585
1353
  impl_b64 = base64.b64encode(impl_code.encode()).decode()
1586
1354
  ref_b64 = base64.b64encode(ref_code.encode()).decode()
1587
1355
  tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
1588
- defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
1589
1356
 
1590
1357
  return f'''
1591
1358
  import base64
@@ -1605,15 +1372,6 @@ with open("/tmp/kernel.py", "w") as f:
1605
1372
  with open("/tmp/reference.py", "w") as f:
1606
1373
  f.write(ref_code)
1607
1374
 
1608
- # Write defense module if available
1609
- run_defensive = {run_defensive}
1610
- defense_b64 = "{defense_b64}"
1611
- # NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
1612
- if run_defensive and defense_b64 and defense_b64 != "None":
1613
- defense_code = base64.b64decode(defense_b64).decode()
1614
- with open("/tmp/defense.py", "w") as f:
1615
- f.write(defense_code)
1616
-
1617
1375
  # Load kernels
1618
1376
  def load_fn(path, name):
1619
1377
  spec = importlib.util.spec_from_file_location("mod", path)
@@ -1627,24 +1385,6 @@ generate_input = load_fn("/tmp/reference.py", "generate_input")
1627
1385
 
1628
1386
  import torch
1629
1387
 
1630
- # Load defense module if available
1631
- defense = None
1632
- if run_defensive and defense_b64 and defense_b64 != "None":
1633
- try:
1634
- defense = load_fn("/tmp/defense.py", "run_all_defenses")
1635
- time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
1636
- print("[Defense] Defense module loaded")
1637
-
1638
- # Wrap kernels for defense API compatibility
1639
- # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1640
- def _wrap_for_defense(kernel):
1641
- return lambda *args: kernel(args)
1642
- custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1643
- ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1644
- except Exception as e:
1645
- print(f"[Defense] Warning: Could not load defense module: {{e}}")
1646
- defense = None
1647
-
1648
1388
  results = []
1649
1389
  all_correct = True
1650
1390
  total_time_ms = 0.0
@@ -1672,60 +1412,36 @@ for tc in test_cases:
1672
1412
  impl_time_ms = 0.0
1673
1413
  ref_time_ms = 0.0
1674
1414
  if {run_benchmarks}:
1675
- if run_defensive and defense is not None:
1676
- # Use full defense suite with wrapped kernels
1677
- inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1678
-
1679
- # Run defense checks
1680
- all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1681
- if not all_passed:
1682
- failed = [name for name, passed, _ in defense_results if not passed]
1683
- raise ValueError(f"Defense checks failed: {{failed}}")
1684
-
1685
- # Time with defensive timing (using wrapped kernels)
1686
- impl_times, _ = time_with_defenses(
1687
- custom_kernel_for_defense,
1688
- inputs_list,
1689
- num_warmup=3,
1690
- num_trials=10,
1691
- verbose=False,
1692
- run_defenses=False,
1693
- )
1694
- impl_time_ms = sum(impl_times) / len(impl_times)
1695
-
1696
- ref_times, _ = time_with_defenses(
1697
- ref_kernel_for_defense,
1698
- inputs_list,
1699
- num_warmup=3,
1700
- num_trials=10,
1701
- verbose=False,
1702
- run_defenses=False,
1703
- )
1704
- ref_time_ms = sum(ref_times) / len(ref_times)
1705
- else:
1706
- # Standard timing
1707
- for _ in range(3):
1708
- custom_kernel(inputs)
1709
- torch.cuda.synchronize()
1710
-
1711
- start = torch.cuda.Event(enable_timing=True)
1712
- end = torch.cuda.Event(enable_timing=True)
1713
- start.record()
1714
- for _ in range(10):
1715
- custom_kernel(inputs)
1716
- end.record()
1717
- torch.cuda.synchronize()
1718
- impl_time_ms = start.elapsed_time(end) / 10
1719
-
1720
- for _ in range(3):
1721
- ref_kernel(inputs)
1722
- torch.cuda.synchronize()
1723
- start.record()
1724
- for _ in range(10):
1725
- ref_kernel(inputs)
1726
- end.record()
1727
- torch.cuda.synchronize()
1728
- ref_time_ms = start.elapsed_time(end) / 10
1415
+ # Warmup
1416
+ for _ in range(3):
1417
+ custom_kernel(inputs)
1418
+ torch.cuda.synchronize()
1419
+
1420
+ # Measure with defensive timing if requested
1421
+ # Defensive: sync before recording end event to catch stream injection
1422
+ start = torch.cuda.Event(enable_timing=True)
1423
+ end = torch.cuda.Event(enable_timing=True)
1424
+ start.record()
1425
+ for _ in range(10):
1426
+ custom_kernel(inputs)
1427
+ if {run_defensive}:
1428
+ torch.cuda.synchronize() # DEFENSE: sync all streams before end
1429
+ end.record()
1430
+ torch.cuda.synchronize()
1431
+ impl_time_ms = start.elapsed_time(end) / 10
1432
+
1433
+ # Reference timing (same defensive approach)
1434
+ for _ in range(3):
1435
+ ref_kernel(inputs)
1436
+ torch.cuda.synchronize()
1437
+ start.record()
1438
+ for _ in range(10):
1439
+ ref_kernel(inputs)
1440
+ if {run_defensive}:
1441
+ torch.cuda.synchronize() # DEFENSE: sync all streams before end
1442
+ end.record()
1443
+ torch.cuda.synchronize()
1444
+ ref_time_ms = start.elapsed_time(end) / 10
1729
1445
 
1730
1446
  total_time_ms += impl_time_ms
1731
1447
  ref_total_time_ms += ref_time_ms
@@ -1787,23 +1503,6 @@ async def run_evaluate_workspace(
1787
1503
  ref_code = args.reference.read_text()
1788
1504
  test_cases_json = args.test_cases.read_text()
1789
1505
 
1790
- # Read defense module if defensive mode is enabled
1791
- defense_code = None
1792
- if args.defensive:
1793
- defense_path = (
1794
- Path(__file__).parent.parent.parent.parent
1795
- / "packages"
1796
- / "wafer-core"
1797
- / "wafer_core"
1798
- / "utils"
1799
- / "kernel_utils"
1800
- / "defense.py"
1801
- )
1802
- if defense_path.exists():
1803
- defense_code = defense_path.read_text()
1804
- else:
1805
- print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1806
-
1807
1506
  # Build inline eval script
1808
1507
  eval_script = _build_workspace_eval_script(
1809
1508
  impl_code=impl_code,
@@ -1811,7 +1510,6 @@ async def run_evaluate_workspace(
1811
1510
  test_cases_json=test_cases_json,
1812
1511
  run_benchmarks=args.benchmark,
1813
1512
  run_defensive=args.defensive,
1814
- defense_code=defense_code,
1815
1513
  )
1816
1514
 
1817
1515
  # Execute via workspace exec
@@ -2157,12 +1855,15 @@ async def run_evaluate_runpod(
2157
1855
  # Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
2158
1856
  venv_bin = env_state.venv_bin
2159
1857
  env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
1858
+ pythonpath = f"PYTHONPATH={wafer_core_workspace}"
1859
+ evaluate_script = (
1860
+ f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
1861
+ )
2160
1862
 
2161
1863
  # Run from run_path so reference_kernel.py is importable
2162
- # Use installed wafer-core module
2163
1864
  eval_cmd = (
2164
1865
  f"cd {run_path} && "
2165
- f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
1866
+ f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2166
1867
  f"--implementation {impl_path} "
2167
1868
  f"--reference {ref_path} "
2168
1869
  f"--test-cases {test_cases_path} "
@@ -2518,12 +2219,15 @@ async def run_evaluate_digitalocean(
2518
2219
  env_vars = (
2519
2220
  f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
2520
2221
  )
2222
+ pythonpath = f"PYTHONPATH={wafer_core_workspace}"
2223
+ evaluate_script = (
2224
+ f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
2225
+ )
2521
2226
 
2522
2227
  # Run from run_path so reference_kernel.py is importable
2523
- # Use installed wafer-core module
2524
2228
  eval_cmd = (
2525
2229
  f"cd {run_path} && "
2526
- f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
2230
+ f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2527
2231
  f"--implementation {impl_path} "
2528
2232
  f"--reference {ref_path} "
2529
2233
  f"--test-cases {test_cases_path} "
@@ -2703,9 +2407,7 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2703
2407
  print(f"Using target: {target_name}")
2704
2408
 
2705
2409
  # Dispatch to appropriate executor
2706
- if isinstance(target, LocalTarget):
2707
- return await run_evaluate_local(args, target)
2708
- elif isinstance(target, BaremetalTarget | VMTarget):
2410
+ if isinstance(target, BaremetalTarget | VMTarget):
2709
2411
  return await run_evaluate_ssh(args, target)
2710
2412
  elif isinstance(target, ModalTarget):
2711
2413
  return await run_evaluate_modal(args, target)
@@ -2972,26 +2674,12 @@ def main():
2972
2674
  parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
2973
2675
  parser.add_argument("--benchmark", action="store_true")
2974
2676
  parser.add_argument("--profile", action="store_true")
2975
- parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
2976
- parser.add_argument("--defense-module", help="Path to defense.py module")
2977
2677
  parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
2978
2678
  parser.add_argument("--num-correct-trials", type=int, default=3)
2979
2679
  parser.add_argument("--num-perf-trials", type=int, default=10)
2980
2680
  parser.add_argument("--output", required=True)
2981
2681
  args = parser.parse_args()
2982
2682
 
2983
- # Load defense module if defensive mode is enabled
2984
- defense_module = None
2985
- if args.defensive and args.defense_module:
2986
- try:
2987
- import importlib.util
2988
- defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
2989
- defense_module = importlib.util.module_from_spec(defense_spec)
2990
- defense_spec.loader.exec_module(defense_module)
2991
- print("[KernelBench] Defense module loaded")
2992
- except Exception as e:
2993
- print(f"[KernelBench] Warning: Could not load defense module: {e}")
2994
-
2995
2683
  # Create output directory for profiles
2996
2684
  output_dir = Path(args.output).parent
2997
2685
  profile_dir = output_dir / "profiles"
@@ -3125,99 +2813,47 @@ def main():
3125
2813
  inputs = get_inputs()
3126
2814
  inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
3127
2815
 
3128
- if args.defensive and defense_module is not None:
3129
- # Use full defense suite
3130
- print("[KernelBench] Running defense checks on implementation...")
3131
- run_all_defenses = defense_module.run_all_defenses
3132
- time_with_defenses = defense_module.time_execution_with_defenses
2816
+ # Warmup
2817
+ for _ in range(5):
2818
+ with torch.no_grad():
2819
+ _ = new_model(*inputs)
2820
+ torch.cuda.synchronize()
3133
2821
 
3134
- # Run defense checks on implementation
3135
- all_passed, defense_results, _ = run_all_defenses(
3136
- lambda *x: new_model(*x),
3137
- *inputs,
3138
- )
3139
- results["defense_results"] = {
3140
- name: {"passed": passed, "message": msg}
3141
- for name, passed, msg in defense_results
3142
- }
3143
- if not all_passed:
3144
- failed = [name for name, passed, _ in defense_results if not passed]
3145
- results["error"] = f"Defense checks failed: {failed}"
3146
- print(f"[KernelBench] Defense checks FAILED: {failed}")
3147
- for name, passed, msg in defense_results:
3148
- status = "PASS" if passed else "FAIL"
3149
- print(f" [{status}] {name}: {msg}")
3150
- else:
3151
- print("[KernelBench] All defense checks passed")
3152
-
3153
- # Time with defensive timing
3154
- impl_times, _ = time_with_defenses(
3155
- lambda: new_model(*inputs),
3156
- [],
3157
- num_warmup=5,
3158
- num_trials=args.num_perf_trials,
3159
- verbose=False,
3160
- run_defenses=False, # Already ran above
3161
- )
3162
- new_time = sum(impl_times) / len(impl_times)
3163
- results["runtime_ms"] = new_time
3164
-
3165
- # Reference timing
3166
- ref_times, _ = time_with_defenses(
3167
- lambda: ref_model(*inputs),
3168
- [],
3169
- num_warmup=5,
3170
- num_trials=args.num_perf_trials,
3171
- verbose=False,
3172
- run_defenses=False,
3173
- )
3174
- ref_time = sum(ref_times) / len(ref_times)
3175
- results["reference_runtime_ms"] = ref_time
3176
- results["speedup"] = ref_time / new_time if new_time > 0 else 0
3177
- print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
3178
- else:
3179
- # Standard timing without full defenses
3180
- # Warmup
3181
- for _ in range(5):
3182
- with torch.no_grad():
3183
- _ = new_model(*inputs)
3184
- torch.cuda.synchronize()
2822
+ # Benchmark new model
2823
+ start = torch.cuda.Event(enable_timing=True)
2824
+ end = torch.cuda.Event(enable_timing=True)
3185
2825
 
3186
- # Benchmark new model
3187
- start = torch.cuda.Event(enable_timing=True)
3188
- end = torch.cuda.Event(enable_timing=True)
3189
-
3190
- times = []
3191
- for _ in range(args.num_perf_trials):
3192
- start.record()
3193
- with torch.no_grad():
3194
- _ = new_model(*inputs)
3195
- end.record()
3196
- torch.cuda.synchronize()
3197
- times.append(start.elapsed_time(end))
3198
-
3199
- new_time = sum(times) / len(times)
3200
- results["runtime_ms"] = new_time
3201
-
3202
- # Benchmark reference model
3203
- for _ in range(5):
3204
- with torch.no_grad():
3205
- _ = ref_model(*inputs)
2826
+ times = []
2827
+ for _ in range(args.num_perf_trials):
2828
+ start.record()
2829
+ with torch.no_grad():
2830
+ _ = new_model(*inputs)
2831
+ end.record()
3206
2832
  torch.cuda.synchronize()
2833
+ times.append(start.elapsed_time(end))
2834
+
2835
+ new_time = sum(times) / len(times)
2836
+ results["runtime_ms"] = new_time
2837
+
2838
+ # Benchmark reference model
2839
+ for _ in range(5):
2840
+ with torch.no_grad():
2841
+ _ = ref_model(*inputs)
2842
+ torch.cuda.synchronize()
3207
2843
 
3208
- times = []
3209
- for _ in range(args.num_perf_trials):
3210
- start.record()
3211
- with torch.no_grad():
3212
- _ = ref_model(*inputs)
3213
- end.record()
3214
- torch.cuda.synchronize()
3215
- times.append(start.elapsed_time(end))
2844
+ times = []
2845
+ for _ in range(args.num_perf_trials):
2846
+ start.record()
2847
+ with torch.no_grad():
2848
+ _ = ref_model(*inputs)
2849
+ end.record()
2850
+ torch.cuda.synchronize()
2851
+ times.append(start.elapsed_time(end))
3216
2852
 
3217
- ref_time = sum(times) / len(times)
3218
- results["reference_runtime_ms"] = ref_time
3219
- results["speedup"] = ref_time / new_time if new_time > 0 else 0
3220
- print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
2853
+ ref_time = sum(times) / len(times)
2854
+ results["reference_runtime_ms"] = ref_time
2855
+ results["speedup"] = ref_time / new_time if new_time > 0 else 0
2856
+ print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
3221
2857
 
3222
2858
  # Run profiling if requested and correctness passed
3223
2859
  if args.profile and all_correct:
@@ -3423,30 +3059,6 @@ async def run_evaluate_kernelbench_docker(
3423
3059
  error_message=f"Failed to write eval script: {write_result.stderr}",
3424
3060
  )
3425
3061
 
3426
- # Write defense module if defensive mode is enabled
3427
- defense_module_path = None
3428
- if args.defensive:
3429
- defense_path = (
3430
- Path(__file__).parent.parent.parent.parent
3431
- / "packages"
3432
- / "wafer-core"
3433
- / "wafer_core"
3434
- / "utils"
3435
- / "kernel_utils"
3436
- / "defense.py"
3437
- )
3438
- if defense_path.exists():
3439
- defense_code = defense_path.read_text()
3440
- defense_module_path = f"{run_path}/defense.py"
3441
- write_result = await client.exec(
3442
- f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3443
- )
3444
- if write_result.exit_code != 0:
3445
- print(f"Warning: Failed to write defense module: {write_result.stderr}")
3446
- defense_module_path = None
3447
- else:
3448
- print(f"Warning: defense.py not found at {defense_path}")
3449
-
3450
3062
  print("Running KernelBench evaluation in Docker container...")
3451
3063
 
3452
3064
  # Paths inside container
@@ -3456,7 +3068,6 @@ async def run_evaluate_kernelbench_docker(
3456
3068
  container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
3457
3069
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3458
3070
  container_output = f"{container_run_path}/results.json"
3459
- container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
3460
3071
 
3461
3072
  # Build eval command
3462
3073
  python_cmd_parts = [
@@ -3472,9 +3083,6 @@ async def run_evaluate_kernelbench_docker(
3472
3083
  python_cmd_parts.append("--profile")
3473
3084
  if container_inputs_path:
3474
3085
  python_cmd_parts.append(f"--inputs {container_inputs_path}")
3475
- if args.defensive and container_defense_path:
3476
- python_cmd_parts.append("--defensive")
3477
- python_cmd_parts.append(f"--defense-module {container_defense_path}")
3478
3086
  python_cmd_parts.append(f"--seed {args.seed}")
3479
3087
 
3480
3088
  eval_cmd = " ".join(python_cmd_parts)
@@ -3690,44 +3298,15 @@ async def run_evaluate_kernelbench_digitalocean(
3690
3298
  error_message=f"Failed to write eval script: {write_result.stderr}",
3691
3299
  )
3692
3300
 
3693
- # Write defense module if defensive mode is enabled
3694
- defense_module_path = None
3695
- if args.defensive:
3696
- defense_path = (
3697
- Path(__file__).parent.parent.parent.parent
3698
- / "packages"
3699
- / "wafer-core"
3700
- / "wafer_core"
3701
- / "utils"
3702
- / "kernel_utils"
3703
- / "defense.py"
3704
- )
3705
- if defense_path.exists():
3706
- defense_code = defense_path.read_text()
3707
- defense_module_path = f"{run_path}/defense.py"
3708
- write_result = await client.exec(
3709
- f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3710
- )
3711
- if write_result.exit_code != 0:
3712
- print(f"Warning: Failed to write defense module: {write_result.stderr}")
3713
- defense_module_path = None
3714
- else:
3715
- print(f"Warning: defense.py not found at {defense_path}")
3716
-
3717
3301
  print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
3718
3302
 
3719
3303
  # Paths inside container
3720
3304
  container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
3721
3305
  container_impl_path = f"{container_run_path}/implementation.py"
3722
3306
  container_ref_path = f"{container_run_path}/reference.py"
3723
- container_inputs_path = (
3724
- f"{container_run_path}/custom_inputs.py" if args.inputs else None
3725
- )
3307
+ container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
3726
3308
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
3727
3309
  container_output = f"{container_run_path}/results.json"
3728
- container_defense_path = (
3729
- f"{container_run_path}/defense.py" if defense_module_path else None
3730
- )
3731
3310
 
3732
3311
  # Build eval command
3733
3312
  python_cmd_parts = [
@@ -3743,9 +3322,6 @@ async def run_evaluate_kernelbench_digitalocean(
3743
3322
  python_cmd_parts.append("--profile")
3744
3323
  if container_inputs_path:
3745
3324
  python_cmd_parts.append(f"--inputs {container_inputs_path}")
3746
- if args.defensive and container_defense_path:
3747
- python_cmd_parts.append("--defensive")
3748
- python_cmd_parts.append(f"--defense-module {container_defense_path}")
3749
3325
  python_cmd_parts.append(f"--seed {args.seed}")
3750
3326
 
3751
3327
  eval_cmd = " ".join(python_cmd_parts)
@@ -3831,544 +3407,71 @@ async def run_evaluate_kernelbench_digitalocean(
3831
3407
  )
3832
3408
 
3833
3409
 
3834
- async def run_evaluate_kernelbench_runpod(
3835
- args: KernelBenchEvaluateArgs,
3836
- target: RunPodTarget,
3837
- ) -> EvaluateResult:
3838
- """Run KernelBench format evaluation directly on RunPod AMD GPU.
3839
-
3840
- Runs evaluation script directly on host (no Docker) since RunPod pods
3841
- already have PyTorch/ROCm installed.
3842
- """
3843
- from datetime import datetime
3410
+ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
3411
+ """Run KernelBench format evaluation on configured target.
3844
3412
 
3845
- from wafer_core.async_ssh import AsyncSSHClient
3846
- from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
3413
+ Args:
3414
+ args: KernelBench evaluate arguments
3847
3415
 
3848
- REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
3416
+ Returns:
3417
+ Evaluation result
3418
+ """
3419
+ from .targets import get_default_target, load_target
3849
3420
 
3850
- # Select GPU
3851
- gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
3421
+ # Validate input files
3422
+ err = _validate_kernelbench_files(args)
3423
+ if err:
3424
+ return EvaluateResult(
3425
+ success=False,
3426
+ all_correct=False,
3427
+ correctness_score=0.0,
3428
+ geomean_speedup=0.0,
3429
+ passed_tests=0,
3430
+ total_tests=0,
3431
+ error_message=err,
3432
+ )
3852
3433
 
3853
- print(f"Provisioning RunPod ({target.gpu_type_id})...")
3434
+ # Load target
3435
+ target_name = args.target_name
3436
+ if not target_name:
3437
+ target_name = get_default_target()
3438
+ if not target_name:
3439
+ return EvaluateResult(
3440
+ success=False,
3441
+ all_correct=False,
3442
+ correctness_score=0.0,
3443
+ geomean_speedup=0.0,
3444
+ passed_tests=0,
3445
+ total_tests=0,
3446
+ error_message=(
3447
+ "No target specified and no default set.\n"
3448
+ "Set up a target first:\n"
3449
+ " wafer config targets init ssh --name my-gpu --host user@host:22\n"
3450
+ " wafer config targets init runpod --gpu MI300X\n"
3451
+ "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
3452
+ ),
3453
+ )
3854
3454
 
3855
3455
  try:
3856
- async with runpod_ssh_context(target) as ssh_info:
3857
- ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
3858
- print(f"Connected to RunPod: {ssh_target}")
3859
-
3860
- async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
3861
- # Create workspace
3862
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3863
- run_dir = f"kernelbench_eval_{timestamp}"
3864
- run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
3456
+ target = load_target(target_name)
3457
+ except FileNotFoundError:
3458
+ return EvaluateResult(
3459
+ success=False,
3460
+ all_correct=False,
3461
+ correctness_score=0.0,
3462
+ geomean_speedup=0.0,
3463
+ passed_tests=0,
3464
+ total_tests=0,
3465
+ error_message=f"Target not found: {target_name}. Run: wafer config targets list",
3466
+ )
3865
3467
 
3866
- await client.exec(f"mkdir -p {run_path}")
3867
- print(f"Created run directory: {run_path}")
3868
-
3869
- # Read and upload files
3870
- impl_code = args.implementation.read_text()
3871
- ref_code = args.reference.read_text()
3872
-
3873
- # Write implementation
3874
- impl_path = f"{run_path}/implementation.py"
3875
- write_result = await client.exec(
3876
- f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
3877
- )
3878
- if write_result.exit_code != 0:
3879
- return EvaluateResult(
3880
- success=False,
3881
- all_correct=False,
3882
- correctness_score=0.0,
3883
- geomean_speedup=0.0,
3884
- passed_tests=0,
3885
- total_tests=0,
3886
- error_message=f"Failed to write implementation: {write_result.stderr}",
3887
- )
3888
-
3889
- # Write reference
3890
- ref_path = f"{run_path}/reference.py"
3891
- write_result = await client.exec(
3892
- f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
3893
- )
3894
- if write_result.exit_code != 0:
3895
- return EvaluateResult(
3896
- success=False,
3897
- all_correct=False,
3898
- correctness_score=0.0,
3899
- geomean_speedup=0.0,
3900
- passed_tests=0,
3901
- total_tests=0,
3902
- error_message=f"Failed to write reference: {write_result.stderr}",
3903
- )
3904
-
3905
- # Write custom inputs if provided
3906
- inputs_path = None
3907
- if args.inputs:
3908
- inputs_code = args.inputs.read_text()
3909
- inputs_path = f"{run_path}/custom_inputs.py"
3910
- write_result = await client.exec(
3911
- f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3912
- )
3913
- if write_result.exit_code != 0:
3914
- return EvaluateResult(
3915
- success=False,
3916
- all_correct=False,
3917
- correctness_score=0.0,
3918
- geomean_speedup=0.0,
3919
- passed_tests=0,
3920
- total_tests=0,
3921
- error_message=f"Failed to write custom inputs: {write_result.stderr}",
3922
- )
3923
-
3924
- # Write eval script
3925
- eval_script_path = f"{run_path}/kernelbench_eval.py"
3926
- write_result = await client.exec(
3927
- f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
3928
- )
3929
- if write_result.exit_code != 0:
3930
- return EvaluateResult(
3931
- success=False,
3932
- all_correct=False,
3933
- correctness_score=0.0,
3934
- geomean_speedup=0.0,
3935
- passed_tests=0,
3936
- total_tests=0,
3937
- error_message=f"Failed to write eval script: {write_result.stderr}",
3938
- )
3939
-
3940
- # Write defense module if defensive mode is enabled
3941
- defense_module_path = None
3942
- if args.defensive:
3943
- defense_path = (
3944
- Path(__file__).parent.parent.parent.parent
3945
- / "packages"
3946
- / "wafer-core"
3947
- / "wafer_core"
3948
- / "utils"
3949
- / "kernel_utils"
3950
- / "defense.py"
3951
- )
3952
- if defense_path.exists():
3953
- defense_code = defense_path.read_text()
3954
- defense_module_path = f"{run_path}/defense.py"
3955
- write_result = await client.exec(
3956
- f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3957
- )
3958
- if write_result.exit_code != 0:
3959
- print(f"Warning: Failed to write defense module: {write_result.stderr}")
3960
- defense_module_path = None
3961
- else:
3962
- print(f"Warning: defense.py not found at {defense_path}")
3963
-
3964
- print("Running KernelBench evaluation (AMD/ROCm)...")
3965
-
3966
- # Find Python with PyTorch - check common locations on RunPod
3967
- python_exe = "python3"
3968
- for candidate in [
3969
- "/opt/conda/envs/py_3.10/bin/python3",
3970
- "/opt/conda/bin/python3",
3971
- ]:
3972
- check = await client.exec(
3973
- f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
3974
- )
3975
- if "OK" in check.stdout:
3976
- python_exe = candidate
3977
- print(f"Using Python: {python_exe}")
3978
- break
3979
-
3980
- # Build eval command - run directly on host
3981
- output_path = f"{run_path}/results.json"
3982
- python_cmd_parts = [
3983
- f"{python_exe} {eval_script_path}",
3984
- f"--impl {impl_path}",
3985
- f"--reference {ref_path}",
3986
- f"--output {output_path}",
3987
- ]
3988
-
3989
- if args.benchmark:
3990
- python_cmd_parts.append("--benchmark")
3991
- if args.profile:
3992
- python_cmd_parts.append("--profile")
3993
- if inputs_path:
3994
- python_cmd_parts.append(f"--inputs {inputs_path}")
3995
- if args.defensive and defense_module_path:
3996
- python_cmd_parts.append("--defensive")
3997
- python_cmd_parts.append(f"--defense-module {defense_module_path}")
3998
- python_cmd_parts.append(f"--seed {args.seed}")
3999
-
4000
- eval_cmd = " ".join(python_cmd_parts)
4001
-
4002
- # Set environment for AMD GPU and run
4003
- env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4004
- full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4005
-
4006
- # Run and stream output
4007
- log_lines = []
4008
- async for line in client.exec_stream(full_cmd):
4009
- print(line)
4010
- log_lines.append(line)
4011
-
4012
- # Read results
4013
- cat_result = await client.exec(f"cat {output_path}")
4014
-
4015
- if cat_result.exit_code != 0:
4016
- log_tail = "\n".join(log_lines[-50:])
4017
- return EvaluateResult(
4018
- success=False,
4019
- all_correct=False,
4020
- correctness_score=0.0,
4021
- geomean_speedup=0.0,
4022
- passed_tests=0,
4023
- total_tests=0,
4024
- error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4025
- )
4026
-
4027
- # Parse results
4028
- try:
4029
- results_data = json.loads(cat_result.stdout)
4030
- except json.JSONDecodeError as e:
4031
- return EvaluateResult(
4032
- success=False,
4033
- all_correct=False,
4034
- correctness_score=0.0,
4035
- geomean_speedup=0.0,
4036
- passed_tests=0,
4037
- total_tests=0,
4038
- error_message=f"Failed to parse results: {e}",
4039
- )
4040
-
4041
- # Convert to EvaluateResult
4042
- correct = results_data.get("correct", False)
4043
- speedup = results_data.get("speedup", 0.0) or 0.0
4044
- error = results_data.get("error")
4045
-
4046
- if error:
4047
- return EvaluateResult(
4048
- success=False,
4049
- all_correct=False,
4050
- correctness_score=0.0,
4051
- geomean_speedup=0.0,
4052
- passed_tests=0,
4053
- total_tests=1,
4054
- error_message=error,
4055
- )
4056
-
4057
- return EvaluateResult(
4058
- success=True,
4059
- all_correct=correct,
4060
- correctness_score=1.0 if correct else 0.0,
4061
- geomean_speedup=speedup,
4062
- passed_tests=1 if correct else 0,
4063
- total_tests=1,
4064
- )
4065
-
4066
- except RunPodError as e:
4067
- return EvaluateResult(
4068
- success=False,
4069
- all_correct=False,
4070
- correctness_score=0.0,
4071
- geomean_speedup=0.0,
4072
- passed_tests=0,
4073
- total_tests=0,
4074
- error_message=f"RunPod error: {e}",
4075
- )
4076
-
4077
-
4078
- async def run_evaluate_kernelbench_baremetal_amd(
4079
- args: KernelBenchEvaluateArgs,
4080
- target: BaremetalTarget,
4081
- ) -> EvaluateResult:
4082
- """Run KernelBench format evaluation directly on AMD baremetal target.
4083
-
4084
- Runs evaluation script directly on host (no Docker) for AMD GPUs
4085
- that have PyTorch/ROCm installed.
4086
- """
4087
- from datetime import datetime
4088
-
4089
- from wafer_core.async_ssh import AsyncSSHClient
4090
-
4091
- REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
4092
-
4093
- # Select GPU
4094
- gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
4095
-
4096
- print(f"Connecting to {target.ssh_target}...")
4097
-
4098
- async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
4099
- # Create workspace
4100
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4101
- run_dir = f"kernelbench_eval_{timestamp}"
4102
- run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
4103
-
4104
- await client.exec(f"mkdir -p {run_path}")
4105
- print(f"Created run directory: {run_path}")
4106
-
4107
- # Read and upload files
4108
- impl_code = args.implementation.read_text()
4109
- ref_code = args.reference.read_text()
4110
-
4111
- # Write implementation
4112
- impl_path = f"{run_path}/implementation.py"
4113
- write_result = await client.exec(
4114
- f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4115
- )
4116
- if write_result.exit_code != 0:
4117
- return EvaluateResult(
4118
- success=False,
4119
- all_correct=False,
4120
- correctness_score=0.0,
4121
- geomean_speedup=0.0,
4122
- passed_tests=0,
4123
- total_tests=0,
4124
- error_message=f"Failed to write implementation: {write_result.stderr}",
4125
- )
4126
-
4127
- # Write reference
4128
- ref_path = f"{run_path}/reference.py"
4129
- write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
4130
- if write_result.exit_code != 0:
4131
- return EvaluateResult(
4132
- success=False,
4133
- all_correct=False,
4134
- correctness_score=0.0,
4135
- geomean_speedup=0.0,
4136
- passed_tests=0,
4137
- total_tests=0,
4138
- error_message=f"Failed to write reference: {write_result.stderr}",
4139
- )
4140
-
4141
- # Write custom inputs if provided
4142
- inputs_path = None
4143
- if args.inputs:
4144
- inputs_code = args.inputs.read_text()
4145
- inputs_path = f"{run_path}/custom_inputs.py"
4146
- write_result = await client.exec(
4147
- f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4148
- )
4149
- if write_result.exit_code != 0:
4150
- return EvaluateResult(
4151
- success=False,
4152
- all_correct=False,
4153
- correctness_score=0.0,
4154
- geomean_speedup=0.0,
4155
- passed_tests=0,
4156
- total_tests=0,
4157
- error_message=f"Failed to write custom inputs: {write_result.stderr}",
4158
- )
4159
-
4160
- # Write eval script
4161
- eval_script_path = f"{run_path}/kernelbench_eval.py"
4162
- write_result = await client.exec(
4163
- f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4164
- )
4165
- if write_result.exit_code != 0:
4166
- return EvaluateResult(
4167
- success=False,
4168
- all_correct=False,
4169
- correctness_score=0.0,
4170
- geomean_speedup=0.0,
4171
- passed_tests=0,
4172
- total_tests=0,
4173
- error_message=f"Failed to write eval script: {write_result.stderr}",
4174
- )
4175
-
4176
- # Write defense module if defensive mode is enabled
4177
- defense_module_path = None
4178
- if args.defensive:
4179
- defense_path = (
4180
- Path(__file__).parent.parent.parent.parent
4181
- / "packages"
4182
- / "wafer-core"
4183
- / "wafer_core"
4184
- / "utils"
4185
- / "kernel_utils"
4186
- / "defense.py"
4187
- )
4188
- if defense_path.exists():
4189
- defense_code = defense_path.read_text()
4190
- defense_module_path = f"{run_path}/defense.py"
4191
- write_result = await client.exec(
4192
- f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4193
- )
4194
- if write_result.exit_code != 0:
4195
- print(f"Warning: Failed to write defense module: {write_result.stderr}")
4196
- defense_module_path = None
4197
- else:
4198
- print(f"Warning: defense.py not found at {defense_path}")
4199
-
4200
- print("Running KernelBench evaluation (AMD/ROCm)...")
4201
-
4202
- # Find Python with PyTorch - check common locations
4203
- python_exe = "python3"
4204
- for candidate in [
4205
- "/opt/conda/envs/py_3.10/bin/python3",
4206
- "/opt/conda/bin/python3",
4207
- ]:
4208
- check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
4209
- if "OK" in check.stdout:
4210
- python_exe = candidate
4211
- print(f"Using Python: {python_exe}")
4212
- break
4213
-
4214
- # Build eval command - run directly on host
4215
- output_path = f"{run_path}/results.json"
4216
- python_cmd_parts = [
4217
- f"{python_exe} {eval_script_path}",
4218
- f"--impl {impl_path}",
4219
- f"--reference {ref_path}",
4220
- f"--output {output_path}",
4221
- ]
4222
-
4223
- if args.benchmark:
4224
- python_cmd_parts.append("--benchmark")
4225
- if args.profile:
4226
- python_cmd_parts.append("--profile")
4227
- if inputs_path:
4228
- python_cmd_parts.append(f"--inputs {inputs_path}")
4229
- if args.defensive and defense_module_path:
4230
- python_cmd_parts.append("--defensive")
4231
- python_cmd_parts.append(f"--defense-module {defense_module_path}")
4232
- python_cmd_parts.append(f"--seed {args.seed}")
4233
-
4234
- eval_cmd = " ".join(python_cmd_parts)
4235
-
4236
- # Set environment for AMD GPU and run
4237
- env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4238
- full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4239
-
4240
- # Run and stream output
4241
- log_lines = []
4242
- async for line in client.exec_stream(full_cmd):
4243
- print(line)
4244
- log_lines.append(line)
4245
-
4246
- # Read results
4247
- cat_result = await client.exec(f"cat {output_path}")
4248
-
4249
- if cat_result.exit_code != 0:
4250
- log_tail = "\n".join(log_lines[-50:])
4251
- return EvaluateResult(
4252
- success=False,
4253
- all_correct=False,
4254
- correctness_score=0.0,
4255
- geomean_speedup=0.0,
4256
- passed_tests=0,
4257
- total_tests=0,
4258
- error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4259
- )
4260
-
4261
- # Parse results
4262
- try:
4263
- results_data = json.loads(cat_result.stdout)
4264
- except json.JSONDecodeError as e:
4265
- return EvaluateResult(
4266
- success=False,
4267
- all_correct=False,
4268
- correctness_score=0.0,
4269
- geomean_speedup=0.0,
4270
- passed_tests=0,
4271
- total_tests=0,
4272
- error_message=f"Failed to parse results: {e}",
4273
- )
4274
-
4275
- # Convert to EvaluateResult
4276
- correct = results_data.get("correct", False)
4277
- speedup = results_data.get("speedup", 0.0) or 0.0
4278
- error = results_data.get("error")
4279
-
4280
- if error:
4281
- return EvaluateResult(
4282
- success=False,
4283
- all_correct=False,
4284
- correctness_score=0.0,
4285
- geomean_speedup=0.0,
4286
- passed_tests=0,
4287
- total_tests=1,
4288
- error_message=error,
4289
- )
4290
-
4291
- return EvaluateResult(
4292
- success=True,
4293
- all_correct=correct,
4294
- correctness_score=1.0 if correct else 0.0,
4295
- geomean_speedup=speedup,
4296
- passed_tests=1 if correct else 0,
4297
- total_tests=1,
4298
- )
4299
-
4300
-
4301
- async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
4302
- """Run KernelBench format evaluation on configured target.
4303
-
4304
- Args:
4305
- args: KernelBench evaluate arguments
4306
-
4307
- Returns:
4308
- Evaluation result
4309
- """
4310
- from .targets import get_default_target, load_target
4311
-
4312
- # Validate input files
4313
- err = _validate_kernelbench_files(args)
4314
- if err:
4315
- return EvaluateResult(
4316
- success=False,
4317
- all_correct=False,
4318
- correctness_score=0.0,
4319
- geomean_speedup=0.0,
4320
- passed_tests=0,
4321
- total_tests=0,
4322
- error_message=err,
4323
- )
4324
-
4325
- # Load target
4326
- target_name = args.target_name
4327
- if not target_name:
4328
- target_name = get_default_target()
4329
- if not target_name:
4330
- return EvaluateResult(
4331
- success=False,
4332
- all_correct=False,
4333
- correctness_score=0.0,
4334
- geomean_speedup=0.0,
4335
- passed_tests=0,
4336
- total_tests=0,
4337
- error_message=(
4338
- "No target specified and no default set.\n"
4339
- "Set up a target first:\n"
4340
- " wafer config targets init ssh --name my-gpu --host user@host:22\n"
4341
- " wafer config targets init runpod --gpu MI300X\n"
4342
- "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
4343
- ),
4344
- )
4345
-
4346
- try:
4347
- target = load_target(target_name)
4348
- except FileNotFoundError:
4349
- return EvaluateResult(
4350
- success=False,
4351
- all_correct=False,
4352
- correctness_score=0.0,
4353
- geomean_speedup=0.0,
4354
- passed_tests=0,
4355
- total_tests=0,
4356
- error_message=f"Target not found: {target_name}. Run: wafer config targets list",
4357
- )
4358
-
4359
- print(f"Using target: {target_name}")
3468
+ print(f"Using target: {target_name}")
4360
3469
 
4361
3470
  # Dispatch to appropriate executor
4362
3471
  if isinstance(target, DigitalOceanTarget):
4363
3472
  # DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
4364
3473
  return await run_evaluate_kernelbench_digitalocean(args, target)
4365
- elif isinstance(target, RunPodTarget):
4366
- # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4367
- return await run_evaluate_kernelbench_runpod(args, target)
4368
3474
  elif isinstance(target, BaremetalTarget | VMTarget):
4369
- # Check if this is an AMD target (gfx* compute capability) - run directly
4370
- if target.compute_capability and target.compute_capability.startswith("gfx"):
4371
- return await run_evaluate_kernelbench_baremetal_amd(args, target)
4372
3475
  # NVIDIA targets - require docker_image to be set
4373
3476
  if not target.docker_image:
4374
3477
  return EvaluateResult(
@@ -4394,6 +3497,6 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
4394
3497
  total_tests=0,
4395
3498
  error_message=(
4396
3499
  f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
4397
- "Use a DigitalOcean, RunPod, Baremetal, or VM target."
3500
+ "Use a DigitalOcean, Baremetal, or VM target."
4398
3501
  ),
4399
3502
  )