emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,3 @@
1
1
  """Auto-generated by build backend. Do not edit."""
2
- BUILD_DATE = '2026-01-15T22:18:39Z'
2
+ BUILD_DATE = '2026-01-20T17:39:52Z'
3
3
  GIT_VERSION = 'unknown'
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.3.0'
32
+ __version_tuple__ = version_tuple = (0, 3, 0)
33
+
34
+ __commit_id__ = commit_id = None
emx_onnx_cgen/cli.py CHANGED
@@ -10,18 +10,89 @@ import shutil
10
10
  import subprocess
11
11
  import sys
12
12
  import tempfile
13
+ import time
14
+ import signal
13
15
  from pathlib import Path
14
- from typing import Sequence
16
+ from dataclasses import dataclass
17
+ from typing import TYPE_CHECKING, Mapping, Sequence
15
18
 
16
19
  import onnx
20
+ from onnx import numpy_helper
17
21
 
18
22
  from ._build_info import BUILD_DATE, GIT_VERSION
19
23
  from .compiler import Compiler, CompilerOptions
20
24
  from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
21
25
  from .onnx_import import import_onnx
26
+ from .onnxruntime_utils import make_deterministic_session_options
27
+ from .testbench import decode_testbench_array
28
+ from .verification import format_success_message, max_ulp_diff
22
29
 
23
30
  LOGGER = logging.getLogger(__name__)
24
31
 
32
+ if TYPE_CHECKING:
33
+ import numpy as np
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class CliResult:
38
+ exit_code: int
39
+ command_line: str
40
+ error: str | None = None
41
+ success_message: str | None = None
42
+ generated: str | None = None
43
+ data_source: str | None = None
44
+ operators: list[str] | None = None
45
+
46
+
47
+ def run_cli_command(
48
+ argv: Sequence[str],
49
+ *,
50
+ testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
51
+ ) -> CliResult:
52
+ raw_argv = list(argv)
53
+ parse_argv = raw_argv
54
+ if raw_argv and raw_argv[0] == "emx-onnx-cgen":
55
+ parse_argv = raw_argv[1:]
56
+ parser = _build_parser()
57
+ args = parser.parse_args(parse_argv)
58
+ args.command_line = _format_command_line(raw_argv)
59
+
60
+ try:
61
+ if args.command != "compile":
62
+ success_message, error, operators = _verify_model(
63
+ args, include_build_details=False
64
+ )
65
+ return CliResult(
66
+ exit_code=0 if error is None else 1,
67
+ command_line=args.command_line,
68
+ error=error,
69
+ success_message=success_message,
70
+ operators=operators,
71
+ )
72
+ generated, data_source, error = _compile_model(
73
+ args, testbench_inputs=testbench_inputs
74
+ )
75
+ if error:
76
+ return CliResult(
77
+ exit_code=1,
78
+ command_line=args.command_line,
79
+ error=error,
80
+ )
81
+ return CliResult(
82
+ exit_code=0,
83
+ command_line=args.command_line,
84
+ success_message="",
85
+ generated=generated,
86
+ data_source=data_source,
87
+ )
88
+ except Exception as exc: # pragma: no cover - defensive reporting
89
+ LOGGER.exception("Unhandled exception while running CLI command.")
90
+ return CliResult(
91
+ exit_code=1,
92
+ command_line=args.command_line,
93
+ error=str(exc),
94
+ )
95
+
25
96
 
26
97
  def _build_parser() -> argparse.ArgumentParser:
27
98
  description = (
@@ -86,6 +157,33 @@ def _build_parser() -> argparse.ArgumentParser:
86
157
  "named like the output with a _data suffix"
87
158
  ),
88
159
  )
160
+ compile_parser.add_argument(
161
+ "--truncate-weights-after",
162
+ type=int,
163
+ default=None,
164
+ help=(
165
+ "Truncate inline weight initializers after N values and insert "
166
+ "\"...\" placeholders (default: no truncation)"
167
+ ),
168
+ )
169
+ compile_parser.add_argument(
170
+ "--large-temp-threshold-bytes",
171
+ type=int,
172
+ default=1024,
173
+ help=(
174
+ "Mark temporary buffers larger than this threshold as static "
175
+ "(default: 1024)"
176
+ ),
177
+ )
178
+ compile_parser.add_argument(
179
+ "--large-weight-threshold",
180
+ type=int,
181
+ default=1024,
182
+ help=(
183
+ "Store weights larger than this element count in a binary file "
184
+ "(default: 1024)"
185
+ ),
186
+ )
89
187
  add_restrict_flags(compile_parser)
