emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,3 @@
1
1
  """Auto-generated by build backend. Do not edit."""
2
- BUILD_DATE = '2026-01-15T22:18:39Z'
2
+ BUILD_DATE = '2026-01-23T02:44:13Z'
3
3
  GIT_VERSION = 'unknown'
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.3.1'
32
+ __version_tuple__ = version_tuple = (0, 3, 1)
33
+
34
+ __commit_id__ = commit_id = None
emx_onnx_cgen/cli.py CHANGED
@@ -10,18 +10,89 @@ import shutil
10
10
  import subprocess
11
11
  import sys
12
12
  import tempfile
13
+ import time
14
+ import signal
13
15
  from pathlib import Path
14
- from typing import Sequence
16
+ from dataclasses import dataclass
17
+ from typing import TYPE_CHECKING, Mapping, Sequence
15
18
 
16
19
  import onnx
20
+ from onnx import numpy_helper
17
21
 
18
22
  from ._build_info import BUILD_DATE, GIT_VERSION
19
23
  from .compiler import Compiler, CompilerOptions
20
24
  from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
21
25
  from .onnx_import import import_onnx
26
+ from .onnxruntime_utils import make_deterministic_session_options
27
+ from .testbench import decode_testbench_array
28
+ from .verification import format_success_message, max_ulp_diff
22
29
 
23
30
  LOGGER = logging.getLogger(__name__)
24
31
 
32
+ if TYPE_CHECKING:
33
+ import numpy as np
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class CliResult:
38
+ exit_code: int
39
+ command_line: str
40
+ error: str | None = None
41
+ success_message: str | None = None
42
+ generated: str | None = None
43
+ data_source: str | None = None
44
+ operators: list[str] | None = None
45
+
46
+
47
+ def run_cli_command(
48
+ argv: Sequence[str],
49
+ *,
50
+ testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
51
+ ) -> CliResult:
52
+ raw_argv = list(argv)
53
+ parse_argv = raw_argv
54
+ if raw_argv and raw_argv[0] == "emx-onnx-cgen":
55
+ parse_argv = raw_argv[1:]
56
+ parser = _build_parser()
57
+ args = parser.parse_args(parse_argv)
58
+ args.command_line = _format_command_line(raw_argv)
59
+
60
+ try:
61
+ if args.command != "compile":
62
+ success_message, error, operators = _verify_model(
63
+ args, include_build_details=False
64
+ )
65
+ return CliResult(
66
+ exit_code=0 if error is None else 1,
67
+ command_line=args.command_line,
68
+ error=error,
69
+ success_message=success_message,
70
+ operators=operators,
71
+ )
72
+ generated, data_source, error = _compile_model(
73
+ args, testbench_inputs=testbench_inputs
74
+ )
75
+ if error:
76
+ return CliResult(
77
+ exit_code=1,
78
+ command_line=args.command_line,
79
+ error=error,
80
+ )
81
+ return CliResult(
82
+ exit_code=0,
83
+ command_line=args.command_line,
84
+ success_message="",
85
+ generated=generated,
86
+ data_source=data_source,
87
+ )
88
+ except Exception as exc: # pragma: no cover - defensive reporting
89
+ LOGGER.exception("Unhandled exception while running CLI command.")
90
+ return CliResult(
91
+ exit_code=1,
92
+ command_line=args.command_line,
93
+ error=str(exc),
94
+ )
95
+
25
96
 
26
97
  def _build_parser() -> argparse.ArgumentParser:
27
98
  description = (
@@ -86,6 +157,33 @@ def _build_parser() -> argparse.ArgumentParser:
86
157
  "named like the output with a _data suffix"
87
158
  ),
88
159
  )
