emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..errors import ShapeInferenceError, UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+ from ..validation import normalize_axis
10
+ from .common import (
11
+ optional_name,
12
+ value_dtype as _value_dtype,
13
+ value_shape as _value_shape,
14
+ )
15
+ from .registry import register_lowering
16
+ from ..ir.ops import DequantizeLinearOp
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class DequantizeSpec:
21
+ input_shape: tuple[int, ...]
22
+ scale_shape: tuple[int, ...]
23
+ axis: int | None
24
+ block_size: int | None
25
+
26
+
27
+ def resolve_dequantize_spec(graph: Graph, node: Node) -> DequantizeSpec:
28
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
29
+ raise UnsupportedOpError(
30
+ "DequantizeLinear must have 2 or 3 inputs and 1 output"
31
+ )
32
+ supported_attrs = {"axis", "block_size"}
33
+ if set(node.attrs) - supported_attrs:
34
+ raise UnsupportedOpError("DequantizeLinear has unsupported attributes")
35
+ block_size = int(node.attrs.get("block_size", 0))
36
+ if block_size < 0:
37
+ raise UnsupportedOpError("DequantizeLinear block_size must be >= 0")
38
+ input_shape = _value_shape(graph, node.inputs[0], node)
39
+ scale_shape = _value_shape(graph, node.inputs[1], node)
40
+ zero_point_name = optional_name(node.inputs, 2)
41
+ if zero_point_name is not None:
42
+ zero_point_shape = _value_shape(graph, zero_point_name, node)
43
+ if zero_point_shape != scale_shape:
44
+ raise ShapeInferenceError(
45
+ "DequantizeLinear zero_point shape must match scale shape"
46
+ )
47
+ if scale_shape not in {(), (1,)}:
48
+ axis = int(node.attrs.get("axis", 1))
49
+ axis = normalize_axis(axis, input_shape, node)
50
+ if block_size > 0:
51
+ if len(scale_shape) != len(input_shape):
52
+ raise UnsupportedOpError(
53
+ "DequantizeLinear blocked scales must match input rank"
54
+ )
55
+ if input_shape[axis] % block_size != 0:
56
+ raise ShapeInferenceError(
57
+ "DequantizeLinear block_size must evenly divide axis length"
58
+ )
59
+ expected = list(input_shape)
60
+ expected[axis] = input_shape[axis] // block_size
61
+ if scale_shape != tuple(expected):
62
+ raise ShapeInferenceError(
63
+ "DequantizeLinear blocked scale shape must match "
64
+ "input shape with a reduced axis"
65
+ )
66
+ else:
67
+ if len(scale_shape) != 1:
68
+ raise UnsupportedOpError(
69
+ "DequantizeLinear supports per-tensor, per-axis, "
70
+ "and blocked scales only"
71
+ )
72
+ if scale_shape[0] != input_shape[axis]:
73
+ raise ShapeInferenceError(
74
+ "DequantizeLinear scale length must match input axis size"
75
+ )
76
+ else:
77
+ axis = None
78
+ block_size = 0
79
+ return DequantizeSpec(
80
+ input_shape=input_shape,
81
+ scale_shape=scale_shape,
82
+ axis=axis,
83
+ block_size=block_size or None,
84
+ )
85
+
86
+
87
+ @register_lowering("DequantizeLinear")
88
+ def lower_dequantize_linear(graph: Graph, node: Node) -> DequantizeLinearOp:
89
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
90
+ scale_dtype = _value_dtype(graph, node.inputs[1], node)
91
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
92
+ if input_dtype not in {
93
+ ScalarType.U8,
94
+ ScalarType.I8,
95
+ ScalarType.U16,
96
+ ScalarType.I16,
97
+ }:
98
+ raise UnsupportedOpError(
99
+ "DequantizeLinear supports int8/uint8/int16/uint16 inputs only"
100
+ )
101
+ if not scale_dtype.is_float or not output_dtype.is_float:
102
+ raise UnsupportedOpError(
103
+ "DequantizeLinear supports float16/float/double scales and outputs only"
104
+ )
105
+ if output_dtype != scale_dtype:
106
+ raise UnsupportedOpError(
107
+ "DequantizeLinear output dtype must match scale dtype"
108
+ )
109
+ zero_point_name = optional_name(node.inputs, 2)
110
+ if zero_point_name is not None:
111
+ zero_point_dtype = _value_dtype(graph, zero_point_name, node)
112
+ if zero_point_dtype != input_dtype:
113
+ raise UnsupportedOpError(
114
+ "DequantizeLinear zero_point dtype must match input dtype"
115
+ )
116
+ spec = resolve_dequantize_spec(graph, node)
117
+ return DequantizeLinearOp(
118
+ input0=node.inputs[0],
119
+ scale=node.inputs[1],
120
+ zero_point=zero_point_name,
121
+ output=node.outputs[0],
122
+ input_shape=spec.input_shape,
123
+ axis=spec.axis,
124
+ block_size=spec.block_size,
125
+ dtype=output_dtype,
126
+ input_dtype=input_dtype,
127
+ scale_dtype=scale_dtype,
128
+ )
@@ -3,11 +3,18 @@ from __future__ import annotations
3
3
  from shared.scalar_functions import ScalarFunction, ScalarFunctionError
