wafer-cli 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/baseline.py ADDED
@@ -0,0 +1,661 @@
1
+ """Baseline CLI commands.
2
+
3
+ Discover what kernel PyTorch dispatches to for a given operation.
4
+ Helps understand the baseline performance you need to beat.
5
+ """
6
+
7
+ import asyncio
8
+
9
+ import typer
10
+
11
+ from wafer_core.tools.dispatch_baseline.client import (
12
+ lookup_baseline,
13
+ store_baseline,
14
+ )
15
+ from wafer_core.tools.dispatch_baseline.codegen import (
16
+ parse_op_string,
17
+ update_dtypes,
18
+ update_shapes,
19
+ )
20
+ from wafer_core.tools.dispatch_baseline.dtypes import KernelTraceConfig
21
+ from wafer_core.tools.dispatch_baseline.executor import trace_kernel_local
22
+ from wafer_core.tools.dispatch_baseline.roofline import HARDWARE_SPECS, get_hardware_spec
23
+
24
+ baseline_app = typer.Typer(
25
+ help="""Discover what kernel PyTorch dispatches to for a given operation.
26
+
27
+ This helps you understand the baseline performance you need to beat when writing
28
+ custom kernels. Run a PyTorch op, profile it, and see:
29
+ - What kernel PyTorch uses (cuBLAS, cuDNN, Triton, etc.)
30
+ - How fast it runs
31
+ - What % of peak hardware performance it achieves
32
+
33
+ Results are stored in a shared database - once traced, everyone benefits.
34
+
35
+ Examples:
36
+ # Run baseline trace
37
+ wafer baseline run "torch.matmul(A, B)" -s A=4096,4096 -s B=4096,4096 --target b200-dev
38
+
39
+ # Show supported hardware
40
+ wafer baseline hardware"""
41
+ )
42
+
43
+
44
+ def _parse_shape(shape_str: str) -> tuple[str, tuple[int, ...]]:
45
+ """Parse shape string like 'A=4096,4096' into (name, shape)."""
46
+ if "=" not in shape_str:
47
+ raise typer.BadParameter(f"Invalid shape format: {shape_str}. Expected: name=dim1,dim2,...")
48
+
49
+ name, dims_str = shape_str.split("=", 1)
50
+ try:
51
+ dims = tuple(int(d.strip()) for d in dims_str.split(","))
52
+ except ValueError:
53
+ raise typer.BadParameter(f"Invalid dimensions in shape: {dims_str}")
54
+
55
+ return name.strip(), dims
56
+
57
+
58
+ def _complete_target_name(incomplete: str) -> list[str]:
59
+ """Autocomplete target names from ~/.wafer/targets/*.toml"""
60
+ from pathlib import Path
61
+
62
+ targets_dir = Path.home() / ".wafer" / "targets"
63
+ if not targets_dir.exists():
64
+ return []
65
+ return [f.stem for f in targets_dir.glob("*.toml") if f.stem.startswith(incomplete)]
66
+
67
+
68
+ @baseline_app.command("run")
69
+ def baseline_run_cmd(
70
+ op: str = typer.Argument(
71
+ ...,
72
+ help='PyTorch operation to trace, e.g., "torch.matmul(A, B)"',
73
+ ),
74
+ shape: list[str] = typer.Option(
75
+ [],
76
+ "--shape",
77
+ "-s",
78
+ help="Tensor shape as name=dim1,dim2,... (can specify multiple)",
79
+ ),
80
+ dtype: str = typer.Option(
81
+ "float16",
82
+ "--dtype",
83
+ "-d",
84
+ help="Data type for tensors (float16, float32, bfloat16, etc.)",
85
+ ),
86
+ hardware: str = typer.Option(
87
+ None,
88
+ "--hardware",
89
+ help="Hardware name for roofline analysis (auto-detected from target if not specified)",
90
+ ),
91
+ target: str = typer.Option(
92
+ None,
93
+ "--target",
94
+ "-t",
95
+ help="GPU target name (see 'wafer config targets list')",
96
+ autocompletion=_complete_target_name,
97
+ ),
98
+ workspace: str = typer.Option(
99
+ None,
100
+ "--workspace",
101
+ "-w",
102
+ help="Workspace name (see 'wafer workspaces list')",
103
+ ),
104
+ num_warmup: int = typer.Option(
105
+ 10,
106
+ "--warmup",
107
+ help="Number of warmup iterations",
108
+ ),
109
+ num_runs: int = typer.Option(
110
+ 100,
111
+ "--runs",
112
+ help="Number of profiling runs",
113
+ ),
114
+ no_cache: bool = typer.Option(
115
+ False,
116
+ "--no-cache",
117
+ help="Skip cache and always run fresh trace",
118
+ ),
119
+ json_output: bool = typer.Option(
120
+ False,
121
+ "--json",
122
+ help="Output as JSON for programmatic use",
123
+ ),
124
+ verbose: bool = typer.Option(
125
+ False,
126
+ "--verbose",
127
+ "-v",
128
+ help="Show verbose output including raw profiler data",
129
+ ),
130
+ timeout: int = typer.Option(
131
+ 120,
132
+ "--timeout",
133
+ help="Timeout in seconds for profiling (default: 120)",
134
+ ),
135
+ ) -> None:
136
+ """Discover what kernel PyTorch dispatches to for a given operation.
137
+
138
+ This runs the operation on your GPU with profiling and reports:
139
+ - Which kernel(s) PyTorch dispatches to
140
+ - Duration of each kernel
141
+ - Library that provides the kernel (cuBLAS, cuDNN, etc.)
142
+ - Roofline analysis (% of peak compute/memory bandwidth)
143
+
144
+ Examples:
145
+ # Run on a target
146
+ wafer baseline run "torch.matmul(A, B)" -s A=4096,4096 -s B=4096,4096 --target b200-dev
147
+
148
+ # Run on a workspace
149
+ wafer baseline run "torch.matmul(A, B)" -s A=4096,4096 -s B=4096,4096 --workspace cutlass-b200-eval
150
+
151
+ # Run locally (requires local GPU)
152
+ wafer baseline run "torch.matmul(A, B)" -s A=4096,4096 -s B=4096,4096
153
+
154
+ # With specific hardware for roofline
155
+ wafer baseline run "torch.matmul(A, B)" -s A=4096,4096 -s B=4096,4096 --target b200-dev --hardware B200
156
+ """
157
+ # Validate mutually exclusive options
158
+ if target and workspace:
159
+ typer.echo("Error: Cannot specify both --target and --workspace", err=True)
160
+ raise typer.Exit(1)
161
+
162
+ # Dispatch to appropriate execution mode
163
+ if target:
164
+ asyncio.run(_run_on_target(
165
+ op, shape, dtype, hardware, target, num_warmup, num_runs, no_cache, json_output, verbose, timeout
166
+ ))
167
+ elif workspace:
168
+ asyncio.run(_run_on_workspace(
169
+ op, shape, dtype, hardware, workspace, num_warmup, num_runs, no_cache, json_output, verbose, timeout
170
+ ))
171
+ else:
172
+ _run_locally(op, shape, dtype, hardware, num_warmup, num_runs, no_cache, json_output, verbose, timeout)
173
+
174
+
175
+ def _run_locally(
176
+ op: str,
177
+ shape: list[str],
178
+ dtype: str,
179
+ hardware: str | None,
180
+ num_warmup: int,
181
+ num_runs: int,
182
+ no_cache: bool,
183
+ json_output: bool,
184
+ verbose: bool,
185
+ timeout: int,
186
+ ) -> None:
187
+ """Run baseline trace on local GPU."""
188
+ import torch
189
+
190
+ # Check CUDA availability
191
+ if not torch.cuda.is_available():
192
+ typer.echo("Error: CUDA not available on this machine", err=True)
193
+ typer.echo("Use --target or --workspace to run on a remote GPU.", err=True)
194
+ raise typer.Exit(1)
195
+
196
+ # Auto-detect hardware if not specified
197
+ if hardware is None:
198
+ hardware = _detect_local_hardware()
199
+ if hardware:
200
+ if not json_output:
201
+ typer.echo(f"Auto-detected hardware: {hardware}")
202
+ else:
203
+ gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown"
204
+ if not json_output:
205
+ typer.echo(f"Warning: No roofline specs for '{gpu_name}'", err=True)
206
+ typer.echo(f"Supported hardware: {', '.join(HARDWARE_SPECS.keys())}", err=True)
207
+ typer.echo("Roofline analysis will be skipped.", err=True)
208
+ typer.echo("")
209
+
210
+ # Parse operation
211
+ try:
212
+ op_spec = parse_op_string(op)
213
+ except ValueError as e:
214
+ typer.echo(f"Error parsing operation: {e}", err=True)
215
+ raise typer.Exit(1)
216
+
217
+ # Parse shapes
218
+ shapes: dict[str, tuple[int, ...]] = {}
219
+ for shape_str in shape:
220
+ try:
221
+ name, dims = _parse_shape(shape_str)
222
+ shapes[name] = dims
223
+ except typer.BadParameter as e:
224
+ typer.echo(f"Error: {e}", err=True)
225
+ raise typer.Exit(1)
226
+
227
+ # Update op_spec with shapes and dtype
228
+ if shapes:
229
+ op_spec = update_shapes(op_spec, shapes)
230
+ op_spec = update_dtypes(op_spec, dtype)
231
+
232
+ # Validate hardware
233
+ hw_spec = get_hardware_spec(hardware)
234
+ if hw_spec is None:
235
+ typer.echo(f"Warning: Unknown hardware '{hardware}', roofline analysis will be skipped", err=True)
236
+ typer.echo(f"Supported hardware: {', '.join(HARDWARE_SPECS.keys())}", err=True)
237
+
238
+ # Get current environment for cache lookup
239
+ pytorch_version = torch.__version__
240
+ props = torch.cuda.get_device_properties(0)
241
+
242
+ # Detect runtime version and architecture (CUDA vs ROCm)
243
+ if hasattr(torch.version, 'hip') and torch.version.hip:
244
+ runtime_version = torch.version.hip
245
+ gpu_arch = getattr(props, 'gcnArchName', f"gfx{props.major}{props.minor}")
246
+ else:
247
+ runtime_version = torch.version.cuda or "unknown"
248
+ gpu_arch = f"sm_{props.major}{props.minor}"
249
+
250
+ # Check cache first (unless --no-cache)
251
+ from_cache = False
252
+ if not no_cache:
253
+ cached = lookup_baseline(op_spec, hardware, pytorch_version, runtime_version, gpu_arch)
254
+ if cached is not None:
255
+ from_cache = True
256
+ # Re-compute roofline with current hardware specs (in case they've been updated)
257
+ config = KernelTraceConfig(op_spec=op_spec, hardware=hardware, num_warmup=0, num_runs=0)
258
+ from wafer_core.tools.dispatch_baseline.executor import _add_roofline_analysis
259
+ result = _add_roofline_analysis(cached, config)
260
+ if not json_output:
261
+ typer.echo(f"Using cached result (key: {pytorch_version}/{runtime_version}/{gpu_arch})")
262
+ typer.echo("")
263
+
264
+ if not from_cache:
265
+ # Create config
266
+ config = KernelTraceConfig(
267
+ op_spec=op_spec,
268
+ hardware=hardware,
269
+ num_warmup=num_warmup,
270
+ num_runs=num_runs,
271
+ timeout_seconds=timeout,
272
+ )
273
+
274
+ # Run trace
275
+ if not json_output:
276
+ typer.echo(f"Profiling: {op_spec}")
277
+ typer.echo(f"Hardware: {hardware}")
278
+ typer.echo("")
279
+
280
+ exec_result = trace_kernel_local(config)
281
+ result = exec_result.result
282
+
283
+ # Cache the result
284
+ if not result.error:
285
+ store_baseline(
286
+ result,
287
+ exec_result.pytorch_version,
288
+ exec_result.runtime_version,
289
+ exec_result.gpu_arch,
290
+ )
291
+
292
+ # Output results
293
+ _output_result(result, json_output, verbose, from_cache)
294
+
295
+
296
+ def _detect_local_hardware() -> str:
297
+ """Detect GPU hardware name from local CUDA device.
298
+
299
+ Only returns hardware names that we have specs for (B200, MI300X).
300
+ Returns None for unsupported hardware.
301
+ """
302
+ import torch
303
+
304
+ if not torch.cuda.is_available():
305
+ return None
306
+
307
+ gpu_name = torch.cuda.get_device_name(0).upper()
308
+
309
+ # Only return hardware we have roofline specs for
310
+ if "B200" in gpu_name:
311
+ return "B200"
312
+ elif "MI300X" in gpu_name:
313
+ return "MI300X"
314
+ else:
315
+ return None # Unsupported hardware
316
+
317
+
318
+ def _detect_hardware_from_target(target_config) -> str | None:
319
+ """Detect hardware from target configuration.
320
+
321
+ Only returns hardware names that we have specs for (B200, MI300X).
322
+ """
323
+ gpu_type = getattr(target_config, "gpu_type", None)
324
+ if gpu_type:
325
+ gpu_upper = gpu_type.upper()
326
+ if gpu_upper in HARDWARE_SPECS:
327
+ return gpu_upper
328
+ return None
329
+
330
+
331
+ async def _run_on_target(
332
+ op: str,
333
+ shape: list[str],
334
+ dtype: str,
335
+ hardware: str | None,
336
+ target_name: str,
337
+ num_warmup: int,
338
+ num_runs: int,
339
+ no_cache: bool,
340
+ json_output: bool,
341
+ verbose: bool,
342
+ timeout: int,
343
+ ) -> None:
344
+ """Run baseline trace on a configured target via SSH."""
345
+ from wafer_core.ssh import SSHClient
346
+ from wafer_core.tools.dispatch_baseline.codegen import generate_trace_script
347
+ from wafer_core.tools.dispatch_baseline.executor import trace_kernel_remote
348
+
349
+ from .targets import load_target
350
+ from .targets_ops import TargetExecError, get_target_ssh_info
351
+
352
+ # Load target config
353
+ try:
354
+ target_config = load_target(target_name)
355
+ except FileNotFoundError:
356
+ typer.echo(f"Error: Target '{target_name}' not found", err=True)
357
+ typer.echo("Run 'wafer config targets list' to see available targets", err=True)
358
+ raise typer.Exit(1)
359
+
360
+ # Auto-detect hardware from target if not specified
361
+ if hardware is None:
362
+ hardware = _detect_hardware_from_target(target_config)
363
+ if hardware:
364
+ if not json_output:
365
+ typer.echo(f"Auto-detected hardware from target: {hardware}")
366
+ else:
367
+ if not json_output:
368
+ typer.echo(f"Warning: No roofline specs for target's GPU", err=True)
369
+ typer.echo(f"Supported hardware: {', '.join(HARDWARE_SPECS.keys())}", err=True)
370
+ typer.echo("Roofline analysis will be skipped.", err=True)
371
+ typer.echo("")
372
+
373
+ # Get SSH info
374
+ try:
375
+ ssh_info = await get_target_ssh_info(target_config)
376
+ except TargetExecError as e:
377
+ typer.echo(f"Error: {e}", err=True)
378
+ raise typer.Exit(1)
379
+
380
+ # Parse operation and create config
381
+ try:
382
+ op_spec = parse_op_string(op)
383
+ except ValueError as e:
384
+ typer.echo(f"Error parsing operation: {e}", err=True)
385
+ raise typer.Exit(1)
386
+
387
+ shapes: dict[str, tuple[int, ...]] = {}
388
+ for shape_str in shape:
389
+ try:
390
+ name, dims = _parse_shape(shape_str)
391
+ shapes[name] = dims
392
+ except typer.BadParameter as e:
393
+ typer.echo(f"Error: {e}", err=True)
394
+ raise typer.Exit(1)
395
+
396
+ if shapes:
397
+ op_spec = update_shapes(op_spec, shapes)
398
+ op_spec = update_dtypes(op_spec, dtype)
399
+
400
+ config = KernelTraceConfig(
401
+ op_spec=op_spec,
402
+ hardware=hardware,
403
+ num_warmup=num_warmup,
404
+ num_runs=num_runs,
405
+ )
406
+
407
+ if not json_output:
408
+ typer.echo(f"Profiling: {op_spec}")
409
+ typer.echo(f"Target: {target_name}")
410
+ typer.echo(f"Hardware: {hardware}")
411
+ typer.echo("")
412
+
413
+ # Create SSH client and run trace
414
+ ssh_client = SSHClient(
415
+ host=ssh_info.host,
416
+ port=ssh_info.port,
417
+ username=ssh_info.user,
418
+ key_path=str(ssh_info.key_path),
419
+ )
420
+
421
+ try:
422
+ ssh_client.connect()
423
+ exec_result = trace_kernel_remote(config, ssh_client)
424
+ result = exec_result.result
425
+
426
+ # Cache the result
427
+ if not result.error and not no_cache:
428
+ store_baseline(
429
+ result,
430
+ exec_result.pytorch_version,
431
+ exec_result.runtime_version,
432
+ exec_result.gpu_arch,
433
+ )
434
+ finally:
435
+ ssh_client.close()
436
+
437
+ _output_result(result, json_output, verbose, from_cache=False)
438
+
439
+
440
+ async def _run_on_workspace(
441
+ op: str,
442
+ shape: list[str],
443
+ dtype: str,
444
+ hardware: str | None,
445
+ workspace_name: str,
446
+ num_warmup: int,
447
+ num_runs: int,
448
+ no_cache: bool,
449
+ json_output: bool,
450
+ verbose: bool,
451
+ timeout: int,
452
+ ) -> None:
453
+ """Run baseline trace on a workspace."""
454
+ import subprocess
455
+ import tempfile
456
+ from pathlib import Path
457
+
458
+ from wafer_core.tools.dispatch_baseline.analyzer import parse_trace_output
459
+ from wafer_core.tools.dispatch_baseline.codegen import generate_trace_script
460
+ from wafer_core.tools.dispatch_baseline.executor import _add_roofline_analysis
461
+
462
+ # Parse operation and create config
463
+ try:
464
+ op_spec = parse_op_string(op)
465
+ except ValueError as e:
466
+ typer.echo(f"Error parsing operation: {e}", err=True)
467
+ raise typer.Exit(1)
468
+
469
+ shapes: dict[str, tuple[int, ...]] = {}
470
+ for shape_str in shape:
471
+ try:
472
+ name, dims = _parse_shape(shape_str)
473
+ shapes[name] = dims
474
+ except typer.BadParameter as e:
475
+ typer.echo(f"Error: {e}", err=True)
476
+ raise typer.Exit(1)
477
+
478
+ if shapes:
479
+ op_spec = update_shapes(op_spec, shapes)
480
+ op_spec = update_dtypes(op_spec, dtype)
481
+
482
+ # Default hardware for workspaces (can be overridden)
483
+ if hardware is None:
484
+ # Try to detect from workspace name (only supported hardware)
485
+ ws_lower = workspace_name.lower()
486
+ if "b200" in ws_lower:
487
+ hardware = "B200"
488
+ elif "mi300" in ws_lower:
489
+ hardware = "MI300X"
490
+ else:
491
+ hardware = None
492
+
493
+ if hardware:
494
+ if not json_output:
495
+ typer.echo(f"Auto-detected hardware from workspace name: {hardware}")
496
+ else:
497
+ if not json_output:
498
+ typer.echo(f"Warning: Could not detect hardware from workspace name '{workspace_name}'", err=True)
499
+ typer.echo(f"Supported hardware: {', '.join(HARDWARE_SPECS.keys())}", err=True)
500
+ typer.echo("Roofline analysis will be skipped.", err=True)
501
+ typer.echo("")
502
+
503
+ config = KernelTraceConfig(
504
+ op_spec=op_spec,
505
+ hardware=hardware,
506
+ num_warmup=num_warmup,
507
+ num_runs=num_runs,
508
+ )
509
+
510
+ if not json_output:
511
+ typer.echo(f"Profiling: {op_spec}")
512
+ typer.echo(f"Workspace: {workspace_name}")
513
+ typer.echo(f"Hardware: {hardware}")
514
+ typer.echo("")
515
+
516
+ # Generate script
517
+ script = generate_trace_script(config)
518
+
519
+ # Write to temp file and sync to workspace
520
+ with tempfile.TemporaryDirectory() as tmpdir:
521
+ script_path = Path(tmpdir) / "baseline_trace.py"
522
+ script_path.write_text(script)
523
+
524
+ # Sync to workspace using wafer CLI
525
+ sync_result = subprocess.run(
526
+ ["wafer", "workspaces", "sync", workspace_name, str(tmpdir)],
527
+ capture_output=True,
528
+ text=True,
529
+ )
530
+ if sync_result.returncode != 0:
531
+ typer.echo(f"Error syncing to workspace: {sync_result.stderr}", err=True)
532
+ raise typer.Exit(1)
533
+
534
+ # Execute on workspace
535
+ exec_result = subprocess.run(
536
+ ["wafer", "workspaces", "exec", "--timeout", str(timeout), workspace_name,
537
+ "python /workspace/baseline_trace.py"],
538
+ capture_output=True,
539
+ text=True,
540
+ )
541
+
542
+ output = exec_result.stdout + exec_result.stderr
543
+
544
+ # Parse result
545
+ parsed = parse_trace_output(output, op_spec, hardware)
546
+ result = _add_roofline_analysis(parsed.result, config)
547
+
548
+ # Cache the result
549
+ if not result.error and not no_cache:
550
+ store_baseline(
551
+ result,
552
+ parsed.pytorch_version,
553
+ parsed.runtime_version,
554
+ parsed.gpu_arch,
555
+ )
556
+
557
+ _output_result(result, json_output, verbose, from_cache=False)
558
+
559
+
560
+ def _output_result(result, json_output: bool, verbose: bool, from_cache: bool = False) -> None:
561
+ """Output trace result in the requested format."""
562
+ if json_output:
563
+ import json
564
+
565
+ output = {
566
+ "op": str(result.op_spec),
567
+ "hardware": result.hardware,
568
+ "total_duration_us": result.total_duration_us,
569
+ "from_cache": from_cache,
570
+ "kernels": [
571
+ {
572
+ "name": k.name,
573
+ "duration_us": k.duration_us,
574
+ }
575
+ for k in result.kernels
576
+ ],
577
+ "primary_kernel": {
578
+ "name": result.primary_kernel.name,
579
+ "duration_us": result.primary_kernel.duration_us,
580
+ }
581
+ if result.primary_kernel
582
+ else None,
583
+ "roofline": {
584
+ "achieved_tflops": result.roofline.achieved_tflops,
585
+ "achieved_memory_bw_tbps": result.roofline.achieved_memory_bw_tbps,
586
+ "compute_pct_of_peak": result.roofline.compute_pct_of_peak,
587
+ "memory_bw_pct_of_peak": result.roofline.memory_bw_pct_of_peak,
588
+ "bottleneck": result.roofline.bottleneck,
589
+ }
590
+ if result.roofline
591
+ else None,
592
+ "error": result.error,
593
+ }
594
+ typer.echo(json.dumps(output, indent=2))
595
+ else:
596
+ if result.error:
597
+ typer.echo(f"Error: {result.error}", err=True)
598
+ if verbose and result.raw_output:
599
+ typer.echo("\nRaw output:")
600
+ typer.echo(result.raw_output)
601
+ raise typer.Exit(1)
602
+
603
+ if from_cache:
604
+ typer.echo("(from cache)")
605
+ typer.echo("")
606
+
607
+ typer.echo(result.summary())
608
+
609
+ if verbose and result.raw_output:
610
+ typer.echo("\n--- Raw Profiler Output ---")
611
+ typer.echo(result.raw_output)
612
+
613
+
614
+ @baseline_app.command("hardware")
615
+ def hardware_cmd(
616
+ json_output: bool = typer.Option(
617
+ False,
618
+ "--json",
619
+ help="Output as JSON",
620
+ ),
621
+ ) -> None:
622
+ """List supported hardware and their specifications.
623
+
624
+ Shows peak FLOPS and memory bandwidth for each supported GPU,
625
+ used for roofline analysis calculations.
626
+
627
+ Examples:
628
+ wafer baseline hardware
629
+ wafer baseline hardware --json
630
+ """
631
+ if json_output:
632
+ import json
633
+
634
+ output = {
635
+ name: {
636
+ "peak_fp16_tflops": spec.peak_fp16_tflops,
637
+ "peak_fp32_tflops": spec.peak_fp32_tflops,
638
+ "peak_memory_bw_tbps": spec.peak_memory_bw_tbps,
639
+ "peak_fp8_tflops": spec.peak_fp8_tflops,
640
+ "peak_int8_tops": spec.peak_int8_tops,
641
+ }
642
+ for name, spec in HARDWARE_SPECS.items()
643
+ }
644
+ typer.echo(json.dumps(output, indent=2))
645
+ else:
646
+ typer.echo("Supported Hardware for Roofline Analysis")
647
+ typer.echo("=" * 60)
648
+ typer.echo("")
649
+ typer.echo(f"{'Name':<12} {'FP16 TFLOPS':<14} {'FP32 TFLOPS':<14} {'Mem BW (TB/s)':<14}")
650
+ typer.echo("-" * 60)
651
+
652
+ for name, spec in sorted(HARDWARE_SPECS.items()):
653
+ typer.echo(
654
+ f"{name:<12} {spec.peak_fp16_tflops:<14.1f} {spec.peak_fp32_tflops:<14.1f} {spec.peak_memory_bw_tbps:<14.2f}"
655
+ )
656
+
657
+ typer.echo("")
658
+ typer.echo("Note: FP16 TFLOPS shown without sparsity for most GPUs.")
659
+ typer.echo("Use --json for complete specifications.")
660
+
661
+
wafer/cli.py CHANGED
@@ -8,6 +8,7 @@
8
8
  Core commands:
