onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,406 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Convolution operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from ..op_registry import register
11
+ from ..utils.attributes import get_attribute
12
+ from ..utils.op_helpers import compute_same_padding, get_optional_input
13
+
14
+ if TYPE_CHECKING:
15
+ from ..graph_builder import GraphBuilder
16
+
17
+
18
+ # =============================================================================
19
+ # Convolution operators
20
+ # =============================================================================
21
+
22
+
23
+ def _get_conv_params(node: onnx.NodeProto) -> dict:
24
+ """Extract common convolution parameters from node attributes."""
25
+ return {
26
+ "dilations": get_attribute(node, "dilations"),
27
+ "group": get_attribute(node, "group", 1),
28
+ "kernel_shape": get_attribute(node, "kernel_shape"),
29
+ "pads": get_attribute(node, "pads"),
30
+ "strides": get_attribute(node, "strides"),
31
+ "auto_pad": get_attribute(node, "auto_pad", "NOTSET"),
32
+ }
33
+
34
+
35
+ @register("Conv")
36
+ def conv(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
37
+ """N-dimensional convolution."""
38
+ x = builder.get_value(node.input[0])
39
+ weight = builder.get_value(node.input[1])
40
+ bias = get_optional_input(builder, node, 2)
41
+
42
+ params = _get_conv_params(node)
43
+ strides = params["strides"] or [1]
44
+ dilations = params["dilations"] or [1]
45
+ group = params["group"]
46
+ pads = params["pads"]
47
+ auto_pad = params["auto_pad"]
48
+ kernel_shape = params["kernel_shape"]
49
+
50
+ def _conv(x, weight, bias, strides, dilations, group, pads, auto_pad, kernel_shape):
51
+ ndim = len(weight.shape) - 2 # Exclude batch and channel dims
52
+
53
+ # Expand strides and dilations to match ndim
54
+ if len(strides) == 1:
55
+ strides = strides * ndim
56
+ if len(dilations) == 1:
57
+ dilations = dilations * ndim
58
+
59
+ # Handle padding
60
+ padding = 0
61
+ if pads is not None:
62
+ n = len(pads) // 2
63
+ symmetric = all(pads[i] == pads[i + n] for i in range(n))
64
+ if symmetric:
65
+ padding = tuple(pads[:n])
66
+ else:
67
+ # Asymmetric padding
68
+ # ONNX: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
69
+ # F.pad: [xn_begin, xn_end, ..., x1_begin, x1_end] (reverse order)
70
+ pad_list = []
71
+ for i in range(n - 1, -1, -1):
72
+ pad_list.extend([pads[i], pads[i + n]])
73
+ x = F.pad(x, pad_list)
74
+ padding = 0
75
+
76
+ # Handle auto_pad
77
+ if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
78
+ # Compute padding for SAME
79
+ input_shape = x.shape[2:]
80
+ k_shape = kernel_shape or weight.shape[2:]
81
+ pad_list = compute_same_padding(
82
+ tuple(input_shape),
83
+ tuple(k_shape),
84
+ tuple(strides),
85
+ tuple(dilations),
86
+ auto_pad,
87
+ )
88
+ x = F.pad(x, pad_list)
89
+ padding = 0
90
+
91
+ strides_tuple = tuple(strides) if len(strides) > 1 else strides[0]
92
+ dilations_tuple = tuple(dilations) if len(dilations) > 1 else dilations[0]
93
+
94
+ if ndim == 1:
95
+ return F.conv1d(
96
+ x,
97
+ weight,
98
+ bias,
99
+ stride=strides_tuple,
100
+ padding=padding,
101
+ dilation=dilations_tuple,
102
+ groups=group,
103
+ )
104
+ elif ndim == 2:
105
+ return F.conv2d(
106
+ x,
107
+ weight,
108
+ bias,
109
+ stride=strides_tuple,
110
+ padding=padding,
111
+ dilation=dilations_tuple,
112
+ groups=group,
113
+ )
114
+ elif ndim == 3:
115
+ return F.conv3d(
116
+ x,
117
+ weight,
118
+ bias,
119
+ stride=strides_tuple,
120
+ padding=padding,
121
+ dilation=dilations_tuple,
122
+ groups=group,
123
+ )
124
+ else:
125
+ raise NotImplementedError(f"Conv{ndim}D not supported")
126
+
127
+ return builder.call_function(
128
+ _conv,
129
+ args=(x, weight, bias, strides, dilations, group, pads, auto_pad, kernel_shape),
130
+ )
131
+
132
+
133
+ @register("ConvTranspose")
134
+ def conv_transpose(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
135
+ """N-dimensional transposed convolution."""
136
+ x = builder.get_value(node.input[0])
137
+ weight = builder.get_value(node.input[1])
138
+ bias = get_optional_input(builder, node, 2)
139
+
140
+ strides = get_attribute(node, "strides") or [1]
141
+ dilations = get_attribute(node, "dilations") or [1]
142
+ group = get_attribute(node, "group", 1)
143
+ pads = get_attribute(node, "pads")
144
+ output_padding = get_attribute(node, "output_padding") or [0]
145
+ auto_pad = get_attribute(node, "auto_pad", "NOTSET")
146
+ output_shape = get_attribute(node, "output_shape")
147
+ kernel_shape = get_attribute(node, "kernel_shape")
148
+
149
+ def _conv_transpose(
150
+ x,
151
+ weight,
152
+ bias,
153
+ strides,
154
+ dilations,
155
+ group,
156
+ pads,
157
+ output_padding,
158
+ auto_pad,
159
+ output_shape,
160
+ kernel_shape,
161
+ ):
162
+ ndim = len(weight.shape) - 2
163
+
164
+ # Expand strides, dilations, and output_padding to match ndim
165
+ if len(strides) == 1:
166
+ strides = strides * ndim
167
+ if len(dilations) == 1:
168
+ dilations = dilations * ndim
169
+ if len(output_padding) == 1:
170
+ output_padding = output_padding * ndim
171
+
172
+ # Get kernel shape from weight if not provided
173
+ k_shape = kernel_shape if kernel_shape else list(weight.shape[2:])
174
+
175
+ # Handle auto_pad and output_shape
176
+ # For ConvTranspose, the output shape formula is:
177
+ # output_shape = (input_shape - 1) * stride + (kernel_shape - 1) * dilation + 1 - pad_begin - pad_end + output_padding
178
+ padding = [0] * ndim
179
+ adj_output_padding = list(output_padding)
180
+
181
+ if output_shape is not None:
182
+ # Compute pads from output_shape
183
+ # output_shape[i] = (input_shape[i] - 1) * stride[i] + (k - 1) * dilation[i] + 1 - total_pad[i] + output_pad[i]
184
+ # total_pad[i] = (input_shape[i] - 1) * stride[i] + (k - 1) * dilation[i] + 1 - output_shape[i] + output_pad[i]
185
+ input_shape = x.shape[2:]
186
+ for i in range(ndim):
187
+ default_output = (
188
+ (input_shape[i] - 1) * strides[i]
189
+ + (k_shape[i] - 1) * dilations[i]
190
+ + 1
191
+ )
192
+ total_pad = default_output - output_shape[i]
193
+ if total_pad >= 0:
194
+ padding[i] = total_pad // 2
195
+ # Adjust output_padding to match the exact output_shape
196
+ adj_output_padding[i] = total_pad - 2 * padding[i]
197
+ else:
198
+ # Need additional output_padding
199
+ padding[i] = 0
200
+ adj_output_padding[i] = -total_pad
201
+ elif auto_pad in ("SAME_UPPER", "SAME_LOWER"):
202
+ # For SAME auto_pad in ConvTranspose:
203
+ # target output_shape = input_shape * stride
204
+ # We do full conv_transpose without padding and then slice the output
205
+ input_shape = x.shape[2:]
206
+ target_shape = [input_shape[i] * strides[i] for i in range(ndim)]
207
+
208
+ # Default output without padding
209
+ default_output = [
210
+ (input_shape[i] - 1) * strides[i] + (k_shape[i] - 1) * dilations[i] + 1
211
+ for i in range(ndim)
212
+ ]
213
+
214
+ # Calculate how much to trim from each dimension
215
+ trim_total = [default_output[i] - target_shape[i] for i in range(ndim)]
216
+
217
+ # For SAME_UPPER: extra pad at end means trim from end
218
+ # For SAME_LOWER: extra pad at begin means trim from begin
219
+ if auto_pad == "SAME_UPPER":
220
+ trim_begin = [t // 2 for t in trim_total]
221
+ trim_end = [t - t // 2 for t in trim_total]
222
+ else: # SAME_LOWER
223
+ trim_end = [t // 2 for t in trim_total]
224
+ trim_begin = [t - t // 2 for t in trim_total]
225
+
226
+ # Do full conv_transpose without padding
227
+ strides_tuple = tuple(strides) if len(strides) > 1 else strides[0]
228
+ dilations_tuple = tuple(dilations) if len(dilations) > 1 else dilations[0]
229
+
230
+ if ndim == 1:
231
+ result = F.conv_transpose1d(
232
+ x,
233
+ weight,
234
+ bias,
235
+ stride=strides_tuple,
236
+ padding=0,
237
+ output_padding=0,
238
+ groups=group,
239
+ dilation=dilations_tuple,
240
+ )
241
+ # Slice to get target shape
242
+ end0 = result.shape[2] - trim_end[0] if trim_end[0] > 0 else None
243
+ return result[:, :, trim_begin[0] : end0]
244
+ elif ndim == 2:
245
+ result = F.conv_transpose2d(
246
+ x,
247
+ weight,
248
+ bias,
249
+ stride=strides_tuple,
250
+ padding=0,
251
+ output_padding=0,
252
+ groups=group,
253
+ dilation=dilations_tuple,
254
+ )
255
+ # Slice to get target shape
256
+ end0 = result.shape[2] - trim_end[0] if trim_end[0] > 0 else None
257
+ end1 = result.shape[3] - trim_end[1] if trim_end[1] > 0 else None
258
+ return result[:, :, trim_begin[0] : end0, trim_begin[1] : end1]
259
+ elif ndim == 3:
260
+ result = F.conv_transpose3d(
261
+ x,
262
+ weight,
263
+ bias,
264
+ stride=strides_tuple,
265
+ padding=0,
266
+ output_padding=0,
267
+ groups=group,
268
+ dilation=dilations_tuple,
269
+ )
270
+ # Slice to get target shape
271
+ end0 = result.shape[2] - trim_end[0] if trim_end[0] > 0 else None
272
+ end1 = result.shape[3] - trim_end[1] if trim_end[1] > 0 else None
273
+ end2 = result.shape[4] - trim_end[2] if trim_end[2] > 0 else None
274
+ return result[
275
+ :,
276
+ :,
277
+ trim_begin[0] : end0,
278
+ trim_begin[1] : end1,
279
+ trim_begin[2] : end2,
280
+ ]
281
+ else:
282
+ raise NotImplementedError(f"ConvTranspose{ndim}D not supported")
283
+ elif pads is not None:
284
+ n = len(pads) // 2
285
+ padding = list(pads[:n])
286
+ # Handle asymmetric pads via output_padding
287
+ for i in range(n):
288
+ if pads[i] != pads[i + n]:
289
+ adj_output_padding[i] = pads[i + n] - pads[i]
290
+
291
+ strides_tuple = tuple(strides) if len(strides) > 1 else strides[0]
292
+ dilations_tuple = tuple(dilations) if len(dilations) > 1 else dilations[0]
293
+ padding_tuple = tuple(padding) if len(padding) > 1 else padding[0]
294
+ output_padding_tuple = (
295
+ tuple(adj_output_padding)
296
+ if len(adj_output_padding) > 1
297
+ else adj_output_padding[0]
298
+ )
299
+
300
+ if ndim == 1:
301
+ return F.conv_transpose1d(
302
+ x,
303
+ weight,
304
+ bias,
305
+ stride=strides_tuple,
306
+ padding=padding_tuple,
307
+ output_padding=output_padding_tuple,
308
+ groups=group,
309
+ dilation=dilations_tuple,
310
+ )
311
+ elif ndim == 2:
312
+ return F.conv_transpose2d(
313
+ x,
314
+ weight,
315
+ bias,
316
+ stride=strides_tuple,
317
+ padding=padding_tuple,
318
+ output_padding=output_padding_tuple,
319
+ groups=group,
320
+ dilation=dilations_tuple,
321
+ )
322
+ elif ndim == 3:
323
+ return F.conv_transpose3d(
324
+ x,
325
+ weight,
326
+ bias,
327
+ stride=strides_tuple,
328
+ padding=padding_tuple,
329
+ output_padding=output_padding_tuple,
330
+ groups=group,
331
+ dilation=dilations_tuple,
332
+ )
333
+ else:
334
+ raise NotImplementedError(f"ConvTranspose{ndim}D not supported")
335
+
336
+ return builder.call_function(
337
+ _conv_transpose,
338
+ args=(
339
+ x,
340
+ weight,
341
+ bias,
342
+ strides,
343
+ dilations,
344
+ group,
345
+ pads,
346
+ output_padding,
347
+ auto_pad,
348
+ output_shape,
349
+ kernel_shape,
350
+ ),
351
+ )
352
+
353
+
354
+ @register("DeformConv")
355
+ def deform_conv(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
356
+ """Deformable convolution.
357
+
358
+ Performs deformable convolution as described in:
359
+ - Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211)
360
+ - Deformable ConvNets v2 (https://arxiv.org/abs/1811.11168) when mask is provided
361
+
362
+ Note: Only 2D deformable convolution is supported as torchvision.ops.deform_conv2d
363
+ only supports 2D inputs.
364
+ """
365
+ import torchvision.ops
366
+
367
+ x = builder.get_value(node.input[0])
368
+ weight = builder.get_value(node.input[1])
369
+ offset = builder.get_value(node.input[2])
370
+
371
+ bias = get_optional_input(builder, node, 3)
372
+ mask = get_optional_input(builder, node, 4)
373
+
374
+ strides = get_attribute(node, "strides") or [1, 1]
375
+ dilations = get_attribute(node, "dilations") or [1, 1]
376
+ pads = get_attribute(node, "pads") or [0, 0, 0, 0]
377
+ # Note: group and offset_group are inferred from tensor shapes by torchvision
378
+ # ONNX attributes are parsed but not explicitly passed to the function
379
+
380
+ def _deform_conv(x, weight, offset, bias, mask, strides, dilations, pads):
381
+ # Handle padding - ONNX uses [begin0, begin1, end0, end1] format
382
+ # torchvision.ops.deform_conv2d expects (pad_H, pad_W)
383
+ # For simplicity, assume symmetric padding (ONNX pads should be symmetric)
384
+ n = len(pads) // 2
385
+ padding = tuple(pads[:n])
386
+
387
+ stride = tuple(strides) if len(strides) > 1 else (strides[0], strides[0])
388
+ dilation = (
389
+ tuple(dilations) if len(dilations) > 1 else (dilations[0], dilations[0])
390
+ )
391
+
392
+ return torchvision.ops.deform_conv2d(
393
+ x,
394
+ offset,
395
+ weight,
396
+ bias=bias,
397
+ stride=stride,
398
+ padding=padding,
399
+ dilation=dilation,
400
+ mask=mask,
401
+ )
402
+
403
+ return builder.call_function(
404
+ _deform_conv,
405
+ args=(x, weight, offset, bias, mask, strides, dilations, pads),
406
+ )