emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from enum import Enum
6
6
  from shared.scalar_functions import ScalarFunction
7
7
  from shared.scalar_types import ScalarType
8
8
 
9
- from ...errors import ShapeInferenceError
9
+ from ...errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..op_base import ConvLikeOpBase, GemmLikeOpBase, MatMulLikeOpBase, RenderableOpBase
11
11
  from ..op_context import OpContext
12
12
 
@@ -29,23 +29,117 @@ def _shape_product(shape: tuple[int, ...]) -> int:
29
29
  return product
30
30
 
31
31
 
32
+ def _broadcast_batch_shapes(
33
+ left: tuple[int, ...], right: tuple[int, ...]
34
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
35
+ max_rank = max(len(left), len(right))
36
+ left_padded = (1,) * (max_rank - len(left)) + left
37
+ right_padded = (1,) * (max_rank - len(right)) + right
38
+ broadcast_shape: list[int] = []
39
+ for left_dim, right_dim in zip(left_padded, right_padded):
40
+ if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
41
+ raise ShapeInferenceError(
42
+ "MatMul batch dimensions must be broadcastable, "
43
+ f"got {left} x {right}"
44
+ )
45
+ broadcast_shape.append(max(left_dim, right_dim))
46
+ return tuple(broadcast_shape), left_padded, right_padded
47
+
48
+
49
+ def _resolve_matmul_spec(
50
+ ctx: OpContext, input0: str, input1: str
51
+ ) -> dict[str, object]:
52
+ input0_shape = ctx.shape(input0)
53
+ input1_shape = ctx.shape(input1)
54
+ if len(input0_shape) < 1 or len(input1_shape) < 1:
55
+ raise UnsupportedOpError(
56
+ "MatMul inputs must be at least 1D, "
57
+ f"got {input0_shape} x {input1_shape}"
58
+ )
59
+ left_vector = len(input0_shape) == 1
60
+ right_vector = len(input1_shape) == 1
61
+ input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
62
+ input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
63
+ m, k_left = input0_effective[-2], input0_effective[-1]
64
+ k_right, n = input1_effective[-2], input1_effective[-1]
65
+ if k_left != k_right:
66
+ raise ShapeInferenceError(
67
+ f"MatMul inner dimensions must match, got {k_left} and {k_right}"
68
+ )
69
+ batch_shape, input0_batch_shape, input1_batch_shape = (
70
+ _broadcast_batch_shapes(
71
+ input0_effective[:-2],
72
+ input1_effective[:-2],
73
+ )
74
+ )
75
+ if left_vector and right_vector:
76
+ output_shape = batch_shape
77
+ elif left_vector:
78
+ output_shape = batch_shape + (n,)
79
+ elif right_vector:
80
+ output_shape = batch_shape + (m,)
81
+ else:
82
+ output_shape = batch_shape + (m, n)
83
+ return {
84
+ "input0_shape": input0_shape,
85
+ "input1_shape": input1_shape,
86
+ "output_shape": output_shape,
87
+ "batch_shape": batch_shape,
88
+ "input0_batch_shape": input0_batch_shape,
89
+ "input1_batch_shape": input1_batch_shape,
90
+ "m": m,
91
+ "n": n,
92
+ "k": k_left,
93
+ "left_vector": left_vector,
94
+ "right_vector": right_vector,
95
+ }
96
+
97
+
32
98
  @dataclass(frozen=True)
33
99
  class MatMulOp(MatMulLikeOpBase):
34
100
  input0: str
35
101
  input1: str
36
102
  output: str
37
- input0_shape: tuple[int, ...]
38
- input1_shape: tuple[int, ...]
39
- output_shape: tuple[int, ...]
40
- batch_shape: tuple[int, ...]
41
- input0_batch_shape: tuple[int, ...]
42
- input1_batch_shape: tuple[int, ...]
43
- m: int
44
- n: int
45
- k: int
46
- left_vector: bool
47
- right_vector: bool
48
- dtype: ScalarType
103
+
104
+ def infer_types(self, ctx: OpContext) -> None:
105
+ input0_dtype = ctx.dtype(self.input0)
106
+ input1_dtype = ctx.dtype(self.input1)
107
+ if input0_dtype != input1_dtype:
108
+ raise UnsupportedOpError(
109
+ "MatMul expects matching input dtypes, "
110
+ f"got {input0_dtype.onnx_name} and {input1_dtype.onnx_name}"
111
+ )
112
+ try:
113
+ output_dtype = ctx.dtype(self.output)
114
+ except ShapeInferenceError:
115
+ ctx.set_dtype(self.output, input0_dtype)
116
+ output_dtype = input0_dtype
117
+ if output_dtype != input0_dtype:
118
+ raise UnsupportedOpError(
119
+ "MatMul expects output dtype to match inputs, "
120
+ f"got {output_dtype.onnx_name} and {input0_dtype.onnx_name}"
121
+ )
122
+
123
+ def infer_shapes(self, ctx: OpContext) -> None:
124
+ spec = _resolve_matmul_spec(ctx, self.input0, self.input1)
125
+ output_shape = spec["output_shape"]
126
+ try:
127
+ expected = ctx.shape(self.output)
128
+ except ShapeInferenceError:
129
+ expected = None
130
+ if expected is not None and expected != output_shape:
131
+ raise ShapeInferenceError(
132
+ f"MatMul output shape must be {output_shape}, got {expected}"
133
+ )
134
+ ctx.set_shape(self.output, output_shape)
135
+ ctx.set_derived(self, "batch_shape", spec["batch_shape"])
136
+ ctx.set_derived(self, "input0_batch_shape", spec["input0_batch_shape"])
137
+ ctx.set_derived(self, "input1_batch_shape", spec["input1_batch_shape"])
138
+ ctx.set_derived(self, "m", spec["m"])
139
+ ctx.set_derived(self, "n", spec["n"])
140
+ ctx.set_derived(self, "k", spec["k"])
141
+ ctx.set_derived(self, "left_vector", spec["left_vector"])
142
+ ctx.set_derived(self, "right_vector", spec["right_vector"])
49
143
 
50
144
  @dataclass(frozen=True)
51
145
  class QLinearMatMulOp(MatMulLikeOpBase):
@@ -98,15 +192,139 @@ class GemmOp(GemmLikeOpBase):
98
192
  input_b: str
99
193
  input_c: str | None
100
194
  output: str
101
- m: int
102
- n: int
103
- k: int
104
- trans_a: bool
105
- trans_b: bool
195
+ trans_a: int
196
+ trans_b: int
106
197
  alpha: float | int
107
198
  beta: float | int
108
- c_shape: tuple[int, ...] | None
109
- dtype: ScalarType
199
+
200
+ @staticmethod
201
+ def _normalize_attrs(
202
+ dtype: ScalarType,
203
+ *,
204
+ alpha: float | int,
205
+ beta: float | int,
206
+ trans_a: int,
207
+ trans_b: int,
208
+ ) -> tuple[float | int, float | int, bool, bool]:
209
+ if trans_a not in {0, 1} or trans_b not in {0, 1}:
210
+ raise UnsupportedOpError(
211
+ "Gemm only supports transA/transB values of 0 or 1"
212
+ )
213
+ if dtype == ScalarType.BOOL:
214
+ raise UnsupportedOpError("Gemm supports numeric inputs only")
215
+ if not dtype.is_float:
216
+ alpha_int = int(alpha)
217
+ beta_int = int(beta)
218
+ if alpha != alpha_int or beta != beta_int:
219
+ raise UnsupportedOpError(
220
+ "Gemm alpha and beta must be integers for non-float inputs"
221
+ )
222
+ alpha = alpha_int
223
+ beta = beta_int
224
+ return alpha, beta, bool(trans_a), bool(trans_b)
225
+
226
+ @staticmethod
227
+ def _validate_bias_shape(
228
+ output_shape: tuple[int, int], bias_shape: tuple[int, ...]
229
+ ) -> tuple[int, ...]:
230
+ if len(bias_shape) == 0:
231
+ return bias_shape
232
+ if len(bias_shape) == 1:
233
+ if bias_shape[0] not in {1, output_shape[1]}:
234
+ raise ShapeInferenceError(
235
+ "Gemm bias input must be broadcastable to output shape, "
236
+ f"got {bias_shape} vs {output_shape}"
237
+ )
238
+ return bias_shape
239
+ if len(bias_shape) == 2:
240
+ m, n = output_shape
241
+ if bias_shape[0] not in {1, m} or bias_shape[1] not in {1, n}:
242
+ raise ShapeInferenceError(
243
+ "Gemm bias input must be broadcastable to output shape, "
244
+ f"got {bias_shape} vs {output_shape}"
245
+ )
246
+ return bias_shape
247
+ raise ShapeInferenceError(
248
+ f"Gemm bias input must be rank 1 or 2, got {bias_shape}"
249
+ )
250
+
251
+ def infer_types(self, ctx: OpContext) -> None:
252
+ input_a_dtype = ctx.dtype(self.input_a)
253
+ input_b_dtype = ctx.dtype(self.input_b)
254
+ if input_a_dtype != input_b_dtype:
255
+ raise UnsupportedOpError(
256
+ "Gemm expects matching input dtypes, "
257
+ f"got {input_a_dtype.onnx_name} and {input_b_dtype.onnx_name}"
258
+ )
259
+ if self.input_c is not None:
260
+ input_c_dtype = ctx.dtype(self.input_c)
261
+ if input_c_dtype != input_a_dtype:
262
+ raise UnsupportedOpError(
263
+ "Gemm expects bias dtype to match inputs, "
264
+ f"got {input_c_dtype.onnx_name} and {input_a_dtype.onnx_name}"
265
+ )
266
+ try:
267
+ output_dtype = ctx.dtype(self.output)
268
+ except ShapeInferenceError:
269
+ ctx.set_dtype(self.output, input_a_dtype)
270
+ output_dtype = input_a_dtype
271
+ if output_dtype != input_a_dtype:
272
+ raise UnsupportedOpError(
273
+ "Gemm expects output dtype to match inputs, "
274
+ f"got {output_dtype.onnx_name} and {input_a_dtype.onnx_name}"
275
+ )
276
+ alpha, beta, trans_a, trans_b = self._normalize_attrs(
277
+ output_dtype,
278
+ alpha=self.alpha,
279
+ beta=self.beta,
280
+ trans_a=self.trans_a,
281
+ trans_b=self.trans_b,
282
+ )
283
+ ctx.set_derived(self, "alpha", alpha)
284
+ ctx.set_derived(self, "beta", beta)
285
+ ctx.set_derived(self, "trans_a", trans_a)
286
+ ctx.set_derived(self, "trans_b", trans_b)
287
+
288
+ def infer_shapes(self, ctx: OpContext) -> None:
289
+ trans_a = ctx.require_derived(self, "trans_a")
290
+ trans_b = ctx.require_derived(self, "trans_b")
291
+ input_a_shape = ctx.shape(self.input_a)
292
+ input_b_shape = ctx.shape(self.input_b)
293
+ if len(input_a_shape) != 2 or len(input_b_shape) != 2:
294
+ raise UnsupportedOpError(
295
+ "Gemm supports 2D inputs only, "
296
+ f"got {input_a_shape} x {input_b_shape}"
297
+ )
298
+ if trans_a:
299
+ m, k_left = input_a_shape[1], input_a_shape[0]
300
+ else:
301
+ m, k_left = input_a_shape
302
+ if trans_b:
303
+ n, k_right = input_b_shape[0], input_b_shape[1]
304
+ else:
305
+ k_right, n = input_b_shape
306
+ if k_left != k_right:
307
+ raise ShapeInferenceError(
308
+ f"Gemm inner dimensions must match, got {k_left} and {k_right}"
309
+ )
310
+ output_shape = (m, n)
311
+ try:
312
+ expected = ctx.shape(self.output)
313
+ except ShapeInferenceError:
314
+ expected = None
315
+ if expected is not None and expected != output_shape:
316
+ raise ShapeInferenceError(
317
+ f"Gemm output shape must be {output_shape}, got {expected}"
318
+ )
319
+ ctx.set_shape(self.output, output_shape)
320
+ c_shape = None
321
+ if self.input_c is not None:
322
+ bias_shape = ctx.shape(self.input_c)
323
+ c_shape = self._validate_bias_shape(output_shape, bias_shape)
324
+ ctx.set_derived(self, "m", m)
325
+ ctx.set_derived(self, "n", n)
326
+ ctx.set_derived(self, "k", k_left)
327
+ ctx.set_derived(self, "c_shape", c_shape)
110
328
 
111
329
  @dataclass(frozen=True)
112
330
  class AttentionOp(RenderableOpBase):
@@ -205,6 +423,31 @@ class ConvOp(ConvLikeOpBase):
205
423
  raise ValueError("Conv output width is undefined for spatial_rank < 2")
206
424
  return self.out_spatial[1]
207
425
 
426
+ @dataclass(frozen=True)
427
+ class ConvIntegerOp(ConvLikeOpBase):
428
+ input0: str
429
+ weights: str
430
+ x_zero_point: str | None
431
+ w_zero_point: str | None
432
+ output: str
433
+ batch: int
434
+ in_channels: int
435
+ out_channels: int
436
+ spatial_rank: int
437
+ in_spatial: tuple[int, ...]
438
+ out_spatial: tuple[int, ...]
439
+ kernel_shape: tuple[int, ...]
440
+ strides: tuple[int, ...]
441
+ pads: tuple[int, ...]
442
+ dilations: tuple[int, ...]
443
+ group: int
444
+ input_dtype: ScalarType
445
+ weight_dtype: ScalarType
446
+ dtype: ScalarType
447
+ x_zero_point_shape: tuple[int, ...] | None
448
+ w_zero_point_shape: tuple[int, ...] | None
449
+ w_zero_point_per_channel: bool
450
+
208
451
  @dataclass(frozen=True)
209
452
  class ConvTransposeOp(ConvLikeOpBase):
210
453
  input0: str
@@ -237,6 +480,8 @@ class AveragePoolOp(RenderableOpBase):
237
480
  out_w: int
238
481
  kernel_h: int
239
482
  kernel_w: int
483
+ dilation_h: int
484
+ dilation_w: int
240
485
  stride_h: int
241
486
  stride_w: int
242
487
  pad_top: int
@@ -245,6 +490,14 @@ class AveragePoolOp(RenderableOpBase):
245
490
  pad_right: int
246
491
  count_include_pad: bool
247
492
  dtype: ScalarType
493
+ spatial_rank: int = 2
494
+ in_d: int = 1
495
+ out_d: int = 1
496
+ kernel_d: int = 1
497
+ dilation_d: int = 1
498
+ stride_d: int = 1
499
+ pad_front: int = 0
500
+ pad_back: int = 0
248
501
 
249
502
  @dataclass(frozen=True)
250
503
  class LpPoolOp(RenderableOpBase):
@@ -258,6 +511,8 @@ class LpPoolOp(RenderableOpBase):
258
511
  out_w: int
259
512
  kernel_h: int
260
513
  kernel_w: int
514
+ dilation_h: int
515
+ dilation_w: int
261
516
  stride_h: int
262
517
  stride_w: int
263
518
  pad_top: int
@@ -271,16 +526,29 @@ class LpPoolOp(RenderableOpBase):
271
526
  class SoftmaxOp(RenderableOpBase):
272
527
  input0: str
273
528
  output: str
274
- outer: int
275
- axis_size: int
276
- inner: int
277
- axis: int
278
- shape: tuple[int, ...]
279
- dtype: ScalarType
529
+ axis: int | None
530
+
531
+ def infer_types(self, ctx: OpContext) -> None:
532
+ input_dtype = ctx.dtype(self.input0)
533
+ if not input_dtype.is_float:
534
+ raise UnsupportedOpError(
535
+ "Softmax supports float16, float, and double inputs only"
536
+ )
537
+ try:
538
+ output_dtype = ctx.dtype(self.output)
539
+ except ShapeInferenceError:
540
+ ctx.set_dtype(self.output, input_dtype)
541
+ return None
542
+ if output_dtype != input_dtype:
543
+ raise UnsupportedOpError(
544
+ "Softmax expects output dtype to match input dtype"
545
+ )
280
546
 
281
547
  def infer_shapes(self, ctx: OpContext) -> None:
282
548
  input_shape = ctx.shape(self.input0)
283
549
  axis = self.axis
550
+ if axis is None:
551
+ axis = -1
284
552
  if axis < 0:
285
553
  axis += len(input_shape)
286
554
  if axis < 0 or axis >= len(input_shape):
@@ -295,6 +563,7 @@ class SoftmaxOp(RenderableOpBase):
295
563
  else 1
296
564
  )
