wafer-cli 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/cli.py +479 -18
- wafer/evaluate.py +760 -268
- wafer/gpu_run.py +5 -1
- wafer/problems.py +357 -0
- wafer/wevin_cli.py +22 -2
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.4.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.4.dist-info}/RECORD +10 -9
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.4.dist-info}/WHEEL +1 -1
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.4.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.4.dist-info}/top_level.txt +0 -0
wafer/evaluate.py
CHANGED
|
@@ -158,6 +158,8 @@ class KernelBenchEvaluateArgs:
|
|
|
158
158
|
target_name: str
|
|
159
159
|
benchmark: bool = False
|
|
160
160
|
profile: bool = False
|
|
161
|
+
inputs: Path | None = None # Custom inputs file to override get_inputs()
|
|
162
|
+
seed: int = 42 # Random seed for reproducibility
|
|
161
163
|
defensive: bool = False
|
|
162
164
|
sync_artifacts: bool = True
|
|
163
165
|
gpu_id: int | None = None
|
|
@@ -349,18 +351,6 @@ def _build_docker_pip_install_cmd(target: BaremetalTarget | VMTarget) -> str:
|
|
|
349
351
|
return " && ".join(commands)
|
|
350
352
|
|
|
351
353
|
|
|
352
|
-
def _get_wafer_root() -> Path:
|
|
353
|
-
"""Get wafer monorepo root directory.
|
|
354
|
-
|
|
355
|
-
Walks up from this file to find the wafer repo root (contains apps/, packages/).
|
|
356
|
-
"""
|
|
357
|
-
current = Path(__file__).resolve()
|
|
358
|
-
for parent in [current] + list(current.parents):
|
|
359
|
-
if (parent / "apps").is_dir() and (parent / "packages").is_dir():
|
|
360
|
-
return parent
|
|
361
|
-
raise RuntimeError(f"Could not find wafer root from {__file__}")
|
|
362
|
-
|
|
363
|
-
|
|
364
354
|
async def run_evaluate_docker(
|
|
365
355
|
args: EvaluateArgs,
|
|
366
356
|
target: BaremetalTarget | VMTarget,
|
|
@@ -394,33 +384,6 @@ async def run_evaluate_docker(
|
|
|
394
384
|
print(f"Connecting to {target.ssh_target}...")
|
|
395
385
|
|
|
396
386
|
async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
|
|
397
|
-
# Upload wafer-core to remote
|
|
398
|
-
try:
|
|
399
|
-
wafer_root = _get_wafer_root()
|
|
400
|
-
wafer_core_path = wafer_root / "packages" / "wafer-core"
|
|
401
|
-
print(f"Uploading wafer-core from {wafer_core_path}...")
|
|
402
|
-
|
|
403
|
-
# Create workspace and upload
|
|
404
|
-
workspace_name = wafer_core_path.name
|
|
405
|
-
remote_workspace = f"{REMOTE_WORKSPACE_BASE}/{workspace_name}"
|
|
406
|
-
await client.exec(f"mkdir -p {remote_workspace}")
|
|
407
|
-
wafer_core_workspace = await client.expand_path(remote_workspace)
|
|
408
|
-
|
|
409
|
-
upload_result = await client.upload_files(
|
|
410
|
-
str(wafer_core_path), wafer_core_workspace, recursive=True
|
|
411
|
-
)
|
|
412
|
-
print(f"Uploaded {upload_result.files_copied} files")
|
|
413
|
-
except Exception as e:
|
|
414
|
-
return EvaluateResult(
|
|
415
|
-
success=False,
|
|
416
|
-
all_correct=False,
|
|
417
|
-
correctness_score=0.0,
|
|
418
|
-
geomean_speedup=0.0,
|
|
419
|
-
passed_tests=0,
|
|
420
|
-
total_tests=0,
|
|
421
|
-
error_message=f"Failed to upload wafer-core: {e}",
|
|
422
|
-
)
|
|
423
|
-
|
|
424
387
|
print(f"Using Docker image: {target.docker_image}")
|
|
425
388
|
print(f"Using GPU {gpu_id}...")
|
|
426
389
|
|
|
@@ -429,10 +392,13 @@ async def run_evaluate_docker(
|
|
|
429
392
|
ref_code = args.reference.read_text()
|
|
430
393
|
test_cases_data = json.loads(args.test_cases.read_text())
|
|
431
394
|
|
|
432
|
-
# Create
|
|
395
|
+
# Create workspace for evaluation files
|
|
433
396
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
434
397
|
run_dir = f"wafer_eval_{timestamp}"
|
|
435
|
-
|
|
398
|
+
eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
|
|
399
|
+
await client.exec(f"mkdir -p {eval_workspace}")
|
|
400
|
+
eval_workspace_expanded = await client.expand_path(eval_workspace)
|
|
401
|
+
run_path = f"{eval_workspace_expanded}/{run_dir}"
|
|
436
402
|
|
|
437
403
|
print("Uploading evaluation files...")
|
|
438
404
|
|
|
@@ -519,17 +485,14 @@ async def run_evaluate_docker(
|
|
|
519
485
|
container_impl_path = f"{container_run_path}/implementation.py"
|
|
520
486
|
container_ref_path = f"{container_run_path}/reference.py"
|
|
521
487
|
container_test_cases_path = f"{container_run_path}/test_cases.json"
|
|
522
|
-
container_evaluate_script = (
|
|
523
|
-
f"{CONTAINER_WORKSPACE}/wafer_core/utils/kernel_utils/evaluate.py"
|
|
524
|
-
)
|
|
525
488
|
|
|
526
|
-
# Build pip install command for torch and other deps
|
|
489
|
+
# Build pip install command for torch and other deps, plus wafer-core
|
|
527
490
|
pip_install_cmd = _build_docker_pip_install_cmd(target)
|
|
491
|
+
install_cmd = f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
|
|
528
492
|
|
|
529
|
-
# Build evaluate command
|
|
493
|
+
# Build evaluate command using installed wafer-core module
|
|
530
494
|
python_cmd_parts = [
|
|
531
|
-
|
|
532
|
-
f"python3 {container_evaluate_script}",
|
|
495
|
+
"python3 -m wafer_core.utils.kernel_utils.evaluate",
|
|
533
496
|
f"--implementation {container_impl_path}",
|
|
534
497
|
f"--reference {container_ref_path}",
|
|
535
498
|
f"--test-cases {container_test_cases_path}",
|
|
@@ -545,8 +508,8 @@ async def run_evaluate_docker(
|
|
|
545
508
|
|
|
546
509
|
eval_cmd = " ".join(python_cmd_parts)
|
|
547
510
|
|
|
548
|
-
# Full command: install
|
|
549
|
-
full_cmd = f"{
|
|
511
|
+
# Full command: install deps + wafer-core, then run evaluate
|
|
512
|
+
full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
|
|
550
513
|
|
|
551
514
|
# Build Docker run command
|
|
552
515
|
# Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
|
|
@@ -556,7 +519,7 @@ async def run_evaluate_docker(
|
|
|
556
519
|
working_dir=container_run_path,
|
|
557
520
|
env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
|
|
558
521
|
gpus="all",
|
|
559
|
-
volumes={
|
|
522
|
+
volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
|
|
560
523
|
cap_add=["SYS_ADMIN"] if args.profile else None,
|
|
561
524
|
)
|
|
562
525
|
|
|
@@ -980,6 +943,7 @@ def _build_modal_sandbox_script(
|
|
|
980
943
|
test_cases_b64: str,
|
|
981
944
|
run_benchmarks: bool,
|
|
982
945
|
run_defensive: bool,
|
|
946
|
+
defense_code_b64: str | None = None,
|
|
983
947
|
) -> str:
|
|
984
948
|
"""Build Python script to create sandbox and run evaluation.
|
|
985
949
|
|
|
@@ -1060,6 +1024,19 @@ print('Files written')
|
|
|
1060
1024
|
print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
|
|
1061
1025
|
return
|
|
1062
1026
|
|
|
1027
|
+
# Write defense module if defensive mode is enabled
|
|
1028
|
+
if {run_defensive} and "{defense_code_b64}":
|
|
1029
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
1030
|
+
import base64
|
|
1031
|
+
with open('/workspace/defense.py', 'w') as f:
|
|
1032
|
+
f.write(base64.b64decode('{defense_code_b64}').decode())
|
|
1033
|
+
print('Defense module written')
|
|
1034
|
+
""")
|
|
1035
|
+
proc.wait()
|
|
1036
|
+
if proc.returncode != 0:
|
|
1037
|
+
print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
|
|
1038
|
+
return
|
|
1039
|
+
|
|
1063
1040
|
# Build inline evaluation script
|
|
1064
1041
|
eval_script = """
|
|
1065
1042
|
import json
|
|
@@ -1087,6 +1064,18 @@ generate_input = load_fn('reference.py', 'generate_input')
|
|
|
1087
1064
|
|
|
1088
1065
|
import torch
|
|
1089
1066
|
|
|
1067
|
+
# Load defense module if available and defensive mode is enabled
|
|
1068
|
+
run_defensive = {run_defensive}
|
|
1069
|
+
defense = None
|
|
1070
|
+
if run_defensive:
|
|
1071
|
+
try:
|
|
1072
|
+
defense = load_fn('defense.py', 'run_all_defenses')
|
|
1073
|
+
time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
|
|
1074
|
+
print('[Defense] Defense module loaded')
|
|
1075
|
+
except Exception as e:
|
|
1076
|
+
print(f'[Defense] Warning: Could not load defense module: {{e}}')
|
|
1077
|
+
defense = None
|
|
1078
|
+
|
|
1090
1079
|
results = []
|
|
1091
1080
|
all_correct = True
|
|
1092
1081
|
total_time_ms = 0.0
|
|
@@ -1114,36 +1103,63 @@ for tc in test_cases:
|
|
|
1114
1103
|
impl_time_ms = 0.0
|
|
1115
1104
|
ref_time_ms = 0.0
|
|
1116
1105
|
if {run_benchmarks}:
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1106
|
+
if run_defensive and defense is not None:
|
|
1107
|
+
# Use full defense suite
|
|
1108
|
+
# Run defense checks on implementation kernel
|
|
1109
|
+
all_passed, defense_results, _ = defense(
|
|
1110
|
+
lambda: custom_kernel(inputs),
|
|
1111
|
+
)
|
|
1112
|
+
if not all_passed:
|
|
1113
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
1114
|
+
raise ValueError(f"Defense checks failed: {{failed}}")
|
|
1115
|
+
|
|
1116
|
+
# Time with defensive timing
|
|
1117
|
+
impl_times, _ = time_with_defenses(
|
|
1118
|
+
lambda: custom_kernel(inputs),
|
|
1119
|
+
[],
|
|
1120
|
+
num_warmup=3,
|
|
1121
|
+
num_trials=10,
|
|
1122
|
+
verbose=False,
|
|
1123
|
+
run_defenses=False, # Already ran defenses above
|
|
1124
|
+
)
|
|
1125
|
+
impl_time_ms = sum(impl_times) / len(impl_times)
|
|
1126
|
+
|
|
1127
|
+
# Reference timing (no defense checks needed)
|
|
1128
|
+
ref_times, _ = time_with_defenses(
|
|
1129
|
+
lambda: ref_kernel(inputs),
|
|
1130
|
+
[],
|
|
1131
|
+
num_warmup=3,
|
|
1132
|
+
num_trials=10,
|
|
1133
|
+
verbose=False,
|
|
1134
|
+
run_defenses=False,
|
|
1135
|
+
)
|
|
1136
|
+
ref_time_ms = sum(ref_times) / len(ref_times)
|
|
1137
|
+
else:
|
|
1138
|
+
# Standard timing without full defenses
|
|
1139
|
+
# Warmup
|
|
1140
|
+
for _ in range(3):
|
|
1141
|
+
custom_kernel(inputs)
|
|
1142
|
+
torch.cuda.synchronize()
|
|
1143
|
+
|
|
1144
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
1145
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
1146
|
+
start.record()
|
|
1147
|
+
for _ in range(10):
|
|
1148
|
+
custom_kernel(inputs)
|
|
1149
|
+
end.record()
|
|
1150
|
+
torch.cuda.synchronize()
|
|
1151
|
+
impl_time_ms = start.elapsed_time(end) / 10
|
|
1152
|
+
|
|
1153
|
+
# Reference timing
|
|
1154
|
+
for _ in range(3):
|
|
1155
|
+
ref_kernel(inputs)
|
|
1156
|
+
torch.cuda.synchronize()
|
|
1157
|
+
start.record()
|
|
1158
|
+
for _ in range(10):
|
|
1159
|
+
ref_kernel(inputs)
|
|
1160
|
+
end.record()
|
|
1161
|
+
torch.cuda.synchronize()
|
|
1162
|
+
ref_time_ms = start.elapsed_time(end) / 10
|
|
1147
1163
|
|
|
1148
1164
|
total_time_ms += impl_time_ms
|
|
1149
1165
|
ref_total_time_ms += ref_time_ms
|
|
@@ -1236,6 +1252,23 @@ async def run_evaluate_modal(
|
|
|
1236
1252
|
ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
|
|
1237
1253
|
test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
|
|
1238
1254
|
|
|
1255
|
+
# Encode defense module if defensive mode is enabled
|
|
1256
|
+
defense_code_b64 = None
|
|
1257
|
+
if args.defensive:
|
|
1258
|
+
defense_path = (
|
|
1259
|
+
Path(__file__).parent.parent.parent.parent
|
|
1260
|
+
/ "packages"
|
|
1261
|
+
/ "wafer-core"
|
|
1262
|
+
/ "wafer_core"
|
|
1263
|
+
/ "utils"
|
|
1264
|
+
/ "kernel_utils"
|
|
1265
|
+
/ "defense.py"
|
|
1266
|
+
)
|
|
1267
|
+
if defense_path.exists():
|
|
1268
|
+
defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
|
|
1269
|
+
else:
|
|
1270
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
1271
|
+
|
|
1239
1272
|
# Build the script that creates sandbox and runs eval
|
|
1240
1273
|
script = _build_modal_sandbox_script(
|
|
1241
1274
|
target=target,
|
|
@@ -1244,6 +1277,7 @@ async def run_evaluate_modal(
|
|
|
1244
1277
|
test_cases_b64=test_cases_b64,
|
|
1245
1278
|
run_benchmarks=args.benchmark,
|
|
1246
1279
|
run_defensive=args.defensive,
|
|
1280
|
+
defense_code_b64=defense_code_b64,
|
|
1247
1281
|
)
|
|
1248
1282
|
|
|
1249
1283
|
def _run_subprocess() -> tuple[str, str, int]:
|
|
@@ -1341,6 +1375,7 @@ def _build_workspace_eval_script(
|
|
|
1341
1375
|
test_cases_json: str,
|
|
1342
1376
|
run_benchmarks: bool,
|
|
1343
1377
|
run_defensive: bool = False,
|
|
1378
|
+
defense_code: str | None = None,
|
|
1344
1379
|
) -> str:
|
|
1345
1380
|
"""Build inline evaluation script for workspace exec.
|
|
1346
1381
|
|
|
@@ -1351,6 +1386,7 @@ def _build_workspace_eval_script(
|
|
|
1351
1386
|
impl_b64 = base64.b64encode(impl_code.encode()).decode()
|
|
1352
1387
|
ref_b64 = base64.b64encode(ref_code.encode()).decode()
|
|
1353
1388
|
tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
|
|
1389
|
+
defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
|
|
1354
1390
|
|
|
1355
1391
|
return f'''
|
|
1356
1392
|
import base64
|
|
@@ -1370,6 +1406,14 @@ with open("/tmp/kernel.py", "w") as f:
|
|
|
1370
1406
|
with open("/tmp/reference.py", "w") as f:
|
|
1371
1407
|
f.write(ref_code)
|
|
1372
1408
|
|
|
1409
|
+
# Write defense module if available
|
|
1410
|
+
run_defensive = {run_defensive}
|
|
1411
|
+
defense_b64 = "{defense_b64}"
|
|
1412
|
+
if run_defensive and defense_b64:
|
|
1413
|
+
defense_code = base64.b64decode(defense_b64).decode()
|
|
1414
|
+
with open("/tmp/defense.py", "w") as f:
|
|
1415
|
+
f.write(defense_code)
|
|
1416
|
+
|
|
1373
1417
|
# Load kernels
|
|
1374
1418
|
def load_fn(path, name):
|
|
1375
1419
|
spec = importlib.util.spec_from_file_location("mod", path)
|
|
@@ -1383,6 +1427,17 @@ generate_input = load_fn("/tmp/reference.py", "generate_input")
|
|
|
1383
1427
|
|
|
1384
1428
|
import torch
|
|
1385
1429
|
|
|
1430
|
+
# Load defense module if available
|
|
1431
|
+
defense = None
|
|
1432
|
+
if run_defensive and defense_b64:
|
|
1433
|
+
try:
|
|
1434
|
+
defense = load_fn("/tmp/defense.py", "run_all_defenses")
|
|
1435
|
+
time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
|
|
1436
|
+
print("[Defense] Defense module loaded")
|
|
1437
|
+
except Exception as e:
|
|
1438
|
+
print(f"[Defense] Warning: Could not load defense module: {{e}}")
|
|
1439
|
+
defense = None
|
|
1440
|
+
|
|
1386
1441
|
results = []
|
|
1387
1442
|
all_correct = True
|
|
1388
1443
|
total_time_ms = 0.0
|
|
@@ -1410,36 +1465,60 @@ for tc in test_cases:
|
|
|
1410
1465
|
impl_time_ms = 0.0
|
|
1411
1466
|
ref_time_ms = 0.0
|
|
1412
1467
|
if {run_benchmarks}:
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1468
|
+
if run_defensive and defense is not None:
|
|
1469
|
+
# Use full defense suite
|
|
1470
|
+
all_passed, defense_results, _ = defense(
|
|
1471
|
+
lambda: custom_kernel(inputs),
|
|
1472
|
+
)
|
|
1473
|
+
if not all_passed:
|
|
1474
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
1475
|
+
raise ValueError(f"Defense checks failed: {{failed}}")
|
|
1476
|
+
|
|
1477
|
+
# Time with defensive timing
|
|
1478
|
+
impl_times, _ = time_with_defenses(
|
|
1479
|
+
lambda: custom_kernel(inputs),
|
|
1480
|
+
[],
|
|
1481
|
+
num_warmup=3,
|
|
1482
|
+
num_trials=10,
|
|
1483
|
+
verbose=False,
|
|
1484
|
+
run_defenses=False,
|
|
1485
|
+
)
|
|
1486
|
+
impl_time_ms = sum(impl_times) / len(impl_times)
|
|
1487
|
+
|
|
1488
|
+
# Reference timing
|
|
1489
|
+
ref_times, _ = time_with_defenses(
|
|
1490
|
+
lambda: ref_kernel(inputs),
|
|
1491
|
+
[],
|
|
1492
|
+
num_warmup=3,
|
|
1493
|
+
num_trials=10,
|
|
1494
|
+
verbose=False,
|
|
1495
|
+
run_defenses=False,
|
|
1496
|
+
)
|
|
1497
|
+
ref_time_ms = sum(ref_times) / len(ref_times)
|
|
1498
|
+
else:
|
|
1499
|
+
# Standard timing
|
|
1500
|
+
for _ in range(3):
|
|
1501
|
+
custom_kernel(inputs)
|
|
1502
|
+
torch.cuda.synchronize()
|
|
1503
|
+
|
|
1504
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
1505
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
1506
|
+
start.record()
|
|
1507
|
+
for _ in range(10):
|
|
1508
|
+
custom_kernel(inputs)
|
|
1509
|
+
end.record()
|
|
1510
|
+
torch.cuda.synchronize()
|
|
1511
|
+
impl_time_ms = start.elapsed_time(end) / 10
|
|
1512
|
+
|
|
1513
|
+
for _ in range(3):
|
|
1514
|
+
ref_kernel(inputs)
|
|
1515
|
+
torch.cuda.synchronize()
|
|
1516
|
+
start.record()
|
|
1517
|
+
for _ in range(10):
|
|
1518
|
+
ref_kernel(inputs)
|
|
1519
|
+
end.record()
|
|
1520
|
+
torch.cuda.synchronize()
|
|
1521
|
+
ref_time_ms = start.elapsed_time(end) / 10
|
|
1443
1522
|
|
|
1444
1523
|
total_time_ms += impl_time_ms
|
|
1445
1524
|
ref_total_time_ms += ref_time_ms
|
|
@@ -1501,6 +1580,23 @@ async def run_evaluate_workspace(
|
|
|
1501
1580
|
ref_code = args.reference.read_text()
|
|
1502
1581
|
test_cases_json = args.test_cases.read_text()
|
|
1503
1582
|
|
|
1583
|
+
# Read defense module if defensive mode is enabled
|
|
1584
|
+
defense_code = None
|
|
1585
|
+
if args.defensive:
|
|
1586
|
+
defense_path = (
|
|
1587
|
+
Path(__file__).parent.parent.parent.parent
|
|
1588
|
+
/ "packages"
|
|
1589
|
+
/ "wafer-core"
|
|
1590
|
+
/ "wafer_core"
|
|
1591
|
+
/ "utils"
|
|
1592
|
+
/ "kernel_utils"
|
|
1593
|
+
/ "defense.py"
|
|
1594
|
+
)
|
|
1595
|
+
if defense_path.exists():
|
|
1596
|
+
defense_code = defense_path.read_text()
|
|
1597
|
+
else:
|
|
1598
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
1599
|
+
|
|
1504
1600
|
# Build inline eval script
|
|
1505
1601
|
eval_script = _build_workspace_eval_script(
|
|
1506
1602
|
impl_code=impl_code,
|
|
@@ -1508,6 +1604,7 @@ async def run_evaluate_workspace(
|
|
|
1508
1604
|
test_cases_json=test_cases_json,
|
|
1509
1605
|
run_benchmarks=args.benchmark,
|
|
1510
1606
|
run_defensive=args.defensive,
|
|
1607
|
+
defense_code=defense_code,
|
|
1511
1608
|
)
|
|
1512
1609
|
|
|
1513
1610
|
# Execute via workspace exec
|
|
@@ -1691,54 +1788,12 @@ async def run_evaluate_runpod(
|
|
|
1691
1788
|
error_message=f"Failed to setup Python environment: {e}",
|
|
1692
1789
|
)
|
|
1693
1790
|
|
|
1694
|
-
#
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
|
|
1701
|
-
await client.exec(f"mkdir -p {wafer_core_remote}")
|
|
1702
|
-
wafer_core_workspace = await client.expand_path(wafer_core_remote)
|
|
1703
|
-
|
|
1704
|
-
upload_result = await client.upload_files(
|
|
1705
|
-
str(wafer_core_path), wafer_core_workspace, recursive=True
|
|
1706
|
-
)
|
|
1707
|
-
|
|
1708
|
-
# Wide event logging for upload result
|
|
1709
|
-
upload_event = {
|
|
1710
|
-
"event": "wafer_core_upload",
|
|
1711
|
-
"target": target.name,
|
|
1712
|
-
"target_type": "runpod",
|
|
1713
|
-
"ssh_host": f"{client.user}@{client.host}:{client.port}",
|
|
1714
|
-
"local_path": str(wafer_core_path),
|
|
1715
|
-
"remote_path": wafer_core_workspace,
|
|
1716
|
-
"success": upload_result.success,
|
|
1717
|
-
"files_copied": upload_result.files_copied,
|
|
1718
|
-
"duration_seconds": upload_result.duration_seconds,
|
|
1719
|
-
"error_message": upload_result.error_message,
|
|
1720
|
-
}
|
|
1721
|
-
if upload_result.debug_info:
|
|
1722
|
-
upload_event["debug_info"] = upload_result.debug_info
|
|
1723
|
-
logger.info(json.dumps(upload_event))
|
|
1724
|
-
|
|
1725
|
-
# Fail fast if upload failed
|
|
1726
|
-
if not upload_result.success:
|
|
1727
|
-
print(f"ERROR: Upload failed: {upload_result.error_message}")
|
|
1728
|
-
if upload_result.debug_info:
|
|
1729
|
-
print(f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}")
|
|
1730
|
-
return EvaluateResult(
|
|
1731
|
-
success=False,
|
|
1732
|
-
all_correct=False,
|
|
1733
|
-
correctness_score=0.0,
|
|
1734
|
-
geomean_speedup=0.0,
|
|
1735
|
-
passed_tests=0,
|
|
1736
|
-
total_tests=0,
|
|
1737
|
-
error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
|
|
1738
|
-
)
|
|
1739
|
-
|
|
1740
|
-
print(f"Uploaded {upload_result.files_copied} files")
|
|
1741
|
-
except Exception as e:
|
|
1791
|
+
# Install wafer-core in remote venv
|
|
1792
|
+
print("Installing wafer-core...")
|
|
1793
|
+
install_result = await client.exec(
|
|
1794
|
+
f"{env_state.venv_bin}/uv pip install wafer-core"
|
|
1795
|
+
)
|
|
1796
|
+
if install_result.exit_code != 0:
|
|
1742
1797
|
return EvaluateResult(
|
|
1743
1798
|
success=False,
|
|
1744
1799
|
all_correct=False,
|
|
@@ -1746,7 +1801,7 @@ async def run_evaluate_runpod(
|
|
|
1746
1801
|
geomean_speedup=0.0,
|
|
1747
1802
|
passed_tests=0,
|
|
1748
1803
|
total_tests=0,
|
|
1749
|
-
error_message=f"Failed to
|
|
1804
|
+
error_message=f"Failed to install wafer-core: {install_result.stderr}",
|
|
1750
1805
|
)
|
|
1751
1806
|
|
|
1752
1807
|
# Select GPU (RunPod pods typically have GPU 0)
|
|
@@ -1853,15 +1908,12 @@ async def run_evaluate_runpod(
|
|
|
1853
1908
|
# Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
|
|
1854
1909
|
venv_bin = env_state.venv_bin
|
|
1855
1910
|
env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
|
|
1856
|
-
pythonpath = f"PYTHONPATH={wafer_core_workspace}"
|
|
1857
|
-
evaluate_script = (
|
|
1858
|
-
f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
|
|
1859
|
-
)
|
|
1860
1911
|
|
|
1861
1912
|
# Run from run_path so reference_kernel.py is importable
|
|
1913
|
+
# Use installed wafer-core module
|
|
1862
1914
|
eval_cmd = (
|
|
1863
1915
|
f"cd {run_path} && "
|
|
1864
|
-
f"{env_vars} {
|
|
1916
|
+
f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
|
|
1865
1917
|
f"--implementation {impl_path} "
|
|
1866
1918
|
f"--reference {ref_path} "
|
|
1867
1919
|
f"--test-cases {test_cases_path} "
|
|
@@ -2046,61 +2098,12 @@ async def run_evaluate_digitalocean(
|
|
|
2046
2098
|
error_message=f"Failed to setup Python environment: {e}",
|
|
2047
2099
|
)
|
|
2048
2100
|
|
|
2049
|
-
#
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
|
|
2056
|
-
await client.exec(f"mkdir -p {wafer_core_remote}")
|
|
2057
|
-
wafer_core_workspace = await client.expand_path(wafer_core_remote)
|
|
2058
|
-
|
|
2059
|
-
# Use SFTP instead of rsync to avoid SSH subprocess timeout issues
|
|
2060
|
-
# (DigitalOcean may rate-limit new SSH connections)
|
|
2061
|
-
upload_result = await client.upload_files(
|
|
2062
|
-
str(wafer_core_path),
|
|
2063
|
-
wafer_core_workspace,
|
|
2064
|
-
recursive=True,
|
|
2065
|
-
use_sftp=True,
|
|
2066
|
-
)
|
|
2067
|
-
|
|
2068
|
-
# Wide event logging for upload result
|
|
2069
|
-
upload_event = {
|
|
2070
|
-
"event": "wafer_core_upload",
|
|
2071
|
-
"target": target.name,
|
|
2072
|
-
"target_type": "digitalocean",
|
|
2073
|
-
"ssh_host": f"{client.user}@{client.host}:{client.port}",
|
|
2074
|
-
"local_path": str(wafer_core_path),
|
|
2075
|
-
"remote_path": wafer_core_workspace,
|
|
2076
|
-
"success": upload_result.success,
|
|
2077
|
-
"files_copied": upload_result.files_copied,
|
|
2078
|
-
"duration_seconds": upload_result.duration_seconds,
|
|
2079
|
-
"error_message": upload_result.error_message,
|
|
2080
|
-
}
|
|
2081
|
-
if upload_result.debug_info:
|
|
2082
|
-
upload_event["debug_info"] = upload_result.debug_info
|
|
2083
|
-
logger.info(json.dumps(upload_event))
|
|
2084
|
-
|
|
2085
|
-
# Fail fast if upload failed
|
|
2086
|
-
if not upload_result.success:
|
|
2087
|
-
print(f"ERROR: Upload failed: {upload_result.error_message}")
|
|
2088
|
-
if upload_result.debug_info:
|
|
2089
|
-
print(
|
|
2090
|
-
f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}"
|
|
2091
|
-
)
|
|
2092
|
-
return EvaluateResult(
|
|
2093
|
-
success=False,
|
|
2094
|
-
all_correct=False,
|
|
2095
|
-
correctness_score=0.0,
|
|
2096
|
-
geomean_speedup=0.0,
|
|
2097
|
-
passed_tests=0,
|
|
2098
|
-
total_tests=0,
|
|
2099
|
-
error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
|
|
2100
|
-
)
|
|
2101
|
-
|
|
2102
|
-
print(f"Uploaded {upload_result.files_copied} files")
|
|
2103
|
-
except Exception as e:
|
|
2101
|
+
# Install wafer-core in remote venv
|
|
2102
|
+
print("Installing wafer-core...")
|
|
2103
|
+
install_result = await client.exec(
|
|
2104
|
+
f"{env_state.venv_bin}/uv pip install wafer-core"
|
|
2105
|
+
)
|
|
2106
|
+
if install_result.exit_code != 0:
|
|
2104
2107
|
return EvaluateResult(
|
|
2105
2108
|
success=False,
|
|
2106
2109
|
all_correct=False,
|
|
@@ -2108,7 +2111,7 @@ async def run_evaluate_digitalocean(
|
|
|
2108
2111
|
geomean_speedup=0.0,
|
|
2109
2112
|
passed_tests=0,
|
|
2110
2113
|
total_tests=0,
|
|
2111
|
-
error_message=f"Failed to
|
|
2114
|
+
error_message=f"Failed to install wafer-core: {install_result.stderr}",
|
|
2112
2115
|
)
|
|
2113
2116
|
|
|
2114
2117
|
# Select GPU (DigitalOcean droplets typically have GPU 0)
|
|
@@ -2217,15 +2220,12 @@ async def run_evaluate_digitalocean(
|
|
|
2217
2220
|
env_vars = (
|
|
2218
2221
|
f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
|
|
2219
2222
|
)
|
|
2220
|
-
pythonpath = f"PYTHONPATH={wafer_core_workspace}"
|
|
2221
|
-
evaluate_script = (
|
|
2222
|
-
f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
|
|
2223
|
-
)
|
|
2224
2223
|
|
|
2225
2224
|
# Run from run_path so reference_kernel.py is importable
|
|
2225
|
+
# Use installed wafer-core module
|
|
2226
2226
|
eval_cmd = (
|
|
2227
2227
|
f"cd {run_path} && "
|
|
2228
|
-
f"{env_vars} {
|
|
2228
|
+
f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
|
|
2229
2229
|
f"--implementation {impl_path} "
|
|
2230
2230
|
f"--reference {ref_path} "
|
|
2231
2231
|
f"--test-cases {test_cases_path} "
|
|
@@ -2435,10 +2435,233 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
|
|
|
2435
2435
|
# This runs inside the Docker container on the remote GPU
|
|
2436
2436
|
KERNELBENCH_EVAL_SCRIPT = """
|
|
2437
2437
|
import json
|
|
2438
|
+
import os
|
|
2438
2439
|
import sys
|
|
2439
2440
|
import time
|
|
2440
2441
|
import torch
|
|
2441
2442
|
import torch.nn as nn
|
|
2443
|
+
from pathlib import Path
|
|
2444
|
+
|
|
2445
|
+
|
|
2446
|
+
def run_profiling(model, inputs, name, output_dir):
|
|
2447
|
+
'''Run torch.profiler and return summary stats.'''
|
|
2448
|
+
from torch.profiler import profile, ProfilerActivity
|
|
2449
|
+
|
|
2450
|
+
# Determine activities based on backend
|
|
2451
|
+
activities = [ProfilerActivity.CPU]
|
|
2452
|
+
if torch.cuda.is_available():
|
|
2453
|
+
activities.append(ProfilerActivity.CUDA)
|
|
2454
|
+
|
|
2455
|
+
# Warmup
|
|
2456
|
+
for _ in range(3):
|
|
2457
|
+
with torch.no_grad():
|
|
2458
|
+
_ = model(*inputs)
|
|
2459
|
+
torch.cuda.synchronize()
|
|
2460
|
+
|
|
2461
|
+
# Profile
|
|
2462
|
+
with profile(
|
|
2463
|
+
activities=activities,
|
|
2464
|
+
record_shapes=True,
|
|
2465
|
+
with_stack=False,
|
|
2466
|
+
profile_memory=True,
|
|
2467
|
+
) as prof:
|
|
2468
|
+
with torch.no_grad():
|
|
2469
|
+
_ = model(*inputs)
|
|
2470
|
+
torch.cuda.synchronize()
|
|
2471
|
+
|
|
2472
|
+
# Get key averages
|
|
2473
|
+
key_averages = prof.key_averages()
|
|
2474
|
+
|
|
2475
|
+
# Find the main kernel (longest GPU time)
|
|
2476
|
+
# Use cuda_time_total for compatibility with both CUDA and ROCm
|
|
2477
|
+
def get_gpu_time(e):
|
|
2478
|
+
# Try different attributes for GPU time
|
|
2479
|
+
if hasattr(e, 'cuda_time_total'):
|
|
2480
|
+
return e.cuda_time_total
|
|
2481
|
+
if hasattr(e, 'device_time_total'):
|
|
2482
|
+
return e.device_time_total
|
|
2483
|
+
if hasattr(e, 'self_cuda_time_total'):
|
|
2484
|
+
return e.self_cuda_time_total
|
|
2485
|
+
return 0
|
|
2486
|
+
|
|
2487
|
+
gpu_events = [e for e in key_averages if get_gpu_time(e) > 0]
|
|
2488
|
+
gpu_events.sort(key=lambda e: get_gpu_time(e), reverse=True)
|
|
2489
|
+
|
|
2490
|
+
stats = {
|
|
2491
|
+
"name": name,
|
|
2492
|
+
"total_gpu_time_ms": sum(get_gpu_time(e) for e in gpu_events) / 1000,
|
|
2493
|
+
"total_cpu_time_ms": sum(e.cpu_time_total for e in key_averages) / 1000,
|
|
2494
|
+
"num_gpu_kernels": len(gpu_events),
|
|
2495
|
+
"top_kernels": [],
|
|
2496
|
+
}
|
|
2497
|
+
|
|
2498
|
+
# Top 5 kernels by GPU time
|
|
2499
|
+
for e in gpu_events[:5]:
|
|
2500
|
+
stats["top_kernels"].append({
|
|
2501
|
+
"name": e.key,
|
|
2502
|
+
"gpu_time_ms": get_gpu_time(e) / 1000,
|
|
2503
|
+
"cpu_time_ms": e.cpu_time_total / 1000,
|
|
2504
|
+
"calls": e.count,
|
|
2505
|
+
})
|
|
2506
|
+
|
|
2507
|
+
# Save trace for visualization
|
|
2508
|
+
trace_path = Path(output_dir) / f"{name}_trace.json"
|
|
2509
|
+
prof.export_chrome_trace(str(trace_path))
|
|
2510
|
+
stats["trace_file"] = str(trace_path)
|
|
2511
|
+
|
|
2512
|
+
return stats
|
|
2513
|
+
|
|
2514
|
+
|
|
2515
|
+
def validate_custom_inputs(original_inputs, custom_inputs):
|
|
2516
|
+
'''Validate that custom inputs match the expected signature.
|
|
2517
|
+
|
|
2518
|
+
Returns (is_valid, error_message).
|
|
2519
|
+
'''
|
|
2520
|
+
if len(original_inputs) != len(custom_inputs):
|
|
2521
|
+
return False, f"get_inputs() must return {len(original_inputs)} tensors, got {len(custom_inputs)}"
|
|
2522
|
+
|
|
2523
|
+
for i, (orig, cust) in enumerate(zip(original_inputs, custom_inputs)):
|
|
2524
|
+
if not isinstance(cust, torch.Tensor):
|
|
2525
|
+
if not isinstance(orig, torch.Tensor):
|
|
2526
|
+
continue # Both non-tensor, ok
|
|
2527
|
+
return False, f"Input {i}: expected Tensor, got {type(cust).__name__}"
|
|
2528
|
+
|
|
2529
|
+
if not isinstance(orig, torch.Tensor):
|
|
2530
|
+
return False, f"Input {i}: expected {type(orig).__name__}, got Tensor"
|
|
2531
|
+
|
|
2532
|
+
if orig.dtype != cust.dtype:
|
|
2533
|
+
return False, f"Input {i}: dtype mismatch - expected {orig.dtype}, got {cust.dtype}"
|
|
2534
|
+
|
|
2535
|
+
if orig.dim() != cust.dim():
|
|
2536
|
+
return False, f"Input {i}: dimension mismatch - expected {orig.dim()}D, got {cust.dim()}D"
|
|
2537
|
+
|
|
2538
|
+
return True, None
|
|
2539
|
+
|
|
2540
|
+
|
|
2541
|
+
def analyze_diff(ref_output, new_output, rtol=1e-3, atol=1e-3, max_samples=5):
|
|
2542
|
+
'''Analyze differences between reference and implementation outputs.
|
|
2543
|
+
|
|
2544
|
+
Returns a dict with detailed diff information.
|
|
2545
|
+
'''
|
|
2546
|
+
diff = (ref_output - new_output).abs()
|
|
2547
|
+
threshold = atol + rtol * ref_output.abs()
|
|
2548
|
+
wrong_mask = diff > threshold
|
|
2549
|
+
|
|
2550
|
+
total_elements = ref_output.numel()
|
|
2551
|
+
wrong_count = wrong_mask.sum().item()
|
|
2552
|
+
|
|
2553
|
+
# Basic stats
|
|
2554
|
+
max_diff = diff.max().item()
|
|
2555
|
+
max_diff_idx = tuple(torch.unravel_index(diff.argmax(), diff.shape))
|
|
2556
|
+
max_diff_idx = tuple(int(i) for i in max_diff_idx) # Convert to Python ints
|
|
2557
|
+
|
|
2558
|
+
# Relative error (avoid div by zero)
|
|
2559
|
+
ref_abs = ref_output.abs()
|
|
2560
|
+
nonzero_mask = ref_abs > 1e-8
|
|
2561
|
+
if nonzero_mask.any():
|
|
2562
|
+
rel_error = diff[nonzero_mask] / ref_abs[nonzero_mask]
|
|
2563
|
+
max_rel_error = rel_error.max().item()
|
|
2564
|
+
mean_rel_error = rel_error.mean().item()
|
|
2565
|
+
else:
|
|
2566
|
+
max_rel_error = float('inf') if max_diff > 0 else 0.0
|
|
2567
|
+
mean_rel_error = max_rel_error
|
|
2568
|
+
|
|
2569
|
+
# Error histogram (buckets: <1e-6, 1e-6 to 1e-4, 1e-4 to 1e-2, 1e-2 to 1, >1)
|
|
2570
|
+
histogram = {
|
|
2571
|
+
'<1e-6': int((diff < 1e-6).sum().item()),
|
|
2572
|
+
'1e-6 to 1e-4': int(((diff >= 1e-6) & (diff < 1e-4)).sum().item()),
|
|
2573
|
+
'1e-4 to 1e-2': int(((diff >= 1e-4) & (diff < 1e-2)).sum().item()),
|
|
2574
|
+
'1e-2 to 1': int(((diff >= 1e-2) & (diff < 1)).sum().item()),
|
|
2575
|
+
'>1': int((diff >= 1).sum().item()),
|
|
2576
|
+
}
|
|
2577
|
+
|
|
2578
|
+
result = {
|
|
2579
|
+
'max_diff': max_diff,
|
|
2580
|
+
'max_diff_idx': max_diff_idx,
|
|
2581
|
+
'mean_diff': diff.mean().item(),
|
|
2582
|
+
'max_rel_error': max_rel_error,
|
|
2583
|
+
'mean_rel_error': mean_rel_error,
|
|
2584
|
+
'total_elements': total_elements,
|
|
2585
|
+
'wrong_count': int(wrong_count),
|
|
2586
|
+
'wrong_pct': 100.0 * wrong_count / total_elements,
|
|
2587
|
+
'histogram': histogram,
|
|
2588
|
+
'samples': [],
|
|
2589
|
+
}
|
|
2590
|
+
|
|
2591
|
+
# Get indices of wrong elements
|
|
2592
|
+
if wrong_count > 0:
|
|
2593
|
+
wrong_indices = torch.nonzero(wrong_mask, as_tuple=False)
|
|
2594
|
+
|
|
2595
|
+
# Take first N samples
|
|
2596
|
+
num_samples = min(max_samples, len(wrong_indices))
|
|
2597
|
+
for i in range(num_samples):
|
|
2598
|
+
idx = tuple(wrong_indices[i].tolist())
|
|
2599
|
+
ref_val = ref_output[idx].item()
|
|
2600
|
+
new_val = new_output[idx].item()
|
|
2601
|
+
diff_val = diff[idx].item()
|
|
2602
|
+
result['samples'].append({
|
|
2603
|
+
'index': idx,
|
|
2604
|
+
'ref': ref_val,
|
|
2605
|
+
'impl': new_val,
|
|
2606
|
+
'diff': diff_val,
|
|
2607
|
+
})
|
|
2608
|
+
|
|
2609
|
+
# Try to detect pattern
|
|
2610
|
+
if wrong_count >= total_elements * 0.99:
|
|
2611
|
+
result['pattern'] = 'all_wrong'
|
|
2612
|
+
elif wrong_count < total_elements * 0.01:
|
|
2613
|
+
# Check if failures are at boundaries
|
|
2614
|
+
shape = ref_output.shape
|
|
2615
|
+
boundary_count = 0
|
|
2616
|
+
for idx in wrong_indices[:min(100, len(wrong_indices))]:
|
|
2617
|
+
idx_list = idx.tolist()
|
|
2618
|
+
is_boundary = any(i == 0 or i == s - 1 for i, s in zip(idx_list, shape))
|
|
2619
|
+
if is_boundary:
|
|
2620
|
+
boundary_count += 1
|
|
2621
|
+
if boundary_count > len(wrong_indices[:100]) * 0.8:
|
|
2622
|
+
result['pattern'] = 'boundary_issue'
|
|
2623
|
+
else:
|
|
2624
|
+
result['pattern'] = 'scattered'
|
|
2625
|
+
else:
|
|
2626
|
+
result['pattern'] = 'partial'
|
|
2627
|
+
|
|
2628
|
+
return result
|
|
2629
|
+
|
|
2630
|
+
|
|
2631
|
+
def print_diff_analysis(analysis):
|
|
2632
|
+
'''Print a human-readable diff analysis.'''
|
|
2633
|
+
print(f"[KernelBench] Diff analysis:")
|
|
2634
|
+
|
|
2635
|
+
# Max diff with location
|
|
2636
|
+
idx_str = ','.join(str(i) for i in analysis['max_diff_idx'])
|
|
2637
|
+
print(f" Max diff: {analysis['max_diff']:.6f} at index [{idx_str}]")
|
|
2638
|
+
print(f" Mean diff: {analysis['mean_diff']:.6f}")
|
|
2639
|
+
|
|
2640
|
+
# Relative errors
|
|
2641
|
+
print(f" Max relative error: {analysis['max_rel_error']:.2%}, Mean: {analysis['mean_rel_error']:.2%}")
|
|
2642
|
+
|
|
2643
|
+
# Wrong count
|
|
2644
|
+
print(f" Wrong elements: {analysis['wrong_count']:,} / {analysis['total_elements']:,} ({analysis['wrong_pct']:.2f}%)")
|
|
2645
|
+
|
|
2646
|
+
# Histogram
|
|
2647
|
+
hist = analysis['histogram']
|
|
2648
|
+
print(f" Error distribution: <1e-6: {hist['<1e-6']:,} | 1e-6~1e-4: {hist['1e-6 to 1e-4']:,} | 1e-4~1e-2: {hist['1e-4 to 1e-2']:,} | 1e-2~1: {hist['1e-2 to 1']:,} | >1: {hist['>1']:,}")
|
|
2649
|
+
|
|
2650
|
+
if 'pattern' in analysis:
|
|
2651
|
+
pattern_desc = {
|
|
2652
|
+
'all_wrong': 'ALL elements wrong - likely algorithmic error or wrong weights',
|
|
2653
|
+
'boundary_issue': 'Mostly BOUNDARY elements wrong - check edge handling',
|
|
2654
|
+
'scattered': 'SCATTERED failures - numerical precision issue?',
|
|
2655
|
+
'partial': 'PARTIAL failures - check specific conditions',
|
|
2656
|
+
}
|
|
2657
|
+
print(f" Pattern: {pattern_desc.get(analysis['pattern'], analysis['pattern'])}")
|
|
2658
|
+
|
|
2659
|
+
if analysis['samples']:
|
|
2660
|
+
print(f" Sample failures:")
|
|
2661
|
+
for s in analysis['samples']:
|
|
2662
|
+
idx_str = ','.join(str(i) for i in s['index'])
|
|
2663
|
+
print(f" [{idx_str}]: ref={s['ref']:.6f} impl={s['impl']:.6f} (diff={s['diff']:.6f})")
|
|
2664
|
+
|
|
2442
2665
|
|
|
2443
2666
|
def main():
|
|
2444
2667
|
# Parse args
|
|
@@ -2446,12 +2669,35 @@ def main():
|
|
|
2446
2669
|
parser = argparse.ArgumentParser()
|
|
2447
2670
|
parser.add_argument("--impl", required=True)
|
|
2448
2671
|
parser.add_argument("--reference", required=True)
|
|
2672
|
+
parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
|
|
2449
2673
|
parser.add_argument("--benchmark", action="store_true")
|
|
2674
|
+
parser.add_argument("--profile", action="store_true")
|
|
2675
|
+
parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
|
|
2676
|
+
parser.add_argument("--defense-module", help="Path to defense.py module")
|
|
2677
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
|
2450
2678
|
parser.add_argument("--num-correct-trials", type=int, default=3)
|
|
2451
2679
|
parser.add_argument("--num-perf-trials", type=int, default=10)
|
|
2452
2680
|
parser.add_argument("--output", required=True)
|
|
2453
2681
|
args = parser.parse_args()
|
|
2454
2682
|
|
|
2683
|
+
# Load defense module if defensive mode is enabled
|
|
2684
|
+
defense_module = None
|
|
2685
|
+
if args.defensive and args.defense_module:
|
|
2686
|
+
try:
|
|
2687
|
+
import importlib.util
|
|
2688
|
+
defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
|
|
2689
|
+
defense_module = importlib.util.module_from_spec(defense_spec)
|
|
2690
|
+
defense_spec.loader.exec_module(defense_module)
|
|
2691
|
+
print("[KernelBench] Defense module loaded")
|
|
2692
|
+
except Exception as e:
|
|
2693
|
+
print(f"[KernelBench] Warning: Could not load defense module: {e}")
|
|
2694
|
+
|
|
2695
|
+
# Create output directory for profiles
|
|
2696
|
+
output_dir = Path(args.output).parent
|
|
2697
|
+
profile_dir = output_dir / "profiles"
|
|
2698
|
+
if args.profile:
|
|
2699
|
+
profile_dir.mkdir(exist_ok=True)
|
|
2700
|
+
|
|
2455
2701
|
results = {
|
|
2456
2702
|
"compiled": False,
|
|
2457
2703
|
"correct": False,
|
|
@@ -2472,6 +2718,33 @@ def main():
|
|
|
2472
2718
|
get_inputs = ref_module.get_inputs
|
|
2473
2719
|
get_init_inputs = ref_module.get_init_inputs
|
|
2474
2720
|
|
|
2721
|
+
# Load custom inputs if provided
|
|
2722
|
+
if args.inputs:
|
|
2723
|
+
inputs_spec = importlib.util.spec_from_file_location("custom_inputs", args.inputs)
|
|
2724
|
+
inputs_module = importlib.util.module_from_spec(inputs_spec)
|
|
2725
|
+
inputs_spec.loader.exec_module(inputs_module)
|
|
2726
|
+
|
|
2727
|
+
# Validate custom inputs match expected signature
|
|
2728
|
+
original_inputs = get_inputs()
|
|
2729
|
+
custom_get_inputs = inputs_module.get_inputs
|
|
2730
|
+
custom_inputs = custom_get_inputs()
|
|
2731
|
+
|
|
2732
|
+
is_valid, error_msg = validate_custom_inputs(original_inputs, custom_inputs)
|
|
2733
|
+
if not is_valid:
|
|
2734
|
+
print(f"[KernelBench] Custom inputs validation failed: {error_msg}")
|
|
2735
|
+
results["error"] = f"Custom inputs validation failed: {error_msg}"
|
|
2736
|
+
raise ValueError(error_msg)
|
|
2737
|
+
|
|
2738
|
+
# Override get_inputs (and optionally get_init_inputs)
|
|
2739
|
+
get_inputs = custom_get_inputs
|
|
2740
|
+
if hasattr(inputs_module, 'get_init_inputs'):
|
|
2741
|
+
get_init_inputs = inputs_module.get_init_inputs
|
|
2742
|
+
|
|
2743
|
+
# Show what changed
|
|
2744
|
+
orig_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in original_inputs]
|
|
2745
|
+
cust_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in custom_inputs]
|
|
2746
|
+
print(f"[KernelBench] Using custom inputs: {orig_shapes} -> {cust_shapes}")
|
|
2747
|
+
|
|
2475
2748
|
# Load implementation module
|
|
2476
2749
|
impl_spec = importlib.util.spec_from_file_location("implementation", args.impl)
|
|
2477
2750
|
impl_module = importlib.util.module_from_spec(impl_spec)
|
|
@@ -2481,12 +2754,19 @@ def main():
|
|
|
2481
2754
|
results["compiled"] = True
|
|
2482
2755
|
print("[KernelBench] Modules loaded successfully")
|
|
2483
2756
|
|
|
2484
|
-
# Instantiate models
|
|
2757
|
+
# Instantiate models with synchronized seeds for reproducible weights
|
|
2758
|
+
# (matches upstream KernelBench behavior in src/eval.py)
|
|
2759
|
+
seed = args.seed
|
|
2485
2760
|
init_inputs = get_init_inputs()
|
|
2486
2761
|
with torch.no_grad():
|
|
2762
|
+
torch.manual_seed(seed)
|
|
2763
|
+
torch.cuda.manual_seed(seed)
|
|
2487
2764
|
ref_model = Model(*init_inputs).cuda().eval()
|
|
2765
|
+
|
|
2766
|
+
torch.manual_seed(seed)
|
|
2767
|
+
torch.cuda.manual_seed(seed)
|
|
2488
2768
|
new_model = ModelNew(*init_inputs).cuda().eval()
|
|
2489
|
-
print("[KernelBench] Models instantiated")
|
|
2769
|
+
print(f"[KernelBench] Models instantiated (seed={seed})")
|
|
2490
2770
|
|
|
2491
2771
|
# Run correctness trials
|
|
2492
2772
|
all_correct = True
|
|
@@ -2502,8 +2782,18 @@ def main():
|
|
|
2502
2782
|
if isinstance(ref_output, torch.Tensor):
|
|
2503
2783
|
if not torch.allclose(ref_output, new_output, rtol=1e-3, atol=1e-3):
|
|
2504
2784
|
all_correct = False
|
|
2505
|
-
|
|
2506
|
-
results["error"] = f"Correctness failed on trial {trial+1}: max diff = {max_diff}"
|
|
2785
|
+
analysis = analyze_diff(ref_output, new_output)
|
|
2786
|
+
results["error"] = f"Correctness failed on trial {trial+1}: max diff = {analysis['max_diff']}"
|
|
2787
|
+
results["diff_analysis"] = analysis
|
|
2788
|
+
print_diff_analysis(analysis)
|
|
2789
|
+
|
|
2790
|
+
# Save tensors for debugging
|
|
2791
|
+
debug_dir = output_dir / "debug"
|
|
2792
|
+
debug_dir.mkdir(exist_ok=True)
|
|
2793
|
+
torch.save(ref_output.cpu(), debug_dir / "ref_output.pt")
|
|
2794
|
+
torch.save(new_output.cpu(), debug_dir / "impl_output.pt")
|
|
2795
|
+
torch.save(inputs[0].cpu() if inputs else None, debug_dir / "input.pt")
|
|
2796
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
2507
2797
|
break
|
|
2508
2798
|
else:
|
|
2509
2799
|
# Handle tuple/list outputs
|
|
@@ -2511,8 +2801,17 @@ def main():
|
|
|
2511
2801
|
if isinstance(r, torch.Tensor):
|
|
2512
2802
|
if not torch.allclose(r, n, rtol=1e-3, atol=1e-3):
|
|
2513
2803
|
all_correct = False
|
|
2514
|
-
|
|
2515
|
-
results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {max_diff}"
|
|
2804
|
+
analysis = analyze_diff(r, n)
|
|
2805
|
+
results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {analysis['max_diff']}"
|
|
2806
|
+
results["diff_analysis"] = analysis
|
|
2807
|
+
print_diff_analysis(analysis)
|
|
2808
|
+
|
|
2809
|
+
# Save tensors for debugging
|
|
2810
|
+
debug_dir = output_dir / "debug"
|
|
2811
|
+
debug_dir.mkdir(exist_ok=True)
|
|
2812
|
+
torch.save(r.cpu(), debug_dir / f"ref_output_{i}.pt")
|
|
2813
|
+
torch.save(n.cpu(), debug_dir / f"impl_output_{i}.pt")
|
|
2814
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
2516
2815
|
break
|
|
2517
2816
|
if not all_correct:
|
|
2518
2817
|
break
|
|
@@ -2526,47 +2825,132 @@ def main():
|
|
|
2526
2825
|
inputs = get_inputs()
|
|
2527
2826
|
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
2528
2827
|
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2828
|
+
if args.defensive and defense_module is not None:
|
|
2829
|
+
# Use full defense suite
|
|
2830
|
+
print("[KernelBench] Running defense checks on implementation...")
|
|
2831
|
+
run_all_defenses = defense_module.run_all_defenses
|
|
2832
|
+
time_with_defenses = defense_module.time_execution_with_defenses
|
|
2534
2833
|
|
|
2535
|
-
|
|
2536
|
-
|
|
2537
|
-
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
2834
|
+
# Run defense checks on implementation
|
|
2835
|
+
all_passed, defense_results, _ = run_all_defenses(
|
|
2836
|
+
lambda *x: new_model(*x),
|
|
2837
|
+
*inputs,
|
|
2838
|
+
)
|
|
2839
|
+
results["defense_results"] = {
|
|
2840
|
+
name: {"passed": passed, "message": msg}
|
|
2841
|
+
for name, passed, msg in defense_results
|
|
2842
|
+
}
|
|
2843
|
+
if not all_passed:
|
|
2844
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
2845
|
+
results["error"] = f"Defense checks failed: {failed}"
|
|
2846
|
+
print(f"[KernelBench] Defense checks FAILED: {failed}")
|
|
2847
|
+
for name, passed, msg in defense_results:
|
|
2848
|
+
status = "PASS" if passed else "FAIL"
|
|
2849
|
+
print(f" [{status}] {name}: {msg}")
|
|
2850
|
+
else:
|
|
2851
|
+
print("[KernelBench] All defense checks passed")
|
|
2852
|
+
|
|
2853
|
+
# Time with defensive timing
|
|
2854
|
+
impl_times, _ = time_with_defenses(
|
|
2855
|
+
lambda: new_model(*inputs),
|
|
2856
|
+
[],
|
|
2857
|
+
num_warmup=5,
|
|
2858
|
+
num_trials=args.num_perf_trials,
|
|
2859
|
+
verbose=False,
|
|
2860
|
+
run_defenses=False, # Already ran above
|
|
2861
|
+
)
|
|
2862
|
+
new_time = sum(impl_times) / len(impl_times)
|
|
2863
|
+
results["runtime_ms"] = new_time
|
|
2864
|
+
|
|
2865
|
+
# Reference timing
|
|
2866
|
+
ref_times, _ = time_with_defenses(
|
|
2867
|
+
lambda: ref_model(*inputs),
|
|
2868
|
+
[],
|
|
2869
|
+
num_warmup=5,
|
|
2870
|
+
num_trials=args.num_perf_trials,
|
|
2871
|
+
verbose=False,
|
|
2872
|
+
run_defenses=False,
|
|
2873
|
+
)
|
|
2874
|
+
ref_time = sum(ref_times) / len(ref_times)
|
|
2875
|
+
results["reference_runtime_ms"] = ref_time
|
|
2876
|
+
results["speedup"] = ref_time / new_time if new_time > 0 else 0
|
|
2877
|
+
print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
|
|
2878
|
+
else:
|
|
2879
|
+
# Standard timing without full defenses
|
|
2880
|
+
# Warmup
|
|
2881
|
+
for _ in range(5):
|
|
2882
|
+
with torch.no_grad():
|
|
2883
|
+
_ = new_model(*inputs)
|
|
2545
2884
|
torch.cuda.synchronize()
|
|
2546
|
-
times.append(start.elapsed_time(end))
|
|
2547
2885
|
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2551
|
-
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2886
|
+
# Benchmark new model
|
|
2887
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
2888
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
2889
|
+
|
|
2890
|
+
times = []
|
|
2891
|
+
for _ in range(args.num_perf_trials):
|
|
2892
|
+
start.record()
|
|
2893
|
+
with torch.no_grad():
|
|
2894
|
+
_ = new_model(*inputs)
|
|
2895
|
+
end.record()
|
|
2896
|
+
torch.cuda.synchronize()
|
|
2897
|
+
times.append(start.elapsed_time(end))
|
|
2898
|
+
|
|
2899
|
+
new_time = sum(times) / len(times)
|
|
2900
|
+
results["runtime_ms"] = new_time
|
|
2901
|
+
|
|
2902
|
+
# Benchmark reference model
|
|
2903
|
+
for _ in range(5):
|
|
2904
|
+
with torch.no_grad():
|
|
2905
|
+
_ = ref_model(*inputs)
|
|
2563
2906
|
torch.cuda.synchronize()
|
|
2564
|
-
times.append(start.elapsed_time(end))
|
|
2565
2907
|
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2908
|
+
times = []
|
|
2909
|
+
for _ in range(args.num_perf_trials):
|
|
2910
|
+
start.record()
|
|
2911
|
+
with torch.no_grad():
|
|
2912
|
+
_ = ref_model(*inputs)
|
|
2913
|
+
end.record()
|
|
2914
|
+
torch.cuda.synchronize()
|
|
2915
|
+
times.append(start.elapsed_time(end))
|
|
2916
|
+
|
|
2917
|
+
ref_time = sum(times) / len(times)
|
|
2918
|
+
results["reference_runtime_ms"] = ref_time
|
|
2919
|
+
results["speedup"] = ref_time / new_time if new_time > 0 else 0
|
|
2920
|
+
print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
|
|
2921
|
+
|
|
2922
|
+
# Run profiling if requested and correctness passed
|
|
2923
|
+
if args.profile and all_correct:
|
|
2924
|
+
print("[KernelBench] Running profiler...")
|
|
2925
|
+
inputs = get_inputs()
|
|
2926
|
+
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
2927
|
+
|
|
2928
|
+
try:
|
|
2929
|
+
# Profile implementation
|
|
2930
|
+
impl_stats = run_profiling(new_model, inputs, "implementation", str(profile_dir))
|
|
2931
|
+
results["profile_impl"] = impl_stats
|
|
2932
|
+
print(f"[KernelBench] Implementation profile:")
|
|
2933
|
+
print(f" Total GPU time: {impl_stats['total_gpu_time_ms']:.3f}ms")
|
|
2934
|
+
print(f" Kernels launched: {impl_stats['num_gpu_kernels']}")
|
|
2935
|
+
if impl_stats['top_kernels']:
|
|
2936
|
+
print(f" Top kernel: {impl_stats['top_kernels'][0]['name'][:60]}...")
|
|
2937
|
+
print(f" {impl_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
|
|
2938
|
+
|
|
2939
|
+
# Profile reference
|
|
2940
|
+
ref_stats = run_profiling(ref_model, inputs, "reference", str(profile_dir))
|
|
2941
|
+
results["profile_ref"] = ref_stats
|
|
2942
|
+
print(f"[KernelBench] Reference profile:")
|
|
2943
|
+
print(f" Total GPU time: {ref_stats['total_gpu_time_ms']:.3f}ms")
|
|
2944
|
+
print(f" Kernels launched: {ref_stats['num_gpu_kernels']}")
|
|
2945
|
+
if ref_stats['top_kernels']:
|
|
2946
|
+
print(f" Top kernel: {ref_stats['top_kernels'][0]['name'][:60]}...")
|
|
2947
|
+
print(f" {ref_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
|
|
2948
|
+
|
|
2949
|
+
print(f"[KernelBench] Profile traces saved to: {profile_dir}/")
|
|
2950
|
+
|
|
2951
|
+
except Exception as prof_err:
|
|
2952
|
+
print(f"[KernelBench] Profiling failed: {prof_err}")
|
|
2953
|
+
results["profile_error"] = str(prof_err)
|
|
2570
2954
|
|
|
2571
2955
|
except Exception as e:
|
|
2572
2956
|
import traceback
|
|
@@ -2705,6 +3089,24 @@ async def run_evaluate_kernelbench_docker(
|
|
|
2705
3089
|
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
2706
3090
|
)
|
|
2707
3091
|
|
|
3092
|
+
# Write custom inputs if provided
|
|
3093
|
+
if args.inputs:
|
|
3094
|
+
inputs_code = args.inputs.read_text()
|
|
3095
|
+
inputs_file_path = f"{run_path}/custom_inputs.py"
|
|
3096
|
+
write_result = await client.exec(
|
|
3097
|
+
f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
3098
|
+
)
|
|
3099
|
+
if write_result.exit_code != 0:
|
|
3100
|
+
return EvaluateResult(
|
|
3101
|
+
success=False,
|
|
3102
|
+
all_correct=False,
|
|
3103
|
+
correctness_score=0.0,
|
|
3104
|
+
geomean_speedup=0.0,
|
|
3105
|
+
passed_tests=0,
|
|
3106
|
+
total_tests=0,
|
|
3107
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
3108
|
+
)
|
|
3109
|
+
|
|
2708
3110
|
# Write eval script
|
|
2709
3111
|
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
2710
3112
|
write_result = await client.exec(
|
|
@@ -2721,14 +3123,40 @@ async def run_evaluate_kernelbench_docker(
|
|
|
2721
3123
|
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
2722
3124
|
)
|
|
2723
3125
|
|
|
3126
|
+
# Write defense module if defensive mode is enabled
|
|
3127
|
+
defense_module_path = None
|
|
3128
|
+
if args.defensive:
|
|
3129
|
+
defense_path = (
|
|
3130
|
+
Path(__file__).parent.parent.parent.parent
|
|
3131
|
+
/ "packages"
|
|
3132
|
+
/ "wafer-core"
|
|
3133
|
+
/ "wafer_core"
|
|
3134
|
+
/ "utils"
|
|
3135
|
+
/ "kernel_utils"
|
|
3136
|
+
/ "defense.py"
|
|
3137
|
+
)
|
|
3138
|
+
if defense_path.exists():
|
|
3139
|
+
defense_code = defense_path.read_text()
|
|
3140
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
3141
|
+
write_result = await client.exec(
|
|
3142
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
3143
|
+
)
|
|
3144
|
+
if write_result.exit_code != 0:
|
|
3145
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
3146
|
+
defense_module_path = None
|
|
3147
|
+
else:
|
|
3148
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
3149
|
+
|
|
2724
3150
|
print("Running KernelBench evaluation in Docker container...")
|
|
2725
3151
|
|
|
2726
3152
|
# Paths inside container
|
|
2727
3153
|
container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
|
|
2728
3154
|
container_impl_path = f"{container_run_path}/implementation.py"
|
|
2729
3155
|
container_ref_path = f"{container_run_path}/reference.py"
|
|
3156
|
+
container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
|
|
2730
3157
|
container_eval_script = f"{container_run_path}/kernelbench_eval.py"
|
|
2731
3158
|
container_output = f"{container_run_path}/results.json"
|
|
3159
|
+
container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
|
|
2732
3160
|
|
|
2733
3161
|
# Build eval command
|
|
2734
3162
|
python_cmd_parts = [
|
|
@@ -2740,6 +3168,14 @@ async def run_evaluate_kernelbench_docker(
|
|
|
2740
3168
|
|
|
2741
3169
|
if args.benchmark:
|
|
2742
3170
|
python_cmd_parts.append("--benchmark")
|
|
3171
|
+
if args.profile:
|
|
3172
|
+
python_cmd_parts.append("--profile")
|
|
3173
|
+
if container_inputs_path:
|
|
3174
|
+
python_cmd_parts.append(f"--inputs {container_inputs_path}")
|
|
3175
|
+
if args.defensive and container_defense_path:
|
|
3176
|
+
python_cmd_parts.append("--defensive")
|
|
3177
|
+
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3178
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
2743
3179
|
|
|
2744
3180
|
eval_cmd = " ".join(python_cmd_parts)
|
|
2745
3181
|
|
|
@@ -2920,6 +3356,24 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
2920
3356
|
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
2921
3357
|
)
|
|
2922
3358
|
|
|
3359
|
+
# Write custom inputs if provided
|
|
3360
|
+
if args.inputs:
|
|
3361
|
+
inputs_code = args.inputs.read_text()
|
|
3362
|
+
inputs_file_path = f"{run_path}/custom_inputs.py"
|
|
3363
|
+
write_result = await client.exec(
|
|
3364
|
+
f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
3365
|
+
)
|
|
3366
|
+
if write_result.exit_code != 0:
|
|
3367
|
+
return EvaluateResult(
|
|
3368
|
+
success=False,
|
|
3369
|
+
all_correct=False,
|
|
3370
|
+
correctness_score=0.0,
|
|
3371
|
+
geomean_speedup=0.0,
|
|
3372
|
+
passed_tests=0,
|
|
3373
|
+
total_tests=0,
|
|
3374
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
3375
|
+
)
|
|
3376
|
+
|
|
2923
3377
|
# Write eval script
|
|
2924
3378
|
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
2925
3379
|
write_result = await client.exec(
|
|
@@ -2936,14 +3390,44 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
2936
3390
|
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
2937
3391
|
)
|
|
2938
3392
|
|
|
3393
|
+
# Write defense module if defensive mode is enabled
|
|
3394
|
+
defense_module_path = None
|
|
3395
|
+
if args.defensive:
|
|
3396
|
+
defense_path = (
|
|
3397
|
+
Path(__file__).parent.parent.parent.parent
|
|
3398
|
+
/ "packages"
|
|
3399
|
+
/ "wafer-core"
|
|
3400
|
+
/ "wafer_core"
|
|
3401
|
+
/ "utils"
|
|
3402
|
+
/ "kernel_utils"
|
|
3403
|
+
/ "defense.py"
|
|
3404
|
+
)
|
|
3405
|
+
if defense_path.exists():
|
|
3406
|
+
defense_code = defense_path.read_text()
|
|
3407
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
3408
|
+
write_result = await client.exec(
|
|
3409
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
3410
|
+
)
|
|
3411
|
+
if write_result.exit_code != 0:
|
|
3412
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
3413
|
+
defense_module_path = None
|
|
3414
|
+
else:
|
|
3415
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
3416
|
+
|
|
2939
3417
|
print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
|
|
2940
3418
|
|
|
2941
3419
|
# Paths inside container
|
|
2942
3420
|
container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
|
|
2943
3421
|
container_impl_path = f"{container_run_path}/implementation.py"
|
|
2944
3422
|
container_ref_path = f"{container_run_path}/reference.py"
|
|
3423
|
+
container_inputs_path = (
|
|
3424
|
+
f"{container_run_path}/custom_inputs.py" if args.inputs else None
|
|
3425
|
+
)
|
|
2945
3426
|
container_eval_script = f"{container_run_path}/kernelbench_eval.py"
|
|
2946
3427
|
container_output = f"{container_run_path}/results.json"
|
|
3428
|
+
container_defense_path = (
|
|
3429
|
+
f"{container_run_path}/defense.py" if defense_module_path else None
|
|
3430
|
+
)
|
|
2947
3431
|
|
|
2948
3432
|
# Build eval command
|
|
2949
3433
|
python_cmd_parts = [
|
|
@@ -2955,6 +3439,14 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
2955
3439
|
|
|
2956
3440
|
if args.benchmark:
|
|
2957
3441
|
python_cmd_parts.append("--benchmark")
|
|
3442
|
+
if args.profile:
|
|
3443
|
+
python_cmd_parts.append("--profile")
|
|
3444
|
+
if container_inputs_path:
|
|
3445
|
+
python_cmd_parts.append(f"--inputs {container_inputs_path}")
|
|
3446
|
+
if args.defensive and container_defense_path:
|
|
3447
|
+
python_cmd_parts.append("--defensive")
|
|
3448
|
+
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3449
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
2958
3450
|
|
|
2959
3451
|
eval_cmd = " ".join(python_cmd_parts)
|
|
2960
3452
|
|