cuda-engine 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. cuda_engine/__init__.py +24 -0
  2. cuda_engine/api.py +39 -0
  3. cuda_engine/cli.py +485 -0
  4. cuda_engine/config.py +32 -0
  5. cuda_engine/models/__init__.py +27 -0
  6. cuda_engine/models/artifact.py +12 -0
  7. cuda_engine/models/reports.py +106 -0
  8. cuda_engine/models/spec.py +45 -0
  9. cuda_engine/orchestrator.py +352 -0
  10. cuda_engine/prompts/__init__.py +8 -0
  11. cuda_engine/prompts/codegen.md +29 -0
  12. cuda_engine/prompts/interview.md +30 -0
  13. cuda_engine/prompts/perf_fix.md +56 -0
  14. cuda_engine/prompts/polish.md +13 -0
  15. cuda_engine/services/__init__.py +1 -0
  16. cuda_engine/services/gpu/__init__.py +3 -0
  17. cuda_engine/services/gpu/_run_kernel_child.py +305 -0
  18. cuda_engine/services/gpu/base.py +88 -0
  19. cuda_engine/services/gpu/local.py +451 -0
  20. cuda_engine/services/gpu/mocks.py +85 -0
  21. cuda_engine/services/llm/__init__.py +3 -0
  22. cuda_engine/services/llm/anthropic.py +71 -0
  23. cuda_engine/services/llm/base.py +35 -0
  24. cuda_engine/services/llm/mocks.py +38 -0
  25. cuda_engine/services/llm/tools.py +64 -0
  26. cuda_engine/services/store/__init__.py +3 -0
  27. cuda_engine/services/store/base.py +24 -0
  28. cuda_engine/services/store/local_dir.py +42 -0
  29. cuda_engine/services/store/mocks.py +27 -0
  30. cuda_engine/stages/__init__.py +1 -0
  31. cuda_engine/stages/base.py +41 -0
  32. cuda_engine/stages/codegen.py +193 -0
  33. cuda_engine/stages/correctness.py +241 -0
  34. cuda_engine/stages/interview.py +117 -0
  35. cuda_engine/stages/performance.py +424 -0
  36. cuda_engine/stages/polish.py +152 -0
  37. cuda_engine/targets/__init__.py +7 -0
  38. cuda_engine/targets/sm_100.py +2 -0
  39. cuda_engine/targets/sm_80.py +18 -0
  40. cuda_engine/targets/sm_90.py +2 -0
  41. cuda_engine-1.0.0.dist-info/METADATA +266 -0
  42. cuda_engine-1.0.0.dist-info/RECORD +45 -0
  43. cuda_engine-1.0.0.dist-info/WHEEL +4 -0
  44. cuda_engine-1.0.0.dist-info/entry_points.txt +2 -0
  45. cuda_engine-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,106 @@