9
9
  agent AI assistant for GPU kernel development
10
10
  evaluate Test kernel correctness and performance
11
+ baseline Discover what kernel PyTorch uses for an op
11
12
  corpus Download GPU documentation for local access
12
13
  workspaces Manage cloud GPU environments
13
14
 
@@ -279,19 +280,19 @@ from wafer.targets_cli import (
279
280
  targets_list as _targets_list_cmd,
280
281
  )
281
282
  from wafer.targets_cli import (
282
- targets_provision as _targets_provision_cmd,
283
+ targets_pools as _targets_pools_cmd,
283
284
  )
284
285
  from wafer.targets_cli import (
285
- targets_reconcile as _targets_reconcile_cmd,
286
+ targets_probe as _targets_probe_cmd,
286
287
  )
287
288
  from wafer.targets_cli import (
288
- targets_terminate as _targets_terminate_cmd,
289
+ targets_provision as _targets_provision_cmd,
289
290
  )
290
291
  from wafer.targets_cli import (
291
- targets_pools as _targets_pools_cmd,
292
+ targets_reconcile as _targets_reconcile_cmd,
292
293
  )
293
294
  from wafer.targets_cli import (
294
- targets_probe as _targets_probe_cmd,
295
+ targets_terminate as _targets_terminate_cmd,
295
296
  )