160
+ compile_parser.add_argument(
161
+ "--truncate-weights-after",
162
+ type=int,
163
+ default=None,
164
+ help=(
165
+ "Truncate inline weight initializers after N values and insert "
166
+ "\"...\" placeholders (default: no truncation)"
167
+ ),
168
+ )
169
+ compile_parser.add_argument(
170
+ "--large-temp-threshold-bytes",
171
+ type=int,
172
+ default=1024,
173
+ help=(
174
+ "Mark temporary buffers larger than this threshold as static "
175
+ "(default: 1024)"
176
+ ),
177
+ )
178
+ compile_parser.add_argument(
179
+ "--large-weight-threshold",
180
+ type=int,
181
+ default=1024 * 1024,
182
+ help=(
183
+ "Store weights larger than this element count in a binary file "
184
+ "(default: 1048576; set to 0 to disable)"
185
+ ),
186
+ )
89
187
  add_restrict_flags(compile_parser)
90
188
 
91
189
  verify_parser = subparsers.add_parser(
@@ -111,6 +209,57 @@ def _build_parser() -> argparse.ArgumentParser:
111
209
  default=None,
112
210
  help="C compiler command to build the testbench binary",
113
211
  )
212
+ verify_parser.add_argument(
213
+ "--truncate-weights-after",
214
+ type=int,
215
+ default=None,
216
+ help=(
217
+ "Truncate inline weight initializers after N values and insert "
218
+ "\"...\" placeholders (default: no truncation)"
219
+ ),
220
+ )
221
+ verify_parser.add_argument(
222
+ "--large-temp-threshold-bytes",
223
+ type=int,
224
+ default=1024,
225
+ help=(
226
+ "Mark temporary buffers larger than this threshold as static "
227
+ "(default: 1024)"
228
+ ),
229
+ )
230
+ verify_parser.add_argument(
231
+ "--large-weight-threshold",
232
+ type=int,
233
+ default=1024,
234
+ help=(
235
+ "Store weights larger than this element count in a binary file "
236
+ "(default: 1024)"
237
+ ),
238
+ )
239
+ verify_parser.add_argument(
240
+ "--test-data-dir",
241
+ type=Path,
242
+ default=None,
243
+ help=(
244
+ "Directory containing input_*.pb files to seed verification inputs "
245
+ "(default: use random testbench inputs)"
246
+ ),
247
+ )
248
+ verify_parser.add_argument(
249
+ "--max-ulp",
250
+ type=int,
251
+ default=100,
252
+ help="Maximum allowed ULP difference for floating outputs (default: 100)",
253
+ )
254
+ verify_parser.add_argument(
255
+ "--runtime",
256
+ choices=("onnxruntime", "onnx-reference"),
257
+ default="onnx-reference",
258
+ help=(
259
+ "Runtime backend for verification (default: onnx-reference; "
260
+ "options: onnxruntime, onnx-reference)"
261
+ ),
262
+ )
114
263
  add_restrict_flags(verify_parser)
115
264
  return parser
116
265
 
@@ -132,7 +281,35 @@ def main(argv: Sequence[str] | None = None) -> int:
132
281
  def _handle_compile(args: argparse.Namespace) -> int:
133
282
  model_path: Path = args.model
134
283
  output_path: Path = args.output or model_path.with_suffix(".c")
135
- model_name = args.model_name or output_path.stem
284
+ model_name = args.model_name or "model"
285
+ generated, data_source, weight_data, error = _compile_model(args)
286
+ if error:
287
+ LOGGER.error("Failed to compile %s: %s", model_path, error)
288
+ return 1
289
+
290
+ output_path.parent.mkdir(parents=True, exist_ok=True)
291
+ output_path.write_text(generated or "", encoding="utf-8")
292
+ LOGGER.info("Wrote C source to %s", output_path)
293
+ if data_source is not None:
294
+ data_path = output_path.with_name(
295
+ f"{output_path.stem}_data{output_path.suffix}"
296
+ )
297
+ data_path.write_text(data_source, encoding="utf-8")
298
+ LOGGER.info("Wrote data source to %s", data_path)
299
+ if weight_data is not None:
300
+ weights_path = output_path.with_name(f"{model_name}.bin")
301
+ weights_path.write_bytes(weight_data)
302
+ LOGGER.info("Wrote weights binary to %s", weights_path)
303
+ return 0
304
+
305
+
306
+ def _compile_model(
307
+ args: argparse.Namespace,
308
+ *,
309
+ testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
310
+ ) -> tuple[str | None, str | None, bytes | None, str | None]:
311
+ model_path: Path = args.model
312
+ model_name = args.model_name or "model"
136
313
  try:
137
314
  model_checksum = _model_checksum(model_path)
138
315
  model = onnx.load_model(model_path)
@@ -143,27 +320,22 @@ def _handle_compile(args: argparse.Namespace) -> int:
143
320
  command_line=args.command_line,
144
321
  model_checksum=model_checksum,
145
322
  restrict_arrays=args.restrict_arrays,
323
+ truncate_weights_after=args.truncate_weights_after,
324
+ large_temp_threshold_bytes=args.large_temp_threshold_bytes,
325
+ large_weight_threshold=args.large_weight_threshold,
326
+ testbench_inputs=testbench_inputs,
146
327
  )
147
328
  compiler = Compiler(options)
148
329
  if args.emit_data_file:
149
- generated, data_source = compiler.compile_with_data_file(model)
330
+ generated, data_source, weight_data = (
331
+ compiler.compile_with_data_file_and_weight_data(model)
332
+ )
150
333
  else:
151
- generated = compiler.compile(model)
334
+ generated, weight_data = compiler.compile_with_weight_data(model)
152
335
  data_source = None
153
336
  except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
154
- LOGGER.error("Failed to compile %s: %s", model_path, exc)
155
- return 1
156
-
157
- output_path.parent.mkdir(parents=True, exist_ok=True)
158
- output_path.write_text(generated, encoding="utf-8")
159
- LOGGER.info("Wrote C source to %s", output_path)
160
- if data_source is not None:
161
- data_path = output_path.with_name(
162
- f"{output_path.stem}_data{output_path.suffix}"
163
- )
164
- data_path.write_text(data_source, encoding="utf-8")
165
- LOGGER.info("Wrote data source to %s", data_path)
166
- return 0
337
+ return None, None, None, str(exc)
338
+ return generated, data_source, weight_data, None
167
339
 
168
340
 
169
341
  def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str] | None:
@@ -198,18 +370,59 @@ def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str]
198
370
 
199
371
 
200
372
  def _handle_verify(args: argparse.Namespace) -> int:
373
+ success_message, error, _operators = _verify_model(
374
+ args, include_build_details=True
375
+ )
376
+ if error is not None:
377
+ LOGGER.error("Verification failed: %s", error)
378
+ return 1
379
+ if success_message:
380
+ LOGGER.info("%s", success_message)
381
+ return 0
382
+
383
+
384
+ def _verify_model(
385
+ args: argparse.Namespace,
386
+ *,
387
+ include_build_details: bool,
388
+ ) -> tuple[str | None, str | None, list[str]]:
201
389
  import numpy as np
202
- import onnxruntime as ort
390
+
391
+ def log_step(step: str, started_at: float) -> None:
392
+ duration = time.perf_counter() - started_at
393
+ LOGGER.info("verify step %s: %.3fs", step, duration)
394
+
395
+ def describe_exit_code(returncode: int) -> str:
396
+ if returncode >= 0:
397
+ return f"exit code {returncode}"
398
+ signal_id = -returncode
399
+ try:
400
+ signal_name = signal.Signals(signal_id).name
401
+ except ValueError:
402
+ signal_name = "unknown"
403
+ return f"exit code {returncode} (signal {signal_id}: {signal_name})"
203
404
 
204
405
  model_path: Path = args.model
205
- model_name = args.model_name or model_path.stem
406
+ model_name = args.model_name or "model"
206
407
  model_checksum = _model_checksum(model_path)
