cuda-engine 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. cuda_engine/__init__.py +24 -0
  2. cuda_engine/api.py +39 -0
  3. cuda_engine/cli.py +485 -0
  4. cuda_engine/config.py +32 -0
  5. cuda_engine/models/__init__.py +27 -0
  6. cuda_engine/models/artifact.py +12 -0
  7. cuda_engine/models/reports.py +106 -0
  8. cuda_engine/models/spec.py +45 -0
  9. cuda_engine/orchestrator.py +352 -0
  10. cuda_engine/prompts/__init__.py +8 -0
  11. cuda_engine/prompts/codegen.md +29 -0
  12. cuda_engine/prompts/interview.md +30 -0
  13. cuda_engine/prompts/perf_fix.md +56 -0
  14. cuda_engine/prompts/polish.md +13 -0
  15. cuda_engine/services/__init__.py +1 -0
  16. cuda_engine/services/gpu/__init__.py +3 -0
  17. cuda_engine/services/gpu/_run_kernel_child.py +305 -0
  18. cuda_engine/services/gpu/base.py +88 -0
  19. cuda_engine/services/gpu/local.py +451 -0
  20. cuda_engine/services/gpu/mocks.py +85 -0
  21. cuda_engine/services/llm/__init__.py +3 -0
  22. cuda_engine/services/llm/anthropic.py +71 -0
  23. cuda_engine/services/llm/base.py +35 -0
  24. cuda_engine/services/llm/mocks.py +38 -0
  25. cuda_engine/services/llm/tools.py +64 -0
  26. cuda_engine/services/store/__init__.py +3 -0
  27. cuda_engine/services/store/base.py +24 -0
  28. cuda_engine/services/store/local_dir.py +42 -0
  29. cuda_engine/services/store/mocks.py +27 -0
  30. cuda_engine/stages/__init__.py +1 -0
  31. cuda_engine/stages/base.py +41 -0
  32. cuda_engine/stages/codegen.py +193 -0
  33. cuda_engine/stages/correctness.py +241 -0
  34. cuda_engine/stages/interview.py +117 -0
  35. cuda_engine/stages/performance.py +424 -0
  36. cuda_engine/stages/polish.py +152 -0
  37. cuda_engine/targets/__init__.py +7 -0
  38. cuda_engine/targets/sm_100.py +2 -0
  39. cuda_engine/targets/sm_80.py +18 -0
  40. cuda_engine/targets/sm_90.py +2 -0
  41. cuda_engine-1.0.0.dist-info/METADATA +266 -0
  42. cuda_engine-1.0.0.dist-info/RECORD +45 -0
  43. cuda_engine-1.0.0.dist-info/WHEEL +4 -0
  44. cuda_engine-1.0.0.dist-info/entry_points.txt +2 -0
  45. cuda_engine-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,24 @@
