emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from collections.abc import Sequence
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
- from ..ir.model import Graph, Node
8
+ from ..ir.model import Graph, Initializer, Node
9
9
 
10
10
 
11
11
  def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
@@ -14,6 +14,17 @@ def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
14
14
  return dtype
15
15
 
16
16
 
17
+ def onnx_opset_version(graph: Graph, domain: str = "") -> int | None:
18
+ if domain in {"", "ai.onnx"}:
19
+ domains = {"", "ai.onnx"}
20
+ else:
21
+ domains = {domain}
22
+ for opset_domain, version in graph.opset_imports:
23
+ if opset_domain in domains:
24
+ return int(version)
25
+ return None
26
+
27
+
17
28
  def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType:
18
29
  try:
19
30
  value = graph.find_value(name)
@@ -28,13 +39,379 @@ def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType
28
39
 
29
40
  def value_shape(graph: Graph, name: str, node: Node | None = None) -> tuple[int, ...]:
30
41
  try:
31
- return graph.find_value(name).type.shape
42
+ value = graph.find_value(name)
32
43
  except KeyError as exc:
33
44
  op_type = node.op_type if node is not None else "unknown"
34
45
  raise ShapeInferenceError(
35
46
  f"Missing shape for value '{name}' in op {op_type}. "
36
47
  "Hint: run ONNX shape inference or export with static shapes."
37
48
  ) from exc
