wafer-cli 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/evaluate.py CHANGED
@@ -158,6 +158,8 @@ class KernelBenchEvaluateArgs:
158
158
  target_name: str
159
159
  benchmark: bool = False
160
160
  profile: bool = False
161
+ inputs: Path | None = None # Custom inputs file to override get_inputs()
162
+ seed: int = 42 # Random seed for reproducibility
161
163
  defensive: bool = False
162
164
  sync_artifacts: bool = True
163
165
  gpu_id: int | None = None
@@ -349,18 +351,6 @@ def _build_docker_pip_install_cmd(target: BaremetalTarget | VMTarget) -> str:
349
351
  return " && ".join(commands)
350
352
 
351
353
 
352
- def _get_wafer_root() -> Path:
353
- """Get wafer monorepo root directory.
354
-
355
- Walks up from this file to find the wafer repo root (contains apps/, packages/).
356
- """
357
- current = Path(__file__).resolve()
358
- for parent in [current] + list(current.parents):
359
- if (parent / "apps").is_dir() and (parent / "packages").is_dir():
360
- return parent
361
- raise RuntimeError(f"Could not find wafer root from {__file__}")
362
-
363
-
364
354
  async def run_evaluate_docker(
365
355
  args: EvaluateArgs,
366
356
  target: BaremetalTarget | VMTarget,
@@ -394,33 +384,6 @@ async def run_evaluate_docker(
394
384
  print(f"Connecting to {target.ssh_target}...")
395
385
 
396
386
  async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
397
- # Upload wafer-core to remote
398
- try:
399
- wafer_root = _get_wafer_root()
400
- wafer_core_path = wafer_root / "packages" / "wafer-core"
401
- print(f"Uploading wafer-core from {wafer_core_path}...")
402
-
403
- # Create workspace and upload
404
- workspace_name = wafer_core_path.name
405
- remote_workspace = f"{REMOTE_WORKSPACE_BASE}/{workspace_name}"
406
- await client.exec(f"mkdir -p {remote_workspace}")
407
- wafer_core_workspace = await client.expand_path(remote_workspace)
408
-
409
- upload_result = await client.upload_files(
410
- str(wafer_core_path), wafer_core_workspace, recursive=True
411
- )
412
- print(f"Uploaded {upload_result.files_copied} files")
413
- except Exception as e:
414
- return EvaluateResult(
415
- success=False,
416
- all_correct=False,
417
- correctness_score=0.0,
418
- geomean_speedup=0.0,
419
- passed_tests=0,
420
- total_tests=0,
421
- error_message=f"Failed to upload wafer-core: {e}",
422
- )
423
-
424
387
  print(f"Using Docker image: {target.docker_image}")
425
388
  print(f"Using GPU {gpu_id}...")
426
389
 
@@ -429,10 +392,13 @@ async def run_evaluate_docker(
429
392
  ref_code = args.reference.read_text()
430
393
  test_cases_data = json.loads(args.test_cases.read_text())
431
394
 
432
- # Create a unique run directory
395
+ # Create workspace for evaluation files
433
396
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
434
397
  run_dir = f"wafer_eval_{timestamp}"
435
- run_path = f"{wafer_core_workspace}/{run_dir}"
398
+ eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
399
+ await client.exec(f"mkdir -p {eval_workspace}")
400
+ eval_workspace_expanded = await client.expand_path(eval_workspace)
401
+ run_path = f"{eval_workspace_expanded}/{run_dir}"
436
402
 
437
403
  print("Uploading evaluation files...")
438
404
 
@@ -519,17 +485,14 @@ async def run_evaluate_docker(
519
485
  container_impl_path = f"{container_run_path}/implementation.py"
520
486
  container_ref_path = f"{container_run_path}/reference.py"
521
487
  container_test_cases_path = f"{container_run_path}/test_cases.json"
522
- container_evaluate_script = (
523
- f"{CONTAINER_WORKSPACE}/wafer_core/utils/kernel_utils/evaluate.py"
524
- )
525
488
 
526
- # Build pip install command for torch and other deps (no wafer-core install needed)
489
+ # Build pip install command for torch and other deps, plus wafer-core
527
490
  pip_install_cmd = _build_docker_pip_install_cmd(target)
491
+ install_cmd = f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
528
492
 
529
- # Build evaluate command - use PYTHONPATH instead of installing wafer-core
493
+ # Build evaluate command using installed wafer-core module
530
494
  python_cmd_parts = [
531
- f"PYTHONPATH={CONTAINER_WORKSPACE}:$PYTHONPATH",
532
- f"python3 {container_evaluate_script}",
495
+ "python3 -m wafer_core.utils.kernel_utils.evaluate",
533
496
  f"--implementation {container_impl_path}",
534
497
  f"--reference {container_ref_path}",
535
498
  f"--test-cases {container_test_cases_path}",
@@ -545,8 +508,8 @@ async def run_evaluate_docker(
545
508
 
546
509
  eval_cmd = " ".join(python_cmd_parts)
547
510
 
548
- # Full command: install torch deps, then run evaluate with PYTHONPATH
549
- full_cmd = f"{pip_install_cmd} && cd {container_run_path} && {eval_cmd}"
511
+ # Full command: install deps + wafer-core, then run evaluate
512
+ full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
550
513
 
551
514
  # Build Docker run command
552
515
  # Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
@@ -556,7 +519,7 @@ async def run_evaluate_docker(
556
519
  working_dir=container_run_path,
557
520
  env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
558
521
  gpus="all",
559
- volumes={wafer_core_workspace: CONTAINER_WORKSPACE},
522
+ volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
560
523
  cap_add=["SYS_ADMIN"] if args.profile else None,
561
524
  )
562
525
 
@@ -980,6 +943,7 @@ def _build_modal_sandbox_script(
980
943
  test_cases_b64: str,
981
944
  run_benchmarks: bool,
982
945
  run_defensive: bool,
946
+ defense_code_b64: str | None = None,
983
947
  ) -> str:
984
948
  """Build Python script to create sandbox and run evaluation.
985
949
 
@@ -1060,6 +1024,19 @@ print('Files written')
1060
1024
  print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
1061
1025
  return
1062
1026
 
1027
+ # Write defense module if defensive mode is enabled
1028
+ if {run_defensive} and "{defense_code_b64}":
1029
+ proc = sandbox.exec("python", "-c", f"""
1030
+ import base64
1031
+ with open('/workspace/defense.py', 'w') as f:
1032
+ f.write(base64.b64decode('{defense_code_b64}').decode())
1033
+ print('Defense module written')
1034
+ """)
1035
+ proc.wait()
1036
+ if proc.returncode != 0:
1037
+ print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
1038
+ return
1039
+
1063
1040
  # Build inline evaluation script
1064
1041
  eval_script = """
1065
1042
  import json
@@ -1087,6 +1064,18 @@ generate_input = load_fn('reference.py', 'generate_input')
1087
1064
 
1088
1065
  import torch
1089
1066
 
1067
+ # Load defense module if available and defensive mode is enabled
1068
+ run_defensive = {run_defensive}
1069
+ defense = None
1070
+ if run_defensive:
1071
+ try:
1072
+ defense = load_fn('defense.py', 'run_all_defenses')
1073
+ time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
1074
+ print('[Defense] Defense module loaded')
1075
+ except Exception as e:
1076
+ print(f'[Defense] Warning: Could not load defense module: {{e}}')
1077
+ defense = None
1078
+
1090
1079
  results = []
1091
1080
  all_correct = True
1092
1081
  total_time_ms = 0.0
@@ -1114,36 +1103,63 @@ for tc in test_cases:
1114
1103
  impl_time_ms = 0.0
1115
1104
  ref_time_ms = 0.0
1116
1105
  if {run_benchmarks}:
1117
- # Warmup
1118
- for _ in range(3):
1119
- custom_kernel(inputs)
1120
- torch.cuda.synchronize()
1121
-
1122
- # Measure with defensive timing if requested
1123
- # Defensive: sync before recording end event to catch stream injection
1124
- start = torch.cuda.Event(enable_timing=True)
1125
- end = torch.cuda.Event(enable_timing=True)
1126
- start.record()
1127
- for _ in range(10):
1128
- custom_kernel(inputs)
1129
- if {run_defensive}:
1130
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1131
- end.record()
1132
- torch.cuda.synchronize()
1133
- impl_time_ms = start.elapsed_time(end) / 10
1134
-
1135
- # Reference timing (same defensive approach)
1136
- for _ in range(3):
1137
- ref_kernel(inputs)
1138
- torch.cuda.synchronize()
1139
- start.record()
1140
- for _ in range(10):
1141
- ref_kernel(inputs)
1142
- if {run_defensive}:
1143
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1144
- end.record()
1145
- torch.cuda.synchronize()
1146
- ref_time_ms = start.elapsed_time(end) / 10
1106
+ if run_defensive and defense is not None:
1107
+ # Use full defense suite
1108
+ # Run defense checks on implementation kernel
1109
+ all_passed, defense_results, _ = defense(
1110
+ lambda: custom_kernel(inputs),
1111
+ )
1112
+ if not all_passed:
1113
+ failed = [name for name, passed, _ in defense_results if not passed]
1114
+ raise ValueError(f"Defense checks failed: {{failed}}")
1115
+
1116
+ # Time with defensive timing
1117
+ impl_times, _ = time_with_defenses(
1118
+ lambda: custom_kernel(inputs),
1119
+ [],
1120
+ num_warmup=3,
1121
+ num_trials=10,
1122
+ verbose=False,
1123
+ run_defenses=False, # Already ran defenses above
1124
+ )
1125
+ impl_time_ms = sum(impl_times) / len(impl_times)
1126
+
1127
+ # Reference timing (no defense checks needed)
1128
+ ref_times, _ = time_with_defenses(
1129
+ lambda: ref_kernel(inputs),
1130
+ [],
1131
+ num_warmup=3,
1132
+ num_trials=10,
1133
+ verbose=False,
1134
+ run_defenses=False,
1135
+ )
1136
+ ref_time_ms = sum(ref_times) / len(ref_times)
1137
+ else:
1138
+ # Standard timing without full defenses
1139
+ # Warmup
1140
+ for _ in range(3):
1141
+ custom_kernel(inputs)
1142
+ torch.cuda.synchronize()
1143
+
1144
+ start = torch.cuda.Event(enable_timing=True)
1145
+ end = torch.cuda.Event(enable_timing=True)
1146
+ start.record()
1147
+ for _ in range(10):
1148
+ custom_kernel(inputs)
1149
+ end.record()
1150
+ torch.cuda.synchronize()
1151
+ impl_time_ms = start.elapsed_time(end) / 10
1152
+
1153
+ # Reference timing
1154
+ for _ in range(3):
1155
+ ref_kernel(inputs)
1156
+ torch.cuda.synchronize()
1157
+ start.record()
1158
+ for _ in range(10):
1159
+ ref_kernel(inputs)
1160
+ end.record()
1161
+ torch.cuda.synchronize()
1162
+ ref_time_ms = start.elapsed_time(end) / 10
1147
1163
 
1148
1164
  total_time_ms += impl_time_ms
1149
1165
  ref_total_time_ms += ref_time_ms
@@ -1236,6 +1252,23 @@ async def run_evaluate_modal(
1236
1252
  ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
1237
1253
  test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
1238
1254
 
1255
+ # Encode defense module if defensive mode is enabled
1256
+ defense_code_b64 = None
1257
+ if args.defensive:
1258
+ defense_path = (
1259
+ Path(__file__).parent.parent.parent.parent
1260
+ / "packages"
1261
+ / "wafer-core"
1262
+ / "wafer_core"
1263
+ / "utils"
1264
+ / "kernel_utils"
1265
+ / "defense.py"
1266
+ )
1267
+ if defense_path.exists():
1268
+ defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
1269
+ else:
1270
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1271
+
1239
1272
  # Build the script that creates sandbox and runs eval
1240
1273
  script = _build_modal_sandbox_script(
1241
1274
  target=target,
@@ -1244,6 +1277,7 @@ async def run_evaluate_modal(
1244
1277
  test_cases_b64=test_cases_b64,
1245
1278
  run_benchmarks=args.benchmark,
1246
1279
  run_defensive=args.defensive,
1280
+ defense_code_b64=defense_code_b64,
1247
1281
  )
1248
1282
 
1249
1283
  def _run_subprocess() -> tuple[str, str, int]:
@@ -1341,6 +1375,7 @@ def _build_workspace_eval_script(
1341
1375
  test_cases_json: str,
1342
1376
  run_benchmarks: bool,
1343
1377
  run_defensive: bool = False,
1378
+ defense_code: str | None = None,
1344
1379
  ) -> str:
1345
1380
  """Build inline evaluation script for workspace exec.
1346
1381
 
@@ -1351,6 +1386,7 @@ def _build_workspace_eval_script(
1351
1386
  impl_b64 = base64.b64encode(impl_code.encode()).decode()
1352
1387
  ref_b64 = base64.b64encode(ref_code.encode()).decode()
1353
1388
  tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
1389
+ defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
1354
1390
 
1355
1391
  return f'''
1356
1392
  import base64
@@ -1370,6 +1406,14 @@ with open("/tmp/kernel.py", "w") as f:
1370
1406
  with open("/tmp/reference.py", "w") as f:
1371
1407
  f.write(ref_code)
1372
1408
 
1409
+ # Write defense module if available
1410
+ run_defensive = {run_defensive}
1411
+ defense_b64 = "{defense_b64}"
1412
+ if run_defensive and defense_b64:
1413
+ defense_code = base64.b64decode(defense_b64).decode()
1414
+ with open("/tmp/defense.py", "w") as f:
1415
+ f.write(defense_code)
1416
+
1373
1417
  # Load kernels
1374
1418
  def load_fn(path, name):
1375
1419
  spec = importlib.util.spec_from_file_location("mod", path)
@@ -1383,6 +1427,17 @@ generate_input = load_fn("/tmp/reference.py", "generate_input")
1383
1427
 
1384
1428
  import torch
1385
1429
 
1430
+ # Load defense module if available
1431
+ defense = None
1432
+ if run_defensive and defense_b64:
1433
+ try:
1434
+ defense = load_fn("/tmp/defense.py", "run_all_defenses")
1435
+ time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
1436
+ print("[Defense] Defense module loaded")
1437
+ except Exception as e:
1438
+ print(f"[Defense] Warning: Could not load defense module: {{e}}")
1439
+ defense = None
1440
+
1386
1441
  results = []
1387
1442
  all_correct = True
1388
1443
  total_time_ms = 0.0
@@ -1410,36 +1465,60 @@ for tc in test_cases:
1410
1465
  impl_time_ms = 0.0
1411
1466
  ref_time_ms = 0.0
1412
1467
  if {run_benchmarks}:
1413
- # Warmup
1414
- for _ in range(3):
1415
- custom_kernel(inputs)
1416
- torch.cuda.synchronize()
1417
-
1418
- # Measure with defensive timing if requested
1419
- # Defensive: sync before recording end event to catch stream injection
1420
- start = torch.cuda.Event(enable_timing=True)
1421
- end = torch.cuda.Event(enable_timing=True)
1422
- start.record()
1423
- for _ in range(10):
1424
- custom_kernel(inputs)
1425
- if {run_defensive}:
1426
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1427
- end.record()
1428
- torch.cuda.synchronize()
1429
- impl_time_ms = start.elapsed_time(end) / 10
1430
-
1431
- # Reference timing (same defensive approach)
1432
- for _ in range(3):
1433
- ref_kernel(inputs)
1434
- torch.cuda.synchronize()
1435
- start.record()
1436
- for _ in range(10):
1437
- ref_kernel(inputs)
1438
- if {run_defensive}:
1439
- torch.cuda.synchronize() # DEFENSE: sync all streams before end
1440
- end.record()
1441
- torch.cuda.synchronize()
1442
- ref_time_ms = start.elapsed_time(end) / 10
1468
+ if run_defensive and defense is not None:
1469
+ # Use full defense suite
1470
+ all_passed, defense_results, _ = defense(
1471
+ lambda: custom_kernel(inputs),
1472
+ )
1473
+ if not all_passed:
1474
+ failed = [name for name, passed, _ in defense_results if not passed]
1475
+ raise ValueError(f"Defense checks failed: {{failed}}")
1476
+
1477
+ # Time with defensive timing
1478
+ impl_times, _ = time_with_defenses(
1479
+ lambda: custom_kernel(inputs),
1480
+ [],
1481
+ num_warmup=3,
1482
+ num_trials=10,
1483
+ verbose=False,
1484
+ run_defenses=False,
1485
+ )
1486
+ impl_time_ms = sum(impl_times) / len(impl_times)
1487
+
1488
+ # Reference timing
1489
+ ref_times, _ = time_with_defenses(
1490
+ lambda: ref_kernel(inputs),
1491
+ [],
1492
+ num_warmup=3,
1493
+ num_trials=10,
1494
+ verbose=False,
1495
+ run_defenses=False,
1496
+ )
1497
+ ref_time_ms = sum(ref_times) / len(ref_times)
1498
+ else:
1499
+ # Standard timing
1500
+ for _ in range(3):
1501
+ custom_kernel(inputs)
1502
+ torch.cuda.synchronize()
1503
+
1504
+ start = torch.cuda.Event(enable_timing=True)
1505
+ end = torch.cuda.Event(enable_timing=True)
1506
+ start.record()
1507
+ for _ in range(10):
1508
+ custom_kernel(inputs)
1509
+ end.record()
1510
+ torch.cuda.synchronize()
1511
+ impl_time_ms = start.elapsed_time(end) / 10
1512
+
1513
+ for _ in range(3):
1514
+ ref_kernel(inputs)
1515
+ torch.cuda.synchronize()
1516
+ start.record()
1517
+ for _ in range(10):
1518
+ ref_kernel(inputs)
1519
+ end.record()
1520
+ torch.cuda.synchronize()
1521
+ ref_time_ms = start.elapsed_time(end) / 10
1443
1522
 
1444
1523
  total_time_ms += impl_time_ms
1445
1524
  ref_total_time_ms += ref_time_ms
@@ -1501,6 +1580,23 @@ async def run_evaluate_workspace(
1501
1580
  ref_code = args.reference.read_text()
1502
1581
  test_cases_json = args.test_cases.read_text()
1503
1582
 
1583
+ # Read defense module if defensive mode is enabled
1584
+ defense_code = None
1585
+ if args.defensive:
1586
+ defense_path = (
1587
+ Path(__file__).parent.parent.parent.parent
1588
+ / "packages"
1589
+ / "wafer-core"
1590
+ / "wafer_core"
1591
+ / "utils"
1592
+ / "kernel_utils"
1593
+ / "defense.py"
1594
+ )
1595
+ if defense_path.exists():
1596
+ defense_code = defense_path.read_text()
1597
+ else:
1598
+ print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
1599
+
1504
1600
  # Build inline eval script
1505
1601
  eval_script = _build_workspace_eval_script(
1506
1602
  impl_code=impl_code,
@@ -1508,6 +1604,7 @@ async def run_evaluate_workspace(
1508
1604
  test_cases_json=test_cases_json,
1509
1605
  run_benchmarks=args.benchmark,
1510
1606
  run_defensive=args.defensive,
1607
+ defense_code=defense_code,
1511
1608
  )
1512
1609
 
1513
1610
  # Execute via workspace exec
@@ -1691,54 +1788,12 @@ async def run_evaluate_runpod(
1691
1788
  error_message=f"Failed to setup Python environment: {e}",
1692
1789
  )
1693
1790
 
1694
- # Upload wafer-core to remote
1695
- try:
1696
- wafer_root = _get_wafer_root()
1697
- wafer_core_path = wafer_root / "packages" / "wafer-core"
1698
- print(f"Uploading wafer-core from {wafer_core_path}...")
1699
-
1700
- wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
1701
- await client.exec(f"mkdir -p {wafer_core_remote}")
1702
- wafer_core_workspace = await client.expand_path(wafer_core_remote)
1703
-
1704
- upload_result = await client.upload_files(
1705
- str(wafer_core_path), wafer_core_workspace, recursive=True
1706
- )
1707
-
1708
- # Wide event logging for upload result
1709
- upload_event = {
1710
- "event": "wafer_core_upload",
1711
- "target": target.name,
1712
- "target_type": "runpod",
1713
- "ssh_host": f"{client.user}@{client.host}:{client.port}",
1714
- "local_path": str(wafer_core_path),
1715
- "remote_path": wafer_core_workspace,
1716
- "success": upload_result.success,
1717
- "files_copied": upload_result.files_copied,
1718
- "duration_seconds": upload_result.duration_seconds,
1719
- "error_message": upload_result.error_message,
1720
- }
1721
- if upload_result.debug_info:
1722
- upload_event["debug_info"] = upload_result.debug_info
1723
- logger.info(json.dumps(upload_event))
1724
-
1725
- # Fail fast if upload failed
1726
- if not upload_result.success:
1727
- print(f"ERROR: Upload failed: {upload_result.error_message}")
1728
- if upload_result.debug_info:
1729
- print(f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}")
1730
- return EvaluateResult(
1731
- success=False,
1732
- all_correct=False,
1733
- correctness_score=0.0,
1734
- geomean_speedup=0.0,
1735
- passed_tests=0,
1736
- total_tests=0,
1737
- error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
1738
- )
1739
-
1740
- print(f"Uploaded {upload_result.files_copied} files")
1741
- except Exception as e:
1791
+ # Install wafer-core in remote venv
1792
+ print("Installing wafer-core...")
1793
+ install_result = await client.exec(
1794
+ f"{env_state.venv_bin}/uv pip install wafer-core"
1795
+ )
1796
+ if install_result.exit_code != 0:
1742
1797
  return EvaluateResult(
1743
1798
  success=False,
1744
1799
  all_correct=False,
@@ -1746,7 +1801,7 @@ async def run_evaluate_runpod(
1746
1801
  geomean_speedup=0.0,
1747
1802
  passed_tests=0,
1748
1803
  total_tests=0,
1749
- error_message=f"Failed to upload wafer-core: {e}",
1804
+ error_message=f"Failed to install wafer-core: {install_result.stderr}",
1750
1805
  )
1751
1806
 
1752
1807
  # Select GPU (RunPod pods typically have GPU 0)
@@ -1853,15 +1908,12 @@ async def run_evaluate_runpod(
1853
1908
  # Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
1854
1909
  venv_bin = env_state.venv_bin
1855
1910
  env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
1856
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
1857
- evaluate_script = (
1858
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
1859
- )
1860
1911
 
1861
1912
  # Run from run_path so reference_kernel.py is importable
1913
+ # Use installed wafer-core module
1862
1914
  eval_cmd = (
1863
1915
  f"cd {run_path} && "
1864
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
1916
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
1865
1917
  f"--implementation {impl_path} "
1866
1918
  f"--reference {ref_path} "
1867
1919
  f"--test-cases {test_cases_path} "
@@ -2046,61 +2098,12 @@ async def run_evaluate_digitalocean(
2046
2098
  error_message=f"Failed to setup Python environment: {e}",
2047
2099
  )
2048
2100
 
2049
- # Upload wafer-core to remote
2050
- try:
2051
- wafer_root = _get_wafer_root()
2052
- wafer_core_path = wafer_root / "packages" / "wafer-core"
2053
- print(f"Uploading wafer-core from {wafer_core_path}...")
2054
-
2055
- wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
2056
- await client.exec(f"mkdir -p {wafer_core_remote}")
2057
- wafer_core_workspace = await client.expand_path(wafer_core_remote)
2058
-
2059
- # Use SFTP instead of rsync to avoid SSH subprocess timeout issues
2060
- # (DigitalOcean may rate-limit new SSH connections)
2061
- upload_result = await client.upload_files(
2062
- str(wafer_core_path),
2063
- wafer_core_workspace,
2064
- recursive=True,
2065
- use_sftp=True,
2066
- )
2067
-
2068
- # Wide event logging for upload result
2069
- upload_event = {
2070
- "event": "wafer_core_upload",
2071
- "target": target.name,
2072
- "target_type": "digitalocean",
2073
- "ssh_host": f"{client.user}@{client.host}:{client.port}",
2074
- "local_path": str(wafer_core_path),
2075
- "remote_path": wafer_core_workspace,
2076
- "success": upload_result.success,
2077
- "files_copied": upload_result.files_copied,
2078
- "duration_seconds": upload_result.duration_seconds,
2079
- "error_message": upload_result.error_message,
2080
- }
2081
- if upload_result.debug_info:
2082
- upload_event["debug_info"] = upload_result.debug_info
2083
- logger.info(json.dumps(upload_event))
2084
-
2085
- # Fail fast if upload failed
2086
- if not upload_result.success:
2087
- print(f"ERROR: Upload failed: {upload_result.error_message}")
2088
- if upload_result.debug_info:
2089
- print(
2090
- f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}"
2091
- )
2092
- return EvaluateResult(
2093
- success=False,
2094
- all_correct=False,
2095
- correctness_score=0.0,
2096
- geomean_speedup=0.0,
2097
- passed_tests=0,
2098
- total_tests=0,
2099
- error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
2100
- )
2101
-
2102
- print(f"Uploaded {upload_result.files_copied} files")
2103
- except Exception as e:
2101
+ # Install wafer-core in remote venv
2102
+ print("Installing wafer-core...")
2103
+ install_result = await client.exec(
2104
+ f"{env_state.venv_bin}/uv pip install wafer-core"
2105
+ )
2106
+ if install_result.exit_code != 0:
2104
2107
  return EvaluateResult(
2105
2108
  success=False,
2106
2109
  all_correct=False,
@@ -2108,7 +2111,7 @@ async def run_evaluate_digitalocean(
2108
2111
  geomean_speedup=0.0,
2109
2112
  passed_tests=0,
2110
2113
  total_tests=0,
2111
- error_message=f"Failed to upload wafer-core: {e}",
2114
+ error_message=f"Failed to install wafer-core: {install_result.stderr}",
2112
2115
  )
2113
2116
 
2114
2117
  # Select GPU (DigitalOcean droplets typically have GPU 0)
@@ -2217,15 +2220,12 @@ async def run_evaluate_digitalocean(
2217
2220
  env_vars = (
2218
2221
  f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
2219
2222
  )
2220
- pythonpath = f"PYTHONPATH={wafer_core_workspace}"
2221
- evaluate_script = (
2222
- f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
2223
- )
2224
2223
 
2225
2224
  # Run from run_path so reference_kernel.py is importable
2225
+ # Use installed wafer-core module
2226
2226
  eval_cmd = (
2227
2227
  f"cd {run_path} && "
2228
- f"{env_vars} {pythonpath} {python_exe} {evaluate_script} "
2228
+ f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
2229
2229
  f"--implementation {impl_path} "
2230
2230
  f"--reference {ref_path} "
2231
2231
  f"--test-cases {test_cases_path} "
@@ -2435,10 +2435,233 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2435
2435
  # This runs inside the Docker container on the remote GPU
2436
2436
  KERNELBENCH_EVAL_SCRIPT = """
2437
2437
  import json
2438
+ import os
2438
2439
  import sys
2439
2440
  import time
2440
2441
  import torch
2441
2442
  import torch.nn as nn
2443
+ from pathlib import Path
2444
+
2445
+
2446
+ def run_profiling(model, inputs, name, output_dir):
2447
+ '''Run torch.profiler and return summary stats.'''
2448
+ from torch.profiler import profile, ProfilerActivity
2449
+
2450
+ # Determine activities based on backend
2451
+ activities = [ProfilerActivity.CPU]
2452
+ if torch.cuda.is_available():
2453
+ activities.append(ProfilerActivity.CUDA)
2454
+
2455
+ # Warmup
2456
+ for _ in range(3):
2457
+ with torch.no_grad():
2458
+ _ = model(*inputs)
2459
+ torch.cuda.synchronize()
2460
+
2461
+ # Profile
2462
+ with profile(
2463
+ activities=activities,
2464
+ record_shapes=True,
2465
+ with_stack=False,
2466
+ profile_memory=True,
2467
+ ) as prof:
2468
+ with torch.no_grad():
2469
+ _ = model(*inputs)
2470
+ torch.cuda.synchronize()
2471
+
2472
+ # Get key averages
2473
+ key_averages = prof.key_averages()
2474
+
2475
+ # Find the main kernel (longest GPU time)
2476
+ # Use cuda_time_total for compatibility with both CUDA and ROCm
2477
+ def get_gpu_time(e):
2478
+ # Try different attributes for GPU time
2479
+ if hasattr(e, 'cuda_time_total'):
2480
+ return e.cuda_time_total
2481
+ if hasattr(e, 'device_time_total'):
2482
+ return e.device_time_total
2483
+ if hasattr(e, 'self_cuda_time_total'):
2484
+ return e.self_cuda_time_total
2485
+ return 0
2486
+
2487
+ gpu_events = [e for e in key_averages if get_gpu_time(e) > 0]
2488
+ gpu_events.sort(key=lambda e: get_gpu_time(e), reverse=True)
2489
+
2490
+ stats = {
2491
+ "name": name,
2492
+ "total_gpu_time_ms": sum(get_gpu_time(e) for e in gpu_events) / 1000,
2493
+ "total_cpu_time_ms": sum(e.cpu_time_total for e in key_averages) / 1000,
2494
+ "num_gpu_kernels": len(gpu_events),
2495
+ "top_kernels": [],
2496
+ }
2497
+
2498
+ # Top 5 kernels by GPU time
2499
+ for e in gpu_events[:5]:
2500
+ stats["top_kernels"].append({
2501
+ "name": e.key,
2502
+ "gpu_time_ms": get_gpu_time(e) / 1000,
2503
+ "cpu_time_ms": e.cpu_time_total / 1000,
2504
+ "calls": e.count,
2505
+ })
2506
+
2507
+ # Save trace for visualization
2508
+ trace_path = Path(output_dir) / f"{name}_trace.json"
2509
+ prof.export_chrome_trace(str(trace_path))
2510
+ stats["trace_file"] = str(trace_path)
2511
+
2512
+ return stats
2513
+
2514
+
2515
+ def validate_custom_inputs(original_inputs, custom_inputs):
2516
+ '''Validate that custom inputs match the expected signature.
2517
+
2518
+ Returns (is_valid, error_message).
2519
+ '''
2520
+ if len(original_inputs) != len(custom_inputs):
2521
+ return False, f"get_inputs() must return {len(original_inputs)} tensors, got {len(custom_inputs)}"
2522
+
2523
+ for i, (orig, cust) in enumerate(zip(original_inputs, custom_inputs)):
2524
+ if not isinstance(cust, torch.Tensor):
2525
+ if not isinstance(orig, torch.Tensor):
2526
+ continue # Both non-tensor, ok
2527
+ return False, f"Input {i}: expected Tensor, got {type(cust).__name__}"
2528
+
2529
+ if not isinstance(orig, torch.Tensor):
2530
+ return False, f"Input {i}: expected {type(orig).__name__}, got Tensor"
2531
+
2532
+ if orig.dtype != cust.dtype:
2533
+ return False, f"Input {i}: dtype mismatch - expected {orig.dtype}, got {cust.dtype}"
2534
+
2535
+ if orig.dim() != cust.dim():
2536
+ return False, f"Input {i}: dimension mismatch - expected {orig.dim()}D, got {cust.dim()}D"
2537
+
2538
+ return True, None
2539
+
2540
+
2541
+ def analyze_diff(ref_output, new_output, rtol=1e-3, atol=1e-3, max_samples=5):
2542
+ '''Analyze differences between reference and implementation outputs.
2543
+
2544
+ Returns a dict with detailed diff information.
2545
+ '''
2546
+ diff = (ref_output - new_output).abs()
2547
+ threshold = atol + rtol * ref_output.abs()
2548
+ wrong_mask = diff > threshold
2549
+
2550
+ total_elements = ref_output.numel()
2551
+ wrong_count = wrong_mask.sum().item()
2552
+
2553
+ # Basic stats
2554
+ max_diff = diff.max().item()
2555
+ max_diff_idx = tuple(torch.unravel_index(diff.argmax(), diff.shape))
2556
+ max_diff_idx = tuple(int(i) for i in max_diff_idx) # Convert to Python ints
2557
+
2558
+ # Relative error (avoid div by zero)
2559
+ ref_abs = ref_output.abs()
2560
+ nonzero_mask = ref_abs > 1e-8
2561
+ if nonzero_mask.any():
2562
+ rel_error = diff[nonzero_mask] / ref_abs[nonzero_mask]
2563
+ max_rel_error = rel_error.max().item()
2564
+ mean_rel_error = rel_error.mean().item()
2565
+ else:
2566
+ max_rel_error = float('inf') if max_diff > 0 else 0.0
2567
+ mean_rel_error = max_rel_error
2568
+
2569
+ # Error histogram (buckets: <1e-6, 1e-6 to 1e-4, 1e-4 to 1e-2, 1e-2 to 1, >1)
2570
+ histogram = {
2571
+ '<1e-6': int((diff < 1e-6).sum().item()),
2572
+ '1e-6 to 1e-4': int(((diff >= 1e-6) & (diff < 1e-4)).sum().item()),
2573
+ '1e-4 to 1e-2': int(((diff >= 1e-4) & (diff < 1e-2)).sum().item()),
2574
+ '1e-2 to 1': int(((diff >= 1e-2) & (diff < 1)).sum().item()),
2575
+ '>1': int((diff >= 1).sum().item()),
2576
+ }
2577
+
2578
+ result = {
2579
+ 'max_diff': max_diff,
2580
+ 'max_diff_idx': max_diff_idx,
2581
+ 'mean_diff': diff.mean().item(),
2582
+ 'max_rel_error': max_rel_error,
2583
+ 'mean_rel_error': mean_rel_error,
2584
+ 'total_elements': total_elements,
2585
+ 'wrong_count': int(wrong_count),
2586
+ 'wrong_pct': 100.0 * wrong_count / total_elements,
2587
+ 'histogram': histogram,
2588
+ 'samples': [],
2589
+ }
2590
+
2591
+ # Get indices of wrong elements
2592
+ if wrong_count > 0:
2593
+ wrong_indices = torch.nonzero(wrong_mask, as_tuple=False)
2594
+
2595
+ # Take first N samples
2596
+ num_samples = min(max_samples, len(wrong_indices))
2597
+ for i in range(num_samples):
2598
+ idx = tuple(wrong_indices[i].tolist())
2599
+ ref_val = ref_output[idx].item()
2600
+ new_val = new_output[idx].item()
2601
+ diff_val = diff[idx].item()
2602
+ result['samples'].append({
2603
+ 'index': idx,
2604
+ 'ref': ref_val,
2605
+ 'impl': new_val,
2606
+ 'diff': diff_val,
2607
+ })
2608
+
2609
+ # Try to detect pattern
2610
+ if wrong_count >= total_elements * 0.99:
2611
+ result['pattern'] = 'all_wrong'
2612
+ elif wrong_count < total_elements * 0.01:
2613
+ # Check if failures are at boundaries
2614
+ shape = ref_output.shape
2615
+ boundary_count = 0
2616
+ for idx in wrong_indices[:min(100, len(wrong_indices))]:
2617
+ idx_list = idx.tolist()
2618
+ is_boundary = any(i == 0 or i == s - 1 for i, s in zip(idx_list, shape))
2619
+ if is_boundary:
2620
+ boundary_count += 1
2621
+ if boundary_count > len(wrong_indices[:100]) * 0.8:
2622
+ result['pattern'] = 'boundary_issue'
2623
+ else:
2624
+ result['pattern'] = 'scattered'
2625
+ else:
2626
+ result['pattern'] = 'partial'
2627
+
2628
+ return result
2629
+
2630
+
2631
+ def print_diff_analysis(analysis):
2632
+ '''Print a human-readable diff analysis.'''
2633
+ print(f"[KernelBench] Diff analysis:")
2634
+
2635
+ # Max diff with location
2636
+ idx_str = ','.join(str(i) for i in analysis['max_diff_idx'])
2637
+ print(f" Max diff: {analysis['max_diff']:.6f} at index [{idx_str}]")
2638
+ print(f" Mean diff: {analysis['mean_diff']:.6f}")
2639
+
2640
+ # Relative errors
2641
+ print(f" Max relative error: {analysis['max_rel_error']:.2%}, Mean: {analysis['mean_rel_error']:.2%}")
2642
+
2643
+ # Wrong count
2644
+ print(f" Wrong elements: {analysis['wrong_count']:,} / {analysis['total_elements']:,} ({analysis['wrong_pct']:.2f}%)")
2645
+
2646
+ # Histogram
2647
+ hist = analysis['histogram']
2648
+ print(f" Error distribution: <1e-6: {hist['<1e-6']:,} | 1e-6~1e-4: {hist['1e-6 to 1e-4']:,} | 1e-4~1e-2: {hist['1e-4 to 1e-2']:,} | 1e-2~1: {hist['1e-2 to 1']:,} | >1: {hist['>1']:,}")
2649
+
2650
+ if 'pattern' in analysis:
2651
+ pattern_desc = {
2652
+ 'all_wrong': 'ALL elements wrong - likely algorithmic error or wrong weights',
2653
+ 'boundary_issue': 'Mostly BOUNDARY elements wrong - check edge handling',
2654
+ 'scattered': 'SCATTERED failures - numerical precision issue?',
2655
+ 'partial': 'PARTIAL failures - check specific conditions',
2656
+ }
2657
+ print(f" Pattern: {pattern_desc.get(analysis['pattern'], analysis['pattern'])}")
2658
+
2659
+ if analysis['samples']:
2660
+ print(f" Sample failures:")
2661
+ for s in analysis['samples']:
2662
+ idx_str = ','.join(str(i) for i in s['index'])
2663
+ print(f" [{idx_str}]: ref={s['ref']:.6f} impl={s['impl']:.6f} (diff={s['diff']:.6f})")
2664
+
2442
2665
 
2443
2666
  def main():
2444
2667
  # Parse args
@@ -2446,12 +2669,35 @@ def main():
2446
2669
  parser = argparse.ArgumentParser()
2447
2670
  parser.add_argument("--impl", required=True)
2448
2671
  parser.add_argument("--reference", required=True)
2672
+ parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
2449
2673
  parser.add_argument("--benchmark", action="store_true")
2674
+ parser.add_argument("--profile", action="store_true")
2675
+ parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
2676
+ parser.add_argument("--defense-module", help="Path to defense.py module")
2677
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
2450
2678
  parser.add_argument("--num-correct-trials", type=int, default=3)
2451
2679
  parser.add_argument("--num-perf-trials", type=int, default=10)
2452
2680
  parser.add_argument("--output", required=True)
2453
2681
  args = parser.parse_args()
2454
2682
 
2683
+ # Load defense module if defensive mode is enabled
2684
+ defense_module = None
2685
+ if args.defensive and args.defense_module:
2686
+ try:
2687
+ import importlib.util
2688
+ defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
2689
+ defense_module = importlib.util.module_from_spec(defense_spec)
2690
+ defense_spec.loader.exec_module(defense_module)
2691
+ print("[KernelBench] Defense module loaded")
2692
+ except Exception as e:
2693
+ print(f"[KernelBench] Warning: Could not load defense module: {e}")
2694
+
2695
+ # Create output directory for profiles
2696
+ output_dir = Path(args.output).parent
2697
+ profile_dir = output_dir / "profiles"
2698
+ if args.profile:
2699
+ profile_dir.mkdir(exist_ok=True)
2700
+
2455
2701
  results = {
2456
2702
  "compiled": False,
2457
2703
  "correct": False,
@@ -2472,6 +2718,33 @@ def main():
2472
2718
  get_inputs = ref_module.get_inputs
2473
2719
  get_init_inputs = ref_module.get_init_inputs
2474
2720
 
2721
+ # Load custom inputs if provided
2722
+ if args.inputs:
2723
+ inputs_spec = importlib.util.spec_from_file_location("custom_inputs", args.inputs)
2724
+ inputs_module = importlib.util.module_from_spec(inputs_spec)
2725
+ inputs_spec.loader.exec_module(inputs_module)
2726
+
2727
+ # Validate custom inputs match expected signature
2728
+ original_inputs = get_inputs()
2729
+ custom_get_inputs = inputs_module.get_inputs
2730
+ custom_inputs = custom_get_inputs()
2731
+
2732
+ is_valid, error_msg = validate_custom_inputs(original_inputs, custom_inputs)
2733
+ if not is_valid:
2734
+ print(f"[KernelBench] Custom inputs validation failed: {error_msg}")
2735
+ results["error"] = f"Custom inputs validation failed: {error_msg}"
2736
+ raise ValueError(error_msg)
2737
+
2738
+ # Override get_inputs (and optionally get_init_inputs)
2739
+ get_inputs = custom_get_inputs
2740
+ if hasattr(inputs_module, 'get_init_inputs'):
2741
+ get_init_inputs = inputs_module.get_init_inputs
2742
+
2743
+ # Show what changed
2744
+ orig_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in original_inputs]
2745
+ cust_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in custom_inputs]
2746
+ print(f"[KernelBench] Using custom inputs: {orig_shapes} -> {cust_shapes}")
2747
+
2475
2748
  # Load implementation module
2476
2749
  impl_spec = importlib.util.spec_from_file_location("implementation", args.impl)
2477
2750
  impl_module = importlib.util.module_from_spec(impl_spec)
@@ -2481,12 +2754,19 @@ def main():
2481
2754
  results["compiled"] = True
2482
2755
  print("[KernelBench] Modules loaded successfully")
2483
2756
 
2484
- # Instantiate models
2757
+ # Instantiate models with synchronized seeds for reproducible weights
2758
+ # (matches upstream KernelBench behavior in src/eval.py)
2759
+ seed = args.seed
2485
2760
  init_inputs = get_init_inputs()
2486
2761
  with torch.no_grad():
2762
+ torch.manual_seed(seed)
2763
+ torch.cuda.manual_seed(seed)
2487
2764
  ref_model = Model(*init_inputs).cuda().eval()
2765
+
2766
+ torch.manual_seed(seed)
2767
+ torch.cuda.manual_seed(seed)
2488
2768
  new_model = ModelNew(*init_inputs).cuda().eval()
2489
- print("[KernelBench] Models instantiated")
2769
+ print(f"[KernelBench] Models instantiated (seed={seed})")
2490
2770
 
2491
2771
  # Run correctness trials
2492
2772
  all_correct = True
@@ -2502,8 +2782,18 @@ def main():
2502
2782
  if isinstance(ref_output, torch.Tensor):
2503
2783
  if not torch.allclose(ref_output, new_output, rtol=1e-3, atol=1e-3):
2504
2784
  all_correct = False
2505
- max_diff = (ref_output - new_output).abs().max().item()
2506
- results["error"] = f"Correctness failed on trial {trial+1}: max diff = {max_diff}"
2785
+ analysis = analyze_diff(ref_output, new_output)
2786
+ results["error"] = f"Correctness failed on trial {trial+1}: max diff = {analysis['max_diff']}"
2787
+ results["diff_analysis"] = analysis
2788
+ print_diff_analysis(analysis)
2789
+
2790
+ # Save tensors for debugging
2791
+ debug_dir = output_dir / "debug"
2792
+ debug_dir.mkdir(exist_ok=True)
2793
+ torch.save(ref_output.cpu(), debug_dir / "ref_output.pt")
2794
+ torch.save(new_output.cpu(), debug_dir / "impl_output.pt")
2795
+ torch.save(inputs[0].cpu() if inputs else None, debug_dir / "input.pt")
2796
+ print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
2507
2797
  break
2508
2798
  else:
2509
2799
  # Handle tuple/list outputs
@@ -2511,8 +2801,17 @@ def main():
2511
2801
  if isinstance(r, torch.Tensor):
2512
2802
  if not torch.allclose(r, n, rtol=1e-3, atol=1e-3):
2513
2803
  all_correct = False
2514
- max_diff = (r - n).abs().max().item()
2515
- results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {max_diff}"
2804
+ analysis = analyze_diff(r, n)
2805
+ results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {analysis['max_diff']}"
2806
+ results["diff_analysis"] = analysis
2807
+ print_diff_analysis(analysis)
2808
+
2809
+ # Save tensors for debugging
2810
+ debug_dir = output_dir / "debug"
2811
+ debug_dir.mkdir(exist_ok=True)
2812
+ torch.save(r.cpu(), debug_dir / f"ref_output_{i}.pt")
2813
+ torch.save(n.cpu(), debug_dir / f"impl_output_{i}.pt")
2814
+ print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
2516
2815
  break
2517
2816
  if not all_correct:
2518
2817
  break
@@ -2526,47 +2825,132 @@ def main():
2526
2825
  inputs = get_inputs()
2527
2826
  inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
2528
2827
 
2529
- # Warmup
2530
- for _ in range(5):
2531
- with torch.no_grad():
2532
- _ = new_model(*inputs)
2533
- torch.cuda.synchronize()
2828
+ if args.defensive and defense_module is not None:
2829
+ # Use full defense suite
2830
+ print("[KernelBench] Running defense checks on implementation...")
2831
+ run_all_defenses = defense_module.run_all_defenses
2832
+ time_with_defenses = defense_module.time_execution_with_defenses
2534
2833
 
2535
- # Benchmark new model
2536
- start = torch.cuda.Event(enable_timing=True)
2537
- end = torch.cuda.Event(enable_timing=True)
2538
-
2539
- times = []
2540
- for _ in range(args.num_perf_trials):
2541
- start.record()
2542
- with torch.no_grad():
2543
- _ = new_model(*inputs)
2544
- end.record()
2834
+ # Run defense checks on implementation
2835
+ all_passed, defense_results, _ = run_all_defenses(
2836
+ lambda *x: new_model(*x),
2837
+ *inputs,
2838
+ )
2839
+ results["defense_results"] = {
2840
+ name: {"passed": passed, "message": msg}
2841
+ for name, passed, msg in defense_results
2842
+ }
2843
+ if not all_passed:
2844
+ failed = [name for name, passed, _ in defense_results if not passed]
2845
+ results["error"] = f"Defense checks failed: {failed}"
2846
+ print(f"[KernelBench] Defense checks FAILED: {failed}")
2847
+ for name, passed, msg in defense_results:
2848
+ status = "PASS" if passed else "FAIL"
2849
+ print(f" [{status}] {name}: {msg}")
2850
+ else:
2851
+ print("[KernelBench] All defense checks passed")
2852
+
2853
+ # Time with defensive timing
2854
+ impl_times, _ = time_with_defenses(
2855
+ lambda: new_model(*inputs),
2856
+ [],
2857
+ num_warmup=5,
2858
+ num_trials=args.num_perf_trials,
2859
+ verbose=False,
2860
+ run_defenses=False, # Already ran above
2861
+ )
2862
+ new_time = sum(impl_times) / len(impl_times)
2863
+ results["runtime_ms"] = new_time
2864
+
2865
+ # Reference timing
2866
+ ref_times, _ = time_with_defenses(
2867
+ lambda: ref_model(*inputs),
2868
+ [],
2869
+ num_warmup=5,
2870
+ num_trials=args.num_perf_trials,
2871
+ verbose=False,
2872
+ run_defenses=False,
2873
+ )
2874
+ ref_time = sum(ref_times) / len(ref_times)
2875
+ results["reference_runtime_ms"] = ref_time
2876
+ results["speedup"] = ref_time / new_time if new_time > 0 else 0
2877
+ print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
2878
+ else:
2879
+ # Standard timing without full defenses
2880
+ # Warmup
2881
+ for _ in range(5):
2882
+ with torch.no_grad():
2883
+ _ = new_model(*inputs)
2545
2884
  torch.cuda.synchronize()
2546
- times.append(start.elapsed_time(end))
2547
2885
 
2548
- new_time = sum(times) / len(times)
2549
- results["runtime_ms"] = new_time
2550
-
2551
- # Benchmark reference model
2552
- for _ in range(5):
2553
- with torch.no_grad():
2554
- _ = ref_model(*inputs)
2555
- torch.cuda.synchronize()
2556
-
2557
- times = []
2558
- for _ in range(args.num_perf_trials):
2559
- start.record()
2560
- with torch.no_grad():
2561
- _ = ref_model(*inputs)
2562
- end.record()
2886
+ # Benchmark new model
2887
+ start = torch.cuda.Event(enable_timing=True)
2888
+ end = torch.cuda.Event(enable_timing=True)
2889
+
2890
+ times = []
2891
+ for _ in range(args.num_perf_trials):
2892
+ start.record()
2893
+ with torch.no_grad():
2894
+ _ = new_model(*inputs)
2895
+ end.record()
2896
+ torch.cuda.synchronize()
2897
+ times.append(start.elapsed_time(end))
2898
+
2899
+ new_time = sum(times) / len(times)
2900
+ results["runtime_ms"] = new_time
2901
+
2902
+ # Benchmark reference model
2903
+ for _ in range(5):
2904
+ with torch.no_grad():
2905
+ _ = ref_model(*inputs)
2563
2906
  torch.cuda.synchronize()
2564
- times.append(start.elapsed_time(end))
2565
2907
 
2566
- ref_time = sum(times) / len(times)
2567
- results["reference_runtime_ms"] = ref_time
2568
- results["speedup"] = ref_time / new_time if new_time > 0 else 0
2569
- print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
2908
+ times = []
2909
+ for _ in range(args.num_perf_trials):
2910
+ start.record()
2911
+ with torch.no_grad():
2912
+ _ = ref_model(*inputs)
2913
+ end.record()
2914
+ torch.cuda.synchronize()
2915
+ times.append(start.elapsed_time(end))
2916
+
2917
+ ref_time = sum(times) / len(times)
2918
+ results["reference_runtime_ms"] = ref_time
2919
+ results["speedup"] = ref_time / new_time if new_time > 0 else 0
2920
+ print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
2921
+
2922
+ # Run profiling if requested and correctness passed
2923
+ if args.profile and all_correct:
2924
+ print("[KernelBench] Running profiler...")
2925
+ inputs = get_inputs()
2926
+ inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
2927
+
2928
+ try:
2929
+ # Profile implementation
2930
+ impl_stats = run_profiling(new_model, inputs, "implementation", str(profile_dir))
2931
+ results["profile_impl"] = impl_stats
2932
+ print(f"[KernelBench] Implementation profile:")
2933
+ print(f" Total GPU time: {impl_stats['total_gpu_time_ms']:.3f}ms")
2934
+ print(f" Kernels launched: {impl_stats['num_gpu_kernels']}")
2935
+ if impl_stats['top_kernels']:
2936
+ print(f" Top kernel: {impl_stats['top_kernels'][0]['name'][:60]}...")
2937
+ print(f" {impl_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
2938
+
2939
+ # Profile reference
2940
+ ref_stats = run_profiling(ref_model, inputs, "reference", str(profile_dir))
2941
+ results["profile_ref"] = ref_stats
2942
+ print(f"[KernelBench] Reference profile:")
2943
+ print(f" Total GPU time: {ref_stats['total_gpu_time_ms']:.3f}ms")
2944
+ print(f" Kernels launched: {ref_stats['num_gpu_kernels']}")
2945
+ if ref_stats['top_kernels']:
2946
+ print(f" Top kernel: {ref_stats['top_kernels'][0]['name'][:60]}...")
2947
+ print(f" {ref_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
2948
+
2949
+ print(f"[KernelBench] Profile traces saved to: {profile_dir}/")
2950
+
2951
+ except Exception as prof_err:
2952
+ print(f"[KernelBench] Profiling failed: {prof_err}")
2953
+ results["profile_error"] = str(prof_err)
2570
2954
 
2571
2955
  except Exception as e:
2572
2956
  import traceback
@@ -2705,6 +3089,24 @@ async def run_evaluate_kernelbench_docker(
2705
3089
  error_message=f"Failed to write reference: {write_result.stderr}",
2706
3090
  )
2707
3091
 
3092
+ # Write custom inputs if provided
3093
+ if args.inputs:
3094
+ inputs_code = args.inputs.read_text()
3095
+ inputs_file_path = f"{run_path}/custom_inputs.py"
3096
+ write_result = await client.exec(
3097
+ f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3098
+ )
3099
+ if write_result.exit_code != 0:
3100
+ return EvaluateResult(
3101
+ success=False,
3102
+ all_correct=False,
3103
+ correctness_score=0.0,
3104
+ geomean_speedup=0.0,
3105
+ passed_tests=0,
3106
+ total_tests=0,
3107
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3108
+ )
3109
+
2708
3110
  # Write eval script
2709
3111
  eval_script_path = f"{run_path}/kernelbench_eval.py"
2710
3112
  write_result = await client.exec(
@@ -2721,14 +3123,40 @@ async def run_evaluate_kernelbench_docker(
2721
3123
  error_message=f"Failed to write eval script: {write_result.stderr}",
2722
3124
  )
2723
3125
 
3126
+ # Write defense module if defensive mode is enabled
3127
+ defense_module_path = None
3128
+ if args.defensive:
3129
+ defense_path = (
3130
+ Path(__file__).parent.parent.parent.parent
3131
+ / "packages"
3132
+ / "wafer-core"
3133
+ / "wafer_core"
3134
+ / "utils"
3135
+ / "kernel_utils"
3136
+ / "defense.py"
3137
+ )
3138
+ if defense_path.exists():
3139
+ defense_code = defense_path.read_text()
3140
+ defense_module_path = f"{run_path}/defense.py"
3141
+ write_result = await client.exec(
3142
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3143
+ )
3144
+ if write_result.exit_code != 0:
3145
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3146
+ defense_module_path = None
3147
+ else:
3148
+ print(f"Warning: defense.py not found at {defense_path}")
3149
+
2724
3150
  print("Running KernelBench evaluation in Docker container...")
2725
3151
 
2726
3152
  # Paths inside container
2727
3153
  container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
2728
3154
  container_impl_path = f"{container_run_path}/implementation.py"
2729
3155
  container_ref_path = f"{container_run_path}/reference.py"
3156
+ container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
2730
3157
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
2731
3158
  container_output = f"{container_run_path}/results.json"
3159
+ container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
2732
3160
 
2733
3161
  # Build eval command
2734
3162
  python_cmd_parts = [
@@ -2740,6 +3168,14 @@ async def run_evaluate_kernelbench_docker(
2740
3168
 
2741
3169
  if args.benchmark:
2742
3170
  python_cmd_parts.append("--benchmark")
3171
+ if args.profile:
3172
+ python_cmd_parts.append("--profile")
3173
+ if container_inputs_path:
3174
+ python_cmd_parts.append(f"--inputs {container_inputs_path}")
3175
+ if args.defensive and container_defense_path:
3176
+ python_cmd_parts.append("--defensive")
3177
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3178
+ python_cmd_parts.append(f"--seed {args.seed}")
2743
3179
 
2744
3180
  eval_cmd = " ".join(python_cmd_parts)
2745
3181
 
@@ -2920,6 +3356,24 @@ async def run_evaluate_kernelbench_digitalocean(
2920
3356
  error_message=f"Failed to write reference: {write_result.stderr}",
2921
3357
  )
2922
3358
 
3359
+ # Write custom inputs if provided
3360
+ if args.inputs:
3361
+ inputs_code = args.inputs.read_text()
3362
+ inputs_file_path = f"{run_path}/custom_inputs.py"
3363
+ write_result = await client.exec(
3364
+ f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3365
+ )
3366
+ if write_result.exit_code != 0:
3367
+ return EvaluateResult(
3368
+ success=False,
3369
+ all_correct=False,
3370
+ correctness_score=0.0,
3371
+ geomean_speedup=0.0,
3372
+ passed_tests=0,
3373
+ total_tests=0,
3374
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3375
+ )
3376
+
2923
3377
  # Write eval script
2924
3378
  eval_script_path = f"{run_path}/kernelbench_eval.py"
2925
3379
  write_result = await client.exec(
@@ -2936,14 +3390,44 @@ async def run_evaluate_kernelbench_digitalocean(
2936
3390
  error_message=f"Failed to write eval script: {write_result.stderr}",
2937
3391
  )
2938
3392
 
3393
+ # Write defense module if defensive mode is enabled
3394
+ defense_module_path = None
3395
+ if args.defensive:
3396
+ defense_path = (
3397
+ Path(__file__).parent.parent.parent.parent
3398
+ / "packages"
3399
+ / "wafer-core"
3400
+ / "wafer_core"
3401
+ / "utils"
3402
+ / "kernel_utils"
3403
+ / "defense.py"
3404
+ )
3405
+ if defense_path.exists():
3406
+ defense_code = defense_path.read_text()
3407
+ defense_module_path = f"{run_path}/defense.py"
3408
+ write_result = await client.exec(
3409
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3410
+ )
3411
+ if write_result.exit_code != 0:
3412
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3413
+ defense_module_path = None
3414
+ else:
3415
+ print(f"Warning: defense.py not found at {defense_path}")
3416
+
2939
3417
  print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
2940
3418
 
2941
3419
  # Paths inside container
2942
3420
  container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
2943
3421
  container_impl_path = f"{container_run_path}/implementation.py"
2944
3422
  container_ref_path = f"{container_run_path}/reference.py"
3423
+ container_inputs_path = (
3424
+ f"{container_run_path}/custom_inputs.py" if args.inputs else None
3425
+ )
2945
3426
  container_eval_script = f"{container_run_path}/kernelbench_eval.py"
2946
3427
  container_output = f"{container_run_path}/results.json"
3428
+ container_defense_path = (
3429
+ f"{container_run_path}/defense.py" if defense_module_path else None
3430
+ )
2947
3431
 
2948
3432
  # Build eval command
2949
3433
  python_cmd_parts = [
@@ -2955,6 +3439,14 @@ async def run_evaluate_kernelbench_digitalocean(
2955
3439
 
2956
3440
  if args.benchmark:
2957
3441
  python_cmd_parts.append("--benchmark")
3442
+ if args.profile:
3443
+ python_cmd_parts.append("--profile")
3444
+ if container_inputs_path:
3445
+ python_cmd_parts.append(f"--inputs {container_inputs_path}")
3446
+ if args.defensive and container_defense_path:
3447
+ python_cmd_parts.append("--defensive")
3448
+ python_cmd_parts.append(f"--defense-module {container_defense_path}")
3449
+ python_cmd_parts.append(f"--seed {args.seed}")
2958
3450
 
2959
3451
  eval_cmd = " ".join(python_cmd_parts)
2960
3452