emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from ..codegen.c_emitter import AveragePoolOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .registry import register_lowering
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class _AveragePoolSpec:
13
+ batch: int
14
+ channels: int
15
+ in_h: int
16
+ in_w: int
17
+ out_h: int
18
+ out_w: int
19
+ kernel_h: int
20
+ kernel_w: int
21
+ stride_h: int
22
+ stride_w: int
23
+ pad_top: int
24
+ pad_left: int
25
+ pad_bottom: int
26
+ pad_right: int
27
+ count_include_pad: bool
28
+
29
+
30
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
31
+ try:
32
+ return graph.find_value(name).type.shape
33
+ except KeyError as exc:
34
+ raise ShapeInferenceError(
35
+ f"Missing shape for value '{name}' in op {node.op_type}. "
36
+ "Hint: run ONNX shape inference or export with static shapes."
37
+ ) from exc
38
+
39
+
40
+ def _value_dtype(graph: Graph, name: str, node: Node) -> str:
41
+ try:
42
+ return graph.find_value(name).type.dtype
43
+ except KeyError as exc:
44
+ raise ShapeInferenceError(
45
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
46
+ "Hint: run ONNX shape inference or export with static shapes."
47
+ ) from exc
48
+
49
+
50
+ def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
51
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
52
+ raise UnsupportedOpError("AveragePool must have 1 input and 1 output")
53
+ supported_attrs = {
54
+ "auto_pad",
55
+ "ceil_mode",
56
+ "count_include_pad",
57
+ "kernel_shape",
58
+ "pads",
59
+ "strides",
60
+ }
61
+ if set(node.attrs) - supported_attrs:
62
+ raise UnsupportedOpError("AveragePool has unsupported attributes")
63
+ auto_pad = node.attrs.get("auto_pad", b"NOTSET")
64
+ if isinstance(auto_pad, bytes):
65
+ auto_pad = auto_pad.decode("utf-8", errors="ignore")
66
+ if auto_pad not in ("", "NOTSET"):
67
+ raise UnsupportedOpError("AveragePool supports auto_pad=NOTSET only")
68
+ ceil_mode = int(node.attrs.get("ceil_mode", 0))
69
+ if ceil_mode != 0:
70
+ raise UnsupportedOpError("AveragePool supports ceil_mode=0 only")
71
+ count_include_pad = int(node.attrs.get("count_include_pad", 0))
72
+ if count_include_pad not in (0, 1):
73
+ raise UnsupportedOpError("AveragePool supports count_include_pad 0 or 1")
74
+ kernel_shape = node.attrs.get("kernel_shape")
75
+ if kernel_shape is None:
76
+ raise UnsupportedOpError("AveragePool requires kernel_shape")
77
+ kernel_shape = tuple(int(value) for value in kernel_shape)
78
+ if len(kernel_shape) != 2:
79
+ raise UnsupportedOpError("AveragePool expects 2D kernel_shape")
80
+ kernel_h, kernel_w = kernel_shape
81
+ strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
82
+ if len(strides) != 2:
83
+ raise UnsupportedOpError("AveragePool expects 2D strides")
84
+ pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
85
+ if len(pads) != 4:
86
+ raise UnsupportedOpError("AveragePool expects 4D pads")
87
+ pad_top, pad_left, pad_bottom, pad_right = pads
88
+ input_shape = _value_shape(graph, node.inputs[0], node)
89
+ if len(input_shape) != 4:
90
+ raise UnsupportedOpError("AveragePool supports NCHW 2D inputs only")
91
+ batch, channels, in_h, in_w = input_shape
92
+ stride_h, stride_w = strides
93
+ out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
94
+ out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
95
+ if out_h < 0 or out_w < 0:
96
+ raise ShapeInferenceError(
97
+ "AveragePool output shape must be non-negative"
98
+ )
99
+ output_shape = _value_shape(graph, node.outputs[0], node)
100
+ expected_output_shape = (batch, channels, out_h, out_w)
101
+ if output_shape != expected_output_shape:
102
+ raise ShapeInferenceError(
103
+ "AveragePool output shape must be "
104
+ f"{expected_output_shape}, got {output_shape}"
105
+ )
106
+ return _AveragePoolSpec(
107
+ batch=batch,
108
+ channels=channels,
109
+ in_h=in_h,
110
+ in_w=in_w,
111
+ out_h=out_h,
112
+ out_w=out_w,
113
+ kernel_h=kernel_h,
114
+ kernel_w=kernel_w,
115
+ stride_h=stride_h,
116
+ stride_w=stride_w,
117
+ pad_top=pad_top,
118
+ pad_left=pad_left,
119
+ pad_bottom=pad_bottom,
120
+ pad_right=pad_right,
121
+ count_include_pad=bool(count_include_pad),
122
+ )
123
+
124
+
125
+ def _resolve_global_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
126
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
127
+ raise UnsupportedOpError("GlobalAveragePool must have 1 input and 1 output")
128
+ if node.attrs:
129
+ raise UnsupportedOpError("GlobalAveragePool has unsupported attributes")
130
+ input_shape = _value_shape(graph, node.inputs[0], node)
131
+ if len(input_shape) != 4:
132
+ raise UnsupportedOpError("GlobalAveragePool supports NCHW 2D inputs only")
133
+ batch, channels, in_h, in_w = input_shape
134
+ output_shape = _value_shape(graph, node.outputs[0], node)
135
+ expected_output_shape = (batch, channels, 1, 1)
136
+ if output_shape != expected_output_shape:
137
+ raise ShapeInferenceError(
138
+ "GlobalAveragePool output shape must be "
139
+ f"{expected_output_shape}, got {output_shape}"
140
+ )
141
+ return _AveragePoolSpec(
142
+ batch=batch,
143
+ channels=channels,
144
+ in_h=in_h,
145
+ in_w=in_w,
146
+ out_h=1,
147
+ out_w=1,
148
+ kernel_h=in_h,
149
+ kernel_w=in_w,
150
+ stride_h=1,
151
+ stride_w=1,
152
+ pad_top=0,
153
+ pad_left=0,
154
+ pad_bottom=0,
155
+ pad_right=0,
156
+ count_include_pad=False,
157
+ )
158
+
159
+
160
+ @register_lowering("AveragePool")
161
+ def lower_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
162
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
163
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
164
+ if op_dtype != output_dtype:
165
+ raise UnsupportedOpError(
166
+ "AveragePool expects matching input/output dtypes, "
167
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
168
+ )
169
+ if not op_dtype.is_float:
170
+ raise UnsupportedOpError(
171
+ "AveragePool supports float16, float, and double inputs only"
172
+ )
173
+ spec = _resolve_average_pool_spec(graph, node)
174
+ return AveragePoolOp(
175
+ input0=node.inputs[0],
176
+ output=node.outputs[0],
177
+ batch=spec.batch,
178
+ channels=spec.channels,
179
+ in_h=spec.in_h,
180
+ in_w=spec.in_w,
181
+ out_h=spec.out_h,
182
+ out_w=spec.out_w,
183
+ kernel_h=spec.kernel_h,
184
+ kernel_w=spec.kernel_w,
185
+ stride_h=spec.stride_h,
186
+ stride_w=spec.stride_w,
187
+ pad_top=spec.pad_top,
188
+ pad_left=spec.pad_left,
189
+ pad_bottom=spec.pad_bottom,
190
+ pad_right=spec.pad_right,
191
+ count_include_pad=spec.count_include_pad,
192
+ dtype=op_dtype,
193
+ )
194
+
195
+
196
+ @register_lowering("GlobalAveragePool")
197
+ def lower_global_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
198
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
199
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
200
+ if op_dtype != output_dtype:
201
+ raise UnsupportedOpError(
202
+ "GlobalAveragePool expects matching input/output dtypes, "
203
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
204
+ )
205
+ if not op_dtype.is_float:
206
+ raise UnsupportedOpError(
207
+ "GlobalAveragePool supports float16, float, and double inputs only"
208
+ )
209
+ spec = _resolve_global_average_pool_spec(graph, node)
210
+ return AveragePoolOp(
211
+ input0=node.inputs[0],
212
+ output=node.outputs[0],
213
+ batch=spec.batch,
214
+ channels=spec.channels,
215
+ in_h=spec.in_h,
216
+ in_w=spec.in_w,
217
+ out_h=spec.out_h,
218
+ out_w=spec.out_w,
219
+ kernel_h=spec.kernel_h,
220
+ kernel_w=spec.kernel_w,
221
+ stride_h=spec.stride_h,
222
+ stride_w=spec.stride_w,
223
+ pad_top=spec.pad_top,
224
+ pad_left=spec.pad_left,
225
+ pad_bottom=spec.pad_bottom,
226
+ pad_right=spec.pad_right,
227
+ count_include_pad=spec.count_include_pad,
228
+ dtype=op_dtype,
229
+ )
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from ..codegen.c_emitter import BatchNormOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .registry import register_lowering
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class _BatchNormSpec:
13
+ shape: tuple[int, ...]
14
+ channels: int
15
+ epsilon: float
16
+
17
+
18
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
19
+ try:
20
+ return graph.find_value(name).type.shape
21
+ except KeyError as exc:
22
+ raise ShapeInferenceError(
23
+ f"Missing shape for value '{name}' in op {node.op_type}. "
24
+ "Hint: run ONNX shape inference or export with static shapes."
25
+ ) from exc
26
+
27
+
28
+ def _value_dtype(graph: Graph, name: str, node: Node) -> str:
29
+ try:
30
+ return graph.find_value(name).type.dtype
31
+ except KeyError as exc:
32
+ raise ShapeInferenceError(
33
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
34
+ "Hint: run ONNX shape inference or export with static shapes."
35
+ ) from exc
36
+
37
+
38
+ def _node_dtype(graph: Graph, node: Node, *names: str) -> str:
39
+ dtypes = {_value_dtype(graph, name, node) for name in names}
40
+ if len(dtypes) != 1:
41
+ raise UnsupportedOpError(
42
+ f"{node.op_type} expects matching dtypes, got {', '.join(sorted(dtypes))}"
43
+ )
44
+ return next(iter(dtypes))
45
+
46
+
47
+ def _resolve_batch_norm_spec(graph: Graph, node: Node) -> _BatchNormSpec:
48
+ if len(node.inputs) != 5 or len(node.outputs) != 1:
49
+ raise UnsupportedOpError(
50
+ "BatchNormalization must have 5 inputs and 1 output"
51
+ )
52
+ supported_attrs = {
53
+ "epsilon",
54
+ "is_test",
55
+ "momentum",
56
+ "spatial",
57
+ "training_mode",
58
+ }
59
+ if set(node.attrs) - supported_attrs:
60
+ raise UnsupportedOpError("BatchNormalization has unsupported attributes")
61
+ is_test = int(node.attrs.get("is_test", 1))
62
+ if is_test != 1:
63
+ raise UnsupportedOpError("BatchNormalization supports is_test=1 only")
64
+ training_mode = int(node.attrs.get("training_mode", 0))
65
+ if training_mode != 0:
66
+ raise UnsupportedOpError("BatchNormalization supports training_mode=0 only")
67
+ spatial = int(node.attrs.get("spatial", 1))
68
+ if spatial != 1:
69
+ raise UnsupportedOpError("BatchNormalization supports spatial=1 only")
70
+ epsilon = float(node.attrs.get("epsilon", 1e-5))
71
+ input_shape = _value_shape(graph, node.inputs[0], node)
72
+ if len(input_shape) < 2:
73
+ raise UnsupportedOpError(
74
+ "BatchNormalization expects input rank of at least 2"
75
+ )
76
+ channels = input_shape[1]
77
+ for name in node.inputs[1:]:
78
+ shape = _value_shape(graph, name, node)
79
+ if shape != (channels,):
80
+ raise ShapeInferenceError(
81
+ "BatchNormalization parameter shape must be "
82
+ f"({channels},), got {shape}"
83
+ )
84
+ output_shape = _value_shape(graph, node.outputs[0], node)
85
+ if output_shape != input_shape:
86
+ raise ShapeInferenceError(
87
+ "BatchNormalization output shape must match input shape, "
88
+ f"got {output_shape}"
89
+ )
90
+ return _BatchNormSpec(
91
+ shape=input_shape,
92
+ channels=channels,
93
+ epsilon=epsilon,
94
+ )
95
+
96
+
97
+ @register_lowering("BatchNormalization")
98
+ def lower_batch_normalization(graph: Graph, node: Node) -> BatchNormOp:
99
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
100
+ if not op_dtype.is_float:
101
+ raise UnsupportedOpError(
102
+ "BatchNormalization supports float16, float, and double inputs only"
103
+ )
104
+ spec = _resolve_batch_norm_spec(graph, node)
105
+ return BatchNormOp(
106
+ input0=node.inputs[0],
107
+ scale=node.inputs[1],
108
+ bias=node.inputs[2],
109
+ mean=node.inputs[3],
110
+ variance=node.inputs[4],
111
+ output=node.outputs[0],
112
+ shape=spec.shape,
113
+ channels=spec.channels,
114
+ epsilon=spec.epsilon,
115
+ dtype=op_dtype,
116
+ )
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ import onnx
4
+
5
+ from ..codegen.c_emitter import CastOp
6
+ from ..dtypes import scalar_type_from_onnx
7
+ from ..errors import ShapeInferenceError, UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+ from .common import ensure_supported_dtype, value_dtype, value_shape
10
+ from .registry import register_lowering
11
+
12
+
13
+ @register_lowering("Cast")
14
+ def lower_cast(graph: Graph, node: Node) -> CastOp:
15
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
16
+ raise UnsupportedOpError("Cast must have 1 input and 1 output")
17
+ if "to" not in node.attrs:
18
+ raise UnsupportedOpError("Cast requires a 'to' attribute")
19
+ target_onnx_dtype = int(node.attrs["to"])
20
+ target_dtype = scalar_type_from_onnx(target_onnx_dtype)
21
+ if target_dtype is None:
22
+ name = onnx.TensorProto.DataType.Name(target_onnx_dtype)
23
+ raise UnsupportedOpError(
24
+ f"Cast 'to' dtype {target_onnx_dtype} ({name}) is not supported"
25
+ )
26
+ target_dtype = ensure_supported_dtype(target_dtype)
27
+ input_dtype = value_dtype(graph, node.inputs[0], node)
28
+ output_dtype = value_dtype(graph, node.outputs[0], node)
29
+ if output_dtype != target_dtype:
30
+ raise UnsupportedOpError(
31
+ "Cast output dtype must match 'to' attribute, "
32
+ f"got {output_dtype.onnx_name} and {target_dtype.onnx_name}"
33
+ )
34
+ input_shape = value_shape(graph, node.inputs[0], node)
35
+ output_shape = value_shape(graph, node.outputs[0], node)
36
+ if input_shape != output_shape:
37
+ raise ShapeInferenceError("Cast input and output shapes must match")
38
+ return CastOp(
39
+ input0=node.inputs[0],
40
+ output=node.outputs[0],
41
+ shape=output_shape,
42
+ input_dtype=input_dtype,
43
+ dtype=output_dtype,
44
+ )
45
+
46
+
47
+ @register_lowering("CastLike")
48
+ def lower_castlike(graph: Graph, node: Node) -> CastOp:
49
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
50
+ raise UnsupportedOpError("CastLike must have 2 inputs and 1 output")
51
+ input_dtype = value_dtype(graph, node.inputs[0], node)
52
+ like_dtype = value_dtype(graph, node.inputs[1], node)
53
+ target_dtype = ensure_supported_dtype(like_dtype)
54
+ output_dtype = value_dtype(graph, node.outputs[0], node)
55
+ if output_dtype != target_dtype:
56
+ raise UnsupportedOpError(
57
+ "CastLike output dtype must match like input dtype, "
58
+ f"got {output_dtype.onnx_name} and {target_dtype.onnx_name}"
59
+ )
60
+ input_shape = value_shape(graph, node.inputs[0], node)
61
+ output_shape = value_shape(graph, node.outputs[0], node)
62
+ if input_shape != output_shape:
63
+ raise ShapeInferenceError("CastLike input and output shapes must match")
64
+ return CastOp(
65
+ input0=node.inputs[0],
66
+ output=node.outputs[0],
67
+ shape=output_shape,
68
+ input_dtype=input_dtype,
69
+ dtype=output_dtype,
70
+ )
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..errors import ShapeInferenceError, UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+
10
+
11
+ def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
12
+ if not isinstance(dtype, ScalarType):
13
+ raise UnsupportedOpError(f"Unsupported dtype {dtype}")
14
+ return dtype
15
+
16
+
17
+ def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType:
18
+ try:
19
+ value = graph.find_value(name)
20
+ except KeyError as exc:
21
+ op_type = node.op_type if node is not None else "unknown"
22
+ raise ShapeInferenceError(
23
+ f"Missing dtype for value '{name}' in op {op_type}. "
24
+ "Hint: run ONNX shape inference or export with static shapes."
25
+ ) from exc
26
+ return ensure_supported_dtype(value.type.dtype)
27
+
28
+
29
+ def value_shape(graph: Graph, name: str, node: Node | None = None) -> tuple[int, ...]:
30
+ try:
31
+ return graph.find_value(name).type.shape
32
+ except KeyError as exc:
33
+ op_type = node.op_type if node is not None else "unknown"
34
+ raise ShapeInferenceError(
35
+ f"Missing shape for value '{name}' in op {op_type}. "
36
+ "Hint: run ONNX shape inference or export with static shapes."
37
+ ) from exc
38
+
39
+
40
+ def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
41
+ filtered = [name for name in names if name]
42
+ if not filtered:
43
+ raise UnsupportedOpError(
44
+ f"{node.op_type} expects at least one typed input or output"
45
+ )
46
+ dtypes = {value_dtype(graph, name, node) for name in filtered}
47
+ if len(dtypes) != 1:
48
+ dtype_names = ", ".join(dtype.onnx_name for dtype in sorted(dtypes, key=str))
49
+ raise UnsupportedOpError(
50
+ f"{node.op_type} expects matching dtypes, got {dtype_names}"
51
+ )
52
+ return next(iter(dtypes))
53
+
54
+
55
+ def shape_product(shape: tuple[int, ...]) -> int:
56
+ if not shape:
57
+ return 1
58
+ product = 1
59
+ for dim in shape:
60
+ if dim < 0:
61
+ raise ShapeInferenceError("Dynamic dims are not supported")
62
+ if dim == 0:
63
+ return 0
64
+ product *= dim
65
+ return product
66
+
67
+
68
+ def optional_name(names: Sequence[str], index: int) -> str | None:
69
+ if index >= len(names):
70
+ return None
71
+ name = names[index]
72
+ return name or None
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import ConcatOp
4
+ from ..errors import UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import node_dtype as _node_dtype
7
+ from .common import value_shape as _value_shape
8
+ from .registry import register_lowering
9
+ from ..validation import validate_concat_shapes
10
+
11
+
12
+ @register_lowering("Concat")
13
+ def lower_concat(graph: Graph, node: Node) -> ConcatOp:
14
+ if len(node.inputs) < 1 or len(node.outputs) != 1:
15
+ raise UnsupportedOpError("Concat must have at least 1 input and 1 output")
16
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
17
+ output_shape = _value_shape(graph, node.outputs[0], node)
18
+ input_shapes = tuple(_value_shape(graph, name, node) for name in node.inputs)
19
+ axis = validate_concat_shapes(
20
+ input_shapes,
21
+ output_shape,
22
+ int(node.attrs.get("axis", 0)),
23
+ )
24
+ return ConcatOp(
25
+ inputs=node.inputs,
26
+ output=node.outputs[0],
27
+ axis=axis,
28
+ input_shapes=input_shapes,
29
+ output_shape=output_shape,
30
+ dtype=op_dtype,
31
+ )
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ from onnx import numpy_helper
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import ConstantOfShapeOp
8
+ from ..dtypes import scalar_type_from_onnx
9
+ from ..errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..ir.model import Graph, Node
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
15
+ try:
16
+ return graph.find_value(name).type.shape
17
+ except KeyError as exc:
18
+ raise ShapeInferenceError(
19
+ f"Missing shape for value '{name}' in op {node.op_type}. "
20
+ "Hint: run ONNX shape inference or export with static shapes."
21
+ ) from exc
22
+
23
+
24
+ def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
25
+ try:
26
+ return graph.find_value(name).type.dtype
27
+ except KeyError as exc:
28
+ raise ShapeInferenceError(
29
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
30
+ "Hint: run ONNX shape inference or export with static shapes."
31
+ ) from exc
32
+
33
+
34
+ def _parse_value_attr(node: Node) -> tuple[ScalarType, float | int | bool]:
35
+ value_attr = node.attrs.get("value")
36
+ if value_attr is None:
37
+ return ScalarType.F32, 0.0
38
+ dtype = scalar_type_from_onnx(value_attr.data_type)
39
+ if dtype is None:
40
+ raise UnsupportedOpError(
41
+ f"ConstantOfShape has unsupported value dtype {value_attr.data_type}"
42
+ )
43
+ data = numpy_helper.to_array(value_attr)
44
+ if data.size != 1:
45
+ raise UnsupportedOpError("ConstantOfShape value must be a scalar")
46
+ return dtype, data.reshape(-1)[0].item()
47
+
48
+
49
+ @register_lowering("ConstantOfShape")
50
+ def lower_constant_of_shape(graph: Graph, node: Node) -> ConstantOfShapeOp:
51
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
52
+ raise UnsupportedOpError("ConstantOfShape must have 1 input and 1 output")
53
+ input_shape = _value_shape(graph, node.inputs[0], node)
54
+ if len(input_shape) != 1:
55
+ raise UnsupportedOpError("ConstantOfShape expects a 1D shape input")
56
+ output_shape = _value_shape(graph, node.outputs[0], node)
57
+ if input_shape[0] != len(output_shape):
58
+ raise ShapeInferenceError(
59
+ "ConstantOfShape input length must match output rank"
60
+ )
61
+ for dim in output_shape:
62
+ if dim < 0:
63
+ raise ShapeInferenceError("Dynamic dims are not supported")
64
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
65
+ if input_dtype != ScalarType.I64:
66
+ raise UnsupportedOpError(
67
+ "ConstantOfShape expects int64 shape input, "
68
+ f"got {input_dtype.onnx_name}"
69
+ )
70
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
71
+ value_dtype, value = _parse_value_attr(node)
72
+ if output_dtype != value_dtype:
73
+ raise UnsupportedOpError(
74
+ "ConstantOfShape output dtype must match value dtype, "
75
+ f"got {output_dtype.onnx_name} and {value_dtype.onnx_name}"
76
+ )
77
+ return ConstantOfShapeOp(
78
+ input0=node.inputs[0],
79
+ output=node.outputs[0],
80
+ input_shape=input_shape,
81
+ shape=output_shape,
82
+ value=value,
83
+ dtype=output_dtype,
84
+ input_dtype=input_dtype,
85
+ )