207
408
  compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
208
409
  if compiler_cmd is None:
209
- LOGGER.error("No C compiler found (set --cc or CC environment variable).")
210
- return 1
410
+ return (
411
+ None,
412
+ "No C compiler found (set --cc or CC environment variable).",
413
+ [],
414
+ )
211
415
  try:
212
416
  model = onnx.load_model(model_path)
417
+ except OSError as exc:
418
+ return None, str(exc), []
419
+
420
+ operators = _collect_model_operators(model)
421
+ operators_display = ", ".join(operators) if operators else "(none)"
422
+ LOGGER.info("verify operators: %s", operators_display)
423
+
424
+ try:
425
+ testbench_inputs = _load_test_data_inputs(model, args.test_data_dir)
213
426
  options = CompilerOptions(
214
427
  template_dir=args.template_dir,
215
428
  model_name=model_name,
@@ -217,100 +430,183 @@ def _handle_verify(args: argparse.Namespace) -> int:
217
430
  command_line=args.command_line,
218
431
  model_checksum=model_checksum,
219
432
  restrict_arrays=args.restrict_arrays,
433
+ truncate_weights_after=args.truncate_weights_after,
434
+ large_temp_threshold_bytes=args.large_temp_threshold_bytes,
435
+ large_weight_threshold=args.large_weight_threshold,
436
+ testbench_inputs=testbench_inputs,
220
437
  )
221
438
  compiler = Compiler(options)
222
- generated = compiler.compile(model)
223
- except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
224
- LOGGER.error("Failed to compile %s: %s", model_path, exc)
225
- return 1
439
+ codegen_started = time.perf_counter()
440
+ generated, weight_data = compiler.compile_with_weight_data(model)
441
+ log_step("codegen", codegen_started)
442
+ except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
443
+ return None, str(exc), operators
226
444
 
227
445
  try:
228
446
  graph = import_onnx(model)
229
447
  output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
230
448
  input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
231
449
  except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
232
- LOGGER.error("Failed to resolve model dtype: %s", exc)
233
- return 1
450
+ return None, f"Failed to resolve model dtype: {exc}", operators
234
451
 
235
452
  with tempfile.TemporaryDirectory() as temp_dir:
236
453
  temp_path = Path(temp_dir)
454
+ LOGGER.info("verify temp dir: %s", temp_path)
237
455
  c_path = temp_path / "model.c"
456
+ weights_path = temp_path / f"{model_name}.bin"
238
457
  exe_path = temp_path / "model"
239
458
  c_path.write_text(generated, encoding="utf-8")
459
+ if weight_data is not None:
460
+ weights_path.write_bytes(weight_data)
240
461
  try:
462
+ compile_started = time.perf_counter()
463
+ compile_cmd = [
464
+ *compiler_cmd,
465
+ "-std=c99",
466
+ "-O2",
467
+ str(c_path),
468
+ "-o",
469
+ str(exe_path),
470
+ "-lm",
471
+ ]
472
+ LOGGER.info("verify compile command: %s", shlex.join(compile_cmd))
241
473
  subprocess.run(
242
- [
243
- *compiler_cmd,
244
- "-std=c99",
245
- "-O2",
246
- str(c_path),
247
- "-o",
248
- str(exe_path),
249
- "-lm",
250
- ],
474
+ compile_cmd,
251
475
  check=True,
252
476
  capture_output=True,
253
477
  text=True,
254
478
  )
479
+ log_step("compile", compile_started)
255
480
  except subprocess.CalledProcessError as exc:
256
- LOGGER.error("Failed to build testbench: %s", exc.stderr.strip())
257
- return 1
481
+ message = "Failed to build testbench."
482
+ if include_build_details:
483
+ details = exc.stderr.strip()
484
+ if details:
485
+ message = f"{message} {details}"
486
+ return None, message, operators
258
487
  try:
488
+ run_started = time.perf_counter()
259
489
  result = subprocess.run(
260
490
  [str(exe_path)],
261
491
  check=True,
262
492
  capture_output=True,
263
493
  text=True,
494
+ cwd=temp_path,
264
495
  )
496
+ log_step("run", run_started)
265
497
  except subprocess.CalledProcessError as exc:
266
- LOGGER.error("Testbench execution failed: %s", exc.stderr.strip())
267
- return 1
498
+ return None, (
499
+ "Testbench execution failed: " + describe_exit_code(exc.returncode)
500
+ ), operators
268
501
 
269
502
  try:
270
503
  payload = json.loads(result.stdout)
271
504
  except json.JSONDecodeError as exc:
272
- LOGGER.error("Failed to parse testbench JSON: %s", exc)
273
- return 1
505
+ return None, f"Failed to parse testbench JSON: {exc}", operators
274
506
 
275
- inputs = {
276
- name: np.array(value["data"], dtype=input_dtypes[name].np_dtype)
277
- for name, value in payload["inputs"].items()
278
- }
279
- sess = ort.InferenceSession(
280
- model.SerializeToString(), providers=["CPUExecutionProvider"]
281
- )
507
+ if testbench_inputs:
508
+ inputs = {
509
+ name: values.astype(input_dtypes[name].np_dtype, copy=False)
510
+ for name, values in testbench_inputs.items()
511
+ }
512
+ else:
513
+ inputs = {
514
+ name: decode_testbench_array(
515
+ value["data"], input_dtypes[name].np_dtype
516
+ )
517
+ for name, value in payload["inputs"].items()
518
+ }
519
+ runtime_name = args.runtime
520
+ runtime_started = time.perf_counter()
282
521
  try:
283
- ort_outputs = sess.run(None, inputs)
522
+ if runtime_name == "onnxruntime":
523
+ import onnxruntime as ort
524
+
525
+ sess_options = make_deterministic_session_options(ort)
526
+ sess = ort.InferenceSession(
527
+ model.SerializeToString(),
528
+ sess_options=sess_options,
529
+ providers=["CPUExecutionProvider"],
530
+ )
531
+ runtime_outputs = sess.run(None, inputs)
532
+ else:
533
+ from onnx.reference import ReferenceEvaluator
534
+
535
+ evaluator = ReferenceEvaluator(model)
536
+ runtime_outputs = evaluator.run(None, inputs)
284
537
  except Exception as exc:
538
+ log_step(runtime_name, runtime_started)
285
539
  message = str(exc)
286
- if "NOT_IMPLEMENTED" in message:
540
+ if runtime_name == "onnxruntime" and "NOT_IMPLEMENTED" in message:
287
541
  LOGGER.warning(
288
542
  "Skipping verification for %s: ONNX Runtime does not support the model (%s)",
289
543
  model_path,
290
544
  message,
291
545
  )
292
- return 0
293
- LOGGER.error("ONNX Runtime failed to run %s: %s", model_path, message)
294
- return 1
546
+ return "", None, operators
547
+ return (
548
+ None,
549
+ f"{runtime_name} failed to run {model_path}: {message}",
550
+ operators,
551
+ )
552
+ log_step(runtime_name, runtime_started)
295
553
  payload_outputs = payload.get("outputs", {})
554
+ max_ulp = 0
296
555
  try:
297
- for value, ort_out in zip(graph.outputs, ort_outputs):
556
+ for value, runtime_out in zip(graph.outputs, runtime_outputs):
298
557
  output_payload = payload_outputs.get(value.name)
299
558
  if output_payload is None:
300
559
  raise AssertionError(f"Missing output {value.name} in testbench data")
301
560
  info = output_dtypes[value.name]
