emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,323 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable, Sequence
5
+
6
+ from shared.scalar_types import ScalarType
7
+
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from .common import node_dtype, optional_name, value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+ ACTIVATION_KIND_BY_NAME = {
14
+ "Relu": 0,
15
+ "Tanh": 1,
16
+ "Sigmoid": 2,
17
+ "Affine": 3,
18
+ "LeakyRelu": 4,
19
+ "ThresholdedRelu": 5,
20
+ "ScaledTanh": 6,
21
+ "HardSigmoid": 7,
22
+ "Elu": 8,
23
+ "Softsign": 9,
24
+ "Softplus": 10,
25
+ }
26
+
27
+ DEFAULT_ACTIVATIONS = ("Sigmoid", "Tanh")
28
+
29
+ DEFAULT_ALPHA_BY_NAME = {
30
+ "Affine": 1.0,
31
+ "LeakyRelu": 0.01,
32
+ "ThresholdedRelu": 1.0,
33
+ "ScaledTanh": 1.0,
34
+ "HardSigmoid": 0.2,
35
+ "Elu": 1.0,
36
+ }
37
+
38
+ DEFAULT_BETA_BY_NAME = {
39
+ "Affine": 0.0,
40
+ "ScaledTanh": 1.0,
41
+ "HardSigmoid": 0.5,
42
+ }
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class GruSpec:
47
+ input_x: str
48
+ input_w: str
49
+ input_r: str
50
+ input_b: str | None
51
+ input_sequence_lens: str | None
52
+ input_initial_h: str | None
53
+ output_y: str | None
54
+ output_y_h: str | None
55
+ seq_length: int
56
+ batch_size: int
57
+ input_size: int
58
+ hidden_size: int
59
+ num_directions: int
60
+ direction: str
61
+ layout: int
62
+ linear_before_reset: int
63
+ clip: float | None
64
+ activation_kinds: tuple[int, ...]
65
+ activation_alphas: tuple[float, ...]
66
+ activation_betas: tuple[float, ...]
67
+ dtype: ScalarType
68
+ sequence_lens_dtype: ScalarType | None
69
+
70
+
71
+ def _normalize_activation_names(values: Iterable[object]) -> list[str]:
72
+ names: list[str] = []
73
+ for value in values:
74
+ if isinstance(value, bytes):
75
+ value = value.decode("utf-8")
76
+ if not isinstance(value, str):
77
+ raise UnsupportedOpError("GRU activations must be strings")
78
+ names.append(value)
79
+ return names
80
+
81
+
82
+ def _resolve_activation_params(
83
+ activations: Sequence[str],
84
+ activation_alpha: Sequence[float] | None,
85
+ activation_beta: Sequence[float] | None,
86
+ ) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
87
+ if activation_alpha is None:
88
+ activation_alpha = []
89
+ if activation_beta is None:
90
+ activation_beta = []
91
+ if activation_alpha and len(activation_alpha) != len(activations):
92
+ raise UnsupportedOpError("GRU activation_alpha must match activations")
93
+ if activation_beta and len(activation_beta) != len(activations):
94
+ raise UnsupportedOpError("GRU activation_beta must match activations")
95
+ activation_kinds: list[int] = []
96
+ alphas: list[float] = []
97
+ betas: list[float] = []
98
+ for idx, name in enumerate(activations):
99
+ kind = ACTIVATION_KIND_BY_NAME.get(name)
100
+ if kind is None:
101
+ raise UnsupportedOpError(f"Unsupported GRU activation {name}")
102
+ activation_kinds.append(kind)
103
+ if activation_alpha:
104
+ alpha = float(activation_alpha[idx])
105
+ else:
106
+ alpha = DEFAULT_ALPHA_BY_NAME.get(name, 1.0)
107
+ if activation_beta:
108
+ beta = float(activation_beta[idx])
109
+ else:
110
+ beta = DEFAULT_BETA_BY_NAME.get(name, 0.0)
111
+ alphas.append(alpha)
112
+ betas.append(beta)
113
+ return tuple(activation_kinds), tuple(alphas), tuple(betas)
114
+
115
+
116
+ def _resolve_activations(
117
+ num_directions: int, attrs: dict[str, object]
118
+ ) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
119
+ activations_attr = attrs.get("activations")
120
+ if activations_attr is None:
121
+ activations = list(DEFAULT_ACTIVATIONS)
122
+ else:
123
+ activations = _normalize_activation_names(activations_attr)
124
+ if num_directions == 1:
125
+ if len(activations) != 2:
126
+ raise UnsupportedOpError("GRU activations must have length 2")
127
+ else:
128
+ if len(activations) == 2:
129
+ activations = activations * 2
130
+ elif len(activations) != 4:
131
+ raise UnsupportedOpError("Bidirectional GRU activations must be length 4")
132
+ activation_alpha = attrs.get("activation_alpha")
133
+ activation_beta = attrs.get("activation_beta")
134
+ return _resolve_activation_params(
135
+ activations,
136
+ activation_alpha,
137
+ activation_beta,
138
+ )
139
+
140
+
141
+ def _expect_shape(
142
+ name: str, shape: tuple[int, ...], expected: tuple[int, ...]
143
+ ) -> None:
144
+ if shape != expected:
145
+ raise UnsupportedOpError(
146
+ f"GRU input {name} must have shape {expected}, got {shape}"
147
+ )
148
+
149
+
150
+ def _validate_direction(direction: str, num_directions: int) -> None:
151
+ if direction == "bidirectional" and num_directions != 2:
152
+ raise UnsupportedOpError(
153
+ "GRU expects num_directions=2 for bidirectional models"
154
+ )
155
+ if direction in {"forward", "reverse"} and num_directions != 1:
156
+ raise UnsupportedOpError(
157
+ "GRU expects num_directions=1 for forward/reverse models"
158
+ )
159
+ if direction not in {"forward", "reverse", "bidirectional"}:
160
+ raise UnsupportedOpError(f"Unsupported GRU direction {direction}")
161
+
162
+
163
+ def resolve_gru_spec(graph: Graph, node: Node) -> GruSpec:
164
+ if len(node.inputs) < 3 or len(node.inputs) > 6:
165
+ raise UnsupportedOpError("GRU expects between 3 and 6 inputs")
166
+ if len(node.outputs) < 1 or len(node.outputs) > 2:
167
+ raise UnsupportedOpError("GRU expects between 1 and 2 outputs")
168
+ input_x = node.inputs[0]
169
+ input_w = node.inputs[1]
170
+ input_r = node.inputs[2]
171
+ input_b = optional_name(node.inputs, 3)
172
+ input_sequence_lens = optional_name(node.inputs, 4)
173
+ input_initial_h = optional_name(node.inputs, 5)
174
+ output_y = optional_name(node.outputs, 0)
175
+ output_y_h = optional_name(node.outputs, 1)
176
+ if output_y is None and output_y_h is None:
177
+ raise UnsupportedOpError("GRU expects at least one output")
178
+ op_dtype = node_dtype(
179
+ graph,
180
+ node,
181
+ input_x,
182
+ input_w,
183
+ input_r,
184
+ *(name for name in (input_b, input_initial_h) if name),
185
+ *(name for name in (output_y, output_y_h) if name),
186
+ )
187
+ if not op_dtype.is_float:
188
+ raise UnsupportedOpError(
189
+ "GRU supports float16, float, and double inputs only"
190
+ )
191
+ x_shape = value_shape(graph, input_x, node)
192
+ if len(x_shape) != 3:
193
+ raise UnsupportedOpError("GRU input X must be rank 3")
194
+ layout = int(node.attrs.get("layout", 0))
195
+ if layout not in {0, 1}:
196
+ raise UnsupportedOpError("GRU layout must be 0 or 1")
197
+ if layout == 0:
198
+ seq_length, batch_size, input_size = x_shape
199
+ else:
200
+ batch_size, seq_length, input_size = x_shape
201
+ w_shape = value_shape(graph, input_w, node)
202
+ if len(w_shape) != 3:
203
+ raise UnsupportedOpError("GRU input W must be rank 3")
204
+ num_directions = w_shape[0]
205
+ hidden_size_attr = node.attrs.get("hidden_size")
206
+ if hidden_size_attr is None:
207
+ if w_shape[1] % 3 != 0:
208
+ raise UnsupportedOpError("GRU W shape is not divisible by 3")
209
+ hidden_size = w_shape[1] // 3
210
+ else:
211
+ hidden_size = int(hidden_size_attr)
212
+ direction = str(node.attrs.get("direction", "forward"))
213
+ _validate_direction(direction, num_directions)
214
+ expected_w_shape = (num_directions, 3 * hidden_size, input_size)
215
+ _expect_shape(input_w, w_shape, expected_w_shape)
216
+ r_shape = value_shape(graph, input_r, node)
217
+ expected_r_shape = (num_directions, 3 * hidden_size, hidden_size)
218
+ _expect_shape(input_r, r_shape, expected_r_shape)
219
+ if input_b is not None:
220
+ b_shape = value_shape(graph, input_b, node)
221
+ _expect_shape(input_b, b_shape, (num_directions, 6 * hidden_size))
222
+ if input_sequence_lens is not None:
223
+ seq_dtype = value_dtype(graph, input_sequence_lens, node)
224
+ if seq_dtype not in {ScalarType.I32, ScalarType.I64}:
225
+ raise UnsupportedOpError("GRU sequence_lens must be int32 or int64")
226
+ seq_shape = value_shape(graph, input_sequence_lens, node)
227
+ if seq_shape != (batch_size,):
228
+ raise UnsupportedOpError("GRU sequence_lens must match batch size")
229
+ state_shape = (
230
+ (num_directions, batch_size, hidden_size)
231
+ if layout == 0
232
+ else (batch_size, num_directions, hidden_size)
233
+ )
234
+ if input_initial_h is not None:
235
+ _expect_shape(
236
+ input_initial_h,
237
+ value_shape(graph, input_initial_h, node),
238
+ state_shape,
239
+ )
240
+ if output_y is not None:
241
+ expected_y_shape = (
242
+ (seq_length, num_directions, batch_size, hidden_size)
243
+ if layout == 0
244
+ else (batch_size, seq_length, num_directions, hidden_size)
245
+ )
246
+ _expect_shape(output_y, value_shape(graph, output_y, node), expected_y_shape)
247
+ if output_y_h is not None:
248
+ _expect_shape(
249
+ output_y_h,
250
+ value_shape(graph, output_y_h, node),
251
+ state_shape,
252
+ )
253
+ linear_before_reset = int(node.attrs.get("linear_before_reset", 0))
254
+ if linear_before_reset not in {0, 1}:
255
+ raise UnsupportedOpError("GRU linear_before_reset must be 0 or 1")
256
+ clip = node.attrs.get("clip")
257
+ if clip is not None:
258
+ clip = float(clip)
259
+ if clip < 0:
260
+ raise UnsupportedOpError("GRU clip must be non-negative")
261
+ activation_kinds, activation_alphas, activation_betas = _resolve_activations(
262
+ num_directions, node.attrs
263
+ )
264
+ sequence_lens_dtype = (
265
+ value_dtype(graph, input_sequence_lens, node)
266
+ if input_sequence_lens is not None
267
+ else None
268
+ )
269
+ return GruSpec(
270
+ input_x=input_x,
271
+ input_w=input_w,
272
+ input_r=input_r,
273
+ input_b=input_b,
274
+ input_sequence_lens=input_sequence_lens,
275
+ input_initial_h=input_initial_h,
276
+ output_y=output_y,
277
+ output_y_h=output_y_h,
278
+ seq_length=seq_length,
279
+ batch_size=batch_size,
280
+ input_size=input_size,
281
+ hidden_size=hidden_size,
282
+ num_directions=num_directions,
283
+ direction=direction,
284
+ layout=layout,
285
+ linear_before_reset=linear_before_reset,
286
+ clip=clip,
287
+ activation_kinds=activation_kinds,
288
+ activation_alphas=activation_alphas,
289
+ activation_betas=activation_betas,
290
+ dtype=op_dtype,
291
+ sequence_lens_dtype=sequence_lens_dtype,
292
+ )
293
+
294
+
295
+ @register_lowering("GRU")
296
+ def lower_gru(graph: Graph, node: Node) -> "GruOp":
297
+ from ..ir.ops import GruOp
298
+
299
+ spec = resolve_gru_spec(graph, node)
300
+ return GruOp(
301
+ input_x=spec.input_x,
302
+ input_w=spec.input_w,
303
+ input_r=spec.input_r,
304
+ input_b=spec.input_b,
305
+ input_sequence_lens=spec.input_sequence_lens,
306
+ input_initial_h=spec.input_initial_h,
307
+ output_y=spec.output_y,
308
+ output_y_h=spec.output_y_h,
309
+ seq_length=spec.seq_length,
310
+ batch_size=spec.batch_size,
311
+ input_size=spec.input_size,
312
+ hidden_size=spec.hidden_size,
313
+ num_directions=spec.num_directions,
314
+ direction=spec.direction,
315
+ layout=spec.layout,
316
+ linear_before_reset=spec.linear_before_reset,
317
+ clip=spec.clip,
318
+ activation_kinds=spec.activation_kinds,
319
+ activation_alphas=spec.activation_alphas,
320
+ activation_betas=spec.activation_betas,
321
+ dtype=spec.dtype,
322
+ sequence_lens_dtype=spec.sequence_lens_dtype,
323
+ )
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..dtypes import dtype_info
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..ir.ops import HammingWindowOp
11
+ from ..lowering.common import value_dtype, value_shape
12
+ from .registry import register_lowering
13
+
14
+
15
+ _SUPPORTED_INPUT_DTYPES = {ScalarType.I32, ScalarType.I64}
16
+ _SUPPORTED_OUTPUT_DTYPES = {
17
+ ScalarType.U8,
18
+ ScalarType.U16,
19
+ ScalarType.U32,
20
+ ScalarType.U64,
21
+ ScalarType.I8,
22
+ ScalarType.I16,
23
+ ScalarType.I32,
24
+ ScalarType.I64,
25
+ ScalarType.F16,
26
+ ScalarType.F32,
27
+ ScalarType.F64,
28
+ }
29
+
30
+
31
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
32
+ for initializer in graph.initializers:
33
+ if initializer.name == name:
34
+ return initializer
35
+ return None
36
+
37
+
38
+ def _read_scalar_initializer(
39
+ graph: Graph, name: str, node: Node
40
+ ) -> int | None:
41
+ initializer = _find_initializer(graph, name)
42
+ if initializer is None:
43
+ return None
44
+ data = np.array(initializer.data)
45
+ if data.size != 1:
46
+ raise UnsupportedOpError(
47
+ f"{node.op_type} size input must be a scalar"
48
+ )
49
+ return int(data.reshape(-1)[0].item())
50
+
51
+
52
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
53
+ return shape == () or shape == (1,)
54
+
55
+
56
+ @register_lowering("HammingWindow")
57
+ def lower_hamming_window(graph: Graph, node: Node) -> HammingWindowOp:
58
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
59
+ raise UnsupportedOpError("HammingWindow must have 1 input and 1 output")
60
+ size_shape = value_shape(graph, node.inputs[0], node)
61
+ if not _is_scalar_shape(size_shape):
62
+ raise UnsupportedOpError("HammingWindow size input must be a scalar")
63
+ input_dtype = value_dtype(graph, node.inputs[0], node)
64
+ if input_dtype not in _SUPPORTED_INPUT_DTYPES:
65
+ raise UnsupportedOpError(
66
+ f"HammingWindow size input must be int32 or int64, got {input_dtype.onnx_name}"
67
+ )
68
+ output_shape = value_shape(graph, node.outputs[0], node)
69
+ if len(output_shape) != 1:
70
+ raise ShapeInferenceError("HammingWindow output must be 1D")
71
+ output_dtype = value_dtype(graph, node.outputs[0], node)
72
+ if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
73
+ raise UnsupportedOpError(
74
+ "HammingWindow output dtype must be numeric, "
75
+ f"got {output_dtype.onnx_name}"
76
+ )
77
+ output_datatype = node.attrs.get("output_datatype")
78
+ if output_datatype is not None:
79
+ attr_dtype = dtype_info(int(output_datatype))
80
+ if attr_dtype != output_dtype:
81
+ raise UnsupportedOpError(
82
+ "HammingWindow output_datatype does not match output dtype"
83
+ )
84
+ periodic = int(node.attrs.get("periodic", 1))
85
+ if periodic not in {0, 1}:
86
+ raise UnsupportedOpError("HammingWindow periodic must be 0 or 1")
87
+ size_value = _read_scalar_initializer(graph, node.inputs[0], node)
88
+ if size_value is not None:
89
+ if size_value < 0:
90
+ raise ShapeInferenceError(
91
+ "HammingWindow size must be non-negative"
92
+ )
93
+ if output_shape[0] != size_value:
94
+ raise ShapeInferenceError(
95
+ "HammingWindow output length does not match size input"
96
+ )
97
+ return HammingWindowOp(
98
+ size=node.inputs[0],
99
+ output=node.outputs[0],
100
+ output_shape=output_shape,
101
+ periodic=periodic == 1,
102
+ dtype=output_dtype,
103
+ input_dtype=input_dtype,
104
+ )
@@ -1,53 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- from shared.scalar_types import ScalarType
4
-
5
3
  from ..ir.ops import HardmaxOp
