wafer-core 0.1.38__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/fusion_analyzer.py +2 -0
- wafer_core/rollouts/_logging/__init__.py +5 -1
- wafer_core/rollouts/_logging/logging_config.py +95 -3
- wafer_core/rollouts/_logging/sample_handler.py +66 -0
- wafer_core/rollouts/_pytui/__init__.py +114 -0
- wafer_core/rollouts/_pytui/app.py +809 -0
- wafer_core/rollouts/_pytui/console.py +291 -0
- wafer_core/rollouts/_pytui/renderer.py +210 -0
- wafer_core/rollouts/_pytui/spinner.py +73 -0
- wafer_core/rollouts/_pytui/terminal.py +489 -0
- wafer_core/rollouts/_pytui/text.py +470 -0
- wafer_core/rollouts/_pytui/theme.py +241 -0
- wafer_core/rollouts/evaluation.py +142 -177
- wafer_core/rollouts/progress_app.py +395 -0
- wafer_core/rollouts/tui/DESIGN.md +251 -115
- wafer_core/rollouts/tui/monitor.py +64 -20
- wafer_core/tools/compile/__init__.py +30 -0
- wafer_core/tools/compile/compiler.py +314 -0
- wafer_core/tools/compile/modal_compile.py +359 -0
- wafer_core/tools/compile/tests/__init__.py +1 -0
- wafer_core/tools/compile/tests/test_compiler.py +675 -0
- wafer_core/tools/compile/tests/test_data/utils.cuh +10 -0
- wafer_core/tools/compile/tests/test_data/vector_add.cu +7 -0
- wafer_core/tools/compile/tests/test_data/with_header.cu +9 -0
- wafer_core/tools/compile/tests/test_modal_integration.py +326 -0
- wafer_core/tools/compile/types.py +117 -0
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/METADATA +1 -1
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/RECORD +29 -12
- wafer_core/rollouts/events.py +0 -240
- wafer_core/rollouts/progress_display.py +0 -476
- wafer_core/utils/event_streaming.py +0 -63
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,675 @@
|
|
|
1
|
+
"""Tests for the cloud CUDA compiler.
|
|
2
|
+
|
|
3
|
+
These tests verify request validation, serialization, and compilation logic.
|
|
4
|
+
Integration tests that actually invoke nvcc are in test_modal_integration.py.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from wafer_core.tools.compile.types import (
|
|
10
|
+
CompileRequest,
|
|
11
|
+
CompileResponse,
|
|
12
|
+
OutputFormat,
|
|
13
|
+
VALID_ARCHITECTURES,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestCompileRequest:
|
|
18
|
+
"""Test request validation and serialization."""
|
|
19
|
+
|
|
20
|
+
def test_single_file_request(self) -> None:
|
|
21
|
+
"""Test creating a request with a single file."""
|
|
22
|
+
request = CompileRequest(
|
|
23
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
24
|
+
)
|
|
25
|
+
assert request.files == {"kernel.cu": "__global__ void test() {}"}
|
|
26
|
+
assert request.arch == "sm_90a" # default
|
|
27
|
+
assert request.flags == ()
|
|
28
|
+
assert request.output == (OutputFormat.PTX, OutputFormat.SASS)
|
|
29
|
+
assert request.main_cu_file == "kernel.cu"
|
|
30
|
+
|
|
31
|
+
def test_multi_file_request(self) -> None:
|
|
32
|
+
"""Test creating a request with multiple files."""
|
|
33
|
+
request = CompileRequest(
|
|
34
|
+
files={
|
|
35
|
+
"main.cu": '#include "utils.cuh"\n__global__ void test() {}',
|
|
36
|
+
"utils.cuh": "__device__ float square(float x) { return x * x; }",
|
|
37
|
+
},
|
|
38
|
+
)
|
|
39
|
+
assert len(request.files) == 2
|
|
40
|
+
assert "main.cu" in request.files
|
|
41
|
+
assert "utils.cuh" in request.files
|
|
42
|
+
|
|
43
|
+
def test_custom_arch(self) -> None:
|
|
44
|
+
"""Test specifying a custom architecture."""
|
|
45
|
+
request = CompileRequest(
|
|
46
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
47
|
+
arch="sm_80",
|
|
48
|
+
)
|
|
49
|
+
assert request.arch == "sm_80"
|
|
50
|
+
|
|
51
|
+
def test_custom_flags(self) -> None:
|
|
52
|
+
"""Test specifying custom compiler flags."""
|
|
53
|
+
request = CompileRequest(
|
|
54
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
55
|
+
flags=("-O3", "--maxrregcount=64", "-lineinfo"),
|
|
56
|
+
)
|
|
57
|
+
assert request.flags == ("-O3", "--maxrregcount=64", "-lineinfo")
|
|
58
|
+
|
|
59
|
+
def test_ptx_only_output(self) -> None:
|
|
60
|
+
"""Test requesting only PTX output."""
|
|
61
|
+
request = CompileRequest(
|
|
62
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
63
|
+
output=(OutputFormat.PTX,),
|
|
64
|
+
)
|
|
65
|
+
assert request.output == (OutputFormat.PTX,)
|
|
66
|
+
|
|
67
|
+
def test_sass_only_output(self) -> None:
|
|
68
|
+
"""Test requesting only SASS output."""
|
|
69
|
+
request = CompileRequest(
|
|
70
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
71
|
+
output=(OutputFormat.SASS,),
|
|
72
|
+
)
|
|
73
|
+
assert request.output == (OutputFormat.SASS,)
|
|
74
|
+
|
|
75
|
+
def test_invalid_arch_rejected(self) -> None:
|
|
76
|
+
"""Test that invalid architecture is rejected."""
|
|
77
|
+
with pytest.raises(ValueError, match="Invalid architecture"):
|
|
78
|
+
CompileRequest(
|
|
79
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
80
|
+
arch="sm_999",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def test_empty_files_rejected(self) -> None:
|
|
84
|
+
"""Test that empty files dict is rejected."""
|
|
85
|
+
with pytest.raises(ValueError, match="At least one file is required"):
|
|
86
|
+
CompileRequest(files={})
|
|
87
|
+
|
|
88
|
+
def test_no_cu_file_rejected(self) -> None:
|
|
89
|
+
"""Test that missing .cu file is rejected."""
|
|
90
|
+
with pytest.raises(ValueError, match="At least one .cu file is required"):
|
|
91
|
+
CompileRequest(files={"utils.cuh": "// header only"})
|
|
92
|
+
|
|
93
|
+
def test_empty_output_rejected(self) -> None:
|
|
94
|
+
"""Test that empty output list is rejected."""
|
|
95
|
+
with pytest.raises(ValueError, match="At least one output format is required"):
|
|
96
|
+
CompileRequest(
|
|
97
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
98
|
+
output=(),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def test_request_is_frozen(self) -> None:
|
|
102
|
+
"""Test that request is immutable."""
|
|
103
|
+
request = CompileRequest(
|
|
104
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
105
|
+
)
|
|
106
|
+
with pytest.raises(AttributeError):
|
|
107
|
+
request.arch = "sm_80" # type: ignore[misc]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class TestCompileResponse:
|
|
111
|
+
"""Test response creation and helper methods."""
|
|
112
|
+
|
|
113
|
+
def test_success_response(self) -> None:
|
|
114
|
+
"""Test creating a successful response."""
|
|
115
|
+
response = CompileResponse(
|
|
116
|
+
success=True,
|
|
117
|
+
ptx=".version 8.0\n.target sm_90a",
|
|
118
|
+
sass="MOV R0, R1;",
|
|
119
|
+
compilation_time_ms=150,
|
|
120
|
+
)
|
|
121
|
+
assert response.success is True
|
|
122
|
+
assert response.ptx is not None
|
|
123
|
+
assert response.sass is not None
|
|
124
|
+
assert response.stderr == ""
|
|
125
|
+
assert response.compilation_time_ms == 150
|
|
126
|
+
|
|
127
|
+
def test_error_response_helper(self) -> None:
|
|
128
|
+
"""Test the error response helper method."""
|
|
129
|
+
response = CompileResponse.error(
|
|
130
|
+
"kernel.cu(10): error: expected a ';'",
|
|
131
|
+
compilation_time_ms=50,
|
|
132
|
+
)
|
|
133
|
+
assert response.success is False
|
|
134
|
+
assert response.ptx is None
|
|
135
|
+
assert response.sass is None
|
|
136
|
+
assert "expected a ';'" in response.stderr
|
|
137
|
+
assert response.compilation_time_ms == 50
|
|
138
|
+
|
|
139
|
+
def test_response_is_frozen(self) -> None:
|
|
140
|
+
"""Test that response is immutable."""
|
|
141
|
+
response = CompileResponse(success=True)
|
|
142
|
+
with pytest.raises(AttributeError):
|
|
143
|
+
response.success = False # type: ignore[misc]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class TestValidArchitectures:
|
|
147
|
+
"""Test the set of valid architectures."""
|
|
148
|
+
|
|
149
|
+
def test_common_architectures_valid(self) -> None:
|
|
150
|
+
"""Test that common architectures are in the valid set."""
|
|
151
|
+
common_archs = ["sm_80", "sm_86", "sm_89", "sm_90", "sm_90a"]
|
|
152
|
+
for arch in common_archs:
|
|
153
|
+
assert arch in VALID_ARCHITECTURES, f"{arch} should be valid"
|
|
154
|
+
|
|
155
|
+
def test_hopper_architectures(self) -> None:
|
|
156
|
+
"""Test Hopper architectures."""
|
|
157
|
+
assert "sm_90" in VALID_ARCHITECTURES
|
|
158
|
+
assert "sm_90a" in VALID_ARCHITECTURES
|
|
159
|
+
|
|
160
|
+
def test_blackwell_architectures(self) -> None:
|
|
161
|
+
"""Test Blackwell architectures."""
|
|
162
|
+
assert "sm_100" in VALID_ARCHITECTURES
|
|
163
|
+
assert "sm_100a" in VALID_ARCHITECTURES
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class TestOutputFormat:
|
|
167
|
+
"""Test output format enum."""
|
|
168
|
+
|
|
169
|
+
def test_ptx_value(self) -> None:
|
|
170
|
+
"""Test PTX enum value."""
|
|
171
|
+
assert OutputFormat.PTX.value == "ptx"
|
|
172
|
+
assert OutputFormat.PTX == "ptx" # str enum comparison
|
|
173
|
+
|
|
174
|
+
def test_sass_value(self) -> None:
|
|
175
|
+
"""Test SASS enum value."""
|
|
176
|
+
assert OutputFormat.SASS.value == "sass"
|
|
177
|
+
assert OutputFormat.SASS == "sass" # str enum comparison
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# =============================================================================
|
|
181
|
+
# Test CUDA Code Examples
|
|
182
|
+
# =============================================================================
|
|
183
|
+
|
|
184
|
+
# These are example CUDA kernels used in tests.
|
|
185
|
+
# They are validated for correct syntax in integration tests.
|
|
186
|
+
|
|
187
|
+
VECTOR_ADD_KERNEL = """\
|
|
188
|
+
__global__ void vector_add(float* a, float* b, float* c, int n) {
|
|
189
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
190
|
+
if (idx < n) {
|
|
191
|
+
c[idx] = a[idx] + b[idx];
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
MATRIX_MULTIPLY_KERNEL = """\
|
|
197
|
+
__global__ void matmul(float* A, float* B, float* C, int M, int N, int K) {
|
|
198
|
+
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
|
199
|
+
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
|
200
|
+
|
|
201
|
+
if (row < M && col < N) {
|
|
202
|
+
float sum = 0.0f;
|
|
203
|
+
for (int k = 0; k < K; k++) {
|
|
204
|
+
sum += A[row * K + k] * B[k * N + col];
|
|
205
|
+
}
|
|
206
|
+
C[row * N + col] = sum;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
REDUCTION_KERNEL = """\
|
|
212
|
+
__global__ void reduce_sum(float* input, float* output, int n) {
|
|
213
|
+
extern __shared__ float sdata[];
|
|
214
|
+
|
|
215
|
+
int tid = threadIdx.x;
|
|
216
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
217
|
+
|
|
218
|
+
sdata[tid] = (i < n) ? input[i] : 0.0f;
|
|
219
|
+
__syncthreads();
|
|
220
|
+
|
|
221
|
+
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
222
|
+
if (tid < s) {
|
|
223
|
+
sdata[tid] += sdata[tid + s];
|
|
224
|
+
}
|
|
225
|
+
__syncthreads();
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
if (tid == 0) {
|
|
229
|
+
output[blockIdx.x] = sdata[0];
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
SHARED_MEMORY_KERNEL = """\
|
|
235
|
+
#define TILE_SIZE 16
|
|
236
|
+
|
|
237
|
+
__global__ void tiled_matmul(float* A, float* B, float* C, int N) {
|
|
238
|
+
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
|
239
|
+
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
|
240
|
+
|
|
241
|
+
int bx = blockIdx.x, by = blockIdx.y;
|
|
242
|
+
int tx = threadIdx.x, ty = threadIdx.y;
|
|
243
|
+
|
|
244
|
+
int row = by * TILE_SIZE + ty;
|
|
245
|
+
int col = bx * TILE_SIZE + tx;
|
|
246
|
+
|
|
247
|
+
float sum = 0.0f;
|
|
248
|
+
|
|
249
|
+
for (int t = 0; t < (N + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
|
250
|
+
if (row < N && t * TILE_SIZE + tx < N)
|
|
251
|
+
As[ty][tx] = A[row * N + t * TILE_SIZE + tx];
|
|
252
|
+
else
|
|
253
|
+
As[ty][tx] = 0.0f;
|
|
254
|
+
|
|
255
|
+
if (t * TILE_SIZE + ty < N && col < N)
|
|
256
|
+
Bs[ty][tx] = B[(t * TILE_SIZE + ty) * N + col];
|
|
257
|
+
else
|
|
258
|
+
Bs[ty][tx] = 0.0f;
|
|
259
|
+
|
|
260
|
+
__syncthreads();
|
|
261
|
+
|
|
262
|
+
for (int k = 0; k < TILE_SIZE; k++) {
|
|
263
|
+
sum += As[ty][k] * Bs[k][tx];
|
|
264
|
+
}
|
|
265
|
+
__syncthreads();
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
if (row < N && col < N) {
|
|
269
|
+
C[row * N + col] = sum;
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class TestSimpleKernels:
|
|
276
|
+
"""Test that simple kernel code can be used in requests."""
|
|
277
|
+
|
|
278
|
+
def test_vector_add_kernel(self) -> None:
|
|
279
|
+
"""Test vector add kernel can be packaged in a request."""
|
|
280
|
+
request = CompileRequest(
|
|
281
|
+
files={"vector_add.cu": VECTOR_ADD_KERNEL},
|
|
282
|
+
)
|
|
283
|
+
assert "vector_add" in request.files["vector_add.cu"]
|
|
284
|
+
|
|
285
|
+
def test_matrix_multiply_kernel(self) -> None:
|
|
286
|
+
"""Test matrix multiply kernel can be packaged in a request."""
|
|
287
|
+
request = CompileRequest(
|
|
288
|
+
files={"matmul.cu": MATRIX_MULTIPLY_KERNEL},
|
|
289
|
+
)
|
|
290
|
+
assert "matmul" in request.files["matmul.cu"]
|
|
291
|
+
|
|
292
|
+
def test_reduction_kernel(self) -> None:
|
|
293
|
+
"""Test reduction kernel can be packaged in a request."""
|
|
294
|
+
request = CompileRequest(
|
|
295
|
+
files={"reduce.cu": REDUCTION_KERNEL},
|
|
296
|
+
)
|
|
297
|
+
assert "__shared__" in request.files["reduce.cu"]
|
|
298
|
+
|
|
299
|
+
def test_shared_memory_kernel(self) -> None:
|
|
300
|
+
"""Test shared memory kernel can be packaged in a request."""
|
|
301
|
+
request = CompileRequest(
|
|
302
|
+
files={"tiled_matmul.cu": SHARED_MEMORY_KERNEL},
|
|
303
|
+
)
|
|
304
|
+
assert "TILE_SIZE" in request.files["tiled_matmul.cu"]
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# Multi-file test examples
|
|
308
|
+
|
|
309
|
+
HEADER_FILE = """\
|
|
310
|
+
#pragma once
|
|
311
|
+
|
|
312
|
+
__device__ float square(float x) {
|
|
313
|
+
return x * x;
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
__device__ float cube(float x) {
|
|
317
|
+
return x * x * x;
|
|
318
|
+
}
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
MAIN_FILE_WITH_HEADER = """\
|
|
322
|
+
#include "utils.cuh"
|
|
323
|
+
|
|
324
|
+
__global__ void apply_square(float* data, int n) {
|
|
325
|
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
326
|
+
if (idx < n) {
|
|
327
|
+
data[idx] = square(data[idx]);
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
NESTED_HEADER_A = """\
|
|
333
|
+
#pragma once
|
|
334
|
+
|
|
335
|
+
#define MY_CONSTANT 42
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
NESTED_HEADER_B = """\
|
|
339
|
+
#pragma once
|
|
340
|
+
|
|
341
|
+
#include "constants.cuh"
|
|
342
|
+
|
|
343
|
+
__device__ int get_constant() {
|
|
344
|
+
return MY_CONSTANT;
|
|
345
|
+
}
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
MAIN_WITH_NESTED = """\
|
|
349
|
+
#include "helper.cuh"
|
|
350
|
+
|
|
351
|
+
__global__ void use_constant(int* out) {
|
|
352
|
+
*out = get_constant();
|
|
353
|
+
}
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class TestMultiFile:
|
|
358
|
+
"""Test multi-file compilation requests."""
|
|
359
|
+
|
|
360
|
+
def test_kernel_with_header(self) -> None:
|
|
361
|
+
"""Test kernel that includes a header."""
|
|
362
|
+
request = CompileRequest(
|
|
363
|
+
files={
|
|
364
|
+
"main.cu": MAIN_FILE_WITH_HEADER,
|
|
365
|
+
"utils.cuh": HEADER_FILE,
|
|
366
|
+
},
|
|
367
|
+
)
|
|
368
|
+
assert len(request.files) == 2
|
|
369
|
+
assert request.main_cu_file == "main.cu"
|
|
370
|
+
|
|
371
|
+
def test_kernel_with_multiple_headers(self) -> None:
|
|
372
|
+
"""Test kernel with multiple headers."""
|
|
373
|
+
extra_header = "__device__ float add(float a, float b) { return a + b; }"
|
|
374
|
+
request = CompileRequest(
|
|
375
|
+
files={
|
|
376
|
+
"main.cu": '#include "utils.cuh"\n#include "math.cuh"\n'
|
|
377
|
+
+ MAIN_FILE_WITH_HEADER.split("\n", 1)[1],
|
|
378
|
+
"utils.cuh": HEADER_FILE,
|
|
379
|
+
"math.cuh": extra_header,
|
|
380
|
+
},
|
|
381
|
+
)
|
|
382
|
+
assert len(request.files) == 3
|
|
383
|
+
|
|
384
|
+
def test_nested_includes(self) -> None:
|
|
385
|
+
"""Test nested header includes."""
|
|
386
|
+
request = CompileRequest(
|
|
387
|
+
files={
|
|
388
|
+
"main.cu": MAIN_WITH_NESTED,
|
|
389
|
+
"helper.cuh": NESTED_HEADER_B,
|
|
390
|
+
"constants.cuh": NESTED_HEADER_A,
|
|
391
|
+
},
|
|
392
|
+
)
|
|
393
|
+
assert len(request.files) == 3
|
|
394
|
+
assert "MY_CONSTANT" in request.files["constants.cuh"]
|
|
395
|
+
|
|
396
|
+
def test_relative_include_paths(self) -> None:
|
|
397
|
+
"""Test that relative include paths are preserved in files dict."""
|
|
398
|
+
request = CompileRequest(
|
|
399
|
+
files={
|
|
400
|
+
"src/main.cu": '#include "../include/utils.cuh"\n__global__ void test() {}',
|
|
401
|
+
"include/utils.cuh": "// utils",
|
|
402
|
+
},
|
|
403
|
+
)
|
|
404
|
+
assert "src/main.cu" in request.files
|
|
405
|
+
assert "include/utils.cuh" in request.files
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
# PyTorch extension examples
|
|
409
|
+
|
|
410
|
+
PYTORCH_ACCESSOR_KERNEL = """\
|
|
411
|
+
#include <torch/types.h>
|
|
412
|
+
#include <cuda.h>
|
|
413
|
+
#include <cuda_runtime.h>
|
|
414
|
+
|
|
415
|
+
__global__ void fused_add_kernel(
|
|
416
|
+
const float* __restrict__ a,
|
|
417
|
+
const float* __restrict__ b,
|
|
418
|
+
float* __restrict__ out,
|
|
419
|
+
int64_t size
|
|
420
|
+
) {
|
|
421
|
+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
422
|
+
if (idx < size) {
|
|
423
|
+
out[idx] = a[idx] + b[idx];
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
PYTORCH_PACKED_ACCESSOR_KERNEL = """\
|
|
429
|
+
#include <torch/types.h>
|
|
430
|
+
#include <ATen/ATen.h>
|
|
431
|
+
#include <cuda.h>
|
|
432
|
+
#include <cuda_runtime.h>
|
|
433
|
+
|
|
434
|
+
template <typename scalar_t>
|
|
435
|
+
__global__ void accessor_kernel(
|
|
436
|
+
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input,
|
|
437
|
+
torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> output
|
|
438
|
+
) {
|
|
439
|
+
int row = blockIdx.x * blockDim.x + threadIdx.x;
|
|
440
|
+
int col = blockIdx.y * blockDim.y + threadIdx.y;
|
|
441
|
+
|
|
442
|
+
if (row < input.size(0) && col < input.size(1)) {
|
|
443
|
+
output[row][col] = input[row][col] * 2;
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class TestPyTorchHeaders:
|
|
450
|
+
"""Test code using PyTorch CUDA types."""
|
|
451
|
+
|
|
452
|
+
def test_pytorch_accessor(self) -> None:
|
|
453
|
+
"""Test PyTorch accessor kernel can be packaged."""
|
|
454
|
+
request = CompileRequest(
|
|
455
|
+
files={"pytorch_kernel.cu": PYTORCH_ACCESSOR_KERNEL},
|
|
456
|
+
)
|
|
457
|
+
assert "torch/types.h" in request.files["pytorch_kernel.cu"]
|
|
458
|
+
|
|
459
|
+
def test_pytorch_packed_accessor(self) -> None:
|
|
460
|
+
"""Test PyTorch packed accessor kernel can be packaged."""
|
|
461
|
+
request = CompileRequest(
|
|
462
|
+
files={"accessor_kernel.cu": PYTORCH_PACKED_ACCESSOR_KERNEL},
|
|
463
|
+
)
|
|
464
|
+
assert "PackedTensorAccessor" in request.files["accessor_kernel.cu"]
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
# CUTLASS examples
|
|
468
|
+
|
|
469
|
+
CUTLASS_TYPES_KERNEL = """\
|
|
470
|
+
#include <cutlass/cutlass.h>
|
|
471
|
+
#include <cutlass/numeric_types.h>
|
|
472
|
+
|
|
473
|
+
__global__ void use_cutlass_types() {
|
|
474
|
+
cutlass::half_t a = cutlass::half_t(1.0f);
|
|
475
|
+
cutlass::half_t b = cutlass::half_t(2.0f);
|
|
476
|
+
cutlass::half_t c = a + b;
|
|
477
|
+
}
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
CUTLASS_GEMM_INCLUDE = """\
|
|
481
|
+
#include <cutlass/cutlass.h>
|
|
482
|
+
#include <cutlass/gemm/device/gemm.h>
|
|
483
|
+
|
|
484
|
+
// Just testing include paths work
|
|
485
|
+
using ColumnMajor = cutlass::layout::ColumnMajor;
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class TestCUTLASSHeaders:
|
|
490
|
+
"""Test code using CUTLASS headers."""
|
|
491
|
+
|
|
492
|
+
def test_cutlass_types(self) -> None:
|
|
493
|
+
"""Test CUTLASS types kernel can be packaged."""
|
|
494
|
+
request = CompileRequest(
|
|
495
|
+
files={"cutlass_kernel.cu": CUTLASS_TYPES_KERNEL},
|
|
496
|
+
)
|
|
497
|
+
assert "cutlass/cutlass.h" in request.files["cutlass_kernel.cu"]
|
|
498
|
+
|
|
499
|
+
def test_cutlass_gemm_include(self) -> None:
|
|
500
|
+
"""Test CUTLASS GEMM include can be packaged."""
|
|
501
|
+
request = CompileRequest(
|
|
502
|
+
files={"gemm_kernel.cu": CUTLASS_GEMM_INCLUDE},
|
|
503
|
+
)
|
|
504
|
+
assert "cutlass/gemm" in request.files["gemm_kernel.cu"]
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class TestOutputFormats:
|
|
508
|
+
"""Test output format expectations for responses."""
|
|
509
|
+
|
|
510
|
+
def test_ptx_output_structure(self) -> None:
|
|
511
|
+
"""Test PTX output has expected structure."""
|
|
512
|
+
# PTX typically starts with version/target info
|
|
513
|
+
mock_ptx = """\
|
|
514
|
+
.version 8.0
|
|
515
|
+
.target sm_90a
|
|
516
|
+
.address_size 64
|
|
517
|
+
|
|
518
|
+
.visible .entry vector_add(
|
|
519
|
+
.param .u64 a,
|
|
520
|
+
.param .u64 b,
|
|
521
|
+
.param .u64 c,
|
|
522
|
+
.param .u32 n
|
|
523
|
+
)
|
|
524
|
+
{
|
|
525
|
+
.reg .pred %p<2>;
|
|
526
|
+
.reg .f32 %f<4>;
|
|
527
|
+
// ...
|
|
528
|
+
}
|
|
529
|
+
"""
|
|
530
|
+
response = CompileResponse(
|
|
531
|
+
success=True,
|
|
532
|
+
ptx=mock_ptx,
|
|
533
|
+
compilation_time_ms=100,
|
|
534
|
+
)
|
|
535
|
+
assert response.ptx is not None
|
|
536
|
+
assert ".version" in response.ptx
|
|
537
|
+
assert ".target" in response.ptx
|
|
538
|
+
assert ".entry" in response.ptx
|
|
539
|
+
|
|
540
|
+
def test_sass_output_structure(self) -> None:
|
|
541
|
+
"""Test SASS output has expected structure."""
|
|
542
|
+
# SASS is assembly with register operations
|
|
543
|
+
mock_sass = """\
|
|
544
|
+
code for sm_90a
|
|
545
|
+
Function : vector_add
|
|
546
|
+
.headerflags @"EF_CUDA_SM90 EF_CUDA_PTX_SM(8.0)"
|
|
547
|
+
/*0000*/ MOV R1, c[0x0][0x28] ;
|
|
548
|
+
/*0010*/ S2R R0, SR_CTAID.X ;
|
|
549
|
+
/*0020*/ IMAD.U32 R0, R0, c[0x0][0x0], R0 ;
|
|
550
|
+
"""
|
|
551
|
+
response = CompileResponse(
|
|
552
|
+
success=True,
|
|
553
|
+
sass=mock_sass,
|
|
554
|
+
compilation_time_ms=100,
|
|
555
|
+
)
|
|
556
|
+
assert response.sass is not None
|
|
557
|
+
assert "MOV" in response.sass or "code for" in response.sass
|
|
558
|
+
|
|
559
|
+
def test_both_outputs(self) -> None:
|
|
560
|
+
"""Test response can have both PTX and SASS."""
|
|
561
|
+
response = CompileResponse(
|
|
562
|
+
success=True,
|
|
563
|
+
ptx=".version 8.0",
|
|
564
|
+
sass="MOV R0, R1",
|
|
565
|
+
compilation_time_ms=150,
|
|
566
|
+
)
|
|
567
|
+
assert response.ptx is not None
|
|
568
|
+
assert response.sass is not None
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class TestArchitectures:
|
|
572
|
+
"""Test different GPU architecture specifications."""
|
|
573
|
+
|
|
574
|
+
def test_sm_90a_hopper(self) -> None:
|
|
575
|
+
"""Test Hopper architecture request."""
|
|
576
|
+
request = CompileRequest(
|
|
577
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
578
|
+
arch="sm_90a",
|
|
579
|
+
)
|
|
580
|
+
assert request.arch == "sm_90a"
|
|
581
|
+
|
|
582
|
+
def test_sm_89_ada(self) -> None:
|
|
583
|
+
"""Test Ada Lovelace architecture request."""
|
|
584
|
+
request = CompileRequest(
|
|
585
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
586
|
+
arch="sm_89",
|
|
587
|
+
)
|
|
588
|
+
assert request.arch == "sm_89"
|
|
589
|
+
|
|
590
|
+
def test_sm_80_ampere(self) -> None:
|
|
591
|
+
"""Test Ampere architecture request."""
|
|
592
|
+
request = CompileRequest(
|
|
593
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
594
|
+
arch="sm_80",
|
|
595
|
+
)
|
|
596
|
+
assert request.arch == "sm_80"
|
|
597
|
+
|
|
598
|
+
def test_sm_100_blackwell(self) -> None:
|
|
599
|
+
"""Test Blackwell architecture request."""
|
|
600
|
+
request = CompileRequest(
|
|
601
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
602
|
+
arch="sm_100",
|
|
603
|
+
)
|
|
604
|
+
assert request.arch == "sm_100"
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
class TestCompilerFlags:
|
|
608
|
+
"""Test custom nvcc compiler flags."""
|
|
609
|
+
|
|
610
|
+
def test_optimization_O3(self) -> None:
|
|
611
|
+
"""Test O3 optimization flag."""
|
|
612
|
+
request = CompileRequest(
|
|
613
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
614
|
+
flags=("-O3",),
|
|
615
|
+
)
|
|
616
|
+
assert "-O3" in request.flags
|
|
617
|
+
|
|
618
|
+
def test_lineinfo_flag(self) -> None:
|
|
619
|
+
"""Test lineinfo debug flag."""
|
|
620
|
+
request = CompileRequest(
|
|
621
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
622
|
+
flags=("-lineinfo",),
|
|
623
|
+
)
|
|
624
|
+
assert "-lineinfo" in request.flags
|
|
625
|
+
|
|
626
|
+
def test_maxrregcount(self) -> None:
|
|
627
|
+
"""Test maxrregcount register limit flag."""
|
|
628
|
+
request = CompileRequest(
|
|
629
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
630
|
+
flags=("--maxrregcount=64",),
|
|
631
|
+
)
|
|
632
|
+
assert "--maxrregcount=64" in request.flags
|
|
633
|
+
|
|
634
|
+
def test_multiple_flags(self) -> None:
|
|
635
|
+
"""Test multiple compiler flags."""
|
|
636
|
+
request = CompileRequest(
|
|
637
|
+
files={"kernel.cu": "__global__ void test() {}"},
|
|
638
|
+
flags=("-O3", "-lineinfo", "--maxrregcount=64"),
|
|
639
|
+
)
|
|
640
|
+
assert len(request.flags) == 3
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
class TestErrors:
|
|
644
|
+
"""Test compilation error reporting."""
|
|
645
|
+
|
|
646
|
+
def test_syntax_error_reported(self) -> None:
|
|
647
|
+
"""Test syntax error is included in response."""
|
|
648
|
+
response = CompileResponse.error(
|
|
649
|
+
"kernel.cu(10): error: expected a ';'"
|
|
650
|
+
)
|
|
651
|
+
assert response.success is False
|
|
652
|
+
assert "expected a ';'" in response.stderr
|
|
653
|
+
|
|
654
|
+
def test_missing_include_error(self) -> None:
|
|
655
|
+
"""Test missing include error."""
|
|
656
|
+
response = CompileResponse.error(
|
|
657
|
+
"kernel.cu(1): fatal error: cannot open include file: missing.h"
|
|
658
|
+
)
|
|
659
|
+
assert response.success is False
|
|
660
|
+
assert "cannot open include file" in response.stderr
|
|
661
|
+
|
|
662
|
+
def test_undefined_symbol_error(self) -> None:
|
|
663
|
+
"""Test undefined symbol error."""
|
|
664
|
+
response = CompileResponse.error(
|
|
665
|
+
"kernel.cu(5): error: identifier 'undefined_func' is undefined"
|
|
666
|
+
)
|
|
667
|
+
assert response.success is False
|
|
668
|
+
assert "undefined" in response.stderr
|
|
669
|
+
|
|
670
|
+
def test_error_line_number_included(self) -> None:
|
|
671
|
+
"""Test that error messages include line numbers."""
|
|
672
|
+
response = CompileResponse.error(
|
|
673
|
+
"kernel.cu(42): error: some error message"
|
|
674
|
+
)
|
|
675
|
+
assert "(42)" in response.stderr
|