emx-onnx-cgen 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/codegen/c_emitter.py +61 -56
- emx_onnx_cgen/compiler.py +39 -1
- emx_onnx_cgen/ir/op_base.py +312 -2
- emx_onnx_cgen/ir/ops/__init__.py +10 -1
- emx_onnx_cgen/ir/ops/elementwise.py +30 -9
- emx_onnx_cgen/ir/ops/misc.py +98 -15
- emx_onnx_cgen/lowering/expand.py +3 -137
- emx_onnx_cgen/lowering/gather.py +2 -31
- emx_onnx_cgen/lowering/variadic.py +21 -54
- emx_onnx_cgen/runtime/evaluator.py +8 -1
- {emx_onnx_cgen-0.3.3.dist-info → emx_onnx_cgen-0.3.5.dist-info}/METADATA +1 -1
- {emx_onnx_cgen-0.3.3.dist-info → emx_onnx_cgen-0.3.5.dist-info}/RECORD +17 -17
- {emx_onnx_cgen-0.3.3.dist-info → emx_onnx_cgen-0.3.5.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.3.3.dist-info → emx_onnx_cgen-0.3.5.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.3.dist-info → emx_onnx_cgen-0.3.5.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/_build_info.py
CHANGED
emx_onnx_cgen/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.5'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 5)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -954,13 +954,13 @@ class CEmitter:
|
|
|
954
954
|
)
|
|
955
955
|
if isinstance(op, MultiInputBinaryOp):
|
|
956
956
|
return MultiInputBinaryOp(
|
|
957
|
+
op_type=op.op_type,
|
|
957
958
|
inputs=tuple(name_map.get(name, name) for name in op.inputs),
|
|
958
959
|
output=name_map.get(op.output, op.output),
|
|
959
960
|
function=op.function,
|
|
960
961
|
operator_kind=op.operator_kind,
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
input_dtype=op.input_dtype,
|
|
962
|
+
min_inputs=op.min_inputs,
|
|
963
|
+
max_inputs=op.max_inputs,
|
|
964
964
|
)
|
|
965
965
|
if isinstance(op, WhereOp):
|
|
966
966
|
return WhereOp(
|
|
@@ -1554,12 +1554,7 @@ class CEmitter:
|
|
|
1554
1554
|
data=name_map.get(op.data, op.data),
|
|
1555
1555
|
indices=name_map.get(op.indices, op.indices),
|
|
1556
1556
|
output=name_map.get(op.output, op.output),
|
|
1557
|
-
data_shape=op.data_shape,
|
|
1558
|
-
indices_shape=op.indices_shape,
|
|
1559
|
-
output_shape=op.output_shape,
|
|
1560
1557
|
axis=op.axis,
|
|
1561
|
-
dtype=op.dtype,
|
|
1562
|
-
indices_dtype=op.indices_dtype,
|
|
1563
1558
|
)
|
|
1564
1559
|
if isinstance(op, GatherNDOp):
|
|
1565
1560
|
return GatherNDOp(
|
|
@@ -1896,13 +1891,8 @@ class CEmitter:
|
|
|
1896
1891
|
if isinstance(op, ExpandOp):
|
|
1897
1892
|
return ExpandOp(
|
|
1898
1893
|
input0=name_map.get(op.input0, op.input0),
|
|
1899
|
-
output=name_map.get(op.output, op.output),
|
|
1900
1894
|
input_shape=op.input_shape,
|
|
1901
|
-
|
|
1902
|
-
input_shape_padded=op.input_shape_padded,
|
|
1903
|
-
input_strides=op.input_strides,
|
|
1904
|
-
dtype=op.dtype,
|
|
1905
|
-
input_dtype=op.input_dtype,
|
|
1895
|
+
output=name_map.get(op.output, op.output),
|
|
1906
1896
|
)
|
|
1907
1897
|
if isinstance(op, CumSumOp):
|
|
1908
1898
|
return CumSumOp(
|
|
@@ -2233,8 +2223,8 @@ class CEmitter:
|
|
|
2233
2223
|
):
|
|
2234
2224
|
testbench_math_include.add("#include <math.h>")
|
|
2235
2225
|
includes = self._collect_includes(
|
|
2236
|
-
|
|
2237
|
-
|
|
2226
|
+
original_model,
|
|
2227
|
+
list(original_model.ops),
|
|
2238
2228
|
emit_testbench=emit_testbench,
|
|
2239
2229
|
extra_includes=scalar_includes | testbench_math_include,
|
|
2240
2230
|
needs_weight_loader=bool(large_constants),
|
|
@@ -2380,8 +2370,8 @@ class CEmitter:
|
|
|
2380
2370
|
):
|
|
2381
2371
|
testbench_math_include.add("#include <math.h>")
|
|
2382
2372
|
includes = self._collect_includes(
|
|
2383
|
-
|
|
2384
|
-
|
|
2373
|
+
original_model,
|
|
2374
|
+
list(original_model.ops),
|
|
2385
2375
|
emit_testbench=emit_testbench,
|
|
2386
2376
|
extra_includes=scalar_includes | testbench_math_include,
|
|
2387
2377
|
needs_weight_loader=bool(large_constants),
|
|
@@ -2790,8 +2780,17 @@ class CEmitter:
|
|
|
2790
2780
|
*(const.dtype for const in model.constants),
|
|
2791
2781
|
*constant_of_shape_inputs,
|
|
2792
2782
|
}
|
|
2783
|
+
def _resolved_output_dtype(op: OpBase) -> ScalarType:
|
|
2784
|
+
if isinstance(op, MultiInputBinaryOp):
|
|
2785
|
+
return model.op_context.dtype(op.inputs[0])
|
|
2786
|
+
if isinstance(op, GatherOp):
|
|
2787
|
+
return model.op_context.dtype(op.data)
|
|
2788
|
+
if isinstance(op, ExpandOp):
|
|
2789
|
+
return model.op_context.dtype(op.input0)
|
|
2790
|
+
return op.dtype
|
|
2791
|
+
|
|
2793
2792
|
model_dtypes.update(
|
|
2794
|
-
op
|
|
2793
|
+
_resolved_output_dtype(op)
|
|
2795
2794
|
for op in resolved_ops
|
|
2796
2795
|
if not isinstance(op, (ArgReduceOp, TopKOp))
|
|
2797
2796
|
)
|
|
@@ -3875,13 +3874,13 @@ class CEmitter:
|
|
|
3875
3874
|
)
|
|
3876
3875
|
if isinstance(op, MultiInputBinaryOp):
|
|
3877
3876
|
return MultiInputBinaryOp(
|
|
3877
|
+
op_type=op.op_type,
|
|
3878
3878
|
inputs=tuple(temp_map.get(name, name) for name in op.inputs),
|
|
3879
3879
|
output=temp_map.get(op.output, op.output),
|
|
3880
3880
|
function=op.function,
|
|
3881
3881
|
operator_kind=op.operator_kind,
|
|
3882
|
-
|
|
3883
|
-
|
|
3884
|
-
input_dtype=op.input_dtype,
|
|
3882
|
+
min_inputs=op.min_inputs,
|
|
3883
|
+
max_inputs=op.max_inputs,
|
|
3885
3884
|
)
|
|
3886
3885
|
if isinstance(op, WhereOp):
|
|
3887
3886
|
return WhereOp(
|
|
@@ -4545,11 +4544,6 @@ class CEmitter:
|
|
|
4545
4544
|
indices=temp_map.get(op.indices, op.indices),
|
|
4546
4545
|
output=temp_map.get(op.output, op.output),
|
|
4547
4546
|
axis=op.axis,
|
|
4548
|
-
data_shape=op.data_shape,
|
|
4549
|
-
indices_shape=op.indices_shape,
|
|
4550
|
-
output_shape=op.output_shape,
|
|
4551
|
-
dtype=op.dtype,
|
|
4552
|
-
indices_dtype=op.indices_dtype,
|
|
4553
4547
|
)
|
|
4554
4548
|
if isinstance(op, GatherNDOp):
|
|
4555
4549
|
return GatherNDOp(
|
|
@@ -4674,13 +4668,8 @@ class CEmitter:
|
|
|
4674
4668
|
if isinstance(op, ExpandOp):
|
|
4675
4669
|
return ExpandOp(
|
|
4676
4670
|
input0=temp_map.get(op.input0, op.input0),
|
|
4677
|
-
output=temp_map.get(op.output, op.output),
|
|
4678
4671
|
input_shape=op.input_shape,
|
|
4679
|
-
|
|
4680
|
-
input_shape_padded=op.input_shape_padded,
|
|
4681
|
-
input_strides=op.input_strides,
|
|
4682
|
-
dtype=op.dtype,
|
|
4683
|
-
input_dtype=op.input_dtype,
|
|
4672
|
+
output=temp_map.get(op.output, op.output),
|
|
4684
4673
|
)
|
|
4685
4674
|
if isinstance(op, CumSumOp):
|
|
4686
4675
|
return CumSumOp(
|
|
@@ -7446,30 +7435,34 @@ class CEmitter:
|
|
|
7446
7435
|
("output", op.output),
|
|
7447
7436
|
]
|
|
7448
7437
|
)
|
|
7449
|
-
|
|
7450
|
-
|
|
7451
|
-
|
|
7452
|
-
|
|
7438
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
7439
|
+
output_shape = CEmitter._codegen_shape(output_shape_raw)
|
|
7440
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
7441
|
+
output_loop_vars = loop_vars if output_shape_raw else ()
|
|
7442
|
+
indices_shape = self._ctx_shape(op.indices)
|
|
7443
|
+
indices_rank = len(indices_shape)
|
|
7453
7444
|
if indices_rank == 0:
|
|
7454
7445
|
indices_indices = ("0",)
|
|
7455
7446
|
else:
|
|
7456
|
-
|
|
7457
|
-
|
|
7458
|
-
|
|
7447
|
+
axis = int(self._derived(op, "axis"))
|
|
7448
|
+
indices_indices = output_loop_vars[axis : axis + indices_rank]
|
|
7449
|
+
axis = int(self._derived(op, "axis"))
|
|
7459
7450
|
data_indices = [
|
|
7460
|
-
*output_loop_vars[:
|
|
7451
|
+
*output_loop_vars[:axis],
|
|
7461
7452
|
"gather_index",
|
|
7462
|
-
*output_loop_vars[
|
|
7453
|
+
*output_loop_vars[axis + indices_rank :],
|
|
7463
7454
|
]
|
|
7464
|
-
|
|
7465
|
-
|
|
7466
|
-
|
|
7455
|
+
data_shape = self._ctx_shape(op.data)
|
|
7456
|
+
data_suffix = self._param_array_suffix(data_shape)
|
|
7457
|
+
indices_suffix = self._param_array_suffix(indices_shape)
|
|
7458
|
+
output_suffix = self._param_array_suffix(output_shape_raw)
|
|
7459
|
+
indices_dtype = self._ctx_dtype(op.indices)
|
|
7467
7460
|
param_decls = self._build_param_decls(
|
|
7468
7461
|
[
|
|
7469
7462
|
(params["data"], c_type, data_suffix, True),
|
|
7470
7463
|
(
|
|
7471
7464
|
params["indices"],
|
|
7472
|
-
|
|
7465
|
+
indices_dtype.c_type,
|
|
7473
7466
|
indices_suffix,
|
|
7474
7467
|
True,
|
|
7475
7468
|
),
|
|
@@ -7484,7 +7477,7 @@ class CEmitter:
|
|
|
7484
7477
|
output=params["output"],
|
|
7485
7478
|
params=param_decls,
|
|
7486
7479
|
c_type=c_type,
|
|
7487
|
-
indices_c_type=
|
|
7480
|
+
indices_c_type=indices_dtype.c_type,
|
|
7488
7481
|
data_suffix=data_suffix,
|
|
7489
7482
|
indices_suffix=indices_suffix,
|
|
7490
7483
|
output_suffix=output_suffix,
|
|
@@ -7492,7 +7485,7 @@ class CEmitter:
|
|
|
7492
7485
|
loop_vars=loop_vars,
|
|
7493
7486
|
indices_indices=indices_indices,
|
|
7494
7487
|
data_indices=data_indices,
|
|
7495
|
-
axis_dim=
|
|
7488
|
+
axis_dim=data_shape[axis],
|
|
7496
7489
|
).rstrip()
|
|
7497
7490
|
return with_node_comment(rendered)
|
|
7498
7491
|
if isinstance(op, GatherNDOp):
|
|
@@ -9139,15 +9132,17 @@ class CEmitter:
|
|
|
9139
9132
|
[("input0", op.input0), ("output", op.output)]
|
|
9140
9133
|
)
|
|
9141
9134
|
output_dim_names = _dim_names_for(op.output)
|
|
9135
|
+
output_shape_raw = self._ctx_shape(op.output)
|
|
9142
9136
|
output_shape = CEmitter._shape_dim_exprs(
|
|
9143
|
-
|
|
9137
|
+
output_shape_raw, output_dim_names
|
|
9144
9138
|
)
|
|
9145
|
-
loop_vars = CEmitter._loop_vars(
|
|
9139
|
+
loop_vars = CEmitter._loop_vars(output_shape_raw)
|
|
9140
|
+
input_shape = self._ctx_shape(op.input0)
|
|
9146
9141
|
input_suffix = self._param_array_suffix(
|
|
9147
|
-
|
|
9142
|
+
input_shape, _dim_names_for(op.input0)
|
|
9148
9143
|
)
|
|
9149
9144
|
output_suffix = self._param_array_suffix(
|
|
9150
|
-
|
|
9145
|
+
output_shape_raw, output_dim_names
|
|
9151
9146
|
)
|
|
9152
9147
|
param_decls = self._build_param_decls(
|
|
9153
9148
|
[
|
|
@@ -9155,10 +9150,12 @@ class CEmitter:
|
|
|
9155
9150
|
(params["output"], c_type, output_suffix, False),
|
|
9156
9151
|
]
|
|
9157
9152
|
)
|
|
9153
|
+
input_shape_padded = self._derived(op, "input_shape_padded")
|
|
9154
|
+
input_strides = self._derived(op, "input_strides")
|
|
9158
9155
|
input_index_terms = [
|
|
9159
9156
|
f"{loop_var} * {stride}"
|
|
9160
9157
|
for loop_var, input_dim, stride in zip(
|
|
9161
|
-
loop_vars,
|
|
9158
|
+
loop_vars, input_shape_padded, input_strides
|
|
9162
9159
|
)
|
|
9163
9160
|
if input_dim != 1
|
|
9164
9161
|
]
|
|
@@ -10442,7 +10439,13 @@ class CEmitter:
|
|
|
10442
10439
|
)
|
|
10443
10440
|
if isinstance(op, NonMaxSuppressionOp):
|
|
10444
10441
|
return ((op.output, op.output_shape, op.output_dtype),)
|
|
10445
|
-
return (
|
|
10442
|
+
return (
|
|
10443
|
+
(
|
|
10444
|
+
op.output,
|
|
10445
|
+
self._op_output_shape(op),
|
|
10446
|
+
self._op_output_dtype(op),
|
|
10447
|
+
),
|
|
10448
|
+
)
|
|
10446
10449
|
|
|
10447
10450
|
def _op_output_shape(
|
|
10448
10451
|
self,
|
|
@@ -10566,7 +10569,7 @@ class CEmitter:
|
|
|
10566
10569
|
if isinstance(op, GatherElementsOp):
|
|
10567
10570
|
return op.output_shape
|
|
10568
10571
|
if isinstance(op, GatherOp):
|
|
10569
|
-
return op.
|
|
10572
|
+
return self._ctx_shape(op.output)
|
|
10570
10573
|
if isinstance(op, GatherNDOp):
|
|
10571
10574
|
return op.output_shape
|
|
10572
10575
|
if isinstance(op, ScatterNDOp):
|
|
@@ -10614,7 +10617,7 @@ class CEmitter:
|
|
|
10614
10617
|
if isinstance(op, NonMaxSuppressionOp):
|
|
10615
10618
|
return op.output_shape
|
|
10616
10619
|
if isinstance(op, ExpandOp):
|
|
10617
|
-
return op.
|
|
10620
|
+
return self._ctx_shape(op.output)
|
|
10618
10621
|
if isinstance(op, CumSumOp):
|
|
10619
10622
|
return op.input_shape
|
|
10620
10623
|
if isinstance(op, RangeOp):
|
|
@@ -10700,10 +10703,12 @@ class CEmitter:
|
|
|
10700
10703
|
SoftmaxOp,
|
|
10701
10704
|
LogSoftmaxOp,
|
|
10702
10705
|
HardmaxOp,
|
|
10706
|
+
GatherOp,
|
|
10703
10707
|
TransposeOp,
|
|
10704
10708
|
ReshapeOp,
|
|
10705
10709
|
IdentityOp,
|
|
10706
10710
|
ReduceOp,
|
|
10711
|
+
ExpandOp,
|
|
10707
10712
|
),
|
|
10708
10713
|
):
|
|
10709
10714
|
return self._ctx_dtype(op.output)
|
emx_onnx_cgen/compiler.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from dataclasses import dataclass
|
|
3
|
+
from dataclasses import dataclass, fields
|
|
4
4
|
import hashlib
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Mapping
|
|
@@ -24,6 +24,7 @@ from .ir.context import GraphContext
|
|
|
24
24
|
from .ir.model import Graph, TensorType, Value
|
|
25
25
|
from .ir.op_base import OpBase
|
|
26
26
|
from .ir.op_context import OpContext
|
|
27
|
+
from .ir.ops import ExpandOp, GatherOp, MultiInputBinaryOp
|
|
27
28
|
from .lowering import load_lowering_registry
|
|
28
29
|
from .lowering.common import ensure_supported_dtype, shape_product, value_dtype
|
|
29
30
|
from .lowering.registry import get_lowering_registry
|
|
@@ -172,6 +173,43 @@ class Compiler:
|
|
|
172
173
|
) = self._collect_io_specs(graph)
|
|
173
174
|
ops, node_infos = self._lower_nodes(ctx)
|
|
174
175
|
op_ctx = OpContext(ctx)
|
|
176
|
+
for op, node_info in zip(ops, node_infos):
|
|
177
|
+
field_names = {field.name for field in fields(op)}
|
|
178
|
+
if "dtype" in field_names:
|
|
179
|
+
dtype = getattr(op, "dtype")
|
|
180
|
+
for field in fields(op):
|
|
181
|
+
if not field.name.startswith("output"):
|
|
182
|
+
continue
|
|
183
|
+
value = getattr(op, field.name)
|
|
184
|
+
if isinstance(value, str):
|
|
185
|
+
op_ctx.set_dtype(value, dtype)
|
|
186
|
+
for name in node_info.outputs:
|
|
187
|
+
op_ctx.set_dtype(name, dtype)
|
|
188
|
+
if "outputs" in field_names:
|
|
189
|
+
dtype = getattr(op, "dtype", None)
|
|
190
|
+
if dtype is not None:
|
|
191
|
+
for name in getattr(op, "outputs"):
|
|
192
|
+
op_ctx.set_dtype(name, dtype)
|
|
193
|
+
if "output_dtype" in field_names and "output" in field_names:
|
|
194
|
+
output_name = getattr(op, "output")
|
|
195
|
+
if isinstance(output_name, str):
|
|
196
|
+
op_ctx.set_dtype(output_name, getattr(op, "output_dtype"))
|
|
197
|
+
if "output_values_dtype" in field_names:
|
|
198
|
+
op_ctx.set_dtype(
|
|
199
|
+
getattr(op, "output_values"),
|
|
200
|
+
getattr(op, "output_values_dtype"),
|
|
201
|
+
)
|
|
202
|
+
if "output_indices_dtype" in field_names:
|
|
203
|
+
op_ctx.set_dtype(
|
|
204
|
+
getattr(op, "output_indices"),
|
|
205
|
+
getattr(op, "output_indices_dtype"),
|
|
206
|
+
)
|
|
207
|
+
if isinstance(op, MultiInputBinaryOp) and op.inputs:
|
|
208
|
+
op_ctx.set_dtype(op.output, op_ctx.dtype(op.inputs[0]))
|
|
209
|
+
if isinstance(op, GatherOp):
|
|
210
|
+
op_ctx.set_dtype(op.output, op_ctx.dtype(op.data))
|
|
211
|
+
if isinstance(op, ExpandOp):
|
|
212
|
+
op_ctx.set_dtype(op.output, op_ctx.dtype(op.input0))
|
|
175
213
|
for op in ops:
|
|
176
214
|
op.validate(op_ctx)
|
|
177
215
|
for op in ops:
|
emx_onnx_cgen/ir/op_base.py
CHANGED
|
@@ -88,7 +88,10 @@ class ElementwiseOpBase(RenderableOpBase):
|
|
|
88
88
|
raise UnsupportedOpError(
|
|
89
89
|
f"{self.kind} expects matching input dtypes, got {dtype_names}"
|
|
90
90
|
)
|
|
91
|
-
|
|
91
|
+
try:
|
|
92
|
+
output_dtype = ctx.dtype(self._elementwise_output())
|
|
93
|
+
except ShapeInferenceError:
|
|
94
|
+
return None
|
|
92
95
|
if self._elementwise_compare():
|
|
93
96
|
if output_dtype != ScalarType.BOOL:
|
|
94
97
|
raise UnsupportedOpError(
|
|
@@ -107,7 +110,25 @@ class ElementwiseOpBase(RenderableOpBase):
|
|
|
107
110
|
output_name = self._elementwise_output()
|
|
108
111
|
for name in input_names:
|
|
109
112
|
ctx.dtype(name)
|
|
110
|
-
|
|
113
|
+
desired_dtype = (
|
|
114
|
+
ScalarType.BOOL if self._elementwise_compare() else None
|
|
115
|
+
)
|
|
116
|
+
if desired_dtype is None:
|
|
117
|
+
data_inputs = self._elementwise_data_inputs()
|
|
118
|
+
if data_inputs:
|
|
119
|
+
desired_dtype = ctx.dtype(data_inputs[0])
|
|
120
|
+
try:
|
|
121
|
+
output_dtype = ctx.dtype(output_name)
|
|
122
|
+
except ShapeInferenceError:
|
|
123
|
+
if desired_dtype is not None:
|
|
124
|
+
ctx.set_dtype(output_name, desired_dtype)
|
|
125
|
+
return None
|
|
126
|
+
raise
|
|
127
|
+
if desired_dtype is not None and output_dtype != desired_dtype:
|
|
128
|
+
raise UnsupportedOpError(
|
|
129
|
+
f"{self.kind} expects output dtype {desired_dtype.onnx_name}, "
|
|
130
|
+
f"got {output_dtype.onnx_name}"
|
|
131
|
+
)
|
|
111
132
|
|
|
112
133
|
def infer_shapes(self, ctx: OpContext) -> None:
|
|
113
134
|
input_names = self._elementwise_inputs()
|
|
@@ -121,6 +142,295 @@ class ElementwiseOpBase(RenderableOpBase):
|
|
|
121
142
|
return None
|
|
122
143
|
|
|
123
144
|
|
|
145
|
+
class GatherLikeOpBase(RenderableOpBase):
|
|
146
|
+
def _gather_data(self) -> str:
|
|
147
|
+
raise NotImplementedError
|
|
148
|
+
|
|
149
|
+
def _gather_indices(self) -> str:
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
def _gather_output(self) -> str:
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
|
|
155
|
+
def _gather_axis(self) -> int:
|
|
156
|
+
raise NotImplementedError
|
|
157
|
+
|
|
158
|
+
def _gather_mode(self) -> str:
|
|
159
|
+
raise NotImplementedError
|
|
160
|
+
|
|
161
|
+
def validate(self, ctx: OpContext) -> None:
|
|
162
|
+
indices_dtype = ctx.dtype(self._gather_indices())
|
|
163
|
+
if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
164
|
+
raise UnsupportedOpError(
|
|
165
|
+
f"{self.kind} indices must be int32 or int64, "
|
|
166
|
+
f"got {indices_dtype.onnx_name}"
|
|
167
|
+
)
|
|
168
|
+
data_shape = ctx.shape(self._gather_data())
|
|
169
|
+
if self._gather_mode() in {"gather", "gather_elements"}:
|
|
170
|
+
if not data_shape:
|
|
171
|
+
raise ShapeInferenceError(
|
|
172
|
+
f"{self.kind} does not support scalar inputs"
|
|
173
|
+
)
|
|
174
|
+
axis = self._gather_axis()
|
|
175
|
+
if axis < 0:
|
|
176
|
+
axis += len(data_shape)
|
|
177
|
+
if axis < 0 or axis >= len(data_shape):
|
|
178
|
+
raise ShapeInferenceError(
|
|
179
|
+
f"{self.kind} axis {axis} is out of range for rank "
|
|
180
|
+
f"{len(data_shape)}"
|
|
181
|
+
)
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
185
|
+
data_dtype = ctx.dtype(self._gather_data())
|
|
186
|
+
try:
|
|
187
|
+
output_dtype = ctx.dtype(self._gather_output())
|
|
188
|
+
except ShapeInferenceError:
|
|
189
|
+
ctx.set_dtype(self._gather_output(), data_dtype)
|
|
190
|
+
output_dtype = data_dtype
|
|
191
|
+
if output_dtype != data_dtype:
|
|
192
|
+
raise UnsupportedOpError(
|
|
193
|
+
f"{self.kind} expects output dtype {data_dtype.onnx_name}, "
|
|
194
|
+
f"got {output_dtype.onnx_name}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
198
|
+
data_shape = ctx.shape(self._gather_data())
|
|
199
|
+
indices_shape = ctx.shape(self._gather_indices())
|
|
200
|
+
axis = self._gather_axis()
|
|
201
|
+
if axis < 0:
|
|
202
|
+
axis += len(data_shape)
|
|
203
|
+
if axis < 0 or axis >= len(data_shape):
|
|
204
|
+
raise ShapeInferenceError(
|
|
205
|
+
f"{self.kind} axis {axis} is out of range for rank "
|
|
206
|
+
f"{len(data_shape)}"
|
|
207
|
+
)
|
|
208
|
+
if self._gather_mode() == "gather":
|
|
209
|
+
output_shape = (
|
|
210
|
+
data_shape[:axis] + indices_shape + data_shape[axis + 1 :]
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
raise UnsupportedOpError(
|
|
214
|
+
f"{self.kind} does not support gather mode "
|
|
215
|
+
f"{self._gather_mode()}"
|
|
216
|
+
)
|
|
217
|
+
try:
|
|
218
|
+
expected = ctx.shape(self._gather_output())
|
|
219
|
+
except ShapeInferenceError:
|
|
220
|
+
expected = None
|
|
221
|
+
if expected is not None and expected != output_shape:
|
|
222
|
+
raise ShapeInferenceError(
|
|
223
|
+
f"{self.kind} output shape must be {output_shape}, got {expected}"
|
|
224
|
+
)
|
|
225
|
+
ctx.set_shape(self._gather_output(), output_shape)
|
|
226
|
+
ctx.set_derived(self, "axis", axis)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class ShapeLikeOpBase(RenderableOpBase):
|
|
230
|
+
def _shape_data(self) -> str:
|
|
231
|
+
raise NotImplementedError
|
|
232
|
+
|
|
233
|
+
def _shape_output(self) -> str:
|
|
234
|
+
raise NotImplementedError
|
|
235
|
+
|
|
236
|
+
def _shape_spec(self, ctx: OpContext) -> tuple[int, ...]:
|
|
237
|
+
raise NotImplementedError
|
|
238
|
+
|
|
239
|
+
def _shape_mode(self) -> str:
|
|
240
|
+
raise NotImplementedError
|
|
241
|
+
|
|
242
|
+
def _shape_derived(
|
|
243
|
+
self,
|
|
244
|
+
ctx: OpContext,
|
|
245
|
+
*,
|
|
246
|
+
data_shape: tuple[int, ...],
|
|
247
|
+
target_shape: tuple[int, ...],
|
|
248
|
+
output_shape: tuple[int, ...],
|
|
249
|
+
) -> None:
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
@staticmethod
|
|
253
|
+
def _validate_static_dims(shape: tuple[int, ...], kind: str) -> None:
|
|
254
|
+
if any(dim < 0 for dim in shape):
|
|
255
|
+
raise ShapeInferenceError(
|
|
256
|
+
f"{kind} does not support dynamic dims"
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
@staticmethod
|
|
260
|
+
def _broadcast_shape(
|
|
261
|
+
input_shape: tuple[int, ...],
|
|
262
|
+
target_shape: tuple[int, ...],
|
|
263
|
+
*,
|
|
264
|
+
kind: str,
|
|
265
|
+
) -> tuple[int, ...]:
|
|
266
|
+
ShapeLikeOpBase._validate_static_dims(input_shape, kind)
|
|
267
|
+
ShapeLikeOpBase._validate_static_dims(target_shape, kind)
|
|
268
|
+
output_rank = max(len(input_shape), len(target_shape))
|
|
269
|
+
input_padded = (1,) * (output_rank - len(input_shape)) + input_shape
|
|
270
|
+
target_padded = (1,) * (output_rank - len(target_shape)) + target_shape
|
|
271
|
+
result: list[int] = []
|
|
272
|
+
for input_dim, target_dim in zip(input_padded, target_padded):
|
|
273
|
+
if input_dim == 1:
|
|
274
|
+
result.append(target_dim)
|
|
275
|
+
elif target_dim == 1:
|
|
276
|
+
result.append(input_dim)
|
|
277
|
+
elif input_dim == target_dim:
|
|
278
|
+
result.append(input_dim)
|
|
279
|
+
else:
|
|
280
|
+
raise ShapeInferenceError(
|
|
281
|
+
f"{kind} input shape {input_shape} is not "
|
|
282
|
+
f"broadcastable to {target_shape}"
|
|
283
|
+
)
|
|
284
|
+
return tuple(result)
|
|
285
|
+
|
|
286
|
+
def validate(self, ctx: OpContext) -> None:
|
|
287
|
+
data_shape = ctx.shape(self._shape_data())
|
|
288
|
+
target_shape = self._shape_spec(ctx)
|
|
289
|
+
if self._shape_mode() == "expand":
|
|
290
|
+
self._broadcast_shape(
|
|
291
|
+
data_shape, target_shape, kind=self.kind
|
|
292
|
+
)
|
|
293
|
+
return None
|
|
294
|
+
|
|
295
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
296
|
+
input_dtype = ctx.dtype(self._shape_data())
|
|
297
|
+
try:
|
|
298
|
+
output_dtype = ctx.dtype(self._shape_output())
|
|
299
|
+
except ShapeInferenceError:
|
|
300
|
+
ctx.set_dtype(self._shape_output(), input_dtype)
|
|
301
|
+
output_dtype = input_dtype
|
|
302
|
+
if output_dtype != input_dtype:
|
|
303
|
+
raise UnsupportedOpError(
|
|
304
|
+
f"{self.kind} expects output dtype {input_dtype.onnx_name}, "
|
|
305
|
+
f"got {output_dtype.onnx_name}"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
309
|
+
data_shape = ctx.shape(self._shape_data())
|
|
310
|
+
target_shape = self._shape_spec(ctx)
|
|
311
|
+
if self._shape_mode() == "expand":
|
|
312
|
+
output_shape = self._broadcast_shape(
|
|
313
|
+
data_shape, target_shape, kind=self.kind
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
output_shape = target_shape
|
|
317
|
+
try:
|
|
318
|
+
expected = ctx.shape(self._shape_output())
|
|
319
|
+
except ShapeInferenceError:
|
|
320
|
+
expected = None
|
|
321
|
+
if expected is not None and expected != output_shape:
|
|
322
|
+
raise ShapeInferenceError(
|
|
323
|
+
f"{self.kind} output shape must be {output_shape}, got {expected}"
|
|
324
|
+
)
|
|
325
|
+
ctx.set_shape(self._shape_output(), output_shape)
|
|
326
|
+
self._shape_derived(
|
|
327
|
+
ctx,
|
|
328
|
+
data_shape=data_shape,
|
|
329
|
+
target_shape=target_shape,
|
|
330
|
+
output_shape=output_shape,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class VariadicLikeOpBase(RenderableOpBase):
|
|
335
|
+
def _variadic_inputs(self) -> tuple[str, ...]:
|
|
336
|
+
raise NotImplementedError
|
|
337
|
+
|
|
338
|
+
def _variadic_output(self) -> str:
|
|
339
|
+
raise NotImplementedError
|
|
340
|
+
|
|
341
|
+
def _variadic_kind(self) -> str:
|
|
342
|
+
return self.kind
|
|
343
|
+
|
|
344
|
+
def _variadic_min_inputs(self) -> int:
|
|
345
|
+
return 2
|
|
346
|
+
|
|
347
|
+
def _variadic_max_inputs(self) -> int | None:
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
def _variadic_compare(self) -> bool:
|
|
351
|
+
return False
|
|
352
|
+
|
|
353
|
+
def _variadic_supports_dtype(self, dtype: ScalarType) -> bool:
|
|
354
|
+
return True
|
|
355
|
+
|
|
356
|
+
def validate(self, ctx: OpContext) -> None:
|
|
357
|
+
inputs = self._variadic_inputs()
|
|
358
|
+
if any(not name for name in inputs):
|
|
359
|
+
raise UnsupportedOpError(
|
|
360
|
+
f"{self._variadic_kind()} input must be provided"
|
|
361
|
+
)
|
|
362
|
+
min_inputs = self._variadic_min_inputs()
|
|
363
|
+
max_inputs = self._variadic_max_inputs()
|
|
364
|
+
if len(inputs) < min_inputs:
|
|
365
|
+
raise UnsupportedOpError(
|
|
366
|
+
f"{self._variadic_kind()} must have at least {min_inputs} inputs"
|
|
367
|
+
)
|
|
368
|
+
if max_inputs is not None and len(inputs) != max_inputs:
|
|
369
|
+
raise UnsupportedOpError(
|
|
370
|
+
f"{self._variadic_kind()} must have exactly {max_inputs} inputs"
|
|
371
|
+
)
|
|
372
|
+
input_dtypes = tuple(ctx.dtype(name) for name in inputs)
|
|
373
|
+
if any(dtype != input_dtypes[0] for dtype in input_dtypes[1:]):
|
|
374
|
+
dtype_names = ", ".join(
|
|
375
|
+
dtype.onnx_name for dtype in input_dtypes
|
|
376
|
+
)
|
|
377
|
+
raise UnsupportedOpError(
|
|
378
|
+
f"{self._variadic_kind()} expects matching input dtypes, "
|
|
379
|
+
f"got {dtype_names}"
|
|
380
|
+
)
|
|
381
|
+
try:
|
|
382
|
+
output_dtype = ctx.dtype(self._variadic_output())
|
|
383
|
+
except ShapeInferenceError:
|
|
384
|
+
output_dtype = None
|
|
385
|
+
if output_dtype is not None:
|
|
386
|
+
if self._variadic_compare():
|
|
387
|
+
if output_dtype != ScalarType.BOOL:
|
|
388
|
+
raise UnsupportedOpError(
|
|
389
|
+
f"{self._variadic_kind()} expects bool output, "
|
|
390
|
+
f"got {output_dtype.onnx_name}"
|
|
391
|
+
)
|
|
392
|
+
elif output_dtype != input_dtypes[0]:
|
|
393
|
+
raise UnsupportedOpError(
|
|
394
|
+
f"{self._variadic_kind()} expects output dtype "
|
|
395
|
+
f"{input_dtypes[0].onnx_name}, got {output_dtype.onnx_name}"
|
|
396
|
+
)
|
|
397
|
+
if not self._variadic_supports_dtype(input_dtypes[0]):
|
|
398
|
+
raise UnsupportedOpError(
|
|
399
|
+
f"{self._variadic_kind()} does not support dtype "
|
|
400
|
+
f"{input_dtypes[0].onnx_name}"
|
|
401
|
+
)
|
|
402
|
+
return None
|
|
403
|
+
|
|
404
|
+
def infer_types(self, ctx: OpContext) -> None:
|
|
405
|
+
for name in self._variadic_inputs():
|
|
406
|
+
ctx.dtype(name)
|
|
407
|
+
try:
|
|
408
|
+
ctx.dtype(self._variadic_output())
|
|
409
|
+
except ShapeInferenceError:
|
|
410
|
+
ctx.set_dtype(
|
|
411
|
+
self._variadic_output(),
|
|
412
|
+
ctx.dtype(self._variadic_inputs()[0]),
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def infer_shapes(self, ctx: OpContext) -> None:
|
|
416
|
+
input_shapes = tuple(ctx.shape(name) for name in self._variadic_inputs())
|
|
417
|
+
output_shape = BroadcastingOpBase.broadcast_shapes(*input_shapes)
|
|
418
|
+
for shape in input_shapes:
|
|
419
|
+
if shape != output_shape:
|
|
420
|
+
raise UnsupportedOpError(
|
|
421
|
+
f"{self._variadic_kind()} expects identical input/output shapes"
|
|
422
|
+
)
|
|
423
|
+
try:
|
|
424
|
+
expected = ctx.shape(self._variadic_output())
|
|
425
|
+
except ShapeInferenceError:
|
|
426
|
+
expected = None
|
|
427
|
+
if expected is not None and expected != output_shape:
|
|
428
|
+
raise UnsupportedOpError(
|
|
429
|
+
f"{self._variadic_kind()} expects identical input/output shapes"
|
|
430
|
+
)
|
|
431
|
+
ctx.set_shape(self._variadic_output(), output_shape)
|
|
432
|
+
|
|
433
|
+
|
|
124
434
|
class ReduceOpBase(RenderableOpBase):
|
|
125
435
|
@staticmethod
|
|
126
436
|
def normalize_axes(
|
emx_onnx_cgen/ir/ops/__init__.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
|
1
|
-
from .elementwise import
|
|
1
|
+
from .elementwise import (
|
|
2
|
+
BinaryOp,
|
|
3
|
+
ClipOp,
|
|
4
|
+
IdentityOp,
|
|
5
|
+
MultiInputBinaryOp,
|
|
6
|
+
UnaryOp,
|
|
7
|
+
VariadicOp,
|
|
8
|
+
WhereOp,
|
|
9
|
+
)
|
|
2
10
|
from .misc import (
|
|
3
11
|
CastOp,
|
|
4
12
|
ConcatOp,
|
|
@@ -126,5 +134,6 @@ __all__ = [
|
|
|
126
134
|
"TransposeOp",
|
|
127
135
|
"TriluOp",
|
|
128
136
|
"UnaryOp",
|
|
137
|
+
"VariadicOp",
|
|
129
138
|
"WhereOp",
|
|
130
139
|
]
|
|
@@ -5,8 +5,8 @@ from dataclasses import dataclass
|
|
|
5
5
|
from shared.scalar_functions import ScalarFunction
|
|
6
6
|
from shared.scalar_types import ScalarType
|
|
7
7
|
|
|
8
|
-
from ...ops import COMPARE_FUNCTIONS, OperatorKind
|
|
9
|
-
from ..op_base import ElementwiseOpBase
|
|
8
|
+
from ...ops import COMPARE_FUNCTIONS, OperatorKind, binary_op_symbol
|
|
9
|
+
from ..op_base import ElementwiseOpBase, VariadicLikeOpBase
|
|
10
10
|
from ..op_context import OpContext
|
|
11
11
|
|
|
12
12
|
|
|
@@ -34,24 +34,45 @@ class BinaryOp(ElementwiseOpBase):
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
@dataclass(frozen=True)
|
|
37
|
-
class
|
|
37
|
+
class VariadicOp(VariadicLikeOpBase):
|
|
38
|
+
op_type: str
|
|
38
39
|
inputs: tuple[str, ...]
|
|
39
40
|
output: str
|
|
40
41
|
function: ScalarFunction
|
|
41
42
|
operator_kind: OperatorKind
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
input_dtype: ScalarType
|
|
43
|
+
min_inputs: int = 2
|
|
44
|
+
max_inputs: int | None = None
|
|
45
45
|
|
|
46
|
-
def
|
|
46
|
+
def _variadic_inputs(self) -> tuple[str, ...]:
|
|
47
47
|
return self.inputs
|
|
48
48
|
|
|
49
|
-
def
|
|
49
|
+
def _variadic_output(self) -> str:
|
|
50
50
|
return self.output
|
|
51
51
|
|
|
52
|
-
def
|
|
52
|
+
def _variadic_kind(self) -> str:
|
|
53
|
+
return self.op_type
|
|
54
|
+
|
|
55
|
+
def _variadic_compare(self) -> bool:
|
|
53
56
|
return self.function in COMPARE_FUNCTIONS
|
|
54
57
|
|
|
58
|
+
def _variadic_min_inputs(self) -> int:
|
|
59
|
+
return self.min_inputs
|
|
60
|
+
|
|
61
|
+
def _variadic_max_inputs(self) -> int | None:
|
|
62
|
+
return self.max_inputs
|
|
63
|
+
|
|
64
|
+
def _variadic_supports_dtype(self, dtype: ScalarType) -> bool:
|
|
65
|
+
return (
|
|
66
|
+
binary_op_symbol(
|
|
67
|
+
self.function, dtype=dtype, validate_attrs=False
|
|
68
|
+
)
|
|
69
|
+
is not None
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MultiInputBinaryOp(VariadicOp):
|
|
74
|
+
pass
|
|
75
|
+
|
|
55
76
|
|
|
56
77
|
@dataclass(frozen=True)
|
|
57
78
|
class WhereOp(ElementwiseOpBase):
|
emx_onnx_cgen/ir/ops/misc.py
CHANGED
|
@@ -2,13 +2,29 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
5
7
|
from shared.scalar_types import ScalarType
|
|
6
8
|
|
|
7
|
-
from ...errors import ShapeInferenceError
|
|
8
|
-
from ..op_base import
|
|
9
|
+
from ...errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from ..op_base import (
|
|
11
|
+
BroadcastingOpBase,
|
|
12
|
+
GatherLikeOpBase,
|
|
13
|
+
RenderableOpBase,
|
|
14
|
+
ShapeLikeOpBase,
|
|
15
|
+
)
|
|
9
16
|
from ..op_context import OpContext
|
|
10
17
|
|
|
11
18
|
|
|
19
|
+
def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
20
|
+
strides: list[int] = []
|
|
21
|
+
stride = 1
|
|
22
|
+
for dim in reversed(shape):
|
|
23
|
+
strides.append(stride)
|
|
24
|
+
stride *= dim
|
|
25
|
+
return tuple(reversed(strides))
|
|
26
|
+
|
|
27
|
+
|
|
12
28
|
@dataclass(frozen=True)
|
|
13
29
|
class CastOp(RenderableOpBase):
|
|
14
30
|
input0: str
|
|
@@ -59,16 +75,26 @@ class GatherElementsOp(RenderableOpBase):
|
|
|
59
75
|
indices_dtype: ScalarType
|
|
60
76
|
|
|
61
77
|
@dataclass(frozen=True)
|
|
62
|
-
class GatherOp(
|
|
78
|
+
class GatherOp(GatherLikeOpBase):
|
|
63
79
|
data: str
|
|
64
80
|
indices: str
|
|
65
81
|
output: str
|
|
66
82
|
axis: int
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
83
|
+
|
|
84
|
+
def _gather_data(self) -> str:
|
|
85
|
+
return self.data
|
|
86
|
+
|
|
87
|
+
def _gather_indices(self) -> str:
|
|
88
|
+
return self.indices
|
|
89
|
+
|
|
90
|
+
def _gather_output(self) -> str:
|
|
91
|
+
return self.output
|
|
92
|
+
|
|
93
|
+
def _gather_axis(self) -> int:
|
|
94
|
+
return self.axis
|
|
95
|
+
|
|
96
|
+
def _gather_mode(self) -> str:
|
|
97
|
+
return "gather"
|
|
72
98
|
|
|
73
99
|
@dataclass(frozen=True)
|
|
74
100
|
class GatherNDOp(RenderableOpBase):
|
|
@@ -360,15 +386,72 @@ class NonMaxSuppressionOp(RenderableOpBase):
|
|
|
360
386
|
score_threshold_shape: tuple[int, ...] | None
|
|
361
387
|
|
|
362
388
|
@dataclass(frozen=True)
|
|
363
|
-
class ExpandOp(
|
|
389
|
+
class ExpandOp(ShapeLikeOpBase):
|
|
364
390
|
input0: str
|
|
391
|
+
input_shape: str
|
|
365
392
|
output: str
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
393
|
+
|
|
394
|
+
def _shape_data(self) -> str:
|
|
395
|
+
return self.input0
|
|
396
|
+
|
|
397
|
+
def _shape_output(self) -> str:
|
|
398
|
+
return self.output
|
|
399
|
+
|
|
400
|
+
def _shape_mode(self) -> str:
|
|
401
|
+
return "expand"
|
|
402
|
+
|
|
403
|
+
def _shape_spec(self, ctx: OpContext) -> tuple[int, ...]:
|
|
404
|
+
initializer = ctx.initializer(self.input_shape)
|
|
405
|
+
if initializer is not None:
|
|
406
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
407
|
+
raise UnsupportedOpError(
|
|
408
|
+
f"{self.kind} shape input must be int64 or int32"
|
|
409
|
+
)
|
|
410
|
+
if len(initializer.type.shape) != 1:
|
|
411
|
+
raise UnsupportedOpError(
|
|
412
|
+
f"{self.kind} shape input must be a 1D tensor"
|
|
413
|
+
)
|
|
414
|
+
values = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
415
|
+
if values.size == 0:
|
|
416
|
+
raise ShapeInferenceError(
|
|
417
|
+
f"{self.kind} shape input cannot be empty"
|
|
418
|
+
)
|
|
419
|
+
return tuple(int(value) for value in values)
|
|
420
|
+
dtype = ctx.dtype(self.input_shape)
|
|
421
|
+
if dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
422
|
+
raise UnsupportedOpError(
|
|
423
|
+
f"{self.kind} shape input must be int64 or int32"
|
|
424
|
+
)
|
|
425
|
+
shape = ctx.shape(self.input_shape)
|
|
426
|
+
if len(shape) != 1:
|
|
427
|
+
raise UnsupportedOpError(
|
|
428
|
+
f"{self.kind} shape input must be a 1D tensor"
|
|
429
|
+
)
|
|
430
|
+
if shape[0] <= 0:
|
|
431
|
+
raise ShapeInferenceError(
|
|
432
|
+
f"{self.kind} shape input cannot be empty"
|
|
433
|
+
)
|
|
434
|
+
output_shape = ctx.shape(self.output)
|
|
435
|
+
if not output_shape:
|
|
436
|
+
raise ShapeInferenceError(
|
|
437
|
+
f"{self.kind} output shape must be specified"
|
|
438
|
+
)
|
|
439
|
+
return output_shape
|
|
440
|
+
|
|
441
|
+
def _shape_derived(
|
|
442
|
+
self,
|
|
443
|
+
ctx: OpContext,
|
|
444
|
+
*,
|
|
445
|
+
data_shape: tuple[int, ...],
|
|
446
|
+
target_shape: tuple[int, ...],
|
|
447
|
+
output_shape: tuple[int, ...],
|
|
448
|
+
) -> None:
|
|
449
|
+
input_shape_padded = (
|
|
450
|
+
(1,) * (len(output_shape) - len(data_shape)) + data_shape
|
|
451
|
+
)
|
|
452
|
+
input_strides = _compute_strides(input_shape_padded)
|
|
453
|
+
ctx.set_derived(self, "input_shape_padded", input_shape_padded)
|
|
454
|
+
ctx.set_derived(self, "input_strides", input_strides)
|
|
372
455
|
|
|
373
456
|
@dataclass(frozen=True)
|
|
374
457
|
class CumSumOp(RenderableOpBase):
|
emx_onnx_cgen/lowering/expand.py
CHANGED
|
@@ -1,151 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from shared.scalar_types import ScalarType
|
|
6
|
-
|
|
3
|
+
from ..errors import UnsupportedOpError
|
|
4
|
+
from ..ir.model import Graph, Node
|
|
7
5
|
from ..ir.ops import ExpandOp
|
|
8
|
-
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
-
from ..ir.model import Graph, Initializer, Node
|
|
10
|
-
from ..lowering.common import value_dtype, value_shape
|
|
11
6
|
from .registry import register_lowering
|
|
12
7
|
|
|
13
8
|
|
|
14
|
-
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
15
|
-
for initializer in graph.initializers:
|
|
16
|
-
if initializer.name == name:
|
|
17
|
-
return initializer
|
|
18
|
-
return None
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def _read_shape_values(graph: Graph, name: str, node: Node) -> list[int] | None:
|
|
22
|
-
initializer = _find_initializer(graph, name)
|
|
23
|
-
if initializer is None:
|
|
24
|
-
return None
|
|
25
|
-
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
26
|
-
raise UnsupportedOpError(
|
|
27
|
-
f"{node.op_type} shape input must be int64 or int32"
|
|
28
|
-
)
|
|
29
|
-
if len(initializer.type.shape) != 1:
|
|
30
|
-
raise UnsupportedOpError(
|
|
31
|
-
f"{node.op_type} shape input must be a 1D tensor"
|
|
32
|
-
)
|
|
33
|
-
values = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
34
|
-
if values.size == 0:
|
|
35
|
-
raise ShapeInferenceError(
|
|
36
|
-
f"{node.op_type} shape input cannot be empty"
|
|
37
|
-
)
|
|
38
|
-
return [int(value) for value in values]
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def _validate_shape_input(graph: Graph, name: str, node: Node) -> None:
|
|
42
|
-
dtype = value_dtype(graph, name, node)
|
|
43
|
-
if dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
44
|
-
raise UnsupportedOpError(
|
|
45
|
-
f"{node.op_type} shape input must be int64 or int32"
|
|
46
|
-
)
|
|
47
|
-
shape = value_shape(graph, name, node)
|
|
48
|
-
if len(shape) != 1:
|
|
49
|
-
raise UnsupportedOpError(
|
|
50
|
-
f"{node.op_type} shape input must be a 1D tensor"
|
|
51
|
-
)
|
|
52
|
-
if shape[0] <= 0:
|
|
53
|
-
raise ShapeInferenceError(
|
|
54
|
-
f"{node.op_type} shape input cannot be empty"
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def _validate_static_dims(shape: tuple[int, ...], node: Node) -> None:
|
|
59
|
-
if any(dim < 0 for dim in shape):
|
|
60
|
-
raise ShapeInferenceError(
|
|
61
|
-
f"{node.op_type} does not support dynamic dims"
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def _broadcast_shape(
|
|
66
|
-
input_shape: tuple[int, ...], shape_values: list[int], node: Node
|
|
67
|
-
) -> tuple[int, ...]:
|
|
68
|
-
_validate_static_dims(input_shape, node)
|
|
69
|
-
for dim in shape_values:
|
|
70
|
-
if dim < 0:
|
|
71
|
-
raise ShapeInferenceError(
|
|
72
|
-
f"{node.op_type} does not support dynamic dims"
|
|
73
|
-
)
|
|
74
|
-
output_rank = max(len(input_shape), len(shape_values))
|
|
75
|
-
input_padded = (1,) * (output_rank - len(input_shape)) + input_shape
|
|
76
|
-
shape_padded = (1,) * (output_rank - len(shape_values)) + tuple(shape_values)
|
|
77
|
-
result: list[int] = []
|
|
78
|
-
for input_dim, shape_dim in zip(input_padded, shape_padded):
|
|
79
|
-
if input_dim == 1:
|
|
80
|
-
result.append(shape_dim)
|
|
81
|
-
elif shape_dim == 1:
|
|
82
|
-
result.append(input_dim)
|
|
83
|
-
elif input_dim == shape_dim:
|
|
84
|
-
result.append(input_dim)
|
|
85
|
-
else:
|
|
86
|
-
raise ShapeInferenceError(
|
|
87
|
-
f"{node.op_type} input shape {input_shape} is not "
|
|
88
|
-
f"broadcastable to {shape_values}"
|
|
89
|
-
)
|
|
90
|
-
return tuple(result)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
94
|
-
strides: list[int] = []
|
|
95
|
-
stride = 1
|
|
96
|
-
for dim in reversed(shape):
|
|
97
|
-
strides.append(stride)
|
|
98
|
-
stride *= dim
|
|
99
|
-
return tuple(reversed(strides))
|
|
100
|
-
|
|
101
|
-
|
|
102
9
|
@register_lowering("Expand")
|
|
103
10
|
def lower_expand(graph: Graph, node: Node) -> ExpandOp:
|
|
104
11
|
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
105
12
|
raise UnsupportedOpError("Expand must have 2 inputs and 1 output")
|
|
106
|
-
input_shape = value_shape(graph, node.inputs[0], node)
|
|
107
|
-
output_shape = value_shape(graph, node.outputs[0], node)
|
|
108
|
-
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
109
|
-
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
110
|
-
if input_dtype != output_dtype:
|
|
111
|
-
raise UnsupportedOpError(
|
|
112
|
-
f"{node.op_type} expects matching input/output dtypes, "
|
|
113
|
-
f"got {input_dtype} and {output_dtype}"
|
|
114
|
-
)
|
|
115
|
-
shape_values = _read_shape_values(graph, node.inputs[1], node)
|
|
116
|
-
if shape_values is not None:
|
|
117
|
-
expected_output_shape = _broadcast_shape(input_shape, shape_values, node)
|
|
118
|
-
_validate_static_dims(expected_output_shape, node)
|
|
119
|
-
if output_shape and output_shape != expected_output_shape:
|
|
120
|
-
raise ShapeInferenceError(
|
|
121
|
-
f"{node.op_type} output shape must be {expected_output_shape}, "
|
|
122
|
-
f"got {output_shape}"
|
|
123
|
-
)
|
|
124
|
-
else:
|
|
125
|
-
_validate_shape_input(graph, node.inputs[1], node)
|
|
126
|
-
if not output_shape:
|
|
127
|
-
raise ShapeInferenceError(
|
|
128
|
-
f"{node.op_type} output shape must be specified"
|
|
129
|
-
)
|
|
130
|
-
expected_output_shape = _broadcast_shape(
|
|
131
|
-
input_shape, list(output_shape), node
|
|
132
|
-
)
|
|
133
|
-
if expected_output_shape != output_shape:
|
|
134
|
-
raise ShapeInferenceError(
|
|
135
|
-
f"{node.op_type} output shape must be {expected_output_shape}, "
|
|
136
|
-
f"got {output_shape}"
|
|
137
|
-
)
|
|
138
|
-
input_shape_padded = (
|
|
139
|
-
(1,) * (len(expected_output_shape) - len(input_shape)) + input_shape
|
|
140
|
-
)
|
|
141
|
-
input_strides = _compute_strides(input_shape_padded)
|
|
142
13
|
return ExpandOp(
|
|
143
14
|
input0=node.inputs[0],
|
|
15
|
+
input_shape=node.inputs[1],
|
|
144
16
|
output=node.outputs[0],
|
|
145
|
-
input_shape=input_shape,
|
|
146
|
-
output_shape=expected_output_shape,
|
|
147
|
-
input_shape_padded=input_shape_padded,
|
|
148
|
-
input_strides=input_strides,
|
|
149
|
-
dtype=input_dtype,
|
|
150
|
-
input_dtype=input_dtype,
|
|
151
17
|
)
|
emx_onnx_cgen/lowering/gather.py
CHANGED
|
@@ -1,13 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from shared.scalar_types import ScalarType
|
|
4
|
-
|
|
5
3
|
from ..ir.ops import GatherOp
|
|
6
|
-
from ..errors import
|
|
4
|
+
from ..errors import UnsupportedOpError
|
|
7
5
|
from ..ir.model import Graph, Node
|
|
8
|
-
from ..validation import normalize_axis
|
|
9
|
-
from .common import value_dtype as _value_dtype
|
|
10
|
-
from .common import value_shape as _value_shape
|
|
11
6
|
from .registry import register_lowering
|
|
12
7
|
|
|
13
8
|
|
|
@@ -16,33 +11,9 @@ def lower_gather(graph: Graph, node: Node) -> GatherOp:
|
|
|
16
11
|
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
17
12
|
raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
|
|
18
13
|
data_name, indices_name = node.inputs
|
|
19
|
-
data_shape = _value_shape(graph, data_name, node)
|
|
20
|
-
indices_shape = _value_shape(graph, indices_name, node)
|
|
21
|
-
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
22
|
-
axis = normalize_axis(int(node.attrs.get("axis", 0)), data_shape, node)
|
|
23
|
-
expected_output_shape = (
|
|
24
|
-
data_shape[:axis] + indices_shape + data_shape[axis + 1 :]
|
|
25
|
-
)
|
|
26
|
-
if output_shape != expected_output_shape:
|
|
27
|
-
raise ShapeInferenceError(
|
|
28
|
-
"Gather output shape must be "
|
|
29
|
-
f"{expected_output_shape}, got {output_shape}"
|
|
30
|
-
)
|
|
31
|
-
op_dtype = _value_dtype(graph, data_name, node)
|
|
32
|
-
indices_dtype = _value_dtype(graph, indices_name, node)
|
|
33
|
-
if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
34
|
-
raise UnsupportedOpError(
|
|
35
|
-
"Gather indices must be int32 or int64, "
|
|
36
|
-
f"got {indices_dtype.onnx_name}"
|
|
37
|
-
)
|
|
38
14
|
return GatherOp(
|
|
39
15
|
data=data_name,
|
|
40
16
|
indices=indices_name,
|
|
41
17
|
output=node.outputs[0],
|
|
42
|
-
axis=axis,
|
|
43
|
-
data_shape=data_shape,
|
|
44
|
-
indices_shape=indices_shape,
|
|
45
|
-
output_shape=output_shape,
|
|
46
|
-
dtype=op_dtype,
|
|
47
|
-
indices_dtype=indices_dtype,
|
|
18
|
+
axis=int(node.attrs.get("axis", 0)),
|
|
48
19
|
)
|
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from shared.scalar_functions import ScalarFunction
|
|
4
|
-
from shared.scalar_types import ScalarType
|
|
5
4
|
|
|
6
|
-
from ..ir.ops import MultiInputBinaryOp
|
|
7
5
|
from ..errors import UnsupportedOpError
|
|
8
6
|
from ..ir.model import Graph, Node
|
|
9
|
-
from ..
|
|
7
|
+
from ..ir.ops import MultiInputBinaryOp
|
|
10
8
|
from ..lowering.registry import register_lowering
|
|
11
|
-
from ..ops import
|
|
9
|
+
from ..ops import OperatorKind
|
|
12
10
|
|
|
13
11
|
VARIADIC_OP_FUNCTIONS: dict[str, ScalarFunction] = {
|
|
14
12
|
"Sum": ScalarFunction.ADD,
|
|
@@ -32,62 +30,31 @@ BINARY_ONLY_OPS = {
|
|
|
32
30
|
"BitwiseXor",
|
|
33
31
|
}
|
|
34
32
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
raise UnsupportedOpError(
|
|
48
|
-
f"{node.op_type} must have at least 2 inputs"
|
|
49
|
-
)
|
|
50
|
-
for name in node.inputs:
|
|
51
|
-
if not name:
|
|
52
|
-
raise UnsupportedOpError(f"{node.op_type} input must be provided")
|
|
53
|
-
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
54
|
-
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
55
|
-
if op_dtype != output_dtype:
|
|
56
|
-
raise UnsupportedOpError(
|
|
57
|
-
f"{node.op_type} expects matching input/output dtypes, "
|
|
58
|
-
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
59
|
-
)
|
|
60
|
-
output_shape = value_shape(graph, node.outputs[0], node)
|
|
61
|
-
for name in node.inputs:
|
|
62
|
-
input_shape = value_shape(graph, name, node)
|
|
63
|
-
if input_shape != output_shape:
|
|
64
|
-
raise UnsupportedOpError(
|
|
65
|
-
f"{node.op_type} expects identical input/output shapes"
|
|
66
|
-
)
|
|
67
|
-
op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
|
|
68
|
-
if op_spec is None:
|
|
69
|
-
raise UnsupportedOpError(
|
|
70
|
-
f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
|
|
71
|
-
)
|
|
72
|
-
return op_dtype, output_shape
|
|
33
|
+
VARIADIC_OP_OPERATOR_KINDS: dict[str, OperatorKind] = {
|
|
34
|
+
"Sum": OperatorKind.INFIX,
|
|
35
|
+
"Mean": OperatorKind.EXPR,
|
|
36
|
+
"Max": OperatorKind.FUNC,
|
|
37
|
+
"Min": OperatorKind.FUNC,
|
|
38
|
+
"And": OperatorKind.INFIX,
|
|
39
|
+
"Or": OperatorKind.INFIX,
|
|
40
|
+
"Xor": OperatorKind.INFIX,
|
|
41
|
+
"BitwiseAnd": OperatorKind.INFIX,
|
|
42
|
+
"BitwiseOr": OperatorKind.INFIX,
|
|
43
|
+
"BitwiseXor": OperatorKind.INFIX,
|
|
44
|
+
}
|
|
73
45
|
|
|
74
46
|
|
|
75
47
|
def _lower_variadic(graph: Graph, node: Node) -> MultiInputBinaryOp:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
|
|
79
|
-
if op_spec is None:
|
|
80
|
-
raise UnsupportedOpError(
|
|
81
|
-
f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
|
|
82
|
-
)
|
|
48
|
+
if len(node.outputs) != 1:
|
|
49
|
+
raise UnsupportedOpError(f"{node.op_type} must have 1 output")
|
|
83
50
|
return MultiInputBinaryOp(
|
|
51
|
+
op_type=node.op_type,
|
|
84
52
|
inputs=tuple(node.inputs),
|
|
85
53
|
output=node.outputs[0],
|
|
86
|
-
function=
|
|
87
|
-
operator_kind=
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
input_dtype=op_dtype,
|
|
54
|
+
function=VARIADIC_OP_FUNCTIONS[node.op_type],
|
|
55
|
+
operator_kind=VARIADIC_OP_OPERATOR_KINDS[node.op_type],
|
|
56
|
+
min_inputs=2,
|
|
57
|
+
max_inputs=2 if node.op_type in BINARY_ONLY_OPS else None,
|
|
91
58
|
)
|
|
92
59
|
|
|
93
60
|
|
|
@@ -7,7 +7,9 @@ import numpy as np
|
|
|
7
7
|
|
|
8
8
|
from shared.scalar_types import ScalarType
|
|
9
9
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
10
|
+
from ..ir.context import GraphContext
|
|
10
11
|
from ..ir.model import Graph, Node
|
|
12
|
+
from ..ir.op_context import OpContext
|
|
11
13
|
from ..lowering.attention import resolve_attention_spec
|
|
12
14
|
from ..lowering.average_pool import lower_average_pool, lower_global_average_pool
|
|
13
15
|
from ..lowering.adagrad import lower_adagrad
|
|
@@ -2021,8 +2023,13 @@ def _eval_nonzero(evaluator: Evaluator, node: Node) -> None:
|
|
|
2021
2023
|
def _eval_expand(evaluator: Evaluator, node: Node) -> None:
|
|
2022
2024
|
op = lower_expand(evaluator.graph, node)
|
|
2023
2025
|
value = evaluator.values[op.input0]
|
|
2026
|
+
op_ctx = OpContext(GraphContext(evaluator.graph))
|
|
2027
|
+
op.validate(op_ctx)
|
|
2028
|
+
op.infer_types(op_ctx)
|
|
2029
|
+
op.infer_shapes(op_ctx)
|
|
2030
|
+
output_shape = op_ctx.shape(op.output)
|
|
2024
2031
|
evaluator.values[op.output] = np.broadcast_to(
|
|
2025
|
-
value,
|
|
2032
|
+
value, output_shape
|
|
2026
2033
|
).copy()
|
|
2027
2034
|
|
|
2028
2035
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
emx_onnx_cgen/__init__.py,sha256=jUSbu1kJ0krzVTYEcph3jCprBhD7tWNtiSdL6r29KrM,221
|
|
2
2
|
emx_onnx_cgen/__main__.py,sha256=iC1lLVtR6-TmpL6OxXcy3oIntExUtajn9-q627R1XyI,140
|
|
3
|
-
emx_onnx_cgen/_build_info.py,sha256=
|
|
4
|
-
emx_onnx_cgen/_version.py,sha256=
|
|
3
|
+
emx_onnx_cgen/_build_info.py,sha256=H1wVqVfVhPbohtf1JhkEW7F4rM6RzR1WlU9jrNMVYlE,112
|
|
4
|
+
emx_onnx_cgen/_version.py,sha256=UAb2Toi6SAdScDfq1uKRRv5QpMUuRtJqqwNxTMGe5Q4,704
|
|
5
5
|
emx_onnx_cgen/cli.py,sha256=7Y9JW-t1PLg25zOizuqyMqwsXbbG9ok99DsYeFSiOFQ,21685
|
|
6
|
-
emx_onnx_cgen/compiler.py,sha256=
|
|
6
|
+
emx_onnx_cgen/compiler.py,sha256=v1-EzVUxZv5Kfn81kCDVuferRxvXFXEeRaNbQ4w6xss,18437
|
|
7
7
|
emx_onnx_cgen/dtypes.py,sha256=jRx3BBvk0qFW14bngoL1B7L_IRasyNJ4jqhpM5YhcOM,1335
|
|
8
8
|
emx_onnx_cgen/errors.py,sha256=HpOv95mTgr9ZX2gYe1RtwVMbPskh7zkqjU_FgAD-uIM,363
|
|
9
9
|
emx_onnx_cgen/onnx_import.py,sha256=IF7KZGfEP9H4H1fHYjobGbB_381fqD_67KtqZYs9AZ4,9168
|
|
@@ -13,16 +13,16 @@ emx_onnx_cgen/testbench.py,sha256=-NbqD1aC7OXvFMLiLzd2IPObenQdHFH85cNxNSB1GeY,64
|
|
|
13
13
|
emx_onnx_cgen/validation.py,sha256=KFdUdGjQbzTj1szCJcjxnTi8f5l6ywNgCB9abbBpTbM,2360
|
|
14
14
|
emx_onnx_cgen/verification.py,sha256=IrhIMm29R2vEkW1Q8gtoQtscMGxfJRavNRSMJHBAJ5g,1041
|
|
15
15
|
emx_onnx_cgen/codegen/__init__.py,sha256=H_kBdc_w_W-3qdUZJHwKBDns1AeP_Un3-46LW20yLV0,406
|
|
16
|
-
emx_onnx_cgen/codegen/c_emitter.py,sha256=
|
|
16
|
+
emx_onnx_cgen/codegen/c_emitter.py,sha256=JdDJGv1HptINaLLZxlxNfo1R7VM9v680EyiDMpeReds,453199
|
|
17
17
|
emx_onnx_cgen/codegen/emitter.py,sha256=udcsqJNr46TFHiyVv5I4wdVH8ll6Bi4VqcR1VvofbnY,92
|
|
18
18
|
emx_onnx_cgen/ir/__init__.py,sha256=fD2D8qxlGoCFJb0m9v6u3XTgzSxDOhB4cfLBiCLovzg,102
|
|
19
19
|
emx_onnx_cgen/ir/context.py,sha256=cM3V6G3zs6VCsABP6TnZ8vvQ7VGwOF1iKtb1hq0WO3g,3356
|
|
20
20
|
emx_onnx_cgen/ir/model.py,sha256=SZ3K8t4dKUqWuXWe5ozApofXx4bdcf4p0WYCdeU-mFA,1265
|
|
21
|
-
emx_onnx_cgen/ir/op_base.py,sha256=
|
|
21
|
+
emx_onnx_cgen/ir/op_base.py,sha256=_iPeVkLPR3jsRASrvXEWk-k3BJboPHtZY6jnB0HdLvk,17611
|
|
22
22
|
emx_onnx_cgen/ir/op_context.py,sha256=9CZCUNJLsV4cJsYmJqWbaDrwQd4sr-9Ot1PmPSqGAto,2103
|
|
23
|
-
emx_onnx_cgen/ir/ops/__init__.py,sha256=
|
|
24
|
-
emx_onnx_cgen/ir/ops/elementwise.py,sha256=
|
|
25
|
-
emx_onnx_cgen/ir/ops/misc.py,sha256=
|
|
23
|
+
emx_onnx_cgen/ir/ops/__init__.py,sha256=Zk7QzNiB4CHcixZlA1thA78mcudXdTvCfKlxUTRrX24,2503
|
|
24
|
+
emx_onnx_cgen/ir/ops/elementwise.py,sha256=TXbyayj3UnfLe4tUYBEwBDr7ZFyFi1i8HdVdCjtvLCc,4241
|
|
25
|
+
emx_onnx_cgen/ir/ops/misc.py,sha256=vN4OpW5gsryQ0aiVNBFiYlZMxwg8Z9wUOBM7w3f4ZFE,13522
|
|
26
26
|
emx_onnx_cgen/ir/ops/nn.py,sha256=-4ZqDkcu7zgci3YVfMzCDzokqpZHgOYZaq_C1GclBZQ,14365
|
|
27
27
|
emx_onnx_cgen/ir/ops/reduce.py,sha256=-aA4bwOMppd9pnWQwhl6hOxryh0G2xRaHqeNwQ97AdY,2756
|
|
28
28
|
emx_onnx_cgen/lowering/__init__.py,sha256=AxnUfmpf5Teos1ms3zE6r0EBxxPYznGSOICDEFWH_pk,1535
|
|
@@ -42,10 +42,10 @@ emx_onnx_cgen/lowering/depth_space.py,sha256=i7INioNkofBxFlZW9y0W_qA6mp67_FAXouh
|
|
|
42
42
|
emx_onnx_cgen/lowering/dropout.py,sha256=MZ4YrB-jvUFXpIKE5kOLyrEF5uy5dh0yjJH6Rj8KlMs,1764
|
|
43
43
|
emx_onnx_cgen/lowering/einsum.py,sha256=MWAgWVOzP38RSOxJABwvYU6ykD9odmhrmddXinmFs7s,6117
|
|
44
44
|
emx_onnx_cgen/lowering/elementwise.py,sha256=q9X3qTll7gLp39NTTdzuLs9RBsONssw50l1hWo8wby0,12229
|
|
45
|
-
emx_onnx_cgen/lowering/expand.py,sha256=
|
|
45
|
+
emx_onnx_cgen/lowering/expand.py,sha256=y0h1x2xh6Oqtblm6TbELB6_I4fsquU3YuZoB4mZJeTo,525
|
|
46
46
|
emx_onnx_cgen/lowering/eye_like.py,sha256=QBiHWYZbgK4uiUYWuS7WHCMBGMSG0paNZM84OYmGb7c,1723
|
|
47
47
|
emx_onnx_cgen/lowering/flatten.py,sha256=6h-TQNy9iq5hfXR9h2clUrc2eHmZP9gAb9KbCSJdV20,2131
|
|
48
|
-
emx_onnx_cgen/lowering/gather.py,sha256=
|
|
48
|
+
emx_onnx_cgen/lowering/gather.py,sha256=3sxrld5GIS4OO3hRVp8QdbMtyLQUHbdCXL8vmZvh67c,599
|
|
49
49
|
emx_onnx_cgen/lowering/gather_elements.py,sha256=cCp2UFOjktgEfS9s9npMS_BXklBkpMpD7UhIIMhQ-_Y,2318
|
|
50
50
|
emx_onnx_cgen/lowering/gather_nd.py,sha256=rmr_ijeSeCrZ_R_QPwdoHPQUCe8nE0YRSv2NjUiiFjY,3090
|
|
51
51
|
emx_onnx_cgen/lowering/gemm.py,sha256=qBaZ-6FZAAMEaZ4uifo58tJI8SoBsJvkZTCg7jvq288,4579
|
|
@@ -92,16 +92,16 @@ emx_onnx_cgen/lowering/topk.py,sha256=Dqx7qMr4HbXhVGN-wJf_D4dPTvYMVT6S82A2M3f9Dw
|
|
|
92
92
|
emx_onnx_cgen/lowering/transpose.py,sha256=oNFRjkH63KqnO2Q4oJengEAUEYC1M3PW12AauWwebzI,1751
|
|
93
93
|
emx_onnx_cgen/lowering/trilu.py,sha256=OjJjyo2ZRcfo9UGH8Zfq4o0PR6YDeoHSj8DzMu0w318,3266
|
|
94
94
|
emx_onnx_cgen/lowering/unsqueeze.py,sha256=9y-OM-oY6ln1-R6duRRemeRrwBIpX2TZs_nRtlYQMYE,5985
|
|
95
|
-
emx_onnx_cgen/lowering/variadic.py,sha256=
|
|
95
|
+
emx_onnx_cgen/lowering/variadic.py,sha256=OrC3rwM3-SNewYRs7YA7DwwS8XW1ucxUobTEjZdEs4s,1823
|
|
96
96
|
emx_onnx_cgen/lowering/where.py,sha256=K2RUDvLg0uTvi6Z_uTOXM5jgc3PXRj0cTZ4u58GEGko,2644
|
|
97
97
|
emx_onnx_cgen/runtime/__init__.py,sha256=88xGpAs1IEBlzlWL_e9tnKUlaSRdc7pQUeVCu5LC4DY,50
|
|
98
|
-
emx_onnx_cgen/runtime/evaluator.py,sha256=
|
|
98
|
+
emx_onnx_cgen/runtime/evaluator.py,sha256=8d9GOzhYNs2XX5q4vjaTM-wxkf8_rE4QEf5e1USWGd8,114981
|
|
99
99
|
shared/__init__.py,sha256=bmP79AVZdY_1aNULJap9pm76Q41Rabrza6X-0A8lDzw,45
|
|
100
100
|
shared/scalar_functions.py,sha256=CErro1Du2Ri3uqX6Dgd18DzNbxduckAvsmLJ6oHGx9A,91123
|
|
101
101
|
shared/scalar_types.py,sha256=kEpsl5T-NVFxCcTzXqPJbtpvDiCgKHfz91dphLLZxZA,4912
|
|
102
102
|
shared/ulp.py,sha256=DpeovCFijmP8_M7zyTZWsNyfOtJ1AjNSdxf5jGsdfJo,1856
|
|
103
|
-
emx_onnx_cgen-0.3.
|
|
104
|
-
emx_onnx_cgen-0.3.
|
|
105
|
-
emx_onnx_cgen-0.3.
|
|
106
|
-
emx_onnx_cgen-0.3.
|
|
107
|
-
emx_onnx_cgen-0.3.
|
|
103
|
+
emx_onnx_cgen-0.3.5.dist-info/METADATA,sha256=XwhvHTOcBPst7LPvgjPnR9hnVV8Jj0RtHtMITPMpAsA,6266
|
|
104
|
+
emx_onnx_cgen-0.3.5.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
105
|
+
emx_onnx_cgen-0.3.5.dist-info/entry_points.txt,sha256=b7Rvmz_Bi9kWyn7QayQC_FEXiRpt4cS1RnluKh49yoo,57
|
|
106
|
+
emx_onnx_cgen-0.3.5.dist-info/top_level.txt,sha256=g39fo-blEbgiVcC_GRqAnBzN234w3LXbcVdLUoItSLk,21
|
|
107
|
+
emx_onnx_cgen-0.3.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|