4
4
  from shared.scalar_types import ScalarType
5
5
 
6
- from ..ir.ops import BinaryOp, ClipOp, UnaryOp
6
+ from ..ir.op_base import BroadcastingOpBase
7
+ from ..ir.ops import BinaryOp, ClipOp, PowOp, UnaryOp
7
8
  from ..errors import UnsupportedOpError
8
9
  from ..ir.context import GraphContext
9
10
  from ..ir.model import Graph, Node
10
- from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
11
+ from ..lowering.common import (
12
+ node_dtype,
13
+ onnx_opset_version,
14
+ optional_name,
15
+ value_dtype,
16
+ value_shape,
17
+ )
11
18
  from ..lowering.registry import register_lowering, register_lowering_if_missing
12
19
  from ..ops import (
13
20
  BINARY_OP_TYPES,
@@ -29,6 +36,24 @@ def lower_clip(graph: Graph, node: Node) -> ClipOp:
29
36
  raise UnsupportedOpError("Clip input must be provided")
30
37
  min_name = optional_name(node.inputs, 1)
31
38
  max_name = optional_name(node.inputs, 2)
39
+ min_value = None
40
+ max_value = None
41
+ opset_version = onnx_opset_version(graph)
42
+ if opset_version is None or opset_version < 11:
43
+ if min_name is None and "min" in node.attrs:
44
+ try:
45
+ min_value = float(node.attrs["min"])
46
+ except (TypeError, ValueError) as exc:
47
+ raise UnsupportedOpError(
48
+ "Clip min attribute must be numeric"
49
+ ) from exc
50
+ if max_name is None and "max" in node.attrs:
51
+ try:
52
+ max_value = float(node.attrs["max"])
53
+ except (TypeError, ValueError) as exc:
54
+ raise UnsupportedOpError(
55
+ "Clip max attribute must be numeric"
56
+ ) from exc
32
57
  input_dtype = value_dtype(graph, input_name, node)
33
58
  output_dtype = value_dtype(graph, node.outputs[0], node)
34
59
  if input_dtype != output_dtype:
@@ -61,11 +86,8 @@ def lower_clip(graph: Graph, node: Node) -> ClipOp:
61
86
  input_min=min_name,
62
87
  input_max=max_name,
63
88
  output=node.outputs[0],
64
- input_shape=input_shape,
65
- min_shape=min_shape,
66
- max_shape=max_shape,
67
- output_shape=output_shape,
68
- dtype=input_dtype,
89
+ min_value=min_value,
90
+ max_value=max_value,
69
91
  )
70
92
 
71
93
 
@@ -82,9 +104,54 @@ def lower_celu(graph: Graph, node: Node) -> UnaryOp:
82
104
  input0=node.inputs[0],
83
105
  output=node.outputs[0],
84
106
  function=ScalarFunction.CELU,
