emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+
7
+
8
+ def _convert_hex_floats(value: Any) -> Any:
9
+ if isinstance(value, list):
10
+ return [_convert_hex_floats(item) for item in value]
11
+ if isinstance(value, str):
12
+ return float.fromhex(value)
13
+ return value
14
+
15
+
16
+ def decode_testbench_array(data: object, dtype: np.dtype) -> np.ndarray:
17
+ """Decode testbench JSON data into a numpy array.
18
+
19
+ Floating-point values are expected to be hex strings (C99 %a formatting).
20
+ """
21
+ if np.issubdtype(dtype, np.floating):
22
+ data = _convert_hex_floats(data)
23
+ return np.array(data, dtype=dtype)
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ def _float_uint_dtype(values: np.ndarray) -> type[np.unsignedinteger]:
7
+ if values.dtype == np.float16:
8
+ return np.uint16
9
+ if values.dtype == np.float32:
10
+ return np.uint32
11
+ if values.dtype == np.float64:
12
+ return np.uint64
13
+ raise ValueError(f"Unsupported floating dtype for ULP calculation: {values.dtype}")
14
+
15
+
16
+ def _float_to_ordered_int(values: np.ndarray) -> np.ndarray:
17
+ uint_dtype = _float_uint_dtype(values)
18
+ bits = np.dtype(uint_dtype).itemsize * 8
19
+ sign_mask = np.array(1 << (bits - 1), dtype=uint_dtype)
20
+ as_uint = values.view(uint_dtype)
21
+ ordered = np.where(as_uint & sign_mask, ~as_uint, as_uint | sign_mask)
22
+ return ordered.astype(np.uint64, copy=False)
23
+
24
+
25
+ def max_ulp_diff(actual: np.ndarray, expected: np.ndarray) -> int:
26
+ if actual.shape != expected.shape:
27
+ raise ValueError(
28
+ f"Shape mismatch for ULP calculation: {actual.shape} vs {expected.shape}"
29
+ )
30
+ if not np.issubdtype(expected.dtype, np.floating):
31
+ return 0
32
+ dtype = expected.dtype
33
+ if dtype not in (np.float16, np.float32, np.float64):
34
+ raise ValueError(f"Unsupported floating dtype for ULP calculation: {dtype}")
35
+ actual_cast = actual.astype(dtype, copy=False)
36
+ expected_cast = expected.astype(dtype, copy=False)
37
+ nan_mask = np.isnan(actual_cast) | np.isnan(expected_cast)
38
+ if nan_mask.any():
39
+ both_nan = np.isnan(actual_cast) & np.isnan(expected_cast)
40
+ if not np.all(both_nan):
41
+ uint_dtype = _float_uint_dtype(expected_cast)
42
+ return int(np.iinfo(uint_dtype).max)
43
+ actual_cast = actual_cast[~nan_mask]
44
+ expected_cast = expected_cast[~nan_mask]
45
+ if actual_cast.size == 0:
46
+ return 0
47
+ eps = np.finfo(dtype).eps
48
+ near_zero = (np.abs(actual_cast) < eps) & (np.abs(expected_cast) < eps)
49
+ if np.any(near_zero):
50
+ actual_cast = actual_cast.copy()
51
+ expected_cast = expected_cast.copy()
52
+ actual_cast[near_zero] = 0
53
+ expected_cast[near_zero] = 0
54
+ ordered_actual = _float_to_ordered_int(actual_cast)
55
+ ordered_expected = _float_to_ordered_int(expected_cast)
56
+ deltas = ordered_actual.astype(np.int64) - ordered_expected.astype(np.int64)
57
+ return int(np.max(np.abs(deltas)))
58
+
59
+
60
+ def format_success_message(max_ulp: int) -> str:
61
+ return f"OK (max ULP {max_ulp})"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emx-onnx-cgen
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: emmtrix ONNX-to-C Code Generator
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -9,14 +9,14 @@ Description-Content-Type: text/markdown
9
9
 