297
565
  ctx.set_shape(self.output, input_shape)
566
+ ctx.set_derived(self, "axis", axis)
298
567
  ctx.set_derived(self, "outer", outer)
299
568
  ctx.set_derived(self, "axis_size", axis_size)
300
569
  ctx.set_derived(self, "inner", inner)
@@ -303,16 +572,29 @@ class SoftmaxOp(RenderableOpBase):
303
572
  class LogSoftmaxOp(RenderableOpBase):
304
573
  input0: str
305
574
  output: str
306
- outer: int
307
- axis_size: int
308
- inner: int
309
- axis: int
310
- shape: tuple[int, ...]
311
- dtype: ScalarType
575
+ axis: int | None
576
+
577
+ def infer_types(self, ctx: OpContext) -> None:
578
+ input_dtype = ctx.dtype(self.input0)
579
+ if not input_dtype.is_float:
580
+ raise UnsupportedOpError(
581
+ "LogSoftmax supports float16, float, and double inputs only"
582
+ )
583
+ try:
584
+ output_dtype = ctx.dtype(self.output)
585
+ except ShapeInferenceError:
586
+ ctx.set_dtype(self.output, input_dtype)
587
+ return None
588
+ if output_dtype != input_dtype:
589
+ raise UnsupportedOpError(
590
+ "LogSoftmax expects output dtype to match input dtype"
591
+ )
312
592
 