6
4
  from ..errors import UnsupportedOpError
7
5
  from ..ir.model import Graph, Node
8
- from .common import node_dtype as _node_dtype
9
- from .common import onnx_opset_version as _onnx_opset_version
10
- from .common import shape_product as _shape_product
11
- from .common import value_shape as _value_shape
12
6
  from .registry import register_lowering
13
- from ..validation import ensure_output_shape_matches_input
14
- from ..validation import normalize_axis as _normalize_axis
15
7
 
16
8
 
17
9
  @register_lowering("Hardmax")
18
10
  def lower_hardmax(graph: Graph, node: Node) -> HardmaxOp:
19
11
  if len(node.inputs) != 1 or len(node.outputs) != 1:
20
12
  raise UnsupportedOpError("Hardmax must have 1 input and 1 output")
21
- op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
22
- if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
23
- raise UnsupportedOpError(
24
- "Hardmax supports float16, float, and double inputs only"
25
- )
26
- input_shape = _value_shape(graph, node.inputs[0], node)
27
- output_shape = _value_shape(graph, node.outputs[0], node)
28
- ensure_output_shape_matches_input(node, input_shape, output_shape)
29
- opset_version = _onnx_opset_version(graph)
30
- default_axis = 1 if opset_version is not None and opset_version < 13 else -1
31
- axis_attr = node.attrs.get("axis", default_axis)
32
- axis = _normalize_axis(
33
- int(axis_attr),
34
- input_shape,
35
- node,
36
- )
37
- outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
38
- axis_size = input_shape[axis]
39
- inner = (
40
- _shape_product(input_shape[axis + 1 :])
41
- if axis + 1 < len(input_shape)
42
- else 1
43
- )
44
13
  return HardmaxOp(
45
14
  input0=node.inputs[0],
46
15
  output=node.outputs[0],
47
- outer=outer,
48
- axis_size=axis_size,
49
- inner=inner,
50
- axis=axis,
51
- shape=input_shape,
52
- dtype=op_dtype,
16
+ axis=int(node.attrs["axis"]) if "axis" in node.attrs else None,
53
17
  )
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..ir.ops import IdentityOp
4
3
  from ..errors import ShapeInferenceError, UnsupportedOpError