90
188
 
91
189
  verify_parser = subparsers.add_parser(
@@ -111,6 +209,48 @@ def _build_parser() -> argparse.ArgumentParser:
111
209
  default=None,
112
210
  help="C compiler command to build the testbench binary",
113
211
  )
212
+ verify_parser.add_argument(
213
+ "--truncate-weights-after",
214
+ type=int,
215
+ default=None,
216
+ help=(
217
+ "Truncate inline weight initializers after N values and insert "
218
+ "\"...\" placeholders (default: no truncation)"
219
+ ),
220
+ )
221
+ verify_parser.add_argument(
222
+ "--large-temp-threshold-bytes",
223
+ type=int,
224
+ default=1024,
225
+ help=(
226
+ "Mark temporary buffers larger than this threshold as static "
227
+ "(default: 1024)"
228
+ ),
229
+ )
230
+ verify_parser.add_argument(
231
+ "--large-weight-threshold",
232
+ type=int,
233
+ default=1024,
234
+ help=(
235
+ "Store weights larger than this element count in a binary file "
236
+ "(default: 1024)"
237
+ ),
238
+ )
239
+ verify_parser.add_argument(
240
+ "--test-data-dir",
241
+ type=Path,
242
+ default=None,
243
+ help=(
244
+ "Directory containing input_*.pb files to seed verification inputs "
245
+ "(default: use random testbench inputs)"
246
+ ),
247
+ )
248
+ verify_parser.add_argument(
249
+ "--max-ulp",
250
+ type=int,
251
+ default=100,
252
+ help="Maximum allowed ULP difference for floating outputs (default: 100)",
253
+ )
114
254
  add_restrict_flags(verify_parser)
115
255
  return parser
116
256
 
@@ -132,7 +272,35 @@ def main(argv: Sequence[str] | None = None) -> int:
132
272
  def _handle_compile(args: argparse.Namespace) -> int:
133
273
  model_path: Path = args.model
134
274
  output_path: Path = args.output or model_path.with_suffix(".c")
135
- model_name = args.model_name or output_path.stem
275
+ model_name = args.model_name or "model"
276
+ generated, data_source, weight_data, error = _compile_model(args)
277
+ if error:
278
+ LOGGER.error("Failed to compile %s: %s", model_path, error)
279
+ return 1
280
+
281
+ output_path.parent.mkdir(parents=True, exist_ok=True)
282
+ output_path.write_text(generated or "", encoding="utf-8")
283
+ LOGGER.info("Wrote C source to %s", output_path)
284
+ if data_source is not None:
285
+ data_path = output_path.with_name(
286
+ f"{output_path.stem}_data{output_path.suffix}"
287
+ )
288
+ data_path.write_text(data_source, encoding="utf-8")
289
+ LOGGER.info("Wrote data source to %s", data_path)
290
+ if weight_data is not None:
291
+ weights_path = output_path.with_name(f"{model_name}.bin")
292
+ weights_path.write_bytes(weight_data)
293
+ LOGGER.info("Wrote weights binary to %s", weights_path)
294
+ return 0
295
+
296
+
297
+ def _compile_model(
298
+ args: argparse.Namespace,
299
+ *,
300
+ testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
301
+ ) -> tuple[str | None, str | None, bytes | None, str | None]:
302
+ model_path: Path = args.model
303
+ model_name = args.model_name or "model"
136
304
  try:
137
305
  model_checksum = _model_checksum(model_path)
138
306
  model = onnx.load_model(model_path)
@@ -143,27 +311,22 @@ def _handle_compile(args: argparse.Namespace) -> int:
143
311
  command_line=args.command_line,
144
312
  model_checksum=model_checksum,
145
313
  restrict_arrays=args.restrict_arrays,
314
+ truncate_weights_after=args.truncate_weights_after,
315
+ large_temp_threshold_bytes=args.large_temp_threshold_bytes,
316
+ large_weight_threshold=args.large_weight_threshold,
317
+ testbench_inputs=testbench_inputs,
146
318
  )
