wafer-core 0.1.38__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. wafer_core/lib/trace_compare/fusion_analyzer.py +2 -0
  2. wafer_core/rollouts/_logging/__init__.py +5 -1
  3. wafer_core/rollouts/_logging/logging_config.py +95 -3
  4. wafer_core/rollouts/_logging/sample_handler.py +66 -0
  5. wafer_core/rollouts/_pytui/__init__.py +114 -0
  6. wafer_core/rollouts/_pytui/app.py +809 -0
  7. wafer_core/rollouts/_pytui/console.py +291 -0
  8. wafer_core/rollouts/_pytui/renderer.py +210 -0
  9. wafer_core/rollouts/_pytui/spinner.py +73 -0
  10. wafer_core/rollouts/_pytui/terminal.py +489 -0
  11. wafer_core/rollouts/_pytui/text.py +470 -0
  12. wafer_core/rollouts/_pytui/theme.py +241 -0
  13. wafer_core/rollouts/evaluation.py +142 -177
  14. wafer_core/rollouts/progress_app.py +395 -0
  15. wafer_core/rollouts/tui/DESIGN.md +251 -115
  16. wafer_core/rollouts/tui/monitor.py +64 -20
  17. wafer_core/tools/compile/__init__.py +30 -0
  18. wafer_core/tools/compile/benchmark.py +636 -0
  19. wafer_core/tools/compile/compiler.py +301 -0
  20. wafer_core/tools/compile/modal_compile.py +369 -0
  21. wafer_core/tools/compile/tests/__init__.py +1 -0
  22. wafer_core/tools/compile/tests/test_compiler.py +675 -0
  23. wafer_core/tools/compile/tests/test_data/utils.cuh +10 -0
  24. wafer_core/tools/compile/tests/test_data/vector_add.cu +7 -0
  25. wafer_core/tools/compile/tests/test_data/with_header.cu +9 -0
  26. wafer_core/tools/compile/tests/test_modal_integration.py +326 -0
  27. wafer_core/tools/compile/types.py +117 -0
  28. {wafer_core-0.1.38.dist-info → wafer_core-0.1.40.dist-info}/METADATA +1 -1
  29. {wafer_core-0.1.38.dist-info → wafer_core-0.1.40.dist-info}/RECORD +30 -12
  30. wafer_core/rollouts/events.py +0 -240
  31. wafer_core/rollouts/progress_display.py +0 -476
  32. wafer_core/utils/event_streaming.py +0 -63
  33. {wafer_core-0.1.38.dist-info → wafer_core-0.1.40.dist-info}/WHEEL +0 -0