4
+ from ..ir.context import GraphContext
5
5
  from ..ir.model import Graph, Node
6
- from .common import value_dtype, value_shape
6
+ from ..ir.ops import IdentityOp
7
+ from .common import value_dtype, value_has_dim_params, value_shape
7
8
  from .registry import register_lowering
8
9
 
9
10
 
@@ -13,9 +14,10 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
13
14
  raise UnsupportedOpError("Identity must have 1 input and 1 output")
14
15
  input_shape = value_shape(graph, node.inputs[0], node)
15
16
  output_shape = value_shape(graph, node.outputs[0], node)
17
+ if value_has_dim_params(graph, node.outputs[0]) or not output_shape:
18
+ output_shape = ()
16
19
  input_dim_params = graph.find_value(node.inputs[0]).type.dim_params
17
20
  output_dim_params = graph.find_value(node.outputs[0]).type.dim_params
18
- resolved_shape = output_shape or input_shape
19
21
  if input_shape and output_shape:
20
22
  if len(input_shape) != len(output_shape):
21
23
  raise ShapeInferenceError("Identity input and output shapes must match")
@@ -35,10 +37,9 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
35
37
  "Identity expects matching input/output dtypes, "
36
38
  f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
37
39
  )
40
+ if isinstance(graph, GraphContext):
41
+ graph.set_shape(node.outputs[0], input_shape)
38
42
  return IdentityOp(
39
43
  input0=node.inputs[0],
40
44
  output=node.outputs[0],
41
- shape=resolved_shape,
42
- dtype=output_dtype,
43
- input_dtype=input_dtype,
44
45
  )