313
593
  def infer_shapes(self, ctx: OpContext) -> None:
314
594
  input_shape = ctx.shape(self.input0)
315
595
  axis = self.axis
596
+ if axis is None:
597
+ axis = -1
316
598
  if axis < 0:
317
599
  axis += len(input_shape)
318
600
  if axis < 0 or axis >= len(input_shape):
@@ -327,6 +609,7 @@ class LogSoftmaxOp(RenderableOpBase):
327
609
  else 1
328
610
  )
329
611
  ctx.set_shape(self.output, input_shape)
612
+ ctx.set_derived(self, "axis", axis)
330
613
  ctx.set_derived(self, "outer", outer)
331
614
  ctx.set_derived(self, "axis_size", axis_size)
332
615
  ctx.set_derived(self, "inner", inner)
@@ -335,16 +618,30 @@ class LogSoftmaxOp(RenderableOpBase):
335
618
  class HardmaxOp(RenderableOpBase):
336
619
  input0: str
337
620
  output: str
338
- outer: int
339
- axis_size: int
340
- inner: int
341
- axis: int
342
- shape: tuple[int, ...]
343
- dtype: ScalarType
621
+ axis: int | None
622
+
623
+ def infer_types(self, ctx: OpContext) -> None:
624
+ input_dtype = ctx.dtype(self.input0)
625
+ if input_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
626
+ raise UnsupportedOpError(
627
+ "Hardmax supports float16, float, and double inputs only"
628
+ )
629
+ try:
630
+ output_dtype = ctx.dtype(self.output)
631
+ except ShapeInferenceError:
632
+ ctx.set_dtype(self.output, input_dtype)
633
+ return None
634
+ if output_dtype != input_dtype:
635
+ raise UnsupportedOpError(
636
+ "Hardmax expects output dtype to match input dtype"
637
+ )
344
638
 
