emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,188 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import ReshapeOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Initializer, Node
8
+ from .registry import register_lowering
9
+
10
+
11
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
12
+ try:
13
+ return graph.find_value(name).type.shape
14
+ except KeyError as exc:
15
+ raise ShapeInferenceError(
16
+ f"Missing shape for value '{name}' in op {node.op_type}. "
17
+ "Hint: run ONNX shape inference or export with static shapes."
18
+ ) from exc
19
+
20
+
21
+ def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
22
+ try:
23
+ return graph.find_value(name).type.dtype
24
+ except KeyError as exc:
25
+ raise ShapeInferenceError(
26
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
27
+ "Hint: run ONNX shape inference or export with static shapes."
28
+ ) from exc
29
+
30
+
31
+ def _shape_product(shape: tuple[int, ...]) -> int:
32
+ product = 1
33
+ for dim in shape:
34
+ if dim < 0:
35
+ raise ShapeInferenceError("Dynamic dims are not supported")
36
+ product *= dim
37
+ return product
38
+
39
+
40
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
41
+ for initializer in graph.initializers:
42
+ if initializer.name == name:
43
+ return initializer
44
+ return None
45
+
46
+
47
+ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
48
+ for node in graph.nodes:
49
+ if name in node.outputs:
50
+ return node
51
+ return None
52
+
53
+
54
+ def _shape_values_from_shape_node(
55
+ graph: Graph, name: str, node: Node
56
+ ) -> list[int] | None:
57
+ shape_node = _find_node_by_output(graph, name)
58
+ if shape_node is None or shape_node.op_type != "Shape":
59
+ return None
60
+ if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
61
+ raise UnsupportedOpError("Shape must have 1 input and 1 output")
62
+ source_shape = _value_shape(graph, shape_node.inputs[0], node)
63
+ return list(source_shape)
64
+
65
+
66
+ def _resolve_target_shape(
67
+ input_shape: tuple[int, ...],
68
+ shape_values: list[int],
69
+ *,
70
+ allowzero: int,
71
+ node: Node,
72
+ ) -> tuple[int, ...]:
73
+ if allowzero not in (0, 1):
74
+ raise UnsupportedOpError("Reshape allowzero must be 0 or 1")
75
+ output_dims: list[int] = []
76
+ unknown_index: int | None = None
77
+ known_product = 1
78
+ contains_zero = False
79
+ for index, dim in enumerate(shape_values):
80
+ if dim == -1:
81
+ if unknown_index is not None:
82
+ raise ShapeInferenceError("Reshape allows only one -1 dimension")
83
+ unknown_index = index
84
+ output_dims.append(-1)
85
+ continue
86
+ if dim == 0:
87
+ contains_zero = True
88
+ if allowzero == 0:
89
+ if index >= len(input_shape):
90
+ raise ShapeInferenceError(
91
+ "Reshape zero dim must index into input shape"
92
+ )
93
+ dim = input_shape[index]
94
+ if dim < 0:
95
+ raise ShapeInferenceError("Reshape dims must be >= -1")
96
+ output_dims.append(dim)
97
+ known_product *= dim
98
+ if allowzero == 1 and contains_zero and unknown_index is not None:
99
+ raise ShapeInferenceError(
100
+ "Reshape allowzero cannot combine zero and -1 dimensions"
101
+ )
102
+ input_product = _shape_product(input_shape)
103
+ if unknown_index is not None:
104
+ if known_product == 0:
105
+ if input_product != 0:
106
+ raise ShapeInferenceError(
107
+ "Reshape cannot infer dimension from input shape"
108
+ )
109
+ output_dims[unknown_index] = 0
110
+ else:
111
+ if input_product % known_product != 0:
112
+ raise ShapeInferenceError(
113
+ "Reshape cannot infer dimension from input shape"
114
+ )
115
+ output_dims[unknown_index] = input_product // known_product
116
+ output_shape = tuple(output_dims)
117
+ if _shape_product(output_shape) != input_product:
118
+ raise ShapeInferenceError(
119
+ "Reshape input and output element counts must match"
120
+ )
121
+ return output_shape
122
+
123
+
124
+ @register_lowering("Reshape")
125
+ def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
126
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
127
+ raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
128
+ input_shape = _value_shape(graph, node.inputs[0], node)
129
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
130
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
131
+ if input_dtype != output_dtype:
132
+ raise UnsupportedOpError(
133
+ "Reshape expects matching input/output dtypes, "
134
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
135
+ )
136
+ output_shape = _value_shape(graph, node.outputs[0], node)
137
+ allowzero = int(node.attrs.get("allowzero", 0))
138
+ shape_initializer = _find_initializer(graph, node.inputs[1])
139
+ resolved_shape: tuple[int, ...] | None = None
140
+ if shape_initializer is None:
141
+ shape_values = _shape_values_from_shape_node(
142
+ graph, node.inputs[1], node
143
+ )
144
+ if shape_values is not None:
145
+ resolved_shape = _resolve_target_shape(
146
+ input_shape,
147
+ shape_values,
148
+ allowzero=allowzero,
149
+ node=node,
150
+ )
151
+ else:
152
+ if _shape_product(output_shape) != _shape_product(input_shape):
153
+ raise ShapeInferenceError(
154
+ "Reshape input and output element counts must match"
155
+ )
156
+ else:
157
+ if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
158
+ raise UnsupportedOpError(
159
+ "Reshape expects int64 or int32 shape input, "
160
+ f"got {shape_initializer.type.dtype.onnx_name}"
161
+ )
162
+ if len(shape_initializer.type.shape) != 1:
163
+ raise UnsupportedOpError("Reshape expects a 1D shape input")
164
+ shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
165
+ resolved_shape = _resolve_target_shape(
166
+ input_shape,
167
+ shape_values,
168
+ allowzero=allowzero,
169
+ node=node,
170
+ )
171
+ if output_shape and resolved_shape != output_shape:
172
+ raise ShapeInferenceError(
173
+ "Reshape output shape must be "
174
+ f"{resolved_shape}, got {output_shape}"
175
+ )
176
+ if resolved_shape is not None:
177
+ output_shape = resolved_shape
178
+ for dim in output_shape:
179
+ if dim < 0:
180
+ raise ShapeInferenceError("Dynamic dims are not supported")
181
+ return ReshapeOp(
182
+ input0=node.inputs[0],
183
+ output=node.outputs[0],
184
+ input_shape=input_shape,
185
+ output_shape=output_shape,
186
+ dtype=input_dtype,
187
+ input_dtype=input_dtype,
188
+ )
@@ -0,0 +1,445 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import ResizeOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from .registry import register_lowering
11
+
12
+ _SUPPORTED_COORD_MODES = {
13
+ "half_pixel",
14
+ "half_pixel_symmetric",
15
+ "asymmetric",
16
+ "align_corners",
17
+ "pytorch_half_pixel",
18
+ "tf_crop_and_resize",
19
+ }
20
+ _SUPPORTED_MODES = {"nearest", "linear", "cubic"}
21
+ _SUPPORTED_NEAREST_MODES = {
22
+ "round_prefer_floor",
23
+ "round_prefer_ceil",
24
+ "floor",
25
+ "ceil",
26
+ }
27
+ _SUPPORTED_KEEP_ASPECT = {"stretch", "not_larger", "not_smaller"}
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class _ResizeInputs:
32
+ roi: str | None
33
+ scales: str | None
34
+ sizes: str | None
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class _ResolvedScales:
39
+ scales: tuple[float, ...]
40
+ output_shape: tuple[int, ...]
41
+ axes: tuple[int, ...]
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class _InputConfig:
46
+ input_shape: tuple[int, ...]
47
+ output_shape: tuple[int, ...]
48
+ input_dtype: ScalarType
49
+ output_dtype: ScalarType
50
+
51
+
52
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
53
+ try:
54
+ return graph.find_value(name).type.shape
55
+ except KeyError as exc:
56
+ raise ShapeInferenceError(
57
+ f"Missing shape for value '{name}' in op {node.op_type}. "
58
+ "Hint: run ONNX shape inference or export with static shapes."
59
+ ) from exc
60
+
61
+
62
+ def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
63
+ try:
64
+ return graph.find_value(name).type.dtype
65
+ except KeyError as exc:
66
+ raise ShapeInferenceError(
67
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
68
+ "Hint: run ONNX shape inference or export with static shapes."
69
+ ) from exc
70
+
71
+
72
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
73
+ for initializer in graph.initializers:
74
+ if initializer.name == name:
75
+ return initializer
76
+ return None
77
+
78
+
79
+ def _decode_attr(value: object, default: str) -> str:
80
+ if value is None:
81
+ return default
82
+ if isinstance(value, bytes):
83
+ return value.decode("utf-8", errors="ignore")
84
+ if isinstance(value, str):
85
+ return value
86
+ return str(value)
87
+
88
+
89
+ def _normalize_axes(
90
+ axes: tuple[int, ...], rank: int, node: Node
91
+ ) -> tuple[int, ...]:
92
+ normalized: list[int] = []
93
+ for axis in axes:
94
+ axis = int(axis)
95
+ if axis < 0:
96
+ axis += rank
97
+ if axis < 0 or axis >= rank:
98
+ raise ShapeInferenceError(
99
+ f"Resize axis {axis} is out of range for rank {rank}"
100
+ )
101
+ normalized.append(axis)
102
+ if len(set(normalized)) != len(normalized):
103
+ raise ShapeInferenceError("Resize axes must be unique")
104
+ return tuple(normalized)
105
+
106
+
107
+ def _round_half_up(value: float) -> int:
108
+ return int(value + 0.5)
109
+
110
+
111
+ def _parse_input_names(node: Node) -> _ResizeInputs:
112
+ inputs = list(node.inputs)
113
+ if len(inputs) > 4:
114
+ raise UnsupportedOpError("Resize expects at most 4 inputs")
115
+ while len(inputs) < 4:
116
+ inputs.append("")
117
+ _, roi, scales, sizes = inputs[:4]
118
+ return _ResizeInputs(
119
+ roi=roi or None,
120
+ scales=scales or None,
121
+ sizes=sizes or None,
122
+ )
123
+
124
+
125
+ def _parse_axes(node: Node, rank: int) -> tuple[int, ...]:
126
+ axes_attr = node.attrs.get("axes")
127
+ if axes_attr is None:
128
+ return tuple(range(rank))
129
+ axes = tuple(int(value) for value in axes_attr)
130
+ return _normalize_axes(axes, rank, node)
131
+
132
+
133
+ def _resolve_input_shapes(
134
+ graph: Graph, node: Node, input_name: str
135
+ ) -> _InputConfig:
136
+ input_shape = _value_shape(graph, input_name, node)
137
+ output_shape = _value_shape(graph, node.outputs[0], node)
138
+ input_dtype = _value_dtype(graph, input_name, node)
139
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
140
+ return _InputConfig(
141
+ input_shape=input_shape,
142
+ output_shape=output_shape,
143
+ input_dtype=input_dtype,
144
+ output_dtype=output_dtype,
145
+ )
146
+
147
+
148
+ def _resolve_scales_from_sizes(
149
+ sizes: tuple[int, ...],
150
+ input_shape: tuple[int, ...],
151
+ axes: tuple[int, ...],
152
+ keep_aspect_ratio_policy: str,
153
+ ) -> _ResolvedScales:
154
+ rank = len(input_shape)
155
+ full_sizes = list(input_shape)
156
+ for index, axis in enumerate(axes):
157
+ full_sizes[axis] = sizes[index]
158
+ if keep_aspect_ratio_policy != "stretch":
159
+ scales = [full_sizes[axis] / input_shape[axis] for axis in axes]
160
+ if keep_aspect_ratio_policy == "not_larger":
161
+ scale = min(scales)
162
+ else:
163
+ scale = max(scales)
164
+ for axis in axes:
165
+ full_sizes[axis] = _round_half_up(scale * input_shape[axis])
166
+ return _ResolvedScales(
167
+ scales=tuple(
168
+ scale if axis in axes else 1.0
169
+ for axis in range(rank)
170
+ ),
171
+ output_shape=tuple(full_sizes),
172
+ axes=axes,
173
+ )
174
+ scales = tuple(
175
+ full_sizes[axis] / input_shape[axis] if axis in axes else 1.0
176
+ for axis in range(rank)
177
+ )
178
+ return _ResolvedScales(
179
+ scales=scales,
180
+ output_shape=tuple(full_sizes),
181
+ axes=axes,
182
+ )
183
+
184
+
185
+ def _resolve_scales_from_values(
186
+ scales: tuple[float, ...],
187
+ input_shape: tuple[int, ...],
188
+ axes: tuple[int, ...],
189
+ ) -> _ResolvedScales:
190
+ rank = len(input_shape)
191
+ full_scales = [1.0] * rank
192
+ for index, axis in enumerate(axes):
193
+ full_scales[axis] = scales[index]
194
+ output_shape = tuple(
195
+ int(input_shape[axis] * full_scales[axis])
196
+ if axis in axes
197
+ else input_shape[axis]
198
+ for axis in range(rank)
199
+ )
200
+ return _ResolvedScales(
201
+ scales=tuple(full_scales),
202
+ output_shape=output_shape,
203
+ axes=axes,
204
+ )
205
+
206
+
207
+ def _load_initializer_values(
208
+ graph: Graph, name: str, node: Node
209
+ ) -> tuple[float | int, ...] | None:
210
+ initializer = _find_initializer(graph, name)
211
+ if initializer is None:
212
+ return None
213
+ data = initializer.data.reshape(-1)
214
+ return tuple(data.tolist())
215
+
216
+
217
+ def _validate_tensor_1d(
218
+ graph: Graph,
219
+ name: str,
220
+ node: Node,
221
+ dtype_options: set[ScalarType],
222
+ ) -> tuple[int, ScalarType]:
223
+ shape = _value_shape(graph, name, node)
224
+ if len(shape) != 1:
225
+ raise UnsupportedOpError("Resize expects 1D auxiliary inputs")
226
+ dtype = _value_dtype(graph, name, node)
227
+ if dtype not in dtype_options:
228
+ raise UnsupportedOpError(
229
+ "Resize expects "
230
+ f"{name} to have dtype in {[dtype.onnx_name for dtype in sorted(dtype_options, key=str)]}"
231
+ )
232
+ return shape[0], dtype
233
+
234
+
235
+ def _resolve_scales(
236
+ graph: Graph,
237
+ node: Node,
238
+ config: _InputConfig,
239
+ inputs: _ResizeInputs,
240
+ axes: tuple[int, ...],
241
+ keep_aspect_ratio_policy: str,
242
+ ) -> tuple[tuple[float, ...], tuple[int, ...]]:
243
+ rank = len(config.input_shape)
244
+ if inputs.scales:
245
+ scale_len, _ = _validate_tensor_1d(
246
+ graph,
247
+ inputs.scales,
248
+ node,
249
+ {ScalarType.F16, ScalarType.F32, ScalarType.F64},
250
+ )
251
+ if scale_len not in {len(axes), rank}:
252
+ raise UnsupportedOpError("Resize scales length mismatch")
253
+ if scale_len == rank and axes != tuple(range(rank)):
254
+ raise UnsupportedOpError(
255
+ "Resize scales length conflicts with axes configuration"
256
+ )
257
+ scale_axes = axes if scale_len == len(axes) else tuple(range(rank))
258
+ values = _load_initializer_values(graph, inputs.scales, node)
259
+ if values is None:
260
+ scales = tuple(
261
+ config.output_shape[axis] / config.input_shape[axis]
262
+ if axis in scale_axes
263
+ else 1.0
264
+ for axis in range(rank)
265
+ )
266
+ return scales, config.output_shape
267
+ resolved = _resolve_scales_from_values(
268
+ tuple(float(value) for value in values),
269
+ config.input_shape,
270
+ scale_axes,
271
+ )
272
+ return resolved.scales, resolved.output_shape
273
+ if inputs.sizes:
274
+ size_len, _ = _validate_tensor_1d(
275
+ graph, inputs.sizes, node, {ScalarType.I64, ScalarType.I32}
276
+ )
277
+ if size_len not in {len(axes), rank}:
278
+ raise UnsupportedOpError("Resize sizes length mismatch")
279
+ if size_len == rank and axes != tuple(range(rank)):
280
+ raise UnsupportedOpError(
281
+ "Resize sizes length conflicts with axes configuration"
282
+ )
283
+ size_axes = axes if size_len == len(axes) else tuple(range(rank))
284
+ values = _load_initializer_values(graph, inputs.sizes, node)
285
+ if values is None:
286
+ scales = tuple(
287
+ config.output_shape[axis] / config.input_shape[axis]
288
+ if axis in size_axes
289
+ else 1.0
290
+ for axis in range(rank)
291
+ )
292
+ return scales, config.output_shape
293
+ resolved = _resolve_scales_from_sizes(
294
+ tuple(int(value) for value in values),
295
+ config.input_shape,
296
+ size_axes,
297
+ keep_aspect_ratio_policy,
298
+ )
299
+ return resolved.scales, resolved.output_shape
300
+ raise UnsupportedOpError("Resize expects scales or sizes input")
301
+
302
+
303
+ def _validate_output_shape(
304
+ expected: tuple[int, ...],
305
+ actual: tuple[int, ...],
306
+ ) -> None:
307
+ if expected != actual:
308
+ raise ShapeInferenceError(
309
+ f"Resize output shape must be {expected}, got {actual}"
310
+ )
311
+ if any(dim < 0 for dim in actual):
312
+ raise ShapeInferenceError("Resize output shape must be non-negative")
313
+
314
+
315
+ @register_lowering("Resize")
316
+ def lower_resize(graph: Graph, node: Node) -> ResizeOp:
317
+ if len(node.outputs) != 1:
318
+ raise UnsupportedOpError("Resize expects one output")
319
+ inputs = _parse_input_names(node)
320
+ if inputs.scales and inputs.sizes:
321
+ raise UnsupportedOpError("Resize cannot set both scales and sizes")
322
+ if not inputs.scales and not inputs.sizes:
323
+ raise UnsupportedOpError("Resize expects scales or sizes input")
324
+ mode = _decode_attr(node.attrs.get("mode"), "nearest")
325
+ coordinate_mode = _decode_attr(
326
+ node.attrs.get("coordinate_transformation_mode"), "half_pixel"
327
+ )
328
+ nearest_mode = _decode_attr(
329
+ node.attrs.get("nearest_mode"), "round_prefer_floor"
330
+ )
331
+ keep_aspect_ratio_policy = _decode_attr(
332
+ node.attrs.get("keep_aspect_ratio_policy"), "stretch"
333
+ )
334
+ antialias = bool(int(node.attrs.get("antialias", 0)))
335
+ cubic_coeff_a = float(node.attrs.get("cubic_coeff_a", -0.75))
336
+ exclude_outside = bool(int(node.attrs.get("exclude_outside", 0)))
337
+ extrapolation_value = float(node.attrs.get("extrapolation_value", 0.0))
338
+ if mode not in _SUPPORTED_MODES:
339
+ raise UnsupportedOpError(f"Resize mode {mode!r} is not supported")
340
+ if coordinate_mode not in _SUPPORTED_COORD_MODES:
341
+ raise UnsupportedOpError(
342
+ "Resize coordinate_transformation_mode "
343
+ f"{coordinate_mode!r} is not supported"
344
+ )
345
+ if nearest_mode not in _SUPPORTED_NEAREST_MODES:
346
+ raise UnsupportedOpError(
347
+ f"Resize nearest_mode {nearest_mode!r} is not supported"
348
+ )
349
+ if keep_aspect_ratio_policy not in _SUPPORTED_KEEP_ASPECT:
350
+ raise UnsupportedOpError(
351
+ "Resize keep_aspect_ratio_policy "
352
+ f"{keep_aspect_ratio_policy!r} is not supported"
353
+ )
354
+ if antialias and mode == "nearest":
355
+ raise UnsupportedOpError("Resize antialias is not supported for nearest")
356
+ config = _resolve_input_shapes(graph, node, node.inputs[0])
357
+ if config.input_dtype != config.output_dtype:
358
+ raise UnsupportedOpError(
359
+ "Resize expects matching input/output dtypes, "
360
+ f"got {config.input_dtype.onnx_name} and {config.output_dtype.onnx_name}"
361
+ )
362
+ rank = len(config.input_shape)
363
+ axes = _parse_axes(node, rank)
364
+ scales, expected_output = _resolve_scales(
365
+ graph,
366
+ node,
367
+ config,
368
+ inputs,
369
+ axes,
370
+ keep_aspect_ratio_policy,
371
+ )
372
+ _validate_output_shape(expected_output, config.output_shape)
373
+ roi_shape = None
374
+ roi_axes = None
375
+ roi_dtype = None
376
+ if inputs.roi:
377
+ roi_len, roi_dtype = _validate_tensor_1d(
378
+ graph,
379
+ inputs.roi,
380
+ node,
381
+ {ScalarType.F16, ScalarType.F32, ScalarType.F64},
382
+ )
383
+ if roi_len == 2 * rank:
384
+ roi_shape = (roi_len,)
385
+ elif roi_len == 2 * len(axes):
386
+ roi_shape = (roi_len,)
387
+ roi_axes = axes
388
+ else:
389
+ raise UnsupportedOpError("Resize roi length mismatch")
390
+ if coordinate_mode != "tf_crop_and_resize" and roi_len != 0:
391
+ roi_axes = roi_axes if roi_len == 2 * len(axes) else None
392
+ if coordinate_mode == "tf_crop_and_resize" and not inputs.roi:
393
+ raise UnsupportedOpError("Resize requires roi for tf_crop_and_resize")
394
+ scales_shape = None
395
+ sizes_shape = None
396
+ scales_dtype = None
397
+ sizes_dtype = None
398
+ scales_axes = None
399
+ sizes_axes = None
400
+ if inputs.scales:
401
+ scale_len, scales_dtype = _validate_tensor_1d(
402
+ graph,
403
+ inputs.scales,
404
+ node,
405
+ {ScalarType.F16, ScalarType.F32, ScalarType.F64},
406
+ )
407
+ scales_shape = (scale_len,)
408
+ if scale_len == len(axes) and len(axes) != rank:
409
+ scales_axes = axes
410
+ if inputs.sizes:
411
+ size_len, sizes_dtype = _validate_tensor_1d(
412
+ graph, inputs.sizes, node, {ScalarType.I64, ScalarType.I32}
413
+ )
414
+ sizes_shape = (size_len,)
415
+ if size_len == len(axes) and len(axes) != rank:
416
+ sizes_axes = axes
417
+ return ResizeOp(
418
+ input0=node.inputs[0],
419
+ output=node.outputs[0],
420
+ input_shape=config.input_shape,
421
+ output_shape=config.output_shape,
422
+ scales=scales,
423
+ scales_input=inputs.scales,
424
+ sizes_input=inputs.sizes,
425
+ roi_input=inputs.roi,
426
+ axes=axes,
427
+ scales_shape=scales_shape,
428
+ sizes_shape=sizes_shape,
429
+ roi_shape=roi_shape,
430
+ scales_dtype=scales_dtype,
431
+ sizes_dtype=sizes_dtype,
432
+ roi_dtype=roi_dtype,
433
+ scales_axes=scales_axes,
434
+ sizes_axes=sizes_axes,
435
+ roi_axes=roi_axes,
436
+ mode=mode,
437
+ coordinate_transformation_mode=coordinate_mode,
438
+ nearest_mode=nearest_mode,
439
+ cubic_coeff_a=cubic_coeff_a,
440
+ exclude_outside=exclude_outside,
441
+ extrapolation_value=extrapolation_value,
442
+ antialias=antialias,
443
+ keep_aspect_ratio_policy=keep_aspect_ratio_policy,
444
+ dtype=config.input_dtype,
445
+ )
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import RMSNormalizationOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from ..validation import ensure_output_shape_matches_input
7
+ from ..validation import normalize_axis
8
+ from .common import node_dtype, shape_product, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ def _ensure_broadcastable(
13
+ name: str,
14
+ shape: tuple[int, ...],
15
+ normalized_shape: tuple[int, ...],
16
+ ) -> None:
17
+ if len(shape) != len(normalized_shape):
18
+ raise ShapeInferenceError(
19
+ f"RMSNormalization {name} rank must match normalized rank"
20
+ )
21
+ for dim, expected in zip(shape, normalized_shape):
22
+ if dim not in {1, expected}:
23
+ raise ShapeInferenceError(
24
+ f"RMSNormalization {name} shape {shape} must be broadcastable "
25
+ f"to {normalized_shape}"
26
+ )
27
+
28
+
29
+ @register_lowering("RMSNormalization")
30
+ def lower_rms_normalization(
31
+ graph: Graph, node: Node
32
+ ) -> RMSNormalizationOp:
33
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
34
+ raise UnsupportedOpError("RMSNormalization must have 2 inputs and 1 output")
35
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
36
+ if not op_dtype.is_float:
37
+ raise UnsupportedOpError(
38
+ "RMSNormalization supports float16, float, and double inputs only"
39
+ )
40
+ input_shape = value_shape(graph, node.inputs[0], node)
41
+ output_shape = value_shape(graph, node.outputs[0], node)
42
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
43
+ axis = normalize_axis(int(node.attrs.get("axis", -1)), input_shape, node)
44
+ normalized_shape = input_shape[axis:]
45
+ scale_shape = value_shape(graph, node.inputs[1], node)
46
+ _ensure_broadcastable("scale", scale_shape, normalized_shape)
47
+ epsilon = float(node.attrs.get("epsilon", 1e-5))
48
+ stash_type = int(node.attrs.get("stash_type", 1))
49
+ if stash_type != 1:
50
+ raise UnsupportedOpError(
51
+ "RMSNormalization supports stash_type=1 only"
52
+ )
53
+ outer = shape_product(input_shape[:axis]) if axis > 0 else 1
54
+ inner = shape_product(normalized_shape)
55
+ return RMSNormalizationOp(
56
+ input0=node.inputs[0],
57
+ scale=node.inputs[1],
58
+ output=node.outputs[0],
59
+ shape=input_shape,
60
+ normalized_shape=normalized_shape,
61
+ scale_shape=scale_shape,
62
+ outer=outer,
63
+ inner=inner,
64
+ axis=axis,
65
+ epsilon=epsilon,
66
+ dtype=op_dtype,
67
+ )