emx-onnx-cgen 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

@@ -1,3 +1,3 @@
1
1
  """Auto-generated by build backend. Do not edit."""
2
- BUILD_DATE = '2026-01-23T06:14:26Z'
2
+ BUILD_DATE = '2026-01-23T07:18:06Z'
3
3
  GIT_VERSION = 'unknown'
emx_onnx_cgen/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.3'
32
- __version_tuple__ = version_tuple = (0, 3, 3)
31
+ __version__ = version = '0.3.5'
32
+ __version_tuple__ = version_tuple = (0, 3, 5)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -954,13 +954,13 @@ class CEmitter:
954
954
  )
955
955
  if isinstance(op, MultiInputBinaryOp):
956
956
  return MultiInputBinaryOp(
957
+ op_type=op.op_type,
957
958
  inputs=tuple(name_map.get(name, name) for name in op.inputs),
958
959
  output=name_map.get(op.output, op.output),
959
960
  function=op.function,
960
961
  operator_kind=op.operator_kind,
961
- shape=op.shape,
962
- dtype=op.dtype,
963
- input_dtype=op.input_dtype,
962
+ min_inputs=op.min_inputs,
963
+ max_inputs=op.max_inputs,
964
964
  )
965
965
  if isinstance(op, WhereOp):
966
966
  return WhereOp(
@@ -1554,12 +1554,7 @@ class CEmitter:
1554
1554
  data=name_map.get(op.data, op.data),
1555
1555
  indices=name_map.get(op.indices, op.indices),
1556
1556
  output=name_map.get(op.output, op.output),
1557
- data_shape=op.data_shape,
1558
- indices_shape=op.indices_shape,
1559
- output_shape=op.output_shape,
1560
1557
  axis=op.axis,
1561
- dtype=op.dtype,
1562
- indices_dtype=op.indices_dtype,
1563
1558
  )
1564
1559
  if isinstance(op, GatherNDOp):
1565
1560
  return GatherNDOp(
@@ -1896,13 +1891,8 @@ class CEmitter:
1896
1891
  if isinstance(op, ExpandOp):
1897
1892
  return ExpandOp(
1898
1893
  input0=name_map.get(op.input0, op.input0),
1899
- output=name_map.get(op.output, op.output),
1900
1894
  input_shape=op.input_shape,
1901
- output_shape=op.output_shape,
1902
- input_shape_padded=op.input_shape_padded,
1903
- input_strides=op.input_strides,
1904
- dtype=op.dtype,
1905
- input_dtype=op.input_dtype,
1895
+ output=name_map.get(op.output, op.output),
1906
1896
  )
1907
1897
  if isinstance(op, CumSumOp):
1908
1898
  return CumSumOp(
@@ -2233,8 +2223,8 @@ class CEmitter:
2233
2223
  ):
2234
2224
  testbench_math_include.add("#include <math.h>")
2235
2225
  includes = self._collect_includes(
2236
- model,
2237
- resolved_ops,
2226
+ original_model,
2227
+ list(original_model.ops),
2238
2228
  emit_testbench=emit_testbench,
2239
2229
  extra_includes=scalar_includes | testbench_math_include,
2240
2230
  needs_weight_loader=bool(large_constants),
@@ -2380,8 +2370,8 @@ class CEmitter:
2380
2370
  ):
2381
2371
  testbench_math_include.add("#include <math.h>")
2382
2372
  includes = self._collect_includes(
2383
- model,
2384
- resolved_ops,
2373
+ original_model,
2374
+ list(original_model.ops),
2385
2375
  emit_testbench=emit_testbench,
2386
2376
  extra_includes=scalar_includes | testbench_math_include,
2387
2377
  needs_weight_loader=bool(large_constants),
@@ -2790,8 +2780,17 @@ class CEmitter:
2790
2780
  *(const.dtype for const in model.constants),
2791
2781
  *constant_of_shape_inputs,
2792
2782
  }
2783
+ def _resolved_output_dtype(op: OpBase) -> ScalarType:
2784
+ if isinstance(op, MultiInputBinaryOp):
2785
+ return model.op_context.dtype(op.inputs[0])
2786
+ if isinstance(op, GatherOp):
2787
+ return model.op_context.dtype(op.data)
2788
+ if isinstance(op, ExpandOp):
2789
+ return model.op_context.dtype(op.input0)
2790
+ return op.dtype
2791
+
2793
2792
  model_dtypes.update(
2794
- op.dtype
2793
+ _resolved_output_dtype(op)
2795
2794
  for op in resolved_ops
2796
2795
  if not isinstance(op, (ArgReduceOp, TopKOp))
2797
2796
  )
@@ -3875,13 +3874,13 @@ class CEmitter:
3875
3874
  )
3876
3875
  if isinstance(op, MultiInputBinaryOp):
3877
3876
  return MultiInputBinaryOp(
3877
+ op_type=op.op_type,
3878
3878
  inputs=tuple(temp_map.get(name, name) for name in op.inputs),
3879
3879
  output=temp_map.get(op.output, op.output),
3880
3880
  function=op.function,
3881
3881
  operator_kind=op.operator_kind,
3882
- shape=op.shape,
3883
- dtype=op.dtype,
3884
- input_dtype=op.input_dtype,
3882
+ min_inputs=op.min_inputs,
3883
+ max_inputs=op.max_inputs,
3885
3884
  )
3886
3885
  if isinstance(op, WhereOp):
3887
3886
  return WhereOp(
@@ -4545,11 +4544,6 @@ class CEmitter:
4545
4544
  indices=temp_map.get(op.indices, op.indices),
4546
4545
  output=temp_map.get(op.output, op.output),
4547
4546
  axis=op.axis,
4548
- data_shape=op.data_shape,
4549
- indices_shape=op.indices_shape,
4550
- output_shape=op.output_shape,
4551
- dtype=op.dtype,
4552
- indices_dtype=op.indices_dtype,
4553
4547
  )
4554
4548
  if isinstance(op, GatherNDOp):
4555
4549
  return GatherNDOp(
@@ -4674,13 +4668,8 @@ class CEmitter:
4674
4668
  if isinstance(op, ExpandOp):
4675
4669
  return ExpandOp(
4676
4670
  input0=temp_map.get(op.input0, op.input0),
4677
- output=temp_map.get(op.output, op.output),
4678
4671
  input_shape=op.input_shape,
4679
- output_shape=op.output_shape,
4680
- input_shape_padded=op.input_shape_padded,
4681
- input_strides=op.input_strides,
4682
- dtype=op.dtype,
4683
- input_dtype=op.input_dtype,
4672
+ output=temp_map.get(op.output, op.output),
4684
4673
  )
4685
4674
  if isinstance(op, CumSumOp):
4686
4675
  return CumSumOp(
@@ -7446,30 +7435,34 @@ class CEmitter:
7446
7435
  ("output", op.output),
7447
7436
  ]
7448
7437
  )
7449
- output_shape = CEmitter._codegen_shape(op.output_shape)
7450
- loop_vars = CEmitter._loop_vars(output_shape)
7451
- output_loop_vars = loop_vars if op.output_shape else ()
7452
- indices_rank = len(op.indices_shape)
7438
+ output_shape_raw = self._ctx_shape(op.output)
7439
+ output_shape = CEmitter._codegen_shape(output_shape_raw)
7440
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
7441
+ output_loop_vars = loop_vars if output_shape_raw else ()
7442
+ indices_shape = self._ctx_shape(op.indices)
7443
+ indices_rank = len(indices_shape)
7453
7444
  if indices_rank == 0:
7454
7445
  indices_indices = ("0",)
7455
7446
  else:
7456
- indices_indices = output_loop_vars[
7457
- op.axis : op.axis + indices_rank
7458
- ]
7447
+ axis = int(self._derived(op, "axis"))
7448
+ indices_indices = output_loop_vars[axis : axis + indices_rank]
7449
+ axis = int(self._derived(op, "axis"))
7459
7450
  data_indices = [
7460
- *output_loop_vars[: op.axis],
7451
+ *output_loop_vars[:axis],
7461
7452
  "gather_index",
7462
- *output_loop_vars[op.axis + indices_rank :],
7453
+ *output_loop_vars[axis + indices_rank :],
7463
7454
  ]