1
+ from typing import Any, Self
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field
4
+
5
+
6
+ class CorrectnessReport(BaseModel):
7
+ passed: bool
8
+ max_abs_err: float
9
+ max_rel_err: float
10
+ shapes_tested: list[tuple[int, ...]]
11
+ shape_results: list[dict[str, Any]] = Field(default_factory=list)
12
+ failing_inputs: list[dict[str, Any]] = Field(default_factory=list)
13
+
14
+
15
+ class PerformanceReport(BaseModel):
16
+ speedup_vs_reference: float | None = None
17
+ speedup_vs_torch_compile: float | None = None
18
+ achieved_tflops: float | None = None
19
+ achieved_gbps: float | None = None
20
+ occupancy: float | None = None
21
+ regs_per_thread: int | None = None
22
+ spill_bytes: int = 0
23
+ below_target: bool = False
24
+ notes: list[str] = Field(default_factory=list)
25
+ warnings: list[str] = Field(default_factory=list)
26
+
27
+
28
+ class StageTrace(BaseModel):
29
+ stage_name: str
30
+ attempts: int
31
+ succeeded: bool
32
+ model_used: str
33
+ tokens_in: int = 0
34
+ tokens_out: int = 0
35
+ cache_read_tokens: int = 0
36
+ latency_seconds: float = 0.0
37
+
38
+
39
+ class SynthesisReport(BaseModel):
40
+ run_id: str
41
+ spec_name: str
42
+ stages_executed: list[str]
43
+ stage_traces: list[StageTrace] = Field(default_factory=list)
44
+ total_llm_tokens_in: int = 0
45
+ total_llm_tokens_out: int = 0
46
+ total_cost_usd: float = 0.0
47
+ wall_time_seconds: float = 0.0
48
+ warnings: list[str] = Field(default_factory=list)
49
+
50
+
51
+ class SynthesisResult(BaseModel):
52
+ """Top-level return value of synthesize()."""
53
+
54
+ model_config = ConfigDict(arbitrary_types_allowed=True)
55
+
56
+ passed: bool
57
+ run_id: str
58
+ artifacts_dir: str
59
+ report: SynthesisReport
60
+ failed_stage: int | None = None
61
+ failure_reason: str | None = None
62
+ correctness: CorrectnessReport | None = None
63
+ performance: PerformanceReport | None = None
64
+ kernel_callable: object | None = None
65
+
66
+ @classmethod
67
+ def ok(
68
+ cls,
69
+ *,
70
+ run_id: str,
71
+ artifacts_dir: str,
72
+ report: SynthesisReport,
73
+ correctness: CorrectnessReport,
74
+ performance: PerformanceReport,
75
+ kernel_callable: object | None,
76
+ ) -> Self:
77
+ return cls(
78
+ passed=True,
79
+ run_id=run_id,
80
+ artifacts_dir=artifacts_dir,
81
+ report=report,
82
+ correctness=correctness,
83
+ performance=performance,
84
+ kernel_callable=kernel_callable,
85
+ )
86
+
87
+ @classmethod
88
+ def failed(
89
+ cls,
90
+ *,
91
+ stage: int,
92
+ reason: str,
93
+ run_id: str,
94
+ artifacts_dir: str,
95
+ report: SynthesisReport,
96
+ correctness: CorrectnessReport | None = None,
97
+ ) -> Self:
98
+ return cls(
99
+ passed=False,
100
+ failed_stage=stage,
101
+ failure_reason=reason,
102
+ run_id=run_id,
103
+ artifacts_dir=artifacts_dir,
104
+ report=report,
105
+ correctness=correctness,
106
+ )
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import StrEnum
4
+ from typing import Literal, TypeAlias
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
+ DType: TypeAlias = Literal["fp32", "fp16", "bf16", "fp64", "int32", "int64", "uint8", "int8"]
9
+ TargetArch: TypeAlias = Literal["sm_80", "sm_90", "sm_100", "sm_120"]
10
+
11
+
12
+ class OptimizationPriority(StrEnum):
13
+ LATENCY = "latency"
14
+ THROUGHPUT = "throughput"
15
+ BALANCED = "balanced"
16
+
17
+
18
+ class TensorArg(BaseModel):
19
+ model_config = ConfigDict(frozen=True)
20
+
21
+ name: str
22
+ dtype: DType
23
+ shape: tuple[str, ...] = Field(description="Symbolic shape, e.g. ('B', 'S', 'D')")
24
+ layout_hint: Literal["row_major", "col_major", "any"] = "any"
25
+
26
+
27
+ class PrecisionTolerance(BaseModel):
28
+ model_config = ConfigDict(frozen=True)
29
+
30
+ rtol: float = 1e-3
31
+ atol: float = 1e-3
32
+
33
+
34
+ class KernelSpec(BaseModel):
35
+ """Frozen Stage 1 contract; downstream stages must not mutate it."""
36
+
37
+ model_config = ConfigDict(frozen=True)
38
+
39
+ name: str
40
+ target_arch: TargetArch
41
+ inputs: list[TensorArg]
42
+ outputs: list[TensorArg]
43
+ precision_tolerance: PrecisionTolerance
44
+ optimization_priority: OptimizationPriority
45
+ notes: str = ""
@@ -0,0 +1,352 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import time
5
+ from collections.abc import Callable
6
+ from typing import Any, TypeVar
7
+
8
+ from cuda_engine.config import SynthesisConfig
9
+ from cuda_engine.models import (
10
+ CorrectnessReport,
11
+ KernelArtifact,
12
+ StageTrace,
13
+ SynthesisReport,
14
+ SynthesisResult,
15
+ )
16
+ from cuda_engine.services.gpu.base import GPURunner
17
+ from cuda_engine.services.llm.base import LLMClient, LLMResponse, ToolSpec
18
+ from cuda_engine.services.store.base import ArtifactStore
19
+ from cuda_engine.stages.base import BudgetExhaustedError
20
+ from cuda_engine.stages.codegen import Stage2Codegen
21
+ from cuda_engine.stages.correctness import Stage3Correctness
22
+ from cuda_engine.stages.interview import Stage1Interview
23
+ from cuda_engine.stages.performance import Stage4Performance
24
+ from cuda_engine.stages.polish import Stage5Polish
25
+
26
+ T = TypeVar("T")
27
+
28
+
29
+ class Orchestrator:
30
+ def __init__(
31
+ self,
32
+ *,
33
+ llm: LLMClient,
34
+ gpu: GPURunner,
35
+ store: ArtifactStore,
36
+ cfg: SynthesisConfig,
37
+ ) -> None:
38
+ self.llm = llm
39
+ self.gpu = gpu
40
+ self.store = store
41
+ self.cfg = cfg
42
+
43
+ def run(self, *, prompt: str, reference: Callable[..., Any], target: str) -> SynthesisResult:
44
+ run_id = self.store.new_run()
45
+ started_at = time.time()
46
+ llm = _TracingLLMClient(self.llm)
47
+ stage_traces: list[StageTrace] = []
48
+ self.store.write_text(run_id, "inputs/prompt.txt", prompt)
49
+ self.store.write_json(run_id, "inputs/config.json", self.cfg)
50
+ self.store.write_text(run_id, "inputs/reference.py", _reference_source(reference))
51
+
52
+ spec = _run_traced_stage(
53
+ stage_traces,
54
+ llm,
55
+ "interview",
56
+ lambda: Stage1Interview(llm=llm, store=self.store).run(
57
+ prompt=prompt,
58
+ reference=reference,
59
+ target_arch=target,
60
+ run_id=run_id,
61
+ model=self.cfg.sonnet_model,
62
+ ),
63
+ )
64
+ artifact = _run_traced_stage(
65
+ stage_traces,
66
+ llm,
67
+ "codegen",
68
+ lambda: _run_codegen_with_escalation(
69
+ llm=llm,
70
+ gpu=self.gpu,
71
+ store=self.store,
72
+ cfg=self.cfg,
73
+ run_args={
74
+ "spec": spec,
75
+ "run_id": run_id,
76
+ "retry_budget": self.cfg.retry_budgets.codegen,
77
+ },
78
+ ),
79
+ )
80
+ correctness = _run_traced_stage(
81
+ stage_traces,
82
+ llm,
83
+ "correctness",
84
+ lambda: Stage3Correctness(llm=llm, gpu=self.gpu, store=self.store).run(
85
+ spec=spec,
86
+ artifact=artifact,
87
+ reference=reference,
88
+ run_id=run_id,
89
+ retry_budget=self.cfg.retry_budgets.correctness,
90
+ correctness_shapes=self.cfg.correctness_shapes,
91
+ ),
92
+ succeeded=lambda report: report.passed,
93
+ )
94
+ for repair_attempt in range(1, self.cfg.retry_budgets.correctness + 1):
95
+ if correctness.passed:
96
+ break
97
+ repair_dir = f"stage3_repair/attempt_{repair_attempt:02d}"
98
+ self.store.write_json(
99
+ run_id,
100
+ f"{repair_dir}/correctness_report.json",
101
+ correctness.model_dump(mode="json"),
102
+ )
103
+
104
+ def repair_action(
105
+ correctness_report: CorrectnessReport = correctness,
106
+ repair_prefix: str = repair_dir,
107
+ ) -> KernelArtifact:
108
+ return _run_codegen_with_escalation(
109
+ llm=llm,
110
+ gpu=self.gpu,
111
+ store=self.store,
112
+ cfg=self.cfg,
113
+ run_args={
114
+ "spec": spec,
115
+ "run_id": run_id,
116
+ "retry_budget": self.cfg.retry_budgets.codegen,
117
+ "repair_context": correctness_report,
118
+ "artifact_prefix": f"{repair_prefix}/codegen",
119
+ },
120
+ )
121
+
122
+ artifact = _run_traced_stage(
123
+ stage_traces,
124
+ llm,
125
+ "codegen_repair",
126
+ repair_action,
127
+ )
128
+
129
+ def correctness_action(candidate: KernelArtifact = artifact) -> CorrectnessReport:
130
+ return Stage3Correctness(llm=llm, gpu=self.gpu, store=self.store).run(
131
+ spec=spec,
132
+ artifact=candidate,
133
+ reference=reference,
134
+ run_id=run_id,
135
+ retry_budget=self.cfg.retry_budgets.correctness,
136
+ correctness_shapes=self.cfg.correctness_shapes,
137
+ )
138
+
139
+ correctness = _run_traced_stage(
140
+ stage_traces,
141
+ llm,
142
+ "correctness",
143
+ correctness_action,
144
+ succeeded=lambda report: report.passed,
145
+ )
146
+ if not correctness.passed:
147
+ result = SynthesisResult.failed(
148
+ stage=3,
149
+ reason="correctness check failed",
150
+ run_id=run_id,
151
+ artifacts_dir=str(self.store.run_dir(run_id)),
152
+ report=_build_report(
153
+ run_id=run_id,
154
+ spec_name=spec.name,
155
+ stage_traces=stage_traces,
156
+ wall_time_seconds=time.time() - started_at,
157
+ ),
158
+ correctness=correctness,
159
+ )
160
+ _write_result_report(self.store, result)
161
+ return result
162
+
163
+ performance, artifact = _run_traced_stage(
164
+ stage_traces,
165
+ llm,
166
+ "performance",
167
+ lambda: Stage4Performance(llm=llm, gpu=self.gpu, store=self.store, cfg=self.cfg).run(
168
+ spec=spec,
169
+ artifact=artifact,
170
+ run_id=run_id,
171
+ retry_budget=self.cfg.retry_budgets.performance,
172
+ reference=reference,
173
+ ),
174
+ )
175
+ artifact = _run_traced_stage(
176
+ stage_traces,
177
+ llm,
178
+ "polish",
179
+ lambda: Stage5Polish(llm=llm, gpu=self.gpu, store=self.store).run(
180
+ spec=spec,
181
+ artifact=artifact,
182
+ correctness=correctness,
183
+ performance=performance,
184
+ reference=reference,
185
+ run_id=run_id,
186
+ model=self.cfg.sonnet_model,
187
+ correctness_shapes=self.cfg.correctness_shapes,
188
+ ),
189
+ )
190
+
191
+ report = _build_report(
192
+ run_id=run_id,
193
+ spec_name=spec.name,
194
+ stage_traces=stage_traces,
195
+ wall_time_seconds=time.time() - started_at,
196
+ warnings=["below perf target"] if performance.below_target else [],
197
+ )
198
+ result = SynthesisResult.ok(
199
+ run_id=run_id,
200
+ artifacts_dir=str(self.store.run_dir(run_id)),
201
+ report=report,
202
+ correctness=correctness,
203
+ performance=performance,
204
+ kernel_callable=None,
205
+ )
206
+ _write_result_report(self.store, result)
207
+ return result
208
+
209
+
210
+ def _run_codegen_with_escalation(
211
+ *,
212
+ llm: _TracingLLMClient,
213
+ gpu: GPURunner,
214
+ store: ArtifactStore,
215
+ cfg: SynthesisConfig,
216
+ run_args: dict[str, Any],
217
+ ) -> KernelArtifact:
218
+ """Run Stage2Codegen with Sonnet, escalating to Opus on BudgetExhaustedError."""
219
+ try:
220
+ return Stage2Codegen(llm=llm, gpu=gpu, store=store).run(
221
+ **run_args, model=cfg.sonnet_model
222
+ )
223
+ except BudgetExhaustedError as bust:
224
+ if not cfg.escalate_to_opus_on_bust or cfg.opus_retry_budget_codegen <= 0:
225
+ raise
226
+ opus_run_args = {
227
+ **run_args,
228
+ "retry_budget": cfg.opus_retry_budget_codegen,
229
+ "artifact_prefix": f"{run_args.get('artifact_prefix', 'stage2_codegen')}/escalated",
230
+ "escalation_context": bust.summary,
231
+ }
232
+ return Stage2Codegen(llm=llm, gpu=gpu, store=store).run(
233
+ **opus_run_args, model=cfg.opus_model
234
+ )
235
+
236
+
237
+ def _reference_source(reference: Callable[..., Any]) -> str:
238
+ try:
239
+ return inspect.getsource(reference)
240
+ except OSError:
241
+ return repr(reference)
242
+
243
+
244
+ class _TracingLLMClient(LLMClient):
245
+ def __init__(self, inner: LLMClient) -> None:
246
+ self._inner = inner
247
+ self.responses: list[LLMResponse] = []
248
+
249
+ def complete(
250
+ self,
251
+ *,
252
+ system: list[dict[str, Any]],
253
+ messages: list[dict[str, Any]],
254
+ tools: list[ToolSpec] | None = None,
255
+ model: str,
256
+ max_tokens: int = 4096,
257
+ temperature: float | None = None,
258
+ ) -> LLMResponse:
259
+ response = self._inner.complete(
260
+ system=system,
261
+ messages=messages,
262
+ tools=tools,
263
+ model=model,
264
+ max_tokens=max_tokens,
265
+ temperature=temperature,
266
+ )
267
+ self.responses.append(response)
268
+ return response
269
+
270
+
271
+ def _run_traced_stage(
272
+ stage_traces: list[StageTrace],
273
+ llm: _TracingLLMClient,
274
+ stage_name: str,
275
+ action: Callable[[], T],
276
+ *,
277
+ succeeded: Callable[[T], bool] | None = None,
278
+ ) -> T:
279
+ response_start = len(llm.responses)
280
+ started_at = time.time()
281
+ try:
282
+ result = action()
283
+ except Exception:
284
+ responses = llm.responses[response_start:]
285
+ stage_traces.append(_build_stage_trace(stage_name, responses, started_at, succeeded=False))
286
+ raise
287
+
288
+ responses = llm.responses[response_start:]
289
+ stage_traces.append(
290
+ _build_stage_trace(
291
+ stage_name,
292
+ responses,
293
+ started_at,
294
+ succeeded=succeeded(result) if succeeded is not None else True,
295
+ )
296
+ )
297
+ return result
298
+
299
+
300
+ def _build_stage_trace(
301
+ stage_name: str,
302
+ responses: list[LLMResponse],
303
+ started_at: float,
304
+ *,
305
+ succeeded: bool,
306
+ ) -> StageTrace:
307
+ reported_latency = sum(response.latency_seconds for response in responses)
308
+ return StageTrace(
309
+ stage_name=stage_name,
310
+ attempts=max(1, len(responses)),
311
+ succeeded=succeeded,
312
+ model_used=_model_summary(responses),
313
+ tokens_in=sum(response.tokens_in for response in responses),
314
+ tokens_out=sum(response.tokens_out for response in responses),
315
+ cache_read_tokens=sum(response.cache_read_tokens for response in responses),
316
+ latency_seconds=reported_latency if reported_latency > 0 else time.time() - started_at,
317
+ )
318
+
319
+
320
+ def _model_summary(responses: list[LLMResponse]) -> str:
321
+ if not responses:
322
+ return "none"
323
+ models: list[str] = []
324
+ for response in responses:
325
+ if response.model not in models:
326
+ models.append(response.model)
327
+ return ", ".join(models)
328
+
329
+
330
+ def _build_report(
331
+ *,
332
+ run_id: str,
333
+ spec_name: str,
334
+ stage_traces: list[StageTrace],
335
+ wall_time_seconds: float,
336
+ warnings: list[str] | None = None,
337
+ ) -> SynthesisReport:
338
+ return SynthesisReport(
339
+ run_id=run_id,
340
+ spec_name=spec_name,
341
+ stages_executed=[trace.stage_name for trace in stage_traces],
342
+ stage_traces=stage_traces,
343
+ total_llm_tokens_in=sum(trace.tokens_in for trace in stage_traces),
344
+ total_llm_tokens_out=sum(trace.tokens_out for trace in stage_traces),
345
+ wall_time_seconds=wall_time_seconds,
346
+ warnings=warnings or [],
347
+ )
348
+
349
+
350
+ def _write_result_report(store: ArtifactStore, result: SynthesisResult) -> None:
351
+ payload = result.model_dump(mode="json", exclude={"kernel_callable"})
352
+ store.write_json(result.run_id, "report.json", payload)
@@ -0,0 +1,8 @@
1
+ from importlib.resources import files
2
+
3
+
4
+ def load_prompt(name: str) -> str:
5
+ prompt_path = files(__package__).joinpath(f"{name}.md")
6
+ if not prompt_path.is_file():
7
+ raise FileNotFoundError(f"Prompt not found: {name}")
8
+ return prompt_path.read_text(encoding="utf-8")
@@ -0,0 +1,29 @@
1
+ # CUDA Codegen Stage
2
+
3
+ You generate a single CUDA `.cu` file for the frozen `KernelSpec`.
4
+
5
+ Required runnable ABI:
6
+ - The generated source must be a Torch-loadable C++/CUDA extension, not a raw CUDA-only library.
7
+ - Include the needed Torch headers, normally `#include <torch/extension.h>` and `#include <ATen/cuda/CUDAContext.h>`.
8
+ - Expose exactly one user-callable op: `cuda_engine::forward`.
9
+ - Register the schema with `TORCH_LIBRARY(cuda_engine, m)`.
10
+ - Register the CUDA implementation with `TORCH_LIBRARY_IMPL(cuda_engine, CUDA, m)`.
11
+ - The Python runner will call `torch.ops.cuda_engine.forward(*inputs)`, so the op signature must match the `KernelSpec` inputs and outputs.
12
+ - Return a single `torch::Tensor` for one output, or a tuple/list-compatible Torch return type for multiple outputs.
13
+
14
+ Rules:
15
+ - Honor the target architecture and the frozen input/output contract.
16
+ - Treat KernelSpec inputs with `shape: []` as scalar/0-D Torch tensors, not vectors.
17
+ - For reduction outputs, return tensors with the exact reduced shape in the KernelSpec.
18
+ Example: input `["B", "D"]` and output `["B"]` means one output element per row.
19
+ - For argmax kernels, return `int64` indices when the KernelSpec output dtype is `int64`.
20
+ - For RMSNorm fp16 kernels, use fp32 accumulation for the mean square and reciprocal square root,
21
+ do not add gamma unless the KernelSpec includes it, and cast the final output to fp16.
22
+ - For `sm_80`, prefer straightforward CUDA C++ suitable for A100.
23
+ - Make memory hierarchy choices explicit in comments when they affect performance.
24
+ - Use 256 threads per block as the default elementwise baseline unless the spec suggests otherwise.
25
+ - Output complete CUDA source as one fenced `cuda` code block.
26
+ - After generating the source, call `compile_kernel(src, target_arch)` using the exact source.
27
+ - If compilation fails, use the compiler errors to revise the source and call `compile_kernel` again.
28
+
29
+ Do not change dtypes, shapes, argument ordering, or precision tolerance.
@@ -0,0 +1,30 @@
1
+ # CUDA Kernel Interview Stage
2
+
3
+ You convert a user prompt plus Python reference metadata into a frozen `KernelSpec`.
4
+
5
+ Return only structured JSON, preferably in a fenced `json` code block. The JSON must match:
6
+
7
+ ```json
8
+ {
9
+ "name": "snake_case_kernel_name",
10
+ "target_arch": "sm_80",
11
+ "inputs": [{"name": "x", "dtype": "fp32", "shape": ["N"], "layout_hint": "any"}],
12
+ "outputs": [{"name": "out", "dtype": "fp32", "shape": ["N"], "layout_hint": "any"}],
13
+ "precision_tolerance": {"rtol": 0.001, "atol": 0.001},
14
+ "optimization_priority": "balanced",
15
+ "notes": "brief clarification notes"
16
+ }
17
+ ```
18
+
19
+ Rules:
20
+ - Do not invent unsupported target architectures.
21
+ - Use symbolic shapes when concrete shapes are unknown.
22
+ - Represent scalar or 0-D tensor inputs with an empty shape list: `"shape": []`.
23
+ - For reductions, shrink only the reduced dimensions. Example: last-dimension sum of `x[B,D]`
24
+ should use input shape `["B", "D"]` and output shape `["B"]`.
25
+ - For `argmax`, use an integer output dtype, normally `int64`, with the reduced output shape.
26
+ - For fp16 RMSNorm without gamma, use fp16 input/output shapes that match, fp32 accumulation
27
+ semantics in `notes`, and a practical fp16 tolerance such as `rtol=0.01`, `atol=0.01`.
28
+ - Preserve the user's requested operation; do not broaden scope.
29
+ - Prefer `throughput` for large elementwise/reduction prompts and `latency` only when the prompt explicitly prioritizes small inputs.
30
+ - Use the reference metadata only to infer names and arity; if uncertain, choose conservative defaults and explain in `notes`.
@@ -0,0 +1,56 @@
1
+ # CUDA Performance Repair
2
+
3
+ You revise a CUDA kernel that compiles and is correct but runs below the
4
+ performance target. Your job is to improve throughput without breaking
5
+ correctness, then call `compile_kernel(src, target_arch)` with the revised
6
+ source.
7
+
8
+ Required runnable ABI (unchanged from the previous kernel):
9
+ - Keep `cuda_engine::forward` as the only user-callable op.
10
+ - Keep the same `TORCH_LIBRARY(cuda_engine, m)` namespace, op signature,
11
+ argument order, dtypes, shapes, and return type.
12
+ - Keep correctness: outputs must remain within the KernelSpec precision
13
+ tolerance compared to the reference.
14
+
15
+ Inputs you will receive:
16
+ - The current `kernel.cu` source.
17
+ - The frozen `KernelSpec`.
18
+ - The latest `BenchmarkResult` (`custom_ms`, `baseline_ms`, achieved GB/s).
19
+ - A `NsightMetrics` snapshot (achieved occupancy, registers per thread,
20
+ spill bytes when available).
21
+ - Suggested optimization hints derived from those metrics.
22
+
23
+ Optimization themes to consider:
24
+ - **Register pressure**: high regs/thread reduces occupancy on A100
25
+ (max 64 regs/thread for full occupancy at 256-thread blocks). Split
26
+ work into more, smaller blocks; reduce live registers; only spill to
27
+ shared memory when necessary.
28
+ - **Occupancy**: low achieved occupancy means few warps are resident.
29
+ Investigate register, shared memory, or block-size limits.
30
+ - **Memory coalescing**: ensure 32 consecutive threads in a warp read
31
+ 128 consecutive bytes. Avoid strided global loads/stores; use
32
+ `__ldg` for read-only cached loads where appropriate.
33
+ - **Grid wave alignment**: A100 has 108 SMs. Choose grid sizes that
34
+ fill full waves; a partial-wave tail can waste up to 20% of runtime.
35
+ - **Shared-memory tiling**: for reductions, use 256-thread blocks with
36
+ `__shfl_down_sync` for warp-level reduction; store partial results
37
+ to shared memory only when the reduction crosses warp boundaries.
38
+ - **Vectorized loads**: `float4`/`__half2` loads can double effective
39
+ bandwidth for elementwise ops on aligned, contiguous data.
40
+ - **Simple fused elementwise kernels**: for one-pass pointwise or fused
41
+ pointwise work, prefer one coalesced read/compute/write pass with enough
42
+ blocks to cover the tensor. Do not add multi-pass reductions, shared-memory
43
+ staging, or complicated synchronization unless the KernelSpec actually
44
+ requires cross-element communication.
45
+
46
+ Matching torch.compile is acceptable but not the goal. To strictly beat it on A100:
47
+ - Prefer vectorized memory ops: `float4` for fp32, `__half2` for fp16. They double effective bandwidth on aligned contiguous data.
48
+ - Align grid to A100's 108 SMs. A full wave is a multiple of 108 blocks; a partial tail wave wastes runtime. For tensors that don't divide evenly, prefer fewer-larger blocks over more-smaller.
49
+ - Maximize instruction-level parallelism: `#pragma unroll` inner loops with small bounded trip count. Keep enough independent work per thread to hide arithmetic and memory latency.
50
+ - Fuse passes when the KernelSpec permits. Reductions followed by elementwise can often be one-pass with `__shfl_down_sync` warp reductions.
51
+ - Inspect register pressure first if Nsight shows occupancy < 50%. If regs/thread > 64 on a 256-thread block, work-split or block-size reduction frees waves.
52
+
53
+ Output the complete revised CUDA source as one fenced `cuda` code block,
54
+ then call `compile_kernel(src, target_arch)` with the exact source.
55
+
56
+ Do not change dtypes, shapes, argument ordering, or precision tolerance.
@@ -0,0 +1,13 @@
1
+ # CUDA Kernel Polish Stage
2
+
3
+ You annotate an already-correct CUDA kernel for maintainability.
4
+
5
+ Return only the complete annotated CUDA source in a fenced `cuda` code block.
6
+
7
+ Annotations should explain:
8
+ - tile size and launch configuration choices
9
+ - memory layout and coalescing assumptions
10
+ - precision tolerance and correctness summary
11
+ - performance summary, including speedups and any occupancy/register notes when available
12
+
13
+ Do not change behavior, signatures, namespace registration, or the `cuda_engine::forward` ABI.
@@ -0,0 +1 @@
1
+ """Service interfaces and implementations."""
@@ -0,0 +1,3 @@
1
+ from cuda_engine.services.gpu.base import CompileResult, GPURunner, NsightMetrics, RunResult
2
+
3
+ __all__ = ["CompileResult", "GPURunner", "NsightMetrics", "RunResult"]