@@ -3,49 +3,15 @@ from __future__ import annotations
3
3
  from ..ir.ops import LogSoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
- from .common import node_dtype as _node_dtype
7
- from .common import onnx_opset_version as _onnx_opset_version
8
- from .common import shape_product as _shape_product
9
- from .common import value_shape as _value_shape
10
6
  from .registry import register_lowering
11
- from ..validation import ensure_output_shape_matches_input
12
- from ..validation import normalize_axis as _normalize_axis
13
7
 
14
8
 
15
9
  @register_lowering("LogSoftmax")
16
10
  def lower_logsoftmax(graph: Graph, node: Node) -> LogSoftmaxOp:
17
11
  if len(node.inputs) != 1 or len(node.outputs) != 1:
18
12
  raise UnsupportedOpError("LogSoftmax must have 1 input and 1 output")
19
- op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
20
- if not op_dtype.is_float:
21
- raise UnsupportedOpError(
22
- "LogSoftmax supports float16, float, and double inputs only"
23
- )
24
- input_shape = _value_shape(graph, node.inputs[0], node)
25
- output_shape = _value_shape(graph, node.outputs[0], node)
26
- ensure_output_shape_matches_input(node, input_shape, output_shape)
27
- opset_version = _onnx_opset_version(graph)
28
- default_axis = 1 if opset_version is not None and opset_version < 13 else -1
29
- axis_attr = node.attrs.get("axis", default_axis)
30
- axis = _normalize_axis(
31
- int(axis_attr),
32
- input_shape,
33
- node,
34
- )
35
- outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
36
- axis_size = input_shape[axis]
37
- inner = (
38
- _shape_product(input_shape[axis + 1 :])
39
- if axis + 1 < len(input_shape)
40
- else 1
41
- )
42
13
  return LogSoftmaxOp(
43
14
  input0=node.inputs[0],
44
15
  output=node.outputs[0],
45
- outer=outer,
46
- axis_size=axis_size,
47
- inner=inner,
48
- axis=axis,
49
- shape=input_shape,
50
- dtype=op_dtype,
16
+ axis=int(node.attrs["axis"]) if "axis" in node.attrs else None,
51
17
  )
