wafer-cli 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/cli.py +862 -104
- wafer/evaluate.py +1423 -158
- wafer/gpu_run.py +5 -1
- wafer/problems.py +357 -0
- wafer/target_lock.py +198 -0
- wafer/targets.py +158 -0
- wafer/wevin_cli.py +22 -2
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/RECORD +12 -10
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/WHEEL +1 -1
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.3.dist-info → wafer_cli-0.2.5.dist-info}/top_level.txt +0 -0
wafer/evaluate.py
CHANGED
|
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
|
|
|
14
14
|
from wafer_core.utils.kernel_utils.targets.config import (
|
|
15
15
|
BaremetalTarget,
|
|
16
16
|
DigitalOceanTarget,
|
|
17
|
+
LocalTarget,
|
|
17
18
|
ModalTarget,
|
|
18
19
|
RunPodTarget,
|
|
19
20
|
VMTarget,
|
|
@@ -158,6 +159,8 @@ class KernelBenchEvaluateArgs:
|
|
|
158
159
|
target_name: str
|
|
159
160
|
benchmark: bool = False
|
|
160
161
|
profile: bool = False
|
|
162
|
+
inputs: Path | None = None # Custom inputs file to override get_inputs()
|
|
163
|
+
seed: int = 42 # Random seed for reproducibility
|
|
161
164
|
defensive: bool = False
|
|
162
165
|
sync_artifacts: bool = True
|
|
163
166
|
gpu_id: int | None = None
|
|
@@ -394,33 +397,6 @@ async def run_evaluate_docker(
|
|
|
394
397
|
print(f"Connecting to {target.ssh_target}...")
|
|
395
398
|
|
|
396
399
|
async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
|
|
397
|
-
# Upload wafer-core to remote
|
|
398
|
-
try:
|
|
399
|
-
wafer_root = _get_wafer_root()
|
|
400
|
-
wafer_core_path = wafer_root / "packages" / "wafer-core"
|
|
401
|
-
print(f"Uploading wafer-core from {wafer_core_path}...")
|
|
402
|
-
|
|
403
|
-
# Create workspace and upload
|
|
404
|
-
workspace_name = wafer_core_path.name
|
|
405
|
-
remote_workspace = f"{REMOTE_WORKSPACE_BASE}/{workspace_name}"
|
|
406
|
-
await client.exec(f"mkdir -p {remote_workspace}")
|
|
407
|
-
wafer_core_workspace = await client.expand_path(remote_workspace)
|
|
408
|
-
|
|
409
|
-
upload_result = await client.upload_files(
|
|
410
|
-
str(wafer_core_path), wafer_core_workspace, recursive=True
|
|
411
|
-
)
|
|
412
|
-
print(f"Uploaded {upload_result.files_copied} files")
|
|
413
|
-
except Exception as e:
|
|
414
|
-
return EvaluateResult(
|
|
415
|
-
success=False,
|
|
416
|
-
all_correct=False,
|
|
417
|
-
correctness_score=0.0,
|
|
418
|
-
geomean_speedup=0.0,
|
|
419
|
-
passed_tests=0,
|
|
420
|
-
total_tests=0,
|
|
421
|
-
error_message=f"Failed to upload wafer-core: {e}",
|
|
422
|
-
)
|
|
423
|
-
|
|
424
400
|
print(f"Using Docker image: {target.docker_image}")
|
|
425
401
|
print(f"Using GPU {gpu_id}...")
|
|
426
402
|
|
|
@@ -429,10 +405,13 @@ async def run_evaluate_docker(
|
|
|
429
405
|
ref_code = args.reference.read_text()
|
|
430
406
|
test_cases_data = json.loads(args.test_cases.read_text())
|
|
431
407
|
|
|
432
|
-
# Create
|
|
408
|
+
# Create workspace for evaluation files
|
|
433
409
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
434
410
|
run_dir = f"wafer_eval_{timestamp}"
|
|
435
|
-
|
|
411
|
+
eval_workspace = f"{REMOTE_WORKSPACE_BASE}/eval_{timestamp}"
|
|
412
|
+
await client.exec(f"mkdir -p {eval_workspace}")
|
|
413
|
+
eval_workspace_expanded = await client.expand_path(eval_workspace)
|
|
414
|
+
run_path = f"{eval_workspace_expanded}/{run_dir}"
|
|
436
415
|
|
|
437
416
|
print("Uploading evaluation files...")
|
|
438
417
|
|
|
@@ -519,17 +498,16 @@ async def run_evaluate_docker(
|
|
|
519
498
|
container_impl_path = f"{container_run_path}/implementation.py"
|
|
520
499
|
container_ref_path = f"{container_run_path}/reference.py"
|
|
521
500
|
container_test_cases_path = f"{container_run_path}/test_cases.json"
|
|
522
|
-
container_evaluate_script = (
|
|
523
|
-
f"{CONTAINER_WORKSPACE}/wafer_core/utils/kernel_utils/evaluate.py"
|
|
524
|
-
)
|
|
525
501
|
|
|
526
|
-
# Build pip install command for torch and other deps
|
|
502
|
+
# Build pip install command for torch and other deps, plus wafer-core
|
|
527
503
|
pip_install_cmd = _build_docker_pip_install_cmd(target)
|
|
504
|
+
install_cmd = (
|
|
505
|
+
f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
|
|
506
|
+
)
|
|
528
507
|
|
|
529
|
-
# Build evaluate command
|
|
508
|
+
# Build evaluate command using installed wafer-core module
|
|
530
509
|
python_cmd_parts = [
|
|
531
|
-
|
|
532
|
-
f"python3 {container_evaluate_script}",
|
|
510
|
+
"python3 -m wafer_core.utils.kernel_utils.evaluate",
|
|
533
511
|
f"--implementation {container_impl_path}",
|
|
534
512
|
f"--reference {container_ref_path}",
|
|
535
513
|
f"--test-cases {container_test_cases_path}",
|
|
@@ -545,8 +523,8 @@ async def run_evaluate_docker(
|
|
|
545
523
|
|
|
546
524
|
eval_cmd = " ".join(python_cmd_parts)
|
|
547
525
|
|
|
548
|
-
# Full command: install
|
|
549
|
-
full_cmd = f"{
|
|
526
|
+
# Full command: install deps + wafer-core, then run evaluate
|
|
527
|
+
full_cmd = f"{install_cmd} && cd {container_run_path} && {eval_cmd}"
|
|
550
528
|
|
|
551
529
|
# Build Docker run command
|
|
552
530
|
# Add SYS_ADMIN capability when profiling (needed for NCU GPU performance counters)
|
|
@@ -556,7 +534,7 @@ async def run_evaluate_docker(
|
|
|
556
534
|
working_dir=container_run_path,
|
|
557
535
|
env={"CUDA_VISIBLE_DEVICES": str(gpu_id), "PYTHONUNBUFFERED": "1"},
|
|
558
536
|
gpus="all",
|
|
559
|
-
volumes={
|
|
537
|
+
volumes={eval_workspace_expanded: CONTAINER_WORKSPACE},
|
|
560
538
|
cap_add=["SYS_ADMIN"] if args.profile else None,
|
|
561
539
|
)
|
|
562
540
|
|
|
@@ -663,6 +641,181 @@ async def run_evaluate_docker(
|
|
|
663
641
|
)
|
|
664
642
|
|
|
665
643
|
|
|
644
|
+
async def run_evaluate_local(
|
|
645
|
+
args: EvaluateArgs,
|
|
646
|
+
target: LocalTarget,
|
|
647
|
+
) -> EvaluateResult:
|
|
648
|
+
"""Run evaluation locally on the current machine.
|
|
649
|
+
|
|
650
|
+
For LocalTarget - no SSH needed, runs directly.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
args: Evaluate arguments
|
|
654
|
+
target: Local target config
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
Evaluation result
|
|
658
|
+
"""
|
|
659
|
+
import os
|
|
660
|
+
import subprocess
|
|
661
|
+
import tempfile
|
|
662
|
+
from datetime import datetime
|
|
663
|
+
|
|
664
|
+
# Select GPU
|
|
665
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
666
|
+
|
|
667
|
+
print(f"Running local evaluation on GPU {gpu_id}...")
|
|
668
|
+
|
|
669
|
+
# Create temp directory for eval files
|
|
670
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
671
|
+
with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
|
|
672
|
+
run_path = Path(run_path)
|
|
673
|
+
|
|
674
|
+
# Write implementation
|
|
675
|
+
impl_path = run_path / "implementation.py"
|
|
676
|
+
impl_path.write_text(args.implementation.read_text())
|
|
677
|
+
|
|
678
|
+
# Write reference
|
|
679
|
+
ref_path = run_path / "reference.py"
|
|
680
|
+
ref_path.write_text(args.reference.read_text())
|
|
681
|
+
|
|
682
|
+
# Write custom inputs if provided
|
|
683
|
+
inputs_path = None
|
|
684
|
+
if args.inputs:
|
|
685
|
+
inputs_path = run_path / "custom_inputs.py"
|
|
686
|
+
inputs_path.write_text(args.inputs.read_text())
|
|
687
|
+
|
|
688
|
+
# Write eval script
|
|
689
|
+
eval_script_path = run_path / "kernelbench_eval.py"
|
|
690
|
+
eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
|
|
691
|
+
|
|
692
|
+
# Write defense module if defensive mode is enabled
|
|
693
|
+
defense_module_path = None
|
|
694
|
+
if args.defensive:
|
|
695
|
+
defense_src = (
|
|
696
|
+
Path(__file__).parent.parent.parent.parent
|
|
697
|
+
/ "packages"
|
|
698
|
+
/ "wafer-core"
|
|
699
|
+
/ "wafer_core"
|
|
700
|
+
/ "utils"
|
|
701
|
+
/ "kernel_utils"
|
|
702
|
+
/ "defense.py"
|
|
703
|
+
)
|
|
704
|
+
if defense_src.exists():
|
|
705
|
+
defense_module_path = run_path / "defense.py"
|
|
706
|
+
defense_module_path.write_text(defense_src.read_text())
|
|
707
|
+
else:
|
|
708
|
+
print(f"Warning: defense.py not found at {defense_src}")
|
|
709
|
+
|
|
710
|
+
# Output file
|
|
711
|
+
output_path = run_path / "results.json"
|
|
712
|
+
|
|
713
|
+
# Build eval command
|
|
714
|
+
cmd_parts = [
|
|
715
|
+
"python3",
|
|
716
|
+
str(eval_script_path),
|
|
717
|
+
"--impl",
|
|
718
|
+
str(impl_path),
|
|
719
|
+
"--reference",
|
|
720
|
+
str(ref_path),
|
|
721
|
+
"--output",
|
|
722
|
+
str(output_path),
|
|
723
|
+
"--seed",
|
|
724
|
+
str(args.seed),
|
|
725
|
+
]
|
|
726
|
+
|
|
727
|
+
if args.benchmark:
|
|
728
|
+
cmd_parts.append("--benchmark")
|
|
729
|
+
if args.profile:
|
|
730
|
+
cmd_parts.append("--profile")
|
|
731
|
+
if inputs_path:
|
|
732
|
+
cmd_parts.extend(["--inputs", str(inputs_path)])
|
|
733
|
+
if args.defensive and defense_module_path:
|
|
734
|
+
cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
|
|
735
|
+
|
|
736
|
+
# Set environment for GPU selection
|
|
737
|
+
env = os.environ.copy()
|
|
738
|
+
if target.vendor == "nvidia":
|
|
739
|
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
740
|
+
else: # AMD
|
|
741
|
+
env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
|
|
742
|
+
env["ROCM_PATH"] = "/opt/rocm"
|
|
743
|
+
|
|
744
|
+
print(f"Running: {' '.join(cmd_parts[:4])} ...")
|
|
745
|
+
|
|
746
|
+
# Run evaluation
|
|
747
|
+
try:
|
|
748
|
+
result = subprocess.run(
|
|
749
|
+
cmd_parts,
|
|
750
|
+
cwd=str(run_path),
|
|
751
|
+
env=env,
|
|
752
|
+
capture_output=True,
|
|
753
|
+
text=True,
|
|
754
|
+
timeout=args.timeout or 600,
|
|
755
|
+
)
|
|
756
|
+
except subprocess.TimeoutExpired:
|
|
757
|
+
return EvaluateResult(
|
|
758
|
+
success=False,
|
|
759
|
+
all_correct=False,
|
|
760
|
+
correctness_score=0.0,
|
|
761
|
+
geomean_speedup=0.0,
|
|
762
|
+
passed_tests=0,
|
|
763
|
+
total_tests=0,
|
|
764
|
+
error_message="Evaluation timed out",
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
if result.returncode != 0:
|
|
768
|
+
error_msg = result.stderr or result.stdout or "Unknown error"
|
|
769
|
+
# Truncate long errors
|
|
770
|
+
if len(error_msg) > 1000:
|
|
771
|
+
error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
|
|
772
|
+
return EvaluateResult(
|
|
773
|
+
success=False,
|
|
774
|
+
all_correct=False,
|
|
775
|
+
correctness_score=0.0,
|
|
776
|
+
geomean_speedup=0.0,
|
|
777
|
+
passed_tests=0,
|
|
778
|
+
total_tests=0,
|
|
779
|
+
error_message=f"Evaluation failed:\n{error_msg}",
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# Parse results
|
|
783
|
+
if not output_path.exists():
|
|
784
|
+
return EvaluateResult(
|
|
785
|
+
success=False,
|
|
786
|
+
all_correct=False,
|
|
787
|
+
correctness_score=0.0,
|
|
788
|
+
geomean_speedup=0.0,
|
|
789
|
+
passed_tests=0,
|
|
790
|
+
total_tests=0,
|
|
791
|
+
error_message="No results.json produced",
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
try:
|
|
795
|
+
results = json.loads(output_path.read_text())
|
|
796
|
+
except json.JSONDecodeError as e:
|
|
797
|
+
return EvaluateResult(
|
|
798
|
+
success=False,
|
|
799
|
+
all_correct=False,
|
|
800
|
+
correctness_score=0.0,
|
|
801
|
+
geomean_speedup=0.0,
|
|
802
|
+
passed_tests=0,
|
|
803
|
+
total_tests=0,
|
|
804
|
+
error_message=f"Failed to parse results: {e}",
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
# Extract results
|
|
808
|
+
return EvaluateResult(
|
|
809
|
+
success=True,
|
|
810
|
+
all_correct=results.get("all_correct", False),
|
|
811
|
+
correctness_score=results.get("correctness_score", 0.0),
|
|
812
|
+
geomean_speedup=results.get("geomean_speedup", 0.0),
|
|
813
|
+
passed_tests=results.get("passed_tests", 0),
|
|
814
|
+
total_tests=results.get("total_tests", 0),
|
|
815
|
+
benchmark_results=results.get("benchmark", {}),
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
|
|
666
819
|
async def run_evaluate_ssh(
|
|
667
820
|
args: EvaluateArgs,
|
|
668
821
|
target: BaremetalTarget | VMTarget,
|
|
@@ -980,6 +1133,7 @@ def _build_modal_sandbox_script(
|
|
|
980
1133
|
test_cases_b64: str,
|
|
981
1134
|
run_benchmarks: bool,
|
|
982
1135
|
run_defensive: bool,
|
|
1136
|
+
defense_code_b64: str | None = None,
|
|
983
1137
|
) -> str:
|
|
984
1138
|
"""Build Python script to create sandbox and run evaluation.
|
|
985
1139
|
|
|
@@ -1060,6 +1214,20 @@ print('Files written')
|
|
|
1060
1214
|
print(json.dumps({{"error": f"Failed to write files: {{proc.stderr.read()}}"}}))
|
|
1061
1215
|
return
|
|
1062
1216
|
|
|
1217
|
+
# Write defense module if defensive mode is enabled
|
|
1218
|
+
# NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
|
|
1219
|
+
if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
|
|
1220
|
+
proc = sandbox.exec("python", "-c", f"""
|
|
1221
|
+
import base64
|
|
1222
|
+
with open('/workspace/defense.py', 'w') as f:
|
|
1223
|
+
f.write(base64.b64decode('{defense_code_b64}').decode())
|
|
1224
|
+
print('Defense module written')
|
|
1225
|
+
""")
|
|
1226
|
+
proc.wait()
|
|
1227
|
+
if proc.returncode != 0:
|
|
1228
|
+
print(json.dumps({{"error": f"Failed to write defense module: {{proc.stderr.read()}}"}}))
|
|
1229
|
+
return
|
|
1230
|
+
|
|
1063
1231
|
# Build inline evaluation script
|
|
1064
1232
|
eval_script = """
|
|
1065
1233
|
import json
|
|
@@ -1087,6 +1255,26 @@ generate_input = load_fn('reference.py', 'generate_input')
|
|
|
1087
1255
|
|
|
1088
1256
|
import torch
|
|
1089
1257
|
|
|
1258
|
+
# Load defense module if available and defensive mode is enabled
|
|
1259
|
+
run_defensive = {run_defensive}
|
|
1260
|
+
defense = None
|
|
1261
|
+
if run_defensive:
|
|
1262
|
+
try:
|
|
1263
|
+
defense = load_fn('defense.py', 'run_all_defenses')
|
|
1264
|
+
time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
|
|
1265
|
+
print('[Defense] Defense module loaded')
|
|
1266
|
+
|
|
1267
|
+
# Wrap kernels for defense API compatibility
|
|
1268
|
+
# Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
|
|
1269
|
+
# These wrappers repack the unpacked args back into a tuple
|
|
1270
|
+
def _wrap_for_defense(kernel):
|
|
1271
|
+
return lambda *args: kernel(args)
|
|
1272
|
+
custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
|
|
1273
|
+
ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
|
|
1274
|
+
except Exception as e:
|
|
1275
|
+
print(f'[Defense] Warning: Could not load defense module: {{e}}')
|
|
1276
|
+
defense = None
|
|
1277
|
+
|
|
1090
1278
|
results = []
|
|
1091
1279
|
all_correct = True
|
|
1092
1280
|
total_time_ms = 0.0
|
|
@@ -1114,36 +1302,63 @@ for tc in test_cases:
|
|
|
1114
1302
|
impl_time_ms = 0.0
|
|
1115
1303
|
ref_time_ms = 0.0
|
|
1116
1304
|
if {run_benchmarks}:
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1305
|
+
if run_defensive and defense is not None:
|
|
1306
|
+
# Use full defense suite with wrapped kernels
|
|
1307
|
+
# inputs_list unpacks the tuple so defense can infer dtype/device from tensors
|
|
1308
|
+
inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
|
|
1309
|
+
|
|
1310
|
+
# Run defense checks
|
|
1311
|
+
all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
|
|
1312
|
+
if not all_passed:
|
|
1313
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
1314
|
+
raise ValueError(f"Defense checks failed: {{failed}}")
|
|
1315
|
+
|
|
1316
|
+
# Time with defensive timing (using wrapped kernels)
|
|
1317
|
+
impl_times, _ = time_with_defenses(
|
|
1318
|
+
custom_kernel_for_defense,
|
|
1319
|
+
inputs_list,
|
|
1320
|
+
num_warmup=3,
|
|
1321
|
+
num_trials=10,
|
|
1322
|
+
verbose=False,
|
|
1323
|
+
run_defenses=False,
|
|
1324
|
+
)
|
|
1325
|
+
impl_time_ms = sum(impl_times) / len(impl_times)
|
|
1326
|
+
|
|
1327
|
+
ref_times, _ = time_with_defenses(
|
|
1328
|
+
ref_kernel_for_defense,
|
|
1329
|
+
inputs_list,
|
|
1330
|
+
num_warmup=3,
|
|
1331
|
+
num_trials=10,
|
|
1332
|
+
verbose=False,
|
|
1333
|
+
run_defenses=False,
|
|
1334
|
+
)
|
|
1335
|
+
ref_time_ms = sum(ref_times) / len(ref_times)
|
|
1336
|
+
else:
|
|
1337
|
+
# Standard timing without full defenses
|
|
1338
|
+
# Warmup
|
|
1339
|
+
for _ in range(3):
|
|
1340
|
+
custom_kernel(inputs)
|
|
1341
|
+
torch.cuda.synchronize()
|
|
1342
|
+
|
|
1343
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
1344
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
1345
|
+
start.record()
|
|
1346
|
+
for _ in range(10):
|
|
1347
|
+
custom_kernel(inputs)
|
|
1348
|
+
end.record()
|
|
1349
|
+
torch.cuda.synchronize()
|
|
1350
|
+
impl_time_ms = start.elapsed_time(end) / 10
|
|
1351
|
+
|
|
1352
|
+
# Reference timing
|
|
1353
|
+
for _ in range(3):
|
|
1354
|
+
ref_kernel(inputs)
|
|
1355
|
+
torch.cuda.synchronize()
|
|
1356
|
+
start.record()
|
|
1357
|
+
for _ in range(10):
|
|
1358
|
+
ref_kernel(inputs)
|
|
1359
|
+
end.record()
|
|
1360
|
+
torch.cuda.synchronize()
|
|
1361
|
+
ref_time_ms = start.elapsed_time(end) / 10
|
|
1147
1362
|
|
|
1148
1363
|
total_time_ms += impl_time_ms
|
|
1149
1364
|
ref_total_time_ms += ref_time_ms
|
|
@@ -1236,6 +1451,23 @@ async def run_evaluate_modal(
|
|
|
1236
1451
|
ref_code_b64 = base64.b64encode(args.reference.read_bytes()).decode()
|
|
1237
1452
|
test_cases_b64 = base64.b64encode(args.test_cases.read_bytes()).decode()
|
|
1238
1453
|
|
|
1454
|
+
# Encode defense module if defensive mode is enabled
|
|
1455
|
+
defense_code_b64 = None
|
|
1456
|
+
if args.defensive:
|
|
1457
|
+
defense_path = (
|
|
1458
|
+
Path(__file__).parent.parent.parent.parent
|
|
1459
|
+
/ "packages"
|
|
1460
|
+
/ "wafer-core"
|
|
1461
|
+
/ "wafer_core"
|
|
1462
|
+
/ "utils"
|
|
1463
|
+
/ "kernel_utils"
|
|
1464
|
+
/ "defense.py"
|
|
1465
|
+
)
|
|
1466
|
+
if defense_path.exists():
|
|
1467
|
+
defense_code_b64 = base64.b64encode(defense_path.read_bytes()).decode()
|
|
1468
|
+
else:
|
|
1469
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
1470
|
+
|
|
1239
1471
|
# Build the script that creates sandbox and runs eval
|
|
1240
1472
|
script = _build_modal_sandbox_script(
|
|
1241
1473
|
target=target,
|
|
@@ -1244,6 +1476,7 @@ async def run_evaluate_modal(
|
|
|
1244
1476
|
test_cases_b64=test_cases_b64,
|
|
1245
1477
|
run_benchmarks=args.benchmark,
|
|
1246
1478
|
run_defensive=args.defensive,
|
|
1479
|
+
defense_code_b64=defense_code_b64,
|
|
1247
1480
|
)
|
|
1248
1481
|
|
|
1249
1482
|
def _run_subprocess() -> tuple[str, str, int]:
|
|
@@ -1341,6 +1574,7 @@ def _build_workspace_eval_script(
|
|
|
1341
1574
|
test_cases_json: str,
|
|
1342
1575
|
run_benchmarks: bool,
|
|
1343
1576
|
run_defensive: bool = False,
|
|
1577
|
+
defense_code: str | None = None,
|
|
1344
1578
|
) -> str:
|
|
1345
1579
|
"""Build inline evaluation script for workspace exec.
|
|
1346
1580
|
|
|
@@ -1351,6 +1585,7 @@ def _build_workspace_eval_script(
|
|
|
1351
1585
|
impl_b64 = base64.b64encode(impl_code.encode()).decode()
|
|
1352
1586
|
ref_b64 = base64.b64encode(ref_code.encode()).decode()
|
|
1353
1587
|
tests_b64 = base64.b64encode(test_cases_json.encode()).decode()
|
|
1588
|
+
defense_b64 = base64.b64encode(defense_code.encode()).decode() if defense_code else ""
|
|
1354
1589
|
|
|
1355
1590
|
return f'''
|
|
1356
1591
|
import base64
|
|
@@ -1370,6 +1605,15 @@ with open("/tmp/kernel.py", "w") as f:
|
|
|
1370
1605
|
with open("/tmp/reference.py", "w") as f:
|
|
1371
1606
|
f.write(ref_code)
|
|
1372
1607
|
|
|
1608
|
+
# Write defense module if available
|
|
1609
|
+
run_defensive = {run_defensive}
|
|
1610
|
+
defense_b64 = "{defense_b64}"
|
|
1611
|
+
# NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
|
|
1612
|
+
if run_defensive and defense_b64 and defense_b64 != "None":
|
|
1613
|
+
defense_code = base64.b64decode(defense_b64).decode()
|
|
1614
|
+
with open("/tmp/defense.py", "w") as f:
|
|
1615
|
+
f.write(defense_code)
|
|
1616
|
+
|
|
1373
1617
|
# Load kernels
|
|
1374
1618
|
def load_fn(path, name):
|
|
1375
1619
|
spec = importlib.util.spec_from_file_location("mod", path)
|
|
@@ -1383,6 +1627,24 @@ generate_input = load_fn("/tmp/reference.py", "generate_input")
|
|
|
1383
1627
|
|
|
1384
1628
|
import torch
|
|
1385
1629
|
|
|
1630
|
+
# Load defense module if available
|
|
1631
|
+
defense = None
|
|
1632
|
+
if run_defensive and defense_b64 and defense_b64 != "None":
|
|
1633
|
+
try:
|
|
1634
|
+
defense = load_fn("/tmp/defense.py", "run_all_defenses")
|
|
1635
|
+
time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
|
|
1636
|
+
print("[Defense] Defense module loaded")
|
|
1637
|
+
|
|
1638
|
+
# Wrap kernels for defense API compatibility
|
|
1639
|
+
# Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
|
|
1640
|
+
def _wrap_for_defense(kernel):
|
|
1641
|
+
return lambda *args: kernel(args)
|
|
1642
|
+
custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
|
|
1643
|
+
ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
|
|
1644
|
+
except Exception as e:
|
|
1645
|
+
print(f"[Defense] Warning: Could not load defense module: {{e}}")
|
|
1646
|
+
defense = None
|
|
1647
|
+
|
|
1386
1648
|
results = []
|
|
1387
1649
|
all_correct = True
|
|
1388
1650
|
total_time_ms = 0.0
|
|
@@ -1410,36 +1672,60 @@ for tc in test_cases:
|
|
|
1410
1672
|
impl_time_ms = 0.0
|
|
1411
1673
|
ref_time_ms = 0.0
|
|
1412
1674
|
if {run_benchmarks}:
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1675
|
+
if run_defensive and defense is not None:
|
|
1676
|
+
# Use full defense suite with wrapped kernels
|
|
1677
|
+
inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
|
|
1678
|
+
|
|
1679
|
+
# Run defense checks
|
|
1680
|
+
all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
|
|
1681
|
+
if not all_passed:
|
|
1682
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
1683
|
+
raise ValueError(f"Defense checks failed: {{failed}}")
|
|
1684
|
+
|
|
1685
|
+
# Time with defensive timing (using wrapped kernels)
|
|
1686
|
+
impl_times, _ = time_with_defenses(
|
|
1687
|
+
custom_kernel_for_defense,
|
|
1688
|
+
inputs_list,
|
|
1689
|
+
num_warmup=3,
|
|
1690
|
+
num_trials=10,
|
|
1691
|
+
verbose=False,
|
|
1692
|
+
run_defenses=False,
|
|
1693
|
+
)
|
|
1694
|
+
impl_time_ms = sum(impl_times) / len(impl_times)
|
|
1695
|
+
|
|
1696
|
+
ref_times, _ = time_with_defenses(
|
|
1697
|
+
ref_kernel_for_defense,
|
|
1698
|
+
inputs_list,
|
|
1699
|
+
num_warmup=3,
|
|
1700
|
+
num_trials=10,
|
|
1701
|
+
verbose=False,
|
|
1702
|
+
run_defenses=False,
|
|
1703
|
+
)
|
|
1704
|
+
ref_time_ms = sum(ref_times) / len(ref_times)
|
|
1705
|
+
else:
|
|
1706
|
+
# Standard timing
|
|
1707
|
+
for _ in range(3):
|
|
1708
|
+
custom_kernel(inputs)
|
|
1709
|
+
torch.cuda.synchronize()
|
|
1710
|
+
|
|
1711
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
1712
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
1713
|
+
start.record()
|
|
1714
|
+
for _ in range(10):
|
|
1715
|
+
custom_kernel(inputs)
|
|
1716
|
+
end.record()
|
|
1717
|
+
torch.cuda.synchronize()
|
|
1718
|
+
impl_time_ms = start.elapsed_time(end) / 10
|
|
1719
|
+
|
|
1720
|
+
for _ in range(3):
|
|
1721
|
+
ref_kernel(inputs)
|
|
1722
|
+
torch.cuda.synchronize()
|
|
1723
|
+
start.record()
|
|
1724
|
+
for _ in range(10):
|
|
1725
|
+
ref_kernel(inputs)
|
|
1726
|
+
end.record()
|
|
1727
|
+
torch.cuda.synchronize()
|
|
1728
|
+
ref_time_ms = start.elapsed_time(end) / 10
|
|
1443
1729
|
|
|
1444
1730
|
total_time_ms += impl_time_ms
|
|
1445
1731
|
ref_total_time_ms += ref_time_ms
|
|
@@ -1501,6 +1787,23 @@ async def run_evaluate_workspace(
|
|
|
1501
1787
|
ref_code = args.reference.read_text()
|
|
1502
1788
|
test_cases_json = args.test_cases.read_text()
|
|
1503
1789
|
|
|
1790
|
+
# Read defense module if defensive mode is enabled
|
|
1791
|
+
defense_code = None
|
|
1792
|
+
if args.defensive:
|
|
1793
|
+
defense_path = (
|
|
1794
|
+
Path(__file__).parent.parent.parent.parent
|
|
1795
|
+
/ "packages"
|
|
1796
|
+
/ "wafer-core"
|
|
1797
|
+
/ "wafer_core"
|
|
1798
|
+
/ "utils"
|
|
1799
|
+
/ "kernel_utils"
|
|
1800
|
+
/ "defense.py"
|
|
1801
|
+
)
|
|
1802
|
+
if defense_path.exists():
|
|
1803
|
+
defense_code = defense_path.read_text()
|
|
1804
|
+
else:
|
|
1805
|
+
print(f"Warning: defense.py not found at {defense_path}, falling back to basic defense")
|
|
1806
|
+
|
|
1504
1807
|
# Build inline eval script
|
|
1505
1808
|
eval_script = _build_workspace_eval_script(
|
|
1506
1809
|
impl_code=impl_code,
|
|
@@ -1508,6 +1811,7 @@ async def run_evaluate_workspace(
|
|
|
1508
1811
|
test_cases_json=test_cases_json,
|
|
1509
1812
|
run_benchmarks=args.benchmark,
|
|
1510
1813
|
run_defensive=args.defensive,
|
|
1814
|
+
defense_code=defense_code,
|
|
1511
1815
|
)
|
|
1512
1816
|
|
|
1513
1817
|
# Execute via workspace exec
|
|
@@ -1853,15 +2157,12 @@ async def run_evaluate_runpod(
|
|
|
1853
2157
|
# Add venv bin to PATH so ninja (from pip) is found by torch.utils.cpp_extension
|
|
1854
2158
|
venv_bin = env_state.venv_bin
|
|
1855
2159
|
env_vars = f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
|
|
1856
|
-
pythonpath = f"PYTHONPATH={wafer_core_workspace}"
|
|
1857
|
-
evaluate_script = (
|
|
1858
|
-
f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
|
|
1859
|
-
)
|
|
1860
2160
|
|
|
1861
2161
|
# Run from run_path so reference_kernel.py is importable
|
|
2162
|
+
# Use installed wafer-core module
|
|
1862
2163
|
eval_cmd = (
|
|
1863
2164
|
f"cd {run_path} && "
|
|
1864
|
-
f"{env_vars} {
|
|
2165
|
+
f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
|
|
1865
2166
|
f"--implementation {impl_path} "
|
|
1866
2167
|
f"--reference {ref_path} "
|
|
1867
2168
|
f"--test-cases {test_cases_path} "
|
|
@@ -2217,15 +2518,12 @@ async def run_evaluate_digitalocean(
|
|
|
2217
2518
|
env_vars = (
|
|
2218
2519
|
f"PATH={venv_bin}:$PATH HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm"
|
|
2219
2520
|
)
|
|
2220
|
-
pythonpath = f"PYTHONPATH={wafer_core_workspace}"
|
|
2221
|
-
evaluate_script = (
|
|
2222
|
-
f"{wafer_core_workspace}/wafer_core/utils/kernel_utils/evaluate.py"
|
|
2223
|
-
)
|
|
2224
2521
|
|
|
2225
2522
|
# Run from run_path so reference_kernel.py is importable
|
|
2523
|
+
# Use installed wafer-core module
|
|
2226
2524
|
eval_cmd = (
|
|
2227
2525
|
f"cd {run_path} && "
|
|
2228
|
-
f"{env_vars} {
|
|
2526
|
+
f"{env_vars} {python_exe} -m wafer_core.utils.kernel_utils.evaluate "
|
|
2229
2527
|
f"--implementation {impl_path} "
|
|
2230
2528
|
f"--reference {ref_path} "
|
|
2231
2529
|
f"--test-cases {test_cases_path} "
|
|
@@ -2405,7 +2703,9 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
|
|
|
2405
2703
|
print(f"Using target: {target_name}")
|
|
2406
2704
|
|
|
2407
2705
|
# Dispatch to appropriate executor
|
|
2408
|
-
if isinstance(target,
|
|
2706
|
+
if isinstance(target, LocalTarget):
|
|
2707
|
+
return await run_evaluate_local(args, target)
|
|
2708
|
+
elif isinstance(target, BaremetalTarget | VMTarget):
|
|
2409
2709
|
return await run_evaluate_ssh(args, target)
|
|
2410
2710
|
elif isinstance(target, ModalTarget):
|
|
2411
2711
|
return await run_evaluate_modal(args, target)
|
|
@@ -2435,10 +2735,233 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
|
|
|
2435
2735
|
# This runs inside the Docker container on the remote GPU
|
|
2436
2736
|
KERNELBENCH_EVAL_SCRIPT = """
|
|
2437
2737
|
import json
|
|
2738
|
+
import os
|
|
2438
2739
|
import sys
|
|
2439
2740
|
import time
|
|
2440
2741
|
import torch
|
|
2441
2742
|
import torch.nn as nn
|
|
2743
|
+
from pathlib import Path
|
|
2744
|
+
|
|
2745
|
+
|
|
2746
|
+
def run_profiling(model, inputs, name, output_dir):
|
|
2747
|
+
'''Run torch.profiler and return summary stats.'''
|
|
2748
|
+
from torch.profiler import profile, ProfilerActivity
|
|
2749
|
+
|
|
2750
|
+
# Determine activities based on backend
|
|
2751
|
+
activities = [ProfilerActivity.CPU]
|
|
2752
|
+
if torch.cuda.is_available():
|
|
2753
|
+
activities.append(ProfilerActivity.CUDA)
|
|
2754
|
+
|
|
2755
|
+
# Warmup
|
|
2756
|
+
for _ in range(3):
|
|
2757
|
+
with torch.no_grad():
|
|
2758
|
+
_ = model(*inputs)
|
|
2759
|
+
torch.cuda.synchronize()
|
|
2760
|
+
|
|
2761
|
+
# Profile
|
|
2762
|
+
with profile(
|
|
2763
|
+
activities=activities,
|
|
2764
|
+
record_shapes=True,
|
|
2765
|
+
with_stack=False,
|
|
2766
|
+
profile_memory=True,
|
|
2767
|
+
) as prof:
|
|
2768
|
+
with torch.no_grad():
|
|
2769
|
+
_ = model(*inputs)
|
|
2770
|
+
torch.cuda.synchronize()
|
|
2771
|
+
|
|
2772
|
+
# Get key averages
|
|
2773
|
+
key_averages = prof.key_averages()
|
|
2774
|
+
|
|
2775
|
+
# Find the main kernel (longest GPU time)
|
|
2776
|
+
# Use cuda_time_total for compatibility with both CUDA and ROCm
|
|
2777
|
+
def get_gpu_time(e):
|
|
2778
|
+
# Try different attributes for GPU time
|
|
2779
|
+
if hasattr(e, 'cuda_time_total'):
|
|
2780
|
+
return e.cuda_time_total
|
|
2781
|
+
if hasattr(e, 'device_time_total'):
|
|
2782
|
+
return e.device_time_total
|
|
2783
|
+
if hasattr(e, 'self_cuda_time_total'):
|
|
2784
|
+
return e.self_cuda_time_total
|
|
2785
|
+
return 0
|
|
2786
|
+
|
|
2787
|
+
gpu_events = [e for e in key_averages if get_gpu_time(e) > 0]
|
|
2788
|
+
gpu_events.sort(key=lambda e: get_gpu_time(e), reverse=True)
|
|
2789
|
+
|
|
2790
|
+
stats = {
|
|
2791
|
+
"name": name,
|
|
2792
|
+
"total_gpu_time_ms": sum(get_gpu_time(e) for e in gpu_events) / 1000,
|
|
2793
|
+
"total_cpu_time_ms": sum(e.cpu_time_total for e in key_averages) / 1000,
|
|
2794
|
+
"num_gpu_kernels": len(gpu_events),
|
|
2795
|
+
"top_kernels": [],
|
|
2796
|
+
}
|
|
2797
|
+
|
|
2798
|
+
# Top 5 kernels by GPU time
|
|
2799
|
+
for e in gpu_events[:5]:
|
|
2800
|
+
stats["top_kernels"].append({
|
|
2801
|
+
"name": e.key,
|
|
2802
|
+
"gpu_time_ms": get_gpu_time(e) / 1000,
|
|
2803
|
+
"cpu_time_ms": e.cpu_time_total / 1000,
|
|
2804
|
+
"calls": e.count,
|
|
2805
|
+
})
|
|
2806
|
+
|
|
2807
|
+
# Save trace for visualization
|
|
2808
|
+
trace_path = Path(output_dir) / f"{name}_trace.json"
|
|
2809
|
+
prof.export_chrome_trace(str(trace_path))
|
|
2810
|
+
stats["trace_file"] = str(trace_path)
|
|
2811
|
+
|
|
2812
|
+
return stats
|
|
2813
|
+
|
|
2814
|
+
|
|
2815
|
+
def validate_custom_inputs(original_inputs, custom_inputs):
|
|
2816
|
+
'''Validate that custom inputs match the expected signature.
|
|
2817
|
+
|
|
2818
|
+
Returns (is_valid, error_message).
|
|
2819
|
+
'''
|
|
2820
|
+
if len(original_inputs) != len(custom_inputs):
|
|
2821
|
+
return False, f"get_inputs() must return {len(original_inputs)} tensors, got {len(custom_inputs)}"
|
|
2822
|
+
|
|
2823
|
+
for i, (orig, cust) in enumerate(zip(original_inputs, custom_inputs)):
|
|
2824
|
+
if not isinstance(cust, torch.Tensor):
|
|
2825
|
+
if not isinstance(orig, torch.Tensor):
|
|
2826
|
+
continue # Both non-tensor, ok
|
|
2827
|
+
return False, f"Input {i}: expected Tensor, got {type(cust).__name__}"
|
|
2828
|
+
|
|
2829
|
+
if not isinstance(orig, torch.Tensor):
|
|
2830
|
+
return False, f"Input {i}: expected {type(orig).__name__}, got Tensor"
|
|
2831
|
+
|
|
2832
|
+
if orig.dtype != cust.dtype:
|
|
2833
|
+
return False, f"Input {i}: dtype mismatch - expected {orig.dtype}, got {cust.dtype}"
|
|
2834
|
+
|
|
2835
|
+
if orig.dim() != cust.dim():
|
|
2836
|
+
return False, f"Input {i}: dimension mismatch - expected {orig.dim()}D, got {cust.dim()}D"
|
|
2837
|
+
|
|
2838
|
+
return True, None
|
|
2839
|
+
|
|
2840
|
+
|
|
2841
|
+
def analyze_diff(ref_output, new_output, rtol=1e-3, atol=1e-3, max_samples=5):
|
|
2842
|
+
'''Analyze differences between reference and implementation outputs.
|
|
2843
|
+
|
|
2844
|
+
Returns a dict with detailed diff information.
|
|
2845
|
+
'''
|
|
2846
|
+
diff = (ref_output - new_output).abs()
|
|
2847
|
+
threshold = atol + rtol * ref_output.abs()
|
|
2848
|
+
wrong_mask = diff > threshold
|
|
2849
|
+
|
|
2850
|
+
total_elements = ref_output.numel()
|
|
2851
|
+
wrong_count = wrong_mask.sum().item()
|
|
2852
|
+
|
|
2853
|
+
# Basic stats
|
|
2854
|
+
max_diff = diff.max().item()
|
|
2855
|
+
max_diff_idx = tuple(torch.unravel_index(diff.argmax(), diff.shape))
|
|
2856
|
+
max_diff_idx = tuple(int(i) for i in max_diff_idx) # Convert to Python ints
|
|
2857
|
+
|
|
2858
|
+
# Relative error (avoid div by zero)
|
|
2859
|
+
ref_abs = ref_output.abs()
|
|
2860
|
+
nonzero_mask = ref_abs > 1e-8
|
|
2861
|
+
if nonzero_mask.any():
|
|
2862
|
+
rel_error = diff[nonzero_mask] / ref_abs[nonzero_mask]
|
|
2863
|
+
max_rel_error = rel_error.max().item()
|
|
2864
|
+
mean_rel_error = rel_error.mean().item()
|
|
2865
|
+
else:
|
|
2866
|
+
max_rel_error = float('inf') if max_diff > 0 else 0.0
|
|
2867
|
+
mean_rel_error = max_rel_error
|
|
2868
|
+
|
|
2869
|
+
# Error histogram (buckets: <1e-6, 1e-6 to 1e-4, 1e-4 to 1e-2, 1e-2 to 1, >1)
|
|
2870
|
+
histogram = {
|
|
2871
|
+
'<1e-6': int((diff < 1e-6).sum().item()),
|
|
2872
|
+
'1e-6 to 1e-4': int(((diff >= 1e-6) & (diff < 1e-4)).sum().item()),
|
|
2873
|
+
'1e-4 to 1e-2': int(((diff >= 1e-4) & (diff < 1e-2)).sum().item()),
|
|
2874
|
+
'1e-2 to 1': int(((diff >= 1e-2) & (diff < 1)).sum().item()),
|
|
2875
|
+
'>1': int((diff >= 1).sum().item()),
|
|
2876
|
+
}
|
|
2877
|
+
|
|
2878
|
+
result = {
|
|
2879
|
+
'max_diff': max_diff,
|
|
2880
|
+
'max_diff_idx': max_diff_idx,
|
|
2881
|
+
'mean_diff': diff.mean().item(),
|
|
2882
|
+
'max_rel_error': max_rel_error,
|
|
2883
|
+
'mean_rel_error': mean_rel_error,
|
|
2884
|
+
'total_elements': total_elements,
|
|
2885
|
+
'wrong_count': int(wrong_count),
|
|
2886
|
+
'wrong_pct': 100.0 * wrong_count / total_elements,
|
|
2887
|
+
'histogram': histogram,
|
|
2888
|
+
'samples': [],
|
|
2889
|
+
}
|
|
2890
|
+
|
|
2891
|
+
# Get indices of wrong elements
|
|
2892
|
+
if wrong_count > 0:
|
|
2893
|
+
wrong_indices = torch.nonzero(wrong_mask, as_tuple=False)
|
|
2894
|
+
|
|
2895
|
+
# Take first N samples
|
|
2896
|
+
num_samples = min(max_samples, len(wrong_indices))
|
|
2897
|
+
for i in range(num_samples):
|
|
2898
|
+
idx = tuple(wrong_indices[i].tolist())
|
|
2899
|
+
ref_val = ref_output[idx].item()
|
|
2900
|
+
new_val = new_output[idx].item()
|
|
2901
|
+
diff_val = diff[idx].item()
|
|
2902
|
+
result['samples'].append({
|
|
2903
|
+
'index': idx,
|
|
2904
|
+
'ref': ref_val,
|
|
2905
|
+
'impl': new_val,
|
|
2906
|
+
'diff': diff_val,
|
|
2907
|
+
})
|
|
2908
|
+
|
|
2909
|
+
# Try to detect pattern
|
|
2910
|
+
if wrong_count >= total_elements * 0.99:
|
|
2911
|
+
result['pattern'] = 'all_wrong'
|
|
2912
|
+
elif wrong_count < total_elements * 0.01:
|
|
2913
|
+
# Check if failures are at boundaries
|
|
2914
|
+
shape = ref_output.shape
|
|
2915
|
+
boundary_count = 0
|
|
2916
|
+
for idx in wrong_indices[:min(100, len(wrong_indices))]:
|
|
2917
|
+
idx_list = idx.tolist()
|
|
2918
|
+
is_boundary = any(i == 0 or i == s - 1 for i, s in zip(idx_list, shape))
|
|
2919
|
+
if is_boundary:
|
|
2920
|
+
boundary_count += 1
|
|
2921
|
+
if boundary_count > len(wrong_indices[:100]) * 0.8:
|
|
2922
|
+
result['pattern'] = 'boundary_issue'
|
|
2923
|
+
else:
|
|
2924
|
+
result['pattern'] = 'scattered'
|
|
2925
|
+
else:
|
|
2926
|
+
result['pattern'] = 'partial'
|
|
2927
|
+
|
|
2928
|
+
return result
|
|
2929
|
+
|
|
2930
|
+
|
|
2931
|
+
def print_diff_analysis(analysis):
|
|
2932
|
+
'''Print a human-readable diff analysis.'''
|
|
2933
|
+
print(f"[KernelBench] Diff analysis:")
|
|
2934
|
+
|
|
2935
|
+
# Max diff with location
|
|
2936
|
+
idx_str = ','.join(str(i) for i in analysis['max_diff_idx'])
|
|
2937
|
+
print(f" Max diff: {analysis['max_diff']:.6f} at index [{idx_str}]")
|
|
2938
|
+
print(f" Mean diff: {analysis['mean_diff']:.6f}")
|
|
2939
|
+
|
|
2940
|
+
# Relative errors
|
|
2941
|
+
print(f" Max relative error: {analysis['max_rel_error']:.2%}, Mean: {analysis['mean_rel_error']:.2%}")
|
|
2942
|
+
|
|
2943
|
+
# Wrong count
|
|
2944
|
+
print(f" Wrong elements: {analysis['wrong_count']:,} / {analysis['total_elements']:,} ({analysis['wrong_pct']:.2f}%)")
|
|
2945
|
+
|
|
2946
|
+
# Histogram
|
|
2947
|
+
hist = analysis['histogram']
|
|
2948
|
+
print(f" Error distribution: <1e-6: {hist['<1e-6']:,} | 1e-6~1e-4: {hist['1e-6 to 1e-4']:,} | 1e-4~1e-2: {hist['1e-4 to 1e-2']:,} | 1e-2~1: {hist['1e-2 to 1']:,} | >1: {hist['>1']:,}")
|
|
2949
|
+
|
|
2950
|
+
if 'pattern' in analysis:
|
|
2951
|
+
pattern_desc = {
|
|
2952
|
+
'all_wrong': 'ALL elements wrong - likely algorithmic error or wrong weights',
|
|
2953
|
+
'boundary_issue': 'Mostly BOUNDARY elements wrong - check edge handling',
|
|
2954
|
+
'scattered': 'SCATTERED failures - numerical precision issue?',
|
|
2955
|
+
'partial': 'PARTIAL failures - check specific conditions',
|
|
2956
|
+
}
|
|
2957
|
+
print(f" Pattern: {pattern_desc.get(analysis['pattern'], analysis['pattern'])}")
|
|
2958
|
+
|
|
2959
|
+
if analysis['samples']:
|
|
2960
|
+
print(f" Sample failures:")
|
|
2961
|
+
for s in analysis['samples']:
|
|
2962
|
+
idx_str = ','.join(str(i) for i in s['index'])
|
|
2963
|
+
print(f" [{idx_str}]: ref={s['ref']:.6f} impl={s['impl']:.6f} (diff={s['diff']:.6f})")
|
|
2964
|
+
|
|
2442
2965
|
|
|
2443
2966
|
def main():
|
|
2444
2967
|
# Parse args
|
|
@@ -2446,12 +2969,35 @@ def main():
|
|
|
2446
2969
|
parser = argparse.ArgumentParser()
|
|
2447
2970
|
parser.add_argument("--impl", required=True)
|
|
2448
2971
|
parser.add_argument("--reference", required=True)
|
|
2972
|
+
parser.add_argument("--inputs", help="Custom inputs file to override get_inputs()/get_init_inputs()")
|
|
2449
2973
|
parser.add_argument("--benchmark", action="store_true")
|
|
2974
|
+
parser.add_argument("--profile", action="store_true")
|
|
2975
|
+
parser.add_argument("--defensive", action="store_true", help="Run full defense checks against reward hacking")
|
|
2976
|
+
parser.add_argument("--defense-module", help="Path to defense.py module")
|
|
2977
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
|
2450
2978
|
parser.add_argument("--num-correct-trials", type=int, default=3)
|
|
2451
2979
|
parser.add_argument("--num-perf-trials", type=int, default=10)
|
|
2452
2980
|
parser.add_argument("--output", required=True)
|
|
2453
2981
|
args = parser.parse_args()
|
|
2454
2982
|
|
|
2983
|
+
# Load defense module if defensive mode is enabled
|
|
2984
|
+
defense_module = None
|
|
2985
|
+
if args.defensive and args.defense_module:
|
|
2986
|
+
try:
|
|
2987
|
+
import importlib.util
|
|
2988
|
+
defense_spec = importlib.util.spec_from_file_location("defense", args.defense_module)
|
|
2989
|
+
defense_module = importlib.util.module_from_spec(defense_spec)
|
|
2990
|
+
defense_spec.loader.exec_module(defense_module)
|
|
2991
|
+
print("[KernelBench] Defense module loaded")
|
|
2992
|
+
except Exception as e:
|
|
2993
|
+
print(f"[KernelBench] Warning: Could not load defense module: {e}")
|
|
2994
|
+
|
|
2995
|
+
# Create output directory for profiles
|
|
2996
|
+
output_dir = Path(args.output).parent
|
|
2997
|
+
profile_dir = output_dir / "profiles"
|
|
2998
|
+
if args.profile:
|
|
2999
|
+
profile_dir.mkdir(exist_ok=True)
|
|
3000
|
+
|
|
2455
3001
|
results = {
|
|
2456
3002
|
"compiled": False,
|
|
2457
3003
|
"correct": False,
|
|
@@ -2472,6 +3018,33 @@ def main():
|
|
|
2472
3018
|
get_inputs = ref_module.get_inputs
|
|
2473
3019
|
get_init_inputs = ref_module.get_init_inputs
|
|
2474
3020
|
|
|
3021
|
+
# Load custom inputs if provided
|
|
3022
|
+
if args.inputs:
|
|
3023
|
+
inputs_spec = importlib.util.spec_from_file_location("custom_inputs", args.inputs)
|
|
3024
|
+
inputs_module = importlib.util.module_from_spec(inputs_spec)
|
|
3025
|
+
inputs_spec.loader.exec_module(inputs_module)
|
|
3026
|
+
|
|
3027
|
+
# Validate custom inputs match expected signature
|
|
3028
|
+
original_inputs = get_inputs()
|
|
3029
|
+
custom_get_inputs = inputs_module.get_inputs
|
|
3030
|
+
custom_inputs = custom_get_inputs()
|
|
3031
|
+
|
|
3032
|
+
is_valid, error_msg = validate_custom_inputs(original_inputs, custom_inputs)
|
|
3033
|
+
if not is_valid:
|
|
3034
|
+
print(f"[KernelBench] Custom inputs validation failed: {error_msg}")
|
|
3035
|
+
results["error"] = f"Custom inputs validation failed: {error_msg}"
|
|
3036
|
+
raise ValueError(error_msg)
|
|
3037
|
+
|
|
3038
|
+
# Override get_inputs (and optionally get_init_inputs)
|
|
3039
|
+
get_inputs = custom_get_inputs
|
|
3040
|
+
if hasattr(inputs_module, 'get_init_inputs'):
|
|
3041
|
+
get_init_inputs = inputs_module.get_init_inputs
|
|
3042
|
+
|
|
3043
|
+
# Show what changed
|
|
3044
|
+
orig_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in original_inputs]
|
|
3045
|
+
cust_shapes = [tuple(t.shape) if hasattr(t, 'shape') else type(t).__name__ for t in custom_inputs]
|
|
3046
|
+
print(f"[KernelBench] Using custom inputs: {orig_shapes} -> {cust_shapes}")
|
|
3047
|
+
|
|
2475
3048
|
# Load implementation module
|
|
2476
3049
|
impl_spec = importlib.util.spec_from_file_location("implementation", args.impl)
|
|
2477
3050
|
impl_module = importlib.util.module_from_spec(impl_spec)
|
|
@@ -2481,12 +3054,19 @@ def main():
|
|
|
2481
3054
|
results["compiled"] = True
|
|
2482
3055
|
print("[KernelBench] Modules loaded successfully")
|
|
2483
3056
|
|
|
2484
|
-
# Instantiate models
|
|
3057
|
+
# Instantiate models with synchronized seeds for reproducible weights
|
|
3058
|
+
# (matches upstream KernelBench behavior in src/eval.py)
|
|
3059
|
+
seed = args.seed
|
|
2485
3060
|
init_inputs = get_init_inputs()
|
|
2486
3061
|
with torch.no_grad():
|
|
3062
|
+
torch.manual_seed(seed)
|
|
3063
|
+
torch.cuda.manual_seed(seed)
|
|
2487
3064
|
ref_model = Model(*init_inputs).cuda().eval()
|
|
3065
|
+
|
|
3066
|
+
torch.manual_seed(seed)
|
|
3067
|
+
torch.cuda.manual_seed(seed)
|
|
2488
3068
|
new_model = ModelNew(*init_inputs).cuda().eval()
|
|
2489
|
-
print("[KernelBench] Models instantiated")
|
|
3069
|
+
print(f"[KernelBench] Models instantiated (seed={seed})")
|
|
2490
3070
|
|
|
2491
3071
|
# Run correctness trials
|
|
2492
3072
|
all_correct = True
|
|
@@ -2502,8 +3082,18 @@ def main():
|
|
|
2502
3082
|
if isinstance(ref_output, torch.Tensor):
|
|
2503
3083
|
if not torch.allclose(ref_output, new_output, rtol=1e-3, atol=1e-3):
|
|
2504
3084
|
all_correct = False
|
|
2505
|
-
|
|
2506
|
-
results["error"] = f"Correctness failed on trial {trial+1}: max diff = {max_diff}"
|
|
3085
|
+
analysis = analyze_diff(ref_output, new_output)
|
|
3086
|
+
results["error"] = f"Correctness failed on trial {trial+1}: max diff = {analysis['max_diff']}"
|
|
3087
|
+
results["diff_analysis"] = analysis
|
|
3088
|
+
print_diff_analysis(analysis)
|
|
3089
|
+
|
|
3090
|
+
# Save tensors for debugging
|
|
3091
|
+
debug_dir = output_dir / "debug"
|
|
3092
|
+
debug_dir.mkdir(exist_ok=True)
|
|
3093
|
+
torch.save(ref_output.cpu(), debug_dir / "ref_output.pt")
|
|
3094
|
+
torch.save(new_output.cpu(), debug_dir / "impl_output.pt")
|
|
3095
|
+
torch.save(inputs[0].cpu() if inputs else None, debug_dir / "input.pt")
|
|
3096
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
2507
3097
|
break
|
|
2508
3098
|
else:
|
|
2509
3099
|
# Handle tuple/list outputs
|
|
@@ -2511,8 +3101,17 @@ def main():
|
|
|
2511
3101
|
if isinstance(r, torch.Tensor):
|
|
2512
3102
|
if not torch.allclose(r, n, rtol=1e-3, atol=1e-3):
|
|
2513
3103
|
all_correct = False
|
|
2514
|
-
|
|
2515
|
-
results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {max_diff}"
|
|
3104
|
+
analysis = analyze_diff(r, n)
|
|
3105
|
+
results["error"] = f"Correctness failed on trial {trial+1}, output {i}: max diff = {analysis['max_diff']}"
|
|
3106
|
+
results["diff_analysis"] = analysis
|
|
3107
|
+
print_diff_analysis(analysis)
|
|
3108
|
+
|
|
3109
|
+
# Save tensors for debugging
|
|
3110
|
+
debug_dir = output_dir / "debug"
|
|
3111
|
+
debug_dir.mkdir(exist_ok=True)
|
|
3112
|
+
torch.save(r.cpu(), debug_dir / f"ref_output_{i}.pt")
|
|
3113
|
+
torch.save(n.cpu(), debug_dir / f"impl_output_{i}.pt")
|
|
3114
|
+
print(f"[KernelBench] Debug tensors saved to: {debug_dir}/")
|
|
2516
3115
|
break
|
|
2517
3116
|
if not all_correct:
|
|
2518
3117
|
break
|
|
@@ -2526,47 +3125,132 @@ def main():
|
|
|
2526
3125
|
inputs = get_inputs()
|
|
2527
3126
|
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
2528
3127
|
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
# Benchmark new model
|
|
2536
|
-
start = torch.cuda.Event(enable_timing=True)
|
|
2537
|
-
end = torch.cuda.Event(enable_timing=True)
|
|
3128
|
+
if args.defensive and defense_module is not None:
|
|
3129
|
+
# Use full defense suite
|
|
3130
|
+
print("[KernelBench] Running defense checks on implementation...")
|
|
3131
|
+
run_all_defenses = defense_module.run_all_defenses
|
|
3132
|
+
time_with_defenses = defense_module.time_execution_with_defenses
|
|
2538
3133
|
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
3134
|
+
# Run defense checks on implementation
|
|
3135
|
+
all_passed, defense_results, _ = run_all_defenses(
|
|
3136
|
+
lambda *x: new_model(*x),
|
|
3137
|
+
*inputs,
|
|
3138
|
+
)
|
|
3139
|
+
results["defense_results"] = {
|
|
3140
|
+
name: {"passed": passed, "message": msg}
|
|
3141
|
+
for name, passed, msg in defense_results
|
|
3142
|
+
}
|
|
3143
|
+
if not all_passed:
|
|
3144
|
+
failed = [name for name, passed, _ in defense_results if not passed]
|
|
3145
|
+
results["error"] = f"Defense checks failed: {failed}"
|
|
3146
|
+
print(f"[KernelBench] Defense checks FAILED: {failed}")
|
|
3147
|
+
for name, passed, msg in defense_results:
|
|
3148
|
+
status = "PASS" if passed else "FAIL"
|
|
3149
|
+
print(f" [{status}] {name}: {msg}")
|
|
3150
|
+
else:
|
|
3151
|
+
print("[KernelBench] All defense checks passed")
|
|
3152
|
+
|
|
3153
|
+
# Time with defensive timing
|
|
3154
|
+
impl_times, _ = time_with_defenses(
|
|
3155
|
+
lambda: new_model(*inputs),
|
|
3156
|
+
[],
|
|
3157
|
+
num_warmup=5,
|
|
3158
|
+
num_trials=args.num_perf_trials,
|
|
3159
|
+
verbose=False,
|
|
3160
|
+
run_defenses=False, # Already ran above
|
|
3161
|
+
)
|
|
3162
|
+
new_time = sum(impl_times) / len(impl_times)
|
|
3163
|
+
results["runtime_ms"] = new_time
|
|
3164
|
+
|
|
3165
|
+
# Reference timing
|
|
3166
|
+
ref_times, _ = time_with_defenses(
|
|
3167
|
+
lambda: ref_model(*inputs),
|
|
3168
|
+
[],
|
|
3169
|
+
num_warmup=5,
|
|
3170
|
+
num_trials=args.num_perf_trials,
|
|
3171
|
+
verbose=False,
|
|
3172
|
+
run_defenses=False,
|
|
3173
|
+
)
|
|
3174
|
+
ref_time = sum(ref_times) / len(ref_times)
|
|
3175
|
+
results["reference_runtime_ms"] = ref_time
|
|
3176
|
+
results["speedup"] = ref_time / new_time if new_time > 0 else 0
|
|
3177
|
+
print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
|
|
3178
|
+
else:
|
|
3179
|
+
# Standard timing without full defenses
|
|
3180
|
+
# Warmup
|
|
3181
|
+
for _ in range(5):
|
|
3182
|
+
with torch.no_grad():
|
|
3183
|
+
_ = new_model(*inputs)
|
|
2545
3184
|
torch.cuda.synchronize()
|
|
2546
|
-
times.append(start.elapsed_time(end))
|
|
2547
|
-
|
|
2548
|
-
new_time = sum(times) / len(times)
|
|
2549
|
-
results["runtime_ms"] = new_time
|
|
2550
3185
|
|
|
2551
|
-
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
3186
|
+
# Benchmark new model
|
|
3187
|
+
start = torch.cuda.Event(enable_timing=True)
|
|
3188
|
+
end = torch.cuda.Event(enable_timing=True)
|
|
3189
|
+
|
|
3190
|
+
times = []
|
|
3191
|
+
for _ in range(args.num_perf_trials):
|
|
3192
|
+
start.record()
|
|
3193
|
+
with torch.no_grad():
|
|
3194
|
+
_ = new_model(*inputs)
|
|
3195
|
+
end.record()
|
|
3196
|
+
torch.cuda.synchronize()
|
|
3197
|
+
times.append(start.elapsed_time(end))
|
|
3198
|
+
|
|
3199
|
+
new_time = sum(times) / len(times)
|
|
3200
|
+
results["runtime_ms"] = new_time
|
|
3201
|
+
|
|
3202
|
+
# Benchmark reference model
|
|
3203
|
+
for _ in range(5):
|
|
3204
|
+
with torch.no_grad():
|
|
3205
|
+
_ = ref_model(*inputs)
|
|
2563
3206
|
torch.cuda.synchronize()
|
|
2564
|
-
times.append(start.elapsed_time(end))
|
|
2565
3207
|
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
3208
|
+
times = []
|
|
3209
|
+
for _ in range(args.num_perf_trials):
|
|
3210
|
+
start.record()
|
|
3211
|
+
with torch.no_grad():
|
|
3212
|
+
_ = ref_model(*inputs)
|
|
3213
|
+
end.record()
|
|
3214
|
+
torch.cuda.synchronize()
|
|
3215
|
+
times.append(start.elapsed_time(end))
|
|
3216
|
+
|
|
3217
|
+
ref_time = sum(times) / len(times)
|
|
3218
|
+
results["reference_runtime_ms"] = ref_time
|
|
3219
|
+
results["speedup"] = ref_time / new_time if new_time > 0 else 0
|
|
3220
|
+
print(f"[KernelBench] New: {new_time:.3f}ms, Ref: {ref_time:.3f}ms, Speedup: {results['speedup']:.2f}x")
|
|
3221
|
+
|
|
3222
|
+
# Run profiling if requested and correctness passed
|
|
3223
|
+
if args.profile and all_correct:
|
|
3224
|
+
print("[KernelBench] Running profiler...")
|
|
3225
|
+
inputs = get_inputs()
|
|
3226
|
+
inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
|
|
3227
|
+
|
|
3228
|
+
try:
|
|
3229
|
+
# Profile implementation
|
|
3230
|
+
impl_stats = run_profiling(new_model, inputs, "implementation", str(profile_dir))
|
|
3231
|
+
results["profile_impl"] = impl_stats
|
|
3232
|
+
print(f"[KernelBench] Implementation profile:")
|
|
3233
|
+
print(f" Total GPU time: {impl_stats['total_gpu_time_ms']:.3f}ms")
|
|
3234
|
+
print(f" Kernels launched: {impl_stats['num_gpu_kernels']}")
|
|
3235
|
+
if impl_stats['top_kernels']:
|
|
3236
|
+
print(f" Top kernel: {impl_stats['top_kernels'][0]['name'][:60]}...")
|
|
3237
|
+
print(f" {impl_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
|
|
3238
|
+
|
|
3239
|
+
# Profile reference
|
|
3240
|
+
ref_stats = run_profiling(ref_model, inputs, "reference", str(profile_dir))
|
|
3241
|
+
results["profile_ref"] = ref_stats
|
|
3242
|
+
print(f"[KernelBench] Reference profile:")
|
|
3243
|
+
print(f" Total GPU time: {ref_stats['total_gpu_time_ms']:.3f}ms")
|
|
3244
|
+
print(f" Kernels launched: {ref_stats['num_gpu_kernels']}")
|
|
3245
|
+
if ref_stats['top_kernels']:
|
|
3246
|
+
print(f" Top kernel: {ref_stats['top_kernels'][0]['name'][:60]}...")
|
|
3247
|
+
print(f" {ref_stats['top_kernels'][0]['gpu_time_ms']:.3f}ms")
|
|
3248
|
+
|
|
3249
|
+
print(f"[KernelBench] Profile traces saved to: {profile_dir}/")
|
|
3250
|
+
|
|
3251
|
+
except Exception as prof_err:
|
|
3252
|
+
print(f"[KernelBench] Profiling failed: {prof_err}")
|
|
3253
|
+
results["profile_error"] = str(prof_err)
|
|
2570
3254
|
|
|
2571
3255
|
except Exception as e:
|
|
2572
3256
|
import traceback
|
|
@@ -2705,6 +3389,24 @@ async def run_evaluate_kernelbench_docker(
|
|
|
2705
3389
|
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
2706
3390
|
)
|
|
2707
3391
|
|
|
3392
|
+
# Write custom inputs if provided
|
|
3393
|
+
if args.inputs:
|
|
3394
|
+
inputs_code = args.inputs.read_text()
|
|
3395
|
+
inputs_file_path = f"{run_path}/custom_inputs.py"
|
|
3396
|
+
write_result = await client.exec(
|
|
3397
|
+
f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
3398
|
+
)
|
|
3399
|
+
if write_result.exit_code != 0:
|
|
3400
|
+
return EvaluateResult(
|
|
3401
|
+
success=False,
|
|
3402
|
+
all_correct=False,
|
|
3403
|
+
correctness_score=0.0,
|
|
3404
|
+
geomean_speedup=0.0,
|
|
3405
|
+
passed_tests=0,
|
|
3406
|
+
total_tests=0,
|
|
3407
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
3408
|
+
)
|
|
3409
|
+
|
|
2708
3410
|
# Write eval script
|
|
2709
3411
|
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
2710
3412
|
write_result = await client.exec(
|
|
@@ -2721,14 +3423,40 @@ async def run_evaluate_kernelbench_docker(
|
|
|
2721
3423
|
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
2722
3424
|
)
|
|
2723
3425
|
|
|
3426
|
+
# Write defense module if defensive mode is enabled
|
|
3427
|
+
defense_module_path = None
|
|
3428
|
+
if args.defensive:
|
|
3429
|
+
defense_path = (
|
|
3430
|
+
Path(__file__).parent.parent.parent.parent
|
|
3431
|
+
/ "packages"
|
|
3432
|
+
/ "wafer-core"
|
|
3433
|
+
/ "wafer_core"
|
|
3434
|
+
/ "utils"
|
|
3435
|
+
/ "kernel_utils"
|
|
3436
|
+
/ "defense.py"
|
|
3437
|
+
)
|
|
3438
|
+
if defense_path.exists():
|
|
3439
|
+
defense_code = defense_path.read_text()
|
|
3440
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
3441
|
+
write_result = await client.exec(
|
|
3442
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
3443
|
+
)
|
|
3444
|
+
if write_result.exit_code != 0:
|
|
3445
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
3446
|
+
defense_module_path = None
|
|
3447
|
+
else:
|
|
3448
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
3449
|
+
|
|
2724
3450
|
print("Running KernelBench evaluation in Docker container...")
|
|
2725
3451
|
|
|
2726
3452
|
# Paths inside container
|
|
2727
3453
|
container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
|
|
2728
3454
|
container_impl_path = f"{container_run_path}/implementation.py"
|
|
2729
3455
|
container_ref_path = f"{container_run_path}/reference.py"
|
|
3456
|
+
container_inputs_path = f"{container_run_path}/custom_inputs.py" if args.inputs else None
|
|
2730
3457
|
container_eval_script = f"{container_run_path}/kernelbench_eval.py"
|
|
2731
3458
|
container_output = f"{container_run_path}/results.json"
|
|
3459
|
+
container_defense_path = f"{container_run_path}/defense.py" if defense_module_path else None
|
|
2732
3460
|
|
|
2733
3461
|
# Build eval command
|
|
2734
3462
|
python_cmd_parts = [
|
|
@@ -2740,6 +3468,14 @@ async def run_evaluate_kernelbench_docker(
|
|
|
2740
3468
|
|
|
2741
3469
|
if args.benchmark:
|
|
2742
3470
|
python_cmd_parts.append("--benchmark")
|
|
3471
|
+
if args.profile:
|
|
3472
|
+
python_cmd_parts.append("--profile")
|
|
3473
|
+
if container_inputs_path:
|
|
3474
|
+
python_cmd_parts.append(f"--inputs {container_inputs_path}")
|
|
3475
|
+
if args.defensive and container_defense_path:
|
|
3476
|
+
python_cmd_parts.append("--defensive")
|
|
3477
|
+
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3478
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
2743
3479
|
|
|
2744
3480
|
eval_cmd = " ".join(python_cmd_parts)
|
|
2745
3481
|
|
|
@@ -2920,6 +3656,24 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
2920
3656
|
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
2921
3657
|
)
|
|
2922
3658
|
|
|
3659
|
+
# Write custom inputs if provided
|
|
3660
|
+
if args.inputs:
|
|
3661
|
+
inputs_code = args.inputs.read_text()
|
|
3662
|
+
inputs_file_path = f"{run_path}/custom_inputs.py"
|
|
3663
|
+
write_result = await client.exec(
|
|
3664
|
+
f"cat > '{inputs_file_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
3665
|
+
)
|
|
3666
|
+
if write_result.exit_code != 0:
|
|
3667
|
+
return EvaluateResult(
|
|
3668
|
+
success=False,
|
|
3669
|
+
all_correct=False,
|
|
3670
|
+
correctness_score=0.0,
|
|
3671
|
+
geomean_speedup=0.0,
|
|
3672
|
+
passed_tests=0,
|
|
3673
|
+
total_tests=0,
|
|
3674
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
3675
|
+
)
|
|
3676
|
+
|
|
2923
3677
|
# Write eval script
|
|
2924
3678
|
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
2925
3679
|
write_result = await client.exec(
|
|
@@ -2936,14 +3690,44 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
2936
3690
|
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
2937
3691
|
)
|
|
2938
3692
|
|
|
3693
|
+
# Write defense module if defensive mode is enabled
|
|
3694
|
+
defense_module_path = None
|
|
3695
|
+
if args.defensive:
|
|
3696
|
+
defense_path = (
|
|
3697
|
+
Path(__file__).parent.parent.parent.parent
|
|
3698
|
+
/ "packages"
|
|
3699
|
+
/ "wafer-core"
|
|
3700
|
+
/ "wafer_core"
|
|
3701
|
+
/ "utils"
|
|
3702
|
+
/ "kernel_utils"
|
|
3703
|
+
/ "defense.py"
|
|
3704
|
+
)
|
|
3705
|
+
if defense_path.exists():
|
|
3706
|
+
defense_code = defense_path.read_text()
|
|
3707
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
3708
|
+
write_result = await client.exec(
|
|
3709
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
3710
|
+
)
|
|
3711
|
+
if write_result.exit_code != 0:
|
|
3712
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
3713
|
+
defense_module_path = None
|
|
3714
|
+
else:
|
|
3715
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
3716
|
+
|
|
2939
3717
|
print("Running KernelBench evaluation in Docker container (AMD/ROCm)...")
|
|
2940
3718
|
|
|
2941
3719
|
# Paths inside container
|
|
2942
3720
|
container_run_path = f"{CONTAINER_WORKSPACE}/{run_dir}"
|
|
2943
3721
|
container_impl_path = f"{container_run_path}/implementation.py"
|
|
2944
3722
|
container_ref_path = f"{container_run_path}/reference.py"
|
|
3723
|
+
container_inputs_path = (
|
|
3724
|
+
f"{container_run_path}/custom_inputs.py" if args.inputs else None
|
|
3725
|
+
)
|
|
2945
3726
|
container_eval_script = f"{container_run_path}/kernelbench_eval.py"
|
|
2946
3727
|
container_output = f"{container_run_path}/results.json"
|
|
3728
|
+
container_defense_path = (
|
|
3729
|
+
f"{container_run_path}/defense.py" if defense_module_path else None
|
|
3730
|
+
)
|
|
2947
3731
|
|
|
2948
3732
|
# Build eval command
|
|
2949
3733
|
python_cmd_parts = [
|
|
@@ -2955,6 +3739,14 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
2955
3739
|
|
|
2956
3740
|
if args.benchmark:
|
|
2957
3741
|
python_cmd_parts.append("--benchmark")
|
|
3742
|
+
if args.profile:
|
|
3743
|
+
python_cmd_parts.append("--profile")
|
|
3744
|
+
if container_inputs_path:
|
|
3745
|
+
python_cmd_parts.append(f"--inputs {container_inputs_path}")
|
|
3746
|
+
if args.defensive and container_defense_path:
|
|
3747
|
+
python_cmd_parts.append("--defensive")
|
|
3748
|
+
python_cmd_parts.append(f"--defense-module {container_defense_path}")
|
|
3749
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
2958
3750
|
|
|
2959
3751
|
eval_cmd = " ".join(python_cmd_parts)
|
|
2960
3752
|
|
|
@@ -3039,11 +3831,478 @@ async def run_evaluate_kernelbench_digitalocean(
|
|
|
3039
3831
|
)
|
|
3040
3832
|
|
|
3041
3833
|
|
|
3042
|
-
async def
|
|
3043
|
-
|
|
3044
|
-
|
|
3045
|
-
|
|
3046
|
-
|
|
3834
|
+
async def run_evaluate_kernelbench_runpod(
|
|
3835
|
+
args: KernelBenchEvaluateArgs,
|
|
3836
|
+
target: RunPodTarget,
|
|
3837
|
+
) -> EvaluateResult:
|
|
3838
|
+
"""Run KernelBench format evaluation directly on RunPod AMD GPU.
|
|
3839
|
+
|
|
3840
|
+
Runs evaluation script directly on host (no Docker) since RunPod pods
|
|
3841
|
+
already have PyTorch/ROCm installed.
|
|
3842
|
+
"""
|
|
3843
|
+
from datetime import datetime
|
|
3844
|
+
|
|
3845
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
3846
|
+
from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
|
|
3847
|
+
|
|
3848
|
+
REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
|
|
3849
|
+
|
|
3850
|
+
# Select GPU
|
|
3851
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
3852
|
+
|
|
3853
|
+
print(f"Provisioning RunPod ({target.gpu_type_id})...")
|
|
3854
|
+
|
|
3855
|
+
try:
|
|
3856
|
+
async with runpod_ssh_context(target) as ssh_info:
|
|
3857
|
+
ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
|
|
3858
|
+
print(f"Connected to RunPod: {ssh_target}")
|
|
3859
|
+
|
|
3860
|
+
async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
|
|
3861
|
+
# Create workspace
|
|
3862
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
3863
|
+
run_dir = f"kernelbench_eval_{timestamp}"
|
|
3864
|
+
run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
|
|
3865
|
+
|
|
3866
|
+
await client.exec(f"mkdir -p {run_path}")
|
|
3867
|
+
print(f"Created run directory: {run_path}")
|
|
3868
|
+
|
|
3869
|
+
# Read and upload files
|
|
3870
|
+
impl_code = args.implementation.read_text()
|
|
3871
|
+
ref_code = args.reference.read_text()
|
|
3872
|
+
|
|
3873
|
+
# Write implementation
|
|
3874
|
+
impl_path = f"{run_path}/implementation.py"
|
|
3875
|
+
write_result = await client.exec(
|
|
3876
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
3877
|
+
)
|
|
3878
|
+
if write_result.exit_code != 0:
|
|
3879
|
+
return EvaluateResult(
|
|
3880
|
+
success=False,
|
|
3881
|
+
all_correct=False,
|
|
3882
|
+
correctness_score=0.0,
|
|
3883
|
+
geomean_speedup=0.0,
|
|
3884
|
+
passed_tests=0,
|
|
3885
|
+
total_tests=0,
|
|
3886
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
3887
|
+
)
|
|
3888
|
+
|
|
3889
|
+
# Write reference
|
|
3890
|
+
ref_path = f"{run_path}/reference.py"
|
|
3891
|
+
write_result = await client.exec(
|
|
3892
|
+
f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
|
|
3893
|
+
)
|
|
3894
|
+
if write_result.exit_code != 0:
|
|
3895
|
+
return EvaluateResult(
|
|
3896
|
+
success=False,
|
|
3897
|
+
all_correct=False,
|
|
3898
|
+
correctness_score=0.0,
|
|
3899
|
+
geomean_speedup=0.0,
|
|
3900
|
+
passed_tests=0,
|
|
3901
|
+
total_tests=0,
|
|
3902
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
3903
|
+
)
|
|
3904
|
+
|
|
3905
|
+
# Write custom inputs if provided
|
|
3906
|
+
inputs_path = None
|
|
3907
|
+
if args.inputs:
|
|
3908
|
+
inputs_code = args.inputs.read_text()
|
|
3909
|
+
inputs_path = f"{run_path}/custom_inputs.py"
|
|
3910
|
+
write_result = await client.exec(
|
|
3911
|
+
f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
3912
|
+
)
|
|
3913
|
+
if write_result.exit_code != 0:
|
|
3914
|
+
return EvaluateResult(
|
|
3915
|
+
success=False,
|
|
3916
|
+
all_correct=False,
|
|
3917
|
+
correctness_score=0.0,
|
|
3918
|
+
geomean_speedup=0.0,
|
|
3919
|
+
passed_tests=0,
|
|
3920
|
+
total_tests=0,
|
|
3921
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
3922
|
+
)
|
|
3923
|
+
|
|
3924
|
+
# Write eval script
|
|
3925
|
+
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
3926
|
+
write_result = await client.exec(
|
|
3927
|
+
f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
|
|
3928
|
+
)
|
|
3929
|
+
if write_result.exit_code != 0:
|
|
3930
|
+
return EvaluateResult(
|
|
3931
|
+
success=False,
|
|
3932
|
+
all_correct=False,
|
|
3933
|
+
correctness_score=0.0,
|
|
3934
|
+
geomean_speedup=0.0,
|
|
3935
|
+
passed_tests=0,
|
|
3936
|
+
total_tests=0,
|
|
3937
|
+
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
3938
|
+
)
|
|
3939
|
+
|
|
3940
|
+
# Write defense module if defensive mode is enabled
|
|
3941
|
+
defense_module_path = None
|
|
3942
|
+
if args.defensive:
|
|
3943
|
+
defense_path = (
|
|
3944
|
+
Path(__file__).parent.parent.parent.parent
|
|
3945
|
+
/ "packages"
|
|
3946
|
+
/ "wafer-core"
|
|
3947
|
+
/ "wafer_core"
|
|
3948
|
+
/ "utils"
|
|
3949
|
+
/ "kernel_utils"
|
|
3950
|
+
/ "defense.py"
|
|
3951
|
+
)
|
|
3952
|
+
if defense_path.exists():
|
|
3953
|
+
defense_code = defense_path.read_text()
|
|
3954
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
3955
|
+
write_result = await client.exec(
|
|
3956
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
3957
|
+
)
|
|
3958
|
+
if write_result.exit_code != 0:
|
|
3959
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
3960
|
+
defense_module_path = None
|
|
3961
|
+
else:
|
|
3962
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
3963
|
+
|
|
3964
|
+
print("Running KernelBench evaluation (AMD/ROCm)...")
|
|
3965
|
+
|
|
3966
|
+
# Find Python with PyTorch - check common locations on RunPod
|
|
3967
|
+
python_exe = "python3"
|
|
3968
|
+
for candidate in [
|
|
3969
|
+
"/opt/conda/envs/py_3.10/bin/python3",
|
|
3970
|
+
"/opt/conda/bin/python3",
|
|
3971
|
+
]:
|
|
3972
|
+
check = await client.exec(
|
|
3973
|
+
f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
|
|
3974
|
+
)
|
|
3975
|
+
if "OK" in check.stdout:
|
|
3976
|
+
python_exe = candidate
|
|
3977
|
+
print(f"Using Python: {python_exe}")
|
|
3978
|
+
break
|
|
3979
|
+
|
|
3980
|
+
# Build eval command - run directly on host
|
|
3981
|
+
output_path = f"{run_path}/results.json"
|
|
3982
|
+
python_cmd_parts = [
|
|
3983
|
+
f"{python_exe} {eval_script_path}",
|
|
3984
|
+
f"--impl {impl_path}",
|
|
3985
|
+
f"--reference {ref_path}",
|
|
3986
|
+
f"--output {output_path}",
|
|
3987
|
+
]
|
|
3988
|
+
|
|
3989
|
+
if args.benchmark:
|
|
3990
|
+
python_cmd_parts.append("--benchmark")
|
|
3991
|
+
if args.profile:
|
|
3992
|
+
python_cmd_parts.append("--profile")
|
|
3993
|
+
if inputs_path:
|
|
3994
|
+
python_cmd_parts.append(f"--inputs {inputs_path}")
|
|
3995
|
+
if args.defensive and defense_module_path:
|
|
3996
|
+
python_cmd_parts.append("--defensive")
|
|
3997
|
+
python_cmd_parts.append(f"--defense-module {defense_module_path}")
|
|
3998
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
3999
|
+
|
|
4000
|
+
eval_cmd = " ".join(python_cmd_parts)
|
|
4001
|
+
|
|
4002
|
+
# Set environment for AMD GPU and run
|
|
4003
|
+
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
|
|
4004
|
+
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4005
|
+
|
|
4006
|
+
# Run and stream output
|
|
4007
|
+
log_lines = []
|
|
4008
|
+
async for line in client.exec_stream(full_cmd):
|
|
4009
|
+
print(line)
|
|
4010
|
+
log_lines.append(line)
|
|
4011
|
+
|
|
4012
|
+
# Read results
|
|
4013
|
+
cat_result = await client.exec(f"cat {output_path}")
|
|
4014
|
+
|
|
4015
|
+
if cat_result.exit_code != 0:
|
|
4016
|
+
log_tail = "\n".join(log_lines[-50:])
|
|
4017
|
+
return EvaluateResult(
|
|
4018
|
+
success=False,
|
|
4019
|
+
all_correct=False,
|
|
4020
|
+
correctness_score=0.0,
|
|
4021
|
+
geomean_speedup=0.0,
|
|
4022
|
+
passed_tests=0,
|
|
4023
|
+
total_tests=0,
|
|
4024
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
4025
|
+
)
|
|
4026
|
+
|
|
4027
|
+
# Parse results
|
|
4028
|
+
try:
|
|
4029
|
+
results_data = json.loads(cat_result.stdout)
|
|
4030
|
+
except json.JSONDecodeError as e:
|
|
4031
|
+
return EvaluateResult(
|
|
4032
|
+
success=False,
|
|
4033
|
+
all_correct=False,
|
|
4034
|
+
correctness_score=0.0,
|
|
4035
|
+
geomean_speedup=0.0,
|
|
4036
|
+
passed_tests=0,
|
|
4037
|
+
total_tests=0,
|
|
4038
|
+
error_message=f"Failed to parse results: {e}",
|
|
4039
|
+
)
|
|
4040
|
+
|
|
4041
|
+
# Convert to EvaluateResult
|
|
4042
|
+
correct = results_data.get("correct", False)
|
|
4043
|
+
speedup = results_data.get("speedup", 0.0) or 0.0
|
|
4044
|
+
error = results_data.get("error")
|
|
4045
|
+
|
|
4046
|
+
if error:
|
|
4047
|
+
return EvaluateResult(
|
|
4048
|
+
success=False,
|
|
4049
|
+
all_correct=False,
|
|
4050
|
+
correctness_score=0.0,
|
|
4051
|
+
geomean_speedup=0.0,
|
|
4052
|
+
passed_tests=0,
|
|
4053
|
+
total_tests=1,
|
|
4054
|
+
error_message=error,
|
|
4055
|
+
)
|
|
4056
|
+
|
|
4057
|
+
return EvaluateResult(
|
|
4058
|
+
success=True,
|
|
4059
|
+
all_correct=correct,
|
|
4060
|
+
correctness_score=1.0 if correct else 0.0,
|
|
4061
|
+
geomean_speedup=speedup,
|
|
4062
|
+
passed_tests=1 if correct else 0,
|
|
4063
|
+
total_tests=1,
|
|
4064
|
+
)
|
|
4065
|
+
|
|
4066
|
+
except RunPodError as e:
|
|
4067
|
+
return EvaluateResult(
|
|
4068
|
+
success=False,
|
|
4069
|
+
all_correct=False,
|
|
4070
|
+
correctness_score=0.0,
|
|
4071
|
+
geomean_speedup=0.0,
|
|
4072
|
+
passed_tests=0,
|
|
4073
|
+
total_tests=0,
|
|
4074
|
+
error_message=f"RunPod error: {e}",
|
|
4075
|
+
)
|
|
4076
|
+
|
|
4077
|
+
|
|
4078
|
+
async def run_evaluate_kernelbench_baremetal_amd(
|
|
4079
|
+
args: KernelBenchEvaluateArgs,
|
|
4080
|
+
target: BaremetalTarget,
|
|
4081
|
+
) -> EvaluateResult:
|
|
4082
|
+
"""Run KernelBench format evaluation directly on AMD baremetal target.
|
|
4083
|
+
|
|
4084
|
+
Runs evaluation script directly on host (no Docker) for AMD GPUs
|
|
4085
|
+
that have PyTorch/ROCm installed.
|
|
4086
|
+
"""
|
|
4087
|
+
from datetime import datetime
|
|
4088
|
+
|
|
4089
|
+
from wafer_core.async_ssh import AsyncSSHClient
|
|
4090
|
+
|
|
4091
|
+
REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
|
|
4092
|
+
|
|
4093
|
+
# Select GPU
|
|
4094
|
+
gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
|
|
4095
|
+
|
|
4096
|
+
print(f"Connecting to {target.ssh_target}...")
|
|
4097
|
+
|
|
4098
|
+
async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
|
|
4099
|
+
# Create workspace
|
|
4100
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
4101
|
+
run_dir = f"kernelbench_eval_{timestamp}"
|
|
4102
|
+
run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
|
|
4103
|
+
|
|
4104
|
+
await client.exec(f"mkdir -p {run_path}")
|
|
4105
|
+
print(f"Created run directory: {run_path}")
|
|
4106
|
+
|
|
4107
|
+
# Read and upload files
|
|
4108
|
+
impl_code = args.implementation.read_text()
|
|
4109
|
+
ref_code = args.reference.read_text()
|
|
4110
|
+
|
|
4111
|
+
# Write implementation
|
|
4112
|
+
impl_path = f"{run_path}/implementation.py"
|
|
4113
|
+
write_result = await client.exec(
|
|
4114
|
+
f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
|
|
4115
|
+
)
|
|
4116
|
+
if write_result.exit_code != 0:
|
|
4117
|
+
return EvaluateResult(
|
|
4118
|
+
success=False,
|
|
4119
|
+
all_correct=False,
|
|
4120
|
+
correctness_score=0.0,
|
|
4121
|
+
geomean_speedup=0.0,
|
|
4122
|
+
passed_tests=0,
|
|
4123
|
+
total_tests=0,
|
|
4124
|
+
error_message=f"Failed to write implementation: {write_result.stderr}",
|
|
4125
|
+
)
|
|
4126
|
+
|
|
4127
|
+
# Write reference
|
|
4128
|
+
ref_path = f"{run_path}/reference.py"
|
|
4129
|
+
write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
|
|
4130
|
+
if write_result.exit_code != 0:
|
|
4131
|
+
return EvaluateResult(
|
|
4132
|
+
success=False,
|
|
4133
|
+
all_correct=False,
|
|
4134
|
+
correctness_score=0.0,
|
|
4135
|
+
geomean_speedup=0.0,
|
|
4136
|
+
passed_tests=0,
|
|
4137
|
+
total_tests=0,
|
|
4138
|
+
error_message=f"Failed to write reference: {write_result.stderr}",
|
|
4139
|
+
)
|
|
4140
|
+
|
|
4141
|
+
# Write custom inputs if provided
|
|
4142
|
+
inputs_path = None
|
|
4143
|
+
if args.inputs:
|
|
4144
|
+
inputs_code = args.inputs.read_text()
|
|
4145
|
+
inputs_path = f"{run_path}/custom_inputs.py"
|
|
4146
|
+
write_result = await client.exec(
|
|
4147
|
+
f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
|
|
4148
|
+
)
|
|
4149
|
+
if write_result.exit_code != 0:
|
|
4150
|
+
return EvaluateResult(
|
|
4151
|
+
success=False,
|
|
4152
|
+
all_correct=False,
|
|
4153
|
+
correctness_score=0.0,
|
|
4154
|
+
geomean_speedup=0.0,
|
|
4155
|
+
passed_tests=0,
|
|
4156
|
+
total_tests=0,
|
|
4157
|
+
error_message=f"Failed to write custom inputs: {write_result.stderr}",
|
|
4158
|
+
)
|
|
4159
|
+
|
|
4160
|
+
# Write eval script
|
|
4161
|
+
eval_script_path = f"{run_path}/kernelbench_eval.py"
|
|
4162
|
+
write_result = await client.exec(
|
|
4163
|
+
f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
|
|
4164
|
+
)
|
|
4165
|
+
if write_result.exit_code != 0:
|
|
4166
|
+
return EvaluateResult(
|
|
4167
|
+
success=False,
|
|
4168
|
+
all_correct=False,
|
|
4169
|
+
correctness_score=0.0,
|
|
4170
|
+
geomean_speedup=0.0,
|
|
4171
|
+
passed_tests=0,
|
|
4172
|
+
total_tests=0,
|
|
4173
|
+
error_message=f"Failed to write eval script: {write_result.stderr}",
|
|
4174
|
+
)
|
|
4175
|
+
|
|
4176
|
+
# Write defense module if defensive mode is enabled
|
|
4177
|
+
defense_module_path = None
|
|
4178
|
+
if args.defensive:
|
|
4179
|
+
defense_path = (
|
|
4180
|
+
Path(__file__).parent.parent.parent.parent
|
|
4181
|
+
/ "packages"
|
|
4182
|
+
/ "wafer-core"
|
|
4183
|
+
/ "wafer_core"
|
|
4184
|
+
/ "utils"
|
|
4185
|
+
/ "kernel_utils"
|
|
4186
|
+
/ "defense.py"
|
|
4187
|
+
)
|
|
4188
|
+
if defense_path.exists():
|
|
4189
|
+
defense_code = defense_path.read_text()
|
|
4190
|
+
defense_module_path = f"{run_path}/defense.py"
|
|
4191
|
+
write_result = await client.exec(
|
|
4192
|
+
f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
|
|
4193
|
+
)
|
|
4194
|
+
if write_result.exit_code != 0:
|
|
4195
|
+
print(f"Warning: Failed to write defense module: {write_result.stderr}")
|
|
4196
|
+
defense_module_path = None
|
|
4197
|
+
else:
|
|
4198
|
+
print(f"Warning: defense.py not found at {defense_path}")
|
|
4199
|
+
|
|
4200
|
+
print("Running KernelBench evaluation (AMD/ROCm)...")
|
|
4201
|
+
|
|
4202
|
+
# Find Python with PyTorch - check common locations
|
|
4203
|
+
python_exe = "python3"
|
|
4204
|
+
for candidate in [
|
|
4205
|
+
"/opt/conda/envs/py_3.10/bin/python3",
|
|
4206
|
+
"/opt/conda/bin/python3",
|
|
4207
|
+
]:
|
|
4208
|
+
check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
|
|
4209
|
+
if "OK" in check.stdout:
|
|
4210
|
+
python_exe = candidate
|
|
4211
|
+
print(f"Using Python: {python_exe}")
|
|
4212
|
+
break
|
|
4213
|
+
|
|
4214
|
+
# Build eval command - run directly on host
|
|
4215
|
+
output_path = f"{run_path}/results.json"
|
|
4216
|
+
python_cmd_parts = [
|
|
4217
|
+
f"{python_exe} {eval_script_path}",
|
|
4218
|
+
f"--impl {impl_path}",
|
|
4219
|
+
f"--reference {ref_path}",
|
|
4220
|
+
f"--output {output_path}",
|
|
4221
|
+
]
|
|
4222
|
+
|
|
4223
|
+
if args.benchmark:
|
|
4224
|
+
python_cmd_parts.append("--benchmark")
|
|
4225
|
+
if args.profile:
|
|
4226
|
+
python_cmd_parts.append("--profile")
|
|
4227
|
+
if inputs_path:
|
|
4228
|
+
python_cmd_parts.append(f"--inputs {inputs_path}")
|
|
4229
|
+
if args.defensive and defense_module_path:
|
|
4230
|
+
python_cmd_parts.append("--defensive")
|
|
4231
|
+
python_cmd_parts.append(f"--defense-module {defense_module_path}")
|
|
4232
|
+
python_cmd_parts.append(f"--seed {args.seed}")
|
|
4233
|
+
|
|
4234
|
+
eval_cmd = " ".join(python_cmd_parts)
|
|
4235
|
+
|
|
4236
|
+
# Set environment for AMD GPU and run
|
|
4237
|
+
env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
|
|
4238
|
+
full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
|
|
4239
|
+
|
|
4240
|
+
# Run and stream output
|
|
4241
|
+
log_lines = []
|
|
4242
|
+
async for line in client.exec_stream(full_cmd):
|
|
4243
|
+
print(line)
|
|
4244
|
+
log_lines.append(line)
|
|
4245
|
+
|
|
4246
|
+
# Read results
|
|
4247
|
+
cat_result = await client.exec(f"cat {output_path}")
|
|
4248
|
+
|
|
4249
|
+
if cat_result.exit_code != 0:
|
|
4250
|
+
log_tail = "\n".join(log_lines[-50:])
|
|
4251
|
+
return EvaluateResult(
|
|
4252
|
+
success=False,
|
|
4253
|
+
all_correct=False,
|
|
4254
|
+
correctness_score=0.0,
|
|
4255
|
+
geomean_speedup=0.0,
|
|
4256
|
+
passed_tests=0,
|
|
4257
|
+
total_tests=0,
|
|
4258
|
+
error_message=f"Evaluation failed. Log tail:\n{log_tail}",
|
|
4259
|
+
)
|
|
4260
|
+
|
|
4261
|
+
# Parse results
|
|
4262
|
+
try:
|
|
4263
|
+
results_data = json.loads(cat_result.stdout)
|
|
4264
|
+
except json.JSONDecodeError as e:
|
|
4265
|
+
return EvaluateResult(
|
|
4266
|
+
success=False,
|
|
4267
|
+
all_correct=False,
|
|
4268
|
+
correctness_score=0.0,
|
|
4269
|
+
geomean_speedup=0.0,
|
|
4270
|
+
passed_tests=0,
|
|
4271
|
+
total_tests=0,
|
|
4272
|
+
error_message=f"Failed to parse results: {e}",
|
|
4273
|
+
)
|
|
4274
|
+
|
|
4275
|
+
# Convert to EvaluateResult
|
|
4276
|
+
correct = results_data.get("correct", False)
|
|
4277
|
+
speedup = results_data.get("speedup", 0.0) or 0.0
|
|
4278
|
+
error = results_data.get("error")
|
|
4279
|
+
|
|
4280
|
+
if error:
|
|
4281
|
+
return EvaluateResult(
|
|
4282
|
+
success=False,
|
|
4283
|
+
all_correct=False,
|
|
4284
|
+
correctness_score=0.0,
|
|
4285
|
+
geomean_speedup=0.0,
|
|
4286
|
+
passed_tests=0,
|
|
4287
|
+
total_tests=1,
|
|
4288
|
+
error_message=error,
|
|
4289
|
+
)
|
|
4290
|
+
|
|
4291
|
+
return EvaluateResult(
|
|
4292
|
+
success=True,
|
|
4293
|
+
all_correct=correct,
|
|
4294
|
+
correctness_score=1.0 if correct else 0.0,
|
|
4295
|
+
geomean_speedup=speedup,
|
|
4296
|
+
passed_tests=1 if correct else 0,
|
|
4297
|
+
total_tests=1,
|
|
4298
|
+
)
|
|
4299
|
+
|
|
4300
|
+
|
|
4301
|
+
async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
|
|
4302
|
+
"""Run KernelBench format evaluation on configured target.
|
|
4303
|
+
|
|
4304
|
+
Args:
|
|
4305
|
+
args: KernelBench evaluate arguments
|
|
3047
4306
|
|
|
3048
4307
|
Returns:
|
|
3049
4308
|
Evaluation result
|
|
@@ -3103,7 +4362,13 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
|
|
|
3103
4362
|
if isinstance(target, DigitalOceanTarget):
|
|
3104
4363
|
# DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
|
|
3105
4364
|
return await run_evaluate_kernelbench_digitalocean(args, target)
|
|
4365
|
+
elif isinstance(target, RunPodTarget):
|
|
4366
|
+
# RunPod AMD MI300X - uses ROCm Docker with device passthrough
|
|
4367
|
+
return await run_evaluate_kernelbench_runpod(args, target)
|
|
3106
4368
|
elif isinstance(target, BaremetalTarget | VMTarget):
|
|
4369
|
+
# Check if this is an AMD target (gfx* compute capability) - run directly
|
|
4370
|
+
if target.compute_capability and target.compute_capability.startswith("gfx"):
|
|
4371
|
+
return await run_evaluate_kernelbench_baremetal_amd(args, target)
|
|
3107
4372
|
# NVIDIA targets - require docker_image to be set
|
|
3108
4373
|
if not target.docker_image:
|
|
3109
4374
|
return EvaluateResult(
|
|
@@ -3129,6 +4394,6 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
|
|
|
3129
4394
|
total_tests=0,
|
|
3130
4395
|
error_message=(
|
|
3131
4396
|
f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
|
|
3132
|
-
"Use a DigitalOcean, Baremetal, or VM target."
|
|
4397
|
+
"Use a DigitalOcean, RunPod, Baremetal, or VM target."
|
|
3133
4398
|
),
|
|
3134
4399
|
)
|