wafer-core 0.1.39__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,636 @@
1
+ """Benchmark script for CUDA compilation performance.
2
+
3
+ This script measures compilation time for kernels of different sizes
4
+ to track performance improvements from optimizations.
5
+
6
+ Usage:
7
+ python -m wafer_core.tools.compile.benchmark
8
+
9
+ # Or with specific test:
10
+ python -m wafer_core.tools.compile.benchmark --kernel simple
11
+ python -m wafer_core.tools.compile.benchmark --kernel medium
12
+ python -m wafer_core.tools.compile.benchmark --kernel complex
13
+ """
14
+
15
+ import argparse
16
+ import statistics
17
+ import time
18
+ from typing import NamedTuple
19
+
20
+ # ============================================================================
21
+ # Test Kernels
22
+ # ============================================================================
23
+
24
+ SIMPLE_KERNEL = """\
25
+ // Simple vector addition kernel (~20 lines)
26
+ __global__ void vector_add(float* a, float* b, float* c, int n) {
27
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
28
+ if (idx < n) {
29
+ c[idx] = a[idx] + b[idx];
30
+ }
31
+ }
32
+ """
33
+
34
+ MEDIUM_KERNEL = """\
35
+ // Medium complexity kernel with shared memory (~100 lines)
36
+ #include <cuda_runtime.h>
37
+
38
+ #define TILE_SIZE 16
39
+
40
+ __global__ void tiled_matmul(
41
+ const float* __restrict__ A,
42
+ const float* __restrict__ B,
43
+ float* __restrict__ C,
44
+ int M, int N, int K
45
+ ) {
46
+ __shared__ float As[TILE_SIZE][TILE_SIZE];
47
+ __shared__ float Bs[TILE_SIZE][TILE_SIZE];
48
+
49
+ int bx = blockIdx.x, by = blockIdx.y;
50
+ int tx = threadIdx.x, ty = threadIdx.y;
51
+
52
+ int row = by * TILE_SIZE + ty;
53
+ int col = bx * TILE_SIZE + tx;
54
+
55
+ float sum = 0.0f;
56
+
57
+ for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
58
+ // Load tile from A
59
+ if (row < M && t * TILE_SIZE + tx < K) {
60
+ As[ty][tx] = A[row * K + t * TILE_SIZE + tx];
61
+ } else {
62
+ As[ty][tx] = 0.0f;
63
+ }
64
+
65
+ // Load tile from B
66
+ if (t * TILE_SIZE + ty < K && col < N) {
67
+ Bs[ty][tx] = B[(t * TILE_SIZE + ty) * N + col];
68
+ } else {
69
+ Bs[ty][tx] = 0.0f;
70
+ }
71
+
72
+ __syncthreads();
73
+
74
+ // Compute partial dot product
75
+ #pragma unroll
76
+ for (int k = 0; k < TILE_SIZE; k++) {
77
+ sum = fmaf(As[ty][k], Bs[k][tx], sum);
78
+ }
79
+
80
+ __syncthreads();
81
+ }
82
+
83
+ if (row < M && col < N) {
84
+ C[row * N + col] = sum;
85
+ }
86
+ }
87
+
88
+ // Reduction kernel
89
+ __global__ void reduce_sum(const float* input, float* output, int n) {
90
+ extern __shared__ float sdata[];
91
+
92
+ unsigned int tid = threadIdx.x;
93
+ unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
94
+
95
+ float mySum = (i < n) ? input[i] : 0.0f;
96
+ if (i + blockDim.x < n) {
97
+ mySum += input[i + blockDim.x];
98
+ }
99
+ sdata[tid] = mySum;
100
+ __syncthreads();
101
+
102
+ for (unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
103
+ if (tid < s) {
104
+ sdata[tid] = mySum = mySum + sdata[tid + s];
105
+ }
106
+ __syncthreads();
107
+ }
108
+
109
+ if (tid < 32) {
110
+ volatile float* smem = sdata;
111
+ smem[tid] = mySum = mySum + smem[tid + 32];
112
+ smem[tid] = mySum = mySum + smem[tid + 16];
113
+ smem[tid] = mySum = mySum + smem[tid + 8];
114
+ smem[tid] = mySum = mySum + smem[tid + 4];
115
+ smem[tid] = mySum = mySum + smem[tid + 2];
116
+ smem[tid] = mySum = mySum + smem[tid + 1];
117
+ }
118
+
119
+ if (tid == 0) {
120
+ output[blockIdx.x] = sdata[0];
121
+ }
122
+ }
123
+ """
124
+
125
+ COMPLEX_KERNEL = """\
126
+ // Complex kernel with multiple features (~500 lines)
127
+ #include <cuda_runtime.h>
128
+ #include <cooperative_groups.h>
129
+
130
+ namespace cg = cooperative_groups;
131
+
132
+ // Constants
133
+ constexpr int BLOCK_SIZE = 256;
134
+ constexpr int TILE_SIZE = 16;
135
+
136
+ // ============================================================================
137
+ // Kernel 1: Vector operations with shared memory and reduction
138
+ // ============================================================================
139
+ template <typename T, int BlockSize>
140
+ __global__ void reduceSum(const T* __restrict__ input, T* __restrict__ output, int N) {
141
+ __shared__ T sdata[BlockSize];
142
+
143
+ unsigned int tid = threadIdx.x;
144
+ unsigned int i = blockIdx.x * (blockDim.x * 2) + threadIdx.x;
145
+
146
+ T mySum = (i < N) ? input[i] : T(0);
147
+ if (i + blockDim.x < N) {
148
+ mySum += input[i + blockDim.x];
149
+ }
150
+ sdata[tid] = mySum;
151
+ __syncthreads();
152
+
153
+ #pragma unroll
154
+ for (unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
155
+ if (tid < s) {
156
+ sdata[tid] = mySum = mySum + sdata[tid + s];
157
+ }
158
+ __syncthreads();
159
+ }
160
+
161
+ if (tid < 32) {
162
+ volatile T* smem = sdata;
163
+ if (BlockSize >= 64) mySum += smem[tid + 32];
164
+ smem[tid] = mySum;
165
+ if (BlockSize >= 32) mySum += smem[tid + 16];
166
+ smem[tid] = mySum;
167
+ if (BlockSize >= 16) mySum += smem[tid + 8];
168
+ smem[tid] = mySum;
169
+ if (BlockSize >= 8) mySum += smem[tid + 4];
170
+ smem[tid] = mySum;
171
+ if (BlockSize >= 4) mySum += smem[tid + 2];
172
+ smem[tid] = mySum;
173
+ if (BlockSize >= 2) mySum += smem[tid + 1];
174
+ smem[tid] = mySum;
175
+ }
176
+
177
+ if (tid == 0) {
178
+ output[blockIdx.x] = sdata[0];
179
+ }
180
+ }
181
+
182
+ // ============================================================================
183
+ // Kernel 2: Matrix transpose with shared memory
184
+ // ============================================================================
185
+ __global__ void matrixTranspose(const float* __restrict__ input,
186
+ float* __restrict__ output,
187
+ int width, int height) {
188
+ __shared__ float tile[TILE_SIZE][TILE_SIZE + 1];
189
+
190
+ int xIndex = blockIdx.x * TILE_SIZE + threadIdx.x;
191
+ int yIndex = blockIdx.y * TILE_SIZE + threadIdx.y;
192
+
193
+ if (xIndex < width && yIndex < height) {
194
+ tile[threadIdx.y][threadIdx.x] = input[yIndex * width + xIndex];
195
+ }
196
+
197
+ __syncthreads();
198
+
199
+ xIndex = blockIdx.y * TILE_SIZE + threadIdx.x;
200
+ yIndex = blockIdx.x * TILE_SIZE + threadIdx.y;
201
+
202
+ if (xIndex < height && yIndex < width) {
203
+ output[yIndex * height + xIndex] = tile[threadIdx.x][threadIdx.y];
204
+ }
205
+ }
206
+
207
+ // ============================================================================
208
+ // Kernel 3: Softmax with cooperative groups
209
+ // ============================================================================
210
+ __global__ void softmaxKernel(const float* __restrict__ input,
211
+ float* __restrict__ output,
212
+ int N, int stride) {
213
+ cg::thread_block block = cg::this_thread_block();
214
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
215
+
216
+ extern __shared__ float shared[];
217
+
218
+ int row = blockIdx.x;
219
+ const float* rowInput = input + row * stride;
220
+ float* rowOutput = output + row * stride;
221
+
222
+ float maxVal = -INFINITY;
223
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
224
+ maxVal = fmaxf(maxVal, rowInput[i]);
225
+ }
226
+
227
+ for (int offset = warp.size() / 2; offset > 0; offset /= 2) {
228
+ maxVal = fmaxf(maxVal, warp.shfl_down(maxVal, offset));
229
+ }
230
+
231
+ if (warp.thread_rank() == 0) {
232
+ shared[threadIdx.x / 32] = maxVal;
233
+ }
234
+ block.sync();
235
+
236
+ if (threadIdx.x < blockDim.x / 32) {
237
+ maxVal = shared[threadIdx.x];
238
+ } else {
239
+ maxVal = -INFINITY;
240
+ }
241
+
242
+ for (int offset = 16; offset > 0; offset /= 2) {
243
+ maxVal = fmaxf(maxVal, __shfl_down_sync(0xffffffff, maxVal, offset));
244
+ }
245
+ maxVal = __shfl_sync(0xffffffff, maxVal, 0);
246
+
247
+ float sum = 0.0f;
248
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
249
+ float val = expf(rowInput[i] - maxVal);
250
+ rowOutput[i] = val;
251
+ sum += val;
252
+ }
253
+
254
+ for (int offset = warp.size() / 2; offset > 0; offset /= 2) {
255
+ sum += warp.shfl_down(sum, offset);
256
+ }
257
+
258
+ if (warp.thread_rank() == 0) {
259
+ shared[threadIdx.x / 32] = sum;
260
+ }
261
+ block.sync();
262
+
263
+ if (threadIdx.x < blockDim.x / 32) {
264
+ sum = shared[threadIdx.x];
265
+ } else {
266
+ sum = 0.0f;
267
+ }
268
+
269
+ for (int offset = 16; offset > 0; offset /= 2) {
270
+ sum += __shfl_down_sync(0xffffffff, sum, offset);
271
+ }
272
+ sum = __shfl_sync(0xffffffff, sum, 0);
273
+
274
+ float invSum = 1.0f / sum;
275
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
276
+ rowOutput[i] *= invSum;
277
+ }
278
+ }
279
+
280
+ // ============================================================================
281
+ // Kernel 4: Fused multiply-add with vectorized loads
282
+ // ============================================================================
283
+ __global__ void fusedMulAddVec4(const float4* __restrict__ A,
284
+ const float4* __restrict__ B,
285
+ float4* __restrict__ C,
286
+ float alpha, float beta, int N) {
287
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
288
+
289
+ if (idx < N) {
290
+ float4 a = A[idx];
291
+ float4 b = B[idx];
292
+ float4 c;
293
+
294
+ c.x = fmaf(alpha, a.x, beta * b.x);
295
+ c.y = fmaf(alpha, a.y, beta * b.y);
296
+ c.z = fmaf(alpha, a.z, beta * b.z);
297
+ c.w = fmaf(alpha, a.w, beta * b.w);
298
+
299
+ C[idx] = c;
300
+ }
301
+ }
302
+
303
+ // ============================================================================
304
+ // Kernel 5: Simple GEMM with shared memory tiling
305
+ // ============================================================================
306
+ __global__ void matmulTiled(const float* __restrict__ A,
307
+ const float* __restrict__ B,
308
+ float* __restrict__ C,
309
+ int M, int N, int K) {
310
+ __shared__ float As[TILE_SIZE][TILE_SIZE];
311
+ __shared__ float Bs[TILE_SIZE][TILE_SIZE];
312
+
313
+ int row = blockIdx.y * TILE_SIZE + threadIdx.y;
314
+ int col = blockIdx.x * TILE_SIZE + threadIdx.x;
315
+
316
+ float sum = 0.0f;
317
+
318
+ for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
319
+ int tiledCol = t * TILE_SIZE + threadIdx.x;
320
+ int tiledRow = t * TILE_SIZE + threadIdx.y;
321
+
322
+ As[threadIdx.y][threadIdx.x] = (row < M && tiledCol < K) ?
323
+ A[row * K + tiledCol] : 0.0f;
324
+ Bs[threadIdx.y][threadIdx.x] = (tiledRow < K && col < N) ?
325
+ B[tiledRow * N + col] : 0.0f;
326
+
327
+ __syncthreads();
328
+
329
+ #pragma unroll
330
+ for (int k = 0; k < TILE_SIZE; k++) {
331
+ sum = fmaf(As[threadIdx.y][k], Bs[k][threadIdx.x], sum);
332
+ }
333
+
334
+ __syncthreads();
335
+ }
336
+
337
+ if (row < M && col < N) {
338
+ C[row * N + col] = sum;
339
+ }
340
+ }
341
+
342
+ // ============================================================================
343
+ // Device helper functions
344
+ // ============================================================================
345
+ __device__ __forceinline__ float warpReduceSum(float val) {
346
+ for (int offset = 16; offset > 0; offset /= 2) {
347
+ val += __shfl_down_sync(0xffffffff, val, offset);
348
+ }
349
+ return val;
350
+ }
351
+
352
+ __device__ __forceinline__ float blockReduceSum(float val) {
353
+ __shared__ float shared[32];
354
+
355
+ int lane = threadIdx.x % 32;
356
+ int wid = threadIdx.x / 32;
357
+
358
+ val = warpReduceSum(val);
359
+
360
+ if (lane == 0) shared[wid] = val;
361
+ __syncthreads();
362
+
363
+ val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0f;
364
+
365
+ if (wid == 0) val = warpReduceSum(val);
366
+
367
+ return val;
368
+ }
369
+
370
+ // ============================================================================
371
+ // Kernel 6: Layer normalization
372
+ // ============================================================================
373
+ __global__ void layerNorm(const float* __restrict__ input,
374
+ const float* __restrict__ gamma,
375
+ const float* __restrict__ beta,
376
+ float* __restrict__ output,
377
+ int N, float eps) {
378
+ int row = blockIdx.x;
379
+ const float* rowInput = input + row * N;
380
+ float* rowOutput = output + row * N;
381
+
382
+ float sum = 0.0f;
383
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
384
+ sum += rowInput[i];
385
+ }
386
+ sum = blockReduceSum(sum);
387
+
388
+ __shared__ float s_mean, s_var;
389
+ if (threadIdx.x == 0) {
390
+ s_mean = sum / N;
391
+ }
392
+ __syncthreads();
393
+
394
+ float var = 0.0f;
395
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
396
+ float diff = rowInput[i] - s_mean;
397
+ var += diff * diff;
398
+ }
399
+ var = blockReduceSum(var);
400
+
401
+ if (threadIdx.x == 0) {
402
+ s_var = rsqrtf(var / N + eps);
403
+ }
404
+ __syncthreads();
405
+
406
+ for (int i = threadIdx.x; i < N; i += blockDim.x) {
407
+ float normalized = (rowInput[i] - s_mean) * s_var;
408
+ rowOutput[i] = fmaf(normalized, gamma[i], beta[i]);
409
+ }
410
+ }
411
+ """
412
+
413
+
414
+ # ============================================================================
415
+ # Benchmark Results
416
+ # ============================================================================
417
+
418
+ class BenchmarkResult(NamedTuple):
419
+ """Result of a single benchmark run."""
420
+ kernel_name: str
421
+ kernel_lines: int
422
+ compile_time_ms: int
423
+ success: bool
424
+ ptx_lines: int | None
425
+ sass_lines: int | None
426
+ error: str | None
427
+
428
+
429
+ def count_lines(code: str) -> int:
430
+ """Count non-empty lines in code."""
431
+ return len([line for line in code.split('\n') if line.strip()])
432
+
433
+
434
+ def run_benchmark(
435
+ kernel_name: str,
436
+ kernel_code: str,
437
+ arch: str = "sm_90a",
438
+ output_formats: list[str] | None = None,
439
+ num_runs: int = 3,
440
+ ) -> list[BenchmarkResult]:
441
+ """Run benchmark for a kernel.
442
+
443
+ Args:
444
+ kernel_name: Name of the kernel for reporting
445
+ kernel_code: CUDA source code
446
+ arch: Target architecture
447
+ output_formats: Output formats to request (default: ["ptx", "sass"])
448
+ num_runs: Number of benchmark runs
449
+
450
+ Returns:
451
+ List of BenchmarkResult for each run
452
+ """
453
+ import modal
454
+
455
+ if output_formats is None:
456
+ output_formats = ["ptx", "sass"]
457
+
458
+ # Get the deployed function
459
+ compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
460
+
461
+ kernel_lines = count_lines(kernel_code)
462
+ results: list[BenchmarkResult] = []
463
+
464
+ for run in range(num_runs):
465
+ print(f" Run {run + 1}/{num_runs}...", end=" ", flush=True)
466
+
467
+ start_time = time.time()
468
+
469
+ try:
470
+ result = compile_fn.remote({
471
+ "files": {"kernel.cu": kernel_code},
472
+ "arch": arch,
473
+ "flags": ["-O3", "-lineinfo"],
474
+ "output": output_formats,
475
+ })
476
+
477
+ elapsed_ms = int((time.time() - start_time) * 1000)
478
+
479
+ if result["success"]:
480
+ ptx_lines = count_lines(result["ptx"]) if result.get("ptx") else None
481
+ sass_lines = count_lines(result["sass"]) if result.get("sass") else None
482
+
483
+ results.append(BenchmarkResult(
484
+ kernel_name=kernel_name,
485
+ kernel_lines=kernel_lines,
486
+ compile_time_ms=elapsed_ms,
487
+ success=True,
488
+ ptx_lines=ptx_lines,
489
+ sass_lines=sass_lines,
490
+ error=None,
491
+ ))
492
+ print(f"{elapsed_ms}ms")
493
+ else:
494
+ results.append(BenchmarkResult(
495
+ kernel_name=kernel_name,
496
+ kernel_lines=kernel_lines,
497
+ compile_time_ms=elapsed_ms,
498
+ success=False,
499
+ ptx_lines=None,
500
+ sass_lines=None,
501
+ error=result.get("stderr", "Unknown error"),
502
+ ))
503
+ print(f"FAILED ({elapsed_ms}ms)")
504
+
505
+ except Exception as e:
506
+ elapsed_ms = int((time.time() - start_time) * 1000)
507
+ results.append(BenchmarkResult(
508
+ kernel_name=kernel_name,
509
+ kernel_lines=kernel_lines,
510
+ compile_time_ms=elapsed_ms,
511
+ success=False,
512
+ ptx_lines=None,
513
+ sass_lines=None,
514
+ error=str(e),
515
+ ))
516
+ print(f"ERROR: {e}")
517
+
518
+ return results
519
+
520
+
521
+ def print_summary(results: list[BenchmarkResult]) -> None:
522
+ """Print benchmark summary."""
523
+ successful = [r for r in results if r.success]
524
+
525
+ if not successful:
526
+ print("\n No successful runs!")
527
+ if results:
528
+ print(f" Error: {results[0].error}")
529
+ return
530
+
531
+ times = [r.compile_time_ms for r in successful]
532
+ mean_time = statistics.mean(times)
533
+
534
+ if len(times) > 1:
535
+ stdev = statistics.stdev(times)
536
+ min_time = min(times)
537
+ max_time = max(times)
538
+ print(f"\n Results: {mean_time:.0f}ms avg (min: {min_time}ms, max: {max_time}ms, stdev: {stdev:.0f}ms)")
539
+ else:
540
+ print(f"\n Results: {mean_time:.0f}ms")
541
+
542
+ # Show output sizes
543
+ if successful[0].ptx_lines:
544
+ print(f" PTX output: {successful[0].ptx_lines} lines")
545
+ if successful[0].sass_lines:
546
+ print(f" SASS output: {successful[0].sass_lines} lines")
547
+
548
+
549
+ def run_all_benchmarks(num_runs: int = 3) -> dict[str, list[BenchmarkResult]]:
550
+ """Run benchmarks for all kernel sizes."""
551
+ print("=" * 60)
552
+ print("CUDA Compilation Benchmark")
553
+ print("=" * 60)
554
+
555
+ kernels = [
556
+ ("simple", SIMPLE_KERNEL),
557
+ ("medium", MEDIUM_KERNEL),
558
+ ("complex", COMPLEX_KERNEL),
559
+ ]
560
+
561
+ all_results: dict[str, list[BenchmarkResult]] = {}
562
+
563
+ for name, code in kernels:
564
+ lines = count_lines(code)
565
+ print(f"\n{name.upper()} KERNEL ({lines} lines)")
566
+ print("-" * 40)
567
+
568
+ results = run_benchmark(name, code, num_runs=num_runs)
569
+ all_results[name] = results
570
+ print_summary(results)
571
+
572
+ # Print final summary
573
+ print("\n" + "=" * 60)
574
+ print("SUMMARY")
575
+ print("=" * 60)
576
+
577
+ for name in ["simple", "medium", "complex"]:
578
+ results = all_results.get(name, [])
579
+ successful = [r for r in results if r.success]
580
+ if successful:
581
+ avg_time = statistics.mean([r.compile_time_ms for r in successful])
582
+ print(f" {name:10s}: {avg_time:6.0f}ms ({results[0].kernel_lines} lines)")
583
+ else:
584
+ print(f" {name:10s}: FAILED")
585
+
586
+ return all_results
587
+
588
+
589
+ def main() -> None:
590
+ """Main entry point."""
591
+ parser = argparse.ArgumentParser(description="Benchmark CUDA compilation")
592
+ parser.add_argument(
593
+ "--kernel",
594
+ choices=["simple", "medium", "complex", "all"],
595
+ default="all",
596
+ help="Which kernel to benchmark",
597
+ )
598
+ parser.add_argument(
599
+ "--runs",
600
+ type=int,
601
+ default=3,
602
+ help="Number of benchmark runs per kernel",
603
+ )
604
+ parser.add_argument(
605
+ "--arch",
606
+ default="sm_90a",
607
+ help="Target GPU architecture",
608
+ )
609
+
610
+ args = parser.parse_args()
611
+
612
+ if args.kernel == "all":
613
+ run_all_benchmarks(num_runs=args.runs)
614
+ else:
615
+ kernel_map = {
616
+ "simple": SIMPLE_KERNEL,
617
+ "medium": MEDIUM_KERNEL,
618
+ "complex": COMPLEX_KERNEL,
619
+ }
620
+ code = kernel_map[args.kernel]
621
+ lines = count_lines(code)
622
+
623
+ print(f"\n{args.kernel.upper()} KERNEL ({lines} lines)")
624
+ print("-" * 40)
625
+
626
+ results = run_benchmark(
627
+ args.kernel,
628
+ code,
629
+ arch=args.arch,
630
+ num_runs=args.runs,
631
+ )
632
+ print_summary(results)
633
+
634
+
635
+ if __name__ == "__main__":
636
+ main()
@@ -58,8 +58,8 @@ async def compile_cuda_remote(
58
58
  ) -> CompileResponse:
59
59
  """Compile CUDA code using Modal (remote execution).
60
60
 
61
- This function spawns a subprocess to call Modal, avoiding event loop
62
- conflicts between the caller's event loop and Modal's asyncio.
61
+ This function calls the deployed Modal function directly using asyncio.to_thread
62
+ to avoid blocking the event loop.
63
63
 
64
64
  Args:
65
65
  request: The compile request
@@ -70,92 +70,74 @@ async def compile_cuda_remote(
70
70
  CompileResponse with PTX/SASS or error
71
71
  """
72
72
  import asyncio
73
- import json
74
73
  import os
75
- import tempfile
76
74
  import time
77
- from pathlib import Path
75
+ from contextlib import contextmanager
78
76
 
79
- start_time = time.time()
77
+ @contextmanager
78
+ def temporary_env_vars(env_updates: dict[str, str]):
79
+ """Context manager to temporarily set environment variables.
80
+
81
+ Saves original values, sets new values, yields, then restores originals.
82
+ This ensures we don't leak credentials between concurrent requests.
83
+ """
84
+ original_values: dict[str, str | None] = {}
85
+ for key, value in env_updates.items():
86
+ original_values[key] = os.environ.get(key)
87
+ os.environ[key] = value
80
88
 
81
- # Write request to temp file
89
+ try:
90
+ yield
91
+ finally:
92
+ for key, original in original_values.items():
93
+ if original is None:
94
+ os.environ.pop(key, None)
95
+ else:
96
+ os.environ[key] = original
97
+
98
+ start_time = time.time()
82
99
  request_dict = request_to_dict(request)