@@ -19,6 +19,8 @@ class LpPoolSpec:
19
19
  out_w: int
20
20
  kernel_h: int
21
21
  kernel_w: int
22
+ dilation_h: int
23
+ dilation_w: int
22
24
  stride_h: int
23
25
  stride_w: int
24
26
  pad_top: int
@@ -51,8 +53,10 @@ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
51
53
  if ceil_mode != 0:
52
54
  raise UnsupportedOpError("LpPool supports ceil_mode=0 only")
53
55
  dilations = tuple(int(value) for value in node.attrs.get("dilations", (1, 1)))
54
- if any(value != 1 for value in dilations):
55
- raise UnsupportedOpError("LpPool supports dilations=1 only")
56
+ if len(dilations) != 2:
57
+ raise UnsupportedOpError("LpPool expects 2D dilations")
58
+ if any(value < 1 for value in dilations):
59
+ raise UnsupportedOpError("LpPool requires dilations >= 1")
56
60
  kernel_shape = node.attrs.get("kernel_shape")
57
61
  if kernel_shape is None:
58
62
  raise UnsupportedOpError("LpPool requires kernel_shape")
@@ -75,8 +79,11 @@ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
75
79
  raise UnsupportedOpError("LpPool supports NCHW 2D inputs only")
76
80
  batch, channels, in_h, in_w = input_shape
