emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/_build_info.py
CHANGED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '0.3.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 1)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
emx_onnx_cgen/cli.py
CHANGED
|
@@ -10,18 +10,89 @@ import shutil
|
|
|
10
10
|
import subprocess
|
|
11
11
|
import sys
|
|
12
12
|
import tempfile
|
|
13
|
+
import time
|
|
14
|
+
import signal
|
|
13
15
|
from pathlib import Path
|
|
14
|
-
from
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import TYPE_CHECKING, Mapping, Sequence
|
|
15
18
|
|
|
16
19
|
import onnx
|
|
20
|
+
from onnx import numpy_helper
|
|
17
21
|
|
|
18
22
|
from ._build_info import BUILD_DATE, GIT_VERSION
|
|
19
23
|
from .compiler import Compiler, CompilerOptions
|
|
20
24
|
from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
|
|
21
25
|
from .onnx_import import import_onnx
|
|
26
|
+
from .onnxruntime_utils import make_deterministic_session_options
|
|
27
|
+
from .testbench import decode_testbench_array
|
|
28
|
+
from .verification import format_success_message, max_ulp_diff
|
|
22
29
|
|
|
23
30
|
LOGGER = logging.getLogger(__name__)
|
|
24
31
|
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
import numpy as np
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class CliResult:
|
|
38
|
+
exit_code: int
|
|
39
|
+
command_line: str
|
|
40
|
+
error: str | None = None
|
|
41
|
+
success_message: str | None = None
|
|
42
|
+
generated: str | None = None
|
|
43
|
+
data_source: str | None = None
|
|
44
|
+
operators: list[str] | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def run_cli_command(
|
|
48
|
+
argv: Sequence[str],
|
|
49
|
+
*,
|
|
50
|
+
testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
|
|
51
|
+
) -> CliResult:
|
|
52
|
+
raw_argv = list(argv)
|
|
53
|
+
parse_argv = raw_argv
|
|
54
|
+
if raw_argv and raw_argv[0] == "emx-onnx-cgen":
|
|
55
|
+
parse_argv = raw_argv[1:]
|
|
56
|
+
parser = _build_parser()
|
|
57
|
+
args = parser.parse_args(parse_argv)
|
|
58
|
+
args.command_line = _format_command_line(raw_argv)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
if args.command != "compile":
|
|
62
|
+
success_message, error, operators = _verify_model(
|
|
63
|
+
args, include_build_details=False
|
|
64
|
+
)
|
|
65
|
+
return CliResult(
|
|
66
|
+
exit_code=0 if error is None else 1,
|
|
67
|
+
command_line=args.command_line,
|
|
68
|
+
error=error,
|
|
69
|
+
success_message=success_message,
|
|
70
|
+
operators=operators,
|
|
71
|
+
)
|
|
72
|
+
generated, data_source, error = _compile_model(
|
|
73
|
+
args, testbench_inputs=testbench_inputs
|
|
74
|
+
)
|
|
75
|
+
if error:
|
|
76
|
+
return CliResult(
|
|
77
|
+
exit_code=1,
|
|
78
|
+
command_line=args.command_line,
|
|
79
|
+
error=error,
|
|
80
|
+
)
|
|
81
|
+
return CliResult(
|
|
82
|
+
exit_code=0,
|
|
83
|
+
command_line=args.command_line,
|
|
84
|
+
success_message="",
|
|
85
|
+
generated=generated,
|
|
86
|
+
data_source=data_source,
|
|
87
|
+
)
|
|
88
|
+
except Exception as exc: # pragma: no cover - defensive reporting
|
|
89
|
+
LOGGER.exception("Unhandled exception while running CLI command.")
|
|
90
|
+
return CliResult(
|
|
91
|
+
exit_code=1,
|
|
92
|
+
command_line=args.command_line,
|
|
93
|
+
error=str(exc),
|
|
94
|
+
)
|
|
95
|
+
|
|
25
96
|
|
|
26
97
|
def _build_parser() -> argparse.ArgumentParser:
|
|
27
98
|
description = (
|
|
@@ -86,6 +157,33 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
86
157
|
"named like the output with a _data suffix"
|
|
87
158
|
),
|
|
88
159
|
)
|
|
160
|
+
compile_parser.add_argument(
|
|
161
|
+
"--truncate-weights-after",
|
|
162
|
+
type=int,
|
|
163
|
+
default=None,
|
|
164
|
+
help=(
|
|
165
|
+
"Truncate inline weight initializers after N values and insert "
|
|
166
|
+
"\"...\" placeholders (default: no truncation)"
|
|
167
|
+
),
|
|
168
|
+
)
|
|
169
|
+
compile_parser.add_argument(
|
|
170
|
+
"--large-temp-threshold-bytes",
|
|
171
|
+
type=int,
|
|
172
|
+
default=1024,
|
|
173
|
+
help=(
|
|
174
|
+
"Mark temporary buffers larger than this threshold as static "
|
|
175
|
+
"(default: 1024)"
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
compile_parser.add_argument(
|
|
179
|
+
"--large-weight-threshold",
|
|
180
|
+
type=int,
|
|
181
|
+
default=1024 * 1024,
|
|
182
|
+
help=(
|
|
183
|
+
"Store weights larger than this element count in a binary file "
|
|
184
|
+
"(default: 1048576; set to 0 to disable)"
|
|
185
|
+
),
|
|
186
|
+
)
|
|
89
187
|
add_restrict_flags(compile_parser)
|
|
90
188
|
|
|
91
189
|
verify_parser = subparsers.add_parser(
|
|
@@ -111,6 +209,57 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
111
209
|
default=None,
|
|
112
210
|
help="C compiler command to build the testbench binary",
|
|
113
211
|
)
|
|
212
|
+
verify_parser.add_argument(
|
|
213
|
+
"--truncate-weights-after",
|
|
214
|
+
type=int,
|
|
215
|
+
default=None,
|
|
216
|
+
help=(
|
|
217
|
+
"Truncate inline weight initializers after N values and insert "
|
|
218
|
+
"\"...\" placeholders (default: no truncation)"
|
|
219
|
+
),
|
|
220
|
+
)
|
|
221
|
+
verify_parser.add_argument(
|
|
222
|
+
"--large-temp-threshold-bytes",
|
|
223
|
+
type=int,
|
|
224
|
+
default=1024,
|
|
225
|
+
help=(
|
|
226
|
+
"Mark temporary buffers larger than this threshold as static "
|
|
227
|
+
"(default: 1024)"
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
verify_parser.add_argument(
|
|
231
|
+
"--large-weight-threshold",
|
|
232
|
+
type=int,
|
|
233
|
+
default=1024,
|
|
234
|
+
help=(
|
|
235
|
+
"Store weights larger than this element count in a binary file "
|
|
236
|
+
"(default: 1024)"
|
|
237
|
+
),
|
|
238
|
+
)
|
|
239
|
+
verify_parser.add_argument(
|
|
240
|
+
"--test-data-dir",
|
|
241
|
+
type=Path,
|
|
242
|
+
default=None,
|
|
243
|
+
help=(
|
|
244
|
+
"Directory containing input_*.pb files to seed verification inputs "
|
|
245
|
+
"(default: use random testbench inputs)"
|
|
246
|
+
),
|
|
247
|
+
)
|
|
248
|
+
verify_parser.add_argument(
|
|
249
|
+
"--max-ulp",
|
|
250
|
+
type=int,
|
|
251
|
+
default=100,
|
|
252
|
+
help="Maximum allowed ULP difference for floating outputs (default: 100)",
|
|
253
|
+
)
|
|
254
|
+
verify_parser.add_argument(
|
|
255
|
+
"--runtime",
|
|
256
|
+
choices=("onnxruntime", "onnx-reference"),
|
|
257
|
+
default="onnx-reference",
|
|
258
|
+
help=(
|
|
259
|
+
"Runtime backend for verification (default: onnx-reference; "
|
|
260
|
+
"options: onnxruntime, onnx-reference)"
|
|
261
|
+
),
|
|
262
|
+
)
|
|
114
263
|
add_restrict_flags(verify_parser)
|
|
115
264
|
return parser
|
|
116
265
|
|
|
@@ -132,7 +281,35 @@ def main(argv: Sequence[str] | None = None) -> int:
|
|
|
132
281
|
def _handle_compile(args: argparse.Namespace) -> int:
|
|
133
282
|
model_path: Path = args.model
|
|
134
283
|
output_path: Path = args.output or model_path.with_suffix(".c")
|
|
135
|
-
model_name = args.model_name or
|
|
284
|
+
model_name = args.model_name or "model"
|
|
285
|
+
generated, data_source, weight_data, error = _compile_model(args)
|
|
286
|
+
if error:
|
|
287
|
+
LOGGER.error("Failed to compile %s: %s", model_path, error)
|
|
288
|
+
return 1
|
|
289
|
+
|
|
290
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
291
|
+
output_path.write_text(generated or "", encoding="utf-8")
|
|
292
|
+
LOGGER.info("Wrote C source to %s", output_path)
|
|
293
|
+
if data_source is not None:
|
|
294
|
+
data_path = output_path.with_name(
|
|
295
|
+
f"{output_path.stem}_data{output_path.suffix}"
|
|
296
|
+
)
|
|
297
|
+
data_path.write_text(data_source, encoding="utf-8")
|
|
298
|
+
LOGGER.info("Wrote data source to %s", data_path)
|
|
299
|
+
if weight_data is not None:
|
|
300
|
+
weights_path = output_path.with_name(f"{model_name}.bin")
|
|
301
|
+
weights_path.write_bytes(weight_data)
|
|
302
|
+
LOGGER.info("Wrote weights binary to %s", weights_path)
|
|
303
|
+
return 0
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _compile_model(
|
|
307
|
+
args: argparse.Namespace,
|
|
308
|
+
*,
|
|
309
|
+
testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
|
|
310
|
+
) -> tuple[str | None, str | None, bytes | None, str | None]:
|
|
311
|
+
model_path: Path = args.model
|
|
312
|
+
model_name = args.model_name or "model"
|
|
136
313
|
try:
|
|
137
314
|
model_checksum = _model_checksum(model_path)
|
|
138
315
|
model = onnx.load_model(model_path)
|
|
@@ -143,27 +320,22 @@ def _handle_compile(args: argparse.Namespace) -> int:
|
|
|
143
320
|
command_line=args.command_line,
|
|
144
321
|
model_checksum=model_checksum,
|
|
145
322
|
restrict_arrays=args.restrict_arrays,
|
|
323
|
+
truncate_weights_after=args.truncate_weights_after,
|
|
324
|
+
large_temp_threshold_bytes=args.large_temp_threshold_bytes,
|
|
325
|
+
large_weight_threshold=args.large_weight_threshold,
|
|
326
|
+
testbench_inputs=testbench_inputs,
|
|
146
327
|
)
|
|
147
328
|
compiler = Compiler(options)
|
|
148
329
|
if args.emit_data_file:
|
|
149
|
-
generated, data_source =
|
|
330
|
+
generated, data_source, weight_data = (
|
|
331
|
+
compiler.compile_with_data_file_and_weight_data(model)
|
|
332
|
+
)
|
|
150
333
|
else:
|
|
151
|
-
generated = compiler.
|
|
334
|
+
generated, weight_data = compiler.compile_with_weight_data(model)
|
|
152
335
|
data_source = None
|
|
153
336
|
except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
158
|
-
output_path.write_text(generated, encoding="utf-8")
|
|
159
|
-
LOGGER.info("Wrote C source to %s", output_path)
|
|
160
|
-
if data_source is not None:
|
|
161
|
-
data_path = output_path.with_name(
|
|
162
|
-
f"{output_path.stem}_data{output_path.suffix}"
|
|
163
|
-
)
|
|
164
|
-
data_path.write_text(data_source, encoding="utf-8")
|
|
165
|
-
LOGGER.info("Wrote data source to %s", data_path)
|
|
166
|
-
return 0
|
|
337
|
+
return None, None, None, str(exc)
|
|
338
|
+
return generated, data_source, weight_data, None
|
|
167
339
|
|
|
168
340
|
|
|
169
341
|
def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str] | None:
|
|
@@ -198,18 +370,59 @@ def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str]
|
|
|
198
370
|
|
|
199
371
|
|
|
200
372
|
def _handle_verify(args: argparse.Namespace) -> int:
|
|
373
|
+
success_message, error, _operators = _verify_model(
|
|
374
|
+
args, include_build_details=True
|
|
375
|
+
)
|
|
376
|
+
if error is not None:
|
|
377
|
+
LOGGER.error("Verification failed: %s", error)
|
|
378
|
+
return 1
|
|
379
|
+
if success_message:
|
|
380
|
+
LOGGER.info("%s", success_message)
|
|
381
|
+
return 0
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _verify_model(
|
|
385
|
+
args: argparse.Namespace,
|
|
386
|
+
*,
|
|
387
|
+
include_build_details: bool,
|
|
388
|
+
) -> tuple[str | None, str | None, list[str]]:
|
|
201
389
|
import numpy as np
|
|
202
|
-
|
|
390
|
+
|
|
391
|
+
def log_step(step: str, started_at: float) -> None:
|
|
392
|
+
duration = time.perf_counter() - started_at
|
|
393
|
+
LOGGER.info("verify step %s: %.3fs", step, duration)
|
|
394
|
+
|
|
395
|
+
def describe_exit_code(returncode: int) -> str:
|
|
396
|
+
if returncode >= 0:
|
|
397
|
+
return f"exit code {returncode}"
|
|
398
|
+
signal_id = -returncode
|
|
399
|
+
try:
|
|
400
|
+
signal_name = signal.Signals(signal_id).name
|
|
401
|
+
except ValueError:
|
|
402
|
+
signal_name = "unknown"
|
|
403
|
+
return f"exit code {returncode} (signal {signal_id}: {signal_name})"
|
|
203
404
|
|
|
204
405
|
model_path: Path = args.model
|
|
205
|
-
model_name = args.model_name or
|
|
406
|
+
model_name = args.model_name or "model"
|
|
206
407
|
model_checksum = _model_checksum(model_path)
|
|
207
408
|
compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
|
|
208
409
|
if compiler_cmd is None:
|
|
209
|
-
|
|
210
|
-
|
|
410
|
+
return (
|
|
411
|
+
None,
|
|
412
|
+
"No C compiler found (set --cc or CC environment variable).",
|
|
413
|
+
[],
|
|
414
|
+
)
|
|
211
415
|
try:
|
|
212
416
|
model = onnx.load_model(model_path)
|
|
417
|
+
except OSError as exc:
|
|
418
|
+
return None, str(exc), []
|
|
419
|
+
|
|
420
|
+
operators = _collect_model_operators(model)
|
|
421
|
+
operators_display = ", ".join(operators) if operators else "(none)"
|
|
422
|
+
LOGGER.info("verify operators: %s", operators_display)
|
|
423
|
+
|
|
424
|
+
try:
|
|
425
|
+
testbench_inputs = _load_test_data_inputs(model, args.test_data_dir)
|
|
213
426
|
options = CompilerOptions(
|
|
214
427
|
template_dir=args.template_dir,
|
|
215
428
|
model_name=model_name,
|
|
@@ -217,100 +430,183 @@ def _handle_verify(args: argparse.Namespace) -> int:
|
|
|
217
430
|
command_line=args.command_line,
|
|
218
431
|
model_checksum=model_checksum,
|
|
219
432
|
restrict_arrays=args.restrict_arrays,
|
|
433
|
+
truncate_weights_after=args.truncate_weights_after,
|
|
434
|
+
large_temp_threshold_bytes=args.large_temp_threshold_bytes,
|
|
435
|
+
large_weight_threshold=args.large_weight_threshold,
|
|
436
|
+
testbench_inputs=testbench_inputs,
|
|
220
437
|
)
|
|
221
438
|
compiler = Compiler(options)
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
439
|
+
codegen_started = time.perf_counter()
|
|
440
|
+
generated, weight_data = compiler.compile_with_weight_data(model)
|
|
441
|
+
log_step("codegen", codegen_started)
|
|
442
|
+
except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
443
|
+
return None, str(exc), operators
|
|
226
444
|
|
|
227
445
|
try:
|
|
228
446
|
graph = import_onnx(model)
|
|
229
447
|
output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
|
|
230
448
|
input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
|
|
231
449
|
except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
|
|
232
|
-
|
|
233
|
-
return 1
|
|
450
|
+
return None, f"Failed to resolve model dtype: {exc}", operators
|
|
234
451
|
|
|
235
452
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
236
453
|
temp_path = Path(temp_dir)
|
|
454
|
+
LOGGER.info("verify temp dir: %s", temp_path)
|
|
237
455
|
c_path = temp_path / "model.c"
|
|
456
|
+
weights_path = temp_path / f"{model_name}.bin"
|
|
238
457
|
exe_path = temp_path / "model"
|
|
239
458
|
c_path.write_text(generated, encoding="utf-8")
|
|
459
|
+
if weight_data is not None:
|
|
460
|
+
weights_path.write_bytes(weight_data)
|
|
240
461
|
try:
|
|
462
|
+
compile_started = time.perf_counter()
|
|
463
|
+
compile_cmd = [
|
|
464
|
+
*compiler_cmd,
|
|
465
|
+
"-std=c99",
|
|
466
|
+
"-O2",
|
|
467
|
+
str(c_path),
|
|
468
|
+
"-o",
|
|
469
|
+
str(exe_path),
|
|
470
|
+
"-lm",
|
|
471
|
+
]
|
|
472
|
+
LOGGER.info("verify compile command: %s", shlex.join(compile_cmd))
|
|
241
473
|
subprocess.run(
|
|
242
|
-
|
|
243
|
-
*compiler_cmd,
|
|
244
|
-
"-std=c99",
|
|
245
|
-
"-O2",
|
|
246
|
-
str(c_path),
|
|
247
|
-
"-o",
|
|
248
|
-
str(exe_path),
|
|
249
|
-
"-lm",
|
|
250
|
-
],
|
|
474
|
+
compile_cmd,
|
|
251
475
|
check=True,
|
|
252
476
|
capture_output=True,
|
|
253
477
|
text=True,
|
|
254
478
|
)
|
|
479
|
+
log_step("compile", compile_started)
|
|
255
480
|
except subprocess.CalledProcessError as exc:
|
|
256
|
-
|
|
257
|
-
|
|
481
|
+
message = "Failed to build testbench."
|
|
482
|
+
if include_build_details:
|
|
483
|
+
details = exc.stderr.strip()
|
|
484
|
+
if details:
|
|
485
|
+
message = f"{message} {details}"
|
|
486
|
+
return None, message, operators
|
|
258
487
|
try:
|
|
488
|
+
run_started = time.perf_counter()
|
|
259
489
|
result = subprocess.run(
|
|
260
490
|
[str(exe_path)],
|
|
261
491
|
check=True,
|
|
262
492
|
capture_output=True,
|
|
263
493
|
text=True,
|
|
494
|
+
cwd=temp_path,
|
|
264
495
|
)
|
|
496
|
+
log_step("run", run_started)
|
|
265
497
|
except subprocess.CalledProcessError as exc:
|
|
266
|
-
|
|
267
|
-
|
|
498
|
+
return None, (
|
|
499
|
+
"Testbench execution failed: " + describe_exit_code(exc.returncode)
|
|
500
|
+
), operators
|
|
268
501
|
|
|
269
502
|
try:
|
|
270
503
|
payload = json.loads(result.stdout)
|
|
271
504
|
except json.JSONDecodeError as exc:
|
|
272
|
-
|
|
273
|
-
return 1
|
|
505
|
+
return None, f"Failed to parse testbench JSON: {exc}", operators
|
|
274
506
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
507
|
+
if testbench_inputs:
|
|
508
|
+
inputs = {
|
|
509
|
+
name: values.astype(input_dtypes[name].np_dtype, copy=False)
|
|
510
|
+
for name, values in testbench_inputs.items()
|
|
511
|
+
}
|
|
512
|
+
else:
|
|
513
|
+
inputs = {
|
|
514
|
+
name: decode_testbench_array(
|
|
515
|
+
value["data"], input_dtypes[name].np_dtype
|
|
516
|
+
)
|
|
517
|
+
for name, value in payload["inputs"].items()
|
|
518
|
+
}
|
|
519
|
+
runtime_name = args.runtime
|
|
520
|
+
runtime_started = time.perf_counter()
|
|
282
521
|
try:
|
|
283
|
-
|
|
522
|
+
if runtime_name == "onnxruntime":
|
|
523
|
+
import onnxruntime as ort
|
|
524
|
+
|
|
525
|
+
sess_options = make_deterministic_session_options(ort)
|
|
526
|
+
sess = ort.InferenceSession(
|
|
527
|
+
model.SerializeToString(),
|
|
528
|
+
sess_options=sess_options,
|
|
529
|
+
providers=["CPUExecutionProvider"],
|
|
530
|
+
)
|
|
531
|
+
runtime_outputs = sess.run(None, inputs)
|
|
532
|
+
else:
|
|
533
|
+
from onnx.reference import ReferenceEvaluator
|
|
534
|
+
|
|
535
|
+
evaluator = ReferenceEvaluator(model)
|
|
536
|
+
runtime_outputs = evaluator.run(None, inputs)
|
|
284
537
|
except Exception as exc:
|
|
538
|
+
log_step(runtime_name, runtime_started)
|
|
285
539
|
message = str(exc)
|
|
286
|
-
if "NOT_IMPLEMENTED" in message:
|
|
540
|
+
if runtime_name == "onnxruntime" and "NOT_IMPLEMENTED" in message:
|
|
287
541
|
LOGGER.warning(
|
|
288
542
|
"Skipping verification for %s: ONNX Runtime does not support the model (%s)",
|
|
289
543
|
model_path,
|
|
290
544
|
message,
|
|
291
545
|
)
|
|
292
|
-
return
|
|
293
|
-
|
|
294
|
-
|
|
546
|
+
return "", None, operators
|
|
547
|
+
return (
|
|
548
|
+
None,
|
|
549
|
+
f"{runtime_name} failed to run {model_path}: {message}",
|
|
550
|
+
operators,
|
|
551
|
+
)
|
|
552
|
+
log_step(runtime_name, runtime_started)
|
|
295
553
|
payload_outputs = payload.get("outputs", {})
|
|
554
|
+
max_ulp = 0
|
|
296
555
|
try:
|
|
297
|
-
for value,
|
|
556
|
+
for value, runtime_out in zip(graph.outputs, runtime_outputs):
|
|
298
557
|
output_payload = payload_outputs.get(value.name)
|
|
299
558
|
if output_payload is None:
|
|
300
559
|
raise AssertionError(f"Missing output {value.name} in testbench data")
|
|
301
560
|
info = output_dtypes[value.name]
|
|
302
|
-
output_data =
|
|
561
|
+
output_data = decode_testbench_array(
|
|
562
|
+
output_payload["data"], info.np_dtype
|
|
563
|
+
).astype(info.np_dtype, copy=False)
|
|
564
|
+
runtime_out = runtime_out.astype(info.np_dtype, copy=False)
|
|
565
|
+
output_data = output_data.reshape(runtime_out.shape)
|
|
303
566
|
if np.issubdtype(info.np_dtype, np.floating):
|
|
304
|
-
|
|
305
|
-
output_data, ort_out, rtol=1e-4, atol=1e-5
|
|
306
|
-
)
|
|
567
|
+
max_ulp = max(max_ulp, max_ulp_diff(output_data, runtime_out))
|
|
307
568
|
else:
|
|
308
|
-
np.testing.assert_array_equal(output_data,
|
|
569
|
+
np.testing.assert_array_equal(output_data, runtime_out)
|
|
309
570
|
except AssertionError as exc:
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
return
|
|
571
|
+
return None, str(exc), operators
|
|
572
|
+
if max_ulp > args.max_ulp:
|
|
573
|
+
return None, f"Out of tolerance (max ULP {max_ulp})", operators
|
|
574
|
+
return format_success_message(max_ulp), None, operators
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def _load_test_data_inputs(
|
|
578
|
+
model: onnx.ModelProto, data_dir: Path | None
|
|
579
|
+
) -> dict[str, "np.ndarray"] | None:
|
|
580
|
+
if data_dir is None:
|
|
581
|
+
return None
|
|
582
|
+
if not data_dir.exists():
|
|
583
|
+
raise CodegenError(f"Test data directory not found: {data_dir}")
|
|
584
|
+
input_files = sorted(
|
|
585
|
+
data_dir.glob("input_*.pb"),
|
|
586
|
+
key=lambda path: int(path.stem.split("_")[-1]),
|
|
587
|
+
)
|
|
588
|
+
if not input_files:
|
|
589
|
+
raise CodegenError(f"No input_*.pb files found in {data_dir}")
|
|
590
|
+
if len(input_files) != len(model.graph.input):
|
|
591
|
+
raise CodegenError(
|
|
592
|
+
"Test data input count does not match model inputs: "
|
|
593
|
+
f"{len(input_files)} vs {len(model.graph.input)}."
|
|
594
|
+
)
|
|
595
|
+
for value_info in model.graph.input:
|
|
596
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
597
|
+
if value_kind != "tensor_type":
|
|
598
|
+
LOGGER.warning(
|
|
599
|
+
"Skipping test data load for non-tensor input %s (type %s).",
|
|
600
|
+
value_info.name,
|
|
601
|
+
value_kind or "unknown",
|
|
602
|
+
)
|
|
603
|
+
return None
|
|
604
|
+
inputs: dict[str, np.ndarray] = {}
|
|
605
|
+
for index, path in enumerate(input_files):
|
|
606
|
+
tensor = onnx.TensorProto()
|
|
607
|
+
tensor.ParseFromString(path.read_bytes())
|
|
608
|
+
inputs[model.graph.input[index].name] = numpy_helper.to_array(tensor)
|
|
609
|
+
return inputs
|
|
314
610
|
|
|
315
611
|
|
|
316
612
|
def _format_command_line(argv: Sequence[str] | None) -> str:
|
|
@@ -326,3 +622,15 @@ def _model_checksum(model_path: Path) -> str:
|
|
|
326
622
|
digest = hashlib.sha256()
|
|
327
623
|
digest.update(model_path.read_bytes())
|
|
328
624
|
return digest.hexdigest()
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
|
|
628
|
+
operators: list[str] = []
|
|
629
|
+
seen: set[str] = set()
|
|
630
|
+
for node in model.graph.node:
|
|
631
|
+
op_name = f"{node.domain}::{node.op_type}" if node.domain else node.op_type
|
|
632
|
+
if op_name in seen:
|
|
633
|
+
continue
|
|
634
|
+
seen.add(op_name)
|
|
635
|
+
operators.append(op_name)
|
|
636
|
+
return operators
|