296
297
 
297
298
  # Billing management - nested under config
@@ -323,6 +324,11 @@ gpumode_app = typer.Typer(
323
324
  )
324
325
  evaluate_app.add_typer(gpumode_app, name="gpumode")
325
326
 
327
+ # Baseline discovery (what kernel does PyTorch use?)
328
+ from wafer.baseline import baseline_app
329
+
330
+ app.add_typer(baseline_app, name="baseline", rich_help_panel="Kernel Development")
331
+
326
332
  # =============================================================================
327
333
  # Dev commands (internal, used by web app proxy)
328
334
  # =============================================================================
@@ -1533,6 +1539,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
1533
1539
  template_args: list[str] | None = typer.Option(None, "--args"),
1534
1540
  corpus: str | None = typer.Option(None, "--corpus"),
1535
1541
  no_sandbox: bool = typer.Option(False, "--no-sandbox"),
1542
+ no_proxy: bool = typer.Option(False, "--no-proxy", help="Skip wafer proxy, use ANTHROPIC_API_KEY directly"),
1536
1543
  ) -> None:
1537
1544
  agent(
1538
1545
  prompt=prompt,
@@ -1553,6 +1560,7 @@ def _make_agent_alias(name: str, doc: str) -> None:
1553
1560
  template_args=template_args,
1554
1561
  corpus=corpus,
1555
1562
  no_sandbox=no_sandbox,
1563
+ no_proxy=no_proxy, # Must explicitly pass to avoid Typer default object being truthy
1556
1564
  )