77
81
  stride_h, stride_w = strides
78
- out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
79
- out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
82
+ dilation_h, dilation_w = dilations
83
+ effective_kernel_h = dilation_h * (kernel_h - 1) + 1
84
+ effective_kernel_w = dilation_w * (kernel_w - 1) + 1
85
+ out_h = (in_h + pad_top + pad_bottom - effective_kernel_h) // stride_h + 1
86
+ out_w = (in_w + pad_left + pad_right - effective_kernel_w) // stride_w + 1
80
87
  if out_h < 0 or out_w < 0:
81
88
  raise ShapeInferenceError("LpPool output shape must be non-negative")
82
89
  output_shape = _value_shape(graph, node.outputs[0], node)
@@ -95,6 +102,8 @@ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
95
102
  out_w=out_w,
96
103
  kernel_h=kernel_h,
97
104
  kernel_w=kernel_w,
105
+ dilation_h=dilation_h,
106
+ dilation_w=dilation_w,
98
107
  stride_h=stride_h,
99
108
  stride_w=stride_w,
100
109
  pad_top=pad_top,
@@ -130,6 +139,8 @@ def lower_lp_pool(graph: Graph, node: Node) -> LpPoolOp:
130
139
  out_w=spec.out_w,
131
140
  kernel_h=spec.kernel_h,
132
141
  kernel_w=spec.kernel_w,
142
+ dilation_h=spec.dilation_h,
143
+ dilation_w=spec.dilation_w,
133
144
  stride_h=spec.stride_h,
134
145
  stride_w=spec.stride_w,
135
146
  pad_top=spec.pad_top,