147
319
  compiler = Compiler(options)
148
320
  if args.emit_data_file:
149
- generated, data_source = compiler.compile_with_data_file(model)
321
+ generated, data_source, weight_data = (
322
+ compiler.compile_with_data_file_and_weight_data(model)
323
+ )
150
324
  else:
151
- generated = compiler.compile(model)
325
+ generated, weight_data = compiler.compile_with_weight_data(model)
152
326
  data_source = None
153
327
  except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
154
- LOGGER.error("Failed to compile %s: %s", model_path, exc)
155
- return 1
156
-
157
- output_path.parent.mkdir(parents=True, exist_ok=True)
158
- output_path.write_text(generated, encoding="utf-8")
159
- LOGGER.info("Wrote C source to %s", output_path)
160
- if data_source is not None:
161
- data_path = output_path.with_name(
162
- f"{output_path.stem}_data{output_path.suffix}"
163
- )
164
- data_path.write_text(data_source, encoding="utf-8")
165
- LOGGER.info("Wrote data source to %s", data_path)
166
- return 0
328
+ return None, None, None, str(exc)
329
+ return generated, data_source, weight_data, None
167
330
 
168
331
 
169
332
  def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str] | None:
@@ -201,15 +364,60 @@ def _handle_verify(args: argparse.Namespace) -> int:
201
364
  import numpy as np
202
365
  import onnxruntime as ort
203
366
 
367
+ success_message, error, _operators = _verify_model(
368
+ args, include_build_details=True
369
+ )
370
+ if error is not None:
371
+ LOGGER.error("Verification failed: %s", error)
372
+ return 1
373
+ if success_message:
374
+ LOGGER.info("%s", success_message)
375
+ return 0
376
+
377
+
378
+ def _verify_model(
379
+ args: argparse.Namespace,
380
+ *,
381
+ include_build_details: bool,
382
+ ) -> tuple[str | None, str | None, list[str]]:
383
+ import numpy as np
384
+ import onnxruntime as ort
385
+
386
+ def log_step(step: str, started_at: float) -> None:
387
+ duration = time.perf_counter() - started_at
388
+ LOGGER.info("verify step %s: %.3fs", step, duration)
389
+
390
+ def describe_exit_code(returncode: int) -> str:
391
+ if returncode >= 0:
392
+ return f"exit code {returncode}"
393
+ signal_id = -returncode
394
+ try:
395
+ signal_name = signal.Signals(signal_id).name
396
+ except ValueError:
397
+ signal_name = "unknown"
398
+ return f"exit code {returncode} (signal {signal_id}: {signal_name})"
399
+
204
400
  model_path: Path = args.model
205
- model_name = args.model_name or model_path.stem
401
+ model_name = args.model_name or "model"
206
402
  model_checksum = _model_checksum(model_path)
207
403
  compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
208
404
  if compiler_cmd is None:
209
- LOGGER.error("No C compiler found (set --cc or CC environment variable).")
210
- return 1
405
+ return (
406
+ None,
407
+ "No C compiler found (set --cc or CC environment variable).",
408
+ [],
409
+ )
211
410
  try:
212
411
  model = onnx.load_model(model_path)
412
+ except OSError as exc:
413
+ return None, str(exc), []
414
+
415
+ operators = _collect_model_operators(model)
416
+ operators_display = ", ".join(operators) if operators else "(none)"
417
+ LOGGER.info("verify operators: %s", operators_display)
418
+
419
+ try:
420
+ testbench_inputs = _load_test_data_inputs(model, args.test_data_dir)
213
421
  options = CompilerOptions(
214
422
  template_dir=args.template_dir,
215
423
  model_name=model_name,
@@ -217,71 +425,103 @@ def _handle_verify(args: argparse.Namespace) -> int:
217
425
  command_line=args.command_line,
218
426
  model_checksum=model_checksum,
219
427
  restrict_arrays=args.restrict_arrays,
428
+ truncate_weights_after=args.truncate_weights_after,
429
+ large_temp_threshold_bytes=args.large_temp_threshold_bytes,
430
+ large_weight_threshold=args.large_weight_threshold,
431
+ testbench_inputs=testbench_inputs,
220
432
  )
221
433
  compiler = Compiler(options)
