emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/lowering/reduce.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from shared.scalar_types import ScalarType
|
|
8
8
|
|
|
9
|
-
from ..
|
|
9
|
+
from ..ir.ops import ReduceOp, ReshapeOp
|
|
10
10
|
from ..dtypes import scalar_type_from_onnx
|
|
11
11
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
12
12
|
from ..ir.model import Graph, Initializer, Node
|
|
@@ -261,13 +261,12 @@ def _infer_axes_from_shapes(
|
|
|
261
261
|
if out_dim == in_dim:
|
|
262
262
|
if in_dim == 1:
|
|
263
263
|
return None
|
|
264
|
-
|
|
265
|
-
if out_dim == 1 and in_dim != 1:
|
|
264
|
+
elif out_dim == 1 and in_dim != 1:
|
|
266
265
|
axes.append(axis)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
266
|
+
else:
|
|
267
|
+
raise ShapeInferenceError(
|
|
268
|
+
f"{node.op_type} output shape does not match input shape"
|
|
269
|
+
)
|
|
271
270
|
return tuple(axes)
|
|
272
271
|
if len(output_shape) > len(input_shape):
|
|
273
272
|
return None
|
|
@@ -3,32 +3,51 @@ from __future__ import annotations
|
|
|
3
3
|
from collections.abc import Callable, Mapping
|
|
4
4
|
from typing import TypeVar
|
|
5
5
|
|
|
6
|
+
from ..ir.context import GraphContext
|
|
6
7
|
from ..ir.model import Graph, Node
|
|
8
|
+
from ..ir.op_base import OpBase
|
|
7
9
|
from ..errors import UnsupportedOpError
|
|
8
10
|
|
|
9
11
|
LoweredOp = TypeVar("LoweredOp")
|
|
10
12
|
Handler = TypeVar("Handler")
|
|
11
13
|
|
|
12
|
-
_LOWERING_REGISTRY: dict[str, Callable[[Graph, Node],
|
|
14
|
+
_LOWERING_REGISTRY: dict[str, Callable[[Graph | GraphContext, Node], OpBase]] = {}
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
def register_lowering(
|
|
16
18
|
op_type: str,
|
|
17
19
|
) -> Callable[[Callable[[Graph, Node], LoweredOp]], Callable[[Graph, Node], LoweredOp]]:
|
|
18
20
|
def decorator(
|
|
19
|
-
func: Callable[[Graph, Node], LoweredOp],
|
|
20
|
-
) -> Callable[[Graph, Node], LoweredOp]:
|
|
21
|
+
func: Callable[[Graph | GraphContext, Node], LoweredOp],
|
|
22
|
+
) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
|
|
21
23
|
_LOWERING_REGISTRY[op_type] = func
|
|
22
24
|
return func
|
|
23
25
|
|
|
24
26
|
return decorator
|
|
25
27
|
|
|
26
28
|
|
|
27
|
-
def
|
|
29
|
+
def register_lowering_if_missing(
|
|
30
|
+
op_type: str,
|
|
31
|
+
) -> Callable[[Callable[[Graph | GraphContext, Node], LoweredOp]], Callable[[Graph | GraphContext, Node], LoweredOp]]:
|
|
32
|
+
def decorator(
|
|
33
|
+
func: Callable[[Graph | GraphContext, Node], LoweredOp],
|
|
34
|
+
) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
|
|
35
|
+
if op_type not in _LOWERING_REGISTRY:
|
|
36
|
+
_LOWERING_REGISTRY[op_type] = func
|
|
37
|
+
return func
|
|
38
|
+
|
|
39
|
+
return decorator
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_lowering(
|
|
43
|
+
op_type: str,
|
|
44
|
+
) -> Callable[[Graph | GraphContext, Node], OpBase] | None:
|
|
28
45
|
return _LOWERING_REGISTRY.get(op_type)
|
|
29
46
|
|
|
30
47
|
|
|
31
|
-
def get_lowering_registry() -> Mapping[
|
|
48
|
+
def get_lowering_registry() -> Mapping[
|
|
49
|
+
str, Callable[[Graph | GraphContext, Node], OpBase]
|
|
50
|
+
]:
|
|
32
51
|
return _LOWERING_REGISTRY
|
|
33
52
|
|
|
34
53
|
|
|
@@ -2,9 +2,10 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from shared.scalar_types import ScalarType
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..ir.ops import ReshapeOp
|
|
6
6
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
7
|
from ..ir.model import Graph, Initializer, Node
|
|
8
|
+
from .common import value_shape as resolved_value_shape
|
|
8
9
|
from .registry import register_lowering
|
|
9
10
|
|
|
10
11
|
|
|
@@ -37,6 +38,21 @@ def _shape_product(shape: tuple[int, ...]) -> int:
|
|
|
37
38
|
return product
|
|
38
39
|
|
|
39
40
|
|
|
41
|
+
def _reshape_mismatch_error(
|
|
42
|
+
node: Node,
|
|
43
|
+
input_shape: tuple[int, ...],
|
|
44
|
+
output_shape: tuple[int, ...],
|
|
45
|
+
) -> ShapeInferenceError:
|
|
46
|
+
node_name = node.name or "<unnamed>"
|
|
47
|
+
return ShapeInferenceError(
|
|
48
|
+
"Reshape input/output element counts must match for op "
|
|
49
|
+
f"{node.op_type} (node '{node_name}'): input shape {input_shape}, "
|
|
50
|
+
f"output shape {output_shape}. "
|
|
51
|
+
"Hint: ensure the reshape target has the same number of elements as "
|
|
52
|
+
"the input."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
40
56
|
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
41
57
|
for initializer in graph.initializers:
|
|
42
58
|
if initializer.name == name:
|
|
@@ -52,15 +68,190 @@ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
|
|
|
52
68
|
|
|
53
69
|
|
|
54
70
|
def _shape_values_from_shape_node(
|
|
55
|
-
graph: Graph,
|
|
56
|
-
) -> list[int]
|
|
57
|
-
shape_node = _find_node_by_output(graph, name)
|
|
58
|
-
if shape_node is None or shape_node.op_type != "Shape":
|
|
59
|
-
return None
|
|
71
|
+
graph: Graph, shape_node: Node, node: Node
|
|
72
|
+
) -> list[int]:
|
|
60
73
|
if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
|
|
61
74
|
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
62
75
|
source_shape = _value_shape(graph, shape_node.inputs[0], node)
|
|
63
|
-
|
|
76
|
+
start = int(shape_node.attrs.get("start", 0))
|
|
77
|
+
end = int(shape_node.attrs.get("end", len(source_shape)))
|
|
78
|
+
if start < 0:
|
|
79
|
+
start += len(source_shape)
|
|
80
|
+
if end < 0:
|
|
81
|
+
end += len(source_shape)
|
|
82
|
+
start = max(start, 0)
|
|
83
|
+
end = min(end, len(source_shape))
|
|
84
|
+
if start > end:
|
|
85
|
+
return []
|
|
86
|
+
return list(source_shape[start:end])
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _shape_values_from_initializer(
|
|
90
|
+
graph: Graph,
|
|
91
|
+
name: str,
|
|
92
|
+
) -> list[int] | None:
|
|
93
|
+
initializer = _find_initializer(graph, name)
|
|
94
|
+
if initializer is None:
|
|
95
|
+
return None
|
|
96
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
97
|
+
raise UnsupportedOpError(
|
|
98
|
+
"Reshape expects int64 or int32 shape input, "
|
|
99
|
+
f"got {initializer.type.dtype.onnx_name}"
|
|
100
|
+
)
|
|
101
|
+
return [int(value) for value in initializer.data.reshape(-1)]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _shape_values_from_input(
|
|
105
|
+
graph: Graph,
|
|
106
|
+
name: str,
|
|
107
|
+
node: Node,
|
|
108
|
+
*,
|
|
109
|
+
_visited: set[str] | None = None,
|
|
110
|
+
) -> list[int] | None:
|
|
111
|
+
if _visited is None:
|
|
112
|
+
_visited = set()
|
|
113
|
+
if name in _visited:
|
|
114
|
+
return None
|
|
115
|
+
_visited.add(name)
|
|
116
|
+
try:
|
|
117
|
+
shape_values = _shape_values_from_initializer(graph, name)
|
|
118
|
+
if shape_values is not None:
|
|
119
|
+
return shape_values
|
|
120
|
+
source_node = _find_node_by_output(graph, name)
|
|
121
|
+
if source_node is None:
|
|
122
|
+
return None
|
|
123
|
+
if source_node.op_type == "Shape":
|
|
124
|
+
return _shape_values_from_shape_node(graph, source_node, node)
|
|
125
|
+
if source_node.op_type == "Concat":
|
|
126
|
+
axis = int(source_node.attrs.get("axis", 0))
|
|
127
|
+
if axis != 0:
|
|
128
|
+
raise UnsupportedOpError("Reshape shape concat must use axis 0")
|
|
129
|
+
values: list[int] = []
|
|
130
|
+
for input_name in source_node.inputs:
|
|
131
|
+
input_values = _shape_values_from_input(
|
|
132
|
+
graph,
|
|
133
|
+
input_name,
|
|
134
|
+
node,
|
|
135
|
+
_visited=_visited,
|
|
136
|
+
)
|
|
137
|
+
if input_values is None:
|
|
138
|
+
return None
|
|
139
|
+
values.extend(input_values)
|
|
140
|
+
return values
|
|
141
|
+
if source_node.op_type == "Cast":
|
|
142
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
143
|
+
raise UnsupportedOpError("Cast must have 1 input and 1 output")
|
|
144
|
+
return _shape_values_from_input(
|
|
145
|
+
graph,
|
|
146
|
+
source_node.inputs[0],
|
|
147
|
+
node,
|
|
148
|
+
_visited=_visited,
|
|
149
|
+
)
|
|
150
|
+
if source_node.op_type == "Unsqueeze":
|
|
151
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
152
|
+
raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
|
|
153
|
+
return _shape_values_from_input(
|
|
154
|
+
graph,
|
|
155
|
+
source_node.inputs[0],
|
|
156
|
+
node,
|
|
157
|
+
_visited=_visited,
|
|
158
|
+
)
|
|
159
|
+
if source_node.op_type == "Identity":
|
|
160
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
161
|
+
raise UnsupportedOpError("Identity must have 1 input and 1 output")
|
|
162
|
+
return _shape_values_from_input(
|
|
163
|
+
graph,
|
|
164
|
+
source_node.inputs[0],
|
|
165
|
+
node,
|
|
166
|
+
_visited=_visited,
|
|
167
|
+
)
|
|
168
|
+
if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
|
|
169
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
170
|
+
raise UnsupportedOpError(
|
|
171
|
+
f"{source_node.op_type} must have 2 inputs and 1 output"
|
|
172
|
+
)
|
|
173
|
+
left = _shape_values_from_input(
|
|
174
|
+
graph,
|
|
175
|
+
source_node.inputs[0],
|
|
176
|
+
node,
|
|
177
|
+
_visited=_visited,
|
|
178
|
+
)
|
|
179
|
+
right = _shape_values_from_input(
|
|
180
|
+
graph,
|
|
181
|
+
source_node.inputs[1],
|
|
182
|
+
node,
|
|
183
|
+
_visited=_visited,
|
|
184
|
+
)
|
|
185
|
+
if left is None or right is None:
|
|
186
|
+
return None
|
|
187
|
+
if len(left) == 1 and len(right) != 1:
|
|
188
|
+
left = left * len(right)
|
|
189
|
+
if len(right) == 1 and len(left) != 1:
|
|
190
|
+
right = right * len(left)
|
|
191
|
+
if len(left) != len(right):
|
|
192
|
+
return None
|
|
193
|
+
if source_node.op_type == "Equal":
|
|
194
|
+
return [1 if l == r else 0 for l, r in zip(left, right)]
|
|
195
|
+
if source_node.op_type == "And":
|
|
196
|
+
return [1 if (l and r) else 0 for l, r in zip(left, right)]
|
|
197
|
+
if source_node.op_type == "Or":
|
|
198
|
+
return [1 if (l or r) else 0 for l, r in zip(left, right)]
|
|
199
|
+
if source_node.op_type == "Div":
|
|
200
|
+
return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
|
|
201
|
+
if source_node.op_type == "Mod":
|
|
202
|
+
return [l % r if r != 0 else 0 for l, r in zip(left, right)]
|
|
203
|
+
if source_node.op_type == "Not":
|
|
204
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
205
|
+
raise UnsupportedOpError("Not must have 1 input and 1 output")
|
|
206
|
+
values = _shape_values_from_input(
|
|
207
|
+
graph,
|
|
208
|
+
source_node.inputs[0],
|
|
209
|
+
node,
|
|
210
|
+
_visited=_visited,
|
|
211
|
+
)
|
|
212
|
+
if values is None:
|
|
213
|
+
return None
|
|
214
|
+
return [0 if value else 1 for value in values]
|
|
215
|
+
if source_node.op_type == "Where":
|
|
216
|
+
if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
|
|
217
|
+
raise UnsupportedOpError("Where must have 3 inputs and 1 output")
|
|
218
|
+
condition = _shape_values_from_input(
|
|
219
|
+
graph,
|
|
220
|
+
source_node.inputs[0],
|
|
221
|
+
node,
|
|
222
|
+
_visited=_visited,
|
|
223
|
+
)
|
|
224
|
+
if condition is None:
|
|
225
|
+
return None
|
|
226
|
+
on_true = _shape_values_from_input(
|
|
227
|
+
graph,
|
|
228
|
+
source_node.inputs[1],
|
|
229
|
+
node,
|
|
230
|
+
_visited=_visited,
|
|
231
|
+
)
|
|
232
|
+
on_false = _shape_values_from_input(
|
|
233
|
+
graph,
|
|
234
|
+
source_node.inputs[2],
|
|
235
|
+
node,
|
|
236
|
+
_visited=_visited,
|
|
237
|
+
)
|
|
238
|
+
if on_true is None or on_false is None:
|
|
239
|
+
return None
|
|
240
|
+
if len(condition) == 1:
|
|
241
|
+
condition = condition * max(len(on_true), len(on_false))
|
|
242
|
+
if len(on_true) == 1 and len(condition) != 1:
|
|
243
|
+
on_true = on_true * len(condition)
|
|
244
|
+
if len(on_false) == 1 and len(condition) != 1:
|
|
245
|
+
on_false = on_false * len(condition)
|
|
246
|
+
if not (len(condition) == len(on_true) == len(on_false)):
|
|
247
|
+
return None
|
|
248
|
+
return [
|
|
249
|
+
t if cond else f
|
|
250
|
+
for cond, t, f in zip(condition, on_true, on_false)
|
|
251
|
+
]
|
|
252
|
+
return None
|
|
253
|
+
finally:
|
|
254
|
+
_visited.remove(name)
|
|
64
255
|
|
|
65
256
|
|
|
66
257
|
def _resolve_target_shape(
|
|
@@ -82,19 +273,19 @@ def _resolve_target_shape(
|
|
|
82
273
|
raise ShapeInferenceError("Reshape allows only one -1 dimension")
|
|
83
274
|
unknown_index = index
|
|
84
275
|
output_dims.append(-1)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
276
|
+
else:
|
|
277
|
+
if dim == 0:
|
|
278
|
+
contains_zero = True
|
|
279
|
+
if allowzero == 0:
|
|
280
|
+
if index >= len(input_shape):
|
|
281
|
+
raise ShapeInferenceError(
|
|
282
|
+
"Reshape zero dim must index into input shape"
|
|
283
|
+
)
|
|
284
|
+
dim = input_shape[index]
|
|
285
|
+
if dim < 0:
|
|
286
|
+
raise ShapeInferenceError("Reshape dims must be >= -1")
|
|
287
|
+
output_dims.append(dim)
|
|
288
|
+
known_product *= dim
|
|
98
289
|
if allowzero == 1 and contains_zero and unknown_index is not None:
|
|
99
290
|
raise ShapeInferenceError(
|
|
100
291
|
"Reshape allowzero cannot combine zero and -1 dimensions"
|
|
@@ -115,9 +306,7 @@ def _resolve_target_shape(
|
|
|
115
306
|
output_dims[unknown_index] = input_product // known_product
|
|
116
307
|
output_shape = tuple(output_dims)
|
|
117
308
|
if _shape_product(output_shape) != input_product:
|
|
118
|
-
raise
|
|
119
|
-
"Reshape input and output element counts must match"
|
|
120
|
-
)
|
|
309
|
+
raise _reshape_mismatch_error(node, input_shape, output_shape)
|
|
121
310
|
return output_shape
|
|
122
311
|
|
|
123
312
|
|
|
@@ -125,7 +314,7 @@ def _resolve_target_shape(
|
|
|
125
314
|
def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
|
|
126
315
|
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
127
316
|
raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
|
|
128
|
-
input_shape =
|
|
317
|
+
input_shape = resolved_value_shape(graph, node.inputs[0], node)
|
|
129
318
|
input_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
130
319
|
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
131
320
|
if input_dtype != output_dtype:
|
|
@@ -133,46 +322,29 @@ def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
|
|
|
133
322
|
"Reshape expects matching input/output dtypes, "
|
|
134
323
|
f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
135
324
|
)
|
|
136
|
-
|
|
325
|
+
output_value = graph.find_value(node.outputs[0])
|
|
326
|
+
output_shape = resolved_value_shape(graph, node.outputs[0], node)
|
|
327
|
+
output_dim_params = output_value.type.dim_params
|
|
137
328
|
allowzero = int(node.attrs.get("allowzero", 0))
|
|
138
|
-
shape_initializer = _find_initializer(graph, node.inputs[1])
|
|
139
329
|
resolved_shape: tuple[int, ...] | None = None
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
graph, node.inputs[1], node
|
|
143
|
-
)
|
|
144
|
-
if shape_values is not None:
|
|
145
|
-
resolved_shape = _resolve_target_shape(
|
|
146
|
-
input_shape,
|
|
147
|
-
shape_values,
|
|
148
|
-
allowzero=allowzero,
|
|
149
|
-
node=node,
|
|
150
|
-
)
|
|
151
|
-
else:
|
|
152
|
-
if _shape_product(output_shape) != _shape_product(input_shape):
|
|
153
|
-
raise ShapeInferenceError(
|
|
154
|
-
"Reshape input and output element counts must match"
|
|
155
|
-
)
|
|
156
|
-
else:
|
|
157
|
-
if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
158
|
-
raise UnsupportedOpError(
|
|
159
|
-
"Reshape expects int64 or int32 shape input, "
|
|
160
|
-
f"got {shape_initializer.type.dtype.onnx_name}"
|
|
161
|
-
)
|
|
162
|
-
if len(shape_initializer.type.shape) != 1:
|
|
163
|
-
raise UnsupportedOpError("Reshape expects a 1D shape input")
|
|
164
|
-
shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
|
|
330
|
+
shape_values = _shape_values_from_input(graph, node.inputs[1], node)
|
|
331
|
+
if shape_values is not None:
|
|
165
332
|
resolved_shape = _resolve_target_shape(
|
|
166
333
|
input_shape,
|
|
167
334
|
shape_values,
|
|
168
335
|
allowzero=allowzero,
|
|
169
336
|
node=node,
|
|
170
337
|
)
|
|
171
|
-
if output_shape and resolved_shape != output_shape
|
|
338
|
+
if output_shape and resolved_shape != output_shape and not any(
|
|
339
|
+
output_dim_params
|
|
340
|
+
):
|
|
172
341
|
raise ShapeInferenceError(
|
|
173
342
|
"Reshape output shape must be "
|
|
174
343
|
f"{resolved_shape}, got {output_shape}"
|
|
175
344
|
)
|
|
345
|
+
else:
|
|
346
|
+
if _shape_product(output_shape) != _shape_product(input_shape):
|
|
347
|
+
raise _reshape_mismatch_error(node, input_shape, output_shape)
|
|
176
348
|
if resolved_shape is not None:
|
|
177
349
|
output_shape = resolved_shape
|
|
178
350
|
for dim in output_shape:
|
emx_onnx_cgen/lowering/resize.py
CHANGED
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
|
-
from ..
|
|
7
|
+
from ..ir.ops import ResizeOp
|
|
8
8
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
9
|
from ..ir.model import Graph, Initializer, Node
|
|
10
10
|
from .registry import register_lowering
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..
|
|
3
|
+
from ..ir.ops import RMSNormalizationOp
|
|
4
4
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
5
5
|
from ..ir.model import Graph, Node
|
|
6
6
|
from ..validation import ensure_output_shape_matches_input
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from shared.scalar_types import ScalarType
|
|
6
|
+
|
|
7
|
+
from ..ir.ops import RotaryEmbeddingOp
|
|
8
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
9
|
+
from ..ir.model import Graph, Node
|
|
10
|
+
from .common import optional_name, value_dtype, value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class RotaryEmbeddingSpec:
|
|
16
|
+
batch: int
|
|
17
|
+
seq_len: int
|
|
18
|
+
num_heads: int
|
|
19
|
+
head_size: int
|
|
20
|
+
rotary_dim: int
|
|
21
|
+
rotary_dim_half: int
|
|
22
|
+
input_rank: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _resolve_rotary_spec(
|
|
26
|
+
graph: Graph, node: Node, dtype: ScalarType
|
|
27
|
+
) -> RotaryEmbeddingSpec:
|
|
28
|
+
if not dtype.is_float:
|
|
29
|
+
raise UnsupportedOpError("Unsupported op RotaryEmbedding")
|
|
30
|
+
if len(node.inputs) < 3 or len(node.outputs) != 1:
|
|
31
|
+
raise UnsupportedOpError("Unsupported op RotaryEmbedding")
|
|
32
|
+
input_shape = value_shape(graph, node.inputs[0], node)
|
|
33
|
+
input_rank = len(input_shape)
|
|
34
|
+
if input_rank not in {3, 4}:
|
|
35
|
+
raise ShapeInferenceError("RotaryEmbedding expects 3D or 4D input")
|
|
36
|
+
if input_rank == 3:
|
|
37
|
+
num_heads_attr = node.attrs.get("num_heads")
|
|
38
|
+
if num_heads_attr is None:
|
|
39
|
+
raise UnsupportedOpError(
|
|
40
|
+
"RotaryEmbedding num_heads attribute is required for 3D inputs"
|
|
41
|
+
)
|
|
42
|
+
num_heads = int(num_heads_attr)
|
|
43
|
+
if num_heads <= 0:
|
|
44
|
+
raise ShapeInferenceError("RotaryEmbedding num_heads must be > 0")
|
|
45
|
+
batch, seq_len, hidden_size = input_shape
|
|
46
|
+
if hidden_size % num_heads != 0:
|
|
47
|
+
raise ShapeInferenceError(
|
|
48
|
+
"RotaryEmbedding hidden size must be divisible by num_heads"
|
|
49
|
+
)
|
|
50
|
+
head_size = hidden_size // num_heads
|
|
51
|
+
else:
|
|
52
|
+
batch, num_heads, seq_len, head_size = input_shape
|
|
53
|
+
num_heads_attr = node.attrs.get("num_heads")
|
|
54
|
+
if num_heads_attr is not None and int(num_heads_attr) != num_heads:
|
|
55
|
+
raise ShapeInferenceError(
|
|
56
|
+
"RotaryEmbedding num_heads must match input head dimension"
|
|
57
|
+
)
|
|
58
|
+
if head_size % 2 != 0:
|
|
59
|
+
raise ShapeInferenceError("RotaryEmbedding head size must be even")
|
|
60
|
+
rotary_dim = int(node.attrs.get("rotary_embedding_dim", 0))
|
|
61
|
+
if rotary_dim == 0:
|
|
62
|
+
rotary_dim = head_size
|
|
63
|
+
if rotary_dim < 0 or rotary_dim > head_size:
|
|
64
|
+
raise ShapeInferenceError(
|
|
65
|
+
"RotaryEmbedding rotary_embedding_dim must be in [0, head_size]"
|
|
66
|
+
)
|
|
67
|
+
if rotary_dim % 2 != 0:
|
|
68
|
+
raise ShapeInferenceError(
|
|
69
|
+
"RotaryEmbedding rotary_embedding_dim must be even"
|
|
70
|
+
)
|
|
71
|
+
rotary_dim_half = rotary_dim // 2
|
|
72
|
+
return RotaryEmbeddingSpec(
|
|
73
|
+
batch=batch,
|
|
74
|
+
seq_len=seq_len,
|
|
75
|
+
num_heads=num_heads,
|
|
76
|
+
head_size=head_size,
|
|
77
|
+
rotary_dim=rotary_dim,
|
|
78
|
+
rotary_dim_half=rotary_dim_half,
|
|
79
|
+
input_rank=input_rank,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@register_lowering("RotaryEmbedding")
|
|
84
|
+
def lower_rotary_embedding(graph: Graph, node: Node) -> RotaryEmbeddingOp:
|
|
85
|
+
input_name = node.inputs[0]
|
|
86
|
+
cos_name = node.inputs[1]
|
|
87
|
+
sin_name = node.inputs[2]
|
|
88
|
+
position_ids = optional_name(node.inputs, 3)
|
|
89
|
+
dtype = value_dtype(graph, input_name, node)
|
|
90
|
+
cos_dtype = value_dtype(graph, cos_name, node)
|
|
91
|
+
sin_dtype = value_dtype(graph, sin_name, node)
|
|
92
|
+
if cos_dtype != dtype or sin_dtype != dtype:
|
|
93
|
+
raise ShapeInferenceError(
|
|
94
|
+
"RotaryEmbedding inputs must share the same dtype"
|
|
95
|
+
)
|
|
96
|
+
spec = _resolve_rotary_spec(graph, node, dtype)
|
|
97
|
+
input_shape = value_shape(graph, input_name, node)
|
|
98
|
+
output_shape = value_shape(graph, node.outputs[0], node)
|
|
99
|
+
if output_shape != input_shape:
|
|
100
|
+
raise ShapeInferenceError(
|
|
101
|
+
"RotaryEmbedding output shape must match input shape"
|
|
102
|
+
)
|
|
103
|
+
cos_shape = value_shape(graph, cos_name, node)
|
|
104
|
+
sin_shape = value_shape(graph, sin_name, node)
|
|
105
|
+
if cos_shape != sin_shape:
|
|
106
|
+
raise ShapeInferenceError(
|
|
107
|
+
"RotaryEmbedding cos/sin cache shapes must match"
|
|
108
|
+
)
|
|
109
|
+
position_shape = None
|
|
110
|
+
position_dtype = None
|
|
111
|
+
if position_ids is not None:
|
|
112
|
+
position_shape = value_shape(graph, position_ids, node)
|
|
113
|
+
if position_shape != (spec.batch, spec.seq_len):
|
|
114
|
+
raise ShapeInferenceError(
|
|
115
|
+
"RotaryEmbedding position_ids must match [batch, seq_len]"
|
|
116
|
+
)
|
|
117
|
+
position_dtype = value_dtype(graph, position_ids, node)
|
|
118
|
+
if not position_dtype.is_integer:
|
|
119
|
+
raise ShapeInferenceError(
|
|
120
|
+
"RotaryEmbedding position_ids must be an integer tensor"
|
|
121
|
+
)
|
|
122
|
+
if len(cos_shape) != 2:
|
|
123
|
+
raise ShapeInferenceError(
|
|
124
|
+
"RotaryEmbedding expects 2D sin/cos caches with position_ids"
|
|
125
|
+
)
|
|
126
|
+
if cos_shape[1] != spec.rotary_dim_half:
|
|
127
|
+
raise ShapeInferenceError(
|
|
128
|
+
"RotaryEmbedding cos/sin cache last dim must match rotary_dim/2"
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
if len(cos_shape) != 3:
|
|
132
|
+
raise ShapeInferenceError(
|
|
133
|
+
"RotaryEmbedding expects 3D sin/cos caches without position_ids"
|
|
134
|
+
)
|
|
135
|
+
if cos_shape != (
|
|
136
|
+
spec.batch,
|
|
137
|
+
spec.seq_len,
|
|
138
|
+
spec.rotary_dim_half,
|
|
139
|
+
):
|
|
140
|
+
raise ShapeInferenceError(
|
|
141
|
+
"RotaryEmbedding sin/cos cache shape must be "
|
|
142
|
+
"[batch, seq_len, rotary_dim/2]"
|
|
143
|
+
)
|
|
144
|
+
interleaved = bool(int(node.attrs.get("interleaved", 0)))
|
|
145
|
+
return RotaryEmbeddingOp(
|
|
146
|
+
input0=input_name,
|
|
147
|
+
cos_cache=cos_name,
|
|
148
|
+
sin_cache=sin_name,
|
|
149
|
+
position_ids=position_ids,
|
|
150
|
+
output=node.outputs[0],
|
|
151
|
+
input_shape=input_shape,
|
|
152
|
+
cos_shape=cos_shape,
|
|
153
|
+
sin_shape=sin_shape,
|
|
154
|
+
position_ids_shape=position_shape,
|
|
155
|
+
dtype=dtype,
|
|
156
|
+
position_ids_dtype=position_dtype,
|
|
157
|
+
rotary_dim=spec.rotary_dim,
|
|
158
|
+
rotary_dim_half=spec.rotary_dim_half,
|
|
159
|
+
head_size=spec.head_size,
|
|
160
|
+
num_heads=spec.num_heads,
|
|
161
|
+
seq_len=spec.seq_len,
|
|
162
|
+
batch=spec.batch,
|
|
163
|
+
input_rank=spec.input_rank,
|
|
164
|
+
interleaved=interleaved,
|
|
165
|
+
)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from shared.scalar_types import ScalarType
|
|
4
|
+
|
|
5
|
+
from ..ir.ops import ScatterNDOp
|
|
6
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
7
|
+
from ..ir.model import Graph, Node
|
|
8
|
+
from .common import value_dtype, value_shape
|
|
9
|
+
from .registry import register_lowering
|
|
10
|
+
|
|
11
|
+
_ALLOWED_REDUCTIONS = {"none", "add", "mul", "min", "max"}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@register_lowering("ScatterND")
|
|
15
|
+
def lower_scatternd(graph: Graph, node: Node) -> ScatterNDOp:
|
|
16
|
+
if len(node.inputs) != 3 or len(node.outputs) != 1:
|
|
17
|
+
raise UnsupportedOpError("ScatterND must have 3 inputs and 1 output")
|
|
18
|
+
data_name, indices_name, updates_name = node.inputs
|
|
19
|
+
output_name = node.outputs[0]
|
|
20
|
+
data_shape = value_shape(graph, data_name, node)
|
|
21
|
+
indices_shape = value_shape(graph, indices_name, node)
|
|
22
|
+
updates_shape = value_shape(graph, updates_name, node)
|
|
23
|
+
output_shape = value_shape(graph, output_name, node)
|
|
24
|
+
if output_shape != data_shape:
|
|
25
|
+
raise ShapeInferenceError(
|
|
26
|
+
"ScatterND output shape must match data shape, "
|
|
27
|
+
f"got {output_shape} vs {data_shape}"
|
|
28
|
+
)
|
|
29
|
+
if len(indices_shape) < 1:
|
|
30
|
+
raise ShapeInferenceError("ScatterND indices must have rank >= 1")
|
|
31
|
+
index_depth = indices_shape[-1]
|
|
32
|
+
if index_depth <= 0:
|
|
33
|
+
raise ShapeInferenceError(
|
|
34
|
+
"ScatterND indices final dimension must be >= 1"
|
|
35
|
+
)
|
|
36
|
+
if index_depth > len(data_shape):
|
|
37
|
+
raise ShapeInferenceError(
|
|
38
|
+
"ScatterND indices final dimension must be <= data rank, "
|
|
39
|
+
f"got {index_depth} vs {len(data_shape)}"
|
|
40
|
+
)
|
|
41
|
+
expected_updates_shape = indices_shape[:-1] + data_shape[index_depth:]
|
|
42
|
+
if updates_shape != expected_updates_shape:
|
|
43
|
+
raise ShapeInferenceError(
|
|
44
|
+
"ScatterND updates shape must be "
|
|
45
|
+
f"{expected_updates_shape}, got {updates_shape}"
|
|
46
|
+
)
|
|
47
|
+
data_dtype = value_dtype(graph, data_name, node)
|
|
48
|
+
updates_dtype = value_dtype(graph, updates_name, node)
|
|
49
|
+
if updates_dtype != data_dtype:
|
|
50
|
+
raise UnsupportedOpError(
|
|
51
|
+
"ScatterND updates dtype must match data dtype, "
|
|
52
|
+
f"got {updates_dtype.onnx_name} vs {data_dtype.onnx_name}"
|
|
53
|
+
)
|
|
54
|
+
indices_dtype = value_dtype(graph, indices_name, node)
|
|
55
|
+
if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
56
|
+
raise UnsupportedOpError(
|
|
57
|
+
"ScatterND indices must be int32 or int64, "
|
|
58
|
+
f"got {indices_dtype.onnx_name}"
|
|
59
|
+
)
|
|
60
|
+
reduction_attr = node.attrs.get("reduction", "none")
|
|
61
|
+
if isinstance(reduction_attr, bytes):
|
|
62
|
+
reduction = reduction_attr.decode()
|
|
63
|
+
else:
|
|
64
|
+
reduction = str(reduction_attr)
|
|
65
|
+
if reduction not in _ALLOWED_REDUCTIONS:
|
|
66
|
+
raise UnsupportedOpError(
|
|
67
|
+
"ScatterND reduction must be one of "
|
|
68
|
+
f"{sorted(_ALLOWED_REDUCTIONS)}, got {reduction}"
|
|
69
|
+
)
|
|
70
|
+
return ScatterNDOp(
|
|
71
|
+
data=data_name,
|
|
72
|
+
indices=indices_name,
|
|
73
|
+
updates=updates_name,
|
|
74
|
+
output=output_name,
|
|
75
|
+
data_shape=data_shape,
|
|
76
|
+
indices_shape=indices_shape,
|
|
77
|
+
updates_shape=updates_shape,
|
|
78
|
+
output_shape=output_shape,
|
|
79
|
+
reduction=reduction,
|
|
80
|
+
dtype=data_dtype,
|
|
81
|
+
indices_dtype=indices_dtype,
|
|
82
|
+
)
|