83
100
 
84
- with tempfile.NamedTemporaryFile(
85
- mode="w", suffix=".json", delete=False
86
- ) as request_file:
87
- json.dump(request_dict, request_file)
88
- request_path = request_file.name
101
+ # Build env updates for credentials (only if provided)
102
+ env_updates: dict[str, str] = {}
103
+ if modal_token_id:
104
+ env_updates["MODAL_TOKEN_ID"] = modal_token_id
105
+ if modal_token_secret:
106
+ env_updates["MODAL_TOKEN_SECRET"] = modal_token_secret
107
+
108
+ def call_modal() -> dict:
109
+ """Call Modal function synchronously (runs in thread pool)."""
110
+ import modal
111
+
112
+ # Look up the deployed function
113
+ compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
114
+
115
+ # Call the function remotely
116
+ return compile_fn.remote(request_dict)
89
117
 
90
118
  try:
91
- # Create a Python script that calls Modal using Function.lookup
92
- # This calls the deployed function without needing to rebuild the image
93
- script = f'''
94
- import json
95
- import modal
96
-
97
- # Load request
98
- with open("{request_path}") as f:
99
- request = json.load(f)
100
-
101
- # Look up the deployed function
102
- compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
103
-
104
- # Call the function remotely
105
- result = compile_fn.remote(request)
106
-
107
- # Output result as JSON
108
- print(json.dumps(result))
109
- '''
110
-
111
- # Run in subprocess to avoid event loop conflicts
112
- env = os.environ.copy()
113
- if modal_token_id:
114
- env["MODAL_TOKEN_ID"] = modal_token_id
115
- if modal_token_secret:
116
- env["MODAL_TOKEN_SECRET"] = modal_token_secret
117
-
118
- # Use the same Python interpreter that's running this code
119
- import sys
120
- python_executable = sys.executable
121
-
122
- # Use asyncio.create_subprocess_exec for async subprocess execution
123
- proc = await asyncio.create_subprocess_exec(
124
- python_executable, "-c", script,
125
- stdout=asyncio.subprocess.PIPE,
126
- stderr=asyncio.subprocess.PIPE,
127
- env=env,
128
- )
129
- stdout_bytes, stderr_bytes = await proc.communicate()
119
+ # Run Modal call in thread pool with temporary credentials
120
+ # The context manager ensures env vars are restored after the call
121
+ def call_modal_with_env() -> dict:
122
+ with temporary_env_vars(env_updates):
123
+ return call_modal()
130
124
 