222
- generated = compiler.compile(model)
223
- except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
224
- LOGGER.error("Failed to compile %s: %s", model_path, exc)
225
- return 1
434
+ codegen_started = time.perf_counter()
435
+ generated, weight_data = compiler.compile_with_weight_data(model)
436
+ log_step("codegen", codegen_started)
437
+ except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
438
+ return None, str(exc), operators
226
439
 
227
440
  try:
228
441
  graph = import_onnx(model)
229
442
  output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
230
443
  input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
231
444
  except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
232
- LOGGER.error("Failed to resolve model dtype: %s", exc)
233
- return 1
445
+ return None, f"Failed to resolve model dtype: {exc}", operators
234
446
 
235
447
  with tempfile.TemporaryDirectory() as temp_dir:
236
448
  temp_path = Path(temp_dir)
449
+ LOGGER.info("verify temp dir: %s", temp_path)
237
450
  c_path = temp_path / "model.c"
451
+ weights_path = temp_path / f"{model_name}.bin"
238
452
  exe_path = temp_path / "model"
239
453
  c_path.write_text(generated, encoding="utf-8")
454
+ if weight_data is not None:
455
+ weights_path.write_bytes(weight_data)
240
456
  try:
457
+ compile_started = time.perf_counter()
458
+ compile_cmd = [
459
+ *compiler_cmd,
460
+ "-std=c99",
461
+ "-O2",
462
+ str(c_path),
463
+ "-o",
464
+ str(exe_path),
465
+ "-lm",
466
+ ]
467
+ LOGGER.info("verify compile command: %s", shlex.join(compile_cmd))
241
468
  subprocess.run(
242
- [
243
- *compiler_cmd,
244
- "-std=c99",
245
- "-O2",
246
- str(c_path),
247
- "-o",
248
- str(exe_path),
249
- "-lm",
250
- ],
469
+ compile_cmd,
251
470
  check=True,
252
471
  capture_output=True,
253
472
  text=True,
254
473
  )
474
+ log_step("compile", compile_started)
255
475
  except subprocess.CalledProcessError as exc:
256
- LOGGER.error("Failed to build testbench: %s", exc.stderr.strip())
257
- return 1
476
+ message = "Failed to build testbench."
477
+ if include_build_details:
478
+ details = exc.stderr.strip()
479
+ if details:
480
+ message = f"{message} {details}"
481
+ return None, message, operators
258
482
  try:
483
+ run_started = time.perf_counter()
259
484
  result = subprocess.run(
260
485
  [str(exe_path)],
261
486
  check=True,
262
487
  capture_output=True,
263
488
  text=True,
489
+ cwd=temp_path,
264
490
  )
491
+ log_step("run", run_started)
265
492
  except subprocess.CalledProcessError as exc:
266
- LOGGER.error("Testbench execution failed: %s", exc.stderr.strip())
267
- return 1
493
+ return None, (
494
+ "Testbench execution failed: " + describe_exit_code(exc.returncode)
495
+ ), operators
268
496
 
269
497
  try:
270
498
  payload = json.loads(result.stdout)
271
499
  except json.JSONDecodeError as exc:
272
- LOGGER.error("Failed to parse testbench JSON: %s", exc)
273
- return 1
500
+ return None, f"Failed to parse testbench JSON: {exc}", operators
274
501
 
275
- inputs = {
276
- name: np.array(value["data"], dtype=input_dtypes[name].np_dtype)
277
- for name, value in payload["inputs"].items()
278
- }
279
- sess = ort.InferenceSession(
280
- model.SerializeToString(), providers=["CPUExecutionProvider"]
281
- )
502
+ if testbench_inputs:
503
+ inputs = {
504
+ name: values.astype(input_dtypes[name].np_dtype, copy=False)
505
+ for name, values in testbench_inputs.items()
506
+ }
507
+ else:
508
+ inputs = {
509
+ name: decode_testbench_array(
510
+ value["data"], input_dtypes[name].np_dtype
511
+ )
512
+ for name, value in payload["inputs"].items()
513
+ }
282
514
  try:
515
+ ort_started = time.perf_counter()
516
+ sess_options = make_deterministic_session_options(ort)
517
+ sess = ort.InferenceSession(
518
+ model.SerializeToString(),
519
+ sess_options=sess_options,
520
+ providers=["CPUExecutionProvider"],
521
+ )
283
522
  ort_outputs = sess.run(None, inputs)
