emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import numpy as np
6
6
 
7
7
  from shared.scalar_types import ScalarType
8
8
 
9
- from ..codegen.c_emitter import ReduceOp, ReshapeOp
9
+ from ..ir.ops import ReduceOp, ReshapeOp
10
10
  from ..dtypes import scalar_type_from_onnx
11
11
  from ..errors import ShapeInferenceError, UnsupportedOpError
12
12
  from ..ir.model import Graph, Initializer, Node
@@ -261,13 +261,12 @@ def _infer_axes_from_shapes(
261
261
  if out_dim == in_dim:
262
262
  if in_dim == 1:
263
263
  return None
264
- continue
265
- if out_dim == 1 and in_dim != 1:
264
+ elif out_dim == 1 and in_dim != 1:
266
265
  axes.append(axis)
267
- continue
268
- raise ShapeInferenceError(
269
- f"{node.op_type} output shape does not match input shape"
270
- )
266
+ else:
267
+ raise ShapeInferenceError(
268
+ f"{node.op_type} output shape does not match input shape"
269
+ )
271
270
  return tuple(axes)
272
271
  if len(output_shape) > len(input_shape):
273
272
  return None
@@ -3,32 +3,51 @@ from __future__ import annotations
3
3
  from collections.abc import Callable, Mapping
4
4
  from typing import TypeVar
5
5
 
6
+ from ..ir.context import GraphContext
6
7
  from ..ir.model import Graph, Node
8
+ from ..ir.op_base import OpBase
7
9
  from ..errors import UnsupportedOpError
8
10
 
9
11
  LoweredOp = TypeVar("LoweredOp")
10
12
  Handler = TypeVar("Handler")
11
13
 
12
- _LOWERING_REGISTRY: dict[str, Callable[[Graph, Node], object]] = {}
14
+ _LOWERING_REGISTRY: dict[str, Callable[[Graph | GraphContext, Node], OpBase]] = {}
13
15
 
14
16
 
15
17
  def register_lowering(
16
18
  op_type: str,
17
19
  ) -> Callable[[Callable[[Graph, Node], LoweredOp]], Callable[[Graph, Node], LoweredOp]]:
18
20
  def decorator(
19
- func: Callable[[Graph, Node], LoweredOp],
20
- ) -> Callable[[Graph, Node], LoweredOp]:
21
+ func: Callable[[Graph | GraphContext, Node], LoweredOp],
22
+ ) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
21
23
  _LOWERING_REGISTRY[op_type] = func
22
24
  return func
23
25
 
24
26
  return decorator
25
27
 
26
28
 
27
- def get_lowering(op_type: str) -> Callable[[Graph, Node], object] | None:
29
+ def register_lowering_if_missing(
30
+ op_type: str,
31
+ ) -> Callable[[Callable[[Graph | GraphContext, Node], LoweredOp]], Callable[[Graph | GraphContext, Node], LoweredOp]]:
32
+ def decorator(
33
+ func: Callable[[Graph | GraphContext, Node], LoweredOp],
34
+ ) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
35
+ if op_type not in _LOWERING_REGISTRY:
36
+ _LOWERING_REGISTRY[op_type] = func
37
+ return func
38
+
39
+ return decorator
40
+
41
+
42
+ def get_lowering(
43
+ op_type: str,
44
+ ) -> Callable[[Graph | GraphContext, Node], OpBase] | None:
28
45
  return _LOWERING_REGISTRY.get(op_type)
29
46
 
30
47
 
31
- def get_lowering_registry() -> Mapping[str, Callable[[Graph, Node], object]]:
48
+ def get_lowering_registry() -> Mapping[
49
+ str, Callable[[Graph | GraphContext, Node], OpBase]
50
+ ]:
32
51
  return _LOWERING_REGISTRY
33
52
 
34
53
 
@@ -2,9 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ReshapeOp
5
+ from ..ir.ops import ReshapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
+ from .common import value_shape as resolved_value_shape
8
9
  from .registry import register_lowering