49
+ if any(value.type.dim_params):
50
+ resolved = _resolve_value_shape(graph, name, node)
51
+ if resolved is not None:
52
+ return resolved
53
+ return value.type.shape
54
+ return value.type.shape
55
+
56
+
57
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
58
+ for initializer in graph.initializers:
59
+ if initializer.name == name:
60
+ return initializer
61
+ return None
62
+
63
+
64
+ def _find_node_by_output(graph: Graph, name: str) -> Node | None:
65
+ for node in graph.nodes:
66
+ if name in node.outputs:
67
+ return node
68
+ return None
69
+
70
+
71
+ def _shape_values_from_shape_node(
72
+ graph: Graph, shape_node: Node, node: Node | None
73
+ ) -> list[int]:
74
+ if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
75
+ raise UnsupportedOpError("Shape must have 1 input and 1 output")
76
+ source_shape = value_shape(graph, shape_node.inputs[0], node)
77
+ start = int(shape_node.attrs.get("start", 0))
78
+ end = int(shape_node.attrs.get("end", len(source_shape)))
79
+ if start < 0:
80
+ start += len(source_shape)
81
+ if end < 0:
82
+ end += len(source_shape)
83
+ start = max(start, 0)
84
+ end = min(end, len(source_shape))
85
+ if start > end:
86
+ return []
87
+ return list(source_shape[start:end])
88
+
89
+
90
+ def _shape_values_from_initializer(
91
+ graph: Graph,
92
+ name: str,
93
+ ) -> list[int] | None:
94
+ initializer = _find_initializer(graph, name)
95
+ if initializer is None:
96
+ return None
97
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
98
+ raise UnsupportedOpError(
99
+ "Reshape expects int64 or int32 shape input, "
100
+ f"got {initializer.type.dtype.onnx_name}"
101
+ )
102
+ return [int(value) for value in initializer.data.reshape(-1)]
103
+
104
+
105
+ def _shape_values_from_input(
106
+ graph: Graph,
107
+ name: str,
108
+ node: Node | None,
109
+ *,
110
+ _visited: set[str] | None = None,
111
+ ) -> list[int] | None:
112
+ if _visited is None:
113
+ _visited = set()
114
+ if name in _visited:
115
+ return None
116
+ _visited.add(name)
117
+ try:
118
+ shape_values = _shape_values_from_initializer(graph, name)
119
+ if shape_values is not None:
120
+ return shape_values
121
+ source_node = _find_node_by_output(graph, name)
122
+ if source_node is None:
123
+ return None
124
+ if source_node.op_type == "Shape":
125
+ return _shape_values_from_shape_node(graph, source_node, node)
126
+ if source_node.op_type == "Concat":
127
+ axis = int(source_node.attrs.get("axis", 0))
128
+ if axis != 0:
129
+ raise UnsupportedOpError("Reshape shape concat must use axis 0")
130
+ values: list[int] = []
131
+ for input_name in source_node.inputs:
132
+ input_values = _shape_values_from_input(
133
+ graph,
134
+ input_name,
135
+ node,
136
+ _visited=_visited,
137
+ )
138
+ if input_values is None:
139
+ return None
140
+ values.extend(input_values)
141
+ return values
142
+ if source_node.op_type == "Cast":
143
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
144
+ raise UnsupportedOpError("Cast must have 1 input and 1 output")
145
+ return _shape_values_from_input(
146
+ graph,
147
+ source_node.inputs[0],
148
+ node,
149
+ _visited=_visited,
150
+ )
151
+ if source_node.op_type == "Unsqueeze":
152
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
153
+ raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
154
+ return _shape_values_from_input(
155
+ graph,
156
+ source_node.inputs[0],
157
+ node,
158
+ _visited=_visited,
159
+ )
160
+ if source_node.op_type == "Identity":
161
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
162
+ raise UnsupportedOpError("Identity must have 1 input and 1 output")
163
+ return _shape_values_from_input(
164
+ graph,
165
+ source_node.inputs[0],
166
+ node,
167
+ _visited=_visited,
168
+ )
169
+ if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
170
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
171
+ raise UnsupportedOpError(
172
+ f"{source_node.op_type} must have 2 inputs and 1 output"
173
+ )
174
+ left = _shape_values_from_input(
175
+ graph,
176
+ source_node.inputs[0],
177
+ node,
178
+ _visited=_visited,
179
+ )
180
+ right = _shape_values_from_input(
181
+ graph,
182
+ source_node.inputs[1],
183
+ node,
184
+ _visited=_visited,
185
+ )
186
+ if left is None or right is None:
187
+ return None
188
+ if len(left) == 1 and len(right) != 1:
189
+ left = left * len(right)
190
+ if len(right) == 1 and len(left) != 1:
191
+ right = right * len(left)
192
+ if len(left) != len(right):
193
+ return None
194
+ if source_node.op_type == "Equal":
195
+ return [1 if l == r else 0 for l, r in zip(left, right)]
196
+ if source_node.op_type == "And":
197
+ return [1 if (l and r) else 0 for l, r in zip(left, right)]
198
+ if source_node.op_type == "Or":
199
+ return [1 if (l or r) else 0 for l, r in zip(left, right)]
200
+ if source_node.op_type == "Div":
201
+ return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
202
+ if source_node.op_type == "Mod":
203
+ return [l % r if r != 0 else 0 for l, r in zip(left, right)]
204
+ if source_node.op_type == "Not":
205
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
206
+ raise UnsupportedOpError("Not must have 1 input and 1 output")
207
+ values = _shape_values_from_input(
208
+ graph,
209
+ source_node.inputs[0],
210
+ node,
211
+ _visited=_visited,
212
+ )
213
+ if values is None:
214
+ return None
215
+ return [0 if value else 1 for value in values]
216
+ if source_node.op_type == "Where":
217
+ if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
218
+ raise UnsupportedOpError("Where must have 3 inputs and 1 output")
219
+ condition = _shape_values_from_input(
220
+ graph,
221
+ source_node.inputs[0],
222
+ node,
223
+ _visited=_visited,
224
+ )
225
+ if condition is None:
226
+ return None
227
+ on_true = _shape_values_from_input(
228
+ graph,
229
+ source_node.inputs[1],
230
+ node,
231
+ _visited=_visited,
232
+ )
233
+ on_false = _shape_values_from_input(
234
+ graph,
235
+ source_node.inputs[2],
236
+ node,
237
+ _visited=_visited,
238
+ )
239
+ if on_true is None or on_false is None:
240
+ return None
241
+ if len(condition) == 1:
242
+ condition = condition * max(len(on_true), len(on_false))
243
+ if len(on_true) == 1 and len(condition) != 1:
244
+ on_true = on_true * len(condition)
245
+ if len(on_false) == 1 and len(condition) != 1:
246
+ on_false = on_false * len(condition)
247
+ if not (len(condition) == len(on_true) == len(on_false)):
248
+ return None
249
+ return [
250
+ t if cond else f
251
+ for cond, t, f in zip(condition, on_true, on_false)
252
+ ]
253
+ return None
254
+ finally:
255
+ _visited.remove(name)
256
+
257
+
258
+ def _broadcast_shapes(
259
+ left: tuple[int, ...],
260
+ right: tuple[int, ...],
261
+ ) -> tuple[int, ...] | None:
262
+ result = []
263
+ left_rev = list(reversed(left))
264
+ right_rev = list(reversed(right))
265
+ for index in range(max(len(left_rev), len(right_rev))):
266
+ left_dim = left_rev[index] if index < len(left_rev) else 1
267
+ right_dim = right_rev[index] if index < len(right_rev) else 1
268
+ if left_dim == right_dim:
269
+ result.append(left_dim)
270
+ elif left_dim == 1:
271
+ result.append(right_dim)
272
+ elif right_dim == 1:
273
+ result.append(left_dim)
274
+ else:
275
+ return None
276
+ return tuple(reversed(result))
277
+
278
+
279
+ def _resolve_value_shape(
280
+ graph: Graph,
281
+ name: str,
282
+ node: Node | None,
283
+ *,
284
+ _visited: set[str] | None = None,
285
+ ) -> tuple[int, ...] | None:
286
+ if _visited is None:
287
+ _visited = set()
288
+ if name in _visited:
289
+ return None
290
+ _visited.add(name)
291
+ try:
292
+ value = graph.find_value(name)
293
+ shape = value.type.shape
294
+ if not any(value.type.dim_params):
295
+ return shape
296
+ source_node = _find_node_by_output(graph, name)
297
+ if source_node is None:
298
+ return None
299
+ if source_node.op_type == "Expand":
300
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
301
+ raise UnsupportedOpError("Expand must have 2 inputs and 1 output")
302
+ shape_values = _shape_values_from_input(
303
+ graph, source_node.inputs[1], node
304
+ )
305
+ if shape_values is not None and all(dim >= 0 for dim in shape_values):
306
+ return tuple(shape_values)
307
+ return None
308
+ if source_node.op_type == "Reshape":
309
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
310
+ raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
311
+ shape_values = _shape_values_from_input(
312
+ graph, source_node.inputs[1], node
313
+ )
314
+ if shape_values is None:
315
+ return None
316
+ allowzero = int(source_node.attrs.get("allowzero", 0))
317
+ input_shape = _resolve_value_shape(
318
+ graph,
319
+ source_node.inputs[0],
320
+ node,
321
+ _visited=_visited,
322
+ )
323
+ if input_shape is None:
324
+ return None
325
+ output_dims: list[int] = []
326
+ unknown_index: int | None = None
327
+ known_product = 1
328
+ contains_zero = False
329
+ for index, dim in enumerate(shape_values):
330
+ if dim == -1:
331
+ if unknown_index is not None:
332
+ return None
333
+ unknown_index = len(output_dims)
334
+ output_dims.append(-1)
335
+ else:
336
+ if dim == 0:
337
+ contains_zero = True
338
+ if allowzero == 0:
339
+ if index >= len(input_shape):
340
+ return None
341
+ dim = input_shape[index]
342
+ if dim < 0:
343
+ return None
344
+ output_dims.append(dim)
345
+ known_product *= dim
346
+ if allowzero == 1 and contains_zero and unknown_index is not None:
347
+ return None
348
+ input_product = shape_product(input_shape)
349
+ if unknown_index is not None:
350
+ if known_product == 0:
351
+ if input_product != 0:
352
+ return None
353
+ output_dims[unknown_index] = 0
354
+ else:
355
+ if input_product % known_product != 0:
356
+ return None
357
+ output_dims[unknown_index] = input_product // known_product
358
+ return tuple(output_dims)
359
+ if source_node.op_type in {
360
+ "Add",
361
+ "Sub",
362
+ "Mul",
363
+ "Div",
364
+ "Pow",
365
+ "Mod",
366
+ "And",
367
+ "Or",
368
+ "Xor",
369
+ "Equal",
370
+ "Greater",
371
+ "Less",
372
+ "GreaterOrEqual",
373
+ "LessOrEqual",
374
+ }:
375
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
376
+ raise UnsupportedOpError(
377
+ f"{source_node.op_type} must have 2 inputs and 1 output"
378
+ )
379
+ left = _resolve_value_shape(
380
+ graph,
381
+ source_node.inputs[0],
382
+ node,
383
+ _visited=_visited,
384
+ )
385
+ right = _resolve_value_shape(
386
+ graph,
387
+ source_node.inputs[1],
388
+ node,
389
+ _visited=_visited,
390
+ )
391
+ if left is None or right is None:
392
+ return None
393
+ return _broadcast_shapes(left, right)
394
+ if source_node.op_type == "Where":
395
+ if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
396
+ raise UnsupportedOpError("Where must have 3 inputs and 1 output")
397
+ on_true = _resolve_value_shape(
398
+ graph,
399
+ source_node.inputs[1],
400
+ node,
401
+ _visited=_visited,
402
+ )
403
+ on_false = _resolve_value_shape(
404
+ graph,
405
+ source_node.inputs[2],
406
+ node,
407
+ _visited=_visited,
408
+ )
409
+ if on_true is None or on_false is None:
410
+ return None
411
+ return _broadcast_shapes(on_true, on_false)
412
+ return None
413
+ finally:
414
+ _visited.remove(name)
38
415
 