131
- if proc.returncode != 0:
132
- stderr = stderr_bytes.decode() if stderr_bytes else "Unknown error"
133
- # Check for common Modal auth errors
134
- if "MODAL_TOKEN" in stderr or "AuthError" in stderr or "not authenticated" in stderr.lower():
135
- return CompileResponse.error(
136
- "Modal not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables, "
137
- "or run 'modal token new' to authenticate.",
138
- compilation_time_ms=int((time.time() - start_time) * 1000),
139
- )
140
- return CompileResponse.error(
141
- f"Compilation failed: {stderr}",
142
- compilation_time_ms=int((time.time() - start_time) * 1000),
143
- )
125
+ result = await asyncio.to_thread(call_modal_with_env)
126
+ return response_from_dict(result)
144
127
 
145
- # Parse result
146
- stdout = stdout_bytes.decode() if stdout_bytes else "{}"
147
- try:
148
- response_dict = json.loads(stdout)
149
- return response_from_dict(response_dict)
150
- except json.JSONDecodeError as e:
128
+ except Exception as e:
129
+ error_str = str(e)
130
+ # Check for common Modal auth errors
131
+ if "MODAL_TOKEN" in error_str or "AuthError" in error_str or "not authenticated" in error_str.lower():
151
132
  return CompileResponse.error(
152
- f"Failed to parse Modal response: {e}\nOutput: {stdout[:500]}",
133
+ "Modal not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables, "
134
+ "or run 'modal token new' to authenticate.",
153
135
  compilation_time_ms=int((time.time() - start_time) * 1000),
154
136
  )
