emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ from onnx import numpy_helper
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import ConstantOfShapeOp
7
+ from ..ir.ops import ConstantOfShapeOp
8
8
  from ..dtypes import scalar_type_from_onnx
9
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import math
4
4
  from dataclasses import dataclass
5
5
 
6
- from ..codegen.c_emitter import ConvOp
6
+ from ..ir.ops import ConvOp
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
8
  from ..ir.model import Graph, Node
9
9
  from .common import node_dtype as _node_dtype
@@ -0,0 +1,301 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ from ..ir.ops import ConvTransposeOp
7
+ from ..errors import ShapeInferenceError, UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+ from .common import node_dtype as _node_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ConvTransposeSpec:
16
+ batch: int
17
+ in_channels: int
18
+ out_channels: int
19
+ spatial_rank: int
20
+ in_spatial: tuple[int, ...]
21
+ out_spatial: tuple[int, ...]
22
+ kernel_shape: tuple[int, ...]
23
+ strides: tuple[int, ...]
24
+ pads: tuple[int, ...]
25
+ dilations: tuple[int, ...]
26
+ output_padding: tuple[int, ...]
27
+ group: int
28
+
29
+
30
+ def _split_padding(
31
+ total_padding: int, auto_pad: str, *, dim: int
32
+ ) -> tuple[int, int]:
33
+ if total_padding < 0:
34
+ raise ShapeInferenceError(
35
+ "ConvTranspose output shape must be fully defined and non-negative"
36
+ )
37
+ pad_end = total_padding // 2
38
+ pad_begin = total_padding - pad_end
39
+ if auto_pad == "SAME_UPPER":
40
+ pad_begin, pad_end = pad_end, pad_begin
41
+ elif auto_pad not in {"SAME_LOWER", "NOTSET", ""}:
42
+ raise UnsupportedOpError(
43
+ f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
44
+ )
45
+ if pad_begin < 0 or pad_end < 0:
46
+ raise ShapeInferenceError(
47
+ f"ConvTranspose pads must be non-negative for dim {dim}"
48
+ )
49
+ return pad_begin, pad_end
50
+
51
+
52
+ def resolve_conv_transpose_spec(graph: Graph, node: Node) -> ConvTransposeSpec:
53
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
54
+ raise UnsupportedOpError(
55
+ "ConvTranspose must have 2 or 3 inputs and 1 output"
56
+ )
57
+ supported_attrs = {
58
+ "auto_pad",
59
+ "dilations",
60
+ "group",
61
+ "kernel_shape",
62
+ "output_padding",
63
+ "output_shape",
64
+ "pads",
65
+ "strides",
66
+ }
67
+ if set(node.attrs) - supported_attrs:
68
+ raise UnsupportedOpError("ConvTranspose has unsupported attributes")
69
+ input_shape = _value_shape(graph, node.inputs[0], node)
70
+ weight_shape = _value_shape(graph, node.inputs[1], node)
71
+ if len(input_shape) < 3:
72
+ raise UnsupportedOpError("ConvTranspose expects NCHW inputs with spatial dims")
73
+ spatial_rank = len(input_shape) - 2
74
+ if spatial_rank not in {1, 2, 3}:
75
+ raise UnsupportedOpError("ConvTranspose supports 1D/2D/3D inputs only")
76
+ if len(weight_shape) != spatial_rank + 2:
77
+ raise UnsupportedOpError(
78
+ "ConvTranspose weight rank must match spatial rank"
79
+ )
80
+ batch, in_channels = input_shape[0], input_shape[1]
81
+ in_spatial = input_shape[2:]
82
+ weight_in_channels, weight_out_channels, *kernel_shape = weight_shape
83
+ kernel_attr = node.attrs.get("kernel_shape")
84
+ if kernel_attr is not None:
85
+ kernel_attr = tuple(int(value) for value in kernel_attr)
86
+ if len(kernel_attr) != spatial_rank:
87
+ raise UnsupportedOpError(
88
+ "ConvTranspose kernel_shape rank must match input spatial rank"
89
+ )
90
+ if kernel_attr != tuple(kernel_shape):
91
+ raise ShapeInferenceError(
92
+ "ConvTranspose kernel_shape must match weights, "
93
+ f"got {kernel_attr} and {tuple(kernel_shape)}"
94
+ )
95
+ kernel_shape = list(kernel_attr)
96
+ else:
97
+ kernel_shape = list(kernel_shape)
98
+ group = int(node.attrs.get("group", 1))
99
+ if group <= 0:
100
+ raise UnsupportedOpError("ConvTranspose expects group >= 1")
101
+ if in_channels % group != 0:
102
+ raise ShapeInferenceError(
103
+ "ConvTranspose expects group to evenly divide in channels, "
104
+ f"got group={group}, in_channels={in_channels}"
105
+ )
106
+ if weight_in_channels != in_channels:
107
+ raise ShapeInferenceError(
108
+ "ConvTranspose input channels must match weight channels, "
109
+ f"got {in_channels} and {weight_in_channels}"
110
+ )
111
+ out_channels = weight_out_channels * group
112
+ if out_channels % group != 0:
113
+ raise ShapeInferenceError(
114
+ "ConvTranspose expects group to evenly divide out channels, "
115
+ f"got group={group}, out_channels={out_channels}"
116
+ )
117
+ if len(node.inputs) == 3:
118
+ bias_shape = _value_shape(graph, node.inputs[2], node)
119
+ if bias_shape != (out_channels,):
120
+ raise ShapeInferenceError(
121
+ f"ConvTranspose bias shape must be {(out_channels,)}, got {bias_shape}"
122
+ )
123
+ strides = tuple(
124
+ int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
125
+ )
126
+ if len(strides) != spatial_rank:
127
+ raise UnsupportedOpError("ConvTranspose stride rank mismatch")
128
+ dilations = tuple(
129
+ int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
130
+ )
131
+ if len(dilations) != spatial_rank:
132
+ raise UnsupportedOpError("ConvTranspose dilation rank mismatch")
133
+ output_padding = tuple(
134
+ int(value)
135
+ for value in node.attrs.get("output_padding", (0,) * spatial_rank)
136
+ )
137
+ if len(output_padding) != spatial_rank:
138
+ raise UnsupportedOpError("ConvTranspose output_padding rank mismatch")
139
+ for dim, (padding, stride) in enumerate(zip(output_padding, strides)):
140
+ if padding < 0:
141
+ raise UnsupportedOpError(
142
+ "ConvTranspose output_padding must be non-negative"
143
+ )
144
+ if padding >= stride:
145
+ raise UnsupportedOpError(
146
+ "ConvTranspose output_padding must be smaller than stride"
147
+ )
148
+ pads = tuple(
149
+ int(value)
150
+ for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
151
+ )
152
+ if len(pads) != 2 * spatial_rank:
153
+ raise UnsupportedOpError("ConvTranspose pads rank mismatch")
154
+ auto_pad = node.attrs.get("auto_pad", b"NOTSET")
155
+ if isinstance(auto_pad, bytes):
156
+ auto_pad = auto_pad.decode("utf-8", errors="ignore")
157
+ if auto_pad == "":
158
+ auto_pad = "NOTSET"
159
+ output_shape_attr = node.attrs.get("output_shape")
160
+ output_shape: list[int] | None = None
161
+ if output_shape_attr is not None:
162
+ output_shape = [int(value) for value in output_shape_attr]
163
+ if len(output_shape) != spatial_rank:
164
+ raise UnsupportedOpError("ConvTranspose output_shape rank mismatch")
165
+ if output_shape is not None:
166
+ if auto_pad == "VALID":
167
+ auto_pad = "NOTSET"
168
+ pad_begin = []
169
+ pad_end = []
170
+ for dim, (in_dim, stride, dilation, kernel, out_dim, out_pad) in enumerate(
171
+ zip(
172
+ in_spatial,
173
+ strides,
174
+ dilations,
175
+ kernel_shape,
176
+ output_shape,
177
+ output_padding,
178
+ )
179
+ ):
180
+ effective_kernel = dilation * (kernel - 1) + 1
181
+ total_padding = (
182
+ stride * (in_dim - 1)
183
+ + out_pad
184
+ + effective_kernel
185
+ - out_dim
186
+ )
187
+ pad_start, pad_finish = _split_padding(
188
+ total_padding, auto_pad, dim=dim
189
+ )
190
+ pad_begin.append(pad_start)
191
+ pad_end.append(pad_finish)
192
+ out_spatial = output_shape
193
+ else:
194
+ if auto_pad == "VALID":
195
+ pad_begin = [0] * spatial_rank
196
+ pad_end = [0] * spatial_rank
197
+ elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
198
+ pad_begin = []
199
+ pad_end = []
200
+ for dim, (in_dim, stride, dilation, kernel, out_pad) in enumerate(
201
+ zip(in_spatial, strides, dilations, kernel_shape, output_padding)
202
+ ):
203
+ effective_kernel = dilation * (kernel - 1) + 1
204
+ out_dim = in_dim * stride
205
+ total_padding = (
206
+ stride * (in_dim - 1)
207
+ + out_pad
208
+ + effective_kernel
209
+ - out_dim
210
+ )
211
+ pad_start, pad_finish = _split_padding(
212
+ total_padding, auto_pad, dim=dim
213
+ )
214
+ pad_begin.append(pad_start)
215
+ pad_end.append(pad_finish)
216
+ elif auto_pad in {"NOTSET"}:
217
+ pad_begin = list(pads[:spatial_rank])
218
+ pad_end = list(pads[spatial_rank:])
219
+ else:
220
+ raise UnsupportedOpError(
221
+ f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
222
+ )
223
+ out_spatial = []
224
+ for dim, (in_dim, stride, dilation, kernel, pad_start, pad_finish, out_pad) in enumerate(
225
+ zip(
226
+ in_spatial,
227
+ strides,
228
+ dilations,
229
+ kernel_shape,
230
+ pad_begin,
231
+ pad_end,
232
+ output_padding,
233
+ )
234
+ ):
235
+ effective_kernel = dilation * (kernel - 1) + 1
236
+ out_dim = (
237
+ stride * (in_dim - 1)
238
+ + out_pad
239
+ + effective_kernel
240
+ - pad_start
241
+ - pad_finish
242
+ )
243
+ if out_dim < 0:
244
+ raise ShapeInferenceError(
245
+ "ConvTranspose output shape must be non-negative"
246
+ )
247
+ out_spatial.append(out_dim)
248
+ output_shape = _value_shape(graph, node.outputs[0], node)
249
+ expected_output_shape = (batch, out_channels, *out_spatial)
250
+ if output_shape != expected_output_shape:
251
+ raise ShapeInferenceError(
252
+ "ConvTranspose output shape must be "
253
+ f"{expected_output_shape}, got {output_shape}"
254
+ )
255
+ return ConvTransposeSpec(
256
+ batch=batch,
257
+ in_channels=in_channels,
258
+ out_channels=out_channels,
259
+ spatial_rank=spatial_rank,
260
+ in_spatial=in_spatial,
261
+ out_spatial=tuple(out_spatial),
262
+ kernel_shape=tuple(kernel_shape),
263
+ strides=strides,
264
+ pads=(*pad_begin, *pad_end),
265
+ dilations=dilations,
266
+ output_padding=output_padding,
267
+ group=group,
268
+ )
269
+
270
+
271
+ @register_lowering("ConvTranspose")
272
+ def lower_conv_transpose(graph: Graph, node: Node) -> ConvTransposeOp:
273
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
274
+ raise UnsupportedOpError(
275
+ "ConvTranspose must have 2 or 3 inputs and 1 output"
276
+ )
277
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
278
+ if not op_dtype.is_float:
279
+ raise UnsupportedOpError(
280
+ "ConvTranspose supports float16, float, and double inputs only"
281
+ )
282
+ spec = resolve_conv_transpose_spec(graph, node)
283
+ return ConvTransposeOp(
284
+ input0=node.inputs[0],
285
+ weights=node.inputs[1],
286
+ bias=node.inputs[2] if len(node.inputs) == 3 else None,
287
+ output=node.outputs[0],
288
+ batch=spec.batch,
289
+ in_channels=spec.in_channels,
290
+ out_channels=spec.out_channels,
291
+ spatial_rank=spec.spatial_rank,
292
+ in_spatial=spec.in_spatial,
293
+ out_spatial=spec.out_spatial,
294
+ kernel_shape=spec.kernel_shape,
295
+ strides=spec.strides,
296
+ pads=spec.pads,
297
+ dilations=spec.dilations,
298
+ output_padding=spec.output_padding,
299
+ group=spec.group,
300
+ dtype=op_dtype,
301
+ )
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import CumSumOp
7
+ from ..ir.ops import CumSumOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import DepthToSpaceOp, SpaceToDepthOp
3
+ from ..ir.ops import DepthToSpaceOp, SpaceToDepthOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..lowering.common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import ReshapeOp
3
+ from ..ir.ops import ReshapeOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import value_dtype as _value_dtype
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ from ..ir.ops import EinsumKind, EinsumOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import node_dtype as _node_dtype
7
+ from .common import value_shape as _value_shape
8
+ from .registry import register_lowering
9
+
10
+
11
+ def _normalize_equation(equation: str) -> str:
12
+ return equation.replace(" ", "")
13
+
14
+
15
+ @register_lowering("Einsum")
16
+ def lower_einsum(graph: Graph, node: Node) -> EinsumOp:
17
+ if not node.inputs or len(node.outputs) != 1:
18
+ raise UnsupportedOpError("Einsum must have 1 output and at least 1 input")
19
+ equation_value = node.attrs.get("equation")
20
+ if equation_value is None:
21
+ raise UnsupportedOpError("Einsum equation attribute is required")
22
+ equation = (
23
+ equation_value.decode()
24
+ if isinstance(equation_value, (bytes, bytearray))
25
+ else str(equation_value)
26
+ )
27
+ normalized = _normalize_equation(equation)
28
+ input_shapes = tuple(
29
+ _value_shape(graph, name, node) for name in node.inputs
30
+ )
31
+ output_shape = _value_shape(graph, node.outputs[0], node)
32
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
33
+ if normalized == "->":
34
+ if len(node.inputs) != 1:
35
+ raise UnsupportedOpError("Einsum '->' must have 1 input")
36
+ if output_shape:
37
+ raise ShapeInferenceError(
38
+ "Einsum '->' output must be scalar, "
39
+ f"got shape {output_shape}"
40
+ )
41
+ kind = EinsumKind.REDUCE_ALL
42
+ elif normalized == "ij->i":
43
+ if len(node.inputs) != 1:
44
+ raise UnsupportedOpError("Einsum 'ij->i' must have 1 input")
45
+ input_shape = input_shapes[0]
46
+ if len(input_shape) != 2:
47
+ raise ShapeInferenceError(
48
+ "Einsum 'ij->i' input must be 2D, "
49
+ f"got shape {input_shape}"
50
+ )
51
+ expected = (input_shape[0],)
52
+ if output_shape != expected:
53
+ raise ShapeInferenceError(
54
+ f"Einsum 'ij->i' output must match shape {expected}, "
55
+ f"got {output_shape}"
56
+ )
57
+ kind = EinsumKind.SUM_J
58
+ elif normalized == "ij->ji":
59
+ if len(node.inputs) != 1:
60
+ raise UnsupportedOpError("Einsum 'ij->ji' must have 1 input")
61
+ input_shape = input_shapes[0]
62
+ if len(input_shape) != 2:
63
+ raise ShapeInferenceError(
64
+ "Einsum 'ij->ji' input must be 2D, "
65
+ f"got shape {input_shape}"
66
+ )
67
+ expected = (input_shape[1], input_shape[0])
68
+ if output_shape != expected:
69
+ raise ShapeInferenceError(
70
+ f"Einsum 'ij->ji' output must match shape {expected}, "
71
+ f"got {output_shape}"
72
+ )
73
+ kind = EinsumKind.TRANSPOSE
74
+ elif normalized in {"i,i", "i,i->"}:
75
+ if len(node.inputs) != 2:
76
+ raise UnsupportedOpError("Einsum 'i,i' must have 2 inputs")
77
+ left_shape, right_shape = input_shapes
78
+ if len(left_shape) != 1 or len(right_shape) != 1:
79
+ raise ShapeInferenceError(
80
+ "Einsum 'i,i' inputs must be vectors, "
81
+ f"got shapes {left_shape} and {right_shape}"
82
+ )
83
+ if left_shape[0] != right_shape[0]:
84
+ raise ShapeInferenceError(
85
+ "Einsum 'i,i' inputs must have the same length, "
86
+ f"got shapes {left_shape} and {right_shape}"
87
+ )
88
+ if output_shape:
89
+ raise ShapeInferenceError(
90
+ "Einsum 'i,i' output must be scalar, "
91
+ f"got shape {output_shape}"
92
+ )
93
+ kind = EinsumKind.DOT
94
+ elif normalized == "bij,bjk->bik":
95
+ if len(node.inputs) != 2:
96
+ raise UnsupportedOpError("Einsum 'bij,bjk->bik' must have 2 inputs")
97
+ left_shape, right_shape = input_shapes
98
+ if len(left_shape) != 3 or len(right_shape) != 3:
99
+ raise ShapeInferenceError(
100
+ "Einsum 'bij,bjk->bik' inputs must be 3D, "
101
+ f"got shapes {left_shape} and {right_shape}"
102
+ )
103
+ if left_shape[0] != right_shape[0]:
104
+ raise ShapeInferenceError(
105
+ "Einsum 'bij,bjk->bik' batch dimensions must match, "
106
+ f"got shapes {left_shape} and {right_shape}"
107
+ )
108
+ if left_shape[2] != right_shape[1]:
109
+ raise ShapeInferenceError(
110
+ "Einsum 'bij,bjk->bik' contraction dimensions must match, "
111
+ f"got shapes {left_shape} and {right_shape}"
112
+ )
113
+ expected = (left_shape[0], left_shape[1], right_shape[2])
114
+ if output_shape != expected:
115
+ raise ShapeInferenceError(
116
+ f"Einsum 'bij,bjk->bik' output must match shape {expected}, "
117
+ f"got {output_shape}"
118
+ )
119
+ kind = EinsumKind.BATCH_MATMUL
120
+ elif normalized == "...ii->...i":
121
+ if len(node.inputs) != 1:
122
+ raise UnsupportedOpError("Einsum '...ii->...i' must have 1 input")
123
+ input_shape = input_shapes[0]
124
+ if len(input_shape) < 2:
125
+ raise ShapeInferenceError(
126
+ "Einsum '...ii->...i' input must be at least 2D, "
127
+ f"got shape {input_shape}"
128
+ )
129
+ if input_shape[-1] != input_shape[-2]:
130
+ raise ShapeInferenceError(
131
+ "Einsum '...ii->...i' requires last two dims to match, "
132
+ f"got shape {input_shape}"
133
+ )
134
+ expected = (*input_shape[:-2], input_shape[-1])
135
+ if output_shape != expected:
136
+ raise ShapeInferenceError(
137
+ f"Einsum '...ii->...i' output must match shape {expected}, "
138
+ f"got {output_shape}"
139
+ )
140
+ kind = EinsumKind.BATCH_DIAGONAL
141
+ else:
142
+ raise UnsupportedOpError(
143
+ f"Unsupported Einsum equation '{equation}'"
144
+ )
145
+ return EinsumOp(
146
+ inputs=tuple(node.inputs),
147
+ output=node.outputs[0],
148
+ kind=kind,
149
+ input_shapes=input_shapes,
150
+ output_shape=output_shape,
151
+ dtype=op_dtype,
152
+ input_dtype=op_dtype,
153
+ )
@@ -1,13 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
- from shared.scalar_functions import ScalarFunction
3
+ from shared.scalar_functions import ScalarFunction, ScalarFunctionError
4
4
  from shared.scalar_types import ScalarType
