onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/image.py
ADDED
|
@@ -0,0 +1,748 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Image and spatial transformation operators.
|
|
3
|
+
|
|
4
|
+
This module implements ONNX operators for image resizing and
|
|
5
|
+
spatial dimension rearrangement.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ..op_registry import register
|
|
14
|
+
from ..utils.attributes import get_attribute
|
|
15
|
+
from ..utils.op_helpers import get_optional_input
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from ..graph_builder import GraphBuilder
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# =============================================================================
|
|
22
|
+
# Resize operator
|
|
23
|
+
# =============================================================================
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@register("Resize")
|
|
27
|
+
def resize(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
28
|
+
"""Resize tensor using interpolation."""
|
|
29
|
+
x = builder.get_value(node.input[0])
|
|
30
|
+
|
|
31
|
+
# roi, scales, sizes are optional inputs
|
|
32
|
+
roi = get_optional_input(builder, node, 1)
|
|
33
|
+
scales = get_optional_input(builder, node, 2)
|
|
34
|
+
sizes = get_optional_input(builder, node, 3)
|
|
35
|
+
|
|
36
|
+
mode = get_attribute(node, "mode", "nearest")
|
|
37
|
+
coordinate_transformation_mode = get_attribute(
|
|
38
|
+
node, "coordinate_transformation_mode", "half_pixel"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def _resize(x, roi, scales, sizes, mode, coord_mode):
|
|
42
|
+
import torch.nn.functional as F
|
|
43
|
+
|
|
44
|
+
# Map ONNX mode to PyTorch mode
|
|
45
|
+
mode_map = {
|
|
46
|
+
"nearest": "nearest",
|
|
47
|
+
"linear": "bilinear" if x.dim() == 4 else "linear",
|
|
48
|
+
"cubic": "bicubic",
|
|
49
|
+
}
|
|
50
|
+
torch_mode = mode_map.get(mode, "nearest")
|
|
51
|
+
|
|
52
|
+
# Determine align_corners based on coordinate transformation mode
|
|
53
|
+
align_corners = coord_mode == "align_corners"
|
|
54
|
+
if torch_mode == "nearest":
|
|
55
|
+
align_corners = None
|
|
56
|
+
|
|
57
|
+
if sizes is not None:
|
|
58
|
+
# Use explicit sizes
|
|
59
|
+
size_list = sizes.tolist() if isinstance(sizes, torch.Tensor) else sizes
|
|
60
|
+
# Skip batch and channel dimensions
|
|
61
|
+
output_size = [int(s) for s in size_list[2:]]
|
|
62
|
+
elif scales is not None:
|
|
63
|
+
# Use scales
|
|
64
|
+
scale_list = scales.tolist() if isinstance(scales, torch.Tensor) else scales
|
|
65
|
+
input_shape = x.shape[2:]
|
|
66
|
+
output_size = [int(s * sc) for s, sc in zip(input_shape, scale_list[2:])]
|
|
67
|
+
else:
|
|
68
|
+
return x
|
|
69
|
+
|
|
70
|
+
kwargs = {"size": output_size, "mode": torch_mode}
|
|
71
|
+
if align_corners is not None and torch_mode not in ["nearest", "area"]:
|
|
72
|
+
kwargs["align_corners"] = align_corners
|
|
73
|
+
|
|
74
|
+
return F.interpolate(x, **kwargs)
|
|
75
|
+
|
|
76
|
+
return builder.call_function(
|
|
77
|
+
_resize, args=(x, roi, scales, sizes, mode, coordinate_transformation_mode)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# =============================================================================
|
|
82
|
+
# Upsample operator (deprecated in opset 10, replaced by Resize)
|
|
83
|
+
# =============================================================================
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@register("Upsample", since_version=7)
|
|
87
|
+
def upsample(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
88
|
+
"""Upsample tensor using interpolation.
|
|
89
|
+
|
|
90
|
+
Deprecated: This operator is deprecated since opset 10.
|
|
91
|
+
Use Resize operator instead.
|
|
92
|
+
|
|
93
|
+
Opset 7-8: scales is an attribute
|
|
94
|
+
Opset 9: scales is an input
|
|
95
|
+
"""
|
|
96
|
+
x = builder.get_value(node.input[0])
|
|
97
|
+
|
|
98
|
+
# In opset 9, scales is an input; in opset 7-8, it's an attribute
|
|
99
|
+
opset = builder.opset_version
|
|
100
|
+
if opset >= 9 and len(node.input) > 1 and node.input[1]:
|
|
101
|
+
scales = builder.get_value(node.input[1])
|
|
102
|
+
else:
|
|
103
|
+
scales = get_attribute(node, "scales")
|
|
104
|
+
|
|
105
|
+
mode = get_attribute(node, "mode", "nearest")
|
|
106
|
+
|
|
107
|
+
def _upsample(x, scales, mode):
|
|
108
|
+
import torch.nn.functional as F
|
|
109
|
+
|
|
110
|
+
# Map ONNX mode to PyTorch mode
|
|
111
|
+
mode_map = {
|
|
112
|
+
"nearest": "nearest",
|
|
113
|
+
"linear": "bilinear" if x.dim() == 4 else "linear",
|
|
114
|
+
"cubic": "bicubic",
|
|
115
|
+
}
|
|
116
|
+
torch_mode = mode_map.get(mode, "nearest")
|
|
117
|
+
|
|
118
|
+
# Use scales to compute output size
|
|
119
|
+
scale_list = scales.tolist() if isinstance(scales, torch.Tensor) else scales
|
|
120
|
+
input_shape = x.shape[2:]
|
|
121
|
+
output_size = [int(s * sc) for s, sc in zip(input_shape, scale_list[2:])]
|
|
122
|
+
|
|
123
|
+
kwargs = {"size": output_size, "mode": torch_mode}
|
|
124
|
+
# align_corners is not used for nearest mode
|
|
125
|
+
if torch_mode not in ["nearest", "area"]:
|
|
126
|
+
kwargs["align_corners"] = False
|
|
127
|
+
|
|
128
|
+
return F.interpolate(x, **kwargs)
|
|
129
|
+
|
|
130
|
+
return builder.call_function(_upsample, args=(x, scales, mode))
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# =============================================================================
|
|
134
|
+
# Depth/Space rearrangement operators
|
|
135
|
+
# =============================================================================
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@register("DepthToSpace")
|
|
139
|
+
def depth_to_space(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
140
|
+
"""Rearrange depth to spatial dimensions."""
|
|
141
|
+
x = builder.get_value(node.input[0])
|
|
142
|
+
blocksize = get_attribute(node, "blocksize")
|
|
143
|
+
mode = get_attribute(node, "mode", "DCR")
|
|
144
|
+
|
|
145
|
+
def _depth_to_space(x, blocksize, mode):
|
|
146
|
+
b, c, h, w = x.shape
|
|
147
|
+
if mode == "DCR":
|
|
148
|
+
# Depth-Column-Row
|
|
149
|
+
x = x.reshape(b, blocksize, blocksize, c // (blocksize**2), h, w)
|
|
150
|
+
x = x.permute(0, 3, 4, 1, 5, 2)
|
|
151
|
+
x = x.reshape(b, c // (blocksize**2), h * blocksize, w * blocksize)
|
|
152
|
+
else:
|
|
153
|
+
# CRD mode
|
|
154
|
+
x = x.reshape(b, c // (blocksize**2), blocksize, blocksize, h, w)
|
|
155
|
+
x = x.permute(0, 1, 4, 2, 5, 3)
|
|
156
|
+
x = x.reshape(b, c // (blocksize**2), h * blocksize, w * blocksize)
|
|
157
|
+
return x
|
|
158
|
+
|
|
159
|
+
return builder.call_function(_depth_to_space, args=(x, blocksize, mode))
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@register("SpaceToDepth")
|
|
163
|
+
def space_to_depth(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
164
|
+
"""Rearrange spatial dimensions to depth."""
|
|
165
|
+
x = builder.get_value(node.input[0])
|
|
166
|
+
blocksize = get_attribute(node, "blocksize")
|
|
167
|
+
|
|
168
|
+
def _space_to_depth(x, blocksize):
|
|
169
|
+
b, c, h, w = x.shape
|
|
170
|
+
x = x.reshape(b, c, h // blocksize, blocksize, w // blocksize, blocksize)
|
|
171
|
+
x = x.permute(0, 3, 5, 1, 2, 4)
|
|
172
|
+
x = x.reshape(b, c * blocksize * blocksize, h // blocksize, w // blocksize)
|
|
173
|
+
return x
|
|
174
|
+
|
|
175
|
+
return builder.call_function(_space_to_depth, args=(x, blocksize))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# =============================================================================
|
|
179
|
+
# Col2Im operator
|
|
180
|
+
# =============================================================================
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@register("Col2Im", since_version=18)
|
|
184
|
+
def col2im(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
185
|
+
"""Rearrange column blocks back into a multidimensional image.
|
|
186
|
+
|
|
187
|
+
ONNX Col2Im is the inverse of Im2Col. It combines sliding local blocks
|
|
188
|
+
(columns) back into a larger image tensor.
|
|
189
|
+
|
|
190
|
+
Inputs:
|
|
191
|
+
input: [N, C * prod(block_shape), L] - batched column data
|
|
192
|
+
image_shape: spatial dimensions of the output image
|
|
193
|
+
block_shape: shape of the sliding block
|
|
194
|
+
|
|
195
|
+
Attributes:
|
|
196
|
+
strides: stride along each spatial axis (default: 1)
|
|
197
|
+
pads: padding for each spatial axis in ONNX format (default: 0)
|
|
198
|
+
dilations: dilation for each spatial axis (default: 1)
|
|
199
|
+
|
|
200
|
+
Output:
|
|
201
|
+
[N, C, *image_shape] - reconstructed image
|
|
202
|
+
"""
|
|
203
|
+
x = builder.get_value(node.input[0])
|
|
204
|
+
image_shape = builder.get_value(node.input[1])
|
|
205
|
+
block_shape = builder.get_value(node.input[2])
|
|
206
|
+
|
|
207
|
+
strides = get_attribute(node, "strides")
|
|
208
|
+
pads = get_attribute(node, "pads")
|
|
209
|
+
dilations = get_attribute(node, "dilations")
|
|
210
|
+
|
|
211
|
+
def _col2im(x, image_shape, block_shape, strides, pads, dilations):
|
|
212
|
+
import torch.nn.functional as F
|
|
213
|
+
from functools import reduce
|
|
214
|
+
from itertools import product
|
|
215
|
+
from operator import mul
|
|
216
|
+
|
|
217
|
+
# Convert to lists if tensors
|
|
218
|
+
if isinstance(image_shape, torch.Tensor):
|
|
219
|
+
image_shape = image_shape.tolist()
|
|
220
|
+
if isinstance(block_shape, torch.Tensor):
|
|
221
|
+
block_shape = block_shape.tolist()
|
|
222
|
+
|
|
223
|
+
n_dims = len(block_shape)
|
|
224
|
+
|
|
225
|
+
# Default values
|
|
226
|
+
if strides is None:
|
|
227
|
+
strides = [1] * n_dims
|
|
228
|
+
if pads is None:
|
|
229
|
+
pads = [0] * (2 * n_dims)
|
|
230
|
+
if dilations is None:
|
|
231
|
+
dilations = [1] * n_dims
|
|
232
|
+
|
|
233
|
+
# For 2D, use PyTorch's optimized fold
|
|
234
|
+
if n_dims == 2:
|
|
235
|
+
# PyTorch fold uses symmetric padding per dimension
|
|
236
|
+
padding = (pads[0], pads[1])
|
|
237
|
+
return F.fold(
|
|
238
|
+
x,
|
|
239
|
+
output_size=tuple(image_shape),
|
|
240
|
+
kernel_size=tuple(block_shape),
|
|
241
|
+
stride=tuple(strides),
|
|
242
|
+
padding=padding,
|
|
243
|
+
dilation=tuple(dilations),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# For N-D, implement manually
|
|
247
|
+
N = x.shape[0]
|
|
248
|
+
L = x.shape[2]
|
|
249
|
+
block_size = reduce(mul, block_shape, 1)
|
|
250
|
+
C = x.shape[1] // block_size
|
|
251
|
+
|
|
252
|
+
# Reshape input: [N, C * prod(block_shape), L] -> [N, C, *block_shape, L]
|
|
253
|
+
input_reshaped = x.reshape(N, C, *block_shape, L)
|
|
254
|
+
|
|
255
|
+
# Initialize output: [N, C, *image_shape]
|
|
256
|
+
output = torch.zeros(N, C, *image_shape, dtype=x.dtype, device=x.device)
|
|
257
|
+
|
|
258
|
+
# Compute effective kernel size after dilation
|
|
259
|
+
effective_block = [(b - 1) * d + 1 for b, d in zip(block_shape, dilations)]
|
|
260
|
+
|
|
261
|
+
# Compute number of blocks in each dimension
|
|
262
|
+
n_blocks = []
|
|
263
|
+
for i, (img_dim, eff_block, s) in enumerate(
|
|
264
|
+
zip(image_shape, effective_block, strides)
|
|
265
|
+
):
|
|
266
|
+
p_begin = pads[i]
|
|
267
|
+
p_end = pads[n_dims + i]
|
|
268
|
+
n_block = (img_dim + p_begin + p_end - eff_block) // s + 1
|
|
269
|
+
n_blocks.append(n_block)
|
|
270
|
+
|
|
271
|
+
# Iterate over all block positions
|
|
272
|
+
block_indices = list(product(*[range(nb) for nb in n_blocks]))
|
|
273
|
+
|
|
274
|
+
for l_idx, block_idx in enumerate(block_indices):
|
|
275
|
+
# Compute starting position for this block
|
|
276
|
+
starts = [
|
|
277
|
+
bi * s - pads[i] for i, (bi, s) in enumerate(zip(block_idx, strides))
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
# For each position in the block
|
|
281
|
+
block_positions = list(product(*[range(b) for b in block_shape]))
|
|
282
|
+
for block_pos in block_positions:
|
|
283
|
+
# Compute actual output position with dilation
|
|
284
|
+
output_pos = [
|
|
285
|
+
starts[i] + block_pos[i] * dilations[i] for i in range(n_dims)
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
# Check bounds
|
|
289
|
+
valid = all(0 <= output_pos[i] < image_shape[i] for i in range(n_dims))
|
|
290
|
+
if valid:
|
|
291
|
+
# Get value from input_reshaped: [N, C, *block_shape, L]
|
|
292
|
+
idx = (slice(None), slice(None)) + tuple(block_pos) + (l_idx,)
|
|
293
|
+
value = input_reshaped[idx]
|
|
294
|
+
|
|
295
|
+
# Add to output
|
|
296
|
+
out_idx = (slice(None), slice(None)) + tuple(output_pos)
|
|
297
|
+
output[out_idx] += value
|
|
298
|
+
|
|
299
|
+
return output
|
|
300
|
+
|
|
301
|
+
return builder.call_function(
|
|
302
|
+
_col2im, args=(x, image_shape, block_shape, strides, pads, dilations)
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# =============================================================================
|
|
307
|
+
# CenterCropPad operator
|
|
308
|
+
# =============================================================================
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
# =============================================================================
|
|
312
|
+
# GridSample operator
|
|
313
|
+
# =============================================================================
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@register("GridSample", since_version=16)
|
|
317
|
+
def grid_sample(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
318
|
+
"""Sample input using grid of sampling locations.
|
|
319
|
+
|
|
320
|
+
Given an input X and a flow-field grid, computes the output Y using X values
|
|
321
|
+
and pixel locations from the grid. This is equivalent to PyTorch's
|
|
322
|
+
torch.nn.functional.grid_sample.
|
|
323
|
+
|
|
324
|
+
Inputs:
|
|
325
|
+
X: Input tensor of shape (N, C, H, W) for 4D or (N, C, D, H, W) for 5D
|
|
326
|
+
grid: Grid tensor of shape (N, H_out, W_out, 2) for 4D or
|
|
327
|
+
(N, D_out, H_out, W_out, 3) for 5D
|
|
328
|
+
|
|
329
|
+
Attributes:
|
|
330
|
+
align_corners: If 1, extrema (-1 and 1) refer to center of corner pixels.
|
|
331
|
+
If 0, they refer to corner points. Default: 0
|
|
332
|
+
mode: Interpolation mode - 'linear'/'bilinear' (default), 'nearest',
|
|
333
|
+
'cubic'/'bicubic'. Opset 16 uses bilinear/bicubic, opset 20+ uses
|
|
334
|
+
linear/cubic.
|
|
335
|
+
padding_mode: Padding mode for outside grid values - 'zeros' (default),
|
|
336
|
+
'border', 'reflection'
|
|
337
|
+
|
|
338
|
+
Output:
|
|
339
|
+
Y: Output tensor of shape (N, C, H_out, W_out) or (N, C, D_out, H_out, W_out)
|
|
340
|
+
"""
|
|
341
|
+
x = builder.get_value(node.input[0])
|
|
342
|
+
grid = builder.get_value(node.input[1])
|
|
343
|
+
|
|
344
|
+
align_corners = get_attribute(node, "align_corners", 0)
|
|
345
|
+
# Handle different mode names across opset versions
|
|
346
|
+
# Opset 16: bilinear (default), nearest, bicubic
|
|
347
|
+
# Opset 20+: linear (default), nearest, cubic
|
|
348
|
+
mode = get_attribute(node, "mode", "linear")
|
|
349
|
+
padding_mode = get_attribute(node, "padding_mode", "zeros")
|
|
350
|
+
|
|
351
|
+
def _grid_sample(x, grid, mode, padding_mode, align_corners):
|
|
352
|
+
import torch.nn.functional as F
|
|
353
|
+
|
|
354
|
+
# Map ONNX mode names to PyTorch mode names
|
|
355
|
+
# PyTorch expects: 'bilinear', 'nearest', 'bicubic' for 4D input
|
|
356
|
+
# PyTorch expects: 'bilinear', 'nearest' for 5D input (no bicubic)
|
|
357
|
+
mode_map = {
|
|
358
|
+
"linear": "bilinear",
|
|
359
|
+
"bilinear": "bilinear",
|
|
360
|
+
"nearest": "nearest",
|
|
361
|
+
"cubic": "bicubic",
|
|
362
|
+
"bicubic": "bicubic",
|
|
363
|
+
}
|
|
364
|
+
torch_mode = mode_map.get(mode, "bilinear")
|
|
365
|
+
|
|
366
|
+
# Convert align_corners from int to bool
|
|
367
|
+
align_corners_bool = bool(align_corners)
|
|
368
|
+
|
|
369
|
+
return F.grid_sample(
|
|
370
|
+
x,
|
|
371
|
+
grid,
|
|
372
|
+
mode=torch_mode,
|
|
373
|
+
padding_mode=padding_mode,
|
|
374
|
+
align_corners=align_corners_bool,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
return builder.call_function(
|
|
378
|
+
_grid_sample, args=(x, grid, mode, padding_mode, align_corners)
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
# =============================================================================
|
|
383
|
+
# AffineGrid operator
|
|
384
|
+
# =============================================================================
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@register("AffineGrid", since_version=20)
|
|
388
|
+
def affine_grid(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
389
|
+
"""Generate 2D or 3D flow field (sampling grid) from affine matrices.
|
|
390
|
+
|
|
391
|
+
Given a batch of affine matrices theta, generates a grid of sampling
|
|
392
|
+
locations. This is typically used with GridSample to build Spatial
|
|
393
|
+
Transformer Networks.
|
|
394
|
+
|
|
395
|
+
Inputs:
|
|
396
|
+
theta: Input batch of affine matrices with shape (N, 2, 3) for 2D
|
|
397
|
+
or (N, 3, 4) for 3D
|
|
398
|
+
size: Target output image size (N, C, H, W) for 2D or (N, C, D, H, W)
|
|
399
|
+
for 3D, as a 1-D tensor
|
|
400
|
+
|
|
401
|
+
Attributes:
|
|
402
|
+
align_corners: If 1, consider -1 and 1 to refer to the centers of the
|
|
403
|
+
corner pixels. If 0, consider -1 and 1 to refer to the
|
|
404
|
+
outer edge of corner pixels. Default: 0
|
|
405
|
+
|
|
406
|
+
Output:
|
|
407
|
+
grid: Output tensor of shape (N, H, W, 2) for 2D sample coordinates
|
|
408
|
+
or (N, D, H, W, 3) for 3D sample coordinates
|
|
409
|
+
"""
|
|
410
|
+
theta = builder.get_value(node.input[0])
|
|
411
|
+
size = builder.get_value(node.input[1])
|
|
412
|
+
|
|
413
|
+
align_corners = get_attribute(node, "align_corners", 0)
|
|
414
|
+
|
|
415
|
+
def _affine_grid(theta, size, align_corners):
|
|
416
|
+
import torch.nn.functional as F
|
|
417
|
+
|
|
418
|
+
# Convert size tensor to a list of integers for torch.Size
|
|
419
|
+
if isinstance(size, torch.Tensor):
|
|
420
|
+
size_list = size.tolist()
|
|
421
|
+
else:
|
|
422
|
+
size_list = list(size)
|
|
423
|
+
size_tuple = torch.Size([int(s) for s in size_list])
|
|
424
|
+
|
|
425
|
+
# Convert align_corners from int to bool
|
|
426
|
+
align_corners_bool = bool(align_corners)
|
|
427
|
+
|
|
428
|
+
return F.affine_grid(theta, size_tuple, align_corners=align_corners_bool)
|
|
429
|
+
|
|
430
|
+
return builder.call_function(_affine_grid, args=(theta, size, align_corners))
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
# =============================================================================
|
|
434
|
+
# CenterCropPad operator
|
|
435
|
+
# =============================================================================
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@register("CenterCropPad", since_version=18)
|
|
439
|
+
def center_crop_pad(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
440
|
+
"""Center crop or pad an input to given dimensions.
|
|
441
|
+
|
|
442
|
+
The crop/pad dimensions can be specified for a subset of the axes.
|
|
443
|
+
Unspecified dimensions will remain unchanged.
|
|
444
|
+
|
|
445
|
+
For cropping (input > target): centered window, start position rounded down.
|
|
446
|
+
For padding (input < target): centered padding, extra pixel on right if odd.
|
|
447
|
+
|
|
448
|
+
Inputs:
|
|
449
|
+
input_data: Input tensor to crop/pad
|
|
450
|
+
shape: 1-D tensor of target dimensions for specified axes
|
|
451
|
+
|
|
452
|
+
Attributes:
|
|
453
|
+
axes: Subset of axes that shape refers to (default: all axes)
|
|
454
|
+
|
|
455
|
+
Output:
|
|
456
|
+
Output tensor with specified dimensions
|
|
457
|
+
"""
|
|
458
|
+
x = builder.get_value(node.input[0])
|
|
459
|
+
shape = builder.get_value(node.input[1])
|
|
460
|
+
|
|
461
|
+
axes = get_attribute(node, "axes")
|
|
462
|
+
|
|
463
|
+
def _center_crop_pad(x, shape, axes):
|
|
464
|
+
# Convert shape to list if tensor
|
|
465
|
+
if isinstance(shape, torch.Tensor):
|
|
466
|
+
target_shape = shape.tolist()
|
|
467
|
+
else:
|
|
468
|
+
target_shape = list(shape)
|
|
469
|
+
|
|
470
|
+
ndim = x.dim()
|
|
471
|
+
|
|
472
|
+
# If axes is not provided, use all axes
|
|
473
|
+
if axes is None:
|
|
474
|
+
axes_list = list(range(ndim))
|
|
475
|
+
else:
|
|
476
|
+
axes_list = list(axes)
|
|
477
|
+
|
|
478
|
+
# Normalize negative axes
|
|
479
|
+
axes_list = [(a + ndim) if a < 0 else a for a in axes_list]
|
|
480
|
+
|
|
481
|
+
# Build slices for cropping and padding amounts
|
|
482
|
+
result = x
|
|
483
|
+
for i, axis in enumerate(axes_list):
|
|
484
|
+
current_size = result.shape[axis]
|
|
485
|
+
target_size = int(target_shape[i])
|
|
486
|
+
|
|
487
|
+
if current_size == target_size:
|
|
488
|
+
# No change needed for this axis
|
|
489
|
+
continue
|
|
490
|
+
elif current_size > target_size:
|
|
491
|
+
# Crop: extract centered window
|
|
492
|
+
# Start position is rounded down (floor division)
|
|
493
|
+
diff = current_size - target_size
|
|
494
|
+
start = diff // 2
|
|
495
|
+
end = start + target_size
|
|
496
|
+
|
|
497
|
+
# Build slice for this axis
|
|
498
|
+
slices = [slice(None)] * result.dim()
|
|
499
|
+
slices[axis] = slice(start, end)
|
|
500
|
+
result = result[tuple(slices)]
|
|
501
|
+
else:
|
|
502
|
+
# Pad: add zeros centered
|
|
503
|
+
# Extra pixel goes to the right side
|
|
504
|
+
diff = target_size - current_size
|
|
505
|
+
pad_before = diff // 2
|
|
506
|
+
pad_after = diff - pad_before
|
|
507
|
+
|
|
508
|
+
# torch.nn.functional.pad uses reverse order: last dim first
|
|
509
|
+
# and pairs are (before, after) for each dim from last to first
|
|
510
|
+
# We need to construct padding for just this one axis
|
|
511
|
+
|
|
512
|
+
# Number of dimensions from the end
|
|
513
|
+
dims_from_end = result.dim() - 1 - axis
|
|
514
|
+
|
|
515
|
+
# Build pad tuple: pairs for each dim from last to first
|
|
516
|
+
# We only pad the current axis
|
|
517
|
+
pad = [0] * (2 * result.dim())
|
|
518
|
+
# Index in pad list: dims_from_end * 2 for before, +1 for after
|
|
519
|
+
pad[dims_from_end * 2] = pad_before
|
|
520
|
+
pad[dims_from_end * 2 + 1] = pad_after
|
|
521
|
+
|
|
522
|
+
result = torch.nn.functional.pad(result, pad, mode="constant", value=0)
|
|
523
|
+
|
|
524
|
+
return result
|
|
525
|
+
|
|
526
|
+
return builder.call_function(_center_crop_pad, args=(x, shape, axes))
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
# =============================================================================
|
|
530
|
+
# RoiAlign operator
|
|
531
|
+
# =============================================================================
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@register("RoiAlign")
|
|
535
|
+
def roi_align(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
536
|
+
"""Region of Interest (RoI) Align operator.
|
|
537
|
+
|
|
538
|
+
Performs RoI pooling with bilinear interpolation for sub-pixel accuracy.
|
|
539
|
+
"""
|
|
540
|
+
x = builder.get_value(node.input[0])
|
|
541
|
+
rois = builder.get_value(node.input[1])
|
|
542
|
+
batch_indices = builder.get_value(node.input[2])
|
|
543
|
+
|
|
544
|
+
# Get attributes with defaults
|
|
545
|
+
mode = get_attribute(node, "mode", "avg")
|
|
546
|
+
output_height = get_attribute(node, "output_height", 1)
|
|
547
|
+
output_width = get_attribute(node, "output_width", 1)
|
|
548
|
+
sampling_ratio = get_attribute(node, "sampling_ratio", 0)
|
|
549
|
+
spatial_scale = get_attribute(node, "spatial_scale", 1.0)
|
|
550
|
+
coordinate_transformation_mode = get_attribute(
|
|
551
|
+
node, "coordinate_transformation_mode", "half_pixel"
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# ONNX coordinate_transformation_mode:
|
|
555
|
+
# - "half_pixel": pixel shift by -0.5, corresponds to aligned=True in PyTorch
|
|
556
|
+
# - "output_half_pixel": no pixel shift (legacy), corresponds to aligned=False
|
|
557
|
+
aligned = coordinate_transformation_mode == "half_pixel"
|
|
558
|
+
|
|
559
|
+
def _roi_align(
|
|
560
|
+
x,
|
|
561
|
+
rois,
|
|
562
|
+
batch_indices,
|
|
563
|
+
mode,
|
|
564
|
+
output_height,
|
|
565
|
+
output_width,
|
|
566
|
+
sampling_ratio,
|
|
567
|
+
spatial_scale,
|
|
568
|
+
aligned,
|
|
569
|
+
):
|
|
570
|
+
from torchvision.ops import roi_align as tv_roi_align
|
|
571
|
+
|
|
572
|
+
# PyTorch expects boxes in format [batch_idx, x1, y1, x2, y2]
|
|
573
|
+
boxes = torch.cat([batch_indices.unsqueeze(1).float(), rois.float()], dim=1)
|
|
574
|
+
output_size = (output_height, output_width)
|
|
575
|
+
|
|
576
|
+
if mode == "avg":
|
|
577
|
+
# Use torchvision's roi_align directly for average mode
|
|
578
|
+
# sampling_ratio: ONNX uses 0 for adaptive, PyTorch uses -1
|
|
579
|
+
torch_sampling = sampling_ratio if sampling_ratio > 0 else -1
|
|
580
|
+
return tv_roi_align(
|
|
581
|
+
x,
|
|
582
|
+
boxes,
|
|
583
|
+
output_size,
|
|
584
|
+
spatial_scale=spatial_scale,
|
|
585
|
+
sampling_ratio=torch_sampling,
|
|
586
|
+
aligned=aligned,
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
# Max mode: ONNX defines max pooling differently from standard
|
|
590
|
+
# bilinear interpolation. For each sample point, it takes the
|
|
591
|
+
# max of the 4 weighted corner values (not the sum).
|
|
592
|
+
return _roi_align_max_mode(
|
|
593
|
+
x,
|
|
594
|
+
rois,
|
|
595
|
+
batch_indices,
|
|
596
|
+
output_height,
|
|
597
|
+
output_width,
|
|
598
|
+
sampling_ratio,
|
|
599
|
+
spatial_scale,
|
|
600
|
+
aligned,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
return builder.call_function(
|
|
604
|
+
_roi_align,
|
|
605
|
+
args=(
|
|
606
|
+
x,
|
|
607
|
+
rois,
|
|
608
|
+
batch_indices,
|
|
609
|
+
mode,
|
|
610
|
+
output_height,
|
|
611
|
+
output_width,
|
|
612
|
+
sampling_ratio,
|
|
613
|
+
spatial_scale,
|
|
614
|
+
aligned,
|
|
615
|
+
),
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _roi_align_max_mode(
|
|
620
|
+
x,
|
|
621
|
+
rois,
|
|
622
|
+
batch_indices,
|
|
623
|
+
output_height,
|
|
624
|
+
output_width,
|
|
625
|
+
sampling_ratio,
|
|
626
|
+
spatial_scale,
|
|
627
|
+
half_pixel,
|
|
628
|
+
):
|
|
629
|
+
"""ONNX RoiAlign with max pooling mode.
|
|
630
|
+
|
|
631
|
+
For each output bin, samples at grid points and takes the MAX of the
|
|
632
|
+
weighted corner values at each sampling point, then MAX across all
|
|
633
|
+
sampling points in the bin.
|
|
634
|
+
"""
|
|
635
|
+
num_rois = rois.shape[0]
|
|
636
|
+
channels = x.shape[1]
|
|
637
|
+
height = x.shape[2]
|
|
638
|
+
width = x.shape[3]
|
|
639
|
+
|
|
640
|
+
output = torch.zeros(
|
|
641
|
+
num_rois, channels, output_height, output_width, dtype=x.dtype, device=x.device
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
for n in range(num_rois):
|
|
645
|
+
roi_batch_ind = int(batch_indices[n].item())
|
|
646
|
+
roi = rois[n]
|
|
647
|
+
|
|
648
|
+
# Apply spatial scale and offset
|
|
649
|
+
offset = 0.5 if half_pixel else 0.0
|
|
650
|
+
roi_start_w = float(roi[0]) * spatial_scale - offset
|
|
651
|
+
roi_start_h = float(roi[1]) * spatial_scale - offset
|
|
652
|
+
roi_end_w = float(roi[2]) * spatial_scale - offset
|
|
653
|
+
roi_end_h = float(roi[3]) * spatial_scale - offset
|
|
654
|
+
|
|
655
|
+
roi_width = roi_end_w - roi_start_w
|
|
656
|
+
roi_height = roi_end_h - roi_start_h
|
|
657
|
+
|
|
658
|
+
if not half_pixel:
|
|
659
|
+
# Force malformed ROIs to be 1x1
|
|
660
|
+
roi_width = max(roi_width, 1.0)
|
|
661
|
+
roi_height = max(roi_height, 1.0)
|
|
662
|
+
|
|
663
|
+
bin_size_h = roi_height / output_height
|
|
664
|
+
bin_size_w = roi_width / output_width
|
|
665
|
+
|
|
666
|
+
# Determine sampling grid size
|
|
667
|
+
if sampling_ratio > 0:
|
|
668
|
+
roi_bin_grid_h = sampling_ratio
|
|
669
|
+
roi_bin_grid_w = sampling_ratio
|
|
670
|
+
else:
|
|
671
|
+
roi_bin_grid_h = int(
|
|
672
|
+
torch.ceil(torch.tensor(roi_height / output_height)).item()
|
|
673
|
+
)
|
|
674
|
+
roi_bin_grid_w = int(
|
|
675
|
+
torch.ceil(torch.tensor(roi_width / output_width)).item()
|
|
676
|
+
)
|
|
677
|
+
roi_bin_grid_h = max(1, roi_bin_grid_h)
|
|
678
|
+
roi_bin_grid_w = max(1, roi_bin_grid_w)
|
|
679
|
+
|
|
680
|
+
for c in range(channels):
|
|
681
|
+
for ph in range(output_height):
|
|
682
|
+
for pw in range(output_width):
|
|
683
|
+
output_val = None
|
|
684
|
+
|
|
685
|
+
for iy in range(roi_bin_grid_h):
|
|
686
|
+
yy = (
|
|
687
|
+
roi_start_h
|
|
688
|
+
+ ph * bin_size_h
|
|
689
|
+
+ (iy + 0.5) * bin_size_h / roi_bin_grid_h
|
|
690
|
+
)
|
|
691
|
+
for ix in range(roi_bin_grid_w):
|
|
692
|
+
xx = (
|
|
693
|
+
roi_start_w
|
|
694
|
+
+ pw * bin_size_w
|
|
695
|
+
+ (ix + 0.5) * bin_size_w / roi_bin_grid_w
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Check bounds
|
|
699
|
+
if yy < -1.0 or yy > height or xx < -1.0 or xx > width:
|
|
700
|
+
continue
|
|
701
|
+
|
|
702
|
+
y = max(yy, 0.0)
|
|
703
|
+
xc = max(xx, 0.0)
|
|
704
|
+
|
|
705
|
+
y_low = int(y)
|
|
706
|
+
x_low = int(xc)
|
|
707
|
+
|
|
708
|
+
if y_low >= height - 1:
|
|
709
|
+
y_high = y_low = height - 1
|
|
710
|
+
y = float(y_low)
|
|
711
|
+
else:
|
|
712
|
+
y_high = y_low + 1
|
|
713
|
+
|
|
714
|
+
if x_low >= width - 1:
|
|
715
|
+
x_high = x_low = width - 1
|
|
716
|
+
xc = float(x_low)
|
|
717
|
+
else:
|
|
718
|
+
x_high = x_low + 1
|
|
719
|
+
|
|
720
|
+
ly = y - y_low
|
|
721
|
+
lx = xc - x_low
|
|
722
|
+
hy = 1.0 - ly
|
|
723
|
+
hx = 1.0 - lx
|
|
724
|
+
|
|
725
|
+
# Weights
|
|
726
|
+
w1 = hy * hx
|
|
727
|
+
w2 = hy * lx
|
|
728
|
+
w3 = ly * hx
|
|
729
|
+
w4 = ly * lx
|
|
730
|
+
|
|
731
|
+
# Get corner values
|
|
732
|
+
v1 = x[roi_batch_ind, c, y_low, x_low].item()
|
|
733
|
+
v2 = x[roi_batch_ind, c, y_low, x_high].item()
|
|
734
|
+
v3 = x[roi_batch_ind, c, y_high, x_low].item()
|
|
735
|
+
v4 = x[roi_batch_ind, c, y_high, x_high].item()
|
|
736
|
+
|
|
737
|
+
# ONNX max mode: max of weighted corners
|
|
738
|
+
val = max(w1 * v1, w2 * v2, w3 * v3, w4 * v4)
|
|
739
|
+
|
|
740
|
+
if output_val is None:
|
|
741
|
+
output_val = val
|
|
742
|
+
else:
|
|
743
|
+
output_val = max(output_val, val)
|
|
744
|
+
|
|
745
|
+
if output_val is not None:
|
|
746
|
+
output[n, c, ph, pw] = output_val
|
|
747
|
+
|
|
748
|
+
return output
|