emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +340 -59
- emx_onnx_cgen/codegen/c_emitter.py +2369 -111
- emx_onnx_cgen/compiler.py +188 -5
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/lowering/common.py +379 -2
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/gather_elements.py +1 -3
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +6 -5
- emx_onnx_cgen/lowering/logsoftmax.py +5 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/matmul.py +6 -7
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/reduce.py +5 -6
- emx_onnx_cgen/lowering/reshape.py +223 -51
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/softmax.py +5 -1
- emx_onnx_cgen/lowering/squeeze.py +5 -5
- emx_onnx_cgen/lowering/topk.py +116 -0
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +5 -5
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +460 -42
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +61 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
- shared/scalar_functions.py +49 -17
- shared/ulp.py +48 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/compiler.py
CHANGED
|
@@ -10,6 +10,7 @@ import onnx
|
|
|
10
10
|
|
|
11
11
|
from shared.scalar_types import ScalarType
|
|
12
12
|
|
|
13
|
+
from .onnxruntime_utils import make_deterministic_session_options
|
|
13
14
|
from .codegen.c_emitter import (
|
|
14
15
|
AttentionOp,
|
|
15
16
|
AveragePoolOp,
|
|
@@ -27,18 +28,26 @@ from .codegen.c_emitter import (
|
|
|
27
28
|
CEmitter,
|
|
28
29
|
ConstTensor,
|
|
29
30
|
ConvOp,
|
|
31
|
+
ConvTransposeOp,
|
|
30
32
|
ConcatOp,
|
|
31
33
|
ConstantOfShapeOp,
|
|
32
34
|
CumSumOp,
|
|
33
35
|
GemmOp,
|
|
34
36
|
GatherOp,
|
|
35
37
|
GatherElementsOp,
|
|
38
|
+
GatherNDOp,
|
|
39
|
+
ScatterNDOp,
|
|
36
40
|
ExpandOp,
|
|
37
41
|
RangeOp,
|
|
42
|
+
OneHotOp,
|
|
43
|
+
LpPoolOp,
|
|
44
|
+
QuantizeLinearOp,
|
|
38
45
|
LrnOp,
|
|
39
46
|
LstmOp,
|
|
40
47
|
LogSoftmaxOp,
|
|
48
|
+
HardmaxOp,
|
|
41
49
|
NegativeLogLikelihoodLossOp,
|
|
50
|
+
NonZeroOp,
|
|
42
51
|
NodeInfo,
|
|
43
52
|
PadOp,
|
|
44
53
|
SplitOp,
|
|
@@ -52,6 +61,7 @@ from .codegen.c_emitter import (
|
|
|
52
61
|
ReshapeOp,
|
|
53
62
|
ResizeOp,
|
|
54
63
|
GridSampleOp,
|
|
64
|
+
HardmaxOp,
|
|
55
65
|
SoftmaxOp,
|
|
56
66
|
ShapeOp,
|
|
57
67
|
SliceOp,
|
|
@@ -61,12 +71,13 @@ from .codegen.c_emitter import (
|
|
|
61
71
|
)
|
|
62
72
|
from .dtypes import dtype_info
|
|
63
73
|
from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
|
|
64
|
-
from .ir.model import Graph, Value
|
|
74
|
+
from .ir.model import Graph, TensorType, Value
|
|
65
75
|
from .lowering.attention import AttentionSpec, resolve_attention_spec
|
|
66
76
|
from .lowering.average_pool import (
|
|
67
77
|
lower_average_pool,
|
|
68
78
|
lower_global_average_pool,
|
|
69
79
|
)
|
|
80
|
+
from .lowering import global_max_pool as _global_max_pool # noqa: F401
|
|
70
81
|
from .lowering.batch_normalization import lower_batch_normalization
|
|
71
82
|
from .lowering.cast import lower_cast
|
|
72
83
|
from .lowering.concat import lower_concat
|
|
@@ -78,25 +89,33 @@ from .lowering.common import (
|
|
|
78
89
|
value_shape,
|
|
79
90
|
)
|
|
80
91
|
from .lowering.conv import ConvSpec, resolve_conv_spec
|
|
92
|
+
from .lowering import conv_transpose as _conv_transpose # noqa: F401
|
|
81
93
|
from .lowering.constant_of_shape import lower_constant_of_shape
|
|
82
94
|
from .lowering.dropout import lower_dropout
|
|
83
95
|
from .lowering import cumsum as _cumsum # noqa: F401
|
|
96
|
+
from .lowering import einsum as _einsum # noqa: F401
|
|
84
97
|
from .lowering.flatten import lower_flatten
|
|
85
98
|
from .lowering.gather import lower_gather
|
|
86
99
|
from .lowering.gather_elements import lower_gather_elements
|
|
100
|
+
from .lowering.gather_nd import lower_gather_nd
|
|
101
|
+
from .lowering import scatter_nd as _scatter_nd # noqa: F401
|
|
87
102
|
from .lowering.gemm import resolve_gemm_spec, validate_gemm_bias_shape
|
|
88
103
|
from .lowering.lrn import LrnSpec, resolve_lrn_spec
|
|
89
104
|
from .lowering.logsoftmax import lower_logsoftmax
|
|
105
|
+
from .lowering import hardmax as _hardmax # noqa: F401
|
|
90
106
|
from .lowering import group_normalization as _group_normalization # noqa: F401
|
|
91
107
|
from .lowering import instance_normalization as _instance_normalization # noqa: F401
|
|
92
108
|
from .lowering import layer_normalization as _layer_normalization # noqa: F401
|
|
93
109
|
from .lowering import lp_normalization as _lp_normalization # noqa: F401
|
|
110
|
+
from .lowering import lp_pool as _lp_pool # noqa: F401
|
|
94
111
|
from .lowering import mean_variance_normalization as _mean_variance_normalization # noqa: F401
|
|
95
112
|
from .lowering.negative_log_likelihood_loss import (
|
|
96
113
|
lower_negative_log_likelihood_loss,
|
|
97
114
|
)
|
|
115
|
+
from .lowering import nonzero as _nonzero # noqa: F401
|
|
98
116
|
from .lowering.expand import lower_expand
|
|
99
117
|
from .lowering.range import lower_range
|
|
118
|
+
from .lowering import one_hot as _one_hot # noqa: F401
|
|
100
119
|
from .lowering.split import lower_split
|
|
101
120
|
from .lowering.softmax_cross_entropy_loss import (
|
|
102
121
|
lower_softmax_cross_entropy_loss,
|
|
@@ -109,15 +128,18 @@ from .lowering.reduce import (
|
|
|
109
128
|
REDUCE_OUTPUTS_FLOAT_ONLY,
|
|
110
129
|
)
|
|
111
130
|
from .lowering import arg_reduce as _arg_reduce # noqa: F401
|
|
131
|
+
from .lowering import topk as _topk # noqa: F401
|
|
112
132
|
from .lowering.reshape import lower_reshape
|
|
113
133
|
from .lowering.resize import lower_resize
|
|
114
134
|
from .lowering.grid_sample import lower_grid_sample
|
|
135
|
+
from .lowering import quantize_linear as _quantize_linear # noqa: F401
|
|
115
136
|
from .lowering.slice import lower_slice
|
|
116
137
|
from .lowering.squeeze import lower_squeeze
|
|
117
138
|
from .lowering import depth_space as _depth_space # noqa: F401
|
|
118
139
|
from .lowering import eye_like as _eye_like # noqa: F401
|
|
119
140
|
from .lowering import identity as _identity # noqa: F401
|
|
120
141
|
from .lowering import tile as _tile # noqa: F401
|
|
142
|
+
from .lowering import trilu as _trilu # noqa: F401
|
|
121
143
|
from .lowering.shape import lower_shape
|
|
122
144
|
from .lowering.size import lower_size
|
|
123
145
|
from .lowering.softmax import lower_softmax
|
|
@@ -157,6 +179,16 @@ class CompilerOptions:
|
|
|
157
179
|
model_checksum: str | None = None
|
|
158
180
|
restrict_arrays: bool = True
|
|
159
181
|
testbench_inputs: Mapping[str, np.ndarray] | None = None
|
|
182
|
+
truncate_weights_after: int | None = None
|
|
183
|
+
large_temp_threshold_bytes: int = 1024
|
|
184
|
+
large_weight_threshold: int = 1024
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _onnx_elem_type(dtype: np.dtype) -> int:
|
|
188
|
+
for elem_type, info in onnx._mapping.TENSOR_TYPE_MAP.items():
|
|
189
|
+
if info.np_dtype == dtype:
|
|
190
|
+
return elem_type
|
|
191
|
+
raise UnsupportedOpError(f"Unsupported dtype {dtype} for ONNX output")
|
|
160
192
|
|
|
161
193
|
|
|
162
194
|
class Compiler:
|
|
@@ -165,11 +197,16 @@ class Compiler:
|
|
|
165
197
|
options = CompilerOptions(template_dir=Path("templates"))
|
|
166
198
|
self._options = options
|
|
167
199
|
self._emitter = CEmitter(
|
|
168
|
-
options.template_dir,
|
|
200
|
+
options.template_dir,
|
|
201
|
+
restrict_arrays=options.restrict_arrays,
|
|
202
|
+
truncate_weights_after=options.truncate_weights_after,
|
|
203
|
+
large_temp_threshold_bytes=options.large_temp_threshold_bytes,
|
|
204
|
+
large_weight_threshold=options.large_weight_threshold,
|
|
169
205
|
)
|
|
170
206
|
|
|
171
207
|
def compile(self, model: onnx.ModelProto) -> str:
|
|
172
208
|
graph = import_onnx(model)
|
|
209
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
173
210
|
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
174
211
|
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
175
212
|
graph
|
|
@@ -185,6 +222,7 @@ class Compiler:
|
|
|
185
222
|
|
|
186
223
|
def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
|
|
187
224
|
graph = import_onnx(model)
|
|
225
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
188
226
|
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
189
227
|
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
190
228
|
graph
|
|
@@ -198,6 +236,46 @@ class Compiler:
|
|
|
198
236
|
variable_dim_outputs=variable_dim_outputs,
|
|
199
237
|
)
|
|
200
238
|
|
|
239
|
+
def compile_with_weight_data(
|
|
240
|
+
self, model: onnx.ModelProto
|
|
241
|
+
) -> tuple[str, bytes | None]:
|
|
242
|
+
graph = import_onnx(model)
|
|
243
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
244
|
+
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
245
|
+
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
246
|
+
graph
|
|
247
|
+
)
|
|
248
|
+
lowered = self._lower_model(model, graph)
|
|
249
|
+
generated = self._emitter.emit_model(
|
|
250
|
+
lowered,
|
|
251
|
+
emit_testbench=self._options.emit_testbench,
|
|
252
|
+
testbench_inputs=testbench_inputs,
|
|
253
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
254
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
255
|
+
)
|
|
256
|
+
weight_data = self._emitter.collect_weight_data(lowered.constants)
|
|
257
|
+
return generated, weight_data
|
|
258
|
+
|
|
259
|
+
def compile_with_data_file_and_weight_data(
|
|
260
|
+
self, model: onnx.ModelProto
|
|
261
|
+
) -> tuple[str, str, bytes | None]:
|
|
262
|
+
graph = import_onnx(model)
|
|
263
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
264
|
+
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
265
|
+
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
266
|
+
graph
|
|
267
|
+
)
|
|
268
|
+
lowered = self._lower_model(model, graph)
|
|
269
|
+
generated, data_source = self._emitter.emit_model_with_data_file(
|
|
270
|
+
lowered,
|
|
271
|
+
emit_testbench=self._options.emit_testbench,
|
|
272
|
+
testbench_inputs=testbench_inputs,
|
|
273
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
274
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
275
|
+
)
|
|
276
|
+
weight_data = self._emitter.collect_weight_data(lowered.constants)
|
|
277
|
+
return generated, data_source, weight_data
|
|
278
|
+
|
|
201
279
|
@staticmethod
|
|
202
280
|
def _collect_variable_dims(
|
|
203
281
|
graph: Graph,
|
|
@@ -282,15 +360,93 @@ class Compiler:
|
|
|
282
360
|
resolved[name] = tuple(array.ravel().tolist())
|
|
283
361
|
return resolved
|
|
284
362
|
|
|
363
|
+
def _concretize_graph_shapes(
|
|
364
|
+
self, model: onnx.ModelProto, graph: Graph
|
|
365
|
+
) -> Graph:
|
|
366
|
+
if not self._options.testbench_inputs:
|
|
367
|
+
return graph
|
|
368
|
+
if not any(value.type.dim_params for value in graph.values):
|
|
369
|
+
if not any(value.type.dim_params for value in graph.inputs):
|
|
370
|
+
if not any(value.type.dim_params for value in graph.outputs):
|
|
371
|
+
return graph
|
|
372
|
+
try:
|
|
373
|
+
import onnxruntime as ort
|
|
374
|
+
except Exception:
|
|
375
|
+
return graph
|
|
376
|
+
try:
|
|
377
|
+
model_with_outputs = onnx.ModelProto()
|
|
378
|
+
model_with_outputs.CopyFrom(model)
|
|
379
|
+
existing_outputs = {
|
|
380
|
+
output.name for output in model_with_outputs.graph.output
|
|
381
|
+
}
|
|
382
|
+
value_info_by_name = {
|
|
383
|
+
value_info.name: value_info
|
|
384
|
+
for value_info in model_with_outputs.graph.value_info
|
|
385
|
+
}
|
|
386
|
+
for value in graph.values:
|
|
387
|
+
if value.name in existing_outputs:
|
|
388
|
+
continue
|
|
389
|
+
value_info = value_info_by_name.get(value.name)
|
|
390
|
+
if value_info is None:
|
|
391
|
+
dims: list[int | str | None] = []
|
|
392
|
+
for index, dim in enumerate(value.type.shape):
|
|
393
|
+
dim_param = None
|
|
394
|
+
if index < len(value.type.dim_params):
|
|
395
|
+
dim_param = value.type.dim_params[index]
|
|
396
|
+
dims.append(dim_param if dim_param else None)
|
|
397
|
+
elem_type = _onnx_elem_type(value.type.dtype.np_dtype)
|
|
398
|
+
value_info = onnx.helper.make_tensor_value_info(
|
|
399
|
+
value.name, elem_type, dims
|
|
400
|
+
)
|
|
401
|
+
model_with_outputs.graph.output.append(value_info)
|
|
402
|
+
existing_outputs.add(value.name)
|
|
403
|
+
output_names = [output.name for output in model_with_outputs.graph.output]
|
|
404
|
+
sess_options = make_deterministic_session_options(ort)
|
|
405
|
+
sess = ort.InferenceSession(
|
|
406
|
+
model_with_outputs.SerializeToString(),
|
|
407
|
+
sess_options=sess_options,
|
|
408
|
+
providers=["CPUExecutionProvider"],
|
|
409
|
+
)
|
|
410
|
+
output_arrays = sess.run(None, self._options.testbench_inputs)
|
|
411
|
+
except Exception:
|
|
412
|
+
return graph
|
|
413
|
+
|
|
414
|
+
shapes_by_name: dict[str, tuple[int, ...]] = {
|
|
415
|
+
name: tuple(int(dim) for dim in array.shape)
|
|
416
|
+
for name, array in zip(output_names, output_arrays)
|
|
417
|
+
}
|
|
418
|
+
for name, array in self._options.testbench_inputs.items():
|
|
419
|
+
shapes_by_name[name] = tuple(int(dim) for dim in array.shape)
|
|
420
|
+
|
|
421
|
+
def concretize_value(value: Value) -> Value:
|
|
422
|
+
shape = shapes_by_name.get(value.name)
|
|
423
|
+
if shape is None:
|
|
424
|
+
return value
|
|
425
|
+
return Value(
|
|
426
|
+
name=value.name,
|
|
427
|
+
type=TensorType(
|
|
428
|
+
dtype=value.type.dtype,
|
|
429
|
+
shape=shape,
|
|
430
|
+
dim_params=(None,) * len(shape),
|
|
431
|
+
),
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
return Graph(
|
|
435
|
+
inputs=tuple(concretize_value(value) for value in graph.inputs),
|
|
436
|
+
outputs=tuple(concretize_value(value) for value in graph.outputs),
|
|
437
|
+
nodes=graph.nodes,
|
|
438
|
+
initializers=graph.initializers,
|
|
439
|
+
values=tuple(concretize_value(value) for value in graph.values),
|
|
440
|
+
opset_imports=graph.opset_imports,
|
|
441
|
+
)
|
|
442
|
+
|
|
285
443
|
def _validate_graph(self, graph: Graph) -> None:
|
|
286
444
|
if not graph.outputs:
|
|
287
445
|
raise UnsupportedOpError("Graph must have at least one output")
|
|
288
446
|
if not graph.nodes:
|
|
289
447
|
raise UnsupportedOpError("Graph must contain at least one node")
|
|
290
448
|
for value in graph.outputs:
|
|
291
|
-
|
|
292
|
-
if element_count <= 0:
|
|
293
|
-
raise ShapeInferenceError("Output shape must be fully defined")
|
|
449
|
+
shape_product(value.type.shape)
|
|
294
450
|
|
|
295
451
|
def _collect_io_specs(
|
|
296
452
|
self, graph: Graph
|
|
@@ -330,11 +486,14 @@ class Compiler:
|
|
|
330
486
|
| UnaryOp
|
|
331
487
|
| ClipOp
|
|
332
488
|
| CastOp
|
|
489
|
+
| QuantizeLinearOp
|
|
333
490
|
| MatMulOp
|
|
334
491
|
| GemmOp
|
|
335
492
|
| AttentionOp
|
|
336
493
|
| ConvOp
|
|
494
|
+
| ConvTransposeOp
|
|
337
495
|
| AveragePoolOp
|
|
496
|
+
| LpPoolOp
|
|
338
497
|
| BatchNormOp
|
|
339
498
|
| LpNormalizationOp
|
|
340
499
|
| InstanceNormalizationOp
|
|
@@ -346,12 +505,15 @@ class Compiler:
|
|
|
346
505
|
| LstmOp
|
|
347
506
|
| SoftmaxOp
|
|
348
507
|
| LogSoftmaxOp
|
|
508
|
+
| HardmaxOp
|
|
349
509
|
| NegativeLogLikelihoodLossOp
|
|
350
510
|
| SoftmaxCrossEntropyLossOp
|
|
351
511
|
| MaxPoolOp
|
|
352
512
|
| ConcatOp
|
|
353
513
|
| GatherElementsOp
|
|
354
514
|
| GatherOp
|
|
515
|
+
| GatherNDOp
|
|
516
|
+
| ScatterNDOp
|
|
355
517
|
| TransposeOp
|
|
356
518
|
| ConstantOfShapeOp
|
|
357
519
|
| ReshapeOp
|
|
@@ -362,9 +524,11 @@ class Compiler:
|
|
|
362
524
|
| ArgReduceOp
|
|
363
525
|
| ShapeOp
|
|
364
526
|
| PadOp
|
|
527
|
+
| NonZeroOp
|
|
365
528
|
| ExpandOp
|
|
366
529
|
| CumSumOp
|
|
367
530
|
| RangeOp
|
|
531
|
+
| OneHotOp
|
|
368
532
|
| SplitOp
|
|
369
533
|
],
|
|
370
534
|
list[NodeInfo],
|
|
@@ -375,11 +539,14 @@ class Compiler:
|
|
|
375
539
|
| UnaryOp
|
|
376
540
|
| ClipOp
|
|
377
541
|
| CastOp
|
|
542
|
+
| QuantizeLinearOp
|
|
378
543
|
| MatMulOp
|
|
379
544
|
| GemmOp
|
|
380
545
|
| AttentionOp
|
|
381
546
|
| ConvOp
|
|
547
|
+
| ConvTransposeOp
|
|
382
548
|
| AveragePoolOp
|
|
549
|
+
| LpPoolOp
|
|
383
550
|
| BatchNormOp
|
|
384
551
|
| LpNormalizationOp
|
|
385
552
|
| InstanceNormalizationOp
|
|
@@ -391,12 +558,14 @@ class Compiler:
|
|
|
391
558
|
| LstmOp
|
|
392
559
|
| SoftmaxOp
|
|
393
560
|
| LogSoftmaxOp
|
|
561
|
+
| HardmaxOp
|
|
394
562
|
| NegativeLogLikelihoodLossOp
|
|
395
563
|
| SoftmaxCrossEntropyLossOp
|
|
396
564
|
| MaxPoolOp
|
|
397
565
|
| ConcatOp
|
|
398
566
|
| GatherElementsOp
|
|
399
567
|
| GatherOp
|
|
568
|
+
| GatherNDOp
|
|
400
569
|
| TransposeOp
|
|
401
570
|
| ConstantOfShapeOp
|
|
402
571
|
| ReshapeOp
|
|
@@ -406,9 +575,11 @@ class Compiler:
|
|
|
406
575
|
| ArgReduceOp
|
|
407
576
|
| ShapeOp
|
|
408
577
|
| PadOp
|
|
578
|
+
| NonZeroOp
|
|
409
579
|
| ExpandOp
|
|
410
580
|
| CumSumOp
|
|
411
581
|
| RangeOp
|
|
582
|
+
| OneHotOp
|
|
412
583
|
| SplitOp
|
|
413
584
|
| WhereOp
|
|
414
585
|
] = []
|
|
@@ -515,6 +686,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
|
515
686
|
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
516
687
|
if op_spec is None:
|
|
517
688
|
raise UnsupportedOpError("Unsupported op BitShift")
|
|
689
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
690
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
518
691
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
519
692
|
return BinaryOp(
|
|
520
693
|
input0=node.inputs[0],
|
|
@@ -522,6 +695,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
|
522
695
|
output=node.outputs[0],
|
|
523
696
|
function=function,
|
|
524
697
|
operator_kind=op_spec.kind,
|
|
698
|
+
input0_shape=input0_shape,
|
|
699
|
+
input1_shape=input1_shape,
|
|
525
700
|
shape=output_shape,
|
|
526
701
|
dtype=op_dtype,
|
|
527
702
|
input_dtype=op_dtype,
|
|
@@ -555,6 +730,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
|
555
730
|
raise UnsupportedOpError(
|
|
556
731
|
f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
|
|
557
732
|
)
|
|
733
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
734
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
558
735
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
559
736
|
return BinaryOp(
|
|
560
737
|
input0=node.inputs[0],
|
|
@@ -562,6 +739,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
|
562
739
|
output=node.outputs[0],
|
|
563
740
|
function=function,
|
|
564
741
|
operator_kind=op_spec.kind,
|
|
742
|
+
input0_shape=input0_shape,
|
|
743
|
+
input1_shape=input1_shape,
|
|
565
744
|
shape=output_shape,
|
|
566
745
|
dtype=output_dtype,
|
|
567
746
|
input_dtype=input_dtype,
|
|
@@ -576,6 +755,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
|
576
755
|
raise UnsupportedOpError(
|
|
577
756
|
f"{node.op_type} must have 2 inputs and 1 output"
|
|
578
757
|
)
|
|
758
|
+
input0_shape = value_shape(graph, node.inputs[0], node)
|
|
759
|
+
input1_shape = value_shape(graph, node.inputs[1], node)
|
|
579
760
|
output_shape = value_shape(graph, node.outputs[0], node)
|
|
580
761
|
return BinaryOp(
|
|
581
762
|
input0=node.inputs[0],
|
|
@@ -583,6 +764,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
|
583
764
|
output=node.outputs[0],
|
|
584
765
|
function=function,
|
|
585
766
|
operator_kind=op_spec.kind,
|
|
767
|
+
input0_shape=input0_shape,
|
|
768
|
+
input1_shape=input1_shape,
|
|
586
769
|
shape=output_shape,
|
|
587
770
|
dtype=op_dtype,
|
|
588
771
|
input_dtype=op_dtype,
|
emx_onnx_cgen/ir/model.py
CHANGED
|
@@ -44,6 +44,7 @@ class Graph:
|
|
|
44
44
|
nodes: tuple[Node, ...]
|
|
45
45
|
initializers: tuple[Initializer, ...]
|
|
46
46
|
values: tuple[Value, ...] = ()
|
|
47
|
+
opset_imports: tuple[tuple[str, int], ...] = ()
|
|
47
48
|
|
|
48
49
|
def find_value(self, name: str) -> Value:
|
|
49
50
|
for value in self.inputs + self.outputs + self.values:
|