9
10
 
10
11
 
@@ -37,6 +38,21 @@ def _shape_product(shape: tuple[int, ...]) -> int:
37
38
  return product
38
39
 
39
40
 
41
+ def _reshape_mismatch_error(
42
+ node: Node,
43
+ input_shape: tuple[int, ...],
44
+ output_shape: tuple[int, ...],
45
+ ) -> ShapeInferenceError:
46
+ node_name = node.name or "<unnamed>"
47
+ return ShapeInferenceError(
48
+ "Reshape input/output element counts must match for op "
49
+ f"{node.op_type} (node '{node_name}'): input shape {input_shape}, "
50
+ f"output shape {output_shape}. "
51
+ "Hint: ensure the reshape target has the same number of elements as "
52
+ "the input."
53
+ )
54
+
55
+
40
56
  def _find_initializer(graph: Graph, name: str) -> Initializer | None:
41
57
  for initializer in graph.initializers:
42
58
  if initializer.name == name:
@@ -52,15 +68,190 @@ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
52
68
 
53
69
 
54
70
  def _shape_values_from_shape_node(
55
- graph: Graph, name: str, node: Node
56
- ) -> list[int] | None:
57
- shape_node = _find_node_by_output(graph, name)
58
- if shape_node is None or shape_node.op_type != "Shape":
59
- return None
71
+ graph: Graph, shape_node: Node, node: Node
72
+ ) -> list[int]:
60
73
  if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
61
74
  raise UnsupportedOpError("Shape must have 1 input and 1 output")
62
75
  source_shape = _value_shape(graph, shape_node.inputs[0], node)
