emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +340 -59
- emx_onnx_cgen/codegen/c_emitter.py +2369 -111
- emx_onnx_cgen/compiler.py +188 -5
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/lowering/common.py +379 -2
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/gather_elements.py +1 -3
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +6 -5
- emx_onnx_cgen/lowering/logsoftmax.py +5 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/matmul.py +6 -7
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/reduce.py +5 -6
- emx_onnx_cgen/lowering/reshape.py +223 -51
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/softmax.py +5 -1
- emx_onnx_cgen/lowering/squeeze.py +5 -5
- emx_onnx_cgen/lowering/topk.py +116 -0
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +5 -5
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +460 -42
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +61 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
- shared/scalar_functions.py +49 -17
- shared/ulp.py +48 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _convert_hex_floats(value: Any) -> Any:
|
|
9
|
+
if isinstance(value, list):
|
|
10
|
+
return [_convert_hex_floats(item) for item in value]
|
|
11
|
+
if isinstance(value, str):
|
|
12
|
+
return float.fromhex(value)
|
|
13
|
+
return value
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def decode_testbench_array(data: object, dtype: np.dtype) -> np.ndarray:
|
|
17
|
+
"""Decode testbench JSON data into a numpy array.
|
|
18
|
+
|
|
19
|
+
Floating-point values are expected to be hex strings (C99 %a formatting).
|
|
20
|
+
"""
|
|
21
|
+
if np.issubdtype(dtype, np.floating):
|
|
22
|
+
data = _convert_hex_floats(data)
|
|
23
|
+
return np.array(data, dtype=dtype)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _float_uint_dtype(values: np.ndarray) -> type[np.unsignedinteger]:
|
|
7
|
+
if values.dtype == np.float16:
|
|
8
|
+
return np.uint16
|
|
9
|
+
if values.dtype == np.float32:
|
|
10
|
+
return np.uint32
|
|
11
|
+
if values.dtype == np.float64:
|
|
12
|
+
return np.uint64
|
|
13
|
+
raise ValueError(f"Unsupported floating dtype for ULP calculation: {values.dtype}")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _float_to_ordered_int(values: np.ndarray) -> np.ndarray:
|
|
17
|
+
uint_dtype = _float_uint_dtype(values)
|
|
18
|
+
bits = np.dtype(uint_dtype).itemsize * 8
|
|
19
|
+
sign_mask = np.array(1 << (bits - 1), dtype=uint_dtype)
|
|
20
|
+
as_uint = values.view(uint_dtype)
|
|
21
|
+
ordered = np.where(as_uint & sign_mask, ~as_uint, as_uint | sign_mask)
|
|
22
|
+
return ordered.astype(np.uint64, copy=False)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def max_ulp_diff(actual: np.ndarray, expected: np.ndarray) -> int:
|
|
26
|
+
if actual.shape != expected.shape:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Shape mismatch for ULP calculation: {actual.shape} vs {expected.shape}"
|
|
29
|
+
)
|
|
30
|
+
if not np.issubdtype(expected.dtype, np.floating):
|
|
31
|
+
return 0
|
|
32
|
+
dtype = expected.dtype
|
|
33
|
+
if dtype not in (np.float16, np.float32, np.float64):
|
|
34
|
+
raise ValueError(f"Unsupported floating dtype for ULP calculation: {dtype}")
|
|
35
|
+
actual_cast = actual.astype(dtype, copy=False)
|
|
36
|
+
expected_cast = expected.astype(dtype, copy=False)
|
|
37
|
+
nan_mask = np.isnan(actual_cast) | np.isnan(expected_cast)
|
|
38
|
+
if nan_mask.any():
|
|
39
|
+
both_nan = np.isnan(actual_cast) & np.isnan(expected_cast)
|
|
40
|
+
if not np.all(both_nan):
|
|
41
|
+
uint_dtype = _float_uint_dtype(expected_cast)
|
|
42
|
+
return int(np.iinfo(uint_dtype).max)
|
|
43
|
+
actual_cast = actual_cast[~nan_mask]
|
|
44
|
+
expected_cast = expected_cast[~nan_mask]
|
|
45
|
+
if actual_cast.size == 0:
|
|
46
|
+
return 0
|
|
47
|
+
eps = np.finfo(dtype).eps
|
|
48
|
+
near_zero = (np.abs(actual_cast) < eps) & (np.abs(expected_cast) < eps)
|
|
49
|
+
if np.any(near_zero):
|
|
50
|
+
actual_cast = actual_cast.copy()
|
|
51
|
+
expected_cast = expected_cast.copy()
|
|
52
|
+
actual_cast[near_zero] = 0
|
|
53
|
+
expected_cast[near_zero] = 0
|
|
54
|
+
ordered_actual = _float_to_ordered_int(actual_cast)
|
|
55
|
+
ordered_expected = _float_to_ordered_int(expected_cast)
|
|
56
|
+
deltas = ordered_actual.astype(np.int64) - ordered_expected.astype(np.int64)
|
|
57
|
+
return int(np.max(np.abs(deltas)))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def format_success_message(max_ulp: int) -> str:
|
|
61
|
+
return f"OK (max ULP {max_ulp})"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: emx-onnx-cgen
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: emmtrix ONNX-to-C Code Generator
|
|
5
5
|
Requires-Python: >=3.10
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
@@ -9,14 +9,14 @@ Description-Content-Type: text/markdown
|
|
|
9
9
|
|
|
10
10
|
[](https://pypi.org/project/emx-onnx-cgen)
|
|
11
11
|
|
|
12
|
-
`emx-onnx-cgen` compiles ONNX models to portable, deterministic C code for deeply embedded systems. The generated code is designed to run without dynamic memory allocation, operating
|
|
12
|
+
`emx-onnx-cgen` compiles ONNX models to portable, deterministic C code for deeply embedded systems. The generated code is designed to run without dynamic memory allocation, operating-system services, or external runtimes, making it suitable for safety-critical and resource-constrained targets.
|
|
13
13
|
|
|
14
14
|
Key characteristics:
|
|
15
15
|
|
|
16
16
|
- **No dynamic memory allocation** (`malloc`, `free`, heap usage)
|
|
17
17
|
- **Static, compile-time known memory layout** for parameters, activations, and temporaries
|
|
18
18
|
- **Deterministic control flow** (explicit loops, no hidden dispatch or callbacks)
|
|
19
|
-
- **No OS
|
|
19
|
+
- **No OS dependencies**, using only standard C headers (for example, `stdint.h` and `stddef.h`)
|
|
20
20
|
- **Single-threaded execution model**
|
|
21
21
|
- **Bitwise-stable code generation** for reproducible builds
|
|
22
22
|
- **Readable, auditable C code** suitable for certification and code reviews
|
|
@@ -47,7 +47,7 @@ Key characteristics:
|
|
|
47
47
|
- `float`, `double`, `float16`
|
|
48
48
|
- `int8_t`, `uint8_t`, `int16_t`, `uint16_t`, `int32_t`, `uint32_t`, `int64_t`, `uint64_t`
|
|
49
49
|
- `bool`
|
|
50
|
-
-
|
|
50
|
+
- Optional support for dynamic dimensions using C99 variable-length arrays (VLAs), when the target compiler supports them.
|
|
51
51
|
|
|
52
52
|
## Installation
|
|
53
53
|
|
|
@@ -93,6 +93,8 @@ Options:
|
|
|
93
93
|
- `--model-name`: Override the generated model name (default: output file stem).
|
|
94
94
|
- `--emit-testbench`: Emit a JSON-producing `main()` testbench for validation.
|
|
95
95
|
- `--emit-data-file`: Emit constant data arrays into a companion `_data` C file.
|
|
96
|
+
- `--large-weight-threshold`: Store weights larger than this element count in a binary file (default: `1024`).
|
|
97
|
+
- `--large-temp-threshold-bytes`: Mark temporary buffers larger than this threshold as static (default: `1024`).
|
|
96
98
|
- `--no-restrict-arrays`: Disable `restrict` qualifiers on generated array parameters.
|
|
97
99
|
|
|
98
100
|
### `verify`
|
|
@@ -106,6 +108,25 @@ Options:
|
|
|
106
108
|
- `--template-dir`: Directory containing the C templates (default: `templates`).
|
|
107
109
|
- `--model-name`: Override the generated model name (default: model file stem).
|
|
108
110
|
- `--cc`: Explicit C compiler command for building the testbench binary.
|
|
111
|
+
- `--large-weight-threshold`: Store weights larger than this element count in a binary file (default: `1024`).
|
|
112
|
+
- `--large-temp-threshold-bytes`: Mark temporary buffers larger than this threshold as static (default: `1024`).
|
|
113
|
+
- `--max-ulp`: Maximum allowed ULP distance for floating outputs (default: `100`).
|
|
114
|
+
|
|
115
|
+
How verification works:
|
|
116
|
+
|
|
117
|
+
1. **Compile with a testbench**: the compiler is invoked with `--emit-testbench`,
|
|
118
|
+
generating a C program that runs the model and prints inputs/outputs as JSON.
|
|
119
|
+
2. **Build and execute**: the testbench is compiled with the selected C compiler
|
|
120
|
+
(`--cc`, `CC`, or a detected `cc/gcc/clang`) and executed in a temporary
|
|
121
|
+
directory.
|
|
122
|
+
3. **Run ONNX Runtime**: the JSON inputs from the testbench are fed to ORT using
|
|
123
|
+
the same model.
|
|
124
|
+
4. **Compare outputs**: floating outputs are compared by maximum ULP distance
|
|
125
|
+
(see https://www.emmtrix.com/wiki/ULP_Difference_of_Float_Numbers for the
|
|
126
|
+
ULP definition and algorithm); non-floating outputs must match exactly.
|
|
127
|
+
Missing outputs or mismatches are treated as failures.
|
|
128
|
+
5. **ORT unsupported models**: if ORT reports `NOT_IMPLEMENTED`, verification is
|
|
129
|
+
skipped with a warning (exit code 0).
|
|
109
130
|
|
|
110
131
|
## Output
|
|
111
132
|
|
|
@@ -113,15 +134,20 @@ By default, the compiler emits a single C source file that includes:
|
|
|
113
134
|
|
|
114
135
|
- A generated entry point that mirrors the ONNX graph inputs/outputs.
|
|
115
136
|
- Tensor buffers for constants and temporaries.
|
|
116
|
-
- A lightweight runtime implemented via templates in `templates/`.
|
|
117
137
|
|
|
118
138
|
When `--emit-data-file` is enabled, the main C source declares constant arrays
|
|
119
139
|
as `extern`, and a second file named like the output with a `_data` suffix
|
|
120
140
|
contains the constant definitions.
|
|
121
141
|
|
|
142
|
+
When `--large-weight-threshold` is set and a weight exceeds the threshold, the
|
|
143
|
+
compiler emits a `<model>.bin` file with weights packed contiguously and
|
|
144
|
+
generates a `<model>_load` helper that loads weights from the binary file at
|
|
145
|
+
runtime.
|
|
146
|
+
|
|
122
147
|
## Official ONNX test coverage
|
|
123
148
|
|
|
124
149
|
See [`OFFICIAL_ONNX_FILE_SUPPORT.md`](OFFICIAL_ONNX_FILE_SUPPORT.md) for the generated support matrix.
|
|
150
|
+
See [`SUPPORT_OPS.md`](SUPPORT_OPS.md) for operator-level support derived from the expectation JSON files.
|
|
125
151
|
|
|
126
152
|
## Maintained by
|
|
127
153
|
|
|
@@ -1,76 +1,93 @@
|
|
|
1
1
|
emx_onnx_cgen/__init__.py,sha256=jUSbu1kJ0krzVTYEcph3jCprBhD7tWNtiSdL6r29KrM,221
|
|
2
2
|
emx_onnx_cgen/__main__.py,sha256=iC1lLVtR6-TmpL6OxXcy3oIntExUtajn9-q627R1XyI,140
|
|
3
|
-
emx_onnx_cgen/_build_info.py,sha256=
|
|
4
|
-
emx_onnx_cgen/
|
|
5
|
-
emx_onnx_cgen/
|
|
3
|
+
emx_onnx_cgen/_build_info.py,sha256=zS8xdzMihYIqmbd58Pfku76TrBsTlBdkLVrWqTCiUs4,112
|
|
4
|
+
emx_onnx_cgen/_version.py,sha256=5zTqm8rgXsWYBpB2M3Zw_K1D-aV8wP7NsBLrmMKkrAQ,704
|
|
5
|
+
emx_onnx_cgen/cli.py,sha256=hRF2xG6t2YUkkTYrAUVsOyz1lTAdjokE-1pxFffsG2c,20643
|
|
6
|
+
emx_onnx_cgen/compiler.py,sha256=Q4a4_a1DkGmbrRJaRgpk0uyOqrJQExqDQD_BNw3AUcw,28585
|
|
6
7
|
emx_onnx_cgen/dtypes.py,sha256=jRx3BBvk0qFW14bngoL1B7L_IRasyNJ4jqhpM5YhcOM,1335
|
|
7
8
|
emx_onnx_cgen/errors.py,sha256=HpOv95mTgr9ZX2gYe1RtwVMbPskh7zkqjU_FgAD-uIM,363
|
|
8
|
-
emx_onnx_cgen/onnx_import.py,sha256=
|
|
9
|
-
emx_onnx_cgen/
|
|
9
|
+
emx_onnx_cgen/onnx_import.py,sha256=IF7KZGfEP9H4H1fHYjobGbB_381fqD_67KtqZYs9AZ4,9168
|
|
10
|
+
emx_onnx_cgen/onnxruntime_utils.py,sha256=mEsC1x00M1jyBgVBKqnKoqx6H1tdgsFFUy7rbITs3bs,308
|
|
11
|
+
emx_onnx_cgen/ops.py,sha256=qpPOaqsYprlJrhCNLVBZ3XnREBRDdmkXbd1zaAkywOI,16732
|
|
12
|
+
emx_onnx_cgen/testbench.py,sha256=-NbqD1aC7OXvFMLiLzd2IPObenQdHFH85cNxNSB1GeY,640
|
|
10
13
|
emx_onnx_cgen/validation.py,sha256=KFdUdGjQbzTj1szCJcjxnTi8f5l6ywNgCB9abbBpTbM,2360
|
|
14
|
+
emx_onnx_cgen/verification.py,sha256=eTnhl_9YObyvs0fqJAw8796TlRzp-IoFM4JuMkQ8XOc,2403
|
|
11
15
|
emx_onnx_cgen/codegen/__init__.py,sha256=-_sxL87uyAIunaetjUvIUo2bc46ugVlaNtSsidegMRM,362
|
|
12
|
-
emx_onnx_cgen/codegen/c_emitter.py,sha256=
|
|
16
|
+
emx_onnx_cgen/codegen/c_emitter.py,sha256=vYrRx3UQvve_s4ElLiuh25lsbEt7mDTCXjvr-kdkggM,422455
|
|
13
17
|
emx_onnx_cgen/ir/__init__.py,sha256=fD2D8qxlGoCFJb0m9v6u3XTgzSxDOhB4cfLBiCLovzg,102
|
|
14
|
-
emx_onnx_cgen/ir/model.py,sha256=
|
|
18
|
+
emx_onnx_cgen/ir/model.py,sha256=SZ3K8t4dKUqWuXWe5ozApofXx4bdcf4p0WYCdeU-mFA,1265
|
|
15
19
|
emx_onnx_cgen/lowering/__init__.py,sha256=wrxLMWcPUH1RbPJOs0Tsdb12FhXjAAeZVDYwKqcIuzw,103
|
|
16
20
|
emx_onnx_cgen/lowering/arg_reduce.py,sha256=2AowDRCJRkIvrVBphbA0rM18oCWEpCDEV5Y4K9wSDII,3388
|
|
17
21
|
emx_onnx_cgen/lowering/attention.py,sha256=19Jq_k0DXwH71a3pmLTWCNMttmw5uuiNK6Jhln5HC4A,16488
|
|
18
22
|
emx_onnx_cgen/lowering/average_pool.py,sha256=9kg3pYHG7QLid_M2dbleC1VoNlVlGsKdOrsWp3pt7sc,8085
|
|
19
23
|
emx_onnx_cgen/lowering/batch_normalization.py,sha256=_aFCm4QaC5jH-JNEvqDFYOyAMdzgUFS_3Gmo1vdPyKE,3987
|
|
20
24
|
emx_onnx_cgen/lowering/cast.py,sha256=zKiE4wI7oWP_TjxBV4fY3-FXvZxK2zy58O6tWJ2dODQ,2852
|
|
21
|
-
emx_onnx_cgen/lowering/common.py,sha256=
|
|
25
|
+
emx_onnx_cgen/lowering/common.py,sha256=OF5UTin4teEFSp-rbiUArYCJogZ636Rujhkgrm2vj_w,16083
|
|
22
26
|
emx_onnx_cgen/lowering/concat.py,sha256=TefckPfuaIHVHxGExJiO9wlkjyRO1TGg-QAMeoW8hW0,1097
|
|
23
27
|
emx_onnx_cgen/lowering/constant_of_shape.py,sha256=btQflQFMP_y22sK7RrhkbGdaeSSLPC_DWhLjxY7CAgk,3208
|
|
24
28
|
emx_onnx_cgen/lowering/conv.py,sha256=I1_tssw_ySf4beKV0sCVe8DRhNxL58PqC0wxtWjD79s,7309
|
|
29
|
+
emx_onnx_cgen/lowering/conv_transpose.py,sha256=vMbH7g3V9o68BjsW-FurNp1G8Dgr3NrV7JPLLfopHG0,11164
|
|
25
30
|
emx_onnx_cgen/lowering/cumsum.py,sha256=eX0bDtwY-qevz0KXNHtJaDiKUUHIOhDX0uDiSxcC0ZU,4125
|
|
26
31
|
emx_onnx_cgen/lowering/depth_space.py,sha256=M4md379jiumGWmg7EgR-CinoPzwof2RdfOiNqOzxd9o,4217
|
|
27
32
|
emx_onnx_cgen/lowering/dropout.py,sha256=oBKPMN-J9Gnw8dRXvf-bN15L1-5W7-qKhR72Z6AgLXQ,1775
|
|
33
|
+
emx_onnx_cgen/lowering/einsum.py,sha256=g0KEZNJb87SzH-TqDDcfNPTcAaRioq455eN6HHLZNNo,6128
|
|
28
34
|
emx_onnx_cgen/lowering/elementwise.py,sha256=HN6vEW58lceYECp-7QWLCWOBo1ImyY66aZIg06nA5g8,6231
|
|
29
35
|
emx_onnx_cgen/lowering/expand.py,sha256=4msnYM-6RnzMplQqde2ovOLsjmWQ4bnXEoUiEM6CT6k,5529
|
|
30
36
|
emx_onnx_cgen/lowering/eye_like.py,sha256=76HEdT-EofDCCy7DewjIpILJdIJyJ-YVCbLXO54SX5E,1734
|
|
31
37
|
emx_onnx_cgen/lowering/flatten.py,sha256=sGol05FDN0xoNgSl_DlVbjYvBHCHWjQC2KB15ytYfPs,2142
|
|
32
38
|
emx_onnx_cgen/lowering/gather.py,sha256=9zMB9fcdJi1fkTmDs_-L6FvQi1fnhdk0h7RmeN5MP2M,1814
|
|
33
|
-
emx_onnx_cgen/lowering/gather_elements.py,sha256=
|
|
39
|
+
emx_onnx_cgen/lowering/gather_elements.py,sha256=K-3w__F_I_gq3Kykk7LydTR5syH_Zpi-0-rdShLumbo,2329
|
|
40
|
+
emx_onnx_cgen/lowering/gather_nd.py,sha256=_0IW93RMRa9VtXSu4KMpBBA18ovLBGTmH90Y-ANOk1M,3101
|
|
34
41
|
emx_onnx_cgen/lowering/gemm.py,sha256=Ps2T4tZgXr5FObz5figwbLZq-Njzg44iBQ9cFmvH78k,4590
|
|
42
|
+
emx_onnx_cgen/lowering/global_max_pool.py,sha256=xyoqQyRFpDKCXBO8bqp7JstVxVfbj9pMd06-848ix5o,2223
|
|
35
43
|
emx_onnx_cgen/lowering/grid_sample.py,sha256=Ne-97ljxSdqfjBJtVHp2AQnEeXGQ5HE-HegCoxcNCm0,5228
|
|
36
44
|
emx_onnx_cgen/lowering/group_normalization.py,sha256=RqXud5_xNxMu8DP5EsPr4L2b6CZghQWCcG6Bh1x2gMA,2664
|
|
37
|
-
emx_onnx_cgen/lowering/
|
|
45
|
+
emx_onnx_cgen/lowering/hardmax.py,sha256=Lqnt9g48-Kpcklq_uJpA93FYdhTI8SjUu4Cnh9_BhYY,1961
|
|
46
|
+
emx_onnx_cgen/lowering/identity.py,sha256=h8cy6T9n10IWEnOKgIGOca9sNVaw5vU1fdt_p8AAqq8,1859
|
|
38
47
|
emx_onnx_cgen/lowering/instance_normalization.py,sha256=1Yx2KPKq_BHberCBTrGQXQswAS0FfDle9NpyeD41ypU,1950
|
|
39
48
|
emx_onnx_cgen/lowering/layer_normalization.py,sha256=ZvqGZOhuoYh8ZPyzb-PV0kIc2bbunWTYj12wmrGu9YY,4529
|
|
40
|
-
emx_onnx_cgen/lowering/logsoftmax.py,sha256=
|
|
49
|
+
emx_onnx_cgen/lowering/logsoftmax.py,sha256=gdPYJdRNjwRDRXozzKOkKHV7HeOw48Zl4guiAN5pgKs,1895
|
|
41
50
|
emx_onnx_cgen/lowering/lp_normalization.py,sha256=61CGS-2yN0bf5dby5b7Ug1PH3CStZN1xZmYWa5TysTI,1712
|
|
51
|
+
emx_onnx_cgen/lowering/lp_pool.py,sha256=96M-CeIqOkPstVr2BEoASImG6-Z4_S7ngB8bmPQlo7M,4873
|
|
42
52
|
emx_onnx_cgen/lowering/lrn.py,sha256=zGw1Jk7iBk1jHdjdDqfAREsV5VcSdOG3LcAmEllIB08,3370
|
|
43
53
|
emx_onnx_cgen/lowering/lstm.py,sha256=JhGxiF3bTSY3flkw_u9mil2esRxvIjr5Tc4vSPULDr4,12305
|
|
44
|
-
emx_onnx_cgen/lowering/matmul.py,sha256=
|
|
54
|
+
emx_onnx_cgen/lowering/matmul.py,sha256=QJ2DfMk6g5hNNpfUUfHH732cfSLL0LFzZxw-4GOuRYQ,4259
|
|
45
55
|
emx_onnx_cgen/lowering/maxpool.py,sha256=MRLeoCEdIwO8JNWOi7iKoeIsJvukqpx_w6GCHaDaYHU,7494
|
|
46
56
|
emx_onnx_cgen/lowering/mean_variance_normalization.py,sha256=L_6ECH9wPEnNX2mL6yroZRexZM8JV5ZnJvoPQS6IAuc,1875
|
|
47
|
-
emx_onnx_cgen/lowering/negative_log_likelihood_loss.py,sha256=
|
|
57
|
+
emx_onnx_cgen/lowering/negative_log_likelihood_loss.py,sha256=nqNETCqZP9MZhBU_Wcgaxu1w0uNKHa6VR5KTw--CWP0,9344
|
|
58
|
+
emx_onnx_cgen/lowering/nonzero.py,sha256=d_81vdF2BkYi_Z7jbypQ3qwSe-ctNoSxaJRYO8qDMjc,1637
|
|
59
|
+
emx_onnx_cgen/lowering/one_hot.py,sha256=twxMlNosGfm3iAeGv5LDGPliMCJ36o7Dxusn4gVJE00,4352
|
|
48
60
|
emx_onnx_cgen/lowering/pad.py,sha256=cNlh-rA3CRPfO-u8gvJ1MeF1j-vdBMXLuJBpp2DkFTc,10416
|
|
61
|
+
emx_onnx_cgen/lowering/quantize_linear.py,sha256=3IvZTQbWAYh61nVNAbd-CeHTpYaULUSEjSdnu-nIOI8,4771
|
|
49
62
|
emx_onnx_cgen/lowering/range.py,sha256=xLbG3SGvQiboPqSIh5qZyw3Krbxjk0nd3YvRlUI_q64,3463
|
|
50
|
-
emx_onnx_cgen/lowering/reduce.py,sha256=
|
|
63
|
+
emx_onnx_cgen/lowering/reduce.py,sha256=xetOc7mfZgKsD57O00K8NKmb0sEcHbMTpdDlqi4U6b4,18413
|
|
51
64
|
emx_onnx_cgen/lowering/registry.py,sha256=rKyWnLDBFHJbHw-iyOtXv2Qc5LEBCwgopUXewvQpEpg,1392
|
|
52
|
-
emx_onnx_cgen/lowering/reshape.py,sha256=
|
|
65
|
+
emx_onnx_cgen/lowering/reshape.py,sha256=dXZx3tTklPHSHAi-ZroiS6niCmHfHb7yGioJUYUDkLY,13452
|
|
53
66
|
emx_onnx_cgen/lowering/resize.py,sha256=J_x53hVHlfJemLwEhq5n_11Pe1TlF9nRMEpkw6IpzN8,14644
|
|
54
67
|
emx_onnx_cgen/lowering/rms_normalization.py,sha256=_H56Pf9T80FYbmy1m3oc7_D5TbNxRrVeJScD5VmLZRo,2536
|
|
68
|
+
emx_onnx_cgen/lowering/scatter_nd.py,sha256=q7rBQ0AdPrwsHnsN5qpCTcnwCWyKzSaG86wpAhNiDzE,3221
|
|
55
69
|
emx_onnx_cgen/lowering/shape.py,sha256=Vvd2zQB06wZcEe4mW5WBRrQuVF8f_tXSM9fpGxe9PEo,2913
|
|
56
70
|
emx_onnx_cgen/lowering/size.py,sha256=Z_DTevdpx2W_3k0GoyQ2uWE3ms_PN1d_Ti7hh6HhB1Q,1261
|
|
57
71
|
emx_onnx_cgen/lowering/slice.py,sha256=yHm_mXeHcLufDmVNvj_kv08zMdbvI39ViHcE-tVPKa0,14816
|
|
58
|
-
emx_onnx_cgen/lowering/softmax.py,sha256=
|
|
72
|
+
emx_onnx_cgen/lowering/softmax.py,sha256=qmg9AcxFgYZcz98VyxxgyaFmSEcUCxhGtfSX6zLTRgE,1874
|
|
59
73
|
emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py,sha256=I0pbWyJdnf-9vAuX8-xsnovDKDGxynlBhfj5k_IVIa4,5230
|
|
60
74
|
emx_onnx_cgen/lowering/split.py,sha256=ImGnsqrl7IdWbPTPazfXYjcsoRoziqsqtJBum12xTXY,5894
|
|
61
|
-
emx_onnx_cgen/lowering/squeeze.py,sha256=
|
|
75
|
+
emx_onnx_cgen/lowering/squeeze.py,sha256=ihLbOmUW3PIrV9_do78Qal-J-Ten7sxTnd4a2tm5aEo,6095
|
|
62
76
|
emx_onnx_cgen/lowering/tile.py,sha256=fT-ybiBZfb4bqBAPrCORZCNm3KWeu4rRW2BJ_UVIVZU,3041
|
|
77
|
+
emx_onnx_cgen/lowering/topk.py,sha256=u2FGoCt1RLI8GDCP-833Raxw5yMlXQY3ElJ-vJj-FPY,4113
|
|
63
78
|
emx_onnx_cgen/lowering/transpose.py,sha256=TrRXUt-4UFNHZWaOpS3N5zEz5-OCK6-twZdlrnw7Pqg,1762
|
|
64
|
-
emx_onnx_cgen/lowering/
|
|
79
|
+
emx_onnx_cgen/lowering/trilu.py,sha256=irA0fZV_OzKRYMhbJGuAZQPcDgxztNE0a-fMw_seU6E,3277
|
|
80
|
+
emx_onnx_cgen/lowering/unsqueeze.py,sha256=tlmdF8OMS9u-aT3jSwmuBo_VHzgMu3QiCnf45UNMNuY,5996
|
|
65
81
|
emx_onnx_cgen/lowering/variadic.py,sha256=hmPzRIj0kcZriGRTR2ma1YMH9g21K_4f-3FXw6qO3jE,3298
|
|
66
82
|
emx_onnx_cgen/lowering/where.py,sha256=uiaWU9RM6o-n38N0AEINIkXS33yVK3-ohkfKIApJOoA,2655
|
|
67
83
|
emx_onnx_cgen/runtime/__init__.py,sha256=88xGpAs1IEBlzlWL_e9tnKUlaSRdc7pQUeVCu5LC4DY,50
|
|
68
|
-
emx_onnx_cgen/runtime/evaluator.py,sha256=
|
|
84
|
+
emx_onnx_cgen/runtime/evaluator.py,sha256=Zd0RRwn0c7Lr3eW4OSvJEX9oBEX9p4cr_fmfvJK4LHY,102372
|
|
69
85
|
shared/__init__.py,sha256=bmP79AVZdY_1aNULJap9pm76Q41Rabrza6X-0A8lDzw,45
|
|
70
|
-
shared/scalar_functions.py,sha256=
|
|
86
|
+
shared/scalar_functions.py,sha256=KawY6sleIcVf5FdffFABQWeh4P_I8Oz7IMPcXjMyRfw,90843
|
|
71
87
|
shared/scalar_types.py,sha256=kEpsl5T-NVFxCcTzXqPJbtpvDiCgKHfz91dphLLZxZA,4912
|
|
72
|
-
|
|
73
|
-
emx_onnx_cgen-0.
|
|
74
|
-
emx_onnx_cgen-0.
|
|
75
|
-
emx_onnx_cgen-0.
|
|
76
|
-
emx_onnx_cgen-0.
|
|
88
|
+
shared/ulp.py,sha256=o_JQ0pyeu1BD5Jx5tWuKnFQYWm1Q8zA8i8EHecf6Ys8,1371
|
|
89
|
+
emx_onnx_cgen-0.3.0.dist-info/METADATA,sha256=Q4l5q1s5a6pLWYmAujI9CzU-kMm40d4JffGtQQwVCjw,6036
|
|
90
|
+
emx_onnx_cgen-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
91
|
+
emx_onnx_cgen-0.3.0.dist-info/entry_points.txt,sha256=b7Rvmz_Bi9kWyn7QayQC_FEXiRpt4cS1RnluKh49yoo,57
|
|
92
|
+
emx_onnx_cgen-0.3.0.dist-info/top_level.txt,sha256=g39fo-blEbgiVcC_GRqAnBzN234w3LXbcVdLUoItSLk,21
|
|
93
|
+
emx_onnx_cgen-0.3.0.dist-info/RECORD,,
|
shared/scalar_functions.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
4
|
from enum import Enum
|
|
5
5
|
import math
|
|
6
6
|
from typing import Callable, Dict, List, Mapping, Set
|
|
@@ -26,6 +26,7 @@ class _GeneratedScalar:
|
|
|
26
26
|
lines: List[str]
|
|
27
27
|
deps: Set[ScalarFunctionKey]
|
|
28
28
|
includes: Set[str]
|
|
29
|
+
constants: Set[str] = field(default_factory=set)
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
def _scalar_function_spec(
|
|
@@ -396,6 +397,7 @@ _ONNX_OP_TO_SCALAR_FUNCTION = {
|
|
|
396
397
|
"Max": ScalarFunction.MAXIMUM,
|
|
397
398
|
"Mean": ScalarFunction.MEAN,
|
|
398
399
|
"Min": ScalarFunction.MINIMUM,
|
|
400
|
+
"Mish": ScalarFunction.MISH,
|
|
399
401
|
"Mod": ScalarFunction.FMOD,
|
|
400
402
|
"Mul": ScalarFunction.MUL,
|
|
401
403
|
"Neg": ScalarFunction.NEG,
|
|
@@ -1071,7 +1073,7 @@ def _float_sign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
|
1071
1073
|
|
|
1072
1074
|
|
|
1073
1075
|
def _float_round(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1074
|
-
return _float_unary_math(dtype_info, "round", "
|
|
1076
|
+
return _float_unary_math(dtype_info, "round", "rint")
|
|
1075
1077
|
|
|
1076
1078
|
|
|
1077
1079
|
def _float_trunc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
@@ -1089,7 +1091,7 @@ def _float_angle(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
|
1089
1091
|
f" return a < {zero} ? {pi} : {zero};",
|
|
1090
1092
|
"}",
|
|
1091
1093
|
]
|
|
1092
|
-
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1094
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set(), constants={pi})
|
|
1093
1095
|
|
|
1094
1096
|
|
|
1095
1097
|
def _float_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
@@ -1099,13 +1101,25 @@ def _float_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
|
1099
1101
|
def _float_deg2rad(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1100
1102
|
pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
|
|
1101
1103
|
one_eighty = _float_literal(180.0, dtype_info)
|
|
1102
|
-
|
|
1104
|
+
generated = _simple_unary(dtype_info, "deg2rad", f"a * ({pi} / {one_eighty})")
|
|
1105
|
+
return _GeneratedScalar(
|
|
1106
|
+
lines=generated.lines,
|
|
1107
|
+
deps=generated.deps,
|
|
1108
|
+
includes=generated.includes,
|
|
1109
|
+
constants={pi},
|
|
1110
|
+
)
|
|
1103
1111
|
|
|
1104
1112
|
|
|
1105
1113
|
def _float_rad2deg(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1106
1114
|
pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
|
|
1107
1115
|
one_eighty = _float_literal(180.0, dtype_info)
|
|
1108
|
-
|
|
1116
|
+
generated = _simple_unary(dtype_info, "rad2deg", f"a * ({one_eighty} / {pi})")
|
|
1117
|
+
return _GeneratedScalar(
|
|
1118
|
+
lines=generated.lines,
|
|
1119
|
+
deps=generated.deps,
|
|
1120
|
+
includes=generated.includes,
|
|
1121
|
+
constants={pi},
|
|
1122
|
+
)
|
|
1109
1123
|
|
|
1110
1124
|
|
|
1111
1125
|
def _float_digamma_f64() -> _GeneratedScalar:
|
|
@@ -1135,7 +1149,9 @@ def _float_digamma_f64() -> _GeneratedScalar:
|
|
|
1135
1149
|
" return result;",
|
|
1136
1150
|
"}",
|
|
1137
1151
|
]
|
|
1138
|
-
return _GeneratedScalar(
|
|
1152
|
+
return _GeneratedScalar(
|
|
1153
|
+
lines=lines, deps=set(), includes=set(), constants={"REF_PI_D"}
|
|
1154
|
+
)
|
|
1139
1155
|
|
|
1140
1156
|
|
|
1141
1157
|
def _float_digamma(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
@@ -1186,7 +1202,7 @@ def _float_erfinv(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
|
1186
1202
|
" return approx;",
|
|
1187
1203
|
"}",
|
|
1188
1204
|
]
|
|
1189
|
-
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1205
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set(), constants={pi})
|
|
1190
1206
|
|
|
1191
1207
|
|
|
1192
1208
|
def _float_frac(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
@@ -1288,7 +1304,7 @@ def _float_sinc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
|
1288
1304
|
f" return {_math_fn('sin', dtype_info)}(x) / x;",
|
|
1289
1305
|
"}",
|
|
1290
1306
|
]
|
|
1291
|
-
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1307
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set(), constants={pi})
|
|
1292
1308
|
|
|
1293
1309
|
|
|
1294
1310
|
def _float_square(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
@@ -2285,7 +2301,12 @@ def _generate_scalar(key: ScalarFunctionKey) -> _GeneratedScalar:
|
|
|
2285
2301
|
includes.add("#include <limits.h>")
|
|
2286
2302
|
if dtype_info.is_bool:
|
|
2287
2303
|
includes.add("#include <stdbool.h>")
|
|
2288
|
-
return _GeneratedScalar(
|
|
2304
|
+
return _GeneratedScalar(
|
|
2305
|
+
lines=generated.lines,
|
|
2306
|
+
deps=generated.deps,
|
|
2307
|
+
includes=includes,
|
|
2308
|
+
constants=generated.constants,
|
|
2309
|
+
)
|
|
2289
2310
|
|
|
2290
2311
|
|
|
2291
2312
|
def _function_name_for_key(key: ScalarFunctionKey) -> str:
|
|
@@ -2352,6 +2373,7 @@ class ScalarFunctionRegistry:
|
|
|
2352
2373
|
def include_lines(self) -> List[str]:
|
|
2353
2374
|
includes: Set[str] = set()
|
|
2354
2375
|
visited: Set[ScalarFunctionKey] = set()
|
|
2376
|
+
constants: Set[str] = set()
|
|
2355
2377
|
|
|
2356
2378
|
def collect(key: ScalarFunctionKey) -> None:
|
|
2357
2379
|
if key in visited:
|
|
@@ -2362,18 +2384,28 @@ class ScalarFunctionRegistry:
|
|
|
2362
2384
|
for dep in entry.deps:
|
|
2363
2385
|
collect(dep)
|
|
2364
2386
|
includes.update(entry.includes)
|
|
2387
|
+
constants.update(entry.constants)
|
|
2365
2388
|
|
|
2366
2389
|
for key in self._requested:
|
|
2367
2390
|
collect(key)
|
|
2368
2391
|
ordered = sorted(includes)
|
|
2369
|
-
preamble = [
|
|
2370
|
-
|
|
2371
|
-
|
|
2372
|
-
|
|
2373
|
-
|
|
2374
|
-
|
|
2375
|
-
|
|
2376
|
-
|
|
2392
|
+
preamble: List[str] = []
|
|
2393
|
+
if "REF_PI_F" in constants:
|
|
2394
|
+
preamble.extend(
|
|
2395
|
+
[
|
|
2396
|
+
"#ifndef REF_PI_F",
|
|
2397
|
+
"#define REF_PI_F 3.14159265358979323846f",
|
|
2398
|
+
"#endif",
|
|
2399
|
+
]
|
|
2400
|
+
)
|
|
2401
|
+
if "REF_PI_D" in constants:
|
|
2402
|
+
preamble.extend(
|
|
2403
|
+
[
|
|
2404
|
+
"#ifndef REF_PI_D",
|
|
2405
|
+
"#define REF_PI_D 3.14159265358979323846",
|
|
2406
|
+
"#endif",
|
|
2407
|
+
]
|
|
2408
|
+
)
|
|
2377
2409
|
return ordered + preamble
|
|
2378
2410
|
|
|
2379
2411
|
def render(self) -> List[str]:
|
shared/ulp.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from shared.scalar_types import ScalarFunctionError
|
|
8
|
+
|
|
9
|
+
_FLOAT_TO_UINT: Dict[np.dtype, np.dtype] = {
|
|
10
|
+
np.dtype("float16"): np.dtype("uint16"),
|
|
11
|
+
np.dtype("float32"): np.dtype("uint32"),
|
|
12
|
+
np.dtype("float64"): np.dtype("uint64"),
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _coerce_float_scalar(value: object, dtype: np.dtype) -> np.ndarray:
|
|
17
|
+
return np.asarray(value, dtype=dtype).reshape(())
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _ulp_intdiff_same_sign(
|
|
21
|
+
f1: np.ndarray, f2: np.ndarray, uint_dtype: np.dtype
|
|
22
|
+
) -> int:
|
|
23
|
+
i1 = f1.view(uint_dtype).item()
|
|
24
|
+
i2 = f2.view(uint_dtype).item()
|
|
25
|
+
return int(i1 - i2) if i1 > i2 else int(i2 - i1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def ulp_intdiff_float(f1: object, f2: object) -> int:
|
|
29
|
+
dtype = np.result_type(f1, f2)
|
|
30
|
+
try:
|
|
31
|
+
uint_dtype = _FLOAT_TO_UINT[dtype]
|
|
32
|
+
except KeyError as exc:
|
|
33
|
+
raise ScalarFunctionError(
|
|
34
|
+
f"unsupported dtype for ULP diff: {dtype}"
|
|
35
|
+
) from exc
|
|
36
|
+
|
|
37
|
+
f1_scalar = _coerce_float_scalar(f1, dtype)
|
|
38
|
+
f2_scalar = _coerce_float_scalar(f2, dtype)
|
|
39
|
+
|
|
40
|
+
if np.signbit(f1_scalar) != np.signbit(f2_scalar):
|
|
41
|
+
zero = _coerce_float_scalar(0.0, dtype)
|
|
42
|
+
return (
|
|
43
|
+
ulp_intdiff_float(zero, np.abs(f1_scalar))
|
|
44
|
+
+ ulp_intdiff_float(zero, np.abs(f2_scalar))
|
|
45
|
+
+ 1
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return _ulp_intdiff_same_sign(f1_scalar, f2_scalar, uint_dtype)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|