emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,544 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from shared.scalar_types import ScalarType
8
+
9
+ from ..codegen.c_emitter import ReduceOp, ReshapeOp
10
+ from ..dtypes import scalar_type_from_onnx
11
+ from ..errors import ShapeInferenceError, UnsupportedOpError
12
+ from ..ir.model import Graph, Initializer, Node
13
+ from .registry import register_lowering
14
+
15
+ REDUCE_KIND_BY_OP = {
16
+ "ReduceSum": "sum",
17
+ "ReduceMean": "mean",
18
+ "ReduceMax": "max",
19
+ "ReduceMin": "min",
20
+ "ReduceProd": "prod",
21
+ "ReduceL1": "l1",
22
+ "ReduceL2": "l2",
23
+ "ReduceLogSum": "logsum",
24
+ "ReduceLogSumExp": "logsumexp",
25
+ "ReduceSumSquare": "sumsquare",
26
+ }
27
+
28
+ REDUCE_OUTPUTS_FLOAT_ONLY = {
29
+ "ReduceMean",
30
+ "ReduceL1",
31
+ "ReduceL2",
32
+ "ReduceLogSum",
33
+ "ReduceLogSumExp",
34
+ }
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class _ReduceSpec:
39
+ axes: tuple[int, ...] | None
40
+ axes_input: str | None
41
+ axes_input_shape: tuple[int, ...] | None
42
+ axes_input_dtype: ScalarType | None
43
+ keepdims: bool
44
+ output_shape: tuple[int, ...]
45
+ reduce_count: int | None
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class _AxesInputSpec:
50
+ axes: tuple[int, ...] | None
51
+ input_name: str | None
52
+ input_shape: tuple[int, ...] | None
53
+ input_dtype: ScalarType | None
54
+ present: bool
55
+
56
+
57
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
58
+ try:
59
+ return graph.find_value(name).type.shape
60
+ except KeyError as exc:
61
+ raise ShapeInferenceError(
62
+ f"Missing shape for value '{name}' in op {node.op_type}. "
63
+ "Hint: run ONNX shape inference or export with static shapes."
64
+ ) from exc
65
+
66
+
67
+ def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
68
+ try:
69
+ return graph.find_value(name).type.dtype
70
+ except KeyError as exc:
71
+ raise ShapeInferenceError(
72
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
73
+ "Hint: run ONNX shape inference or export with static shapes."
74
+ ) from exc
75
+
76
+
77
+ def _shape_product(shape: tuple[int, ...]) -> int:
78
+ product = 1
79
+ for dim in shape:
80
+ if dim < 0:
81
+ raise ShapeInferenceError("Dynamic dims are not supported")
82
+ if dim == 0:
83
+ return 0
84
+ product *= dim
85
+ return product
86
+
87
+
88
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
89
+ for initializer in graph.initializers:
90
+ if initializer.name == name:
91
+ return initializer
92
+ return None
93
+
94
+
95
+ def _axes_input_info(graph: Graph, node: Node) -> _AxesInputSpec:
96
+ if len(node.inputs) < 2:
97
+ return _AxesInputSpec(None, None, None, None, False)
98
+ if node.inputs[1] == "":
99
+ return _AxesInputSpec(None, None, None, None, False)
100
+ initializer = _find_initializer(graph, node.inputs[1])
101
+ if initializer is None:
102
+ try:
103
+ value = graph.find_value(node.inputs[1])
104
+ except KeyError as exc:
105
+ raise UnsupportedOpError(
106
+ f"{node.op_type} axes input must be constant or inferable from shapes"
107
+ ) from exc
108
+ if value.type.dtype not in {ScalarType.I64, ScalarType.I32}:
109
+ raise UnsupportedOpError(
110
+ f"{node.op_type} axes input must be int64 or int32"
111
+ )
112
+ if any(dim == 0 for dim in value.type.shape):
113
+ return _AxesInputSpec((), None, None, None, True)
114
+ return _AxesInputSpec(
115
+ None,
116
+ node.inputs[1],
117
+ value.type.shape,
118
+ value.type.dtype,
119
+ True,
120
+ )
121
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
122
+ raise UnsupportedOpError(
123
+ f"{node.op_type} axes input must be int64 or int32"
124
+ )
125
+ data = np.array(initializer.data, dtype=np.int64).ravel()
126
+ return _AxesInputSpec(
127
+ tuple(int(value) for value in data),
128
+ None,
129
+ None,
130
+ None,
131
+ True,
132
+ )
133
+
134
+
135
+ def _axes_values_from_shape_ops(
136
+ graph: Graph, axes_input: str, node: Node
137
+ ) -> tuple[int, ...] | None:
138
+ node_by_output = {
139
+ output: graph_node
140
+ for graph_node in graph.nodes
141
+ for output in graph_node.outputs
142
+ }
143
+ cache: dict[str, np.ndarray] = {}
144
+
145
+ def resolve_value(name: str) -> np.ndarray | None:
146
+ if name in cache:
147
+ return cache[name]
148
+ initializer = _find_initializer(graph, name)
149
+ if initializer is not None:
150
+ value = np.array(initializer.data)
151
+ cache[name] = value
152
+ return value
153
+ producer = node_by_output.get(name)
154
+ if producer is None:
155
+ return None
156
+ op_type = producer.op_type
157
+ if op_type == "Identity":
158
+ if len(producer.inputs) != 1:
159
+ return None
160
+ input_value = resolve_value(producer.inputs[0])
161
+ if input_value is None:
162
+ return None
163
+ value = np.array(input_value, copy=True)
164
+ elif op_type == "Cast":
165
+ if len(producer.inputs) != 1:
166
+ return None
167
+ input_value = resolve_value(producer.inputs[0])
168
+ if input_value is None:
169
+ return None
170
+ to_attr = producer.attrs.get("to")
171
+ if to_attr is None:
172
+ return None
173
+ dtype = scalar_type_from_onnx(int(to_attr))
174
+ if dtype is None:
175
+ return None
176
+ value = np.array(input_value, dtype=dtype.np_dtype)
177
+ elif op_type == "Shape":
178
+ if len(producer.inputs) != 1:
179
+ return None
180
+ input_shape = _value_shape(graph, producer.inputs[0], node)
181
+ value = np.array(input_shape, dtype=np.int64)
182
+ elif op_type == "Size":
183
+ if len(producer.inputs) != 1:
184
+ return None
185
+ input_shape = _value_shape(graph, producer.inputs[0], node)
186
+ value = np.array(_shape_product(input_shape), dtype=np.int64)
187
+ elif op_type == "Range":
188
+ if len(producer.inputs) != 3:
189
+ return None
190
+ start_value = resolve_value(producer.inputs[0])
191
+ limit_value = resolve_value(producer.inputs[1])
192
+ delta_value = resolve_value(producer.inputs[2])
193
+ if (
194
+ start_value is None
195
+ or limit_value is None
196
+ or delta_value is None
197
+ ):
198
+ return None
199
+ start = np.array(start_value).reshape(-1)[0]
200
+ limit = np.array(limit_value).reshape(-1)[0]
201
+ delta = np.array(delta_value).reshape(-1)[0]
202
+ if float(delta) == 0.0:
203
+ raise UnsupportedOpError("Range delta must be non-zero")
204
+ dtype = _value_dtype(graph, producer.outputs[0], node)
205
+ value = np.arange(
206
+ start, limit, delta, dtype=dtype.np_dtype
207
+ )
208
+ elif op_type in {"Add", "Sub"}:
209
+ if len(producer.inputs) != 2:
210
+ return None
211
+ left_value = resolve_value(producer.inputs[0])
212
+ right_value = resolve_value(producer.inputs[1])
213
+ if left_value is None or right_value is None:
214
+ return None
215
+ if op_type == "Add":
216
+ value = np.array(left_value) + np.array(right_value)
217
+ else:
218
+ value = np.array(left_value) - np.array(right_value)
219
+ else:
220
+ return None
221
+ cache[name] = value
222
+ return value
223
+
224
+ axes_value = resolve_value(axes_input)
225
+ if axes_value is None:
226
+ return None
227
+ if axes_value.dtype.kind not in {"i", "u"}:
228
+ raise UnsupportedOpError(
229
+ f"{node.op_type} axes input must be int64 or int32"
230
+ )
231
+ return tuple(int(axis) for axis in axes_value.ravel())
232
+
233
+
234
+ def _all_ones_shape(shape: tuple[int, ...]) -> bool:
235
+ return all(dim == 1 for dim in shape)
236
+
237
+
238
+ def _allow_unknown_reduce_output_shape(
239
+ expected_output_shape: tuple[int, ...],
240
+ output_shape: tuple[int, ...],
241
+ input_shape: tuple[int, ...],
242
+ ) -> bool:
243
+ if expected_output_shape != () or not output_shape or not input_shape:
244
+ return False
245
+ return True
246
+
247
+
248
+ def _infer_axes_from_shapes(
249
+ input_shape: tuple[int, ...],
250
+ output_shape: tuple[int, ...],
251
+ keepdims: bool,
252
+ node: Node,
253
+ ) -> tuple[int, ...] | None:
254
+ if keepdims:
255
+ if len(input_shape) != len(output_shape):
256
+ return None
257
+ axes: list[int] = []
258
+ for axis, (in_dim, out_dim) in enumerate(
259
+ zip(input_shape, output_shape)
260
+ ):
261
+ if out_dim == in_dim:
262
+ if in_dim == 1:
263
+ return None
264
+ continue
265
+ if out_dim == 1 and in_dim != 1:
266
+ axes.append(axis)
267
+ continue
268
+ raise ShapeInferenceError(
269
+ f"{node.op_type} output shape does not match input shape"
270
+ )
271
+ return tuple(axes)
272
+ if len(output_shape) > len(input_shape):
273
+ return None
274
+
275
+ results: list[tuple[int, ...]] = []
276
+
277
+ def backtrack(
278
+ input_index: int, output_index: int, reduced_axes: list[int]
279
+ ) -> None:
280
+ if output_index == len(output_shape):
281
+ results.append(
282
+ tuple(reduced_axes + list(range(input_index, len(input_shape))))
283
+ )
284
+ return
285
+ if input_index == len(input_shape):
286
+ return
287
+ if input_shape[input_index] == output_shape[output_index]:
288
+ backtrack(input_index + 1, output_index + 1, reduced_axes)
289
+ backtrack(
290
+ input_index + 1, output_index, reduced_axes + [input_index]
291
+ )
292
+
293
+ backtrack(0, 0, [])
294
+ unique = {axes for axes in results}
295
+ if len(unique) == 1:
296
+ return tuple(sorted(next(iter(unique))))
297
+ if not unique:
298
+ raise ShapeInferenceError(
299
+ f"{node.op_type} output shape does not match input shape"
300
+ )
301
+ return None
302
+
303
+
304
+ def normalize_reduce_axes(
305
+ axes: tuple[int, ...], input_shape: tuple[int, ...], node: Node
306
+ ) -> tuple[int, ...]:
307
+ rank = len(input_shape)
308
+ normalized: list[int] = []
309
+ for axis in axes:
310
+ axis = int(axis)
311
+ if axis < 0:
312
+ axis += rank
313
+ if axis < 0 or axis >= rank:
314
+ raise ShapeInferenceError(
315
+ f"{node.op_type} axis {axis} is out of range for rank {rank}"
316
+ )
317
+ normalized.append(axis)
318
+ if len(set(normalized)) != len(normalized):
319
+ raise ShapeInferenceError(f"{node.op_type} axes must be unique")
320
+ return tuple(sorted(normalized))
321
+
322
+
323
+ def resolve_reduce_axes(
324
+ graph: Graph, node: Node, input_shape: tuple[int, ...]
325
+ ) -> tuple[_ReduceSpec | None, bool]:
326
+ axes_attr = node.attrs.get("axes")
327
+ axes_input = _axes_input_info(graph, node)
328
+ if axes_attr is not None and axes_input.present:
329
+ raise UnsupportedOpError(
330
+ f"{node.op_type} cannot set both axes attribute and axes input"
331
+ )
332
+ keepdims = bool(int(node.attrs.get("keepdims", 1)))
333
+ if axes_attr is not None:
334
+ axes = tuple(int(value) for value in axes_attr)
335
+ axes_input_name = None
336
+ axes_input_shape = None
337
+ axes_input_dtype = None
338
+ elif axes_input.axes is not None:
339
+ axes = axes_input.axes
340
+ axes_input_name = None
341
+ axes_input_shape = None
342
+ axes_input_dtype = None
343
+ elif axes_input.present:
344
+ axes = None
345
+ if axes_input.input_name:
346
+ axes = _axes_values_from_shape_ops(
347
+ graph, axes_input.input_name, node
348
+ )
349
+ if axes is None:
350
+ output_shape = _value_shape(graph, node.outputs[0], node)
351
+ axes = _infer_axes_from_shapes(
352
+ input_shape, output_shape, keepdims, node
353
+ )
354
+ if axes is None:
355
+ axes_input_name = axes_input.input_name
356
+ axes_input_shape = axes_input.input_shape
357
+ axes_input_dtype = axes_input.input_dtype
358
+ else:
359
+ axes_input_name = None
360
+ axes_input_shape = None
361
+ axes_input_dtype = None
362
+ else:
363
+ axes = ()
364
+ axes_input_name = None
365
+ axes_input_shape = None
366
+ axes_input_dtype = None
367
+ noop_with_empty_axes = bool(int(node.attrs.get("noop_with_empty_axes", 0)))
368
+ if axes is not None and not axes:
369
+ if noop_with_empty_axes:
370
+ return None, True
371
+ axes = tuple(range(len(input_shape)))
372
+ if axes is None:
373
+ output_shape = _value_shape(graph, node.outputs[0], node)
374
+ if keepdims and len(output_shape) != len(input_shape):
375
+ raise ShapeInferenceError(
376
+ f"{node.op_type} output shape rank must match input rank"
377
+ )
378
+ if len(output_shape) > len(input_shape):
379
+ raise ShapeInferenceError(
380
+ f"{node.op_type} output shape rank must not exceed input rank"
381
+ )
382
+ return _ReduceSpec(
383
+ axes=None,
384
+ axes_input=axes_input_name,
385
+ axes_input_shape=axes_input_shape,
386
+ axes_input_dtype=axes_input_dtype,
387
+ keepdims=keepdims,
388
+ output_shape=output_shape,
389
+ reduce_count=None,
390
+ ), False
391
+ axes = normalize_reduce_axes(axes, input_shape, node)
392
+ return _ReduceSpec(
393
+ axes=axes,
394
+ axes_input=None,
395
+ axes_input_shape=None,
396
+ axes_input_dtype=None,
397
+ keepdims=keepdims,
398
+ output_shape=(),
399
+ reduce_count=None,
400
+ ), False
401
+
402
+
403
+ def _resolve_reduce_spec(graph: Graph, node: Node) -> _ReduceSpec | None:
404
+ if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
405
+ raise UnsupportedOpError(
406
+ f"{node.op_type} must have 1 or 2 inputs and 1 output"
407
+ )
408
+ input_shape = _value_shape(graph, node.inputs[0], node)
409
+ axes_spec, noop = resolve_reduce_axes(graph, node, input_shape)
410
+ if noop:
411
+ output_shape = _value_shape(graph, node.outputs[0], node)
412
+ if output_shape != input_shape:
413
+ raise ShapeInferenceError(
414
+ f"{node.op_type} output shape must be {input_shape}, got {output_shape}"
415
+ )
416
+ return None
417
+ if axes_spec is None:
418
+ raise ShapeInferenceError(f"{node.op_type} axes spec missing")
419
+ if axes_spec.axes is None:
420
+ return _ReduceSpec(
421
+ axes=None,
422
+ axes_input=axes_spec.axes_input,
423
+ axes_input_shape=axes_spec.axes_input_shape,
424
+ axes_input_dtype=axes_spec.axes_input_dtype,
425
+ keepdims=axes_spec.keepdims,
426
+ output_shape=axes_spec.output_shape,
427
+ reduce_count=None,
428
+ )
429
+ axes = axes_spec.axes
430
+ keepdims = axes_spec.keepdims
431
+ if keepdims:
432
+ output_shape = tuple(
433
+ 1 if axis in axes else dim
434
+ for axis, dim in enumerate(input_shape)
435
+ )
436
+ else:
437
+ output_shape = tuple(
438
+ dim
439
+ for axis, dim in enumerate(input_shape)
440
+ if axis not in axes
441
+ )
442
+ expected_output_shape = _value_shape(graph, node.outputs[0], node)
443
+ if expected_output_shape != output_shape:
444
+ if _allow_unknown_reduce_output_shape(
445
+ expected_output_shape, output_shape, input_shape
446
+ ):
447
+ pass
448
+ elif not (
449
+ _all_ones_shape(expected_output_shape)
450
+ and _all_ones_shape(output_shape)
451
+ and _shape_product(expected_output_shape)
452
+ == _shape_product(output_shape)
453
+ ):
454
+ raise ShapeInferenceError(
455
+ f"{node.op_type} output shape must be {output_shape}, got {expected_output_shape}"
456
+ )
457
+ reduce_count = _shape_product(tuple(input_shape[axis] for axis in axes))
458
+ return _ReduceSpec(
459
+ axes=axes,
460
+ axes_input=None,
461
+ axes_input_shape=None,
462
+ axes_input_dtype=None,
463
+ keepdims=keepdims,
464
+ output_shape=output_shape,
465
+ reduce_count=reduce_count,
466
+ )
467
+
468
+
469
+ def _reduce_dtype_supported(dtype: ScalarType) -> bool:
470
+ return dtype in {
471
+ ScalarType.F16,
472
+ ScalarType.F32,
473
+ ScalarType.F64,
474
+ ScalarType.I64,
475
+ ScalarType.I32,
476
+ ScalarType.I16,
477
+ ScalarType.I8,
478
+ ScalarType.U64,
479
+ ScalarType.U32,
480
+ ScalarType.U16,
481
+ ScalarType.U8,
482
+ }
483
+
484
+
485
+ def lower_reduce(graph: Graph, node: Node) -> ReduceOp | ReshapeOp:
486
+ if node.op_type not in REDUCE_KIND_BY_OP:
487
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
488
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
489
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
490
+ if op_dtype != output_dtype:
491
+ raise UnsupportedOpError(
492
+ f"{node.op_type} expects matching input/output dtypes, "
493
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
494
+ )
495
+ if not _reduce_dtype_supported(op_dtype):
496
+ raise UnsupportedOpError(
497
+ f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
498
+ )
499
+ if node.op_type in REDUCE_OUTPUTS_FLOAT_ONLY and op_dtype not in {
500
+ ScalarType.F16,
501
+ ScalarType.F32,
502
+ ScalarType.F64,
503
+ }:
504
+ raise UnsupportedOpError(
505
+ f"{node.op_type} supports float16, float, and double inputs only"
506
+ )
507
+ spec = _resolve_reduce_spec(graph, node)
508
+ if spec is None:
509
+ input_shape = _value_shape(graph, node.inputs[0], node)
510
+ output_shape = _value_shape(graph, node.outputs[0], node)
511
+ return ReshapeOp(
512
+ input0=node.inputs[0],
513
+ output=node.outputs[0],
514
+ input_shape=input_shape,
515
+ output_shape=output_shape,
516
+ dtype=op_dtype,
517
+ input_dtype=op_dtype,
518
+ )
519
+ input_shape = _value_shape(graph, node.inputs[0], node)
520
+ if spec.axes_input and (
521
+ spec.axes_input_shape is None or spec.axes_input_dtype is None
522
+ ):
523
+ raise ShapeInferenceError(
524
+ f"{node.op_type} axes input must have a static shape and dtype"
525
+ )
526
+ return ReduceOp(
527
+ input0=node.inputs[0],
528
+ output=node.outputs[0],
529
+ input_shape=input_shape,
530
+ output_shape=spec.output_shape,
531
+ axes=spec.axes or (),
532
+ axes_input=spec.axes_input,
533
+ axes_input_shape=spec.axes_input_shape,
534
+ axes_input_dtype=spec.axes_input_dtype,
535
+ keepdims=spec.keepdims,
536
+ noop_with_empty_axes=bool(int(node.attrs.get("noop_with_empty_axes", 0))),
537
+ reduce_kind=REDUCE_KIND_BY_OP[node.op_type],
538
+ reduce_count=spec.reduce_count,
539
+ dtype=op_dtype,
540
+ )
541
+
542
+
543
+ for _op_type in REDUCE_KIND_BY_OP:
544
+ register_lowering(_op_type)(lower_reduce)
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping
4
+ from typing import TypeVar
5
+
6
+ from ..ir.model import Graph, Node
7
+ from ..errors import UnsupportedOpError
8
+
9
+ LoweredOp = TypeVar("LoweredOp")
10
+ Handler = TypeVar("Handler")
11
+
12
+ _LOWERING_REGISTRY: dict[str, Callable[[Graph, Node], object]] = {}
13
+
14
+
15
+ def register_lowering(
16
+ op_type: str,
17
+ ) -> Callable[[Callable[[Graph, Node], LoweredOp]], Callable[[Graph, Node], LoweredOp]]:
18
+ def decorator(
19
+ func: Callable[[Graph, Node], LoweredOp],
20
+ ) -> Callable[[Graph, Node], LoweredOp]:
21
+ _LOWERING_REGISTRY[op_type] = func
22
+ return func
23
+
24
+ return decorator
25
+
26
+
27
+ def get_lowering(op_type: str) -> Callable[[Graph, Node], object] | None:
28
+ return _LOWERING_REGISTRY.get(op_type)
29
+
30
+
31
+ def get_lowering_registry() -> Mapping[str, Callable[[Graph, Node], object]]:
32
+ return _LOWERING_REGISTRY
33
+
34
+
35
+ def resolve_dispatch(
36
+ op_type: str,
37
+ registry: Mapping[str, Handler],
38
+ *,
39
+ binary_types: set[str],
40
+ unary_types: set[str],
41
+ binary_fallback: Callable[[], Handler],
42
+ unary_fallback: Callable[[], Handler],
43
+ ) -> Handler:
44
+ handler = registry.get(op_type)
45
+ if handler is not None:
46
+ return handler
47
+ if op_type in binary_types:
48
+ return binary_fallback()
49
+ if op_type in unary_types:
50
+ return unary_fallback()
51
+ raise UnsupportedOpError(f"Unsupported op {op_type}")