emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,6 @@
1
+ """ONNX to C compiler MVP."""
2
+
3
+ from .compiler import Compiler
4
+ from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
5
+
6
+ __all__ = ["Compiler", "CodegenError", "ShapeInferenceError", "UnsupportedOpError"]
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+
5
+ from .cli import main
6
+
7
+
8
+ if __name__ == "__main__":
9
+ raise SystemExit(main(sys.argv[1:]))
@@ -0,0 +1,3 @@
1
+ """Auto-generated by build backend. Do not edit."""
2
+ BUILD_DATE = '2026-01-15T22:18:39Z'
3
+ GIT_VERSION = 'unknown'
emx_onnx_cgen/cli.py ADDED
@@ -0,0 +1,328 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import hashlib
5
+ import json
6
+ import logging
7
+ import os
8
+ import shlex
9
+ import shutil
10
+ import subprocess
11
+ import sys
12
+ import tempfile
13
+ from pathlib import Path
14
+ from typing import Sequence
15
+
16
+ import onnx
17
+
18
+ from ._build_info import BUILD_DATE, GIT_VERSION
19
+ from .compiler import Compiler, CompilerOptions
20
+ from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
21
+ from .onnx_import import import_onnx
22
+
23
+ LOGGER = logging.getLogger(__name__)
24
+
25
+
26
+ def _build_parser() -> argparse.ArgumentParser:
27
+ description = (
28
+ "emmtrix ONNX-to-C Code Generator "
29
+ f"(build date: {BUILD_DATE}, git: {GIT_VERSION})"
30
+ )
31
+ parser = argparse.ArgumentParser(prog="emx-onnx-cgen", description=description)
32
+ subparsers = parser.add_subparsers(dest="command", required=True)
33
+
34
+ def add_restrict_flags(subparser: argparse.ArgumentParser) -> None:
35
+ restrict_group = subparser.add_mutually_exclusive_group()
36
+ restrict_group.add_argument(
37
+ "--restrict-arrays",
38
+ dest="restrict_arrays",
39
+ action="store_true",
40
+ help="Enable restrict qualifiers on generated array parameters",
41
+ )
42
+ restrict_group.add_argument(
43
+ "--no-restrict-arrays",
44
+ dest="restrict_arrays",
45
+ action="store_false",
46
+ help="Disable restrict qualifiers on generated array parameters",
47
+ )
48
+ subparser.set_defaults(restrict_arrays=True)
49
+
50
+ compile_parser = subparsers.add_parser(
51
+ "compile", help="Compile an ONNX model into C source"
52
+ )
53
+ compile_parser.add_argument("model", type=Path, help="Path to the ONNX model")
54
+ compile_parser.add_argument(
55
+ "output",
56
+ type=Path,
57
+ nargs="?",
58
+ default=None,
59
+ help=(
60
+ "Output C file path (default: use model filename with .c suffix, "
61
+ "e.g., model.onnx -> model.c)"
62
+ ),
63
+ )
64
+ compile_parser.add_argument(
65
+ "--template-dir",
66
+ type=Path,
67
+ default=Path("templates"),
68
+ help="Template directory (default: templates)",
69
+ )
70
+ compile_parser.add_argument(
71
+ "--model-name",
72
+ type=str,
73
+ default=None,
74
+ help="Override the generated model name (default: output file stem)",
75
+ )
76
+ compile_parser.add_argument(
77
+ "--emit-testbench",
78
+ action="store_true",
79
+ help="Emit a JSON-producing testbench main() for validation",
80
+ )
81
+ compile_parser.add_argument(
82
+ "--emit-data-file",
83
+ action="store_true",
84
+ help=(
85
+ "Emit constant data arrays to a separate C file "
86
+ "named like the output with a _data suffix"
87
+ ),
88
+ )
89
+ add_restrict_flags(compile_parser)
90
+
91
+ verify_parser = subparsers.add_parser(
92
+ "verify",
93
+ help="Compile an ONNX model and verify outputs against ONNX Runtime",
94
+ )
95
+ verify_parser.add_argument("model", type=Path, help="Path to the ONNX model")
96
+ verify_parser.add_argument(
97
+ "--template-dir",
98
+ type=Path,
99
+ default=Path("templates"),
100
+ help="Template directory (default: templates)",
101
+ )
102
+ verify_parser.add_argument(
103
+ "--model-name",
104
+ type=str,
105
+ default=None,
106
+ help="Override the generated model name (default: model file stem)",
107
+ )
108
+ verify_parser.add_argument(
109
+ "--cc",
110
+ type=str,
111
+ default=None,
112
+ help="C compiler command to build the testbench binary",
113
+ )
114
+ add_restrict_flags(verify_parser)
115
+ return parser
116
+
117
+
118
+ def main(argv: Sequence[str] | None = None) -> int:
119
+ logging.basicConfig(level=logging.INFO)
120
+ parser = _build_parser()
121
+ args = parser.parse_args(argv)
122
+ args.command_line = _format_command_line(argv)
123
+
124
+ if args.command == "compile":
125
+ return _handle_compile(args)
126
+ if args.command == "verify":
127
+ return _handle_verify(args)
128
+ parser.error(f"Unknown command {args.command}")
129
+ return 1
130
+
131
+
132
+ def _handle_compile(args: argparse.Namespace) -> int:
133
+ model_path: Path = args.model
134
+ output_path: Path = args.output or model_path.with_suffix(".c")
135
+ model_name = args.model_name or output_path.stem
136
+ try:
137
+ model_checksum = _model_checksum(model_path)
138
+ model = onnx.load_model(model_path)
139
+ options = CompilerOptions(
140
+ template_dir=args.template_dir,
141
+ model_name=model_name,
142
+ emit_testbench=args.emit_testbench,
143
+ command_line=args.command_line,
144
+ model_checksum=model_checksum,
145
+ restrict_arrays=args.restrict_arrays,
146
+ )
147
+ compiler = Compiler(options)
148
+ if args.emit_data_file:
149
+ generated, data_source = compiler.compile_with_data_file(model)
150
+ else:
151
+ generated = compiler.compile(model)
152
+ data_source = None
153
+ except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
154
+ LOGGER.error("Failed to compile %s: %s", model_path, exc)
155
+ return 1
156
+
157
+ output_path.parent.mkdir(parents=True, exist_ok=True)
158
+ output_path.write_text(generated, encoding="utf-8")
159
+ LOGGER.info("Wrote C source to %s", output_path)
160
+ if data_source is not None:
161
+ data_path = output_path.with_name(
162
+ f"{output_path.stem}_data{output_path.suffix}"
163
+ )
164
+ data_path.write_text(data_source, encoding="utf-8")
165
+ LOGGER.info("Wrote data source to %s", data_path)
166
+ return 0
167
+
168
+
169
+ def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str] | None:
170
+ def maybe_prefix_ccache(tokens: list[str]) -> list[str]:
171
+ if not prefer_ccache:
172
+ return tokens
173
+ ccache = shutil.which("ccache")
174
+ if not ccache:
175
+ return tokens
176
+ return [ccache, *tokens]
177
+
178
+ def resolve_tokens(tokens: list[str]) -> list[str] | None:
179
+ if not tokens:
180
+ return None
181
+ if shutil.which(tokens[0]):
182
+ return tokens
183
+ for token in reversed(tokens):
184
+ if shutil.which(token):
185
+ return [token]
186
+ return None
187
+
188
+ if cc:
189
+ return resolve_tokens(shlex.split(cc))
190
+ env_cc = os.environ.get("CC")
191
+ if env_cc:
192
+ return resolve_tokens(shlex.split(env_cc))
193
+ for candidate in ("cc", "gcc", "clang"):
194
+ resolved = shutil.which(candidate)
195
+ if resolved:
196
+ return maybe_prefix_ccache([resolved])
197
+ return None
198
+
199
+
200
+ def _handle_verify(args: argparse.Namespace) -> int:
201
+ import numpy as np
202
+ import onnxruntime as ort
203
+
204
+ model_path: Path = args.model
205
+ model_name = args.model_name or model_path.stem
206
+ model_checksum = _model_checksum(model_path)
207
+ compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
208
+ if compiler_cmd is None:
209
+ LOGGER.error("No C compiler found (set --cc or CC environment variable).")
210
+ return 1
211
+ try:
212
+ model = onnx.load_model(model_path)
213
+ options = CompilerOptions(
214
+ template_dir=args.template_dir,
215
+ model_name=model_name,
216
+ emit_testbench=True,
217
+ command_line=args.command_line,
218
+ model_checksum=model_checksum,
219
+ restrict_arrays=args.restrict_arrays,
220
+ )
221
+ compiler = Compiler(options)
222
+ generated = compiler.compile(model)
223
+ except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
224
+ LOGGER.error("Failed to compile %s: %s", model_path, exc)
225
+ return 1
226
+
227
+ try:
228
+ graph = import_onnx(model)
229
+ output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
230
+ input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
231
+ except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
232
+ LOGGER.error("Failed to resolve model dtype: %s", exc)
233
+ return 1
234
+
235
+ with tempfile.TemporaryDirectory() as temp_dir:
236
+ temp_path = Path(temp_dir)
237
+ c_path = temp_path / "model.c"
238
+ exe_path = temp_path / "model"
239
+ c_path.write_text(generated, encoding="utf-8")
240
+ try:
241
+ subprocess.run(
242
+ [
243
+ *compiler_cmd,
244
+ "-std=c99",
245
+ "-O2",
246
+ str(c_path),
247
+ "-o",
248
+ str(exe_path),
249
+ "-lm",
250
+ ],
251
+ check=True,
252
+ capture_output=True,
253
+ text=True,
254
+ )
255
+ except subprocess.CalledProcessError as exc:
256
+ LOGGER.error("Failed to build testbench: %s", exc.stderr.strip())
257
+ return 1
258
+ try:
259
+ result = subprocess.run(
260
+ [str(exe_path)],
261
+ check=True,
262
+ capture_output=True,
263
+ text=True,
264
+ )
265
+ except subprocess.CalledProcessError as exc:
266
+ LOGGER.error("Testbench execution failed: %s", exc.stderr.strip())
267
+ return 1
268
+
269
+ try:
270
+ payload = json.loads(result.stdout)
271
+ except json.JSONDecodeError as exc:
272
+ LOGGER.error("Failed to parse testbench JSON: %s", exc)
273
+ return 1
274
+
275
+ inputs = {
276
+ name: np.array(value["data"], dtype=input_dtypes[name].np_dtype)
277
+ for name, value in payload["inputs"].items()
278
+ }
279
+ sess = ort.InferenceSession(
280
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
281
+ )
282
+ try:
283
+ ort_outputs = sess.run(None, inputs)
284
+ except Exception as exc:
285
+ message = str(exc)
286
+ if "NOT_IMPLEMENTED" in message:
287
+ LOGGER.warning(
288
+ "Skipping verification for %s: ONNX Runtime does not support the model (%s)",
289
+ model_path,
290
+ message,
291
+ )
292
+ return 0
293
+ LOGGER.error("ONNX Runtime failed to run %s: %s", model_path, message)
294
+ return 1
295
+ payload_outputs = payload.get("outputs", {})
296
+ try:
297
+ for value, ort_out in zip(graph.outputs, ort_outputs):
298
+ output_payload = payload_outputs.get(value.name)
299
+ if output_payload is None:
300
+ raise AssertionError(f"Missing output {value.name} in testbench data")
301
+ info = output_dtypes[value.name]
302
+ output_data = np.array(output_payload["data"], dtype=info.np_dtype)
303
+ if np.issubdtype(info.np_dtype, np.floating):
304
+ np.testing.assert_allclose(
305
+ output_data, ort_out, rtol=1e-4, atol=1e-5
306
+ )
307
+ else:
308
+ np.testing.assert_array_equal(output_data, ort_out)
309
+ except AssertionError as exc:
310
+ LOGGER.error("Verification failed: %s", exc)
311
+ return 1
312
+ LOGGER.info("Verification succeeded for %s", model_path)
313
+ return 0
314
+
315
+
316
+ def _format_command_line(argv: Sequence[str] | None) -> str:
317
+ if argv is None:
318
+ argv = sys.argv
319
+ args = [str(arg) for arg in argv[1:]]
320
+ if not args:
321
+ return ""
322
+ return shlex.join(args)
323
+
324
+
325
+ def _model_checksum(model_path: Path) -> str:
326
+ digest = hashlib.sha256()
327
+ digest.update(model_path.read_bytes())
328
+ return digest.hexdigest()
@@ -0,0 +1,25 @@
1
+ from .c_emitter import (
2
+ BinaryOp,
3
+ CEmitter,
4
+ CastOp,
5
+ ConstTensor,
6
+ ConstantOfShapeOp,
7
+ GemmOp,
8
+ LoweredModel,
9
+ MatMulOp,
10
+ ShapeOp,
11
+ UnaryOp,
12
+ )
13
+
14
+ __all__ = [
15
+ "BinaryOp",
16
+ "CEmitter",
17
+ "CastOp",
18
+ "ConstTensor",
19
+ "ConstantOfShapeOp",
20
+ "GemmOp",
21
+ "LoweredModel",
22
+ "MatMulOp",
23
+ "ShapeOp",
24
+ "UnaryOp",
25
+ ]