155
-
156
- finally:
157
- # Clean up temp file
158
- Path(request_path).unlink(missing_ok=True)
137
+ return CompileResponse.error(
138
+ f"Compilation failed: {error_str}",
139
+ compilation_time_ms=int((time.time() - start_time) * 1000),
140
+ )
159
141
 
160
142
 
161
143
  def compile_cuda_local(request: CompileRequest) -> CompileResponse:
@@ -197,7 +179,12 @@ def compile_cuda_local(request: CompileRequest) -> CompileResponse:
197
179
 
198
180
  # Write all files to temp directory
199
181
  for filename, content in request.files.items():
200
- file_path = tmp_path / filename
182
+ file_path = (tmp_path / filename).resolve()
183
+ if not file_path.is_relative_to(tmp_path):
184
+ return CompileResponse.error(
185
+ f"Invalid filename: {filename}",
186
+ compilation_time_ms=int((time.time() - start_time) * 1000),
187
+ )
201
188
  file_path.parent.mkdir(parents=True, exist_ok=True)
202
189
  file_path.write_text(content)
203
190
 
@@ -79,7 +79,10 @@ app = modal.App(name="cuda-compile", image=compile_image)
79
79
  cpu=4,
80
80
  memory=8192, # 8GB RAM
81
81
  timeout=120, # 2 minute timeout