@@ -0,0 +1,636 @@
1
+ """Benchmark script for CUDA compilation performance.
2
+
3
+ This script measures compilation time for kernels of different sizes
4
+ to track performance improvements from optimizations.
5
+
6
+ Usage:
7
+ python -m wafer_core.tools.compile.benchmark
8
+
9
+ # Or with specific test:
10
+ python -m wafer_core.tools.compile.benchmark --kernel simple
11
+ python -m wafer_core.tools.compile.benchmark --kernel medium
12
+ python -m wafer_core.tools.compile.benchmark --kernel complex
13
+ """
14
+
15
+ import argparse
16
+ import statistics
17
+ import time
18
+ from typing import NamedTuple
19
+
20
+ # ============================================================================
21
+ # Test Kernels
22
+ # ============================================================================
23
+
24
+ SIMPLE_KERNEL = """\
25
+ // Simple vector addition kernel (~20 lines)
26
+ __global__ void vector_add(float* a, float* b, float* c, int n) {
27
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
28
+ if (idx < n) {
29
+ c[idx] = a[idx] + b[idx];
30
+ }
31
+ }
32
+ """
33
+
34
+ MEDIUM_KERNEL = """\
35
+ // Medium complexity kernel with shared memory (~100 lines)
36
+ #include <cuda_runtime.h>
37
+
38
+ #define TILE_SIZE 16
39
+
40
+ __global__ void tiled_matmul(
41
+ const float* __restrict__ A,
42
+ const float* __restrict__ B,
43
+ float* __restrict__ C,
44
+ int M, int N, int K
45
+ ) {
46
+ __shared__ float As[TILE_SIZE][TILE_SIZE];
47
+ __shared__ float Bs[TILE_SIZE][TILE_SIZE];
48
+
49
+ int bx = blockIdx.x, by = blockIdx.y;
50
+ int tx = threadIdx.x, ty = threadIdx.y;
51
+
52
+ int row = by * TILE_SIZE + ty;
53
+ int col = bx * TILE_SIZE + tx;
54
+
55
+ float sum = 0.0f;
56
+
57
+ for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
58
+ // Load tile from A
59
+ if (row < M && t * TILE_SIZE + tx < K) {
60
+ As[ty][tx] = A[row * K + t * TILE_SIZE + tx];
61
+ } else {
62
+ As[ty][tx] = 0.0f;
63
+ }
64
+
65
+ // Load tile from B
66
+ if (t * TILE_SIZE + ty < K && col < N) {
67
+ Bs[ty][tx] = B[(t * TILE_SIZE + ty) * N + col];
68
+ } else {
69
+ Bs[ty][tx] = 0.0f;
70
+ }
71
+
72
+ __syncthreads();
73
+
74
+ // Compute partial dot product
75
+ #pragma unroll
76
+ for (int k = 0; k < TILE_SIZE; k++) {
77
+ sum = fmaf(As[ty][k], Bs[k][tx], sum);
78
+ }
79
+
80
+ __syncthreads();
81
+ }
82
+
83
+ if (row < M && col < N) {
84
+ C[row * N + col] = sum;
85
+ }
86
+ }
87
+
88
+ // Reduction kernel
89
+ __global__ void reduce_sum(const float* input, float* output, int n) {
90
+ extern __shared__ float sdata[];
91
+
92
+ unsigned int tid = threadIdx.x;
93
+ unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
94
+
95
+ float mySum = (i < n) ? input[i] : 0.0f;
96
+ if (i + blockDim.x < n) {
97
+ mySum += input[i + blockDim.x];
98
+ }
99
+ sdata[tid] = mySum;
100
+ __syncthreads();
101
+
102
+ for (unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
103
+ if (tid < s) {
104
+ sdata[tid] = mySum = mySum + sdata[tid + s];
105
+ }
106
+ __syncthreads();
107
+ }
108
+
109
+ if (tid < 32) {
110
+ volatile float* smem = sdata;
111
+ smem[tid] = mySum = mySum + smem[tid + 32];
112
+ smem[tid] = mySum = mySum + smem[tid + 16];
113
+ smem[tid] = mySum = mySum + smem[tid + 8];
114
+ smem[tid] = mySum = mySum + smem[tid + 4];
115
+ smem[tid] = mySum = mySum + smem[tid + 2];
116
+ smem[tid] = mySum = mySum + smem[tid + 1];
117
+ }
118
+
119
+ if (tid == 0) {
120
+ output[blockIdx.x] = sdata[0];
121
+ }
122
+ }
123
+ """
124
+
125
+ COMPLEX_KERNEL = """\
126
+ // Complex kernel with multiple features (~500 lines)
127
+ #include <cuda_runtime.h>
128
+ #include <cooperative_groups.h>
129
+
130
+ namespace cg = cooperative_groups;
131
+
132
+ // Constants
133
+ constexpr int BLOCK_SIZE = 256;
134
+ constexpr int TILE_SIZE = 16;
135
+
136
+ // ============================================================================
137
+ // Kernel 1: Vector operations with shared memory and reduction
138
+ // ============================================================================
139
+ template <typename T, int BlockSize>
140
+ __global__ void reduceSum(const T* __restrict__ input, T* __restrict__ output, int N) {
141
+ __shared__ T sdata[BlockSize];
142
+
143
+ unsigned int tid = threadIdx.x;
144
+ unsigned int i = blockIdx.x * (blockDim.x * 2) + threadIdx.x;
145
+
146
+ T mySum = (i < N) ? input[i] : T(0);
147
+ if (i + blockDim.x < N) {
148
+ mySum += input[i + blockDim.x];
149
+ }
150
+ sdata[tid] = mySum;
151
+ __syncthreads();
152
+
153
+ #pragma unroll
154
+ for (unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
155
+ if (tid < s) {
156
+ sdata[tid] = mySum = mySum + sdata[tid + s];
157
+ }
158
+ __syncthreads();
159
+ }
160
+
161
+ if (tid < 32) {
162
+ volatile T* smem = sdata;
163
+ if (BlockSize >= 64) mySum += smem[tid + 32];
164
+ smem[tid] = mySum;
165
+ if (BlockSize >= 32) mySum += smem[tid + 16];
166
+ smem[tid] = mySum;
167
+ if (BlockSize >= 16) mySum += smem[tid + 8];
168
+ smem[tid] = mySum;
169
+ if (BlockSize >= 8) mySum += smem[tid + 4];
170
+ smem[tid] = mySum;
171
+ if (BlockSize >= 4) mySum += smem[tid + 2];
172
+ smem[tid] = mySum;
173
+ if (BlockSize >= 2) mySum += smem[tid + 1];
174
+ smem[tid] = mySum;
175
+ }
176
+
177
+ if (tid == 0) {
178
+ output[blockIdx.x] = sdata[0];
179
+ }
180
+ }
181
+
182
+ // ============================================================================
183
+ // Kernel 2: Matrix transpose with shared memory
184
+ // ============================================================================
185
+ __global__ void matrixTranspose(const float* __restrict__ input,
186
+ float* __restrict__ output,
187
+ int width, int height) {
188
+ __shared__ float tile[TILE_SIZE][TILE_SIZE + 1];
189
+
190
+ int xIndex = blockIdx.x * TILE_SIZE + threadIdx.x;
191
+ int yIndex = blockIdx.y * TILE_SIZE + threadIdx.y;
192
+
193
+ if (xIndex < width && yIndex < height) {
194
+ tile[threadIdx.y][threadIdx.x] = input[yIndex * width + xIndex];
195
+ }
196
+
197
+ __syncthreads();
198
+
199
+ xIndex = blockIdx.y * TILE_SIZE + threadIdx.x;
200
+ yIndex = blockIdx.x * TILE_SIZE + threadIdx.y;
201
+
202
+ if (xIndex < height && yIndex < width) {
203
+ output[yIndex * height + xIndex] = tile[threadIdx.x][threadIdx.y];
204
+ }
205
+ }
206
+
207
+ // ============================================================================
208
+ // Kernel 3: Softmax with cooperative groups
209
+ // ============================================================================
210
+ __global__ void softmaxKernel(const float* __restrict__ input,
211
+ float* __restrict__ output,
212
+ int N, int stride) {
213
+ cg::thread_block block = cg::this_thread_block();
214
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
215
+
216
+ extern __shared__ float shared[];
217
+
218
+ int row = blockIdx.x;
219
+ const float* rowInput = input + row * stride;
220
+ float* rowOutput = output + row * stride;
221
+
222
+ float maxVal = -INFINITY;
223
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
224
+ maxVal = fmaxf(maxVal, rowInput[i]);
225
+ }
226
+
227
+ for (int offset = warp.size() / 2; offset > 0; offset /= 2) {
228
+ maxVal = fmaxf(maxVal, warp.shfl_down(maxVal, offset));
229
+ }
230
+
231
+ if (warp.thread_rank() == 0) {
232
+ shared[threadIdx.x / 32] = maxVal;
233
+ }
234
+ block.sync();
235
+
236
+ if (threadIdx.x < blockDim.x / 32) {
237
+ maxVal = shared[threadIdx.x];
238
+ } else {
239
+ maxVal = -INFINITY;
240
+ }
241
+
242
+ for (int offset = 16; offset > 0; offset /= 2) {
243
+ maxVal = fmaxf(maxVal, __shfl_down_sync(0xffffffff, maxVal, offset));
244
+ }
245
+ maxVal = __shfl_sync(0xffffffff, maxVal, 0);
246
+
247
+ float sum = 0.0f;
248
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
249
+ float val = expf(rowInput[i] - maxVal);
250
+ rowOutput[i] = val;
251
+ sum += val;
252
+ }
253
+
254
+ for (int offset = warp.size() / 2; offset > 0; offset /= 2) {
255
+ sum += warp.shfl_down(sum, offset);
256
+ }
257
+
258
+ if (warp.thread_rank() == 0) {
259
+ shared[threadIdx.x / 32] = sum;
260
+ }
261
+ block.sync();
262
+
263
+ if (threadIdx.x < blockDim.x / 32) {
264
+ sum = shared[threadIdx.x];
265
+ } else {
266
+ sum = 0.0f;
267
+ }
268
+
269
+ for (int offset = 16; offset > 0; offset /= 2) {
270
+ sum += __shfl_down_sync(0xffffffff, sum, offset);
271
+ }
272
+ sum = __shfl_sync(0xffffffff, sum, 0);
273
+
274
+ float invSum = 1.0f / sum;
275
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
276
+ rowOutput[i] *= invSum;
277
+ }
278
+ }
279
+
280
+ // ============================================================================
281
+ // Kernel 4: Fused multiply-add with vectorized loads
282
+ // ============================================================================
283
+ __global__ void fusedMulAddVec4(const float4* __restrict__ A,
284
+ const float4* __restrict__ B,
285
+ float4* __restrict__ C,
286
+ float alpha, float beta, int N) {
287
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
288
+
289
+ if (idx < N) {
290
+ float4 a = A[idx];
291
+ float4 b = B[idx];
292
+ float4 c;
293
+
294
+ c.x = fmaf(alpha, a.x, beta * b.x);
295
+ c.y = fmaf(alpha, a.y, beta * b.y);
296
+ c.z = fmaf(alpha, a.z, beta * b.z);
297
+ c.w = fmaf(alpha, a.w, beta * b.w);
298
+
299
+ C[idx] = c;
300
+ }
301
+ }
302
+
303
+ // ============================================================================
304
+ // Kernel 5: Simple GEMM with shared memory tiling
305
+ // ============================================================================
306
+ __global__ void matmulTiled(const float* __restrict__ A,
307
+ const float* __restrict__ B,
308
+ float* __restrict__ C,
309
+ int M, int N, int K) {
310
+ __shared__ float As[TILE_SIZE][TILE_SIZE];
311
+ __shared__ float Bs[TILE_SIZE][TILE_SIZE];
312
+
313
+ int row = blockIdx.y * TILE_SIZE + threadIdx.y;
314
+ int col = blockIdx.x * TILE_SIZE + threadIdx.x;
315
+
316
+ float sum = 0.0f;
317
+
318
+ for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
319
+ int tiledCol = t * TILE_SIZE + threadIdx.x;
320
+ int tiledRow = t * TILE_SIZE + threadIdx.y;
321
+
322
+ As[threadIdx.y][threadIdx.x] = (row < M && tiledCol < K) ?
323
+ A[row * K + tiledCol] : 0.0f;
324
+ Bs[threadIdx.y][threadIdx.x] = (tiledRow < K && col < N) ?
325
+ B[tiledRow * N + col] : 0.0f;
326
+
327
+ __syncthreads();
328
+
329
+ #pragma unroll
330
+ for (int k = 0; k < TILE_SIZE; k++) {
331
+ sum = fmaf(As[threadIdx.y][k], Bs[k][threadIdx.x], sum);
332
+ }
333
+
334
+ __syncthreads();
335
+ }
336
+
337
+ if (row < M && col < N) {
338
+ C[row * N + col] = sum;
339
+ }
340
+ }
341
+
342
+ // ============================================================================
343
+ // Device helper functions
344
+ // ============================================================================
345
+ __device__ __forceinline__ float warpReduceSum(float val) {
346
+ for (int offset = 16; offset > 0; offset /= 2) {
347
+ val += __shfl_down_sync(0xffffffff, val, offset);
348
+ }
349
+ return val;
350
+ }
351
+
352
+ __device__ __forceinline__ float blockReduceSum(float val) {
353
+ __shared__ float shared[32];
354
+
355
+ int lane = threadIdx.x % 32;
356
+ int wid = threadIdx.x / 32;
357
+
358
+ val = warpReduceSum(val);
359
+
360
+ if (lane == 0) shared[wid] = val;
361
+ __syncthreads();
362
+
363
+ val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0f;
364
+
365
+ if (wid == 0) val = warpReduceSum(val);
366
+
367
+ return val;
368
+ }
369
+
370
+ // ============================================================================
371
+ // Kernel 6: Layer normalization
372
+ // ============================================================================
373
+ __global__ void layerNorm(const float* __restrict__ input,
374
+ const float* __restrict__ gamma,
375
+ const float* __restrict__ beta,
376
+ float* __restrict__ output,
377
+ int N, float eps) {
378
+ int row = blockIdx.x;
379
+ const float* rowInput = input + row * N;
380
+ float* rowOutput = output + row * N;
381
+
382
+ float sum = 0.0f;
383
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
384
+ sum += rowInput[i];
385
+ }
386
+ sum = blockReduceSum(sum);
387
+
388
+ __shared__ float s_mean, s_var;
389
+ if (threadIdx.x == 0) {
390
+ s_mean = sum / N;
391
+ }
392
+ __syncthreads();
393
+
394
+ float var = 0.0f;
395
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
396
+ float diff = rowInput[i] - s_mean;
397
+ var += diff * diff;
398
+ }
399
+ var = blockReduceSum(var);
400
+
401
+ if (threadIdx.x == 0) {
402
+ s_var = rsqrtf(var / N + eps);
403
+ }
404
+ __syncthreads();
405
+
406
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
407
+ float normalized = (rowInput[i] - s_mean) * s_var;
408
+ rowOutput[i] = fmaf(normalized, gamma[i], beta[i]);
409
+ }
410
+ }
411
+ """
412
+
413
+
414
+ # ============================================================================
415
+ # Benchmark Results
416
+ # ============================================================================
417
+
418
+ class BenchmarkResult(NamedTuple):
419
+ """Result of a single benchmark run."""
420
+ kernel_name: str
421
+ kernel_lines: int
422
+ compile_time_ms: int
423
+ success: bool
424
+ ptx_lines: int | None
425
+ sass_lines: int | None
426
+ error: str | None
427
+
428
+
429
+ def count_lines(code: str) -> int:
430
+ """Count non-empty lines in code."""
431
+ return len([line for line in code.split('\n') if line.strip()])
432
+
433
+
434
+ def run_benchmark(
435
+ kernel_name: str,
436
+ kernel_code: str,
437
+ arch: str = "sm_90a",
438
+ output_formats: list[str] | None = None,
439
+ num_runs: int = 3,
440
+ ) -> list[BenchmarkResult]:
441
+ """Run benchmark for a kernel.
442
+
443
+ Args:
444
+ kernel_name: Name of the kernel for reporting
445
+ kernel_code: CUDA source code
446
+ arch: Target architecture
447
+ output_formats: Output formats to request (default: ["ptx", "sass"])
448
+ num_runs: Number of benchmark runs
449
+
450
+ Returns:
451
+ List of BenchmarkResult for each run
452
+ """
453
+ import modal
454
+
455
+ if output_formats is None:
456
+ output_formats = ["ptx", "sass"]
457
+
458
+ # Get the deployed function
459
+ compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
460
+
461
+ kernel_lines = count_lines(kernel_code)
462
+ results: list[BenchmarkResult] = []
463
+
464
+ for run in range(num_runs):
465
+ print(f" Run {run + 1}/{num_runs}...", end=" ", flush=True)
466
+
467
+ start_time = time.time()
468
+
469
+ try:
470
+ result = compile_fn.remote({
471
+ "files": {"kernel.cu": kernel_code},
472
+ "arch": arch,
473
+ "flags": ["-O3", "-lineinfo"],
474
+ "output": output_formats,
475
+ })
476
+
477
+ elapsed_ms = int((time.time() - start_time) * 1000)
478
+
479
+ if result["success"]:
480
+ ptx_lines = count_lines(result["ptx"]) if result.get("ptx") else None
481
+ sass_lines = count_lines(result["sass"]) if result.get("sass") else None
482
+
483
+ results.append(BenchmarkResult(
484
+ kernel_name=kernel_name,
485
+ kernel_lines=kernel_lines,
486
+ compile_time_ms=elapsed_ms,
487
+ success=True,
488
+ ptx_lines=ptx_lines,
489
+ sass_lines=sass_lines,
490
+ error=None,
491
+ ))
492
+ print(f"{elapsed_ms}ms")
493
+ else:
494
+ results.append(BenchmarkResult(
495
+ kernel_name=kernel_name,
496
+ kernel_lines=kernel_lines,
497
+ compile_time_ms=elapsed_ms,
498
+ success=False,
499
+ ptx_lines=None,
500
+ sass_lines=None,
501
+ error=result.get("stderr", "Unknown error"),
502
+ ))
503
+ print(f"FAILED ({elapsed_ms}ms)")
504
+
505
+ except Exception as e:
506
+ elapsed_ms = int((time.time() - start_time) * 1000)
507
+ results.append(BenchmarkResult(
508
+ kernel_name=kernel_name,
509
+ kernel_lines=kernel_lines,
510
+ compile_time_ms=elapsed_ms,
511
+ success=False,
512
+ ptx_lines=None,
513
+ sass_lines=None,
514
+ error=str(e),
515
+ ))
516
+ print(f"ERROR: {e}")
517
+
518
+ return results
519
+
520
+
521
+ def print_summary(results: list[BenchmarkResult]) -> None:
522
+ """Print benchmark summary."""
523
+ successful = [r for r in results if r.success]
524
+
525
+ if not successful:
526
+ print("\n No successful runs!")
527
+ if results:
528
+ print(f" Error: {results[0].error}")
529
+ return
530
+
531
+ times = [r.compile_time_ms for r in successful]
532
+ mean_time = statistics.mean(times)
533
+
534
+ if len(times) > 1:
535
+ stdev = statistics.stdev(times)
536
+ min_time = min(times)
537
+ max_time = max(times)
538
+ print(f"\n Results: {mean_time:.0f}ms avg (min: {min_time}ms, max: {max_time}ms, stdev: {stdev:.0f}ms)")
539
+ else:
540
+ print(f"\n Results: {mean_time:.0f}ms")
541
+
542
+ # Show output sizes
543
+ if successful[0].ptx_lines:
544
+ print(f" PTX output: {successful[0].ptx_lines} lines")
545
+ if successful[0].sass_lines:
546
+ print(f" SASS output: {successful[0].sass_lines} lines")
547
+
548
+
549
+ def run_all_benchmarks(num_runs: int = 3) -> dict[str, list[BenchmarkResult]]:
550
+ """Run benchmarks for all kernel sizes."""
551
+ print("=" * 60)
552
+ print("CUDA Compilation Benchmark")
553
+ print("=" * 60)
554
+
555
+ kernels = [
556
+ ("simple", SIMPLE_KERNEL),
557
+ ("medium", MEDIUM_KERNEL),
558
+ ("complex", COMPLEX_KERNEL),
559
+ ]
560
+
561
+ all_results: dict[str, list[BenchmarkResult]] = {}
562
+
563
+ for name, code in kernels:
564
+ lines = count_lines(code)
565
+ print(f"\n{name.upper()} KERNEL ({lines} lines)")
566
+ print("-" * 40)
567
+
568
+ results = run_benchmark(name, code, num_runs=num_runs)
569
+ all_results[name] = results
570
+ print_summary(results)
571
+
572
+ # Print final summary
573
+ print("\n" + "=" * 60)
574
+ print("SUMMARY")
575
+ print("=" * 60)
576
+
577
+ for name in ["simple", "medium", "complex"]:
578
+ results = all_results.get(name, [])
579
+ successful = [r for r in results if r.success]
580
+ if successful:
581
+ avg_time = statistics.mean([r.compile_time_ms for r in successful])
582
+ print(f" {name:10s}: {avg_time:6.0f}ms ({results[0].kernel_lines} lines)")
583
+ else:
584
+ print(f" {name:10s}: FAILED")
585
+
586
+ return all_results
587
+
588
+
589
+ def main() -> None:
590
+ """Main entry point."""
591
+ parser = argparse.ArgumentParser(description="Benchmark CUDA compilation")
592
+ parser.add_argument(
593
+ "--kernel",
594
+ choices=["simple", "medium", "complex", "all"],
595
+ default="all",
596
+ help="Which kernel to benchmark",
597
+ )
598
+ parser.add_argument(
599
+ "--runs",
600
+ type=int,
601
+ default=3,
602
+ help="Number of benchmark runs per kernel",
603
+ )
604
+ parser.add_argument(
605
+ "--arch",
606
+ default="sm_90a",
607
+ help="Target GPU architecture",
608
+ )
609
+
610
+ args = parser.parse_args()
611
+
612
+ if args.kernel == "all":
613
+ run_all_benchmarks(num_runs=args.runs)
614
+ else:
615
+ kernel_map = {
616
+ "simple": SIMPLE_KERNEL,
617
+ "medium": MEDIUM_KERNEL,
618
+ "complex": COMPLEX_KERNEL,
619
+ }
620
+ code = kernel_map[args.kernel]
621
+ lines = count_lines(code)
622
+
623
+ print(f"\n{args.kernel.upper()} KERNEL ({lines} lines)")
624
+ print("-" * 40)
625
+
626
+ results = run_benchmark(
627
+ args.kernel,
628
+ code,
629
+ arch=args.arch,
630
+ num_runs=args.runs,
631
+ )
632
+ print_summary(results)
633
+
634
+
635
+ if __name__ == "__main__":
636
+ main()