emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +340 -59
- emx_onnx_cgen/codegen/c_emitter.py +2369 -111
- emx_onnx_cgen/compiler.py +188 -5
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/lowering/common.py +379 -2
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/gather_elements.py +1 -3
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +6 -5
- emx_onnx_cgen/lowering/logsoftmax.py +5 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/matmul.py +6 -7
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/reduce.py +5 -6
- emx_onnx_cgen/lowering/reshape.py +223 -51
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/softmax.py +5 -1
- emx_onnx_cgen/lowering/squeeze.py +5 -5
- emx_onnx_cgen/lowering/topk.py +116 -0
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +5 -5
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +460 -42
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +61 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
- shared/scalar_functions.py +49 -17
- shared/ulp.py +48 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/lowering/common.py
CHANGED
|
@@ -5,7 +5,7 @@ from collections.abc import Sequence
|
|
|
5
5
|
from shared.scalar_types import ScalarType
|
|
6
6
|
|
|
7
7
|
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
-
from ..ir.model import Graph, Node
|
|
8
|
+
from ..ir.model import Graph, Initializer, Node
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
|
|
@@ -14,6 +14,17 @@ def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
|
|
|
14
14
|
return dtype
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
def onnx_opset_version(graph: Graph, domain: str = "") -> int | None:
|
|
18
|
+
if domain in {"", "ai.onnx"}:
|
|
19
|
+
domains = {"", "ai.onnx"}
|
|
20
|
+
else:
|
|
21
|
+
domains = {domain}
|
|
22
|
+
for opset_domain, version in graph.opset_imports:
|
|
23
|
+
if opset_domain in domains:
|
|
24
|
+
return int(version)
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
|
|
17
28
|
def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType:
|
|
18
29
|
try:
|
|
19
30
|
value = graph.find_value(name)
|
|
@@ -28,13 +39,379 @@ def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType
|
|
|
28
39
|
|
|
29
40
|
def value_shape(graph: Graph, name: str, node: Node | None = None) -> tuple[int, ...]:
|
|
30
41
|
try:
|
|
31
|
-
|
|
42
|
+
value = graph.find_value(name)
|
|
32
43
|
except KeyError as exc:
|
|
33
44
|
op_type = node.op_type if node is not None else "unknown"
|
|
34
45
|
raise ShapeInferenceError(
|
|
35
46
|
f"Missing shape for value '{name}' in op {op_type}. "
|
|
36
47
|
"Hint: run ONNX shape inference or export with static shapes."
|
|
37
48
|
) from exc
|
|
49
|
+
if any(value.type.dim_params):
|
|
50
|
+
resolved = _resolve_value_shape(graph, name, node)
|
|
51
|
+
if resolved is not None:
|
|
52
|
+
return resolved
|
|
53
|
+
return value.type.shape
|
|
54
|
+
return value.type.shape
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _find_initializer(graph: Graph, name: str) -> Initializer | None:
|
|
58
|
+
for initializer in graph.initializers:
|
|
59
|
+
if initializer.name == name:
|
|
60
|
+
return initializer
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _find_node_by_output(graph: Graph, name: str) -> Node | None:
|
|
65
|
+
for node in graph.nodes:
|
|
66
|
+
if name in node.outputs:
|
|
67
|
+
return node
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _shape_values_from_shape_node(
|
|
72
|
+
graph: Graph, shape_node: Node, node: Node | None
|
|
73
|
+
) -> list[int]:
|
|
74
|
+
if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
|
|
75
|
+
raise UnsupportedOpError("Shape must have 1 input and 1 output")
|
|
76
|
+
source_shape = value_shape(graph, shape_node.inputs[0], node)
|
|
77
|
+
start = int(shape_node.attrs.get("start", 0))
|
|
78
|
+
end = int(shape_node.attrs.get("end", len(source_shape)))
|
|
79
|
+
if start < 0:
|
|
80
|
+
start += len(source_shape)
|
|
81
|
+
if end < 0:
|
|
82
|
+
end += len(source_shape)
|
|
83
|
+
start = max(start, 0)
|
|
84
|
+
end = min(end, len(source_shape))
|
|
85
|
+
if start > end:
|
|
86
|
+
return []
|
|
87
|
+
return list(source_shape[start:end])
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _shape_values_from_initializer(
|
|
91
|
+
graph: Graph,
|
|
92
|
+
name: str,
|
|
93
|
+
) -> list[int] | None:
|
|
94
|
+
initializer = _find_initializer(graph, name)
|
|
95
|
+
if initializer is None:
|
|
96
|
+
return None
|
|
97
|
+
if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
|
|
98
|
+
raise UnsupportedOpError(
|
|
99
|
+
"Reshape expects int64 or int32 shape input, "
|
|
100
|
+
f"got {initializer.type.dtype.onnx_name}"
|
|
101
|
+
)
|
|
102
|
+
return [int(value) for value in initializer.data.reshape(-1)]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _shape_values_from_input(
|
|
106
|
+
graph: Graph,
|
|
107
|
+
name: str,
|
|
108
|
+
node: Node | None,
|
|
109
|
+
*,
|
|
110
|
+
_visited: set[str] | None = None,
|
|
111
|
+
) -> list[int] | None:
|
|
112
|
+
if _visited is None:
|
|
113
|
+
_visited = set()
|
|
114
|
+
if name in _visited:
|
|
115
|
+
return None
|
|
116
|
+
_visited.add(name)
|
|
117
|
+
try:
|
|
118
|
+
shape_values = _shape_values_from_initializer(graph, name)
|
|
119
|
+
if shape_values is not None:
|
|
120
|
+
return shape_values
|
|
121
|
+
source_node = _find_node_by_output(graph, name)
|
|
122
|
+
if source_node is None:
|
|
123
|
+
return None
|
|
124
|
+
if source_node.op_type == "Shape":
|
|
125
|
+
return _shape_values_from_shape_node(graph, source_node, node)
|
|
126
|
+
if source_node.op_type == "Concat":
|
|
127
|
+
axis = int(source_node.attrs.get("axis", 0))
|
|
128
|
+
if axis != 0:
|
|
129
|
+
raise UnsupportedOpError("Reshape shape concat must use axis 0")
|
|
130
|
+
values: list[int] = []
|
|
131
|
+
for input_name in source_node.inputs:
|
|
132
|
+
input_values = _shape_values_from_input(
|
|
133
|
+
graph,
|
|
134
|
+
input_name,
|
|
135
|
+
node,
|
|
136
|
+
_visited=_visited,
|
|
137
|
+
)
|
|
138
|
+
if input_values is None:
|
|
139
|
+
return None
|
|
140
|
+
values.extend(input_values)
|
|
141
|
+
return values
|
|
142
|
+
if source_node.op_type == "Cast":
|
|
143
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
144
|
+
raise UnsupportedOpError("Cast must have 1 input and 1 output")
|
|
145
|
+
return _shape_values_from_input(
|
|
146
|
+
graph,
|
|
147
|
+
source_node.inputs[0],
|
|
148
|
+
node,
|
|
149
|
+
_visited=_visited,
|
|
150
|
+
)
|
|
151
|
+
if source_node.op_type == "Unsqueeze":
|
|
152
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
153
|
+
raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
|
|
154
|
+
return _shape_values_from_input(
|
|
155
|
+
graph,
|
|
156
|
+
source_node.inputs[0],
|
|
157
|
+
node,
|
|
158
|
+
_visited=_visited,
|
|
159
|
+
)
|
|
160
|
+
if source_node.op_type == "Identity":
|
|
161
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
162
|
+
raise UnsupportedOpError("Identity must have 1 input and 1 output")
|
|
163
|
+
return _shape_values_from_input(
|
|
164
|
+
graph,
|
|
165
|
+
source_node.inputs[0],
|
|
166
|
+
node,
|
|
167
|
+
_visited=_visited,
|
|
168
|
+
)
|
|
169
|
+
if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
|
|
170
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
171
|
+
raise UnsupportedOpError(
|
|
172
|
+
f"{source_node.op_type} must have 2 inputs and 1 output"
|
|
173
|
+
)
|
|
174
|
+
left = _shape_values_from_input(
|
|
175
|
+
graph,
|
|
176
|
+
source_node.inputs[0],
|
|
177
|
+
node,
|
|
178
|
+
_visited=_visited,
|
|
179
|
+
)
|
|
180
|
+
right = _shape_values_from_input(
|
|
181
|
+
graph,
|
|
182
|
+
source_node.inputs[1],
|
|
183
|
+
node,
|
|
184
|
+
_visited=_visited,
|
|
185
|
+
)
|
|
186
|
+
if left is None or right is None:
|
|
187
|
+
return None
|
|
188
|
+
if len(left) == 1 and len(right) != 1:
|
|
189
|
+
left = left * len(right)
|
|
190
|
+
if len(right) == 1 and len(left) != 1:
|
|
191
|
+
right = right * len(left)
|
|
192
|
+
if len(left) != len(right):
|
|
193
|
+
return None
|
|
194
|
+
if source_node.op_type == "Equal":
|
|
195
|
+
return [1 if l == r else 0 for l, r in zip(left, right)]
|
|
196
|
+
if source_node.op_type == "And":
|
|
197
|
+
return [1 if (l and r) else 0 for l, r in zip(left, right)]
|
|
198
|
+
if source_node.op_type == "Or":
|
|
199
|
+
return [1 if (l or r) else 0 for l, r in zip(left, right)]
|
|
200
|
+
if source_node.op_type == "Div":
|
|
201
|
+
return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
|
|
202
|
+
if source_node.op_type == "Mod":
|
|
203
|
+
return [l % r if r != 0 else 0 for l, r in zip(left, right)]
|
|
204
|
+
if source_node.op_type == "Not":
|
|
205
|
+
if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
|
|
206
|
+
raise UnsupportedOpError("Not must have 1 input and 1 output")
|
|
207
|
+
values = _shape_values_from_input(
|
|
208
|
+
graph,
|
|
209
|
+
source_node.inputs[0],
|
|
210
|
+
node,
|
|
211
|
+
_visited=_visited,
|
|
212
|
+
)
|
|
213
|
+
if values is None:
|
|
214
|
+
return None
|
|
215
|
+
return [0 if value else 1 for value in values]
|
|
216
|
+
if source_node.op_type == "Where":
|
|
217
|
+
if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
|
|
218
|
+
raise UnsupportedOpError("Where must have 3 inputs and 1 output")
|
|
219
|
+
condition = _shape_values_from_input(
|
|
220
|
+
graph,
|
|
221
|
+
source_node.inputs[0],
|
|
222
|
+
node,
|
|
223
|
+
_visited=_visited,
|
|
224
|
+
)
|
|
225
|
+
if condition is None:
|
|
226
|
+
return None
|
|
227
|
+
on_true = _shape_values_from_input(
|
|
228
|
+
graph,
|
|
229
|
+
source_node.inputs[1],
|
|
230
|
+
node,
|
|
231
|
+
_visited=_visited,
|
|
232
|
+
)
|
|
233
|
+
on_false = _shape_values_from_input(
|
|
234
|
+
graph,
|
|
235
|
+
source_node.inputs[2],
|
|
236
|
+
node,
|
|
237
|
+
_visited=_visited,
|
|
238
|
+
)
|
|
239
|
+
if on_true is None or on_false is None:
|
|
240
|
+
return None
|
|
241
|
+
if len(condition) == 1:
|
|
242
|
+
condition = condition * max(len(on_true), len(on_false))
|
|
243
|
+
if len(on_true) == 1 and len(condition) != 1:
|
|
244
|
+
on_true = on_true * len(condition)
|
|
245
|
+
if len(on_false) == 1 and len(condition) != 1:
|
|
246
|
+
on_false = on_false * len(condition)
|
|
247
|
+
if not (len(condition) == len(on_true) == len(on_false)):
|
|
248
|
+
return None
|
|
249
|
+
return [
|
|
250
|
+
t if cond else f
|
|
251
|
+
for cond, t, f in zip(condition, on_true, on_false)
|
|
252
|
+
]
|
|
253
|
+
return None
|
|
254
|
+
finally:
|
|
255
|
+
_visited.remove(name)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _broadcast_shapes(
|
|
259
|
+
left: tuple[int, ...],
|
|
260
|
+
right: tuple[int, ...],
|
|
261
|
+
) -> tuple[int, ...] | None:
|
|
262
|
+
result = []
|
|
263
|
+
left_rev = list(reversed(left))
|
|
264
|
+
right_rev = list(reversed(right))
|
|
265
|
+
for index in range(max(len(left_rev), len(right_rev))):
|
|
266
|
+
left_dim = left_rev[index] if index < len(left_rev) else 1
|
|
267
|
+
right_dim = right_rev[index] if index < len(right_rev) else 1
|
|
268
|
+
if left_dim == right_dim:
|
|
269
|
+
result.append(left_dim)
|
|
270
|
+
elif left_dim == 1:
|
|
271
|
+
result.append(right_dim)
|
|
272
|
+
elif right_dim == 1:
|
|
273
|
+
result.append(left_dim)
|
|
274
|
+
else:
|
|
275
|
+
return None
|
|
276
|
+
return tuple(reversed(result))
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _resolve_value_shape(
|
|
280
|
+
graph: Graph,
|
|
281
|
+
name: str,
|
|
282
|
+
node: Node | None,
|
|
283
|
+
*,
|
|
284
|
+
_visited: set[str] | None = None,
|
|
285
|
+
) -> tuple[int, ...] | None:
|
|
286
|
+
if _visited is None:
|
|
287
|
+
_visited = set()
|
|
288
|
+
if name in _visited:
|
|
289
|
+
return None
|
|
290
|
+
_visited.add(name)
|
|
291
|
+
try:
|
|
292
|
+
value = graph.find_value(name)
|
|
293
|
+
shape = value.type.shape
|
|
294
|
+
if not any(value.type.dim_params):
|
|
295
|
+
return shape
|
|
296
|
+
source_node = _find_node_by_output(graph, name)
|
|
297
|
+
if source_node is None:
|
|
298
|
+
return None
|
|
299
|
+
if source_node.op_type == "Expand":
|
|
300
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
301
|
+
raise UnsupportedOpError("Expand must have 2 inputs and 1 output")
|
|
302
|
+
shape_values = _shape_values_from_input(
|
|
303
|
+
graph, source_node.inputs[1], node
|
|
304
|
+
)
|
|
305
|
+
if shape_values is not None and all(dim >= 0 for dim in shape_values):
|
|
306
|
+
return tuple(shape_values)
|
|
307
|
+
return None
|
|
308
|
+
if source_node.op_type == "Reshape":
|
|
309
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
310
|
+
raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
|
|
311
|
+
shape_values = _shape_values_from_input(
|
|
312
|
+
graph, source_node.inputs[1], node
|
|
313
|
+
)
|
|
314
|
+
if shape_values is None:
|
|
315
|
+
return None
|
|
316
|
+
allowzero = int(source_node.attrs.get("allowzero", 0))
|
|
317
|
+
input_shape = _resolve_value_shape(
|
|
318
|
+
graph,
|
|
319
|
+
source_node.inputs[0],
|
|
320
|
+
node,
|
|
321
|
+
_visited=_visited,
|
|
322
|
+
)
|
|
323
|
+
if input_shape is None:
|
|
324
|
+
return None
|
|
325
|
+
output_dims: list[int] = []
|
|
326
|
+
unknown_index: int | None = None
|
|
327
|
+
known_product = 1
|
|
328
|
+
contains_zero = False
|
|
329
|
+
for index, dim in enumerate(shape_values):
|
|
330
|
+
if dim == -1:
|
|
331
|
+
if unknown_index is not None:
|
|
332
|
+
return None
|
|
333
|
+
unknown_index = len(output_dims)
|
|
334
|
+
output_dims.append(-1)
|
|
335
|
+
else:
|
|
336
|
+
if dim == 0:
|
|
337
|
+
contains_zero = True
|
|
338
|
+
if allowzero == 0:
|
|
339
|
+
if index >= len(input_shape):
|
|
340
|
+
return None
|
|
341
|
+
dim = input_shape[index]
|
|
342
|
+
if dim < 0:
|
|
343
|
+
return None
|
|
344
|
+
output_dims.append(dim)
|
|
345
|
+
known_product *= dim
|
|
346
|
+
if allowzero == 1 and contains_zero and unknown_index is not None:
|
|
347
|
+
return None
|
|
348
|
+
input_product = shape_product(input_shape)
|
|
349
|
+
if unknown_index is not None:
|
|
350
|
+
if known_product == 0:
|
|
351
|
+
if input_product != 0:
|
|
352
|
+
return None
|
|
353
|
+
output_dims[unknown_index] = 0
|
|
354
|
+
else:
|
|
355
|
+
if input_product % known_product != 0:
|
|
356
|
+
return None
|
|
357
|
+
output_dims[unknown_index] = input_product // known_product
|
|
358
|
+
return tuple(output_dims)
|
|
359
|
+
if source_node.op_type in {
|
|
360
|
+
"Add",
|
|
361
|
+
"Sub",
|
|
362
|
+
"Mul",
|
|
363
|
+
"Div",
|
|
364
|
+
"Pow",
|
|
365
|
+
"Mod",
|
|
366
|
+
"And",
|
|
367
|
+
"Or",
|
|
368
|
+
"Xor",
|
|
369
|
+
"Equal",
|
|
370
|
+
"Greater",
|
|
371
|
+
"Less",
|
|
372
|
+
"GreaterOrEqual",
|
|
373
|
+
"LessOrEqual",
|
|
374
|
+
}:
|
|
375
|
+
if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
|
|
376
|
+
raise UnsupportedOpError(
|
|
377
|
+
f"{source_node.op_type} must have 2 inputs and 1 output"
|
|
378
|
+
)
|
|
379
|
+
left = _resolve_value_shape(
|
|
380
|
+
graph,
|
|
381
|
+
source_node.inputs[0],
|
|
382
|
+
node,
|
|
383
|
+
_visited=_visited,
|
|
384
|
+
)
|
|
385
|
+
right = _resolve_value_shape(
|
|
386
|
+
graph,
|
|
387
|
+
source_node.inputs[1],
|
|
388
|
+
node,
|
|
389
|
+
_visited=_visited,
|
|
390
|
+
)
|
|
391
|
+
if left is None or right is None:
|
|
392
|
+
return None
|
|
393
|
+
return _broadcast_shapes(left, right)
|
|
394
|
+
if source_node.op_type == "Where":
|
|
395
|
+
if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
|
|
396
|
+
raise UnsupportedOpError("Where must have 3 inputs and 1 output")
|
|
397
|
+
on_true = _resolve_value_shape(
|
|
398
|
+
graph,
|
|
399
|
+
source_node.inputs[1],
|
|
400
|
+
node,
|
|
401
|
+
_visited=_visited,
|
|
402
|
+
)
|
|
403
|
+
on_false = _resolve_value_shape(
|
|
404
|
+
graph,
|
|
405
|
+
source_node.inputs[2],
|
|
406
|
+
node,
|
|
407
|
+
_visited=_visited,
|
|
408
|
+
)
|
|
409
|
+
if on_true is None or on_false is None:
|
|
410
|
+
return None
|
|
411
|
+
return _broadcast_shapes(on_true, on_false)
|
|
412
|
+
return None
|
|
413
|
+
finally:
|
|
414
|
+
_visited.remove(name)
|
|
38
415
|
|
|
39
416
|
|
|
40
417
|
def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from ..codegen.c_emitter import ConvTransposeOp
|
|
7
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
8
|
+
from ..ir.model import Graph, Node
|
|
9
|
+
from .common import node_dtype as _node_dtype
|
|
10
|
+
from .common import value_shape as _value_shape
|
|
11
|
+
from .registry import register_lowering
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ConvTransposeSpec:
|
|
16
|
+
batch: int
|
|
17
|
+
in_channels: int
|
|
18
|
+
out_channels: int
|
|
19
|
+
spatial_rank: int
|
|
20
|
+
in_spatial: tuple[int, ...]
|
|
21
|
+
out_spatial: tuple[int, ...]
|
|
22
|
+
kernel_shape: tuple[int, ...]
|
|
23
|
+
strides: tuple[int, ...]
|
|
24
|
+
pads: tuple[int, ...]
|
|
25
|
+
dilations: tuple[int, ...]
|
|
26
|
+
output_padding: tuple[int, ...]
|
|
27
|
+
group: int
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _split_padding(
|
|
31
|
+
total_padding: int, auto_pad: str, *, dim: int
|
|
32
|
+
) -> tuple[int, int]:
|
|
33
|
+
if total_padding < 0:
|
|
34
|
+
raise ShapeInferenceError(
|
|
35
|
+
"ConvTranspose output shape must be fully defined and non-negative"
|
|
36
|
+
)
|
|
37
|
+
pad_end = total_padding // 2
|
|
38
|
+
pad_begin = total_padding - pad_end
|
|
39
|
+
if auto_pad == "SAME_UPPER":
|
|
40
|
+
pad_begin, pad_end = pad_end, pad_begin
|
|
41
|
+
elif auto_pad not in {"SAME_LOWER", "NOTSET", ""}:
|
|
42
|
+
raise UnsupportedOpError(
|
|
43
|
+
f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
|
|
44
|
+
)
|
|
45
|
+
if pad_begin < 0 or pad_end < 0:
|
|
46
|
+
raise ShapeInferenceError(
|
|
47
|
+
f"ConvTranspose pads must be non-negative for dim {dim}"
|
|
48
|
+
)
|
|
49
|
+
return pad_begin, pad_end
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def resolve_conv_transpose_spec(graph: Graph, node: Node) -> ConvTransposeSpec:
|
|
53
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
54
|
+
raise UnsupportedOpError(
|
|
55
|
+
"ConvTranspose must have 2 or 3 inputs and 1 output"
|
|
56
|
+
)
|
|
57
|
+
supported_attrs = {
|
|
58
|
+
"auto_pad",
|
|
59
|
+
"dilations",
|
|
60
|
+
"group",
|
|
61
|
+
"kernel_shape",
|
|
62
|
+
"output_padding",
|
|
63
|
+
"output_shape",
|
|
64
|
+
"pads",
|
|
65
|
+
"strides",
|
|
66
|
+
}
|
|
67
|
+
if set(node.attrs) - supported_attrs:
|
|
68
|
+
raise UnsupportedOpError("ConvTranspose has unsupported attributes")
|
|
69
|
+
input_shape = _value_shape(graph, node.inputs[0], node)
|
|
70
|
+
weight_shape = _value_shape(graph, node.inputs[1], node)
|
|
71
|
+
if len(input_shape) < 3:
|
|
72
|
+
raise UnsupportedOpError("ConvTranspose expects NCHW inputs with spatial dims")
|
|
73
|
+
spatial_rank = len(input_shape) - 2
|
|
74
|
+
if spatial_rank not in {1, 2, 3}:
|
|
75
|
+
raise UnsupportedOpError("ConvTranspose supports 1D/2D/3D inputs only")
|
|
76
|
+
if len(weight_shape) != spatial_rank + 2:
|
|
77
|
+
raise UnsupportedOpError(
|
|
78
|
+
"ConvTranspose weight rank must match spatial rank"
|
|
79
|
+
)
|
|
80
|
+
batch, in_channels = input_shape[0], input_shape[1]
|
|
81
|
+
in_spatial = input_shape[2:]
|
|
82
|
+
weight_in_channels, weight_out_channels, *kernel_shape = weight_shape
|
|
83
|
+
kernel_attr = node.attrs.get("kernel_shape")
|
|
84
|
+
if kernel_attr is not None:
|
|
85
|
+
kernel_attr = tuple(int(value) for value in kernel_attr)
|
|
86
|
+
if len(kernel_attr) != spatial_rank:
|
|
87
|
+
raise UnsupportedOpError(
|
|
88
|
+
"ConvTranspose kernel_shape rank must match input spatial rank"
|
|
89
|
+
)
|
|
90
|
+
if kernel_attr != tuple(kernel_shape):
|
|
91
|
+
raise ShapeInferenceError(
|
|
92
|
+
"ConvTranspose kernel_shape must match weights, "
|
|
93
|
+
f"got {kernel_attr} and {tuple(kernel_shape)}"
|
|
94
|
+
)
|
|
95
|
+
kernel_shape = list(kernel_attr)
|
|
96
|
+
else:
|
|
97
|
+
kernel_shape = list(kernel_shape)
|
|
98
|
+
group = int(node.attrs.get("group", 1))
|
|
99
|
+
if group <= 0:
|
|
100
|
+
raise UnsupportedOpError("ConvTranspose expects group >= 1")
|
|
101
|
+
if in_channels % group != 0:
|
|
102
|
+
raise ShapeInferenceError(
|
|
103
|
+
"ConvTranspose expects group to evenly divide in channels, "
|
|
104
|
+
f"got group={group}, in_channels={in_channels}"
|
|
105
|
+
)
|
|
106
|
+
if weight_in_channels != in_channels:
|
|
107
|
+
raise ShapeInferenceError(
|
|
108
|
+
"ConvTranspose input channels must match weight channels, "
|
|
109
|
+
f"got {in_channels} and {weight_in_channels}"
|
|
110
|
+
)
|
|
111
|
+
out_channels = weight_out_channels * group
|
|
112
|
+
if out_channels % group != 0:
|
|
113
|
+
raise ShapeInferenceError(
|
|
114
|
+
"ConvTranspose expects group to evenly divide out channels, "
|
|
115
|
+
f"got group={group}, out_channels={out_channels}"
|
|
116
|
+
)
|
|
117
|
+
if len(node.inputs) == 3:
|
|
118
|
+
bias_shape = _value_shape(graph, node.inputs[2], node)
|
|
119
|
+
if bias_shape != (out_channels,):
|
|
120
|
+
raise ShapeInferenceError(
|
|
121
|
+
f"ConvTranspose bias shape must be {(out_channels,)}, got {bias_shape}"
|
|
122
|
+
)
|
|
123
|
+
strides = tuple(
|
|
124
|
+
int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
|
|
125
|
+
)
|
|
126
|
+
if len(strides) != spatial_rank:
|
|
127
|
+
raise UnsupportedOpError("ConvTranspose stride rank mismatch")
|
|
128
|
+
dilations = tuple(
|
|
129
|
+
int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
|
|
130
|
+
)
|
|
131
|
+
if len(dilations) != spatial_rank:
|
|
132
|
+
raise UnsupportedOpError("ConvTranspose dilation rank mismatch")
|
|
133
|
+
output_padding = tuple(
|
|
134
|
+
int(value)
|
|
135
|
+
for value in node.attrs.get("output_padding", (0,) * spatial_rank)
|
|
136
|
+
)
|
|
137
|
+
if len(output_padding) != spatial_rank:
|
|
138
|
+
raise UnsupportedOpError("ConvTranspose output_padding rank mismatch")
|
|
139
|
+
for dim, (padding, stride) in enumerate(zip(output_padding, strides)):
|
|
140
|
+
if padding < 0:
|
|
141
|
+
raise UnsupportedOpError(
|
|
142
|
+
"ConvTranspose output_padding must be non-negative"
|
|
143
|
+
)
|
|
144
|
+
if padding >= stride:
|
|
145
|
+
raise UnsupportedOpError(
|
|
146
|
+
"ConvTranspose output_padding must be smaller than stride"
|
|
147
|
+
)
|
|
148
|
+
pads = tuple(
|
|
149
|
+
int(value)
|
|
150
|
+
for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
|
|
151
|
+
)
|
|
152
|
+
if len(pads) != 2 * spatial_rank:
|
|
153
|
+
raise UnsupportedOpError("ConvTranspose pads rank mismatch")
|
|
154
|
+
auto_pad = node.attrs.get("auto_pad", b"NOTSET")
|
|
155
|
+
if isinstance(auto_pad, bytes):
|
|
156
|
+
auto_pad = auto_pad.decode("utf-8", errors="ignore")
|
|
157
|
+
if auto_pad == "":
|
|
158
|
+
auto_pad = "NOTSET"
|
|
159
|
+
output_shape_attr = node.attrs.get("output_shape")
|
|
160
|
+
output_shape: list[int] | None = None
|
|
161
|
+
if output_shape_attr is not None:
|
|
162
|
+
output_shape = [int(value) for value in output_shape_attr]
|
|
163
|
+
if len(output_shape) != spatial_rank:
|
|
164
|
+
raise UnsupportedOpError("ConvTranspose output_shape rank mismatch")
|
|
165
|
+
if output_shape is not None:
|
|
166
|
+
if auto_pad == "VALID":
|
|
167
|
+
auto_pad = "NOTSET"
|
|
168
|
+
pad_begin = []
|
|
169
|
+
pad_end = []
|
|
170
|
+
for dim, (in_dim, stride, dilation, kernel, out_dim, out_pad) in enumerate(
|
|
171
|
+
zip(
|
|
172
|
+
in_spatial,
|
|
173
|
+
strides,
|
|
174
|
+
dilations,
|
|
175
|
+
kernel_shape,
|
|
176
|
+
output_shape,
|
|
177
|
+
output_padding,
|
|
178
|
+
)
|
|
179
|
+
):
|
|
180
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
181
|
+
total_padding = (
|
|
182
|
+
stride * (in_dim - 1)
|
|
183
|
+
+ out_pad
|
|
184
|
+
+ effective_kernel
|
|
185
|
+
- out_dim
|
|
186
|
+
)
|
|
187
|
+
pad_start, pad_finish = _split_padding(
|
|
188
|
+
total_padding, auto_pad, dim=dim
|
|
189
|
+
)
|
|
190
|
+
pad_begin.append(pad_start)
|
|
191
|
+
pad_end.append(pad_finish)
|
|
192
|
+
out_spatial = output_shape
|
|
193
|
+
else:
|
|
194
|
+
if auto_pad == "VALID":
|
|
195
|
+
pad_begin = [0] * spatial_rank
|
|
196
|
+
pad_end = [0] * spatial_rank
|
|
197
|
+
elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
|
|
198
|
+
pad_begin = []
|
|
199
|
+
pad_end = []
|
|
200
|
+
for dim, (in_dim, stride, dilation, kernel, out_pad) in enumerate(
|
|
201
|
+
zip(in_spatial, strides, dilations, kernel_shape, output_padding)
|
|
202
|
+
):
|
|
203
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
204
|
+
out_dim = in_dim * stride
|
|
205
|
+
total_padding = (
|
|
206
|
+
stride * (in_dim - 1)
|
|
207
|
+
+ out_pad
|
|
208
|
+
+ effective_kernel
|
|
209
|
+
- out_dim
|
|
210
|
+
)
|
|
211
|
+
pad_start, pad_finish = _split_padding(
|
|
212
|
+
total_padding, auto_pad, dim=dim
|
|
213
|
+
)
|
|
214
|
+
pad_begin.append(pad_start)
|
|
215
|
+
pad_end.append(pad_finish)
|
|
216
|
+
elif auto_pad in {"NOTSET"}:
|
|
217
|
+
pad_begin = list(pads[:spatial_rank])
|
|
218
|
+
pad_end = list(pads[spatial_rank:])
|
|
219
|
+
else:
|
|
220
|
+
raise UnsupportedOpError(
|
|
221
|
+
f"ConvTranspose has unsupported auto_pad mode '{auto_pad}'"
|
|
222
|
+
)
|
|
223
|
+
out_spatial = []
|
|
224
|
+
for dim, (in_dim, stride, dilation, kernel, pad_start, pad_finish, out_pad) in enumerate(
|
|
225
|
+
zip(
|
|
226
|
+
in_spatial,
|
|
227
|
+
strides,
|
|
228
|
+
dilations,
|
|
229
|
+
kernel_shape,
|
|
230
|
+
pad_begin,
|
|
231
|
+
pad_end,
|
|
232
|
+
output_padding,
|
|
233
|
+
)
|
|
234
|
+
):
|
|
235
|
+
effective_kernel = dilation * (kernel - 1) + 1
|
|
236
|
+
out_dim = (
|
|
237
|
+
stride * (in_dim - 1)
|
|
238
|
+
+ out_pad
|
|
239
|
+
+ effective_kernel
|
|
240
|
+
- pad_start
|
|
241
|
+
- pad_finish
|
|
242
|
+
)
|
|
243
|
+
if out_dim < 0:
|
|
244
|
+
raise ShapeInferenceError(
|
|
245
|
+
"ConvTranspose output shape must be non-negative"
|
|
246
|
+
)
|
|
247
|
+
out_spatial.append(out_dim)
|
|
248
|
+
output_shape = _value_shape(graph, node.outputs[0], node)
|
|
249
|
+
expected_output_shape = (batch, out_channels, *out_spatial)
|
|
250
|
+
if output_shape != expected_output_shape:
|
|
251
|
+
raise ShapeInferenceError(
|
|
252
|
+
"ConvTranspose output shape must be "
|
|
253
|
+
f"{expected_output_shape}, got {output_shape}"
|
|
254
|
+
)
|
|
255
|
+
return ConvTransposeSpec(
|
|
256
|
+
batch=batch,
|
|
257
|
+
in_channels=in_channels,
|
|
258
|
+
out_channels=out_channels,
|
|
259
|
+
spatial_rank=spatial_rank,
|
|
260
|
+
in_spatial=in_spatial,
|
|
261
|
+
out_spatial=tuple(out_spatial),
|
|
262
|
+
kernel_shape=tuple(kernel_shape),
|
|
263
|
+
strides=strides,
|
|
264
|
+
pads=(*pad_begin, *pad_end),
|
|
265
|
+
dilations=dilations,
|
|
266
|
+
output_padding=output_padding,
|
|
267
|
+
group=group,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@register_lowering("ConvTranspose")
|
|
272
|
+
def lower_conv_transpose(graph: Graph, node: Node) -> ConvTransposeOp:
|
|
273
|
+
if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
|
|
274
|
+
raise UnsupportedOpError(
|
|
275
|
+
"ConvTranspose must have 2 or 3 inputs and 1 output"
|
|
276
|
+
)
|
|
277
|
+
op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
278
|
+
if not op_dtype.is_float:
|
|
279
|
+
raise UnsupportedOpError(
|
|
280
|
+
"ConvTranspose supports float16, float, and double inputs only"
|
|
281
|
+
)
|
|
282
|
+
spec = resolve_conv_transpose_spec(graph, node)
|
|
283
|
+
return ConvTransposeOp(
|
|
284
|
+
input0=node.inputs[0],
|
|
285
|
+
weights=node.inputs[1],
|
|
286
|
+
bias=node.inputs[2] if len(node.inputs) == 3 else None,
|
|
287
|
+
output=node.outputs[0],
|
|
288
|
+
batch=spec.batch,
|
|
289
|
+
in_channels=spec.in_channels,
|
|
290
|
+
out_channels=spec.out_channels,
|
|
291
|
+
spatial_rank=spec.spatial_rank,
|
|
292
|
+
in_spatial=spec.in_spatial,
|
|
293
|
+
out_spatial=spec.out_spatial,
|
|
294
|
+
kernel_shape=spec.kernel_shape,
|
|
295
|
+
strides=spec.strides,
|
|
296
|
+
pads=spec.pads,
|
|
297
|
+
dilations=spec.dilations,
|
|
298
|
+
output_padding=spec.output_padding,
|
|
299
|
+
group=spec.group,
|
|
300
|
+
dtype=op_dtype,
|
|
301
|
+
)
|