emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import OneHotOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
15
+ for initializer in graph.initializers:
16
+ if initializer.name == name:
17
+ return initializer
18
+ return None
19
+
20
+
21
+ def _read_scalar_initializer(
22
+ graph: Graph, name: str, node: Node, label: str
23
+ ) -> int | None:
24
+ initializer = _find_initializer(graph, name)
25
+ if initializer is None:
26
+ return None
27
+ data = np.array(initializer.data)
28
+ if data.size != 1:
29
+ raise UnsupportedOpError(
30
+ f"{node.op_type} {label} input must be a scalar"
31
+ )
32
+ return int(data.reshape(-1)[0])
33
+
34
+
35
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
36
+ return shape == () or shape == (1,)
37
+
38
+
39
+ def _normalize_onehot_axis(axis: int, rank: int, node: Node) -> int:
40
+ if axis < 0:
41
+ axis += rank + 1
42
+ if axis < 0 or axis > rank:
43
+ raise ShapeInferenceError(
44
+ f"{node.op_type} axis {axis} is out of range for rank {rank}"
45
+ )
46
+ return axis
47
+
48
+
49
+ @register_lowering("OneHot")
50
+ def lower_onehot(graph: Graph, node: Node) -> OneHotOp:
51
+ if len(node.inputs) != 3 or len(node.outputs) != 1:
52
+ raise UnsupportedOpError("OneHot must have 3 inputs and 1 output")
53
+ indices_name, depth_name, values_name = node.inputs
54
+ indices_shape = value_shape(graph, indices_name, node)
55
+ depth_shape = value_shape(graph, depth_name, node)
56
+ values_shape = value_shape(graph, values_name, node)
57
+ output_shape = value_shape(graph, node.outputs[0], node)
58
+ if not _is_scalar_shape(depth_shape):
59
+ raise UnsupportedOpError("OneHot depth input must be a scalar")
60
+ if len(values_shape) != 1 or values_shape[0] != 2:
61
+ raise UnsupportedOpError(
62
+ "OneHot values input must be a 1D tensor of size 2"
63
+ )
64
+ output_rank = len(indices_shape) + 1
65
+ if len(output_shape) != output_rank:
66
+ raise ShapeInferenceError(
67
+ f"OneHot output rank must be {output_rank}, got {len(output_shape)}"
68
+ )
69
+ axis = _normalize_onehot_axis(
70
+ int(node.attrs.get("axis", -1)), len(indices_shape), node
71
+ )
72
+ depth_value = _read_scalar_initializer(graph, depth_name, node, "depth")
73
+ if depth_value is not None:
74
+ if depth_value < 0:
75
+ raise ShapeInferenceError("OneHot depth must be non-negative")
76
+ if output_shape[axis] != depth_value:
77
+ raise ShapeInferenceError(
78
+ "OneHot output depth must be "
79
+ f"{depth_value}, got {output_shape[axis]}"
80
+ )
81
+ depth_dim = depth_value
82
+ else:
83
+ depth_dim = output_shape[axis]
84
+ if depth_dim < 0:
85
+ raise ShapeInferenceError("OneHot output depth must be non-negative")
86
+ expected_output_shape = (
87
+ indices_shape[:axis] + (depth_dim,) + indices_shape[axis:]
88
+ )
89
+ if output_shape != expected_output_shape:
90
+ raise ShapeInferenceError(
91
+ "OneHot output shape must be "
92
+ f"{expected_output_shape}, got {output_shape}"
93
+ )
94
+ indices_dtype = value_dtype(graph, indices_name, node)
95
+ depth_dtype = value_dtype(graph, depth_name, node)
96
+ values_dtype = value_dtype(graph, values_name, node)
97
+ output_dtype = value_dtype(graph, node.outputs[0], node)
98
+ if indices_dtype.is_bool:
99
+ raise UnsupportedOpError("OneHot indices must be numeric")
100
+ if depth_dtype.is_bool:
101
+ raise UnsupportedOpError("OneHot depth must be numeric")
102
+ if values_dtype != output_dtype:
103
+ raise UnsupportedOpError(
104
+ "OneHot values dtype must match output dtype, "
105
+ f"got {values_dtype.onnx_name} and {output_dtype.onnx_name}"
106
+ )
107
+ return OneHotOp(
108
+ indices=indices_name,
109
+ depth=depth_name,
110
+ values=values_name,
111
+ output=node.outputs[0],
112
+ axis=axis,
113
+ indices_shape=indices_shape,
114
+ values_shape=values_shape,
115
+ output_shape=output_shape,
116
+ depth_dim=depth_dim,
117
+ dtype=values_dtype,
118
+ indices_dtype=indices_dtype,
119
+ depth_dtype=depth_dtype,
120
+ )
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..dtypes import scalar_type_from_onnx
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from ..validation import normalize_axis
11
+ from .common import optional_name, value_dtype as _value_dtype, value_shape as _value_shape
12
+ from .registry import register_lowering
13
+ from ..codegen.c_emitter import QuantizeLinearOp
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class QuantizeSpec:
18
+ input_shape: tuple[int, ...]
19
+ scale_shape: tuple[int, ...]
20
+ axis: int | None
21
+ output_dtype: ScalarType
22
+
23
+
24
+ def _resolve_output_dtype(
25
+ graph: Graph, node: Node, zero_point_name: str | None
26
+ ) -> ScalarType:
27
+ output_attr = int(node.attrs.get("output_dtype", 0))
28
+ if output_attr:
29
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
30
+ attr_dtype = scalar_type_from_onnx(output_attr)
31
+ if attr_dtype is None:
32
+ raise UnsupportedOpError(
33
+ "QuantizeLinear output_dtype must map to a supported scalar type"
34
+ )
35
+ if output_dtype != attr_dtype:
36
+ raise UnsupportedOpError(
37
+ "QuantizeLinear output_dtype must match output tensor dtype"
38
+ )
39
+ return output_dtype
40
+ if zero_point_name is None:
41
+ return ScalarType.U8
42
+ return _value_dtype(graph, zero_point_name, node)
43
+
44
+
45
+ def resolve_quantize_spec(graph: Graph, node: Node) -> QuantizeSpec:
46
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
47
+ raise UnsupportedOpError(
48
+ "QuantizeLinear must have 2 or 3 inputs and 1 output"
49
+ )
50
+ supported_attrs = {"axis", "block_size", "output_dtype", "precision", "saturate"}
51
+ if set(node.attrs) - supported_attrs:
52
+ raise UnsupportedOpError("QuantizeLinear has unsupported attributes")
53
+ block_size = int(node.attrs.get("block_size", 0))
54
+ if block_size != 0:
55
+ raise UnsupportedOpError("QuantizeLinear block_size is not supported")
56
+ precision = int(node.attrs.get("precision", 0))
57
+ if precision != 0:
58
+ raise UnsupportedOpError("QuantizeLinear precision is not supported")
59
+ saturate = int(node.attrs.get("saturate", 1))
60
+ if saturate != 1:
61
+ raise UnsupportedOpError("QuantizeLinear saturate must be 1")
62
+ input_shape = _value_shape(graph, node.inputs[0], node)
63
+ scale_shape = _value_shape(graph, node.inputs[1], node)
64
+ zero_point_name = optional_name(node.inputs, 2)
65
+ output_dtype = _resolve_output_dtype(graph, node, zero_point_name)
66
+ if output_dtype not in {
67
+ ScalarType.U8,
68
+ ScalarType.I8,
69
+ ScalarType.U16,
70
+ ScalarType.I16,
71
+ }:
72
+ raise UnsupportedOpError(
73
+ "QuantizeLinear supports int8/uint8/int16/uint16 outputs only"
74
+ )
75
+ if zero_point_name is not None:
76
+ zero_point_dtype = _value_dtype(graph, zero_point_name, node)
77
+ if zero_point_dtype != output_dtype:
78
+ raise UnsupportedOpError(
79
+ "QuantizeLinear zero_point dtype must match output dtype"
80
+ )
81
+ zero_point_shape = _value_shape(graph, zero_point_name, node)
82
+ if zero_point_shape != scale_shape:
83
+ raise ShapeInferenceError(
84
+ "QuantizeLinear zero_point shape must match scale shape"
85
+ )
86
+ if scale_shape not in {(), (1,)}:
87
+ if len(scale_shape) != 1:
88
+ raise UnsupportedOpError(
89
+ "QuantizeLinear supports per-tensor and per-axis scales only"
90
+ )
91
+ axis = int(node.attrs.get("axis", 1))
92
+ axis = normalize_axis(axis, input_shape, node)
93
+ if scale_shape[0] != input_shape[axis]:
94
+ raise ShapeInferenceError(
95
+ "QuantizeLinear scale length must match input axis size"
96
+ )
97
+ else:
98
+ axis = None
99
+ return QuantizeSpec(
100
+ input_shape=input_shape,
101
+ scale_shape=scale_shape,
102
+ axis=axis,
103
+ output_dtype=output_dtype,
104
+ )
105
+
106
+
107
+ @register_lowering("QuantizeLinear")
108
+ def lower_quantize_linear(graph: Graph, node: Node) -> QuantizeLinearOp:
109
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
110
+ scale_dtype = _value_dtype(graph, node.inputs[1], node)
111
+ if not op_dtype.is_float or not scale_dtype.is_float:
112
+ raise UnsupportedOpError(
113
+ "QuantizeLinear supports float16/float/double inputs only"
114
+ )
115
+ spec = resolve_quantize_spec(graph, node)
116
+ return QuantizeLinearOp(
117
+ input0=node.inputs[0],
118
+ scale=node.inputs[1],
119
+ zero_point=optional_name(node.inputs, 2),
120
+ output=node.outputs[0],
121
+ input_shape=spec.input_shape,
122
+ axis=spec.axis,
123
+ dtype=spec.output_dtype,
124
+ input_dtype=op_dtype,
125
+ scale_dtype=scale_dtype,
126
+ )
@@ -261,13 +261,12 @@ def _infer_axes_from_shapes(
261
261
  if out_dim == in_dim:
262
262
  if in_dim == 1:
263
263
  return None
264
- continue
265
- if out_dim == 1 and in_dim != 1:
264
+ elif out_dim == 1 and in_dim != 1:
266
265
  axes.append(axis)
267
- continue
268
- raise ShapeInferenceError(
269
- f"{node.op_type} output shape does not match input shape"
270
- )
266
+ else:
267
+ raise ShapeInferenceError(
268
+ f"{node.op_type} output shape does not match input shape"
269
+ )
271
270
  return tuple(axes)
272
271
  if len(output_shape) > len(input_shape):
273
272
  return None
@@ -5,6 +5,7 @@ from shared.scalar_types import ScalarType
5
5
  from ..codegen.c_emitter import ReshapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
+ from .common import value_shape as resolved_value_shape
8
9
  from .registry import register_lowering
9
10
 
10
11
 
@@ -37,6 +38,21 @@ def _shape_product(shape: tuple[int, ...]) -> int:
37
38
  return product
38
39
 
39
40
 
41
+ def _reshape_mismatch_error(
42
+ node: Node,
43
+ input_shape: tuple[int, ...],
44
+ output_shape: tuple[int, ...],
45
+ ) -> ShapeInferenceError:
46
+ node_name = node.name or "<unnamed>"
47
+ return ShapeInferenceError(
48
+ "Reshape input/output element counts must match for op "
49
+ f"{node.op_type} (node '{node_name}'): input shape {input_shape}, "
50
+ f"output shape {output_shape}. "
51
+ "Hint: ensure the reshape target has the same number of elements as "
52
+ "the input."
53
+ )
54
+
55
+
40
56
  def _find_initializer(graph: Graph, name: str) -> Initializer | None:
41
57
  for initializer in graph.initializers:
42
58
  if initializer.name == name:
@@ -52,15 +68,190 @@ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
52
68
 
53
69
 
54
70
  def _shape_values_from_shape_node(
55
- graph: Graph, name: str, node: Node
56
- ) -> list[int] | None:
57
- shape_node = _find_node_by_output(graph, name)
58
- if shape_node is None or shape_node.op_type != "Shape":
59
- return None
71
+ graph: Graph, shape_node: Node, node: Node
72
+ ) -> list[int]:
60
73
  if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