284
523
  except Exception as exc:
524
+ log_step("onnx runtime", ort_started)
285
525
  message = str(exc)
286
526
  if "NOT_IMPLEMENTED" in message:
287
527
  LOGGER.warning(
@@ -289,28 +529,57 @@ def _handle_verify(args: argparse.Namespace) -> int:
289
529
  model_path,
290
530
  message,
291
531
  )
292
- return 0
293
- LOGGER.error("ONNX Runtime failed to run %s: %s", model_path, message)
294
- return 1
532
+ return "", None, operators
533
+ return None, f"ONNX Runtime failed to run {model_path}: {message}", operators
534
+ log_step("onnx runtime", ort_started)
295
535
  payload_outputs = payload.get("outputs", {})
536
+ max_ulp = 0
296
537
  try:
297
538
  for value, ort_out in zip(graph.outputs, ort_outputs):
298
539
  output_payload = payload_outputs.get(value.name)
299
540
  if output_payload is None:
300
541
  raise AssertionError(f"Missing output {value.name} in testbench data")
301
542
  info = output_dtypes[value.name]
302
- output_data = np.array(output_payload["data"], dtype=info.np_dtype)
543
+ output_data = decode_testbench_array(
544
+ output_payload["data"], info.np_dtype
545
+ ).astype(info.np_dtype, copy=False)
546
+ ort_out = ort_out.astype(info.np_dtype, copy=False)
547
+ output_data = output_data.reshape(ort_out.shape)
303
548
  if np.issubdtype(info.np_dtype, np.floating):
304
- np.testing.assert_allclose(
305
- output_data, ort_out, rtol=1e-4, atol=1e-5
306
- )
549
+ max_ulp = max(max_ulp, max_ulp_diff(output_data, ort_out))
307
550
  else:
308
551
  np.testing.assert_array_equal(output_data, ort_out)
309
552
  except AssertionError as exc:
310
- LOGGER.error("Verification failed: %s", exc)
311
- return 1
312
- LOGGER.info("Verification succeeded for %s", model_path)
313
- return 0
553
+ return None, str(exc), operators
554
+ if max_ulp > args.max_ulp:
555
+ return None, f"Out of tolerance (max ULP {max_ulp})", operators
556
+ return format_success_message(max_ulp), None, operators
557
+
558
+
559
+ def _load_test_data_inputs(
560
+ model: onnx.ModelProto, data_dir: Path | None
561
+ ) -> dict[str, "np.ndarray"] | None:
562
+ if data_dir is None:
563
+ return None
564
+ if not data_dir.exists():
565
+ raise CodegenError(f"Test data directory not found: {data_dir}")
566
+ input_files = sorted(
567
+ data_dir.glob("input_*.pb"),
568
+ key=lambda path: int(path.stem.split("_")[-1]),
569
+ )
570
+ if not input_files:
571
+ raise CodegenError(f"No input_*.pb files found in {data_dir}")
572
+ if len(input_files) != len(model.graph.input):
573
+ raise CodegenError(
574
+ "Test data input count does not match model inputs: "
575
+ f"{len(input_files)} vs {len(model.graph.input)}."
576
+ )
577
+ inputs: dict[str, np.ndarray] = {}
578
+ for index, path in enumerate(input_files):
579
+ tensor = onnx.TensorProto()
580
+ tensor.ParseFromString(path.read_bytes())
581
+ inputs[model.graph.input[index].name] = numpy_helper.to_array(tensor)
582
+ return inputs
314
583
 
315
584
 
316
585
  def _format_command_line(argv: Sequence[str] | None) -> str:
@@ -326,3 +595,15 @@ def _model_checksum(model_path: Path) -> str:
326
595
  digest = hashlib.sha256()
327
596
  digest.update(model_path.read_bytes())
328
597
  return digest.hexdigest()
598
+
599
+
600
+ def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
601
+ operators: list[str] = []
602
+ seen: set[str] = set()
603
+ for node in model.graph.node:
604
+ op_name = f"{node.domain}::{node.op_type}" if node.domain else node.op_type
605
+ if op_name in seen:
606
+ continue
607
+ seen.add(op_name)
608
+ operators.append(op_name)
609
+ return operators