emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+
1
5
  from .registry import get_lowering, register_lowering
2
6
 
3
- __all__ = ["get_lowering", "register_lowering"]
7
+ _LOWERING_MODULES = [
8
+ "adagrad",
9
+ "arg_reduce",
10
+ "attention",
11
+ "average_pool",
12
+ "batch_normalization",
13
+ "cast",
14
+ "concat",
15
+ "constant_of_shape",
16
+ "conv",
17
+ "conv_transpose",
18
+ "cumsum",
19
+ "depth_space",
20
+ "dropout",
21
+ "einsum",
22
+ "elementwise",
23
+ "expand",
24
+ "eye_like",
25
+ "flatten",
26
+ "gather",
27
+ "gather_elements",
28
+ "gather_nd",
29
+ "gemm",
30
+ "global_max_pool",
31
+ "grid_sample",
32
+ "group_normalization",
33
+ "hardmax",
34
+ "identity",
35
+ "instance_normalization",
36
+ "layer_normalization",
37
+ "logsoftmax",
38
+ "lp_normalization",
39
+ "lp_pool",
40
+ "lrn",
41
+ "lstm",
42
+ "matmul",
43
+ "maxpool",
44
+ "mean_variance_normalization",
45
+ "negative_log_likelihood_loss",
46
+ "non_max_suppression",
47
+ "nonzero",
48
+ "one_hot",
49
+ "pad",
50
+ "qlinear_matmul",
51
+ "quantize_linear",
52
+ "range",
53
+ "reduce",
54
+ "reshape",
55
+ "resize",
56
+ "rms_normalization",
57
+ "rotary_embedding",
58
+ "scatter_nd",
59
+ "shape",
60
+ "size",
61
+ "slice",
62
+ "softmax",
63
+ "softmax_cross_entropy_loss",
64
+ "split",
65
+ "squeeze",
66
+ "tensor_scatter",
67
+ "tile",
68
+ "topk",
69
+ "transpose",
70
+ "trilu",
71
+ "unsqueeze",
72
+ "variadic",
73
+ "where",
74
+ ]
75
+
76
+
77
+ def load_lowering_registry() -> None:
78
+ for module_name in _LOWERING_MODULES:
79
+ importlib.import_module(f"{__name__}.{module_name}")
80
+
81
+ __all__ = ["get_lowering", "register_lowering", "load_lowering_registry"]
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import AdagradOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
13
+ return shape == () or shape == (1,)
14
+
15
+
16
+ @register_lowering("Adagrad")
17
+ def lower_adagrad(graph: Graph, node: Node) -> AdagradOp:
18
+ if len(node.inputs) < 5:
19
+ raise UnsupportedOpError("Adagrad must have at least 5 inputs")
20
+ if len(node.outputs) < 2:
21
+ raise UnsupportedOpError("Adagrad must have at least 2 outputs")
22
+ if (len(node.inputs) - 2) % 3 != 0:
23
+ raise UnsupportedOpError(
24
+ "Adagrad inputs must be R, T, Xs, Gs, Hs with matching counts"
25
+ )
26
+ tensor_count = (len(node.inputs) - 2) // 3
27
+ if len(node.outputs) != tensor_count * 2:
28
+ raise UnsupportedOpError(
29
+ "Adagrad outputs must be X_news followed by H_news"
30
+ )
31
+ rate_name = node.inputs[0]
32
+ timestep_name = node.inputs[1]
33
+ rate_shape = value_shape(graph, rate_name, node)
34
+ timestep_shape = value_shape(graph, timestep_name, node)
35
+ if not _is_scalar_shape(rate_shape):
36
+ raise UnsupportedOpError("Adagrad R input must be a scalar")
37
+ if not _is_scalar_shape(timestep_shape):
38
+ raise UnsupportedOpError("Adagrad T input must be a scalar")
39
+ rate_dtype = value_dtype(graph, rate_name, node)
40
+ if rate_dtype not in {ScalarType.F32, ScalarType.F64}:
41
+ raise UnsupportedOpError(
42
+ "Adagrad R input must be float or double"
43
+ )
44
+ timestep_dtype = value_dtype(graph, timestep_name, node)
45
+ if timestep_dtype != ScalarType.I64:
46
+ raise UnsupportedOpError("Adagrad T input must be int64")
47
+
48
+ inputs = node.inputs[2 : 2 + tensor_count]
49
+ gradients = node.inputs[2 + tensor_count : 2 + tensor_count * 2]
50
+ accumulators = node.inputs[2 + tensor_count * 2 : 2 + tensor_count * 3]
51
+ outputs = node.outputs[:tensor_count]
52
+ accumulator_outputs = node.outputs[tensor_count:]
53
+ if not inputs or not gradients or not accumulators:
54
+ raise UnsupportedOpError("Adagrad requires X, G, H inputs")
55
+ dtype = value_dtype(graph, inputs[0], node)
56
+ if dtype not in {ScalarType.F32, ScalarType.F64}:
57
+ raise UnsupportedOpError("Adagrad supports float and double tensors only")
58
+ if rate_dtype != dtype:
59
+ raise UnsupportedOpError(
60
+ "Adagrad R input dtype must match tensor dtype"
61
+ )
62
+ input_shapes: list[tuple[int, ...]] = []
63
+ output_shapes: list[tuple[int, ...]] = []
64
+ for index, (x_name, g_name, h_name, out_name, h_out_name) in enumerate(
65
+ zip(inputs, gradients, accumulators, outputs, accumulator_outputs)
66
+ ):
67
+ x_dtype = value_dtype(graph, x_name, node)
68
+ g_dtype = value_dtype(graph, g_name, node)
69
+ h_dtype = value_dtype(graph, h_name, node)
70
+ out_dtype = value_dtype(graph, out_name, node)
71
+ h_out_dtype = value_dtype(graph, h_out_name, node)
72
+ if {x_dtype, g_dtype, h_dtype, out_dtype, h_out_dtype} != {dtype}:
73
+ raise UnsupportedOpError(
74
+ "Adagrad inputs and outputs must share the same dtype"
75
+ )
76
+ x_shape = value_shape(graph, x_name, node)
77
+ g_shape = value_shape(graph, g_name, node)
78
+ h_shape = value_shape(graph, h_name, node)
79
+ out_shape = value_shape(graph, out_name, node)
80
+ h_out_shape = value_shape(graph, h_out_name, node)
81
+ if x_shape != g_shape or x_shape != h_shape:
82
+ raise ShapeInferenceError(
83
+ f"Adagrad inputs X/G/H shapes must match for tensor {index}"
84
+ )
85
+ if out_shape != x_shape or h_out_shape != x_shape:
86
+ raise ShapeInferenceError(
87
+ f"Adagrad outputs must match X shape for tensor {index}"
88
+ )
89
+ input_shapes.append(x_shape)
90
+ output_shapes.append(out_shape)
91
+
92
+ norm_coefficient = float(node.attrs.get("norm_coefficient", 0.0))
93
+ epsilon = float(node.attrs.get("epsilon", 0.0))
94
+ decay_factor = float(node.attrs.get("decay_factor", 0.0))
95
+
96
+ return AdagradOp(
97
+ rate=rate_name,
98
+ timestep=timestep_name,
99
+ inputs=tuple(inputs),
100
+ gradients=tuple(gradients),
101
+ accumulators=tuple(accumulators),
102
+ outputs=tuple(outputs),
103
+ accumulator_outputs=tuple(accumulator_outputs),
104
+ rate_shape=rate_shape,
105
+ timestep_shape=timestep_shape,
106
+ tensor_shapes=tuple(input_shapes),
107
+ output_shapes=tuple(output_shapes),
108
+ dtype=dtype,
109
+ rate_dtype=rate_dtype,
110
+ timestep_dtype=timestep_dtype,
111
+ norm_coefficient=norm_coefficient,
112
+ epsilon=epsilon,
113
+ decay_factor=decay_factor,
114
+ )
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ArgReduceOp
5
+ from ..ir.ops import ArgReduceOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import shape_product, value_dtype, value_shape
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
5
 