345
639
  def infer_shapes(self, ctx: OpContext) -> None:
346
640
  input_shape = ctx.shape(self.input0)
347
641
  axis = self.axis
642
+ if axis is None:
643
+ opset_version = ctx.opset_version()
644
+ axis = 1 if opset_version is not None and opset_version < 13 else -1
348
645
  if axis < 0:
349
646
  axis += len(input_shape)
350
647
  if axis < 0 or axis >= len(input_shape):
@@ -359,6 +656,7 @@ class HardmaxOp(RenderableOpBase):
359
656
  else 1
360
657
  )
361
658
  ctx.set_shape(self.output, input_shape)
659
+ ctx.set_derived(self, "axis", axis)
362
660
  ctx.set_derived(self, "outer", outer)
363
661
  ctx.set_derived(self, "axis_size", axis_size)
364
662
  ctx.set_derived(self, "inner", inner)
@@ -512,6 +810,31 @@ class LrnOp(RenderableOpBase):
512
810
  bias: float
513
811
  dtype: ScalarType
514
812
 
813
+ @dataclass(frozen=True)
814
+ class GruOp(RenderableOpBase):
815
+ input_x: str
816
+ input_w: str
817
+ input_r: str
818
+ input_b: str | None
819
+ input_sequence_lens: str | None
820
+ input_initial_h: str | None
821
+ output_y: str | None
822
+ output_y_h: str | None
823
+ seq_length: int
824
+ batch_size: int
825
+ input_size: int
826
+ hidden_size: int
827
+ num_directions: int
828
+ direction: str
829
+ layout: int
830
+ linear_before_reset: int
831
+ clip: float | None
832
+ activation_kinds: tuple[int, ...]
833
+ activation_alphas: tuple[float, ...]
834
+ activation_betas: tuple[float, ...]
835
+ dtype: ScalarType
836
+ sequence_lens_dtype: ScalarType | None
837
+
515
838
  @dataclass(frozen=True)
