emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,3 @@
1
1
  """Auto-generated by build backend. Do not edit."""
2
- BUILD_DATE = '2026-01-20T17:39:52Z'
2
+ BUILD_DATE = '2026-01-23T03:11:42Z'
3
3
  GIT_VERSION = 'unknown'
emx_onnx_cgen/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.0'
32
- __version_tuple__ = version_tuple = (0, 3, 0)
31
+ __version__ = version = '0.3.2'
32
+ __version_tuple__ = version_tuple = (0, 3, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
emx_onnx_cgen/cli.py CHANGED
@@ -178,10 +178,10 @@ def _build_parser() -> argparse.ArgumentParser:
178
178
  compile_parser.add_argument(
179
179
  "--large-weight-threshold",
180
180
  type=int,
181
- default=1024,
181
+ default=1024 * 1024,
182
182
  help=(
183
183
  "Store weights larger than this element count in a binary file "
184
- "(default: 1024)"
184
+ "(default: 1048576; set to 0 to disable)"
185
185
  ),
186
186
  )
187
187
  add_restrict_flags(compile_parser)
@@ -251,6 +251,15 @@ def _build_parser() -> argparse.ArgumentParser:
251
251
  default=100,
252
252
  help="Maximum allowed ULP difference for floating outputs (default: 100)",
253
253
  )
254
+ verify_parser.add_argument(
255
+ "--runtime",
256
+ choices=("onnxruntime", "onnx-reference"),
257
+ default="onnx-reference",
258
+ help=(
259
+ "Runtime backend for verification (default: onnx-reference; "
260
+ "options: onnxruntime, onnx-reference)"
261
+ ),
262
+ )
254
263
  add_restrict_flags(verify_parser)
255
264
  return parser
256
265
 
@@ -361,9 +370,6 @@ def _resolve_compiler(cc: str | None, prefer_ccache: bool = False) -> list[str]
361
370
 
362
371
 
363
372
  def _handle_verify(args: argparse.Namespace) -> int:
364
- import numpy as np
365
- import onnxruntime as ort
366
-
367
373
  success_message, error, _operators = _verify_model(
368
374
  args, include_build_details=True
369
375
  )
@@ -381,7 +387,6 @@ def _verify_model(
381
387
  include_build_details: bool,
382
388
  ) -> tuple[str | None, str | None, list[str]]:
383
389
  import numpy as np
384
- import onnxruntime as ort
385
390
 
386
391
  def log_step(step: str, started_at: float) -> None:
387
392
  duration = time.perf_counter() - started_at
@@ -511,31 +516,44 @@ def _verify_model(
511
516
  )
512
517
  for name, value in payload["inputs"].items()
513
518
  }
519
+ runtime_name = args.runtime
520
+ runtime_started = time.perf_counter()
514
521
  try:
515
- ort_started = time.perf_counter()
516
- sess_options = make_deterministic_session_options(ort)
517
- sess = ort.InferenceSession(
518
- model.SerializeToString(),
519
- sess_options=sess_options,
520
- providers=["CPUExecutionProvider"],
521
- )
522
- ort_outputs = sess.run(None, inputs)
522
+ if runtime_name == "onnxruntime":
523
+ import onnxruntime as ort
524
+
525
+ sess_options = make_deterministic_session_options(ort)
526
+ sess = ort.InferenceSession(
527
+ model.SerializeToString(),
528
+ sess_options=sess_options,
529
+ providers=["CPUExecutionProvider"],
530
+ )
531
+ runtime_outputs = sess.run(None, inputs)
532
+ else:
533
+ from onnx.reference import ReferenceEvaluator
534
+
535
+ evaluator = ReferenceEvaluator(model)
536
+ runtime_outputs = evaluator.run(None, inputs)
523
537
  except Exception as exc:
524
- log_step("onnx runtime", ort_started)
538
+ log_step(runtime_name, runtime_started)
525
539
  message = str(exc)
526
- if "NOT_IMPLEMENTED" in message:
540
+ if runtime_name == "onnxruntime" and "NOT_IMPLEMENTED" in message:
527
541
  LOGGER.warning(
528
542
  "Skipping verification for %s: ONNX Runtime does not support the model (%s)",
529
543
  model_path,
530
544
  message,
531
545
  )
532
546
  return "", None, operators
533
- return None, f"ONNX Runtime failed to run {model_path}: {message}", operators
534
- log_step("onnx runtime", ort_started)
547
+ return (
548
+ None,
549
+ f"{runtime_name} failed to run {model_path}: {message}",
550
+ operators,
551
+ )
552
+ log_step(runtime_name, runtime_started)
535
553
  payload_outputs = payload.get("outputs", {})
536
554
  max_ulp = 0
537
555
  try:
538
- for value, ort_out in zip(graph.outputs, ort_outputs):
556
+ for value, runtime_out in zip(graph.outputs, runtime_outputs):
539
557
  output_payload = payload_outputs.get(value.name)
540
558
  if output_payload is None:
541
559
  raise AssertionError(f"Missing output {value.name} in testbench data")
@@ -543,12 +561,12 @@ def _verify_model(
543
561
  output_data = decode_testbench_array(
544
562
  output_payload["data"], info.np_dtype
545
563
  ).astype(info.np_dtype, copy=False)
546
- ort_out = ort_out.astype(info.np_dtype, copy=False)
547
- output_data = output_data.reshape(ort_out.shape)
564
+ runtime_out = runtime_out.astype(info.np_dtype, copy=False)
565
+ output_data = output_data.reshape(runtime_out.shape)
548
566
  if np.issubdtype(info.np_dtype, np.floating):
549
- max_ulp = max(max_ulp, max_ulp_diff(output_data, ort_out))
567
+ max_ulp = max(max_ulp, max_ulp_diff(output_data, runtime_out))
550
568
  else:
551
- np.testing.assert_array_equal(output_data, ort_out)
569
+ np.testing.assert_array_equal(output_data, runtime_out)
552
570
  except AssertionError as exc:
553
571
  return None, str(exc), operators
554
572
  if max_ulp > args.max_ulp:
@@ -574,6 +592,15 @@ def _load_test_data_inputs(
574
592
  "Test data input count does not match model inputs: "
575
593
  f"{len(input_files)} vs {len(model.graph.input)}."
576
594
  )
595
+ for value_info in model.graph.input:
596
+ value_kind = value_info.type.WhichOneof("value")
597
+ if value_kind != "tensor_type":
598
+ LOGGER.warning(
599
+ "Skipping test data load for non-tensor input %s (type %s).",
600
+ value_info.name,
601
+ value_kind or "unknown",
602
+ )
603
+ return None
577
604
  inputs: dict[str, np.ndarray] = {}
578
605
  for index, path in enumerate(input_files):
579
606
  tensor = onnx.TensorProto()
@@ -7,6 +7,7 @@ from .c_emitter import (
7
7
  GemmOp,
8
8
  LoweredModel,
9
9
  MatMulOp,
10
+ QLinearMatMulOp,
10
11
  ShapeOp,
11
12
  UnaryOp,
12
13
  )
@@ -20,6 +21,7 @@ __all__ = [
20
21
  "GemmOp",
21
22
  "LoweredModel",
22
23
  "MatMulOp",
24
+ "QLinearMatMulOp",
23
25
  "ShapeOp",
24
26
  "UnaryOp",
25
27
  ]