emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/lowering/shape.py
CHANGED
|
@@ -2,32 +2,13 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import ShapeOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
8
9
|
from .registry import register_lowering
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
12
|
-
try:
|
|
13
|
-
return graph.find_value(name).type.shape
|
|
14
|
-
except KeyError as exc:
|
|
15
|
-
raise ShapeInferenceError(
|
|
16
|
-
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
17
|
-
"Hint: run ONNX shape inference or export with static shapes."
|
|
18
|
-
) from exc
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
22
|
-
try:
|
|
23
|
-
return graph.find_value(name).type.dtype
|
|
24
|
-
except KeyError as exc:
|
|
25
|
-
raise ShapeInferenceError(
|
|
26
|
-
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
27
|
-
"Hint: run ONNX shape inference or export with static shapes."
|
|
28
|
-
) from exc
|
|
29
|
-
|
|
30
|
-
|
|
31
12
|
def _normalize_slice_bounds(
|
|
32
13
|
rank: int, *, start: int | None, end: int | None
|
|
33
14
|
) -> tuple[int, int]:
|
|
@@ -46,14 +27,14 @@ def _normalize_slice_bounds(
|
|
|
46
27
|
def lower_shape(graph: Graph, node: Node) -> ShapeOp:
|
|
47
28
|
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
48
29
|
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
49
|
-
input_shape =
|
|
50
|
-
output_shape =
|
|
30
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
31
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
51
32
|
if len(output_shape) != 1:
|
|
52
33
|
raise ShapeInferenceError("Shape output must be 1D")
|
|
53
34
|
if output_shape[0] < 0:
|
|
54
35
|
raise ShapeInferenceError("Shape output length must be non-negative")
|
|
55
|
-
input_dtype =
|
|
56
|
-
output_dtype =
|
|
36
|
+
input_dtype = value_dtype(graph, node.inputs[0], node)
|
|
37
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
57
38
|
if output_dtype != ScalarType.I64:
|
|
58
39
|
raise UnsupportedOpError("Shape output dtype must be int64")
|
|
59
40
|
start = node.attrs.get("start")
|
emx_onnx_cgen/lowering/size.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import SizeOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import shape_product, value_dtype, value_shape
|
emx_onnx_cgen/lowering/slice.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from shared.scalar_types import ScalarType
|
|
8
8
|
|
|
9
|
-
from ..
|
|
9
|
+
from ..ir.ops import SliceOp
|
|
10
10
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
11
11
|
from ..ir.model import Graph, Initializer, Node
|
|
12
12
|
from ..lowering.common import value_dtype, value_shape
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import SoftmaxOp
|
|
4
4
|
from ..errors import UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from .common import node_dtype as _node_dtype
|
|
7
|
+
from .common import onnx_opset_version as _onnx_opset_version
|
|
7
8
|
from .common import shape_product as _shape_product
|
|
8
9
|
from .common import value_shape as _value_shape
|
|
9
10
|
from .registry import register_lowering
|
|
@@ -23,8 +24,11 @@ def lower_softmax(graph: Graph, node: Node) -> SoftmaxOp:
|
|
|
23
24
|
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
24
25
|
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
25
26
|
ensure_output_shape_matches_input(node, input_shape, output_shape)
|
|
27
|
+
opset_version = _onnx_opset_version(graph)
|
|
28
|
+
default_axis = 1 if opset_version is not None and opset_version < 13 else -1
|
|
29
|
+
axis_attr = node.attrs.get("axis", default_axis)
|
|
26
30
|
axis = _normalize_axis(
|
|
27
|
-
int(
|
|
31
|
+
int(axis_attr),
|
|
28
32
|
input_shape,
|
|
29
33
|
node,
|
|
30
34
|
)
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import SoftmaxCrossEntropyLossOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import shape_product as _shape_product
|
emx_onnx_cgen/lowering/split.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import SplitOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import optional_name, value_dtype, value_shape
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import ReshapeOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
8
|
from .registry import register_lowering
|
|
@@ -95,11 +95,11 @@ def _validate_output_shape_for_unknown_axes(
|
|
|
95
95
|
for dim in input_shape:
|
|
96
96
|
if output_index < len(output_shape) and dim == output_shape[output_index]:
|
|
97
97
|
output_index += 1
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
98
|
+
else:
|
|
99
|
+
if dim != 1:
|
|
100
|
+
raise ShapeInferenceError(
|
|
101
|
+
"Squeeze output shape must remove only dimensions of size 1"
|
|
102
|
+
)
|
|
103
103
|
if output_index != len(output_shape):
|
|
104
104
|
raise ShapeInferenceError(
|
|
105
105
|
"Squeeze output shape must preserve input order while removing size-1 axes"
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import TensorScatterOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from ..validation import normalize_axis
|
|
9
|
+
from .common import optional_name, value_dtype, value_shape
|
|
10
|
+
from .registry import register_lowering
|
|
11
|
+
|
|
12
|
+
_ALLOWED_MODES = {"linear", "circular"}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_lowering("TensorScatter")
|
|
16
|
+
def lower_tensor_scatter(graph: Graph, node: Node) -> TensorScatterOp:
|
|
17
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
18
|
+
raise UnsupportedOpError(
|
|
19
|
+
"TensorScatter must have 2 or 3 inputs and 1 output"
|
|
20
|
+
)
|
|
21
|
+
past_cache_name = node.inputs[0]
|
|
22
|
+
update_name = node.inputs[1]
|
|
23
|
+
write_indices_name = optional_name(node.inputs, 2)
|
|
24
|
+
output_name = node.outputs[0]
|
|
25
|
+
past_cache_shape = value_shape(graph, past_cache_name, node)
|
|
26
|
+
update_shape = value_shape(graph, update_name, node)
|
|
27
|
+
output_shape = value_shape(graph, output_name, node)
|
|
28
|
+
if output_shape != past_cache_shape:
|
|
29
|
+
raise ShapeInferenceError(
|
|
30
|
+
"TensorScatter output shape must match past_cache shape, "
|
|
31
|
+
f"got {output_shape} vs {past_cache_shape}"
|
|
32
|
+
)
|
|
33
|
+
if len(update_shape) != len(past_cache_shape):
|
|
34
|
+
raise ShapeInferenceError(
|
|
35
|
+
"TensorScatter update shape rank must match past_cache rank, "
|
|
36
|
+
f"got {len(update_shape)} vs {len(past_cache_shape)}"
|
|
37
|
+
)
|
|
38
|
+
axis = normalize_axis(int(node.attrs.get("axis", -2)), past_cache_shape, node)
|
|
39
|
+
if axis == 0:
|
|
40
|
+
raise UnsupportedOpError(
|
|
41
|
+
"TensorScatter axis cannot be 0 (batch dimension)"
|
|
42
|
+
)
|
|
43
|
+
for dim_index, (past_dim, update_dim) in enumerate(
|
|
44
|
+
zip(past_cache_shape, update_shape)
|
|
45
|
+
):
|
|
46
|
+
if dim_index == axis:
|
|
47
|
+
if update_dim > past_dim:
|
|
48
|
+
raise ShapeInferenceError(
|
|
49
|
+
"TensorScatter update sequence length must be <= "
|
|
50
|
+
"past_cache sequence length, "
|
|
51
|
+
f"got {update_dim} vs {past_dim}"
|
|
52
|
+
)
|
|
53
|
+
elif update_dim != past_dim:
|
|
54
|
+
raise ShapeInferenceError(
|
|
55
|
+
"TensorScatter update shape must match past_cache shape "
|
|
56
|
+
f"outside axis {axis}, got {update_shape} vs {past_cache_shape}"
|
|
57
|
+
)
|
|
58
|
+
mode = node.attrs.get("mode", "linear")
|
|
59
|
+
if isinstance(mode, bytes):
|
|
60
|
+
mode = mode.decode("utf-8")
|
|
61
|
+
if mode not in _ALLOWED_MODES:
|
|
62
|
+
raise UnsupportedOpError(
|
|
63
|
+
"TensorScatter mode must be one of "
|
|
64
|
+
f"{sorted(_ALLOWED_MODES)}, got {mode}"
|
|
65
|
+
)
|
|
66
|
+
dtype = value_dtype(graph, past_cache_name, node)
|
|
67
|
+
update_dtype = value_dtype(graph, update_name, node)
|
|
68
|
+
output_dtype = value_dtype(graph, output_name, node)
|
|
69
|
+
if update_dtype != dtype or output_dtype != dtype:
|
|
70
|
+
raise UnsupportedOpError(
|
|
71
|
+
"TensorScatter expects past_cache, update, and output "
|
|
72
|
+
"to share the same dtype, "
|
|
73
|
+
f"got {dtype.onnx_name}, {update_dtype.onnx_name}, "
|
|
74
|
+
f"{output_dtype.onnx_name}"
|
|
75
|
+
)
|
|
76
|
+
write_indices_shape = None
|
|
77
|
+
write_indices_dtype = None
|
|
78
|
+
if write_indices_name is not None:
|
|
79
|
+
write_indices_shape = value_shape(graph, write_indices_name, node)
|
|
80
|
+
if len(write_indices_shape) != 1:
|
|
81
|
+
raise ShapeInferenceError(
|
|
82
|
+
"TensorScatter write_indices must be a 1D tensor"
|
|
83
|
+
)
|
|
84
|
+
if write_indices_shape[0] != past_cache_shape[0]:
|
|
85
|
+
raise ShapeInferenceError(
|
|
86
|
+
"TensorScatter write_indices length must match batch size, "
|
|
87
|
+
f"got {write_indices_shape[0]} vs {past_cache_shape[0]}"
|
|
88
|
+
)
|
|
89
|
+
write_indices_dtype = value_dtype(
|
|
90
|
+
graph, write_indices_name, node
|
|
91
|
+
)
|
|
92
|
+
if write_indices_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
93
|
+
raise UnsupportedOpError(
|
|
94
|
+
"TensorScatter write_indices must be int32 or int64, "
|
|
95
|
+
f"got {write_indices_dtype.onnx_name}"
|
|
96
|
+
)
|
|
97
|
+
return TensorScatterOp(
|
|
98
|
+
past_cache=past_cache_name,
|
|
99
|
+
update=update_name,
|
|
100
|
+
write_indices=write_indices_name,
|
|
101
|
+
output=output_name,
|
|
102
|
+
past_cache_shape=past_cache_shape,
|
|
103
|
+
update_shape=update_shape,
|
|
104
|
+
output_shape=output_shape,
|
|
105
|
+
write_indices_shape=write_indices_shape,
|
|
106
|
+
axis=axis,
|
|
107
|
+
mode=mode,
|
|
108
|
+
dtype=dtype,
|
|
109
|
+
write_indices_dtype=write_indices_dtype,
|
|
110
|
+
)
|
emx_onnx_cgen/lowering/tile.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import TileOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from ..lowering.common import value_dtype, value_shape
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..ir.ops import TopKOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import shape_product, value_dtype, value_shape
|
|
11
|
+
from ..validation import normalize_axis
|
|
12
|
+
from .registry import register_lowering
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
16
|
+
for initializer in graph.initializers:
|
|
17
|
+
if initializer.name == name:
|
|
18
|
+
return initializer
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _read_k(graph: Graph, name: str, node: Node) -> int | None:
|
|
23
|
+
initializer = _find_initializer(graph, name)
|
|
24
|
+
if initializer is None:
|
|
25
|
+
return None
|
|
26
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
27
|
+
raise UnsupportedOpError(
|
|
28
|
+
f"{node.op_type} k input must be int64 or int32"
|
|
29
|
+
)
|
|
30
|
+
data = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
31
|
+
if data.size != 1:
|
|
32
|
+
raise ShapeInferenceError(
|
|
33
|
+
f"{node.op_type} k input must contain a single value"
|
|
34
|
+
)
|
|
35
|
+
k = int(data[0])
|
|
36
|
+
if k <= 0:
|
|
37
|
+
raise ShapeInferenceError(
|
|
38
|
+
f"{node.op_type} k must be a positive value, got {k}"
|
|
39
|
+
)
|
|
40
|
+
return k
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _topk_dtype_supported(dtype: ScalarType) -> bool:
|
|
44
|
+
return not dtype.is_bool
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def lower_topk(graph: Graph, node: Node) -> TopKOp:
|
|
48
|
+
if node.op_type != "TopK":
|
|
49
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
50
|
+
if len(node.inputs) != 2 or len(node.outputs) != 2:
|
|
51
|
+
raise UnsupportedOpError(
|
|
52
|
+
f"{node.op_type} must have 2 inputs and 2 outputs"
|
|
53
|
+
)
|
|
54
|
+
input_name = node.inputs[0]
|
|
55
|
+
k_name = node.inputs[1]
|
|
56
|
+
output_values = node.outputs[0]
|
|
57
|
+
output_indices = node.outputs[1]
|
|
58
|
+
input_shape = value_shape(graph, input_name, node)
|
|
59
|
+
shape_product(input_shape)
|
|
60
|
+
axis = int(node.attrs.get("axis", -1))
|
|
61
|
+
axis = normalize_axis(axis, input_shape, node)
|
|
62
|
+
k = _read_k(graph, k_name, node)
|
|
63
|
+
axis_dim = input_shape[axis]
|
|
64
|
+
values_shape = value_shape(graph, output_values, node)
|
|
65
|
+
indices_shape = value_shape(graph, output_indices, node)
|
|
66
|
+
if values_shape != indices_shape:
|
|
67
|
+
raise ShapeInferenceError(
|
|
68
|
+
f"{node.op_type} values and indices output shapes must match, "
|
|
69
|
+
f"got {values_shape} and {indices_shape}"
|
|
70
|
+
)
|
|
71
|
+
if k is None:
|
|
72
|
+
k_shape = value_shape(graph, k_name, node)
|
|
73
|
+
if len(k_shape) != 1 or k_shape[0] != 1:
|
|
74
|
+
raise ShapeInferenceError(
|
|
75
|
+
f"{node.op_type} k input must be a 1-element tensor"
|
|
76
|
+
)
|
|
77
|
+
if axis >= len(values_shape):
|
|
78
|
+
raise ShapeInferenceError(
|
|
79
|
+
f"{node.op_type} axis {axis} exceeds output rank {len(values_shape)}"
|
|
80
|
+
)
|
|
81
|
+
k = values_shape[axis]
|
|
82
|
+
if k <= 0:
|
|
83
|
+
raise ShapeInferenceError(
|
|
84
|
+
f"{node.op_type} k must be a positive value, got {k}"
|
|
85
|
+
)
|
|
86
|
+
if k > axis_dim:
|
|
87
|
+
raise ShapeInferenceError(
|
|
88
|
+
f"{node.op_type} k {k} exceeds axis dimension {axis_dim}"
|
|
89
|
+
)
|
|
90
|
+
output_shape_expected = list(input_shape)
|
|
91
|
+
output_shape_expected[axis] = k
|
|
92
|
+
output_shape = tuple(output_shape_expected)
|
|
93
|
+
if values_shape != output_shape:
|
|
94
|
+
raise ShapeInferenceError(
|
|
95
|
+
f"{node.op_type} values output shape must be {output_shape}, got {values_shape}"
|
|
96
|
+
)
|
|
97
|
+
if indices_shape != output_shape:
|
|
98
|
+
raise ShapeInferenceError(
|
|
99
|
+
f"{node.op_type} indices output shape must be {output_shape}, got {indices_shape}"
|
|
100
|
+
)
|
|
101
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
102
|
+
if not _topk_dtype_supported(input_dtype):
|
|
103
|
+
raise UnsupportedOpError(
|
|
104
|
+
f"{node.op_type} does not support dtype {input_dtype.onnx_name}"
|
|
105
|
+
)
|
|
106
|
+
values_dtype = value_dtype(graph, output_values, node)
|
|
107
|
+
if values_dtype != input_dtype:
|
|
108
|
+
raise UnsupportedOpError(
|
|
109
|
+
f"{node.op_type} values output dtype must be {input_dtype.onnx_name}"
|
|
110
|
+
)
|
|
111
|
+
indices_dtype = value_dtype(graph, output_indices, node)
|
|
112
|
+
if indices_dtype != ScalarType.I64:
|
|
113
|
+
raise UnsupportedOpError(
|
|
114
|
+
f"{node.op_type} indices output dtype must be int64"
|
|
115
|
+
)
|
|
116
|
+
largest = bool(int(node.attrs.get("largest", 1)))
|
|
117
|
+
sorted_output = bool(int(node.attrs.get("sorted", 1)))
|
|
118
|
+
return TopKOp(
|
|
119
|
+
input0=input_name,
|
|
120
|
+
output_values=output_values,
|
|
121
|
+
output_indices=output_indices,
|
|
122
|
+
input_shape=input_shape,
|
|
123
|
+
output_shape=output_shape,
|
|
124
|
+
axis=axis,
|
|
125
|
+
k=k,
|
|
126
|
+
largest=largest,
|
|
127
|
+
sorted=sorted_output,
|
|
128
|
+
input_dtype=input_dtype,
|
|
129
|
+
output_values_dtype=values_dtype,
|
|
130
|
+
output_indices_dtype=indices_dtype,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
register_lowering("TopK")(lower_topk)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..ir.ops import TriluOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Initializer, Node
|
|
10
|
+
from ..lowering.common import optional_name, value_dtype, value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
15
|
+
for initializer in graph.initializers:
|
|
16
|
+
if initializer.name == name:
|
|
17
|
+
return initializer
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
|
|
22
|
+
return shape == () or shape == (1,)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _read_k_initializer(initializer: Initializer, node: Node) -> int:
|
|
26
|
+
if initializer.type.dtype != ScalarType.I64:
|
|
27
|
+
raise UnsupportedOpError(
|
|
28
|
+
f"{node.op_type} k input must be int64"
|
|
29
|
+
)
|
|
30
|
+
data = np.array(initializer.data, dtype=np.int64).reshape(-1)
|
|
31
|
+
if data.size != 1:
|
|
32
|
+
raise UnsupportedOpError(f"{node.op_type} k input must be scalar")
|
|
33
|
+
return int(data[0])
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@register_lowering("Trilu")
|
|
37
|
+
def lower_trilu(graph: Graph, node: Node) -> TriluOp:
|
|
38
|
+
if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
|
|
39
|
+
raise UnsupportedOpError("Trilu must have 1 or 2 inputs and 1 output")
|
|
40
|
+
input_name = node.inputs[0]
|
|
41
|
+
if not input_name:
|
|
42
|
+
raise UnsupportedOpError("Trilu input must be provided")
|
|
43
|
+
input_shape = value_shape(graph, input_name, node)
|
|
44
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
45
|
+
if input_shape != output_shape:
|
|
46
|
+
raise ShapeInferenceError("Trilu input and output shapes must match")
|
|
47
|
+
if len(output_shape) < 2:
|
|
48
|
+
raise UnsupportedOpError("Trilu expects input rank >= 2")
|
|
49
|
+
input_dtype = value_dtype(graph, input_name, node)
|
|
50
|
+
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
51
|
+
if input_dtype != output_dtype:
|
|
52
|
+
raise UnsupportedOpError(
|
|
53
|
+
"Trilu expects matching input/output dtypes, "
|
|
54
|
+
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
55
|
+
)
|
|
56
|
+
upper_attr = node.attrs.get("upper", 1)
|
|
57
|
+
upper = bool(int(upper_attr))
|
|
58
|
+
k_input = optional_name(node.inputs, 1)
|
|
59
|
+
k_value = 0
|
|
60
|
+
k_input_name = None
|
|
61
|
+
k_input_shape = None
|
|
62
|
+
k_input_dtype = None
|
|
63
|
+
if k_input:
|
|
64
|
+
k_initializer = _find_initializer(graph, k_input)
|
|
65
|
+
if k_initializer is not None:
|
|
66
|
+
k_value = _read_k_initializer(k_initializer, node)
|
|
67
|
+
else:
|
|
68
|
+
k_shape = value_shape(graph, k_input, node)
|
|
69
|
+
if not _is_scalar_shape(k_shape):
|
|
70
|
+
raise UnsupportedOpError("Trilu k input must be scalar")
|
|
71
|
+
k_dtype = value_dtype(graph, k_input, node)
|
|
72
|
+
if k_dtype != ScalarType.I64:
|
|
73
|
+
raise UnsupportedOpError("Trilu k input must be int64")
|
|
74
|
+
k_input_name = k_input
|
|
75
|
+
k_input_shape = k_shape
|
|
76
|
+
k_input_dtype = k_dtype
|
|
77
|
+
return TriluOp(
|
|
78
|
+
input0=input_name,
|
|
79
|
+
output=node.outputs[0],
|
|
80
|
+
input_shape=input_shape,
|
|
81
|
+
output_shape=output_shape,
|
|
82
|
+
upper=upper,
|
|
83
|
+
k_value=k_value,
|
|
84
|
+
k_input=k_input_name,
|
|
85
|
+
k_input_shape=k_input_shape,
|
|
86
|
+
k_input_dtype=k_input_dtype,
|
|
87
|
+
dtype=output_dtype,
|
|
88
|
+
input_dtype=input_dtype,
|
|
89
|
+
)
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import ReshapeOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
8
|
from .registry import register_lowering
|
|
@@ -131,11 +131,11 @@ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
|
|
|
131
131
|
for dim in output_shape:
|
|
132
132
|
if input_index < len(input_shape) and dim == input_shape[input_index]:
|
|
133
133
|
input_index += 1
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
134
|
+
else:
|
|
135
|
+
if dim != 1:
|
|
136
|
+
raise ShapeInferenceError(
|
|
137
|
+
"Unsqueeze output shape must insert ones only"
|
|
138
|
+
)
|
|
139
139
|
if input_index != len(input_shape):
|
|
140
140
|
raise ShapeInferenceError(
|
|
141
141
|
"Unsqueeze output shape must contain input shape in order"
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from shared.scalar_functions import ScalarFunction
|
|
4
4
|
from shared.scalar_types import ScalarType
|
|
5
5
|
|
|
6
|
-
from ..
|
|
6
|
+
from ..ir.ops import MultiInputBinaryOp
|
|
7
7
|
from ..errors import UnsupportedOpError
|
|
8
8
|
from ..ir.model import Graph, Node
|
|
9
9
|
from ..lowering.common import node_dtype, value_dtype, value_shape
|
emx_onnx_cgen/lowering/where.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import WhereOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Node
|
|
8
8
|
from .common import value_dtype as _value_dtype
|
emx_onnx_cgen/onnx_import.py
CHANGED
|
@@ -212,6 +212,9 @@ def import_onnx(model: onnx.ModelProto) -> Graph:
|
|
|
212
212
|
dim_param_by_name = _collect_dim_params(
|
|
213
213
|
tuple(model.graph.input) + tuple(model.graph.output)
|
|
214
214
|
)
|
|
215
|
+
opset_imports = tuple(
|
|
216
|
+
(opset.domain, opset.version) for opset in model.opset_import
|
|
217
|
+
)
|
|
215
218
|
try:
|
|
216
219
|
model = shape_inference.infer_shapes(model, data_prop=True)
|
|
217
220
|
except Exception as exc: # pragma: no cover - onnx inference errors
|
|
@@ -258,4 +261,5 @@ def import_onnx(model: onnx.ModelProto) -> Graph:
|
|
|
258
261
|
nodes=nodes,
|
|
259
262
|
initializers=initializers,
|
|
260
263
|
values=values,
|
|
264
|
+
opset_imports=opset_imports,
|
|
261
265
|
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def make_deterministic_session_options(ort: Any) -> Any:
|
|
7
|
+
options = ort.SessionOptions()
|
|
8
|
+
options.intra_op_num_threads = 1
|
|
9
|
+
options.inter_op_num_threads = 1
|
|
10
|
+
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
11
|
+
return options
|
emx_onnx_cgen/ops.py
CHANGED
|
@@ -87,6 +87,7 @@ UNARY_OP_TYPES = {
|
|
|
87
87
|
"Identity",
|
|
88
88
|
"LeakyRelu",
|
|
89
89
|
"Log",
|
|
90
|
+
"Mish",
|
|
90
91
|
"Neg",
|
|
91
92
|
"Not",
|
|
92
93
|
"Reciprocal",
|
|
@@ -177,6 +178,7 @@ UNARY_SYMBOLS_DOUBLE = {
|
|
|
177
178
|
ScalarFunction.LEAKY_RELU: "leaky_relu",
|
|
178
179
|
ScalarFunction.POSITIVE: "identity",
|
|
179
180
|
ScalarFunction.LOG: "log",
|
|
181
|
+
ScalarFunction.MISH: "mish",
|
|
180
182
|
ScalarFunction.NEG: "neg",
|
|
181
183
|
ScalarFunction.RECIPROCAL: "reciprocal",
|
|
182
184
|
ScalarFunction.RELU: "relu",
|
|
@@ -215,6 +217,7 @@ UNARY_SYMBOLS_FLOAT = {
|
|
|
215
217
|
ScalarFunction.LEAKY_RELU: "leaky_relu",
|
|
216
218
|
ScalarFunction.POSITIVE: "identity",
|
|
217
219
|
ScalarFunction.LOG: "logf",
|
|
220
|
+
ScalarFunction.MISH: "mish",
|
|
218
221
|
ScalarFunction.NEG: "neg",
|
|
219
222
|
ScalarFunction.RECIPROCAL: "reciprocal",
|
|
220
223
|
ScalarFunction.RELU: "relu",
|
|
@@ -457,6 +460,7 @@ UNARY_APPLY_FUNCS = {
|
|
|
457
460
|
"thresholded_relu": lambda value: np.where(
|
|
458
461
|
value > 1.0, value, 0.0
|
|
459
462
|
),
|
|
463
|
+
"mish": lambda value: value * np.tanh(np.log1p(np.exp(value))),
|
|
460
464
|
"atanhf": np.arctanh,
|
|
461
465
|
"atanh": np.arctanh,
|
|
462
466
|
}
|