5
5
 
6
- from ..codegen.c_emitter import ClipOp, UnaryOp
6
+ from ..ir.ops import BinaryOp, ClipOp, UnaryOp
7
7
  from ..errors import UnsupportedOpError
8
+ from ..ir.context import GraphContext
8
9
  from ..ir.model import Graph, Node
9
10
  from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
10
- from ..lowering.registry import register_lowering
11
+ from ..lowering.registry import register_lowering, register_lowering_if_missing
12
+ from ..ops import (
13
+ BINARY_OP_TYPES,
14
+ COMPARE_FUNCTIONS,
15
+ UNARY_OP_TYPES,
16
+ binary_op_symbol,
17
+ unary_op_symbol,
18
+ validate_unary_attrs,
19
+ )
20
+ from ..lowering.variadic import VARIADIC_OP_FUNCTIONS
11
21
 
12
22
 
13
23
  @register_lowering("Clip")
@@ -120,6 +130,138 @@ def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
120
130
  )
121
131
 
122
132
 
133
+ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | UnaryOp:
134
+ if node.op_type == "BitShift":
135
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
136
+ raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
137
+ direction_attr = node.attrs.get("direction", "LEFT")
138
+ if isinstance(direction_attr, bytes):
139
+ direction = direction_attr.decode()
140
+ else:
141
+ direction = str(direction_attr)
142
+ if direction not in {"LEFT", "RIGHT"}:
143
+ raise UnsupportedOpError(
144
+ "BitShift direction must be LEFT or RIGHT"
145
+ )
146
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
147
+ if not op_dtype.is_integer:
148
+ raise UnsupportedOpError("BitShift expects integer inputs")
149
+ function = (
150
+ ScalarFunction.BITWISE_LEFT_SHIFT
151
+ if direction == "LEFT"
152
+ else ScalarFunction.BITWISE_RIGHT_SHIFT
153
+ )
154
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
155
+ if op_spec is None:
156
+ raise UnsupportedOpError("Unsupported op BitShift")
157
+ input0_shape = value_shape(graph, node.inputs[0], node)
158
+ input1_shape = value_shape(graph, node.inputs[1], node)
159
+ output_shape = value_shape(graph, node.outputs[0], node)
160
+ return BinaryOp(
161
+ input0=node.inputs[0],
162
+ input1=node.inputs[1],
163
+ output=node.outputs[0],
164
+ function=function,
165
+ operator_kind=op_spec.kind,
166
+ input0_shape=input0_shape,
167
+ input1_shape=input1_shape,
168
+ shape=output_shape,
169
+ dtype=op_dtype,
170
+ input_dtype=op_dtype,
171
+ )
172
+ if node.op_type == "Mod":
173
+ fmod = int(node.attrs.get("fmod", 0))
174
+ if fmod not in {0, 1}:
175
+ raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
176
+ function = (
177
+ ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
178
+ )
179
+ else:
180
+ try:
181
+ function = ScalarFunction.from_onnx_op(node.op_type)
182
+ except ScalarFunctionError as exc:
183
+ raise UnsupportedOpError(
184
+ f"Unsupported op {node.op_type}"
185
+ ) from exc
186
+ validate_unary_attrs(node.op_type, node.attrs)
187
+ if function in COMPARE_FUNCTIONS:
188
+ input_dtype = node_dtype(graph, node, *node.inputs)
189
+ output_dtype = value_dtype(graph, node.outputs[0], node)
190
+ op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
191
+ if op_spec is None:
192
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
193
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
194
+ raise UnsupportedOpError(
195
+ f"{node.op_type} must have 2 inputs and 1 output"
196
+ )
197
+ if output_dtype != ScalarType.BOOL:
198
+ raise UnsupportedOpError(
199
+ f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
200
+ )
201
+ input0_shape = value_shape(graph, node.inputs[0], node)
202
+ input1_shape = value_shape(graph, node.inputs[1], node)
203
+ output_shape = value_shape(graph, node.outputs[0], node)
204
+ return BinaryOp(
205
+ input0=node.inputs[0],
206
+ input1=node.inputs[1],
207
+ output=node.outputs[0],
208
+ function=function,
209
+ operator_kind=op_spec.kind,
210
+ input0_shape=input0_shape,
211
+ input1_shape=input1_shape,
212
+ shape=output_shape,
213
+ dtype=output_dtype,
214
+ input_dtype=input_dtype,
215
+ )
216
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
217
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
218
+ unary_symbol = unary_op_symbol(function, dtype=op_dtype)
219
+ if op_spec is None and unary_symbol is None:
220
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
221
+ if op_spec is not None:
222
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
223
+ raise UnsupportedOpError(
224
+ f"{node.op_type} must have 2 inputs and 1 output"
225
+ )
226
+ input0_shape = value_shape(graph, node.inputs[0], node)
227
+ input1_shape = value_shape(graph, node.inputs[1], node)
228
+ output_shape = value_shape(graph, node.outputs[0], node)
229
+ return BinaryOp(
230
+ input0=node.inputs[0],
231
+ input1=node.inputs[1],
232
+ output=node.outputs[0],
233
+ function=function,
234
+ operator_kind=op_spec.kind,
235
+ input0_shape=input0_shape,
236
+ input1_shape=input1_shape,
237
+ shape=output_shape,
238
+ dtype=op_dtype,
239
+ input_dtype=op_dtype,
240
+ )
241
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
242
+ raise UnsupportedOpError(
243
+ f"{node.op_type} must have 1 input and 1 output"
244
+ )
245
+ output_shape = value_shape(graph, node.outputs[0], node)
246
+ return UnaryOp(
247
+ input0=node.inputs[0],
248
+ output=node.outputs[0],
249
+ function=function,
250
+ shape=output_shape,
251
+ dtype=op_dtype,
252
+ input_dtype=op_dtype,
253
+ params=(),
254
+ )
255
+
256
+
257
+ _DEFAULT_ELEMENTWISE_TYPES = (
258
+ BINARY_OP_TYPES.union(UNARY_OP_TYPES) - set(VARIADIC_OP_FUNCTIONS.keys())
259
+ )
260
+
261
+ for _op_type in _DEFAULT_ELEMENTWISE_TYPES:
262
+ register_lowering_if_missing(_op_type)(_lower_binary_unary)
263
+
264
+
123
265
  @register_lowering("IsInf")
124
266
  def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
125
267
  if len(node.inputs) != 1 or len(node.outputs) != 1:
@@ -130,6 +272,12 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
130
272
  raise UnsupportedOpError("IsInf only supports floating-point inputs")
131
273
  if output_dtype != ScalarType.BOOL:
132
274
  raise UnsupportedOpError("IsInf output must be bool")
275
+ detect_negative = int(node.attrs.get("detect_negative", 1))
276
+ detect_positive = int(node.attrs.get("detect_positive", 1))
277
+ if detect_negative not in {0, 1} or detect_positive not in {0, 1}:
278
+ raise UnsupportedOpError(
279
+ "IsInf detect_negative and detect_positive must be 0 or 1"
280
+ )
133
281
  output_shape = value_shape(graph, node.outputs[0], node)
134
282
  return UnaryOp(
135
283
  input0=node.inputs[0],
@@ -138,7 +286,7 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
138
286
  shape=output_shape,
139
287
  dtype=output_dtype,
140
288
  input_dtype=input_dtype,
141
- params=(),
289
+ params=(float(detect_negative), float(detect_positive)),
142
290
  )
143
291
 
144
292