85
- shape=output_shape,
86
- dtype=dtype,
87
- input_dtype=dtype,
107
+ params=(alpha,),
108
+ )
109
+
110
+
111
+ @register_lowering("Elu")
112
+ def lower_elu(graph: Graph, node: Node) -> UnaryOp:
113
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
114
+ raise UnsupportedOpError("Elu must have 1 input and 1 output")
115
+ dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
116
+ if not dtype.is_float:
117
+ raise UnsupportedOpError("Elu only supports floating-point inputs")
118
+ for key in node.attrs:
119
+ if key != "alpha":
120
+ raise UnsupportedOpError(f"Elu does not support attribute {key}")
121
+ try:
122
+ alpha = float(node.attrs.get("alpha", 1.0))
123
+ except (TypeError, ValueError) as exc:
124
+ raise UnsupportedOpError("Elu alpha must be numeric") from exc
125
+ output_shape = value_shape(graph, node.outputs[0], node)
126
+ return UnaryOp(
127
+ input0=node.inputs[0],
128
+ output=node.outputs[0],
129
+ function=ScalarFunction.ELU,
130
+ params=(alpha,),
131
+ )
132
+
133
+
134
+ @register_lowering("LeakyRelu")
135
+ def lower_leaky_relu(graph: Graph, node: Node) -> UnaryOp:
136
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
137
+ raise UnsupportedOpError("LeakyRelu must have 1 input and 1 output")
138
+ dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
139
+ if not dtype.is_float:
140
+ raise UnsupportedOpError("LeakyRelu only supports floating-point inputs")
141
+ for key in node.attrs:
142
+ if key != "alpha":
143
+ raise UnsupportedOpError(
144
+ f"LeakyRelu does not support attribute {key}"
145
+ )
146
+ try:
147
+ alpha = float(node.attrs.get("alpha", 0.01))
148
+ except (TypeError, ValueError) as exc:
149
+ raise UnsupportedOpError("LeakyRelu alpha must be numeric") from exc
150
+ output_shape = value_shape(graph, node.outputs[0], node)
151
+ return UnaryOp(
152
+ input0=node.inputs[0],
153
+ output=node.outputs[0],
154
+ function=ScalarFunction.LEAKY_RELU,
88
155
  params=(alpha,),
89
156
  )
90
157
 
@@ -102,9 +169,6 @@ def lower_swish(graph: Graph, node: Node) -> UnaryOp:
102
169
  input0=node.inputs[0],
103
170
  output=node.outputs[0],
104
171
  function=ScalarFunction.SWISH,
105
- shape=output_shape,
106
- dtype=dtype,
107
- input_dtype=dtype,
108
172
  params=(alpha,),
109
173
  )
110
174
 
@@ -123,13 +187,50 @@ def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
123
187
  input0=node.inputs[0],
124
188
  output=node.outputs[0],
125
189
  function=ScalarFunction.SHRINK,
126
- shape=output_shape,
127
- dtype=dtype,
128
- input_dtype=dtype,
129
190
  params=(bias, lambd),
130
191
  )
131
192
 
132
193
 
194
+ @register_lowering("Pow")
195
+ def lower_pow(graph: Graph, node: Node) -> PowOp:
196
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
197
+ raise UnsupportedOpError("Pow must have 2 inputs and 1 output")
198
+ op_dtype = value_dtype(graph, node.inputs[0], node)
199
+ op_spec = binary_op_symbol(ScalarFunction.POW, node.attrs, dtype=op_dtype)
200
+ if op_spec is None:
201
+ raise UnsupportedOpError("Unsupported op Pow")
202
+ return PowOp(
203
+ input0=node.inputs[0],
204
+ input1=node.inputs[1],
205
+ output=node.outputs[0],
206
+ function=ScalarFunction.POW,
207
+ operator_kind=op_spec.kind,
208
+ )
209
+
210
+
211
+ def _infer_binary_output_shape(
212
+ *,
213
+ function: ScalarFunction,
214
+ input0_shape: tuple[int, ...],
215
+ input1_shape: tuple[int, ...],
216
+ ) -> tuple[int, ...]:
217
+ if function != ScalarFunction.PRELU:
218
+ return BroadcastingOpBase.broadcast_shapes(input0_shape, input1_shape)
219
+ if BroadcastingOpBase.unidirectional_broadcastable(
220
+ input1_shape, input0_shape
221
+ ):
222
+ return input0_shape
223
+ channel_axis = BroadcastingOpBase.prelu_channel_axis(
224
+ input0_shape, input1_shape
225
+ )
226
+ if channel_axis is None:
227
+ raise ShapeInferenceError(
228
+ "Broadcasting mismatch for shapes: "
229
+ + ", ".join(str(shape) for shape in (input0_shape, input1_shape))
230
+ )
231
+ return input0_shape
232
+
233
+
133
234
  def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | UnaryOp:
