cuda-engine 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuda_engine/__init__.py +24 -0
- cuda_engine/api.py +39 -0
- cuda_engine/cli.py +485 -0
- cuda_engine/config.py +32 -0
- cuda_engine/models/__init__.py +27 -0
- cuda_engine/models/artifact.py +12 -0
- cuda_engine/models/reports.py +106 -0
- cuda_engine/models/spec.py +45 -0
- cuda_engine/orchestrator.py +352 -0
- cuda_engine/prompts/__init__.py +8 -0
- cuda_engine/prompts/codegen.md +29 -0
- cuda_engine/prompts/interview.md +30 -0
- cuda_engine/prompts/perf_fix.md +56 -0
- cuda_engine/prompts/polish.md +13 -0
- cuda_engine/services/__init__.py +1 -0
- cuda_engine/services/gpu/__init__.py +3 -0
- cuda_engine/services/gpu/_run_kernel_child.py +305 -0
- cuda_engine/services/gpu/base.py +88 -0
- cuda_engine/services/gpu/local.py +451 -0
- cuda_engine/services/gpu/mocks.py +85 -0
- cuda_engine/services/llm/__init__.py +3 -0
- cuda_engine/services/llm/anthropic.py +71 -0
- cuda_engine/services/llm/base.py +35 -0
- cuda_engine/services/llm/mocks.py +38 -0
- cuda_engine/services/llm/tools.py +64 -0
- cuda_engine/services/store/__init__.py +3 -0
- cuda_engine/services/store/base.py +24 -0
- cuda_engine/services/store/local_dir.py +42 -0
- cuda_engine/services/store/mocks.py +27 -0
- cuda_engine/stages/__init__.py +1 -0
- cuda_engine/stages/base.py +41 -0
- cuda_engine/stages/codegen.py +193 -0
- cuda_engine/stages/correctness.py +241 -0
- cuda_engine/stages/interview.py +117 -0
- cuda_engine/stages/performance.py +424 -0
- cuda_engine/stages/polish.py +152 -0
- cuda_engine/targets/__init__.py +7 -0
- cuda_engine/targets/sm_100.py +2 -0
- cuda_engine/targets/sm_80.py +18 -0
- cuda_engine/targets/sm_90.py +2 -0
- cuda_engine-1.0.0.dist-info/METADATA +266 -0
- cuda_engine-1.0.0.dist-info/RECORD +45 -0
- cuda_engine-1.0.0.dist-info/WHEEL +4 -0
- cuda_engine-1.0.0.dist-info/entry_points.txt +2 -0
- cuda_engine-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from typing import Any, Self
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CorrectnessReport(BaseModel):
|
|
7
|
+
passed: bool
|
|
8
|
+
max_abs_err: float
|
|
9
|
+
max_rel_err: float
|
|
10
|
+
shapes_tested: list[tuple[int, ...]]
|
|
11
|
+
shape_results: list[dict[str, Any]] = Field(default_factory=list)
|
|
12
|
+
failing_inputs: list[dict[str, Any]] = Field(default_factory=list)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PerformanceReport(BaseModel):
|
|
16
|
+
speedup_vs_reference: float | None = None
|
|
17
|
+
speedup_vs_torch_compile: float | None = None
|
|
18
|
+
achieved_tflops: float | None = None
|
|
19
|
+
achieved_gbps: float | None = None
|
|
20
|
+
occupancy: float | None = None
|
|
21
|
+
regs_per_thread: int | None = None
|
|
22
|
+
spill_bytes: int = 0
|
|
23
|
+
below_target: bool = False
|
|
24
|
+
notes: list[str] = Field(default_factory=list)
|
|
25
|
+
warnings: list[str] = Field(default_factory=list)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class StageTrace(BaseModel):
|
|
29
|
+
stage_name: str
|
|
30
|
+
attempts: int
|
|
31
|
+
succeeded: bool
|
|
32
|
+
model_used: str
|
|
33
|
+
tokens_in: int = 0
|
|
34
|
+
tokens_out: int = 0
|
|
35
|
+
cache_read_tokens: int = 0
|
|
36
|
+
latency_seconds: float = 0.0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SynthesisReport(BaseModel):
|
|
40
|
+
run_id: str
|
|
41
|
+
spec_name: str
|
|
42
|
+
stages_executed: list[str]
|
|
43
|
+
stage_traces: list[StageTrace] = Field(default_factory=list)
|
|
44
|
+
total_llm_tokens_in: int = 0
|
|
45
|
+
total_llm_tokens_out: int = 0
|
|
46
|
+
total_cost_usd: float = 0.0
|
|
47
|
+
wall_time_seconds: float = 0.0
|
|
48
|
+
warnings: list[str] = Field(default_factory=list)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SynthesisResult(BaseModel):
|
|
52
|
+
"""Top-level return value of synthesize()."""
|
|
53
|
+
|
|
54
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
55
|
+
|
|
56
|
+
passed: bool
|
|
57
|
+
run_id: str
|
|
58
|
+
artifacts_dir: str
|
|
59
|
+
report: SynthesisReport
|
|
60
|
+
failed_stage: int | None = None
|
|
61
|
+
failure_reason: str | None = None
|
|
62
|
+
correctness: CorrectnessReport | None = None
|
|
63
|
+
performance: PerformanceReport | None = None
|
|
64
|
+
kernel_callable: object | None = None
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def ok(
|
|
68
|
+
cls,
|
|
69
|
+
*,
|
|
70
|
+
run_id: str,
|
|
71
|
+
artifacts_dir: str,
|
|
72
|
+
report: SynthesisReport,
|
|
73
|
+
correctness: CorrectnessReport,
|
|
74
|
+
performance: PerformanceReport,
|
|
75
|
+
kernel_callable: object | None,
|
|
76
|
+
) -> Self:
|
|
77
|
+
return cls(
|
|
78
|
+
passed=True,
|
|
79
|
+
run_id=run_id,
|
|
80
|
+
artifacts_dir=artifacts_dir,
|
|
81
|
+
report=report,
|
|
82
|
+
correctness=correctness,
|
|
83
|
+
performance=performance,
|
|
84
|
+
kernel_callable=kernel_callable,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def failed(
|
|
89
|
+
cls,
|
|
90
|
+
*,
|
|
91
|
+
stage: int,
|
|
92
|
+
reason: str,
|
|
93
|
+
run_id: str,
|
|
94
|
+
artifacts_dir: str,
|
|
95
|
+
report: SynthesisReport,
|
|
96
|
+
correctness: CorrectnessReport | None = None,
|
|
97
|
+
) -> Self:
|
|
98
|
+
return cls(
|
|
99
|
+
passed=False,
|
|
100
|
+
failed_stage=stage,
|
|
101
|
+
failure_reason=reason,
|
|
102
|
+
run_id=run_id,
|
|
103
|
+
artifacts_dir=artifacts_dir,
|
|
104
|
+
report=report,
|
|
105
|
+
correctness=correctness,
|
|
106
|
+
)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from typing import Literal, TypeAlias
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
7
|
+
|
|
8
|
+
DType: TypeAlias = Literal["fp32", "fp16", "bf16", "fp64", "int32", "int64", "uint8", "int8"]
|
|
9
|
+
TargetArch: TypeAlias = Literal["sm_80", "sm_90", "sm_100", "sm_120"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OptimizationPriority(StrEnum):
|
|
13
|
+
LATENCY = "latency"
|
|
14
|
+
THROUGHPUT = "throughput"
|
|
15
|
+
BALANCED = "balanced"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TensorArg(BaseModel):
|
|
19
|
+
model_config = ConfigDict(frozen=True)
|
|
20
|
+
|
|
21
|
+
name: str
|
|
22
|
+
dtype: DType
|
|
23
|
+
shape: tuple[str, ...] = Field(description="Symbolic shape, e.g. ('B', 'S', 'D')")
|
|
24
|
+
layout_hint: Literal["row_major", "col_major", "any"] = "any"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PrecisionTolerance(BaseModel):
|
|
28
|
+
model_config = ConfigDict(frozen=True)
|
|
29
|
+
|
|
30
|
+
rtol: float = 1e-3
|
|
31
|
+
atol: float = 1e-3
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class KernelSpec(BaseModel):
|
|
35
|
+
"""Frozen Stage 1 contract; downstream stages must not mutate it."""
|
|
36
|
+
|
|
37
|
+
model_config = ConfigDict(frozen=True)
|
|
38
|
+
|
|
39
|
+
name: str
|
|
40
|
+
target_arch: TargetArch
|
|
41
|
+
inputs: list[TensorArg]
|
|
42
|
+
outputs: list[TensorArg]
|
|
43
|
+
precision_tolerance: PrecisionTolerance
|
|
44
|
+
optimization_priority: OptimizationPriority
|
|
45
|
+
notes: str = ""
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, TypeVar
|
|
7
|
+
|
|
8
|
+
from cuda_engine.config import SynthesisConfig
|
|
9
|
+
from cuda_engine.models import (
|
|
10
|
+
CorrectnessReport,
|
|
11
|
+
KernelArtifact,
|
|
12
|
+
StageTrace,
|
|
13
|
+
SynthesisReport,
|
|
14
|
+
SynthesisResult,
|
|
15
|
+
)
|
|
16
|
+
from cuda_engine.services.gpu.base import GPURunner
|
|
17
|
+
from cuda_engine.services.llm.base import LLMClient, LLMResponse, ToolSpec
|
|
18
|
+
from cuda_engine.services.store.base import ArtifactStore
|
|
19
|
+
from cuda_engine.stages.base import BudgetExhaustedError
|
|
20
|
+
from cuda_engine.stages.codegen import Stage2Codegen
|
|
21
|
+
from cuda_engine.stages.correctness import Stage3Correctness
|
|
22
|
+
from cuda_engine.stages.interview import Stage1Interview
|
|
23
|
+
from cuda_engine.stages.performance import Stage4Performance
|
|
24
|
+
from cuda_engine.stages.polish import Stage5Polish
|
|
25
|
+
|
|
26
|
+
T = TypeVar("T")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Orchestrator:
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
*,
|
|
33
|
+
llm: LLMClient,
|
|
34
|
+
gpu: GPURunner,
|
|
35
|
+
store: ArtifactStore,
|
|
36
|
+
cfg: SynthesisConfig,
|
|
37
|
+
) -> None:
|
|
38
|
+
self.llm = llm
|
|
39
|
+
self.gpu = gpu
|
|
40
|
+
self.store = store
|
|
41
|
+
self.cfg = cfg
|
|
42
|
+
|
|
43
|
+
def run(self, *, prompt: str, reference: Callable[..., Any], target: str) -> SynthesisResult:
|
|
44
|
+
run_id = self.store.new_run()
|
|
45
|
+
started_at = time.time()
|
|
46
|
+
llm = _TracingLLMClient(self.llm)
|
|
47
|
+
stage_traces: list[StageTrace] = []
|
|
48
|
+
self.store.write_text(run_id, "inputs/prompt.txt", prompt)
|
|
49
|
+
self.store.write_json(run_id, "inputs/config.json", self.cfg)
|
|
50
|
+
self.store.write_text(run_id, "inputs/reference.py", _reference_source(reference))
|
|
51
|
+
|
|
52
|
+
spec = _run_traced_stage(
|
|
53
|
+
stage_traces,
|
|
54
|
+
llm,
|
|
55
|
+
"interview",
|
|
56
|
+
lambda: Stage1Interview(llm=llm, store=self.store).run(
|
|
57
|
+
prompt=prompt,
|
|
58
|
+
reference=reference,
|
|
59
|
+
target_arch=target,
|
|
60
|
+
run_id=run_id,
|
|
61
|
+
model=self.cfg.sonnet_model,
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
artifact = _run_traced_stage(
|
|
65
|
+
stage_traces,
|
|
66
|
+
llm,
|
|
67
|
+
"codegen",
|
|
68
|
+
lambda: _run_codegen_with_escalation(
|
|
69
|
+
llm=llm,
|
|
70
|
+
gpu=self.gpu,
|
|
71
|
+
store=self.store,
|
|
72
|
+
cfg=self.cfg,
|
|
73
|
+
run_args={
|
|
74
|
+
"spec": spec,
|
|
75
|
+
"run_id": run_id,
|
|
76
|
+
"retry_budget": self.cfg.retry_budgets.codegen,
|
|
77
|
+
},
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
correctness = _run_traced_stage(
|
|
81
|
+
stage_traces,
|
|
82
|
+
llm,
|
|
83
|
+
"correctness",
|
|
84
|
+
lambda: Stage3Correctness(llm=llm, gpu=self.gpu, store=self.store).run(
|
|
85
|
+
spec=spec,
|
|
86
|
+
artifact=artifact,
|
|
87
|
+
reference=reference,
|
|
88
|
+
run_id=run_id,
|
|
89
|
+
retry_budget=self.cfg.retry_budgets.correctness,
|
|
90
|
+
correctness_shapes=self.cfg.correctness_shapes,
|
|
91
|
+
),
|
|
92
|
+
succeeded=lambda report: report.passed,
|
|
93
|
+
)
|
|
94
|
+
for repair_attempt in range(1, self.cfg.retry_budgets.correctness + 1):
|
|
95
|
+
if correctness.passed:
|
|
96
|
+
break
|
|
97
|
+
repair_dir = f"stage3_repair/attempt_{repair_attempt:02d}"
|
|
98
|
+
self.store.write_json(
|
|
99
|
+
run_id,
|
|
100
|
+
f"{repair_dir}/correctness_report.json",
|
|
101
|
+
correctness.model_dump(mode="json"),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def repair_action(
|
|
105
|
+
correctness_report: CorrectnessReport = correctness,
|
|
106
|
+
repair_prefix: str = repair_dir,
|
|
107
|
+
) -> KernelArtifact:
|
|
108
|
+
return _run_codegen_with_escalation(
|
|
109
|
+
llm=llm,
|
|
110
|
+
gpu=self.gpu,
|
|
111
|
+
store=self.store,
|
|
112
|
+
cfg=self.cfg,
|
|
113
|
+
run_args={
|
|
114
|
+
"spec": spec,
|
|
115
|
+
"run_id": run_id,
|
|
116
|
+
"retry_budget": self.cfg.retry_budgets.codegen,
|
|
117
|
+
"repair_context": correctness_report,
|
|
118
|
+
"artifact_prefix": f"{repair_prefix}/codegen",
|
|
119
|
+
},
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
artifact = _run_traced_stage(
|
|
123
|
+
stage_traces,
|
|
124
|
+
llm,
|
|
125
|
+
"codegen_repair",
|
|
126
|
+
repair_action,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def correctness_action(candidate: KernelArtifact = artifact) -> CorrectnessReport:
|
|
130
|
+
return Stage3Correctness(llm=llm, gpu=self.gpu, store=self.store).run(
|
|
131
|
+
spec=spec,
|
|
132
|
+
artifact=candidate,
|
|
133
|
+
reference=reference,
|
|
134
|
+
run_id=run_id,
|
|
135
|
+
retry_budget=self.cfg.retry_budgets.correctness,
|
|
136
|
+
correctness_shapes=self.cfg.correctness_shapes,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
correctness = _run_traced_stage(
|
|
140
|
+
stage_traces,
|
|
141
|
+
llm,
|
|
142
|
+
"correctness",
|
|
143
|
+
correctness_action,
|
|
144
|
+
succeeded=lambda report: report.passed,
|
|
145
|
+
)
|
|
146
|
+
if not correctness.passed:
|
|
147
|
+
result = SynthesisResult.failed(
|
|
148
|
+
stage=3,
|
|
149
|
+
reason="correctness check failed",
|
|
150
|
+
run_id=run_id,
|
|
151
|
+
artifacts_dir=str(self.store.run_dir(run_id)),
|
|
152
|
+
report=_build_report(
|
|
153
|
+
run_id=run_id,
|
|
154
|
+
spec_name=spec.name,
|
|
155
|
+
stage_traces=stage_traces,
|
|
156
|
+
wall_time_seconds=time.time() - started_at,
|
|
157
|
+
),
|
|
158
|
+
correctness=correctness,
|
|
159
|
+
)
|
|
160
|
+
_write_result_report(self.store, result)
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
performance, artifact = _run_traced_stage(
|
|
164
|
+
stage_traces,
|
|
165
|
+
llm,
|
|
166
|
+
"performance",
|
|
167
|
+
lambda: Stage4Performance(llm=llm, gpu=self.gpu, store=self.store, cfg=self.cfg).run(
|
|
168
|
+
spec=spec,
|
|
169
|
+
artifact=artifact,
|
|
170
|
+
run_id=run_id,
|
|
171
|
+
retry_budget=self.cfg.retry_budgets.performance,
|
|
172
|
+
reference=reference,
|
|
173
|
+
),
|
|
174
|
+
)
|
|
175
|
+
artifact = _run_traced_stage(
|
|
176
|
+
stage_traces,
|
|
177
|
+
llm,
|
|
178
|
+
"polish",
|
|
179
|
+
lambda: Stage5Polish(llm=llm, gpu=self.gpu, store=self.store).run(
|
|
180
|
+
spec=spec,
|
|
181
|
+
artifact=artifact,
|
|
182
|
+
correctness=correctness,
|
|
183
|
+
performance=performance,
|
|
184
|
+
reference=reference,
|
|
185
|
+
run_id=run_id,
|
|
186
|
+
model=self.cfg.sonnet_model,
|
|
187
|
+
correctness_shapes=self.cfg.correctness_shapes,
|
|
188
|
+
),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
report = _build_report(
|
|
192
|
+
run_id=run_id,
|
|
193
|
+
spec_name=spec.name,
|
|
194
|
+
stage_traces=stage_traces,
|
|
195
|
+
wall_time_seconds=time.time() - started_at,
|
|
196
|
+
warnings=["below perf target"] if performance.below_target else [],
|
|
197
|
+
)
|
|
198
|
+
result = SynthesisResult.ok(
|
|
199
|
+
run_id=run_id,
|
|
200
|
+
artifacts_dir=str(self.store.run_dir(run_id)),
|
|
201
|
+
report=report,
|
|
202
|
+
correctness=correctness,
|
|
203
|
+
performance=performance,
|
|
204
|
+
kernel_callable=None,
|
|
205
|
+
)
|
|
206
|
+
_write_result_report(self.store, result)
|
|
207
|
+
return result
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _run_codegen_with_escalation(
|
|
211
|
+
*,
|
|
212
|
+
llm: _TracingLLMClient,
|
|
213
|
+
gpu: GPURunner,
|
|
214
|
+
store: ArtifactStore,
|
|
215
|
+
cfg: SynthesisConfig,
|
|
216
|
+
run_args: dict[str, Any],
|
|
217
|
+
) -> KernelArtifact:
|
|
218
|
+
"""Run Stage2Codegen with Sonnet, escalating to Opus on BudgetExhaustedError."""
|
|
219
|
+
try:
|
|
220
|
+
return Stage2Codegen(llm=llm, gpu=gpu, store=store).run(
|
|
221
|
+
**run_args, model=cfg.sonnet_model
|
|
222
|
+
)
|
|
223
|
+
except BudgetExhaustedError as bust:
|
|
224
|
+
if not cfg.escalate_to_opus_on_bust or cfg.opus_retry_budget_codegen <= 0:
|
|
225
|
+
raise
|
|
226
|
+
opus_run_args = {
|
|
227
|
+
**run_args,
|
|
228
|
+
"retry_budget": cfg.opus_retry_budget_codegen,
|
|
229
|
+
"artifact_prefix": f"{run_args.get('artifact_prefix', 'stage2_codegen')}/escalated",
|
|
230
|
+
"escalation_context": bust.summary,
|
|
231
|
+
}
|
|
232
|
+
return Stage2Codegen(llm=llm, gpu=gpu, store=store).run(
|
|
233
|
+
**opus_run_args, model=cfg.opus_model
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _reference_source(reference: Callable[..., Any]) -> str:
|
|
238
|
+
try:
|
|
239
|
+
return inspect.getsource(reference)
|
|
240
|
+
except OSError:
|
|
241
|
+
return repr(reference)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class _TracingLLMClient(LLMClient):
|
|
245
|
+
def __init__(self, inner: LLMClient) -> None:
|
|
246
|
+
self._inner = inner
|
|
247
|
+
self.responses: list[LLMResponse] = []
|
|
248
|
+
|
|
249
|
+
def complete(
|
|
250
|
+
self,
|
|
251
|
+
*,
|
|
252
|
+
system: list[dict[str, Any]],
|
|
253
|
+
messages: list[dict[str, Any]],
|
|
254
|
+
tools: list[ToolSpec] | None = None,
|
|
255
|
+
model: str,
|
|
256
|
+
max_tokens: int = 4096,
|
|
257
|
+
temperature: float | None = None,
|
|
258
|
+
) -> LLMResponse:
|
|
259
|
+
response = self._inner.complete(
|
|
260
|
+
system=system,
|
|
261
|
+
messages=messages,
|
|
262
|
+
tools=tools,
|
|
263
|
+
model=model,
|
|
264
|
+
max_tokens=max_tokens,
|
|
265
|
+
temperature=temperature,
|
|
266
|
+
)
|
|
267
|
+
self.responses.append(response)
|
|
268
|
+
return response
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _run_traced_stage(
|
|
272
|
+
stage_traces: list[StageTrace],
|
|
273
|
+
llm: _TracingLLMClient,
|
|
274
|
+
stage_name: str,
|
|
275
|
+
action: Callable[[], T],
|
|
276
|
+
*,
|
|
277
|
+
succeeded: Callable[[T], bool] | None = None,
|
|
278
|
+
) -> T:
|
|
279
|
+
response_start = len(llm.responses)
|
|
280
|
+
started_at = time.time()
|
|
281
|
+
try:
|
|
282
|
+
result = action()
|
|
283
|
+
except Exception:
|
|
284
|
+
responses = llm.responses[response_start:]
|
|
285
|
+
stage_traces.append(_build_stage_trace(stage_name, responses, started_at, succeeded=False))
|
|
286
|
+
raise
|
|
287
|
+
|
|
288
|
+
responses = llm.responses[response_start:]
|
|
289
|
+
stage_traces.append(
|
|
290
|
+
_build_stage_trace(
|
|
291
|
+
stage_name,
|
|
292
|
+
responses,
|
|
293
|
+
started_at,
|
|
294
|
+
succeeded=succeeded(result) if succeeded is not None else True,
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
return result
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _build_stage_trace(
|
|
301
|
+
stage_name: str,
|
|
302
|
+
responses: list[LLMResponse],
|
|
303
|
+
started_at: float,
|
|
304
|
+
*,
|
|
305
|
+
succeeded: bool,
|
|
306
|
+
) -> StageTrace:
|
|
307
|
+
reported_latency = sum(response.latency_seconds for response in responses)
|
|
308
|
+
return StageTrace(
|
|
309
|
+
stage_name=stage_name,
|
|
310
|
+
attempts=max(1, len(responses)),
|
|
311
|
+
succeeded=succeeded,
|
|
312
|
+
model_used=_model_summary(responses),
|
|
313
|
+
tokens_in=sum(response.tokens_in for response in responses),
|
|
314
|
+
tokens_out=sum(response.tokens_out for response in responses),
|
|
315
|
+
cache_read_tokens=sum(response.cache_read_tokens for response in responses),
|
|
316
|
+
latency_seconds=reported_latency if reported_latency > 0 else time.time() - started_at,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _model_summary(responses: list[LLMResponse]) -> str:
|
|
321
|
+
if not responses:
|
|
322
|
+
return "none"
|
|
323
|
+
models: list[str] = []
|
|
324
|
+
for response in responses:
|
|
325
|
+
if response.model not in models:
|
|
326
|
+
models.append(response.model)
|
|
327
|
+
return ", ".join(models)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _build_report(
|
|
331
|
+
*,
|
|
332
|
+
run_id: str,
|
|
333
|
+
spec_name: str,
|
|
334
|
+
stage_traces: list[StageTrace],
|
|
335
|
+
wall_time_seconds: float,
|
|
336
|
+
warnings: list[str] | None = None,
|
|
337
|
+
) -> SynthesisReport:
|
|
338
|
+
return SynthesisReport(
|
|
339
|
+
run_id=run_id,
|
|
340
|
+
spec_name=spec_name,
|
|
341
|
+
stages_executed=[trace.stage_name for trace in stage_traces],
|
|
342
|
+
stage_traces=stage_traces,
|
|
343
|
+
total_llm_tokens_in=sum(trace.tokens_in for trace in stage_traces),
|
|
344
|
+
total_llm_tokens_out=sum(trace.tokens_out for trace in stage_traces),
|
|
345
|
+
wall_time_seconds=wall_time_seconds,
|
|
346
|
+
warnings=warnings or [],
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _write_result_report(store: ArtifactStore, result: SynthesisResult) -> None:
|
|
351
|
+
payload = result.model_dump(mode="json", exclude={"kernel_callable"})
|
|
352
|
+
store.write_json(result.run_id, "report.json", payload)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from importlib.resources import files
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def load_prompt(name: str) -> str:
|
|
5
|
+
prompt_path = files(__package__).joinpath(f"{name}.md")
|
|
6
|
+
if not prompt_path.is_file():
|
|
7
|
+
raise FileNotFoundError(f"Prompt not found: {name}")
|
|
8
|
+
return prompt_path.read_text(encoding="utf-8")
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# CUDA Codegen Stage
|
|
2
|
+
|
|
3
|
+
You generate a single CUDA `.cu` file for the frozen `KernelSpec`.
|
|
4
|
+
|
|
5
|
+
Required runnable ABI:
|
|
6
|
+
- The generated source must be a Torch-loadable C++/CUDA extension, not a raw CUDA-only library.
|
|
7
|
+
- Include the needed Torch headers, normally `#include <torch/extension.h>` and `#include <ATen/cuda/CUDAContext.h>`.
|
|
8
|
+
- Expose exactly one user-callable op: `cuda_engine::forward`.
|
|
9
|
+
- Register the schema with `TORCH_LIBRARY(cuda_engine, m)`.
|
|
10
|
+
- Register the CUDA implementation with `TORCH_LIBRARY_IMPL(cuda_engine, CUDA, m)`.
|
|
11
|
+
- The Python runner will call `torch.ops.cuda_engine.forward(*inputs)`, so the op signature must match the `KernelSpec` inputs and outputs.
|
|
12
|
+
- Return a single `torch::Tensor` for one output, or a tuple/list-compatible Torch return type for multiple outputs.
|
|
13
|
+
|
|
14
|
+
Rules:
|
|
15
|
+
- Honor the target architecture and the frozen input/output contract.
|
|
16
|
+
- Treat KernelSpec inputs with `shape: []` as scalar/0-D Torch tensors, not vectors.
|
|
17
|
+
- For reduction outputs, return tensors with the exact reduced shape in the KernelSpec.
|
|
18
|
+
Example: input `["B", "D"]` and output `["B"]` means one output element per row.
|
|
19
|
+
- For argmax kernels, return `int64` indices when the KernelSpec output dtype is `int64`.
|
|
20
|
+
- For RMSNorm fp16 kernels, use fp32 accumulation for the mean square and reciprocal square root,
|
|
21
|
+
do not add gamma unless the KernelSpec includes it, and cast the final output to fp16.
|
|
22
|
+
- For `sm_80`, prefer straightforward CUDA C++ suitable for A100.
|
|
23
|
+
- Make memory hierarchy choices explicit in comments when they affect performance.
|
|
24
|
+
- Use 256 threads per block as the default elementwise baseline unless the spec suggests otherwise.
|
|
25
|
+
- Output complete CUDA source as one fenced `cuda` code block.
|
|
26
|
+
- After generating the source, call `compile_kernel(src, target_arch)` using the exact source.
|
|
27
|
+
- If compilation fails, use the compiler errors to revise the source and call `compile_kernel` again.
|
|
28
|
+
|
|
29
|
+
Do not change dtypes, shapes, argument ordering, or precision tolerance.
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# CUDA Kernel Interview Stage
|
|
2
|
+
|
|
3
|
+
You convert a user prompt plus Python reference metadata into a frozen `KernelSpec`.
|
|
4
|
+
|
|
5
|
+
Return only structured JSON, preferably in a fenced `json` code block. The JSON must match:
|
|
6
|
+
|
|
7
|
+
```json
|
|
8
|
+
{
|
|
9
|
+
"name": "snake_case_kernel_name",
|
|
10
|
+
"target_arch": "sm_80",
|
|
11
|
+
"inputs": [{"name": "x", "dtype": "fp32", "shape": ["N"], "layout_hint": "any"}],
|
|
12
|
+
"outputs": [{"name": "out", "dtype": "fp32", "shape": ["N"], "layout_hint": "any"}],
|
|
13
|
+
"precision_tolerance": {"rtol": 0.001, "atol": 0.001},
|
|
14
|
+
"optimization_priority": "balanced",
|
|
15
|
+
"notes": "brief clarification notes"
|
|
16
|
+
}
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
Rules:
|
|
20
|
+
- Do not invent unsupported target architectures.
|
|
21
|
+
- Use symbolic shapes when concrete shapes are unknown.
|
|
22
|
+
- Represent scalar or 0-D tensor inputs with an empty shape list: `"shape": []`.
|
|
23
|
+
- For reductions, shrink only the reduced dimensions. Example: last-dimension sum of `x[B,D]`
|
|
24
|
+
should use input shape `["B", "D"]` and output shape `["B"]`.
|
|
25
|
+
- For `argmax`, use an integer output dtype, normally `int64`, with the reduced output shape.
|
|
26
|
+
- For fp16 RMSNorm without gamma, use fp16 input/output shapes that match, fp32 accumulation
|
|
27
|
+
semantics in `notes`, and a practical fp16 tolerance such as `rtol=0.01`, `atol=0.01`.
|
|
28
|
+
- Preserve the user's requested operation; do not broaden scope.
|
|
29
|
+
- Prefer `throughput` for large elementwise/reduction prompts and `latency` only when the prompt explicitly prioritizes small inputs.
|
|
30
|
+
- Use the reference metadata only to infer names and arity; if uncertain, choose conservative defaults and explain in `notes`.
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# CUDA Performance Repair
|
|
2
|
+
|
|
3
|
+
You revise a CUDA kernel that compiles and is correct but runs below the
|
|
4
|
+
performance target. Your job is to improve throughput without breaking
|
|
5
|
+
correctness, then call `compile_kernel(src, target_arch)` with the revised
|
|
6
|
+
source.
|
|
7
|
+
|
|
8
|
+
Required runnable ABI (unchanged from the previous kernel):
|
|
9
|
+
- Keep `cuda_engine::forward` as the only user-callable op.
|
|
10
|
+
- Keep the same `TORCH_LIBRARY(cuda_engine, m)` namespace, op signature,
|
|
11
|
+
argument order, dtypes, shapes, and return type.
|
|
12
|
+
- Keep correctness: outputs must remain within the KernelSpec precision
|
|
13
|
+
tolerance compared to the reference.
|
|
14
|
+
|
|
15
|
+
Inputs you will receive:
|
|
16
|
+
- The current `kernel.cu` source.
|
|
17
|
+
- The frozen `KernelSpec`.
|
|
18
|
+
- The latest `BenchmarkResult` (`custom_ms`, `baseline_ms`, achieved GB/s).
|
|
19
|
+
- A `NsightMetrics` snapshot (achieved occupancy, registers per thread,
|
|
20
|
+
spill bytes when available).
|
|
21
|
+
- Suggested optimization hints derived from those metrics.
|
|
22
|
+
|
|
23
|
+
Optimization themes to consider:
|
|
24
|
+
- **Register pressure**: high regs/thread reduces occupancy on A100
|
|
25
|
+
(max 64 regs/thread for full occupancy at 256-thread blocks). Split
|
|
26
|
+
work into more, smaller blocks; reduce live registers; only spill to
|
|
27
|
+
shared memory when necessary.
|
|
28
|
+
- **Occupancy**: low achieved occupancy means few warps are resident.
|
|
29
|
+
Investigate register, shared memory, or block-size limits.
|
|
30
|
+
- **Memory coalescing**: ensure 32 consecutive threads in a warp read
|
|
31
|
+
128 consecutive bytes. Avoid strided global loads/stores; use
|
|
32
|
+
`__ldg` for read-only cached loads where appropriate.
|
|
33
|
+
- **Grid wave alignment**: A100 has 108 SMs. Choose grid sizes that
|
|
34
|
+
fill full waves; a partial-wave tail can waste up to 20% of runtime.
|
|
35
|
+
- **Shared-memory tiling**: for reductions, use 256-thread blocks with
|
|
36
|
+
`__shfl_down_sync` for warp-level reduction; store partial results
|
|
37
|
+
to shared memory only when the reduction crosses warp boundaries.
|
|
38
|
+
- **Vectorized loads**: `float4`/`__half2` loads can double effective
|
|
39
|
+
bandwidth for elementwise ops on aligned, contiguous data.
|
|
40
|
+
- **Simple fused elementwise kernels**: for one-pass pointwise or fused
|
|
41
|
+
pointwise work, prefer one coalesced read/compute/write pass with enough
|
|
42
|
+
blocks to cover the tensor. Do not add multi-pass reductions, shared-memory
|
|
43
|
+
staging, or complicated synchronization unless the KernelSpec actually
|
|
44
|
+
requires cross-element communication.
|
|
45
|
+
|
|
46
|
+
Matching torch.compile is acceptable but not the goal. To strictly beat it on A100:
|
|
47
|
+
- Prefer vectorized memory ops: `float4` for fp32, `__half2` for fp16. They double effective bandwidth on aligned contiguous data.
|
|
48
|
+
- Align grid to A100's 108 SMs. A full wave is a multiple of 108 blocks; a partial tail wave wastes runtime. For tensors that don't divide evenly, prefer fewer-larger blocks over more-smaller.
|
|
49
|
+
- Maximize instruction-level parallelism: `#pragma unroll` inner loops with small bounded trip count. Keep enough independent work per thread to hide arithmetic and memory latency.
|
|
50
|
+
- Fuse passes when the KernelSpec permits. Reductions followed by elementwise can often be one-pass with `__shfl_down_sync` warp reductions.
|
|
51
|
+
- Inspect register pressure first if Nsight shows occupancy < 50%. If regs/thread > 64 on a 256-thread block, work-split or block-size reduction frees waves.
|
|
52
|
+
|
|
53
|
+
Output the complete revised CUDA source as one fenced `cuda` code block,
|
|
54
|
+
then call `compile_kernel(src, target_arch)` with the exact source.
|
|
55
|
+
|
|
56
|
+
Do not change dtypes, shapes, argument ordering, or precision tolerance.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# CUDA Kernel Polish Stage
|
|
2
|
+
|
|
3
|
+
You annotate an already-correct CUDA kernel for maintainability.
|
|
4
|
+
|
|
5
|
+
Return only the complete annotated CUDA source in a fenced `cuda` code block.
|
|
6
|
+
|
|
7
|
+
Annotations should explain:
|
|
8
|
+
- tile size and launch configuration choices
|
|
9
|
+
- memory layout and coalescing assumptions
|
|
10
|
+
- precision tolerance and correctness summary
|
|
11
|
+
- performance summary, including speedups and any occupancy/register notes when available
|
|
12
|
+
|
|
13
|
+
Do not change behavior, signatures, namespace registration, or the `cuda_engine::forward` ABI.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Service interfaces and implementations."""
|