emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import ShapeOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .registry import register_lowering
9
+
10
+
11
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
12
+ try:
13
+ return graph.find_value(name).type.shape
14
+ except KeyError as exc:
15
+ raise ShapeInferenceError(
16
+ f"Missing shape for value '{name}' in op {node.op_type}. "
17
+ "Hint: run ONNX shape inference or export with static shapes."
18
+ ) from exc
19
+
20
+
21
+ def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
22
+ try:
23
+ return graph.find_value(name).type.dtype
24
+ except KeyError as exc:
25
+ raise ShapeInferenceError(
26
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
27
+ "Hint: run ONNX shape inference or export with static shapes."
28
+ ) from exc
29
+
30
+
31
+ def _normalize_slice_bounds(
32
+ rank: int, *, start: int | None, end: int | None
33
+ ) -> tuple[int, int]:
34
+ normalized_start = 0 if start is None else int(start)
35
+ normalized_end = rank if end is None else int(end)
36
+ if normalized_start < 0:
37
+ normalized_start += rank
38
+ if normalized_end < 0:
39
+ normalized_end += rank
40
+ normalized_start = max(0, min(normalized_start, rank))
41
+ normalized_end = max(0, min(normalized_end, rank))
42
+ return normalized_start, normalized_end
43
+
44
+
45
+ @register_lowering("Shape")
46
+ def lower_shape(graph: Graph, node: Node) -> ShapeOp:
47
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
48
+ raise UnsupportedOpError("Shape must have 1 input and 1 output")
49
+ input_shape = _value_shape(graph, node.inputs[0], node)
50
+ output_shape = _value_shape(graph, node.outputs[0], node)
51
+ if len(output_shape) != 1:
52
+ raise ShapeInferenceError("Shape output must be 1D")
53
+ if output_shape[0] < 0:
54
+ raise ShapeInferenceError("Shape output length must be non-negative")
55
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
56
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
57
+ if output_dtype != ScalarType.I64:
58
+ raise UnsupportedOpError("Shape output dtype must be int64")
59
+ start = node.attrs.get("start")
60
+ end = node.attrs.get("end")
61
+ start_index, end_index = _normalize_slice_bounds(
62
+ len(input_shape), start=start, end=end
63
+ )
64
+ expected_shape = (max(0, end_index - start_index),)
65
+ if expected_shape != output_shape:
66
+ raise ShapeInferenceError(
67
+ "Shape output shape must be "
68
+ f"{expected_shape}, got {output_shape}"
69
+ )
70
+ return ShapeOp(
71
+ input0=node.inputs[0],
72
+ output=node.outputs[0],
73
+ input_shape=input_shape,
74
+ output_shape=output_shape,
75
+ values=input_shape[start_index:end_index],
76
+ dtype=output_dtype,
77
+ input_dtype=input_dtype,
78
+ )
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import SizeOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import shape_product, value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ @register_lowering("Size")
13
+ def lower_size(graph: Graph, node: Node) -> SizeOp:
14
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
15
+ raise UnsupportedOpError("Size must have 1 input and 1 output")
16
+ input_shape = value_shape(graph, node.inputs[0], node)
17
+ output_shape = value_shape(graph, node.outputs[0], node)
18
+ if len(output_shape) != 0:
19
+ raise ShapeInferenceError("Size output must be a scalar")
20
+ output_dtype = value_dtype(graph, node.outputs[0], node)
21
+ if output_dtype != ScalarType.I64:
22
+ raise UnsupportedOpError("Size output dtype must be int64")
23
+ input_dtype = value_dtype(graph, node.inputs[0], node)
24
+ element_count = shape_product(input_shape)
25
+ return SizeOp(
26
+ input0=node.inputs[0],
27
+ output=node.outputs[0],
28
+ input_shape=input_shape,
29
+ output_shape=output_shape,
30
+ value=element_count,
31
+ dtype=output_dtype,
32
+ input_dtype=input_dtype,
33
+ )
@@ -0,0 +1,425 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+
7
+ from shared.scalar_types import ScalarType
8
+
9
+ from ..codegen.c_emitter import SliceOp
10
+ from ..errors import ShapeInferenceError, UnsupportedOpError
11
+ from ..ir.model import Graph, Initializer, Node
12
+ from ..lowering.common import value_dtype, value_shape
13
+ from ..validation import normalize_axis
14
+ from .registry import register_lowering
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class SliceSpec:
19
+ input_shape: tuple[int, ...]
20
+ output_shape: tuple[int, ...]
21
+ starts: tuple[int, ...]
22
+ steps: tuple[int, ...]
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class SliceInputs:
27
+ starts: list[int] | None
28
+ ends: list[int] | None
29
+ axes: list[int] | None
30
+ steps: list[int] | None
31
+ starts_input: str | None
32
+ ends_input: str | None
33
+ axes_input: str | None
34
+ steps_input: str | None
35
+ starts_shape: tuple[int, ...] | None
36
+ ends_shape: tuple[int, ...] | None
37
+ axes_shape: tuple[int, ...] | None
38
+ steps_shape: tuple[int, ...] | None
39
+ starts_dtype: ScalarType | None
40
+ ends_dtype: ScalarType | None
41
+ axes_dtype: ScalarType | None
42
+ steps_dtype: ScalarType | None
43
+
44
+
45
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
46
+ for initializer in graph.initializers:
47
+ if initializer.name == name:
48
+ return initializer
49
+ return None
50
+
51
+
52
+ def _read_int_list(
53
+ graph: Graph, name: str, node: Node, *, label: str
54
+ ) -> list[int]:
55
+ initializer = _find_initializer(graph, name)
56
+ if initializer is None:
57
+ raise UnsupportedOpError(
58
+ f"{node.op_type} {label} input must be a constant initializer"
59
+ )
60
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
61
+ raise UnsupportedOpError(
62
+ f"{node.op_type} {label} input must be int64 or int32"
63
+ )
64
+ data = np.array(initializer.data, dtype=np.int64).reshape(-1)
65
+ return [int(value) for value in data]
66
+
67
+
68
+ def _maybe_read_int_list(
69
+ graph: Graph, name: str, node: Node, *, label: str
70
+ ) -> list[int] | None:
71
+ initializer = _find_initializer(graph, name)
72
+ if initializer is None:
73
+ return None
74
+ return _read_int_list(graph, name, node, label=label)
75
+
76
+
77
+ def _validate_int_input(
78
+ graph: Graph, name: str, node: Node, *, label: str
79
+ ) -> tuple[tuple[int, ...], ScalarType]:
80
+ dtype = value_dtype(graph, name, node)
81
+ if dtype not in {ScalarType.I64, ScalarType.I32}:
82
+ raise UnsupportedOpError(
83
+ f"{node.op_type} {label} input must be int64 or int32"
84
+ )
85
+ shape = value_shape(graph, name, node)
86
+ if len(shape) != 1:
87
+ raise UnsupportedOpError(
88
+ f"{node.op_type} {label} input must be a 1D tensor"
89
+ )
90
+ return shape, dtype
91
+
92
+
93
+ def _resolve_inputs(
94
+ graph: Graph, node: Node
95
+ ) -> SliceInputs:
96
+ if "starts" in node.attrs or "ends" in node.attrs:
97
+ if len(node.inputs) != 1:
98
+ raise UnsupportedOpError(
99
+ f"{node.op_type} with starts/ends attributes expects 1 input"
100
+ )
101
+ if "starts" not in node.attrs or "ends" not in node.attrs:
102
+ raise UnsupportedOpError(
103
+ f"{node.op_type} must specify both starts and ends"
104
+ )
105
+ starts = [int(value) for value in node.attrs.get("starts", [])]
106
+ ends = [int(value) for value in node.attrs.get("ends", [])]
107
+ axes_attr = node.attrs.get("axes")
108
+ axes = [int(value) for value in axes_attr] if axes_attr else None
109
+ steps = None
110
+ return SliceInputs(
111
+ starts=starts,
112
+ ends=ends,
113
+ axes=axes,
114
+ steps=steps,
115
+ starts_input=None,
116
+ ends_input=None,
117
+ axes_input=None,
118
+ steps_input=None,
119
+ starts_shape=None,
120
+ ends_shape=None,
121
+ axes_shape=None,
122
+ steps_shape=None,
123
+ starts_dtype=None,
124
+ ends_dtype=None,
125
+ axes_dtype=None,
126
+ steps_dtype=None,
127
+ )
128
+ if len(node.inputs) < 3:
129
+ raise UnsupportedOpError(
130
+ f"{node.op_type} expects at least 3 inputs"
131
+ )
132
+ starts_name = node.inputs[1]
133
+ ends_name = node.inputs[2]
134
+ axes_name = node.inputs[3] if len(node.inputs) >= 4 else ""
135
+ steps_name = node.inputs[4] if len(node.inputs) >= 5 else ""
136
+ starts = _maybe_read_int_list(graph, starts_name, node, label="starts")
137
+ ends = _maybe_read_int_list(graph, ends_name, node, label="ends")
138
+ axes = (
139
+ _maybe_read_int_list(graph, axes_name, node, label="axes")
140
+ if axes_name
141
+ else None
142
+ )
143
+ steps = (
144
+ _maybe_read_int_list(graph, steps_name, node, label="steps")
145
+ if steps_name
146
+ else None
147
+ )
148
+ if starts is not None and ends is not None:
149
+ return SliceInputs(
150
+ starts=starts,
151
+ ends=ends,
152
+ axes=axes,
153
+ steps=steps,
154
+ starts_input=None,
155
+ ends_input=None,
156
+ axes_input=None,
157
+ steps_input=None,
158
+ starts_shape=None,
159
+ ends_shape=None,
160
+ axes_shape=None,
161
+ steps_shape=None,
162
+ starts_dtype=None,
163
+ ends_dtype=None,
164
+ axes_dtype=None,
165
+ steps_dtype=None,
166
+ )
167
+ if starts is None or ends is None:
168
+ starts_shape, starts_dtype = _validate_int_input(
169
+ graph, starts_name, node, label="starts"
170
+ )
171
+ ends_shape, ends_dtype = _validate_int_input(
172
+ graph, ends_name, node, label="ends"
173
+ )
174
+ if starts_shape != ends_shape:
175
+ raise ShapeInferenceError(
176
+ f"{node.op_type} starts and ends must have matching shapes"
177
+ )
178
+ axes_shape = None
179
+ axes_dtype = None
180
+ steps_shape = None
181
+ steps_dtype = None
182
+ axes_input = None
183
+ steps_input = None
184
+ if axes_name:
185
+ axes_shape, axes_dtype = _validate_int_input(
186
+ graph, axes_name, node, label="axes"
187
+ )
188
+ if axes_shape != starts_shape:
189
+ raise ShapeInferenceError(
190
+ f"{node.op_type} axes must match starts length"
191
+ )
192
+ axes_input = axes_name
193
+ if steps_name:
194
+ steps_shape, steps_dtype = _validate_int_input(
195
+ graph, steps_name, node, label="steps"
196
+ )
197
+ if steps_shape != starts_shape:
198
+ raise ShapeInferenceError(
199
+ f"{node.op_type} steps must match starts length"
200
+ )
201
+ steps_input = steps_name
202
+ return SliceInputs(
203
+ starts=None,
204
+ ends=None,
205
+ axes=None,
206
+ steps=None,
207
+ starts_input=starts_name,
208
+ ends_input=ends_name,
209
+ axes_input=axes_input,
210
+ steps_input=steps_input,
211
+ starts_shape=starts_shape,
212
+ ends_shape=ends_shape,
213
+ axes_shape=axes_shape,
214
+ steps_shape=steps_shape,
215
+ starts_dtype=starts_dtype,
216
+ ends_dtype=ends_dtype,
217
+ axes_dtype=axes_dtype,
218
+ steps_dtype=steps_dtype,
219
+ )
220
+ raise UnsupportedOpError(
221
+ f"{node.op_type} starts and ends inputs must both be constant initializers"
222
+ )
223
+
224
+
225
+ def _normalize_slices(
226
+ input_shape: tuple[int, ...],
227
+ starts: list[int],
228
+ ends: list[int],
229
+ axes: list[int] | None,
230
+ steps: list[int] | None,
231
+ node: Node,
232
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
233
+ rank = len(input_shape)
234
+ if rank == 0:
235
+ raise ShapeInferenceError(
236
+ f"{node.op_type} does not support scalar inputs"
237
+ )
238
+ if len(starts) != len(ends):
239
+ raise ShapeInferenceError(
240
+ f"{node.op_type} starts and ends must have matching lengths"
241
+ )
242
+ if axes is None:
243
+ axes = list(range(len(starts)))
244
+ if steps is None:
245
+ steps = [1] * len(starts)
246
+ if len(axes) != len(starts) or len(steps) != len(starts):
247
+ raise ShapeInferenceError(
248
+ f"{node.op_type} axes and steps must match starts length"
249
+ )
250
+ normalized_starts = [0] * rank
251
+ normalized_steps = [1] * rank
252
+ output_shape = list(input_shape)
253
+ seen_axes: set[int] = set()
254
+ for index, axis in enumerate(axes):
255
+ normalized_axis = normalize_axis(int(axis), input_shape, node)
256
+ if normalized_axis in seen_axes:
257
+ raise ShapeInferenceError(
258
+ f"{node.op_type} axes must be unique"
259
+ )
260
+ seen_axes.add(normalized_axis)
261
+ dim = input_shape[normalized_axis]
262
+ if dim < 0:
263
+ raise ShapeInferenceError("Dynamic dims are not supported")
264
+ step = int(steps[index])
265
+ if step == 0:
266
+ raise UnsupportedOpError(
267
+ f"{node.op_type} steps must be non-zero"
268
+ )
269
+ if step < 0:
270
+ raise UnsupportedOpError(
271
+ f"{node.op_type} only supports positive steps"
272
+ )
273
+ start = int(starts[index])
274
+ end = int(ends[index])
275
+ if start < 0:
276
+ start += dim
277
+ if end < 0:
278
+ end += dim
279
+ start = max(0, min(start, dim))
280
+ end = max(0, min(end, dim))
281
+ length = max(0, (end - start + step - 1) // step)
282
+ normalized_starts[normalized_axis] = start
283
+ normalized_steps[normalized_axis] = step
284
+ output_shape[normalized_axis] = length
285
+ return (
286
+ tuple(normalized_starts),
287
+ tuple(normalized_steps),
288
+ tuple(output_shape),
289
+ )
290
+
291
+
292
+ def resolve_slice_spec(graph: Graph, node: Node) -> SliceSpec:
293
+ if len(node.inputs) < 1 or len(node.outputs) != 1:
294
+ raise UnsupportedOpError("Slice must have 1 output")
295
+ input_shape = value_shape(graph, node.inputs[0], node)
296
+ output_shape = value_shape(graph, node.outputs[0], node)
297
+ input_dtype = value_dtype(graph, node.inputs[0], node)
298
+ output_dtype = value_dtype(graph, node.outputs[0], node)
299
+ if input_dtype != output_dtype:
300
+ raise UnsupportedOpError(
301
+ f"{node.op_type} expects matching input/output dtypes, "
302
+ f"got {input_dtype} and {output_dtype}"
303
+ )
304
+ if any(dim < 0 for dim in input_shape):
305
+ raise ShapeInferenceError("Dynamic dims are not supported")
306
+ if any(dim < 0 for dim in output_shape):
307
+ raise ShapeInferenceError("Dynamic dims are not supported")
308
+ inputs = _resolve_inputs(graph, node)
309
+ if inputs.starts is None or inputs.ends is None:
310
+ raise UnsupportedOpError(
311
+ f"{node.op_type} starts/ends inputs must be constant for shape "
312
+ "inference"
313
+ )
314
+ starts = inputs.starts
315
+ ends = inputs.ends
316
+ axes = inputs.axes
317
+ steps = inputs.steps
318
+ normalized_starts, normalized_steps, computed_output_shape = _normalize_slices(
319
+ input_shape, starts, ends, axes, steps, node
320
+ )
321
+ if output_shape and computed_output_shape != output_shape:
322
+ raise ShapeInferenceError(
323
+ f"{node.op_type} output shape must be "
324
+ f"{computed_output_shape}, got {output_shape}"
325
+ )
326
+ return SliceSpec(
327
+ input_shape=input_shape,
328
+ output_shape=computed_output_shape,
329
+ starts=normalized_starts,
330
+ steps=normalized_steps,
331
+ )
332
+
333
+
334
+ @register_lowering("Slice")
335
+ def lower_slice(graph: Graph, node: Node) -> SliceOp:
336
+ input_shape = value_shape(graph, node.inputs[0], node)
337
+ output_shape = value_shape(graph, node.outputs[0], node)
338
+ input_dtype = value_dtype(graph, node.inputs[0], node)
339
+ output_dtype = value_dtype(graph, node.outputs[0], node)
340
+ if input_dtype != output_dtype:
341
+ raise UnsupportedOpError(
342
+ f"{node.op_type} expects matching input/output dtypes, "
343
+ f"got {input_dtype} and {output_dtype}"
344
+ )
345
+ if any(dim < 0 for dim in input_shape):
346
+ raise ShapeInferenceError("Dynamic dims are not supported")
347
+ if any(dim < 0 for dim in output_shape):
348
+ raise ShapeInferenceError("Dynamic dims are not supported")
349
+ inputs = _resolve_inputs(graph, node)
350
+ if inputs.starts is not None and inputs.ends is not None:
351
+ normalized_starts, normalized_steps, computed_output_shape = _normalize_slices(
352
+ input_shape, inputs.starts, inputs.ends, inputs.axes, inputs.steps, node
353
+ )
354
+ if output_shape and computed_output_shape != output_shape:
355
+ raise ShapeInferenceError(
356
+ f"{node.op_type} output shape must be "
357
+ f"{computed_output_shape}, got {output_shape}"
358
+ )
359
+ return SliceOp(
360
+ input0=node.inputs[0],
361
+ output=node.outputs[0],
362
+ input_shape=input_shape,
363
+ output_shape=computed_output_shape,
364
+ starts=normalized_starts,
365
+ steps=normalized_steps,
366
+ axes=None,
367
+ starts_input=None,
368
+ ends_input=None,
369
+ axes_input=None,
370
+ steps_input=None,
371
+ starts_shape=None,
372
+ ends_shape=None,
373
+ axes_shape=None,
374
+ steps_shape=None,
375
+ starts_dtype=None,
376
+ ends_dtype=None,
377
+ axes_dtype=None,
378
+ steps_dtype=None,
379
+ dtype=input_dtype,
380
+ input_dtype=input_dtype,
381
+ )
382
+ if len(output_shape) != len(input_shape):
383
+ raise ShapeInferenceError(
384
+ f"{node.op_type} output rank must match input rank"
385
+ )
386
+ if inputs.starts_shape is None or inputs.ends_shape is None:
387
+ raise UnsupportedOpError(
388
+ f"{node.op_type} starts and ends inputs must be provided"
389
+ )
390
+ if inputs.starts_shape != inputs.ends_shape:
391
+ raise ShapeInferenceError(
392
+ f"{node.op_type} starts and ends must have matching shapes"
393
+ )
394
+ starts_len = inputs.starts_shape[0]
395
+ if starts_len > len(input_shape):
396
+ raise ShapeInferenceError(
397
+ f"{node.op_type} starts length exceeds input rank"
398
+ )
399
+ if starts_len == 0 and output_shape != input_shape:
400
+ raise ShapeInferenceError(
401
+ f"{node.op_type} empty starts expects output shape to match input"
402
+ )
403
+ return SliceOp(
404
+ input0=node.inputs[0],
405
+ output=node.outputs[0],
406
+ input_shape=input_shape,
407
+ output_shape=output_shape,
408
+ starts=None,
409
+ steps=None,
410
+ axes=None,
411
+ starts_input=inputs.starts_input,
412
+ ends_input=inputs.ends_input,
413
+ axes_input=inputs.axes_input,
414
+ steps_input=inputs.steps_input,
415
+ starts_shape=inputs.starts_shape,
416
+ ends_shape=inputs.ends_shape,
417
+ axes_shape=inputs.axes_shape,
418
+ steps_shape=inputs.steps_shape,
419
+ starts_dtype=inputs.starts_dtype,
420
+ ends_dtype=inputs.ends_dtype,
421
+ axes_dtype=inputs.axes_dtype,
422
+ steps_dtype=inputs.steps_dtype,
423
+ dtype=input_dtype,
424
+ input_dtype=input_dtype,
425
+ )
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import SoftmaxOp
4
+ from ..errors import UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import node_dtype as _node_dtype
7
+ from .common import shape_product as _shape_product
8
+ from .common import value_shape as _value_shape
9
+ from .registry import register_lowering
10
+ from ..validation import ensure_output_shape_matches_input
11
+ from ..validation import normalize_axis as _normalize_axis
12
+
13
+
14
+ @register_lowering("Softmax")
15
+ def lower_softmax(graph: Graph, node: Node) -> SoftmaxOp:
16
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
17
+ raise UnsupportedOpError("Softmax must have 1 input and 1 output")
18
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
19
+ if not op_dtype.is_float:
20
+ raise UnsupportedOpError(
21
+ "Softmax supports float16, float, and double inputs only"
22
+ )
23
+ input_shape = _value_shape(graph, node.inputs[0], node)
24
+ output_shape = _value_shape(graph, node.outputs[0], node)
25
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
26
+ axis = _normalize_axis(
27
+ int(node.attrs.get("axis", -1)),
28
+ input_shape,
29
+ node,
30
+ )
31
+ outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
32
+ axis_size = input_shape[axis]
33
+ inner = (
34
+ _shape_product(input_shape[axis + 1 :])
35
+ if axis + 1 < len(input_shape)
36
+ else 1
37
+ )
38
+ return SoftmaxOp(
39
+ input0=node.inputs[0],
40
+ output=node.outputs[0],
41
+ outer=outer,
42
+ axis_size=axis_size,
43
+ inner=inner,
44
+ axis=axis,
45
+ shape=input_shape,
46
+ dtype=op_dtype,
47
+ )