39
416
 
40
417
  def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
@@ -0,0 +1,301 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ from ..codegen.c_emitter import ConvTransposeOp
7
+ from ..errors import ShapeInferenceError, UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+ from .common import node_dtype as _node_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ConvTransposeSpec:
16
+ batch: int
17
+ in_channels: int
18
+ out_channels: int
19
+ spatial_rank: int
20
+ in_spatial: tuple[int, ...]
21
+ out_spatial: tuple[int, ...]
22
+ kernel_shape: tuple[int, ...]
23
+ strides: tuple[int, ...]
24
+ pads: tuple[int, ...]
25
+ dilations: tuple[int, ...]
26
+ output_padding: tuple[int, ...]
27
+ group: int
28
+
29
+
30
+ def _split_padding(
31
+ total_padding: int, auto_pad: str, *, dim: int
32
+ ) -> tuple[int, int]:
33
+ if total_padding < 0:
34
+ raise ShapeInferenceError(
35
+ "ConvTranspose output shape must be fully defined and non-negative"
36
+ )
37
+ pad_end = total_padding // 2
38
+ pad_begin = total_padding - pad_end
39
+ if auto_pad == "SAME_UPPER":
40
+ pad_begin, pad_end = pad_end, pad_begin
41
+ elif auto_pad not in {"SAME_LOWER", "NOTSET", ""}:
42
+ raise UnsupportedOpError(
43
+ f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
44
+ )
45
+ if pad_begin < 0 or pad_end < 0:
46
+ raise ShapeInferenceError(
47
+ f"ConvTranspose pads must be non-negative for dim {dim}"
48
+ )
49
+ return pad_begin, pad_end
50
+
51
+
52
+ def resolve_conv_transpose_spec(graph: Graph, node: Node) -> ConvTransposeSpec:
53
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
54
+ raise UnsupportedOpError(
55
+ "ConvTranspose must have 2 or 3 inputs and 1 output"
56
+ )
57
+ supported_attrs = {
58
+ "auto_pad",
59
+ "dilations",
60
+ "group",
61
+ "kernel_shape",
62
+ "output_padding",
63
+ "output_shape",
64
+ "pads",
65
+ "strides",
66
+ }
67
+ if set(node.attrs) - supported_attrs:
68
+ raise UnsupportedOpError("ConvTranspose has unsupported attributes")
69
+ input_shape = _value_shape(graph, node.inputs[0], node)
70
+ weight_shape = _value_shape(graph, node.inputs[1], node)
71
+ if len(input_shape) < 3:
72
+ raise UnsupportedOpError("ConvTranspose expects NCHW inputs with spatial dims")
73
+ spatial_rank = len(input_shape) - 2
74
+ if spatial_rank not in {1, 2, 3}:
75
+ raise UnsupportedOpError("ConvTranspose supports 1D/2D/3D inputs only")
76
+ if len(weight_shape) != spatial_rank + 2:
77
+ raise UnsupportedOpError(
78
+ "ConvTranspose weight rank must match spatial rank"
79
+ )
80
+ batch, in_channels = input_shape[0], input_shape[1]
81
+ in_spatial = input_shape[2:]
82
+ weight_in_channels, weight_out_channels, *kernel_shape = weight_shape
83
+ kernel_attr = node.attrs.get("kernel_shape")
84
+ if kernel_attr is not None:
85
+ kernel_attr = tuple(int(value) for value in kernel_attr)
86
+ if len(kernel_attr) != spatial_rank:
87
+ raise UnsupportedOpError(
88
+ "ConvTranspose kernel_shape rank must match input spatial rank"
89
+ )
90
+ if kernel_attr != tuple(kernel_shape):
91
+ raise ShapeInferenceError(
92
+ "ConvTranspose kernel_shape must match weights, "
93
+ f"got {kernel_attr} and {tuple(kernel_shape)}"
94
+ )
95
+ kernel_shape = list(kernel_attr)
96
+ else:
97
+ kernel_shape = list(kernel_shape)
98
+ group = int(node.attrs.get("group", 1))
99
+ if group <= 0:
100
+ raise UnsupportedOpError("ConvTranspose expects group >= 1")
101
+ if in_channels % group != 0:
102
+ raise ShapeInferenceError(
103
+ "ConvTranspose expects group to evenly divide in channels, "
104
+ f"got group={group}, in_channels={in_channels}"
105
+ )
106
+ if weight_in_channels != in_channels:
107
+ raise ShapeInferenceError(
108
+ "ConvTranspose input channels must match weight channels, "
109
+ f"got {in_channels} and {weight_in_channels}"
110
+ )
111
+ out_channels = weight_out_channels * group
112
+ if out_channels % group != 0:
113
+ raise ShapeInferenceError(
114
+ "ConvTranspose expects group to evenly divide out channels, "
115
+ f"got group={group}, out_channels={out_channels}"
116
+ )
117
+ if len(node.inputs) == 3:
118
+ bias_shape = _value_shape(graph, node.inputs[2], node)
119
+ if bias_shape != (out_channels,):
120
+ raise ShapeInferenceError(
121
+ f"ConvTranspose bias shape must be {(out_channels,)}, got {bias_shape}"
122
+ )
123
+ strides = tuple(
124
+ int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
125
+ )
126
+ if len(strides) != spatial_rank:
127
+ raise UnsupportedOpError("ConvTranspose stride rank mismatch")
128
+ dilations = tuple(
129
+ int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
130
+ )
131
+ if len(dilations) != spatial_rank:
132
+ raise UnsupportedOpError("ConvTranspose dilation rank mismatch")
133
+ output_padding = tuple(
134
+ int(value)
135
+ for value in node.attrs.get("output_padding", (0,) * spatial_rank)
136
+ )
137
+ if len(output_padding) != spatial_rank:
138
+ raise UnsupportedOpError("ConvTranspose output_padding rank mismatch")
139
+ for dim, (padding, stride) in enumerate(zip(output_padding, strides)):
140
+ if padding < 0:
141
+ raise UnsupportedOpError(
142
+ "ConvTranspose output_padding must be non-negative"
143
+ )
144
+ if padding >= stride:
145
+ raise UnsupportedOpError(
146
+ "ConvTranspose output_padding must be smaller than stride"
147
+ )
148
+ pads = tuple(
149
+ int(value)
150
+ for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
151
+ )
152
+ if len(pads) != 2 * spatial_rank:
153
+ raise UnsupportedOpError("ConvTranspose pads rank mismatch")
154
+ auto_pad = node.attrs.get("auto_pad", b"NOTSET")
155
+ if isinstance(auto_pad, bytes):
156
+ auto_pad = auto_pad.decode("utf-8", errors="ignore")
157
+ if auto_pad == "":
158
+ auto_pad = "NOTSET"
159
+ output_shape_attr = node.attrs.get("output_shape")
160
+ output_shape: list[int] | None = None
161
+ if output_shape_attr is not None:
162
+ output_shape = [int(value) for value in output_shape_attr]
163
+ if len(output_shape) != spatial_rank:
164
+ raise UnsupportedOpError("ConvTranspose output_shape rank mismatch")
165
+ if output_shape is not None:
166
+ if auto_pad == "VALID":
167
+ auto_pad = "NOTSET"
168
+ pad_begin = []
169
+ pad_end = []
170
+ for dim, (in_dim, stride, dilation, kernel, out_dim, out_pad) in enumerate(
171
+ zip(
172
+ in_spatial,
173
+ strides,
174
+ dilations,
175
+ kernel_shape,
176
+ output_shape,
177
+ output_padding,
178
+ )
179
+ ):
180
+ effective_kernel = dilation * (kernel - 1) + 1
181
+ total_padding = (
182
+ stride * (in_dim - 1)
183
+ + out_pad
184
+ + effective_kernel
185
+ - out_dim
186
+ )
187
+ pad_start, pad_finish = _split_padding(
188
+ total_padding, auto_pad, dim=dim
189
+ )
190
+ pad_begin.append(pad_start)
191
+ pad_end.append(pad_finish)
192
+ out_spatial = output_shape
193
+ else:
194
+ if auto_pad == "VALID":
195
+ pad_begin = [0] * spatial_rank
196
+ pad_end = [0] * spatial_rank
197
+ elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
198
+ pad_begin = []
199
+ pad_end = []
200
+ for dim, (in_dim, stride, dilation, kernel, out_pad) in enumerate(
201
+ zip(in_spatial, strides, dilations, kernel_shape, output_padding)
202
+ ):
203
+ effective_kernel = dilation * (kernel - 1) + 1
204
+ out_dim = in_dim * stride
205
+ total_padding = (
206
+ stride * (in_dim - 1)
207
+ + out_pad
208
+ + effective_kernel
209
+ - out_dim
210
+ )
211
+ pad_start, pad_finish = _split_padding(
212
+ total_padding, auto_pad, dim=dim
213
+ )
214
+ pad_begin.append(pad_start)
215
+ pad_end.append(pad_finish)
216
+ elif auto_pad in {"NOTSET"}:
217
+ pad_begin = list(pads[:spatial_rank])
218
+ pad_end = list(pads[spatial_rank:])
219
+ else:
220
+ raise UnsupportedOpError(
221
+ f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
222
+ )
223
+ out_spatial = []
224
+ for dim, (in_dim, stride, dilation, kernel, pad_start, pad_finish, out_pad) in enumerate(
225
+ zip(
226
+ in_spatial,
227
+ strides,
228
+ dilations,
229
+ kernel_shape,
230
+ pad_begin,
231
+ pad_end,
232
+ output_padding,
233
+ )
234
+ ):
235
+ effective_kernel = dilation * (kernel - 1) + 1
236
+ out_dim = (
237
+ stride * (in_dim - 1)
238
+ + out_pad
239
+ + effective_kernel
240
+ - pad_start
241
+ - pad_finish
242
+ )
243
+ if out_dim < 0:
244
+ raise ShapeInferenceError(
245
+ "ConvTranspose output shape must be non-negative"
246
+ )
247
+ out_spatial.append(out_dim)
248
+ output_shape = _value_shape(graph, node.outputs[0], node)
249
+ expected_output_shape = (batch, out_channels, *out_spatial)
250
+ if output_shape != expected_output_shape:
251
+ raise ShapeInferenceError(
252
+ "ConvTranspose output shape must be "
253
+ f"{expected_output_shape}, got {output_shape}"
254
+ )
255
+ return ConvTransposeSpec(
256
+ batch=batch,
257
+ in_channels=in_channels,
258
+ out_channels=out_channels,
259
+ spatial_rank=spatial_rank,
260
+ in_spatial=in_spatial,
261
+ out_spatial=tuple(out_spatial),
262
+ kernel_shape=tuple(kernel_shape),
263
+ strides=strides,
264
+ pads=(*pad_begin, *pad_end),
265
+ dilations=dilations,
266
+ output_padding=output_padding,
267
+ group=group,
268
+ )
269
+
270
+
271
+ @register_lowering("ConvTranspose")
272
+ def lower_conv_transpose(graph: Graph, node: Node) -> ConvTransposeOp:
273
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
274
+ raise UnsupportedOpError(
275
+ "ConvTranspose must have 2 or 3 inputs and 1 output"
276
+ )
277
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
278
+ if not op_dtype.is_float:
279
+ raise UnsupportedOpError(
280
+ "ConvTranspose supports float16, float, and double inputs only"
281
+ )
282
+ spec = resolve_conv_transpose_spec(graph, node)
283
+ return ConvTransposeOp(
284
+ input0=node.inputs[0],
285
+ weights=node.inputs[1],
286
+ bias=node.inputs[2] if len(node.inputs) == 3 else None,
287
+ output=node.outputs[0],
288
+ batch=spec.batch,
289
+ in_channels=spec.in_channels,
290
+ out_channels=spec.out_channels,
291
+ spatial_rank=spec.spatial_rank,
292
+ in_spatial=spec.in_spatial,
293
+ out_spatial=spec.out_spatial,
294
+ kernel_shape=spec.kernel_shape,
295
+ strides=spec.strides,
296
+ pads=spec.pads,
297
+ dilations=spec.dilations,
298
+ output_padding=spec.output_padding,
299
+ group=spec.group,
300
+ dtype=op_dtype,
301
+ )