emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
emx_onnx_cgen/cli.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import shlex
|
|
9
|
+
import shutil
|
|
10
|
+
import subprocess
|
|
11
|
+
import sys
|
|
12
|
+
import tempfile
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Sequence
|
|
15
|
+
|
|
16
|
+
import onnx
|
|
17
|
+
|
|
18
|
+
from ._build_info import BUILD_DATE, GIT_VERSION
|
|
19
|
+
from .compiler import Compiler, CompilerOptions
|
|
20
|
+
from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
|
|
21
|
+
from .onnx_import import import_onnx
|
|
22
|
+
|
|
23
|
+
LOGGER = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
27
|
+
description = (
|
|
28
|
+
"emmtrix ONNX-to-C Code Generator "
|
|
29
|
+
f"(build date: {BUILD_DATE}, git: {GIT_VERSION})"
|
|
30
|
+
)
|
|
31
|
+
parser = argparse.ArgumentParser(prog="emx-onnx-cgen", description=description)
|
|
32
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
33
|
+
|
|
34
|
+
def add_restrict_flags(subparser: argparse.ArgumentParser) -> None:
|
|
35
|
+
restrict_group = subparser.add_mutually_exclusive_group()
|
|
36
|
+
restrict_group.add_argument(
|
|
37
|
+
"--restrict-arrays",
|
|
38
|
+
dest="restrict_arrays",
|
|
39
|
+
action="store_true",
|
|
40
|
+
help="Enable restrict qualifiers on generated array parameters",
|
|
41
|
+
)
|
|
42
|
+
restrict_group.add_argument(
|
|
43
|
+
"--no-restrict-arrays",
|
|
44
|
+
dest="restrict_arrays",
|
|
45
|
+
action="store_false",
|
|
46
|
+
help="Disable restrict qualifiers on generated array parameters",
|
|
47
|
+
)
|
|
48
|
+
subparser.set_defaults(restrict_arrays=True)
|
|
49
|
+
|
|
50
|
+
compile_parser = subparsers.add_parser(
|
|
51
|
+
"compile", help="Compile an ONNX model into C source"
|
|
52
|
+
)
|
|
53
|
+
compile_parser.add_argument("model", type=Path, help="Path to the ONNX model")
|
|
54
|
+
compile_parser.add_argument(
|
|
55
|
+
"output",
|
|
56
|
+
type=Path,
|
|
57
|
+
nargs="?",
|
|
58
|
+
default=None,
|
|
59
|
+
help=(
|
|
60
|
+
"Output C file path (default: use model filename with .c suffix, "
|
|
61
|
+
"e.g., model.onnx -> model.c)"
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
compile_parser.add_argument(
|
|
65
|
+
"--template-dir",
|
|
66
|
+
type=Path,
|
|
67
|
+
default=Path("templates"),
|
|
68
|
+
help="Template directory (default: templates)",
|
|
69
|
+
)
|
|
70
|
+
compile_parser.add_argument(
|
|
71
|
+
"--model-name",
|
|
72
|
+
type=str,
|
|
73
|
+
default=None,
|
|
74
|
+
help="Override the generated model name (default: output file stem)",
|
|
75
|
+
)
|
|
76
|
+
compile_parser.add_argument(
|
|
77
|
+
"--emit-testbench",
|
|
78
|
+
action="store_true",
|
|
79
|
+
help="Emit a JSON-producing testbench main() for validation",
|
|
80
|
+
)
|
|
81
|
+
compile_parser.add_argument(
|
|
82
|
+
"--emit-data-file",
|
|
83
|
+
action="store_true",
|
|
84
|
+
help=(
|
|
85
|
+
"Emit constant data arrays to a separate C file "
|
|
86
|
+
"named like the output with a _data suffix"
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
add_restrict_flags(compile_parser)
|
|
90
|
+
|
|
91
|
+
verify_parser = subparsers.add_parser(
|
|
92
|
+
"verify",
|
|
93
|
+
help="Compile an ONNX model and verify outputs against ONNX Runtime",
|
|
94
|
+
)
|
|
95
|
+
verify_parser.add_argument("model", type=Path, help="Path to the ONNX model")
|
|
96
|
+
verify_parser.add_argument(
|
|
97
|
+
"--template-dir",
|
|
98
|
+
type=Path,
|
|
99
|
+
default=Path("templates"),
|
|
100
|
+
help="Template directory (default: templates)",
|
|
101
|
+
)
|
|
102
|
+
verify_parser.add_argument(
|
|
103
|
+
"--model-name",
|
|
104
|
+
type=str,
|
|
105
|
+
default=None,
|
|
106
|
+
help="Override the generated model name (default: model file stem)",
|
|
107
|
+
)
|
|
108
|
+
verify_parser.add_argument(
|
|
109
|
+
"--cc",
|
|
110
|
+
type=str,
|
|
111
|
+
default=None,
|
|
112
|
+
help="C compiler command to build the testbench binary",
|
|
113
|
+
)
|
|
114
|
+
add_restrict_flags(verify_parser)
|
|
115
|
+
return parser
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def main(argv: Sequence[str] | None = None) -> int:
|
|
119
|
+
logging.basicConfig(level=logging.INFO)
|
|
120
|
+
parser = _build_parser()
|
|
121
|
+
args = parser.parse_args(argv)
|
|
122
|
+
args.command_line = _format_command_line(argv)
|
|
123
|
+
|
|
124
|
+
if args.command == "compile":
|
|
125
|
+
return _handle_compile(args)
|
|
126
|
+
if args.command == "verify":
|
|
127
|
+
return _handle_verify(args)
|
|
128
|
+
parser.error(f"Unknown command {args.command}")
|
|
129
|
+
return 1
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _handle_compile(args: argparse.Namespace) -> int:
|
|
133
|
+
model_path: Path = args.model
|
|
134
|
+
output_path: Path = args.output or model_path.with_suffix(".c")
|
|
135
|
+
model_name = args.model_name or output_path.stem
|
|
136
|
+
try:
|
|
137
|
+
model_checksum = _model_checksum(model_path)
|
|
138
|
+
model = onnx.load_model(model_path)
|
|
139
|
+
options = CompilerOptions(
|
|
140
|
+
template_dir=args.template_dir,
|
|
141
|
+
model_name=model_name,
|
|
142
|
+
emit_testbench=args.emit_testbench,
|
|
143
|
+
command_line=args.command_line,
|
|
144
|
+
model_checksum=model_checksum,
|
|
145
|
+
restrict_arrays=args.restrict_arrays,
|
|
146
|
+
)
|
|
147
|
+
compiler = Compiler(options)
|
|
148
|
+
if args.emit_data_file:
|
|
149
|
+
generated, data_source = compiler.compile_with_data_file(model)
|
|
150
|
+
else:
|
|
151
|
+
generated = compiler.compile(model)
|
|
152
|
+
data_source = None
|
|
153
|
+
except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
154
|
+
LOGGER.error("Failed to compile %s: %s", model_path, exc)
|
|
155
|
+
return 1
|
|
156
|
+
|
|
157
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
158
|
+
output_path.write_text(generated, encoding="utf-8")
|
|
159
|
+
LOGGER.info("Wrote C source to %s", output_path)
|
|
160
|
+
if data_source is not None:
|
|
161
|
+
data_path = output_path.with_name(
|
|
162
|
+
f"{output_path.stem}_data{output_path.suffix}"
|
|
163
|
+
)
|
|
164
|
+
data_path.write_text(data_source, encoding="utf-8")
|
|
165
|
+
LOGGER.info("Wrote data source to %s", data_path)
|
|
166
|
+
return 0
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str] | None:
|
|
170
|
+
def maybe_prefix_ccache(tokens: list[str]) -> list[str]:
|
|
171
|
+
if not prefer_ccache:
|
|
172
|
+
return tokens
|
|
173
|
+
ccache = shutil.which("ccache")
|
|
174
|
+
if not ccache:
|
|
175
|
+
return tokens
|
|
176
|
+
return [ccache, *tokens]
|
|
177
|
+
|
|
178
|
+
def resolve_tokens(tokens: list[str]) -> list[str] | None:
|
|
179
|
+
if not tokens:
|
|
180
|
+
return None
|
|
181
|
+
if shutil.which(tokens[0]):
|
|
182
|
+
return tokens
|
|
183
|
+
for token in reversed(tokens):
|
|
184
|
+
if shutil.which(token):
|
|
185
|
+
return [token]
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
if cc:
|
|
189
|
+
return resolve_tokens(shlex.split(cc))
|
|
190
|
+
env_cc = os.environ.get("CC")
|
|
191
|
+
if env_cc:
|
|
192
|
+
return resolve_tokens(shlex.split(env_cc))
|
|
193
|
+
for candidate in ("cc", "gcc", "clang"):
|
|
194
|
+
resolved = shutil.which(candidate)
|
|
195
|
+
if resolved:
|
|
196
|
+
return maybe_prefix_ccache([resolved])
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _handle_verify(args: argparse.Namespace) -> int:
|
|
201
|
+
import numpy as np
|
|
202
|
+
import onnxruntime as ort
|
|
203
|
+
|
|
204
|
+
model_path: Path = args.model
|
|
205
|
+
model_name = args.model_name or model_path.stem
|
|
206
|
+
model_checksum = _model_checksum(model_path)
|
|
207
|
+
compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
|
|
208
|
+
if compiler_cmd is None:
|
|
209
|
+
LOGGER.error("No C compiler found (set --cc or CC environment variable).")
|
|
210
|
+
return 1
|
|
211
|
+
try:
|
|
212
|
+
model = onnx.load_model(model_path)
|
|
213
|
+
options = CompilerOptions(
|
|
214
|
+
template_dir=args.template_dir,
|
|
215
|
+
model_name=model_name,
|
|
216
|
+
emit_testbench=True,
|
|
217
|
+
command_line=args.command_line,
|
|
218
|
+
model_checksum=model_checksum,
|
|
219
|
+
restrict_arrays=args.restrict_arrays,
|
|
220
|
+
)
|
|
221
|
+
compiler = Compiler(options)
|
|
222
|
+
generated = compiler.compile(model)
|
|
223
|
+
except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
224
|
+
LOGGER.error("Failed to compile %s: %s", model_path, exc)
|
|
225
|
+
return 1
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
graph = import_onnx(model)
|
|
229
|
+
output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
|
|
230
|
+
input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
|
|
231
|
+
except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
|
|
232
|
+
LOGGER.error("Failed to resolve model dtype: %s", exc)
|
|
233
|
+
return 1
|
|
234
|
+
|
|
235
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
236
|
+
temp_path = Path(temp_dir)
|
|
237
|
+
c_path = temp_path / "model.c"
|
|
238
|
+
exe_path = temp_path / "model"
|
|
239
|
+
c_path.write_text(generated, encoding="utf-8")
|
|
240
|
+
try:
|
|
241
|
+
subprocess.run(
|
|
242
|
+
[
|
|
243
|
+
*compiler_cmd,
|
|
244
|
+
"-std=c99",
|
|
245
|
+
"-O2",
|
|
246
|
+
str(c_path),
|
|
247
|
+
"-o",
|
|
248
|
+
str(exe_path),
|
|
249
|
+
"-lm",
|
|
250
|
+
],
|
|
251
|
+
check=True,
|
|
252
|
+
capture_output=True,
|
|
253
|
+
text=True,
|
|
254
|
+
)
|
|
255
|
+
except subprocess.CalledProcessError as exc:
|
|
256
|
+
LOGGER.error("Failed to build testbench: %s", exc.stderr.strip())
|
|
257
|
+
return 1
|
|
258
|
+
try:
|
|
259
|
+
result = subprocess.run(
|
|
260
|
+
[str(exe_path)],
|
|
261
|
+
check=True,
|
|
262
|
+
capture_output=True,
|
|
263
|
+
text=True,
|
|
264
|
+
)
|
|
265
|
+
except subprocess.CalledProcessError as exc:
|
|
266
|
+
LOGGER.error("Testbench execution failed: %s", exc.stderr.strip())
|
|
267
|
+
return 1
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
payload = json.loads(result.stdout)
|
|
271
|
+
except json.JSONDecodeError as exc:
|
|
272
|
+
LOGGER.error("Failed to parse testbench JSON: %s", exc)
|
|
273
|
+
return 1
|
|
274
|
+
|
|
275
|
+
inputs = {
|
|
276
|
+
name: np.array(value["data"], dtype=input_dtypes[name].np_dtype)
|
|
277
|
+
for name, value in payload["inputs"].items()
|
|
278
|
+
}
|
|
279
|
+
sess = ort.InferenceSession(
|
|
280
|
+
model.SerializeToString(), providers=["CPUExecutionProvider"]
|
|
281
|
+
)
|
|
282
|
+
try:
|
|
283
|
+
ort_outputs = sess.run(None, inputs)
|
|
284
|
+
except Exception as exc:
|
|
285
|
+
message = str(exc)
|
|
286
|
+
if "NOT_IMPLEMENTED" in message:
|
|
287
|
+
LOGGER.warning(
|
|
288
|
+
"Skipping verification for %s: ONNX Runtime does not support the model (%s)",
|
|
289
|
+
model_path,
|
|
290
|
+
message,
|
|
291
|
+
)
|
|
292
|
+
return 0
|
|
293
|
+
LOGGER.error("ONNX Runtime failed to run %s: %s", model_path, message)
|
|
294
|
+
return 1
|
|
295
|
+
payload_outputs = payload.get("outputs", {})
|
|
296
|
+
try:
|
|
297
|
+
for value, ort_out in zip(graph.outputs, ort_outputs):
|
|
298
|
+
output_payload = payload_outputs.get(value.name)
|
|
299
|
+
if output_payload is None:
|
|
300
|
+
raise AssertionError(f"Missing output {value.name} in testbench data")
|
|
301
|
+
info = output_dtypes[value.name]
|
|
302
|
+
output_data = np.array(output_payload["data"], dtype=info.np_dtype)
|
|
303
|
+
if np.issubdtype(info.np_dtype, np.floating):
|
|
304
|
+
np.testing.assert_allclose(
|
|
305
|
+
output_data, ort_out, rtol=1e-4, atol=1e-5
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
np.testing.assert_array_equal(output_data, ort_out)
|
|
309
|
+
except AssertionError as exc:
|
|
310
|
+
LOGGER.error("Verification failed: %s", exc)
|
|
311
|
+
return 1
|
|
312
|
+
LOGGER.info("Verification succeeded for %s", model_path)
|
|
313
|
+
return 0
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _format_command_line(argv: Sequence[str] | None) -> str:
|
|
317
|
+
if argv is None:
|
|
318
|
+
argv = sys.argv
|
|
319
|
+
args = [str(arg) for arg in argv[1:]]
|
|
320
|
+
if not args:
|
|
321
|
+
return ""
|
|
322
|
+
return shlex.join(args)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _model_checksum(model_path: Path) -> str:
|
|
326
|
+
digest = hashlib.sha256()
|
|
327
|
+
digest.update(model_path.read_bytes())
|
|
328
|
+
return digest.hexdigest()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from .c_emitter import (
|
|
2
|
+
BinaryOp,
|
|
3
|
+
CEmitter,
|
|
4
|
+
CastOp,
|
|
5
|
+
ConstTensor,
|
|
6
|
+
ConstantOfShapeOp,
|
|
7
|
+
GemmOp,
|
|
8
|
+
LoweredModel,
|
|
9
|
+
MatMulOp,
|
|
10
|
+
ShapeOp,
|
|
11
|
+
UnaryOp,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"BinaryOp",
|
|
16
|
+
"CEmitter",
|
|
17
|
+
"CastOp",
|
|
18
|
+
"ConstTensor",
|
|
19
|
+
"ConstantOfShapeOp",
|
|
20
|
+
"GemmOp",
|
|
21
|
+
"LoweredModel",
|
|
22
|
+
"MatMulOp",
|
|
23
|
+
"ShapeOp",
|
|
24
|
+
"UnaryOp",
|
|
25
|
+
]
|