emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/compiler.py
CHANGED
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
from dataclasses import dataclass, fields
|
|
4
4
|
import hashlib
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
|
|
6
|
+
import time
|
|
7
|
+
from typing import Callable, Mapping, TypeVar
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import onnx
|
|
@@ -29,21 +30,27 @@ from .lowering import load_lowering_registry
|
|
|
29
30
|
from .lowering.common import ensure_supported_dtype, shape_product, value_dtype
|
|
30
31
|
from .lowering.registry import get_lowering_registry
|
|
31
32
|
from .onnx_import import import_onnx
|
|
32
|
-
from .runtime.evaluator import Evaluator
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
@dataclass(frozen=True)
|
|
36
36
|
class CompilerOptions:
|
|
37
|
-
template_dir: Path
|
|
37
|
+
template_dir: Path | None = None
|
|
38
38
|
model_name: str = "model"
|
|
39
39
|
emit_testbench: bool = False
|
|
40
40
|
command_line: str | None = None
|
|
41
41
|
model_checksum: str | None = None
|
|
42
42
|
restrict_arrays: bool = True
|
|
43
|
+
fp32_accumulation_strategy: str = "fp64"
|
|
44
|
+
fp16_accumulation_strategy: str = "fp32"
|
|
43
45
|
testbench_inputs: Mapping[str, np.ndarray] | None = None
|
|
46
|
+
testbench_optional_inputs: Mapping[str, bool] | None = None
|
|
44
47
|
truncate_weights_after: int | None = None
|
|
45
48
|
large_temp_threshold_bytes: int = 1024
|
|
46
|
-
large_weight_threshold: int =
|
|
49
|
+
large_weight_threshold: int = 100 * 1024
|
|
50
|
+
timings: dict[str, float] | None = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
_T = TypeVar("_T")
|
|
47
54
|
|
|
48
55
|
|
|
49
56
|
def _onnx_elem_type(dtype: np.dtype) -> int:
|
|
@@ -53,90 +60,155 @@ def _onnx_elem_type(dtype: np.dtype) -> int:
|
|
|
53
60
|
raise UnsupportedOpError(f"Unsupported dtype {dtype} for ONNX output")
|
|
54
61
|
|
|
55
62
|
|
|
63
|
+
def _optional_flag_name(name: str) -> str:
|
|
64
|
+
return f"{name}_present"
|
|
65
|
+
|
|
66
|
+
|
|
56
67
|
class Compiler:
|
|
57
68
|
def __init__(self, options: CompilerOptions | None = None) -> None:
|
|
58
69
|
if options is None:
|
|
59
|
-
options = CompilerOptions(
|
|
70
|
+
options = CompilerOptions()
|
|
60
71
|
self._options = options
|
|
61
72
|
self._emitter = CEmitter(
|
|
62
73
|
options.template_dir,
|
|
63
74
|
restrict_arrays=options.restrict_arrays,
|
|
75
|
+
fp32_accumulation_strategy=options.fp32_accumulation_strategy,
|
|
76
|
+
fp16_accumulation_strategy=options.fp16_accumulation_strategy,
|
|
64
77
|
truncate_weights_after=options.truncate_weights_after,
|
|
65
78
|
large_temp_threshold_bytes=options.large_temp_threshold_bytes,
|
|
66
79
|
large_weight_threshold=options.large_weight_threshold,
|
|
67
80
|
)
|
|
68
81
|
load_lowering_registry()
|
|
69
82
|
|
|
83
|
+
def _time_step(self, label: str, func: Callable[[], _T]) -> _T:
|
|
84
|
+
timings = self._options.timings
|
|
85
|
+
if timings is None:
|
|
86
|
+
return func()
|
|
87
|
+
started = time.perf_counter()
|
|
88
|
+
result = func()
|
|
89
|
+
timings[label] = time.perf_counter() - started
|
|
90
|
+
return result
|
|
91
|
+
|
|
70
92
|
def compile(self, model: onnx.ModelProto) -> str:
|
|
71
|
-
graph = import_onnx(model)
|
|
72
|
-
graph = self.
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
93
|
+
graph = self._time_step("import_onnx", lambda: import_onnx(model))
|
|
94
|
+
graph = self._time_step(
|
|
95
|
+
"concretize_shapes",
|
|
96
|
+
lambda: self._concretize_graph_shapes(model, graph),
|
|
97
|
+
)
|
|
98
|
+
testbench_inputs = self._time_step(
|
|
99
|
+
"resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
|
|
100
|
+
)
|
|
101
|
+
variable_dim_inputs, variable_dim_outputs = self._time_step(
|
|
102
|
+
"collect_variable_dims", lambda: self._collect_variable_dims(graph)
|
|
76
103
|
)
|
|
77
|
-
lowered = self.
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
104
|
+
lowered = self._time_step(
|
|
105
|
+
"lower_model", lambda: self._lower_model(model, graph)
|
|
106
|
+
)
|
|
107
|
+
return self._time_step(
|
|
108
|
+
"emit_model",
|
|
109
|
+
lambda: self._emitter.emit_model(
|
|
110
|
+
lowered,
|
|
111
|
+
emit_testbench=self._options.emit_testbench,
|
|
112
|
+
testbench_inputs=testbench_inputs,
|
|
113
|
+
testbench_optional_inputs=self._options.testbench_optional_inputs,
|
|
114
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
115
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
116
|
+
),
|
|
84
117
|
)
|
|
85
118
|
|
|
86
119
|
def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
|
|
87
|
-
graph = import_onnx(model)
|
|
88
|
-
graph = self.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
120
|
+
graph = self._time_step("import_onnx", lambda: import_onnx(model))
|
|
121
|
+
graph = self._time_step(
|
|
122
|
+
"concretize_shapes",
|
|
123
|
+
lambda: self._concretize_graph_shapes(model, graph),
|
|
124
|
+
)
|
|
125
|
+
testbench_inputs = self._time_step(
|
|
126
|
+
"resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
|
|
127
|
+
)
|
|
128
|
+
variable_dim_inputs, variable_dim_outputs = self._time_step(
|
|
129
|
+
"collect_variable_dims", lambda: self._collect_variable_dims(graph)
|
|
130
|
+
)
|
|
131
|
+
lowered = self._time_step(
|
|
132
|
+
"lower_model", lambda: self._lower_model(model, graph)
|
|
92
133
|
)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
134
|
+
return self._time_step(
|
|
135
|
+
"emit_model_with_data_file",
|
|
136
|
+
lambda: self._emitter.emit_model_with_data_file(
|
|
137
|
+
lowered,
|
|
138
|
+
emit_testbench=self._options.emit_testbench,
|
|
139
|
+
testbench_inputs=testbench_inputs,
|
|
140
|
+
testbench_optional_inputs=self._options.testbench_optional_inputs,
|
|
141
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
142
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
143
|
+
),
|
|
100
144
|
)
|
|
101
145
|
|
|
102
146
|
def compile_with_weight_data(
|
|
103
147
|
self, model: onnx.ModelProto
|
|
104
148
|
) -> tuple[str, bytes | None]:
|
|
105
|
-
graph = import_onnx(model)
|
|
106
|
-
graph = self.
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
)
|
|
119
|
-
|
|
149
|
+
graph = self._time_step("import_onnx", lambda: import_onnx(model))
|
|
150
|
+
graph = self._time_step(
|
|
151
|
+
"concretize_shapes",
|
|
152
|
+
lambda: self._concretize_graph_shapes(model, graph),
|
|
153
|
+
)
|
|
154
|
+
testbench_inputs = self._time_step(
|
|
155
|
+
"resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
|
|
156
|
+
)
|
|
157
|
+
variable_dim_inputs, variable_dim_outputs = self._time_step(
|
|
158
|
+
"collect_variable_dims", lambda: self._collect_variable_dims(graph)
|
|
159
|
+
)
|
|
160
|
+
lowered = self._time_step(
|
|
161
|
+
"lower_model", lambda: self._lower_model(model, graph)
|
|
162
|
+
)
|
|
163
|
+
generated = self._time_step(
|
|
164
|
+
"emit_model",
|
|
165
|
+
lambda: self._emitter.emit_model(
|
|
166
|
+
lowered,
|
|
167
|
+
emit_testbench=self._options.emit_testbench,
|
|
168
|
+
testbench_inputs=testbench_inputs,
|
|
169
|
+
testbench_optional_inputs=self._options.testbench_optional_inputs,
|
|
170
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
171
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
weight_data = self._time_step(
|
|
175
|
+
"collect_weight_data",
|
|
176
|
+
lambda: self._emitter.collect_weight_data(lowered.constants),
|
|
177
|
+
)
|
|
120
178
|
return generated, weight_data
|
|
121
179
|
|
|
122
180
|
def compile_with_data_file_and_weight_data(
|
|
123
181
|
self, model: onnx.ModelProto
|
|
124
182
|
) -> tuple[str, str, bytes | None]:
|
|
125
|
-
graph = import_onnx(model)
|
|
126
|
-
graph = self.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
)
|
|
139
|
-
|
|
183
|
+
graph = self._time_step("import_onnx", lambda: import_onnx(model))
|
|
184
|
+
graph = self._time_step(
|
|
185
|
+
"concretize_shapes",
|
|
186
|
+
lambda: self._concretize_graph_shapes(model, graph),
|
|
187
|
+
)
|
|
188
|
+
testbench_inputs = self._time_step(
|
|
189
|
+
"resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
|
|
190
|
+
)
|
|
191
|
+
variable_dim_inputs, variable_dim_outputs = self._time_step(
|
|
192
|
+
"collect_variable_dims", lambda: self._collect_variable_dims(graph)
|
|
193
|
+
)
|
|
194
|
+
lowered = self._time_step(
|
|
195
|
+
"lower_model", lambda: self._lower_model(model, graph)
|
|
196
|
+
)
|
|
197
|
+
generated, data_source = self._time_step(
|
|
198
|
+
"emit_model_with_data_file",
|
|
199
|
+
lambda: self._emitter.emit_model_with_data_file(
|
|
200
|
+
lowered,
|
|
201
|
+
emit_testbench=self._options.emit_testbench,
|
|
202
|
+
testbench_inputs=testbench_inputs,
|
|
203
|
+
testbench_optional_inputs=self._options.testbench_optional_inputs,
|
|
204
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
205
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
206
|
+
),
|
|
207
|
+
)
|
|
208
|
+
weight_data = self._time_step(
|
|
209
|
+
"collect_weight_data",
|
|
210
|
+
lambda: self._emitter.collect_weight_data(lowered.constants),
|
|
211
|
+
)
|
|
140
212
|
return generated, data_source, weight_data
|
|
141
213
|
|
|
142
214
|
@staticmethod
|
|
@@ -165,9 +237,11 @@ class Compiler:
|
|
|
165
237
|
self._validate_graph(graph)
|
|
166
238
|
(
|
|
167
239
|
input_names,
|
|
240
|
+
input_optional_names,
|
|
168
241
|
input_shapes,
|
|
169
242
|
input_dtypes,
|
|
170
243
|
output_names,
|
|
244
|
+
output_optional_names,
|
|
171
245
|
output_shapes,
|
|
172
246
|
output_dtypes,
|
|
173
247
|
) = self._collect_io_specs(graph)
|
|
@@ -220,9 +294,11 @@ class Compiler:
|
|
|
220
294
|
return LoweredModel(
|
|
221
295
|
name=self._options.model_name,
|
|
222
296
|
input_names=input_names,
|
|
297
|
+
input_optional_names=input_optional_names,
|
|
223
298
|
input_shapes=input_shapes,
|
|
224
299
|
input_dtypes=input_dtypes,
|
|
225
300
|
output_names=output_names,
|
|
301
|
+
output_optional_names=output_optional_names,
|
|
226
302
|
output_shapes=output_shapes,
|
|
227
303
|
output_dtypes=output_dtypes,
|
|
228
304
|
constants=constants,
|
|
@@ -248,7 +324,6 @@ class Compiler:
|
|
|
248
324
|
"Testbench inputs include unknown inputs: "
|
|
249
325
|
+ ", ".join(unknown_inputs)
|
|
250
326
|
)
|
|
251
|
-
resolved: dict[str, tuple[float | int | bool, ...]] = {}
|
|
252
327
|
for name, values in self._options.testbench_inputs.items():
|
|
253
328
|
if not isinstance(values, np.ndarray):
|
|
254
329
|
raise CodegenError(
|
|
@@ -265,9 +340,7 @@ class Compiler:
|
|
|
265
340
|
"Testbench input "
|
|
266
341
|
f"{name} has {array.size} elements, expected {expected_count}"
|
|
267
342
|
)
|
|
268
|
-
|
|
269
|
-
resolved[name] = tuple(array.ravel().tolist())
|
|
270
|
-
return resolved
|
|
343
|
+
return None
|
|
271
344
|
|
|
272
345
|
def _concretize_graph_shapes(
|
|
273
346
|
self, model: onnx.ModelProto, graph: Graph
|
|
@@ -337,6 +410,7 @@ class Compiler:
|
|
|
337
410
|
dtype=value.type.dtype,
|
|
338
411
|
shape=shape,
|
|
339
412
|
dim_params=(None,) * len(shape),
|
|
413
|
+
is_optional=value.type.is_optional,
|
|
340
414
|
),
|
|
341
415
|
)
|
|
342
416
|
|
|
@@ -361,27 +435,39 @@ class Compiler:
|
|
|
361
435
|
self, graph: Graph
|
|
362
436
|
) -> tuple[
|
|
363
437
|
tuple[str, ...],
|
|
438
|
+
tuple[str | None, ...],
|
|
364
439
|
tuple[tuple[int, ...], ...],
|
|
365
440
|
tuple[ScalarType, ...],
|
|
366
441
|
tuple[str, ...],
|
|
442
|
+
tuple[str | None, ...],
|
|
367
443
|
tuple[tuple[int, ...], ...],
|
|
368
444
|
tuple[ScalarType, ...],
|
|
369
445
|
]:
|
|
370
446
|
input_names = tuple(value.name for value in graph.inputs)
|
|
447
|
+
input_optional_names = tuple(
|
|
448
|
+
_optional_flag_name(value.name) if value.type.is_optional else None
|
|
449
|
+
for value in graph.inputs
|
|
450
|
+
)
|
|
371
451
|
input_shapes = tuple(value.type.shape for value in graph.inputs)
|
|
372
452
|
input_dtypes = tuple(
|
|
373
453
|
value_dtype(graph, value.name) for value in graph.inputs
|
|
374
454
|
)
|
|
375
455
|
output_names = tuple(value.name for value in graph.outputs)
|
|
456
|
+
output_optional_names = tuple(
|
|
457
|
+
_optional_flag_name(value.name) if value.type.is_optional else None
|
|
458
|
+
for value in graph.outputs
|
|
459
|
+
)
|
|
376
460
|
output_shapes = tuple(value.type.shape for value in graph.outputs)
|
|
377
461
|
output_dtypes = tuple(
|
|
378
462
|
value_dtype(graph, value.name) for value in graph.outputs
|
|
379
463
|
)
|
|
380
464
|
return (
|
|
381
465
|
input_names,
|
|
466
|
+
input_optional_names,
|
|
382
467
|
input_shapes,
|
|
383
468
|
input_dtypes,
|
|
384
469
|
output_names,
|
|
470
|
+
output_optional_names,
|
|
385
471
|
output_shapes,
|
|
386
472
|
output_dtypes,
|
|
387
473
|
)
|
|
@@ -439,26 +525,22 @@ class Compiler:
|
|
|
439
525
|
initializer_count=len(graph.initializers),
|
|
440
526
|
)
|
|
441
527
|
|
|
442
|
-
def run(
|
|
443
|
-
self, model: onnx.ModelProto, feeds: Mapping[str, np.ndarray]
|
|
444
|
-
) -> dict[str, np.ndarray]:
|
|
445
|
-
graph = import_onnx(model)
|
|
446
|
-
evaluator = Evaluator(graph)
|
|
447
|
-
return evaluator.run(feeds)
|
|
448
|
-
|
|
449
|
-
|
|
450
528
|
def _lowered_constants(graph: Graph | GraphContext) -> tuple[ConstTensor, ...]:
|
|
529
|
+
used_initializers = {value.name for value in graph.outputs}
|
|
530
|
+
for node in graph.nodes:
|
|
531
|
+
used_initializers.update(node.inputs)
|
|
451
532
|
constants: list[ConstTensor] = []
|
|
452
533
|
for initializer in graph.initializers:
|
|
534
|
+
if initializer.name not in used_initializers:
|
|
535
|
+
continue
|
|
453
536
|
dtype = ensure_supported_dtype(initializer.type.dtype)
|
|
537
|
+
data_array = initializer.data.astype(dtype.np_dtype, copy=False)
|
|
538
|
+
data_tuple = tuple(data_array.ravel().tolist())
|
|
454
539
|
constants.append(
|
|
455
540
|
ConstTensor(
|
|
456
541
|
name=initializer.name,
|
|
457
542
|
shape=initializer.type.shape,
|
|
458
|
-
data=
|
|
459
|
-
dtype.np_dtype.type(value)
|
|
460
|
-
for value in initializer.data.ravel()
|
|
461
|
-
),
|
|
543
|
+
data=data_tuple,
|
|
462
544
|
dtype=dtype,
|
|
463
545
|
)
|
|
464
546
|
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
import os
|
|
5
|
+
from typing import Iterator
|
|
6
|
+
|
|
7
|
+
THREAD_ENV_VARS = (
|
|
8
|
+
"OMP_NUM_THREADS",
|
|
9
|
+
"OPENBLAS_NUM_THREADS",
|
|
10
|
+
"MKL_NUM_THREADS",
|
|
11
|
+
"VECLIB_MAXIMUM_THREADS",
|
|
12
|
+
"NUMEXPR_NUM_THREADS",
|
|
13
|
+
"BLIS_NUM_THREADS",
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@contextmanager
|
|
18
|
+
def deterministic_reference_runtime() -> Iterator[None]:
|
|
19
|
+
previous = {name: os.environ.get(name) for name in THREAD_ENV_VARS}
|
|
20
|
+
for name in THREAD_ENV_VARS:
|
|
21
|
+
os.environ[name] = "1"
|
|
22
|
+
limits_context = None
|
|
23
|
+
try:
|
|
24
|
+
try:
|
|
25
|
+
from threadpoolctl import threadpool_limits
|
|
26
|
+
except Exception:
|
|
27
|
+
threadpool_limits = None
|
|
28
|
+
if threadpool_limits is not None:
|
|
29
|
+
limits_context = threadpool_limits(limits=1)
|
|
30
|
+
limits_context.__enter__()
|
|
31
|
+
yield
|
|
32
|
+
finally:
|
|
33
|
+
if limits_context is not None:
|
|
34
|
+
limits_context.__exit__(None, None, None)
|
|
35
|
+
for name, value in previous.items():
|
|
36
|
+
if value is None:
|
|
37
|
+
os.environ.pop(name, None)
|
|
38
|
+
else:
|
|
39
|
+
os.environ[name] = value
|
emx_onnx_cgen/ir/context.py
CHANGED
|
@@ -14,9 +14,28 @@ class GraphContext:
|
|
|
14
14
|
_shape_cache: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
15
15
|
_initializer_cache: dict[str, Initializer] = field(default_factory=dict)
|
|
16
16
|
_producer_cache: dict[str, Node] = field(default_factory=dict)
|
|
17
|
+
_value_cache: dict[str, Value] = field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
def __post_init__(self) -> None:
|
|
20
|
+
for value in self.graph.inputs + self.graph.outputs + self.graph.values:
|
|
21
|
+
self._value_cache[value.name] = value
|
|
22
|
+
for initializer in self.graph.initializers:
|
|
23
|
+
if initializer.name not in self._value_cache:
|
|
24
|
+
self._value_cache[initializer.name] = Value(
|
|
25
|
+
name=initializer.name,
|
|
26
|
+
type=initializer.type,
|
|
27
|
+
)
|
|
28
|
+
self._initializer_cache[initializer.name] = initializer
|
|
29
|
+
for node in self.graph.nodes:
|
|
30
|
+
for output in node.outputs:
|
|
31
|
+
if output and output not in self._producer_cache:
|
|
32
|
+
self._producer_cache[output] = node
|
|
17
33
|
|
|
18
34
|
def find_value(self, name: str) -> Value:
|
|
19
|
-
|
|
35
|
+
value = self._value_cache.get(name)
|
|
36
|
+
if value is None:
|
|
37
|
+
raise KeyError(name)
|
|
38
|
+
return value
|
|
20
39
|
|
|
21
40
|
def dtype(self, name: str, node: Node | None = None) -> ScalarType:
|
|
22
41
|
if name in self._dtype_cache:
|
|
@@ -55,23 +74,14 @@ class GraphContext:
|
|
|
55
74
|
def set_shape(self, name: str, shape: tuple[int, ...]) -> None:
|
|
56
75
|
self._shape_cache[name] = shape
|
|
57
76
|
|
|
77
|
+
def has_shape(self, name: str) -> bool:
|
|
78
|
+
return name in self._shape_cache
|
|
79
|
+
|
|
58
80
|
def initializer(self, name: str) -> Initializer | None:
|
|
59
|
-
|
|
60
|
-
return self._initializer_cache[name]
|
|
61
|
-
for initializer in self.graph.initializers:
|
|
62
|
-
if initializer.name == name:
|
|
63
|
-
self._initializer_cache[name] = initializer
|
|
64
|
-
return initializer
|
|
65
|
-
return None
|
|
81
|
+
return self._initializer_cache.get(name)
|
|
66
82
|
|
|
67
83
|
def producer(self, output_name: str) -> Node | None:
|
|
68
|
-
|
|
69
|
-
return self._producer_cache[output_name]
|
|
70
|
-
for node in self.graph.nodes:
|
|
71
|
-
if output_name in node.outputs:
|
|
72
|
-
self._producer_cache[output_name] = node
|
|
73
|
-
return node
|
|
74
|
-
return None
|
|
84
|
+
return self._producer_cache.get(output_name)
|
|
75
85
|
|
|
76
86
|
def opset_version(self, domain: str = "") -> int | None:
|
|
77
87
|
if domain in {"", "ai.onnx"}:
|
emx_onnx_cgen/ir/model.py
CHANGED
emx_onnx_cgen/ir/op_base.py
CHANGED
|
@@ -414,19 +414,20 @@ class VariadicLikeOpBase(RenderableOpBase):
|
|
|
414
414
|
|
|
415
415
|
def infer_shapes(self, ctx: OpContext) -> None:
|
|
416
416
|
input_shapes = tuple(ctx.shape(name) for name in self._variadic_inputs())
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
417
|
+
try:
|
|
418
|
+
output_shape = BroadcastingOpBase.broadcast_shapes(*input_shapes)
|
|
419
|
+
except ShapeInferenceError as exc:
|
|
420
|
+
raise UnsupportedOpError(
|
|
421
|
+
f"{self._variadic_kind()} expects broadcastable input shapes"
|
|
422
|
+
) from exc
|
|
423
423
|
try:
|
|
424
424
|
expected = ctx.shape(self._variadic_output())
|
|
425
425
|
except ShapeInferenceError:
|
|
426
426
|
expected = None
|
|
427
427
|
if expected is not None and expected != output_shape:
|
|
428
428
|
raise UnsupportedOpError(
|
|
429
|
-
f"{self._variadic_kind()}
|
|
429
|
+
f"{self._variadic_kind()} output shape must be {output_shape}, "
|
|
430
|
+
f"got {expected}"
|
|
430
431
|
)
|
|
431
432
|
ctx.set_shape(self._variadic_output(), output_shape)
|
|
432
433
|
|
|
@@ -469,6 +470,30 @@ class ReduceOpBase(RenderableOpBase):
|
|
|
469
470
|
|
|
470
471
|
|
|
471
472
|
class BroadcastingOpBase(RenderableOpBase):
|
|
473
|
+
@staticmethod
|
|
474
|
+
def unidirectional_broadcastable(
|
|
475
|
+
source: tuple[int, ...],
|
|
476
|
+
target: tuple[int, ...],
|
|
477
|
+
) -> bool:
|
|
478
|
+
if len(source) > len(target):
|
|
479
|
+
return False
|
|
480
|
+
padded = (1,) * (len(target) - len(source)) + source
|
|
481
|
+
for source_dim, target_dim in zip(padded, target):
|
|
482
|
+
if source_dim not in {1, target_dim}:
|
|
483
|
+
return False
|
|
484
|
+
return True
|
|
485
|
+
|
|
486
|
+
@staticmethod
|
|
487
|
+
def prelu_channel_axis(
|
|
488
|
+
input_shape: tuple[int, ...],
|
|
489
|
+
slope_shape: tuple[int, ...],
|
|
490
|
+
) -> int | None:
|
|
491
|
+
if len(input_shape) < 2 or len(slope_shape) != 1:
|
|
492
|
+
return None
|
|
493
|
+
if slope_shape[0] != input_shape[1]:
|
|
494
|
+
return None
|
|
495
|
+
return 1
|
|
496
|
+
|
|
472
497
|
@staticmethod
|
|
473
498
|
def broadcast_shapes(
|
|
474
499
|
*shapes: tuple[int, ...],
|
emx_onnx_cgen/ir/ops/__init__.py
CHANGED
|
@@ -3,28 +3,35 @@ from .elementwise import (
|
|
|
3
3
|
ClipOp,
|
|
4
4
|
IdentityOp,
|
|
5
5
|
MultiInputBinaryOp,
|
|
6
|
+
PowOp,
|
|
7
|
+
QLinearMulOp,
|
|
6
8
|
UnaryOp,
|
|
7
9
|
VariadicOp,
|
|
8
10
|
WhereOp,
|
|
9
11
|
)
|
|
10
12
|
from .misc import (
|
|
13
|
+
BernoulliOp,
|
|
11
14
|
CastOp,
|
|
12
15
|
ConcatOp,
|
|
13
16
|
ConstantOfShapeOp,
|
|
14
17
|
CumSumOp,
|
|
15
18
|
DepthToSpaceOp,
|
|
19
|
+
DequantizeLinearOp,
|
|
16
20
|
ExpandOp,
|
|
17
21
|
EyeLikeOp,
|
|
18
22
|
GatherElementsOp,
|
|
19
23
|
GatherNDOp,
|
|
20
24
|
GatherOp,
|
|
21
25
|
GridSampleOp,
|
|
26
|
+
HammingWindowOp,
|
|
22
27
|
NonMaxSuppressionOp,
|
|
23
28
|
NonZeroOp,
|
|
24
29
|
OneHotOp,
|
|
30
|
+
OptionalHasElementOp,
|
|
25
31
|
PadOp,
|
|
26
32
|
QuantizeLinearOp,
|
|
27
33
|
RangeOp,
|
|
34
|
+
HammingWindowOp,
|
|
28
35
|
ReshapeOp,
|
|
29
36
|
ResizeOp,
|
|
30
37
|
ScatterNDOp,
|
|
@@ -34,6 +41,7 @@ from .misc import (
|
|
|
34
41
|
SpaceToDepthOp,
|
|
35
42
|
SplitOp,
|
|
36
43
|
TensorScatterOp,
|
|
44
|
+
TfIdfVectorizerOp,
|
|
37
45
|
TileOp,
|
|
38
46
|
TransposeOp,
|
|
39
47
|
TriluOp,
|
|
@@ -44,10 +52,12 @@ from .nn import (
|
|
|
44
52
|
AveragePoolOp,
|
|
45
53
|
BatchNormOp,
|
|
46
54
|
ConvOp,
|
|
55
|
+
ConvIntegerOp,
|
|
47
56
|
ConvTransposeOp,
|
|
48
57
|
EinsumKind,
|
|
49
58
|
EinsumOp,
|
|
50
59
|
GemmOp,
|
|
60
|
+
GruOp,
|
|
51
61
|
GroupNormalizationOp,
|
|
52
62
|
HardmaxOp,
|
|
53
63
|
InstanceNormalizationOp,
|
|
@@ -75,15 +85,18 @@ __all__ = [
|
|
|
75
85
|
"AttentionOp",
|
|
76
86
|
"AveragePoolOp",
|
|
77
87
|
"BatchNormOp",
|
|
88
|
+
"BernoulliOp",
|
|
78
89
|
"BinaryOp",
|
|
79
90
|
"CastOp",
|
|
80
91
|
"ClipOp",
|
|
81
92
|
"ConcatOp",
|
|
82
93
|
"ConstantOfShapeOp",
|
|
83
94
|
"ConvOp",
|
|
95
|
+
"ConvIntegerOp",
|
|
84
96
|
"ConvTransposeOp",
|
|
85
97
|
"CumSumOp",
|
|
86
98
|
"DepthToSpaceOp",
|
|
99
|
+
"DequantizeLinearOp",
|
|
87
100
|
"EinsumKind",
|
|
88
101
|
"EinsumOp",
|
|
89
102
|
"ExpandOp",
|
|
@@ -93,6 +106,8 @@ __all__ = [
|
|
|
93
106
|
"GatherOp",
|
|
94
107
|
"GemmOp",
|
|
95
108
|
"GridSampleOp",
|
|
109
|
+
"GruOp",
|
|
110
|
+
"HammingWindowOp",
|
|
96
111
|
"GroupNormalizationOp",
|
|
97
112
|
"HardmaxOp",
|
|
98
113
|
"IdentityOp",
|
|
@@ -111,10 +126,14 @@ __all__ = [
|
|
|
111
126
|
"NonMaxSuppressionOp",
|
|
112
127
|
"NonZeroOp",
|
|
113
128
|
"OneHotOp",
|
|
129
|
+
"OptionalHasElementOp",
|
|
114
130
|
"PadOp",
|
|
131
|
+
"PowOp",
|
|
115
132
|
"QuantizeLinearOp",
|
|
133
|
+
"QLinearMulOp",
|
|
116
134
|
"QLinearMatMulOp",
|
|
117
135
|
"RangeOp",
|
|
136
|
+
"HammingWindowOp",
|
|
118
137
|
"ReduceOp",
|
|
119
138
|
"ReshapeOp",
|
|
120
139
|
"ResizeOp",
|
|
@@ -129,6 +148,7 @@ __all__ = [
|
|
|
129
148
|
"SpaceToDepthOp",
|
|
130
149
|
"SplitOp",
|
|
131
150
|
"TensorScatterOp",
|
|
151
|
+
"TfIdfVectorizerOp",
|
|
132
152
|
"TileOp",
|
|
133
153
|
"TopKOp",
|
|
134
154
|
"TransposeOp",
|