emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,355 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable, Sequence
5
+
6
+ from shared.scalar_types import ScalarType
7
+
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from .common import node_dtype, optional_name, value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+ ACTIVATION_KIND_BY_NAME = {
14
+ "Relu": 0,
15
+ "Tanh": 1,
16
+ "Sigmoid": 2,
17
+ "Affine": 3,
18
+ "LeakyRelu": 4,
19
+ "ThresholdedRelu": 5,
20
+ "ScaledTanh": 6,
21
+ "HardSigmoid": 7,
22
+ "Elu": 8,
23
+ "Softsign": 9,
24
+ "Softplus": 10,
25
+ }
26
+
27
+ DEFAULT_ACTIVATIONS = ("Sigmoid", "Tanh", "Tanh")
28
+
29
+ DEFAULT_ALPHA_BY_NAME = {
30
+ "Affine": 1.0,
31
+ "LeakyRelu": 0.01,
32
+ "ThresholdedRelu": 1.0,
33
+ "ScaledTanh": 1.0,
34
+ "HardSigmoid": 0.2,
35
+ "Elu": 1.0,
36
+ }
37
+
38
+ DEFAULT_BETA_BY_NAME = {
39
+ "Affine": 0.0,
40
+ "ScaledTanh": 1.0,
41
+ "HardSigmoid": 0.5,
42
+ }
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class LstmSpec:
47
+ input_x: str
48
+ input_w: str
49
+ input_r: str
50
+ input_b: str | None
51
+ input_sequence_lens: str | None
52
+ input_initial_h: str | None
53
+ input_initial_c: str | None
54
+ input_p: str | None
55
+ output_y: str | None
56
+ output_y_h: str | None
57
+ output_y_c: str | None
58
+ seq_length: int
59
+ batch_size: int
60
+ input_size: int
61
+ hidden_size: int
62
+ num_directions: int
63
+ direction: str
64
+ layout: int
65
+ input_forget: int
66
+ clip: float | None
67
+ activation_kinds: tuple[int, ...]
68
+ activation_alphas: tuple[float, ...]
69
+ activation_betas: tuple[float, ...]
70
+ dtype: ScalarType
71
+ sequence_lens_dtype: ScalarType | None
72
+
73
+
74
+ def _normalize_activation_names(values: Iterable[object]) -> list[str]:
75
+ names: list[str] = []
76
+ for value in values:
77
+ if isinstance(value, bytes):
78
+ value = value.decode("utf-8")
79
+ if not isinstance(value, str):
80
+ raise UnsupportedOpError("LSTM activations must be strings")
81
+ names.append(value)
82
+ return names
83
+
84
+
85
+ def _resolve_activation_params(
86
+ activations: Sequence[str],
87
+ activation_alpha: Sequence[float] | None,
88
+ activation_beta: Sequence[float] | None,
89
+ ) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
90
+ if activation_alpha is None:
91
+ activation_alpha = []
92
+ if activation_beta is None:
93
+ activation_beta = []
94
+ if activation_alpha and len(activation_alpha) != len(activations):
95
+ raise UnsupportedOpError("LSTM activation_alpha must match activations")
96
+ if activation_beta and len(activation_beta) != len(activations):
97
+ raise UnsupportedOpError("LSTM activation_beta must match activations")
98
+ activation_kinds: list[int] = []
99
+ alphas: list[float] = []
100
+ betas: list[float] = []
101
+ for idx, name in enumerate(activations):
102
+ kind = ACTIVATION_KIND_BY_NAME.get(name)
103
+ if kind is None:
104
+ raise UnsupportedOpError(f"Unsupported LSTM activation {name}")
105
+ activation_kinds.append(kind)
106
+ if activation_alpha:
107
+ alpha = float(activation_alpha[idx])
108
+ else:
109
+ alpha = DEFAULT_ALPHA_BY_NAME.get(name, 1.0)
110
+ if activation_beta:
111
+ beta = float(activation_beta[idx])
112
+ else:
113
+ beta = DEFAULT_BETA_BY_NAME.get(name, 0.0)
114
+ alphas.append(alpha)
115
+ betas.append(beta)
116
+ return tuple(activation_kinds), tuple(alphas), tuple(betas)
117
+
118
+
119
+ def _resolve_activations(
120
+ direction: str, num_directions: int, attrs: dict[str, object]
121
+ ) -> tuple[tuple[int, ...], tuple[float, ...], tuple[float, ...]]:
122
+ activations_attr = attrs.get("activations")
123
+ if activations_attr is None:
124
+ activations = list(DEFAULT_ACTIVATIONS)
125
+ else:
126
+ activations = _normalize_activation_names(activations_attr)
127
+ if num_directions == 1:
128
+ if len(activations) != 3:
129
+ raise UnsupportedOpError("LSTM activations must have length 3")
130
+ else:
131
+ if len(activations) == 3:
132
+ activations = activations * 2
133
+ elif len(activations) != 6:
134
+ raise UnsupportedOpError("Bidirectional LSTM activations must be length 6")
135
+ activation_alpha = attrs.get("activation_alpha")
136
+ activation_beta = attrs.get("activation_beta")
137
+ return _resolve_activation_params(
138
+ activations,
139
+ activation_alpha,
140
+ activation_beta,
141
+ )
142
+
143
+
144
+ def _expect_shape(
145
+ name: str, shape: tuple[int, ...], expected: tuple[int, ...]
146
+ ) -> None:
147
+ if shape != expected:
148
+ raise UnsupportedOpError(
149
+ f"LSTM input {name} must have shape {expected}, got {shape}"
150
+ )
151
+
152
+
153
+ def _validate_direction(direction: str, num_directions: int) -> None:
154
+ if direction == "bidirectional" and num_directions != 2:
155
+ raise UnsupportedOpError(
156
+ "LSTM expects num_directions=2 for bidirectional models"
157
+ )
158
+ if direction in {"forward", "reverse"} and num_directions != 1:
159
+ raise UnsupportedOpError(
160
+ "LSTM expects num_directions=1 for forward/reverse models"
161
+ )
162
+ if direction not in {"forward", "reverse", "bidirectional"}:
163
+ raise UnsupportedOpError(f"Unsupported LSTM direction {direction}")
164
+
165
+
166
+ def resolve_lstm_spec(graph: Graph, node: Node) -> LstmSpec:
167
+ if len(node.inputs) < 3 or len(node.inputs) > 8:
168
+ raise UnsupportedOpError("LSTM expects between 3 and 8 inputs")
169
+ if len(node.outputs) < 1 or len(node.outputs) > 3:
170
+ raise UnsupportedOpError("LSTM expects between 1 and 3 outputs")
171
+ input_x = node.inputs[0]
172
+ input_w = node.inputs[1]
173
+ input_r = node.inputs[2]
174
+ input_b = optional_name(node.inputs, 3)
175
+ input_sequence_lens = optional_name(node.inputs, 4)
176
+ input_initial_h = optional_name(node.inputs, 5)
177
+ input_initial_c = optional_name(node.inputs, 6)
178
+ input_p = optional_name(node.inputs, 7)
179
+ output_y = optional_name(node.outputs, 0)
180
+ output_y_h = optional_name(node.outputs, 1)
181
+ output_y_c = optional_name(node.outputs, 2)
182
+ if output_y is None and output_y_h is None and output_y_c is None:
183
+ raise UnsupportedOpError("LSTM expects at least one output")
184
+ op_dtype = node_dtype(
185
+ graph,
186
+ node,
187
+ input_x,
188
+ input_w,
189
+ input_r,
190
+ *(name for name in (input_b, input_initial_h, input_initial_c, input_p) if name),
191
+ *(name for name in (output_y, output_y_h, output_y_c) if name),
192
+ )
193
+ if not op_dtype.is_float:
194
+ raise UnsupportedOpError(
195
+ "LSTM supports float16, float, and double inputs only"
196
+ )
197
+ x_shape = value_shape(graph, input_x, node)
198
+ if len(x_shape) != 3:
199
+ raise UnsupportedOpError("LSTM input X must be rank 3")
200
+ layout = int(node.attrs.get("layout", 0))
201
+ if layout not in {0, 1}:
202
+ raise UnsupportedOpError("LSTM layout must be 0 or 1")
203
+ if layout == 0:
204
+ seq_length, batch_size, input_size = x_shape
205
+ else:
206
+ batch_size, seq_length, input_size = x_shape
207
+ w_shape = value_shape(graph, input_w, node)
208
+ if len(w_shape) != 3:
209
+ raise UnsupportedOpError("LSTM input W must be rank 3")
210
+ num_directions = w_shape[0]
211
+ hidden_size_attr = node.attrs.get("hidden_size")
212
+ if hidden_size_attr is None:
213
+ if w_shape[1] % 4 != 0:
214
+ raise UnsupportedOpError("LSTM W shape is not divisible by 4")
215
+ hidden_size = w_shape[1] // 4
216
+ else:
217
+ hidden_size = int(hidden_size_attr)
218
+ _validate_direction(str(node.attrs.get("direction", "forward")), num_directions)
219
+ direction = str(node.attrs.get("direction", "forward"))
220
+ expected_w_shape = (num_directions, 4 * hidden_size, input_size)
221
+ _expect_shape(input_w, w_shape, expected_w_shape)
222
+ r_shape = value_shape(graph, input_r, node)
223
+ expected_r_shape = (num_directions, 4 * hidden_size, hidden_size)
224
+ _expect_shape(input_r, r_shape, expected_r_shape)
225
+ if input_b is not None:
226
+ b_shape = value_shape(graph, input_b, node)
227
+ _expect_shape(input_b, b_shape, (num_directions, 8 * hidden_size))
228
+ if input_sequence_lens is not None:
229
+ seq_dtype = value_dtype(graph, input_sequence_lens, node)
230
+ if seq_dtype not in {ScalarType.I32, ScalarType.I64}:
231
+ raise UnsupportedOpError("LSTM sequence_lens must be int32 or int64")
232
+ seq_shape = value_shape(graph, input_sequence_lens, node)
233
+ if seq_shape != (batch_size,):
234
+ raise UnsupportedOpError(
235
+ "LSTM sequence_lens must match batch size"
236
+ )
237
+ state_shape = (
238
+ (num_directions, batch_size, hidden_size)
239
+ if layout == 0
240
+ else (batch_size, num_directions, hidden_size)
241
+ )
242
+ if input_initial_h is not None:
243
+ _expect_shape(
244
+ input_initial_h,
245
+ value_shape(graph, input_initial_h, node),
246
+ state_shape,
247
+ )
248
+ if input_initial_c is not None:
249
+ _expect_shape(
250
+ input_initial_c,
251
+ value_shape(graph, input_initial_c, node),
252
+ state_shape,
253
+ )
254
+ if input_p is not None:
255
+ _expect_shape(
256
+ input_p,
257
+ value_shape(graph, input_p, node),
258
+ (num_directions, 3 * hidden_size),
259
+ )
260
+ if output_y is not None:
261
+ expected_y_shape = (
262
+ (seq_length, num_directions, batch_size, hidden_size)
263
+ if layout == 0
264
+ else (batch_size, seq_length, num_directions, hidden_size)
265
+ )
266
+ _expect_shape(output_y, value_shape(graph, output_y, node), expected_y_shape)
267
+ if output_y_h is not None:
268
+ _expect_shape(
269
+ output_y_h,
270
+ value_shape(graph, output_y_h, node),
271
+ state_shape,
272
+ )
273
+ if output_y_c is not None:
274
+ _expect_shape(
275
+ output_y_c,
276
+ value_shape(graph, output_y_c, node),
277
+ state_shape,
278
+ )
279
+ input_forget = int(node.attrs.get("input_forget", 0))
280
+ if input_forget not in {0, 1}:
281
+ raise UnsupportedOpError("LSTM input_forget must be 0 or 1")
282
+ clip = node.attrs.get("clip")
283
+ if clip is not None:
284
+ clip = float(clip)
285
+ if clip < 0:
286
+ raise UnsupportedOpError("LSTM clip must be non-negative")
287
+ activation_kinds, activation_alphas, activation_betas = _resolve_activations(
288
+ direction, num_directions, node.attrs
289
+ )
290
+ sequence_lens_dtype = (
291
+ value_dtype(graph, input_sequence_lens, node)
292
+ if input_sequence_lens is not None
293
+ else None
294
+ )
295
+ return LstmSpec(
296
+ input_x=input_x,
297
+ input_w=input_w,
298
+ input_r=input_r,
299
+ input_b=input_b,
300
+ input_sequence_lens=input_sequence_lens,
301
+ input_initial_h=input_initial_h,
302
+ input_initial_c=input_initial_c,
303
+ input_p=input_p,
304
+ output_y=output_y,
305
+ output_y_h=output_y_h,
306
+ output_y_c=output_y_c,
307
+ seq_length=seq_length,
308
+ batch_size=batch_size,
309
+ input_size=input_size,
310
+ hidden_size=hidden_size,
311
+ num_directions=num_directions,
312
+ direction=direction,
313
+ layout=layout,
314
+ input_forget=input_forget,
315
+ clip=clip,
316
+ activation_kinds=activation_kinds,
317
+ activation_alphas=activation_alphas,
318
+ activation_betas=activation_betas,
319
+ dtype=op_dtype,
320
+ sequence_lens_dtype=sequence_lens_dtype,
321
+ )
322
+
323
+
324
+ @register_lowering("LSTM")
325
+ def lower_lstm(graph: Graph, node: Node) -> "LstmOp":
326
+ from ..codegen.c_emitter import LstmOp
327
+
328
+ spec = resolve_lstm_spec(graph, node)
329
+ return LstmOp(
330
+ input_x=spec.input_x,
331
+ input_w=spec.input_w,
332
+ input_r=spec.input_r,
333
+ input_b=spec.input_b,
334
+ input_sequence_lens=spec.input_sequence_lens,
335
+ input_initial_h=spec.input_initial_h,
336
+ input_initial_c=spec.input_initial_c,
337
+ input_p=spec.input_p,
338
+ output_y=spec.output_y,
339
+ output_y_h=spec.output_y_h,
340
+ output_y_c=spec.output_y_c,
341
+ seq_length=spec.seq_length,
342
+ batch_size=spec.batch_size,
343
+ input_size=spec.input_size,
344
+ hidden_size=spec.hidden_size,
345
+ num_directions=spec.num_directions,
346
+ direction=spec.direction,
347
+ layout=spec.layout,
348
+ input_forget=spec.input_forget,
349
+ clip=spec.clip,
350
+ activation_kinds=spec.activation_kinds,
351
+ activation_alphas=spec.activation_alphas,
352
+ activation_betas=spec.activation_betas,
353
+ dtype=spec.dtype,
354
+ sequence_lens_dtype=spec.sequence_lens_dtype,
355
+ )
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from ..codegen.c_emitter import MatMulOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import node_dtype as _node_dtype
9
+ from .common import value_shape as _value_shape
10
+ from .registry import register_lowering
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class MatMulSpec:
15
+ input0_shape: tuple[int, ...]
16
+ input1_shape: tuple[int, ...]
17
+ output_shape: tuple[int, ...]
18
+ batch_shape: tuple[int, ...]
19
+ input0_batch_shape: tuple[int, ...]
20
+ input1_batch_shape: tuple[int, ...]
21
+ m: int
22
+ n: int
23
+ k: int
24
+ left_vector: bool
25
+ right_vector: bool
26
+
27
+
28
+ def resolve_matmul_spec(graph: Graph, node: Node) -> MatMulSpec:
29
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
30
+ raise UnsupportedOpError("MatMul must have 2 inputs and 1 output")
31
+ input0_shape = _value_shape(graph, node.inputs[0], node)
32
+ input1_shape = _value_shape(graph, node.inputs[1], node)
33
+ if len(input0_shape) < 1 or len(input1_shape) < 1:
34
+ raise UnsupportedOpError(
35
+ "MatMul inputs must be at least 1D, "
36
+ f"got {input0_shape} x {input1_shape}"
37
+ )
38
+ left_vector = len(input0_shape) == 1
39
+ right_vector = len(input1_shape) == 1
40
+ input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
41
+ input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
42
+ m, k_left = input0_effective[-2], input0_effective[-1]
43
+ k_right, n = input1_effective[-2], input1_effective[-1]
44
+ if k_left != k_right:
45
+ raise ShapeInferenceError(
46
+ f"MatMul inner dimensions must match, got {k_left} and {k_right}"
47
+ )
48
+ batch_shape, input0_batch_shape, input1_batch_shape = (
49
+ _broadcast_batch_shapes(
50
+ input0_effective[:-2], input1_effective[:-2], node
51
+ )
52
+ )
53
+ if left_vector and right_vector:
54
+ output_shape = batch_shape
55
+ elif left_vector:
56
+ output_shape = batch_shape + (n,)
57
+ elif right_vector:
58
+ output_shape = batch_shape + (m,)
59
+ else:
60
+ output_shape = batch_shape + (m, n)
61
+ expected_output_shape = _value_shape(graph, node.outputs[0], node)
62
+ if expected_output_shape != output_shape:
63
+ raise ShapeInferenceError(
64
+ "MatMul output shape must be "
65
+ f"{output_shape}, got {expected_output_shape}"
66
+ )
67
+ return MatMulSpec(
68
+ input0_shape=input0_shape,
69
+ input1_shape=input1_shape,
70
+ output_shape=output_shape,
71
+ batch_shape=batch_shape,
72
+ input0_batch_shape=input0_batch_shape,
73
+ input1_batch_shape=input1_batch_shape,
74
+ m=m,
75
+ n=n,
76
+ k=k_left,
77
+ left_vector=left_vector,
78
+ right_vector=right_vector,
79
+ )
80
+
81
+
82
+ def _broadcast_batch_shapes(
83
+ left: tuple[int, ...], right: tuple[int, ...], node: Node
84
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
85
+ max_rank = max(len(left), len(right))
86
+ left_padded = (1,) * (max_rank - len(left)) + left
87
+ right_padded = (1,) * (max_rank - len(right)) + right
88
+ broadcast_shape = []
89
+ for left_dim, right_dim in zip(left_padded, right_padded):
90
+ if left_dim == right_dim or left_dim == 1 or right_dim == 1:
91
+ broadcast_shape.append(max(left_dim, right_dim))
92
+ continue
93
+ raise ShapeInferenceError(
94
+ "MatMul batch dimensions must be broadcastable, "
95
+ f"got {left} x {right}"
96
+ )
97
+ return tuple(broadcast_shape), left_padded, right_padded
98
+
99
+
100
+ @register_lowering("MatMul")
101
+ def lower_matmul(graph: Graph, node: Node) -> MatMulOp:
102
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
103
+ spec = resolve_matmul_spec(graph, node)
104
+ return MatMulOp(
105
+ input0=node.inputs[0],
106
+ input1=node.inputs[1],
107
+ output=node.outputs[0],
108
+ input0_shape=spec.input0_shape,
109
+ input1_shape=spec.input1_shape,
110
+ output_shape=spec.output_shape,
111
+ batch_shape=spec.batch_shape,
112
+ input0_batch_shape=spec.input0_batch_shape,
113
+ input1_batch_shape=spec.input1_batch_shape,
114
+ m=spec.m,
115
+ n=spec.n,
116
+ k=spec.k,
117
+ left_vector=spec.left_vector,
118
+ right_vector=spec.right_vector,
119
+ dtype=op_dtype,
120
+ )
@@ -0,0 +1,195 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ from shared.scalar_types import ScalarType
7
+
8
+ from ..codegen.c_emitter import MaxPoolOp
9
+ from ..errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..ir.model import Graph, Node
11
+ from .common import node_dtype as _node_dtype
12
+ from .common import value_dtype as _value_dtype
13
+ from .common import value_shape as _value_shape
14
+ from .registry import register_lowering
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class MaxPoolSpec:
19
+ batch: int
20
+ channels: int
21
+ spatial_rank: int
22
+ in_spatial: tuple[int, ...]
23
+ out_spatial: tuple[int, ...]
24
+ kernel_shape: tuple[int, ...]
25
+ strides: tuple[int, ...]
26
+ pads: tuple[int, ...]
27
+ dilations: tuple[int, ...]
28
+ ceil_mode: bool
29
+ storage_order: int
30
+
31
+
32
+ def resolve_maxpool_spec(graph: Graph, node: Node) -> MaxPoolSpec:
33
+ if len(node.inputs) != 1 or len(node.outputs) not in {1, 2}:
34
+ raise UnsupportedOpError("MaxPool must have 1 input and 1 or 2 outputs")
35
+ supported_attrs = {
36
+ "auto_pad",
37
+ "ceil_mode",
38
+ "dilations",
39
+ "kernel_shape",
40
+ "pads",
41
+ "storage_order",
42
+ "strides",
43
+ }
44
+ if set(node.attrs) - supported_attrs:
45
+ raise UnsupportedOpError("MaxPool has unsupported attributes")
46
+ storage_order = int(node.attrs.get("storage_order", 0))
47
+ if storage_order not in (0, 1):
48
+ raise UnsupportedOpError("MaxPool supports storage_order=0 or 1 only")
49
+ kernel_shape = node.attrs.get("kernel_shape")
50
+ if kernel_shape is None:
51
+ raise UnsupportedOpError("MaxPool requires kernel_shape")
52
+ kernel_shape = tuple(int(value) for value in kernel_shape)
53
+ input_shape = _value_shape(graph, node.inputs[0], node)
54
+ if len(input_shape) < 3:
55
+ raise UnsupportedOpError("MaxPool expects NCHW inputs with spatial dims")
56
+ spatial_rank = len(input_shape) - 2
57
+ if spatial_rank not in {1, 2, 3}:
58
+ raise UnsupportedOpError("MaxPool supports 1D/2D/3D inputs only")
59
+ if len(kernel_shape) != spatial_rank:
60
+ raise ShapeInferenceError(
61
+ f"MaxPool kernel_shape must have {spatial_rank} dims, got {kernel_shape}"
62
+ )
63
+ strides = tuple(
64
+ int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
65
+ )
66
+ if len(strides) != spatial_rank:
67
+ raise UnsupportedOpError("MaxPool stride rank mismatch")
68
+ dilations = tuple(
69
+ int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
70
+ )
71
+ if len(dilations) != spatial_rank:
72
+ raise UnsupportedOpError("MaxPool dilation rank mismatch")
73
+ pads = tuple(
74
+ int(value)
75
+ for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
76
+ )
77
+ if len(pads) != 2 * spatial_rank:
78
+ raise UnsupportedOpError("MaxPool pads rank mismatch")
79
+ auto_pad = node.attrs.get("auto_pad", b"NOTSET")
80
+ if isinstance(auto_pad, bytes):
81
+ auto_pad = auto_pad.decode("utf-8", errors="ignore")
82
+ if auto_pad in ("", "NOTSET"):
83
+ pad_begin = pads[:spatial_rank]
84
+ pad_end = pads[spatial_rank:]
85
+ elif auto_pad == "VALID":
86
+ pad_begin = (0,) * spatial_rank
87
+ pad_end = (0,) * spatial_rank
88
+ elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
89
+ pad_begin = []
90
+ pad_end = []
91
+ for dim, stride, dilation, kernel in zip(
92
+ input_shape[2:], strides, dilations, kernel_shape
93
+ ):
94
+ effective_kernel = dilation * (kernel - 1) + 1
95
+ out_dim = math.ceil(dim / stride)
96
+ pad_needed = max(
97
+ 0, (out_dim - 1) * stride + effective_kernel - dim
98
+ )
99
+ if auto_pad == "SAME_UPPER":
100
+ pad_start = pad_needed // 2
101
+ else:
102
+ pad_start = (pad_needed + 1) // 2
103
+ pad_begin.append(pad_start)
104
+ pad_end.append(pad_needed - pad_start)
105
+ pad_begin = tuple(pad_begin)
106
+ pad_end = tuple(pad_end)
107
+ else:
108
+ raise UnsupportedOpError("MaxPool has unsupported auto_pad mode")
109
+ ceil_mode = int(node.attrs.get("ceil_mode", 0))
110
+ if ceil_mode not in (0, 1):
111
+ raise UnsupportedOpError("MaxPool supports ceil_mode=0 or 1 only")
112
+ batch, channels = input_shape[0], input_shape[1]
113
+ in_spatial = input_shape[2:]
114
+ out_spatial = []
115
+ for dim, stride, dilation, kernel, pad_start, pad_finish in zip(
116
+ in_spatial, strides, dilations, kernel_shape, pad_begin, pad_end
117
+ ):
118
+ effective_kernel = dilation * (kernel - 1) + 1
119
+ numerator = dim + pad_start + pad_finish - effective_kernel
120
+ if ceil_mode:
121
+ out_dim = (numerator + stride - 1) // stride + 1
122
+ if (out_dim - 1) * stride >= dim + pad_start:
123
+ out_dim -= 1
124
+ else:
125
+ out_dim = numerator // stride + 1
126
+ if out_dim < 0:
127
+ raise ShapeInferenceError(
128
+ "MaxPool output shape must be non-negative"
129
+ )
130
+ out_spatial.append(out_dim)
131
+ expected_output_shape = (batch, channels, *out_spatial)
132
+ output_shape = _value_shape(graph, node.outputs[0], node)
133
+ if output_shape != expected_output_shape:
134
+ raise ShapeInferenceError(
135
+ "MaxPool output shape must be "
136
+ f"{expected_output_shape}, got {output_shape}"
137
+ )
138
+ if len(node.outputs) == 2:
139
+ indices_shape = _value_shape(graph, node.outputs[1], node)
140
+ if indices_shape != expected_output_shape:
141
+ raise ShapeInferenceError(
142
+ "MaxPool indices output shape must be "
143
+ f"{expected_output_shape}, got {indices_shape}"
144
+ )
145
+ indices_dtype = _value_dtype(graph, node.outputs[1], node)
146
+ if indices_dtype != ScalarType.I64:
147
+ raise UnsupportedOpError("MaxPool indices output must be int64")
148
+ pads = (*pad_begin, *pad_end)
149
+ return MaxPoolSpec(
150
+ batch=batch,
151
+ channels=channels,
152
+ spatial_rank=spatial_rank,
153
+ in_spatial=in_spatial,
154
+ out_spatial=tuple(out_spatial),
155
+ kernel_shape=kernel_shape,
156
+ strides=strides,
157
+ pads=pads,
158
+ dilations=dilations,
159
+ ceil_mode=bool(ceil_mode),
160
+ storage_order=storage_order,
161
+ )
162
+
163
+
164
+ @register_lowering("MaxPool")
165
+ def lower_maxpool(graph: Graph, node: Node) -> MaxPoolOp:
166
+ if len(node.inputs) != 1 or len(node.outputs) not in {1, 2}:
167
+ raise UnsupportedOpError("MaxPool must have 1 input and 1 or 2 outputs")
168
+ op_dtype = _node_dtype(graph, node, node.inputs[0], node.outputs[0])
169
+ if op_dtype == ScalarType.BOOL:
170
+ raise UnsupportedOpError("MaxPool supports numeric inputs only")
171
+ spec = resolve_maxpool_spec(graph, node)
172
+ indices = node.outputs[1] if len(node.outputs) == 2 else None
173
+ indices_dtype = (
174
+ _value_dtype(graph, indices, node) if indices is not None else None
175
+ )
176
+ if indices_dtype is not None and indices_dtype != ScalarType.I64:
177
+ raise UnsupportedOpError("MaxPool indices output must be int64")
178
+ return MaxPoolOp(
179
+ input0=node.inputs[0],
180
+ output=node.outputs[0],
181
+ indices=indices,
182
+ batch=spec.batch,
183
+ channels=spec.channels,
184
+ spatial_rank=spec.spatial_rank,
185
+ in_spatial=spec.in_spatial,
186
+ out_spatial=spec.out_spatial,
187
+ kernel_shape=spec.kernel_shape,
188
+ strides=spec.strides,
189
+ pads=spec.pads,
190
+ dilations=spec.dilations,
191
+ ceil_mode=spec.ceil_mode,
192
+ storage_order=spec.storage_order,
193
+ dtype=op_dtype,
194
+ indices_dtype=indices_dtype,
195
+ )