emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import GatherElementsOp
5
+ from ..ir.ops import GatherElementsOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from ..validation import normalize_axis
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import GatherNDOp
5
+ from ..ir.ops import GatherNDOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import value_dtype as _value_dtype
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import GemmOp
7
+ from ..ir.ops import GemmOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Node
10
10
  from .common import node_dtype as _node_dtype
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ReduceOp
5
+ from ..ir.ops import ReduceOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import value_dtype as _value_dtype
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import GridSampleOp
7
+ from ..ir.ops import GridSampleOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Node
10
10
  from .common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import GroupNormalizationOp
3
+ from ..ir.ops import GroupNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import HardmaxOp
5
+ from ..ir.ops import HardmaxOp
6
6
  from ..errors import UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import node_dtype as _node_dtype
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import IdentityOp
3
+ from ..ir.ops import IdentityOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import InstanceNormalizationOp
3
+ from ..ir.ops import InstanceNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import LayerNormalizationOp
3
+ from ..ir.ops import LayerNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import LogSoftmaxOp
3
+ from ..ir.ops import LogSoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import LpNormalizationOp
3
+ from ..ir.ops import LpNormalizationOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import LpPoolOp
5
+ from ..ir.ops import LpPoolOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .registry import register_lowering
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import LrnOp
5
+ from ..ir.ops import LrnOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .registry import register_lowering
@@ -323,7 +323,7 @@ def resolve_lstm_spec(graph: Graph, node: Node) -> LstmSpec:
323
323
 
324
324
  @register_lowering("LSTM")
325
325
  def lower_lstm(graph: Graph, node: Node) -> "LstmOp":
326
- from ..codegen.c_emitter import LstmOp
326
+ from ..ir.ops import LstmOp
327
327
 
328
328
  spec = resolve_lstm_spec(graph, node)
