emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/cli.py
CHANGED
|
@@ -7,15 +7,16 @@ import logging
|
|
|
7
7
|
import os
|
|
8
8
|
import shlex
|
|
9
9
|
import shutil
|
|
10
|
+
import signal
|
|
10
11
|
import subprocess
|
|
11
12
|
import sys
|
|
12
13
|
import tempfile
|
|
13
14
|
import time
|
|
14
|
-
import signal
|
|
15
|
-
from pathlib import Path
|
|
16
15
|
from dataclasses import dataclass
|
|
17
|
-
from
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Mapping, Sequence, TextIO
|
|
18
18
|
|
|
19
|
+
import numpy as np
|
|
19
20
|
import onnx
|
|
20
21
|
from onnx import numpy_helper
|
|
21
22
|
|
|
@@ -23,25 +24,220 @@ from ._build_info import BUILD_DATE, GIT_VERSION
|
|
|
23
24
|
from .compiler import Compiler, CompilerOptions
|
|
24
25
|
from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
|
|
25
26
|
from .onnx_import import import_onnx
|
|
27
|
+
from .determinism import deterministic_reference_runtime
|
|
26
28
|
from .onnxruntime_utils import make_deterministic_session_options
|
|
27
29
|
from .testbench import decode_testbench_array
|
|
28
|
-
from .verification import format_success_message,
|
|
30
|
+
from .verification import format_success_message, worst_ulp_diff
|
|
29
31
|
|
|
30
32
|
LOGGER = logging.getLogger(__name__)
|
|
31
|
-
|
|
32
|
-
if TYPE_CHECKING:
|
|
33
|
-
import numpy as np
|
|
33
|
+
_NONDETERMINISTIC_OPERATORS = {"Bernoulli"}
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
@dataclass(frozen=True)
|
|
37
37
|
class CliResult:
|
|
38
38
|
exit_code: int
|
|
39
39
|
command_line: str
|
|
40
|
-
|
|
41
|
-
success_message: str | None = None
|
|
40
|
+
result: str | None = None
|
|
42
41
|
generated: str | None = None
|
|
43
42
|
data_source: str | None = None
|
|
44
43
|
operators: list[str] | None = None
|
|
44
|
+
opset_version: int | None = None
|
|
45
|
+
generated_checksum: str | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class _WorstDiff:
|
|
50
|
+
output_name: str
|
|
51
|
+
node_name: str | None
|
|
52
|
+
index: tuple[int, ...]
|
|
53
|
+
got: float
|
|
54
|
+
reference: float
|
|
55
|
+
ulp: int
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(frozen=True)
|
|
59
|
+
class _WorstAbsDiff:
|
|
60
|
+
output_name: str
|
|
61
|
+
node_name: str | None
|
|
62
|
+
index: tuple[int, ...]
|
|
63
|
+
got: object
|
|
64
|
+
reference: object
|
|
65
|
+
abs_diff: float | int
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _VerifyReporter:
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
stream: TextIO | None = None,
|
|
72
|
+
*,
|
|
73
|
+
color_mode: str = "auto",
|
|
74
|
+
) -> None:
|
|
75
|
+
self._stream = stream or sys.stdout
|
|
76
|
+
self._use_color = self._should_use_color(color_mode)
|
|
77
|
+
|
|
78
|
+
def _should_use_color(self, color_mode: str) -> bool:
|
|
79
|
+
if color_mode == "always":
|
|
80
|
+
return True
|
|
81
|
+
if color_mode == "never":
|
|
82
|
+
return False
|
|
83
|
+
if not hasattr(self._stream, "isatty"):
|
|
84
|
+
return False
|
|
85
|
+
return bool(self._stream.isatty())
|
|
86
|
+
|
|
87
|
+
def _color(self, text: str, code: str) -> str:
|
|
88
|
+
if not self._use_color:
|
|
89
|
+
return text
|
|
90
|
+
return f"\x1b[{code}m{text}\x1b[0m"
|
|
91
|
+
|
|
92
|
+
def start_step(self, label: str) -> float:
|
|
93
|
+
print(f"{label} ...", end=" ", file=self._stream, flush=True)
|
|
94
|
+
return time.perf_counter()
|
|
95
|
+
|
|
96
|
+
def step_ok(self, started_at: float) -> None:
|
|
97
|
+
duration = time.perf_counter() - started_at
|
|
98
|
+
ok = self._color("OK", "32")
|
|
99
|
+
dim = self._color(f"({duration:.3f}s)", "90")
|
|
100
|
+
print(f"{ok} {dim}", file=self._stream)
|
|
101
|
+
|
|
102
|
+
def step_ok_simple(self) -> None:
|
|
103
|
+
ok = self._color("OK", "32")
|
|
104
|
+
print(ok, file=self._stream)
|
|
105
|
+
|
|
106
|
+
def step_ok_detail(self, detail: str) -> None:
|
|
107
|
+
ok = self._color("OK", "32")
|
|
108
|
+
dim = self._color(f"({detail})", "90")
|
|
109
|
+
print(f"{ok} {dim}", file=self._stream)
|
|
110
|
+
|
|
111
|
+
def step_fail(self, reason: str) -> None:
|
|
112
|
+
fail = self._color("FAIL", "31")
|
|
113
|
+
print(f"{fail} ({reason})", file=self._stream)
|
|
114
|
+
|
|
115
|
+
def note(self, message: str) -> None:
|
|
116
|
+
label = self._color("Note:", "33")
|
|
117
|
+
print(f"{label} {message}", file=self._stream)
|
|
118
|
+
|
|
119
|
+
def info(self, message: str) -> None:
|
|
120
|
+
print(message, file=self._stream)
|
|
121
|
+
|
|
122
|
+
def result(self, message: str, *, ok: bool) -> None:
|
|
123
|
+
colored = self._color(message, "32" if ok else "31")
|
|
124
|
+
print(f"Result: {colored}", file=self._stream)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class _NullVerifyReporter(_VerifyReporter):
|
|
128
|
+
def __init__(self) -> None:
|
|
129
|
+
super().__init__(stream=sys.stdout, color_mode="never")
|
|
130
|
+
|
|
131
|
+
def start_step(self, label: str) -> float:
|
|
132
|
+
return time.perf_counter()
|
|
133
|
+
|
|
134
|
+
def step_ok(self, started_at: float) -> None:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
def step_ok_simple(self) -> None:
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
def step_ok_detail(self, detail: str) -> None:
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def step_fail(self, reason: str) -> None:
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def note(self, message: str) -> None:
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
def info(self, message: str) -> None:
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
def result(self, message: str, *, ok: bool) -> None:
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _format_artifact_size(size_bytes: int) -> str:
|
|
157
|
+
if size_bytes < 1024:
|
|
158
|
+
return f"{size_bytes} bytes"
|
|
159
|
+
return f"{size_bytes / 1024:.1f} KiB"
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _report_generated_artifacts(
|
|
163
|
+
reporter: _VerifyReporter,
|
|
164
|
+
*,
|
|
165
|
+
artifacts: Sequence[tuple[str, int]],
|
|
166
|
+
) -> None:
|
|
167
|
+
for name, size_bytes in artifacts:
|
|
168
|
+
reporter.info(f" {name} ({_format_artifact_size(size_bytes)})")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _worst_ulp_diff(
|
|
172
|
+
actual: "np.ndarray", expected: "np.ndarray"
|
|
173
|
+
) -> tuple[int, tuple[tuple[int, ...], float, float] | None]:
|
|
174
|
+
if actual.shape != expected.shape:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Shape mismatch for ULP calculation: {actual.shape} vs {expected.shape}"
|
|
177
|
+
)
|
|
178
|
+
if not np.issubdtype(expected.dtype, np.floating):
|
|
179
|
+
return 0, None
|
|
180
|
+
if actual.size == 0:
|
|
181
|
+
return 0, None
|
|
182
|
+
dtype = expected.dtype
|
|
183
|
+
actual_cast = actual.astype(dtype, copy=False)
|
|
184
|
+
expected_cast = expected.astype(dtype, copy=False)
|
|
185
|
+
max_diff = 0
|
|
186
|
+
worst: tuple[tuple[int, ...], float, float] | None = None
|
|
187
|
+
iterator = np.nditer(
|
|
188
|
+
[actual_cast, expected_cast], flags=["refs_ok", "multi_index"]
|
|
189
|
+
)
|
|
190
|
+
for actual_value, expected_value in iterator:
|
|
191
|
+
actual_scalar = float(actual_value[()])
|
|
192
|
+
expected_scalar = float(expected_value[()])
|
|
193
|
+
diff = ulp_intdiff_float(actual_value[()], expected_value[()])
|
|
194
|
+
if diff > max_diff:
|
|
195
|
+
max_diff = diff
|
|
196
|
+
worst = (
|
|
197
|
+
iterator.multi_index,
|
|
198
|
+
actual_scalar,
|
|
199
|
+
expected_scalar,
|
|
200
|
+
)
|
|
201
|
+
return max_diff, worst
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _worst_abs_diff(
|
|
205
|
+
actual: "np.ndarray", expected: "np.ndarray"
|
|
206
|
+
) -> tuple[float | int, tuple[tuple[int, ...], object, object] | None]:
|
|
207
|
+
if actual.shape != expected.shape:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Shape mismatch for diff calculation: {actual.shape} vs {expected.shape}"
|
|
210
|
+
)
|
|
211
|
+
if actual.size == 0:
|
|
212
|
+
return 0, None
|
|
213
|
+
dtype = expected.dtype
|
|
214
|
+
actual_cast = actual.astype(dtype, copy=False)
|
|
215
|
+
expected_cast = expected.astype(dtype, copy=False)
|
|
216
|
+
max_diff: float | int = 0
|
|
217
|
+
worst: tuple[tuple[int, ...], object, object] | None = None
|
|
218
|
+
iterator = np.nditer(
|
|
219
|
+
[actual_cast, expected_cast], flags=["refs_ok", "multi_index"]
|
|
220
|
+
)
|
|
221
|
+
for actual_value, expected_value in iterator:
|
|
222
|
+
actual_scalar = actual_value[()]
|
|
223
|
+
expected_scalar = expected_value[()]
|
|
224
|
+
if actual_scalar == expected_scalar:
|
|
225
|
+
continue
|
|
226
|
+
try:
|
|
227
|
+
if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_):
|
|
228
|
+
diff: float | int = abs(int(actual_scalar) - int(expected_scalar))
|
|
229
|
+
else:
|
|
230
|
+
diff = float(abs(actual_scalar - expected_scalar))
|
|
231
|
+
except Exception:
|
|
232
|
+
diff = 1
|
|
233
|
+
if diff > max_diff:
|
|
234
|
+
max_diff = diff
|
|
235
|
+
worst = (
|
|
236
|
+
iterator.multi_index,
|
|
237
|
+
actual_scalar,
|
|
238
|
+
expected_scalar,
|
|
239
|
+
)
|
|
240
|
+
return max_diff, worst
|
|
45
241
|
|
|
46
242
|
|
|
47
243
|
def run_cli_command(
|
|
@@ -56,18 +252,26 @@ def run_cli_command(
|
|
|
56
252
|
parser = _build_parser()
|
|
57
253
|
args = parser.parse_args(parse_argv)
|
|
58
254
|
args.command_line = _format_command_line(raw_argv)
|
|
255
|
+
_apply_base_dir(args, parser)
|
|
59
256
|
|
|
60
257
|
try:
|
|
61
258
|
if args.command != "compile":
|
|
62
|
-
|
|
63
|
-
|
|
259
|
+
(
|
|
260
|
+
success_message,
|
|
261
|
+
error,
|
|
262
|
+
operators,
|
|
263
|
+
opset_version,
|
|
264
|
+
generated_checksum,
|
|
265
|
+
) = _verify_model(
|
|
266
|
+
args, include_build_details=False, reporter=_NullVerifyReporter()
|
|
64
267
|
)
|
|
65
268
|
return CliResult(
|
|
66
269
|
exit_code=0 if error is None else 1,
|
|
67
270
|
command_line=args.command_line,
|
|
68
|
-
|
|
69
|
-
success_message=success_message,
|
|
271
|
+
result=error or success_message,
|
|
70
272
|
operators=operators,
|
|
273
|
+
opset_version=opset_version,
|
|
274
|
+
generated_checksum=generated_checksum,
|
|
71
275
|
)
|
|
72
276
|
generated, data_source, error = _compile_model(
|
|
73
277
|
args, testbench_inputs=testbench_inputs
|
|
@@ -76,12 +280,12 @@ def run_cli_command(
|
|
|
76
280
|
return CliResult(
|
|
77
281
|
exit_code=1,
|
|
78
282
|
command_line=args.command_line,
|
|
79
|
-
|
|
283
|
+
result=error,
|
|
80
284
|
)
|
|
81
285
|
return CliResult(
|
|
82
286
|
exit_code=0,
|
|
83
287
|
command_line=args.command_line,
|
|
84
|
-
|
|
288
|
+
result="",
|
|
85
289
|
generated=generated,
|
|
86
290
|
data_source=data_source,
|
|
87
291
|
)
|
|
@@ -90,7 +294,7 @@ def run_cli_command(
|
|
|
90
294
|
return CliResult(
|
|
91
295
|
exit_code=1,
|
|
92
296
|
command_line=args.command_line,
|
|
93
|
-
|
|
297
|
+
result=str(exc),
|
|
94
298
|
)
|
|
95
299
|
|
|
96
300
|
|
|
@@ -102,6 +306,24 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
102
306
|
parser = argparse.ArgumentParser(prog="emx-onnx-cgen", description=description)
|
|
103
307
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
104
308
|
|
|
309
|
+
def add_color_flag(subparser: argparse.ArgumentParser) -> None:
|
|
310
|
+
subparser.add_argument(
|
|
311
|
+
"--color",
|
|
312
|
+
choices=("auto", "always", "never"),
|
|
313
|
+
default="auto",
|
|
314
|
+
help=(
|
|
315
|
+
"Colorize CLI output (default: auto; options: auto, always, never)"
|
|
316
|
+
),
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def add_verbose_flag(subparser: argparse.ArgumentParser) -> None:
|
|
320
|
+
subparser.add_argument(
|
|
321
|
+
"--verbose",
|
|
322
|
+
"-v",
|
|
323
|
+
action="store_true",
|
|
324
|
+
help="Enable verbose logging (includes codegen timing).",
|
|
325
|
+
)
|
|
326
|
+
|
|
105
327
|
def add_restrict_flags(subparser: argparse.ArgumentParser) -> None:
|
|
106
328
|
restrict_group = subparser.add_mutually_exclusive_group()
|
|
107
329
|
restrict_group.add_argument(
|
|
@@ -118,9 +340,47 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
118
340
|
)
|
|
119
341
|
subparser.set_defaults(restrict_arrays=True)
|
|
120
342
|
|
|
343
|
+
def add_fp32_accumulation_strategy_flag(
|
|
344
|
+
subparser: argparse.ArgumentParser,
|
|
345
|
+
) -> None:
|
|
346
|
+
subparser.add_argument(
|
|
347
|
+
"--fp32-accumulation-strategy",
|
|
348
|
+
choices=("simple", "fp64"),
|
|
349
|
+
default="simple",
|
|
350
|
+
help=(
|
|
351
|
+
"Accumulation strategy for float32 inputs "
|
|
352
|
+
"(simple uses float32, fp64 uses double; default: simple)"
|
|
353
|
+
),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def add_fp16_accumulation_strategy_flag(
|
|
357
|
+
subparser: argparse.ArgumentParser,
|
|
358
|
+
) -> None:
|
|
359
|
+
subparser.add_argument(
|
|
360
|
+
"--fp16-accumulation-strategy",
|
|
361
|
+
choices=("simple", "fp32"),
|
|
362
|
+
default="fp32",
|
|
363
|
+
help=(
|
|
364
|
+
"Accumulation strategy for float16 inputs "
|
|
365
|
+
"(simple uses float16, fp32 uses float; default: fp32)"
|
|
366
|
+
),
|
|
367
|
+
)
|
|
368
|
+
|
|
121
369
|
compile_parser = subparsers.add_parser(
|
|
122
370
|
"compile", help="Compile an ONNX model into C source"
|
|
123
371
|
)
|
|
372
|
+
add_color_flag(compile_parser)
|
|
373
|
+
add_verbose_flag(compile_parser)
|
|
374
|
+
compile_parser.add_argument(
|
|
375
|
+
"--model-base-dir",
|
|
376
|
+
"-B",
|
|
377
|
+
type=Path,
|
|
378
|
+
default=None,
|
|
379
|
+
help=(
|
|
380
|
+
"Base directory for resolving the model path "
|
|
381
|
+
"(example: tool --model-base-dir /data model.onnx)"
|
|
382
|
+
),
|
|
383
|
+
)
|
|
124
384
|
compile_parser.add_argument("model", type=Path, help="Path to the ONNX model")
|
|
125
385
|
compile_parser.add_argument(
|
|
126
386
|
"output",
|
|
@@ -132,12 +392,6 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
132
392
|
"e.g., model.onnx -> model.c)"
|
|
133
393
|
),
|
|
134
394
|
)
|
|
135
|
-
compile_parser.add_argument(
|
|
136
|
-
"--template-dir",
|
|
137
|
-
type=Path,
|
|
138
|
-
default=Path("templates"),
|
|
139
|
-
help="Template directory (default: templates)",
|
|
140
|
-
)
|
|
141
395
|
compile_parser.add_argument(
|
|
142
396
|
"--model-name",
|
|
143
397
|
type=str,
|
|
@@ -167,9 +421,10 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
167
421
|
),
|
|
168
422
|
)
|
|
169
423
|
compile_parser.add_argument(
|
|
170
|
-
"--large-temp-threshold
|
|
424
|
+
"--large-temp-threshold",
|
|
171
425
|
type=int,
|
|
172
426
|
default=1024,
|
|
427
|
+
dest="large_temp_threshold_bytes",
|
|
173
428
|
help=(
|
|
174
429
|
"Mark temporary buffers larger than this threshold as static "
|
|
175
430
|
"(default: 1024)"
|
|
@@ -178,25 +433,33 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
178
433
|
compile_parser.add_argument(
|
|
179
434
|
"--large-weight-threshold",
|
|
180
435
|
type=int,
|
|
181
|
-
default=
|
|
436
|
+
default=100 * 1024,
|
|
182
437
|
help=(
|
|
183
|
-
"Store weights
|
|
184
|
-
"(default:
|
|
438
|
+
"Store weights in a binary file once the cumulative byte size "
|
|
439
|
+
"exceeds this threshold (default: 102400; set to 0 to disable)"
|
|
185
440
|
),
|
|
186
441
|
)
|
|
187
442
|
add_restrict_flags(compile_parser)
|
|
443
|
+
add_fp32_accumulation_strategy_flag(compile_parser)
|
|
444
|
+
add_fp16_accumulation_strategy_flag(compile_parser)
|
|
188
445
|
|
|
189
446
|
verify_parser = subparsers.add_parser(
|
|
190
447
|
"verify",
|
|
191
448
|
help="Compile an ONNX model and verify outputs against ONNX Runtime",
|
|
192
449
|
)
|
|
193
|
-
verify_parser
|
|
450
|
+
add_color_flag(verify_parser)
|
|
451
|
+
add_verbose_flag(verify_parser)
|
|
194
452
|
verify_parser.add_argument(
|
|
195
|
-
"--
|
|
453
|
+
"--model-base-dir",
|
|
454
|
+
"-B",
|
|
196
455
|
type=Path,
|
|
197
|
-
default=
|
|
198
|
-
help=
|
|
456
|
+
default=None,
|
|
457
|
+
help=(
|
|
458
|
+
"Base directory for resolving the model and test data paths "
|
|
459
|
+
"(example: tool --model-base-dir /data model.onnx --test-data-dir inputs)"
|
|
460
|
+
),
|
|
199
461
|
)
|
|
462
|
+
verify_parser.add_argument("model", type=Path, help="Path to the ONNX model")
|
|
200
463
|
verify_parser.add_argument(
|
|
201
464
|
"--model-name",
|
|
202
465
|
type=str,
|
|
@@ -219,9 +482,10 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
219
482
|
),
|
|
220
483
|
)
|
|
221
484
|
verify_parser.add_argument(
|
|
222
|
-
"--large-temp-threshold
|
|
485
|
+
"--large-temp-threshold",
|
|
223
486
|
type=int,
|
|
224
487
|
default=1024,
|
|
488
|
+
dest="large_temp_threshold_bytes",
|
|
225
489
|
help=(
|
|
226
490
|
"Mark temporary buffers larger than this threshold as static "
|
|
227
491
|
"(default: 1024)"
|
|
@@ -230,10 +494,10 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
230
494
|
verify_parser.add_argument(
|
|
231
495
|
"--large-weight-threshold",
|
|
232
496
|
type=int,
|
|
233
|
-
default=1024,
|
|
497
|
+
default=100 * 1024,
|
|
234
498
|
help=(
|
|
235
|
-
"Store weights
|
|
236
|
-
"(default:
|
|
499
|
+
"Store weights in a binary file once the cumulative byte size "
|
|
500
|
+
"exceeds this threshold (default: 102400)"
|
|
237
501
|
),
|
|
238
502
|
)
|
|
239
503
|
verify_parser.add_argument(
|
|
@@ -245,30 +509,100 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
245
509
|
"(default: use random testbench inputs)"
|
|
246
510
|
),
|
|
247
511
|
)
|
|
512
|
+
verify_parser.add_argument(
|
|
513
|
+
"--temp-dir-root",
|
|
514
|
+
type=Path,
|
|
515
|
+
default=None,
|
|
516
|
+
help=(
|
|
517
|
+
"Root directory in which to create a temporary verification "
|
|
518
|
+
"directory (default: system temp dir)"
|
|
519
|
+
),
|
|
520
|
+
)
|
|
521
|
+
verify_parser.add_argument(
|
|
522
|
+
"--temp-dir",
|
|
523
|
+
type=Path,
|
|
524
|
+
default=None,
|
|
525
|
+
help=(
|
|
526
|
+
"Exact directory to use for temporary verification files "
|
|
527
|
+
"(default: create a temporary directory)"
|
|
528
|
+
),
|
|
529
|
+
)
|
|
530
|
+
verify_parser.add_argument(
|
|
531
|
+
"--keep-temp-dir",
|
|
532
|
+
action="store_true",
|
|
533
|
+
help="Keep the temporary verification directory (default: delete it)",
|
|
534
|
+
)
|
|
248
535
|
verify_parser.add_argument(
|
|
249
536
|
"--max-ulp",
|
|
250
537
|
type=int,
|
|
251
538
|
default=100,
|
|
252
539
|
help="Maximum allowed ULP difference for floating outputs (default: 100)",
|
|
253
540
|
)
|
|
541
|
+
verify_parser.add_argument(
|
|
542
|
+
"--atol-eps",
|
|
543
|
+
type=float,
|
|
544
|
+
default=1.0,
|
|
545
|
+
help=(
|
|
546
|
+
"Absolute tolerance as a multiple of machine epsilon for ULP checks "
|
|
547
|
+
"(default: 1.0)"
|
|
548
|
+
),
|
|
549
|
+
)
|
|
254
550
|
verify_parser.add_argument(
|
|
255
551
|
"--runtime",
|
|
256
552
|
choices=("onnxruntime", "onnx-reference"),
|
|
257
|
-
default="
|
|
553
|
+
default="onnxruntime",
|
|
258
554
|
help=(
|
|
259
|
-
"Runtime backend for verification (default:
|
|
555
|
+
"Runtime backend for verification (default: onnxruntime; "
|
|
260
556
|
"options: onnxruntime, onnx-reference)"
|
|
261
557
|
),
|
|
262
558
|
)
|
|
559
|
+
verify_parser.add_argument(
|
|
560
|
+
"--expected-checksum",
|
|
561
|
+
type=str,
|
|
562
|
+
default=None,
|
|
563
|
+
help=(
|
|
564
|
+
"Expected generated C checksum (sha256). When it matches the "
|
|
565
|
+
"computed checksum, verification exits early with CHECKSUM."
|
|
566
|
+
),
|
|
567
|
+
)
|
|
263
568
|
add_restrict_flags(verify_parser)
|
|
569
|
+
add_fp32_accumulation_strategy_flag(verify_parser)
|
|
570
|
+
add_fp16_accumulation_strategy_flag(verify_parser)
|
|
264
571
|
return parser
|
|
265
572
|
|
|
266
573
|
|
|
574
|
+
def _resolve_with_base_dir(base_dir: Path, path: Path) -> Path:
|
|
575
|
+
if path.is_absolute():
|
|
576
|
+
return path
|
|
577
|
+
return Path(os.path.normpath(os.path.join(base_dir, path)))
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _apply_base_dir(
|
|
581
|
+
args: argparse.Namespace, parser: argparse.ArgumentParser
|
|
582
|
+
) -> None:
|
|
583
|
+
model_base_dir: Path | None = args.model_base_dir
|
|
584
|
+
if model_base_dir is None:
|
|
585
|
+
return
|
|
586
|
+
if not model_base_dir.exists() or not model_base_dir.is_dir():
|
|
587
|
+
parser.error(
|
|
588
|
+
f"--model-base-dir {model_base_dir} does not exist or is not a directory"
|
|
589
|
+
)
|
|
590
|
+
path_fields = ("model", "test_data_dir")
|
|
591
|
+
for field in path_fields:
|
|
592
|
+
value = getattr(args, field, None)
|
|
593
|
+
if value is None:
|
|
594
|
+
continue
|
|
595
|
+
if not isinstance(value, Path):
|
|
596
|
+
continue
|
|
597
|
+
setattr(args, field, _resolve_with_base_dir(model_base_dir, value))
|
|
598
|
+
|
|
599
|
+
|
|
267
600
|
def main(argv: Sequence[str] | None = None) -> int:
|
|
268
601
|
logging.basicConfig(level=logging.INFO)
|
|
269
602
|
parser = _build_parser()
|
|
270
603
|
args = parser.parse_args(argv)
|
|
271
604
|
args.command_line = _format_command_line(argv)
|
|
605
|
+
_apply_base_dir(args, parser)
|
|
272
606
|
|
|
273
607
|
if args.command == "compile":
|
|
274
608
|
return _handle_compile(args)
|
|
@@ -279,27 +613,28 @@ def main(argv: Sequence[str] | None = None) -> int:
|
|
|
279
613
|
|
|
280
614
|
|
|
281
615
|
def _handle_compile(args: argparse.Namespace) -> int:
|
|
616
|
+
reporter = _VerifyReporter(color_mode=args.color)
|
|
282
617
|
model_path: Path = args.model
|
|
283
618
|
output_path: Path = args.output or model_path.with_suffix(".c")
|
|
284
619
|
model_name = args.model_name or "model"
|
|
285
|
-
generated, data_source, weight_data, error = _compile_model(
|
|
620
|
+
generated, data_source, weight_data, error = _compile_model(
|
|
621
|
+
args, reporter=reporter
|
|
622
|
+
)
|
|
286
623
|
if error:
|
|
287
|
-
|
|
624
|
+
reporter.info("")
|
|
625
|
+
reporter.result(error, ok=False)
|
|
288
626
|
return 1
|
|
289
627
|
|
|
290
628
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
291
629
|
output_path.write_text(generated or "", encoding="utf-8")
|
|
292
|
-
LOGGER.info("Wrote C source to %s", output_path)
|
|
293
630
|
if data_source is not None:
|
|
294
631
|
data_path = output_path.with_name(
|
|
295
632
|
f"{output_path.stem}_data{output_path.suffix}"
|
|
296
633
|
)
|
|
297
634
|
data_path.write_text(data_source, encoding="utf-8")
|
|
298
|
-
LOGGER.info("Wrote data source to %s", data_path)
|
|
299
635
|
if weight_data is not None:
|
|
300
636
|
weights_path = output_path.with_name(f"{model_name}.bin")
|
|
301
637
|
weights_path.write_bytes(weight_data)
|
|
302
|
-
LOGGER.info("Wrote weights binary to %s", weights_path)
|
|
303
638
|
return 0
|
|
304
639
|
|
|
305
640
|
|
|
@@ -307,23 +642,50 @@ def _compile_model(
|
|
|
307
642
|
args: argparse.Namespace,
|
|
308
643
|
*,
|
|
309
644
|
testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
|
|
645
|
+
reporter: _VerifyReporter | None = None,
|
|
310
646
|
) -> tuple[str | None, str | None, bytes | None, str | None]:
|
|
311
647
|
model_path: Path = args.model
|
|
312
648
|
model_name = args.model_name or "model"
|
|
649
|
+
active_reporter = reporter or _NullVerifyReporter()
|
|
650
|
+
load_started = active_reporter.start_step(
|
|
651
|
+
f"Loading model {model_path.name}"
|
|
652
|
+
)
|
|
653
|
+
timings: dict[str, float] = {}
|
|
654
|
+
try:
|
|
655
|
+
model, model_checksum = _load_model_and_checksum(model_path)
|
|
656
|
+
active_reporter.step_ok(load_started)
|
|
657
|
+
except OSError as exc:
|
|
658
|
+
active_reporter.step_fail(str(exc))
|
|
659
|
+
return None, None, None, str(exc)
|
|
660
|
+
operators = _collect_model_operators(model)
|
|
661
|
+
opset_version = _model_opset_version(model)
|
|
662
|
+
_report_model_details(
|
|
663
|
+
active_reporter,
|
|
664
|
+
model_path=model_path,
|
|
665
|
+
model_checksum=model_checksum,
|
|
666
|
+
operators=operators,
|
|
667
|
+
opset_version=opset_version,
|
|
668
|
+
node_count=len(model.graph.node),
|
|
669
|
+
initializer_count=len(model.graph.initializer),
|
|
670
|
+
input_count=len(model.graph.input),
|
|
671
|
+
output_count=len(model.graph.output),
|
|
672
|
+
)
|
|
673
|
+
active_reporter.info("")
|
|
674
|
+
codegen_started = active_reporter.start_step("Generating C code")
|
|
313
675
|
try:
|
|
314
|
-
model_checksum = _model_checksum(model_path)
|
|
315
|
-
model = onnx.load_model(model_path)
|
|
316
676
|
options = CompilerOptions(
|
|
317
|
-
template_dir=args.template_dir,
|
|
318
677
|
model_name=model_name,
|
|
319
678
|
emit_testbench=args.emit_testbench,
|
|
320
679
|
command_line=args.command_line,
|
|
321
680
|
model_checksum=model_checksum,
|
|
322
681
|
restrict_arrays=args.restrict_arrays,
|
|
682
|
+
fp32_accumulation_strategy=args.fp32_accumulation_strategy,
|
|
683
|
+
fp16_accumulation_strategy=args.fp16_accumulation_strategy,
|
|
323
684
|
truncate_weights_after=args.truncate_weights_after,
|
|
324
685
|
large_temp_threshold_bytes=args.large_temp_threshold_bytes,
|
|
325
686
|
large_weight_threshold=args.large_weight_threshold,
|
|
326
687
|
testbench_inputs=testbench_inputs,
|
|
688
|
+
timings=timings,
|
|
327
689
|
)
|
|
328
690
|
compiler = Compiler(options)
|
|
329
691
|
if args.emit_data_file:
|
|
@@ -333,8 +695,26 @@ def _compile_model(
|
|
|
333
695
|
else:
|
|
334
696
|
generated, weight_data = compiler.compile_with_weight_data(model)
|
|
335
697
|
data_source = None
|
|
336
|
-
|
|
698
|
+
active_reporter.step_ok(codegen_started)
|
|
699
|
+
if args.verbose:
|
|
700
|
+
_report_codegen_timings(active_reporter, timings=timings)
|
|
701
|
+
except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
702
|
+
active_reporter.step_fail(str(exc))
|
|
337
703
|
return None, None, None, str(exc)
|
|
704
|
+
output_path: Path = args.output or model_path.with_suffix(".c")
|
|
705
|
+
artifacts = [(str(output_path), len(generated.encode("utf-8")))]
|
|
706
|
+
if data_source is not None:
|
|
707
|
+
data_path = output_path.with_name(
|
|
708
|
+
f"{output_path.stem}_data{output_path.suffix}"
|
|
709
|
+
)
|
|
710
|
+
artifacts.append((str(data_path), len(data_source.encode("utf-8"))))
|
|
711
|
+
if weight_data is not None:
|
|
712
|
+
weights_path = output_path.with_name(f"{model_name}.bin")
|
|
713
|
+
artifacts.append((str(weights_path), len(weight_data)))
|
|
714
|
+
_report_generated_artifacts(active_reporter, artifacts=artifacts)
|
|
715
|
+
active_reporter.info(
|
|
716
|
+
f" Generated checksum (sha256): {_generated_checksum(generated)}"
|
|
717
|
+
)
|
|
338
718
|
return generated, data_source, weight_data, None
|
|
339
719
|
|
|
340
720
|
|
|
@@ -363,21 +743,27 @@ def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str]
|
|
|
363
743
|
if env_cc:
|
|
364
744
|
return resolve_tokens(shlex.split(env_cc))
|
|
365
745
|
for candidate in ("cc", "gcc", "clang"):
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
return maybe_prefix_ccache([resolved])
|
|
746
|
+
if shutil.which(candidate):
|
|
747
|
+
return maybe_prefix_ccache([candidate])
|
|
369
748
|
return None
|
|
370
749
|
|
|
371
750
|
|
|
372
751
|
def _handle_verify(args: argparse.Namespace) -> int:
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
752
|
+
reporter = _VerifyReporter(color_mode=args.color)
|
|
753
|
+
(
|
|
754
|
+
success_message,
|
|
755
|
+
error,
|
|
756
|
+
_operators,
|
|
757
|
+
_opset_version,
|
|
758
|
+
generated_checksum,
|
|
759
|
+
) = _verify_model(args, include_build_details=True, reporter=reporter)
|
|
376
760
|
if error is not None:
|
|
377
|
-
|
|
761
|
+
reporter.info("")
|
|
762
|
+
reporter.result(error, ok=False)
|
|
378
763
|
return 1
|
|
379
764
|
if success_message:
|
|
380
|
-
|
|
765
|
+
reporter.info("")
|
|
766
|
+
reporter.result(success_message, ok=True)
|
|
381
767
|
return 0
|
|
382
768
|
|
|
383
769
|
|
|
@@ -385,12 +771,9 @@ def _verify_model(
|
|
|
385
771
|
args: argparse.Namespace,
|
|
386
772
|
*,
|
|
387
773
|
include_build_details: bool,
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
def log_step(step: str, started_at: float) -> None:
|
|
392
|
-
duration = time.perf_counter() - started_at
|
|
393
|
-
LOGGER.info("verify step %s: %.3fs", step, duration)
|
|
774
|
+
reporter: _VerifyReporter | None = None,
|
|
775
|
+
) -> tuple[str | None, str | None, list[str], int | None, str | None]:
|
|
776
|
+
active_reporter = reporter or _NullVerifyReporter()
|
|
394
777
|
|
|
395
778
|
def describe_exit_code(returncode: int) -> str:
|
|
396
779
|
if returncode >= 0:
|
|
@@ -404,54 +787,176 @@ def _verify_model(
|
|
|
404
787
|
|
|
405
788
|
model_path: Path = args.model
|
|
406
789
|
model_name = args.model_name or "model"
|
|
407
|
-
model_checksum =
|
|
790
|
+
model, model_checksum = _load_model_and_checksum(model_path)
|
|
408
791
|
compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
|
|
409
792
|
if compiler_cmd is None:
|
|
410
793
|
return (
|
|
411
794
|
None,
|
|
412
795
|
"No C compiler found (set --cc or CC environment variable).",
|
|
413
796
|
[],
|
|
797
|
+
None,
|
|
798
|
+
None,
|
|
414
799
|
)
|
|
800
|
+
temp_dir_root: Path | None = args.temp_dir_root
|
|
801
|
+
explicit_temp_dir: Path | None = args.temp_dir
|
|
802
|
+
if temp_dir_root is not None and explicit_temp_dir is not None:
|
|
803
|
+
return (
|
|
804
|
+
None,
|
|
805
|
+
"Cannot set both --temp-dir-root and --temp-dir.",
|
|
806
|
+
operators,
|
|
807
|
+
opset_version,
|
|
808
|
+
generated_checksum,
|
|
809
|
+
)
|
|
810
|
+
if temp_dir_root is not None:
|
|
811
|
+
if temp_dir_root.exists() and not temp_dir_root.is_dir():
|
|
812
|
+
return (
|
|
813
|
+
None,
|
|
814
|
+
f"Verification temp dir root is not a directory: {temp_dir_root}",
|
|
815
|
+
operators,
|
|
816
|
+
opset_version,
|
|
817
|
+
generated_checksum,
|
|
818
|
+
)
|
|
819
|
+
temp_dir_root.mkdir(parents=True, exist_ok=True)
|
|
820
|
+
if explicit_temp_dir is not None:
|
|
821
|
+
if explicit_temp_dir.exists() and not explicit_temp_dir.is_dir():
|
|
822
|
+
return (
|
|
823
|
+
None,
|
|
824
|
+
f"Verification temp dir is not a directory: {explicit_temp_dir}",
|
|
825
|
+
operators,
|
|
826
|
+
opset_version,
|
|
827
|
+
generated_checksum,
|
|
828
|
+
)
|
|
829
|
+
temp_dir: tempfile.TemporaryDirectory | None = None
|
|
830
|
+
cleanup_created_dir = False
|
|
831
|
+
if explicit_temp_dir is not None:
|
|
832
|
+
temp_path = explicit_temp_dir
|
|
833
|
+
if not temp_path.exists():
|
|
834
|
+
temp_path.mkdir(parents=True, exist_ok=True)
|
|
835
|
+
cleanup_created_dir = not args.keep_temp_dir
|
|
836
|
+
elif args.keep_temp_dir:
|
|
837
|
+
temp_path = Path(
|
|
838
|
+
tempfile.mkdtemp(
|
|
839
|
+
dir=str(temp_dir_root) if temp_dir_root is not None else None
|
|
840
|
+
)
|
|
841
|
+
)
|
|
842
|
+
else:
|
|
843
|
+
temp_dir = tempfile.TemporaryDirectory(
|
|
844
|
+
dir=str(temp_dir_root) if temp_dir_root is not None else None
|
|
845
|
+
)
|
|
846
|
+
temp_path = Path(temp_dir.name)
|
|
847
|
+
keep_label = (
|
|
848
|
+
"--keep-temp-dir set" if args.keep_temp_dir else "--keep-temp-dir not set"
|
|
849
|
+
)
|
|
850
|
+
active_reporter.note(
|
|
851
|
+
f"Using temporary folder [{keep_label}]: {temp_path}"
|
|
852
|
+
)
|
|
853
|
+
active_reporter.info("")
|
|
854
|
+
load_started = active_reporter.start_step(f"Loading model {model_path.name}")
|
|
415
855
|
try:
|
|
416
|
-
model =
|
|
856
|
+
model, model_checksum = _load_model_and_checksum(model_path)
|
|
417
857
|
except OSError as exc:
|
|
418
|
-
|
|
858
|
+
active_reporter.step_fail(str(exc))
|
|
859
|
+
return None, str(exc), [], None, None
|
|
860
|
+
active_reporter.step_ok(load_started)
|
|
419
861
|
|
|
420
862
|
operators = _collect_model_operators(model)
|
|
421
|
-
|
|
422
|
-
|
|
863
|
+
opset_version = _model_opset_version(model)
|
|
864
|
+
_report_model_details(
|
|
865
|
+
active_reporter,
|
|
866
|
+
model_path=model_path,
|
|
867
|
+
model_checksum=model_checksum,
|
|
868
|
+
operators=operators,
|
|
869
|
+
opset_version=opset_version,
|
|
870
|
+
node_count=len(model.graph.node),
|
|
871
|
+
initializer_count=len(model.graph.initializer),
|
|
872
|
+
input_count=len(model.graph.input),
|
|
873
|
+
output_count=len(model.graph.output),
|
|
874
|
+
)
|
|
423
875
|
|
|
876
|
+
timings: dict[str, float] = {}
|
|
424
877
|
try:
|
|
425
|
-
|
|
878
|
+
active_reporter.info("")
|
|
879
|
+
codegen_started = active_reporter.start_step("Generating C code")
|
|
880
|
+
testbench_inputs, testbench_optional_inputs = _load_test_data_inputs(
|
|
881
|
+
model, args.test_data_dir
|
|
882
|
+
)
|
|
883
|
+
testbench_outputs = _load_test_data_outputs(model, args.test_data_dir)
|
|
426
884
|
options = CompilerOptions(
|
|
427
|
-
template_dir=args.template_dir,
|
|
428
885
|
model_name=model_name,
|
|
429
886
|
emit_testbench=True,
|
|
430
|
-
command_line=
|
|
887
|
+
command_line=None,
|
|
431
888
|
model_checksum=model_checksum,
|
|
432
889
|
restrict_arrays=args.restrict_arrays,
|
|
890
|
+
fp32_accumulation_strategy=args.fp32_accumulation_strategy,
|
|
891
|
+
fp16_accumulation_strategy=args.fp16_accumulation_strategy,
|
|
433
892
|
truncate_weights_after=args.truncate_weights_after,
|
|
434
893
|
large_temp_threshold_bytes=args.large_temp_threshold_bytes,
|
|
435
894
|
large_weight_threshold=args.large_weight_threshold,
|
|
436
895
|
testbench_inputs=testbench_inputs,
|
|
896
|
+
testbench_optional_inputs=testbench_optional_inputs,
|
|
897
|
+
timings=timings,
|
|
437
898
|
)
|
|
438
899
|
compiler = Compiler(options)
|
|
439
|
-
codegen_started = time.perf_counter()
|
|
440
900
|
generated, weight_data = compiler.compile_with_weight_data(model)
|
|
441
|
-
|
|
901
|
+
active_reporter.step_ok(codegen_started)
|
|
902
|
+
if args.verbose:
|
|
903
|
+
_report_codegen_timings(active_reporter, timings=timings)
|
|
904
|
+
artifacts = [("model.c", len(generated.encode("utf-8")))]
|
|
905
|
+
if weight_data is not None:
|
|
906
|
+
artifacts.append((f"{model_name}.bin", len(weight_data)))
|
|
907
|
+
_report_generated_artifacts(active_reporter, artifacts=artifacts)
|
|
442
908
|
except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
443
|
-
|
|
909
|
+
active_reporter.step_fail(str(exc))
|
|
910
|
+
return None, str(exc), operators, opset_version, None
|
|
911
|
+
generated_checksum = _generated_checksum(generated)
|
|
912
|
+
active_reporter.info(f" Generated checksum (sha256): {generated_checksum}")
|
|
913
|
+
expected_checksum = args.expected_checksum
|
|
914
|
+
if expected_checksum and expected_checksum == generated_checksum:
|
|
915
|
+
return "CHECKSUM", None, operators, opset_version, generated_checksum
|
|
444
916
|
|
|
445
917
|
try:
|
|
446
918
|
graph = import_onnx(model)
|
|
447
919
|
output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
|
|
448
920
|
input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
|
|
449
921
|
except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
|
|
450
|
-
return
|
|
922
|
+
return (
|
|
923
|
+
None,
|
|
924
|
+
f"Failed to resolve model dtype: {exc}",
|
|
925
|
+
operators,
|
|
926
|
+
opset_version,
|
|
927
|
+
None,
|
|
928
|
+
)
|
|
451
929
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
930
|
+
def _cleanup_temp() -> None:
|
|
931
|
+
if temp_dir is None and not cleanup_created_dir:
|
|
932
|
+
return
|
|
933
|
+
if temp_dir is None:
|
|
934
|
+
shutil.rmtree(temp_path)
|
|
935
|
+
else:
|
|
936
|
+
temp_dir.cleanup()
|
|
937
|
+
|
|
938
|
+
try:
|
|
939
|
+
payload: dict[str, Any] | None = None
|
|
940
|
+
testbench_input_path: Path | None = None
|
|
941
|
+
if testbench_inputs:
|
|
942
|
+
input_order = [value.name for value in graph.inputs]
|
|
943
|
+
testbench_input_path = temp_path / "testbench_inputs.bin"
|
|
944
|
+
with testbench_input_path.open("wb") as handle:
|
|
945
|
+
for name in input_order:
|
|
946
|
+
array = testbench_inputs.get(name)
|
|
947
|
+
if array is None:
|
|
948
|
+
return (
|
|
949
|
+
None,
|
|
950
|
+
f"Missing testbench input data for {name}.",
|
|
951
|
+
operators,
|
|
952
|
+
opset_version,
|
|
953
|
+
generated_checksum,
|
|
954
|
+
)
|
|
955
|
+
dtype = input_dtypes[name].np_dtype
|
|
956
|
+
blob = np.ascontiguousarray(
|
|
957
|
+
array.astype(dtype, copy=False)
|
|
958
|
+
).tobytes(order="C")
|
|
959
|
+
handle.write(blob)
|
|
455
960
|
c_path = temp_path / "model.c"
|
|
456
961
|
weights_path = temp_path / f"{model_name}.bin"
|
|
457
962
|
exe_path = temp_path / "model"
|
|
@@ -459,126 +964,302 @@ def _verify_model(
|
|
|
459
964
|
if weight_data is not None:
|
|
460
965
|
weights_path.write_bytes(weight_data)
|
|
461
966
|
try:
|
|
462
|
-
compile_started = time.perf_counter()
|
|
463
967
|
compile_cmd = [
|
|
464
968
|
*compiler_cmd,
|
|
465
969
|
"-std=c99",
|
|
466
|
-
"-
|
|
467
|
-
|
|
970
|
+
"-O1",
|
|
971
|
+
"-fsanitize=address,undefined",
|
|
972
|
+
"-Wall",
|
|
973
|
+
"-Werror",
|
|
974
|
+
str(c_path.name),
|
|
468
975
|
"-o",
|
|
469
|
-
str(exe_path),
|
|
976
|
+
str(exe_path.name),
|
|
470
977
|
"-lm",
|
|
471
978
|
]
|
|
472
|
-
|
|
979
|
+
active_reporter.info("")
|
|
980
|
+
compile_started = active_reporter.start_step("Compiling C code")
|
|
473
981
|
subprocess.run(
|
|
474
982
|
compile_cmd,
|
|
475
983
|
check=True,
|
|
476
984
|
capture_output=True,
|
|
477
985
|
text=True,
|
|
986
|
+
cwd=temp_path,
|
|
987
|
+
)
|
|
988
|
+
active_reporter.step_ok(compile_started)
|
|
989
|
+
active_reporter.info(
|
|
990
|
+
f" Compile command: {shlex.join(compile_cmd)}"
|
|
478
991
|
)
|
|
479
|
-
|
|
992
|
+
active_reporter.info("")
|
|
993
|
+
if args.test_data_dir is not None:
|
|
994
|
+
active_reporter.info(
|
|
995
|
+
f"Verifying using test data set: {args.test_data_dir.name}"
|
|
996
|
+
)
|
|
997
|
+
else:
|
|
998
|
+
active_reporter.info(
|
|
999
|
+
"Verifying using generated random inputs"
|
|
1000
|
+
)
|
|
480
1001
|
except subprocess.CalledProcessError as exc:
|
|
481
1002
|
message = "Failed to build testbench."
|
|
482
1003
|
if include_build_details:
|
|
483
1004
|
details = exc.stderr.strip()
|
|
484
1005
|
if details:
|
|
485
1006
|
message = f"{message} {details}"
|
|
486
|
-
|
|
1007
|
+
active_reporter.step_fail(message)
|
|
1008
|
+
return None, message, operators, opset_version, generated_checksum
|
|
487
1009
|
try:
|
|
488
|
-
run_started =
|
|
1010
|
+
run_started = active_reporter.start_step(
|
|
1011
|
+
" Running generated binary"
|
|
1012
|
+
)
|
|
1013
|
+
run_cmd = [str(exe_path)]
|
|
1014
|
+
if testbench_input_path is not None:
|
|
1015
|
+
run_cmd.append(str(testbench_input_path))
|
|
489
1016
|
result = subprocess.run(
|
|
490
|
-
|
|
1017
|
+
run_cmd,
|
|
491
1018
|
check=True,
|
|
492
1019
|
capture_output=True,
|
|
493
1020
|
text=True,
|
|
494
1021
|
cwd=temp_path,
|
|
495
1022
|
)
|
|
496
|
-
|
|
1023
|
+
active_reporter.step_ok(run_started)
|
|
1024
|
+
result_json_path = temp_path / "testbench.json"
|
|
1025
|
+
result_json_path.write_text(result.stdout, encoding="utf-8")
|
|
1026
|
+
try:
|
|
1027
|
+
payload = json.loads(result_json_path.read_text(encoding="utf-8"))
|
|
1028
|
+
except json.JSONDecodeError as exc:
|
|
1029
|
+
return (
|
|
1030
|
+
None,
|
|
1031
|
+
f"Failed to parse testbench JSON: {exc}",
|
|
1032
|
+
operators,
|
|
1033
|
+
opset_version,
|
|
1034
|
+
generated_checksum,
|
|
1035
|
+
)
|
|
497
1036
|
except subprocess.CalledProcessError as exc:
|
|
1037
|
+
active_reporter.step_fail(describe_exit_code(exc.returncode))
|
|
498
1038
|
return None, (
|
|
499
1039
|
"Testbench execution failed: " + describe_exit_code(exc.returncode)
|
|
500
|
-
), operators
|
|
1040
|
+
), operators, opset_version, generated_checksum
|
|
1041
|
+
if payload is None:
|
|
1042
|
+
return (
|
|
1043
|
+
None,
|
|
1044
|
+
"Failed to parse testbench JSON: missing output.",
|
|
1045
|
+
operators,
|
|
1046
|
+
opset_version,
|
|
1047
|
+
generated_checksum,
|
|
1048
|
+
)
|
|
501
1049
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
1050
|
+
if testbench_inputs:
|
|
1051
|
+
inputs = {
|
|
1052
|
+
name: values.astype(input_dtypes[name].np_dtype, copy=False)
|
|
1053
|
+
for name, values in testbench_inputs.items()
|
|
1054
|
+
}
|
|
1055
|
+
else:
|
|
1056
|
+
inputs = {
|
|
1057
|
+
name: decode_testbench_array(
|
|
1058
|
+
value["data"], input_dtypes[name].np_dtype
|
|
1059
|
+
)
|
|
1060
|
+
for name, value in payload["inputs"].items()
|
|
1061
|
+
}
|
|
1062
|
+
runtime_outputs: dict[str, np.ndarray] | None = None
|
|
1063
|
+
if testbench_outputs is not None:
|
|
1064
|
+
runtime_outputs = {
|
|
1065
|
+
name: output.astype(output_dtypes[name].np_dtype, copy=False)
|
|
1066
|
+
for name, output in testbench_outputs.items()
|
|
1067
|
+
}
|
|
1068
|
+
else:
|
|
1069
|
+
runtime_name = args.runtime
|
|
1070
|
+
custom_domains = sorted(
|
|
1071
|
+
{
|
|
1072
|
+
opset.domain
|
|
1073
|
+
for opset in model.opset_import
|
|
1074
|
+
if opset.domain not in {"", "ai.onnx"}
|
|
1075
|
+
}
|
|
1076
|
+
)
|
|
1077
|
+
if runtime_name == "onnx-reference" and custom_domains:
|
|
1078
|
+
active_reporter.note(
|
|
1079
|
+
"Runtime: switching to onnxruntime for custom domains "
|
|
1080
|
+
f"{', '.join(custom_domains)}"
|
|
1081
|
+
)
|
|
1082
|
+
runtime_name = "onnxruntime"
|
|
1083
|
+
runtime_started = active_reporter.start_step(
|
|
1084
|
+
f" Running {runtime_name} [--runtime={args.runtime}]"
|
|
516
1085
|
)
|
|
517
|
-
|
|
1086
|
+
try:
|
|
1087
|
+
if runtime_name == "onnxruntime":
|
|
1088
|
+
import onnxruntime as ort
|
|
1089
|
+
|
|
1090
|
+
sess_options = make_deterministic_session_options(ort)
|
|
1091
|
+
sess = ort.InferenceSession(
|
|
1092
|
+
model.SerializeToString(),
|
|
1093
|
+
sess_options=sess_options,
|
|
1094
|
+
providers=["CPUExecutionProvider"],
|
|
1095
|
+
)
|
|
1096
|
+
runtime_outputs_list = sess.run(None, inputs)
|
|
1097
|
+
else:
|
|
1098
|
+
from onnx.reference import ReferenceEvaluator
|
|
1099
|
+
|
|
1100
|
+
with deterministic_reference_runtime():
|
|
1101
|
+
evaluator = ReferenceEvaluator(model)
|
|
1102
|
+
runtime_outputs_list = evaluator.run(None, inputs)
|
|
1103
|
+
except Exception as exc:
|
|
1104
|
+
active_reporter.step_fail(str(exc))
|
|
1105
|
+
message = str(exc)
|
|
1106
|
+
if runtime_name == "onnxruntime" and "NOT_IMPLEMENTED" in message:
|
|
1107
|
+
active_reporter.note(
|
|
1108
|
+
f"Skipping verification for {model_path}: "
|
|
1109
|
+
"ONNX Runtime does not support the model "
|
|
1110
|
+
f"({message})"
|
|
1111
|
+
)
|
|
1112
|
+
return "", None, operators, opset_version, generated_checksum
|
|
1113
|
+
return (
|
|
1114
|
+
None,
|
|
1115
|
+
f"{runtime_name} failed to run {model_path}: {message}",
|
|
1116
|
+
operators,
|
|
1117
|
+
opset_version,
|
|
1118
|
+
generated_checksum,
|
|
1119
|
+
)
|
|
1120
|
+
active_reporter.step_ok(runtime_started)
|
|
1121
|
+
runtime_outputs = {
|
|
1122
|
+
value.name: output
|
|
1123
|
+
for value, output in zip(graph.outputs, runtime_outputs_list)
|
|
1124
|
+
}
|
|
1125
|
+
nondeterministic_ops = sorted(
|
|
1126
|
+
set(operators).intersection(_NONDETERMINISTIC_OPERATORS)
|
|
1127
|
+
)
|
|
1128
|
+
if nondeterministic_ops:
|
|
1129
|
+
active_reporter.note(
|
|
1130
|
+
"Skipping output comparison for non-deterministic operator(s): "
|
|
1131
|
+
f"{', '.join(nondeterministic_ops)}"
|
|
1132
|
+
)
|
|
1133
|
+
return (
|
|
1134
|
+
"OK (non-deterministic output)",
|
|
1135
|
+
None,
|
|
1136
|
+
operators,
|
|
1137
|
+
opset_version,
|
|
1138
|
+
generated_checksum,
|
|
1139
|
+
)
|
|
1140
|
+
payload_outputs = payload.get("outputs", {})
|
|
1141
|
+
max_ulp = 0
|
|
1142
|
+
worst_diff: _WorstDiff | None = None
|
|
1143
|
+
max_abs_diff: float | int = 0
|
|
1144
|
+
worst_abs_diff: _WorstAbsDiff | None = None
|
|
1145
|
+
output_nodes = {
|
|
1146
|
+
output_name: node
|
|
1147
|
+
for node in graph.nodes
|
|
1148
|
+
for output_name in node.outputs
|
|
518
1149
|
}
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
1150
|
+
active_reporter.start_step(
|
|
1151
|
+
f" Comparing outputs [--max-ulp={args.max_ulp}]"
|
|
1152
|
+
)
|
|
1153
|
+
try:
|
|
1154
|
+
for value in graph.outputs:
|
|
1155
|
+
runtime_out = runtime_outputs[value.name]
|
|
1156
|
+
output_payload = payload_outputs.get(value.name)
|
|
1157
|
+
if output_payload is None:
|
|
1158
|
+
raise AssertionError(
|
|
1159
|
+
f"Missing output {value.name} in testbench data"
|
|
1160
|
+
)
|
|
1161
|
+
info = output_dtypes[value.name]
|
|
1162
|
+
output_data = decode_testbench_array(
|
|
1163
|
+
output_payload["data"], info.np_dtype
|
|
1164
|
+
).astype(info.np_dtype, copy=False)
|
|
1165
|
+
runtime_out = runtime_out.astype(info.np_dtype, copy=False)
|
|
1166
|
+
output_data = output_data.reshape(runtime_out.shape)
|
|
1167
|
+
if np.issubdtype(info.np_dtype, np.floating):
|
|
1168
|
+
output_max, output_worst = worst_ulp_diff(
|
|
1169
|
+
output_data,
|
|
1170
|
+
runtime_out,
|
|
1171
|
+
atol_eps=args.atol_eps,
|
|
1172
|
+
)
|
|
1173
|
+
if output_max > max_ulp:
|
|
1174
|
+
max_ulp = output_max
|
|
1175
|
+
if output_worst is not None:
|
|
1176
|
+
node = output_nodes.get(value.name)
|
|
1177
|
+
worst_diff = _WorstDiff(
|
|
1178
|
+
output_name=value.name,
|
|
1179
|
+
node_name=node.name if node else None,
|
|
1180
|
+
index=output_worst[0],
|
|
1181
|
+
got=float(output_worst[1]),
|
|
1182
|
+
reference=float(output_worst[2]),
|
|
1183
|
+
ulp=output_max,
|
|
1184
|
+
)
|
|
1185
|
+
else:
|
|
1186
|
+
output_max, output_worst = _worst_abs_diff(
|
|
1187
|
+
output_data, runtime_out
|
|
1188
|
+
)
|
|
1189
|
+
if output_max > max_abs_diff:
|
|
1190
|
+
max_abs_diff = output_max
|
|
1191
|
+
if output_worst is not None:
|
|
1192
|
+
node = output_nodes.get(value.name)
|
|
1193
|
+
worst_abs_diff = _WorstAbsDiff(
|
|
1194
|
+
output_name=value.name,
|
|
1195
|
+
node_name=node.name if node else None,
|
|
1196
|
+
index=output_worst[0],
|
|
1197
|
+
got=output_worst[1],
|
|
1198
|
+
reference=output_worst[2],
|
|
1199
|
+
abs_diff=output_max,
|
|
1200
|
+
)
|
|
1201
|
+
except AssertionError as exc:
|
|
1202
|
+
active_reporter.step_fail(str(exc))
|
|
1203
|
+
return None, str(exc), operators, opset_version, generated_checksum
|
|
1204
|
+
if max_abs_diff > 0:
|
|
1205
|
+
active_reporter.step_fail(f"max abs diff {max_abs_diff}")
|
|
1206
|
+
if worst_abs_diff is not None:
|
|
1207
|
+
node_label = worst_abs_diff.node_name or "(unknown)"
|
|
1208
|
+
index_display = ", ".join(str(dim) for dim in worst_abs_diff.index)
|
|
1209
|
+
active_reporter.info(
|
|
1210
|
+
" Worst diff: output="
|
|
1211
|
+
f"{worst_abs_diff.output_name} node={node_label} "
|
|
1212
|
+
f"index=[{index_display}] "
|
|
1213
|
+
f"got={worst_abs_diff.got} "
|
|
1214
|
+
f"ref={worst_abs_diff.reference} "
|
|
1215
|
+
f"abs_diff={worst_abs_diff.abs_diff}"
|
|
1216
|
+
)
|
|
1217
|
+
return (
|
|
1218
|
+
None,
|
|
1219
|
+
f"Arrays are not equal (max abs diff {max_abs_diff})",
|
|
1220
|
+
operators,
|
|
1221
|
+
opset_version,
|
|
1222
|
+
generated_checksum,
|
|
530
1223
|
)
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
1224
|
+
if max_ulp > args.max_ulp:
|
|
1225
|
+
active_reporter.step_fail(f"max ULP {max_ulp}")
|
|
1226
|
+
if worst_diff is not None:
|
|
1227
|
+
node_label = worst_diff.node_name or "(unknown)"
|
|
1228
|
+
index_display = ", ".join(str(dim) for dim in worst_diff.index)
|
|
1229
|
+
active_reporter.info(
|
|
1230
|
+
" Worst diff: output="
|
|
1231
|
+
f"{worst_diff.output_name} node={node_label} "
|
|
1232
|
+
f"index=[{index_display}] "
|
|
1233
|
+
f"got={worst_diff.got:.8g} "
|
|
1234
|
+
f"ref={worst_diff.reference:.8g} "
|
|
1235
|
+
f"ulp={worst_diff.ulp}"
|
|
1236
|
+
)
|
|
1237
|
+
return (
|
|
1238
|
+
None,
|
|
1239
|
+
f"Out of tolerance (max ULP {max_ulp})",
|
|
1240
|
+
operators,
|
|
1241
|
+
opset_version,
|
|
1242
|
+
generated_checksum,
|
|
545
1243
|
)
|
|
546
|
-
|
|
1244
|
+
active_reporter.step_ok_simple()
|
|
1245
|
+
active_reporter.info(f" Maximum ULP: {max_ulp}")
|
|
547
1246
|
return (
|
|
1247
|
+
format_success_message(max_ulp),
|
|
548
1248
|
None,
|
|
549
|
-
f"{runtime_name} failed to run {model_path}: {message}",
|
|
550
1249
|
operators,
|
|
1250
|
+
opset_version,
|
|
1251
|
+
generated_checksum,
|
|
551
1252
|
)
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
try:
|
|
556
|
-
for value, runtime_out in zip(graph.outputs, runtime_outputs):
|
|
557
|
-
output_payload = payload_outputs.get(value.name)
|
|
558
|
-
if output_payload is None:
|
|
559
|
-
raise AssertionError(f"Missing output {value.name} in testbench data")
|
|
560
|
-
info = output_dtypes[value.name]
|
|
561
|
-
output_data = decode_testbench_array(
|
|
562
|
-
output_payload["data"], info.np_dtype
|
|
563
|
-
).astype(info.np_dtype, copy=False)
|
|
564
|
-
runtime_out = runtime_out.astype(info.np_dtype, copy=False)
|
|
565
|
-
output_data = output_data.reshape(runtime_out.shape)
|
|
566
|
-
if np.issubdtype(info.np_dtype, np.floating):
|
|
567
|
-
max_ulp = max(max_ulp, max_ulp_diff(output_data, runtime_out))
|
|
568
|
-
else:
|
|
569
|
-
np.testing.assert_array_equal(output_data, runtime_out)
|
|
570
|
-
except AssertionError as exc:
|
|
571
|
-
return None, str(exc), operators
|
|
572
|
-
if max_ulp > args.max_ulp:
|
|
573
|
-
return None, f"Out of tolerance (max ULP {max_ulp})", operators
|
|
574
|
-
return format_success_message(max_ulp), None, operators
|
|
1253
|
+
finally:
|
|
1254
|
+
active_reporter.info("")
|
|
1255
|
+
_cleanup_temp()
|
|
575
1256
|
|
|
576
1257
|
|
|
577
1258
|
def _load_test_data_inputs(
|
|
578
1259
|
model: onnx.ModelProto, data_dir: Path | None
|
|
579
|
-
) -> dict[str, "np.ndarray"] | None:
|
|
1260
|
+
) -> tuple[dict[str, "np.ndarray"] | None, dict[str, bool] | None]:
|
|
580
1261
|
if data_dir is None:
|
|
581
|
-
return None
|
|
1262
|
+
return None, None
|
|
582
1263
|
if not data_dir.exists():
|
|
583
1264
|
raise CodegenError(f"Test data directory not found: {data_dir}")
|
|
584
1265
|
input_files = sorted(
|
|
@@ -587,26 +1268,115 @@ def _load_test_data_inputs(
|
|
|
587
1268
|
)
|
|
588
1269
|
if not input_files:
|
|
589
1270
|
raise CodegenError(f"No input_*.pb files found in {data_dir}")
|
|
590
|
-
|
|
1271
|
+
initializer_names = {init.name for init in model.graph.initializer}
|
|
1272
|
+
initializer_names.update(
|
|
1273
|
+
sparse_init.name for sparse_init in model.graph.sparse_initializer
|
|
1274
|
+
)
|
|
1275
|
+
model_inputs = [
|
|
1276
|
+
value_info
|
|
1277
|
+
for value_info in model.graph.input
|
|
1278
|
+
if value_info.name not in initializer_names
|
|
1279
|
+
]
|
|
1280
|
+
if len(input_files) != len(model_inputs):
|
|
591
1281
|
raise CodegenError(
|
|
592
1282
|
"Test data input count does not match model inputs: "
|
|
593
|
-
f"{len(input_files)} vs {len(
|
|
1283
|
+
f"{len(input_files)} vs {len(model_inputs)}."
|
|
594
1284
|
)
|
|
595
|
-
for value_info in
|
|
1285
|
+
for value_info in model_inputs:
|
|
596
1286
|
value_kind = value_info.type.WhichOneof("value")
|
|
597
|
-
if value_kind
|
|
1287
|
+
if value_kind not in {"tensor_type", "optional_type"}:
|
|
598
1288
|
LOGGER.warning(
|
|
599
1289
|
"Skipping test data load for non-tensor input %s (type %s).",
|
|
600
1290
|
value_info.name,
|
|
601
1291
|
value_kind or "unknown",
|
|
602
1292
|
)
|
|
603
|
-
return None
|
|
1293
|
+
return None, None
|
|
604
1294
|
inputs: dict[str, np.ndarray] = {}
|
|
1295
|
+
optional_flags: dict[str, bool] = {}
|
|
605
1296
|
for index, path in enumerate(input_files):
|
|
1297
|
+
value_info = model_inputs[index]
|
|
1298
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
1299
|
+
if value_kind == "tensor_type":
|
|
1300
|
+
tensor = onnx.TensorProto()
|
|
1301
|
+
tensor.ParseFromString(path.read_bytes())
|
|
1302
|
+
inputs[value_info.name] = numpy_helper.to_array(tensor)
|
|
1303
|
+
continue
|
|
1304
|
+
optional = onnx.OptionalProto()
|
|
1305
|
+
optional.ParseFromString(path.read_bytes())
|
|
1306
|
+
elem_type = value_info.type.optional_type.elem_type
|
|
1307
|
+
if elem_type.WhichOneof("value") != "tensor_type":
|
|
1308
|
+
LOGGER.warning(
|
|
1309
|
+
"Skipping test data load for non-tensor optional input %s.",
|
|
1310
|
+
value_info.name,
|
|
1311
|
+
)
|
|
1312
|
+
return None, None
|
|
1313
|
+
tensor_type = elem_type.tensor_type
|
|
1314
|
+
if optional.HasField("tensor_value"):
|
|
1315
|
+
inputs[value_info.name] = numpy_helper.to_array(
|
|
1316
|
+
optional.tensor_value
|
|
1317
|
+
)
|
|
1318
|
+
optional_flags[value_info.name] = True
|
|
1319
|
+
continue
|
|
1320
|
+
if not tensor_type.HasField("elem_type"):
|
|
1321
|
+
raise CodegenError(
|
|
1322
|
+
f"Optional input {value_info.name} is missing elem_type."
|
|
1323
|
+
)
|
|
1324
|
+
dtype_info = onnx._mapping.TENSOR_TYPE_MAP.get(tensor_type.elem_type)
|
|
1325
|
+
if dtype_info is None:
|
|
1326
|
+
raise CodegenError(
|
|
1327
|
+
f"Optional input {value_info.name} has unsupported elem_type."
|
|
1328
|
+
)
|
|
1329
|
+
shape: list[int] = []
|
|
1330
|
+
for dim in tensor_type.shape.dim:
|
|
1331
|
+
if dim.HasField("dim_value"):
|
|
1332
|
+
shape.append(dim.dim_value)
|
|
1333
|
+
elif dim.HasField("dim_param"):
|
|
1334
|
+
shape.append(1)
|
|
1335
|
+
else:
|
|
1336
|
+
raise CodegenError(
|
|
1337
|
+
f"Optional input {value_info.name} has unknown shape."
|
|
1338
|
+
)
|
|
1339
|
+
inputs[value_info.name] = np.zeros(
|
|
1340
|
+
tuple(shape), dtype=dtype_info.np_dtype
|
|
1341
|
+
)
|
|
1342
|
+
optional_flags[value_info.name] = False
|
|
1343
|
+
return inputs, optional_flags
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
def _load_test_data_outputs(
|
|
1347
|
+
model: onnx.ModelProto, data_dir: Path | None
|
|
1348
|
+
) -> dict[str, "np.ndarray"] | None:
|
|
1349
|
+
if data_dir is None:
|
|
1350
|
+
return None
|
|
1351
|
+
if not data_dir.exists():
|
|
1352
|
+
raise CodegenError(f"Test data directory not found: {data_dir}")
|
|
1353
|
+
output_files = sorted(
|
|
1354
|
+
data_dir.glob("output_*.pb"),
|
|
1355
|
+
key=lambda path: int(path.stem.split("_")[-1]),
|
|
1356
|
+
)
|
|
1357
|
+
if not output_files:
|
|
1358
|
+
return None
|
|
1359
|
+
model_outputs = list(model.graph.output)
|
|
1360
|
+
if len(output_files) != len(model_outputs):
|
|
1361
|
+
raise CodegenError(
|
|
1362
|
+
"Test data output count does not match model outputs: "
|
|
1363
|
+
f"{len(output_files)} vs {len(model_outputs)}."
|
|
1364
|
+
)
|
|
1365
|
+
for value_info in model_outputs:
|
|
1366
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
1367
|
+
if value_kind != "tensor_type":
|
|
1368
|
+
LOGGER.warning(
|
|
1369
|
+
"Skipping test data load for non-tensor output %s (type %s).",
|
|
1370
|
+
value_info.name,
|
|
1371
|
+
value_kind or "unknown",
|
|
1372
|
+
)
|
|
1373
|
+
return None
|
|
1374
|
+
outputs: dict[str, np.ndarray] = {}
|
|
1375
|
+
for index, path in enumerate(output_files):
|
|
606
1376
|
tensor = onnx.TensorProto()
|
|
607
1377
|
tensor.ParseFromString(path.read_bytes())
|
|
608
|
-
|
|
609
|
-
return
|
|
1378
|
+
outputs[model_outputs[index].name] = numpy_helper.to_array(tensor)
|
|
1379
|
+
return outputs
|
|
610
1380
|
|
|
611
1381
|
|
|
612
1382
|
def _format_command_line(argv: Sequence[str] | None) -> str:
|
|
@@ -615,15 +1385,97 @@ def _format_command_line(argv: Sequence[str] | None) -> str:
|
|
|
615
1385
|
args = [str(arg) for arg in argv[1:]]
|
|
616
1386
|
if not args:
|
|
617
1387
|
return ""
|
|
618
|
-
|
|
1388
|
+
filtered: list[str] = []
|
|
1389
|
+
skip_next = False
|
|
1390
|
+
for arg in args:
|
|
1391
|
+
if skip_next:
|
|
1392
|
+
skip_next = False
|
|
1393
|
+
continue
|
|
1394
|
+
if arg == "--expected-checksum":
|
|
1395
|
+
skip_next = True
|
|
1396
|
+
continue
|
|
1397
|
+
if arg.startswith("--expected-checksum="):
|
|
1398
|
+
continue
|
|
1399
|
+
filtered.append(arg)
|
|
1400
|
+
if not filtered:
|
|
1401
|
+
return ""
|
|
1402
|
+
return shlex.join(filtered)
|
|
1403
|
+
|
|
1404
|
+
|
|
1405
|
+
def _load_model_and_checksum(
|
|
1406
|
+
model_path: Path,
|
|
1407
|
+
) -> tuple[onnx.ModelProto, str]:
|
|
1408
|
+
model_bytes = model_path.read_bytes()
|
|
1409
|
+
digest = hashlib.sha256()
|
|
1410
|
+
digest.update(model_bytes)
|
|
1411
|
+
model = onnx.load_model_from_string(model_bytes)
|
|
1412
|
+
return model, digest.hexdigest()
|
|
619
1413
|
|
|
620
1414
|
|
|
621
|
-
def
|
|
1415
|
+
def _generated_checksum(generated: str) -> str:
|
|
622
1416
|
digest = hashlib.sha256()
|
|
623
|
-
digest.update(
|
|
1417
|
+
digest.update(generated.encode("utf-8"))
|
|
624
1418
|
return digest.hexdigest()
|
|
625
1419
|
|
|
626
1420
|
|
|
1421
|
+
def _report_model_details(
|
|
1422
|
+
reporter: _VerifyReporter,
|
|
1423
|
+
*,
|
|
1424
|
+
model_path: Path,
|
|
1425
|
+
model_checksum: str,
|
|
1426
|
+
operators: Sequence[str],
|
|
1427
|
+
opset_version: int | None,
|
|
1428
|
+
node_count: int,
|
|
1429
|
+
initializer_count: int,
|
|
1430
|
+
input_count: int,
|
|
1431
|
+
output_count: int,
|
|
1432
|
+
) -> None:
|
|
1433
|
+
operators_display = ", ".join(operators) if operators else "(none)"
|
|
1434
|
+
reporter.info(
|
|
1435
|
+
f" Model operators ({len(operators)}): {operators_display}"
|
|
1436
|
+
)
|
|
1437
|
+
reporter.info(
|
|
1438
|
+
f" Model file size: {_format_artifact_size(model_path.stat().st_size)}"
|
|
1439
|
+
)
|
|
1440
|
+
reporter.info(f" Model checksum (sha256): {model_checksum}")
|
|
1441
|
+
if opset_version is not None:
|
|
1442
|
+
reporter.info(f" Opset version: {opset_version}")
|
|
1443
|
+
reporter.info(
|
|
1444
|
+
" Counts: "
|
|
1445
|
+
f"nodes={node_count}, "
|
|
1446
|
+
f"initializers={initializer_count}, "
|
|
1447
|
+
f"inputs={input_count}, "
|
|
1448
|
+
f"outputs={output_count}"
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
|
|
1452
|
+
def _report_codegen_timings(
|
|
1453
|
+
reporter: _VerifyReporter, *, timings: Mapping[str, float]
|
|
1454
|
+
) -> None:
|
|
1455
|
+
if not timings:
|
|
1456
|
+
return
|
|
1457
|
+
order = [
|
|
1458
|
+
("import_onnx", "import"),
|
|
1459
|
+
("concretize_shapes", "concretize"),
|
|
1460
|
+
("resolve_testbench_inputs", "testbench"),
|
|
1461
|
+
("collect_variable_dims", "var_dims"),
|
|
1462
|
+
("lower_model", "lower"),
|
|
1463
|
+
("emit_model", "emit"),
|
|
1464
|
+
("emit_model_with_data_file", "emit_data"),
|
|
1465
|
+
("collect_weight_data", "weights"),
|
|
1466
|
+
]
|
|
1467
|
+
seen = set()
|
|
1468
|
+
parts: list[str] = []
|
|
1469
|
+
for key, label in order:
|
|
1470
|
+
if key not in timings:
|
|
1471
|
+
continue
|
|
1472
|
+
parts.append(f"{label}={timings[key]:.3f}s")
|
|
1473
|
+
seen.add(key)
|
|
1474
|
+
for key in sorted(k for k in timings if k not in seen):
|
|
1475
|
+
parts.append(f"{key}={timings[key]:.3f}s")
|
|
1476
|
+
reporter.info(f" Codegen timing: {', '.join(parts)}")
|
|
1477
|
+
|
|
1478
|
+
|
|
627
1479
|
def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
|
|
628
1480
|
operators: list[str] = []
|
|
629
1481
|
seen: set[str] = set()
|
|
@@ -634,3 +1486,14 @@ def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
|
|
|
634
1486
|
seen.add(op_name)
|
|
635
1487
|
operators.append(op_name)
|
|
636
1488
|
return operators
|
|
1489
|
+
|
|
1490
|
+
|
|
1491
|
+
def _model_opset_version(model: onnx.ModelProto, *, domain: str = "") -> int | None:
|
|
1492
|
+
if not model.opset_import:
|
|
1493
|
+
return None
|
|
1494
|
+
domains = (domain,) if domain else ("", "ai.onnx")
|
|
1495
|
+
for target_domain in domains:
|
|
1496
|
+
for opset in model.opset_import:
|
|
1497
|
+
if opset.domain == target_domain:
|
|
1498
|
+
return opset.version
|
|
1499
|
+
return None
|