7464
- data_suffix = self._param_array_suffix(op.data_shape)
7465
- indices_suffix = self._param_array_suffix(op.indices_shape)
7466
- output_suffix = self._param_array_suffix(op.output_shape)
7455
+ data_shape = self._ctx_shape(op.data)
7456
+ data_suffix = self._param_array_suffix(data_shape)
7457
+ indices_suffix = self._param_array_suffix(indices_shape)
7458
+ output_suffix = self._param_array_suffix(output_shape_raw)
7459
+ indices_dtype = self._ctx_dtype(op.indices)
7467
7460
  param_decls = self._build_param_decls(
7468
7461
  [
7469
7462
  (params["data"], c_type, data_suffix, True),
7470
7463
  (
7471
7464
  params["indices"],
7472
- op.indices_dtype.c_type,
7465
+ indices_dtype.c_type,
7473
7466
  indices_suffix,
7474
7467
  True,
7475
7468
  ),
@@ -7484,7 +7477,7 @@ class CEmitter:
7484
7477
  output=params["output"],
7485
7478
  params=param_decls,
7486
7479
  c_type=c_type,
7487
- indices_c_type=op.indices_dtype.c_type,
7480
+ indices_c_type=indices_dtype.c_type,
7488
7481
  data_suffix=data_suffix,
7489
7482
  indices_suffix=indices_suffix,
7490
7483
  output_suffix=output_suffix,
@@ -7492,7 +7485,7 @@ class CEmitter:
7492
7485
  loop_vars=loop_vars,
7493
7486
  indices_indices=indices_indices,
7494
7487
  data_indices=data_indices,
7495
- axis_dim=op.data_shape[op.axis],
7488
+ axis_dim=data_shape[axis],
7496
7489
  ).rstrip()
7497
7490
  return with_node_comment(rendered)
7498
7491
  if isinstance(op, GatherNDOp):
@@ -9139,15 +9132,17 @@ class CEmitter:
9139
9132
  [("input0", op.input0), ("output", op.output)]
9140
9133
  )
9141
9134
  output_dim_names = _dim_names_for(op.output)
9135
+ output_shape_raw = self._ctx_shape(op.output)
9142
9136
  output_shape = CEmitter._shape_dim_exprs(
9143
- op.output_shape, output_dim_names
9137
+ output_shape_raw, output_dim_names
9144
9138
  )
9145
- loop_vars = CEmitter._loop_vars(op.output_shape)
9139
+ loop_vars = CEmitter._loop_vars(output_shape_raw)
9140
+ input_shape = self._ctx_shape(op.input0)
9146
9141
  input_suffix = self._param_array_suffix(
9147
- op.input_shape, _dim_names_for(op.input0)
9142
+ input_shape, _dim_names_for(op.input0)
9148
9143
  )
9149
9144
  output_suffix = self._param_array_suffix(
9150
- op.output_shape, output_dim_names
9145
+ output_shape_raw, output_dim_names
9151
9146
  )
9152
9147
  param_decls = self._build_param_decls(
9153
9148
  [
@@ -9155,10 +9150,12 @@ class CEmitter:
9155
9150
  (params["output"], c_type, output_suffix, False),
9156
9151
  ]
9157
9152
  )
9153
+ input_shape_padded = self._derived(op, "input_shape_padded")
9154
+ input_strides = self._derived(op, "input_strides")
9158
9155
  input_index_terms = [
9159
9156
  f"{loop_var} * {stride}"
9160
9157
  for loop_var, input_dim, stride in zip(
9161
- loop_vars, op.input_shape_padded, op.input_strides
9158
+ loop_vars, input_shape_padded, input_strides
9162
9159
  )
9163
9160
  if input_dim != 1
9164
9161
  ]
@@ -10442,7 +10439,13 @@ class CEmitter:
10442
10439
  )
10443
10440
  if isinstance(op, NonMaxSuppressionOp):
10444
10441
  return ((op.output, op.output_shape, op.output_dtype),)
10445
- return ((op.output, self._op_output_shape(op), op.dtype),)
10442
+ return (
10443
+ (
10444
+ op.output,
10445
+ self._op_output_shape(op),
10446
+ self._op_output_dtype(op),
10447
+ ),
10448
+ )
10446
10449
 
10447
10450
  def _op_output_shape(
10448
10451
  self,
@@ -10566,7 +10569,7 @@ class CEmitter:
10566
10569
  if isinstance(op, GatherElementsOp):
10567
10570
  return op.output_shape
10568
10571
  if isinstance(op, GatherOp):
10569
- return op.output_shape
10572
+ return self._ctx_shape(op.output)
10570
10573
  if isinstance(op, GatherNDOp):
10571
10574
  return op.output_shape
10572
10575
  if isinstance(op, ScatterNDOp):
@@ -10614,7 +10617,7 @@ class CEmitter:
10614
10617
  if isinstance(op, NonMaxSuppressionOp):
10615
10618
  return op.output_shape
10616
10619
  if isinstance(op, ExpandOp):
10617
- return op.output_shape
10620
+ return self._ctx_shape(op.output)
10618
10621
  if isinstance(op, CumSumOp):
10619
10622
  return op.input_shape
10620
10623
  if isinstance(op, RangeOp):
@@ -10700,10 +10703,12 @@ class CEmitter:
10700
10703
  SoftmaxOp,
10701
10704
  LogSoftmaxOp,
10702
10705
  HardmaxOp,
10706
+ GatherOp,
10703
10707
  TransposeOp,
10704
10708
  ReshapeOp,
10705
10709
  IdentityOp,
10706
10710
  ReduceOp,
10711
+ ExpandOp,
10707
10712
  ),
10708
10713
  ):
10709
10714
  return self._ctx_dtype(op.output)
emx_onnx_cgen/compiler.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
3
+ from dataclasses import dataclass, fields
4
4
  import hashlib
5
5
  from pathlib import Path
6
6
  from typing import Mapping
@@ -24,6 +24,7 @@ from .ir.context import GraphContext
24
24
  from .ir.model import Graph, TensorType, Value
25
25
  from .ir.op_base import OpBase
26
26
  from .ir.op_context import OpContext
27
+ from .ir.ops import ExpandOp, GatherOp, MultiInputBinaryOp
27
28
  from .lowering import load_lowering_registry
28
29
  from .lowering.common import ensure_supported_dtype, shape_product, value_dtype
29
30
  from .lowering.registry import get_lowering_registry
@@ -172,6 +173,43 @@ class Compiler:
172
173
  ) = self._collect_io_specs(graph)
173
174
  ops, node_infos = self._lower_nodes(ctx)
174
175
  op_ctx = OpContext(ctx)
176
+ for op, node_info in zip(ops, node_infos):
177
+ field_names = {field.name for field in fields(op)}
178
+ if "dtype" in field_names:
179
+ dtype = getattr(op, "dtype")
180
+ for field in fields(op):
181
+ if not field.name.startswith("output"):
182
+ continue
183
+ value = getattr(op, field.name)
184
+ if isinstance(value, str):
185
+ op_ctx.set_dtype(value, dtype)
186
+ for name in node_info.outputs:
187
+ op_ctx.set_dtype(name, dtype)
188
+ if "outputs" in field_names:
189
+ dtype = getattr(op, "dtype", None)
190
+ if dtype is not None:
191
+ for name in getattr(op, "outputs"):
192
+ op_ctx.set_dtype(name, dtype)
193
+ if "output_dtype" in field_names and "output" in field_names:
194
+ output_name = getattr(op, "output")
195
+ if isinstance(output_name, str):
196
+ op_ctx.set_dtype(output_name, getattr(op, "output_dtype"))
197
+ if "output_values_dtype" in field_names:
198
+ op_ctx.set_dtype(
199
+ getattr(op, "output_values"),
200
+ getattr(op, "output_values_dtype"),
201
+ )
202
+ if "output_indices_dtype" in field_names:
203
+ op_ctx.set_dtype(
204
+ getattr(op, "output_indices"),
205
+ getattr(op, "output_indices_dtype"),
206
+ )
207
+ if isinstance(op, MultiInputBinaryOp) and op.inputs:
208
+ op_ctx.set_dtype(op.output, op_ctx.dtype(op.inputs[0]))
209
+ if isinstance(op, GatherOp):
210
+ op_ctx.set_dtype(op.output, op_ctx.dtype(op.data))
211
+ if isinstance(op, ExpandOp):
212
+ op_ctx.set_dtype(op.output, op_ctx.dtype(op.input0))
175
213
  for op in ops:
176
214
  op.validate(op_ctx)
177
215
  for op in ops:
@@ -88,7 +88,10 @@ class ElementwiseOpBase(RenderableOpBase):
88
88
  raise UnsupportedOpError(
89
89
  f"{self.kind} expects matching input dtypes, got {dtype_names}"
90
90
  )
91
- output_dtype = ctx.dtype(self._elementwise_output())
91
+ try:
92
+ output_dtype = ctx.dtype(self._elementwise_output())
93
+ except ShapeInferenceError:
94
+ return None
92
95
  if self._elementwise_compare():
93
96
  if output_dtype != ScalarType.BOOL:
94
97
  raise UnsupportedOpError(
@@ -107,7 +110,25 @@ class ElementwiseOpBase(RenderableOpBase):
107
110
  output_name = self._elementwise_output()
108
111
  for name in input_names:
109
112
  ctx.dtype(name)
110
- ctx.dtype(output_name)
113
+ desired_dtype = (
114
+ ScalarType.BOOL if self._elementwise_compare() else None
115
+ )
116
+ if desired_dtype is None:
117
+ data_inputs = self._elementwise_data_inputs()
118
+ if data_inputs:
119
+ desired_dtype = ctx.dtype(data_inputs[0])
120
+ try:
121
+ output_dtype = ctx.dtype(output_name)
122
+ except ShapeInferenceError:
123
+ if desired_dtype is not None:
124
+ ctx.set_dtype(output_name, desired_dtype)
125
+ return None
126
+ raise
127
+ if desired_dtype is not None and output_dtype != desired_dtype:
128
+ raise UnsupportedOpError(
129
+ f"{self.kind} expects output dtype {desired_dtype.onnx_name}, "
130
+ f"got {output_dtype.onnx_name}"
131
+ )
111
132
 
112
133
  def infer_shapes(self, ctx: OpContext) -> None:
113
134
  input_names = self._elementwise_inputs()
@@ -121,6 +142,295 @@ class ElementwiseOpBase(RenderableOpBase):
121
142
  return None
122
143
 
123
144
 
145
+ class GatherLikeOpBase(RenderableOpBase):
146
+ def _gather_data(self) -> str:
147
+ raise NotImplementedError
148
+
149
+ def _gather_indices(self) -> str:
150
+ raise NotImplementedError
151
+
152
+ def _gather_output(self) -> str:
153
+ raise NotImplementedError
154
+
155
+ def _gather_axis(self) -> int:
156
+ raise NotImplementedError
157
+
158
+ def _gather_mode(self) -> str:
159
+ raise NotImplementedError
160
+
161
+ def validate(self, ctx: OpContext) -> None:
162
+ indices_dtype = ctx.dtype(self._gather_indices())
163
+ if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
164
+ raise UnsupportedOpError(
165
+ f"{self.kind} indices must be int32 or int64, "
166
+ f"got {indices_dtype.onnx_name}"
167
+ )
168
+ data_shape = ctx.shape(self._gather_data())
169
+ if self._gather_mode() in {"gather", "gather_elements"}:
170
+ if not data_shape:
171
+ raise ShapeInferenceError(
172
+ f"{self.kind} does not support scalar inputs"
173
+ )
174
+ axis = self._gather_axis()
175
+ if axis < 0:
176
+ axis += len(data_shape)
177
+ if axis < 0 or axis >= len(data_shape):
178
+ raise ShapeInferenceError(
179
+ f"{self.kind} axis {axis} is out of range for rank "
180
+ f"{len(data_shape)}"
181
+ )
182
+ return None
183
+
184
+ def infer_types(self, ctx: OpContext) -> None:
185
+ data_dtype = ctx.dtype(self._gather_data())
186
+ try:
187
+ output_dtype = ctx.dtype(self._gather_output())
188
+ except ShapeInferenceError:
189
+ ctx.set_dtype(self._gather_output(), data_dtype)
190
+ output_dtype = data_dtype
191
+ if output_dtype != data_dtype:
192
+ raise UnsupportedOpError(
193
+ f"{self.kind} expects output dtype {data_dtype.onnx_name}, "
194
+ f"got {output_dtype.onnx_name}"
195
+ )
196
+
197
+ def infer_shapes(self, ctx: OpContext) -> None:
198
+ data_shape = ctx.shape(self._gather_data())
199
+ indices_shape = ctx.shape(self._gather_indices())
200
+ axis = self._gather_axis()
201
+ if axis < 0:
202
+ axis += len(data_shape)
203
+ if axis < 0 or axis >= len(data_shape):
204
+ raise ShapeInferenceError(
205
+ f"{self.kind} axis {axis} is out of range for rank "
206
+ f"{len(data_shape)}"
207
+ )
208
+ if self._gather_mode() == "gather":
209
+ output_shape = (
210
+ data_shape[:axis] + indices_shape + data_shape[axis + 1 :]
211
+ )
212
+ else:
213
+ raise UnsupportedOpError(
214
+ f"{self.kind} does not support gather mode "
215
+ f"{self._gather_mode()}"
216
+ )
217
+ try:
218
+ expected = ctx.shape(self._gather_output())
219
+ except ShapeInferenceError:
220
+ expected = None
221
+ if expected is not None and expected != output_shape:
222
+ raise ShapeInferenceError(
223
+ f"{self.kind} output shape must be {output_shape}, got {expected}"
224
+ )
225
+ ctx.set_shape(self._gather_output(), output_shape)
226
+ ctx.set_derived(self, "axis", axis)
227
+
228
+
229
+ class ShapeLikeOpBase(RenderableOpBase):
230
+ def _shape_data(self) -> str:
231
+ raise NotImplementedError
232
+
233
+ def _shape_output(self) -> str:
234
+ raise NotImplementedError
235
+
236
+ def _shape_spec(self, ctx: OpContext) -> tuple[int, ...]:
237
+ raise NotImplementedError
238
+
239
+ def _shape_mode(self) -> str:
240
+ raise NotImplementedError
241
+
242
+ def _shape_derived(
243
+ self,
244
+ ctx: OpContext,
245
+ *,
246
+ data_shape: tuple[int, ...],
247
+ target_shape: tuple[int, ...],
248
+ output_shape: tuple[int, ...],
249
+ ) -> None:
250
+ return None
251
+
252
+ @staticmethod
253
+ def _validate_static_dims(shape: tuple[int, ...], kind: str) -> None:
254
+ if any(dim < 0 for dim in shape):
255
+ raise ShapeInferenceError(
256
+ f"{kind} does not support dynamic dims"
257
+ )
258
+
259
+ @staticmethod
260
+ def _broadcast_shape(
261
+ input_shape: tuple[int, ...],
262
+ target_shape: tuple[int, ...],
263
+ *,
264
+ kind: str,
265
+ ) -> tuple[int, ...]:
266
+ ShapeLikeOpBase._validate_static_dims(input_shape, kind)
267
+ ShapeLikeOpBase._validate_static_dims(target_shape, kind)
268
+ output_rank = max(len(input_shape), len(target_shape))
269
+ input_padded = (1,) * (output_rank - len(input_shape)) + input_shape
270
+ target_padded = (1,) * (output_rank - len(target_shape)) + target_shape
271
+ result: list[int] = []
272
+ for input_dim, target_dim in zip(input_padded, target_padded):
273
+ if input_dim == 1:
274
+ result.append(target_dim)
275
+ elif target_dim == 1:
276
+ result.append(input_dim)
277
+ elif input_dim == target_dim:
278
+ result.append(input_dim)
279
+ else:
280
+ raise ShapeInferenceError(
281
+ f"{kind} input shape {input_shape} is not "
282
+ f"broadcastable to {target_shape}"
283
+ )
284
+ return tuple(result)
285
+
286
+ def validate(self, ctx: OpContext) -> None:
287
+ data_shape = ctx.shape(self._shape_data())
288
+ target_shape = self._shape_spec(ctx)
289
+ if self._shape_mode() == "expand":
290
+ self._broadcast_shape(
291
+ data_shape, target_shape, kind=self.kind
292
+ )
293
+ return None
294
+
295
+ def infer_types(self, ctx: OpContext) -> None:
296
+ input_dtype = ctx.dtype(self._shape_data())
297
+ try:
298
+ output_dtype = ctx.dtype(self._shape_output())
299
+ except ShapeInferenceError:
300
+ ctx.set_dtype(self._shape_output(), input_dtype)
301
+ output_dtype = input_dtype
302
+ if output_dtype != input_dtype:
303
+ raise UnsupportedOpError(
304
+ f"{self.kind} expects output dtype {input_dtype.onnx_name}, "
305
+ f"got {output_dtype.onnx_name}"
306
+ )
307
+
308
+ def infer_shapes(self, ctx: OpContext) -> None:
309
+ data_shape = ctx.shape(self._shape_data())
310
+ target_shape = self._shape_spec(ctx)
311
+ if self._shape_mode() == "expand":
312
+ output_shape = self._broadcast_shape(
313
+ data_shape, target_shape, kind=self.kind
314
+ )
315
+ else:
316
+ output_shape = target_shape
317
+ try:
318
+ expected = ctx.shape(self._shape_output())
319
+ except ShapeInferenceError:
320
+ expected = None
321
+ if expected is not None and expected != output_shape:
322
+ raise ShapeInferenceError(
323
+ f"{self.kind} output shape must be {output_shape}, got {expected}"
324
+ )
325
+ ctx.set_shape(self._shape_output(), output_shape)
326
+ self._shape_derived(
327
+ ctx,
328
+ data_shape=data_shape,
329
+ target_shape=target_shape,
330
+ output_shape=output_shape,
331
+ )
332
+
333
+
334
+ class VariadicLikeOpBase(RenderableOpBase):
335
+ def _variadic_inputs(self) -> tuple[str, ...]:
336
+ raise NotImplementedError
337
+
338
+ def _variadic_output(self) -> str:
339
+ raise NotImplementedError
340
+
341
+ def _variadic_kind(self) -> str:
342
+ return self.kind
343
+
344
+ def _variadic_min_inputs(self) -> int:
345
+ return 2
346
+
347
+ def _variadic_max_inputs(self) -> int | None:
348
+ return None
349
+
350
+ def _variadic_compare(self) -> bool:
351
+ return False
352
+
353
+ def _variadic_supports_dtype(self, dtype: ScalarType) -> bool:
354
+ return True
355
+
356
+ def validate(self, ctx: OpContext) -> None:
357
+ inputs = self._variadic_inputs()
358
+ if any(not name for name in inputs):
359
+ raise UnsupportedOpError(
360
+ f"{self._variadic_kind()} input must be provided"
361
+ )
362
+ min_inputs = self._variadic_min_inputs()
363
+ max_inputs = self._variadic_max_inputs()
364
+ if len(inputs) < min_inputs:
365
+ raise UnsupportedOpError(
366
+ f"{self._variadic_kind()} must have at least {min_inputs} inputs"
367
+ )
368
+ if max_inputs is not None and len(inputs) != max_inputs:
369
+ raise UnsupportedOpError(
370
+ f"{self._variadic_kind()} must have exactly {max_inputs} inputs"
371
+ )
372
+ input_dtypes = tuple(ctx.dtype(name) for name in inputs)
373
+ if any(dtype != input_dtypes[0] for dtype in input_dtypes[1:]):
374
+ dtype_names = ", ".join(
375
+ dtype.onnx_name for dtype in input_dtypes
376
+ )
377
+ raise UnsupportedOpError(
378
+ f"{self._variadic_kind()} expects matching input dtypes, "
379
+ f"got {dtype_names}"
380
+ )
381
+ try:
382
+ output_dtype = ctx.dtype(self._variadic_output())
383
+ except ShapeInferenceError:
384
+ output_dtype = None
385
+ if output_dtype is not None:
386
+ if self._variadic_compare():
387
+ if output_dtype != ScalarType.BOOL:
388
+ raise UnsupportedOpError(
389
+ f"{self._variadic_kind()} expects bool output, "
390
+ f"got {output_dtype.onnx_name}"
391
+ )
392
+ elif output_dtype != input_dtypes[0]:
393
+ raise UnsupportedOpError(
394
+ f"{self._variadic_kind()} expects output dtype "
395
+ f"{input_dtypes[0].onnx_name}, got {output_dtype.onnx_name}"
396
+ )
397
+ if not self._variadic_supports_dtype(input_dtypes[0]):
398
+ raise UnsupportedOpError(
399
+ f"{self._variadic_kind()} does not support dtype "
400
+ f"{input_dtypes[0].onnx_name}"
401
+ )
402
+ return None
403
+
404
+ def infer_types(self, ctx: OpContext) -> None:
405
+ for name in self._variadic_inputs():
406
+ ctx.dtype(name)
407
+ try:
408
+ ctx.dtype(self._variadic_output())
409
+ except ShapeInferenceError:
410
+ ctx.set_dtype(
411
+ self._variadic_output(),
412
+ ctx.dtype(self._variadic_inputs()[0]),
413
+ )
414
+
415
+ def infer_shapes(self, ctx: OpContext) -> None:
416
+ input_shapes = tuple(ctx.shape(name) for name in self._variadic_inputs())
417
+ output_shape = BroadcastingOpBase.broadcast_shapes(*input_shapes)
418
+ for shape in input_shapes:
419
+ if shape != output_shape:
420
+ raise UnsupportedOpError(
421
+ f"{self._variadic_kind()} expects identical input/output shapes"
422
+ )
423
+ try:
424
+ expected = ctx.shape(self._variadic_output())
425
+ except ShapeInferenceError:
426
+ expected = None
427
+ if expected is not None and expected != output_shape:
428
+ raise UnsupportedOpError(
429
+ f"{self._variadic_kind()} expects identical input/output shapes"
430
+ )
431
+ ctx.set_shape(self._variadic_output(), output_shape)
432
+
433
+
124
434
  class ReduceOpBase(RenderableOpBase):
125
435
  @staticmethod
126
436
  def normalize_axes(
@@ -1,4 +1,12 @@
1
- from .elementwise import BinaryOp, ClipOp, IdentityOp, MultiInputBinaryOp, UnaryOp, WhereOp
1
+ from .elementwise import (
2
+ BinaryOp,
3
+ ClipOp,
4
+ IdentityOp,
5
+ MultiInputBinaryOp,
6
+ UnaryOp,
7
+ VariadicOp,
8
+ WhereOp,
9
+ )
2
10
  from .misc import (
3
11
  CastOp,
4
12
  ConcatOp,
@@ -126,5 +134,6 @@ __all__ = [
126
134
  "TransposeOp",
127
135
  "TriluOp",
128
136
  "UnaryOp",
137
+ "VariadicOp",
129
138
  "WhereOp",
130
139
  ]
@@ -5,8 +5,8 @@ from dataclasses import dataclass
5
5
  from shared.scalar_functions import ScalarFunction
6
6
  from shared.scalar_types import ScalarType
7
7
 
8
- from ...ops import COMPARE_FUNCTIONS, OperatorKind
9
- from ..op_base import ElementwiseOpBase
8
+ from ...ops import COMPARE_FUNCTIONS, OperatorKind, binary_op_symbol
9
+ from ..op_base import ElementwiseOpBase, VariadicLikeOpBase
10
10
  from ..op_context import OpContext
11
11
 
12
12
 
@@ -34,24 +34,45 @@ class BinaryOp(ElementwiseOpBase):
34
34
 
35
35
 
36
36
  @dataclass(frozen=True)
37
- class MultiInputBinaryOp(ElementwiseOpBase):
37
+ class VariadicOp(VariadicLikeOpBase):
38
+ op_type: str
38
39
  inputs: tuple[str, ...]
39
40
  output: str
40
41
  function: ScalarFunction
41
42
  operator_kind: OperatorKind
42
- shape: tuple[int, ...]
43
- dtype: ScalarType
44
- input_dtype: ScalarType
43
+ min_inputs: int = 2
44
+ max_inputs: int | None = None
45
45
 
46
- def _elementwise_inputs(self) -> tuple[str, ...]:
46
+ def _variadic_inputs(self) -> tuple[str, ...]:
47
47
  return self.inputs
48
48
 
49
- def _elementwise_output(self) -> str:
49
+ def _variadic_output(self) -> str:
50
50
  return self.output
51
51
 
52
- def _elementwise_compare(self) -> bool:
52
+ def _variadic_kind(self) -> str:
53
+ return self.op_type
54
+
55
+ def _variadic_compare(self) -> bool:
53
56
  return self.function in COMPARE_FUNCTIONS
54
57
 
58
+ def _variadic_min_inputs(self) -> int:
59
+ return self.min_inputs
60
+
61
+ def _variadic_max_inputs(self) -> int | None:
62
+ return self.max_inputs
63
+
64
+ def _variadic_supports_dtype(self, dtype: ScalarType) -> bool:
65
+ return (
66
+ binary_op_symbol(
67
+ self.function, dtype=dtype, validate_attrs=False
68
+ )
69
+ is not None
70
+ )
71
+
72
+
73
+ class MultiInputBinaryOp(VariadicOp):
74
+ pass
75
+
55
76
 
56
77
  @dataclass(frozen=True)
57
78
  class WhereOp(ElementwiseOpBase):
@@ -2,13 +2,29 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
+ import numpy as np
6
+
5
7
  from shared.scalar_types import ScalarType
6
8
 
7
- from ...errors import ShapeInferenceError
8
- from ..op_base import BroadcastingOpBase, RenderableOpBase
9
+ from ...errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..op_base import (
11
+ BroadcastingOpBase,
12
+ GatherLikeOpBase,
13
+ RenderableOpBase,
14
+ ShapeLikeOpBase,
15
+ )
9
16
  from ..op_context import OpContext
10
17
 
11
18
 
19
+ def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
20
+ strides: list[int] = []
21
+ stride = 1
22
+ for dim in reversed(shape):
23
+ strides.append(stride)
24
+ stride *= dim
25
+ return tuple(reversed(strides))
26
+
27
+
12
28
  @dataclass(frozen=True)
13
29
  class CastOp(RenderableOpBase):
14
30
  input0: str
@@ -59,16 +75,26 @@ class GatherElementsOp(RenderableOpBase):
59
75
  indices_dtype: ScalarType
60
76
 
61
77
  @dataclass(frozen=True)
62
- class GatherOp(RenderableOpBase):
78
+ class GatherOp(GatherLikeOpBase):
63
79
  data: str
64
80
  indices: str
65
81
  output: str
66
82
  axis: int
67
- data_shape: tuple[int, ...]
68
- indices_shape: tuple[int, ...]
69
- output_shape: tuple[int, ...]
70
- dtype: ScalarType
71
- indices_dtype: ScalarType
83
+
84
+ def _gather_data(self) -> str:
85
+ return self.data
86
+
87
+ def _gather_indices(self) -> str:
88
+ return self.indices
89
+
90
+ def _gather_output(self) -> str:
91
+ return self.output
92
+
93
+ def _gather_axis(self) -> int:
94
+ return self.axis
95
+
96
+ def _gather_mode(self) -> str:
97
+ return "gather"
72
98
 
73
99
  @dataclass(frozen=True)
74
100
  class GatherNDOp(RenderableOpBase):
@@ -360,15 +386,72 @@ class NonMaxSuppressionOp(RenderableOpBase):
360
386
  score_threshold_shape: tuple[int, ...] | None
361
387
 
362
388
  @dataclass(frozen=True)
363
- class ExpandOp(BroadcastingOpBase):
389
+ class ExpandOp(ShapeLikeOpBase):
364
390
  input0: str
391
+ input_shape: str
365
392
  output: str
366
- input_shape: tuple[int, ...]
367
- output_shape: tuple[int, ...]
368
- input_shape_padded: tuple[int, ...]
369
- input_strides: tuple[int, ...]
370
- dtype: ScalarType
371
- input_dtype: ScalarType
393
+
394
+ def _shape_data(self) -> str:
395
+ return self.input0
396
+
397
+ def _shape_output(self) -> str:
398
+ return self.output
399
+
400
+ def _shape_mode(self) -> str:
401
+ return "expand"
402
+
403
+ def _shape_spec(self, ctx: OpContext) -> tuple[int, ...]:
404
+ initializer = ctx.initializer(self.input_shape)
405
+ if initializer is not None:
406
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
407
+ raise UnsupportedOpError(
408
+ f"{self.kind} shape input must be int64 or int32"
409
+ )
410
+ if len(initializer.type.shape) != 1:
411
+ raise UnsupportedOpError(
412
+ f"{self.kind} shape input must be a 1D tensor"
413
+ )
414
+ values = np.array(initializer.data, dtype=np.int64).reshape(-1)
415
+ if values.size == 0:
416
+ raise ShapeInferenceError(
417
+ f"{self.kind} shape input cannot be empty"
418
+ )
419
+ return tuple(int(value) for value in values)
420
+ dtype = ctx.dtype(self.input_shape)
421
+ if dtype not in {ScalarType.I64, ScalarType.I32}:
422
+ raise UnsupportedOpError(
423
+ f"{self.kind} shape input must be int64 or int32"
424
+ )
425
+ shape = ctx.shape(self.input_shape)
426
+ if len(shape) != 1:
427
+ raise UnsupportedOpError(
428
+ f"{self.kind} shape input must be a 1D tensor"
429
+ )
430
+ if shape[0] <= 0:
431
+ raise ShapeInferenceError(
432
+ f"{self.kind} shape input cannot be empty"
433
+ )
434
+ output_shape = ctx.shape(self.output)
435
+ if not output_shape:
436
+ raise ShapeInferenceError(
437
+ f"{self.kind} output shape must be specified"
438
+ )
439
+ return output_shape
440
+
441
+ def _shape_derived(
442
+ self,
443
+ ctx: OpContext,
444
+ *,
445
+ data_shape: tuple[int, ...],
446
+ target_shape: tuple[int, ...],
447
+ output_shape: tuple[int, ...],
448
+ ) -> None:
449
+ input_shape_padded = (
450
+ (1,) * (len(output_shape) - len(data_shape)) + data_shape
451
+ )
452
+ input_strides = _compute_strides(input_shape_padded)
453
+ ctx.set_derived(self, "input_shape_padded", input_shape_padded)
454
+ ctx.set_derived(self, "input_strides", input_strides)
372
455
 
373
456
  @dataclass(frozen=True)
374
457
  class CumSumOp(RenderableOpBase):
@@ -1,151 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
-
5
- from shared.scalar_types import ScalarType
6
-
3
+ from ..errors import UnsupportedOpError
4
+ from ..ir.model import Graph, Node
7
5
  from ..ir.ops import ExpandOp
8
- from ..errors import ShapeInferenceError, UnsupportedOpError
9
- from ..ir.model import Graph, Initializer, Node
10
- from ..lowering.common import value_dtype, value_shape
11
6
  from .registry import register_lowering
12
7
 
13
8
 
14
- def _find_initializer(graph: Graph, name: str) -> Initializer | None:
15
- for initializer in graph.initializers:
16
- if initializer.name == name:
17
- return initializer
18
- return None
19
-
20
-
21
- def _read_shape_values(graph: Graph, name: str, node: Node) -> list[int] | None:
22
- initializer = _find_initializer(graph, name)
23
- if initializer is None:
24
- return None
25
- if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
26
- raise UnsupportedOpError(
27
- f"{node.op_type} shape input must be int64 or int32"
28
- )
29
- if len(initializer.type.shape) != 1:
30
- raise UnsupportedOpError(
31
- f"{node.op_type} shape input must be a 1D tensor"
32
- )
33
- values = np.array(initializer.data, dtype=np.int64).reshape(-1)
34
- if values.size == 0:
35
- raise ShapeInferenceError(
36
- f"{node.op_type} shape input cannot be empty"
37
- )
38
- return [int(value) for value in values]
39
-
40
-
41
- def _validate_shape_input(graph: Graph, name: str, node: Node) -> None:
42
- dtype = value_dtype(graph, name, node)
43
- if dtype not in {ScalarType.I64, ScalarType.I32}:
44
- raise UnsupportedOpError(
45
- f"{node.op_type} shape input must be int64 or int32"
46
- )
47
- shape = value_shape(graph, name, node)
48
- if len(shape) != 1:
49
- raise UnsupportedOpError(
50
- f"{node.op_type} shape input must be a 1D tensor"
51
- )
52
- if shape[0] <= 0:
53
- raise ShapeInferenceError(
54
- f"{node.op_type} shape input cannot be empty"
55
- )
56
-
57
-
58
- def _validate_static_dims(shape: tuple[int, ...], node: Node) -> None:
59
- if any(dim < 0 for dim in shape):
60
- raise ShapeInferenceError(
61
- f"{node.op_type} does not support dynamic dims"
62
- )
63
-
64
-
65
- def _broadcast_shape(
66
- input_shape: tuple[int, ...], shape_values: list[int], node: Node
67
- ) -> tuple[int, ...]:
68
- _validate_static_dims(input_shape, node)
69
- for dim in shape_values:
70
- if dim < 0:
71
- raise ShapeInferenceError(
72
- f"{node.op_type} does not support dynamic dims"
73
- )
74
- output_rank = max(len(input_shape), len(shape_values))
75
- input_padded = (1,) * (output_rank - len(input_shape)) + input_shape
76
- shape_padded = (1,) * (output_rank - len(shape_values)) + tuple(shape_values)
77
- result: list[int] = []
78
- for input_dim, shape_dim in zip(input_padded, shape_padded):
79
- if input_dim == 1:
80
- result.append(shape_dim)
81
- elif shape_dim == 1:
82
- result.append(input_dim)
83
- elif input_dim == shape_dim:
84
- result.append(input_dim)
85
- else:
86
- raise ShapeInferenceError(
87
- f"{node.op_type} input shape {input_shape} is not "
88
- f"broadcastable to {shape_values}"
89
- )
90
- return tuple(result)
91
-
92
-
93
- def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
94
- strides: list[int] = []
95
- stride = 1
96
- for dim in reversed(shape):
97
- strides.append(stride)
98
- stride *= dim
99
- return tuple(reversed(strides))
100
-
101
-
102
9
  @register_lowering("Expand")
103
10
  def lower_expand(graph: Graph, node: Node) -> ExpandOp:
104
11
  if len(node.inputs) != 2 or len(node.outputs) != 1:
105
12
  raise UnsupportedOpError("Expand must have 2 inputs and 1 output")
106
- input_shape = value_shape(graph, node.inputs[0], node)
107
- output_shape = value_shape(graph, node.outputs[0], node)
108
- input_dtype = value_dtype(graph, node.inputs[0], node)
109
- output_dtype = value_dtype(graph, node.outputs[0], node)
110
- if input_dtype != output_dtype:
111
- raise UnsupportedOpError(
112
- f"{node.op_type} expects matching input/output dtypes, "
113
- f"got {input_dtype} and {output_dtype}"
114
- )
115
- shape_values = _read_shape_values(graph, node.inputs[1], node)
116
- if shape_values is not None:
117
- expected_output_shape = _broadcast_shape(input_shape, shape_values, node)
118
- _validate_static_dims(expected_output_shape, node)
119
- if output_shape and output_shape != expected_output_shape:
120
- raise ShapeInferenceError(
121
- f"{node.op_type} output shape must be {expected_output_shape}, "
122
- f"got {output_shape}"
123
- )
124
- else:
125
- _validate_shape_input(graph, node.inputs[1], node)
126
- if not output_shape:
127
- raise ShapeInferenceError(
128
- f"{node.op_type} output shape must be specified"
129
- )
130
- expected_output_shape = _broadcast_shape(
131
- input_shape, list(output_shape), node
132
- )
133
- if expected_output_shape != output_shape:
134
- raise ShapeInferenceError(
135
- f"{node.op_type} output shape must be {expected_output_shape}, "
136
- f"got {output_shape}"
137
- )
138
- input_shape_padded = (
139
- (1,) * (len(expected_output_shape) - len(input_shape)) + input_shape
140
- )
141
- input_strides = _compute_strides(input_shape_padded)
142
13
  return ExpandOp(
143
14
  input0=node.inputs[0],
15
+ input_shape=node.inputs[1],
144
16
  output=node.outputs[0],
145
- input_shape=input_shape,
146
- output_shape=expected_output_shape,
147
- input_shape_padded=input_shape_padded,
148
- input_strides=input_strides,
149
- dtype=input_dtype,
150
- input_dtype=input_dtype,
151
17
  )
@@ -1,13 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from shared.scalar_types import ScalarType
4
-
5
3
  from ..ir.ops import GatherOp
6
- from ..errors import ShapeInferenceError, UnsupportedOpError
4
+ from ..errors import UnsupportedOpError
7
5
  from ..ir.model import Graph, Node
8
- from ..validation import normalize_axis
9
- from .common import value_dtype as _value_dtype
10
- from .common import value_shape as _value_shape
11
6
  from .registry import register_lowering
12
7
 
13
8
 
@@ -16,33 +11,9 @@ def lower_gather(graph: Graph, node: Node) -> GatherOp:
16
11
  if len(node.inputs) != 2 or len(node.outputs) != 1:
17
12
  raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
18
13
  data_name, indices_name = node.inputs
19
- data_shape = _value_shape(graph, data_name, node)
20
- indices_shape = _value_shape(graph, indices_name, node)
21
- output_shape = _value_shape(graph, node.outputs[0], node)
22
- axis = normalize_axis(int(node.attrs.get("axis", 0)), data_shape, node)
23
- expected_output_shape = (
24
- data_shape[:axis] + indices_shape + data_shape[axis + 1 :]
25
- )
26
- if output_shape != expected_output_shape:
27
- raise ShapeInferenceError(
28
- "Gather output shape must be "
29
- f"{expected_output_shape}, got {output_shape}"
30
- )
31
- op_dtype = _value_dtype(graph, data_name, node)
32
- indices_dtype = _value_dtype(graph, indices_name, node)
33
- if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
34
- raise UnsupportedOpError(
35
- "Gather indices must be int32 or int64, "
36
- f"got {indices_dtype.onnx_name}"
37
- )
38
14
  return GatherOp(
39
15
  data=data_name,
40
16
  indices=indices_name,
41
17
  output=node.outputs[0],
42
- axis=axis,
43
- data_shape=data_shape,
44
- indices_shape=indices_shape,
45
- output_shape=output_shape,
46
- dtype=op_dtype,
47
- indices_dtype=indices_dtype,
18
+ axis=int(node.attrs.get("axis", 0)),
48
19
  )
@@ -1,14 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from shared.scalar_functions import ScalarFunction
4
- from shared.scalar_types import ScalarType
5
4
 
6
- from ..ir.ops import MultiInputBinaryOp
7
5
  from ..errors import UnsupportedOpError
8
6
  from ..ir.model import Graph, Node
9
- from ..lowering.common import node_dtype, value_dtype, value_shape
7
+ from ..ir.ops import MultiInputBinaryOp
10
8
  from ..lowering.registry import register_lowering
11
- from ..ops import binary_op_symbol
9
+ from ..ops import OperatorKind
12
10
 
13
11
  VARIADIC_OP_FUNCTIONS: dict[str, ScalarFunction] = {
14
12
  "Sum": ScalarFunction.ADD,
@@ -32,62 +30,31 @@ BINARY_ONLY_OPS = {
32
30
  "BitwiseXor",
33
31
  }
34
32
 
35
-
36
- def _validate_inputs(
37
- graph: Graph, node: Node, *, function: ScalarFunction
38
- ) -> tuple[ScalarType, tuple[int, ...]]:
39
- if len(node.outputs) != 1:
40
- raise UnsupportedOpError(f"{node.op_type} must have 1 output")
41
- if node.op_type in BINARY_ONLY_OPS:
42
- if len(node.inputs) != 2:
43
- raise UnsupportedOpError(
44
- f"{node.op_type} must have exactly 2 inputs"
45
- )
46
- elif len(node.inputs) < 2:
47
- raise UnsupportedOpError(
48
- f"{node.op_type} must have at least 2 inputs"
49
- )
50
- for name in node.inputs:
51
- if not name:
52
- raise UnsupportedOpError(f"{node.op_type} input must be provided")
53
- op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
54
- output_dtype = value_dtype(graph, node.outputs[0], node)
55
- if op_dtype != output_dtype:
56
- raise UnsupportedOpError(
57
- f"{node.op_type} expects matching input/output dtypes, "
58
- f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
59
- )
60
- output_shape = value_shape(graph, node.outputs[0], node)
61
- for name in node.inputs:
62
- input_shape = value_shape(graph, name, node)
63
- if input_shape != output_shape:
64
- raise UnsupportedOpError(
65
- f"{node.op_type} expects identical input/output shapes"
66
- )
67
- op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
68
- if op_spec is None:
69
- raise UnsupportedOpError(
70
- f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
71
- )
72
- return op_dtype, output_shape
33
+ VARIADIC_OP_OPERATOR_KINDS: dict[str, OperatorKind] = {
34
+ "Sum": OperatorKind.INFIX,
35
+ "Mean": OperatorKind.EXPR,
36
+ "Max": OperatorKind.FUNC,
37
+ "Min": OperatorKind.FUNC,
38
+ "And": OperatorKind.INFIX,
39
+ "Or": OperatorKind.INFIX,
40
+ "Xor": OperatorKind.INFIX,
41
+ "BitwiseAnd": OperatorKind.INFIX,
42
+ "BitwiseOr": OperatorKind.INFIX,
43
+ "BitwiseXor": OperatorKind.INFIX,
44
+ }
73
45
 
74
46
 
75
47
  def _lower_variadic(graph: Graph, node: Node) -> MultiInputBinaryOp:
76
- function = VARIADIC_OP_FUNCTIONS[node.op_type]
77
- op_dtype, output_shape = _validate_inputs(graph, node, function=function)
78
- op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
79
- if op_spec is None:
80
- raise UnsupportedOpError(
81
- f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
82
- )
48
+ if len(node.outputs) != 1:
49
+ raise UnsupportedOpError(f"{node.op_type} must have 1 output")
83
50
  return MultiInputBinaryOp(
51
+ op_type=node.op_type,
84
52
  inputs=tuple(node.inputs),
85
53
  output=node.outputs[0],
86
- function=function,
87
- operator_kind=op_spec.kind,
88
- shape=output_shape,
89
- dtype=op_dtype,
90
- input_dtype=op_dtype,
54
+ function=VARIADIC_OP_FUNCTIONS[node.op_type],
55
+ operator_kind=VARIADIC_OP_OPERATOR_KINDS[node.op_type],
56
+ min_inputs=2,
57
+ max_inputs=2 if node.op_type in BINARY_ONLY_OPS else None,
91
58
  )
92
59
 
93
60
 
@@ -7,7 +7,9 @@ import numpy as np
7
7
 
8
8
  from shared.scalar_types import ScalarType
9
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..ir.context import GraphContext
10
11
  from ..ir.model import Graph, Node
12
+ from ..ir.op_context import OpContext
11
13
  from ..lowering.attention import resolve_attention_spec
12
14
  from ..lowering.average_pool import lower_average_pool, lower_global_average_pool
13
15
  from ..lowering.adagrad import lower_adagrad
@@ -2021,8 +2023,13 @@ def _eval_nonzero(evaluator: Evaluator, node: Node) -> None:
2021
2023
  def _eval_expand(evaluator: Evaluator, node: Node) -> None:
2022
2024
  op = lower_expand(evaluator.graph, node)
2023
2025
  value = evaluator.values[op.input0]
2026
+ op_ctx = OpContext(GraphContext(evaluator.graph))
2027
+ op.validate(op_ctx)
2028
+ op.infer_types(op_ctx)
2029
+ op.infer_shapes(op_ctx)
2030
+ output_shape = op_ctx.shape(op.output)
2024
2031
  evaluator.values[op.output] = np.broadcast_to(
2025
- value, op.output_shape
2032
+ value, output_shape
2026
2033
  ).copy()
2027
2034
 
2028
2035
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emx-onnx-cgen
3
- Version: 0.3.3
3
+ Version: 0.3.5
4
4
  Summary: emmtrix ONNX-to-C Code Generator
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,9 +1,9 @@
1
1
  emx_onnx_cgen/__init__.py,sha256=jUSbu1kJ0krzVTYEcph3jCprBhD7tWNtiSdL6r29KrM,221
2
2
  emx_onnx_cgen/__main__.py,sha256=iC1lLVtR6-TmpL6OxXcy3oIntExUtajn9-q627R1XyI,140
3
- emx_onnx_cgen/_build_info.py,sha256=W-LMY90LuY6_m3BNDTItdi_pO01vSSb1j-UK-NFUge0,112
4
- emx_onnx_cgen/_version.py,sha256=lemL_4Kl75FgrO6lVuFrrtw6-Dcf9wtXBalKkXuzkO4,704
3
+ emx_onnx_cgen/_build_info.py,sha256=H1wVqVfVhPbohtf1JhkEW7F4rM6RzR1WlU9jrNMVYlE,112
4
+ emx_onnx_cgen/_version.py,sha256=UAb2Toi6SAdScDfq1uKRRv5QpMUuRtJqqwNxTMGe5Q4,704
5
5
  emx_onnx_cgen/cli.py,sha256=7Y9JW-t1PLg25zOizuqyMqwsXbbG9ok99DsYeFSiOFQ,21685
6
- emx_onnx_cgen/compiler.py,sha256=qXKUQedaQY6A2jX-twte4qVA263T3UtCDlPjvoM5vYU,16513
6
+ emx_onnx_cgen/compiler.py,sha256=v1-EzVUxZv5Kfn81kCDVuferRxvXFXEeRaNbQ4w6xss,18437
7
7
  emx_onnx_cgen/dtypes.py,sha256=jRx3BBvk0qFW14bngoL1B7L_IRasyNJ4jqhpM5YhcOM,1335
8
8
  emx_onnx_cgen/errors.py,sha256=HpOv95mTgr9ZX2gYe1RtwVMbPskh7zkqjU_FgAD-uIM,363
9
9
  emx_onnx_cgen/onnx_import.py,sha256=IF7KZGfEP9H4H1fHYjobGbB_381fqD_67KtqZYs9AZ4,9168
@@ -13,16 +13,16 @@ emx_onnx_cgen/testbench.py,sha256=-NbqD1aC7OXvFMLiLzd2IPObenQdHFH85cNxNSB1GeY,64
13
13
  emx_onnx_cgen/validation.py,sha256=KFdUdGjQbzTj1szCJcjxnTi8f5l6ywNgCB9abbBpTbM,2360
14
14
  emx_onnx_cgen/verification.py,sha256=IrhIMm29R2vEkW1Q8gtoQtscMGxfJRavNRSMJHBAJ5g,1041
15
15
  emx_onnx_cgen/codegen/__init__.py,sha256=H_kBdc_w_W-3qdUZJHwKBDns1AeP_Un3-46LW20yLV0,406
16
- emx_onnx_cgen/codegen/c_emitter.py,sha256=dS-vjjuWT0GHETbV3ipoYedvuvcJB0yGwMZgoQuJe-g,452931
16
+ emx_onnx_cgen/codegen/c_emitter.py,sha256=JdDJGv1HptINaLLZxlxNfo1R7VM9v680EyiDMpeReds,453199
17
17
  emx_onnx_cgen/codegen/emitter.py,sha256=udcsqJNr46TFHiyVv5I4wdVH8ll6Bi4VqcR1VvofbnY,92
18
18
  emx_onnx_cgen/ir/__init__.py,sha256=fD2D8qxlGoCFJb0m9v6u3XTgzSxDOhB4cfLBiCLovzg,102
19
19
  emx_onnx_cgen/ir/context.py,sha256=cM3V6G3zs6VCsABP6TnZ8vvQ7VGwOF1iKtb1hq0WO3g,3356
20
20
  emx_onnx_cgen/ir/model.py,sha256=SZ3K8t4dKUqWuXWe5ozApofXx4bdcf4p0WYCdeU-mFA,1265
21
- emx_onnx_cgen/ir/op_base.py,sha256=mHvp0VD55JIrwQI2MFEmSILi22kuurBX085aamcjQ0g,6160
21
+ emx_onnx_cgen/ir/op_base.py,sha256=_iPeVkLPR3jsRASrvXEWk-k3BJboPHtZY6jnB0HdLvk,17611
22
22
  emx_onnx_cgen/ir/op_context.py,sha256=9CZCUNJLsV4cJsYmJqWbaDrwQd4sr-9Ot1PmPSqGAto,2103
23
- emx_onnx_cgen/ir/ops/__init__.py,sha256=IcllGXB4T3TCrpBq9cy3jR_edS_IJ_qXac37K_rIZcA,2440
24
- emx_onnx_cgen/ir/ops/elementwise.py,sha256=sZ1S6X_fagNDevN6dXHBy75g_z-WP_dHFAVmPGnmeaU,3721
25
- emx_onnx_cgen/ir/ops/misc.py,sha256=1ekAgV5j6Stc1Yw8e-0EPD5t8mI1YJxmyIkAn9Zr4h8,10920
23
+ emx_onnx_cgen/ir/ops/__init__.py,sha256=Zk7QzNiB4CHcixZlA1thA78mcudXdTvCfKlxUTRrX24,2503
24
+ emx_onnx_cgen/ir/ops/elementwise.py,sha256=TXbyayj3UnfLe4tUYBEwBDr7ZFyFi1i8HdVdCjtvLCc,4241
25
+ emx_onnx_cgen/ir/ops/misc.py,sha256=vN4OpW5gsryQ0aiVNBFiYlZMxwg8Z9wUOBM7w3f4ZFE,13522
26
26
  emx_onnx_cgen/ir/ops/nn.py,sha256=-4ZqDkcu7zgci3YVfMzCDzokqpZHgOYZaq_C1GclBZQ,14365
27
27
  emx_onnx_cgen/ir/ops/reduce.py,sha256=-aA4bwOMppd9pnWQwhl6hOxryh0G2xRaHqeNwQ97AdY,2756
28
28
  emx_onnx_cgen/lowering/__init__.py,sha256=AxnUfmpf5Teos1ms3zE6r0EBxxPYznGSOICDEFWH_pk,1535
@@ -42,10 +42,10 @@ emx_onnx_cgen/lowering/depth_space.py,sha256=i7INioNkofBxFlZW9y0W_qA6mp67_FAXouh
42
42
  emx_onnx_cgen/lowering/dropout.py,sha256=MZ4YrB-jvUFXpIKE5kOLyrEF5uy5dh0yjJH6Rj8KlMs,1764
43
43
  emx_onnx_cgen/lowering/einsum.py,sha256=MWAgWVOzP38RSOxJABwvYU6ykD9odmhrmddXinmFs7s,6117
44
44
  emx_onnx_cgen/lowering/elementwise.py,sha256=q9X3qTll7gLp39NTTdzuLs9RBsONssw50l1hWo8wby0,12229
45
- emx_onnx_cgen/lowering/expand.py,sha256=GmYJZWXXcBV42hMGUgbKKbLjeCxpbcMSoG9OU1ZkFFY,5518
45
+ emx_onnx_cgen/lowering/expand.py,sha256=y0h1x2xh6Oqtblm6TbELB6_I4fsquU3YuZoB4mZJeTo,525
46
46
  emx_onnx_cgen/lowering/eye_like.py,sha256=QBiHWYZbgK4uiUYWuS7WHCMBGMSG0paNZM84OYmGb7c,1723
47
47
  emx_onnx_cgen/lowering/flatten.py,sha256=6h-TQNy9iq5hfXR9h2clUrc2eHmZP9gAb9KbCSJdV20,2131
48
- emx_onnx_cgen/lowering/gather.py,sha256=PCER36AjmpxzAM4wuL7En3XR1RKZCdSzjxualDCUHAI,1803
48
+ emx_onnx_cgen/lowering/gather.py,sha256=3sxrld5GIS4OO3hRVp8QdbMtyLQUHbdCXL8vmZvh67c,599
49
49
  emx_onnx_cgen/lowering/gather_elements.py,sha256=cCp2UFOjktgEfS9s9npMS_BXklBkpMpD7UhIIMhQ-_Y,2318
50
50
  emx_onnx_cgen/lowering/gather_nd.py,sha256=rmr_ijeSeCrZ_R_QPwdoHPQUCe8nE0YRSv2NjUiiFjY,3090
51
51
  emx_onnx_cgen/lowering/gemm.py,sha256=qBaZ-6FZAAMEaZ4uifo58tJI8SoBsJvkZTCg7jvq288,4579
@@ -92,16 +92,16 @@ emx_onnx_cgen/lowering/topk.py,sha256=Dqx7qMr4HbXhVGN-wJf_D4dPTvYMVT6S82A2M3f9Dw
92
92
  emx_onnx_cgen/lowering/transpose.py,sha256=oNFRjkH63KqnO2Q4oJengEAUEYC1M3PW12AauWwebzI,1751
93
93
  emx_onnx_cgen/lowering/trilu.py,sha256=OjJjyo2ZRcfo9UGH8Zfq4o0PR6YDeoHSj8DzMu0w318,3266
94
94
  emx_onnx_cgen/lowering/unsqueeze.py,sha256=9y-OM-oY6ln1-R6duRRemeRrwBIpX2TZs_nRtlYQMYE,5985
95
- emx_onnx_cgen/lowering/variadic.py,sha256=etIWA7jVqWrWH3NkNvpF5opVGgvb0ZS4iLo4L3euWDs,3287
95
+ emx_onnx_cgen/lowering/variadic.py,sha256=OrC3rwM3-SNewYRs7YA7DwwS8XW1ucxUobTEjZdEs4s,1823
96
96
  emx_onnx_cgen/lowering/where.py,sha256=K2RUDvLg0uTvi6Z_uTOXM5jgc3PXRj0cTZ4u58GEGko,2644
97
97
  emx_onnx_cgen/runtime/__init__.py,sha256=88xGpAs1IEBlzlWL_e9tnKUlaSRdc7pQUeVCu5LC4DY,50
98
- emx_onnx_cgen/runtime/evaluator.py,sha256=yqsBpAIlBky-rby7J5z7i1SvDaK6PjObxH-wQSdZ2G0,114732
98
+ emx_onnx_cgen/runtime/evaluator.py,sha256=8d9GOzhYNs2XX5q4vjaTM-wxkf8_rE4QEf5e1USWGd8,114981
99
99
  shared/__init__.py,sha256=bmP79AVZdY_1aNULJap9pm76Q41Rabrza6X-0A8lDzw,45
100
100
  shared/scalar_functions.py,sha256=CErro1Du2Ri3uqX6Dgd18DzNbxduckAvsmLJ6oHGx9A,91123
101
101
  shared/scalar_types.py,sha256=kEpsl5T-NVFxCcTzXqPJbtpvDiCgKHfz91dphLLZxZA,4912
102
102
  shared/ulp.py,sha256=DpeovCFijmP8_M7zyTZWsNyfOtJ1AjNSdxf5jGsdfJo,1856
103
- emx_onnx_cgen-0.3.3.dist-info/METADATA,sha256=0joVoW9rki1TEClARZ0z235N-2ruM8R3mZ00zXGmR34,6266
104
- emx_onnx_cgen-0.3.3.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
105
- emx_onnx_cgen-0.3.3.dist-info/entry_points.txt,sha256=b7Rvmz_Bi9kWyn7QayQC_FEXiRpt4cS1RnluKh49yoo,57
106
- emx_onnx_cgen-0.3.3.dist-info/top_level.txt,sha256=g39fo-blEbgiVcC_GRqAnBzN234w3LXbcVdLUoItSLk,21
107
- emx_onnx_cgen-0.3.3.dist-info/RECORD,,
103
+ emx_onnx_cgen-0.3.5.dist-info/METADATA,sha256=XwhvHTOcBPst7LPvgjPnR9hnVV8Jj0RtHtMITPMpAsA,6266
104
+ emx_onnx_cgen-0.3.5.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
105
+ emx_onnx_cgen-0.3.5.dist-info/entry_points.txt,sha256=b7Rvmz_Bi9kWyn7QayQC_FEXiRpt4cS1RnluKh49yoo,57
106
+ emx_onnx_cgen-0.3.5.dist-info/top_level.txt,sha256=g39fo-blEbgiVcC_GRqAnBzN234w3LXbcVdLUoItSLk,21
107
+ emx_onnx_cgen-0.3.5.dist-info/RECORD,,