emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/cli.py CHANGED
@@ -7,15 +7,16 @@ import logging
7
7
  import os
8
8
  import shlex
9
9
  import shutil
10
+ import signal
10
11
  import subprocess
11
12
  import sys
12
13
  import tempfile
13
14
  import time
14
- import signal
15
- from pathlib import Path
16
15
  from dataclasses import dataclass
17
- from typing import TYPE_CHECKING, Mapping, Sequence
16
+ from pathlib import Path
17
+ from typing import Any, Mapping, Sequence, TextIO
18
18
 
19
+ import numpy as np
19
20
  import onnx
20
21
  from onnx import numpy_helper
21
22
 
@@ -23,25 +24,220 @@ from ._build_info import BUILD_DATE, GIT_VERSION
23
24
  from .compiler import Compiler, CompilerOptions
24
25
  from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
25
26
  from .onnx_import import import_onnx
27
+ from .determinism import deterministic_reference_runtime
26
28
  from .onnxruntime_utils import make_deterministic_session_options
27
29
  from .testbench import decode_testbench_array
28
- from .verification import format_success_message, max_ulp_diff
30
+ from .verification import format_success_message, worst_ulp_diff
29
31
 
30
32
  LOGGER = logging.getLogger(__name__)
31
-
32
- if TYPE_CHECKING:
33
- import numpy as np
33
+ _NONDETERMINISTIC_OPERATORS = {"Bernoulli"}
34
34
 
35
35
 
36
36
  @dataclass(frozen=True)
37
37
  class CliResult:
38
38
  exit_code: int
39
39
  command_line: str
40
- error: str | None = None
41
- success_message: str | None = None
40
+ result: str | None = None
42
41
  generated: str | None = None
43
42
  data_source: str | None = None
44
43
  operators: list[str] | None = None
44
+ opset_version: int | None = None
45
+ generated_checksum: str | None = None
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class _WorstDiff:
50
+ output_name: str
51
+ node_name: str | None
52
+ index: tuple[int, ...]
53
+ got: float
54
+ reference: float
55
+ ulp: int
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class _WorstAbsDiff:
60
+ output_name: str
61
+ node_name: str | None
62
+ index: tuple[int, ...]
63
+ got: object
64
+ reference: object
65
+ abs_diff: float | int
66
+
67
+
68
+ class _VerifyReporter:
69
+ def __init__(
70
+ self,
71
+ stream: TextIO | None = None,
72
+ *,
73
+ color_mode: str = "auto",
74
+ ) -> None:
75
+ self._stream = stream or sys.stdout
76
+ self._use_color = self._should_use_color(color_mode)
77
+
78
+ def _should_use_color(self, color_mode: str) -> bool:
79
+ if color_mode == "always":
80
+ return True
81
+ if color_mode == "never":
82
+ return False
83
+ if not hasattr(self._stream, "isatty"):
84
+ return False
85
+ return bool(self._stream.isatty())
86
+
87
+ def _color(self, text: str, code: str) -> str:
88
+ if not self._use_color:
89
+ return text
90
+ return f"\x1b[{code}m{text}\x1b[0m"
91
+
92
+ def start_step(self, label: str) -> float:
93
+ print(f"{label} ...", end=" ", file=self._stream, flush=True)
94
+ return time.perf_counter()
95
+
96
+ def step_ok(self, started_at: float) -> None:
97
+ duration = time.perf_counter() - started_at
98
+ ok = self._color("OK", "32")
99
+ dim = self._color(f"({duration:.3f}s)", "90")
100
+ print(f"{ok} {dim}", file=self._stream)
101
+
102
+ def step_ok_simple(self) -> None:
103
+ ok = self._color("OK", "32")
104
+ print(ok, file=self._stream)
105
+
106
+ def step_ok_detail(self, detail: str) -> None:
107
+ ok = self._color("OK", "32")
108
+ dim = self._color(f"({detail})", "90")
109
+ print(f"{ok} {dim}", file=self._stream)
110
+
111
+ def step_fail(self, reason: str) -> None:
112
+ fail = self._color("FAIL", "31")
113
+ print(f"{fail} ({reason})", file=self._stream)
114
+
115
+ def note(self, message: str) -> None:
116
+ label = self._color("Note:", "33")
117
+ print(f"{label} {message}", file=self._stream)
118
+
119
+ def info(self, message: str) -> None:
120
+ print(message, file=self._stream)
121
+
122
+ def result(self, message: str, *, ok: bool) -> None:
123
+ colored = self._color(message, "32" if ok else "31")
124
+ print(f"Result: {colored}", file=self._stream)
125
+
126
+
127
+ class _NullVerifyReporter(_VerifyReporter):
128
+ def __init__(self) -> None:
129
+ super().__init__(stream=sys.stdout, color_mode="never")
130
+
131
+ def start_step(self, label: str) -> float:
132
+ return time.perf_counter()
133
+
134
+ def step_ok(self, started_at: float) -> None:
135
+ return None
136
+
137
+ def step_ok_simple(self) -> None:
138
+ return None
139
+
140
+ def step_ok_detail(self, detail: str) -> None:
141
+ return None
142
+
143
+ def step_fail(self, reason: str) -> None:
144
+ return None
145
+
146
+ def note(self, message: str) -> None:
147
+ return None
148
+
149
+ def info(self, message: str) -> None:
150
+ return None
151
+
152
+ def result(self, message: str, *, ok: bool) -> None:
153
+ return None
154
+
155
+
156
+ def _format_artifact_size(size_bytes: int) -> str:
157
+ if size_bytes < 1024:
158
+ return f"{size_bytes} bytes"
159
+ return f"{size_bytes / 1024:.1f} KiB"
160
+
161
+
162
+ def _report_generated_artifacts(
163
+ reporter: _VerifyReporter,
164
+ *,
165
+ artifacts: Sequence[tuple[str, int]],
166
+ ) -> None:
167
+ for name, size_bytes in artifacts:
168
+ reporter.info(f" {name} ({_format_artifact_size(size_bytes)})")
169
+
170
+
171
+ def _worst_ulp_diff(
172
+ actual: "np.ndarray", expected: "np.ndarray"
173
+ ) -> tuple[int, tuple[tuple[int, ...], float, float] | None]:
174
+ if actual.shape != expected.shape:
175
+ raise ValueError(
176
+ f"Shape mismatch for ULP calculation: {actual.shape} vs {expected.shape}"
177
+ )
178
+ if not np.issubdtype(expected.dtype, np.floating):
179
+ return 0, None
180
+ if actual.size == 0:
181
+ return 0, None
182
+ dtype = expected.dtype
183
+ actual_cast = actual.astype(dtype, copy=False)
184
+ expected_cast = expected.astype(dtype, copy=False)
185
+ max_diff = 0
186
+ worst: tuple[tuple[int, ...], float, float] | None = None
187
+ iterator = np.nditer(
188
+ [actual_cast, expected_cast], flags=["refs_ok", "multi_index"]
189
+ )
190
+ for actual_value, expected_value in iterator:
191
+ actual_scalar = float(actual_value[()])
192
+ expected_scalar = float(expected_value[()])
193
+ diff = ulp_intdiff_float(actual_value[()], expected_value[()])
194
+ if diff > max_diff:
195
+ max_diff = diff
196
+ worst = (
197
+ iterator.multi_index,
198
+ actual_scalar,
199
+ expected_scalar,
200
+ )
201
+ return max_diff, worst
202
+
203
+
204
+ def _worst_abs_diff(
205
+ actual: "np.ndarray", expected: "np.ndarray"
206
+ ) -> tuple[float | int, tuple[tuple[int, ...], object, object] | None]:
207
+ if actual.shape != expected.shape:
208
+ raise ValueError(
209
+ f"Shape mismatch for diff calculation: {actual.shape} vs {expected.shape}"
210
+ )
211
+ if actual.size == 0:
212
+ return 0, None
213
+ dtype = expected.dtype
214
+ actual_cast = actual.astype(dtype, copy=False)
215
+ expected_cast = expected.astype(dtype, copy=False)
216
+ max_diff: float | int = 0
217
+ worst: tuple[tuple[int, ...], object, object] | None = None
218
+ iterator = np.nditer(
219
+ [actual_cast, expected_cast], flags=["refs_ok", "multi_index"]
220
+ )
221
+ for actual_value, expected_value in iterator:
222
+ actual_scalar = actual_value[()]
223
+ expected_scalar = expected_value[()]
224
+ if actual_scalar == expected_scalar:
225
+ continue
226
+ try:
227
+ if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_):
228
+ diff: float | int = abs(int(actual_scalar) - int(expected_scalar))
229
+ else:
230
+ diff = float(abs(actual_scalar - expected_scalar))
231
+ except Exception:
232
+ diff = 1
233
+ if diff > max_diff:
234
+ max_diff = diff
235
+ worst = (
236
+ iterator.multi_index,
237
+ actual_scalar,
238
+ expected_scalar,
239
+ )
240
+ return max_diff, worst
45
241
 