134
235
  if node.op_type == "BitShift":
135
236
  if len(node.inputs) != 2 or len(node.outputs) != 1:
@@ -163,11 +264,6 @@ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | U
163
264
  output=node.outputs[0],
164
265
  function=function,
165
266
  operator_kind=op_spec.kind,
166
- input0_shape=input0_shape,
167
- input1_shape=input1_shape,
168
- shape=output_shape,
169
- dtype=op_dtype,
170
- input_dtype=op_dtype,
171
267
  )
172
268
  if node.op_type == "Mod":
173
269
  fmod = int(node.attrs.get("fmod", 0))
@@ -201,18 +297,21 @@ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | U
201
297
  input0_shape = value_shape(graph, node.inputs[0], node)
202
298
  input1_shape = value_shape(graph, node.inputs[1], node)
203
299
  output_shape = value_shape(graph, node.outputs[0], node)
204
- return BinaryOp(
300
+ op = BinaryOp(
205
301
  input0=node.inputs[0],
206
302
  input1=node.inputs[1],
207
303
  output=node.outputs[0],
208
304
  function=function,
209
305
  operator_kind=op_spec.kind,
210
- input0_shape=input0_shape,
211
- input1_shape=input1_shape,
212
- shape=output_shape,
213
- dtype=output_dtype,
214
- input_dtype=input_dtype,
215
306
  )
307
+ if isinstance(graph, GraphContext):
308
+ inferred_shape = _infer_binary_output_shape(
309
+ function=function,
310
+ input0_shape=input0_shape,
311
+ input1_shape=input1_shape,
312
+ )
313
+ graph.set_shape(node.outputs[0], inferred_shape)
314
+ return op
216
315
  op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
217
316
  op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
218
317
  unary_symbol = unary_op_symbol(function, dtype=op_dtype)
@@ -226,32 +325,36 @@ def _lower_binary_unary(graph: Graph | GraphContext, node: Node) -> BinaryOp | U
226
325
  input0_shape = value_shape(graph, node.inputs[0], node)
227
326
  input1_shape = value_shape(graph, node.inputs[1], node)
228
327
  output_shape = value_shape(graph, node.outputs[0], node)
229
- return BinaryOp(
328
+ op = BinaryOp(
230
329
  input0=node.inputs[0],
231
330
  input1=node.inputs[1],
232
331
  output=node.outputs[0],
233
332
  function=function,
234
333
  operator_kind=op_spec.kind,
235
- input0_shape=input0_shape,
236
- input1_shape=input1_shape,
237
- shape=output_shape,
238
- dtype=op_dtype,
239
- input_dtype=op_dtype,
240
334
  )
335
+ if isinstance(graph, GraphContext):
336
+ inferred_shape = _infer_binary_output_shape(
337
+ function=function,
338
+ input0_shape=input0_shape,
339
+ input1_shape=input1_shape,
340
+ )
341
+ graph.set_shape(node.outputs[0], inferred_shape)
342
+ return op
241
343
  if len(node.inputs) != 1 or len(node.outputs) != 1:
242
344
  raise UnsupportedOpError(
243
345
  f"{node.op_type} must have 1 input and 1 output"
244
346
  )
245
347
  output_shape = value_shape(graph, node.outputs[0], node)
246
- return UnaryOp(
348
+ op = UnaryOp(
247
349
  input0=node.inputs[0],
248
350
  output=node.outputs[0],
249
351
  function=function,
250
- shape=output_shape,
251
- dtype=op_dtype,
252
- input_dtype=op_dtype,
253
352
  params=(),
254
353
  )
354
+ if isinstance(graph, GraphContext):
355
+ inferred_shape = value_shape(graph, node.inputs[0], node)
356
+ graph.set_shape(node.outputs[0], inferred_shape)
357
+ return op
255
358
 
256
359
 
257
360
  _DEFAULT_ELEMENTWISE_TYPES = (
@@ -283,9 +386,6 @@ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
283
386
  input0=node.inputs[0],
284
387
  output=node.outputs[0],
285
388
  function=ScalarFunction.ISINF,
286
- shape=output_shape,
287
- dtype=output_dtype,
288
- input_dtype=input_dtype,
289
389
  params=(float(detect_negative), float(detect_positive)),
290
390
  )
291
391
 
@@ -305,8 +405,5 @@ def lower_isnan(graph: Graph, node: Node) -> UnaryOp:
305
405
  input0=node.inputs[0],
306
406
  output=node.outputs[0],
307
407
  function=ScalarFunction.ISNAN,
308
- shape=output_shape,
309
- dtype=output_dtype,
310
- input_dtype=input_dtype,
311
408
  params=(),
312
409
  )
@@ -1,8 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..ir.ops import GatherOp
4
3
  from ..errors import UnsupportedOpError
4
+ from ..ir.context import GraphContext
5
5
  from ..ir.model import Graph, Node
6
+ from ..ir.ops import GatherOp
7
+ from ..lowering.common import value_shape
8
+ from ..validation import normalize_axis
6
9
  from .registry import register_lowering
7
10
 
8
11
 
@@ -11,9 +14,15 @@ def lower_gather(graph: Graph, node: Node) -> GatherOp:
11
14
  if len(node.inputs) != 2 or len(node.outputs) != 1:
12
15
  raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
13
16
  data_name, indices_name = node.inputs
17
+ data_shape = value_shape(graph, data_name, node)
18
+ indices_shape = value_shape(graph, indices_name, node)
19
+ axis = normalize_axis(int(node.attrs.get("axis", 0)), data_shape, node)
20
+ output_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1 :]
21
+ if isinstance(graph, GraphContext):
22
+ graph.set_shape(node.outputs[0], output_shape)
14
23
  return GatherOp(
15
24
  data=data_name,
16
25
  indices=indices_name,
17
26
  output=node.outputs[0],
18
- axis=int(node.attrs.get("axis", 0)),
27
+ axis=axis,
19
28
  )
@@ -1,139 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
-
5
- from shared.scalar_types import ScalarType
6
-
7
3
  from ..ir.ops import GemmOp
8
- from ..errors import ShapeInferenceError, UnsupportedOpError
4
+ from ..errors import UnsupportedOpError
9
5
  from ..ir.model import Graph, Node
10
- from .common import node_dtype as _node_dtype
11
- from .common import value_shape as _value_shape
12
6
  from .registry import register_lowering
13
7
 
14
8
 
15
- @dataclass(frozen=True)
16
- class GemmSpec:
17
- m: int
18
- n: int
19
- k: int
20
- alpha: float | int
21
- beta: float | int
22
- trans_a: bool
23
- trans_b: bool
24
- c_shape: tuple[int, ...] | None
25
-
26
-
27
- def resolve_gemm_spec(graph: Graph, node: Node, dtype: ScalarType) -> GemmSpec:
28
- if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
29
- raise UnsupportedOpError("Gemm must have 2 or 3 inputs and 1 output")
30
- alpha, beta, trans_a, trans_b = _resolve_gemm_attrs(node, dtype)
31
- input0_shape = _value_shape(graph, node.inputs[0], node)
32
- input1_shape = _value_shape(graph, node.inputs[1], node)
33
- if len(input0_shape) != 2 or len(input1_shape) != 2:
34
- raise UnsupportedOpError(
35
- "Gemm supports 2D inputs only, "
36
- f"got {input0_shape} x {input1_shape}"
37
- )
38
- if trans_a:
39
- m, k_left = input0_shape[1], input0_shape[0]
40
- else:
41
- m, k_left = input0_shape
42
- if trans_b:
43
- n, k_right = input1_shape[0], input1_shape[1]
44
- else:
45
- k_right, n = input1_shape
46
- if k_left != k_right:
47
- raise ShapeInferenceError(
48
- f"Gemm inner dimensions must match, got {k_left} and {k_right}"
49
- )
50
- output_shape = _value_shape(graph, node.outputs[0], node)
51
- if output_shape != (m, n):
52
- raise ShapeInferenceError(
53
- f"Gemm output shape must be {(m, n)}, got {output_shape}"
54
- )
55
- c_shape = None
56
- if len(node.inputs) == 3:
57
- bias_shape = _value_shape(graph, node.inputs[2], node)
58
- c_shape = validate_gemm_bias_shape((m, n), bias_shape, node)
59
- return GemmSpec(
60
- m=m,
61
- n=n,
62
- k=k_left,
63
- alpha=alpha,
64
- beta=beta,
65
- trans_a=trans_a,
66
- trans_b=trans_b,
67
- c_shape=c_shape,
68
- )
69
-
70
-
71
- def _resolve_gemm_attrs(
72
- node: Node, dtype: ScalarType
73
- ) -> tuple[float | int, float | int, bool, bool]:
74
- alpha = float(node.attrs.get("alpha", 1.0))
75
- beta = float(node.attrs.get("beta", 1.0))
76
- trans_a = int(node.attrs.get("transA", 0))
77
- trans_b = int(node.attrs.get("transB", 0))
78
- if trans_a not in {0, 1} or trans_b not in {0, 1}:
79
- raise UnsupportedOpError(
80
- "Gemm only supports transA/transB values of 0 or 1"
81
- )
82
- if dtype == ScalarType.BOOL:
83
- raise UnsupportedOpError("Gemm supports numeric inputs only")
84
- if not dtype.is_float:
85
- alpha_int = int(alpha)
86
- beta_int = int(beta)
87
- if alpha != alpha_int or beta != beta_int:
88
- raise UnsupportedOpError(
89
- "Gemm alpha and beta must be integers for non-float inputs"
90
- )
91
- alpha = alpha_int
92
- beta = beta_int
93
- return alpha, beta, bool(trans_a), bool(trans_b)
94
-
95
-
96
- def validate_gemm_bias_shape(
97
- output_shape: tuple[int, int], bias_shape: tuple[int, ...], node: Node
98
- ) -> tuple[int, ...]:
99
- if len(bias_shape) == 0:
100
- return bias_shape
101
- if len(bias_shape) == 1:
102
- if bias_shape[0] not in {1, output_shape[1]}:
103
- raise ShapeInferenceError(
104
- "Gemm bias input must be broadcastable to output shape, "
105
- f"got {bias_shape} vs {output_shape}"
106
- )
107
- return bias_shape
108
- if len(bias_shape) == 2:
109
- m, n = output_shape
110
- if bias_shape[0] not in {1, m} or bias_shape[1] not in {1, n}:
111
- raise ShapeInferenceError(
112
- "Gemm bias input must be broadcastable to output shape, "
113
- f"got {bias_shape} vs {output_shape}"
114
- )
115
- return bias_shape
116
- raise ShapeInferenceError(
117
- f"Gemm bias input must be rank 1 or 2, got {bias_shape}"
118
- )
119
-
120
-
121
9
  @register_lowering("Gemm")
122
10
  def lower_gemm(graph: Graph, node: Node) -> GemmOp:
123
- op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
124
- spec = resolve_gemm_spec(graph, node, op_dtype)
11
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
12
+ raise UnsupportedOpError("Gemm must have 2 or 3 inputs and 1 output")
125
13
  return GemmOp(
126
14
  input_a=node.inputs[0],
127
15
  input_b=node.inputs[1],
128
16
  input_c=node.inputs[2] if len(node.inputs) == 3 else None,
129
17
  output=node.outputs[0],
130
- m=spec.m,
131
- n=spec.n,
132
- k=spec.k,
133
- trans_a=spec.trans_a,
134
- trans_b=spec.trans_b,
135
- alpha=spec.alpha,
136
- beta=spec.beta,
137
- c_shape=spec.c_shape,
138
- dtype=op_dtype,
18
+ alpha=float(node.attrs.get("alpha", 1.0)),
19
+ beta=float(node.attrs.get("beta", 1.0)),
20
+ trans_a=int(node.attrs.get("transA", 0)),
21
+ trans_b=int(node.attrs.get("transB", 0)),
139
22
  )
@@ -45,15 +45,10 @@ def lower_global_max_pool(graph: Graph, node: Node) -> ReduceOp:
45
45
  return ReduceOp(
46
46
  input0=node.inputs[0],
47
47
  output=node.outputs[0],
48
- input_shape=input_shape,
49
- output_shape=output_shape,
50
48
  axes=axes,
51
49
  axes_input=None,
52
- axes_input_shape=None,
53
- axes_input_dtype=None,
54
50
  keepdims=True,
55
51
  noop_with_empty_axes=False,
56
52
  reduce_kind="max",
57
53
  reduce_count=None,
58
- dtype=op_dtype,
59
54
  )