emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import hashlib
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Mapping
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import onnx
|
|
10
|
+
|
|
11
|
+
from shared.scalar_types import ScalarType
|
|
12
|
+
|
|
13
|
+
from .codegen.c_emitter import (
|
|
14
|
+
AttentionOp,
|
|
15
|
+
AveragePoolOp,
|
|
16
|
+
BatchNormOp,
|
|
17
|
+
LpNormalizationOp,
|
|
18
|
+
InstanceNormalizationOp,
|
|
19
|
+
GroupNormalizationOp,
|
|
20
|
+
LayerNormalizationOp,
|
|
21
|
+
MeanVarianceNormalizationOp,
|
|
22
|
+
RMSNormalizationOp,
|
|
23
|
+
BinaryOp,
|
|
24
|
+
MultiInputBinaryOp,
|
|
25
|
+
CastOp,
|
|
26
|
+
ClipOp,
|
|
27
|
+
CEmitter,
|
|
28
|
+
ConstTensor,
|
|
29
|
+
ConvOp,
|
|
30
|
+
ConcatOp,
|
|
31
|
+
ConstantOfShapeOp,
|
|
32
|
+
CumSumOp,
|
|
33
|
+
GemmOp,
|
|
34
|
+
GatherOp,
|
|
35
|
+
GatherElementsOp,
|
|
36
|
+
ExpandOp,
|
|
37
|
+
RangeOp,
|
|
38
|
+
LrnOp,
|
|
39
|
+
LstmOp,
|
|
40
|
+
LogSoftmaxOp,
|
|
41
|
+
NegativeLogLikelihoodLossOp,
|
|
42
|
+
NodeInfo,
|
|
43
|
+
PadOp,
|
|
44
|
+
SplitOp,
|
|
45
|
+
SoftmaxCrossEntropyLossOp,
|
|
46
|
+
LoweredModel,
|
|
47
|
+
ModelHeader,
|
|
48
|
+
MatMulOp,
|
|
49
|
+
MaxPoolOp,
|
|
50
|
+
ReduceOp,
|
|
51
|
+
ArgReduceOp,
|
|
52
|
+
ReshapeOp,
|
|
53
|
+
ResizeOp,
|
|
54
|
+
GridSampleOp,
|
|
55
|
+
SoftmaxOp,
|
|
56
|
+
ShapeOp,
|
|
57
|
+
SliceOp,
|
|
58
|
+
TransposeOp,
|
|
59
|
+
UnaryOp,
|
|
60
|
+
WhereOp,
|
|
61
|
+
)
|
|
62
|
+
from .dtypes import dtype_info
|
|
63
|
+
from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
|
|
64
|
+
from .ir.model import Graph, Value
|
|
65
|
+
from .lowering.attention import AttentionSpec, resolve_attention_spec
|
|
66
|
+
from .lowering.average_pool import (
|
|
67
|
+
lower_average_pool,
|
|
68
|
+
lower_global_average_pool,
|
|
69
|
+
)
|
|
70
|
+
from .lowering.batch_normalization import lower_batch_normalization
|
|
71
|
+
from .lowering.cast import lower_cast
|
|
72
|
+
from .lowering.concat import lower_concat
|
|
73
|
+
from .lowering.common import (
|
|
74
|
+
ensure_supported_dtype,
|
|
75
|
+
node_dtype,
|
|
76
|
+
shape_product,
|
|
77
|
+
value_dtype,
|
|
78
|
+
value_shape,
|
|
79
|
+
)
|
|
80
|
+
from .lowering.conv import ConvSpec, resolve_conv_spec
|
|
81
|
+
from .lowering.constant_of_shape import lower_constant_of_shape
|
|
82
|
+
from .lowering.dropout import lower_dropout
|
|
83
|
+
from .lowering import cumsum as _cumsum # noqa: F401
|
|
84
|
+
from .lowering.flatten import lower_flatten
|
|
85
|
+
from .lowering.gather import lower_gather
|
|
86
|
+
from .lowering.gather_elements import lower_gather_elements
|
|
87
|
+
from .lowering.gemm import resolve_gemm_spec, validate_gemm_bias_shape
|
|
88
|
+
from .lowering.lrn import LrnSpec, resolve_lrn_spec
|
|
89
|
+
from .lowering.logsoftmax import lower_logsoftmax
|
|
90
|
+
from .lowering import group_normalization as _group_normalization # noqa: F401
|
|
91
|
+
from .lowering import instance_normalization as _instance_normalization # noqa: F401
|
|
92
|
+
from .lowering import layer_normalization as _layer_normalization # noqa: F401
|
|
93
|
+
from .lowering import lp_normalization as _lp_normalization # noqa: F401
|
|
94
|
+
from .lowering import mean_variance_normalization as _mean_variance_normalization # noqa: F401
|
|
95
|
+
from .lowering.negative_log_likelihood_loss import (
|
|
96
|
+
lower_negative_log_likelihood_loss,
|
|
97
|
+
)
|
|
98
|
+
from .lowering.expand import lower_expand
|
|
99
|
+
from .lowering.range import lower_range
|
|
100
|
+
from .lowering.split import lower_split
|
|
101
|
+
from .lowering.softmax_cross_entropy_loss import (
|
|
102
|
+
lower_softmax_cross_entropy_loss,
|
|
103
|
+
)
|
|
104
|
+
from .lowering.matmul import lower_matmul
|
|
105
|
+
from .lowering.maxpool import MaxPoolSpec, resolve_maxpool_spec
|
|
106
|
+
from .lowering import pad as _pad # noqa: F401
|
|
107
|
+
from .lowering.reduce import (
|
|
108
|
+
REDUCE_KIND_BY_OP,
|
|
109
|
+
REDUCE_OUTPUTS_FLOAT_ONLY,
|
|
110
|
+
)
|
|
111
|
+
from .lowering import arg_reduce as _arg_reduce # noqa: F401
|
|
112
|
+
from .lowering.reshape import lower_reshape
|
|
113
|
+
from .lowering.resize import lower_resize
|
|
114
|
+
from .lowering.grid_sample import lower_grid_sample
|
|
115
|
+
from .lowering.slice import lower_slice
|
|
116
|
+
from .lowering.squeeze import lower_squeeze
|
|
117
|
+
from .lowering import depth_space as _depth_space # noqa: F401
|
|
118
|
+
from .lowering import eye_like as _eye_like # noqa: F401
|
|
119
|
+
from .lowering import identity as _identity # noqa: F401
|
|
120
|
+
from .lowering import tile as _tile # noqa: F401
|
|
121
|
+
from .lowering.shape import lower_shape
|
|
122
|
+
from .lowering.size import lower_size
|
|
123
|
+
from .lowering.softmax import lower_softmax
|
|
124
|
+
from .lowering.transpose import lower_transpose
|
|
125
|
+
from .lowering.unsqueeze import lower_unsqueeze
|
|
126
|
+
from .lowering.where import lower_where
|
|
127
|
+
from .lowering.elementwise import (
|
|
128
|
+
lower_celu,
|
|
129
|
+
lower_clip,
|
|
130
|
+
lower_isinf,
|
|
131
|
+
lower_isnan,
|
|
132
|
+
lower_shrink,
|
|
133
|
+
lower_swish,
|
|
134
|
+
)
|
|
135
|
+
from .lowering import variadic as _variadic # noqa: F401
|
|
136
|
+
from .lowering import rms_normalization as _rms_normalization # noqa: F401
|
|
137
|
+
from .lowering.registry import get_lowering_registry, resolve_dispatch
|
|
138
|
+
from .onnx_import import import_onnx
|
|
139
|
+
from .ops import (
|
|
140
|
+
BINARY_OP_TYPES,
|
|
141
|
+
COMPARE_FUNCTIONS,
|
|
142
|
+
UNARY_OP_TYPES,
|
|
143
|
+
binary_op_symbol,
|
|
144
|
+
unary_op_symbol,
|
|
145
|
+
validate_unary_attrs,
|
|
146
|
+
)
|
|
147
|
+
from shared.scalar_functions import ScalarFunction, ScalarFunctionError
|
|
148
|
+
from .runtime.evaluator import Evaluator
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass(frozen=True)
|
|
152
|
+
class CompilerOptions:
|
|
153
|
+
template_dir: Path
|
|
154
|
+
model_name: str = "model"
|
|
155
|
+
emit_testbench: bool = False
|
|
156
|
+
command_line: str | None = None
|
|
157
|
+
model_checksum: str | None = None
|
|
158
|
+
restrict_arrays: bool = True
|
|
159
|
+
testbench_inputs: Mapping[str, np.ndarray] | None = None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class Compiler:
|
|
163
|
+
def __init__(self, options: CompilerOptions | None = None) -> None:
|
|
164
|
+
if options is None:
|
|
165
|
+
options = CompilerOptions(template_dir=Path("templates"))
|
|
166
|
+
self._options = options
|
|
167
|
+
self._emitter = CEmitter(
|
|
168
|
+
options.template_dir, restrict_arrays=options.restrict_arrays
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def compile(self, model: onnx.ModelProto) -> str:
|
|
172
|
+
graph = import_onnx(model)
|
|
173
|
+
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
174
|
+
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
175
|
+
graph
|
|
176
|
+
)
|
|
177
|
+
lowered = self._lower_model(model, graph)
|
|
178
|
+
return self._emitter.emit_model(
|
|
179
|
+
lowered,
|
|
180
|
+
emit_testbench=self._options.emit_testbench,
|
|
181
|
+
testbench_inputs=testbench_inputs,
|
|
182
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
183
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
|
|
187
|
+
graph = import_onnx(model)
|
|
188
|
+
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
189
|
+
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
190
|
+
graph
|
|
191
|
+
)
|
|
192
|
+
lowered = self._lower_model(model, graph)
|
|
193
|
+
return self._emitter.emit_model_with_data_file(
|
|
194
|
+
lowered,
|
|
195
|
+
emit_testbench=self._options.emit_testbench,
|
|
196
|
+
testbench_inputs=testbench_inputs,
|
|
197
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
198
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def _collect_variable_dims(
|
|
203
|
+
graph: Graph,
|
|
204
|
+
) -> tuple[dict[int, dict[int, str]], dict[int, dict[int, str]]]:
|
|
205
|
+
def collect(values: tuple[Value, ...]) -> dict[int, dict[int, str]]:
|
|
206
|
+
dim_map: dict[int, dict[int, str]] = {}
|
|
207
|
+
for index, value in enumerate(values):
|
|
208
|
+
dims = {
|
|
209
|
+
dim_index: dim_param
|
|
210
|
+
for dim_index, dim_param in enumerate(
|
|
211
|
+
value.type.dim_params
|
|
212
|
+
)
|
|
213
|
+
if dim_param
|
|
214
|
+
}
|
|
215
|
+
if dims:
|
|
216
|
+
dim_map[index] = dims
|
|
217
|
+
return dim_map
|
|
218
|
+
|
|
219
|
+
return collect(graph.inputs), collect(graph.outputs)
|
|
220
|
+
|
|
221
|
+
def _lower_model(self, model: onnx.ModelProto, graph: Graph) -> LoweredModel:
|
|
222
|
+
constants = _lowered_constants(graph)
|
|
223
|
+
self._validate_graph(graph)
|
|
224
|
+
(
|
|
225
|
+
input_names,
|
|
226
|
+
input_shapes,
|
|
227
|
+
input_dtypes,
|
|
228
|
+
output_names,
|
|
229
|
+
output_shapes,
|
|
230
|
+
output_dtypes,
|
|
231
|
+
) = self._collect_io_specs(graph)
|
|
232
|
+
ops, node_infos = self._lower_nodes(graph)
|
|
233
|
+
header = self._build_header(model, graph)
|
|
234
|
+
return LoweredModel(
|
|
235
|
+
name=self._options.model_name,
|
|
236
|
+
input_names=input_names,
|
|
237
|
+
input_shapes=input_shapes,
|
|
238
|
+
input_dtypes=input_dtypes,
|
|
239
|
+
output_names=output_names,
|
|
240
|
+
output_shapes=output_shapes,
|
|
241
|
+
output_dtypes=output_dtypes,
|
|
242
|
+
constants=constants,
|
|
243
|
+
ops=tuple(ops),
|
|
244
|
+
node_infos=tuple(node_infos),
|
|
245
|
+
header=header,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def _resolve_testbench_inputs(
|
|
249
|
+
self, graph: Graph
|
|
250
|
+
) -> Mapping[str, tuple[float | int | bool, ...]] | None:
|
|
251
|
+
if not self._options.testbench_inputs:
|
|
252
|
+
return None
|
|
253
|
+
input_specs = {value.name: value for value in graph.inputs}
|
|
254
|
+
unknown_inputs = sorted(
|
|
255
|
+
name
|
|
256
|
+
for name in self._options.testbench_inputs
|
|
257
|
+
if name not in input_specs
|
|
258
|
+
)
|
|
259
|
+
if unknown_inputs:
|
|
260
|
+
raise CodegenError(
|
|
261
|
+
"Testbench inputs include unknown inputs: "
|
|
262
|
+
+ ", ".join(unknown_inputs)
|
|
263
|
+
)
|
|
264
|
+
resolved: dict[str, tuple[float | int | bool, ...]] = {}
|
|
265
|
+
for name, values in self._options.testbench_inputs.items():
|
|
266
|
+
if not isinstance(values, np.ndarray):
|
|
267
|
+
raise CodegenError(
|
|
268
|
+
f"Testbench input {name} must be a numpy array"
|
|
269
|
+
)
|
|
270
|
+
input_value = input_specs[name]
|
|
271
|
+
dtype = value_dtype(graph, name)
|
|
272
|
+
info = dtype_info(dtype)
|
|
273
|
+
expected_shape = input_value.type.shape
|
|
274
|
+
expected_count = shape_product(expected_shape)
|
|
275
|
+
array = values.astype(info.np_dtype, copy=False)
|
|
276
|
+
if array.size != expected_count:
|
|
277
|
+
raise CodegenError(
|
|
278
|
+
"Testbench input "
|
|
279
|
+
f"{name} has {array.size} elements, expected {expected_count}"
|
|
280
|
+
)
|
|
281
|
+
array = array.reshape(expected_shape)
|
|
282
|
+
resolved[name] = tuple(array.ravel().tolist())
|
|
283
|
+
return resolved
|
|
284
|
+
|
|
285
|
+
def _validate_graph(self, graph: Graph) -> None:
|
|
286
|
+
if not graph.outputs:
|
|
287
|
+
raise UnsupportedOpError("Graph must have at least one output")
|
|
288
|
+
if not graph.nodes:
|
|
289
|
+
raise UnsupportedOpError("Graph must contain at least one node")
|
|
290
|
+
for value in graph.outputs:
|
|
291
|
+
element_count = shape_product(value.type.shape)
|
|
292
|
+
if element_count <= 0:
|
|
293
|
+
raise ShapeInferenceError("Output shape must be fully defined")
|
|
294
|
+
|
|
295
|
+
def _collect_io_specs(
|
|
296
|
+
self, graph: Graph
|
|
297
|
+
) -> tuple[
|
|
298
|
+
tuple[str, ...],
|
|
299
|
+
tuple[tuple[int, ...], ...],
|
|
300
|
+
tuple[ScalarType, ...],
|
|
301
|
+
tuple[str, ...],
|
|
302
|
+
tuple[tuple[int, ...], ...],
|
|
303
|
+
tuple[ScalarType, ...],
|
|
304
|
+
]:
|
|
305
|
+
input_names = tuple(value.name for value in graph.inputs)
|
|
306
|
+
input_shapes = tuple(value.type.shape for value in graph.inputs)
|
|
307
|
+
input_dtypes = tuple(
|
|
308
|
+
value_dtype(graph, value.name) for value in graph.inputs
|
|
309
|
+
)
|
|
310
|
+
output_names = tuple(value.name for value in graph.outputs)
|
|
311
|
+
output_shapes = tuple(value.type.shape for value in graph.outputs)
|
|
312
|
+
output_dtypes = tuple(
|
|
313
|
+
value_dtype(graph, value.name) for value in graph.outputs
|
|
314
|
+
)
|
|
315
|
+
return (
|
|
316
|
+
input_names,
|
|
317
|
+
input_shapes,
|
|
318
|
+
input_dtypes,
|
|
319
|
+
output_names,
|
|
320
|
+
output_shapes,
|
|
321
|
+
output_dtypes,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def _lower_nodes(
|
|
325
|
+
self, graph: Graph
|
|
326
|
+
) -> tuple[
|
|
327
|
+
list[
|
|
328
|
+
BinaryOp
|
|
329
|
+
| MultiInputBinaryOp
|
|
330
|
+
| UnaryOp
|
|
331
|
+
| ClipOp
|
|
332
|
+
| CastOp
|
|
333
|
+
| MatMulOp
|
|
334
|
+
| GemmOp
|
|
335
|
+
| AttentionOp
|
|
336
|
+
| ConvOp
|
|
337
|
+
| AveragePoolOp
|
|
338
|
+
| BatchNormOp
|
|
339
|
+
| LpNormalizationOp
|
|
340
|
+
| InstanceNormalizationOp
|
|
341
|
+
| GroupNormalizationOp
|
|
342
|
+
| LayerNormalizationOp
|
|
343
|
+
| MeanVarianceNormalizationOp
|
|
344
|
+
| RMSNormalizationOp
|
|
345
|
+
| LrnOp
|
|
346
|
+
| LstmOp
|
|
347
|
+
| SoftmaxOp
|
|
348
|
+
| LogSoftmaxOp
|
|
349
|
+
| NegativeLogLikelihoodLossOp
|
|
350
|
+
| SoftmaxCrossEntropyLossOp
|
|
351
|
+
| MaxPoolOp
|
|
352
|
+
| ConcatOp
|
|
353
|
+
| GatherElementsOp
|
|
354
|
+
| GatherOp
|
|
355
|
+
| TransposeOp
|
|
356
|
+
| ConstantOfShapeOp
|
|
357
|
+
| ReshapeOp
|
|
358
|
+
| SliceOp
|
|
359
|
+
| ResizeOp
|
|
360
|
+
| GridSampleOp
|
|
361
|
+
| ReduceOp
|
|
362
|
+
| ArgReduceOp
|
|
363
|
+
| ShapeOp
|
|
364
|
+
| PadOp
|
|
365
|
+
| ExpandOp
|
|
366
|
+
| CumSumOp
|
|
367
|
+
| RangeOp
|
|
368
|
+
| SplitOp
|
|
369
|
+
],
|
|
370
|
+
list[NodeInfo],
|
|
371
|
+
]:
|
|
372
|
+
ops: list[
|
|
373
|
+
BinaryOp
|
|
374
|
+
| MultiInputBinaryOp
|
|
375
|
+
| UnaryOp
|
|
376
|
+
| ClipOp
|
|
377
|
+
| CastOp
|
|
378
|
+
| MatMulOp
|
|
379
|
+
| GemmOp
|
|
380
|
+
| AttentionOp
|
|
381
|
+
| ConvOp
|
|
382
|
+
| AveragePoolOp
|
|
383
|
+
| BatchNormOp
|
|
384
|
+
| LpNormalizationOp
|
|
385
|
+
| InstanceNormalizationOp
|
|
386
|
+
| GroupNormalizationOp
|
|
387
|
+
| LayerNormalizationOp
|
|
388
|
+
| MeanVarianceNormalizationOp
|
|
389
|
+
| RMSNormalizationOp
|
|
390
|
+
| LrnOp
|
|
391
|
+
| LstmOp
|
|
392
|
+
| SoftmaxOp
|
|
393
|
+
| LogSoftmaxOp
|
|
394
|
+
| NegativeLogLikelihoodLossOp
|
|
395
|
+
| SoftmaxCrossEntropyLossOp
|
|
396
|
+
| MaxPoolOp
|
|
397
|
+
| ConcatOp
|
|
398
|
+
| GatherElementsOp
|
|
399
|
+
| GatherOp
|
|
400
|
+
| TransposeOp
|
|
401
|
+
| ConstantOfShapeOp
|
|
402
|
+
| ReshapeOp
|
|
403
|
+
| SliceOp
|
|
404
|
+
| ResizeOp
|
|
405
|
+
| ReduceOp
|
|
406
|
+
| ArgReduceOp
|
|
407
|
+
| ShapeOp
|
|
408
|
+
| PadOp
|
|
409
|
+
| ExpandOp
|
|
410
|
+
| CumSumOp
|
|
411
|
+
| RangeOp
|
|
412
|
+
| SplitOp
|
|
413
|
+
| WhereOp
|
|
414
|
+
] = []
|
|
415
|
+
node_infos: list[NodeInfo] = []
|
|
416
|
+
for node in graph.nodes:
|
|
417
|
+
lowering = resolve_dispatch(
|
|
418
|
+
node.op_type,
|
|
419
|
+
get_lowering_registry(),
|
|
420
|
+
binary_types=BINARY_OP_TYPES,
|
|
421
|
+
unary_types=UNARY_OP_TYPES,
|
|
422
|
+
binary_fallback=lambda: _lower_binary_unary,
|
|
423
|
+
unary_fallback=lambda: _lower_binary_unary,
|
|
424
|
+
)
|
|
425
|
+
ops.append(lowering(graph, node))
|
|
426
|
+
node_infos.append(
|
|
427
|
+
NodeInfo(
|
|
428
|
+
op_type=node.op_type,
|
|
429
|
+
name=node.name,
|
|
430
|
+
inputs=tuple(node.inputs),
|
|
431
|
+
outputs=tuple(node.outputs),
|
|
432
|
+
attrs=dict(node.attrs),
|
|
433
|
+
)
|
|
434
|
+
)
|
|
435
|
+
return ops, node_infos
|
|
436
|
+
|
|
437
|
+
def _build_header(self, model: onnx.ModelProto, graph: Graph) -> ModelHeader:
|
|
438
|
+
metadata_props = tuple(
|
|
439
|
+
(prop.key, prop.value) for prop in model.metadata_props
|
|
440
|
+
)
|
|
441
|
+
opset_imports = tuple(
|
|
442
|
+
(opset.domain, opset.version) for opset in model.opset_import
|
|
443
|
+
)
|
|
444
|
+
checksum = self._options.model_checksum
|
|
445
|
+
if checksum is None:
|
|
446
|
+
checksum = hashlib.sha256(model.SerializeToString()).hexdigest()
|
|
447
|
+
return ModelHeader(
|
|
448
|
+
generator="Generated by emmtrix ONNX-to-C Code Generator (emx-onnx-cgen)",
|
|
449
|
+
command_line=self._options.command_line,
|
|
450
|
+
model_checksum=checksum,
|
|
451
|
+
model_name=self._options.model_name,
|
|
452
|
+
graph_name=model.graph.name or None,
|
|
453
|
+
description=model.doc_string or None,
|
|
454
|
+
graph_description=model.graph.doc_string or None,
|
|
455
|
+
producer_name=model.producer_name or None,
|
|
456
|
+
producer_version=model.producer_version or None,
|
|
457
|
+
domain=model.domain or None,
|
|
458
|
+
model_version=model.model_version or None,
|
|
459
|
+
ir_version=model.ir_version or None,
|
|
460
|
+
opset_imports=opset_imports,
|
|
461
|
+
metadata_props=metadata_props,
|
|
462
|
+
input_count=len(graph.inputs),
|
|
463
|
+
output_count=len(graph.outputs),
|
|
464
|
+
node_count=len(graph.nodes),
|
|
465
|
+
initializer_count=len(graph.initializers),
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def run(
|
|
469
|
+
self, model: onnx.ModelProto, feeds: Mapping[str, np.ndarray]
|
|
470
|
+
) -> dict[str, np.ndarray]:
|
|
471
|
+
graph = import_onnx(model)
|
|
472
|
+
evaluator = Evaluator(graph)
|
|
473
|
+
return evaluator.run(feeds)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
|
|
477
|
+
constants: list[ConstTensor] = []
|
|
478
|
+
for initializer in graph.initializers:
|
|
479
|
+
dtype = ensure_supported_dtype(initializer.type.dtype)
|
|
480
|
+
constants.append(
|
|
481
|
+
ConstTensor(
|
|
482
|
+
name=initializer.name,
|
|
483
|
+
shape=initializer.type.shape,
|
|
484
|
+
data=tuple(
|
|
485
|
+
dtype.np_dtype.type(value)
|
|
486
|
+
for value in initializer.data.ravel()
|
|
487
|
+
),
|
|
488
|
+
dtype=dtype,
|
|
489
|
+
)
|
|
490
|
+
)
|
|
491
|
+
return tuple(constants)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
495
|
+
if node.op_type == "BitShift":
|
|
496
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
497
|
+
raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
|
|
498
|
+
direction_attr = node.attrs.get("direction", "LEFT")
|
|
499
|
+
if isinstance(direction_attr, bytes):
|
|
500
|
+
direction = direction_attr.decode()
|
|
501
|
+
else:
|
|
502
|
+
direction = str(direction_attr)
|
|
503
|
+
if direction not in {"LEFT", "RIGHT"}:
|
|
504
|
+
raise UnsupportedOpError(
|
|
505
|
+
"BitShift direction must be LEFT or RIGHT"
|
|
506
|
+
)
|
|
507
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
508
|
+
if not op_dtype.is_integer:
|
|
509
|
+
raise UnsupportedOpError("BitShift expects integer inputs")
|
|
510
|
+
function = (
|
|
511
|
+
ScalarFunction.BITWISE_LEFT_SHIFT
|
|
512
|
+
if direction == "LEFT"
|
|
513
|
+
else ScalarFunction.BITWISE_RIGHT_SHIFT
|
|
514
|
+
)
|
|
515
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
516
|
+
if op_spec is None:
|
|
517
|
+
raise UnsupportedOpError("Unsupported op BitShift")
|
|
518
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
519
|
+
return BinaryOp(
|
|
520
|
+
input0=node.inputs[0],
|
|
521
|
+
input1=node.inputs[1],
|
|
522
|
+
output=node.outputs[0],
|
|
523
|
+
function=function,
|
|
524
|
+
operator_kind=op_spec.kind,
|
|
525
|
+
shape=output_shape,
|
|
526
|
+
dtype=op_dtype,
|
|
527
|
+
input_dtype=op_dtype,
|
|
528
|
+
)
|
|
529
|
+
if node.op_type == "Mod":
|
|
530
|
+
fmod = int(node.attrs.get("fmod", 0))
|
|
531
|
+
if fmod not in {0, 1}:
|
|
532
|
+
raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
|
|
533
|
+
function = (
|
|
534
|
+
ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
|
|
535
|
+
)
|
|
536
|
+
else:
|
|
537
|
+
try:
|
|
538
|
+
function = ScalarFunction.from_onnx_op(node.op_type)
|
|
539
|
+
except ScalarFunctionError as exc:
|
|
540
|
+
raise UnsupportedOpError(
|
|
541
|
+
f"Unsupported op {node.op_type}"
|
|
542
|
+
) from exc
|
|
543
|
+
validate_unary_attrs(node.op_type, node.attrs)
|
|
544
|
+
if function in COMPARE_FUNCTIONS:
|
|
545
|
+
input_dtype = node_dtype(graph, node, *node.inputs)
|
|
546
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
547
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
|
|
548
|
+
if op_spec is None:
|
|
549
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
550
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
551
|
+
raise UnsupportedOpError(
|
|
552
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
553
|
+
)
|
|
554
|
+
if output_dtype != ScalarType.BOOL:
|
|
555
|
+
raise UnsupportedOpError(
|
|
556
|
+
f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
|
|
557
|
+
)
|
|
558
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
559
|
+
return BinaryOp(
|
|
560
|
+
input0=node.inputs[0],
|
|
561
|
+
input1=node.inputs[1],
|
|
562
|
+
output=node.outputs[0],
|
|
563
|
+
function=function,
|
|
564
|
+
operator_kind=op_spec.kind,
|
|
565
|
+
shape=output_shape,
|
|
566
|
+
dtype=output_dtype,
|
|
567
|
+
input_dtype=input_dtype,
|
|
568
|
+
)
|
|
569
|
+
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
570
|
+
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
571
|
+
unary_symbol = unary_op_symbol(function, dtype=op_dtype)
|
|
572
|
+
if op_spec is None and unary_symbol is None:
|
|
573
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
574
|
+
if op_spec is not None:
|
|
575
|
+
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
576
|
+
raise UnsupportedOpError(
|
|
577
|
+
f"{node.op_type} must have 2 inputs and 1 output"
|
|
578
|
+
)
|
|
579
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
580
|
+
return BinaryOp(
|
|
581
|
+
input0=node.inputs[0],
|
|
582
|
+
input1=node.inputs[1],
|
|
583
|
+
output=node.outputs[0],
|
|
584
|
+
function=function,
|
|
585
|
+
operator_kind=op_spec.kind,
|
|
586
|
+
shape=output_shape,
|
|
587
|
+
dtype=op_dtype,
|
|
588
|
+
input_dtype=op_dtype,
|
|
589
|
+
)
|
|
590
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
591
|
+
raise UnsupportedOpError(f"{node.op_type} must have 1 input and 1 output")
|
|
592
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
593
|
+
return UnaryOp(
|
|
594
|
+
input0=node.inputs[0],
|
|
595
|
+
output=node.outputs[0],
|
|
596
|
+
function=function,
|
|
597
|
+
shape=output_shape,
|
|
598
|
+
dtype=op_dtype,
|
|
599
|
+
input_dtype=op_dtype,
|
|
600
|
+
params=(),
|
|
601
|
+
)
|
emx_onnx_cgen/dtypes.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import onnx
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarFunctionError, ScalarType
|
|
6
|
+
|
|
7
|
+
from .errors import UnsupportedOpError
|
|
8
|
+
|
|
9
|
+
ONNX_TO_SCALAR_TYPE: dict[int, ScalarType] = {
|
|
10
|
+
onnx.TensorProto.FLOAT16: ScalarType.F16,
|
|
11
|
+
onnx.TensorProto.FLOAT: ScalarType.F32,
|
|
12
|
+
onnx.TensorProto.DOUBLE: ScalarType.F64,
|
|
13
|
+
onnx.TensorProto.BOOL: ScalarType.BOOL,
|
|
14
|
+
onnx.TensorProto.UINT8: ScalarType.U8,
|
|
15
|
+
onnx.TensorProto.UINT16: ScalarType.U16,
|
|
16
|
+
onnx.TensorProto.UINT32: ScalarType.U32,
|
|
17
|
+
onnx.TensorProto.UINT64: ScalarType.U64,
|
|
18
|
+
onnx.TensorProto.INT8: ScalarType.I8,
|
|
19
|
+
onnx.TensorProto.INT16: ScalarType.I16,
|
|
20
|
+
onnx.TensorProto.INT32: ScalarType.I32,
|
|
21
|
+
onnx.TensorProto.INT64: ScalarType.I64,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def scalar_type_from_onnx(elem_type: int) -> ScalarType | None:
|
|
26
|
+
return ONNX_TO_SCALAR_TYPE.get(elem_type)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def dtype_info(dtype: ScalarType | int | str) -> ScalarType:
|
|
30
|
+
if isinstance(dtype, ScalarType):
|
|
31
|
+
return dtype
|
|
32
|
+
if isinstance(dtype, int):
|
|
33
|
+
scalar = scalar_type_from_onnx(dtype)
|
|
34
|
+
if scalar is None:
|
|
35
|
+
raise UnsupportedOpError(f"Unsupported ONNX dtype enum: {dtype}")
|
|
36
|
+
return scalar
|
|
37
|
+
try:
|
|
38
|
+
return ScalarType.from_onnx_name(dtype)
|
|
39
|
+
except ScalarFunctionError:
|
|
40
|
+
raise UnsupportedOpError(f"Unsupported dtype: {dtype}") from None
|
emx_onnx_cgen/errors.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
class CompilerError(RuntimeError):
|
|
2
|
+
"""Base error for compiler failures."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class UnsupportedOpError(CompilerError):
|
|
6
|
+
"""Raised when an ONNX operator is not supported."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ShapeInferenceError(CompilerError):
|
|
10
|
+
"""Raised when tensor shapes cannot be resolved."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CodegenError(CompilerError):
|
|
14
|
+
"""Raised when C code generation fails."""
|