46
242
 
47
243
  def run_cli_command(
@@ -56,18 +252,26 @@ def run_cli_command(
56
252
  parser = _build_parser()
57
253
  args = parser.parse_args(parse_argv)
58
254
  args.command_line = _format_command_line(raw_argv)
255
+ _apply_base_dir(args, parser)
59
256
 
60
257
  try:
61
258
  if args.command != "compile":
62
- success_message, error, operators = _verify_model(
63
- args, include_build_details=False
259
+ (
260
+ success_message,
261
+ error,
262
+ operators,
263
+ opset_version,
264
+ generated_checksum,
265
+ ) = _verify_model(
266
+ args, include_build_details=False, reporter=_NullVerifyReporter()
64
267
  )
65
268
  return CliResult(
66
269
  exit_code=0 if error is None else 1,
67
270
  command_line=args.command_line,
68
- error=error,
69
- success_message=success_message,
271
+ result=error or success_message,
70
272
  operators=operators,
273
+ opset_version=opset_version,
274
+ generated_checksum=generated_checksum,
71
275
  )
72
276
  generated, data_source, error = _compile_model(
73
277
  args, testbench_inputs=testbench_inputs
@@ -76,12 +280,12 @@ def run_cli_command(
76
280
  return CliResult(
77
281
  exit_code=1,
78
282
  command_line=args.command_line,
79
- error=error,
283
+ result=error,
80
284
  )
81
285
  return CliResult(
82
286
  exit_code=0,
83
287
  command_line=args.command_line,
84
- success_message="",
288
+ result="",
85
289
  generated=generated,
86
290
  data_source=data_source,
87
291
  )
@@ -90,7 +294,7 @@ def run_cli_command(
90
294
  return CliResult(
91
295
  exit_code=1,
92
296
  command_line=args.command_line,
93
- error=str(exc),
297
+ result=str(exc),
94
298
  )
95
299
 
96
300
 
@@ -102,6 +306,24 @@ def _build_parser() -> argparse.ArgumentParser:
102
306
  parser = argparse.ArgumentParser(prog="emx-onnx-cgen", description=description)
103
307
  subparsers = parser.add_subparsers(dest="command", required=True)
104
308
 
309
+ def add_color_flag(subparser: argparse.ArgumentParser) -> None:
310
+ subparser.add_argument(
311
+ "--color",
312
+ choices=("auto", "always", "never"),
313
+ default="auto",
314
+ help=(
315
+ "Colorize CLI output (default: auto; options: auto, always, never)"
316
+ ),
317
+ )
318
+
319
+ def add_verbose_flag(subparser: argparse.ArgumentParser) -> None:
320
+ subparser.add_argument(
321
+ "--verbose",
322
+ "-v",
323
+ action="store_true",
324
+ help="Enable verbose logging (includes codegen timing).",
325
+ )
326
+
105
327
  def add_restrict_flags(subparser: argparse.ArgumentParser) -> None:
106
328
  restrict_group = subparser.add_mutually_exclusive_group()
107
329
  restrict_group.add_argument(
@@ -118,9 +340,47 @@ def _build_parser() -> argparse.ArgumentParser:
118
340
  )
119
341
  subparser.set_defaults(restrict_arrays=True)
120
342
 
343
+ def add_fp32_accumulation_strategy_flag(
344
+ subparser: argparse.ArgumentParser,
345
+ ) -> None:
346
+ subparser.add_argument(
347
+ "--fp32-accumulation-strategy",
348
+ choices=("simple", "fp64"),
349
+ default="simple",
350
+ help=(
351
+ "Accumulation strategy for float32 inputs "
352
+ "(simple uses float32, fp64 uses double; default: simple)"
353
+ ),
354
+ )
355
+
356
+ def add_fp16_accumulation_strategy_flag(
357
+ subparser: argparse.ArgumentParser,
358
+ ) -> None:
359
+ subparser.add_argument(
360
+ "--fp16-accumulation-strategy",
361
+ choices=("simple", "fp32"),
362
+ default="fp32",
363
+ help=(
364
+ "Accumulation strategy for float16 inputs "
365
+ "(simple uses float16, fp32 uses float; default: fp32)"
366
+ ),
367
+ )
368
+
121
369
  compile_parser = subparsers.add_parser(
122
370
  "compile", help="Compile an ONNX model into C source"
123
371
  )
372
+ add_color_flag(compile_parser)
373
+ add_verbose_flag(compile_parser)
374
+ compile_parser.add_argument(
375
+ "--model-base-dir",
376
+ "-B",
377
+ type=Path,
378
+ default=None,
379
+ help=(
380
+ "Base directory for resolving the model path "
381
+ "(example: tool --model-base-dir /data model.onnx)"
382
+ ),
383
+ )
124
384
  compile_parser.add_argument("model", type=Path, help="Path to the ONNX model")
125
385
  compile_parser.add_argument(
126
386
  "output",
@@ -132,12 +392,6 @@ def _build_parser() -> argparse.ArgumentParser:
132
392
  "e.g., model.onnx -> model.c)"
133
393
  ),
134
394
  )
135
- compile_parser.add_argument(
136
- "--template-dir",
137
- type=Path,
138
- default=Path("templates"),
139
- help="Template directory (default: templates)",
140
- )
141
395
  compile_parser.add_argument(
142
396
  "--model-name",
143
397
  type=str,
@@ -167,9 +421,10 @@ def _build_parser() -> argparse.ArgumentParser:
167
421
  ),
168
422
  )
169
423
  compile_parser.add_argument(
170
- "--large-temp-threshold-bytes",
424
+ "--large-temp-threshold",
171
425
  type=int,
172
426
  default=1024,
427
+ dest="large_temp_threshold_bytes",
173
428
  help=(
174
429
  "Mark temporary buffers larger than this threshold as static "
175
430
  "(default: 1024)"
@@ -178,25 +433,33 @@ def _build_parser() -> argparse.ArgumentParser:
178
433
  compile_parser.add_argument(
179
434
  "--large-weight-threshold",
180
435
  type=int,
181
- default=1024 * 1024,
436
+ default=100 * 1024,
182
437
  help=(
183
- "Store weights larger than this element count in a binary file "
184
- "(default: 1048576; set to 0 to disable)"
438
+ "Store weights in a binary file once the cumulative byte size "
439
+ "exceeds this threshold (default: 102400; set to 0 to disable)"
185
440
  ),
186
441
  )
187
442
  add_restrict_flags(compile_parser)
443
+ add_fp32_accumulation_strategy_flag(compile_parser)
444
+ add_fp16_accumulation_strategy_flag(compile_parser)
188
445
 
189
446
  verify_parser = subparsers.add_parser(
190
447
  "verify",
191
448
  help="Compile an ONNX model and verify outputs against ONNX Runtime",
192
449
  )
193
- verify_parser.add_argument("model", type=Path, help="Path to the ONNX model")
450
+ add_color_flag(verify_parser)
451
+ add_verbose_flag(verify_parser)
194
452
  verify_parser.add_argument(
195
- "--template-dir",
453
+ "--model-base-dir",
454
+ "-B",
196
455
  type=Path,
197
- default=Path("templates"),
198
- help="Template directory (default: templates)",
456
+ default=None,
457
+ help=(
458
+ "Base directory for resolving the model and test data paths "
459
+ "(example: tool --model-base-dir /data model.onnx --test-data-dir inputs)"
460
+ ),
199
461
  )
462
+ verify_parser.add_argument("model", type=Path, help="Path to the ONNX model")
200
463
  verify_parser.add_argument(
201
464
  "--model-name",
202
465
  type=str,
@@ -219,9 +482,10 @@ def _build_parser() -> argparse.ArgumentParser:
219
482
  ),
220
483
  )
221
484
  verify_parser.add_argument(
222
- "--large-temp-threshold-bytes",
485
+ "--large-temp-threshold",
223
486
  type=int,
224
487
  default=1024,
488
+ dest="large_temp_threshold_bytes",
225
489
  help=(
226
490
  "Mark temporary buffers larger than this threshold as static "
227
491
  "(default: 1024)"
@@ -230,10 +494,10 @@ def _build_parser() -> argparse.ArgumentParser:
230
494
  verify_parser.add_argument(
231
495
  "--large-weight-threshold",
232
496
  type=int,
233
- default=1024,
497
+ default=100 * 1024,
234
498
  help=(
235
- "Store weights larger than this element count in a binary file "
236
- "(default: 1024)"
499
+ "Store weights in a binary file once the cumulative byte size "
500
+ "exceeds this threshold (default: 102400)"
237
501
  ),
238
502
  )
239
503
  verify_parser.add_argument(
@@ -245,30 +509,100 @@ def _build_parser() -> argparse.ArgumentParser:
245
509
  "(default: use random testbench inputs)"
246
510
  ),
247
511
  )
512
+ verify_parser.add_argument(
513
+ "--temp-dir-root",
514
+ type=Path,
515
+ default=None,
516
+ help=(
517
+ "Root directory in which to create a temporary verification "
518
+ "directory (default: system temp dir)"
519
+ ),
520
+ )
521
+ verify_parser.add_argument(
522
+ "--temp-dir",
523
+ type=Path,
524
+ default=None,
525
+ help=(
526
+ "Exact directory to use for temporary verification files "
527
+ "(default: create a temporary directory)"
528
+ ),
529
+ )
530
+ verify_parser.add_argument(
531
+ "--keep-temp-dir",
532
+ action="store_true",
533
+ help="Keep the temporary verification directory (default: delete it)",
534
+ )
248
535
  verify_parser.add_argument(
249
536
  "--max-ulp",
250
537
  type=int,
251
538
  default=100,
252
539
  help="Maximum allowed ULP difference for floating outputs (default: 100)",
253
540
  )
541
+ verify_parser.add_argument(
542
+ "--atol-eps",
543
+ type=float,
544
+ default=1.0,
545
+ help=(
546
+ "Absolute tolerance as a multiple of machine epsilon for ULP checks "
547
+ "(default: 1.0)"
548
+ ),
549
+ )
254
550
  verify_parser.add_argument(
255
551
  "--runtime",
256
552
  choices=("onnxruntime", "onnx-reference"),
257
- default="onnx-reference",
553
+ default="onnxruntime",
258
554
  help=(
259
- "Runtime backend for verification (default: onnx-reference; "
555
+ "Runtime backend for verification (default: onnxruntime; "
260
556
  "options: onnxruntime, onnx-reference)"
261
557
  ),
262
558
  )
559
+ verify_parser.add_argument(
560
+ "--expected-checksum",
561
+ type=str,
562
+ default=None,
563
+ help=(
564
+ "Expected generated C checksum (sha256). When it matches the "
565
+ "computed checksum, verification exits early with CHECKSUM."
566
+ ),
567
+ )
263
568
  add_restrict_flags(verify_parser)
569
+ add_fp32_accumulation_strategy_flag(verify_parser)
570
+ add_fp16_accumulation_strategy_flag(verify_parser)
264
571
  return parser
265
572
 
266
573
 
574
+ def _resolve_with_base_dir(base_dir: Path, path: Path) -> Path:
575
+ if path.is_absolute():
576
+ return path
577
+ return Path(os.path.normpath(os.path.join(base_dir, path)))
578
+
579
+
580
+ def _apply_base_dir(
581
+ args: argparse.Namespace, parser: argparse.ArgumentParser
582
+ ) -> None:
583
+ model_base_dir: Path | None = args.model_base_dir
584
+ if model_base_dir is None:
585
+ return
586
+ if not model_base_dir.exists() or not model_base_dir.is_dir():
587
+ parser.error(
588
+ f"--model-base-dir {model_base_dir} does not exist or is not a directory"
589
+ )
590
+ path_fields = ("model", "test_data_dir")
591
+ for field in path_fields:
592
+ value = getattr(args, field, None)
593
+ if value is None:
594
+ continue
595
+ if not isinstance(value, Path):
596
+ continue
597
+ setattr(args, field, _resolve_with_base_dir(model_base_dir, value))
598
+
599
+
267
600
  def main(argv: Sequence[str] | None = None) -> int:
268
601
  logging.basicConfig(level=logging.INFO)
269
602
  parser = _build_parser()
270
603
  args = parser.parse_args(argv)
271
604
  args.command_line = _format_command_line(argv)
605
+ _apply_base_dir(args, parser)
272
606
 
273
607
  if args.command == "compile":
274
608
  return _handle_compile(args)
@@ -279,27 +613,28 @@ def main(argv: Sequence[str] | None = None) -> int:
279
613
 
280
614
 
281
615
  def _handle_compile(args: argparse.Namespace) -> int:
616
+ reporter = _VerifyReporter(color_mode=args.color)
282
617
  model_path: Path = args.model
283
618
  output_path: Path = args.output or model_path.with_suffix(".c")
284
619
  model_name = args.model_name or "model"
285
- generated, data_source, weight_data, error = _compile_model(args)
620
+ generated, data_source, weight_data, error = _compile_model(
621
+ args, reporter=reporter
622
+ )
286
623
  if error:
287
- LOGGER.error("Failed to compile %s: %s", model_path, error)
624
+ reporter.info("")
625
+ reporter.result(error, ok=False)
288
626
  return 1
289
627
 
290
628
  output_path.parent.mkdir(parents=True, exist_ok=True)
291
629
  output_path.write_text(generated or "", encoding="utf-8")
292
- LOGGER.info("Wrote C source to %s", output_path)
293
630
  if data_source is not None:
294
631
  data_path = output_path.with_name(
295
632
  f"{output_path.stem}_data{output_path.suffix}"
296
633
  )
297
634
  data_path.write_text(data_source, encoding="utf-8")
298
- LOGGER.info("Wrote data source to %s", data_path)
299
635
  if weight_data is not None:
300
636
  weights_path = output_path.with_name(f"{model_name}.bin")
301
637
  weights_path.write_bytes(weight_data)
302
- LOGGER.info("Wrote weights binary to %s", weights_path)
303
638
  return 0
304
639
 
305
640
 
@@ -307,23 +642,50 @@ def _compile_model(
307
642
  args: argparse.Namespace,
308
643
  *,
309
644
  testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
645
+ reporter: _VerifyReporter | None = None,
310
646
  ) -> tuple[str | None, str | None, bytes | None, str | None]:
311
647
  model_path: Path = args.model
312
648
  model_name = args.model_name or "model"
649
+ active_reporter = reporter or _NullVerifyReporter()
650
+ load_started = active_reporter.start_step(
651
+ f"Loading model {model_path.name}"
652
+ )
653
+ timings: dict[str, float] = {}
654
+ try:
655
+ model, model_checksum = _load_model_and_checksum(model_path)
656
+ active_reporter.step_ok(load_started)
657
+ except OSError as exc:
658
+ active_reporter.step_fail(str(exc))
659
+ return None, None, None, str(exc)
660
+ operators = _collect_model_operators(model)
661
+ opset_version = _model_opset_version(model)
662
+ _report_model_details(
663
+ active_reporter,
664
+ model_path=model_path,
665
+ model_checksum=model_checksum,
666
+ operators=operators,
667
+ opset_version=opset_version,
668
+ node_count=len(model.graph.node),
669
+ initializer_count=len(model.graph.initializer),
670
+ input_count=len(model.graph.input),
671
+ output_count=len(model.graph.output),
672
+ )
673
+ active_reporter.info("")
674
+ codegen_started = active_reporter.start_step("Generating C code")
313
675
  try:
314
- model_checksum = _model_checksum(model_path)
315
- model = onnx.load_model(model_path)
316
676
  options = CompilerOptions(
317
- template_dir=args.template_dir,
318
677
  model_name=model_name,
319
678
  emit_testbench=args.emit_testbench,
320
679
  command_line=args.command_line,
321
680
  model_checksum=model_checksum,
322
681
  restrict_arrays=args.restrict_arrays,
682
+ fp32_accumulation_strategy=args.fp32_accumulation_strategy,
683
+ fp16_accumulation_strategy=args.fp16_accumulation_strategy,
323
684
  truncate_weights_after=args.truncate_weights_after,
324
685
  large_temp_threshold_bytes=args.large_temp_threshold_bytes,
325
686
  large_weight_threshold=args.large_weight_threshold,
326
687
  testbench_inputs=testbench_inputs,
688
+ timings=timings,
327
689
  )
328
690
  compiler = Compiler(options)
329
691
  if args.emit_data_file:
@@ -333,8 +695,26 @@ def _compile_model(
333
695
  else:
334
696
  generated, weight_data = compiler.compile_with_weight_data(model)
335
697
  data_source = None
336
- except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
698
+ active_reporter.step_ok(codegen_started)
699
+ if args.verbose:
700
+ _report_codegen_timings(active_reporter, timings=timings)
701
+ except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
702
+ active_reporter.step_fail(str(exc))
337
703
  return None, None, None, str(exc)
704
+ output_path: Path = args.output or model_path.with_suffix(".c")
705
+ artifacts = [(str(output_path), len(generated.encode("utf-8")))]
706
+ if data_source is not None:
707
+ data_path = output_path.with_name(
708
+ f"{output_path.stem}_data{output_path.suffix}"
709
+ )
710
+ artifacts.append((str(data_path), len(data_source.encode("utf-8"))))
711
+ if weight_data is not None:
712
+ weights_path = output_path.with_name(f"{model_name}.bin")
713
+ artifacts.append((str(weights_path), len(weight_data)))
714
+ _report_generated_artifacts(active_reporter, artifacts=artifacts)
715
+ active_reporter.info(
716
+ f" Generated checksum (sha256): {_generated_checksum(generated)}"
717
+ )
338
718
  return generated, data_source, weight_data, None
339
719
 
340
720
 
@@ -363,21 +743,27 @@ def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str]
363
743
  if env_cc:
364
744
  return resolve_tokens(shlex.split(env_cc))
365
745
  for candidate in ("cc", "gcc", "clang"):
366
- resolved = shutil.which(candidate)
367
- if resolved:
368
- return maybe_prefix_ccache([resolved])
746
+ if shutil.which(candidate):
747
+ return maybe_prefix_ccache([candidate])
369
748
  return None
370
749
 
371
750
 
372
751
  def _handle_verify(args: argparse.Namespace) -> int:
373
- success_message, error, _operators = _verify_model(
374
- args, include_build_details=True
375
- )
752
+ reporter = _VerifyReporter(color_mode=args.color)
753
+ (
754
+ success_message,
755
+ error,
756
+ _operators,
757
+ _opset_version,
758
+ generated_checksum,
759
+ ) = _verify_model(args, include_build_details=True, reporter=reporter)
376
760
  if error is not None:
377
- LOGGER.error("Verification failed: %s", error)
761
+ reporter.info("")
762
+ reporter.result(error, ok=False)
378
763
  return 1
379
764
  if success_message:
380
- LOGGER.info("%s", success_message)
765
+ reporter.info("")
766
+ reporter.result(success_message, ok=True)
381
767
  return 0
382
768
 
383
769
 
@@ -385,12 +771,9 @@ def _verify_model(
385
771
  args: argparse.Namespace,
386
772
  *,
387
773
  include_build_details: bool,
388
- ) -> tuple[str | None, str | None, list[str]]:
389
- import numpy as np
390
-
391
- def log_step(step: str, started_at: float) -> None:
392
- duration = time.perf_counter() - started_at
393
- LOGGER.info("verify step %s: %.3fs", step, duration)
774
+ reporter: _VerifyReporter | None = None,
775
+ ) -> tuple[str | None, str | None, list[str], int | None, str | None]:
776
+ active_reporter = reporter or _NullVerifyReporter()
394
777
 
395
778
  def describe_exit_code(returncode: int) -> str:
396
779
  if returncode >= 0:
@@ -404,54 +787,176 @@ def _verify_model(
404
787
 
405
788
  model_path: Path = args.model
406
789
  model_name = args.model_name or "model"
407
- model_checksum = _model_checksum(model_path)
790
+ model, model_checksum = _load_model_and_checksum(model_path)
408
791
  compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
409
792
  if compiler_cmd is None:
410
793
  return (
411
794
  None,
412
795
  "No C compiler found (set --cc or CC environment variable).",
413
796
  [],
797
+ None,
798
+ None,
414
799
  )
800
+ temp_dir_root: Path | None = args.temp_dir_root
801
+ explicit_temp_dir: Path | None = args.temp_dir
802
+ if temp_dir_root is not None and explicit_temp_dir is not None:
803
+ return (
804
+ None,
805
+ "Cannot set both --temp-dir-root and --temp-dir.",
806
+ operators,
807
+ opset_version,
808
+ generated_checksum,
809
+ )
810
+ if temp_dir_root is not None:
811
+ if temp_dir_root.exists() and not temp_dir_root.is_dir():
812
+ return (
813
+ None,
814
+ f"Verification temp dir root is not a directory: {temp_dir_root}",
815
+ operators,
816
+ opset_version,
817
+ generated_checksum,
818
+ )
819
+ temp_dir_root.mkdir(parents=True, exist_ok=True)
820
+ if explicit_temp_dir is not None:
821
+ if explicit_temp_dir.exists() and not explicit_temp_dir.is_dir():
822
+ return (
823
+ None,
824
+ f"Verification temp dir is not a directory: {explicit_temp_dir}",
825
+ operators,
826
+ opset_version,
827
+ generated_checksum,
828
+ )
829
+ temp_dir: tempfile.TemporaryDirectory | None = None
830
+ cleanup_created_dir = False
831
+ if explicit_temp_dir is not None:
832
+ temp_path = explicit_temp_dir
833
+ if not temp_path.exists():
834
+ temp_path.mkdir(parents=True, exist_ok=True)
835
+ cleanup_created_dir = not args.keep_temp_dir
836
+ elif args.keep_temp_dir:
837
+ temp_path = Path(
838
+ tempfile.mkdtemp(
839
+ dir=str(temp_dir_root) if temp_dir_root is not None else None
840
+ )
841
+ )
842
+ else:
843
+ temp_dir = tempfile.TemporaryDirectory(
844
+ dir=str(temp_dir_root) if temp_dir_root is not None else None
845
+ )
846
+ temp_path = Path(temp_dir.name)
847
+ keep_label = (
848
+ "--keep-temp-dir set" if args.keep_temp_dir else "--keep-temp-dir not set"
849
+ )
850
+ active_reporter.note(
851
+ f"Using temporary folder [{keep_label}]: {temp_path}"
852
+ )
853
+ active_reporter.info("")
854
+ load_started = active_reporter.start_step(f"Loading model {model_path.name}")
415
855
  try:
416
- model = onnx.load_model(model_path)
856
+ model, model_checksum = _load_model_and_checksum(model_path)
417
857
  except OSError as exc:
418
- return None, str(exc), []
858
+ active_reporter.step_fail(str(exc))
859
+ return None, str(exc), [], None, None
860
+ active_reporter.step_ok(load_started)
419
861
 
420
862
  operators = _collect_model_operators(model)
421
- operators_display = ", ".join(operators) if operators else "(none)"
422
- LOGGER.info("verify operators: %s", operators_display)
863
+ opset_version = _model_opset_version(model)
864
+ _report_model_details(
865
+ active_reporter,
866
+ model_path=model_path,
867
+ model_checksum=model_checksum,
868
+ operators=operators,
869
+ opset_version=opset_version,
870
+ node_count=len(model.graph.node),
871
+ initializer_count=len(model.graph.initializer),
872
+ input_count=len(model.graph.input),
873
+ output_count=len(model.graph.output),
874
+ )
423
875
 
876
+ timings: dict[str, float] = {}
424
877
  try:
425
- testbench_inputs = _load_test_data_inputs(model, args.test_data_dir)
878
+ active_reporter.info("")
879
+ codegen_started = active_reporter.start_step("Generating C code")
880
+ testbench_inputs, testbench_optional_inputs = _load_test_data_inputs(
881
+ model, args.test_data_dir
882
+ )
883
+ testbench_outputs = _load_test_data_outputs(model, args.test_data_dir)
426
884
  options = CompilerOptions(
427
- template_dir=args.template_dir,
428
885
  model_name=model_name,
429
886
  emit_testbench=True,
430
- command_line=args.command_line,
887
+ command_line=None,
431
888
  model_checksum=model_checksum,
432
889
  restrict_arrays=args.restrict_arrays,
890
+ fp32_accumulation_strategy=args.fp32_accumulation_strategy,
891
+ fp16_accumulation_strategy=args.fp16_accumulation_strategy,
433
892
  truncate_weights_after=args.truncate_weights_after,
434
893
  large_temp_threshold_bytes=args.large_temp_threshold_bytes,
435
894
  large_weight_threshold=args.large_weight_threshold,
436
895
  testbench_inputs=testbench_inputs,
896
+ testbench_optional_inputs=testbench_optional_inputs,
897
+ timings=timings,
437
898
  )
438
899
  compiler = Compiler(options)
439
- codegen_started = time.perf_counter()
440
900
  generated, weight_data = compiler.compile_with_weight_data(model)
441
- log_step("codegen", codegen_started)
901
+ active_reporter.step_ok(codegen_started)
902
+ if args.verbose:
903
+ _report_codegen_timings(active_reporter, timings=timings)
904
+ artifacts = [("model.c", len(generated.encode("utf-8")))]
905
+ if weight_data is not None:
906
+ artifacts.append((f"{model_name}.bin", len(weight_data)))
907
+ _report_generated_artifacts(active_reporter, artifacts=artifacts)
442
908
  except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
443
- return None, str(exc), operators
909
+ active_reporter.step_fail(str(exc))
910
+ return None, str(exc), operators, opset_version, None
911
+ generated_checksum = _generated_checksum(generated)
912
+ active_reporter.info(f" Generated checksum (sha256): {generated_checksum}")
913
+ expected_checksum = args.expected_checksum
914
+ if expected_checksum and expected_checksum == generated_checksum:
915
+ return "CHECKSUM", None, operators, opset_version, generated_checksum
444
916
 
445
917
  try:
446
918
  graph = import_onnx(model)
447
919
  output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
448
920
  input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
449
921
  except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
450
- return None, f"Failed to resolve model dtype: {exc}", operators
922
+ return (
923
+ None,
924
+ f"Failed to resolve model dtype: {exc}",
925
+ operators,
926
+ opset_version,
927
+ None,
928
+ )
451
929
 
452
- with tempfile.TemporaryDirectory() as temp_dir:
453
- temp_path = Path(temp_dir)
454
- LOGGER.info("verify temp dir: %s", temp_path)
930
+ def _cleanup_temp() -> None:
931
+ if temp_dir is None and not cleanup_created_dir:
932
+ return
933
+ if temp_dir is None:
934
+ shutil.rmtree(temp_path)
935
+ else:
936
+ temp_dir.cleanup()
937
+
938
+ try:
939
+ payload: dict[str, Any] | None = None
940
+ testbench_input_path: Path | None = None
941
+ if testbench_inputs:
942
+ input_order = [value.name for value in graph.inputs]
943
+ testbench_input_path = temp_path / "testbench_inputs.bin"
944
+ with testbench_input_path.open("wb") as handle:
945
+ for name in input_order:
946
+ array = testbench_inputs.get(name)
947
+ if array is None:
948
+ return (
949
+ None,
950
+ f"Missing testbench input data for {name}.",
951
+ operators,
952
+ opset_version,
953
+ generated_checksum,
954
+ )
955
+ dtype = input_dtypes[name].np_dtype
956
+ blob = np.ascontiguousarray(
957
+ array.astype(dtype, copy=False)
958
+ ).tobytes(order="C")
959
+ handle.write(blob)
455
960
  c_path = temp_path / "model.c"
456
961
  weights_path = temp_path / f"{model_name}.bin"
457
962
  exe_path = temp_path / "model"
@@ -459,126 +964,302 @@ def _verify_model(
459
964
  if weight_data is not None:
460
965
  weights_path.write_bytes(weight_data)
461
966
  try:
462
- compile_started = time.perf_counter()
463
967
  compile_cmd = [
464
968
  *compiler_cmd,
465
969
  "-std=c99",
466
- "-O2",
467
- str(c_path),
970
+ "-O1",
971
+ "-fsanitize=address,undefined",
972
+ "-Wall",
973
+ "-Werror",
974
+ str(c_path.name),
468
975
  "-o",
469
- str(exe_path),
976
+ str(exe_path.name),
470
977
  "-lm",
471
978
  ]
472
- LOGGER.info("verify compile command: %s", shlex.join(compile_cmd))
979
+ active_reporter.info("")
980
+ compile_started = active_reporter.start_step("Compiling C code")
473
981
  subprocess.run(
474
982
  compile_cmd,
475
983
  check=True,
476
984
  capture_output=True,
477
985
  text=True,
986
+ cwd=temp_path,
987
+ )
988
+ active_reporter.step_ok(compile_started)
989
+ active_reporter.info(
990
+ f" Compile command: {shlex.join(compile_cmd)}"
478
991
  )
479
- log_step("compile", compile_started)
992
+ active_reporter.info("")
993
+ if args.test_data_dir is not None:
994
+ active_reporter.info(
995
+ f"Verifying using test data set: {args.test_data_dir.name}"
996
+ )
997
+ else:
998
+ active_reporter.info(
999
+ "Verifying using generated random inputs"
1000
+ )
480
1001
  except subprocess.CalledProcessError as exc:
481
1002
  message = "Failed to build testbench."
482
1003
  if include_build_details:
483
1004
  details = exc.stderr.strip()
484
1005
  if details:
485
1006
  message = f"{message} {details}"
486
- return None, message, operators
1007
+ active_reporter.step_fail(message)
1008
+ return None, message, operators, opset_version, generated_checksum
487
1009
  try:
488
- run_started = time.perf_counter()
1010
+ run_started = active_reporter.start_step(
1011
+ " Running generated binary"
1012
+ )
1013
+ run_cmd = [str(exe_path)]
1014
+ if testbench_input_path is not None:
1015
+ run_cmd.append(str(testbench_input_path))
489
1016
  result = subprocess.run(
490
- [str(exe_path)],
1017
+ run_cmd,
491
1018
  check=True,
492
1019
  capture_output=True,
493
1020
  text=True,
494
1021
  cwd=temp_path,
495
1022
  )
496
- log_step("run", run_started)
1023
+ active_reporter.step_ok(run_started)
1024
+ result_json_path = temp_path / "testbench.json"
1025
+ result_json_path.write_text(result.stdout, encoding="utf-8")
1026
+ try:
1027
+ payload = json.loads(result_json_path.read_text(encoding="utf-8"))
1028
+ except json.JSONDecodeError as exc:
1029
+ return (
1030
+ None,
1031
+ f"Failed to parse testbench JSON: {exc}",
1032
+ operators,
1033
+ opset_version,
1034
+ generated_checksum,
1035
+ )
497
1036
  except subprocess.CalledProcessError as exc:
1037
+ active_reporter.step_fail(describe_exit_code(exc.returncode))
498
1038
  return None, (
499
1039
  "Testbench execution failed: " + describe_exit_code(exc.returncode)
500
- ), operators
1040
+ ), operators, opset_version, generated_checksum
1041
+ if payload is None:
1042
+ return (
1043
+ None,
1044
+ "Failed to parse testbench JSON: missing output.",
1045
+ operators,
1046
+ opset_version,
1047
+ generated_checksum,
1048
+ )
501
1049
 
502
- try:
503
- payload = json.loads(result.stdout)
504
- except json.JSONDecodeError as exc:
505
- return None, f"Failed to parse testbench JSON: {exc}", operators
506
-
507
- if testbench_inputs:
508
- inputs = {
509
- name: values.astype(input_dtypes[name].np_dtype, copy=False)
510
- for name, values in testbench_inputs.items()
511
- }
512
- else:
513
- inputs = {
514
- name: decode_testbench_array(
515
- value["data"], input_dtypes[name].np_dtype
1050
+ if testbench_inputs:
1051
+ inputs = {
1052
+ name: values.astype(input_dtypes[name].np_dtype, copy=False)
1053
+ for name, values in testbench_inputs.items()
1054
+ }
1055
+ else:
1056
+ inputs = {
1057
+ name: decode_testbench_array(
1058
+ value["data"], input_dtypes[name].np_dtype
1059
+ )
1060
+ for name, value in payload["inputs"].items()
1061
+ }
1062
+ runtime_outputs: dict[str, np.ndarray] | None = None
1063
+ if testbench_outputs is not None:
1064
+ runtime_outputs = {
1065
+ name: output.astype(output_dtypes[name].np_dtype, copy=False)
1066
+ for name, output in testbench_outputs.items()
1067
+ }
1068
+ else:
1069
+ runtime_name = args.runtime
1070
+ custom_domains = sorted(
1071
+ {
1072
+ opset.domain
1073
+ for opset in model.opset_import
1074
+ if opset.domain not in {"", "ai.onnx"}
1075
+ }
1076
+ )
1077
+ if runtime_name == "onnx-reference" and custom_domains:
1078
+ active_reporter.note(
1079
+ "Runtime: switching to onnxruntime for custom domains "
1080
+ f"{', '.join(custom_domains)}"
1081
+ )
1082
+ runtime_name = "onnxruntime"
1083
+ runtime_started = active_reporter.start_step(
1084
+ f" Running {runtime_name} [--runtime={args.runtime}]"
516
1085
  )
517
- for name, value in payload["inputs"].items()
1086
+ try:
1087
+ if runtime_name == "onnxruntime":
1088
+ import onnxruntime as ort
1089
+
1090
+ sess_options = make_deterministic_session_options(ort)
1091
+ sess = ort.InferenceSession(
1092
+ model.SerializeToString(),
1093
+ sess_options=sess_options,
1094
+ providers=["CPUExecutionProvider"],
1095
+ )
1096
+ runtime_outputs_list = sess.run(None, inputs)
1097
+ else:
1098
+ from onnx.reference import ReferenceEvaluator
1099
+
1100
+ with deterministic_reference_runtime():
1101
+ evaluator = ReferenceEvaluator(model)
1102
+ runtime_outputs_list = evaluator.run(None, inputs)
1103
+ except Exception as exc:
1104
+ active_reporter.step_fail(str(exc))
1105
+ message = str(exc)
1106
+ if runtime_name == "onnxruntime" and "NOT_IMPLEMENTED" in message:
1107
+ active_reporter.note(
1108
+ f"Skipping verification for {model_path}: "
1109
+ "ONNX Runtime does not support the model "
1110
+ f"({message})"
1111
+ )
1112
+ return "", None, operators, opset_version, generated_checksum
1113
+ return (
1114
+ None,
1115
+ f"{runtime_name} failed to run {model_path}: {message}",
1116
+ operators,
1117
+ opset_version,
1118
+ generated_checksum,
1119
+ )
1120
+ active_reporter.step_ok(runtime_started)
1121
+ runtime_outputs = {
1122
+ value.name: output
1123
+ for value, output in zip(graph.outputs, runtime_outputs_list)
1124
+ }
1125
+ nondeterministic_ops = sorted(
1126
+ set(operators).intersection(_NONDETERMINISTIC_OPERATORS)
1127
+ )
1128
+ if nondeterministic_ops:
1129
+ active_reporter.note(
1130
+ "Skipping output comparison for non-deterministic operator(s): "
1131
+ f"{', '.join(nondeterministic_ops)}"
1132
+ )
1133
+ return (
1134
+ "OK (non-deterministic output)",
1135
+ None,
1136
+ operators,
1137
+ opset_version,
1138
+ generated_checksum,
1139
+ )
1140
+ payload_outputs = payload.get("outputs", {})
1141
+ max_ulp = 0
1142
+ worst_diff: _WorstDiff | None = None
1143
+ max_abs_diff: float | int = 0
1144
+ worst_abs_diff: _WorstAbsDiff | None = None
1145
+ output_nodes = {
1146
+ output_name: node
1147
+ for node in graph.nodes
1148
+ for output_name in node.outputs
518
1149
  }
519
- runtime_name = args.runtime
520
- runtime_started = time.perf_counter()
521
- try:
522
- if runtime_name == "onnxruntime":
523
- import onnxruntime as ort
524
-
525
- sess_options = make_deterministic_session_options(ort)
526
- sess = ort.InferenceSession(
527
- model.SerializeToString(),
528
- sess_options=sess_options,
529
- providers=["CPUExecutionProvider"],
1150
+ active_reporter.start_step(
1151
+ f" Comparing outputs [--max-ulp={args.max_ulp}]"
1152
+ )
1153
+ try:
1154
+ for value in graph.outputs:
1155
+ runtime_out = runtime_outputs[value.name]
1156
+ output_payload = payload_outputs.get(value.name)
1157
+ if output_payload is None:
1158
+ raise AssertionError(
1159
+ f"Missing output {value.name} in testbench data"
1160
+ )
1161
+ info = output_dtypes[value.name]
1162
+ output_data = decode_testbench_array(
1163
+ output_payload["data"], info.np_dtype
1164
+ ).astype(info.np_dtype, copy=False)
1165
+ runtime_out = runtime_out.astype(info.np_dtype, copy=False)
1166
+ output_data = output_data.reshape(runtime_out.shape)
1167
+ if np.issubdtype(info.np_dtype, np.floating):
1168
+ output_max, output_worst = worst_ulp_diff(
1169
+ output_data,
1170
+ runtime_out,
1171
+ atol_eps=args.atol_eps,
1172
+ )
1173
+ if output_max > max_ulp:
1174
+ max_ulp = output_max
1175
+ if output_worst is not None:
1176
+ node = output_nodes.get(value.name)
1177
+ worst_diff = _WorstDiff(
1178
+ output_name=value.name,
1179
+ node_name=node.name if node else None,
1180
+ index=output_worst[0],
1181
+ got=float(output_worst[1]),
1182
+ reference=float(output_worst[2]),
1183
+ ulp=output_max,
1184
+ )
1185
+ else:
1186
+ output_max, output_worst = _worst_abs_diff(
1187
+ output_data, runtime_out
1188
+ )
1189
+ if output_max > max_abs_diff:
1190
+ max_abs_diff = output_max
1191
+ if output_worst is not None:
1192
+ node = output_nodes.get(value.name)
1193
+ worst_abs_diff = _WorstAbsDiff(
1194
+ output_name=value.name,
1195
+ node_name=node.name if node else None,
1196
+ index=output_worst[0],
1197
+ got=output_worst[1],
1198
+ reference=output_worst[2],
1199
+ abs_diff=output_max,
1200
+ )
1201
+ except AssertionError as exc:
1202
+ active_reporter.step_fail(str(exc))
1203
+ return None, str(exc), operators, opset_version, generated_checksum
1204
+ if max_abs_diff > 0:
1205
+ active_reporter.step_fail(f"max abs diff {max_abs_diff}")
1206
+ if worst_abs_diff is not None:
1207
+ node_label = worst_abs_diff.node_name or "(unknown)"
1208
+ index_display = ", ".join(str(dim) for dim in worst_abs_diff.index)
1209
+ active_reporter.info(
1210
+ " Worst diff: output="
1211
+ f"{worst_abs_diff.output_name} node={node_label} "
1212
+ f"index=[{index_display}] "
1213
+ f"got={worst_abs_diff.got} "
1214
+ f"ref={worst_abs_diff.reference} "
1215
+ f"abs_diff={worst_abs_diff.abs_diff}"
1216
+ )
1217
+ return (
1218
+ None,
1219
+ f"Arrays are not equal (max abs diff {max_abs_diff})",
1220
+ operators,
1221
+ opset_version,
1222
+ generated_checksum,
530
1223
  )
531
- runtime_outputs = sess.run(None, inputs)
532
- else:
533
- from onnx.reference import ReferenceEvaluator
534
-
535
- evaluator = ReferenceEvaluator(model)
536
- runtime_outputs = evaluator.run(None, inputs)
537
- except Exception as exc:
538
- log_step(runtime_name, runtime_started)
539
- message = str(exc)
540
- if runtime_name == "onnxruntime" and "NOT_IMPLEMENTED" in message:
541
- LOGGER.warning(
542
- "Skipping verification for %s: ONNX Runtime does not support the model (%s)",
543
- model_path,
544
- message,
1224
+ if max_ulp > args.max_ulp:
1225
+ active_reporter.step_fail(f"max ULP {max_ulp}")
1226
+ if worst_diff is not None:
1227
+ node_label = worst_diff.node_name or "(unknown)"
1228
+ index_display = ", ".join(str(dim) for dim in worst_diff.index)
1229
+ active_reporter.info(
1230
+ " Worst diff: output="
1231
+ f"{worst_diff.output_name} node={node_label} "
1232
+ f"index=[{index_display}] "
1233
+ f"got={worst_diff.got:.8g} "
1234
+ f"ref={worst_diff.reference:.8g} "
1235
+ f"ulp={worst_diff.ulp}"
1236
+ )
1237
+ return (
1238
+ None,
1239
+ f"Out of tolerance (max ULP {max_ulp})",
1240
+ operators,
1241
+ opset_version,
1242
+ generated_checksum,
545
1243
  )
546
- return "", None, operators
1244
+ active_reporter.step_ok_simple()
1245
+ active_reporter.info(f" Maximum ULP: {max_ulp}")
547
1246
  return (
1247
+ format_success_message(max_ulp),
548
1248
  None,
549
- f"{runtime_name} failed to run {model_path}: {message}",
550
1249
  operators,
1250
+ opset_version,
1251
+ generated_checksum,
551
1252
  )
552
- log_step(runtime_name, runtime_started)
553
- payload_outputs = payload.get("outputs", {})
554
- max_ulp = 0
555
- try:
556
- for value, runtime_out in zip(graph.outputs, runtime_outputs):
557
- output_payload = payload_outputs.get(value.name)
558
- if output_payload is None:
559
- raise AssertionError(f"Missing output {value.name} in testbench data")
560
- info = output_dtypes[value.name]
561
- output_data = decode_testbench_array(
562
- output_payload["data"], info.np_dtype
563
- ).astype(info.np_dtype, copy=False)
564
- runtime_out = runtime_out.astype(info.np_dtype, copy=False)
565
- output_data = output_data.reshape(runtime_out.shape)
566
- if np.issubdtype(info.np_dtype, np.floating):
567
- max_ulp = max(max_ulp, max_ulp_diff(output_data, runtime_out))
568
- else:
569
- np.testing.assert_array_equal(output_data, runtime_out)
570
- except AssertionError as exc:
571
- return None, str(exc), operators
572
- if max_ulp > args.max_ulp:
573
- return None, f"Out of tolerance (max ULP {max_ulp})", operators
574
- return format_success_message(max_ulp), None, operators
1253
+ finally:
1254
+ active_reporter.info("")
1255
+ _cleanup_temp()
575
1256
 
576
1257
 
577
1258
  def _load_test_data_inputs(
578
1259
  model: onnx.ModelProto, data_dir: Path | None
579
- ) -> dict[str, "np.ndarray"] | None:
1260
+ ) -> tuple[dict[str, "np.ndarray"] | None, dict[str, bool] | None]:
580
1261
  if data_dir is None:
581
- return None
1262
+ return None, None
582
1263
  if not data_dir.exists():
583
1264
  raise CodegenError(f"Test data directory not found: {data_dir}")
584
1265
  input_files = sorted(
@@ -587,26 +1268,115 @@ def _load_test_data_inputs(
587
1268
  )
588
1269
  if not input_files:
589
1270
  raise CodegenError(f"No input_*.pb files found in {data_dir}")
590
- if len(input_files) != len(model.graph.input):
1271
+ initializer_names = {init.name for init in model.graph.initializer}
1272
+ initializer_names.update(
1273
+ sparse_init.name for sparse_init in model.graph.sparse_initializer
1274
+ )
1275
+ model_inputs = [
1276
+ value_info
1277
+ for value_info in model.graph.input
1278
+ if value_info.name not in initializer_names
1279
+ ]
1280
+ if len(input_files) != len(model_inputs):
591
1281
  raise CodegenError(
592
1282
  "Test data input count does not match model inputs: "
593
- f"{len(input_files)} vs {len(model.graph.input)}."
1283
+ f"{len(input_files)} vs {len(model_inputs)}."
594
1284
  )
595
- for value_info in model.graph.input:
1285
+ for value_info in model_inputs:
596
1286
  value_kind = value_info.type.WhichOneof("value")
597
- if value_kind != "tensor_type":
1287
+ if value_kind not in {"tensor_type", "optional_type"}:
598
1288
  LOGGER.warning(
599
1289
  "Skipping test data load for non-tensor input %s (type %s).",
600
1290
  value_info.name,
601
1291
  value_kind or "unknown",
602
1292
  )
603
- return None
1293
+ return None, None
604
1294
  inputs: dict[str, np.ndarray] = {}
1295
+ optional_flags: dict[str, bool] = {}
605
1296
  for index, path in enumerate(input_files):
1297
+ value_info = model_inputs[index]
1298
+ value_kind = value_info.type.WhichOneof("value")
1299
+ if value_kind == "tensor_type":
1300
+ tensor = onnx.TensorProto()
1301
+ tensor.ParseFromString(path.read_bytes())
1302
+ inputs[value_info.name] = numpy_helper.to_array(tensor)
1303
+ continue
1304
+ optional = onnx.OptionalProto()
1305
+ optional.ParseFromString(path.read_bytes())
1306
+ elem_type = value_info.type.optional_type.elem_type
1307
+ if elem_type.WhichOneof("value") != "tensor_type":
1308
+ LOGGER.warning(
1309
+ "Skipping test data load for non-tensor optional input %s.",
1310
+ value_info.name,
1311
+ )
1312
+ return None, None
1313
+ tensor_type = elem_type.tensor_type
1314
+ if optional.HasField("tensor_value"):
1315
+ inputs[value_info.name] = numpy_helper.to_array(
1316
+ optional.tensor_value
1317
+ )
1318
+ optional_flags[value_info.name] = True
1319
+ continue
1320
+ if not tensor_type.HasField("elem_type"):
1321
+ raise CodegenError(
1322
+ f"Optional input {value_info.name} is missing elem_type."
1323
+ )
1324
+ dtype_info = onnx._mapping.TENSOR_TYPE_MAP.get(tensor_type.elem_type)
1325
+ if dtype_info is None:
1326
+ raise CodegenError(
1327
+ f"Optional input {value_info.name} has unsupported elem_type."
1328
+ )
1329
+ shape: list[int] = []
1330
+ for dim in tensor_type.shape.dim:
1331
+ if dim.HasField("dim_value"):
1332
+ shape.append(dim.dim_value)
1333
+ elif dim.HasField("dim_param"):
1334
+ shape.append(1)
1335
+ else:
1336
+ raise CodegenError(
1337
+ f"Optional input {value_info.name} has unknown shape."
1338
+ )
1339
+ inputs[value_info.name] = np.zeros(
1340
+ tuple(shape), dtype=dtype_info.np_dtype
1341
+ )
1342
+ optional_flags[value_info.name] = False
1343
+ return inputs, optional_flags
1344
+
1345
+
1346
+ def _load_test_data_outputs(
1347
+ model: onnx.ModelProto, data_dir: Path | None
1348
+ ) -> dict[str, "np.ndarray"] | None:
1349
+ if data_dir is None:
1350
+ return None
1351
+ if not data_dir.exists():
1352
+ raise CodegenError(f"Test data directory not found: {data_dir}")
1353
+ output_files = sorted(
1354
+ data_dir.glob("output_*.pb"),
1355
+ key=lambda path: int(path.stem.split("_")[-1]),
1356
+ )
1357
+ if not output_files:
1358
+ return None
1359
+ model_outputs = list(model.graph.output)
1360
+ if len(output_files) != len(model_outputs):
1361
+ raise CodegenError(
1362
+ "Test data output count does not match model outputs: "
1363
+ f"{len(output_files)} vs {len(model_outputs)}."
1364
+ )
1365
+ for value_info in model_outputs:
1366
+ value_kind = value_info.type.WhichOneof("value")
1367
+ if value_kind != "tensor_type":
1368
+ LOGGER.warning(
1369
+ "Skipping test data load for non-tensor output %s (type %s).",
1370
+ value_info.name,
1371
+ value_kind or "unknown",
1372
+ )
1373
+ return None
1374
+ outputs: dict[str, np.ndarray] = {}
1375
+ for index, path in enumerate(output_files):
606
1376
  tensor = onnx.TensorProto()
607
1377
  tensor.ParseFromString(path.read_bytes())
608
- inputs[model.graph.input[index].name] = numpy_helper.to_array(tensor)
609
- return inputs
1378
+ outputs[model_outputs[index].name] = numpy_helper.to_array(tensor)
1379
+ return outputs
610
1380
 
611
1381
 
612
1382
  def _format_command_line(argv: Sequence[str] | None) -> str:
@@ -615,15 +1385,97 @@ def _format_command_line(argv: Sequence[str] | None) -> str:
615
1385
  args = [str(arg) for arg in argv[1:]]
616
1386
  if not args:
617
1387
  return ""
618
- return shlex.join(args)
1388
+ filtered: list[str] = []
1389
+ skip_next = False
1390
+ for arg in args:
1391
+ if skip_next:
1392
+ skip_next = False
1393
+ continue
1394
+ if arg == "--expected-checksum":
1395
+ skip_next = True
1396
+ continue
1397
+ if arg.startswith("--expected-checksum="):
1398
+ continue
1399
+ filtered.append(arg)
1400
+ if not filtered:
1401
+ return ""
1402
+ return shlex.join(filtered)
1403
+
1404
+
1405
+ def _load_model_and_checksum(
1406
+ model_path: Path,
1407
+ ) -> tuple[onnx.ModelProto, str]:
1408
+ model_bytes = model_path.read_bytes()
1409
+ digest = hashlib.sha256()
1410
+ digest.update(model_bytes)
1411
+ model = onnx.load_model_from_string(model_bytes)
1412
+ return model, digest.hexdigest()
619
1413
 
620
1414
 
621
- def _model_checksum(model_path: Path) -> str:
1415
+ def _generated_checksum(generated: str) -> str:
622
1416
  digest = hashlib.sha256()
623
- digest.update(model_path.read_bytes())
1417
+ digest.update(generated.encode("utf-8"))
624
1418
  return digest.hexdigest()
625
1419
 
626
1420
 
1421
+ def _report_model_details(
1422
+ reporter: _VerifyReporter,
1423
+ *,
1424
+ model_path: Path,
1425
+ model_checksum: str,
1426
+ operators: Sequence[str],
1427
+ opset_version: int | None,
1428
+ node_count: int,
1429
+ initializer_count: int,
1430
+ input_count: int,
1431
+ output_count: int,
1432
+ ) -> None:
1433
+ operators_display = ", ".join(operators) if operators else "(none)"
1434
+ reporter.info(
1435
+ f" Model operators ({len(operators)}): {operators_display}"
1436
+ )
1437
+ reporter.info(
1438
+ f" Model file size: {_format_artifact_size(model_path.stat().st_size)}"
1439
+ )
1440
+ reporter.info(f" Model checksum (sha256): {model_checksum}")
1441
+ if opset_version is not None:
1442
+ reporter.info(f" Opset version: {opset_version}")
1443
+ reporter.info(
1444
+ " Counts: "
1445
+ f"nodes={node_count}, "
1446
+ f"initializers={initializer_count}, "
1447
+ f"inputs={input_count}, "
1448
+ f"outputs={output_count}"
1449
+ )
1450
+
1451
+
1452
+ def _report_codegen_timings(
1453
+ reporter: _VerifyReporter, *, timings: Mapping[str, float]
1454
+ ) -> None:
1455
+ if not timings:
1456
+ return
1457
+ order = [
1458
+ ("import_onnx", "import"),
1459
+ ("concretize_shapes", "concretize"),
1460
+ ("resolve_testbench_inputs", "testbench"),
1461
+ ("collect_variable_dims", "var_dims"),
1462
+ ("lower_model", "lower"),
1463
+ ("emit_model", "emit"),
1464
+ ("emit_model_with_data_file", "emit_data"),
1465
+ ("collect_weight_data", "weights"),
1466
+ ]
1467
+ seen = set()
1468
+ parts: list[str] = []
1469
+ for key, label in order:
1470
+ if key not in timings:
1471
+ continue
1472
+ parts.append(f"{label}={timings[key]:.3f}s")
1473
+ seen.add(key)
1474
+ for key in sorted(k for k in timings if k not in seen):
1475
+ parts.append(f"{key}={timings[key]:.3f}s")
1476
+ reporter.info(f" Codegen timing: {', '.join(parts)}")
1477
+
1478
+
627
1479
  def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
628
1480
  operators: list[str] = []
629
1481
  seen: set[str] = set()
@@ -634,3 +1486,14 @@ def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
634
1486
  seen.add(op_name)
635
1487
  operators.append(op_name)
636
1488
  return operators
1489
+
1490
+
1491
+ def _model_opset_version(model: onnx.ModelProto, *, domain: str = "") -> int | None:
1492
+ if not model.opset_import:
1493
+ return None
1494
+ domains = (domain,) if domain else ("", "ai.onnx")
1495
+ for target_domain in domains:
1496
+ for opset in model.opset_import:
1497
+ if opset.domain == target_domain:
1498
+ return opset.version
1499
+ return None