wafer-core 0.1.38__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/fusion_analyzer.py +2 -0
- wafer_core/rollouts/_logging/__init__.py +5 -1
- wafer_core/rollouts/_logging/logging_config.py +95 -3
- wafer_core/rollouts/_logging/sample_handler.py +66 -0
- wafer_core/rollouts/_pytui/__init__.py +114 -0
- wafer_core/rollouts/_pytui/app.py +809 -0
- wafer_core/rollouts/_pytui/console.py +291 -0
- wafer_core/rollouts/_pytui/renderer.py +210 -0
- wafer_core/rollouts/_pytui/spinner.py +73 -0
- wafer_core/rollouts/_pytui/terminal.py +489 -0
- wafer_core/rollouts/_pytui/text.py +470 -0
- wafer_core/rollouts/_pytui/theme.py +241 -0
- wafer_core/rollouts/evaluation.py +142 -177
- wafer_core/rollouts/progress_app.py +395 -0
- wafer_core/rollouts/tui/DESIGN.md +251 -115
- wafer_core/rollouts/tui/monitor.py +64 -20
- wafer_core/tools/compile/__init__.py +30 -0
- wafer_core/tools/compile/benchmark.py +636 -0
- wafer_core/tools/compile/compiler.py +301 -0
- wafer_core/tools/compile/modal_compile.py +369 -0
- wafer_core/tools/compile/tests/__init__.py +1 -0
- wafer_core/tools/compile/tests/test_compiler.py +675 -0
- wafer_core/tools/compile/tests/test_data/utils.cuh +10 -0
- wafer_core/tools/compile/tests/test_data/vector_add.cu +7 -0
- wafer_core/tools/compile/tests/test_data/with_header.cu +9 -0
- wafer_core/tools/compile/tests/test_modal_integration.py +326 -0
- wafer_core/tools/compile/types.py +117 -0
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.40.dist-info}/METADATA +1 -1
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.40.dist-info}/RECORD +30 -12
- wafer_core/rollouts/events.py +0 -240
- wafer_core/rollouts/progress_display.py +0 -476
- wafer_core/utils/event_streaming.py +0 -63
- {wafer_core-0.1.38.dist-info → wafer_core-0.1.40.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""CUDA compiler wrapper.
|
|
2
|
+
|
|
3
|
+
This module provides the core compilation logic that can be used locally
|
|
4
|
+
or through Modal. It handles request/response conversion and provides
|
|
5
|
+
a clean interface for the API and CLI.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import subprocess
|
|
9
|
+
from dataclasses import asdict
|
|
10
|
+
|
|
11
|
+
from wafer_core.tools.compile.types import (
|
|
12
|
+
CompileRequest,
|
|
13
|
+
CompileResponse,
|
|
14
|
+
OutputFormat,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def request_to_dict(request: CompileRequest) -> dict:
|
|
19
|
+
"""Convert a CompileRequest to a dict for Modal invocation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
request: The compile request
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Dict suitable for passing to Modal function
|
|
26
|
+
"""
|
|
27
|
+
return {
|
|
28
|
+
"files": dict(request.files),
|
|
29
|
+
"arch": request.arch,
|
|
30
|
+
"flags": list(request.flags),
|
|
31
|
+
"output": [fmt.value for fmt in request.output],
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def response_from_dict(data: dict) -> CompileResponse:
|
|
36
|
+
"""Convert a Modal response dict to a CompileResponse.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
data: Dict returned from Modal function
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
CompileResponse object
|
|
43
|
+
"""
|
|
44
|
+
return CompileResponse(
|
|
45
|
+
success=data.get("success", False),
|
|
46
|
+
ptx=data.get("ptx"),
|
|
47
|
+
sass=data.get("sass"),
|
|
48
|
+
stderr=data.get("stderr", ""),
|
|
49
|
+
compilation_time_ms=data.get("compilation_time_ms", 0),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def compile_cuda_remote(
|
|
54
|
+
request: CompileRequest,
|
|
55
|
+
*,
|
|
56
|
+
modal_token_id: str | None = None,
|
|
57
|
+
modal_token_secret: str | None = None,
|
|
58
|
+
) -> CompileResponse:
|
|
59
|
+
"""Compile CUDA code using Modal (remote execution).
|
|
60
|
+
|
|
61
|
+
This function calls the deployed Modal function directly using asyncio.to_thread
|
|
62
|
+
to avoid blocking the event loop.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
request: The compile request
|
|
66
|
+
modal_token_id: Optional Modal token ID (uses env var if not provided)
|
|
67
|
+
modal_token_secret: Optional Modal token secret
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
CompileResponse with PTX/SASS or error
|
|
71
|
+
"""
|
|
72
|
+
import asyncio
|
|
73
|
+
import os
|
|
74
|
+
import time
|
|
75
|
+
from contextlib import contextmanager
|
|
76
|
+
|
|
77
|
+
@contextmanager
|
|
78
|
+
def temporary_env_vars(env_updates: dict[str, str]):
|
|
79
|
+
"""Context manager to temporarily set environment variables.
|
|
80
|
+
|
|
81
|
+
Saves original values, sets new values, yields, then restores originals.
|
|
82
|
+
This ensures we don't leak credentials between concurrent requests.
|
|
83
|
+
"""
|
|
84
|
+
original_values: dict[str, str | None] = {}
|
|
85
|
+
for key, value in env_updates.items():
|
|
86
|
+
original_values[key] = os.environ.get(key)
|
|
87
|
+
os.environ[key] = value
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
yield
|
|
91
|
+
finally:
|
|
92
|
+
for key, original in original_values.items():
|
|
93
|
+
if original is None:
|
|
94
|
+
os.environ.pop(key, None)
|
|
95
|
+
else:
|
|
96
|
+
os.environ[key] = original
|
|
97
|
+
|
|
98
|
+
start_time = time.time()
|
|
99
|
+
request_dict = request_to_dict(request)
|
|
100
|
+
|
|
101
|
+
# Build env updates for credentials (only if provided)
|
|
102
|
+
env_updates: dict[str, str] = {}
|
|
103
|
+
if modal_token_id:
|
|
104
|
+
env_updates["MODAL_TOKEN_ID"] = modal_token_id
|
|
105
|
+
if modal_token_secret:
|
|
106
|
+
env_updates["MODAL_TOKEN_SECRET"] = modal_token_secret
|
|
107
|
+
|
|
108
|
+
def call_modal() -> dict:
|
|
109
|
+
"""Call Modal function synchronously (runs in thread pool)."""
|
|
110
|
+
import modal
|
|
111
|
+
|
|
112
|
+
# Look up the deployed function
|
|
113
|
+
compile_fn = modal.Function.from_name("cuda-compile", "compile_cuda")
|
|
114
|
+
|
|
115
|
+
# Call the function remotely
|
|
116
|
+
return compile_fn.remote(request_dict)
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
# Run Modal call in thread pool with temporary credentials
|
|
120
|
+
# The context manager ensures env vars are restored after the call
|
|
121
|
+
def call_modal_with_env() -> dict:
|
|
122
|
+
with temporary_env_vars(env_updates):
|
|
123
|
+
return call_modal()
|
|
124
|
+
|
|
125
|
+
result = await asyncio.to_thread(call_modal_with_env)
|
|
126
|
+
return response_from_dict(result)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
error_str = str(e)
|
|
130
|
+
# Check for common Modal auth errors
|
|
131
|
+
if "MODAL_TOKEN" in error_str or "AuthError" in error_str or "not authenticated" in error_str.lower():
|
|
132
|
+
return CompileResponse.error(
|
|
133
|
+
"Modal not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables, "
|
|
134
|
+
"or run 'modal token new' to authenticate.",
|
|
135
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
136
|
+
)
|
|
137
|
+
return CompileResponse.error(
|
|
138
|
+
f"Compilation failed: {error_str}",
|
|
139
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def compile_cuda_local(request: CompileRequest) -> CompileResponse:
|
|
144
|
+
"""Compile CUDA code locally using nvcc.
|
|
145
|
+
|
|
146
|
+
This function requires nvcc to be installed locally.
|
|
147
|
+
Primarily useful for testing without Modal.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
request: The compile request
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
CompileResponse with PTX/SASS or error
|
|
154
|
+
"""
|
|
155
|
+
import os
|
|
156
|
+
import tempfile
|
|
157
|
+
import time
|
|
158
|
+
from pathlib import Path
|
|
159
|
+
|
|
160
|
+
start_time = time.time()
|
|
161
|
+
|
|
162
|
+
# Check if nvcc is available
|
|
163
|
+
try:
|
|
164
|
+
subprocess.run(
|
|
165
|
+
["nvcc", "--version"],
|
|
166
|
+
capture_output=True,
|
|
167
|
+
check=True,
|
|
168
|
+
timeout=10,
|
|
169
|
+
)
|
|
170
|
+
except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired):
|
|
171
|
+
return CompileResponse.error(
|
|
172
|
+
"nvcc not found. Install CUDA toolkit or use Modal for remote compilation.",
|
|
173
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
178
|
+
tmp_path = Path(tmpdir)
|
|
179
|
+
|
|
180
|
+
# Write all files to temp directory
|
|
181
|
+
for filename, content in request.files.items():
|
|
182
|
+
file_path = (tmp_path / filename).resolve()
|
|
183
|
+
if not file_path.is_relative_to(tmp_path):
|
|
184
|
+
return CompileResponse.error(
|
|
185
|
+
f"Invalid filename: {filename}",
|
|
186
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
187
|
+
)
|
|
188
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
189
|
+
file_path.write_text(content)
|
|
190
|
+
|
|
191
|
+
main_cu_path = tmp_path / request.main_cu_file
|
|
192
|
+
include_dir = main_cu_path.parent
|
|
193
|
+
|
|
194
|
+
results: dict[str, str | None] = {"ptx": None, "sass": None}
|
|
195
|
+
|
|
196
|
+
# Build base nvcc command
|
|
197
|
+
base_cmd = [
|
|
198
|
+
"nvcc",
|
|
199
|
+
"-arch",
|
|
200
|
+
request.arch,
|
|
201
|
+
f"-I{include_dir}",
|
|
202
|
+
]
|
|
203
|
+
base_cmd.extend(request.flags)
|
|
204
|
+
|
|
205
|
+
# Generate PTX if requested
|
|
206
|
+
if OutputFormat.PTX in request.output:
|
|
207
|
+
ptx_output = tmp_path / "output.ptx"
|
|
208
|
+
ptx_cmd = base_cmd + [
|
|
209
|
+
"--ptx",
|
|
210
|
+
"-o",
|
|
211
|
+
str(ptx_output),
|
|
212
|
+
str(main_cu_path),
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
ptx_result = subprocess.run(
|
|
216
|
+
ptx_cmd,
|
|
217
|
+
capture_output=True,
|
|
218
|
+
text=True,
|
|
219
|
+
timeout=60,
|
|
220
|
+
cwd=tmpdir,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if ptx_result.returncode != 0:
|
|
224
|
+
return CompileResponse.error(
|
|
225
|
+
ptx_result.stderr or ptx_result.stdout,
|
|
226
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if ptx_output.exists():
|
|
230
|
+
results["ptx"] = ptx_output.read_text()
|
|
231
|
+
|
|
232
|
+
# Generate SASS if requested
|
|
233
|
+
if OutputFormat.SASS in request.output:
|
|
234
|
+
cubin_output = tmp_path / "output.cubin"
|
|
235
|
+
cubin_cmd = base_cmd + [
|
|
236
|
+
"--cubin",
|
|
237
|
+
"-o",
|
|
238
|
+
str(cubin_output),
|
|
239
|
+
str(main_cu_path),
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
cubin_result = subprocess.run(
|
|
243
|
+
cubin_cmd,
|
|
244
|
+
capture_output=True,
|
|
245
|
+
text=True,
|
|
246
|
+
timeout=60,
|
|
247
|
+
cwd=tmpdir,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if cubin_result.returncode != 0:
|
|
251
|
+
if results["ptx"]:
|
|
252
|
+
return CompileResponse(
|
|
253
|
+
success=True,
|
|
254
|
+
ptx=results["ptx"],
|
|
255
|
+
sass=None,
|
|
256
|
+
stderr=f"SASS generation failed: {cubin_result.stderr}",
|
|
257
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
258
|
+
)
|
|
259
|
+
return CompileResponse.error(
|
|
260
|
+
cubin_result.stderr or cubin_result.stdout,
|
|
261
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if cubin_output.exists():
|
|
265
|
+
sass_result = subprocess.run(
|
|
266
|
+
["cuobjdump", "--dump-sass", str(cubin_output)],
|
|
267
|
+
capture_output=True,
|
|
268
|
+
text=True,
|
|
269
|
+
timeout=30,
|
|
270
|
+
cwd=tmpdir,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if sass_result.returncode == 0:
|
|
274
|
+
results["sass"] = sass_result.stdout
|
|
275
|
+
|
|
276
|
+
if not results["ptx"] and not results["sass"]:
|
|
277
|
+
return CompileResponse.error(
|
|
278
|
+
"No output generated",
|
|
279
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
return CompileResponse(
|
|
283
|
+
success=True,
|
|
284
|
+
ptx=results["ptx"],
|
|
285
|
+
sass=results["sass"],
|
|
286
|
+
stderr="",
|
|
287
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
except subprocess.TimeoutExpired:
|
|
291
|
+
return CompileResponse.error(
|
|
292
|
+
"Compilation timed out",
|
|
293
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
294
|
+
)
|
|
295
|
+
except Exception as e:
|
|
296
|
+
import traceback
|
|
297
|
+
|
|
298
|
+
return CompileResponse.error(
|
|
299
|
+
f"Internal error: {e}\n{traceback.format_exc()}",
|
|
300
|
+
compilation_time_ms=int((time.time() - start_time) * 1000),
|
|
301
|
+
)
|
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
"""Modal function for cloud CUDA compilation.
|
|
2
|
+
|
|
3
|
+
This module provides a CPU-only Modal function that compiles CUDA code
|
|
4
|
+
and returns PTX/SASS assembly. No GPU is required - nvcc can generate
|
|
5
|
+
PTX/SASS for any architecture without a physical GPU.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# From Python
|
|
9
|
+
from wafer_core.tools.compile.modal_compile import compile_cuda
|
|
10
|
+
result = compile_cuda.remote(request_dict)
|
|
11
|
+
|
|
12
|
+
# Deploy Modal app
|
|
13
|
+
modal deploy modal_compile.py
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from enum import Enum
|
|
17
|
+
|
|
18
|
+
import modal
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Inline OutputFormat to avoid import chain issues with Modal deployment
|
|
22
|
+
class OutputFormat(str, Enum):
|
|
23
|
+
"""Supported output formats."""
|
|
24
|
+
|
|
25
|
+
PTX = "ptx"
|
|
26
|
+
SASS = "sass"
|
|
27
|
+
|
|
28
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
29
|
+
# Modal Image Configuration
|
|
30
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
31
|
+
|
|
32
|
+
# Build a CPU-only image with nvcc for CUDA compilation
|
|
33
|
+
# No GPU driver needed - nvcc can generate PTX/SASS without a GPU
|
|
34
|
+
compile_image = (
|
|
35
|
+
# Start with NVIDIA's CUDA development image (includes nvcc)
|
|
36
|
+
# Use debian-slim + manual CUDA install for smaller image
|
|
37
|
+
modal.Image.from_registry(
|
|
38
|
+
"nvidia/cuda:13.0.1-devel-ubuntu22.04",
|
|
39
|
+
add_python="3.12",
|
|
40
|
+
)
|
|
41
|
+
# System dependencies
|
|
42
|
+
.apt_install("git", "build-essential", "wget")
|
|
43
|
+
# Install PyTorch headers (CPU-only, we just need the headers)
|
|
44
|
+
.pip_install(
|
|
45
|
+
"torch>=2.5.0",
|
|
46
|
+
index_url="https://download.pytorch.org/whl/cpu",
|
|
47
|
+
extra_index_url="https://pypi.org/simple",
|
|
48
|
+
)
|
|
49
|
+
# Install CUTLASS headers (v4.3.5)
|
|
50
|
+
.run_commands(
|
|
51
|
+
"git clone --depth 1 --branch v4.3.5 https://github.com/NVIDIA/cutlass.git /usr/local/cutlass",
|
|
52
|
+
)
|
|
53
|
+
# Set environment variables for CUDA and libraries
|
|
54
|
+
.env(
|
|
55
|
+
{
|
|
56
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
57
|
+
"PATH": "/usr/local/cuda/bin:$PATH",
|
|
58
|
+
"CUTLASS_PATH": "/usr/local/cutlass/include",
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
# Verify nvcc is working
|
|
62
|
+
.run_commands(
|
|
63
|
+
"nvcc --version",
|
|
64
|
+
"ls -la /usr/local/cuda/bin/nvcc",
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Create Modal app
|
|
69
|
+
app = modal.App(name="cuda-compile", image=compile_image)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
73
|
+
# Modal Function - CUDA Compilation
|
|
74
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@app.function(
|
|
78
|
+
# CPU only - no GPU needed for compilation!
|
|
79
|
+
cpu=4,
|
|
80
|
+
memory=8192, # 8GB RAM
|
|
81
|
+
timeout=120, # 2 minute timeout
|
|
82
|
+
# Keep one container warm to avoid cold starts (~5-10s savings)
|
|
83
|
+
min_containers=1,
|
|
84
|
+
)
|
|
85
|
+
@modal.concurrent(max_inputs=4) # Allow concurrent compilations for better throughput
|
|
86
|
+
def compile_cuda(request: dict) -> dict:
|
|
87
|
+
"""Compile CUDA code and return PTX/SASS.
|
|
88
|
+
|
|
89
|
+
This function runs on a CPU-only container with nvcc installed.
|
|
90
|
+
No GPU is required because nvcc can cross-compile for any target architecture.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
request: Dict with the following fields:
|
|
94
|
+
- files: dict[str, str] - Mapping of filename to content
|
|
95
|
+
- arch: str - Target architecture (e.g., "sm_90a")
|
|
96
|
+
- flags: list[str] - Additional nvcc flags
|
|
97
|
+
- output: list[str] - Output formats ("ptx", "sass")
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Dict with the following fields:
|
|
101
|
+
- success: bool - Whether compilation succeeded
|
|
102
|
+
- ptx: str | None - Generated PTX code
|
|
103
|
+
- sass: str | None - Generated SASS code
|
|
104
|
+
- stderr: str - Compiler warnings/errors
|
|
105
|
+
- compilation_time_ms: int - Time taken in ms
|
|
106
|
+
"""
|
|
107
|
+
import os
|
|
108
|
+
import subprocess
|
|
109
|
+
import tempfile
|
|
110
|
+
import time
|
|
111
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
112
|
+
from pathlib import Path
|
|
113
|
+
|
|
114
|
+
start_time = time.time()
|
|
115
|
+
|
|
116
|
+
# Extract request fields
|
|
117
|
+
files: dict[str, str] = request.get("files", {})
|
|
118
|
+
arch: str = request.get("arch", "sm_90a")
|
|
119
|
+
flags: list[str] = request.get("flags", [])
|
|
120
|
+
output_formats: list[str] = request.get("output", ["ptx", "sass"])
|
|
121
|
+
|
|
122
|
+
# Validate request
|
|
123
|
+
if not files:
|
|
124
|
+
return {
|
|
125
|
+
"success": False,
|
|
126
|
+
"ptx": None,
|
|
127
|
+
"sass": None,
|
|
128
|
+
"stderr": "No files provided",
|
|
129
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
# Find the main .cu file
|
|
133
|
+
cu_files = [f for f in files.keys() if f.endswith(".cu")]
|
|
134
|
+
if not cu_files:
|
|
135
|
+
return {
|
|
136
|
+
"success": False,
|
|
137
|
+
"ptx": None,
|
|
138
|
+
"sass": None,
|
|
139
|
+
"stderr": "No .cu file found",
|
|
140
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
main_cu_file = cu_files[0]
|
|
144
|
+
|
|
145
|
+
# Build environment for nvcc
|
|
146
|
+
nvcc_env = {
|
|
147
|
+
**os.environ,
|
|
148
|
+
"CUDA_HOME": "/usr/local/cuda",
|
|
149
|
+
"PATH": f"/usr/local/cuda/bin:{os.environ.get('PATH', '')}",
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
def compile_ptx(tmpdir: str, base_cmd: list[str], main_cu_path: Path) -> tuple[str | None, str | None]:
|
|
153
|
+
"""Compile to PTX. Returns (ptx_content, error_message)."""
|
|
154
|
+
ptx_output = Path(tmpdir) / "output.ptx"
|
|
155
|
+
ptx_cmd = base_cmd + [
|
|
156
|
+
"--ptx",
|
|
157
|
+
"-o",
|
|
158
|
+
str(ptx_output),
|
|
159
|
+
str(main_cu_path),
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
ptx_result = subprocess.run(
|
|
163
|
+
ptx_cmd,
|
|
164
|
+
capture_output=True,
|
|
165
|
+
text=True,
|
|
166
|
+
timeout=60,
|
|
167
|
+
cwd=tmpdir,
|
|
168
|
+
env=nvcc_env,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if ptx_result.returncode != 0:
|
|
172
|
+
return None, ptx_result.stderr or ptx_result.stdout
|
|
173
|
+
|
|
174
|
+
if ptx_output.exists():
|
|
175
|
+
return ptx_output.read_text(), None
|
|
176
|
+
return None, "PTX output file not created"
|
|
177
|
+
|
|
178
|
+
def compile_sass(tmpdir: str, base_cmd: list[str], main_cu_path: Path) -> tuple[str | None, str | None]:
|
|
179
|
+
"""Compile to SASS (via cubin). Returns (sass_content, error_message)."""
|
|
180
|
+
cubin_output = Path(tmpdir) / "output.cubin"
|
|
181
|
+
cubin_cmd = base_cmd + [
|
|
182
|
+
"--cubin",
|
|
183
|
+
"-o",
|
|
184
|
+
str(cubin_output),
|
|
185
|
+
str(main_cu_path),
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
cubin_result = subprocess.run(
|
|
189
|
+
cubin_cmd,
|
|
190
|
+
capture_output=True,
|
|
191
|
+
text=True,
|
|
192
|
+
timeout=60,
|
|
193
|
+
cwd=tmpdir,
|
|
194
|
+
env=nvcc_env,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if cubin_result.returncode != 0:
|
|
198
|
+
return None, cubin_result.stderr or cubin_result.stdout
|
|
199
|
+
|
|
200
|
+
if not cubin_output.exists():
|
|
201
|
+
return None, "cubin output file not created"
|
|
202
|
+
|
|
203
|
+
# Disassemble cubin to SASS
|
|
204
|
+
sass_result = subprocess.run(
|
|
205
|
+
["cuobjdump", "--dump-sass", str(cubin_output)],
|
|
206
|
+
capture_output=True,
|
|
207
|
+
text=True,
|
|
208
|
+
timeout=30,
|
|
209
|
+
cwd=tmpdir,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if sass_result.returncode == 0:
|
|
213
|
+
return sass_result.stdout, None
|
|
214
|
+
return None, f"SASS disassembly failed: {sass_result.stderr}"
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
218
|
+
tmp_path = Path(tmpdir)
|
|
219
|
+
|
|
220
|
+
# Write all files to temp directory, preserving subdirectory structure
|
|
221
|
+
for filename, content in files.items():
|
|
222
|
+
file_path = (tmp_path / filename).resolve()
|
|
223
|
+
if not file_path.is_relative_to(tmp_path):
|
|
224
|
+
return {
|
|
225
|
+
"success": False,
|
|
226
|
+
"ptx": None,
|
|
227
|
+
"sass": None,
|
|
228
|
+
"stderr": f"Invalid filename: {filename}",
|
|
229
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
230
|
+
}
|
|
231
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
232
|
+
file_path.write_text(content)
|
|
233
|
+
|
|
234
|
+
# Determine the directory of the main .cu file for include paths
|
|
235
|
+
main_cu_path = tmp_path / main_cu_file
|
|
236
|
+
include_dir = main_cu_path.parent
|
|
237
|
+
|
|
238
|
+
# Build base nvcc command with common flags
|
|
239
|
+
base_cmd = [
|
|
240
|
+
"nvcc",
|
|
241
|
+
"-arch",
|
|
242
|
+
arch,
|
|
243
|
+
f"-I{include_dir}",
|
|
244
|
+
"-I/usr/local/lib/python3.12/site-packages/torch/include",
|
|
245
|
+
"-I/usr/local/lib/python3.12/site-packages/torch/include/torch/csrc/api/include",
|
|
246
|
+
"-I/usr/local/cutlass/include",
|
|
247
|
+
]
|
|
248
|
+
base_cmd.extend(flags)
|
|
249
|
+
|
|
250
|
+
# Determine what to compile
|
|
251
|
+
want_ptx = OutputFormat.PTX.value in output_formats
|
|
252
|
+
want_sass = OutputFormat.SASS.value in output_formats
|
|
253
|
+
|
|
254
|
+
results: dict[str, str | None] = {"ptx": None, "sass": None}
|
|
255
|
+
errors: list[str] = []
|
|
256
|
+
|
|
257
|
+
# Run compilations in parallel if both are requested
|
|
258
|
+
if want_ptx and want_sass:
|
|
259
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
260
|
+
futures = {
|
|
261
|
+
executor.submit(compile_ptx, tmpdir, base_cmd, main_cu_path): "ptx",
|
|
262
|
+
executor.submit(compile_sass, tmpdir, base_cmd, main_cu_path): "sass",
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
for future in as_completed(futures):
|
|
266
|
+
output_type = futures[future]
|
|
267
|
+
try:
|
|
268
|
+
content, error = future.result()
|
|
269
|
+
if content:
|
|
270
|
+
results[output_type] = content
|
|
271
|
+
if error:
|
|
272
|
+
errors.append(f"{output_type.upper()}: {error}")
|
|
273
|
+
except Exception as e:
|
|
274
|
+
errors.append(f"{output_type.upper()} compilation error: {e}")
|
|
275
|
+
|
|
276
|
+
elif want_ptx:
|
|
277
|
+
content, error = compile_ptx(tmpdir, base_cmd, main_cu_path)
|
|
278
|
+
if content:
|
|
279
|
+
results["ptx"] = content
|
|
280
|
+
if error:
|
|
281
|
+
errors.append(error)
|
|
282
|
+
|
|
283
|
+
elif want_sass:
|
|
284
|
+
content, error = compile_sass(tmpdir, base_cmd, main_cu_path)
|
|
285
|
+
if content:
|
|
286
|
+
results["sass"] = content
|
|
287
|
+
if error:
|
|
288
|
+
errors.append(error)
|
|
289
|
+
|
|
290
|
+
# Check results
|
|
291
|
+
if not results["ptx"] and not results["sass"]:
|
|
292
|
+
return {
|
|
293
|
+
"success": False,
|
|
294
|
+
"ptx": None,
|
|
295
|
+
"sass": None,
|
|
296
|
+
"stderr": "\n".join(errors) if errors else "No output generated",
|
|
297
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
# Partial success if we got at least one output
|
|
301
|
+
stderr = ""
|
|
302
|
+
if errors and (results["ptx"] or results["sass"]):
|
|
303
|
+
stderr = "\n".join(errors)
|
|
304
|
+
|
|
305
|
+
return {
|
|
306
|
+
"success": True,
|
|
307
|
+
"ptx": results["ptx"],
|
|
308
|
+
"sass": results["sass"],
|
|
309
|
+
"stderr": stderr,
|
|
310
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
except subprocess.TimeoutExpired:
|
|
314
|
+
return {
|
|
315
|
+
"success": False,
|
|
316
|
+
"ptx": None,
|
|
317
|
+
"sass": None,
|
|
318
|
+
"stderr": "Compilation timed out",
|
|
319
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
320
|
+
}
|
|
321
|
+
except Exception as e:
|
|
322
|
+
import traceback
|
|
323
|
+
|
|
324
|
+
return {
|
|
325
|
+
"success": False,
|
|
326
|
+
"ptx": None,
|
|
327
|
+
"sass": None,
|
|
328
|
+
"stderr": f"Internal error: {e}\n{traceback.format_exc()}",
|
|
329
|
+
"compilation_time_ms": int((time.time() - start_time) * 1000),
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
334
|
+
# Health Check
|
|
335
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@app.function(cpu=1, memory=1024, timeout=30)
|
|
339
|
+
def health_check() -> dict:
|
|
340
|
+
"""Verify the compilation environment is working.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
Dict with status and version info
|
|
344
|
+
"""
|
|
345
|
+
import subprocess
|
|
346
|
+
|
|
347
|
+
try:
|
|
348
|
+
# Get nvcc version
|
|
349
|
+
nvcc_result = subprocess.run(
|
|
350
|
+
["nvcc", "--version"],
|
|
351
|
+
capture_output=True,
|
|
352
|
+
text=True,
|
|
353
|
+
timeout=10,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Parse version from output
|
|
357
|
+
version_line = nvcc_result.stdout.strip().split("\n")[-1]
|
|
358
|
+
|
|
359
|
+
return {
|
|
360
|
+
"status": "ok",
|
|
361
|
+
"nvcc_version": version_line,
|
|
362
|
+
"nvcc_available": nvcc_result.returncode == 0,
|
|
363
|
+
}
|
|
364
|
+
except Exception as e:
|
|
365
|
+
return {
|
|
366
|
+
"status": "error",
|
|
367
|
+
"error": str(e),
|
|
368
|
+
"nvcc_available": False,
|
|
369
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Tests for the cloud CUDA compiler."""
|