10
10
  [![PyPI - Version](https://img.shields.io/pypi/v/emx-onnx-cgen.svg)](https://pypi.org/project/emx-onnx-cgen)
11
11
 
12
- `emx-onnx-cgen` compiles ONNX models to portable, deterministic C code for deeply embedded systems. The generated code is designed to run without dynamic memory allocation, operating systems, or external runtimes, making it suitable for safety-critical and resource-constrained targets.
12
+ `emx-onnx-cgen` compiles ONNX models to portable, deterministic C code for deeply embedded systems. The generated code is designed to run without dynamic memory allocation, operating-system services, or external runtimes, making it suitable for safety-critical and resource-constrained targets.
13
13
 
14
14
  Key characteristics:
15
15
 
16
16
  - **No dynamic memory allocation** (`malloc`, `free`, heap usage)
17
17
  - **Static, compile-time known memory layout** for parameters, activations, and temporaries
18
18
  - **Deterministic control flow** (explicit loops, no hidden dispatch or callbacks)
19
- - **No OS or libc dependencies** beyond basic C
19
+ - **No OS dependencies**, using only standard C headers (for example, `stdint.h` and `stddef.h`)
20
20
  - **Single-threaded execution model**
21
21
  - **Bitwise-stable code generation** for reproducible builds
22
22
  - **Readable, auditable C code** suitable for certification and code reviews
@@ -47,7 +47,7 @@ Key characteristics:
47
47
  - `float`, `double`, `float16`
48
48
  - `int8_t`, `uint8_t`, `int16_t`, `uint16_t`, `int32_t`, `uint32_t`, `int64_t`, `uint64_t`
49
49
  - `bool`
50
- - Supporting dynamic dimensions by utilizing C99 variable-length arrays (VLAs).
50
+ - Optional support for dynamic dimensions using C99 variable-length arrays (VLAs), when the target compiler supports them.
51
51
 
52
52
  ## Installation
53
53
 
@@ -93,6 +93,8 @@ Options:
93
93
  - `--model-name`: Override the generated model name (default: output file stem).
94
94
  - `--emit-testbench`: Emit a JSON-producing `main()` testbench for validation.
95
95
  - `--emit-data-file`: Emit constant data arrays into a companion `_data` C file.
96
+ - `--large-weight-threshold`: Store weights larger than this element count in a binary file (default: `1024`).
97
+ - `--large-temp-threshold-bytes`: Mark temporary buffers larger than this threshold as static (default: `1024`).
96
98
  - `--no-restrict-arrays`: Disable `restrict` qualifiers on generated array parameters.
97
99
 
98
100
  ### `verify`
@@ -106,6 +108,25 @@ Options:
106
108
  - `--template-dir`: Directory containing the C templates (default: `templates`).
107
109
  - `--model-name`: Override the generated model name (default: model file stem).
108
110
  - `--cc`: Explicit C compiler command for building the testbench binary.
111
+ - `--large-weight-threshold`: Store weights larger than this element count in a binary file (default: `1024`).
112
+ - `--large-temp-threshold-bytes`: Mark temporary buffers larger than this threshold as static (default: `1024`).
113
+ - `--max-ulp`: Maximum allowed ULP distance for floating outputs (default: `100`).
114
+
115
+ How verification works:
116
+
117
+ 1. **Compile with a testbench**: the compiler is invoked with `--emit-testbench`,
118
+ generating a C program that runs the model and prints inputs/outputs as JSON.
119
+ 2. **Build and execute**: the testbench is compiled with the selected C compiler
120
+ (`--cc`, `CC`, or a detected `cc/gcc/clang`) and executed in a temporary
121
+ directory.
122
+ 3. **Run ONNX Runtime**: the JSON inputs from the testbench are fed to ORT using
123
+ the same model.
124
+ 4. **Compare outputs**: floating outputs are compared by maximum ULP distance
125
+ (see https://www.emmtrix.com/wiki/ULP_Difference_of_Float_Numbers for the
126
+ ULP definition and algorithm); non-floating outputs must match exactly.
127
+ Missing outputs or mismatches are treated as failures.
128
+ 5. **ORT unsupported models**: if ORT reports `NOT_IMPLEMENTED`, verification is
129
+ skipped with a warning (exit code 0).
109
130
 
110
131
  ## Output
111
132
 
@@ -113,15 +134,20 @@ By default, the compiler emits a single C source file that includes:
113
134
 
114
135
  - A generated entry point that mirrors the ONNX graph inputs/outputs.
115
136
  - Tensor buffers for constants and temporaries.
116
- - A lightweight runtime implemented via templates in `templates/`.
117
137
 
118
138
  When `--emit-data-file` is enabled, the main C source declares constant arrays
119
139
  as `extern`, and a second file named like the output with a `_data` suffix
120
140
  contains the constant definitions.
121
141
 
142
+ When `--large-weight-threshold` is set and a weight exceeds the threshold, the
143
+ compiler emits a `<model>.bin` file with weights packed contiguously and
144
+ generates a `<model>_load` helper that loads weights from the binary file at
145
+ runtime.
146
+
122
147
  ## Official ONNX test coverage
123
148
 
124
149
  See [`OFFICIAL_ONNX_FILE_SUPPORT.md`](OFFICIAL_ONNX_FILE_SUPPORT.md) for the generated support matrix.
150
+ See [`SUPPORT_OPS.md`](SUPPORT_OPS.md) for operator-level support derived from the expectation JSON files.
125
151
 
126
152
  ## Maintained by
127
153
 
@@ -1,76 +1,93 @@
1
1
  emx_onnx_cgen/__init__.py,sha256=jUSbu1kJ0krzVTYEcph3jCprBhD7tWNtiSdL6r29KrM,221
2
2
  emx_onnx_cgen/__main__.py,sha256=iC1lLVtR6-TmpL6OxXcy3oIntExUtajn9-q627R1XyI,140
3
- emx_onnx_cgen/_build_info.py,sha256=tp1Kzo4PNvj3ydRSCdr84YUd2wEUVUopl7AoaJfOaHk,112
4
- emx_onnx_cgen/cli.py,sha256=2eQ6JxfdamiLDqo-ZfZNnnb1a1I7NTEvecfUJaAA_3M,11024
5
- emx_onnx_cgen/compiler.py,sha256=foB9JK1Z2NRcRg_Fn-kL9K7L7tKGWxI9-kaMbYg3dnM,20877
3
+ emx_onnx_cgen/_build_info.py,sha256=zS8xdzMihYIqmbd58Pfku76TrBsTlBdkLVrWqTCiUs4,112
4
+ emx_onnx_cgen/_version.py,sha256=5zTqm8rgXsWYBpB2M3Zw_K1D-aV8wP7NsBLrmMKkrAQ,704
5
+ emx_onnx_cgen/cli.py,sha256=hRF2xG6t2YUkkTYrAUVsOyz1lTAdjokE-1pxFffsG2c,20643
6
+ emx_onnx_cgen/compiler.py,sha256=Q4a4_a1DkGmbrRJaRgpk0uyOqrJQExqDQD_BNw3AUcw,28585
6
7
  emx_onnx_cgen/dtypes.py,sha256=jRx3BBvk0qFW14bngoL1B7L_IRasyNJ4jqhpM5YhcOM,1335
7
8
  emx_onnx_cgen/errors.py,sha256=HpOv95mTgr9ZX2gYe1RtwVMbPskh7zkqjU_FgAD-uIM,363
8
- emx_onnx_cgen/onnx_import.py,sha256=aMvSxT3ycg4UmnutWYvsQpzGt2m_KpNnDIiddlg-vDA,9028
9
- emx_onnx_cgen/ops.py,sha256=pW3ks2EJITiJxvThfU58KjQZE7AFUiPmMwKZRCNb1mY,16586
9
+ emx_onnx_cgen/onnx_import.py,sha256=IF7KZGfEP9H4H1fHYjobGbB_381fqD_67KtqZYs9AZ4,9168
10
+ emx_onnx_cgen/onnxruntime_utils.py,sha256=mEsC1x00M1jyBgVBKqnKoqx6H1tdgsFFUy7rbITs3bs,308
11
+ emx_onnx_cgen/ops.py,sha256=qpPOaqsYprlJrhCNLVBZ3XnREBRDdmkXbd1zaAkywOI,16732
12
+ emx_onnx_cgen/testbench.py,sha256=-NbqD1aC7OXvFMLiLzd2IPObenQdHFH85cNxNSB1GeY,640
10
13
  emx_onnx_cgen/validation.py,sha256=KFdUdGjQbzTj1szCJcjxnTi8f5l6ywNgCB9abbBpTbM,2360
14
+ emx_onnx_cgen/verification.py,sha256=eTnhl_9YObyvs0fqJAw8796TlRzp-IoFM4JuMkQ8XOc,2403
11
15
  emx_onnx_cgen/codegen/__init__.py,sha256=-_sxL87uyAIunaetjUvIUo2bc46ugVlaNtSsidegMRM,362
12
- emx_onnx_cgen/codegen/c_emitter.py,sha256=gCgbqTgDWGAmp7H8TCFPBKp7DCNb_nskkMCAzKwDG0Y,337438
16
+ emx_onnx_cgen/codegen/c_emitter.py,sha256=vYrRx3UQvve_s4ElLiuh25lsbEt7mDTCXjvr-kdkggM,422455
13
17
  emx_onnx_cgen/ir/__init__.py,sha256=fD2D8qxlGoCFJb0m9v6u3XTgzSxDOhB4cfLBiCLovzg,102
14
- emx_onnx_cgen/ir/model.py,sha256=e8vRA0RNDU8Ioz3TXQKpdUhDtUK6Hm71KouUPMhCcpg,1213
18
+ emx_onnx_cgen/ir/model.py,sha256=SZ3K8t4dKUqWuXWe5ozApofXx4bdcf4p0WYCdeU-mFA,1265
15
19
  emx_onnx_cgen/lowering/__init__.py,sha256=wrxLMWcPUH1RbPJOs0Tsdb12FhXjAAeZVDYwKqcIuzw,103
16
20
  emx_onnx_cgen/lowering/arg_reduce.py,sha256=2AowDRCJRkIvrVBphbA0rM18oCWEpCDEV5Y4K9wSDII,3388
17
21
  emx_onnx_cgen/lowering/attention.py,sha256=19Jq_k0DXwH71a3pmLTWCNMttmw5uuiNK6Jhln5HC4A,16488
18
22
  emx_onnx_cgen/lowering/average_pool.py,sha256=9kg3pYHG7QLid_M2dbleC1VoNlVlGsKdOrsWp3pt7sc,8085
19
23
  emx_onnx_cgen/lowering/batch_normalization.py,sha256=_aFCm4QaC5jH-JNEvqDFYOyAMdzgUFS_3Gmo1vdPyKE,3987
20
24
  emx_onnx_cgen/lowering/cast.py,sha256=zKiE4wI7oWP_TjxBV4fY3-FXvZxK2zy58O6tWJ2dODQ,2852
21
- emx_onnx_cgen/lowering/common.py,sha256=4w9kjKW3_LilOGgmXYcUGg5lohbYsaLudoL4ALoDUkk,2356
25
+ emx_onnx_cgen/lowering/common.py,sha256=OF5UTin4teEFSp-rbiUArYCJogZ636Rujhkgrm2vj_w,16083
22
26
  emx_onnx_cgen/lowering/concat.py,sha256=TefckPfuaIHVHxGExJiO9wlkjyRO1TGg-QAMeoW8hW0,1097
23
27
  emx_onnx_cgen/lowering/constant_of_shape.py,sha256=btQflQFMP_y22sK7RrhkbGdaeSSLPC_DWhLjxY7CAgk,3208
24
28
  emx_onnx_cgen/lowering/conv.py,sha256=I1_tssw_ySf4beKV0sCVe8DRhNxL58PqC0wxtWjD79s,7309
29
+ emx_onnx_cgen/lowering/conv_transpose.py,sha256=vMbH7g3V9o68BjsW-FurNp1G8Dgr3NrV7JPLLfopHG0,11164
25
30
  emx_onnx_cgen/lowering/cumsum.py,sha256=eX0bDtwY-qevz0KXNHtJaDiKUUHIOhDX0uDiSxcC0ZU,4125
26
31
  emx_onnx_cgen/lowering/depth_space.py,sha256=M4md379jiumGWmg7EgR-CinoPzwof2RdfOiNqOzxd9o,4217
27
32
  emx_onnx_cgen/lowering/dropout.py,sha256=oBKPMN-J9Gnw8dRXvf-bN15L1-5W7-qKhR72Z6AgLXQ,1775
33
+ emx_onnx_cgen/lowering/einsum.py,sha256=g0KEZNJb87SzH-TqDDcfNPTcAaRioq455eN6HHLZNNo,6128
28
34
  emx_onnx_cgen/lowering/elementwise.py,sha256=HN6vEW58lceYECp-7QWLCWOBo1ImyY66aZIg06nA5g8,6231
29
35
  emx_onnx_cgen/lowering/expand.py,sha256=4msnYM-6RnzMplQqde2ovOLsjmWQ4bnXEoUiEM6CT6k,5529
30
36
  emx_onnx_cgen/lowering/eye_like.py,sha256=76HEdT-EofDCCy7DewjIpILJdIJyJ-YVCbLXO54SX5E,1734
31
37
  emx_onnx_cgen/lowering/flatten.py,sha256=sGol05FDN0xoNgSl_DlVbjYvBHCHWjQC2KB15ytYfPs,2142
32
38
  emx_onnx_cgen/lowering/gather.py,sha256=9zMB9fcdJi1fkTmDs_-L6FvQi1fnhdk0h7RmeN5MP2M,1814
33
- emx_onnx_cgen/lowering/gather_elements.py,sha256=0E-WAge15HhGeWPRC_ZE94fb9C6LIoef8p5D1usWvBQ,2358
39
+ emx_onnx_cgen/lowering/gather_elements.py,sha256=K-3w__F_I_gq3Kykk7LydTR5syH_Zpi-0-rdShLumbo,2329
40
+ emx_onnx_cgen/lowering/gather_nd.py,sha256=_0IW93RMRa9VtXSu4KMpBBA18ovLBGTmH90Y-ANOk1M,3101
34
41
  emx_onnx_cgen/lowering/gemm.py,sha256=Ps2T4tZgXr5FObz5figwbLZq-Njzg44iBQ9cFmvH78k,4590
42
+ emx_onnx_cgen/lowering/global_max_pool.py,sha256=xyoqQyRFpDKCXBO8bqp7JstVxVfbj9pMd06-848ix5o,2223
35
43
  emx_onnx_cgen/lowering/grid_sample.py,sha256=Ne-97ljxSdqfjBJtVHp2AQnEeXGQ5HE-HegCoxcNCm0,5228
36
44
  emx_onnx_cgen/lowering/group_normalization.py,sha256=RqXud5_xNxMu8DP5EsPr4L2b6CZghQWCcG6Bh1x2gMA,2664
37
- emx_onnx_cgen/lowering/identity.py,sha256=fn1Tg56xACwAjhesy0wyr9TJjCmmddnd8QrQ4-uCdO0,1843
45
+ emx_onnx_cgen/lowering/hardmax.py,sha256=Lqnt9g48-Kpcklq_uJpA93FYdhTI8SjUu4Cnh9_BhYY,1961
46
+ emx_onnx_cgen/lowering/identity.py,sha256=h8cy6T9n10IWEnOKgIGOca9sNVaw5vU1fdt_p8AAqq8,1859
38
47
  emx_onnx_cgen/lowering/instance_normalization.py,sha256=1Yx2KPKq_BHberCBTrGQXQswAS0FfDle9NpyeD41ypU,1950
39
48
  emx_onnx_cgen/lowering/layer_normalization.py,sha256=ZvqGZOhuoYh8ZPyzb-PV0kIc2bbunWTYj12wmrGu9YY,4529
40
- emx_onnx_cgen/lowering/logsoftmax.py,sha256=1FEaX45GdDr6jIdS_sOwXOy_DdVDruZem4yZ9XA4a38,1669
49
+ emx_onnx_cgen/lowering/logsoftmax.py,sha256=gdPYJdRNjwRDRXozzKOkKHV7HeOw48Zl4guiAN5pgKs,1895
41
50
  emx_onnx_cgen/lowering/lp_normalization.py,sha256=61CGS-2yN0bf5dby5b7Ug1PH3CStZN1xZmYWa5TysTI,1712
51
+ emx_onnx_cgen/lowering/lp_pool.py,sha256=96M-CeIqOkPstVr2BEoASImG6-Z4_S7ngB8bmPQlo7M,4873
42
52
  emx_onnx_cgen/lowering/lrn.py,sha256=zGw1Jk7iBk1jHdjdDqfAREsV5VcSdOG3LcAmEllIB08,3370
43
53
  emx_onnx_cgen/lowering/lstm.py,sha256=JhGxiF3bTSY3flkw_u9mil2esRxvIjr5Tc4vSPULDr4,12305
44
- emx_onnx_cgen/lowering/matmul.py,sha256=NEfBa140ofpgm9xnqUBulMSA-yQlb29F2NqhCJpmKSY,4262
54
+ emx_onnx_cgen/lowering/matmul.py,sha256=QJ2DfMk6g5hNNpfUUfHH732cfSLL0LFzZxw-4GOuRYQ,4259
45
55
  emx_onnx_cgen/lowering/maxpool.py,sha256=MRLeoCEdIwO8JNWOi7iKoeIsJvukqpx_w6GCHaDaYHU,7494
46
56
  emx_onnx_cgen/lowering/mean_variance_normalization.py,sha256=L_6ECH9wPEnNX2mL6yroZRexZM8JV5ZnJvoPQS6IAuc,1875
47
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py,sha256=wMWvJ9ymkA-ptFgqx0VmHAJHS5NVvDeo2GAlrECztJ8,9307
57
+ emx_onnx_cgen/lowering/negative_log_likelihood_loss.py,sha256=nqNETCqZP9MZhBU_Wcgaxu1w0uNKHa6VR5KTw--CWP0,9344
58
+ emx_onnx_cgen/lowering/nonzero.py,sha256=d_81vdF2BkYi_Z7jbypQ3qwSe-ctNoSxaJRYO8qDMjc,1637
59
+ emx_onnx_cgen/lowering/one_hot.py,sha256=twxMlNosGfm3iAeGv5LDGPliMCJ36o7Dxusn4gVJE00,4352
48
60
  emx_onnx_cgen/lowering/pad.py,sha256=cNlh-rA3CRPfO-u8gvJ1MeF1j-vdBMXLuJBpp2DkFTc,10416
61
+ emx_onnx_cgen/lowering/quantize_linear.py,sha256=3IvZTQbWAYh61nVNAbd-CeHTpYaULUSEjSdnu-nIOI8,4771
49
62
  emx_onnx_cgen/lowering/range.py,sha256=xLbG3SGvQiboPqSIh5qZyw3Krbxjk0nd3YvRlUI_q64,3463
50
- emx_onnx_cgen/lowering/reduce.py,sha256=XRxPopJCU9FGd1XmdxCZickmSTEvmkpAgLUPDFfjRm0,18431
63
+ emx_onnx_cgen/lowering/reduce.py,sha256=xetOc7mfZgKsD57O00K8NKmb0sEcHbMTpdDlqi4U6b4,18413
51
64
  emx_onnx_cgen/lowering/registry.py,sha256=rKyWnLDBFHJbHw-iyOtXv2Qc5LEBCwgopUXewvQpEpg,1392
52
- emx_onnx_cgen/lowering/reshape.py,sha256=SF46eP95Z19PT67ayJDKW1fwllBWEZmGfXAL5q9wy-I,6881
65
+ emx_onnx_cgen/lowering/reshape.py,sha256=dXZx3tTklPHSHAi-ZroiS6niCmHfHb7yGioJUYUDkLY,13452
53
66
  emx_onnx_cgen/lowering/resize.py,sha256=J_x53hVHlfJemLwEhq5n_11Pe1TlF9nRMEpkw6IpzN8,14644
54
67
  emx_onnx_cgen/lowering/rms_normalization.py,sha256=_H56Pf9T80FYbmy1m3oc7_D5TbNxRrVeJScD5VmLZRo,2536
68
+ emx_onnx_cgen/lowering/scatter_nd.py,sha256=q7rBQ0AdPrwsHnsN5qpCTcnwCWyKzSaG86wpAhNiDzE,3221
55
69
  emx_onnx_cgen/lowering/shape.py,sha256=Vvd2zQB06wZcEe4mW5WBRrQuVF8f_tXSM9fpGxe9PEo,2913
56
70
  emx_onnx_cgen/lowering/size.py,sha256=Z_DTevdpx2W_3k0GoyQ2uWE3ms_PN1d_Ti7hh6HhB1Q,1261
57
71
  emx_onnx_cgen/lowering/slice.py,sha256=yHm_mXeHcLufDmVNvj_kv08zMdbvI39ViHcE-tVPKa0,14816
58
- emx_onnx_cgen/lowering/softmax.py,sha256=ZaOZf00f5PNHRjSki08Fv-iod6UgqL7cmblfpE_OQRU,1648
72
+ emx_onnx_cgen/lowering/softmax.py,sha256=qmg9AcxFgYZcz98VyxxgyaFmSEcUCxhGtfSX6zLTRgE,1874
59
73
  emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py,sha256=I0pbWyJdnf-9vAuX8-xsnovDKDGxynlBhfj5k_IVIa4,5230
60
74
  emx_onnx_cgen/lowering/split.py,sha256=ImGnsqrl7IdWbPTPazfXYjcsoRoziqsqtJBum12xTXY,5894
61
- emx_onnx_cgen/lowering/squeeze.py,sha256=rgICFprcWhC03h-GXZNaIQsdFjsqyxybJYtPklTfaYM,6086
75
+ emx_onnx_cgen/lowering/squeeze.py,sha256=ihLbOmUW3PIrV9_do78Qal-J-Ten7sxTnd4a2tm5aEo,6095
62
76
  emx_onnx_cgen/lowering/tile.py,sha256=fT-ybiBZfb4bqBAPrCORZCNm3KWeu4rRW2BJ_UVIVZU,3041
77
+ emx_onnx_cgen/lowering/topk.py,sha256=u2FGoCt1RLI8GDCP-833Raxw5yMlXQY3ElJ-vJj-FPY,4113
63
78
  emx_onnx_cgen/lowering/transpose.py,sha256=TrRXUt-4UFNHZWaOpS3N5zEz5-OCK6-twZdlrnw7Pqg,1762
64
- emx_onnx_cgen/lowering/unsqueeze.py,sha256=sE3vribz8EyHqDG8lEcreKII7rQDElnHf1OpoM5HiAo,5987
79
+ emx_onnx_cgen/lowering/trilu.py,sha256=irA0fZV_OzKRYMhbJGuAZQPcDgxztNE0a-fMw_seU6E,3277
80
+ emx_onnx_cgen/lowering/unsqueeze.py,sha256=tlmdF8OMS9u-aT3jSwmuBo_VHzgMu3QiCnf45UNMNuY,5996
65
81
  emx_onnx_cgen/lowering/variadic.py,sha256=hmPzRIj0kcZriGRTR2ma1YMH9g21K_4f-3FXw6qO3jE,3298
66
82
  emx_onnx_cgen/lowering/where.py,sha256=uiaWU9RM6o-n38N0AEINIkXS33yVK3-ohkfKIApJOoA,2655
67
83
  emx_onnx_cgen/runtime/__init__.py,sha256=88xGpAs1IEBlzlWL_e9tnKUlaSRdc7pQUeVCu5LC4DY,50
68
- emx_onnx_cgen/runtime/evaluator.py,sha256=GFxrBXcKuQkZ0HY46twOTrNc955UqW3cRKAu5AYVJzQ,84910
84
+ emx_onnx_cgen/runtime/evaluator.py,sha256=Zd0RRwn0c7Lr3eW4OSvJEX9oBEX9p4cr_fmfvJK4LHY,102372
69
85
  shared/__init__.py,sha256=bmP79AVZdY_1aNULJap9pm76Q41Rabrza6X-0A8lDzw,45
70
- shared/scalar_functions.py,sha256=OAFO6kT6Gtcv5jp7UBLRhifhGmAbWhDKAmapTvqQruc,89911
86
+ shared/scalar_functions.py,sha256=KawY6sleIcVf5FdffFABQWeh4P_I8Oz7IMPcXjMyRfw,90843
71
87
  shared/scalar_types.py,sha256=kEpsl5T-NVFxCcTzXqPJbtpvDiCgKHfz91dphLLZxZA,4912
72
- emx_onnx_cgen-0.2.0.dist-info/METADATA,sha256=xTbPSAdUMfyXwvkxGGVYhsGhKjpxgPSzV1rimuB8zn0,4256
73
- emx_onnx_cgen-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
74
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt,sha256=b7Rvmz_Bi9kWyn7QayQC_FEXiRpt4cS1RnluKh49yoo,57
75
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt,sha256=g39fo-blEbgiVcC_GRqAnBzN234w3LXbcVdLUoItSLk,21
76
- emx_onnx_cgen-0.2.0.dist-info/RECORD,,
88
+ shared/ulp.py,sha256=o_JQ0pyeu1BD5Jx5tWuKnFQYWm1Q8zA8i8EHecf6Ys8,1371
89
+ emx_onnx_cgen-0.3.0.dist-info/METADATA,sha256=Q4l5q1s5a6pLWYmAujI9CzU-kMm40d4JffGtQQwVCjw,6036
90
+ emx_onnx_cgen-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
91
+ emx_onnx_cgen-0.3.0.dist-info/entry_points.txt,sha256=b7Rvmz_Bi9kWyn7QayQC_FEXiRpt4cS1RnluKh49yoo,57
92
+ emx_onnx_cgen-0.3.0.dist-info/top_level.txt,sha256=g39fo-blEbgiVcC_GRqAnBzN234w3LXbcVdLUoItSLk,21
93
+ emx_onnx_cgen-0.3.0.dist-info/RECORD,,
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
3
+ from dataclasses import dataclass, field
4
4
  from enum import Enum
5
5
  import math
6
6
  from typing import Callable, Dict, List, Mapping, Set
@@ -26,6 +26,7 @@ class _GeneratedScalar:
26
26
  lines: List[str]
27
27
  deps: Set[ScalarFunctionKey]
28
28
  includes: Set[str]
29
+ constants: Set[str] = field(default_factory=set)
29
30
 
30
31
 
31
32
  def _scalar_function_spec(
@@ -396,6 +397,7 @@ _ONNX_OP_TO_SCALAR_FUNCTION = {
396
397
  "Max": ScalarFunction.MAXIMUM,
397
398
  "Mean": ScalarFunction.MEAN,
398
399
  "Min": ScalarFunction.MINIMUM,
400
+ "Mish": ScalarFunction.MISH,
399
401
  "Mod": ScalarFunction.FMOD,
400
402
  "Mul": ScalarFunction.MUL,
401
403
  "Neg": ScalarFunction.NEG,
@@ -1071,7 +1073,7 @@ def _float_sign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1071
1073
 
1072
1074
 
1073
1075
  def _float_round(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1074
- return _float_unary_math(dtype_info, "round", "round")
1076
+ return _float_unary_math(dtype_info, "round", "rint")
1075
1077
 
1076
1078
 
1077
1079
  def _float_trunc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
@@ -1089,7 +1091,7 @@ def _float_angle(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1089
1091
  f" return a < {zero} ? {pi} : {zero};",
1090
1092
  "}",
1091
1093
  ]
1092
- return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1094
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set(), constants={pi})
1093
1095
 
1094
1096
 
1095
1097
  def _float_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
@@ -1099,13 +1101,25 @@ def _float_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1099
1101
  def _float_deg2rad(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1100
1102
  pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
1101
1103
  one_eighty = _float_literal(180.0, dtype_info)
1102
- return _simple_unary(dtype_info, "deg2rad", f"a * ({pi} / {one_eighty})")
1104
+ generated = _simple_unary(dtype_info, "deg2rad", f"a * ({pi} / {one_eighty})")
1105
+ return _GeneratedScalar(
1106
+ lines=generated.lines,
1107
+ deps=generated.deps,
1108
+ includes=generated.includes,
1109
+ constants={pi},
1110
+ )
1103
1111
 
1104
1112
 
1105
1113
  def _float_rad2deg(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1106
1114
  pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
1107
1115
  one_eighty = _float_literal(180.0, dtype_info)
1108
- return _simple_unary(dtype_info, "rad2deg", f"a * ({one_eighty} / {pi})")
1116
+ generated = _simple_unary(dtype_info, "rad2deg", f"a * ({one_eighty} / {pi})")
1117
+ return _GeneratedScalar(
1118
+ lines=generated.lines,
1119
+ deps=generated.deps,
1120
+ includes=generated.includes,
1121
+ constants={pi},
1122
+ )
1109
1123
 
1110
1124
 
1111
1125
  def _float_digamma_f64() -> _GeneratedScalar:
@@ -1135,7 +1149,9 @@ def _float_digamma_f64() -> _GeneratedScalar:
1135
1149
  " return result;",
1136
1150
  "}",
1137
1151
  ]
1138
- return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1152
+ return _GeneratedScalar(
1153
+ lines=lines, deps=set(), includes=set(), constants={"REF_PI_D"}
1154
+ )
1139
1155
 
1140
1156
 
1141
1157
  def _float_digamma(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
@@ -1186,7 +1202,7 @@ def _float_erfinv(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1186
1202
  " return approx;",
1187
1203
  "}",
1188
1204
  ]
1189
- return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1205
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set(), constants={pi})
1190
1206
 
1191
1207
 
1192
1208
  def _float_frac(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
@@ -1288,7 +1304,7 @@ def _float_sinc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1288
1304
  f" return {_math_fn('sin', dtype_info)}(x) / x;",
1289
1305
  "}",
1290
1306
  ]
1291
- return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1307
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set(), constants={pi})
1292
1308
 
1293
1309
 
1294
1310
  def _float_square(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
@@ -2285,7 +2301,12 @@ def _generate_scalar(key: ScalarFunctionKey) -> _GeneratedScalar:
2285
2301
  includes.add("#include <limits.h>")
2286
2302
  if dtype_info.is_bool:
2287
2303
  includes.add("#include <stdbool.h>")
2288
- return _GeneratedScalar(lines=generated.lines, deps=generated.deps, includes=includes)
2304
+ return _GeneratedScalar(
2305
+ lines=generated.lines,
2306
+ deps=generated.deps,
2307
+ includes=includes,
2308
+ constants=generated.constants,
2309
+ )
2289
2310
 
2290
2311
 
2291
2312
  def _function_name_for_key(key: ScalarFunctionKey) -> str:
@@ -2352,6 +2373,7 @@ class ScalarFunctionRegistry:
2352
2373
  def include_lines(self) -> List[str]:
2353
2374
  includes: Set[str] = set()
2354
2375
  visited: Set[ScalarFunctionKey] = set()
2376
+ constants: Set[str] = set()
2355
2377
 
2356
2378
  def collect(key: ScalarFunctionKey) -> None:
2357
2379
  if key in visited:
@@ -2362,18 +2384,28 @@ class ScalarFunctionRegistry:
2362
2384
  for dep in entry.deps:
2363
2385
  collect(dep)
2364
2386
  includes.update(entry.includes)
2387
+ constants.update(entry.constants)
2365
2388
 
2366
2389
  for key in self._requested:
2367
2390
  collect(key)
2368
2391
  ordered = sorted(includes)
2369
- preamble = [
2370
- "#ifndef REF_PI_F",
2371
- "#define REF_PI_F 3.14159265358979323846f",
2372
- "#endif",
2373
- "#ifndef REF_PI_D",
2374
- "#define REF_PI_D 3.14159265358979323846",
2375
- "#endif",
2376
- ]
2392
+ preamble: List[str] = []
2393
+ if "REF_PI_F" in constants:
2394
+ preamble.extend(
2395
+ [
2396
+ "#ifndef REF_PI_F",
2397
+ "#define REF_PI_F 3.14159265358979323846f",
2398
+ "#endif",
2399
+ ]
2400
+ )
2401
+ if "REF_PI_D" in constants:
2402
+ preamble.extend(
2403
+ [
2404
+ "#ifndef REF_PI_D",
2405
+ "#define REF_PI_D 3.14159265358979323846",
2406
+ "#endif",
2407
+ ]
2408
+ )
2377
2409
  return ordered + preamble
2378
2410
 
2379
2411
  def render(self) -> List[str]:
shared/ulp.py ADDED
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict
4
+
5
+ import numpy as np
6
+
7
+ from shared.scalar_types import ScalarFunctionError
8
+
9
+ _FLOAT_TO_UINT: Dict[np.dtype, np.dtype] = {
10
+ np.dtype("float16"): np.dtype("uint16"),
11
+ np.dtype("float32"): np.dtype("uint32"),
12
+ np.dtype("float64"): np.dtype("uint64"),
13
+ }
14
+
15
+
16
+ def _coerce_float_scalar(value: object, dtype: np.dtype) -> np.ndarray:
17
+ return np.asarray(value, dtype=dtype).reshape(())
18
+
19
+
20
+ def _ulp_intdiff_same_sign(
21
+ f1: np.ndarray, f2: np.ndarray, uint_dtype: np.dtype
22
+ ) -> int:
23
+ i1 = f1.view(uint_dtype).item()
24
+ i2 = f2.view(uint_dtype).item()
25
+ return int(i1 - i2) if i1 > i2 else int(i2 - i1)
26
+
27
+
28
+ def ulp_intdiff_float(f1: object, f2: object) -> int:
29
+ dtype = np.result_type(f1, f2)
30
+ try:
31
+ uint_dtype = _FLOAT_TO_UINT[dtype]
32
+ except KeyError as exc:
33
+ raise ScalarFunctionError(
34
+ f"unsupported dtype for ULP diff: {dtype}"
35
+ ) from exc
36
+
37
+ f1_scalar = _coerce_float_scalar(f1, dtype)
38
+ f2_scalar = _coerce_float_scalar(f2, dtype)
39
+
40
+ if np.signbit(f1_scalar) != np.signbit(f2_scalar):
41
+ zero = _coerce_float_scalar(0.0, dtype)
42
+ return (
43
+ ulp_intdiff_float(zero, np.abs(f1_scalar))
44
+ + ulp_intdiff_float(zero, np.abs(f2_scalar))
45
+ + 1
46
+ )
47
+
48
+ return _ulp_intdiff_same_sign(f1_scalar, f2_scalar, uint_dtype)