61
74
  raise UnsupportedOpError("Shape must have 1 input and 1 output")
62
75
  source_shape = _value_shape(graph, shape_node.inputs[0], node)
63
- return list(source_shape)
76
+ start = int(shape_node.attrs.get("start", 0))
77
+ end = int(shape_node.attrs.get("end", len(source_shape)))
78
+ if start < 0:
79
+ start += len(source_shape)
80
+ if end < 0:
81
+ end += len(source_shape)
82
+ start = max(start, 0)
83
+ end = min(end, len(source_shape))
84
+ if start > end:
85
+ return []
86
+ return list(source_shape[start:end])
87
+
88
+
89
+ def _shape_values_from_initializer(
90
+ graph: Graph,
91
+ name: str,
92
+ ) -> list[int] | None:
93
+ initializer = _find_initializer(graph, name)
94
+ if initializer is None:
95
+ return None
96
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
97
+ raise UnsupportedOpError(
98
+ "Reshape expects int64 or int32 shape input, "
99
+ f"got {initializer.type.dtype.onnx_name}"
100
+ )
101
+ return [int(value) for value in initializer.data.reshape(-1)]
102
+
103
+
104
+ def _shape_values_from_input(
105
+ graph: Graph,
106
+ name: str,
107
+ node: Node,
108
+ *,
109
+ _visited: set[str] | None = None,
110
+ ) -> list[int] | None:
111
+ if _visited is None:
112
+ _visited = set()
113
+ if name in _visited:
114
+ return None
115
+ _visited.add(name)
116
+ try:
117
+ shape_values = _shape_values_from_initializer(graph, name)
118
+ if shape_values is not None:
119
+ return shape_values
120
+ source_node = _find_node_by_output(graph, name)
121
+ if source_node is None:
122
+ return None
123
+ if source_node.op_type == "Shape":
124
+ return _shape_values_from_shape_node(graph, source_node, node)
125
+ if source_node.op_type == "Concat":
126
+ axis = int(source_node.attrs.get("axis", 0))
127
+ if axis != 0:
128
+ raise UnsupportedOpError("Reshape shape concat must use axis 0")
129
+ values: list[int] = []
130
+ for input_name in source_node.inputs:
131
+ input_values = _shape_values_from_input(
132
+ graph,
133
+ input_name,
134
+ node,
135
+ _visited=_visited,
136
+ )
137
+ if input_values is None:
138
+ return None
139
+ values.extend(input_values)
140
+ return values
141
+ if source_node.op_type == "Cast":
142
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
143
+ raise UnsupportedOpError("Cast must have 1 input and 1 output")
144
+ return _shape_values_from_input(
145
+ graph,
146
+ source_node.inputs[0],
147
+ node,
148
+ _visited=_visited,
149
+ )
150
+ if source_node.op_type == "Unsqueeze":
151
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
152
+ raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
153
+ return _shape_values_from_input(
154
+ graph,
155
+ source_node.inputs[0],
156
+ node,
157
+ _visited=_visited,
158
+ )
159
+ if source_node.op_type == "Identity":
160
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
161
+ raise UnsupportedOpError("Identity must have 1 input and 1 output")
162
+ return _shape_values_from_input(
163
+ graph,
164
+ source_node.inputs[0],
165
+ node,
166
+ _visited=_visited,
167
+ )
168
+ if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
169
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
170
+ raise UnsupportedOpError(
171
+ f"{source_node.op_type} must have 2 inputs and 1 output"
172
+ )
173
+ left = _shape_values_from_input(
174
+ graph,
175
+ source_node.inputs[0],
176
+ node,
177
+ _visited=_visited,
178
+ )
179
+ right = _shape_values_from_input(
180
+ graph,
181
+ source_node.inputs[1],
182
+ node,
183
+ _visited=_visited,
184
+ )
185
+ if left is None or right is None:
186
+ return None
187
+ if len(left) == 1 and len(right) != 1:
188
+ left = left * len(right)
189
+ if len(right) == 1 and len(left) != 1:
190
+ right = right * len(left)
191
+ if len(left) != len(right):
192
+ return None
193
+ if source_node.op_type == "Equal":
194
+ return [1 if l == r else 0 for l, r in zip(left, right)]
195
+ if source_node.op_type == "And":
196
+ return [1 if (l and r) else 0 for l, r in zip(left, right)]
197
+ if source_node.op_type == "Or":
198
+ return [1 if (l or r) else 0 for l, r in zip(left, right)]
199
+ if source_node.op_type == "Div":
200
+ return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
201
+ if source_node.op_type == "Mod":
202
+ return [l % r if r != 0 else 0 for l, r in zip(left, right)]
203
+ if source_node.op_type == "Not":
204
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
205
+ raise UnsupportedOpError("Not must have 1 input and 1 output")
206
+ values = _shape_values_from_input(
207
+ graph,
208
+ source_node.inputs[0],
209
+ node,
210
+ _visited=_visited,
211
+ )
212
+ if values is None:
213
+ return None
214
+ return [0 if value else 1 for value in values]
215
+ if source_node.op_type == "Where":
216
+ if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
217
+ raise UnsupportedOpError("Where must have 3 inputs and 1 output")
218
+ condition = _shape_values_from_input(
219
+ graph,
220
+ source_node.inputs[0],
221
+ node,
222
+ _visited=_visited,
223
+ )
224
+ if condition is None:
225
+ return None
226
+ on_true = _shape_values_from_input(
227
+ graph,
228
+ source_node.inputs[1],
229
+ node,
230
+ _visited=_visited,
231
+ )
232
+ on_false = _shape_values_from_input(
233
+ graph,
234
+ source_node.inputs[2],
235
+ node,
236
+ _visited=_visited,
237
+ )
238
+ if on_true is None or on_false is None:
239
+ return None
240
+ if len(condition) == 1:
241
+ condition = condition * max(len(on_true), len(on_false))
242
+ if len(on_true) == 1 and len(condition) != 1:
243
+ on_true = on_true * len(condition)
244
+ if len(on_false) == 1 and len(condition) != 1:
245
+ on_false = on_false * len(condition)
246
+ if not (len(condition) == len(on_true) == len(on_false)):
247
+ return None
248
+ return [
249
+ t if cond else f
250
+ for cond, t, f in zip(condition, on_true, on_false)
251
+ ]
252
+ return None
253
+ finally:
254
+ _visited.remove(name)
64
255
 
