wafer-cli 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/evaluate.py CHANGED
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
14
14
  from wafer_core.utils.kernel_utils.targets.config import (
15
15
  BaremetalTarget,
16
16
  DigitalOceanTarget,
17
+ LocalTarget,
17
18
  ModalTarget,
18
19
  RunPodTarget,
19
20
  VMTarget,
@@ -351,6 +352,18 @@ def _build_docker_pip_install_cmd(target: BaremetalTarget | VMTarget) -> str:
351
352
  return " && ".join(commands)
352
353
 
353
354
 
355
+ def _get_wafer_root() -> Path:
356
+ """Get wafer monorepo root directory.
357
+
358
+ Walks up from this file to find the wafer repo root (contains apps/, packages/).
359
+ """
360
+ current = Path(__file__).resolve()
361
+ for parent in [current] + list(current.parents):
362
+ if (parent / "apps").is_dir() and (parent / "packages").is_dir():
363
+ return parent
364
+ raise RuntimeError(f"Could not find wafer root from {__file__}")
365
+
366
+
354
367
  async def run_evaluate_docker(
355
368
  args: EvaluateArgs,
356
369
  target: BaremetalTarget | VMTarget,
@@ -488,7 +501,9 @@ async def run_evaluate_docker(
488
501
 
489
502
  # Build pip install command for torch and other deps, plus wafer-core
490
503
  pip_install_cmd = _build_docker_pip_install_cmd(target)
491
- install_cmd = f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
504
+ install_cmd = (
505
+ f"{pip_install_cmd} && uv pip install --system --break-system-packages wafer-core"
506
+ )
492
507
 
493
508
  # Build evaluate command using installed wafer-core module
494
509
  python_cmd_parts = [
@@ -626,6 +641,181 @@ async def run_evaluate_docker(
626
641
  )
627
642
 
628
643
 
644
+ async def run_evaluate_local(
645
+ args: EvaluateArgs,
646
+ target: LocalTarget,
647
+ ) -> EvaluateResult:
648
+ """Run evaluation locally on the current machine.
649
+
650
+ For LocalTarget - no SSH needed, runs directly.
651
+
652
+ Args:
653
+ args: Evaluate arguments
654
+ target: Local target config
655
+
656
+ Returns:
657
+ Evaluation result
658
+ """
659
+ import os
660
+ import subprocess
661
+ import tempfile
662
+ from datetime import datetime
663
+
664
+ # Select GPU
665
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
666
+
667
+ print(f"Running local evaluation on GPU {gpu_id}...")
668
+
669
+ # Create temp directory for eval files
670
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
671
+ with tempfile.TemporaryDirectory(prefix=f"wafer_eval_{timestamp}_") as run_path:
672
+ run_path = Path(run_path)
673
+
674
+ # Write implementation
675
+ impl_path = run_path / "implementation.py"
676
+ impl_path.write_text(args.implementation.read_text())
677
+
678
+ # Write reference
679
+ ref_path = run_path / "reference.py"
680
+ ref_path.write_text(args.reference.read_text())
681
+
682
+ # Write custom inputs if provided
683
+ inputs_path = None
684
+ if args.inputs:
685
+ inputs_path = run_path / "custom_inputs.py"
686
+ inputs_path.write_text(args.inputs.read_text())
687
+
688
+ # Write eval script
689
+ eval_script_path = run_path / "kernelbench_eval.py"
690
+ eval_script_path.write_text(KERNELBENCH_EVAL_SCRIPT)
691
+
692
+ # Write defense module if defensive mode is enabled
693
+ defense_module_path = None
694
+ if args.defensive:
695
+ defense_src = (
696
+ Path(__file__).parent.parent.parent.parent
697
+ / "packages"
698
+ / "wafer-core"
699
+ / "wafer_core"
700
+ / "utils"
701
+ / "kernel_utils"
702
+ / "defense.py"
703
+ )
704
+ if defense_src.exists():
705
+ defense_module_path = run_path / "defense.py"
706
+ defense_module_path.write_text(defense_src.read_text())
707
+ else:
708
+ print(f"Warning: defense.py not found at {defense_src}")
709
+
710
+ # Output file
711
+ output_path = run_path / "results.json"
712
+
713
+ # Build eval command
714
+ cmd_parts = [
715
+ "python3",
716
+ str(eval_script_path),
717
+ "--impl",
718
+ str(impl_path),
719
+ "--reference",
720
+ str(ref_path),
721
+ "--output",
722
+ str(output_path),
723
+ "--seed",
724
+ str(args.seed),
725
+ ]
726
+
727
+ if args.benchmark:
728
+ cmd_parts.append("--benchmark")
729
+ if args.profile:
730
+ cmd_parts.append("--profile")
731
+ if inputs_path:
732
+ cmd_parts.extend(["--inputs", str(inputs_path)])
733
+ if args.defensive and defense_module_path:
734
+ cmd_parts.extend(["--defensive", "--defense-module", str(defense_module_path)])
735
+
736
+ # Set environment for GPU selection
737
+ env = os.environ.copy()
738
+ if target.vendor == "nvidia":
739
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
740
+ else: # AMD
741
+ env["HIP_VISIBLE_DEVICES"] = str(gpu_id)
742
+ env["ROCM_PATH"] = "/opt/rocm"
743
+
744
+ print(f"Running: {' '.join(cmd_parts[:4])} ...")
745
+
746
+ # Run evaluation
747
+ try:
748
+ result = subprocess.run(
749
+ cmd_parts,
750
+ cwd=str(run_path),
751
+ env=env,
752
+ capture_output=True,
753
+ text=True,
754
+ timeout=args.timeout or 600,
755
+ )
756
+ except subprocess.TimeoutExpired:
757
+ return EvaluateResult(
758
+ success=False,
759
+ all_correct=False,
760
+ correctness_score=0.0,
761
+ geomean_speedup=0.0,
762
+ passed_tests=0,
763
+ total_tests=0,
764
+ error_message="Evaluation timed out",
765
+ )
766
+
767
+ if result.returncode != 0:
768
+ error_msg = result.stderr or result.stdout or "Unknown error"
769
+ # Truncate long errors
770
+ if len(error_msg) > 1000:
771
+ error_msg = error_msg[:500] + "\n...\n" + error_msg[-500:]
772
+ return EvaluateResult(
773
+ success=False,
774
+ all_correct=False,
775
+ correctness_score=0.0,
776
+ geomean_speedup=0.0,
777
+ passed_tests=0,
778
+ total_tests=0,
779
+ error_message=f"Evaluation failed:\n{error_msg}",
780
+ )
781
+
782
+ # Parse results
783
+ if not output_path.exists():
784
+ return EvaluateResult(
785
+ success=False,
786
+ all_correct=False,
787
+ correctness_score=0.0,
788
+ geomean_speedup=0.0,
789
+ passed_tests=0,
790
+ total_tests=0,
791
+ error_message="No results.json produced",
792
+ )
793
+
794
+ try:
795
+ results = json.loads(output_path.read_text())
796
+ except json.JSONDecodeError as e:
797
+ return EvaluateResult(
798
+ success=False,
799
+ all_correct=False,
800
+ correctness_score=0.0,
801
+ geomean_speedup=0.0,
802
+ passed_tests=0,
803
+ total_tests=0,
804
+ error_message=f"Failed to parse results: {e}",
805
+ )
806
+
807
+ # Extract results
808
+ return EvaluateResult(
809
+ success=True,
810
+ all_correct=results.get("all_correct", False),
811
+ correctness_score=results.get("correctness_score", 0.0),
812
+ geomean_speedup=results.get("geomean_speedup", 0.0),
813
+ passed_tests=results.get("passed_tests", 0),
814
+ total_tests=results.get("total_tests", 0),
815
+ benchmark_results=results.get("benchmark", {}),
816
+ )
817
+
818
+
629
819
  async def run_evaluate_ssh(
630
820
  args: EvaluateArgs,
631
821
  target: BaremetalTarget | VMTarget,
@@ -1025,7 +1215,8 @@ print('Files written')
1025
1215
  return
1026
1216
 
1027
1217
  # Write defense module if defensive mode is enabled
1028
- if {run_defensive} and "{defense_code_b64}":
1218
+ # NOTE: Check for actual base64 content, not just truthy string (None becomes "None")
1219
+ if {run_defensive} and "{defense_code_b64}" and "{defense_code_b64}" != "None":
1029
1220
  proc = sandbox.exec("python", "-c", f"""
1030
1221
  import base64
1031
1222
  with open('/workspace/defense.py', 'w') as f:
@@ -1072,6 +1263,14 @@ if run_defensive:
1072
1263
  defense = load_fn('defense.py', 'run_all_defenses')
1073
1264
  time_with_defenses = load_fn('defense.py', 'time_execution_with_defenses')
1074
1265
  print('[Defense] Defense module loaded')
1266
+
1267
+ # Wrap kernels for defense API compatibility
1268
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1269
+ # These wrappers repack the unpacked args back into a tuple
1270
+ def _wrap_for_defense(kernel):
1271
+ return lambda *args: kernel(args)
1272
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1273
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1075
1274
  except Exception as e:
1076
1275
  print(f'[Defense] Warning: Could not load defense module: {{e}}')
1077
1276
  defense = None
@@ -1104,30 +1303,30 @@ for tc in test_cases:
1104
1303
  ref_time_ms = 0.0
1105
1304
  if {run_benchmarks}:
1106
1305
  if run_defensive and defense is not None:
1107
- # Use full defense suite
1108
- # Run defense checks on implementation kernel
1109
- all_passed, defense_results, _ = defense(
1110
- lambda: custom_kernel(inputs),
1111
- )
1306
+ # Use full defense suite with wrapped kernels
1307
+ # inputs_list unpacks the tuple so defense can infer dtype/device from tensors
1308
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1309
+
1310
+ # Run defense checks
1311
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1112
1312
  if not all_passed:
1113
1313
  failed = [name for name, passed, _ in defense_results if not passed]
1114
1314
  raise ValueError(f"Defense checks failed: {{failed}}")
1115
1315
 
1116
- # Time with defensive timing
1316
+ # Time with defensive timing (using wrapped kernels)
1117
1317
  impl_times, _ = time_with_defenses(
1118
- lambda: custom_kernel(inputs),
1119
- [],
1318
+ custom_kernel_for_defense,
1319
+ inputs_list,
1120
1320
  num_warmup=3,
1121
1321
  num_trials=10,
1122
1322
  verbose=False,
1123
- run_defenses=False, # Already ran defenses above
1323
+ run_defenses=False,
1124
1324
  )
1125
1325
  impl_time_ms = sum(impl_times) / len(impl_times)
1126
1326
 
1127
- # Reference timing (no defense checks needed)
1128
1327
  ref_times, _ = time_with_defenses(
1129
- lambda: ref_kernel(inputs),
1130
- [],
1328
+ ref_kernel_for_defense,
1329
+ inputs_list,
1131
1330
  num_warmup=3,
1132
1331
  num_trials=10,
1133
1332
  verbose=False,
@@ -1409,7 +1608,8 @@ with open("/tmp/reference.py", "w") as f:
1409
1608
  # Write defense module if available
1410
1609
  run_defensive = {run_defensive}
1411
1610
  defense_b64 = "{defense_b64}"
1412
- if run_defensive and defense_b64:
1611
+ # NOTE: Check defense_b64 is not empty and not the string "None" (from None formatting)
1612
+ if run_defensive and defense_b64 and defense_b64 != "None":
1413
1613
  defense_code = base64.b64decode(defense_b64).decode()
1414
1614
  with open("/tmp/defense.py", "w") as f:
1415
1615
  f.write(defense_code)
@@ -1429,11 +1629,18 @@ import torch
1429
1629
 
1430
1630
  # Load defense module if available
1431
1631
  defense = None
1432
- if run_defensive and defense_b64:
1632
+ if run_defensive and defense_b64 and defense_b64 != "None":
1433
1633
  try:
1434
1634
  defense = load_fn("/tmp/defense.py", "run_all_defenses")
1435
1635
  time_with_defenses = load_fn("/tmp/defense.py", "time_execution_with_defenses")
1436
1636
  print("[Defense] Defense module loaded")
1637
+
1638
+ # Wrap kernels for defense API compatibility
1639
+ # Defense API calls kernel(*args), but functional format expects kernel(inputs_tuple)
1640
+ def _wrap_for_defense(kernel):
1641
+ return lambda *args: kernel(args)
1642
+ custom_kernel_for_defense = _wrap_for_defense(custom_kernel)
1643
+ ref_kernel_for_defense = _wrap_for_defense(ref_kernel)
1437
1644
  except Exception as e:
1438
1645
  print(f"[Defense] Warning: Could not load defense module: {{e}}")
1439
1646
  defense = None
@@ -1466,18 +1673,19 @@ for tc in test_cases:
1466
1673
  ref_time_ms = 0.0
1467
1674
  if {run_benchmarks}:
1468
1675
  if run_defensive and defense is not None:
1469
- # Use full defense suite
1470
- all_passed, defense_results, _ = defense(
1471
- lambda: custom_kernel(inputs),
1472
- )
1676
+ # Use full defense suite with wrapped kernels
1677
+ inputs_list = list(inputs) if hasattr(inputs, '__iter__') and not isinstance(inputs, torch.Tensor) else [inputs]
1678
+
1679
+ # Run defense checks
1680
+ all_passed, defense_results, _ = defense(custom_kernel_for_defense, *inputs_list)
1473
1681
  if not all_passed:
1474
1682
  failed = [name for name, passed, _ in defense_results if not passed]
1475
1683
  raise ValueError(f"Defense checks failed: {{failed}}")
1476
1684
 
1477
- # Time with defensive timing
1685
+ # Time with defensive timing (using wrapped kernels)
1478
1686
  impl_times, _ = time_with_defenses(
1479
- lambda: custom_kernel(inputs),
1480
- [],
1687
+ custom_kernel_for_defense,
1688
+ inputs_list,
1481
1689
  num_warmup=3,
1482
1690
  num_trials=10,
1483
1691
  verbose=False,
@@ -1485,10 +1693,9 @@ for tc in test_cases:
1485
1693
  )
1486
1694
  impl_time_ms = sum(impl_times) / len(impl_times)
1487
1695
 
1488
- # Reference timing
1489
1696
  ref_times, _ = time_with_defenses(
1490
- lambda: ref_kernel(inputs),
1491
- [],
1697
+ ref_kernel_for_defense,
1698
+ inputs_list,
1492
1699
  num_warmup=3,
1493
1700
  num_trials=10,
1494
1701
  verbose=False,
@@ -1788,12 +1995,54 @@ async def run_evaluate_runpod(
1788
1995
  error_message=f"Failed to setup Python environment: {e}",
1789
1996
  )
1790
1997
 
1791
- # Install wafer-core in remote venv
1792
- print("Installing wafer-core...")
1793
- install_result = await client.exec(
1794
- f"{env_state.venv_bin}/uv pip install wafer-core"
1795
- )
1796
- if install_result.exit_code != 0:
1998
+ # Upload wafer-core to remote
1999
+ try:
2000
+ wafer_root = _get_wafer_root()
2001
+ wafer_core_path = wafer_root / "packages" / "wafer-core"
2002
+ print(f"Uploading wafer-core from {wafer_core_path}...")
2003
+
2004
+ wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
2005
+ await client.exec(f"mkdir -p {wafer_core_remote}")
2006
+ wafer_core_workspace = await client.expand_path(wafer_core_remote)
2007
+
2008
+ upload_result = await client.upload_files(
2009
+ str(wafer_core_path), wafer_core_workspace, recursive=True
2010
+ )
2011
+
2012
+ # Wide event logging for upload result
2013
+ upload_event = {
2014
+ "event": "wafer_core_upload",
2015
+ "target": target.name,
2016
+ "target_type": "runpod",
2017
+ "ssh_host": f"{client.user}@{client.host}:{client.port}",
2018
+ "local_path": str(wafer_core_path),
2019
+ "remote_path": wafer_core_workspace,
2020
+ "success": upload_result.success,
2021
+ "files_copied": upload_result.files_copied,
2022
+ "duration_seconds": upload_result.duration_seconds,
2023
+ "error_message": upload_result.error_message,
2024
+ }
2025
+ if upload_result.debug_info:
2026
+ upload_event["debug_info"] = upload_result.debug_info
2027
+ logger.info(json.dumps(upload_event))
2028
+
2029
+ # Fail fast if upload failed
2030
+ if not upload_result.success:
2031
+ print(f"ERROR: Upload failed: {upload_result.error_message}")
2032
+ if upload_result.debug_info:
2033
+ print(f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}")
2034
+ return EvaluateResult(
2035
+ success=False,
2036
+ all_correct=False,
2037
+ correctness_score=0.0,
2038
+ geomean_speedup=0.0,
2039
+ passed_tests=0,
2040
+ total_tests=0,
2041
+ error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
2042
+ )
2043
+
2044
+ print(f"Uploaded {upload_result.files_copied} files")
2045
+ except Exception as e:
1797
2046
  return EvaluateResult(
1798
2047
  success=False,
1799
2048
  all_correct=False,
@@ -1801,7 +2050,7 @@ async def run_evaluate_runpod(
1801
2050
  geomean_speedup=0.0,
1802
2051
  passed_tests=0,
1803
2052
  total_tests=0,
1804
- error_message=f"Failed to install wafer-core: {install_result.stderr}",
2053
+ error_message=f"Failed to upload wafer-core: {e}",
1805
2054
  )
1806
2055
 
1807
2056
  # Select GPU (RunPod pods typically have GPU 0)
@@ -2098,12 +2347,61 @@ async def run_evaluate_digitalocean(
2098
2347
  error_message=f"Failed to setup Python environment: {e}",
2099
2348
  )
2100
2349
 
2101
- # Install wafer-core in remote venv
2102
- print("Installing wafer-core...")
2103
- install_result = await client.exec(
2104
- f"{env_state.venv_bin}/uv pip install wafer-core"
2105
- )
2106
- if install_result.exit_code != 0:
2350
+ # Upload wafer-core to remote
2351
+ try:
2352
+ wafer_root = _get_wafer_root()
2353
+ wafer_core_path = wafer_root / "packages" / "wafer-core"
2354
+ print(f"Uploading wafer-core from {wafer_core_path}...")
2355
+
2356
+ wafer_core_remote = f"{REMOTE_WORKSPACE}/wafer-core"
2357
+ await client.exec(f"mkdir -p {wafer_core_remote}")
2358
+ wafer_core_workspace = await client.expand_path(wafer_core_remote)
2359
+
2360
+ # Use SFTP instead of rsync to avoid SSH subprocess timeout issues
2361
+ # (DigitalOcean may rate-limit new SSH connections)
2362
+ upload_result = await client.upload_files(
2363
+ str(wafer_core_path),
2364
+ wafer_core_workspace,
2365
+ recursive=True,
2366
+ use_sftp=True,
2367
+ )
2368
+
2369
+ # Wide event logging for upload result
2370
+ upload_event = {
2371
+ "event": "wafer_core_upload",
2372
+ "target": target.name,
2373
+ "target_type": "digitalocean",
2374
+ "ssh_host": f"{client.user}@{client.host}:{client.port}",
2375
+ "local_path": str(wafer_core_path),
2376
+ "remote_path": wafer_core_workspace,
2377
+ "success": upload_result.success,
2378
+ "files_copied": upload_result.files_copied,
2379
+ "duration_seconds": upload_result.duration_seconds,
2380
+ "error_message": upload_result.error_message,
2381
+ }
2382
+ if upload_result.debug_info:
2383
+ upload_event["debug_info"] = upload_result.debug_info
2384
+ logger.info(json.dumps(upload_event))
2385
+
2386
+ # Fail fast if upload failed
2387
+ if not upload_result.success:
2388
+ print(f"ERROR: Upload failed: {upload_result.error_message}")
2389
+ if upload_result.debug_info:
2390
+ print(
2391
+ f"Debug info: {json.dumps(upload_result.debug_info, indent=2)}"
2392
+ )
2393
+ return EvaluateResult(
2394
+ success=False,
2395
+ all_correct=False,
2396
+ correctness_score=0.0,
2397
+ geomean_speedup=0.0,
2398
+ passed_tests=0,
2399
+ total_tests=0,
2400
+ error_message=f"Failed to upload wafer-core: {upload_result.error_message}",
2401
+ )
2402
+
2403
+ print(f"Uploaded {upload_result.files_copied} files")
2404
+ except Exception as e:
2107
2405
  return EvaluateResult(
2108
2406
  success=False,
2109
2407
  all_correct=False,
@@ -2111,7 +2409,7 @@ async def run_evaluate_digitalocean(
2111
2409
  geomean_speedup=0.0,
2112
2410
  passed_tests=0,
2113
2411
  total_tests=0,
2114
- error_message=f"Failed to install wafer-core: {install_result.stderr}",
2412
+ error_message=f"Failed to upload wafer-core: {e}",
2115
2413
  )
2116
2414
 
2117
2415
  # Select GPU (DigitalOcean droplets typically have GPU 0)
@@ -2405,7 +2703,9 @@ async def run_evaluate(args: EvaluateArgs) -> EvaluateResult:
2405
2703
  print(f"Using target: {target_name}")
2406
2704
 
2407
2705
  # Dispatch to appropriate executor
2408
- if isinstance(target, BaremetalTarget | VMTarget):
2706
+ if isinstance(target, LocalTarget):
2707
+ return await run_evaluate_local(args, target)
2708
+ elif isinstance(target, BaremetalTarget | VMTarget):
2409
2709
  return await run_evaluate_ssh(args, target)
2410
2710
  elif isinstance(target, ModalTarget):
2411
2711
  return await run_evaluate_modal(args, target)
@@ -3531,71 +3831,544 @@ async def run_evaluate_kernelbench_digitalocean(
3531
3831
  )
3532
3832
 
3533
3833
 
3534
- async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
3535
- """Run KernelBench format evaluation on configured target.
3536
-
3537
- Args:
3538
- args: KernelBench evaluate arguments
3834
+ async def run_evaluate_kernelbench_runpod(
3835
+ args: KernelBenchEvaluateArgs,
3836
+ target: RunPodTarget,
3837
+ ) -> EvaluateResult:
3838
+ """Run KernelBench format evaluation directly on RunPod AMD GPU.
3539
3839
 
3540
- Returns:
3541
- Evaluation result
3840
+ Runs evaluation script directly on host (no Docker) since RunPod pods
3841
+ already have PyTorch/ROCm installed.
3542
3842
  """
3543
- from .targets import get_default_target, load_target
3843
+ from datetime import datetime
3544
3844
 
3545
- # Validate input files
3546
- err = _validate_kernelbench_files(args)
3547
- if err:
3548
- return EvaluateResult(
3549
- success=False,
3550
- all_correct=False,
3551
- correctness_score=0.0,
3552
- geomean_speedup=0.0,
3553
- passed_tests=0,
3554
- total_tests=0,
3555
- error_message=err,
3556
- )
3845
+ from wafer_core.async_ssh import AsyncSSHClient
3846
+ from wafer_core.targets.runpod import RunPodError, runpod_ssh_context
3557
3847
 
3558
- # Load target
3559
- target_name = args.target_name
3560
- if not target_name:
3561
- target_name = get_default_target()
3562
- if not target_name:
3563
- return EvaluateResult(
3564
- success=False,
3565
- all_correct=False,
3566
- correctness_score=0.0,
3567
- geomean_speedup=0.0,
3568
- passed_tests=0,
3569
- total_tests=0,
3570
- error_message=(
3571
- "No target specified and no default set.\n"
3572
- "Set up a target first:\n"
3573
- " wafer config targets init ssh --name my-gpu --host user@host:22\n"
3574
- " wafer config targets init runpod --gpu MI300X\n"
3575
- "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
3576
- ),
3577
- )
3848
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
3849
+
3850
+ # Select GPU
3851
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
3852
+
3853
+ print(f"Provisioning RunPod ({target.gpu_type_id})...")
3578
3854
 
3579
3855
  try:
3580
- target = load_target(target_name)
3581
- except FileNotFoundError:
3582
- return EvaluateResult(
3583
- success=False,
3584
- all_correct=False,
3585
- correctness_score=0.0,
3586
- geomean_speedup=0.0,
3587
- passed_tests=0,
3588
- total_tests=0,
3589
- error_message=f"Target not found: {target_name}. Run: wafer config targets list",
3590
- )
3856
+ async with runpod_ssh_context(target) as ssh_info:
3857
+ ssh_target = f"{ssh_info.user}@{ssh_info.host}:{ssh_info.port}"
3858
+ print(f"Connected to RunPod: {ssh_target}")
3591
3859
 
3592
- print(f"Using target: {target_name}")
3860
+ async with AsyncSSHClient(ssh_target, target.ssh_key) as client:
3861
+ # Create workspace
3862
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
3863
+ run_dir = f"kernelbench_eval_{timestamp}"
3864
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
3593
3865
 
3594
- # Dispatch to appropriate executor
3595
- if isinstance(target, DigitalOceanTarget):
3596
- # DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
3597
- return await run_evaluate_kernelbench_digitalocean(args, target)
3866
+ await client.exec(f"mkdir -p {run_path}")
3867
+ print(f"Created run directory: {run_path}")
3868
+
3869
+ # Read and upload files
3870
+ impl_code = args.implementation.read_text()
3871
+ ref_code = args.reference.read_text()
3872
+
3873
+ # Write implementation
3874
+ impl_path = f"{run_path}/implementation.py"
3875
+ write_result = await client.exec(
3876
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
3877
+ )
3878
+ if write_result.exit_code != 0:
3879
+ return EvaluateResult(
3880
+ success=False,
3881
+ all_correct=False,
3882
+ correctness_score=0.0,
3883
+ geomean_speedup=0.0,
3884
+ passed_tests=0,
3885
+ total_tests=0,
3886
+ error_message=f"Failed to write implementation: {write_result.stderr}",
3887
+ )
3888
+
3889
+ # Write reference
3890
+ ref_path = f"{run_path}/reference.py"
3891
+ write_result = await client.exec(
3892
+ f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF"
3893
+ )
3894
+ if write_result.exit_code != 0:
3895
+ return EvaluateResult(
3896
+ success=False,
3897
+ all_correct=False,
3898
+ correctness_score=0.0,
3899
+ geomean_speedup=0.0,
3900
+ passed_tests=0,
3901
+ total_tests=0,
3902
+ error_message=f"Failed to write reference: {write_result.stderr}",
3903
+ )
3904
+
3905
+ # Write custom inputs if provided
3906
+ inputs_path = None
3907
+ if args.inputs:
3908
+ inputs_code = args.inputs.read_text()
3909
+ inputs_path = f"{run_path}/custom_inputs.py"
3910
+ write_result = await client.exec(
3911
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
3912
+ )
3913
+ if write_result.exit_code != 0:
3914
+ return EvaluateResult(
3915
+ success=False,
3916
+ all_correct=False,
3917
+ correctness_score=0.0,
3918
+ geomean_speedup=0.0,
3919
+ passed_tests=0,
3920
+ total_tests=0,
3921
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
3922
+ )
3923
+
3924
+ # Write eval script
3925
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
3926
+ write_result = await client.exec(
3927
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
3928
+ )
3929
+ if write_result.exit_code != 0:
3930
+ return EvaluateResult(
3931
+ success=False,
3932
+ all_correct=False,
3933
+ correctness_score=0.0,
3934
+ geomean_speedup=0.0,
3935
+ passed_tests=0,
3936
+ total_tests=0,
3937
+ error_message=f"Failed to write eval script: {write_result.stderr}",
3938
+ )
3939
+
3940
+ # Write defense module if defensive mode is enabled
3941
+ defense_module_path = None
3942
+ if args.defensive:
3943
+ defense_path = (
3944
+ Path(__file__).parent.parent.parent.parent
3945
+ / "packages"
3946
+ / "wafer-core"
3947
+ / "wafer_core"
3948
+ / "utils"
3949
+ / "kernel_utils"
3950
+ / "defense.py"
3951
+ )
3952
+ if defense_path.exists():
3953
+ defense_code = defense_path.read_text()
3954
+ defense_module_path = f"{run_path}/defense.py"
3955
+ write_result = await client.exec(
3956
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
3957
+ )
3958
+ if write_result.exit_code != 0:
3959
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
3960
+ defense_module_path = None
3961
+ else:
3962
+ print(f"Warning: defense.py not found at {defense_path}")
3963
+
3964
+ print("Running KernelBench evaluation (AMD/ROCm)...")
3965
+
3966
+ # Find Python with PyTorch - check common locations on RunPod
3967
+ python_exe = "python3"
3968
+ for candidate in [
3969
+ "/opt/conda/envs/py_3.10/bin/python3",
3970
+ "/opt/conda/bin/python3",
3971
+ ]:
3972
+ check = await client.exec(
3973
+ f"{candidate} -c 'import torch' 2>/dev/null && echo OK"
3974
+ )
3975
+ if "OK" in check.stdout:
3976
+ python_exe = candidate
3977
+ print(f"Using Python: {python_exe}")
3978
+ break
3979
+
3980
+ # Build eval command - run directly on host
3981
+ output_path = f"{run_path}/results.json"
3982
+ python_cmd_parts = [
3983
+ f"{python_exe} {eval_script_path}",
3984
+ f"--impl {impl_path}",
3985
+ f"--reference {ref_path}",
3986
+ f"--output {output_path}",
3987
+ ]
3988
+
3989
+ if args.benchmark:
3990
+ python_cmd_parts.append("--benchmark")
3991
+ if args.profile:
3992
+ python_cmd_parts.append("--profile")
3993
+ if inputs_path:
3994
+ python_cmd_parts.append(f"--inputs {inputs_path}")
3995
+ if args.defensive and defense_module_path:
3996
+ python_cmd_parts.append("--defensive")
3997
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
3998
+ python_cmd_parts.append(f"--seed {args.seed}")
3999
+
4000
+ eval_cmd = " ".join(python_cmd_parts)
4001
+
4002
+ # Set environment for AMD GPU and run
4003
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4004
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4005
+
4006
+ # Run and stream output
4007
+ log_lines = []
4008
+ async for line in client.exec_stream(full_cmd):
4009
+ print(line)
4010
+ log_lines.append(line)
4011
+
4012
+ # Read results
4013
+ cat_result = await client.exec(f"cat {output_path}")
4014
+
4015
+ if cat_result.exit_code != 0:
4016
+ log_tail = "\n".join(log_lines[-50:])
4017
+ return EvaluateResult(
4018
+ success=False,
4019
+ all_correct=False,
4020
+ correctness_score=0.0,
4021
+ geomean_speedup=0.0,
4022
+ passed_tests=0,
4023
+ total_tests=0,
4024
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4025
+ )
4026
+
4027
+ # Parse results
4028
+ try:
4029
+ results_data = json.loads(cat_result.stdout)
4030
+ except json.JSONDecodeError as e:
4031
+ return EvaluateResult(
4032
+ success=False,
4033
+ all_correct=False,
4034
+ correctness_score=0.0,
4035
+ geomean_speedup=0.0,
4036
+ passed_tests=0,
4037
+ total_tests=0,
4038
+ error_message=f"Failed to parse results: {e}",
4039
+ )
4040
+
4041
+ # Convert to EvaluateResult
4042
+ correct = results_data.get("correct", False)
4043
+ speedup = results_data.get("speedup", 0.0) or 0.0
4044
+ error = results_data.get("error")
4045
+
4046
+ if error:
4047
+ return EvaluateResult(
4048
+ success=False,
4049
+ all_correct=False,
4050
+ correctness_score=0.0,
4051
+ geomean_speedup=0.0,
4052
+ passed_tests=0,
4053
+ total_tests=1,
4054
+ error_message=error,
4055
+ )
4056
+
4057
+ return EvaluateResult(
4058
+ success=True,
4059
+ all_correct=correct,
4060
+ correctness_score=1.0 if correct else 0.0,
4061
+ geomean_speedup=speedup,
4062
+ passed_tests=1 if correct else 0,
4063
+ total_tests=1,
4064
+ )
4065
+
4066
+ except RunPodError as e:
4067
+ return EvaluateResult(
4068
+ success=False,
4069
+ all_correct=False,
4070
+ correctness_score=0.0,
4071
+ geomean_speedup=0.0,
4072
+ passed_tests=0,
4073
+ total_tests=0,
4074
+ error_message=f"RunPod error: {e}",
4075
+ )
4076
+
4077
+
4078
+ async def run_evaluate_kernelbench_baremetal_amd(
4079
+ args: KernelBenchEvaluateArgs,
4080
+ target: BaremetalTarget,
4081
+ ) -> EvaluateResult:
4082
+ """Run KernelBench format evaluation directly on AMD baremetal target.
4083
+
4084
+ Runs evaluation script directly on host (no Docker) for AMD GPUs
4085
+ that have PyTorch/ROCm installed.
4086
+ """
4087
+ from datetime import datetime
4088
+
4089
+ from wafer_core.async_ssh import AsyncSSHClient
4090
+
4091
+ REMOTE_WORKSPACE_BASE = "/tmp/wafer_eval"
4092
+
4093
+ # Select GPU
4094
+ gpu_id = args.gpu_id if args.gpu_id is not None else target.gpu_ids[0]
4095
+
4096
+ print(f"Connecting to {target.ssh_target}...")
4097
+
4098
+ async with AsyncSSHClient(target.ssh_target, target.ssh_key) as client:
4099
+ # Create workspace
4100
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
4101
+ run_dir = f"kernelbench_eval_{timestamp}"
4102
+ run_path = f"{REMOTE_WORKSPACE_BASE}/{run_dir}"
4103
+
4104
+ await client.exec(f"mkdir -p {run_path}")
4105
+ print(f"Created run directory: {run_path}")
4106
+
4107
+ # Read and upload files
4108
+ impl_code = args.implementation.read_text()
4109
+ ref_code = args.reference.read_text()
4110
+
4111
+ # Write implementation
4112
+ impl_path = f"{run_path}/implementation.py"
4113
+ write_result = await client.exec(
4114
+ f"cat > '{impl_path}' << 'IMPL_EOF'\n{impl_code}\nIMPL_EOF"
4115
+ )
4116
+ if write_result.exit_code != 0:
4117
+ return EvaluateResult(
4118
+ success=False,
4119
+ all_correct=False,
4120
+ correctness_score=0.0,
4121
+ geomean_speedup=0.0,
4122
+ passed_tests=0,
4123
+ total_tests=0,
4124
+ error_message=f"Failed to write implementation: {write_result.stderr}",
4125
+ )
4126
+
4127
+ # Write reference
4128
+ ref_path = f"{run_path}/reference.py"
4129
+ write_result = await client.exec(f"cat > '{ref_path}' << 'REF_EOF'\n{ref_code}\nREF_EOF")
4130
+ if write_result.exit_code != 0:
4131
+ return EvaluateResult(
4132
+ success=False,
4133
+ all_correct=False,
4134
+ correctness_score=0.0,
4135
+ geomean_speedup=0.0,
4136
+ passed_tests=0,
4137
+ total_tests=0,
4138
+ error_message=f"Failed to write reference: {write_result.stderr}",
4139
+ )
4140
+
4141
+ # Write custom inputs if provided
4142
+ inputs_path = None
4143
+ if args.inputs:
4144
+ inputs_code = args.inputs.read_text()
4145
+ inputs_path = f"{run_path}/custom_inputs.py"
4146
+ write_result = await client.exec(
4147
+ f"cat > '{inputs_path}' << 'INPUTS_EOF'\n{inputs_code}\nINPUTS_EOF"
4148
+ )
4149
+ if write_result.exit_code != 0:
4150
+ return EvaluateResult(
4151
+ success=False,
4152
+ all_correct=False,
4153
+ correctness_score=0.0,
4154
+ geomean_speedup=0.0,
4155
+ passed_tests=0,
4156
+ total_tests=0,
4157
+ error_message=f"Failed to write custom inputs: {write_result.stderr}",
4158
+ )
4159
+
4160
+ # Write eval script
4161
+ eval_script_path = f"{run_path}/kernelbench_eval.py"
4162
+ write_result = await client.exec(
4163
+ f"cat > '{eval_script_path}' << 'EVAL_EOF'\n{KERNELBENCH_EVAL_SCRIPT}\nEVAL_EOF"
4164
+ )
4165
+ if write_result.exit_code != 0:
4166
+ return EvaluateResult(
4167
+ success=False,
4168
+ all_correct=False,
4169
+ correctness_score=0.0,
4170
+ geomean_speedup=0.0,
4171
+ passed_tests=0,
4172
+ total_tests=0,
4173
+ error_message=f"Failed to write eval script: {write_result.stderr}",
4174
+ )
4175
+
4176
+ # Write defense module if defensive mode is enabled
4177
+ defense_module_path = None
4178
+ if args.defensive:
4179
+ defense_path = (
4180
+ Path(__file__).parent.parent.parent.parent
4181
+ / "packages"
4182
+ / "wafer-core"
4183
+ / "wafer_core"
4184
+ / "utils"
4185
+ / "kernel_utils"
4186
+ / "defense.py"
4187
+ )
4188
+ if defense_path.exists():
4189
+ defense_code = defense_path.read_text()
4190
+ defense_module_path = f"{run_path}/defense.py"
4191
+ write_result = await client.exec(
4192
+ f"cat > '{defense_module_path}' << 'DEFENSE_EOF'\n{defense_code}\nDEFENSE_EOF"
4193
+ )
4194
+ if write_result.exit_code != 0:
4195
+ print(f"Warning: Failed to write defense module: {write_result.stderr}")
4196
+ defense_module_path = None
4197
+ else:
4198
+ print(f"Warning: defense.py not found at {defense_path}")
4199
+
4200
+ print("Running KernelBench evaluation (AMD/ROCm)...")
4201
+
4202
+ # Find Python with PyTorch - check common locations
4203
+ python_exe = "python3"
4204
+ for candidate in [
4205
+ "/opt/conda/envs/py_3.10/bin/python3",
4206
+ "/opt/conda/bin/python3",
4207
+ ]:
4208
+ check = await client.exec(f"{candidate} -c 'import torch' 2>/dev/null && echo OK")
4209
+ if "OK" in check.stdout:
4210
+ python_exe = candidate
4211
+ print(f"Using Python: {python_exe}")
4212
+ break
4213
+
4214
+ # Build eval command - run directly on host
4215
+ output_path = f"{run_path}/results.json"
4216
+ python_cmd_parts = [
4217
+ f"{python_exe} {eval_script_path}",
4218
+ f"--impl {impl_path}",
4219
+ f"--reference {ref_path}",
4220
+ f"--output {output_path}",
4221
+ ]
4222
+
4223
+ if args.benchmark:
4224
+ python_cmd_parts.append("--benchmark")
4225
+ if args.profile:
4226
+ python_cmd_parts.append("--profile")
4227
+ if inputs_path:
4228
+ python_cmd_parts.append(f"--inputs {inputs_path}")
4229
+ if args.defensive and defense_module_path:
4230
+ python_cmd_parts.append("--defensive")
4231
+ python_cmd_parts.append(f"--defense-module {defense_module_path}")
4232
+ python_cmd_parts.append(f"--seed {args.seed}")
4233
+
4234
+ eval_cmd = " ".join(python_cmd_parts)
4235
+
4236
+ # Set environment for AMD GPU and run
4237
+ env_vars = f"HIP_VISIBLE_DEVICES={gpu_id} ROCM_PATH=/opt/rocm PYTHONUNBUFFERED=1"
4238
+ full_cmd = f"cd {run_path} && {env_vars} {eval_cmd}"
4239
+
4240
+ # Run and stream output
4241
+ log_lines = []
4242
+ async for line in client.exec_stream(full_cmd):
4243
+ print(line)
4244
+ log_lines.append(line)
4245
+
4246
+ # Read results
4247
+ cat_result = await client.exec(f"cat {output_path}")
4248
+
4249
+ if cat_result.exit_code != 0:
4250
+ log_tail = "\n".join(log_lines[-50:])
4251
+ return EvaluateResult(
4252
+ success=False,
4253
+ all_correct=False,
4254
+ correctness_score=0.0,
4255
+ geomean_speedup=0.0,
4256
+ passed_tests=0,
4257
+ total_tests=0,
4258
+ error_message=f"Evaluation failed. Log tail:\n{log_tail}",
4259
+ )
4260
+
4261
+ # Parse results
4262
+ try:
4263
+ results_data = json.loads(cat_result.stdout)
4264
+ except json.JSONDecodeError as e:
4265
+ return EvaluateResult(
4266
+ success=False,
4267
+ all_correct=False,
4268
+ correctness_score=0.0,
4269
+ geomean_speedup=0.0,
4270
+ passed_tests=0,
4271
+ total_tests=0,
4272
+ error_message=f"Failed to parse results: {e}",
4273
+ )
4274
+
4275
+ # Convert to EvaluateResult
4276
+ correct = results_data.get("correct", False)
4277
+ speedup = results_data.get("speedup", 0.0) or 0.0
4278
+ error = results_data.get("error")
4279
+
4280
+ if error:
4281
+ return EvaluateResult(
4282
+ success=False,
4283
+ all_correct=False,
4284
+ correctness_score=0.0,
4285
+ geomean_speedup=0.0,
4286
+ passed_tests=0,
4287
+ total_tests=1,
4288
+ error_message=error,
4289
+ )
4290
+
4291
+ return EvaluateResult(
4292
+ success=True,
4293
+ all_correct=correct,
4294
+ correctness_score=1.0 if correct else 0.0,
4295
+ geomean_speedup=speedup,
4296
+ passed_tests=1 if correct else 0,
4297
+ total_tests=1,
4298
+ )
4299
+
4300
+
4301
+ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateResult:
4302
+ """Run KernelBench format evaluation on configured target.
4303
+
4304
+ Args:
4305
+ args: KernelBench evaluate arguments
4306
+
4307
+ Returns:
4308
+ Evaluation result
4309
+ """
4310
+ from .targets import get_default_target, load_target
4311
+
4312
+ # Validate input files
4313
+ err = _validate_kernelbench_files(args)
4314
+ if err:
4315
+ return EvaluateResult(
4316
+ success=False,
4317
+ all_correct=False,
4318
+ correctness_score=0.0,
4319
+ geomean_speedup=0.0,
4320
+ passed_tests=0,
4321
+ total_tests=0,
4322
+ error_message=err,
4323
+ )
4324
+
4325
+ # Load target
4326
+ target_name = args.target_name
4327
+ if not target_name:
4328
+ target_name = get_default_target()
4329
+ if not target_name:
4330
+ return EvaluateResult(
4331
+ success=False,
4332
+ all_correct=False,
4333
+ correctness_score=0.0,
4334
+ geomean_speedup=0.0,
4335
+ passed_tests=0,
4336
+ total_tests=0,
4337
+ error_message=(
4338
+ "No target specified and no default set.\n"
4339
+ "Set up a target first:\n"
4340
+ " wafer config targets init ssh --name my-gpu --host user@host:22\n"
4341
+ " wafer config targets init runpod --gpu MI300X\n"
4342
+ "Then use: --target my-gpu (or set default: wafer config targets default my-gpu)"
4343
+ ),
4344
+ )
4345
+
4346
+ try:
4347
+ target = load_target(target_name)
4348
+ except FileNotFoundError:
4349
+ return EvaluateResult(
4350
+ success=False,
4351
+ all_correct=False,
4352
+ correctness_score=0.0,
4353
+ geomean_speedup=0.0,
4354
+ passed_tests=0,
4355
+ total_tests=0,
4356
+ error_message=f"Target not found: {target_name}. Run: wafer config targets list",
4357
+ )
4358
+
4359
+ print(f"Using target: {target_name}")
4360
+
4361
+ # Dispatch to appropriate executor
4362
+ if isinstance(target, DigitalOceanTarget):
4363
+ # DigitalOcean AMD MI300X - uses ROCm Docker with device passthrough
4364
+ return await run_evaluate_kernelbench_digitalocean(args, target)
4365
+ elif isinstance(target, RunPodTarget):
4366
+ # RunPod AMD MI300X - uses ROCm Docker with device passthrough
4367
+ return await run_evaluate_kernelbench_runpod(args, target)
3598
4368
  elif isinstance(target, BaremetalTarget | VMTarget):
4369
+ # Check if this is an AMD target (gfx* compute capability) - run directly
4370
+ if target.compute_capability and target.compute_capability.startswith("gfx"):
4371
+ return await run_evaluate_kernelbench_baremetal_amd(args, target)
3599
4372
  # NVIDIA targets - require docker_image to be set
3600
4373
  if not target.docker_image:
3601
4374
  return EvaluateResult(
@@ -3621,6 +4394,6 @@ async def run_evaluate_kernelbench(args: KernelBenchEvaluateArgs) -> EvaluateRes
3621
4394
  total_tests=0,
3622
4395
  error_message=(
3623
4396
  f"Target type '{type(target).__name__}' not yet supported for KernelBench format. "
3624
- "Use a DigitalOcean, Baremetal, or VM target."
4397
+ "Use a DigitalOcean, RunPod, Baremetal, or VM target."
3625
4398
  ),
3626
4399
  )