wafer-core 0.1.38__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/fusion_analyzer.py +2 -0
- wafer_core/rollouts/_logging/__init__.py +5 -1
- wafer_core/rollouts/_logging/logging_config.py +95 -3
- wafer_core/rollouts/_logging/sample_handler.py +66 -0
- wafer_core/rollouts/_pytui/__init__.py +114 -0
- wafer_core/rollouts/_pytui/app.py +809 -0
- wafer_core/rollouts/_pytui/console.py +291 -0
- wafer_core/rollouts/_pytui/renderer.py +210 -0
- wafer_core/rollouts/_pytui/spinner.py +73 -0
- wafer_core/rollouts/_pytui/terminal.py +489 -0
- wafer_core/rollouts/_pytui/text.py +470 -0
- wafer_core/rollouts/_pytui/theme.py +241 -0
- wafer_core/rollouts/evaluation.py +142 -177
- wafer_core/rollouts/progress_app.py +395 -0
- wafer_core/rollouts/tui/DESIGN.md +251 -115
- wafer_core/rollouts/tui/monitor.py +64 -20
- wafer_core/tools/compile/__init__.py +30 -0
- wafer_core/tools/compile/compiler.py +314 -0
- wafer_core/tools/compile/modal_compile.py +359 -0
- wafer_core/tools/compile/tests/__init__.py +1 -0
- wafer_core/tools/compile/tests/test_compiler.py +675 -0
- wafer_core/tools/compile/tests/test_data/utils.cuh +10 -0
- wafer_core/tools/compile/tests/test_data/vector_add.cu +7 -0
- wafer_core/tools/compile/tests/test_data/with_header.cu +9 -0
- wafer_core/tools/compile/tests/test_modal_integration.py +326 -0
- wafer_core/tools/compile/types.py +117 -0
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/METADATA +1 -1
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/RECORD +29 -12
- wafer_core/rollouts/events.py +0 -240
- wafer_core/rollouts/progress_display.py +0 -476
- wafer_core/utils/event_streaming.py +0 -63
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
"""CUDA compiler wrapper.
|
|
2
|
+
|
|
3
|
+
This module provides the core compilation logic that can be used locally
|
|
4
|
+
or through Modal. It handles request/response conversion and provides
|
|
5
|
+
a clean interface for the API and CLI.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import subprocess
|
|
9
|
+
from dataclasses import asdict
|
|
10
|
+
|
|
11
|
+
from wafer_core.tools.compile.types import (
|
|
12
|
+
CompileRequest,
|
|
13
|
+
CompileResponse,
|
|
14
|
+
OutputFormat,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def request_to_dict(request: CompileRequest) -> dict:
|
|
19
|
+
"""Convert a CompileRequest to a dict for Modal invocation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
request: The compile request
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Dict suitable for passing to Modal function
|
|
26
|
+
"""
|
|
27
|
+
return {
|
|
28
|
+
"files": dict(request.files),
|
|
29
|
+
"arch": request.arch,
|
|
30
|
+
"flags": list(request.flags),
|
|
31
|
+
"output": [fmt.value for fmt in request.output],
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def response_from_dict(data: dict) -> CompileResponse:
|
|
36
|
+
"""Convert a Modal response dict to a CompileResponse.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
data: Dict returned from Modal function
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
CompileResponse object
|
|
43
|
+
"""
|
|
44
|
+
return CompileResponse(
|
|
45
|
+
success=data.get("success", False),
|
|
46
|
+
ptx=data.get("ptx"),
|
|
47
|
+
sass=data.get("sass"),
|
|
48
|
+
stderr=data.get("stderr", ""),
|
|
49
|
+
compilation_time_ms=data.get("compilation_time_ms", 0),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def compile_cuda_remote(
|
|
54
|
+
request: CompileRequest,
|
|
55
|
+
*,
|
|
56
|
+
modal_token_id: str | None = None,
|
|
57
|
+
modal_token_secret: str | None = None,
|
|
58
|
+
) -> CompileResponse:
|
|
59
|
+
"""Compile CUDA code using Modal (remote execution).
|
|
60
|
+
|
|
61
|
+
This function spawns a subprocess to call Modal, avoiding event loop
|
|
62
|
+
conflicts between the caller's event loop and Modal's asyncio.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
request: The compile request
|
|
66
|
+
modal_token_id: Optional Modal token ID (uses env var if not provided)
|
|
67
|
+
modal_token_secret: Optional Modal token secret
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
CompileResponse with PTX/SASS or error
|
|
71
|
+
"""
|
|
72
|
+
import asyncio
|
|
73
|
+
import json
|
|
74
|
+
import os
|
|
75
|
+
import tempfile
|
|
76
|
+
import time
|
|
77
|
+
from pathlib import Path
|
|
78
|
+
|
|
79
|
+
start_time = time.time()
|
|
80
|
+
|
|
81
|
+
# Write request to temp file
|
|
82
|
+
request_dict = request_to_dict(request)
|
|
83
|
+
|
|
84
|
+
with tempfile.NamedTemporaryFile(
|
|
85
|
+
mode="w", suffix=".json", delete=False
|
|
86
|
+
) as request_file:
|
|
87
|
+
json.dump(request_dict, request_file)
|
|
88
|
+
request_path = request_file.name
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
# Create a Python script that calls Modal using Function.lookup
|
|
92
|
+
# This calls the deployed function without needing to rebuild the image
|
|
93
|
+
script = f'''
|
|
94
|
+
import json
|
|
95
|
+
import modal
|
|
96
|
+
|
|
97
|
+
# Load request
|
|
98
|
+
with open("{request_path}") as f:
|
|
99
|
+
request = json.load(f)
|
|
100
|
+
|
|
101
|
+
# Look up the deployed function
|
|
102
|
+
compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
|
|
103
|
+
|
|
104
|
+
# Call the function remotely
|
|
105
|
+
result = compile_fn.remote(request)
|
|
106
|
+
|
|
107
|
+
# Output result as JSON
|
|
108
|
+
print(json.dumps(result))
|
|
109
|
+
'''
|
|
110
|
+
|
|
111
|
+
# Run in subprocess to avoid event loop conflicts
|
|
112
|
+
env = os.environ.copy()
|
|
113
|
+
if modal_token_id:
|
|
114
|
+
env["MODAL_TOKEN_ID"] = modal_token_id
|
|
115
|
+
if modal_token_secret:
|
|
116
|
+
env["MODAL_TOKEN_SECRET"] = modal_token_secret
|
|
117
|
+
|
|
118
|
+
# Use the same Python interpreter that's running this code
|
|
119
|
+
import sys
|
|
120
|
+
python_executable = sys.executable
|
|
121
|
+
|
|
122
|
+
# Use asyncio.create_subprocess_exec for async subprocess execution
|
|
123
|
+
proc = await asyncio.create_subprocess_exec(
|
|
124
|
+
python_executable, "-c", script,
|
|
125
|
+
stdout=asyncio.subprocess.PIPE,
|
|
126
|
+
stderr=asyncio.subprocess.PIPE,
|
|
127
|
+
env=env,
|
|
128
|
+
)
|
|
129
|
+
stdout_bytes, stderr_bytes = await proc.communicate()
|
|
130
|
+
|
|
131
|
+
if proc.returncode != 0:
|
|
132
|
+
stderr = stderr_bytes.decode() if stderr_bytes else "Unknown error"
|
|
133
|
+
# Check for common Modal auth errors
|
|
134
|
+
if "MODAL_TOKEN" in stderr or "AuthError" in stderr or "not authenticated" in stderr.lower():
|
|
135
|
+
return CompileResponse.error(
|
|
136
|
+
"Modal not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables, "
|
|
137
|
+
"or run 'modal token new' to authenticate.",
|
|
138
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
139
|
+
)
|
|
140
|
+
return CompileResponse.error(
|
|
141
|
+
f"Compilation failed: {stderr}",
|
|
142
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Parse result
|
|
146
|
+
stdout = stdout_bytes.decode() if stdout_bytes else "{}"
|
|
147
|
+
try:
|
|
148
|
+
response_dict = json.loads(stdout)
|
|
149
|
+
return response_from_dict(response_dict)
|
|
150
|
+
except json.JSONDecodeError as e:
|
|
151
|
+
return CompileResponse.error(
|
|
152
|
+
f"Failed to parse Modal response: {e}\nOutput: {stdout[:500]}",
|
|
153
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
finally:
|
|
157
|
+
# Clean up temp file
|
|
158
|
+
Path(request_path).unlink(missing_ok=True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def compile_cuda_local(request: CompileRequest) -> CompileResponse:
|
|
162
|
+
"""Compile CUDA code locally using nvcc.
|
|
163
|
+
|
|
164
|
+
This function requires nvcc to be installed locally.
|
|
165
|
+
Primarily useful for testing without Modal.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
request: The compile request
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
CompileResponse with PTX/SASS or error
|
|
172
|
+
"""
|
|
173
|
+
import os
|
|
174
|
+
import tempfile
|
|
175
|
+
import time
|
|
176
|
+
from pathlib import Path
|
|
177
|
+
|
|
178
|
+
start_time = time.time()
|
|
179
|
+
|
|
180
|
+
# Check if nvcc is available
|
|
181
|
+
try:
|
|
182
|
+
subprocess.run(
|
|
183
|
+
["nvcc", "--version"],
|
|
184
|
+
capture_output=True,
|
|
185
|
+
check=True,
|
|
186
|
+
timeout=10,
|
|
187
|
+
)
|
|
188
|
+
except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired):
|
|
189
|
+
return CompileResponse.error(
|
|
190
|
+
"nvcc not found. Install CUDA toolkit or use Modal for remote compilation.",
|
|
191
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
196
|
+
tmp_path = Path(tmpdir)
|
|
197
|
+
|
|
198
|
+
# Write all files to temp directory
|
|
199
|
+
for filename, content in request.files.items():
|
|
200
|
+
file_path = tmp_path / filename
|
|
201
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
202
|
+
file_path.write_text(content)
|
|
203
|
+
|
|
204
|
+
main_cu_path = tmp_path / request.main_cu_file
|
|
205
|
+
include_dir = main_cu_path.parent
|
|
206
|
+
|
|
207
|
+
results: dict[str, str | None] = {"ptx": None, "sass": None}
|
|
208
|
+
|
|
209
|
+
# Build base nvcc command
|
|
210
|
+
base_cmd = [
|
|
211
|
+
"nvcc",
|
|
212
|
+
"-arch",
|
|
213
|
+
request.arch,
|
|
214
|
+
f"-I{include_dir}",
|
|
215
|
+
]
|
|
216
|
+
base_cmd.extend(request.flags)
|
|
217
|
+
|
|
218
|
+
# Generate PTX if requested
|
|
219
|
+
if OutputFormat.PTX in request.output:
|
|
220
|
+
ptx_output = tmp_path / "output.ptx"
|
|
221
|
+
ptx_cmd = base_cmd + [
|
|
222
|
+
"--ptx",
|
|
223
|
+
"-o",
|
|
224
|
+
str(ptx_output),
|
|
225
|
+
str(main_cu_path),
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
ptx_result = subprocess.run(
|
|
229
|
+
ptx_cmd,
|
|
230
|
+
capture_output=True,
|
|
231
|
+
text=True,
|
|
232
|
+
timeout=60,
|
|
233
|
+
cwd=tmpdir,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if ptx_result.returncode != 0:
|
|
237
|
+
return CompileResponse.error(
|
|
238
|
+
ptx_result.stderr or ptx_result.stdout,
|
|
239
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if ptx_output.exists():
|
|
243
|
+
results["ptx"] = ptx_output.read_text()
|
|
244
|
+
|
|
245
|
+
# Generate SASS if requested
|
|
246
|
+
if OutputFormat.SASS in request.output:
|
|
247
|
+
cubin_output = tmp_path / "output.cubin"
|
|
248
|
+
cubin_cmd = base_cmd + [
|
|
249
|
+
"--cubin",
|
|
250
|
+
"-o",
|
|
251
|
+
str(cubin_output),
|
|
252
|
+
str(main_cu_path),
|
|
253
|
+
]
|
|
254
|
+
|
|
255
|
+
cubin_result = subprocess.run(
|
|
256
|
+
cubin_cmd,
|
|
257
|
+
capture_output=True,
|
|
258
|
+
text=True,
|
|
259
|
+
timeout=60,
|
|
260
|
+
cwd=tmpdir,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if cubin_result.returncode != 0:
|
|
264
|
+
if results["ptx"]:
|
|
265
|
+
return CompileResponse(
|
|
266
|
+
success=True,
|
|
267
|
+
ptx=results["ptx"],
|
|
268
|
+
sass=None,
|
|
269
|
+
stderr=f"SASS generation failed: {cubin_result.stderr}",
|
|
270
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
271
|
+
)
|
|
272
|
+
return CompileResponse.error(
|
|
273
|
+
cubin_result.stderr or cubin_result.stdout,
|
|
274
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
if cubin_output.exists():
|
|
278
|
+
sass_result = subprocess.run(
|
|
279
|
+
["cuobjdump", "--dump-sass", str(cubin_output)],
|
|
280
|
+
capture_output=True,
|
|
281
|
+
text=True,
|
|
282
|
+
timeout=30,
|
|
283
|
+
cwd=tmpdir,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
if sass_result.returncode == 0:
|
|
287
|
+
results["sass"] = sass_result.stdout
|
|
288
|
+
|
|
289
|
+
if not results["ptx"] and not results["sass"]:
|
|
290
|
+
return CompileResponse.error(
|
|
291
|
+
"No output generated",
|
|
292
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
return CompileResponse(
|
|
296
|
+
success=True,
|
|
297
|
+
ptx=results["ptx"],
|
|
298
|
+
sass=results["sass"],
|
|
299
|
+
stderr="",
|
|
300
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
except subprocess.TimeoutExpired:
|
|
304
|
+
return CompileResponse.error(
|
|
305
|
+
"Compilation timed out",
|
|
306
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
307
|
+
)
|
|
308
|
+
except Exception as e:
|
|
309
|
+
import traceback
|
|
310
|
+
|
|
311
|
+
return CompileResponse.error(
|
|
312
|
+
f"Internal error: {e}\n{traceback.format_exc()}",
|
|
313
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
314
|
+
)
|
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""Modal function for cloud CUDA compilation.
|
|
2
|
+
|
|
3
|
+
This module provides a CPU-only Modal function that compiles CUDA code
|
|
4
|
+
and returns PTX/SASS assembly. No GPU is required - nvcc can generate
|
|
5
|
+
PTX/SASS for any architecture without a physical GPU.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# From Python
|
|
9
|
+
from wafer_core.tools.compile.modal_compile import compile_cuda
|
|
10
|
+
result = compile_cuda.remote(request_dict)
|
|
11
|
+
|
|
12
|
+
# Deploy Modal app
|
|
13
|
+
modal deploy modal_compile.py
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from enum import Enum
|
|
17
|
+
|
|
18
|
+
import modal
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Inline OutputFormat to avoid import chain issues with Modal deployment
|
|
22
|
+
class OutputFormat(str, Enum):
|
|
23
|
+
"""Supported output formats."""
|
|
24
|
+
|
|
25
|
+
PTX = "ptx"
|
|
26
|
+
SASS = "sass"
|
|
27
|
+
|
|
28
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
29
|
+
# Modal Image Configuration
|
|
30
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
31
|
+
|
|
32
|
+
# Build a CPU-only image with nvcc for CUDA compilation
|
|
33
|
+
# No GPU driver needed - nvcc can generate PTX/SASS without a GPU
|
|
34
|
+
compile_image = (
|
|
35
|
+
# Start with NVIDIA's CUDA development image (includes nvcc)
|
|
36
|
+
# Use debian-slim + manual CUDA install for smaller image
|
|
37
|
+
modal.Image.from_registry(
|
|
38
|
+
"nvidia/cuda:13.0.1-devel-ubuntu22.04",
|
|
39
|
+
add_python="3.12",
|
|
40
|
+
)
|
|
41
|
+
# System dependencies
|
|
42
|
+
.apt_install("git", "build-essential", "wget")
|
|
43
|
+
# Install PyTorch headers (CPU-only, we just need the headers)
|
|
44
|
+
.pip_install(
|
|
45
|
+
"torch>=2.5.0",
|
|
46
|
+
index_url="https://download.pytorch.org/whl/cpu",
|
|
47
|
+
extra_index_url="https://pypi.org/simple",
|
|
48
|
+
)
|
|
49
|
+
# Install CUTLASS headers (v4.3.5)
|
|
50
|
+
.run_commands(
|
|
51
|
+
"git clone --depth 1 --branch v4.3.5 https://github.com/NVIDIA/cutlass.git /usr/local/cutlass",
|
|
52
|
+
)
|
|
53
|
+
# Set environment variables for CUDA and libraries
|
|
54
|
+
.env(
|
|
55
|
+
{
|
|
56
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
57
|
+
"PATH": "/usr/local/cuda/bin:$PATH",
|
|
58
|
+
"CUTLASS_PATH": "/usr/local/cutlass/include",
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
# Verify nvcc is working
|
|
62
|
+
.run_commands(
|
|
63
|
+
"nvcc --version",
|
|
64
|
+
"ls -la /usr/local/cuda/bin/nvcc",
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Create Modal app
|
|
69
|
+
app = modal.App(name="cuda-compile", image=compile_image)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
73
|
+
# Modal Function - CUDA Compilation
|
|
74
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@app.function(
|
|
78
|
+
# CPU only - no GPU needed for compilation!
|
|
79
|
+
cpu=4,
|
|
80
|
+
memory=8192, # 8GB RAM
|
|
81
|
+
timeout=120, # 2 minute timeout
|
|
82
|
+
)
|
|
83
|
+
def compile_cuda(request: dict) -> dict:
|
|
84
|
+
"""Compile CUDA code and return PTX/SASS.
|
|
85
|
+
|
|
86
|
+
This function runs on a CPU-only container with nvcc installed.
|
|
87
|
+
No GPU is required because nvcc can cross-compile for any target architecture.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
request: Dict with the following fields:
|
|
91
|
+
- files: dict[str, str] - Mapping of filename to content
|
|
92
|
+
- arch: str - Target architecture (e.g., "sm_90a")
|
|
93
|
+
- flags: list[str] - Additional nvcc flags
|
|
94
|
+
- output: list[str] - Output formats ("ptx", "sass")
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Dict with the following fields:
|
|
98
|
+
- success: bool - Whether compilation succeeded
|
|
99
|
+
- ptx: str | None - Generated PTX code
|
|
100
|
+
- sass: str | None - Generated SASS code
|
|
101
|
+
- stderr: str - Compiler warnings/errors
|
|
102
|
+
- compilation_time_ms: int - Time taken in ms
|
|
103
|
+
"""
|
|
104
|
+
import os
|
|
105
|
+
import subprocess
|
|
106
|
+
import tempfile
|
|
107
|
+
import time
|
|
108
|
+
from pathlib import Path
|
|
109
|
+
|
|
110
|
+
start_time = time.time()
|
|
111
|
+
|
|
112
|
+
# Extract request fields
|
|
113
|
+
files: dict[str, str] = request.get("files", {})
|
|
114
|
+
arch: str = request.get("arch", "sm_90a")
|
|
115
|
+
flags: list[str] = request.get("flags", [])
|
|
116
|
+
output_formats: list[str] = request.get("output", ["ptx", "sass"])
|
|
117
|
+
|
|
118
|
+
# Validate request
|
|
119
|
+
if not files:
|
|
120
|
+
return {
|
|
121
|
+
"success": False,
|
|
122
|
+
"ptx": None,
|
|
123
|
+
"sass": None,
|
|
124
|
+
"stderr": "No files provided",
|
|
125
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
# Find the main .cu file
|
|
129
|
+
cu_files = [f for f in files.keys() if f.endswith(".cu")]
|
|
130
|
+
if not cu_files:
|
|
131
|
+
return {
|
|
132
|
+
"success": False,
|
|
133
|
+
"ptx": None,
|
|
134
|
+
"sass": None,
|
|
135
|
+
"stderr": "No .cu file found",
|
|
136
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
main_cu_file = cu_files[0]
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
143
|
+
tmp_path = Path(tmpdir)
|
|
144
|
+
|
|
145
|
+
# Write all files to temp directory, preserving subdirectory structure
|
|
146
|
+
for filename, content in files.items():
|
|
147
|
+
file_path = tmp_path / filename
|
|
148
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
149
|
+
file_path.write_text(content)
|
|
150
|
+
|
|
151
|
+
# Determine the directory of the main .cu file for include paths
|
|
152
|
+
main_cu_path = tmp_path / main_cu_file
|
|
153
|
+
include_dir = main_cu_path.parent
|
|
154
|
+
|
|
155
|
+
results: dict[str, str | None] = {"ptx": None, "sass": None}
|
|
156
|
+
|
|
157
|
+
# Build base nvcc command with common flags
|
|
158
|
+
base_cmd = [
|
|
159
|
+
"nvcc",
|
|
160
|
+
"-arch",
|
|
161
|
+
arch,
|
|
162
|
+
# Include the temp directory for user headers
|
|
163
|
+
f"-I{include_dir}",
|
|
164
|
+
# Include PyTorch headers
|
|
165
|
+
"-I/usr/local/lib/python3.12/site-packages/torch/include",
|
|
166
|
+
"-I/usr/local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include",
|
|
167
|
+
# Include CUTLASS headers
|
|
168
|
+
"-I/usr/local/cutlass/include",
|
|
169
|
+
# Standard CUDA headers are already in the default path
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
# Add user-specified flags
|
|
173
|
+
base_cmd.extend(flags)
|
|
174
|
+
|
|
175
|
+
# Generate PTX if requested
|
|
176
|
+
if OutputFormat.PTX.value in output_formats:
|
|
177
|
+
ptx_output = tmp_path / "output.ptx"
|
|
178
|
+
ptx_cmd = base_cmd + [
|
|
179
|
+
"--ptx", # Generate PTX
|
|
180
|
+
"-o",
|
|
181
|
+
str(ptx_output),
|
|
182
|
+
str(main_cu_path),
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
ptx_result = subprocess.run(
|
|
186
|
+
ptx_cmd,
|
|
187
|
+
capture_output=True,
|
|
188
|
+
text=True,
|
|
189
|
+
timeout=60,
|
|
190
|
+
cwd=tmpdir,
|
|
191
|
+
env={
|
|
192
|
+
**os.environ,
|
|
193
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
194
|
+
"PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if ptx_result.returncode != 0:
|
|
199
|
+
return {
|
|
200
|
+
"success": False,
|
|
201
|
+
"ptx": None,
|
|
202
|
+
"sass": None,
|
|
203
|
+
"stderr": ptx_result.stderr or ptx_result.stdout,
|
|
204
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
if ptx_output.exists():
|
|
208
|
+
results["ptx"] = ptx_output.read_text()
|
|
209
|
+
|
|
210
|
+
# Generate SASS if requested
|
|
211
|
+
if OutputFormat.SASS.value in output_formats:
|
|
212
|
+
# First compile to cubin, then disassemble to SASS
|
|
213
|
+
cubin_output = tmp_path / "output.cubin"
|
|
214
|
+
cubin_cmd = base_cmd + [
|
|
215
|
+
"--cubin", # Generate cubin (binary)
|
|
216
|
+
"-o",
|
|
217
|
+
str(cubin_output),
|
|
218
|
+
str(main_cu_path),
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
cubin_result = subprocess.run(
|
|
222
|
+
cubin_cmd,
|
|
223
|
+
capture_output=True,
|
|
224
|
+
text=True,
|
|
225
|
+
timeout=60,
|
|
226
|
+
cwd=tmpdir,
|
|
227
|
+
env={
|
|
228
|
+
**os.environ,
|
|
229
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
230
|
+
"PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
|
|
231
|
+
},
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if cubin_result.returncode != 0:
|
|
235
|
+
# If we already have PTX, that's a partial success
|
|
236
|
+
if results["ptx"]:
|
|
237
|
+
return {
|
|
238
|
+
"success": True,
|
|
239
|
+
"ptx": results["ptx"],
|
|
240
|
+
"sass": None,
|
|
241
|
+
"stderr": f"SASS generation failed: {cubin_result.stderr}",
|
|
242
|
+
"compilation_time_ms": int(
|
|
243
|
+
(time.time() - start_time) * 1000
|
|
244
|
+
),
|
|
245
|
+
}
|
|
246
|
+
return {
|
|
247
|
+
"success": False,
|
|
248
|
+
"ptx": None,
|
|
249
|
+
"sass": None,
|
|
250
|
+
"stderr": cubin_result.stderr or cubin_result.stdout,
|
|
251
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
# Disassemble cubin to SASS using cuobjdump
|
|
255
|
+
if cubin_output.exists():
|
|
256
|
+
sass_cmd = [
|
|
257
|
+
"cuobjdump",
|
|
258
|
+
"--dump-sass",
|
|
259
|
+
str(cubin_output),
|
|
260
|
+
]
|
|
261
|
+
|
|
262
|
+
sass_result = subprocess.run(
|
|
263
|
+
sass_cmd,
|
|
264
|
+
capture_output=True,
|
|
265
|
+
text=True,
|
|
266
|
+
timeout=30,
|
|
267
|
+
cwd=tmpdir,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if sass_result.returncode == 0:
|
|
271
|
+
results["sass"] = sass_result.stdout
|
|
272
|
+
else:
|
|
273
|
+
# SASS generation failed but we might have PTX
|
|
274
|
+
if results["ptx"]:
|
|
275
|
+
return {
|
|
276
|
+
"success": True,
|
|
277
|
+
"ptx": results["ptx"],
|
|
278
|
+
"sass": None,
|
|
279
|
+
"stderr": f"SASS disassembly failed: {sass_result.stderr}",
|
|
280
|
+
"compilation_time_ms": int(
|
|
281
|
+
(time.time() - start_time) * 1000
|
|
282
|
+
),
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
# Check if we got any output
|
|
286
|
+
if not results["ptx"] and not results["sass"]:
|
|
287
|
+
return {
|
|
288
|
+
"success": False,
|
|
289
|
+
"ptx": None,
|
|
290
|
+
"sass": None,
|
|
291
|
+
"stderr": "No output generated",
|
|
292
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
return {
|
|
296
|
+
"success": True,
|
|
297
|
+
"ptx": results["ptx"],
|
|
298
|
+
"sass": results["sass"],
|
|
299
|
+
"stderr": "",
|
|
300
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
except subprocess.TimeoutExpired:
|
|
304
|
+
return {
|
|
305
|
+
"success": False,
|
|
306
|
+
"ptx": None,
|
|
307
|
+
"sass": None,
|
|
308
|
+
"stderr": "Compilation timed out",
|
|
309
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
310
|
+
}
|
|
311
|
+
except Exception as e:
|
|
312
|
+
import traceback
|
|
313
|
+
|
|
314
|
+
return {
|
|
315
|
+
"success": False,
|
|
316
|
+
"ptx": None,
|
|
317
|
+
"sass": None,
|
|
318
|
+
"stderr": f"Internal error: {e}\n{traceback.format_exc()}",
|
|
319
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
324
|
+
# Health Check
|
|
325
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@app.function(cpu=1, memory=1024, timeout=30)
|
|
329
|
+
def health_check() -> dict:
|
|
330
|
+
"""Verify the compilation environment is working.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Dict with status and version info
|
|
334
|
+
"""
|
|
335
|
+
import subprocess
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
# Get nvcc version
|
|
339
|
+
nvcc_result = subprocess.run(
|
|
340
|
+
["nvcc", "--version"],
|
|
341
|
+
capture_output=True,
|
|
342
|
+
text=True,
|
|
343
|
+
timeout=10,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Parse version from output
|
|
347
|
+
version_line = nvcc_result.stdout.strip().split("\n")[-1]
|
|
348
|
+
|
|
349
|
+
return {
|
|
350
|
+
"status": "ok",
|
|
351
|
+
"nvcc_version": version_line,
|
|
352
|
+
"nvcc_available": nvcc_result.returncode == 0,
|
|
353
|
+
}
|
|
354
|
+
except Exception as e:
|
|
355
|
+
return {
|
|
356
|
+
"status": "error",
|
|
357
|
+
"error": str(e),
|
|
358
|
+
"nvcc_available": False,
|
|
359
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Tests for the cloud CUDA compiler."""
|