onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Convolution operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from ..op_registry import register
|
|
11
|
+
from ..utils.attributes import get_attribute
|
|
12
|
+
from ..utils.op_helpers import compute_same_padding, get_optional_input
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from ..graph_builder import GraphBuilder
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# Convolution operators
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _get_conv_params(node: onnx.NodeProto) -> dict:
|
|
24
|
+
"""Extract common convolution parameters from node attributes."""
|
|
25
|
+
return {
|
|
26
|
+
"dilations": get_attribute(node, "dilations"),
|
|
27
|
+
"group": get_attribute(node, "group", 1),
|
|
28
|
+
"kernel_shape": get_attribute(node, "kernel_shape"),
|
|
29
|
+
"pads": get_attribute(node, "pads"),
|
|
30
|
+
"strides": get_attribute(node, "strides"),
|
|
31
|
+
"auto_pad": get_attribute(node, "auto_pad", "NOTSET"),
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@register("Conv")
|
|
36
|
+
def conv(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
37
|
+
"""N-dimensional convolution."""
|
|
38
|
+
x = builder.get_value(node.input[0])
|
|
39
|
+
weight = builder.get_value(node.input[1])
|
|
40
|
+
bias = get_optional_input(builder, node, 2)
|
|
41
|
+
|
|
42
|
+
params = _get_conv_params(node)
|
|
43
|
+
strides = params["strides"] or [1]
|
|
44
|
+
dilations = params["dilations"] or [1]
|
|
45
|
+
group = params["group"]
|
|
46
|
+
pads = params["pads"]
|
|
47
|
+
auto_pad = params["auto_pad"]
|
|
48
|
+
kernel_shape = params["kernel_shape"]
|
|
49
|
+
|
|
50
|
+
def _conv(x, weight, bias, strides, dilations, group, pads, auto_pad, kernel_shape):
|
|
51
|
+
ndim = len(weight.shape) - 2 # Exclude batch and channel dims
|
|
52
|
+
|
|
53
|
+
# Expand strides and dilations to match ndim
|
|
54
|
+
if len(strides) == 1:
|
|
55
|
+
strides = strides * ndim
|
|
56
|
+
if len(dilations) == 1:
|
|
57
|
+
dilations = dilations * ndim
|
|
58
|
+
|
|
59
|
+
# Handle padding
|
|
60
|
+
padding = 0
|
|
61
|
+
if pads is not None:
|
|
62
|
+
n = len(pads) // 2
|
|
63
|
+
symmetric = all(pads[i] == pads[i + n] for i in range(n))
|
|
64
|
+
if symmetric:
|
|
65
|
+
padding = tuple(pads[:n])
|
|
66
|
+
else:
|
|
67
|
+
# Asymmetric padding
|
|
68
|
+
# ONNX: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
|
69
|
+
# F.pad: [xn_begin, xn_end, ..., x1_begin, x1_end] (reverse order)
|
|
70
|
+
pad_list = []
|
|
71
|
+
for i in range(n - 1, -1, -1):
|
|
72
|
+
pad_list.extend([pads[i], pads[i + n]])
|
|
73
|
+
x = F.pad(x, pad_list)
|
|
74
|
+
padding = 0
|
|
75
|
+
|
|
76
|
+
# Handle auto_pad
|
|
77
|
+
if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
|
|
78
|
+
# Compute padding for SAME
|
|
79
|
+
input_shape = x.shape[2:]
|
|
80
|
+
k_shape = kernel_shape or weight.shape[2:]
|
|
81
|
+
pad_list = compute_same_padding(
|
|
82
|
+
tuple(input_shape),
|
|
83
|
+
tuple(k_shape),
|
|
84
|
+
tuple(strides),
|
|
85
|
+
tuple(dilations),
|
|
86
|
+
auto_pad,
|
|
87
|
+
)
|
|
88
|
+
x = F.pad(x, pad_list)
|
|
89
|
+
padding = 0
|
|
90
|
+
|
|
91
|
+
strides_tuple = tuple(strides) if len(strides) > 1 else strides[0]
|
|
92
|
+
dilations_tuple = tuple(dilations) if len(dilations) > 1 else dilations[0]
|
|
93
|
+
|
|
94
|
+
if ndim == 1:
|
|
95
|
+
return F.conv1d(
|
|
96
|
+
x,
|
|
97
|
+
weight,
|
|
98
|
+
bias,
|
|
99
|
+
stride=strides_tuple,
|
|
100
|
+
padding=padding,
|
|
101
|
+
dilation=dilations_tuple,
|
|
102
|
+
groups=group,
|
|
103
|
+
)
|
|
104
|
+
elif ndim == 2:
|
|
105
|
+
return F.conv2d(
|
|
106
|
+
x,
|
|
107
|
+
weight,
|
|
108
|
+
bias,
|
|
109
|
+
stride=strides_tuple,
|
|
110
|
+
padding=padding,
|
|
111
|
+
dilation=dilations_tuple,
|
|
112
|
+
groups=group,
|
|
113
|
+
)
|
|
114
|
+
elif ndim == 3:
|
|
115
|
+
return F.conv3d(
|
|
116
|
+
x,
|
|
117
|
+
weight,
|
|
118
|
+
bias,
|
|
119
|
+
stride=strides_tuple,
|
|
120
|
+
padding=padding,
|
|
121
|
+
dilation=dilations_tuple,
|
|
122
|
+
groups=group,
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
raise NotImplementedError(f"Conv{ndim}D not supported")
|
|
126
|
+
|
|
127
|
+
return builder.call_function(
|
|
128
|
+
_conv,
|
|
129
|
+
args=(x, weight, bias, strides, dilations, group, pads, auto_pad, kernel_shape),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@register("ConvTranspose")
|
|
134
|
+
def conv_transpose(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
135
|
+
"""N-dimensional transposed convolution."""
|
|
136
|
+
x = builder.get_value(node.input[0])
|
|
137
|
+
weight = builder.get_value(node.input[1])
|
|
138
|
+
bias = get_optional_input(builder, node, 2)
|
|
139
|
+
|
|
140
|
+
strides = get_attribute(node, "strides") or [1]
|
|
141
|
+
dilations = get_attribute(node, "dilations") or [1]
|
|
142
|
+
group = get_attribute(node, "group", 1)
|
|
143
|
+
pads = get_attribute(node, "pads")
|
|
144
|
+
output_padding = get_attribute(node, "output_padding") or [0]
|
|
145
|
+
auto_pad = get_attribute(node, "auto_pad", "NOTSET")
|
|
146
|
+
output_shape = get_attribute(node, "output_shape")
|
|
147
|
+
kernel_shape = get_attribute(node, "kernel_shape")
|
|
148
|
+
|
|
149
|
+
def _conv_transpose(
|
|
150
|
+
x,
|
|
151
|
+
weight,
|
|
152
|
+
bias,
|
|
153
|
+
strides,
|
|
154
|
+
dilations,
|
|
155
|
+
group,
|
|
156
|
+
pads,
|
|
157
|
+
output_padding,
|
|
158
|
+
auto_pad,
|
|
159
|
+
output_shape,
|
|
160
|
+
kernel_shape,
|
|
161
|
+
):
|
|
162
|
+
ndim = len(weight.shape) - 2
|
|
163
|
+
|
|
164
|
+
# Expand strides, dilations, and output_padding to match ndim
|
|
165
|
+
if len(strides) == 1:
|
|
166
|
+
strides = strides * ndim
|
|
167
|
+
if len(dilations) == 1:
|
|
168
|
+
dilations = dilations * ndim
|
|
169
|
+
if len(output_padding) == 1:
|
|
170
|
+
output_padding = output_padding * ndim
|
|
171
|
+
|
|
172
|
+
# Get kernel shape from weight if not provided
|
|
173
|
+
k_shape = kernel_shape if kernel_shape else list(weight.shape[2:])
|
|
174
|
+
|
|
175
|
+
# Handle auto_pad and output_shape
|
|
176
|
+
# For ConvTranspose, the output shape formula is:
|
|
177
|
+
# output_shape = (input_shape - 1) * stride + (kernel_shape - 1) * dilation + 1 - pad_begin - pad_end + output_padding
|
|
178
|
+
padding = [0] * ndim
|
|
179
|
+
adj_output_padding = list(output_padding)
|
|
180
|
+
|
|
181
|
+
if output_shape is not None:
|
|
182
|
+
# Compute pads from output_shape
|
|
183
|
+
# output_shape[i] = (input_shape[i] - 1) * stride[i] + (k - 1) * dilation[i] + 1 - total_pad[i] + output_pad[i]
|
|
184
|
+
# total_pad[i] = (input_shape[i] - 1) * stride[i] + (k - 1) * dilation[i] + 1 - output_shape[i] + output_pad[i]
|
|
185
|
+
input_shape = x.shape[2:]
|
|
186
|
+
for i in range(ndim):
|
|
187
|
+
default_output = (
|
|
188
|
+
(input_shape[i] - 1) * strides[i]
|
|
189
|
+
+ (k_shape[i] - 1) * dilations[i]
|
|
190
|
+
+ 1
|
|
191
|
+
)
|
|
192
|
+
total_pad = default_output - output_shape[i]
|
|
193
|
+
if total_pad >= 0:
|
|
194
|
+
padding[i] = total_pad // 2
|
|
195
|
+
# Adjust output_padding to match the exact output_shape
|
|
196
|
+
adj_output_padding[i] = total_pad - 2 * padding[i]
|
|
197
|
+
else:
|
|
198
|
+
# Need additional output_padding
|
|
199
|
+
padding[i] = 0
|
|
200
|
+
adj_output_padding[i] = -total_pad
|
|
201
|
+
elif auto_pad in ("SAME_UPPER", "SAME_LOWER"):
|
|
202
|
+
# For SAME auto_pad in ConvTranspose:
|
|
203
|
+
# target output_shape = input_shape * stride
|
|
204
|
+
# We do full conv_transpose without padding and then slice the output
|
|
205
|
+
input_shape = x.shape[2:]
|
|
206
|
+
target_shape = [input_shape[i] * strides[i] for i in range(ndim)]
|
|
207
|
+
|
|
208
|
+
# Default output without padding
|
|
209
|
+
default_output = [
|
|
210
|
+
(input_shape[i] - 1) * strides[i] + (k_shape[i] - 1) * dilations[i] + 1
|
|
211
|
+
for i in range(ndim)
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
# Calculate how much to trim from each dimension
|
|
215
|
+
trim_total = [default_output[i] - target_shape[i] for i in range(ndim)]
|
|
216
|
+
|
|
217
|
+
# For SAME_UPPER: extra pad at end means trim from end
|
|
218
|
+
# For SAME_LOWER: extra pad at begin means trim from begin
|
|
219
|
+
if auto_pad == "SAME_UPPER":
|
|
220
|
+
trim_begin = [t // 2 for t in trim_total]
|
|
221
|
+
trim_end = [t - t // 2 for t in trim_total]
|
|
222
|
+
else: # SAME_LOWER
|
|
223
|
+
trim_end = [t // 2 for t in trim_total]
|
|
224
|
+
trim_begin = [t - t // 2 for t in trim_total]
|
|
225
|
+
|
|
226
|
+
# Do full conv_transpose without padding
|
|
227
|
+
strides_tuple = tuple(strides) if len(strides) > 1 else strides[0]
|
|
228
|
+
dilations_tuple = tuple(dilations) if len(dilations) > 1 else dilations[0]
|
|
229
|
+
|
|
230
|
+
if ndim == 1:
|
|
231
|
+
result = F.conv_transpose1d(
|
|
232
|
+
x,
|
|
233
|
+
weight,
|
|
234
|
+
bias,
|
|
235
|
+
stride=strides_tuple,
|
|
236
|
+
padding=0,
|
|
237
|
+
output_padding=0,
|
|
238
|
+
groups=group,
|
|
239
|
+
dilation=dilations_tuple,
|
|
240
|
+
)
|
|
241
|
+
# Slice to get target shape
|
|
242
|
+
end0 = result.shape[2] - trim_end[0] if trim_end[0] > 0 else None
|
|
243
|
+
return result[:, :, trim_begin[0] : end0]
|
|
244
|
+
elif ndim == 2:
|
|
245
|
+
result = F.conv_transpose2d(
|
|
246
|
+
x,
|
|
247
|
+
weight,
|
|
248
|
+
bias,
|
|
249
|
+
stride=strides_tuple,
|
|
250
|
+
padding=0,
|
|
251
|
+
output_padding=0,
|
|
252
|
+
groups=group,
|
|
253
|
+
dilation=dilations_tuple,
|
|
254
|
+
)
|
|
255
|
+
# Slice to get target shape
|
|
256
|
+
end0 = result.shape[2] - trim_end[0] if trim_end[0] > 0 else None
|
|
257
|
+
end1 = result.shape[3] - trim_end[1] if trim_end[1] > 0 else None
|
|
258
|
+
return result[:, :, trim_begin[0] : end0, trim_begin[1] : end1]
|
|
259
|
+
elif ndim == 3:
|
|
260
|
+
result = F.conv_transpose3d(
|
|
261
|
+
x,
|
|
262
|
+
weight,
|
|
263
|
+
bias,
|
|
264
|
+
stride=strides_tuple,
|
|
265
|
+
padding=0,
|
|
266
|
+
output_padding=0,
|
|
267
|
+
groups=group,
|
|
268
|
+
dilation=dilations_tuple,
|
|
269
|
+
)
|
|
270
|
+
# Slice to get target shape
|
|
271
|
+
end0 = result.shape[2] - trim_end[0] if trim_end[0] > 0 else None
|
|
272
|
+
end1 = result.shape[3] - trim_end[1] if trim_end[1] > 0 else None
|
|
273
|
+
end2 = result.shape[4] - trim_end[2] if trim_end[2] > 0 else None
|
|
274
|
+
return result[
|
|
275
|
+
:,
|
|
276
|
+
:,
|
|
277
|
+
trim_begin[0] : end0,
|
|
278
|
+
trim_begin[1] : end1,
|
|
279
|
+
trim_begin[2] : end2,
|
|
280
|
+
]
|
|
281
|
+
else:
|
|
282
|
+
raise NotImplementedError(f"ConvTranspose{ndim}D not supported")
|
|
283
|
+
elif pads is not None:
|
|
284
|
+
n = len(pads) // 2
|
|
285
|
+
padding = list(pads[:n])
|
|
286
|
+
# Handle asymmetric pads via output_padding
|
|
287
|
+
for i in range(n):
|
|
288
|
+
if pads[i] != pads[i + n]:
|
|
289
|
+
adj_output_padding[i] = pads[i + n] - pads[i]
|
|
290
|
+
|
|
291
|
+
strides_tuple = tuple(strides) if len(strides) > 1 else strides[0]
|
|
292
|
+
dilations_tuple = tuple(dilations) if len(dilations) > 1 else dilations[0]
|
|
293
|
+
padding_tuple = tuple(padding) if len(padding) > 1 else padding[0]
|
|
294
|
+
output_padding_tuple = (
|
|
295
|
+
tuple(adj_output_padding)
|
|
296
|
+
if len(adj_output_padding) > 1
|
|
297
|
+
else adj_output_padding[0]
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
if ndim == 1:
|
|
301
|
+
return F.conv_transpose1d(
|
|
302
|
+
x,
|
|
303
|
+
weight,
|
|
304
|
+
bias,
|
|
305
|
+
stride=strides_tuple,
|
|
306
|
+
padding=padding_tuple,
|
|
307
|
+
output_padding=output_padding_tuple,
|
|
308
|
+
groups=group,
|
|
309
|
+
dilation=dilations_tuple,
|
|
310
|
+
)
|
|
311
|
+
elif ndim == 2:
|
|
312
|
+
return F.conv_transpose2d(
|
|
313
|
+
x,
|
|
314
|
+
weight,
|
|
315
|
+
bias,
|
|
316
|
+
stride=strides_tuple,
|
|
317
|
+
padding=padding_tuple,
|
|
318
|
+
output_padding=output_padding_tuple,
|
|
319
|
+
groups=group,
|
|
320
|
+
dilation=dilations_tuple,
|
|
321
|
+
)
|
|
322
|
+
elif ndim == 3:
|
|
323
|
+
return F.conv_transpose3d(
|
|
324
|
+
x,
|
|
325
|
+
weight,
|
|
326
|
+
bias,
|
|
327
|
+
stride=strides_tuple,
|
|
328
|
+
padding=padding_tuple,
|
|
329
|
+
output_padding=output_padding_tuple,
|
|
330
|
+
groups=group,
|
|
331
|
+
dilation=dilations_tuple,
|
|
332
|
+
)
|
|
333
|
+
else:
|
|
334
|
+
raise NotImplementedError(f"ConvTranspose{ndim}D not supported")
|
|
335
|
+
|
|
336
|
+
return builder.call_function(
|
|
337
|
+
_conv_transpose,
|
|
338
|
+
args=(
|
|
339
|
+
x,
|
|
340
|
+
weight,
|
|
341
|
+
bias,
|
|
342
|
+
strides,
|
|
343
|
+
dilations,
|
|
344
|
+
group,
|
|
345
|
+
pads,
|
|
346
|
+
output_padding,
|
|
347
|
+
auto_pad,
|
|
348
|
+
output_shape,
|
|
349
|
+
kernel_shape,
|
|
350
|
+
),
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
@register("DeformConv")
|
|
355
|
+
def deform_conv(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
356
|
+
"""Deformable convolution.
|
|
357
|
+
|
|
358
|
+
Performs deformable convolution as described in:
|
|
359
|
+
- Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211)
|
|
360
|
+
- Deformable ConvNets v2 (https://arxiv.org/abs/1811.11168) when mask is provided
|
|
361
|
+
|
|
362
|
+
Note: Only 2D deformable convolution is supported as torchvision.ops.deform_conv2d
|
|
363
|
+
only supports 2D inputs.
|
|
364
|
+
"""
|
|
365
|
+
import torchvision.ops
|
|
366
|
+
|
|
367
|
+
x = builder.get_value(node.input[0])
|
|
368
|
+
weight = builder.get_value(node.input[1])
|
|
369
|
+
offset = builder.get_value(node.input[2])
|
|
370
|
+
|
|
371
|
+
bias = get_optional_input(builder, node, 3)
|
|
372
|
+
mask = get_optional_input(builder, node, 4)
|
|
373
|
+
|
|
374
|
+
strides = get_attribute(node, "strides") or [1, 1]
|
|
375
|
+
dilations = get_attribute(node, "dilations") or [1, 1]
|
|
376
|
+
pads = get_attribute(node, "pads") or [0, 0, 0, 0]
|
|
377
|
+
# Note: group and offset_group are inferred from tensor shapes by torchvision
|
|
378
|
+
# ONNX attributes are parsed but not explicitly passed to the function
|
|
379
|
+
|
|
380
|
+
def _deform_conv(x, weight, offset, bias, mask, strides, dilations, pads):
|
|
381
|
+
# Handle padding - ONNX uses [begin0, begin1, end0, end1] format
|
|
382
|
+
# torchvision.ops.deform_conv2d expects (pad_H, pad_W)
|
|
383
|
+
# For simplicity, assume symmetric padding (ONNX pads should be symmetric)
|
|
384
|
+
n = len(pads) // 2
|
|
385
|
+
padding = tuple(pads[:n])
|
|
386
|
+
|
|
387
|
+
stride = tuple(strides) if len(strides) > 1 else (strides[0], strides[0])
|
|
388
|
+
dilation = (
|
|
389
|
+
tuple(dilations) if len(dilations) > 1 else (dilations[0], dilations[0])
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
return torchvision.ops.deform_conv2d(
|
|
393
|
+
x,
|
|
394
|
+
offset,
|
|
395
|
+
weight,
|
|
396
|
+
bias=bias,
|
|
397
|
+
stride=stride,
|
|
398
|
+
padding=padding,
|
|
399
|
+
dilation=dilation,
|
|
400
|
+
mask=mask,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return builder.call_function(
|
|
404
|
+
_deform_conv,
|
|
405
|
+
args=(x, weight, offset, bias, mask, strides, dilations, pads),
|
|
406
|
+
)
|