329
329
  return LstmOp(
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import MatMulOp
5
+ from ..ir.ops import MatMulOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import node_dtype as _node_dtype
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
5
 
6
6
  from shared.scalar_types import ScalarType
7
7
 
8
- from ..codegen.c_emitter import MaxPoolOp
8
+ from ..ir.ops import MaxPoolOp
9
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
11
11
  from .common import node_dtype as _node_dtype
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import MeanVarianceNormalizationOp
3
+ from ..ir.ops import MeanVarianceNormalizationOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import NegativeLogLikelihoodLossOp
5
+ from ..ir.ops import NegativeLogLikelihoodLossOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
8
  from .common import shape_product as _shape_product
@@ -0,0 +1,157 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import NonMaxSuppressionOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from ..lowering.common import optional_name, shape_product, value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ def _validate_scalar_input(
13
+ graph: Graph,
14
+ name: str,
15
+ node: Node,
16
+ *,
17
+ allowed_dtypes: set[ScalarType],
18
+ label: str,
19
+ ) -> tuple[ScalarType, tuple[int, ...]]:
20
+ dtype = value_dtype(graph, name, node)
21
+ if dtype not in allowed_dtypes:
22
+ allowed = ", ".join(sorted(d.onnx_name for d in allowed_dtypes))
23
+ raise UnsupportedOpError(
24
+ f"{node.op_type} {label} must be {allowed}, got {dtype.onnx_name}"
25
+ )
26
+ shape = value_shape(graph, name, node)
27
+ if shape not in {(), (1,)}:
28
+ total = shape_product(shape)
29
+ if total != 1:
30
+ raise ShapeInferenceError(
31
+ f"{node.op_type} {label} must be a scalar tensor, got shape {shape}"
32
+ )
33
+ return dtype, shape
34
+
35
+
36
+ @register_lowering("NonMaxSuppression")
37
+ def lower_non_max_suppression(graph: Graph, node: Node) -> NonMaxSuppressionOp:
38
+ if node.op_type != "NonMaxSuppression":
39
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
40
+ if len(node.outputs) != 1:
41
+ raise UnsupportedOpError(
42
+ f"{node.op_type} must have 1 output, got {len(node.outputs)}"
43
+ )
44
+ if len(node.inputs) < 2 or len(node.inputs) > 5:
45
+ raise UnsupportedOpError(
46
+ f"{node.op_type} must have 2 to 5 inputs, got {len(node.inputs)}"
47
+ )
48
+
49
+ boxes = node.inputs[0]
50
+ scores = node.inputs[1]
51
+ max_output_boxes_per_class = optional_name(node.inputs, 2)
52
+ iou_threshold = optional_name(node.inputs, 3)
53
+ score_threshold = optional_name(node.inputs, 4)
54
+ output = node.outputs[0]
55
+
56
+ boxes_shape = value_shape(graph, boxes, node)
57
+ scores_shape = value_shape(graph, scores, node)
58
+ if len(boxes_shape) != 3 or boxes_shape[2] != 4:
59
+ raise ShapeInferenceError(
60
+ f"{node.op_type} boxes input must have shape "
61
+ f"[num_batches, num_boxes, 4], got {boxes_shape}"
62
+ )
63
+ if len(scores_shape) != 3:
64
+ raise ShapeInferenceError(
65
+ f"{node.op_type} scores input must have shape "
66
+ f"[num_batches, num_classes, num_boxes], got {scores_shape}"
67
+ )
68
+ if boxes_shape[0] != scores_shape[0]:
69
+ raise ShapeInferenceError(
70
+ f"{node.op_type} boxes/scores batch dims must match, "
71
+ f"got {boxes_shape[0]} and {scores_shape[0]}"
72
+ )
73
+ if boxes_shape[1] != scores_shape[2]:
74
+ raise ShapeInferenceError(
75
+ f"{node.op_type} boxes num_boxes dim {boxes_shape[1]} "
76
+ f"must match scores num_boxes dim {scores_shape[2]}"
77
+ )
78
+
79
+ boxes_dtype = value_dtype(graph, boxes, node)
80
+ scores_dtype = value_dtype(graph, scores, node)
81
+ if boxes_dtype != scores_dtype or not boxes_dtype.is_float:
82
+ raise UnsupportedOpError(
83
+ f"{node.op_type} boxes and scores must be the same float dtype, "
84
+ f"got {boxes_dtype.onnx_name} and {scores_dtype.onnx_name}"
85
+ )
86
+
87
+ max_output_dtype = None
88
+ max_output_shape = None
89
+ if max_output_boxes_per_class is not None:
90
+ max_output_dtype, max_output_shape = _validate_scalar_input(
91
+ graph,
92
+ max_output_boxes_per_class,
93
+ node,
94
+ allowed_dtypes={ScalarType.I32, ScalarType.I64},
95
+ label="max_output_boxes_per_class input",
96
+ )
97
+
98
+ iou_threshold_dtype = None
99
+ iou_threshold_shape = None
100
+ if iou_threshold is not None:
101
+ iou_threshold_dtype, iou_threshold_shape = _validate_scalar_input(
102
+ graph,
103
+ iou_threshold,
104
+ node,
105
+ allowed_dtypes={ScalarType.F32, ScalarType.F64},
106
+ label="iou_threshold input",
107
+ )
108
+
109
+ score_threshold_dtype = None
110
+ score_threshold_shape = None
111
+ if score_threshold is not None:
112
+ score_threshold_dtype, score_threshold_shape = _validate_scalar_input(
113
+ graph,
114
+ score_threshold,
115
+ node,
116
+ allowed_dtypes={ScalarType.F32, ScalarType.F64},
117
+ label="score_threshold input",
118
+ )
119
+
120
+ output_shape = value_shape(graph, output, node)
121
+ if len(output_shape) != 2 or output_shape[1] != 3:
122
+ raise ShapeInferenceError(
123
+ f"{node.op_type} output must have shape [num_selected, 3], "
124
+ f"got {output_shape}"
125
+ )
126
+ output_dtype = value_dtype(graph, output, node)
127
+ if output_dtype != ScalarType.I64:
128
+ raise UnsupportedOpError(
129
+ f"{node.op_type} output dtype must be int64"
130
+ )
131
+
132
+ center_point_box = int(node.attrs.get("center_point_box", 0))
133
+ if center_point_box not in {0, 1}:
134
+ raise UnsupportedOpError(
135
+ f"{node.op_type} center_point_box must be 0 or 1, got {center_point_box}"
136
+ )
137
+
138
+ return NonMaxSuppressionOp(
139
+ boxes=boxes,
140
+ scores=scores,
141
+ max_output_boxes_per_class=max_output_boxes_per_class,
142
+ iou_threshold=iou_threshold,
143
+ score_threshold=score_threshold,
144
+ output=output,
145
+ boxes_shape=boxes_shape,
146
+ scores_shape=scores_shape,
147
+ output_shape=output_shape,
148
+ center_point_box=center_point_box,
149
+ boxes_dtype=boxes_dtype,
150
+ output_dtype=output_dtype,
151
+ max_output_dtype=max_output_dtype,
152
+ max_output_shape=max_output_shape,
153
+ iou_threshold_dtype=iou_threshold_dtype,
154
+ iou_threshold_shape=iou_threshold_shape,
155
+ score_threshold_dtype=score_threshold_dtype,
156
+ score_threshold_shape=score_threshold_shape,
157
+ )
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import NonZeroOp
5
+ from ..ir.ops import NonZeroOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import value_dtype, value_shape
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import OneHotOp
7
+ from ..ir.ops import OneHotOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import value_dtype, value_shape
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import PadOp
7
+ from ..ir.ops import PadOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import optional_name, value_dtype, value_shape
@@ -0,0 +1,212 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..ir.ops import QLinearMatMulOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from .common import value_dtype as _value_dtype
11
+ from .common import value_shape as _value_shape
12
+ from .registry import register_lowering
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class QLinearMatMulSpec:
17
+ input0_shape: tuple[int, ...]
18
+ input1_shape: tuple[int, ...]
19
+ output_shape: tuple[int, ...]
20
+ batch_shape: tuple[int, ...]
21
+ input0_batch_shape: tuple[int, ...]
22
+ input1_batch_shape: tuple[int, ...]
23
+ m: int
24
+ n: int
25
+ k: int
26
+ left_vector: bool
27
+ right_vector: bool
28
+
29
+
30
+ def resolve_qlinear_matmul_spec(graph: Graph, node: Node) -> QLinearMatMulSpec:
31
+ if len(node.inputs) != 8 or len(node.outputs) != 1:
32
+ raise UnsupportedOpError(
33
+ "QLinearMatMul must have 8 inputs and 1 output"
34
+ )
35
+ input0_shape = _value_shape(graph, node.inputs[0], node)
36
+ input1_shape = _value_shape(graph, node.inputs[3], node)
37
+ if len(input0_shape) < 1 or len(input1_shape) < 1:
38
+ raise UnsupportedOpError(
39
+ "QLinearMatMul inputs must be at least 1D, "
40
+ f"got {input0_shape} x {input1_shape}"
41
+ )
42
+ left_vector = len(input0_shape) == 1
43
+ right_vector = len(input1_shape) == 1
44
+ input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
45
+ input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
46
+ m, k_left = input0_effective[-2], input0_effective[-1]
47
+ k_right, n = input1_effective[-2], input1_effective[-1]
48
+ if k_left != k_right:
49
+ raise ShapeInferenceError(
50
+ "QLinearMatMul inner dimensions must match, "
51
+ f"got {k_left} and {k_right}"
52
+ )
53
+ batch_shape, input0_batch_shape, input1_batch_shape = (
54
+ _broadcast_batch_shapes(
55
+ input0_effective[:-2], input1_effective[:-2], node
56
+ )
57
+ )
58
+ if left_vector and right_vector:
59
+ output_shape = batch_shape
60
+ elif left_vector:
61
+ output_shape = batch_shape + (n,)
62
+ elif right_vector:
63
+ output_shape = batch_shape + (m,)
64
+ else:
65
+ output_shape = batch_shape + (m, n)
66
+ expected_output_shape = _value_shape(graph, node.outputs[0], node)
67
+ if expected_output_shape != output_shape:
68
+ raise ShapeInferenceError(
69
+ "QLinearMatMul output shape must be "
70
+ f"{output_shape}, got {expected_output_shape}"
71
+ )
72
+ return QLinearMatMulSpec(
73
+ input0_shape=input0_shape,
74
+ input1_shape=input1_shape,
75
+ output_shape=output_shape,
76
+ batch_shape=batch_shape,
77
+ input0_batch_shape=input0_batch_shape,
78
+ input1_batch_shape=input1_batch_shape,
79
+ m=m,
80
+ n=n,
81
+ k=k_left,
82
+ left_vector=left_vector,
83
+ right_vector=right_vector,
84
+ )
85
+
86
+
87
+ def _broadcast_batch_shapes(
88
+ left: tuple[int, ...], right: tuple[int, ...], node: Node
89
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
90
+ max_rank = max(len(left), len(right))
91
+ left_padded = (1,) * (max_rank - len(left)) + left
92
+ right_padded = (1,) * (max_rank - len(right)) + right
93
+ broadcast_shape = []
94
+ for left_dim, right_dim in zip(left_padded, right_padded):
95
+ if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
96
+ raise ShapeInferenceError(
97
+ "QLinearMatMul batch dimensions must be broadcastable, "
98
+ f"got {left} x {right}"
99
+ )
100
+ broadcast_shape.append(max(left_dim, right_dim))
101
+ return tuple(broadcast_shape), left_padded, right_padded
102
+
103
+
104
+ def _ensure_scalar_input(
105
+ graph: Graph, name: str, node: Node, label: str
106
+ ) -> tuple[int, ...]:
107
+ shape = _value_shape(graph, name, node)
108
+ if shape not in {(), (1,)}:
109
+ raise UnsupportedOpError(
110
+ f"QLinearMatMul {label} must be scalar, got shape {shape}"
111
+ )
112
+ return shape
113
+
114
+
115
+ def _ensure_scale_dtype(dtype: ScalarType, label: str) -> None:
116
+ if not dtype.is_float:
117
+ raise UnsupportedOpError(
118
+ f"QLinearMatMul {label} must be float16/float/double"
119
+ )
120
+
121
+
122
+ @register_lowering("QLinearMatMul")
123
+ def lower_qlinear_matmul(graph: Graph, node: Node) -> QLinearMatMulOp:
124
+ spec = resolve_qlinear_matmul_spec(graph, node)
125
+ input0_dtype = _value_dtype(graph, node.inputs[0], node)
126
+ input1_dtype = _value_dtype(graph, node.inputs[3], node)
127
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
128
+ if input0_dtype not in {ScalarType.U8, ScalarType.I8}:
129
+ raise UnsupportedOpError(
130
+ "QLinearMatMul supports uint8/int8 inputs only"
131
+ )
132
+ if input1_dtype not in {ScalarType.U8, ScalarType.I8}:
133
+ raise UnsupportedOpError(
134
+ "QLinearMatMul supports uint8/int8 inputs only"
135
+ )
136
+ if output_dtype not in {ScalarType.U8, ScalarType.I8}:
137
+ raise UnsupportedOpError(
138
+ "QLinearMatMul supports uint8/int8 outputs only"
139
+ )
140
+ input0_scale_dtype = _value_dtype(graph, node.inputs[1], node)
141
+ input1_scale_dtype = _value_dtype(graph, node.inputs[4], node)
142
+ output_scale_dtype = _value_dtype(graph, node.inputs[6], node)
143
+ _ensure_scale_dtype(input0_scale_dtype, "a_scale")
144
+ _ensure_scale_dtype(input1_scale_dtype, "b_scale")
145
+ _ensure_scale_dtype(output_scale_dtype, "y_scale")
146
+ input0_zero_dtype = _value_dtype(graph, node.inputs[2], node)
147
+ input1_zero_dtype = _value_dtype(graph, node.inputs[5], node)
148
+ output_zero_dtype = _value_dtype(graph, node.inputs[7], node)
149
+ if input0_zero_dtype != input0_dtype:
150
+ raise UnsupportedOpError(
151
+ "QLinearMatMul a_zero_point dtype must match a"
152
+ )
153
+ if input1_zero_dtype != input1_dtype:
154
+ raise UnsupportedOpError(
155
+ "QLinearMatMul b_zero_point dtype must match b"
156
+ )
157
+ if output_zero_dtype != output_dtype:
158
+ raise UnsupportedOpError(
159
+ "QLinearMatMul y_zero_point dtype must match y"
160
+ )
161
+ input0_scale_shape = _ensure_scalar_input(
162
+ graph, node.inputs[1], node, "a_scale"
163
+ )
164
+ input1_scale_shape = _ensure_scalar_input(
165
+ graph, node.inputs[4], node, "b_scale"
166
+ )
167
+ output_scale_shape = _ensure_scalar_input(
168
+ graph, node.inputs[6], node, "y_scale"
169
+ )
170
+ input0_zero_shape = _ensure_scalar_input(
171
+ graph, node.inputs[2], node, "a_zero_point"
172
+ )
173
+ input1_zero_shape = _ensure_scalar_input(
174
+ graph, node.inputs[5], node, "b_zero_point"
175
+ )
176
+ output_zero_shape = _ensure_scalar_input(
177
+ graph, node.inputs[7], node, "y_zero_point"
178
+ )
179
+ return QLinearMatMulOp(
180
+ input0=node.inputs[0],
181
+ input0_scale=node.inputs[1],
182
+ input0_zero_point=node.inputs[2],
183
+ input1=node.inputs[3],
184
+ input1_scale=node.inputs[4],
185
+ input1_zero_point=node.inputs[5],
186
+ output_scale=node.inputs[6],
187
+ output_zero_point=node.inputs[7],
188
+ output=node.outputs[0],
189
+ input0_shape=spec.input0_shape,
190
+ input1_shape=spec.input1_shape,
191
+ output_shape=spec.output_shape,
192
+ batch_shape=spec.batch_shape,
193
+ input0_batch_shape=spec.input0_batch_shape,
194
+ input1_batch_shape=spec.input1_batch_shape,
195
+ m=spec.m,
196
+ n=spec.n,
197
+ k=spec.k,
198
+ left_vector=spec.left_vector,
199
+ right_vector=spec.right_vector,
200
+ input0_dtype=input0_dtype,
201
+ input1_dtype=input1_dtype,
202
+ dtype=output_dtype,
203
+ input0_scale_dtype=input0_scale_dtype,
204
+ input1_scale_dtype=input1_scale_dtype,
205
+ output_scale_dtype=output_scale_dtype,
206
+ input0_scale_shape=input0_scale_shape,
207
+ input1_scale_shape=input1_scale_shape,
208
+ output_scale_shape=output_scale_shape,
209
+ input0_zero_shape=input0_zero_shape,
210
+ input1_zero_shape=input1_zero_shape,
211
+ output_zero_shape=output_zero_shape,
212
+ )
@@ -10,7 +10,7 @@ from ..ir.model import Graph, Node
10
10
  from ..validation import normalize_axis
11
11
  from .common import optional_name, value_dtype as _value_dtype, value_shape as _value_shape
12
12
  from .registry import register_lowering
13
- from ..codegen.c_emitter import QuantizeLinearOp
13
+ from ..ir.ops import QuantizeLinearOp
14
14
 
15
15
 
16
16
  @dataclass(frozen=True)
@@ -6,7 +6,7 @@ import numpy as np
6
6
 
7
7
  from shared.scalar_types import ScalarType
8
8
 
9
- from ..codegen.c_emitter import RangeOp
9
+ from ..ir.ops import RangeOp
10
10
  from ..errors import ShapeInferenceError, UnsupportedOpError
11
11
  from ..ir.model import Graph, Initializer, Node
12
12
  from ..lowering.common import node_dtype, value_shape
@@ -6,7 +6,7 @@ import numpy as np
6
6
 
7
7
  from shared.scalar_types import ScalarType
8
8
 
9
- from ..codegen.c_emitter import ReduceOp, ReshapeOp
9
+ from ..ir.ops import ReduceOp, ReshapeOp
10
10
  from ..dtypes import scalar_type_from_onnx
11
11
  from ..errors import ShapeInferenceError, UnsupportedOpError
12
12
  from ..ir.model import Graph, Initializer, Node
@@ -3,32 +3,51 @@ from __future__ import annotations
3
3
  from collections.abc import Callable, Mapping
4
4
  from typing import TypeVar
5
5
 
6
+ from ..ir.context import GraphContext
6
7
  from ..ir.model import Graph, Node
8
+ from ..ir.op_base import OpBase
7
9
  from ..errors import UnsupportedOpError
8
10
 
9
11
  LoweredOp = TypeVar("LoweredOp")
10
12
  Handler = TypeVar("Handler")
11
13
 
12
- _LOWERING_REGISTRY: dict[str, Callable[[Graph, Node], object]] = {}
14
+ _LOWERING_REGISTRY: dict[str, Callable[[Graph | GraphContext, Node], OpBase]] = {}
13
15
 
14
16
 
15
17
  def register_lowering(
16
18
  op_type: str,
17
19
  ) -> Callable[[Callable[[Graph, Node], LoweredOp]], Callable[[Graph, Node], LoweredOp]]:
18
20
  def decorator(
19
- func: Callable[[Graph, Node], LoweredOp],
20
- ) -> Callable[[Graph, Node], LoweredOp]:
21
+ func: Callable[[Graph | GraphContext, Node], LoweredOp],
22
+ ) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
21
23
  _LOWERING_REGISTRY[op_type] = func
22
24
  return func
23
25
 
24
26
  return decorator
25
27
 
26
28
 
27
- def get_lowering(op_type: str) -> Callable[[Graph, Node], object] | None:
29
+ def register_lowering_if_missing(
30
+ op_type: str,
31
+ ) -> Callable[[Callable[[Graph | GraphContext, Node], LoweredOp]], Callable[[Graph | GraphContext, Node], LoweredOp]]:
32
+ def decorator(
33
+ func: Callable[[Graph | GraphContext, Node], LoweredOp],
34
+ ) -> Callable[[Graph | GraphContext, Node], LoweredOp]:
35
+ if op_type not in _LOWERING_REGISTRY:
36
+ _LOWERING_REGISTRY[op_type] = func
37
+ return func
38
+
39
+ return decorator
40
+
41
+
42
+ def get_lowering(
43
+ op_type: str,
44
+ ) -> Callable[[Graph | GraphContext, Node], OpBase] | None:
28
45
  return _LOWERING_REGISTRY.get(op_type)
29
46
 
30
47
 
31
- def get_lowering_registry() -> Mapping[str, Callable[[Graph, Node], object]]:
48
+ def get_lowering_registry() -> Mapping[
49
+ str, Callable[[Graph | GraphContext, Node], OpBase]
50
+ ]:
32
51
  return _LOWERING_REGISTRY
33
52
 
34
53