emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +50 -23
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +30 -387
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +36 -18
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +1 -1
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +1 -1
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +1 -1
- emx_onnx_cgen/lowering/gather_nd.py +1 -1
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +1 -1
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +1 -1
- emx_onnx_cgen/lowering/identity.py +1 -1
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +1 -1
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +1 -1
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +1 -1
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +1 -1
- emx_onnx_cgen/lowering/one_hot.py +1 -1
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +1 -1
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +1 -1
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +1 -1
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +1 -1
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +1 -1
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +1 -1
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +25 -7
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +1 -1
- emx_onnx_cgen/lowering/unsqueeze.py +1 -1
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/runtime/evaluator.py +325 -1
- emx_onnx_cgen/verification.py +9 -39
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
- emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +11 -0
- shared/ulp.py +17 -0
- emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/_build_info.py
CHANGED
emx_onnx_cgen/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 2)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
emx_onnx_cgen/cli.py
CHANGED
|
@@ -178,10 +178,10 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
178
178
|
compile_parser.add_argument(
|
|
179
179
|
"--large-weight-threshold",
|
|
180
180
|
type=int,
|
|
181
|
-
default=1024,
|
|
181
|
+
default=1024 * 1024,
|
|
182
182
|
help=(
|
|
183
183
|
"Store weights larger than this element count in a binary file "
|
|
184
|
-
"(default:
|
|
184
|
+
"(default: 1048576; set to 0 to disable)"
|
|
185
185
|
),
|
|
186
186
|
)
|
|
187
187
|
add_restrict_flags(compile_parser)
|
|
@@ -251,6 +251,15 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
|
251
251
|
default=100,
|
|
252
252
|
help="Maximum allowed ULP difference for floating outputs (default: 100)",
|
|
253
253
|
)
|
|
254
|
+
verify_parser.add_argument(
|
|
255
|
+
"--runtime",
|
|
256
|
+
choices=("onnxruntime", "onnx-reference"),
|
|
257
|
+
default="onnx-reference",
|
|
258
|
+
help=(
|
|
259
|
+
"Runtime backend for verification (default: onnx-reference; "
|
|
260
|
+
"options: onnxruntime, onnx-reference)"
|
|
261
|
+
),
|
|
262
|
+
)
|
|
254
263
|
add_restrict_flags(verify_parser)
|
|
255
264
|
return parser
|
|
256
265
|
|
|
@@ -361,9 +370,6 @@ def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str]
|
|
|
361
370
|
|
|
362
371
|
|
|
363
372
|
def _handle_verify(args: argparse.Namespace) -> int:
|
|
364
|
-
import numpy as np
|
|
365
|
-
import onnxruntime as ort
|
|
366
|
-
|
|
367
373
|
success_message, error, _operators = _verify_model(
|
|
368
374
|
args, include_build_details=True
|
|
369
375
|
)
|
|
@@ -381,7 +387,6 @@ def _verify_model(
|
|
|
381
387
|
include_build_details: bool,
|
|
382
388
|
) -> tuple[str | None, str | None, list[str]]:
|
|
383
389
|
import numpy as np
|
|
384
|
-
import onnxruntime as ort
|
|
385
390
|
|
|
386
391
|
def log_step(step: str, started_at: float) -> None:
|
|
387
392
|
duration = time.perf_counter() - started_at
|
|
@@ -511,31 +516,44 @@ def _verify_model(
|
|
|
511
516
|
)
|
|
512
517
|
for name, value in payload["inputs"].items()
|
|
513
518
|
}
|
|
519
|
+
runtime_name = args.runtime
|
|
520
|
+
runtime_started = time.perf_counter()
|
|
514
521
|
try:
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
522
|
+
if runtime_name == "onnxruntime":
|
|
523
|
+
import onnxruntime as ort
|
|
524
|
+
|
|
525
|
+
sess_options = make_deterministic_session_options(ort)
|
|
526
|
+
sess = ort.InferenceSession(
|
|
527
|
+
model.SerializeToString(),
|
|
528
|
+
sess_options=sess_options,
|
|
529
|
+
providers=["CPUExecutionProvider"],
|
|
530
|
+
)
|
|
531
|
+
runtime_outputs = sess.run(None, inputs)
|
|
532
|
+
else:
|
|
533
|
+
from onnx.reference import ReferenceEvaluator
|
|
534
|
+
|
|
535
|
+
evaluator = ReferenceEvaluator(model)
|
|
536
|
+
runtime_outputs = evaluator.run(None, inputs)
|
|
523
537
|
except Exception as exc:
|
|
524
|
-
log_step(
|
|
538
|
+
log_step(runtime_name, runtime_started)
|
|
525
539
|
message = str(exc)
|
|
526
|
-
if "NOT_IMPLEMENTED" in message:
|
|
540
|
+
if runtime_name == "onnxruntime" and "NOT_IMPLEMENTED" in message:
|
|
527
541
|
LOGGER.warning(
|
|
528
542
|
"Skipping verification for %s: ONNX Runtime does not support the model (%s)",
|
|
529
543
|
model_path,
|
|
530
544
|
message,
|
|
531
545
|
)
|
|
532
546
|
return "", None, operators
|
|
533
|
-
return
|
|
534
|
-
|
|
547
|
+
return (
|
|
548
|
+
None,
|
|
549
|
+
f"{runtime_name} failed to run {model_path}: {message}",
|
|
550
|
+
operators,
|
|
551
|
+
)
|
|
552
|
+
log_step(runtime_name, runtime_started)
|
|
535
553
|
payload_outputs = payload.get("outputs", {})
|
|
536
554
|
max_ulp = 0
|
|
537
555
|
try:
|
|
538
|
-
for value,
|
|
556
|
+
for value, runtime_out in zip(graph.outputs, runtime_outputs):
|
|
539
557
|
output_payload = payload_outputs.get(value.name)
|
|
540
558
|
if output_payload is None:
|
|
541
559
|
raise AssertionError(f"Missing output {value.name} in testbench data")
|
|
@@ -543,12 +561,12 @@ def _verify_model(
|
|
|
543
561
|
output_data = decode_testbench_array(
|
|
544
562
|
output_payload["data"], info.np_dtype
|
|
545
563
|
).astype(info.np_dtype, copy=False)
|
|
546
|
-
|
|
547
|
-
output_data = output_data.reshape(
|
|
564
|
+
runtime_out = runtime_out.astype(info.np_dtype, copy=False)
|
|
565
|
+
output_data = output_data.reshape(runtime_out.shape)
|
|
548
566
|
if np.issubdtype(info.np_dtype, np.floating):
|
|
549
|
-
max_ulp = max(max_ulp, max_ulp_diff(output_data,
|
|
567
|
+
max_ulp = max(max_ulp, max_ulp_diff(output_data, runtime_out))
|
|
550
568
|
else:
|
|
551
|
-
np.testing.assert_array_equal(output_data,
|
|
569
|
+
np.testing.assert_array_equal(output_data, runtime_out)
|
|
552
570
|
except AssertionError as exc:
|
|
553
571
|
return None, str(exc), operators
|
|
554
572
|
if max_ulp > args.max_ulp:
|
|
@@ -574,6 +592,15 @@ def _load_test_data_inputs(
|
|
|
574
592
|
"Test data input count does not match model inputs: "
|
|
575
593
|
f"{len(input_files)} vs {len(model.graph.input)}."
|
|
576
594
|
)
|
|
595
|
+
for value_info in model.graph.input:
|
|
596
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
597
|
+
if value_kind != "tensor_type":
|
|
598
|
+
LOGGER.warning(
|
|
599
|
+
"Skipping test data load for non-tensor input %s (type %s).",
|
|
600
|
+
value_info.name,
|
|
601
|
+
value_kind or "unknown",
|
|
602
|
+
)
|
|
603
|
+
return None
|
|
577
604
|
inputs: dict[str, np.ndarray] = {}
|
|
578
605
|
for index, path in enumerate(input_files):
|
|
579
606
|
tensor = onnx.TensorProto()
|