302
- output_data = np.array(output_payload["data"], dtype=info.np_dtype)
561
+ output_data = decode_testbench_array(
562
+ output_payload["data"], info.np_dtype
563
+ ).astype(info.np_dtype, copy=False)
564
+ runtime_out = runtime_out.astype(info.np_dtype, copy=False)
565
+ output_data = output_data.reshape(runtime_out.shape)
303
566
  if np.issubdtype(info.np_dtype, np.floating):
304
- np.testing.assert_allclose(
305
- output_data, ort_out, rtol=1e-4, atol=1e-5
306
- )
567
+ max_ulp = max(max_ulp, max_ulp_diff(output_data, runtime_out))
307
568
  else:
308
- np.testing.assert_array_equal(output_data, ort_out)
569
+ np.testing.assert_array_equal(output_data, runtime_out)
309
570
  except AssertionError as exc:
310
- LOGGER.error("Verification failed: %s", exc)
311
- return 1
312
- LOGGER.info("Verification succeeded for %s", model_path)
313
- return 0
571
+ return None, str(exc), operators
572
+ if max_ulp > args.max_ulp:
573
+ return None, f"Out of tolerance (max ULP {max_ulp})", operators
574
+ return format_success_message(max_ulp), None, operators
575
+
576
+
577
+ def _load_test_data_inputs(
578
+ model: onnx.ModelProto, data_dir: Path | None
579
+ ) -> dict[str, "np.ndarray"] | None:
580
+ if data_dir is None:
581
+ return None
582
+ if not data_dir.exists():
583
+ raise CodegenError(f"Test data directory not found: {data_dir}")
584
+ input_files = sorted(
585
+ data_dir.glob("input_*.pb"),
586
+ key=lambda path: int(path.stem.split("_")[-1]),
587
+ )
588
+ if not input_files:
589
+ raise CodegenError(f"No input_*.pb files found in {data_dir}")
590
+ if len(input_files) != len(model.graph.input):
591
+ raise CodegenError(
592
+ "Test data input count does not match model inputs: "
593
+ f"{len(input_files)} vs {len(model.graph.input)}."
594
+ )
595
+ for value_info in model.graph.input:
596
+ value_kind = value_info.type.WhichOneof("value")
597
+ if value_kind != "tensor_type":
598
+ LOGGER.warning(
599
+ "Skipping test data load for non-tensor input %s (type %s).",
600
+ value_info.name,
601
+ value_kind or "unknown",
602
+ )
603
+ return None
604
+ inputs: dict[str, np.ndarray] = {}
605
+ for index, path in enumerate(input_files):
606
+ tensor = onnx.TensorProto()
607
+ tensor.ParseFromString(path.read_bytes())
608
+ inputs[model.graph.input[index].name] = numpy_helper.to_array(tensor)
609
+ return inputs
314
610
 
315
611
 
316
612
  def _format_command_line(argv: Sequence[str] | None) -> str:
@@ -326,3 +622,15 @@ def _model_checksum(model_path: Path) -> str:
326
622
  digest = hashlib.sha256()
327
623
  digest.update(model_path.read_bytes())
328
624
  return digest.hexdigest()
625
+
626
+
627
+ def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
628
+ operators: list[str] = []
629
+ seen: set[str] = set()
630
+ for node in model.graph.node:
631
+ op_name = f"{node.domain}::{node.op_type}" if node.domain else node.op_type
632
+ if op_name in seen:
633
+ continue
634
+ seen.add(op_name)
635
+ operators.append(op_name)
636
+ return operators
@@ -7,6 +7,7 @@ from .c_emitter import (
7
7
  GemmOp,
8
8
  LoweredModel,
9
9
  MatMulOp,
10
+ QLinearMatMulOp,
10
11
  ShapeOp,
11
12
  UnaryOp,
12
13
  )
@@ -20,6 +21,7 @@ __all__ = [
20
21
  "GemmOp",
21
22
  "LoweredModel",
22
23
  "MatMulOp",
24
+ "QLinearMatMulOp",
23
25
  "ShapeOp",
24
26
  "UnaryOp",
25
27
  ]