emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import MeanVarianceNormalizationOp
4
+ from ..errors import UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from ..validation import ensure_output_shape_matches_input
7
+ from .common import node_dtype, shape_product, value_shape
8
+ from .reduce import normalize_reduce_axes
9
+ from .registry import register_lowering
10
+
11
+
12
+ @register_lowering("MeanVarianceNormalization")
13
+ def lower_mean_variance_normalization(
14
+ graph: Graph, node: Node
15
+ ) -> MeanVarianceNormalizationOp:
16
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
17
+ raise UnsupportedOpError(
18
+ "MeanVarianceNormalization must have 1 input and 1 output"
19
+ )
20
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
21
+ if not op_dtype.is_float:
22
+ raise UnsupportedOpError(
23
+ "MeanVarianceNormalization supports float16, float, and double inputs only"
24
+ )
25
+ input_shape = value_shape(graph, node.inputs[0], node)
26
+ output_shape = value_shape(graph, node.outputs[0], node)
27
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
28
+ axes_attr = node.attrs.get("axes")
29
+ if axes_attr is None:
30
+ axes = (0, 2, 3)
31
+ else:
32
+ axes = tuple(int(axis) for axis in axes_attr)
33
+ axes = normalize_reduce_axes(axes, input_shape, node)
34
+ if not axes:
35
+ raise UnsupportedOpError(
36
+ "MeanVarianceNormalization requires non-empty reduction axes"
37
+ )
38
+ non_axes = tuple(i for i in range(len(input_shape)) if i not in axes)
39
+ reduce_count = shape_product(tuple(input_shape[axis] for axis in axes))
40
+ return MeanVarianceNormalizationOp(
41
+ input0=node.inputs[0],
42
+ output=node.outputs[0],
43
+ shape=input_shape,
44
+ axes=axes,
45
+ non_axes=non_axes,
46
+ reduce_count=reduce_count,
47
+ epsilon=1e-9,
48
+ dtype=op_dtype,
49
+ )
@@ -0,0 +1,250 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import NegativeLogLikelihoodLossOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Initializer, Node
8
+ from .common import shape_product as _shape_product
9
+ from .common import value_dtype as _value_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
15
+ for node in graph.nodes:
16
+ if name in node.outputs:
17
+ return node
18
+ return None
19
+
20
+
21
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
22
+ for initializer in graph.initializers:
23
+ if initializer.name == name:
24
+ return initializer
25
+ return None
26
+
27
+
28
+ def _resolve_target_shape(
29
+ input_shape: tuple[int, ...],
30
+ shape_values: list[int],
31
+ *,
32
+ allowzero: int,
33
+ node: Node,
34
+ ) -> tuple[int, ...]:
35
+ if allowzero not in (0, 1):
36
+ raise UnsupportedOpError("Reshape allowzero must be 0 or 1")
37
+ output_dims: list[int] = []
38
+ unknown_index: int | None = None
39
+ known_product = 1
40
+ for index, dim in enumerate(shape_values):
41
+ if dim == -1:
42
+ if unknown_index is not None:
43
+ raise ShapeInferenceError("Reshape allows only one -1 dimension")
44
+ unknown_index = index
45
+ output_dims.append(-1)
46
+ continue
47
+ if dim == 0:
48
+ if allowzero == 0:
49
+ if index >= len(input_shape):
50
+ raise ShapeInferenceError(
51
+ "Reshape zero dim must index into input shape"
52
+ )
53
+ dim = input_shape[index]
54
+ if dim < 0:
55
+ raise ShapeInferenceError("Reshape dims must be >= -1")
56
+ output_dims.append(dim)
57
+ known_product *= dim
58
+ input_product = _shape_product(input_shape)
59
+ if unknown_index is not None:
60
+ if known_product == 0 or input_product % known_product != 0:
61
+ raise ShapeInferenceError(
62
+ "Reshape cannot infer dimension from input shape"
63
+ )
64
+ output_dims[unknown_index] = input_product // known_product
65
+ output_shape = tuple(output_dims)
66
+ if _shape_product(output_shape) != input_product:
67
+ raise ShapeInferenceError(
68
+ "Reshape input and output element counts must match"
69
+ )
70
+ return output_shape
71
+
72
+
73
+ def _shape_values_from_shape_node(
74
+ graph: Graph, name: str, node: Node
75
+ ) -> list[int] | None:
76
+ shape_node = _find_node_by_output(graph, name)
77
+ if shape_node is None or shape_node.op_type != "Shape":
78
+ return None
79
+ if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
80
+ raise UnsupportedOpError("Shape must have 1 input and 1 output")
81
+ source_shape = _value_shape(graph, shape_node.inputs[0], node)
82
+ return list(source_shape)
83
+
84
+
85
+ def _resolve_shape_from_reshape(
86
+ graph: Graph, name: str, node: Node
87
+ ) -> tuple[int, ...] | None:
88
+ reshape_node = _find_node_by_output(graph, name)
89
+ if reshape_node is None or reshape_node.op_type != "Reshape":
90
+ return None
91
+ if len(reshape_node.inputs) != 2 or len(reshape_node.outputs) != 1:
92
+ raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
93
+ input_shape = _value_shape(graph, reshape_node.inputs[0], node)
94
+ if not input_shape:
95
+ return None
96
+ allowzero = int(reshape_node.attrs.get("allowzero", 0))
97
+ shape_initializer = _find_initializer(graph, reshape_node.inputs[1])
98
+ if shape_initializer is not None:
99
+ if shape_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
100
+ raise UnsupportedOpError(
101
+ "Reshape expects int64 or int32 shape input, "
102
+ f"got {shape_initializer.type.dtype.onnx_name}"
103
+ )
104
+ if len(shape_initializer.type.shape) != 1:
105
+ raise UnsupportedOpError("Reshape expects a 1D shape input")
106
+ shape_values = [int(value) for value in shape_initializer.data.reshape(-1)]
107
+ return _resolve_target_shape(
108
+ input_shape,
109
+ shape_values,
110
+ allowzero=allowzero,
111
+ node=node,
112
+ )
113
+ shape_values = _shape_values_from_shape_node(
114
+ graph, reshape_node.inputs[1], node
115
+ )
116
+ if shape_values is None:
117
+ return None
118
+ return _resolve_target_shape(
119
+ input_shape,
120
+ shape_values,
121
+ allowzero=allowzero,
122
+ node=node,
123
+ )
124
+
125
+
126
+ def _resolve_input_shape(
127
+ graph: Graph,
128
+ input_name: str,
129
+ target_shape: tuple[int, ...],
130
+ weight_name: str | None,
131
+ node: Node,
132
+ ) -> tuple[int, ...]:
133
+ input_shape = _value_shape(graph, input_name, node)
134
+ if input_shape:
135
+ return input_shape
136
+ reshaped = _resolve_shape_from_reshape(graph, input_name, node)
137
+ if reshaped is not None:
138
+ return reshaped
139
+ if weight_name is not None and target_shape:
140
+ weight_shape = _value_shape(graph, weight_name, node)
141
+ if len(weight_shape) != 1:
142
+ return input_shape
143
+ return (target_shape[0], weight_shape[0], *target_shape[1:])
144
+ return input_shape
145
+
146
+
147
+ @register_lowering("NegativeLogLikelihoodLoss")
148
+ def lower_negative_log_likelihood_loss(
149
+ graph: Graph, node: Node
150
+ ) -> NegativeLogLikelihoodLossOp:
151
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
152
+ raise UnsupportedOpError(
153
+ "NegativeLogLikelihoodLoss must have 2 or 3 inputs and 1 output"
154
+ )
155
+ input_name = node.inputs[0]
156
+ target_name = node.inputs[1]
157
+ weight_name = node.inputs[2] if len(node.inputs) > 2 else None
158
+ input_dtype = _value_dtype(graph, input_name, node)
159
+ if not input_dtype.is_float:
160
+ raise UnsupportedOpError(
161
+ "NegativeLogLikelihoodLoss supports float16, float, and double inputs only"
162
+ )
163
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
164
+ if output_dtype != input_dtype:
165
+ raise UnsupportedOpError(
166
+ "NegativeLogLikelihoodLoss output dtype must match input dtype"
167
+ )
168
+ target_dtype = _value_dtype(graph, target_name, node)
169
+ if target_dtype not in {ScalarType.I32, ScalarType.I64}:
170
+ raise UnsupportedOpError(
171
+ "NegativeLogLikelihoodLoss target must be int32 or int64"
172
+ )
173
+ weight_dtype = None
174
+ weight_shape: tuple[int, ...] | None = None
175
+ if weight_name is not None:
176
+ weight_dtype = _value_dtype(graph, weight_name, node)
177
+ if weight_dtype != input_dtype:
178
+ raise UnsupportedOpError(
179
+ "NegativeLogLikelihoodLoss weight dtype must match input dtype"
180
+ )
181
+ target_shape = _value_shape(graph, target_name, node)
182
+ output_shape = _value_shape(graph, node.outputs[0], node)
183
+ input_shape = _resolve_input_shape(
184
+ graph, input_name, target_shape, weight_name, node
185
+ )
186
+ if len(input_shape) < 2:
187
+ raise ShapeInferenceError(
188
+ "NegativeLogLikelihoodLoss input must be at least 2D"
189
+ )
190
+ if len(target_shape) != len(input_shape) - 1:
191
+ raise ShapeInferenceError(
192
+ "NegativeLogLikelihoodLoss target rank must be input rank - 1"
193
+ )
194
+ if input_shape[0] != target_shape[0]:
195
+ raise ShapeInferenceError(
196
+ "NegativeLogLikelihoodLoss target batch dimension must match input"
197
+ )
198
+ if input_shape[2:] != target_shape[1:]:
199
+ raise ShapeInferenceError(
200
+ "NegativeLogLikelihoodLoss target spatial dimensions must match input"
201
+ )
202
+ if weight_name is not None:
203
+ weight_shape = _value_shape(graph, weight_name, node)
204
+ if len(weight_shape) != 1 or weight_shape[0] != input_shape[1]:
205
+ raise ShapeInferenceError(
206
+ "NegativeLogLikelihoodLoss weight must have shape (C,)"
207
+ )
208
+ reduction = node.attrs.get("reduction", "mean")
209
+ if isinstance(reduction, bytes):
210
+ reduction = reduction.decode("utf-8")
211
+ if reduction not in {"none", "mean", "sum"}:
212
+ raise UnsupportedOpError(
213
+ "NegativeLogLikelihoodLoss reduction must be none, mean, or sum"
214
+ )
215
+ if reduction == "none":
216
+ if not output_shape:
217
+ output_shape = target_shape
218
+ if output_shape != target_shape:
219
+ raise ShapeInferenceError(
220
+ "NegativeLogLikelihoodLoss output must match target shape "
221
+ "when reduction is none"
222
+ )
223
+ else:
224
+ if output_shape and output_shape not in {(), (1,)}:
225
+ raise ShapeInferenceError(
226
+ "NegativeLogLikelihoodLoss output must be scalar when reduced"
227
+ )
228
+ n = input_shape[0]
229
+ c = input_shape[1]
230
+ d = _shape_product(input_shape[2:]) if len(input_shape) > 2 else 1
231
+ ignore_index = int(node.attrs.get("ignore_index", -1))
232
+ return NegativeLogLikelihoodLossOp(
233
+ input0=input_name,
234
+ target=target_name,
235
+ weight=weight_name,
236
+ output=node.outputs[0],
237
+ input_shape=input_shape,
238
+ target_shape=target_shape,
239
+ output_shape=output_shape,
240
+ n=n,
241
+ c=c,
242
+ d=d,
243
+ reduction=reduction,
244
+ ignore_index=ignore_index,
245
+ input_dtype=input_dtype,
246
+ weight_dtype=weight_dtype,
247
+ weight_shape=weight_shape,
248
+ dtype=input_dtype,
249
+ target_dtype=target_dtype,
250
+ )
@@ -0,0 +1,287 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import PadOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import optional_name, value_dtype, value_shape
11
+ from ..validation import normalize_axis
12
+ from .registry import register_lowering
13
+
14
+
15
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
16
+ for initializer in graph.initializers:
17
+ if initializer.name == name:
18
+ return initializer
19
+ return None
20
+
21
+
22
+ def _read_int_initializer(
23
+ graph: Graph,
24
+ name: str,
25
+ node: Node,
26
+ *,
27
+ label: str,
28
+ ) -> tuple[int, ...] | None:
29
+ initializer = _find_initializer(graph, name)
30
+ if initializer is None:
31
+ return None
32
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
33
+ raise UnsupportedOpError(
34
+ f"Pad {label} input must be int64 or int32"
35
+ )
36
+ if len(initializer.type.shape) != 1:
37
+ raise UnsupportedOpError(f"Pad {label} input must be a 1D tensor")
38
+ values = np.array(initializer.data, dtype=np.int64).reshape(-1)
39
+ return tuple(int(value) for value in values)
40
+
41
+
42
+ def _read_scalar_initializer(
43
+ graph: Graph, name: str, node: Node, *, dtype: ScalarType
44
+ ) -> float | int | bool | None:
45
+ initializer = _find_initializer(graph, name)
46
+ if initializer is None:
47
+ return None
48
+ if initializer.type.dtype != dtype:
49
+ raise UnsupportedOpError(
50
+ "Pad value input must match input dtype, "
51
+ f"got {initializer.type.dtype.onnx_name}"
52
+ )
53
+ values = np.array(initializer.data).reshape(-1)
54
+ if values.size != 1:
55
+ raise UnsupportedOpError("Pad value input must be a scalar")
56
+ return values.item()
57
+
58
+
59
+ def _normalize_axes(
60
+ axes: tuple[int, ...], input_shape: tuple[int, ...], node: Node
61
+ ) -> tuple[int, ...]:
62
+ normalized = [normalize_axis(axis, input_shape, node) for axis in axes]
63
+ if len(set(normalized)) != len(normalized):
64
+ raise UnsupportedOpError("Pad axes must be unique")
65
+ return tuple(normalized)
66
+
67
+
68
+ def _default_pad_value(dtype: ScalarType) -> float | int | bool:
69
+ if dtype.is_bool:
70
+ return False
71
+ if dtype.is_float:
72
+ return 0.0
73
+ return 0
74
+
75
+
76
+ def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
77
+ strides: list[int] = []
78
+ stride = 1
79
+ for dim in reversed(shape):
80
+ strides.append(stride)
81
+ stride *= dim
82
+ return tuple(reversed(strides))
83
+
84
+
85
+ @register_lowering("Pad")
86
+ def lower_pad(graph: Graph, node: Node) -> PadOp:
87
+ if not node.inputs or len(node.outputs) != 1:
88
+ raise UnsupportedOpError("Pad must have 1 output")
89
+ input_name = node.inputs[0]
90
+ if not input_name:
91
+ raise UnsupportedOpError("Pad input must be provided")
92
+ input_shape = value_shape(graph, input_name, node)
93
+ output_shape = value_shape(graph, node.outputs[0], node)
94
+ input_dtype = value_dtype(graph, input_name, node)
95
+ output_dtype = value_dtype(graph, node.outputs[0], node)
96
+ if input_dtype != output_dtype:
97
+ raise UnsupportedOpError(
98
+ "Pad expects matching input/output dtypes, "
99
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
100
+ )
101
+ mode = node.attrs.get("mode", "constant")
102
+ if isinstance(mode, bytes):
103
+ mode = mode.decode("utf-8")
104
+ if mode not in {"constant", "edge", "reflect", "wrap"}:
105
+ raise UnsupportedOpError(f"Pad mode '{mode}' is not supported")
106
+ pads_name = optional_name(node.inputs, 1)
107
+ pads_attr = node.attrs.get("pads")
108
+ if pads_name and pads_attr:
109
+ raise UnsupportedOpError("Pad pads must be provided via input or attribute")
110
+ pads = None
111
+ pads_input = None
112
+ pads_shape = None
113
+ pads_dtype = None
114
+ if pads_name:
115
+ pads = _read_int_initializer(graph, pads_name, node, label="pads")
116
+ if pads is None:
117
+ pads_shape = value_shape(graph, pads_name, node)
118
+ pads_dtype = value_dtype(graph, pads_name, node)
119
+ if pads_dtype not in {ScalarType.I64, ScalarType.I32}:
120
+ raise UnsupportedOpError(
121
+ "Pad pads input must be int64 or int32"
122
+ )
123
+ if len(pads_shape) != 1:
124
+ raise UnsupportedOpError("Pad pads input must be a 1D tensor")
125
+ pads_input = pads_name
126
+ elif pads_attr is not None:
127
+ pads = tuple(int(value) for value in pads_attr)
128
+ if pads is None and pads_input is None:
129
+ pads = tuple(0 for _ in range(2 * len(input_shape)))
130
+
131
+ axes_name = optional_name(node.inputs, 3)
132
+ axes = None
133
+ axes_input = None
134
+ axes_shape = None
135
+ axes_dtype = None
136
+ if axes_name:
137
+ axes = _read_int_initializer(graph, axes_name, node, label="axes")
138
+ if axes is None:
139
+ axes_shape = value_shape(graph, axes_name, node)
140
+ axes_dtype = value_dtype(graph, axes_name, node)
141
+ if axes_dtype not in {ScalarType.I64, ScalarType.I32}:
142
+ raise UnsupportedOpError(
143
+ "Pad axes input must be int64 or int32"
144
+ )
145
+ if len(axes_shape) != 1:
146
+ raise UnsupportedOpError("Pad axes input must be a 1D tensor")
147
+ if axes_shape[0] < 0:
148
+ raise ShapeInferenceError(
149
+ "Pad axes input must have a static length"
150
+ )
151
+ axes_input = axes_name
152
+ else:
153
+ axes = _normalize_axes(axes, input_shape, node)
154
+
155
+ pads_axis_map = None
156
+ pads_values = None
157
+ pads_begin = None
158
+ pads_end = None
159
+
160
+ if axes_input is None and axes is None:
161
+ if pads is None:
162
+ if pads_shape is None or pads_shape[0] != 2 * len(input_shape):
163
+ raise ShapeInferenceError(
164
+ "Pad pads must have length 2 * rank of input"
165
+ )
166
+ pads_begin = None
167
+ pads_end = None
168
+ else:
169
+ if len(pads) != 2 * len(input_shape):
170
+ raise ShapeInferenceError(
171
+ "Pad pads must have length 2 * rank of input"
172
+ )
173
+ pads_begin = list(pads[: len(input_shape)])
174
+ pads_end = list(pads[len(input_shape) :])
175
+ elif axes_input is None:
176
+ if pads_input is not None:
177
+ if pads_shape is None or pads_shape[0] != 2 * len(axes):
178
+ raise ShapeInferenceError(
179
+ "Pad pads must have length 2 * len(axes)"
180
+ )
181
+ pads_axis_map = [None] * len(input_shape)
182
+ for index, axis in enumerate(axes):
183
+ pads_axis_map[axis] = index
184
+ else:
185
+ if len(pads) != 2 * len(axes):
186
+ raise ShapeInferenceError(
187
+ "Pad pads must have length 2 * len(axes)"
188
+ )
189
+ pads_begin = [0] * len(input_shape)
190
+ pads_end = [0] * len(input_shape)
191
+ for index, axis in enumerate(axes):
192
+ pads_begin[axis] = pads[index]
193
+ pads_end[axis] = pads[index + len(axes)]
194
+ else:
195
+ axes_len = axes_shape[0] if axes_shape is not None else 0
196
+ if pads_input is not None:
197
+ if pads_shape is None or pads_shape[0] != 2 * axes_len:
198
+ raise ShapeInferenceError(
199
+ "Pad pads must have length 2 * len(axes)"
200
+ )
201
+ else:
202
+ if len(pads) != 2 * axes_len:
203
+ raise ShapeInferenceError(
204
+ "Pad pads must have length 2 * len(axes)"
205
+ )
206
+ pads_values = pads
207
+
208
+ if pads_begin is not None and pads_end is not None:
209
+ if any(value < 0 for value in pads_begin + pads_end):
210
+ raise UnsupportedOpError("Pad pads must be non-negative")
211
+
212
+ expected_shape = tuple(
213
+ dim + pad_before + pad_after
214
+ for dim, pad_before, pad_after in zip(
215
+ input_shape, pads_begin, pads_end
216
+ )
217
+ )
218
+ if output_shape != expected_shape:
219
+ raise ShapeInferenceError(
220
+ "Pad output shape mismatch: "
221
+ f"expected {expected_shape}, got {output_shape}"
222
+ )
223
+ elif pads_values is not None:
224
+ if any(value < 0 for value in pads_values):
225
+ raise UnsupportedOpError("Pad pads must be non-negative")
226
+
227
+ value_name = optional_name(node.inputs, 2)
228
+ pad_value = None
229
+ value_input = None
230
+ value_input_shape = None
231
+ if value_name:
232
+ pad_value = _read_scalar_initializer(
233
+ graph, value_name, node, dtype=input_dtype
234
+ )
235
+ if pad_value is None:
236
+ value_input_shape = value_shape(graph, value_name, node)
237
+ input_value_dtype = value_dtype(graph, value_name, node)
238
+ if input_value_dtype != input_dtype:
239
+ raise UnsupportedOpError(
240
+ "Pad value input must match input dtype, "
241
+ f"got {input_value_dtype.onnx_name}"
242
+ )
243
+ if value_input_shape:
244
+ raise UnsupportedOpError("Pad value input must be a scalar")
245
+ value_input = value_name
246
+ elif "value" in node.attrs:
247
+ pad_value = node.attrs["value"]
248
+ if pad_value is None and value_input is None:
249
+ pad_value = _default_pad_value(input_dtype)
250
+
251
+ return PadOp(
252
+ input0=input_name,
253
+ output=node.outputs[0],
254
+ input_shape=input_shape,
255
+ output_shape=output_shape,
256
+ pads_begin=(
257
+ tuple(int(value) for value in pads_begin)
258
+ if pads_begin is not None
259
+ else None
260
+ ),
261
+ pads_end=(
262
+ tuple(int(value) for value in pads_end)
263
+ if pads_end is not None
264
+ else None
265
+ ),
266
+ pads_input=pads_input,
267
+ pads_shape=pads_shape,
268
+ pads_dtype=pads_dtype,
269
+ pads_axis_map=(
270
+ tuple(pads_axis_map) if pads_axis_map is not None else None
271
+ ),
272
+ pads_values=(
273
+ tuple(int(value) for value in pads_values)
274
+ if pads_values is not None
275
+ else None
276
+ ),
277
+ axes_input=axes_input,
278
+ axes_shape=axes_shape,
279
+ axes_dtype=axes_dtype,
280
+ mode=mode,
281
+ value=pad_value,
282
+ value_input=value_input,
283
+ value_shape=value_input_shape,
284
+ dtype=output_dtype,
285
+ input_dtype=input_dtype,
286
+ input_strides=_compute_strides(input_shape),
287
+ )
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+
7
+ from shared.scalar_types import ScalarType
8
+
9
+ from ..codegen.c_emitter import RangeOp
10
+ from ..errors import ShapeInferenceError, UnsupportedOpError
11
+ from ..ir.model import Graph, Initializer, Node
12
+ from ..lowering.common import node_dtype, value_shape
13
+ from .registry import register_lowering
14
+
15
+
16
+ _SUPPORTED_RANGE_DTYPES = {
17
+ ScalarType.F32,
18
+ ScalarType.F64,
19
+ ScalarType.I16,
20
+ ScalarType.I32,
21
+ ScalarType.I64,
22
+ }
23
+
24
+
25
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
26
+ for initializer in graph.initializers:
27
+ if initializer.name == name:
28
+ return initializer
29
+ return None
30
+
31
+
32
+ def _read_scalar_initializer(
33
+ graph: Graph, name: str, node: Node, label: str
34
+ ) -> float | int | None:
35
+ initializer = _find_initializer(graph, name)
36
+ if initializer is None:
37
+ return None
38
+ data = np.array(initializer.data)
39
+ if data.size != 1:
40
+ raise UnsupportedOpError(
41
+ f"{node.op_type} {label} input must be a scalar"
42
+ )
43
+ return data.reshape(-1)[0].item()
44
+
45
+
46
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
47
+ return shape == () or shape == (1,)
48
+
49
+
50
+ @register_lowering("Range")
51
+ def lower_range(graph: Graph, node: Node) -> RangeOp:
52
+ if len(node.inputs) != 3 or len(node.outputs) != 1:
53
+ raise UnsupportedOpError("Range must have 3 inputs and 1 output")
54
+ start_shape = value_shape(graph, node.inputs[0], node)
55
+ limit_shape = value_shape(graph, node.inputs[1], node)
56
+ delta_shape = value_shape(graph, node.inputs[2], node)
57
+ if not (
58
+ _is_scalar_shape(start_shape)
59
+ and _is_scalar_shape(limit_shape)
60
+ and _is_scalar_shape(delta_shape)
61
+ ):
62
+ raise UnsupportedOpError("Range inputs must be scalars")
63
+ dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
64
+ if dtype not in _SUPPORTED_RANGE_DTYPES:
65
+ raise UnsupportedOpError(
66
+ f"Range does not support dtype {dtype.onnx_name}"
67
+ )
68
+ output_shape = value_shape(graph, node.outputs[0], node)
69
+ if len(output_shape) != 1:
70
+ raise ShapeInferenceError("Range output must be 1D")
71
+ start_value = _read_scalar_initializer(graph, node.inputs[0], node, "start")
72
+ limit_value = _read_scalar_initializer(graph, node.inputs[1], node, "limit")
73
+ delta_value = _read_scalar_initializer(graph, node.inputs[2], node, "delta")
74
+ if (
75
+ start_value is not None
76
+ and limit_value is not None
77
+ and delta_value is not None
78
+ ):
79
+ if float(delta_value) == 0.0:
80
+ raise UnsupportedOpError("Range delta must be non-zero")
81
+ raw_count = (
82
+ float(limit_value) - float(start_value)
83
+ ) / float(delta_value)
84
+ length = max(int(math.ceil(raw_count)), 0)
85
+ if length < 0:
86
+ raise ShapeInferenceError("Range output length must be non-negative")
87
+ if output_shape[0] != length:
88
+ raise ShapeInferenceError(
89
+ f"Range output length must be {length}, got {output_shape[0]}"
90
+ )
91
+ else:
92
+ length = output_shape[0]
93
+ if length < 0:
94
+ raise ShapeInferenceError("Range output length must be non-negative")
95
+ return RangeOp(
96
+ start=node.inputs[0],
97
+ limit=node.inputs[1],
98
+ delta=node.inputs[2],
99
+ output=node.outputs[0],
100
+ output_shape=output_shape,
101
+ length=length,
102
+ dtype=dtype,
103
+ input_dtype=dtype,
104
+ )