1
+ from cuda_engine.api import synthesize
2
+ from cuda_engine.config import RetryBudgets, SynthesisConfig
3
+ from cuda_engine.models import (
4
+ CorrectnessReport,
5
+ KernelArtifact,
6
+ KernelSpec,
7
+ PerformanceReport,
8
+ SynthesisReport,
9
+ SynthesisResult,
10
+ )
11
+
12
+ __all__ = [
13
+ "CorrectnessReport",
14
+ "KernelArtifact",
15
+ "KernelSpec",
16
+ "PerformanceReport",
17
+ "RetryBudgets",
18
+ "SynthesisConfig",
19
+ "SynthesisReport",
20
+ "SynthesisResult",
21
+ "synthesize",
22
+ ]
23
+
24
+ __version__ = "0.0.1"
cuda_engine/api.py ADDED
@@ -0,0 +1,39 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ from cuda_engine.config import SynthesisConfig
5
+ from cuda_engine.models import SynthesisResult
6
+ from cuda_engine.orchestrator import Orchestrator
7
+ from cuda_engine.services.gpu.base import GPURunner
8
+ from cuda_engine.services.llm.base import LLMClient
9
+ from cuda_engine.services.store.base import ArtifactStore
10
+
11
+
12
+ def synthesize(
13
+ prompt: str,
14
+ reference: Callable[..., Any],
15
+ target: str = "sm_80",
16
+ config: SynthesisConfig | None = None,
17
+ *,
18
+ _llm: LLMClient | None = None,
19
+ _gpu: GPURunner | None = None,
20
+ _store: ArtifactStore | None = None,
21
+ ) -> SynthesisResult:
22
+ """Synthesize a CUDA kernel from English prompt + PyTorch reference."""
23
+
24
+ cfg = config or SynthesisConfig()
25
+ if _llm is None:
26
+ from cuda_engine.services.llm.anthropic import AnthropicClient
27
+
28
+ _llm = AnthropicClient(cfg=cfg)
29
+ if _gpu is None:
30
+ from cuda_engine.services.gpu.local import LocalGPURunner
31
+
32
+ _gpu = LocalGPURunner(cfg=cfg)
33
+ if _store is None:
34
+ from cuda_engine.services.store.local_dir import LocalDirStore
35
+
36
+ _store = LocalDirStore(cfg=cfg)
37
+
38
+ orchestrator = Orchestrator(llm=_llm, gpu=_gpu, store=_store, cfg=cfg)
39
+ return orchestrator.run(prompt=prompt, reference=reference, target=target)
cuda_engine/cli.py ADDED
@@ -0,0 +1,485 @@
1
+ import importlib.util
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+ from typing import Annotated, Any
7
+
8
+ import typer
9
+
10
+ from cuda_engine.config import SynthesisConfig
11
+
12
+ app = typer.Typer(help="CUDA synthesis engine CLI.")
13
+
14
+
15
+ @app.callback()
16
+ def main() -> None:
17
+ """Command-line entry point placeholder for M0."""
18
+
19
+
20
+ @app.command("show-report")
21
+ def show_report(run_dir: Path) -> None:
22
+ """Print a compact summary for a run directory containing report.json."""
23
+ report_path = run_dir / "report.json"
24
+ if not report_path.exists():
25
+ typer.echo(f"report.json not found: {report_path}")
26
+ raise typer.Exit(code=1)
27
+ _print_report_summary(report_path)
28
+
29
+
30
+ @app.command("latest-report")
31
+ def latest_report(runs_root: Path) -> None:
32
+ """Print the newest report.json summary under a runs root."""
33
+ report_path = _latest_report_path(runs_root)
34
+ if report_path is None:
35
+ typer.echo(f"no report.json files found under: {runs_root}")
36
+ raise typer.Exit(code=1)
37
+ _print_report_summary(report_path)
38
+
39
+
40
+ @app.command("synthesize")
41
+ def synthesize_cmd(
42
+ prompt: Annotated[
43
+ str | None,
44
+ typer.Option(
45
+ "--prompt",
46
+ help="Prompt text. Mutually exclusive with --prompt-file.",
47
+ ),
48
+ ] = None,
49
+ prompt_file: Annotated[
50
+ Path | None,
51
+ typer.Option("--prompt-file", help="Path to a text file containing the prompt."),
52
+ ] = None,
53
+ reference: Annotated[
54
+ Path,
55
+ typer.Option(
56
+ "--reference",
57
+ help="Python file defining the reference function (variable REFERENCE or reference()).",
58
+ ),
59
+ ] = ..., # type: ignore[assignment]
60
+ target: Annotated[str, typer.Option(help="CUDA target architecture.")] = "sm_80",
61
+ out: Annotated[
62
+ Path | None,
63
+ typer.Option(
64
+ "--out",
65
+ help="Artifact root for the run. Defaults to ~/.cache/cuda_engine/runs/.",
66
+ ),
67
+ ] = None,
68
+ ) -> None:
69
+ """Synthesize a single CUDA kernel from a prompt + reference function."""
70
+ if prompt is None and prompt_file is None:
71
+ typer.echo("error: one of --prompt or --prompt-file is required")
72
+ raise typer.Exit(code=2)
73
+ if prompt is not None and prompt_file is not None:
74
+ typer.echo("error: --prompt and --prompt-file are mutually exclusive")
75
+ raise typer.Exit(code=2)
76
+
77
+ if prompt is not None:
78
+ prompt_text = prompt
79
+ else:
80
+ assert prompt_file is not None # narrowed by validation above
81
+ prompt_text = _read_text(prompt_file)
82
+ reference_fn = _load_reference_from_path(reference)
83
+
84
+ synthesize_fn = _resolve_synthesize_fn()
85
+ config = SynthesisConfig(artifact_root=str(out)) if out is not None else SynthesisConfig()
86
+ result = synthesize_fn(
87
+ prompt=prompt_text,
88
+ reference=reference_fn,
89
+ target=target,
90
+ config=config,
91
+ )
92
+
93
+ typer.echo(f"Run: {result.run_id}")
94
+ typer.echo(f"Status: {'PASS' if result.passed else 'FAIL'}")
95
+ typer.echo(f"Artifacts: {result.artifacts_dir}")
96
+ if not result.passed:
97
+ typer.echo(f"Failed stage: {result.failed_stage}")
98
+ typer.echo(f"Reason: {result.failure_reason}")
99
+ raise typer.Exit(code=1)
100
+
101
+
102
+ @app.command("inspect")
103
+ def inspect_run(
104
+ run: Annotated[
105
+ str,
106
+ typer.Argument(
107
+ help="Run id or path to a run directory containing report.json.",
108
+ ),
109
+ ],
110
+ runs_root: Annotated[
111
+ Path | None,
112
+ typer.Option(
113
+ "--runs-root",
114
+ help="Directory containing run subdirectories. Defaults to ~/.cache/cuda_engine/runs.",
115
+ ),
116
+ ] = None,
117
+ ) -> None:
118
+ """Pretty-print the report for a synthesis run."""
119
+ run_dir = _resolve_run_dir(run, runs_root)
120
+ if run_dir is None:
121
+ typer.echo(f"run not found: {run}")
122
+ raise typer.Exit(code=1)
123
+ report_path = run_dir / "report.json"
124
+ if not report_path.exists():
125
+ typer.echo(f"report.json not found: {report_path}")
126
+ raise typer.Exit(code=1)
127
+ _print_report_summary(report_path)
128
+
129
+
130
+ def _resolve_run_dir(run: str, runs_root: Path | None) -> Path | None:
131
+ direct = Path(run)
132
+ if direct.is_dir() and (direct / "report.json").exists():
133
+ return direct
134
+ root = runs_root if runs_root is not None else Path.home() / ".cache" / "cuda_engine" / "runs"
135
+ if not root.exists():
136
+ return None
137
+ candidate = root / run
138
+ if candidate.is_dir() and (candidate / "report.json").exists():
139
+ return candidate
140
+ # Tolerate truncated run_ids: pick the unique match if any.
141
+ matches = [
142
+ path for path in root.iterdir()
143
+ if path.is_dir() and path.name.startswith(run) and (path / "report.json").exists()
144
+ ]
145
+ if len(matches) == 1:
146
+ return matches[0]
147
+ return None
148
+
149
+
150
+ def _read_text(path: Path) -> str:
151
+ try:
152
+ return path.read_text(encoding="utf-8")
153
+ except OSError as exc:
154
+ typer.echo(f"could not read prompt file: {exc}")
155
+ raise typer.Exit(code=1) from exc
156
+
157
+
158
+ def _load_reference_from_path(reference_path: Path) -> Any:
159
+ if not reference_path.exists():
160
+ typer.echo(f"reference file not found: {reference_path}")
161
+ raise typer.Exit(code=1)
162
+ module_name = f"cuda_engine_cli_reference_{reference_path.stem}"
163
+ spec = importlib.util.spec_from_file_location(module_name, reference_path)
164
+ if spec is None or spec.loader is None:
165
+ typer.echo(f"could not load reference module: {reference_path}")
166
+ raise typer.Exit(code=1)
167
+ module = importlib.util.module_from_spec(spec)
168
+ try:
169
+ spec.loader.exec_module(module)
170
+ except Exception as exc:
171
+ typer.echo(f"reference module failed to import: {exc}")
172
+ raise typer.Exit(code=1) from exc
173
+ reference = getattr(module, "REFERENCE", None) or getattr(module, "reference", None)
174
+ if not callable(reference):
175
+ typer.echo(f"reference file must define REFERENCE or reference(): {reference_path}")
176
+ raise typer.Exit(code=1)
177
+ return reference
178
+
179
+
180
+ def _resolve_synthesize_fn() -> Any:
181
+ from cuda_engine import synthesize as synthesize_fn
182
+
183
+ return synthesize_fn
184
+
185
+
186
+ @app.command("eval")
187
+ def eval_suite(
188
+ out: Annotated[Path, typer.Option("--out", help="Directory for aggregate eval outputs.")],
189
+ suite: Annotated[
190
+ str,
191
+ typer.Option(help="Suite name ('internal') or path to a suite directory."),
192
+ ] = "internal",
193
+ baseline: Annotated[
194
+ Path | None,
195
+ typer.Option(help="Optional prior results directory."),
196
+ ] = None,
197
+ target: Annotated[str, typer.Option(help="CUDA target architecture.")] = "sm_80",
198
+ only: Annotated[
199
+ str | None,
200
+ typer.Option(help="Comma-separated kernel names to run."),
201
+ ] = None,
202
+ limit: Annotated[
203
+ int | None,
204
+ typer.Option(help="Maximum number of selected kernels to run."),
205
+ ] = None,
206
+ resume: Annotated[
207
+ bool,
208
+ typer.Option("--resume/--no-resume", help="Skip kernels with existing JSON results."),
209
+ ] = True,
210
+ yes: Annotated[
211
+ bool,
212
+ typer.Option("--yes", "-y", help="Skip the large-run confirmation prompt."),
213
+ ] = False,
214
+ ) -> None:
215
+ """Run an eval suite and write aggregate markdown/CSV results."""
216
+ eval_runner = _load_eval_runner()
217
+ suite_root = _resolve_suite_root(suite)
218
+
219
+ if not yes:
220
+ n = _count_suite_kernels(suite_root, _parse_only(only), limit)
221
+ if n > 5:
222
+ low = n * 0.10
223
+ high = n * 0.40
224
+ typer.echo(
225
+ f"About to synthesize {n} kernels "
226
+ f"(est. ${low:.0f}-${high:.0f} in API credits, opus escalation off by default).\n"
227
+ f"Tip: use --only k1,k2 to target specific kernels, or --yes to skip this prompt."
228
+ )
229
+ confirmed = typer.confirm("Continue?", default=False)
230
+ if not confirmed:
231
+ raise typer.Exit(code=0)
232
+
233
+ summary = eval_runner.run_eval_suite(
234
+ suite_root=suite_root,
235
+ out_dir=out,
236
+ baseline_dir=baseline,
237
+ target=target,
238
+ config=SynthesisConfig(),
239
+ only=_parse_only(only),
240
+ limit=limit,
241
+ resume=resume,
242
+ progress=typer.echo,
243
+ )
244
+ passed = sum(1 for row in summary.rows if row.passed)
245
+ typer.echo(f"Eval complete: {passed}/{len(summary.rows)} passed")
246
+ typer.echo(f"CSV: {summary.csv_path}")
247
+ typer.echo(f"Summary: {summary.markdown_path}")
248
+
249
+
250
+ def _count_suite_kernels(suite_root: Path, only: set[str] | None, limit: int | None) -> int:
251
+ if not suite_root.exists():
252
+ return 0
253
+ names = sorted(d.name for d in suite_root.iterdir() if d.is_dir() and (d / "prompt.txt").exists())
254
+ if only is not None:
255
+ names = [n for n in names if n in only]
256
+ if limit is not None:
257
+ names = names[:limit]
258
+ return len(names)
259
+
260
+
261
+ def _resolve_suite_root(suite: str) -> Path:
262
+ """Map well-known suite names to their on-disk directories; passthrough paths."""
263
+ known: dict[str, Path] = {
264
+ "internal": Path("evals") / "internal",
265
+ "kernelbench": Path("evals") / "kernelbench" / "filtered",
266
+ }
267
+ return known.get(suite, Path(suite))
268
+
269
+
270
+ def _load_eval_runner() -> ModuleType:
271
+ try:
272
+ from evals import runner as eval_runner
273
+
274
+ return eval_runner
275
+ except ModuleNotFoundError as exc:
276
+ for root in (Path.cwd(), *Path.cwd().parents):
277
+ runner_path = root / "evals" / "runner.py"
278
+ if runner_path.exists():
279
+ spec = importlib.util.spec_from_file_location(
280
+ "cuda_engine_source_eval_runner",
281
+ runner_path,
282
+ )
283
+ if spec is None or spec.loader is None:
284
+ break
285
+ module = importlib.util.module_from_spec(spec)
286
+ sys.modules[spec.name] = module
287
+ spec.loader.exec_module(module)
288
+ return module
289
+ raise ModuleNotFoundError(
290
+ "could not import evals.runner; run this command from a cuda-engine source checkout"
291
+ ) from exc
292
+
293
+
294
+ def _parse_only(value: str | None) -> set[str] | None:
295
+ if value is None:
296
+ return None
297
+ names = {item.strip() for item in value.split(",") if item.strip()}
298
+ return names or None
299
+
300
+
301
+ def _print_report_summary(report_path: Path) -> None:
302
+ payload = _load_report(report_path)
303
+ report = _dict(payload.get("report"))
304
+ correctness = payload.get("correctness")
305
+ performance = payload.get("performance")
306
+
307
+ typer.echo(f"Run: {payload.get('run_id', report.get('run_id', 'unknown'))}")
308
+ typer.echo(f"Status: {'PASS' if payload.get('passed') else 'FAIL'}")
309
+ typer.echo(f"Spec: {report.get('spec_name', 'unknown')}")
310
+ typer.echo(f"Stages: {' -> '.join(_strings(report.get('stages_executed')))}")
311
+ typer.echo(
312
+ "LLM tokens: "
313
+ f"{int(report.get('total_llm_tokens_in', 0))} in / "
314
+ f"{int(report.get('total_llm_tokens_out', 0))} out"
315
+ )
316
+ _print_stage_traces(report.get("stage_traces"))
317
+
318
+ if not payload.get("passed"):
319
+ typer.echo(f"Failed stage: {payload.get('failed_stage')}")
320
+ typer.echo(f"Reason: {payload.get('failure_reason')}")
321
+
322
+ if isinstance(correctness, dict):
323
+ _print_correctness(correctness)
324
+ else:
325
+ typer.echo("Correctness: not available")
326
+
327
+ _print_performance(performance, report_path.parent)
328
+ _print_polish_artifacts(report_path.parent)
329
+ _print_repair_artifacts(report_path.parent)
330
+
331
+ warnings = _strings(report.get("warnings"))
332
+ if warnings:
333
+ typer.echo(f"Warnings: {', '.join(warnings)}")
334
+
335
+ typer.echo(f"Artifacts: {payload.get('artifacts_dir', str(report_path.parent))}")
336
+
337
+
338
+ def _latest_report_path(runs_root: Path) -> Path | None:
339
+ if not runs_root.exists():
340
+ return None
341
+ report_paths = [path for path in runs_root.rglob("report.json") if path.is_file()]
342
+ if not report_paths:
343
+ return None
344
+ return max(report_paths, key=lambda path: path.stat().st_mtime)
345
+
346
+
347
+ def _load_report(report_path: Path) -> dict[str, Any]:
348
+ try:
349
+ data = json.loads(report_path.read_text(encoding="utf-8"))
350
+ except json.JSONDecodeError as exc:
351
+ typer.echo(f"report.json could not be decoded: {exc}")
352
+ raise typer.Exit(code=1) from exc
353
+ if not isinstance(data, dict):
354
+ typer.echo("report.json must contain an object")
355
+ raise typer.Exit(code=1)
356
+ return data
357
+
358
+
359
+ def _dict(value: object) -> dict[str, Any]:
360
+ return value if isinstance(value, dict) else {}
361
+
362
+
363
+ def _strings(value: object) -> list[str]:
364
+ if not isinstance(value, list):
365
+ return []
366
+ return [str(item) for item in value]
367
+
368
+
369
+ def _print_stage_traces(value: object) -> None:
370
+ if not isinstance(value, list) or not value:
371
+ return
372
+
373
+ typer.echo("Stage traces:")
374
+ for item in value:
375
+ trace = _dict(item)
376
+ status = "ok" if trace.get("succeeded") else "failed"
377
+ typer.echo(
378
+ f"- {trace.get('stage_name', 'unknown')}: "
379
+ f"{status} "
380
+ f"attempts={int(trace.get('attempts', 0))} "
381
+ f"model={trace.get('model_used', 'unknown')} "
382
+ f"tokens={int(trace.get('tokens_in', 0))}/{int(trace.get('tokens_out', 0))} "
383
+ f"cache_read={int(trace.get('cache_read_tokens', 0))}"
384
+ )
385
+
386
+
387
+ def _print_correctness(correctness: dict[str, Any]) -> None:
388
+ if correctness.get("passed"):
389
+ typer.echo("Correctness: PASS")
390
+ return
391
+
392
+ typer.echo(
393
+ "Correctness: FAIL "
394
+ f"max_abs_err={correctness.get('max_abs_err')} "
395
+ f"max_rel_err={correctness.get('max_rel_err')}"
396
+ )
397
+ failing_inputs = correctness.get("failing_inputs")
398
+ if isinstance(failing_inputs, list) and failing_inputs:
399
+ first_failure = _dict(failing_inputs[0])
400
+ typer.echo(
401
+ "First failure: "
402
+ f"shape={first_failure.get('shape')} "
403
+ f"error={first_failure.get('error')}"
404
+ )
405
+
406
+
407
+ def _print_performance(performance: object, run_dir: Path) -> None:
408
+ if not isinstance(performance, dict):
409
+ typer.echo("Performance: not available")
410
+ return
411
+
412
+ def _fmt(value: object) -> str:
413
+ if value is None:
414
+ return "n/a"
415
+ try:
416
+ return f"{float(value):.2f}" # type: ignore[arg-type]
417
+ except (TypeError, ValueError):
418
+ return "n/a"
419
+
420
+ typer.echo(
421
+ "Performance: "
422
+ f"speedup_vs_reference={_fmt(performance.get('speedup_vs_reference'))}, "
423
+ f"speedup_vs_torch_compile={_fmt(performance.get('speedup_vs_torch_compile'))}"
424
+ )
425
+
426
+ achieved_gbps = performance.get("achieved_gbps")
427
+ if achieved_gbps is not None:
428
+ typer.echo(f"Bandwidth: achieved_gbps={float(achieved_gbps):.2f}")
429
+
430
+ typer.echo(f"Below target: {str(bool(performance.get('below_target', False))).lower()}")
431
+
432
+ notes = _strings(performance.get("notes"))
433
+ if notes:
434
+ typer.echo(f"Performance notes: {', '.join(notes)}")
435
+
436
+ benchmark_path = run_dir / "stage4_performance" / "benchmark.json"
437
+ if benchmark_path.exists():
438
+ typer.echo(f"Benchmark: {benchmark_path}")
439
+
440
+
441
+ def _print_repair_artifacts(run_dir: Path) -> None:
442
+ repair_root = run_dir / "stage3_repair"
443
+ if not repair_root.exists():
444
+ return
445
+
446
+ attempt_dirs = sorted(path for path in repair_root.glob("attempt_*") if path.is_dir())
447
+ if not attempt_dirs:
448
+ return
449
+
450
+ typer.echo(f"Correctness repairs: {len(attempt_dirs)}")
451
+ for attempt_dir in attempt_dirs:
452
+ report_path = attempt_dir / "correctness_report.json"
453
+ kernel_path = attempt_dir / "codegen" / "final" / "kernel.cu"
454
+ if report_path.exists():
455
+ typer.echo(f"- correctness_report: {report_path}")
456
+ if kernel_path.exists():
457
+ typer.echo(f"- repaired_kernel: {kernel_path}")
458
+
459
+
460
+ def _print_polish_artifacts(run_dir: Path) -> None:
461
+ status_path = run_dir / "stage5_polish" / "status.json"
462
+ if not status_path.exists():
463
+ return
464
+
465
+ try:
466
+ status = json.loads(status_path.read_text(encoding="utf-8"))
467
+ except json.JSONDecodeError:
468
+ typer.echo(f"Polish: status unreadable at {status_path}")
469
+ return
470
+ if not isinstance(status, dict):
471
+ typer.echo(f"Polish: status unreadable at {status_path}")
472
+ return
473
+
474
+ accepted = bool(status.get("accepted", False))
475
+ typer.echo(f"Polish: {'accepted' if accepted else 'rejected'}")
476
+ reason = status.get("reason")
477
+ if reason:
478
+ typer.echo(f"Polish reason: {reason}")
479
+
480
+ kernel_path = run_dir / "stage5_polish" / "final" / "kernel.cu"
481
+ if not kernel_path.exists():
482
+ raw_path = status.get("kernel_cu_path")
483
+ kernel_path = Path(str(raw_path)) if raw_path else kernel_path
484
+ if kernel_path.exists() or status.get("kernel_cu_path"):
485
+ typer.echo(f"Polished kernel: {kernel_path}")
cuda_engine/config.py ADDED
@@ -0,0 +1,32 @@
1
+ from pydantic import BaseModel, ConfigDict, Field
2
+
3
+
4
+ class RetryBudgets(BaseModel):
5
+ model_config = ConfigDict(frozen=True)
6
+
7
+ interview: int = 1
8
+ codegen: int = 3
9
+ correctness: int = 3
10
+ performance: int = 3
11
+ polish: int = 1
12
+
13
+
14
+ class SynthesisConfig(BaseModel):
15
+ model_config = ConfigDict(frozen=True)
16
+
17
+ retry_budgets: RetryBudgets = Field(default_factory=RetryBudgets)
18
+ escalate_to_opus_on_bust: bool = False
19
+ perf_target_speedup_vs_torch_compile: float = 1.0
20
+ correctness_rtol: float = 1e-3
21
+ correctness_atol: float = 1e-3
22
+ correctness_shapes: tuple[tuple[int, ...], ...] = ((0,), (1,), (127,), (128,), (1024,), (4097,))
23
+ nvcc_flags: tuple[str, ...] = ("-O3", "--use_fast_math")
24
+ artifact_root: str | None = None
25
+ performance_shape_n: int = 16_777_216
26
+ benchmark_warmup_iterations: int = 10
27
+ benchmark_timed_iterations: int = 100
28
+ sonnet_model: str = "claude-sonnet-4-6"
29
+ opus_model: str = "claude-opus-4-7"
30
+ opus_retry_budget_codegen: int = 1
31
+ opus_retry_budget_performance: int = 1
32
+ request_timeout_seconds: int = 120
@@ -0,0 +1,27 @@
1
+ from cuda_engine.models.artifact import KernelArtifact
2
+ from cuda_engine.models.reports import (
3
+ CorrectnessReport,
4
+ PerformanceReport,
5
+ StageTrace,
6
+ SynthesisReport,
7
+ SynthesisResult,
8
+ )
9
+ from cuda_engine.models.spec import (
10
+ KernelSpec,
11
+ OptimizationPriority,
12
+ PrecisionTolerance,
13
+ TensorArg,
14
+ )
15
+
16
+ __all__ = [
17
+ "CorrectnessReport",
18
+ "KernelArtifact",
19
+ "KernelSpec",
20
+ "OptimizationPriority",
21
+ "PerformanceReport",
22
+ "PrecisionTolerance",
23
+ "StageTrace",
24
+ "SynthesisReport",
25
+ "SynthesisResult",
26
+ "TensorArg",
27
+ ]
@@ -0,0 +1,12 @@
1
+ from pathlib import Path
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+
6
+ class KernelArtifact(BaseModel):
7
+ model_config = ConfigDict(frozen=True)
8
+
9
+ kernel_cu_path: Path
10
+ kernel_so_path: Path | None = None
11
+ compile_log: str = ""
12
+ ptx_size_bytes: int = 0