516
839
  class LstmOp(RenderableOpBase):
517
840
  input_x: str
@@ -2,8 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from shared.scalar_types import ScalarType
6
-
7
5
  from ..op_base import ReduceOpBase
8
6
  from ..op_context import OpContext
9
7
 
@@ -12,17 +10,12 @@ from ..op_context import OpContext
12
10
  class ReduceOp(ReduceOpBase):
13
11
  input0: str
14
12
  output: str
15
- input_shape: tuple[int, ...]
16
- output_shape: tuple[int, ...]
17
13
  axes: tuple[int, ...]
18
14
  axes_input: str | None
19
- axes_input_shape: tuple[int, ...] | None
20
- axes_input_dtype: ScalarType | None
21
15
  keepdims: bool
22
16
  noop_with_empty_axes: bool
23
17
  reduce_kind: str
24
18
  reduce_count: int | None
25
- dtype: ScalarType
26
19
 
27
20
  def infer_types(self, ctx: OpContext) -> None:
28
21
  ctx.dtype(self.output)
@@ -45,14 +38,10 @@ class ReduceOp(ReduceOpBase):
45
38
  class ArgReduceOp(ReduceOpBase):
46
39
  input0: str
47
40
  output: str
48
- input_shape: tuple[int, ...]
49
- output_shape: tuple[int, ...]
50
41
  axis: int
