emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +340 -59
- emx_onnx_cgen/codegen/c_emitter.py +2369 -111
- emx_onnx_cgen/compiler.py +188 -5
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/lowering/common.py +379 -2
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/gather_elements.py +1 -3
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +6 -5
- emx_onnx_cgen/lowering/logsoftmax.py +5 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/matmul.py +6 -7
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/reduce.py +5 -6
- emx_onnx_cgen/lowering/reshape.py +223 -51
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/softmax.py +5 -1
- emx_onnx_cgen/lowering/squeeze.py +5 -5
- emx_onnx_cgen/lowering/topk.py +116 -0
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +5 -5
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +460 -42
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +61 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
- shared/scalar_functions.py +49 -17
- shared/ulp.py +48 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/_build_info.py
CHANGED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '0.3.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
emx_onnx_cgen/cli.py
CHANGED
|
@@ -10,18 +10,89 @@ import shutil
|
|
|
10
10
|
import subprocess
|
|
11
11
|
import sys
|
|
12
12
|
import tempfile
|
|
13
|
+
import time
|
|
14
|
+
import signal
|
|
13
15
|
from pathlib import Path
|
|
14
|
-
from
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import TYPE_CHECKING, Mapping, Sequence
|
|
15
18
|
|
|
16
19
|
import onnx
|
|
20
|
+
from onnx import numpy_helper
|
|
17
21
|
|
|
18
22
|
from ._build_info import BUILD_DATE, GIT_VERSION
|
|
19
23
|
from .compiler import Compiler, CompilerOptions
|
|
20
24
|
from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
|
|
21
25
|
from .onnx_import import import_onnx
|
|
26
|
+
from .onnxruntime_utils import make_deterministic_session_options
|
|
27
|
+
from .testbench import decode_testbench_array
|
|
28
|
+
from .verification import format_success_message, max_ulp_diff
|
|
22
29
|
|
|
23
30
|
LOGGER = logging.getLogger(__name__)
|
|
24
31
|
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
import numpy as np
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class CliResult:
|
|
38
|
+
exit_code: int
|
|
39
|
+
command_line: str
|
|
40
|
+
error: str | None = None
|
|
41
|
+
success_message: str | None = None
|
|
42
|
+
generated: str | None = None
|
|
43
|
+
data_source: str | None = None
|
|
44
|
+
operators: list[str] | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def run_cli_command(
|
|
48
|
+
argv: Sequence[str],
|
|
49
|
+
*,
|
|
50
|
+
testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
|
|
51
|
+
) -> CliResult:
|
|
52
|
+
raw_argv = list(argv)
|
|
53
|
+
parse_argv = raw_argv
|
|
54
|
+
if raw_argv and raw_argv[0] == "emx-onnx-cgen":
|
|
55
|
+
parse_argv = raw_argv[1:]
|
|
56
|
+
parser = _build_parser()
|
|
57
|
+
args = parser.parse_args(parse_argv)
|
|
58
|
+
args.command_line = _format_command_line(raw_argv)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
if args.command != "compile":
|
|
62
|
+
success_message, error, operators = _verify_model(
|
|
63
|
+
args, include_build_details=False
|
|
64
|
+
)
|
|
65
|
+
return CliResult(
|
|
66
|
+
exit_code=0 if error is None else 1,
|
|
67
|
+
command_line=args.command_line,
|
|
68
|
+
error=error,
|
|
69
|
+
success_message=success_message,
|
|
70
|
+
operators=operators,
|
|
71
|
+
)
|
|
72
|
+
generated, data_source, error = _compile_model(
|
|
73
|
+
args, testbench_inputs=testbench_inputs
|
|
74
|
+
)
|
|
75
|
+
if error:
|
|
76
|
+
return CliResult(
|
|
77
|
+
exit_code=1,
|
|
78
|
+
command_line=args.command_line,
|
|
79
|
+
error=error,
|
|
80
|
+
)
|
|
81
|
+
return CliResult(
|
|
82
|
+
exit_code=0,
|
|
83
|
+
command_line=args.command_line,
|
|
84
|
+
success_message="",
|
|
85
|
+
generated=generated,
|
|
86
|
+
data_source=data_source,
|
|
87
|
+
)
|
|
88
|
+
except Exception as exc: # pragma: no cover - defensive reporting
|
|
89
|
+
LOGGER.exception("Unhandled exception while running CLI command.")
|
|
90
|
+
return CliResult(
|
|
91
|
+
exit_code=1,
|
|
92
|
+
command_line=args.command_line,
|
|
93
|
+
error=str(exc),
|
|
94
|
+
)
|
|
95
|
+
|
|
25
96
|
|
|
26
97
|
def _build_parser() -> argparse.ArgumentParser:
|
|
27
98
|
description = (
|
|
@@ -86,6 +157,33 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
86
157
|
"named like the output with a _data suffix"
|
|
87
158
|
),
|
|
88
159
|
)
|
|
160
|
+
compile_parser.add_argument(
|
|
161
|
+
"--truncate-weights-after",
|
|
162
|
+
type=int,
|
|
163
|
+
default=None,
|
|
164
|
+
help=(
|
|
165
|
+
"Truncate inline weight initializers after N values and insert "
|
|
166
|
+
"\"...\" placeholders (default: no truncation)"
|
|
167
|
+
),
|
|
168
|
+
)
|
|
169
|
+
compile_parser.add_argument(
|
|
170
|
+
"--large-temp-threshold-bytes",
|
|
171
|
+
type=int,
|
|
172
|
+
default=1024,
|
|
173
|
+
help=(
|
|
174
|
+
"Mark temporary buffers larger than this threshold as static "
|
|
175
|
+
"(default: 1024)"
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
compile_parser.add_argument(
|
|
179
|
+
"--large-weight-threshold",
|
|
180
|
+
type=int,
|
|
181
|
+
default=1024,
|
|
182
|
+
help=(
|
|
183
|
+
"Store weights larger than this element count in a binary file "
|
|
184
|
+
"(default: 1024)"
|
|
185
|
+
),
|
|
186
|
+
)
|
|
89
187
|
add_restrict_flags(compile_parser)
|
|
90
188
|
|
|
91
189
|
verify_parser = subparsers.add_parser(
|
|
@@ -111,6 +209,48 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
111
209
|
default=None,
|
|
112
210
|
help="C compiler command to build the testbench binary",
|
|
113
211
|
)
|
|
212
|
+
verify_parser.add_argument(
|
|
213
|
+
"--truncate-weights-after",
|
|
214
|
+
type=int,
|
|
215
|
+
default=None,
|
|
216
|
+
help=(
|
|
217
|
+
"Truncate inline weight initializers after N values and insert "
|
|
218
|
+
"\"...\" placeholders (default: no truncation)"
|
|
219
|
+
),
|
|
220
|
+
)
|
|
221
|
+
verify_parser.add_argument(
|
|
222
|
+
"--large-temp-threshold-bytes",
|
|
223
|
+
type=int,
|
|
224
|
+
default=1024,
|
|
225
|
+
help=(
|
|
226
|
+
"Mark temporary buffers larger than this threshold as static "
|
|
227
|
+
"(default: 1024)"
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
verify_parser.add_argument(
|
|
231
|
+
"--large-weight-threshold",
|
|
232
|
+
type=int,
|
|
233
|
+
default=1024,
|
|
234
|
+
help=(
|
|
235
|
+
"Store weights larger than this element count in a binary file "
|
|
236
|
+
"(default: 1024)"
|
|
237
|
+
),
|
|
238
|
+
)
|
|
239
|
+
verify_parser.add_argument(
|
|
240
|
+
"--test-data-dir",
|
|
241
|
+
type=Path,
|
|
242
|
+
default=None,
|
|
243
|
+
help=(
|
|
244
|
+
"Directory containing input_*.pb files to seed verification inputs "
|
|
245
|
+
"(default: use random testbench inputs)"
|
|
246
|
+
),
|
|
247
|
+
)
|
|
248
|
+
verify_parser.add_argument(
|
|
249
|
+
"--max-ulp",
|
|
250
|
+
type=int,
|
|
251
|
+
default=100,
|
|
252
|
+
help="Maximum allowed ULP difference for floating outputs (default: 100)",
|
|
253
|
+
)
|
|
114
254
|
add_restrict_flags(verify_parser)
|
|
115
255
|
return parser
|
|
116
256
|
|
|
@@ -132,7 +272,35 @@ def main(argv: Sequence[str] | None = None) -> int:
|
|
|
132
272
|
def _handle_compile(args: argparse.Namespace) -> int:
|
|
133
273
|
model_path: Path = args.model
|
|
134
274
|
output_path: Path = args.output or model_path.with_suffix(".c")
|
|
135
|
-
model_name = args.model_name or
|
|
275
|
+
model_name = args.model_name or "model"
|
|
276
|
+
generated, data_source, weight_data, error = _compile_model(args)
|
|
277
|
+
if error:
|
|
278
|
+
LOGGER.error("Failed to compile %s: %s", model_path, error)
|
|
279
|
+
return 1
|
|
280
|
+
|
|
281
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
282
|
+
output_path.write_text(generated or "", encoding="utf-8")
|
|
283
|
+
LOGGER.info("Wrote C source to %s", output_path)
|
|
284
|
+
if data_source is not None:
|
|
285
|
+
data_path = output_path.with_name(
|
|
286
|
+
f"{output_path.stem}_data{output_path.suffix}"
|
|
287
|
+
)
|
|
288
|
+
data_path.write_text(data_source, encoding="utf-8")
|
|
289
|
+
LOGGER.info("Wrote data source to %s", data_path)
|
|
290
|
+
if weight_data is not None:
|
|
291
|
+
weights_path = output_path.with_name(f"{model_name}.bin")
|
|
292
|
+
weights_path.write_bytes(weight_data)
|
|
293
|
+
LOGGER.info("Wrote weights binary to %s", weights_path)
|
|
294
|
+
return 0
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _compile_model(
|
|
298
|
+
args: argparse.Namespace,
|
|
299
|
+
*,
|
|
300
|
+
testbench_inputs: Mapping[str, "np.ndarray"] | None = None,
|
|
301
|
+
) -> tuple[str | None, str | None, bytes | None, str | None]:
|
|
302
|
+
model_path: Path = args.model
|
|
303
|
+
model_name = args.model_name or "model"
|
|
136
304
|
try:
|
|
137
305
|
model_checksum = _model_checksum(model_path)
|
|
138
306
|
model = onnx.load_model(model_path)
|
|
@@ -143,27 +311,22 @@ def _handle_compile(args: argparse.Namespace) -> int:
|
|
|
143
311
|
command_line=args.command_line,
|
|
144
312
|
model_checksum=model_checksum,
|
|
145
313
|
restrict_arrays=args.restrict_arrays,
|
|
314
|
+
truncate_weights_after=args.truncate_weights_after,
|
|
315
|
+
large_temp_threshold_bytes=args.large_temp_threshold_bytes,
|
|
316
|
+
large_weight_threshold=args.large_weight_threshold,
|
|
317
|
+
testbench_inputs=testbench_inputs,
|
|
146
318
|
)
|
|
147
319
|
compiler = Compiler(options)
|
|
148
320
|
if args.emit_data_file:
|
|
149
|
-
generated, data_source =
|
|
321
|
+
generated, data_source, weight_data = (
|
|
322
|
+
compiler.compile_with_data_file_and_weight_data(model)
|
|
323
|
+
)
|
|
150
324
|
else:
|
|
151
|
-
generated = compiler.
|
|
325
|
+
generated, weight_data = compiler.compile_with_weight_data(model)
|
|
152
326
|
data_source = None
|
|
153
327
|
except (OSError, CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
158
|
-
output_path.write_text(generated, encoding="utf-8")
|
|
159
|
-
LOGGER.info("Wrote C source to %s", output_path)
|
|
160
|
-
if data_source is not None:
|
|
161
|
-
data_path = output_path.with_name(
|
|
162
|
-
f"{output_path.stem}_data{output_path.suffix}"
|
|
163
|
-
)
|
|
164
|
-
data_path.write_text(data_source, encoding="utf-8")
|
|
165
|
-
LOGGER.info("Wrote data source to %s", data_path)
|
|
166
|
-
return 0
|
|
328
|
+
return None, None, None, str(exc)
|
|
329
|
+
return generated, data_source, weight_data, None
|
|
167
330
|
|
|
168
331
|
|
|
169
332
|
def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str] | None:
|
|
@@ -201,15 +364,60 @@ def _handle_verify(args: argparse.Namespace) -> int:
|
|
|
201
364
|
import numpy as np
|
|
202
365
|
import onnxruntime as ort
|
|
203
366
|
|
|
367
|
+
success_message, error, _operators = _verify_model(
|
|
368
|
+
args, include_build_details=True
|
|
369
|
+
)
|
|
370
|
+
if error is not None:
|
|
371
|
+
LOGGER.error("Verification failed: %s", error)
|
|
372
|
+
return 1
|
|
373
|
+
if success_message:
|
|
374
|
+
LOGGER.info("%s", success_message)
|
|
375
|
+
return 0
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _verify_model(
|
|
379
|
+
args: argparse.Namespace,
|
|
380
|
+
*,
|
|
381
|
+
include_build_details: bool,
|
|
382
|
+
) -> tuple[str | None, str | None, list[str]]:
|
|
383
|
+
import numpy as np
|
|
384
|
+
import onnxruntime as ort
|
|
385
|
+
|
|
386
|
+
def log_step(step: str, started_at: float) -> None:
|
|
387
|
+
duration = time.perf_counter() - started_at
|
|
388
|
+
LOGGER.info("verify step %s: %.3fs", step, duration)
|
|
389
|
+
|
|
390
|
+
def describe_exit_code(returncode: int) -> str:
|
|
391
|
+
if returncode >= 0:
|
|
392
|
+
return f"exit code {returncode}"
|
|
393
|
+
signal_id = -returncode
|
|
394
|
+
try:
|
|
395
|
+
signal_name = signal.Signals(signal_id).name
|
|
396
|
+
except ValueError:
|
|
397
|
+
signal_name = "unknown"
|
|
398
|
+
return f"exit code {returncode} (signal {signal_id}: {signal_name})"
|
|
399
|
+
|
|
204
400
|
model_path: Path = args.model
|
|
205
|
-
model_name = args.model_name or
|
|
401
|
+
model_name = args.model_name or "model"
|
|
206
402
|
model_checksum = _model_checksum(model_path)
|
|
207
403
|
compiler_cmd = _resolve_compiler(args.cc, prefer_ccache=False)
|
|
208
404
|
if compiler_cmd is None:
|
|
209
|
-
|
|
210
|
-
|
|
405
|
+
return (
|
|
406
|
+
None,
|
|
407
|
+
"No C compiler found (set --cc or CC environment variable).",
|
|
408
|
+
[],
|
|
409
|
+
)
|
|
211
410
|
try:
|
|
212
411
|
model = onnx.load_model(model_path)
|
|
412
|
+
except OSError as exc:
|
|
413
|
+
return None, str(exc), []
|
|
414
|
+
|
|
415
|
+
operators = _collect_model_operators(model)
|
|
416
|
+
operators_display = ", ".join(operators) if operators else "(none)"
|
|
417
|
+
LOGGER.info("verify operators: %s", operators_display)
|
|
418
|
+
|
|
419
|
+
try:
|
|
420
|
+
testbench_inputs = _load_test_data_inputs(model, args.test_data_dir)
|
|
213
421
|
options = CompilerOptions(
|
|
214
422
|
template_dir=args.template_dir,
|
|
215
423
|
model_name=model_name,
|
|
@@ -217,71 +425,103 @@ def _handle_verify(args: argparse.Namespace) -> int:
|
|
|
217
425
|
command_line=args.command_line,
|
|
218
426
|
model_checksum=model_checksum,
|
|
219
427
|
restrict_arrays=args.restrict_arrays,
|
|
428
|
+
truncate_weights_after=args.truncate_weights_after,
|
|
429
|
+
large_temp_threshold_bytes=args.large_temp_threshold_bytes,
|
|
430
|
+
large_weight_threshold=args.large_weight_threshold,
|
|
431
|
+
testbench_inputs=testbench_inputs,
|
|
220
432
|
)
|
|
221
433
|
compiler = Compiler(options)
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
434
|
+
codegen_started = time.perf_counter()
|
|
435
|
+
generated, weight_data = compiler.compile_with_weight_data(model)
|
|
436
|
+
log_step("codegen", codegen_started)
|
|
437
|
+
except (CodegenError, ShapeInferenceError, UnsupportedOpError) as exc:
|
|
438
|
+
return None, str(exc), operators
|
|
226
439
|
|
|
227
440
|
try:
|
|
228
441
|
graph = import_onnx(model)
|
|
229
442
|
output_dtypes = {value.name: value.type.dtype for value in graph.outputs}
|
|
230
443
|
input_dtypes = {value.name: value.type.dtype for value in graph.inputs}
|
|
231
444
|
except (KeyError, UnsupportedOpError, ShapeInferenceError) as exc:
|
|
232
|
-
|
|
233
|
-
return 1
|
|
445
|
+
return None, f"Failed to resolve model dtype: {exc}", operators
|
|
234
446
|
|
|
235
447
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
236
448
|
temp_path = Path(temp_dir)
|
|
449
|
+
LOGGER.info("verify temp dir: %s", temp_path)
|
|
237
450
|
c_path = temp_path / "model.c"
|
|
451
|
+
weights_path = temp_path / f"{model_name}.bin"
|
|
238
452
|
exe_path = temp_path / "model"
|
|
239
453
|
c_path.write_text(generated, encoding="utf-8")
|
|
454
|
+
if weight_data is not None:
|
|
455
|
+
weights_path.write_bytes(weight_data)
|
|
240
456
|
try:
|
|
457
|
+
compile_started = time.perf_counter()
|
|
458
|
+
compile_cmd = [
|
|
459
|
+
*compiler_cmd,
|
|
460
|
+
"-std=c99",
|
|
461
|
+
"-O2",
|
|
462
|
+
str(c_path),
|
|
463
|
+
"-o",
|
|
464
|
+
str(exe_path),
|
|
465
|
+
"-lm",
|
|
466
|
+
]
|
|
467
|
+
LOGGER.info("verify compile command: %s", shlex.join(compile_cmd))
|
|
241
468
|
subprocess.run(
|
|
242
|
-
|
|
243
|
-
*compiler_cmd,
|
|
244
|
-
"-std=c99",
|
|
245
|
-
"-O2",
|
|
246
|
-
str(c_path),
|
|
247
|
-
"-o",
|
|
248
|
-
str(exe_path),
|
|
249
|
-
"-lm",
|
|
250
|
-
],
|
|
469
|
+
compile_cmd,
|
|
251
470
|
check=True,
|
|
252
471
|
capture_output=True,
|
|
253
472
|
text=True,
|
|
254
473
|
)
|
|
474
|
+
log_step("compile", compile_started)
|
|
255
475
|
except subprocess.CalledProcessError as exc:
|
|
256
|
-
|
|
257
|
-
|
|
476
|
+
message = "Failed to build testbench."
|
|
477
|
+
if include_build_details:
|
|
478
|
+
details = exc.stderr.strip()
|
|
479
|
+
if details:
|
|
480
|
+
message = f"{message} {details}"
|
|
481
|
+
return None, message, operators
|
|
258
482
|
try:
|
|
483
|
+
run_started = time.perf_counter()
|
|
259
484
|
result = subprocess.run(
|
|
260
485
|
[str(exe_path)],
|
|
261
486
|
check=True,
|
|
262
487
|
capture_output=True,
|
|
263
488
|
text=True,
|
|
489
|
+
cwd=temp_path,
|
|
264
490
|
)
|
|
491
|
+
log_step("run", run_started)
|
|
265
492
|
except subprocess.CalledProcessError as exc:
|
|
266
|
-
|
|
267
|
-
|
|
493
|
+
return None, (
|
|
494
|
+
"Testbench execution failed: " + describe_exit_code(exc.returncode)
|
|
495
|
+
), operators
|
|
268
496
|
|
|
269
497
|
try:
|
|
270
498
|
payload = json.loads(result.stdout)
|
|
271
499
|
except json.JSONDecodeError as exc:
|
|
272
|
-
|
|
273
|
-
return 1
|
|
500
|
+
return None, f"Failed to parse testbench JSON: {exc}", operators
|
|
274
501
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
502
|
+
if testbench_inputs:
|
|
503
|
+
inputs = {
|
|
504
|
+
name: values.astype(input_dtypes[name].np_dtype, copy=False)
|
|
505
|
+
for name, values in testbench_inputs.items()
|
|
506
|
+
}
|
|
507
|
+
else:
|
|
508
|
+
inputs = {
|
|
509
|
+
name: decode_testbench_array(
|
|
510
|
+
value["data"], input_dtypes[name].np_dtype
|
|
511
|
+
)
|
|
512
|
+
for name, value in payload["inputs"].items()
|
|
513
|
+
}
|
|
282
514
|
try:
|
|
515
|
+
ort_started = time.perf_counter()
|
|
516
|
+
sess_options = make_deterministic_session_options(ort)
|
|
517
|
+
sess = ort.InferenceSession(
|
|
518
|
+
model.SerializeToString(),
|
|
519
|
+
sess_options=sess_options,
|
|
520
|
+
providers=["CPUExecutionProvider"],
|
|
521
|
+
)
|
|
283
522
|
ort_outputs = sess.run(None, inputs)
|
|
284
523
|
except Exception as exc:
|
|
524
|
+
log_step("onnx runtime", ort_started)
|
|
285
525
|
message = str(exc)
|
|
286
526
|
if "NOT_IMPLEMENTED" in message:
|
|
287
527
|
LOGGER.warning(
|
|
@@ -289,28 +529,57 @@ def _handle_verify(args: argparse.Namespace) -> int:
|
|
|
289
529
|
model_path,
|
|
290
530
|
message,
|
|
291
531
|
)
|
|
292
|
-
return
|
|
293
|
-
|
|
294
|
-
|
|
532
|
+
return "", None, operators
|
|
533
|
+
return None, f"ONNX Runtime failed to run {model_path}: {message}", operators
|
|
534
|
+
log_step("onnx runtime", ort_started)
|
|
295
535
|
payload_outputs = payload.get("outputs", {})
|
|
536
|
+
max_ulp = 0
|
|
296
537
|
try:
|
|
297
538
|
for value, ort_out in zip(graph.outputs, ort_outputs):
|
|
298
539
|
output_payload = payload_outputs.get(value.name)
|
|
299
540
|
if output_payload is None:
|
|
300
541
|
raise AssertionError(f"Missing output {value.name} in testbench data")
|
|
301
542
|
info = output_dtypes[value.name]
|
|
302
|
-
output_data =
|
|
543
|
+
output_data = decode_testbench_array(
|
|
544
|
+
output_payload["data"], info.np_dtype
|
|
545
|
+
).astype(info.np_dtype, copy=False)
|
|
546
|
+
ort_out = ort_out.astype(info.np_dtype, copy=False)
|
|
547
|
+
output_data = output_data.reshape(ort_out.shape)
|
|
303
548
|
if np.issubdtype(info.np_dtype, np.floating):
|
|
304
|
-
|
|
305
|
-
output_data, ort_out, rtol=1e-4, atol=1e-5
|
|
306
|
-
)
|
|
549
|
+
max_ulp = max(max_ulp, max_ulp_diff(output_data, ort_out))
|
|
307
550
|
else:
|
|
308
551
|
np.testing.assert_array_equal(output_data, ort_out)
|
|
309
552
|
except AssertionError as exc:
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
return
|
|
553
|
+
return None, str(exc), operators
|
|
554
|
+
if max_ulp > args.max_ulp:
|
|
555
|
+
return None, f"Out of tolerance (max ULP {max_ulp})", operators
|
|
556
|
+
return format_success_message(max_ulp), None, operators
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _load_test_data_inputs(
|
|
560
|
+
model: onnx.ModelProto, data_dir: Path | None
|
|
561
|
+
) -> dict[str, "np.ndarray"] | None:
|
|
562
|
+
if data_dir is None:
|
|
563
|
+
return None
|
|
564
|
+
if not data_dir.exists():
|
|
565
|
+
raise CodegenError(f"Test data directory not found: {data_dir}")
|
|
566
|
+
input_files = sorted(
|
|
567
|
+
data_dir.glob("input_*.pb"),
|
|
568
|
+
key=lambda path: int(path.stem.split("_")[-1]),
|
|
569
|
+
)
|
|
570
|
+
if not input_files:
|
|
571
|
+
raise CodegenError(f"No input_*.pb files found in {data_dir}")
|
|
572
|
+
if len(input_files) != len(model.graph.input):
|
|
573
|
+
raise CodegenError(
|
|
574
|
+
"Test data input count does not match model inputs: "
|
|
575
|
+
f"{len(input_files)} vs {len(model.graph.input)}."
|
|
576
|
+
)
|
|
577
|
+
inputs: dict[str, np.ndarray] = {}
|
|
578
|
+
for index, path in enumerate(input_files):
|
|
579
|
+
tensor = onnx.TensorProto()
|
|
580
|
+
tensor.ParseFromString(path.read_bytes())
|
|
581
|
+
inputs[model.graph.input[index].name] = numpy_helper.to_array(tensor)
|
|
582
|
+
return inputs
|
|
314
583
|
|
|
315
584
|
|
|
316
585
|
def _format_command_line(argv: Sequence[str] | None) -> str:
|
|
@@ -326,3 +595,15 @@ def _model_checksum(model_path: Path) -> str:
|
|
|
326
595
|
digest = hashlib.sha256()
|
|
327
596
|
digest.update(model_path.read_bytes())
|
|
328
597
|
return digest.hexdigest()
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def _collect_model_operators(model: onnx.ModelProto) -> list[str]:
|
|
601
|
+
operators: list[str] = []
|
|
602
|
+
seen: set[str] = set()
|
|
603
|
+
for node in model.graph.node:
|
|
604
|
+
op_name = f"{node.domain}::{node.op_type}" if node.domain else node.op_type
|
|
605
|
+
if op_name in seen:
|
|
606
|
+
continue
|
|
607
|
+
seen.add(op_name)
|
|
608
|
+
operators.append(op_name)
|
|
609
|
+
return operators
|