65
256
 
66
257
  def _resolve_target_shape(
@@ -82,19 +273,19 @@ def _resolve_target_shape(
82
273
  raise ShapeInferenceError("Reshape allows only one -1 dimension")
83
274
  unknown_index = index
84
275
  output_dims.append(-1)
85
- continue
86
- if dim == 0:
87
- contains_zero = True
88
- if allowzero == 0:
89
- if index >= len(input_shape):
90
- raise ShapeInferenceError(
91
- "Reshape zero dim must index into input shape"
92
- )
93
- dim = input_shape[index]
94
- if dim < 0:
95
- raise ShapeInferenceError("Reshape dims must be >= -1")
96
- output_dims.append(dim)
97
- known_product *= dim
276
+ else:
277
+ if dim == 0:
278
+ contains_zero = True
279
+ if allowzero == 0:
280
+ if index >= len(input_shape):
281
+ raise ShapeInferenceError(
282
+ "Reshape zero dim must index into input shape"
283
+ )
284
+ dim = input_shape[index]
285
+ if dim < 0:
286
+ raise ShapeInferenceError("Reshape dims must be >= -1")
287
+ output_dims.append(dim)
288
+ known_product *= dim
98
289
  if allowzero == 1 and contains_zero and unknown_index is not None:
99
290
  raise ShapeInferenceError(
100
291
  "Reshape allowzero cannot combine zero and -1 dimensions"
@@ -115,9 +306,7 @@ def _resolve_target_shape(
115
306
  output_dims[unknown_index] = input_product // known_product
116
307
  output_shape = tuple(output_dims)
117
308
  if _shape_product(output_shape) != input_product:
118
- raise ShapeInferenceError(
119
- "Reshape input and output element counts must match"
120
- )
309
+ raise _reshape_mismatch_error(node, input_shape, output_shape)
121
310
  return output_shape
122
311
 
123
312
 
@@ -125,7 +314,7 @@ def _resolve_target_shape(
125
314
  def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
126
315
  if len(node.inputs) != 2 or len(node.outputs) != 1:
127
316
  raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
128
- input_shape = _value_shape(graph, node.inputs[0], node)
317
+ input_shape = resolved_value_shape(graph, node.inputs[0], node)
129
318
  input_dtype = _value_dtype(graph, node.inputs[0], node)
130
319
  output_dtype = _value_dtype(graph, node.outputs[0], node)
131
320
  if input_dtype != output_dtype:
@@ -133,46 +322,29 @@ def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
133
322
  "Reshape expects matching input/output dtypes, "
134
323
  f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
135
324
  )
136
- output_shape = _value_shape(graph, node.outputs[0], node)
325
+ output_value = graph.find_value(node.outputs[0])
326
+ output_shape = resolved_value_shape(graph, node.outputs[0], node)
327
+ output_dim_params = output_value.type.dim_params
137
328
  allowzero = int(node.attrs.get("allowzero", 0))
138
- shape_initializer = _find_initializer(graph, node.inputs[1])
139
329
  resolved_shape: tuple[int, ...] | None = None
140
- if shape_initializer is None:
141
- shape_values = _shape_values_from_shape_node(
142
- graph, node.inputs[1], node
143
- )
144
- if shape_values is not None:
145
- resolved_shape = _resolve_target_shape(
146
- input_shape,
147
- shape_values,
148
- allowzero=allowzero,
149
- node=node,
150
- )
151
- else:
152
- if _shape_product(output_shape) != _shape_product(input_shape):
153
- raise ShapeInferenceError(
154
- "Reshape input and output element counts must match"
155
- )
156
- else:
157
- if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
158
- raise UnsupportedOpError(
159
- "Reshape expects int64 or int32 shape input, "
160
- f"got {shape_initializer.type.dtype.onnx_name}"
161
- )
162
- if len(shape_initializer.type.shape) != 1:
163
- raise UnsupportedOpError("Reshape expects a 1D shape input")
164
- shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
330
+ shape_values = _shape_values_from_input(graph, node.inputs[1], node)
331
+ if shape_values is not None:
165
332
  resolved_shape = _resolve_target_shape(
166
333
  input_shape,
167
334
  shape_values,
168
335
  allowzero=allowzero,
169
336
  node=node,
170
337
  )
171
- if output_shape and resolved_shape != output_shape:
338
+ if output_shape and resolved_shape != output_shape and not any(
339
+ output_dim_params
340
+ ):
172
341
  raise ShapeInferenceError(
173
342
  "Reshape output shape must be "
174
343
  f"{resolved_shape}, got {output_shape}"
175
344
  )
345
+ else:
346
+ if _shape_product(output_shape) != _shape_product(input_shape):
347
+ raise _reshape_mismatch_error(node, input_shape, output_shape)
176
348
  if resolved_shape is not None:
177
349
  output_shape = resolved_shape
178
350
  for dim in output_shape:
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import ScatterNDOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+ _ALLOWED_REDUCTIONS = {"none", "add", "mul", "min", "max"}
12
+
13
+
14
+ @register_lowering("ScatterND")
15
+ def lower_scatternd(graph: Graph, node: Node) -> ScatterNDOp:
16
+ if len(node.inputs) != 3 or len(node.outputs) != 1:
17
+ raise UnsupportedOpError("ScatterND must have 3 inputs and 1 output")
18
+ data_name, indices_name, updates_name = node.inputs
19
+ output_name = node.outputs[0]
20
+ data_shape = value_shape(graph, data_name, node)
21
+ indices_shape = value_shape(graph, indices_name, node)
22
+ updates_shape = value_shape(graph, updates_name, node)
23
+ output_shape = value_shape(graph, output_name, node)
24
+ if output_shape != data_shape:
25
+ raise ShapeInferenceError(
26
+ "ScatterND output shape must match data shape, "
27
+ f"got {output_shape} vs {data_shape}"
28
+ )
29
+ if len(indices_shape) < 1:
30
+ raise ShapeInferenceError("ScatterND indices must have rank >= 1")
31
+ index_depth = indices_shape[-1]
32
+ if index_depth <= 0:
33
+ raise ShapeInferenceError(
34
+ "ScatterND indices final dimension must be >= 1"
35
+ )
36
+ if index_depth > len(data_shape):
37
+ raise ShapeInferenceError(
38
+ "ScatterND indices final dimension must be <= data rank, "
39
+ f"got {index_depth} vs {len(data_shape)}"
40
+ )
41
+ expected_updates_shape = indices_shape[:-1] + data_shape[index_depth:]
42
+ if updates_shape != expected_updates_shape:
43
+ raise ShapeInferenceError(
44
+ "ScatterND updates shape must be "
45
+ f"{expected_updates_shape}, got {updates_shape}"
46
+ )
47
+ data_dtype = value_dtype(graph, data_name, node)
48
+ updates_dtype = value_dtype(graph, updates_name, node)
49
+ if updates_dtype != data_dtype:
50
+ raise UnsupportedOpError(
51
+ "ScatterND updates dtype must match data dtype, "
52
+ f"got {updates_dtype.onnx_name} vs {data_dtype.onnx_name}"
53
+ )
54
+ indices_dtype = value_dtype(graph, indices_name, node)
55
+ if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
56
+ raise UnsupportedOpError(
57
+ "ScatterND indices must be int32 or int64, "
58
+ f"got {indices_dtype.onnx_name}"
59
+ )
60
+ reduction_attr = node.attrs.get("reduction", "none")
61
+ if isinstance(reduction_attr, bytes):
62
+ reduction = reduction_attr.decode()
63
+ else:
64
+ reduction = str(reduction_attr)
65
+ if reduction not in _ALLOWED_REDUCTIONS:
66
+ raise UnsupportedOpError(
67
+ "ScatterND reduction must be one of "
68
+ f"{sorted(_ALLOWED_REDUCTIONS)}, got {reduction}"
69
+ )
70
+ return ScatterNDOp(
71
+ data=data_name,
72
+ indices=indices_name,
73
+ updates=updates_name,
74
+ output=output_name,
75
+ data_shape=data_shape,
76
+ indices_shape=indices_shape,
77
+ updates_shape=updates_shape,
78
+ output_shape=output_shape,
79
+ reduction=reduction,
80
+ dtype=data_dtype,
81
+ indices_dtype=indices_dtype,
82
+ )