51
42
  keepdims: bool
52
43
  select_last_index: bool
53
44
  reduce_kind: str
54
- input_dtype: ScalarType
55
- output_dtype: ScalarType
56
45
 
57
46
  def infer_types(self, ctx: OpContext) -> None:
58
47
  ctx.dtype(self.input0)
@@ -71,17 +60,13 @@ class ArgReduceOp(ReduceOpBase):
71
60
  @dataclass(frozen=True)
72
61
  class TopKOp(ReduceOpBase):
73
62
  input0: str
63
+ k_input: str
74
64
  output_values: str
75
65
  output_indices: str
76
- input_shape: tuple[int, ...]
77
- output_shape: tuple[int, ...]
78
66
  axis: int
79
67
  k: int
80
68
  largest: bool
81
69
  sorted: bool
82
- input_dtype: ScalarType
83
- output_values_dtype: ScalarType
84
- output_indices_dtype: ScalarType
85
70
 
86
71
  def infer_types(self, ctx: OpContext) -> None:
87
72
  ctx.dtype(self.input0)
@@ -10,13 +10,16 @@ _LOWERING_MODULES = [
10
10
  "attention",
11
11
  "average_pool",
12
12
  "batch_normalization",
13
+ "bernoulli",
13
14
  "cast",
14
15
  "concat",
15
16
  "constant_of_shape",
16
17
  "conv",
18
+ "conv_integer",
17
19
  "conv_transpose",
18
20
  "cumsum",
19
21
  "depth_space",
22
+ "dequantize_linear",
20
23
  "dropout",
21
24
  "einsum",
22
25
  "elementwise",
@@ -29,8 +32,10 @@ _LOWERING_MODULES = [
29
32
  "gemm",
30
33
  "global_max_pool",
31
34
  "grid_sample",
35
+ "gru",
32
36
  "group_normalization",
33
37
  "hardmax",
38
+ "hamming_window",
34
39
  "identity",
35
40
  "instance_normalization",
36
41
  "layer_normalization",
@@ -45,9 +50,11 @@ _LOWERING_MODULES = [
45
50
  "negative_log_likelihood_loss",
46
51
  "non_max_suppression",
47
52
  "nonzero",
53
+ "optional_has_element",
48
54
  "one_hot",
49
55
  "pad",
50
56
  "qlinear_matmul",
57
+ "qlinear_mul",
51
58
  "quantize_linear",
52
59
  "range",
53
60
  "reduce",
@@ -64,11 +71,13 @@ _LOWERING_MODULES = [
64
71
  "split",
65
72
  "squeeze",
66
73
  "tensor_scatter",
74
+ "tfidf_vectorizer",
67
75
  "tile",
68
76
  "topk",
69
77
  "transpose",
70
78
  "trilu",
71
79
  "unsqueeze",
80
+ "upsample",
72
81
  "variadic",
73
82
  "where",
74
83
  ]
@@ -84,14 +84,10 @@ def lower_arg_reduce(graph: Graph, node: Node) -> ArgReduceOp:
84
84
  return ArgReduceOp(
85
85
  input0=input_name,
86
86
  output=output_name,
87
- input_shape=input_shape,
88
- output_shape=output_shape,
89
87
  axis=axis,
90
88
  keepdims=keepdims,
91
89
  select_last_index=select_last_index,
92
90
  reduce_kind=ARG_REDUCE_KIND_BY_OP[node.op_type],
93
- input_dtype=input_dtype,
94
- output_dtype=output_dtype,
95
91
  )
96
92
 
97
93