82
+ # Keep one container warm to avoid cold starts (~5-10s savings)
83
+ min_containers=1,
82
84
  )
85
+ @modal.concurrent(max_inputs=4) # Allow concurrent compilations for better throughput
83
86
  def compile_cuda(request: dict) -> dict:
84
87
  """Compile CUDA code and return PTX/SASS.
85
88
 
@@ -105,6 +108,7 @@ def compile_cuda(request: dict) -> dict:
105
108
  import subprocess
106
109
  import tempfile
107
110
  import time
111
+ from concurrent.futures import ThreadPoolExecutor, as_completed
108
112
  from pathlib import Path
109
113
 
110
114
  start_time = time.time()
@@ -138,13 +142,92 @@ def compile_cuda(request: dict) -> dict:
138
142
 
139
143
  main_cu_file = cu_files[0]
140
144
 
145
+ # Build environment for nvcc
146
+ nvcc_env = {
147
+ **os.environ,
148
+ "CUDA_HOME": "/usr/local/cuda",
149
+ "PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
150
+ }
151
+
152
+ def compile_ptx(tmpdir: str, base_cmd: list[str], main_cu_path: Path) -> tuple[str | None, str | None]:
153
+ """Compile to PTX. Returns (ptx_content, error_message)."""
154
+ ptx_output = Path(tmpdir) / "output.ptx"
155
+ ptx_cmd = base_cmd + [
156
+ "--ptx",
157
+ "-o",
158
+ str(ptx_output),
159
+ str(main_cu_path),
160
+ ]
161
+
162
+ ptx_result = subprocess.run(
163
+ ptx_cmd,
164
+ capture_output=True,
165
+ text=True,
166
+ timeout=60,
167
+ cwd=tmpdir,
168
+ env=nvcc_env,
169
+ )
170
+
171
+ if ptx_result.returncode != 0:
172
+ return None, ptx_result.stderr or ptx_result.stdout
173
+
174
+ if ptx_output.exists():
175
+ return ptx_output.read_text(), None
176
+ return None, "PTX output file not created"
177
+
178
+ def compile_sass(tmpdir: str, base_cmd: list[str], main_cu_path: Path) -> tuple[str | None, str | None]:
179
+ """Compile to SASS (via cubin). Returns (sass_content, error_message)."""
180
+ cubin_output = Path(tmpdir) / "output.cubin"
181
+ cubin_cmd = base_cmd + [
182
+ "--cubin",
183
+ "-o",
184
+ str(cubin_output),
185
+ str(main_cu_path),
186
+ ]
187
+
188
+ cubin_result = subprocess.run(
189
+ cubin_cmd,
190
+ capture_output=True,
191
+ text=True,
192
+ timeout=60,
193
+ cwd=tmpdir,
194
+ env=nvcc_env,
195
+ )
196
+
197
+ if cubin_result.returncode != 0:
198
+ return None, cubin_result.stderr or cubin_result.stdout
199
+
200
+ if not cubin_output.exists():
201
+ return None, "cubin output file not created"
202
+
203
+ # Disassemble cubin to SASS
204
+ sass_result = subprocess.run(
205
+ ["cuobjdump", "--dump-sass", str(cubin_output)],
206
+ capture_output=True,
207
+ text=True,
208
+ timeout=30,
209
+ cwd=tmpdir,
210
+ )
211
+
212
+ if sass_result.returncode == 0:
213
+ return sass_result.stdout, None
214
+ return None, f"SASS disassembly failed: {sass_result.stderr}"
215
+
141
216
  try:
142
217
  with tempfile.TemporaryDirectory() as tmpdir:
143
218
  tmp_path = Path(tmpdir)
144
219
 
145
220
  # Write all files to temp directory, preserving subdirectory structure
146
221
  for filename, content in files.items():
147
- file_path = tmp_path / filename
222
+ file_path = (tmp_path / filename).resolve()
223
+ if not file_path.is_relative_to(tmp_path):
224
+ return {
225
+ "success": False,
226
+ "ptx": None,
227
+ "sass": None,
228
+ "stderr": f"Invalid filename: {filename}",
229
+ "compilation_time_ms": int((time.time() - start_time) * 1000),
230
+ }
148
231
  file_path.parent.mkdir(parents=True, exist_ok=True)
149
232
  file_path.write_text(content)
150
233
 
@@ -152,151 +235,78 @@ def compile_cuda(request: dict) -> dict:
152
235
  main_cu_path = tmp_path / main_cu_file
153
236
  include_dir = main_cu_path.parent
154
237
 
155
- results: dict[str, str | None] = {"ptx": None, "sass": None}
156
-
157
238
  # Build base nvcc command with common flags
158
239
  base_cmd = [
159
240
  "nvcc",
160
241
  "-arch",
161
242
  arch,
162
- # Include the temp directory for user headers
163
243
  f"-I{include_dir}",
164
- # Include PyTorch headers
165
244
  "-I/usr/local/lib/python3.12/site-packages/torch/include",
166
245
  "-I/usr/local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include",
167
- # Include CUTLASS headers
168
246
  "-I/usr/local/cutlass/include",
169
- # Standard CUDA headers are already in the default path
170
247
  ]
171
-
172
- # Add user-specified flags
173
248
  base_cmd.extend(flags)
174
249
 
175
- # Generate PTX if requested
176
- if OutputFormat.PTX.value in output_formats:
177
- ptx_output = tmp_path / "output.ptx"
178
- ptx_cmd = base_cmd + [
179
- "--ptx", # Generate PTX
180
- "-o",
181
- str(ptx_output),
182
- str(main_cu_path),
183
- ]
184
-
185
- ptx_result = subprocess.run(
186
- ptx_cmd,
187
- capture_output=True,
188
- text=True,
189
- timeout=60,
190
- cwd=tmpdir,
191
- env={
192
- **os.environ,
193
- "CUDA_HOME": "/usr/local/cuda",
194
- "PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
195
- },
196
- )
197
-
198
- if ptx_result.returncode != 0:
199
- return {
200
- "success": False,
201
- "ptx": None,
202
- "sass": None,
203
- "stderr": ptx_result.stderr or ptx_result.stdout,
204
- "compilation_time_ms": int((time.time() - start_time) * 1000),
205
- }
250
+ # Determine what to compile
251
+ want_ptx = OutputFormat.PTX.value in output_formats
252
+ want_sass = OutputFormat.SASS.value in output_formats
206
253
 
207
- if ptx_output.exists():
208
- results["ptx"] = ptx_output.read_text()
209
-
210
- # Generate SASS if requested
211
- if OutputFormat.SASS.value in output_formats:
212
- # First compile to cubin, then disassemble to SASS
213
- cubin_output = tmp_path / "output.cubin"
214
- cubin_cmd = base_cmd + [
215
- "--cubin", # Generate cubin (binary)
216
- "-o",
217
- str(cubin_output),
218
- str(main_cu_path),
219
- ]
220
-
221
- cubin_result = subprocess.run(
222
- cubin_cmd,
223
- capture_output=True,
224
- text=True,
225
- timeout=60,
226
- cwd=tmpdir,
227
- env={
228
- **os.environ,
229
- "CUDA_HOME": "/usr/local/cuda",
230
- "PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
231
- },
232
- )
233
-
234
- if cubin_result.returncode != 0:
235
- # If we already have PTX, that's a partial success
236
- if results["ptx"]:
237
- return {
238
- "success": True,
239
- "ptx": results["ptx"],
240
- "sass": None,
241
- "stderr": f"SASS generation failed: {cubin_result.stderr}",
242
- "compilation_time_ms": int(
243
- (time.time() - start_time) * 1000
244
- ),
245
- }
246
- return {
247
- "success": False,
248
- "ptx": None,
249
- "sass": None,
250
- "stderr": cubin_result.stderr or cubin_result.stdout,
251
- "compilation_time_ms": int((time.time() - start_time) * 1000),
254
+ results: dict[str, str | None] = {"ptx": None, "sass": None}
255
+ errors: list[str] = []
256
+
257
+ # Run compilations in parallel if both are requested
258
+ if want_ptx and want_sass:
259
+ with ThreadPoolExecutor(max_workers=2) as executor:
260
+ futures = {
261
+ executor.submit(compile_ptx, tmpdir, base_cmd, main_cu_path): "ptx",
262
+ executor.submit(compile_sass, tmpdir, base_cmd, main_cu_path): "sass",
252
263
  }
253
264
 
254
- # Disassemble cubin to SASS using cuobjdump
255
- if cubin_output.exists():
256
- sass_cmd = [
257
- "cuobjdump",
258
- "--dump-sass",
259
- str(cubin_output),
260
- ]
261
-
262
- sass_result = subprocess.run(
263
- sass_cmd,
264
- capture_output=True,
265
- text=True,
266
- timeout=30,
267
- cwd=tmpdir,
268
- )
269
-
270
- if sass_result.returncode == 0:
271
- results["sass"] = sass_result.stdout
272
- else:
273
- # SASS generation failed but we might have PTX
274
- if results["ptx"]:
275
- return {
276
- "success": True,
277
- "ptx": results["ptx"],
278
- "sass": None,
279
- "stderr": f"SASS disassembly failed: {sass_result.stderr}",
280
- "compilation_time_ms": int(
281
- (time.time() - start_time) * 1000
282
- ),
283
- }
284
-
285
- # Check if we got any output
265
+ for future in as_completed(futures):
266
+ output_type = futures[future]
267
+ try:
268
+ content, error = future.result()
269
+ if content:
270
+ results[output_type] = content
271
+ if error:
272
+ errors.append(f"{output_type.upper()}: {error}")
273
+ except Exception as e:
274
+ errors.append(f"{output_type.upper()} compilation error: {e}")
275
+
276
+ elif want_ptx:
277
+ content, error = compile_ptx(tmpdir, base_cmd, main_cu_path)
278
+ if content:
279
+ results["ptx"] = content
280
+ if error:
281
+ errors.append(error)
282
+
283
+ elif want_sass:
284
+ content, error = compile_sass(tmpdir, base_cmd, main_cu_path)
285
+ if content:
286
+ results["sass"] = content
287
+ if error:
288
+ errors.append(error)
289
+
290
+ # Check results
286
291
  if not results["ptx"] and not results["sass"]:
287
292
  return {
288
293
  "success": False,
289
294
  "ptx": None,
290
295
  "sass": None,
291
- "stderr": "No output generated",
296
+ "stderr": "\n".join(errors) if errors else "No output generated",
292
297
  "compilation_time_ms": int((time.time() - start_time) * 1000),
293
298
  }
294
299
 
300
+ # Partial success if we got at least one output
301
+ stderr = ""
302
+ if errors and (results["ptx"] or results["sass"]):
303
+ stderr = "\n".join(errors)
304
+
295
305
  return {
296
306
  "success": True,
297
307
  "ptx": results["ptx"],
298
308
  "sass": results["sass"],
299
- "stderr": "",
309
+ "stderr": stderr,
300
310
  "compilation_time_ms": int((time.time() - start_time) * 1000),
301
311
  }
302
312
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-core
3
- Version: 0.1.39
3
+ Version: 0.1.40
4
4
  Summary: Core utilities and environments for Wafer GPU kernel optimization
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: aiohttp>=3.9.0
@@ -645,8 +645,9 @@ wafer_core/tools/capture_tool/dtypes.py,sha256=1Vm5obOCYc-Njuwkp7uqh_W4lqtYurT3b
645
645
  wafer_core/tools/capture_tool/executor.py,sha256=n1DVfbsP60yJAazx9C9Kwed9LB7AcKXJcoDnhno7ydU,1495
646
646
  wafer_core/tools/capture_tool/metrics.py,sha256=BFZNmdE-kh3LneYdWXTNZmlLuo-DCrP5aEBHxEQYJDU,10890
647
647
  wafer_core/tools/compile/__init__.py,sha256=8VyaMDDPxg4DcT-rwMf9lcNhAanWnmsqijUJYsuzJNg,615
648
- wafer_core/tools/compile/compiler.py,sha256=rGPvfqLTg-7y3hyFEihF6lxiEOfbIsRwfvOZSaVJ2_A,10192
649
- wafer_core/tools/compile/modal_compile.py,sha256=zYrkAtGYkDiM6tJfH_hD-mJ0LqCW5HCSsf_6fADJIbI,13310
648
+ wafer_core/tools/compile/benchmark.py,sha256=6_nfhl24vTWt59EwGievbyMHZK2l4wfslP77BHWsoQ4,19408
649
+ wafer_core/tools/compile/compiler.py,sha256=Y7iwfQkSBc4fmKXpv97ce1grw5L4tJ_VqWFFyYolRAg,10054
650
+ wafer_core/tools/compile/modal_compile.py,sha256=lYMxdrvEQctA1Om6yESetjUAsSyv0W0evNVb8WOY2Ps,13384
650
651
  wafer_core/tools/compile/types.py,sha256=8Hjh6Mz2a7s2JjtKYQq-l3X41gmywnbKk3tc1wvbMLM,3277
651
652
  wafer_core/tools/compile/tests/__init__.py,sha256=gSuBMN-7VayQ9HgyNuUXRumenwk7jtq86ZxdCgFjeYE,41
652
653
  wafer_core/tools/compile/tests/test_compiler.py,sha256=kQ-YTLY8ETnS83nQ8xVSygKY532epxqRTsGx311SG7w,20795
@@ -722,6 +723,6 @@ wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdh
722
723
  wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
723
724
  wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
724
725
  wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
725
- wafer_core-0.1.39.dist-info/METADATA,sha256=OcMn8TZzsUvPT2JBa0xYK_sAT_og1PAZd-DpDcLG1XA,1477
726
- wafer_core-0.1.39.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
727
- wafer_core-0.1.39.dist-info/RECORD,,
726
+ wafer_core-0.1.40.dist-info/METADATA,sha256=yCfawhvfbqAmwkjDxe7GaIRD8LB6L37DR6-XlGGzevs,1477
727
+ wafer_core-0.1.40.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
728
+ wafer_core-0.1.40.dist-info/RECORD,,