6
6
  from shared.scalar_types import ScalarType
7
7
 
8
- from ..codegen.c_emitter import AttentionOp
8
+ from ..ir.ops import AttentionOp
9
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
11
11
  from .common import node_dtype as _node_dtype
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import AveragePoolOp
5
+ from ..ir.ops import AveragePoolOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .registry import register_lowering
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import BatchNormOp
5
+ from ..ir.ops import BatchNormOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .registry import register_lowering
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import onnx
4
4
 
5
- from ..codegen.c_emitter import CastOp
5
+ from ..ir.ops import CastOp
6
6
  from ..dtypes import scalar_type_from_onnx
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
8
  from ..ir.model import Graph, Node
@@ -5,6 +5,7 @@ from collections.abc import Sequence
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
+ from ..ir.context import GraphContext
8
9
  from ..ir.model import Graph, Initializer, Node
9
10
 
10
11
 
@@ -14,7 +15,9 @@ def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
14
15
  return dtype
15
16
 
16
17
 
17
- def onnx_opset_version(graph: Graph, domain: str = "") -> int | None:
18
+ def onnx_opset_version(graph: Graph | GraphContext, domain: str = "") -> int | None:
19
+ if isinstance(graph, GraphContext):
20
+ return graph.opset_version(domain)
18
21
  if domain in {"", "ai.onnx"}:
19
22
  domains = {"", "ai.onnx"}
20
23
  else:
@@ -25,7 +28,11 @@ def onnx_opset_version(graph: Graph, domain: str = "") -> int | None:
25
28
  return None
26
29
 
27
30
 
28
- def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType:
31
+ def value_dtype(
32
+ graph: Graph | GraphContext, name: str, node: Node | None = None
33
+ ) -> ScalarType:
34
+ if isinstance(graph, GraphContext):
35
+ return graph.dtype(name, node)
29
36
  try:
30
37
  value = graph.find_value(name)
31
38
  except KeyError as exc:
@@ -37,31 +44,42 @@ def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType
37
44
  return ensure_supported_dtype(value.type.dtype)
38
45
 
39
46
 
40
- def value_shape(graph: Graph, name: str, node: Node | None = None) -> tuple[int, ...]:
41
- try:
47
+ def value_shape(
48
+ graph: Graph | GraphContext, name: str, node: Node | None = None
49
+ ) -> tuple[int, ...]:
50
+ if isinstance(graph, GraphContext):
51
+ shape = graph.shape(name, node)
42
52
  value = graph.find_value(name)
43
- except KeyError as exc:
44
- op_type = node.op_type if node is not None else "unknown"
45
- raise ShapeInferenceError(
46
- f"Missing shape for value '{name}' in op {op_type}. "
47
- "Hint: run ONNX shape inference or export with static shapes."
48
- ) from exc
53
+ else:
54
+ try:
55
+ value = graph.find_value(name)
56
+ except KeyError as exc:
57
+ op_type = node.op_type if node is not None else "unknown"
58
+ raise ShapeInferenceError(
59
+ f"Missing shape for value '{name}' in op {op_type}. "
60
+ "Hint: run ONNX shape inference or export with static shapes."
61
+ ) from exc
62
+ shape = value.type.shape
49
63
  if any(value.type.dim_params):
50
64
  resolved = _resolve_value_shape(graph, name, node)
51
65
  if resolved is not None:
52
66
  return resolved
53
67
  return value.type.shape
54
- return value.type.shape
68
+ return shape
55
69
 
56
70
 
57
- def _find_initializer(graph: Graph, name: str) -> Initializer | None:
71
+ def _find_initializer(graph: Graph | GraphContext, name: str) -> Initializer | None:
72
+ if isinstance(graph, GraphContext):
73
+ return graph.initializer(name)
58
74
  for initializer in graph.initializers:
59
75
  if initializer.name == name:
60
76
  return initializer
61
77
  return None
62
78
 
63
79
 
64
- def _find_node_by_output(graph: Graph, name: str) -> Node | None:
80
+ def _find_node_by_output(graph: Graph | GraphContext, name: str) -> Node | None:
81
+ if isinstance(graph, GraphContext):
82
+ return graph.producer(name)
65
83
  for node in graph.nodes:
66
84
  if name in node.outputs:
67
85
  return node
@@ -69,7 +87,7 @@ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
69
87
 
70
88
 
71
89
  def _shape_values_from_shape_node(
72
- graph: Graph, shape_node: Node, node: Node | None
90
+ graph: Graph | GraphContext, shape_node: Node, node: Node | None
73
91
  ) -> list[int]:
74
92
  if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
75
93
  raise UnsupportedOpError("Shape must have 1 input and 1 output")
@@ -88,7 +106,7 @@ def _shape_values_from_shape_node(
88
106
 
89
107
 
90
108
  def _shape_values_from_initializer(
91
- graph: Graph,
109
+ graph: Graph | GraphContext,
92
110
  name: str,
93
111
  ) -> list[int] | None:
94
112
  initializer = _find_initializer(graph, name)
@@ -103,7 +121,7 @@ def _shape_values_from_initializer(
103
121
 
104
122
 
105
123
  def _shape_values_from_input(
106
- graph: Graph,
124
+ graph: Graph | GraphContext,
107
125
  name: str,
108
126
  node: Node | None,
109
127
  *,
@@ -277,7 +295,7 @@ def _broadcast_shapes(
277
295
 
278
296
 
279
297
  def _resolve_value_shape(
280
- graph: Graph,
298
+ graph: Graph | GraphContext,
281
299
  name: str,
282
300
  node: Node | None,
283
301
  *,
@@ -414,7 +432,7 @@ def _resolve_value_shape(
414
432
  _visited.remove(name)
415
433
 
416
434
 
417
- def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
435
+ def node_dtype(graph: Graph | GraphContext, node: Node, *names: str) -> ScalarType:
418
436
  filtered = [name for name in names if name]
419
437
  if not filtered:
420
438
  raise UnsupportedOpError(
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import ConcatOp
3
+ from ..ir.ops import ConcatOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
@@ -4,7 +4,7 @@ from onnx import numpy_helper
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import ConstantOfShapeOp
7
+ from ..ir.ops import ConstantOfShapeOp
8
8
  from ..dtypes import scalar_type_from_onnx
9
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import math
4
4
  from dataclasses import dataclass
5
5
 
6
- from ..codegen.c_emitter import ConvOp
6
+ from ..ir.ops import ConvOp
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
8
  from ..ir.model import Graph, Node
9
9
  from .common import node_dtype as _node_dtype
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import math
4
4
  from dataclasses import dataclass
5
5
 
6
- from ..codegen.c_emitter import ConvTransposeOp
6
+ from ..ir.ops import ConvTransposeOp
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
8
  from ..ir.model import Graph, Node
9
9
  from .common import node_dtype as _node_dtype
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import CumSumOp
7
+ from ..ir.ops import CumSumOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import DepthToSpaceOp, SpaceToDepthOp
3
+ from ..ir.ops import DepthToSpaceOp, SpaceToDepthOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..lowering.common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import ReshapeOp
3
+ from ..ir.ops import ReshapeOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import value_dtype as _value_dtype
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import EinsumKind, EinsumOp
3
+ from ..ir.ops import EinsumKind, EinsumOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
@@ -1,13 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
- from shared.scalar_functions import ScalarFunction
3
+ from shared.scalar_functions import ScalarFunction, ScalarFunctionError
4
4
  from shared.scalar_types import ScalarType
5
5
 
6
- from ..codegen.c_emitter import ClipOp, UnaryOp
6
+ from ..ir.ops import BinaryOp, ClipOp, UnaryOp
7
7
  from ..errors import UnsupportedOpError
8
+ from ..ir.context import GraphContext
8
9
  from ..ir.model import Graph, Node
9
10
  from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
10
- from ..lowering.registry import register_lowering
11
+ from ..lowering.registry import register_lowering, register_lowering_if_missing
12
+ from ..ops import (
13
+ BINARY_OP_TYPES,
14
+ COMPARE_FUNCTIONS,
15
+ UNARY_OP_TYPES,
16
+ binary_op_symbol,
17
+ unary_op_symbol,
18
+ validate_unary_attrs,
19
+ )
20
+ from ..lowering.variadic import VARIADIC_OP_FUNCTIONS
11
21
 
12
22
 
13
23
  @register_lowering("Clip")
@@ -120,6 +130,138 @@ def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
120
130
  )
121
131
 
122
132
 
133
+ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | UnaryOp:
134
+ if node.op_type == "BitShift":
135
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
136
+ raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
137
+ direction_attr = node.attrs.get("direction", "LEFT")
138
+ if isinstance(direction_attr, bytes):
139
+ direction = direction_attr.decode()
140
+ else:
141
+ direction = str(direction_attr)
142
+ if direction not in {"LEFT", "RIGHT"}:
143
+ raise UnsupportedOpError(
144
+ "BitShift direction must be LEFT or RIGHT"
145
+ )
146
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
147
+ if not op_dtype.is_integer:
148
+ raise UnsupportedOpError("BitShift expects integer inputs")
149
+ function = (
150
+ ScalarFunction.BITWISE_LEFT_SHIFT
151
+ if direction == "LEFT"
152
+ else ScalarFunction.BITWISE_RIGHT_SHIFT
153
+ )
154
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
155
+ if op_spec is None:
156
+ raise UnsupportedOpError("Unsupported op BitShift")
157
+ input0_shape = value_shape(graph, node.inputs[0], node)
158
+ input1_shape = value_shape(graph, node.inputs[1], node)
159
+ output_shape = value_shape(graph, node.outputs[0], node)
160
+ return BinaryOp(
161
+ input0=node.inputs[0],
162
+ input1=node.inputs[1],
163
+ output=node.outputs[0],
164
+ function=function,
165
+ operator_kind=op_spec.kind,
166
+ input0_shape=input0_shape,
167
+ input1_shape=input1_shape,
168
+ shape=output_shape,
169
+ dtype=op_dtype,
170
+ input_dtype=op_dtype,
171
+ )
172
+ if node.op_type == "Mod":
173
+ fmod = int(node.attrs.get("fmod", 0))
174
+ if fmod not in {0, 1}:
175
+ raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
176
+ function = (
177
+ ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
178
+ )
179
+ else:
180
+ try:
181
+ function = ScalarFunction.from_onnx_op(node.op_type)
182
+ except ScalarFunctionError as exc:
183
+ raise UnsupportedOpError(
184
+ f"Unsupported op {node.op_type}"
185
+ ) from exc
186
+ validate_unary_attrs(node.op_type, node.attrs)
187
+ if function in COMPARE_FUNCTIONS:
188
+ input_dtype = node_dtype(graph, node, *node.inputs)
189
+ output_dtype = value_dtype(graph, node.outputs[0], node)
190
+ op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
191
+ if op_spec is None:
192
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
193
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
194
+ raise UnsupportedOpError(
195
+ f"{node.op_type} must have 2 inputs and 1 output"
196
+ )
197
+ if output_dtype != ScalarType.BOOL:
198
+ raise UnsupportedOpError(
199
+ f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
200
+ )
201
+ input0_shape = value_shape(graph, node.inputs[0], node)
202
+ input1_shape = value_shape(graph, node.inputs[1], node)
203
+ output_shape = value_shape(graph, node.outputs[0], node)
204
+ return BinaryOp(
205
+ input0=node.inputs[0],
206
+ input1=node.inputs[1],
207
+ output=node.outputs[0],
208
+ function=function,
209
+ operator_kind=op_spec.kind,
210
+ input0_shape=input0_shape,
211
+ input1_shape=input1_shape,
212
+ shape=output_shape,
213
+ dtype=output_dtype,
214
+ input_dtype=input_dtype,
215
+ )
216
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
217
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
218
+ unary_symbol = unary_op_symbol(function, dtype=op_dtype)
219
+ if op_spec is None and unary_symbol is None:
220
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
221
+ if op_spec is not None:
222
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
223
+ raise UnsupportedOpError(
224
+ f"{node.op_type} must have 2 inputs and 1 output"
225
+ )
226
+ input0_shape = value_shape(graph, node.inputs[0], node)
227
+ input1_shape = value_shape(graph, node.inputs[1], node)
228
+ output_shape = value_shape(graph, node.outputs[0], node)
229
+ return BinaryOp(
230
+ input0=node.inputs[0],
231
+ input1=node.inputs[1],
232
+ output=node.outputs[0],
233
+ function=function,
234
+ operator_kind=op_spec.kind,
235
+ input0_shape=input0_shape,
236
+ input1_shape=input1_shape,
237
+ shape=output_shape,
238
+ dtype=op_dtype,
239
+ input_dtype=op_dtype,
240
+ )
241
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
242
+ raise UnsupportedOpError(
243
+ f"{node.op_type} must have 1 input and 1 output"
244
+ )
245
+ output_shape = value_shape(graph, node.outputs[0], node)
246
+ return UnaryOp(
247
+ input0=node.inputs[0],
248
+ output=node.outputs[0],
249
+ function=function,
250
+ shape=output_shape,
251
+ dtype=op_dtype,
252
+ input_dtype=op_dtype,
253
+ params=(),
254
+ )
255
+
256
+
257
+ _DEFAULT_ELEMENTWISE_TYPES = (
258
+ BINARY_OP_TYPES.union(UNARY_OP_TYPES) - set(VARIADIC_OP_FUNCTIONS.keys())
259
+ )
260
+
261
+ for _op_type in _DEFAULT_ELEMENTWISE_TYPES:
262
+ register_lowering_if_missing(_op_type)(_lower_binary_unary)
263
+
264
+
123
265
  @register_lowering("IsInf")
124
266
  def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
125
267
  if len(node.inputs) != 1 or len(node.outputs) != 1:
@@ -130,6 +272,12 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
130
272
  raise UnsupportedOpError("IsInf only supports floating-point inputs")
131
273
  if output_dtype != ScalarType.BOOL:
132
274
  raise UnsupportedOpError("IsInf output must be bool")
275
+ detect_negative = int(node.attrs.get("detect_negative", 1))
276
+ detect_positive = int(node.attrs.get("detect_positive", 1))
277
+ if detect_negative not in {0, 1} or detect_positive not in {0, 1}:
278
+ raise UnsupportedOpError(
279
+ "IsInf detect_negative and detect_positive must be 0 or 1"
280
+ )
133
281
  output_shape = value_shape(graph, node.outputs[0], node)
134
282
  return UnaryOp(
135
283
  input0=node.inputs[0],
@@ -138,7 +286,7 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
138
286
  shape=output_shape,
139
287
  dtype=output_dtype,
140
288
  input_dtype=input_dtype,
141
- params=(),
289
+ params=(float(detect_negative), float(detect_positive)),
142
290
  )
143
291
 
144
292
 
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import ExpandOp
7
+ from ..ir.ops import ExpandOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import EyeLikeOp
3
+ from ..ir.ops import EyeLikeOp
4
4
  from ..dtypes import scalar_type_from_onnx
5
5
  from ..errors import ShapeInferenceError, UnsupportedOpError
6
6
  from ..ir.model import Graph, Node
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import ReshapeOp
3
+ from ..ir.ops import ReshapeOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import shape_product, value_dtype, value_shape
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import GatherOp
5
+ from ..ir.ops import GatherOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from ..validation import normalize_axis