1557
1565
 
1558
1566
  alias_cmd.__doc__ = doc
@@ -1592,7 +1600,9 @@ def evaluate( # noqa: PLR0913
1592
1600
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
1593
1601
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
1594
1602
  defensive: bool = typer.Option(
1595
- False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
1603
+ True,
1604
+ "--defense/--no-defense",
1605
+ help="Run reward hack defense checks after benchmarking. Enabled by default.",
1596
1606
  ),
1597
1607
  sync_artifacts: bool = typer.Option(
1598
1608
  True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
@@ -1606,19 +1616,19 @@ def evaluate( # noqa: PLR0913
1606
1616
  The evaluation checks:
1607
1617
  1. Correctness: Does the kernel produce the same output as the reference?
1608
1618
  2. Performance (--benchmark): How fast is it compared to the reference?
1609
- 3. Defense (--defensive): Detects evaluation hacking (stream injection, etc.)
1619
+ 3. Defense: Detects reward hacking (runs automatically with benchmark, disable with --no-defense)
1610
1620
 
1611
1621
  Examples:
1612
1622
  # Basic correctness check
1613
1623
  wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json
1614
1624
 
1615
- # With benchmarking on a specific target
1625
+ # With benchmarking (defense checks run automatically)
1616
1626
  wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
1617
1627
  --target vultr-b200 --benchmark
1618
1628
 
1619
- # Full evaluation with defensive timing (detects cheating)
1629
+ # Benchmarking without defense checks
1620
1630
  wafer evaluate gpumode --impl kernel.py --reference ref.py --test-cases tests.json \\
1621
- --benchmark --defensive
1631
+ --benchmark --no-defense
1622
1632
 
1623
1633
  Subcommands:
1624
1634
  gpumode Use GPUMode format (functional) - RECOMMENDED
@@ -1863,7 +1873,9 @@ def _resolve_pool_query(pool: str, collector) -> tuple[str, object]:
1863
1873
  spec_targets = [t for t in matched_targets if t.spec_name]
1864
1874
  if not spec_targets:
1865
1875
  collector.set_error(
1866
- "pool", "NoSpecTargets", pool=pool,
1876
+ "pool",
1877
+ "NoSpecTargets",
1878
+ pool=pool,
1867
1879
  message="Matched targets have no spec binding — evaluator needs spec fields",
1868
1880
  )
1869
1881
  collector.finalize()
@@ -1963,7 +1975,9 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
1963
1975
  ),
1964
1976
  seed: int = typer.Option(42, "--seed", help="Random seed for weight initialization"),
1965
1977
  defensive: bool = typer.Option(
1966
- False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
1978
+ True,
1979
+ "--defense/--no-defense",
1980
+ help="Run reward hack defense checks after benchmarking. Enabled by default.",
1967
1981
  ),
1968
1982
  backend: str | None = typer.Option(
1969
1983
  None,
@@ -2003,16 +2017,20 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
2003
2017
  The evaluation checks:
2004
2018
  1. Correctness: Does ModelNew.forward() produce same output as Model.forward()?
2005
2019
  2. Performance (--benchmark): How fast is it compared to the reference?
2006
- 3. Defense (--defensive): Detects evaluation hacking
2020
+ 3. Defense: Detects reward hacking (runs automatically with benchmark, disable with --no-defense)
2007
2021
 
2008
2022
  Examples:
2009
2023
  # Basic correctness check
2010
2024
  wafer evaluate kernelbench --impl my_kernel.py --reference problem.py
2011
2025
 
2012
- # With benchmarking
2026
+ # With benchmarking (defense checks run automatically)
2013
2027
  wafer evaluate kernelbench --impl my_kernel.py --reference problem.py \\
2014
2028
  --target vultr-b200 --benchmark
2015
2029
 
2030
+ # Benchmarking without defense checks
2031
+ wafer evaluate kernelbench --impl my_kernel.py --reference problem.py \\
2032
+ --target vultr-b200 --benchmark --no-defense
2033
+
2016
2034
  Subcommands:
2017
2035
  make-template Extract a KernelBench problem as template
2018
2036
  """
@@ -2072,12 +2090,15 @@ def kernelbench_evaluate( # noqa: PLR0913, PLR0915
2072
2090
  if stages == "all":
2073
2091
  resolved_stages = "compile,correctness,benchmark,defense"
2074
2092
 
2075
- # Handle backward compat: --benchmark and --defensive flags add to stages
2093
+ # Handle --benchmark and --defense/--no-defense flags
2076
2094
  stage_set = set(resolved_stages.split(","))
2077
2095
  if benchmark and "benchmark" not in stage_set:
2078
2096
  stage_set.add("benchmark")
2079
- if defensive and "defense" not in stage_set:
2097
+ # Defense runs automatically when benchmarking, unless --no-defense
2098
+ if defensive and "benchmark" in stage_set and "defense" not in stage_set:
2080
2099
  stage_set.add("defense")
2100
+ if not defensive:
2101
+ stage_set.discard("defense")
2081
2102
  resolved_stages = ",".join(
2082
2103
  sorted(
2083
2104
  stage_set,
@@ -2411,7 +2432,9 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2411
2432
  benchmark: bool = typer.Option(False, "--benchmark", help="Run performance benchmarks"),
2412
2433
  profile: bool = typer.Option(False, "--profile", help="Enable profiling"),
2413
2434
  defensive: bool = typer.Option(
2414
- False, "--defensive", help="Enable defensive timing to detect evaluation hacking"
2435
+ True,
2436
+ "--defense/--no-defense",
2437
+ help="Run reward hack defense checks after benchmarking. Enabled by default.",
2415
2438
  ),
2416
2439
  sync_artifacts: bool = typer.Option(
2417
2440
  True, "--sync-artifacts/--no-sync-artifacts", help="Download artifacts"
@@ -2567,307 +2590,6 @@ def gpumode_evaluate( # noqa: PLR0913, PLR0915
2567
2590
  else:
2568
2591
  typer.echo(f"Error: {result.error_message}", err=True)
2569
2592
  raise typer.Exit(1)
2570
-
2571
-
2572
- # =============================================================================
2573
- # Push and Remote-Run commands
2574
- # =============================================================================
2575
-
2576
-
2577
- @app.command("push", hidden=True)
2578
- def push(
2579
- local_path: Path = typer.Argument(..., help="Local directory to upload"),
2580
- workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace name override"),
2581
- direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
2582
- target_name: str | None = typer.Option(
2583
- None,
2584
- "--target",
2585
- "-t",
2586
- help="Target for --direct mode. See 'wafer config targets list'.",
2587
- autocompletion=complete_target_name,
2588
- ),
2589
- ) -> None:
2590
- """Push directory to remote GPU.
2591
-
2592
- By default, uses wafer-api. Use --direct for direct SSH mode.
2593
-
2594
- Examples:
2595
- wafer push ./my_project
2596
- wafer push . --workspace my-kernel
2597
- wafer push ./my_project --direct --target vultr-b200
2598
- """
2599
- # Validate path
2600
- if not local_path.exists():
2601
- typer.echo(f"Error: Path not found: {local_path}", err=True)
2602
- raise typer.Exit(1)
2603
-
2604
- if not local_path.is_dir():
2605
- typer.echo(f"Error: Not a directory: {local_path}", err=True)
2606
- raise typer.Exit(1)
2607
-
2608
- # Resolve to absolute path
2609
- local_path = local_path.resolve()
2610
-
2611
- if direct:
2612
- # Direct SSH mode (requires target)
2613
- if not target_name:
2614
- typer.echo("Error: --target required for --direct mode", err=True)
2615
- raise typer.Exit(1)
2616
-
2617
- from wafer_core.utils.kernel_utils.targets.config import ModalTarget
2618
-
2619
- from .gpu_run import push_directory as push_direct
2620
- from .targets import load_target
2621
-
2622
- try:
2623
- target = load_target(target_name)
2624
- except FileNotFoundError:
2625
- typer.echo(f"Error: Target not found: {target_name}", err=True)
2626
- typer.echo("List targets with: wafer config targets list", err=True)
2627
- raise typer.Exit(1) from None
2628
-
2629
- if isinstance(target, ModalTarget):
2630
- typer.echo(
2631
- f"Error: Target '{target_name}' is a Modal target. Direct push requires SSH.",
2632
- err=True,
2633
- )
2634
- raise typer.Exit(1) from None
2635
-
2636
- typer.echo(f"Connecting to {target.ssh_target}...")
2637
- try:
2638
- result = push_direct(local_path, target)
2639
- except Exception as e:
2640
- typer.echo(f"Error: {e}", err=True)
2641
- raise typer.Exit(1) from None
2642
-
2643
- typer.echo(f"Uploading {len(result.files_uploaded)} files to {result.workspace_path}")
2644
- for f in result.files_uploaded:
2645
- typer.echo(f" ✓ {f}")
2646
- typer.echo(f"Pushed to: {result.workspace_path}")
2647
- else:
2648
- # API mode (default)
2649
- from .api_client import push_directory as push_api
2650
-
2651
- workspace_name = workspace or local_path.name
2652
- typer.echo(f"Pushing {local_path.name} to wafer-api...")
2653
-
2654
- try:
2655
- result = push_api(local_path, workspace_name)
2656
- except Exception as e:
2657
- typer.echo(f"Error: {e}", err=True)
2658
- raise typer.Exit(1) from None
2659
-
2660
- typer.echo(f"Uploaded {len(result.files_uploaded)} files")
2661
- for f in result.files_uploaded:
2662
- typer.echo(f" ✓ {f}")
2663
- typer.echo(f"Workspace ID: {result.workspace_id}")
2664
-
2665
-
2666
- def _run_direct_mode(
2667
- cmd_str: str,
2668
- target_name: str,
2669
- upload_dir: Path | None,
2670
- workspace_id: str | None,
2671
- gpu_id: int | None,
2672
- ) -> int:
2673
- """Run command via direct SSH mode. Returns exit code."""
2674
- from wafer_core.utils.kernel_utils.targets.config import ModalTarget
2675
-
2676
- from .gpu_run import push_directory as push_direct
2677
- from .gpu_run import run_command as run_direct
2678
- from .targets import load_target
2679
-
2680
- try:
2681
- target = load_target(target_name)
2682
- except FileNotFoundError:
2683
- typer.echo(f"Error: Target not found: {target_name}", err=True)
2684
- typer.echo("List targets with: wafer config targets list", err=True)
2685
- raise typer.Exit(1) from None
2686
-
2687
- if isinstance(target, ModalTarget):
2688
- typer.echo(
2689
- f"Error: Target '{target_name}' is a Modal target. Direct mode requires SSH.", err=True
2690
- )
2691
- raise typer.Exit(1) from None
2692
-
2693
- if not target.docker_image:
2694
- typer.echo(f"Error: Target '{target_name}' has no docker_image configured", err=True)
2695
- raise typer.Exit(1)
2696
-
2697
- # If upload_dir provided, push first
2698
- workspace_name = workspace_id
2699
- if upload_dir:
2700
- typer.echo(f"Uploading {upload_dir.name}...")
2701
- try:
2702
- push_result = push_direct(upload_dir, target)
2703
- workspace_name = push_result.workspace_name
2704
- typer.echo(f"Uploaded {len(push_result.files_uploaded)} files")
2705
- except Exception as e:
2706
- typer.echo(f"Error uploading: {e}", err=True)
2707
- raise typer.Exit(1) from None
2708
- elif not workspace_name:
2709
- workspace_name = "tmp"
2710
-
2711
- effective_gpu = gpu_id if gpu_id is not None else target.gpu_ids[0]
2712
- typer.echo(f"Target: {target_name} (docker: {target.docker_image})")
2713
- typer.echo(f"Workspace: {workspace_name}")
2714
- typer.echo(f"GPU: {effective_gpu}")
2715
- typer.echo(f"Command: {cmd_str}")
2716
- typer.echo("-" * 60)
2717
-
2718
- try:
2719
- return run_direct(cmd_str, workspace_name, target, gpu_id)
2720
- except KeyboardInterrupt:
2721
- typer.echo("\nInterrupted by user", err=True)
2722
- raise typer.Exit(130) from None
2723
- except Exception as e:
2724
- typer.echo(f"Error: {e}", err=True)
2725
- raise typer.Exit(1) from None
2726
-
2727
-
2728
- def _run_api_mode( # noqa: PLR0913
2729
- cmd_str: str,
2730
- upload_dir: Path | None,
2731
- workspace_id: str | None,
2732
- gpu_id: int | None,
2733
- gpu_count: int,
2734
- docker_image: str | None,
2735
- docker_entrypoint: str | None,
2736
- pull_image: bool,
2737
- require_hwc: bool,
2738
- ) -> int:
2739
- """Run command via wafer-api. Returns exit code."""
2740
- from .api_client import run_command_stream
2741
-
2742
- if upload_dir:
2743
- typer.echo(f"Uploading: {upload_dir}")
2744
- elif workspace_id:
2745
- typer.echo(f"Workspace: {workspace_id}")
2746
- if gpu_id is not None:
2747
- typer.echo(f"GPU: {gpu_id}")
2748
- if gpu_count > 1:
2749
- typer.echo(f"GPU count: {gpu_count}")
2750
- if docker_image:
2751
- typer.echo(f"Image: {docker_image}")
2752
- if docker_entrypoint:
2753
- typer.echo(f"Entrypoint: {docker_entrypoint}")
2754
- if pull_image:
2755
- typer.echo("Pull image: yes")
2756
- typer.echo(f"Command: {cmd_str}")
2757
- if require_hwc:
2758
- typer.echo("Hardware counters: required (baremetal)")
2759
- typer.echo("-" * 60)
2760
-
2761
- try:
2762
- return run_command_stream(
2763
- command=cmd_str,
2764
- upload_dir=upload_dir,
2765
- workspace_id=workspace_id,
2766
- gpu_id=gpu_id,
2767
- gpu_count=gpu_count,
2768
- docker_image=docker_image,
2769
- docker_entrypoint=docker_entrypoint,
2770
- pull_image=pull_image,
2771
- require_hardware_counters=require_hwc,
2772
- )
2773
- except KeyboardInterrupt:
2774
- typer.echo("\nInterrupted by user", err=True)
2775
- raise typer.Exit(130) from None
2776
- except Exception as e:
2777
- typer.echo(f"Error: {e}", err=True)
2778
- raise typer.Exit(1) from None
2779
-
2780
-
2781
- @app.command("remote-run", hidden=True)
2782
- def remote_run( # noqa: PLR0913
2783
- command: list[str] = typer.Argument(..., help="Command to run"),
2784
- upload_dir: Path | None = typer.Option(
2785
- None, "--upload-dir", "-u", help="Directory to upload (stateless mode)"
2786
- ),
2787
- workspace_id: str | None = typer.Option(
2788
- None, "--workspace-id", "-w", help="Workspace ID (from wafer push)"
2789
- ),
2790
- gpu_id: int | None = typer.Option(None, "--gpu", "-g", help="GPU ID"),
2791
- gpu_count: int = typer.Option(1, "--gpu-count", "-n", help="Number of GPUs (1-8)"),
2792
- docker_image: str | None = typer.Option(None, "--image", "-i", help="Docker image override"),
2793
- docker_entrypoint: str | None = typer.Option(
2794
- None, "--docker-entrypoint", help="Override Docker entrypoint (e.g., 'bash')"
2795
- ),
2796
- pull_image: bool = typer.Option(
2797
- False, "--pull-image", help="Pull image if not available on target"
2798
- ),
2799
- require_hwc: bool = typer.Option(
2800
- False, "--require-hwc", help="Require hardware counters (baremetal)"
2801
- ),
2802
- direct: bool = typer.Option(False, "--direct", "-d", help="Use direct SSH instead of API"),
2803
- target_name: str | None = typer.Option(
2804
- None,
2805
- "--target",
2806
- "-t",
2807
- help="Target for --direct mode. See 'wafer config targets list'.",
2808
- autocompletion=complete_target_name,
2809
- ),
2810
- ) -> None:
2811
- """Run command on remote GPU in Docker.
2812
-
2813
- Two modes:
2814
- - High-level (stateless): --upload-dir uploads files and runs command
2815
- - Low-level: --workspace-id uses existing workspace from 'wafer push'
2816
-
2817
- By default, uses wafer-api. Use --direct for direct SSH mode.
2818
-
2819
- Examples:
2820
- # Stateless: upload and run
2821
- wafer remote-run --upload-dir ./my_project -- python train.py
2822
-
2823
- # Run without files
2824
- wafer remote-run -- nvidia-smi
2825
-
2826
- # Low-level: use existing workspace
2827
- wafer remote-run --workspace-id ws_abc123 -- python train.py
2828
-
2829
- # Direct SSH mode
2830
- wafer remote-run --upload-dir ./my_project --direct --target vultr-b200 -- python train.py
2831
- """
2832
- cmd_str = " ".join(command)
2833
- if not cmd_str.strip():
2834
- typer.echo("Error: Empty command", err=True)
2835
- raise typer.Exit(1)
2836
-
2837
- if upload_dir and workspace_id:
2838
- typer.echo("Error: --upload-dir and --workspace-id are mutually exclusive", err=True)
2839
- raise typer.Exit(1)
2840
-
2841
- if upload_dir:
2842
- if not upload_dir.exists():
2843
- typer.echo(f"Error: Directory not found: {upload_dir}", err=True)
2844
- raise typer.Exit(1)
2845
- if not upload_dir.is_dir():
2846
- typer.echo(f"Error: Not a directory: {upload_dir}", err=True)
2847
- raise typer.Exit(1)
2848
- upload_dir = upload_dir.resolve()
2849
-
2850
- if direct:
2851
- if not target_name:
2852
- typer.echo("Error: --target required for --direct mode", err=True)
2853
- raise typer.Exit(1)
2854
- exit_code = _run_direct_mode(cmd_str, target_name, upload_dir, workspace_id, gpu_id)
2855
- else:
2856
- exit_code = _run_api_mode(
2857
- cmd_str,
2858
- upload_dir,
2859
- workspace_id,
2860
- gpu_id,
2861
- gpu_count,
2862
- docker_image,
2863
- docker_entrypoint,
2864
- pull_image,
2865
- require_hwc,
2866
- )
2867
-
2868
- raise typer.Exit(exit_code)
2869
-
2870
-
2871
2593
  # =============================================================================
2872
2594
  # Authentication commands
2873
2595
  # =============================================================================
@@ -6114,7 +5836,7 @@ def ncu_analyze(
6114
5836
  By default, uses local NCU if available, otherwise runs analysis
6115
5837
  remotely via wafer-api (requires authentication: wafer auth login).
6116
5838
 
6117
- Use --target for direct SSH mode (like wafer remote-run --direct).
5839
+ Use --target for direct SSH mode.
6118
5840
  Use --include-source to fetch SASS assembly with register/instruction data.
6119
5841
 
6120
5842
  Examples:
@@ -7988,7 +7710,7 @@ def compare_fusion_cmd(
7988
7710
  wafer compare fusion amd_trace.json nvidia_trace.json --format csv -o fusion.csv
7989
7711
  """
7990
7712
  from .trace_compare import compare_align
7991
-
7713
+
7992
7714
  compare_align(
7993
7715
  trace1=trace1,
7994
7716
  trace2=trace2,
@@ -8042,7 +7764,7 @@ def compare_align_cmd(
8042
7764
  wafer compare align amd_trace.json nvidia_trace.json --layer 5
8043
7765
  """
8044
7766
  from .trace_compare import compare_align
8045
-
7767
+
8046
7768
  compare_align(
8047
7769
  trace1=trace1,
8048
7770
  trace2=trace2,
wafer/evaluate.py CHANGED
@@ -78,9 +78,10 @@ def _build_docker_run_command(
78
78
  for cap in cap_add:
79
79
  parts.extend(["--cap-add", cap])
80
80
 
81
- # GPU access - use single quotes for the device spec to avoid shell escaping issues
81
+ # GPU access - use --runtime=nvidia alongside --gpus for compatibility
82
+ # with newer NVIDIA drivers (580+) where --gpus alone may not initialize CUDA
82
83
  if gpus:
83
- parts.extend(["--gpus", f"'{gpus}'"])
84
+ parts.extend(["--runtime=nvidia", "--gpus", f"'{gpus}'"])
84
85
 
85
86
  # Volume mounts
86
87
  if volumes:
@@ -3159,15 +3160,35 @@ def main():
3159
3160
  inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs]
3160
3161
 
3161
3162
  if run_defense and defense_module is not None:
3162
- # Use full defense suite
3163
+ # Use extended defense suite (Makora taxonomy + CUDA-L2)
3163
3164
  print("[KernelBench] Running defense checks on implementation...")
3164
- run_all_defenses = defense_module.run_all_defenses
3165
+ run_extended = defense_module.run_all_defenses_extended
3165
3166
  time_with_defenses = defense_module.time_execution_with_defenses
3166
3167
 
3167
- # Run defense checks on implementation
3168
- all_passed, defense_results, _ = run_all_defenses(
3168
+ # Read source code for LLM adversarial evaluator
3169
+ _problem_code = None
3170
+ _kernel_code = None
3171
+ try:
3172
+ _problem_code = Path(args.reference).read_text()
3173
+ _kernel_code = Path(args.impl).read_text()
3174
+ except Exception:
3175
+ pass
3176
+
3177
+ # Input generator for caching/multi-input checks
3178
+ def _input_generator():
3179
+ _ins = get_inputs()
3180
+ return tuple(x.cuda() if isinstance(x, torch.Tensor) else x for x in _ins)
3181
+
3182
+ # Run all defense checks (original + extended)
3183
+ all_passed, defense_results, _ = run_extended(
3169
3184
  lambda *x: new_model(*x),
3170
3185
  *inputs,
3186
+ reference_fn=lambda *x: ref_model(*x),
3187
+ input_generator=_input_generator,
3188
+ test_shapes=[(128, 128), (256, 256), (512, 512)],
3189
+ check_precision_ulp=True,
3190
+ problem_code=_problem_code,
3191
+ kernel_code=_kernel_code,
3171
3192
  )
3172
3193
  results["defense_results"] = {
3173
3194
  name: {"passed": passed, "message": msg}
@@ -35,7 +35,8 @@ Strategy:
35
35
  Commands:
36
36
  - `wafer evaluate --impl <file> --reference <ref> --test-cases <tests>` - Run evaluation
37
37
  - `wafer evaluate --impl <file> --reference <ref> --test-cases <tests> --profile` - With NCU profiling
38
- - `wafer remote-run "<command>"` - Run arbitrary commands on remote GPU
38
+ - `wafer workspaces exec -- <command>` - Run arbitrary commands on remote GPU
39
+ - `wafer targets exec <target> -- <command>` - Run commands on a configured target via SSH
39
40
 
40
41
  Output:
41
42
  - Summary of optimizations applied
@@ -48,7 +49,8 @@ IMPORTANT: Always verify correctness with wafer evaluate before claiming success
48
49
  tools=["read", "write", "edit", "glob", "grep", "bash"],
49
50
  bash_allowlist=[
50
51
  "wafer evaluate",
51
- "wafer remote-run",
52
+ "wafer workspaces exec",
53
+ "wafer targets exec",
52
54
  "wafer nvidia ncu",
53
55
  "wafer nvidia nsys",
54
56
  "wafer nvidia perfetto",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.30
3
+ Version: 0.2.32
4
4
  Summary: CLI for running GPU workloads, managing remote workspaces, and evaluating/optimizing kernels
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -5,12 +5,13 @@ wafer/analytics.py,sha256=qLY6Z16usVHFD8TCv7XBuz7l47vXVdXk-qhOzA-hW_8,8179
5
5
  wafer/api_client.py,sha256=i_Az2b2llC3DSW8yOL-BKqa7LSKuxOr8hSN40s-oQXY,6313
6
6
  wafer/auth.py,sha256=dwss_se5P-FFc9IN38q4kh_dBrA6k-CguDBkivgcdj0,14003
7
7
  wafer/autotuner.py,sha256=41WYP41pTDvMijv2h42vm89bcHtDMJXObDlWmn6xpFU,44416
8
+ wafer/baseline.py,sha256=OrGCAut_xtkH9Ogx4mMU5-94Q0oClIXqac94YRwqERY,21534
8
9
  wafer/billing.py,sha256=hEEwtrtIsbPQ3lLJNcyTLMsapUbcuvcVW_e9_0SxzVo,7199
9
- wafer/cli.py,sha256=zuVZhPdML5AOBtLUqLwAwjl8XMNe9EwQkffZxtBGLx4,282748
10
+ wafer/cli.py,sha256=jHh4EcCGheDq14E11rdSHXImMdriMSFb2vNcvhsV59A,273228
10
11
  wafer/cli_instructions.py,sha256=bziUKDNDAXABVMvKPLEMXm-hFSD2TcFSh-FKRYa949k,4693
11
12
  wafer/config.py,sha256=h5Eo9_yfWqWGoPNdVQikI9GoZVUeysunSYiixf1mKcw,3411
12
13
  wafer/corpus.py,sha256=CY9T7wXENNDJxnrtI-XsQmXeptrFfKG4x-lngrc9_3s,24748
13
- wafer/evaluate.py,sha256=QswzCD0CZRT2jpzpeekjNezEPbKHZnvVI7KRQZas8LA,186310
14
+ wafer/evaluate.py,sha256=i15PliAVI3W04_4eju46PBDdh2BwSToLME5n7yGu7dU,187355
14
15
  wafer/global_config.py,sha256=iu1HbTDr1695tSeDG2NfkK7PiY7XD6vjCk37w1wHbgk,11920
15
16
  wafer/gpu_run.py,sha256=TwqXy72T7f2I7e6n5WWod3xgxCPnDhU0BgLsB4CUoQY,9716
16
17
  wafer/inference.py,sha256=tZCO5i05FKY27ewis3CSBHFBeFbXY3xwj0DSjdoMY9s,4314
@@ -36,12 +37,12 @@ wafer/workspaces.py,sha256=J-TXGwHXSZlzRWCew63KNvk6HLJ-zTSELRgzjryTkMk,35710
36
37
  wafer/skills/wafer-guide/SKILL.md,sha256=UDsXCD5Kb-lDParKCTf2WkE3kodVs-rja8XeumSBO5U,3934
37
38
  wafer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
39
  wafer/templates/ask_docs.py,sha256=15t1Aa4WBMwMox8XmFdzyosOZfBLMdXyaxo3GDb7nTE,2254
39
- wafer/templates/optimize_kernel.py,sha256=4-MaKm_C9BQHQEllrNLLYkcdhJpcj6D-8zbJ4FdLUEY,2444
40
+ wafer/templates/optimize_kernel.py,sha256=Q4FA_8ECEegW_3DS51mkLCX6Vk1dcWWzY3A_RQ4NW8U,2576
40
41
  wafer/templates/optimize_kernelbench.py,sha256=T3co9Y9eSLWDrZG66gwQVFMdnGVoyUQos-TxnMMBLL8,3747
41
42
  wafer/templates/trace_analyze.py,sha256=B7CiRlsokERzBjLL-k49kGjpU2zlJZqzTE05xbRS1WI,2878
42
43
  wafer/tests/test_eval_cli_parity.py,sha256=SGmaj2NGBZ7GdDF53bXsECvQbV21iHZw8YeL_MJOLk0,7206
43
- wafer_cli-0.2.30.dist-info/METADATA,sha256=4mvnQyUVD_irJIdhwj0ECm5FEDVXlcmFjklLzZqq1V8,2799
44
- wafer_cli-0.2.30.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
45
- wafer_cli-0.2.30.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
46
- wafer_cli-0.2.30.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
47
- wafer_cli-0.2.30.dist-info/RECORD,,
44
+ wafer_cli-0.2.32.dist-info/METADATA,sha256=snTvnaN37WG1rXCxW2YibK6CtBX9lhe8mOxEi9Th5iA,2799
45
+ wafer_cli-0.2.32.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
46
+ wafer_cli-0.2.32.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
47
+ wafer_cli-0.2.32.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
48
+ wafer_cli-0.2.32.dist-info/RECORD,,