63
- return list(source_shape)
76
+ start = int(shape_node.attrs.get("start", 0))
77
+ end = int(shape_node.attrs.get("end", len(source_shape)))
78
+ if start < 0:
79
+ start += len(source_shape)
80
+ if end < 0:
81
+ end += len(source_shape)
82
+ start = max(start, 0)
83
+ end = min(end, len(source_shape))
84
+ if start > end:
85
+ return []
86
+ return list(source_shape[start:end])
87
+
88
+
89
+ def _shape_values_from_initializer(
90
+ graph: Graph,
91
+ name: str,
92
+ ) -> list[int] | None:
93
+ initializer = _find_initializer(graph, name)
94
+ if initializer is None:
95
+ return None
96
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
97
+ raise UnsupportedOpError(
98
+ "Reshape expects int64 or int32 shape input, "
99
+ f"got {initializer.type.dtype.onnx_name}"
100
+ )
101
+ return [int(value) for value in initializer.data.reshape(-1)]
102
+
103
+
104
+ def _shape_values_from_input(
105
+ graph: Graph,
106
+ name: str,
107
+ node: Node,
108
+ *,
109
+ _visited: set[str] | None = None,
110
+ ) -> list[int] | None:
111
+ if _visited is None:
112
+ _visited = set()
113
+ if name in _visited:
114
+ return None
115
+ _visited.add(name)
116
+ try:
117
+ shape_values = _shape_values_from_initializer(graph, name)
118
+ if shape_values is not None:
119
+ return shape_values
120
+ source_node = _find_node_by_output(graph, name)
121
+ if source_node is None:
122
+ return None
123
+ if source_node.op_type == "Shape":
124
+ return _shape_values_from_shape_node(graph, source_node, node)
125
+ if source_node.op_type == "Concat":
126
+ axis = int(source_node.attrs.get("axis", 0))
127
+ if axis != 0:
128
+ raise UnsupportedOpError("Reshape shape concat must use axis 0")
129
+ values: list[int] = []
130
+ for input_name in source_node.inputs:
131
+ input_values = _shape_values_from_input(
132
+ graph,
133
+ input_name,
134
+ node,
135
+ _visited=_visited,
136
+ )
137
+ if input_values is None:
138
+ return None
139
+ values.extend(input_values)
140
+ return values
141
+ if source_node.op_type == "Cast":
142
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
143
+ raise UnsupportedOpError("Cast must have 1 input and 1 output")
144
+ return _shape_values_from_input(
145
+ graph,
146
+ source_node.inputs[0],
147
+ node,
148
+ _visited=_visited,
149
+ )
150
+ if source_node.op_type == "Unsqueeze":
151
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
152
+ raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
153
+ return _shape_values_from_input(
154
+ graph,
155
+ source_node.inputs[0],
156
+ node,
157
+ _visited=_visited,
158
+ )
159
+ if source_node.op_type == "Identity":
160
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
161
+ raise UnsupportedOpError("Identity must have 1 input and 1 output")
162
+ return _shape_values_from_input(
163
+ graph,
164
+ source_node.inputs[0],
165
+ node,
166
+ _visited=_visited,
167
+ )
168
+ if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
169
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
170
+ raise UnsupportedOpError(
171
+ f"{source_node.op_type} must have 2 inputs and 1 output"
172
+ )
173
+ left = _shape_values_from_input(
174
+ graph,
175
+ source_node.inputs[0],
176
+ node,
177
+ _visited=_visited,
178
+ )
179
+ right = _shape_values_from_input(
180
+ graph,
181
+ source_node.inputs[1],
182
+ node,
183
+ _visited=_visited,
184
+ )
185
+ if left is None or right is None:
186
+ return None
187
+ if len(left) == 1 and len(right) != 1:
188
+ left = left * len(right)
189
+ if len(right) == 1 and len(left) != 1:
190
+ right = right * len(left)
191
+ if len(left) != len(right):
192
+ return None
193
+ if source_node.op_type == "Equal":
194
+ return [1 if l == r else 0 for l, r in zip(left, right)]
195
+ if source_node.op_type == "And":
196
+ return [1 if (l and r) else 0 for l, r in zip(left, right)]
197
+ if source_node.op_type == "Or":
198
+ return [1 if (l or r) else 0 for l, r in zip(left, right)]
199
+ if source_node.op_type == "Div":
200
+ return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
201
+ if source_node.op_type == "Mod":
202
+ return [l % r if r != 0 else 0 for l, r in zip(left, right)]
203
+ if source_node.op_type == "Not":
204
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
205
+ raise UnsupportedOpError("Not must have 1 input and 1 output")
206
+ values = _shape_values_from_input(
207
+ graph,
208
+ source_node.inputs[0],
209
+ node,
210
+ _visited=_visited,
211
+ )
212
+ if values is None:
213
+ return None
214
+ return [0 if value else 1 for value in values]
215
+ if source_node.op_type == "Where":
216
+ if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
217
+ raise UnsupportedOpError("Where must have 3 inputs and 1 output")
218
+ condition = _shape_values_from_input(
219
+ graph,
220
+ source_node.inputs[0],
221
+ node,
222
+ _visited=_visited,
223
+ )
224
+ if condition is None:
225
+ return None
226
+ on_true = _shape_values_from_input(
227
+ graph,
228
+ source_node.inputs[1],
229
+ node,
230
+ _visited=_visited,
231
+ )
232
+ on_false = _shape_values_from_input(
233
+ graph,
234
+ source_node.inputs[2],
235
+ node,
236
+ _visited=_visited,
237
+ )
238
+ if on_true is None or on_false is None:
239
+ return None
240
+ if len(condition) == 1:
241
+ condition = condition * max(len(on_true), len(on_false))
242
+ if len(on_true) == 1 and len(condition) != 1:
243
+ on_true = on_true * len(condition)
244
+ if len(on_false) == 1 and len(condition) != 1:
245
+ on_false = on_false * len(condition)
246
+ if not (len(condition) == len(on_true) == len(on_false)):
247
+ return None
248
+ return [
249
+ t if cond else f
250
+ for cond, t, f in zip(condition, on_true, on_false)
251
+ ]
252
+ return None
253
+ finally:
254
+ _visited.remove(name)
64
255
 
