emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from shared.scalar_types import ScalarType
|
|
8
|
+
|
|
9
|
+
from ..codegen.c_emitter import ReduceOp, ReshapeOp
|
|
10
|
+
from ..dtypes import scalar_type_from_onnx
|
|
11
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
12
|
+
from ..ir.model import Graph, Initializer, Node
|
|
13
|
+
from .registry import register_lowering
|
|
14
|
+
|
|
15
|
+
REDUCE_KIND_BY_OP = {
|
|
16
|
+
"ReduceSum": "sum",
|
|
17
|
+
"ReduceMean": "mean",
|
|
18
|
+
"ReduceMax": "max",
|
|
19
|
+
"ReduceMin": "min",
|
|
20
|
+
"ReduceProd": "prod",
|
|
21
|
+
"ReduceL1": "l1",
|
|
22
|
+
"ReduceL2": "l2",
|
|
23
|
+
"ReduceLogSum": "logsum",
|
|
24
|
+
"ReduceLogSumExp": "logsumexp",
|
|
25
|
+
"ReduceSumSquare": "sumsquare",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
REDUCE_OUTPUTS_FLOAT_ONLY = {
|
|
29
|
+
"ReduceMean",
|
|
30
|
+
"ReduceL1",
|
|
31
|
+
"ReduceL2",
|
|
32
|
+
"ReduceLogSum",
|
|
33
|
+
"ReduceLogSumExp",
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class _ReduceSpec:
|
|
39
|
+
axes: tuple[int, ...] | None
|
|
40
|
+
axes_input: str | None
|
|
41
|
+
axes_input_shape: tuple[int, ...] | None
|
|
42
|
+
axes_input_dtype: ScalarType | None
|
|
43
|
+
keepdims: bool
|
|
44
|
+
output_shape: tuple[int, ...]
|
|
45
|
+
reduce_count: int | None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class _AxesInputSpec:
|
|
50
|
+
axes: tuple[int, ...] | None
|
|
51
|
+
input_name: str | None
|
|
52
|
+
input_shape: tuple[int, ...] | None
|
|
53
|
+
input_dtype: ScalarType | None
|
|
54
|
+
present: bool
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
|
|
58
|
+
try:
|
|
59
|
+
return graph.find_value(name).type.shape
|
|
60
|
+
except KeyError as exc:
|
|
61
|
+
raise ShapeInferenceError(
|
|
62
|
+
f"Missing shape for value '{name}' in op {node.op_type}. "
|
|
63
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
64
|
+
) from exc
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
|
|
68
|
+
try:
|
|
69
|
+
return graph.find_value(name).type.dtype
|
|
70
|
+
except KeyError as exc:
|
|
71
|
+
raise ShapeInferenceError(
|
|
72
|
+
f"Missing dtype for value '{name}' in op {node.op_type}. "
|
|
73
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
74
|
+
) from exc
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _shape_product(shape: tuple[int, ...]) -> int:
|
|
78
|
+
product = 1
|
|
79
|
+
for dim in shape:
|
|
80
|
+
if dim < 0:
|
|
81
|
+
raise ShapeInferenceError("Dynamic dims are not supported")
|
|
82
|
+
if dim == 0:
|
|
83
|
+
return 0
|
|
84
|
+
product *= dim
|
|
85
|
+
return product
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
89
|
+
for initializer in graph.initializers:
|
|
90
|
+
if initializer.name == name:
|
|
91
|
+
return initializer
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _axes_input_info(graph: Graph, node: Node) -> _AxesInputSpec:
|
|
96
|
+
if len(node.inputs) < 2:
|
|
97
|
+
return _AxesInputSpec(None, None, None, None, False)
|
|
98
|
+
if node.inputs[1] == "":
|
|
99
|
+
return _AxesInputSpec(None, None, None, None, False)
|
|
100
|
+
initializer = _find_initializer(graph, node.inputs[1])
|
|
101
|
+
if initializer is None:
|
|
102
|
+
try:
|
|
103
|
+
value = graph.find_value(node.inputs[1])
|
|
104
|
+
except KeyError as exc:
|
|
105
|
+
raise UnsupportedOpError(
|
|
106
|
+
f"{node.op_type} axes input must be constant or inferable from shapes"
|
|
107
|
+
) from exc
|
|
108
|
+
if value.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
109
|
+
raise UnsupportedOpError(
|
|
110
|
+
f"{node.op_type} axes input must be int64 or int32"
|
|
111
|
+
)
|
|
112
|
+
if any(dim == 0 for dim in value.type.shape):
|
|
113
|
+
return _AxesInputSpec((), None, None, None, True)
|
|
114
|
+
return _AxesInputSpec(
|
|
115
|
+
None,
|
|
116
|
+
node.inputs[1],
|
|
117
|
+
value.type.shape,
|
|
118
|
+
value.type.dtype,
|
|
119
|
+
True,
|
|
120
|
+
)
|
|
121
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
122
|
+
raise UnsupportedOpError(
|
|
123
|
+
f"{node.op_type} axes input must be int64 or int32"
|
|
124
|
+
)
|
|
125
|
+
data = np.array(initializer.data, dtype=np.int64).ravel()
|
|
126
|
+
return _AxesInputSpec(
|
|
127
|
+
tuple(int(value) for value in data),
|
|
128
|
+
None,
|
|
129
|
+
None,
|
|
130
|
+
None,
|
|
131
|
+
True,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _axes_values_from_shape_ops(
|
|
136
|
+
graph: Graph, axes_input: str, node: Node
|
|
137
|
+
) -> tuple[int, ...] | None:
|
|
138
|
+
node_by_output = {
|
|
139
|
+
output: graph_node
|
|
140
|
+
for graph_node in graph.nodes
|
|
141
|
+
for output in graph_node.outputs
|
|
142
|
+
}
|
|
143
|
+
cache: dict[str, np.ndarray] = {}
|
|
144
|
+
|
|
145
|
+
def resolve_value(name: str) -> np.ndarray | None:
|
|
146
|
+
if name in cache:
|
|
147
|
+
return cache[name]
|
|
148
|
+
initializer = _find_initializer(graph, name)
|
|
149
|
+
if initializer is not None:
|
|
150
|
+
value = np.array(initializer.data)
|
|
151
|
+
cache[name] = value
|
|
152
|
+
return value
|
|
153
|
+
producer = node_by_output.get(name)
|
|
154
|
+
if producer is None:
|
|
155
|
+
return None
|
|
156
|
+
op_type = producer.op_type
|
|
157
|
+
if op_type == "Identity":
|
|
158
|
+
if len(producer.inputs) != 1:
|
|
159
|
+
return None
|
|
160
|
+
input_value = resolve_value(producer.inputs[0])
|
|
161
|
+
if input_value is None:
|
|
162
|
+
return None
|
|
163
|
+
value = np.array(input_value, copy=True)
|
|
164
|
+
elif op_type == "Cast":
|
|
165
|
+
if len(producer.inputs) != 1:
|
|
166
|
+
return None
|
|
167
|
+
input_value = resolve_value(producer.inputs[0])
|
|
168
|
+
if input_value is None:
|
|
169
|
+
return None
|
|
170
|
+
to_attr = producer.attrs.get("to")
|
|
171
|
+
if to_attr is None:
|
|
172
|
+
return None
|
|
173
|
+
dtype = scalar_type_from_onnx(int(to_attr))
|
|
174
|
+
if dtype is None:
|
|
175
|
+
return None
|
|
176
|
+
value = np.array(input_value, dtype=dtype.np_dtype)
|
|
177
|
+
elif op_type == "Shape":
|
|
178
|
+
if len(producer.inputs) != 1:
|
|
179
|
+
return None
|
|
180
|
+
input_shape = _value_shape(graph, producer.inputs[0], node)
|
|
181
|
+
value = np.array(input_shape, dtype=np.int64)
|
|
182
|
+
elif op_type == "Size":
|
|
183
|
+
if len(producer.inputs) != 1:
|
|
184
|
+
return None
|
|
185
|
+
input_shape = _value_shape(graph, producer.inputs[0], node)
|
|
186
|
+
value = np.array(_shape_product(input_shape), dtype=np.int64)
|
|
187
|
+
elif op_type == "Range":
|
|
188
|
+
if len(producer.inputs) != 3:
|
|
189
|
+
return None
|
|
190
|
+
start_value = resolve_value(producer.inputs[0])
|
|
191
|
+
limit_value = resolve_value(producer.inputs[1])
|
|
192
|
+
delta_value = resolve_value(producer.inputs[2])
|
|
193
|
+
if (
|
|
194
|
+
start_value is None
|
|
195
|
+
or limit_value is None
|
|
196
|
+
or delta_value is None
|
|
197
|
+
):
|
|
198
|
+
return None
|
|
199
|
+
start = np.array(start_value).reshape(-1)[0]
|
|
200
|
+
limit = np.array(limit_value).reshape(-1)[0]
|
|
201
|
+
delta = np.array(delta_value).reshape(-1)[0]
|
|
202
|
+
if float(delta) == 0.0:
|
|
203
|
+
raise UnsupportedOpError("Range delta must be non-zero")
|
|
204
|
+
dtype = _value_dtype(graph, producer.outputs[0], node)
|
|
205
|
+
value = np.arange(
|
|
206
|
+
start, limit, delta, dtype=dtype.np_dtype
|
|
207
|
+
)
|
|
208
|
+
elif op_type in {"Add", "Sub"}:
|
|
209
|
+
if len(producer.inputs) != 2:
|
|
210
|
+
return None
|
|
211
|
+
left_value = resolve_value(producer.inputs[0])
|
|
212
|
+
right_value = resolve_value(producer.inputs[1])
|
|
213
|
+
if left_value is None or right_value is None:
|
|
214
|
+
return None
|
|
215
|
+
if op_type == "Add":
|
|
216
|
+
value = np.array(left_value) + np.array(right_value)
|
|
217
|
+
else:
|
|
218
|
+
value = np.array(left_value) - np.array(right_value)
|
|
219
|
+
else:
|
|
220
|
+
return None
|
|
221
|
+
cache[name] = value
|
|
222
|
+
return value
|
|
223
|
+
|
|
224
|
+
axes_value = resolve_value(axes_input)
|
|
225
|
+
if axes_value is None:
|
|
226
|
+
return None
|
|
227
|
+
if axes_value.dtype.kind not in {"i", "u"}:
|
|
228
|
+
raise UnsupportedOpError(
|
|
229
|
+
f"{node.op_type} axes input must be int64 or int32"
|
|
230
|
+
)
|
|
231
|
+
return tuple(int(axis) for axis in axes_value.ravel())
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _all_ones_shape(shape: tuple[int, ...]) -> bool:
|
|
235
|
+
return all(dim == 1 for dim in shape)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _allow_unknown_reduce_output_shape(
|
|
239
|
+
expected_output_shape: tuple[int, ...],
|
|
240
|
+
output_shape: tuple[int, ...],
|
|
241
|
+
input_shape: tuple[int, ...],
|
|
242
|
+
) -> bool:
|
|
243
|
+
if expected_output_shape != () or not output_shape or not input_shape:
|
|
244
|
+
return False
|
|
245
|
+
return True
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _infer_axes_from_shapes(
|
|
249
|
+
input_shape: tuple[int, ...],
|
|
250
|
+
output_shape: tuple[int, ...],
|
|
251
|
+
keepdims: bool,
|
|
252
|
+
node: Node,
|
|
253
|
+
) -> tuple[int, ...] | None:
|
|
254
|
+
if keepdims:
|
|
255
|
+
if len(input_shape) != len(output_shape):
|
|
256
|
+
return None
|
|
257
|
+
axes: list[int] = []
|
|
258
|
+
for axis, (in_dim, out_dim) in enumerate(
|
|
259
|
+
zip(input_shape, output_shape)
|
|
260
|
+
):
|
|
261
|
+
if out_dim == in_dim:
|
|
262
|
+
if in_dim == 1:
|
|
263
|
+
return None
|
|
264
|
+
continue
|
|
265
|
+
if out_dim == 1 and in_dim != 1:
|
|
266
|
+
axes.append(axis)
|
|
267
|
+
continue
|
|
268
|
+
raise ShapeInferenceError(
|
|
269
|
+
f"{node.op_type} output shape does not match input shape"
|
|
270
|
+
)
|
|
271
|
+
return tuple(axes)
|
|
272
|
+
if len(output_shape) > len(input_shape):
|
|
273
|
+
return None
|
|
274
|
+
|
|
275
|
+
results: list[tuple[int, ...]] = []
|
|
276
|
+
|
|
277
|
+
def backtrack(
|
|
278
|
+
input_index: int, output_index: int, reduced_axes: list[int]
|
|
279
|
+
) -> None:
|
|
280
|
+
if output_index == len(output_shape):
|
|
281
|
+
results.append(
|
|
282
|
+
tuple(reduced_axes + list(range(input_index, len(input_shape))))
|
|
283
|
+
)
|
|
284
|
+
return
|
|
285
|
+
if input_index == len(input_shape):
|
|
286
|
+
return
|
|
287
|
+
if input_shape[input_index] == output_shape[output_index]:
|
|
288
|
+
backtrack(input_index + 1, output_index + 1, reduced_axes)
|
|
289
|
+
backtrack(
|
|
290
|
+
input_index + 1, output_index, reduced_axes + [input_index]
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
backtrack(0, 0, [])
|
|
294
|
+
unique = {axes for axes in results}
|
|
295
|
+
if len(unique) == 1:
|
|
296
|
+
return tuple(sorted(next(iter(unique))))
|
|
297
|
+
if not unique:
|
|
298
|
+
raise ShapeInferenceError(
|
|
299
|
+
f"{node.op_type} output shape does not match input shape"
|
|
300
|
+
)
|
|
301
|
+
return None
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def normalize_reduce_axes(
|
|
305
|
+
axes: tuple[int, ...], input_shape: tuple[int, ...], node: Node
|
|
306
|
+
) -> tuple[int, ...]:
|
|
307
|
+
rank = len(input_shape)
|
|
308
|
+
normalized: list[int] = []
|
|
309
|
+
for axis in axes:
|
|
310
|
+
axis = int(axis)
|
|
311
|
+
if axis < 0:
|
|
312
|
+
axis += rank
|
|
313
|
+
if axis < 0 or axis >= rank:
|
|
314
|
+
raise ShapeInferenceError(
|
|
315
|
+
f"{node.op_type} axis {axis} is out of range for rank {rank}"
|
|
316
|
+
)
|
|
317
|
+
normalized.append(axis)
|
|
318
|
+
if len(set(normalized)) != len(normalized):
|
|
319
|
+
raise ShapeInferenceError(f"{node.op_type} axes must be unique")
|
|
320
|
+
return tuple(sorted(normalized))
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def resolve_reduce_axes(
|
|
324
|
+
graph: Graph, node: Node, input_shape: tuple[int, ...]
|
|
325
|
+
) -> tuple[_ReduceSpec | None, bool]:
|
|
326
|
+
axes_attr = node.attrs.get("axes")
|
|
327
|
+
axes_input = _axes_input_info(graph, node)
|
|
328
|
+
if axes_attr is not None and axes_input.present:
|
|
329
|
+
raise UnsupportedOpError(
|
|
330
|
+
f"{node.op_type} cannot set both axes attribute and axes input"
|
|
331
|
+
)
|
|
332
|
+
keepdims = bool(int(node.attrs.get("keepdims", 1)))
|
|
333
|
+
if axes_attr is not None:
|
|
334
|
+
axes = tuple(int(value) for value in axes_attr)
|
|
335
|
+
axes_input_name = None
|
|
336
|
+
axes_input_shape = None
|
|
337
|
+
axes_input_dtype = None
|
|
338
|
+
elif axes_input.axes is not None:
|
|
339
|
+
axes = axes_input.axes
|
|
340
|
+
axes_input_name = None
|
|
341
|
+
axes_input_shape = None
|
|
342
|
+
axes_input_dtype = None
|
|
343
|
+
elif axes_input.present:
|
|
344
|
+
axes = None
|
|
345
|
+
if axes_input.input_name:
|
|
346
|
+
axes = _axes_values_from_shape_ops(
|
|
347
|
+
graph, axes_input.input_name, node
|
|
348
|
+
)
|
|
349
|
+
if axes is None:
|
|
350
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
351
|
+
axes = _infer_axes_from_shapes(
|
|
352
|
+
input_shape, output_shape, keepdims, node
|
|
353
|
+
)
|
|
354
|
+
if axes is None:
|
|
355
|
+
axes_input_name = axes_input.input_name
|
|
356
|
+
axes_input_shape = axes_input.input_shape
|
|
357
|
+
axes_input_dtype = axes_input.input_dtype
|
|
358
|
+
else:
|
|
359
|
+
axes_input_name = None
|
|
360
|
+
axes_input_shape = None
|
|
361
|
+
axes_input_dtype = None
|
|
362
|
+
else:
|
|
363
|
+
axes = ()
|
|
364
|
+
axes_input_name = None
|
|
365
|
+
axes_input_shape = None
|
|
366
|
+
axes_input_dtype = None
|
|
367
|
+
noop_with_empty_axes = bool(int(node.attrs.get("noop_with_empty_axes", 0)))
|
|
368
|
+
if axes is not None and not axes:
|
|
369
|
+
if noop_with_empty_axes:
|
|
370
|
+
return None, True
|
|
371
|
+
axes = tuple(range(len(input_shape)))
|
|
372
|
+
if axes is None:
|
|
373
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
374
|
+
if keepdims and len(output_shape) != len(input_shape):
|
|
375
|
+
raise ShapeInferenceError(
|
|
376
|
+
f"{node.op_type} output shape rank must match input rank"
|
|
377
|
+
)
|
|
378
|
+
if len(output_shape) > len(input_shape):
|
|
379
|
+
raise ShapeInferenceError(
|
|
380
|
+
f"{node.op_type} output shape rank must not exceed input rank"
|
|
381
|
+
)
|
|
382
|
+
return _ReduceSpec(
|
|
383
|
+
axes=None,
|
|
384
|
+
axes_input=axes_input_name,
|
|
385
|
+
axes_input_shape=axes_input_shape,
|
|
386
|
+
axes_input_dtype=axes_input_dtype,
|
|
387
|
+
keepdims=keepdims,
|
|
388
|
+
output_shape=output_shape,
|
|
389
|
+
reduce_count=None,
|
|
390
|
+
), False
|
|
391
|
+
axes = normalize_reduce_axes(axes, input_shape, node)
|
|
392
|
+
return _ReduceSpec(
|
|
393
|
+
axes=axes,
|
|
394
|
+
axes_input=None,
|
|
395
|
+
axes_input_shape=None,
|
|
396
|
+
axes_input_dtype=None,
|
|
397
|
+
keepdims=keepdims,
|
|
398
|
+
output_shape=(),
|
|
399
|
+
reduce_count=None,
|
|
400
|
+
), False
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _resolve_reduce_spec(graph: Graph, node: Node) -> _ReduceSpec | None:
|
|
404
|
+
if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
|
|
405
|
+
raise UnsupportedOpError(
|
|
406
|
+
f"{node.op_type} must have 1 or 2 inputs and 1 output"
|
|
407
|
+
)
|
|
408
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
409
|
+
axes_spec, noop = resolve_reduce_axes(graph, node, input_shape)
|
|
410
|
+
if noop:
|
|
411
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
412
|
+
if output_shape != input_shape:
|
|
413
|
+
raise ShapeInferenceError(
|
|
414
|
+
f"{node.op_type} output shape must be {input_shape}, got {output_shape}"
|
|
415
|
+
)
|
|
416
|
+
return None
|
|
417
|
+
if axes_spec is None:
|
|
418
|
+
raise ShapeInferenceError(f"{node.op_type} axes spec missing")
|
|
419
|
+
if axes_spec.axes is None:
|
|
420
|
+
return _ReduceSpec(
|
|
421
|
+
axes=None,
|
|
422
|
+
axes_input=axes_spec.axes_input,
|
|
423
|
+
axes_input_shape=axes_spec.axes_input_shape,
|
|
424
|
+
axes_input_dtype=axes_spec.axes_input_dtype,
|
|
425
|
+
keepdims=axes_spec.keepdims,
|
|
426
|
+
output_shape=axes_spec.output_shape,
|
|
427
|
+
reduce_count=None,
|
|
428
|
+
)
|
|
429
|
+
axes = axes_spec.axes
|
|
430
|
+
keepdims = axes_spec.keepdims
|
|
431
|
+
if keepdims:
|
|
432
|
+
output_shape = tuple(
|
|
433
|
+
1 if axis in axes else dim
|
|
434
|
+
for axis, dim in enumerate(input_shape)
|
|
435
|
+
)
|
|
436
|
+
else:
|
|
437
|
+
output_shape = tuple(
|
|
438
|
+
dim
|
|
439
|
+
for axis, dim in enumerate(input_shape)
|
|
440
|
+
if axis not in axes
|
|
441
|
+
)
|
|
442
|
+
expected_output_shape = _value_shape(graph, node.outputs[0], node)
|
|
443
|
+
if expected_output_shape != output_shape:
|
|
444
|
+
if _allow_unknown_reduce_output_shape(
|
|
445
|
+
expected_output_shape, output_shape, input_shape
|
|
446
|
+
):
|
|
447
|
+
pass
|
|
448
|
+
elif not (
|
|
449
|
+
_all_ones_shape(expected_output_shape)
|
|
450
|
+
and _all_ones_shape(output_shape)
|
|
451
|
+
and _shape_product(expected_output_shape)
|
|
452
|
+
== _shape_product(output_shape)
|
|
453
|
+
):
|
|
454
|
+
raise ShapeInferenceError(
|
|
455
|
+
f"{node.op_type} output shape must be {output_shape}, got {expected_output_shape}"
|
|
456
|
+
)
|
|
457
|
+
reduce_count = _shape_product(tuple(input_shape[axis] for axis in axes))
|
|
458
|
+
return _ReduceSpec(
|
|
459
|
+
axes=axes,
|
|
460
|
+
axes_input=None,
|
|
461
|
+
axes_input_shape=None,
|
|
462
|
+
axes_input_dtype=None,
|
|
463
|
+
keepdims=keepdims,
|
|
464
|
+
output_shape=output_shape,
|
|
465
|
+
reduce_count=reduce_count,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def _reduce_dtype_supported(dtype: ScalarType) -> bool:
|
|
470
|
+
return dtype in {
|
|
471
|
+
ScalarType.F16,
|
|
472
|
+
ScalarType.F32,
|
|
473
|
+
ScalarType.F64,
|
|
474
|
+
ScalarType.I64,
|
|
475
|
+
ScalarType.I32,
|
|
476
|
+
ScalarType.I16,
|
|
477
|
+
ScalarType.I8,
|
|
478
|
+
ScalarType.U64,
|
|
479
|
+
ScalarType.U32,
|
|
480
|
+
ScalarType.U16,
|
|
481
|
+
ScalarType.U8,
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def lower_reduce(graph: Graph, node: Node) -> ReduceOp | ReshapeOp:
|
|
486
|
+
if node.op_type not in REDUCE_KIND_BY_OP:
|
|
487
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
488
|
+
op_dtype = _value_dtype(graph, node.inputs[0], node)
|
|
489
|
+
output_dtype = _value_dtype(graph, node.outputs[0], node)
|
|
490
|
+
if op_dtype != output_dtype:
|
|
491
|
+
raise UnsupportedOpError(
|
|
492
|
+
f"{node.op_type} expects matching input/output dtypes, "
|
|
493
|
+
f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
|
|
494
|
+
)
|
|
495
|
+
if not _reduce_dtype_supported(op_dtype):
|
|
496
|
+
raise UnsupportedOpError(
|
|
497
|
+
f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
|
|
498
|
+
)
|
|
499
|
+
if node.op_type in REDUCE_OUTPUTS_FLOAT_ONLY and op_dtype not in {
|
|
500
|
+
ScalarType.F16,
|
|
501
|
+
ScalarType.F32,
|
|
502
|
+
ScalarType.F64,
|
|
503
|
+
}:
|
|
504
|
+
raise UnsupportedOpError(
|
|
505
|
+
f"{node.op_type} supports float16, float, and double inputs only"
|
|
506
|
+
)
|
|
507
|
+
spec = _resolve_reduce_spec(graph, node)
|
|
508
|
+
if spec is None:
|
|
509
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
510
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
511
|
+
return ReshapeOp(
|
|
512
|
+
input0=node.inputs[0],
|
|
513
|
+
output=node.outputs[0],
|
|
514
|
+
input_shape=input_shape,
|
|
515
|
+
output_shape=output_shape,
|
|
516
|
+
dtype=op_dtype,
|
|
517
|
+
input_dtype=op_dtype,
|
|
518
|
+
)
|
|
519
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
520
|
+
if spec.axes_input and (
|
|
521
|
+
spec.axes_input_shape is None or spec.axes_input_dtype is None
|
|
522
|
+
):
|
|
523
|
+
raise ShapeInferenceError(
|
|
524
|
+
f"{node.op_type} axes input must have a static shape and dtype"
|
|
525
|
+
)
|
|
526
|
+
return ReduceOp(
|
|
527
|
+
input0=node.inputs[0],
|
|
528
|
+
output=node.outputs[0],
|
|
529
|
+
input_shape=input_shape,
|
|
530
|
+
output_shape=spec.output_shape,
|
|
531
|
+
axes=spec.axes or (),
|
|
532
|
+
axes_input=spec.axes_input,
|
|
533
|
+
axes_input_shape=spec.axes_input_shape,
|
|
534
|
+
axes_input_dtype=spec.axes_input_dtype,
|
|
535
|
+
keepdims=spec.keepdims,
|
|
536
|
+
noop_with_empty_axes=bool(int(node.attrs.get("noop_with_empty_axes", 0))),
|
|
537
|
+
reduce_kind=REDUCE_KIND_BY_OP[node.op_type],
|
|
538
|
+
reduce_count=spec.reduce_count,
|
|
539
|
+
dtype=op_dtype,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
for _op_type in REDUCE_KIND_BY_OP:
|
|
544
|
+
register_lowering(_op_type)(lower_reduce)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Mapping
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
from ..ir.model import Graph, Node
|
|
7
|
+
from ..errors import UnsupportedOpError
|
|
8
|
+
|
|
9
|
+
LoweredOp = TypeVar("LoweredOp")
|
|
10
|
+
Handler = TypeVar("Handler")
|
|
11
|
+
|
|
12
|
+
_LOWERING_REGISTRY: dict[str, Callable[[Graph, Node], object]] = {}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def register_lowering(
|
|
16
|
+
op_type: str,
|
|
17
|
+
) -> Callable[[Callable[[Graph, Node], LoweredOp]], Callable[[Graph, Node], LoweredOp]]:
|
|
18
|
+
def decorator(
|
|
19
|
+
func: Callable[[Graph, Node], LoweredOp],
|
|
20
|
+
) -> Callable[[Graph, Node], LoweredOp]:
|
|
21
|
+
_LOWERING_REGISTRY[op_type] = func
|
|
22
|
+
return func
|
|
23
|
+
|
|
24
|
+
return decorator
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_lowering(op_type: str) -> Callable[[Graph, Node], object] | None:
|
|
28
|
+
return _LOWERING_REGISTRY.get(op_type)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_lowering_registry() -> Mapping[str, Callable[[Graph, Node], object]]:
|
|
32
|
+
return _LOWERING_REGISTRY
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def resolve_dispatch(
|
|
36
|
+
op_type: str,
|
|
37
|
+
registry: Mapping[str, Handler],
|
|
38
|
+
*,
|
|
39
|
+
binary_types: set[str],
|
|
40
|
+
unary_types: set[str],
|
|
41
|
+
binary_fallback: Callable[[], Handler],
|
|
42
|
+
unary_fallback: Callable[[], Handler],
|
|
43
|
+
) -> Handler:
|
|
44
|
+
handler = registry.get(op_type)
|
|
45
|
+
if handler is not None:
|
|
46
|
+
return handler
|
|
47
|
+
if op_type in binary_types:
|
|
48
|
+
return binary_fallback()
|
|
49
|
+
if op_type in unary_types:
|
|
50
|
+
return unary_fallback()
|
|
51
|
+
raise UnsupportedOpError(f"Unsupported op {op_type}")
|