wafer-core 0.1.39__py3-none-any.whl → 0.1.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/tools/compile/benchmark.py +636 -0
- wafer_core/tools/compile/compiler.py +63 -76
- wafer_core/tools/compile/modal_compile.py +129 -119
- {wafer_core-0.1.39.dist-info → wafer_core-0.1.41.dist-info}/METADATA +1 -1
- {wafer_core-0.1.39.dist-info → wafer_core-0.1.41.dist-info}/RECORD +6 -5
- {wafer_core-0.1.39.dist-info → wafer_core-0.1.41.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,636 @@
|
|
|
1
|
+
"""Benchmark script for CUDA compilation performance.
|
|
2
|
+
|
|
3
|
+
This script measures compilation time for kernels of different sizes
|
|
4
|
+
to track performance improvements from optimizations.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
python -m wafer_core.tools.compile.benchmark
|
|
8
|
+
|
|
9
|
+
# Or with specific test:
|
|
10
|
+
python -m wafer_core.tools.compile.benchmark --kernel simple
|
|
11
|
+
python -m wafer_core.tools.compile.benchmark --kernel medium
|
|
12
|
+
python -m wafer_core.tools.compile.benchmark --kernel complex
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import statistics
|
|
17
|
+
import time
|
|
18
|
+
from typing import NamedTuple
|
|
19
|
+
|
|
20
|
+
# ============================================================================
|
|
21
|
+
# Test Kernels
|
|
22
|
+
# ============================================================================
|
|
23
|
+
|
|
24
|
+
SIMPLE_KERNEL = """\
|
|
25
|
+
// Simple vector addition kernel (~20 lines)
|
|
26
|
+
__global__ void vector_add(float* a, float* b, float* c, int n) {
|
|
27
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
28
|
+
if (idx < n) {
|
|
29
|
+
c[idx] = a[idx] + b[idx];
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
MEDIUM_KERNEL = """\
|
|
35
|
+
// Medium complexity kernel with shared memory (~100 lines)
|
|
36
|
+
#include <cuda_runtime.h>
|
|
37
|
+
|
|
38
|
+
#define TILE_SIZE 16
|
|
39
|
+
|
|
40
|
+
__global__ void tiled_matmul(
|
|
41
|
+
const float* __restrict__ A,
|
|
42
|
+
const float* __restrict__ B,
|
|
43
|
+
float* __restrict__ C,
|
|
44
|
+
int M, int N, int K
|
|
45
|
+
) {
|
|
46
|
+
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
|
47
|
+
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
|
48
|
+
|
|
49
|
+
int bx = blockIdx.x, by = blockIdx.y;
|
|
50
|
+
int tx = threadIdx.x, ty = threadIdx.y;
|
|
51
|
+
|
|
52
|
+
int row = by * TILE_SIZE + ty;
|
|
53
|
+
int col = bx * TILE_SIZE + tx;
|
|
54
|
+
|
|
55
|
+
float sum = 0.0f;
|
|
56
|
+
|
|
57
|
+
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
|
58
|
+
// Load tile from A
|
|
59
|
+
if (row < M && t * TILE_SIZE + tx < K) {
|
|
60
|
+
As[ty][tx] = A[row * K + t * TILE_SIZE + tx];
|
|
61
|
+
} else {
|
|
62
|
+
As[ty][tx] = 0.0f;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
// Load tile from B
|
|
66
|
+
if (t * TILE_SIZE + ty < K && col < N) {
|
|
67
|
+
Bs[ty][tx] = B[(t * TILE_SIZE + ty) * N + col];
|
|
68
|
+
} else {
|
|
69
|
+
Bs[ty][tx] = 0.0f;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
__syncthreads();
|
|
73
|
+
|
|
74
|
+
// Compute partial dot product
|
|
75
|
+
#pragma unroll
|
|
76
|
+
for (int k = 0; k < TILE_SIZE; k++) {
|
|
77
|
+
sum = fmaf(As[ty][k], Bs[k][tx], sum);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
__syncthreads();
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
if (row < M && col < N) {
|
|
84
|
+
C[row * N + col] = sum;
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Reduction kernel
|
|
89
|
+
__global__ void reduce_sum(const float* input, float* output, int n) {
|
|
90
|
+
extern __shared__ float sdata[];
|
|
91
|
+
|
|
92
|
+
unsigned int tid = threadIdx.x;
|
|
93
|
+
unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
|
|
94
|
+
|
|
95
|
+
float mySum = (i < n) ? input[i] : 0.0f;
|
|
96
|
+
if (i + blockDim.x < n) {
|
|
97
|
+
mySum += input[i + blockDim.x];
|
|
98
|
+
}
|
|
99
|
+
sdata[tid] = mySum;
|
|
100
|
+
__syncthreads();
|
|
101
|
+
|
|
102
|
+
for (unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
|
|
103
|
+
if (tid < s) {
|
|
104
|
+
sdata[tid] = mySum = mySum + sdata[tid + s];
|
|
105
|
+
}
|
|
106
|
+
__syncthreads();
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
if (tid < 32) {
|
|
110
|
+
volatile float* smem = sdata;
|
|
111
|
+
smem[tid] = mySum = mySum + smem[tid + 32];
|
|
112
|
+
smem[tid] = mySum = mySum + smem[tid + 16];
|
|
113
|
+
smem[tid] = mySum = mySum + smem[tid + 8];
|
|
114
|
+
smem[tid] = mySum = mySum + smem[tid + 4];
|
|
115
|
+
smem[tid] = mySum = mySum + smem[tid + 2];
|
|
116
|
+
smem[tid] = mySum = mySum + smem[tid + 1];
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
if (tid == 0) {
|
|
120
|
+
output[blockIdx.x] = sdata[0];
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
COMPLEX_KERNEL = """\
|
|
126
|
+
// Complex kernel with multiple features (~500 lines)
|
|
127
|
+
#include <cuda_runtime.h>
|
|
128
|
+
#include <cooperative_groups.h>
|
|
129
|
+
|
|
130
|
+
namespace cg = cooperative_groups;
|
|
131
|
+
|
|
132
|
+
// Constants
|
|
133
|
+
constexpr int BLOCK_SIZE = 256;
|
|
134
|
+
constexpr int TILE_SIZE = 16;
|
|
135
|
+
|
|
136
|
+
// ============================================================================
|
|
137
|
+
// Kernel 1: Vector operations with shared memory and reduction
|
|
138
|
+
// ============================================================================
|
|
139
|
+
template <typename T, int BlockSize>
|
|
140
|
+
__global__ void reduceSum(const T* __restrict__ input, T* __restrict__ output, int N) {
|
|
141
|
+
__shared__ T sdata[BlockSize];
|
|
142
|
+
|
|
143
|
+
unsigned int tid = threadIdx.x;
|
|
144
|
+
unsigned int i = blockIdx.x * (blockDim.x * 2) + threadIdx.x;
|
|
145
|
+
|
|
146
|
+
T mySum = (i < N) ? input[i] : T(0);
|
|
147
|
+
if (i + blockDim.x < N) {
|
|
148
|
+
mySum += input[i + blockDim.x];
|
|
149
|
+
}
|
|
150
|
+
sdata[tid] = mySum;
|
|
151
|
+
__syncthreads();
|
|
152
|
+
|
|
153
|
+
#pragma unroll
|
|
154
|
+
for (unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
|
|
155
|
+
if (tid < s) {
|
|
156
|
+
sdata[tid] = mySum = mySum + sdata[tid + s];
|
|
157
|
+
}
|
|
158
|
+
__syncthreads();
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
if (tid < 32) {
|
|
162
|
+
volatile T* smem = sdata;
|
|
163
|
+
if (BlockSize >= 64) mySum += smem[tid + 32];
|
|
164
|
+
smem[tid] = mySum;
|
|
165
|
+
if (BlockSize >= 32) mySum += smem[tid + 16];
|
|
166
|
+
smem[tid] = mySum;
|
|
167
|
+
if (BlockSize >= 16) mySum += smem[tid + 8];
|
|
168
|
+
smem[tid] = mySum;
|
|
169
|
+
if (BlockSize >= 8) mySum += smem[tid + 4];
|
|
170
|
+
smem[tid] = mySum;
|
|
171
|
+
if (BlockSize >= 4) mySum += smem[tid + 2];
|
|
172
|
+
smem[tid] = mySum;
|
|
173
|
+
if (BlockSize >= 2) mySum += smem[tid + 1];
|
|
174
|
+
smem[tid] = mySum;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
if (tid == 0) {
|
|
178
|
+
output[blockIdx.x] = sdata[0];
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// ============================================================================
|
|
183
|
+
// Kernel 2: Matrix transpose with shared memory
|
|
184
|
+
// ============================================================================
|
|
185
|
+
__global__ void matrixTranspose(const float* __restrict__ input,
|
|
186
|
+
float* __restrict__ output,
|
|
187
|
+
int width, int height) {
|
|
188
|
+
__shared__ float tile[TILE_SIZE][TILE_SIZE + 1];
|
|
189
|
+
|
|
190
|
+
int xIndex = blockIdx.x * TILE_SIZE + threadIdx.x;
|
|
191
|
+
int yIndex = blockIdx.y * TILE_SIZE + threadIdx.y;
|
|
192
|
+
|
|
193
|
+
if (xIndex < width && yIndex < height) {
|
|
194
|
+
tile[threadIdx.y][threadIdx.x] = input[yIndex * width + xIndex];
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
__syncthreads();
|
|
198
|
+
|
|
199
|
+
xIndex = blockIdx.y * TILE_SIZE + threadIdx.x;
|
|
200
|
+
yIndex = blockIdx.x * TILE_SIZE + threadIdx.y;
|
|
201
|
+
|
|
202
|
+
if (xIndex < height && yIndex < width) {
|
|
203
|
+
output[yIndex * height + xIndex] = tile[threadIdx.x][threadIdx.y];
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
// ============================================================================
|
|
208
|
+
// Kernel 3: Softmax with cooperative groups
|
|
209
|
+
// ============================================================================
|
|
210
|
+
__global__ void softmaxKernel(const float* __restrict__ input,
|
|
211
|
+
float* __restrict__ output,
|
|
212
|
+
int N, int stride) {
|
|
213
|
+
cg::thread_block block = cg::this_thread_block();
|
|
214
|
+
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
|
|
215
|
+
|
|
216
|
+
extern __shared__ float shared[];
|
|
217
|
+
|
|
218
|
+
int row = blockIdx.x;
|
|
219
|
+
const float* rowInput = input + row * stride;
|
|
220
|
+
float* rowOutput = output + row * stride;
|
|
221
|
+
|
|
222
|
+
float maxVal = -INFINITY;
|
|
223
|
+
for (int i = threadIdx.x; i < N; i += blockDim.x) {
|
|
224
|
+
maxVal = fmaxf(maxVal, rowInput[i]);
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
for (int offset = warp.size() / 2; offset > 0; offset /= 2) {
|
|
228
|
+
maxVal = fmaxf(maxVal, warp.shfl_down(maxVal, offset));
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
if (warp.thread_rank() == 0) {
|
|
232
|
+
shared[threadIdx.x / 32] = maxVal;
|
|
233
|
+
}
|
|
234
|
+
block.sync();
|
|
235
|
+
|
|
236
|
+
if (threadIdx.x < blockDim.x / 32) {
|
|
237
|
+
maxVal = shared[threadIdx.x];
|
|
238
|
+
} else {
|
|
239
|
+
maxVal = -INFINITY;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
for (int offset = 16; offset > 0; offset /= 2) {
|
|
243
|
+
maxVal = fmaxf(maxVal, __shfl_down_sync(0xffffffff, maxVal, offset));
|
|
244
|
+
}
|
|
245
|
+
maxVal = __shfl_sync(0xffffffff, maxVal, 0);
|
|
246
|
+
|
|
247
|
+
float sum = 0.0f;
|
|
248
|
+
for (int i = threadIdx.x; i < N; i += blockDim.x) {
|
|
249
|
+
float val = expf(rowInput[i] - maxVal);
|
|
250
|
+
rowOutput[i] = val;
|
|
251
|
+
sum += val;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
for (int offset = warp.size() / 2; offset > 0; offset /= 2) {
|
|
255
|
+
sum += warp.shfl_down(sum, offset);
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
if (warp.thread_rank() == 0) {
|
|
259
|
+
shared[threadIdx.x / 32] = sum;
|
|
260
|
+
}
|
|
261
|
+
block.sync();
|
|
262
|
+
|
|
263
|
+
if (threadIdx.x < blockDim.x / 32) {
|
|
264
|
+
sum = shared[threadIdx.x];
|
|
265
|
+
} else {
|
|
266
|
+
sum = 0.0f;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
for (int offset = 16; offset > 0; offset /= 2) {
|
|
270
|
+
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
|
271
|
+
}
|
|
272
|
+
sum = __shfl_sync(0xffffffff, sum, 0);
|
|
273
|
+
|
|
274
|
+
float invSum = 1.0f / sum;
|
|
275
|
+
for (int i = threadIdx.x; i < N; i += blockDim.x) {
|
|
276
|
+
rowOutput[i] *= invSum;
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// ============================================================================
|
|
281
|
+
// Kernel 4: Fused multiply-add with vectorized loads
|
|
282
|
+
// ============================================================================
|
|
283
|
+
__global__ void fusedMulAddVec4(const float4* __restrict__ A,
|
|
284
|
+
const float4* __restrict__ B,
|
|
285
|
+
float4* __restrict__ C,
|
|
286
|
+
float alpha, float beta, int N) {
|
|
287
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
288
|
+
|
|
289
|
+
if (idx < N) {
|
|
290
|
+
float4 a = A[idx];
|
|
291
|
+
float4 b = B[idx];
|
|
292
|
+
float4 c;
|
|
293
|
+
|
|
294
|
+
c.x = fmaf(alpha, a.x, beta * b.x);
|
|
295
|
+
c.y = fmaf(alpha, a.y, beta * b.y);
|
|
296
|
+
c.z = fmaf(alpha, a.z, beta * b.z);
|
|
297
|
+
c.w = fmaf(alpha, a.w, beta * b.w);
|
|
298
|
+
|
|
299
|
+
C[idx] = c;
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
// ============================================================================
|
|
304
|
+
// Kernel 5: Simple GEMM with shared memory tiling
|
|
305
|
+
// ============================================================================
|
|
306
|
+
__global__ void matmulTiled(const float* __restrict__ A,
|
|
307
|
+
const float* __restrict__ B,
|
|
308
|
+
float* __restrict__ C,
|
|
309
|
+
int M, int N, int K) {
|
|
310
|
+
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
|
311
|
+
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
|
312
|
+
|
|
313
|
+
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
|
314
|
+
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
|
315
|
+
|
|
316
|
+
float sum = 0.0f;
|
|
317
|
+
|
|
318
|
+
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
|
319
|
+
int tiledCol = t * TILE_SIZE + threadIdx.x;
|
|
320
|
+
int tiledRow = t * TILE_SIZE + threadIdx.y;
|
|
321
|
+
|
|
322
|
+
As[threadIdx.y][threadIdx.x] = (row < M && tiledCol < K) ?
|
|
323
|
+
A[row * K + tiledCol] : 0.0f;
|
|
324
|
+
Bs[threadIdx.y][threadIdx.x] = (tiledRow < K && col < N) ?
|
|
325
|
+
B[tiledRow * N + col] : 0.0f;
|
|
326
|
+
|
|
327
|
+
__syncthreads();
|
|
328
|
+
|
|
329
|
+
#pragma unroll
|
|
330
|
+
for (int k = 0; k < TILE_SIZE; k++) {
|
|
331
|
+
sum = fmaf(As[threadIdx.y][k], Bs[k][threadIdx.x], sum);
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
__syncthreads();
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
if (row < M && col < N) {
|
|
338
|
+
C[row * N + col] = sum;
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
// ============================================================================
|
|
343
|
+
// Device helper functions
|
|
344
|
+
// ============================================================================
|
|
345
|
+
__device__ __forceinline__ float warpReduceSum(float val) {
|
|
346
|
+
for (int offset = 16; offset > 0; offset /= 2) {
|
|
347
|
+
val += __shfl_down_sync(0xffffffff, val, offset);
|
|
348
|
+
}
|
|
349
|
+
return val;
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
__device__ __forceinline__ float blockReduceSum(float val) {
|
|
353
|
+
__shared__ float shared[32];
|
|
354
|
+
|
|
355
|
+
int lane = threadIdx.x % 32;
|
|
356
|
+
int wid = threadIdx.x / 32;
|
|
357
|
+
|
|
358
|
+
val = warpReduceSum(val);
|
|
359
|
+
|
|
360
|
+
if (lane == 0) shared[wid] = val;
|
|
361
|
+
__syncthreads();
|
|
362
|
+
|
|
363
|
+
val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0f;
|
|
364
|
+
|
|
365
|
+
if (wid == 0) val = warpReduceSum(val);
|
|
366
|
+
|
|
367
|
+
return val;
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
// ============================================================================
|
|
371
|
+
// Kernel 6: Layer normalization
|
|
372
|
+
// ============================================================================
|
|
373
|
+
__global__ void layerNorm(const float* __restrict__ input,
|
|
374
|
+
const float* __restrict__ gamma,
|
|
375
|
+
const float* __restrict__ beta,
|
|
376
|
+
float* __restrict__ output,
|
|
377
|
+
int N, float eps) {
|
|
378
|
+
int row = blockIdx.x;
|
|
379
|
+
const float* rowInput = input + row * N;
|
|
380
|
+
float* rowOutput = output + row * N;
|
|
381
|
+
|
|
382
|
+
float sum = 0.0f;
|
|
383
|
+
for (int i = threadIdx.x; i < N; i += blockDim.x) {
|
|
384
|
+
sum += rowInput[i];
|
|
385
|
+
}
|
|
386
|
+
sum = blockReduceSum(sum);
|
|
387
|
+
|
|
388
|
+
__shared__ float s_mean, s_var;
|
|
389
|
+
if (threadIdx.x == 0) {
|
|
390
|
+
s_mean = sum / N;
|
|
391
|
+
}
|
|
392
|
+
__syncthreads();
|
|
393
|
+
|
|
394
|
+
float var = 0.0f;
|
|
395
|
+
for (int i = threadIdx.x; i < N; i += blockDim.x) {
|
|
396
|
+
float diff = rowInput[i] - s_mean;
|
|
397
|
+
var += diff * diff;
|
|
398
|
+
}
|
|
399
|
+
var = blockReduceSum(var);
|
|
400
|
+
|
|
401
|
+
if (threadIdx.x == 0) {
|
|
402
|
+
s_var = rsqrtf(var / N + eps);
|
|
403
|
+
}
|
|
404
|
+
__syncthreads();
|
|
405
|
+
|
|
406
|
+
for (int i = threadIdx.x; i < N; i += blockDim.x) {
|
|
407
|
+
float normalized = (rowInput[i] - s_mean) * s_var;
|
|
408
|
+
rowOutput[i] = fmaf(normalized, gamma[i], beta[i]);
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
# ============================================================================
|
|
415
|
+
# Benchmark Results
|
|
416
|
+
# ============================================================================
|
|
417
|
+
|
|
418
|
+
class BenchmarkResult(NamedTuple):
|
|
419
|
+
"""Result of a single benchmark run."""
|
|
420
|
+
kernel_name: str
|
|
421
|
+
kernel_lines: int
|
|
422
|
+
compile_time_ms: int
|
|
423
|
+
success: bool
|
|
424
|
+
ptx_lines: int | None
|
|
425
|
+
sass_lines: int | None
|
|
426
|
+
error: str | None
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def count_lines(code: str) -> int:
|
|
430
|
+
"""Count non-empty lines in code."""
|
|
431
|
+
return len([line for line in code.split('\n') if line.strip()])
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def run_benchmark(
|
|
435
|
+
kernel_name: str,
|
|
436
|
+
kernel_code: str,
|
|
437
|
+
arch: str = "sm_90a",
|
|
438
|
+
output_formats: list[str] | None = None,
|
|
439
|
+
num_runs: int = 3,
|
|
440
|
+
) -> list[BenchmarkResult]:
|
|
441
|
+
"""Run benchmark for a kernel.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
kernel_name: Name of the kernel for reporting
|
|
445
|
+
kernel_code: CUDA source code
|
|
446
|
+
arch: Target architecture
|
|
447
|
+
output_formats: Output formats to request (default: ["ptx", "sass"])
|
|
448
|
+
num_runs: Number of benchmark runs
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
List of BenchmarkResult for each run
|
|
452
|
+
"""
|
|
453
|
+
import modal
|
|
454
|
+
|
|
455
|
+
if output_formats is None:
|
|
456
|
+
output_formats = ["ptx", "sass"]
|
|
457
|
+
|
|
458
|
+
# Get the deployed function
|
|
459
|
+
compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
|
|
460
|
+
|
|
461
|
+
kernel_lines = count_lines(kernel_code)
|
|
462
|
+
results: list[BenchmarkResult] = []
|
|
463
|
+
|
|
464
|
+
for run in range(num_runs):
|
|
465
|
+
print(f" Run {run + 1}/{num_runs}...", end=" ", flush=True)
|
|
466
|
+
|
|
467
|
+
start_time = time.time()
|
|
468
|
+
|
|
469
|
+
try:
|
|
470
|
+
result = compile_fn.remote({
|
|
471
|
+
"files": {"kernel.cu": kernel_code},
|
|
472
|
+
"arch": arch,
|
|
473
|
+
"flags": ["-O3", "-lineinfo"],
|
|
474
|
+
"output": output_formats,
|
|
475
|
+
})
|
|
476
|
+
|
|
477
|
+
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
478
|
+
|
|
479
|
+
if result["success"]:
|
|
480
|
+
ptx_lines = count_lines(result["ptx"]) if result.get("ptx") else None
|
|
481
|
+
sass_lines = count_lines(result["sass"]) if result.get("sass") else None
|
|
482
|
+
|
|
483
|
+
results.append(BenchmarkResult(
|
|
484
|
+
kernel_name=kernel_name,
|
|
485
|
+
kernel_lines=kernel_lines,
|
|
486
|
+
compile_time_ms=elapsed_ms,
|
|
487
|
+
success=True,
|
|
488
|
+
ptx_lines=ptx_lines,
|
|
489
|
+
sass_lines=sass_lines,
|
|
490
|
+
error=None,
|
|
491
|
+
))
|
|
492
|
+
print(f"{elapsed_ms}ms")
|
|
493
|
+
else:
|
|
494
|
+
results.append(BenchmarkResult(
|
|
495
|
+
kernel_name=kernel_name,
|
|
496
|
+
kernel_lines=kernel_lines,
|
|
497
|
+
compile_time_ms=elapsed_ms,
|
|
498
|
+
success=False,
|
|
499
|
+
ptx_lines=None,
|
|
500
|
+
sass_lines=None,
|
|
501
|
+
error=result.get("stderr", "Unknown error"),
|
|
502
|
+
))
|
|
503
|
+
print(f"FAILED ({elapsed_ms}ms)")
|
|
504
|
+
|
|
505
|
+
except Exception as e:
|
|
506
|
+
elapsed_ms = int((time.time() - start_time) * 1000)
|
|
507
|
+
results.append(BenchmarkResult(
|
|
508
|
+
kernel_name=kernel_name,
|
|
509
|
+
kernel_lines=kernel_lines,
|
|
510
|
+
compile_time_ms=elapsed_ms,
|
|
511
|
+
success=False,
|
|
512
|
+
ptx_lines=None,
|
|
513
|
+
sass_lines=None,
|
|
514
|
+
error=str(e),
|
|
515
|
+
))
|
|
516
|
+
print(f"ERROR: {e}")
|
|
517
|
+
|
|
518
|
+
return results
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def print_summary(results: list[BenchmarkResult]) -> None:
|
|
522
|
+
"""Print benchmark summary."""
|
|
523
|
+
successful = [r for r in results if r.success]
|
|
524
|
+
|
|
525
|
+
if not successful:
|
|
526
|
+
print("\n No successful runs!")
|
|
527
|
+
if results:
|
|
528
|
+
print(f" Error: {results[0].error}")
|
|
529
|
+
return
|
|
530
|
+
|
|
531
|
+
times = [r.compile_time_ms for r in successful]
|
|
532
|
+
mean_time = statistics.mean(times)
|
|
533
|
+
|
|
534
|
+
if len(times) > 1:
|
|
535
|
+
stdev = statistics.stdev(times)
|
|
536
|
+
min_time = min(times)
|
|
537
|
+
max_time = max(times)
|
|
538
|
+
print(f"\n Results: {mean_time:.0f}ms avg (min: {min_time}ms, max: {max_time}ms, stdev: {stdev:.0f}ms)")
|
|
539
|
+
else:
|
|
540
|
+
print(f"\n Results: {mean_time:.0f}ms")
|
|
541
|
+
|
|
542
|
+
# Show output sizes
|
|
543
|
+
if successful[0].ptx_lines:
|
|
544
|
+
print(f" PTX output: {successful[0].ptx_lines} lines")
|
|
545
|
+
if successful[0].sass_lines:
|
|
546
|
+
print(f" SASS output: {successful[0].sass_lines} lines")
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def run_all_benchmarks(num_runs: int = 3) -> dict[str, list[BenchmarkResult]]:
|
|
550
|
+
"""Run benchmarks for all kernel sizes."""
|
|
551
|
+
print("=" * 60)
|
|
552
|
+
print("CUDA Compilation Benchmark")
|
|
553
|
+
print("=" * 60)
|
|
554
|
+
|
|
555
|
+
kernels = [
|
|
556
|
+
("simple", SIMPLE_KERNEL),
|
|
557
|
+
("medium", MEDIUM_KERNEL),
|
|
558
|
+
("complex", COMPLEX_KERNEL),
|
|
559
|
+
]
|
|
560
|
+
|
|
561
|
+
all_results: dict[str, list[BenchmarkResult]] = {}
|
|
562
|
+
|
|
563
|
+
for name, code in kernels:
|
|
564
|
+
lines = count_lines(code)
|
|
565
|
+
print(f"\n{name.upper()} KERNEL ({lines} lines)")
|
|
566
|
+
print("-" * 40)
|
|
567
|
+
|
|
568
|
+
results = run_benchmark(name, code, num_runs=num_runs)
|
|
569
|
+
all_results[name] = results
|
|
570
|
+
print_summary(results)
|
|
571
|
+
|
|
572
|
+
# Print final summary
|
|
573
|
+
print("\n" + "=" * 60)
|
|
574
|
+
print("SUMMARY")
|
|
575
|
+
print("=" * 60)
|
|
576
|
+
|
|
577
|
+
for name in ["simple", "medium", "complex"]:
|
|
578
|
+
results = all_results.get(name, [])
|
|
579
|
+
successful = [r for r in results if r.success]
|
|
580
|
+
if successful:
|
|
581
|
+
avg_time = statistics.mean([r.compile_time_ms for r in successful])
|
|
582
|
+
print(f" {name:10s}: {avg_time:6.0f}ms ({results[0].kernel_lines} lines)")
|
|
583
|
+
else:
|
|
584
|
+
print(f" {name:10s}: FAILED")
|
|
585
|
+
|
|
586
|
+
return all_results
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def main() -> None:
|
|
590
|
+
"""Main entry point."""
|
|
591
|
+
parser = argparse.ArgumentParser(description="Benchmark CUDA compilation")
|
|
592
|
+
parser.add_argument(
|
|
593
|
+
"--kernel",
|
|
594
|
+
choices=["simple", "medium", "complex", "all"],
|
|
595
|
+
default="all",
|
|
596
|
+
help="Which kernel to benchmark",
|
|
597
|
+
)
|
|
598
|
+
parser.add_argument(
|
|
599
|
+
"--runs",
|
|
600
|
+
type=int,
|
|
601
|
+
default=3,
|
|
602
|
+
help="Number of benchmark runs per kernel",
|
|
603
|
+
)
|
|
604
|
+
parser.add_argument(
|
|
605
|
+
"--arch",
|
|
606
|
+
default="sm_90a",
|
|
607
|
+
help="Target GPU architecture",
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
args = parser.parse_args()
|
|
611
|
+
|
|
612
|
+
if args.kernel == "all":
|
|
613
|
+
run_all_benchmarks(num_runs=args.runs)
|
|
614
|
+
else:
|
|
615
|
+
kernel_map = {
|
|
616
|
+
"simple": SIMPLE_KERNEL,
|
|
617
|
+
"medium": MEDIUM_KERNEL,
|
|
618
|
+
"complex": COMPLEX_KERNEL,
|
|
619
|
+
}
|
|
620
|
+
code = kernel_map[args.kernel]
|
|
621
|
+
lines = count_lines(code)
|
|
622
|
+
|
|
623
|
+
print(f"\n{args.kernel.upper()} KERNEL ({lines} lines)")
|
|
624
|
+
print("-" * 40)
|
|
625
|
+
|
|
626
|
+
results = run_benchmark(
|
|
627
|
+
args.kernel,
|
|
628
|
+
code,
|
|
629
|
+
arch=args.arch,
|
|
630
|
+
num_runs=args.runs,
|
|
631
|
+
)
|
|
632
|
+
print_summary(results)
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
if __name__ == "__main__":
|
|
636
|
+
main()
|
|
@@ -58,8 +58,8 @@ async def compile_cuda_remote(
|
|
|
58
58
|
) -> CompileResponse:
|
|
59
59
|
"""Compile CUDA code using Modal (remote execution).
|
|
60
60
|
|
|
61
|
-
This function
|
|
62
|
-
|
|
61
|
+
This function calls the deployed Modal function directly using asyncio.to_thread
|
|
62
|
+
to avoid blocking the event loop.
|
|
63
63
|
|
|
64
64
|
Args:
|
|
65
65
|
request: The compile request
|
|
@@ -70,92 +70,74 @@ async def compile_cuda_remote(
|
|
|
70
70
|
CompileResponse with PTX/SASS or error
|
|
71
71
|
"""
|
|
72
72
|
import asyncio
|
|
73
|
-
import json
|
|
74
73
|
import os
|
|
75
|
-
import tempfile
|
|
76
74
|
import time
|
|
77
|
-
from
|
|
75
|
+
from contextlib import contextmanager
|
|
78
76
|
|
|
79
|
-
|
|
77
|
+
@contextmanager
|
|
78
|
+
def temporary_env_vars(env_updates: dict[str, str]):
|
|
79
|
+
"""Context manager to temporarily set environment variables.
|
|
80
|
+
|
|
81
|
+
Saves original values, sets new values, yields, then restores originals.
|
|
82
|
+
This ensures we don't leak credentials between concurrent requests.
|
|
83
|
+
"""
|
|
84
|
+
original_values: dict[str, str | None] = {}
|
|
85
|
+
for key, value in env_updates.items():
|
|
86
|
+
original_values[key] = os.environ.get(key)
|
|
87
|
+
os.environ[key] = value
|
|
80
88
|
|
|
81
|
-
|
|
89
|
+
try:
|
|
90
|
+
yield
|
|
91
|
+
finally:
|
|
92
|
+
for key, original in original_values.items():
|
|
93
|
+
if original is None:
|
|
94
|
+
os.environ.pop(key, None)
|
|
95
|
+
else:
|
|
96
|
+
os.environ[key] = original
|
|
97
|
+
|
|
98
|
+
start_time = time.time()
|
|
82
99
|
request_dict = request_to_dict(request)
|
|
83
100
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
101
|
+
# Build env updates for credentials (only if provided)
|
|
102
|
+
env_updates: dict[str, str] = {}
|
|
103
|
+
if modal_token_id:
|
|
104
|
+
env_updates["MODAL_TOKEN_ID"] = modal_token_id
|
|
105
|
+
if modal_token_secret:
|
|
106
|
+
env_updates["MODAL_TOKEN_SECRET"] = modal_token_secret
|
|
107
|
+
|
|
108
|
+
def call_modal() -> dict:
|
|
109
|
+
"""Call Modal function synchronously (runs in thread pool)."""
|
|
110
|
+
import modal
|
|
111
|
+
|
|
112
|
+
# Look up the deployed function
|
|
113
|
+
compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
|
|
114
|
+
|
|
115
|
+
# Call the function remotely
|
|
116
|
+
return compile_fn.remote(request_dict)
|
|
89
117
|
|
|
90
118
|
try:
|
|
91
|
-
#
|
|
92
|
-
#
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
# Load request
|
|
98
|
-
with open("{request_path}") as f:
|
|
99
|
-
request = json.load(f)
|
|
100
|
-
|
|
101
|
-
# Look up the deployed function
|
|
102
|
-
compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
|
|
103
|
-
|
|
104
|
-
# Call the function remotely
|
|
105
|
-
result = compile_fn.remote(request)
|
|
106
|
-
|
|
107
|
-
# Output result as JSON
|
|
108
|
-
print(json.dumps(result))
|
|
109
|
-
'''
|
|
110
|
-
|
|
111
|
-
# Run in subprocess to avoid event loop conflicts
|
|
112
|
-
env = os.environ.copy()
|
|
113
|
-
if modal_token_id:
|
|
114
|
-
env["MODAL_TOKEN_ID"] = modal_token_id
|
|
115
|
-
if modal_token_secret:
|
|
116
|
-
env["MODAL_TOKEN_SECRET"] = modal_token_secret
|
|
117
|
-
|
|
118
|
-
# Use the same Python interpreter that's running this code
|
|
119
|
-
import sys
|
|
120
|
-
python_executable = sys.executable
|
|
121
|
-
|
|
122
|
-
# Use asyncio.create_subprocess_exec for async subprocess execution
|
|
123
|
-
proc = await asyncio.create_subprocess_exec(
|
|
124
|
-
python_executable, "-c", script,
|
|
125
|
-
stdout=asyncio.subprocess.PIPE,
|
|
126
|
-
stderr=asyncio.subprocess.PIPE,
|
|
127
|
-
env=env,
|
|
128
|
-
)
|
|
129
|
-
stdout_bytes, stderr_bytes = await proc.communicate()
|
|
119
|
+
# Run Modal call in thread pool with temporary credentials
|
|
120
|
+
# The context manager ensures env vars are restored after the call
|
|
121
|
+
def call_modal_with_env() -> dict:
|
|
122
|
+
with temporary_env_vars(env_updates):
|
|
123
|
+
return call_modal()
|
|
130
124
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
# Check for common Modal auth errors
|
|
134
|
-
if "MODAL_TOKEN" in stderr or "AuthError" in stderr or "not authenticated" in stderr.lower():
|
|
135
|
-
return CompileResponse.error(
|
|
136
|
-
"Modal not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables, "
|
|
137
|
-
"or run 'modal token new' to authenticate.",
|
|
138
|
-
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
139
|
-
)
|
|
140
|
-
return CompileResponse.error(
|
|
141
|
-
f"Compilation failed: {stderr}",
|
|
142
|
-
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
143
|
-
)
|
|
125
|
+
result = await asyncio.to_thread(call_modal_with_env)
|
|
126
|
+
return response_from_dict(result)
|
|
144
127
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
return response_from_dict(response_dict)
|
|
150
|
-
except json.JSONDecodeError as e:
|
|
128
|
+
except Exception as e:
|
|
129
|
+
error_str = str(e)
|
|
130
|
+
# Check for common Modal auth errors
|
|
131
|
+
if "MODAL_TOKEN" in error_str or "AuthError" in error_str or "not authenticated" in error_str.lower():
|
|
151
132
|
return CompileResponse.error(
|
|
152
|
-
|
|
133
|
+
"Modal not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables, "
|
|
134
|
+
"or run 'modal token new' to authenticate.",
|
|
153
135
|
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
154
136
|
)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
137
|
+
return CompileResponse.error(
|
|
138
|
+
f"Compilation failed: {error_str}",
|
|
139
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
140
|
+
)
|
|
159
141
|
|
|
160
142
|
|
|
161
143
|
def compile_cuda_local(request: CompileRequest) -> CompileResponse:
|
|
@@ -197,7 +179,12 @@ def compile_cuda_local(request: CompileRequest) -> CompileResponse:
|
|
|
197
179
|
|
|
198
180
|
# Write all files to temp directory
|
|
199
181
|
for filename, content in request.files.items():
|
|
200
|
-
file_path = tmp_path / filename
|
|
182
|
+
file_path = (tmp_path / filename).resolve()
|
|
183
|
+
if not file_path.is_relative_to(tmp_path):
|
|
184
|
+
return CompileResponse.error(
|
|
185
|
+
f"Invalid filename: {filename}",
|
|
186
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
187
|
+
)
|
|
201
188
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
202
189
|
file_path.write_text(content)
|
|
203
190
|
|
|
@@ -79,7 +79,10 @@ app = modal.App(name="cuda-compile", image=compile_image)
|
|
|
79
79
|
cpu=4,
|
|
80
80
|
memory=8192, # 8GB RAM
|
|
81
81
|
timeout=120, # 2 minute timeout
|
|
82
|
+
# Keep one container warm to avoid cold starts (~5-10s savings)
|
|
83
|
+
min_containers=1,
|
|
82
84
|
)
|
|
85
|
+
@modal.concurrent(max_inputs=4) # Allow concurrent compilations for better throughput
|
|
83
86
|
def compile_cuda(request: dict) -> dict:
|
|
84
87
|
"""Compile CUDA code and return PTX/SASS.
|
|
85
88
|
|
|
@@ -105,6 +108,7 @@ def compile_cuda(request: dict) -> dict:
|
|
|
105
108
|
import subprocess
|
|
106
109
|
import tempfile
|
|
107
110
|
import time
|
|
111
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
108
112
|
from pathlib import Path
|
|
109
113
|
|
|
110
114
|
start_time = time.time()
|
|
@@ -138,13 +142,92 @@ def compile_cuda(request: dict) -> dict:
|
|
|
138
142
|
|
|
139
143
|
main_cu_file = cu_files[0]
|
|
140
144
|
|
|
145
|
+
# Build environment for nvcc
|
|
146
|
+
nvcc_env = {
|
|
147
|
+
**os.environ,
|
|
148
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
149
|
+
"PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
def compile_ptx(tmpdir: str, base_cmd: list[str], main_cu_path: Path) -> tuple[str | None, str | None]:
|
|
153
|
+
"""Compile to PTX. Returns (ptx_content, error_message)."""
|
|
154
|
+
ptx_output = Path(tmpdir) / "output.ptx"
|
|
155
|
+
ptx_cmd = base_cmd + [
|
|
156
|
+
"--ptx",
|
|
157
|
+
"-o",
|
|
158
|
+
str(ptx_output),
|
|
159
|
+
str(main_cu_path),
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
ptx_result = subprocess.run(
|
|
163
|
+
ptx_cmd,
|
|
164
|
+
capture_output=True,
|
|
165
|
+
text=True,
|
|
166
|
+
timeout=60,
|
|
167
|
+
cwd=tmpdir,
|
|
168
|
+
env=nvcc_env,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if ptx_result.returncode != 0:
|
|
172
|
+
return None, ptx_result.stderr or ptx_result.stdout
|
|
173
|
+
|
|
174
|
+
if ptx_output.exists():
|
|
175
|
+
return ptx_output.read_text(), None
|
|
176
|
+
return None, "PTX output file not created"
|
|
177
|
+
|
|
178
|
+
def compile_sass(tmpdir: str, base_cmd: list[str], main_cu_path: Path) -> tuple[str | None, str | None]:
|
|
179
|
+
"""Compile to SASS (via cubin). Returns (sass_content, error_message)."""
|
|
180
|
+
cubin_output = Path(tmpdir) / "output.cubin"
|
|
181
|
+
cubin_cmd = base_cmd + [
|
|
182
|
+
"--cubin",
|
|
183
|
+
"-o",
|
|
184
|
+
str(cubin_output),
|
|
185
|
+
str(main_cu_path),
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
cubin_result = subprocess.run(
|
|
189
|
+
cubin_cmd,
|
|
190
|
+
capture_output=True,
|
|
191
|
+
text=True,
|
|
192
|
+
timeout=60,
|
|
193
|
+
cwd=tmpdir,
|
|
194
|
+
env=nvcc_env,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if cubin_result.returncode != 0:
|
|
198
|
+
return None, cubin_result.stderr or cubin_result.stdout
|
|
199
|
+
|
|
200
|
+
if not cubin_output.exists():
|
|
201
|
+
return None, "cubin output file not created"
|
|
202
|
+
|
|
203
|
+
# Disassemble cubin to SASS
|
|
204
|
+
sass_result = subprocess.run(
|
|
205
|
+
["cuobjdump", "--dump-sass", str(cubin_output)],
|
|
206
|
+
capture_output=True,
|
|
207
|
+
text=True,
|
|
208
|
+
timeout=30,
|
|
209
|
+
cwd=tmpdir,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if sass_result.returncode == 0:
|
|
213
|
+
return sass_result.stdout, None
|
|
214
|
+
return None, f"SASS disassembly failed: {sass_result.stderr}"
|
|
215
|
+
|
|
141
216
|
try:
|
|
142
217
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
143
218
|
tmp_path = Path(tmpdir)
|
|
144
219
|
|
|
145
220
|
# Write all files to temp directory, preserving subdirectory structure
|
|
146
221
|
for filename, content in files.items():
|
|
147
|
-
file_path = tmp_path / filename
|
|
222
|
+
file_path = (tmp_path / filename).resolve()
|
|
223
|
+
if not file_path.is_relative_to(tmp_path):
|
|
224
|
+
return {
|
|
225
|
+
"success": False,
|
|
226
|
+
"ptx": None,
|
|
227
|
+
"sass": None,
|
|
228
|
+
"stderr": f"Invalid filename: {filename}",
|
|
229
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
230
|
+
}
|
|
148
231
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
149
232
|
file_path.write_text(content)
|
|
150
233
|
|
|
@@ -152,151 +235,78 @@ def compile_cuda(request: dict) -> dict:
|
|
|
152
235
|
main_cu_path = tmp_path / main_cu_file
|
|
153
236
|
include_dir = main_cu_path.parent
|
|
154
237
|
|
|
155
|
-
results: dict[str, str | None] = {"ptx": None, "sass": None}
|
|
156
|
-
|
|
157
238
|
# Build base nvcc command with common flags
|
|
158
239
|
base_cmd = [
|
|
159
240
|
"nvcc",
|
|
160
241
|
"-arch",
|
|
161
242
|
arch,
|
|
162
|
-
# Include the temp directory for user headers
|
|
163
243
|
f"-I{include_dir}",
|
|
164
|
-
# Include PyTorch headers
|
|
165
244
|
"-I/usr/local/lib/python3.12/site-packages/torch/include",
|
|
166
245
|
"-I/usr/local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include",
|
|
167
|
-
# Include CUTLASS headers
|
|
168
246
|
"-I/usr/local/cutlass/include",
|
|
169
|
-
# Standard CUDA headers are already in the default path
|
|
170
247
|
]
|
|
171
|
-
|
|
172
|
-
# Add user-specified flags
|
|
173
248
|
base_cmd.extend(flags)
|
|
174
249
|
|
|
175
|
-
#
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
ptx_cmd = base_cmd + [
|
|
179
|
-
"--ptx", # Generate PTX
|
|
180
|
-
"-o",
|
|
181
|
-
str(ptx_output),
|
|
182
|
-
str(main_cu_path),
|
|
183
|
-
]
|
|
184
|
-
|
|
185
|
-
ptx_result = subprocess.run(
|
|
186
|
-
ptx_cmd,
|
|
187
|
-
capture_output=True,
|
|
188
|
-
text=True,
|
|
189
|
-
timeout=60,
|
|
190
|
-
cwd=tmpdir,
|
|
191
|
-
env={
|
|
192
|
-
**os.environ,
|
|
193
|
-
"CUDA_HOME": "/usr/local/cuda",
|
|
194
|
-
"PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
|
|
195
|
-
},
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
if ptx_result.returncode != 0:
|
|
199
|
-
return {
|
|
200
|
-
"success": False,
|
|
201
|
-
"ptx": None,
|
|
202
|
-
"sass": None,
|
|
203
|
-
"stderr": ptx_result.stderr or ptx_result.stdout,
|
|
204
|
-
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
205
|
-
}
|
|
250
|
+
# Determine what to compile
|
|
251
|
+
want_ptx = OutputFormat.PTX.value in output_formats
|
|
252
|
+
want_sass = OutputFormat.SASS.value in output_formats
|
|
206
253
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
#
|
|
211
|
-
if
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
"-o",
|
|
217
|
-
str(cubin_output),
|
|
218
|
-
str(main_cu_path),
|
|
219
|
-
]
|
|
220
|
-
|
|
221
|
-
cubin_result = subprocess.run(
|
|
222
|
-
cubin_cmd,
|
|
223
|
-
capture_output=True,
|
|
224
|
-
text=True,
|
|
225
|
-
timeout=60,
|
|
226
|
-
cwd=tmpdir,
|
|
227
|
-
env={
|
|
228
|
-
**os.environ,
|
|
229
|
-
"CUDA_HOME": "/usr/local/cuda",
|
|
230
|
-
"PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
|
|
231
|
-
},
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
if cubin_result.returncode != 0:
|
|
235
|
-
# If we already have PTX, that's a partial success
|
|
236
|
-
if results["ptx"]:
|
|
237
|
-
return {
|
|
238
|
-
"success": True,
|
|
239
|
-
"ptx": results["ptx"],
|
|
240
|
-
"sass": None,
|
|
241
|
-
"stderr": f"SASS generation failed: {cubin_result.stderr}",
|
|
242
|
-
"compilation_time_ms": int(
|
|
243
|
-
(time.time() - start_time) * 1000
|
|
244
|
-
),
|
|
245
|
-
}
|
|
246
|
-
return {
|
|
247
|
-
"success": False,
|
|
248
|
-
"ptx": None,
|
|
249
|
-
"sass": None,
|
|
250
|
-
"stderr": cubin_result.stderr or cubin_result.stdout,
|
|
251
|
-
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
254
|
+
results: dict[str, str | None] = {"ptx": None, "sass": None}
|
|
255
|
+
errors: list[str] = []
|
|
256
|
+
|
|
257
|
+
# Run compilations in parallel if both are requested
|
|
258
|
+
if want_ptx and want_sass:
|
|
259
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
260
|
+
futures = {
|
|
261
|
+
executor.submit(compile_ptx, tmpdir, base_cmd, main_cu_path): "ptx",
|
|
262
|
+
executor.submit(compile_sass, tmpdir, base_cmd, main_cu_path): "sass",
|
|
252
263
|
}
|
|
253
264
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
"compilation_time_ms": int(
|
|
281
|
-
(time.time() - start_time) * 1000
|
|
282
|
-
),
|
|
283
|
-
}
|
|
284
|
-
|
|
285
|
-
# Check if we got any output
|
|
265
|
+
for future in as_completed(futures):
|
|
266
|
+
output_type = futures[future]
|
|
267
|
+
try:
|
|
268
|
+
content, error = future.result()
|
|
269
|
+
if content:
|
|
270
|
+
results[output_type] = content
|
|
271
|
+
if error:
|
|
272
|
+
errors.append(f"{output_type.upper()}: {error}")
|
|
273
|
+
except Exception as e:
|
|
274
|
+
errors.append(f"{output_type.upper()} compilation error: {e}")
|
|
275
|
+
|
|
276
|
+
elif want_ptx:
|
|
277
|
+
content, error = compile_ptx(tmpdir, base_cmd, main_cu_path)
|
|
278
|
+
if content:
|
|
279
|
+
results["ptx"] = content
|
|
280
|
+
if error:
|
|
281
|
+
errors.append(error)
|
|
282
|
+
|
|
283
|
+
elif want_sass:
|
|
284
|
+
content, error = compile_sass(tmpdir, base_cmd, main_cu_path)
|
|
285
|
+
if content:
|
|
286
|
+
results["sass"] = content
|
|
287
|
+
if error:
|
|
288
|
+
errors.append(error)
|
|
289
|
+
|
|
290
|
+
# Check results
|
|
286
291
|
if not results["ptx"] and not results["sass"]:
|
|
287
292
|
return {
|
|
288
293
|
"success": False,
|
|
289
294
|
"ptx": None,
|
|
290
295
|
"sass": None,
|
|
291
|
-
"stderr": "No output generated",
|
|
296
|
+
"stderr": "\n".join(errors) if errors else "No output generated",
|
|
292
297
|
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
293
298
|
}
|
|
294
299
|
|
|
300
|
+
# Partial success if we got at least one output
|
|
301
|
+
stderr = ""
|
|
302
|
+
if errors and (results["ptx"] or results["sass"]):
|
|
303
|
+
stderr = "\n".join(errors)
|
|
304
|
+
|
|
295
305
|
return {
|
|
296
306
|
"success": True,
|
|
297
307
|
"ptx": results["ptx"],
|
|
298
308
|
"sass": results["sass"],
|
|
299
|
-
"stderr":
|
|
309
|
+
"stderr": stderr,
|
|
300
310
|
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
301
311
|
}
|
|
302
312
|
|
|
@@ -645,8 +645,9 @@ wafer_core/tools/capture_tool/dtypes.py,sha256=1Vm5obOCYc-Njuwkp7uqh_W4lqtYurT3b
|
|
|
645
645
|
wafer_core/tools/capture_tool/executor.py,sha256=n1DVfbsP60yJAazx9C9Kwed9LB7AcKXJcoDnhno7ydU,1495
|
|
646
646
|
wafer_core/tools/capture_tool/metrics.py,sha256=BFZNmdE-kh3LneYdWXTNZmlLuo-DCrP5aEBHxEQYJDU,10890
|
|
647
647
|
wafer_core/tools/compile/__init__.py,sha256=8VyaMDDPxg4DcT-rwMf9lcNhAanWnmsqijUJYsuzJNg,615
|
|
648
|
-
wafer_core/tools/compile/
|
|
649
|
-
wafer_core/tools/compile/
|
|
648
|
+
wafer_core/tools/compile/benchmark.py,sha256=6_nfhl24vTWt59EwGievbyMHZK2l4wfslP77BHWsoQ4,19408
|
|
649
|
+
wafer_core/tools/compile/compiler.py,sha256=Y7iwfQkSBc4fmKXpv97ce1grw5L4tJ_VqWFFyYolRAg,10054
|
|
650
|
+
wafer_core/tools/compile/modal_compile.py,sha256=lYMxdrvEQctA1Om6yESetjUAsSyv0W0evNVb8WOY2Ps,13384
|
|
650
651
|
wafer_core/tools/compile/types.py,sha256=8Hjh6Mz2a7s2JjtKYQq-l3X41gmywnbKk3tc1wvbMLM,3277
|
|
651
652
|
wafer_core/tools/compile/tests/__init__.py,sha256=gSuBMN-7VayQ9HgyNuUXRumenwk7jtq86ZxdCgFjeYE,41
|
|
652
653
|
wafer_core/tools/compile/tests/test_compiler.py,sha256=kQ-YTLY8ETnS83nQ8xVSygKY532epxqRTsGx311SG7w,20795
|
|
@@ -722,6 +723,6 @@ wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdh
|
|
|
722
723
|
wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
|
|
723
724
|
wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
|
|
724
725
|
wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
|
|
725
|
-
wafer_core-0.1.
|
|
726
|
-
wafer_core-0.1.
|
|
727
|
-
wafer_core-0.1.
|
|
726
|
+
wafer_core-0.1.41.dist-info/METADATA,sha256=98wUlC8ReP0ZSLp7E_CFAefo7E3W2AVirN9RaKb6Urg,1477
|
|
727
|
+
wafer_core-0.1.41.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
728
|
+
wafer_core-0.1.41.dist-info/RECORD,,
|
|
File without changes
|