65
256
 
66
257
  def _resolve_target_shape(
@@ -82,19 +273,19 @@ def _resolve_target_shape(
82
273
  raise ShapeInferenceError("Reshape allows only one -1 dimension")
83
274
  unknown_index = index
84
275
  output_dims.append(-1)
85
- continue
86
- if dim == 0:
87
- contains_zero = True
88
- if allowzero == 0:
89
- if index >= len(input_shape):
90
- raise ShapeInferenceError(
91
- "Reshape zero dim must index into input shape"
92
- )
93
- dim = input_shape[index]
94
- if dim < 0:
95
- raise ShapeInferenceError("Reshape dims must be >= -1")
96
- output_dims.append(dim)
97
- known_product *= dim
276
+ else:
277
+ if dim == 0:
278
+ contains_zero = True
279
+ if allowzero == 0:
280
+ if index >= len(input_shape):
281
+ raise ShapeInferenceError(
282
+ "Reshape zero dim must index into input shape"
283
+ )
284
+ dim = input_shape[index]
285
+ if dim < 0:
286
+ raise ShapeInferenceError("Reshape dims must be >= -1")
287
+ output_dims.append(dim)
288
+ known_product *= dim
98
289
  if allowzero == 1 and contains_zero and unknown_index is not None:
99
290
  raise ShapeInferenceError(
100
291
  "Reshape allowzero cannot combine zero and -1 dimensions"
@@ -115,9 +306,7 @@ def _resolve_target_shape(
115
306
  output_dims[unknown_index] = input_product // known_product
116
307
  output_shape = tuple(output_dims)
117
308
  if _shape_product(output_shape) != input_product:
118
- raise ShapeInferenceError(
119
- "Reshape input and output element counts must match"
120
- )
309
+ raise _reshape_mismatch_error(node, input_shape, output_shape)
121
310
  return output_shape
122
311
 
123
312
 
@@ -125,7 +314,7 @@ def _resolve_target_shape(
125
314
  def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
126
315
  if len(node.inputs) != 2 or len(node.outputs) != 1:
127
316
  raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
128
- input_shape = _value_shape(graph, node.inputs[0], node)
317
+ input_shape = resolved_value_shape(graph, node.inputs[0], node)
129
318
  input_dtype = _value_dtype(graph, node.inputs[0], node)
130
319
  output_dtype = _value_dtype(graph, node.outputs[0], node)
131
320
  if input_dtype != output_dtype:
@@ -133,46 +322,29 @@ def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
133
322
  "Reshape expects matching input/output dtypes, "
134
323
  f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
135
324
  )
136
- output_shape = _value_shape(graph, node.outputs[0], node)
325
+ output_value = graph.find_value(node.outputs[0])
326
+ output_shape = resolved_value_shape(graph, node.outputs[0], node)
327
+ output_dim_params = output_value.type.dim_params
137
328
  allowzero = int(node.attrs.get("allowzero", 0))
138
- shape_initializer = _find_initializer(graph, node.inputs[1])
139
329
  resolved_shape: tuple[int, ...] | None = None
140
- if shape_initializer is None:
141
- shape_values = _shape_values_from_shape_node(
142
- graph, node.inputs[1], node
143
- )
144
- if shape_values is not None:
145
- resolved_shape = _resolve_target_shape(
146
- input_shape,
147
- shape_values,
148
- allowzero=allowzero,
149
- node=node,
150
- )
151
- else:
152
- if _shape_product(output_shape) != _shape_product(input_shape):
153
- raise ShapeInferenceError(
154
- "Reshape input and output element counts must match"
155
- )
156
- else:
157
- if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
158
- raise UnsupportedOpError(
159
- "Reshape expects int64 or int32 shape input, "
160
- f"got {shape_initializer.type.dtype.onnx_name}"
161
- )
162
- if len(shape_initializer.type.shape) != 1:
163
- raise UnsupportedOpError("Reshape expects a 1D shape input")
164
- shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
330
+ shape_values = _shape_values_from_input(graph, node.inputs[1], node)
331
+ if shape_values is not None:
165
332
  resolved_shape = _resolve_target_shape(
166
333
  input_shape,
167
334
  shape_values,
168
335
  allowzero=allowzero,
169
336
  node=node,
170
337
  )
171
- if output_shape and resolved_shape != output_shape:
338
+ if output_shape and resolved_shape != output_shape and not any(
339
+ output_dim_params
340
+ ):
172
341
  raise ShapeInferenceError(
173
342
  "Reshape output shape must be "
174
343
  f"{resolved_shape}, got {output_shape}"
175
344
  )
345
+ else:
346
+ if _shape_product(output_shape) != _shape_product(input_shape):
347
+ raise _reshape_mismatch_error(node, input_shape, output_shape)
176
348
  if resolved_shape is not None:
177
349
  output_shape = resolved_shape
178
350
  for dim in output_shape:
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import ResizeOp
7
+ from ..ir.ops import ResizeOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from .registry import register_lowering
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import RMSNormalizationOp
3
+ from ..ir.ops import RMSNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..ir.ops import RotaryEmbeddingOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from .common import optional_name, value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class RotaryEmbeddingSpec:
16
+ batch: int
17
+ seq_len: int
18
+ num_heads: int
19
+ head_size: int
20
+ rotary_dim: int
21
+ rotary_dim_half: int
22
+ input_rank: int
23
+
24
+
25
+ def _resolve_rotary_spec(
26
+ graph: Graph, node: Node, dtype: ScalarType
27
+ ) -> RotaryEmbeddingSpec:
28
+ if not dtype.is_float:
29
+ raise UnsupportedOpError("Unsupported op RotaryEmbedding")
30
+ if len(node.inputs) < 3 or len(node.outputs) != 1:
31
+ raise UnsupportedOpError("Unsupported op RotaryEmbedding")
32
+ input_shape = value_shape(graph, node.inputs[0], node)
33
+ input_rank = len(input_shape)
34
+ if input_rank not in {3, 4}:
35
+ raise ShapeInferenceError("RotaryEmbedding expects 3D or 4D input")
36
+ if input_rank == 3:
37
+ num_heads_attr = node.attrs.get("num_heads")
38
+ if num_heads_attr is None:
39
+ raise UnsupportedOpError(
40
+ "RotaryEmbedding num_heads attribute is required for 3D inputs"
41
+ )
42
+ num_heads = int(num_heads_attr)
43
+ if num_heads <= 0:
44
+ raise ShapeInferenceError("RotaryEmbedding num_heads must be > 0")
45
+ batch, seq_len, hidden_size = input_shape
46
+ if hidden_size % num_heads != 0:
47
+ raise ShapeInferenceError(
48
+ "RotaryEmbedding hidden size must be divisible by num_heads"
49
+ )
50
+ head_size = hidden_size // num_heads
51
+ else:
52
+ batch, num_heads, seq_len, head_size = input_shape
53
+ num_heads_attr = node.attrs.get("num_heads")
54
+ if num_heads_attr is not None and int(num_heads_attr) != num_heads:
55
+ raise ShapeInferenceError(
56
+ "RotaryEmbedding num_heads must match input head dimension"
57
+ )
58
+ if head_size % 2 != 0:
59
+ raise ShapeInferenceError("RotaryEmbedding head size must be even")
60
+ rotary_dim = int(node.attrs.get("rotary_embedding_dim", 0))
61
+ if rotary_dim == 0:
62
+ rotary_dim = head_size
63
+ if rotary_dim < 0 or rotary_dim > head_size:
64
+ raise ShapeInferenceError(
65
+ "RotaryEmbedding rotary_embedding_dim must be in [0, head_size]"
66
+ )
67
+ if rotary_dim % 2 != 0:
68
+ raise ShapeInferenceError(
69
+ "RotaryEmbedding rotary_embedding_dim must be even"
70
+ )
71
+ rotary_dim_half = rotary_dim // 2
72
+ return RotaryEmbeddingSpec(
73
+ batch=batch,
74
+ seq_len=seq_len,
75
+ num_heads=num_heads,
76
+ head_size=head_size,
77
+ rotary_dim=rotary_dim,
78
+ rotary_dim_half=rotary_dim_half,
79
+ input_rank=input_rank,
80
+ )
81
+
82
+
83
+ @register_lowering("RotaryEmbedding")
84
+ def lower_rotary_embedding(graph: Graph, node: Node) -> RotaryEmbeddingOp:
85
+ input_name = node.inputs[0]
86
+ cos_name = node.inputs[1]
87
+ sin_name = node.inputs[2]
88
+ position_ids = optional_name(node.inputs, 3)
89
+ dtype = value_dtype(graph, input_name, node)
90
+ cos_dtype = value_dtype(graph, cos_name, node)
91
+ sin_dtype = value_dtype(graph, sin_name, node)
92
+ if cos_dtype != dtype or sin_dtype != dtype:
93
+ raise ShapeInferenceError(
94
+ "RotaryEmbedding inputs must share the same dtype"
95
+ )
96
+ spec = _resolve_rotary_spec(graph, node, dtype)
97
+ input_shape = value_shape(graph, input_name, node)
98
+ output_shape = value_shape(graph, node.outputs[0], node)
99
+ if output_shape != input_shape:
100
+ raise ShapeInferenceError(
101
+ "RotaryEmbedding output shape must match input shape"
102
+ )
103
+ cos_shape = value_shape(graph, cos_name, node)
104
+ sin_shape = value_shape(graph, sin_name, node)
105
+ if cos_shape != sin_shape:
106
+ raise ShapeInferenceError(
107
+ "RotaryEmbedding cos/sin cache shapes must match"
108
+ )
109
+ position_shape = None
110
+ position_dtype = None
111
+ if position_ids is not None:
112
+ position_shape = value_shape(graph, position_ids, node)
113
+ if position_shape != (spec.batch, spec.seq_len):
114
+ raise ShapeInferenceError(
115
+ "RotaryEmbedding position_ids must match [batch, seq_len]"
116
+ )
117
+ position_dtype = value_dtype(graph, position_ids, node)
118
+ if not position_dtype.is_integer:
119
+ raise ShapeInferenceError(
120
+ "RotaryEmbedding position_ids must be an integer tensor"
121
+ )
122
+ if len(cos_shape) != 2:
123
+ raise ShapeInferenceError(
124
+ "RotaryEmbedding expects 2D sin/cos caches with position_ids"
125
+ )
126
+ if cos_shape[1] != spec.rotary_dim_half:
127
+ raise ShapeInferenceError(
128
+ "RotaryEmbedding cos/sin cache last dim must match rotary_dim/2"
129
+ )
130
+ else:
131
+ if len(cos_shape) != 3:
132
+ raise ShapeInferenceError(
133
+ "RotaryEmbedding expects 3D sin/cos caches without position_ids"
134
+ )
135
+ if cos_shape != (
136
+ spec.batch,
137
+ spec.seq_len,
138
+ spec.rotary_dim_half,
139
+ ):
140
+ raise ShapeInferenceError(
141
+ "RotaryEmbedding sin/cos cache shape must be "
142
+ "[batch, seq_len, rotary_dim/2]"
143
+ )
144
+ interleaved = bool(int(node.attrs.get("interleaved", 0)))
145
+ return RotaryEmbeddingOp(
146
+ input0=input_name,
147
+ cos_cache=cos_name,
148
+ sin_cache=sin_name,
149
+ position_ids=position_ids,
150
+ output=node.outputs[0],
151
+ input_shape=input_shape,
152
+ cos_shape=cos_shape,
153
+ sin_shape=sin_shape,
154
+ position_ids_shape=position_shape,
155
+ dtype=dtype,
156
+ position_ids_dtype=position_dtype,
157
+ rotary_dim=spec.rotary_dim,
158
+ rotary_dim_half=spec.rotary_dim_half,
159
+ head_size=spec.head_size,
160
+ num_heads=spec.num_heads,
161
+ seq_len=spec.seq_len,
162
+ batch=spec.batch,
163
+ input_rank=spec.input_rank,
164
+ interleaved=interleaved,
165
+ )
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import ScatterNDOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+ _ALLOWED_REDUCTIONS = {"none", "add", "mul", "min", "max"}
12
+
13
+
14
+ @register_lowering("ScatterND")
15
+ def lower_scatternd(graph: Graph, node: Node) -> ScatterNDOp:
16
+ if len(node.inputs) != 3 or len(node.outputs) != 1:
17
+ raise UnsupportedOpError("ScatterND must have 3 inputs and 1 output")
18
+ data_name, indices_name, updates_name = node.inputs
19
+ output_name = node.outputs[0]
20
+ data_shape = value_shape(graph, data_name, node)
21
+ indices_shape = value_shape(graph, indices_name, node)
22
+ updates_shape = value_shape(graph, updates_name, node)
23
+ output_shape = value_shape(graph, output_name, node)
24
+ if output_shape != data_shape:
25
+ raise ShapeInferenceError(
26
+ "ScatterND output shape must match data shape, "
27
+ f"got {output_shape} vs {data_shape}"
28
+ )
29
+ if len(indices_shape) < 1:
30
+ raise ShapeInferenceError("ScatterND indices must have rank >= 1")
31
+ index_depth = indices_shape[-1]
32
+ if index_depth <= 0:
33
+ raise ShapeInferenceError(
34
+ "ScatterND indices final dimension must be >= 1"
35
+ )
36
+ if index_depth > len(data_shape):
37
+ raise ShapeInferenceError(
38
+ "ScatterND indices final dimension must be <= data rank, "
39
+ f"got {index_depth} vs {len(data_shape)}"
40
+ )
41
+ expected_updates_shape = indices_shape[:-1] + data_shape[index_depth:]
42
+ if updates_shape != expected_updates_shape:
43
+ raise ShapeInferenceError(
44
+ "ScatterND updates shape must be "
45
+ f"{expected_updates_shape}, got {updates_shape}"
46
+ )
47
+ data_dtype = value_dtype(graph, data_name, node)
48
+ updates_dtype = value_dtype(graph, updates_name, node)
49
+ if updates_dtype != data_dtype:
50
+ raise UnsupportedOpError(
51
+ "ScatterND updates dtype must match data dtype, "
52
+ f"got {updates_dtype.onnx_name} vs {data_dtype.onnx_name}"
53
+ )
54
+ indices_dtype = value_dtype(graph, indices_name, node)
55
+ if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
56
+ raise UnsupportedOpError(
57
+ "ScatterND indices must be int32 or int64, "
58
+ f"got {indices_dtype.onnx_name}"
59
+ )
60
+ reduction_attr = node.attrs.get("reduction", "none")
61
+ if isinstance(reduction_attr, bytes):
62
+ reduction = reduction_attr.decode()
63
+ else:
64
+ reduction = str(reduction_attr)
65
+ if reduction not in _ALLOWED_REDUCTIONS:
66
+ raise UnsupportedOpError(
67
+ "ScatterND reduction must be one of "
68
+ f"{sorted(_ALLOWED_REDUCTIONS)}, got {reduction}"
69
+ )
70
+ return ScatterNDOp(
71
+ data=data_name,
72
+ indices=indices_name,
73
+ updates=updates_name,
74
+ output=output_name,
75
+ data_shape=data_shape,
76
+ indices_shape=indices_shape,
77
+ updates_shape=updates_shape,
78
+ output_shape=output_shape,
79
+ reduction=reduction,
80
+ dtype=data_dtype,
81
+ indices_dtype=indices_dtype,
82
+ )