onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/pooling.py
ADDED
|
@@ -0,0 +1,897 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Pooling operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from ..op_registry import register
|
|
11
|
+
from ..utils.attributes import get_attribute
|
|
12
|
+
from ..utils.op_helpers import (
|
|
13
|
+
compute_same_padding,
|
|
14
|
+
get_optional_input,
|
|
15
|
+
pad_list_to_onnx_pads,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ..graph_builder import GraphBuilder
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# =============================================================================
|
|
23
|
+
# Pooling operators
|
|
24
|
+
# =============================================================================
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@register("MaxPool")
|
|
28
|
+
def max_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
29
|
+
"""Max pooling."""
|
|
30
|
+
x = builder.get_value(node.input[0])
|
|
31
|
+
|
|
32
|
+
kernel_shape = get_attribute(node, "kernel_shape")
|
|
33
|
+
# ONNX spec: strides defaults to 1 along each spatial axis (not kernel_shape)
|
|
34
|
+
strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
|
|
35
|
+
pads = get_attribute(node, "pads")
|
|
36
|
+
dilations = get_attribute(node, "dilations") or [1] * len(kernel_shape)
|
|
37
|
+
ceil_mode = get_attribute(node, "ceil_mode", 0)
|
|
38
|
+
auto_pad = get_attribute(node, "auto_pad", "NOTSET")
|
|
39
|
+
storage_order = get_attribute(node, "storage_order", 0)
|
|
40
|
+
|
|
41
|
+
# Check if we need to return indices (second output requested)
|
|
42
|
+
return_indices = len(node.output) > 1 and node.output[1] != ""
|
|
43
|
+
|
|
44
|
+
def _max_pool(
|
|
45
|
+
x,
|
|
46
|
+
kernel_shape,
|
|
47
|
+
strides,
|
|
48
|
+
pads,
|
|
49
|
+
dilations,
|
|
50
|
+
ceil_mode,
|
|
51
|
+
auto_pad,
|
|
52
|
+
return_indices,
|
|
53
|
+
storage_order,
|
|
54
|
+
):
|
|
55
|
+
ndim = len(kernel_shape)
|
|
56
|
+
input_dtype = x.dtype
|
|
57
|
+
input_shape = x.shape # (N, C, D1, D2, ...)
|
|
58
|
+
|
|
59
|
+
# PyTorch max_pool doesn't support int8/uint8, need to convert
|
|
60
|
+
needs_cast = input_dtype in (torch.int8, torch.uint8)
|
|
61
|
+
if needs_cast:
|
|
62
|
+
x = x.float()
|
|
63
|
+
|
|
64
|
+
padding = 0
|
|
65
|
+
# Handle auto_pad first (before explicit pads)
|
|
66
|
+
if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
|
|
67
|
+
spatial_shape = x.shape[2:]
|
|
68
|
+
pad_list = compute_same_padding(
|
|
69
|
+
tuple(spatial_shape),
|
|
70
|
+
tuple(kernel_shape),
|
|
71
|
+
tuple(strides),
|
|
72
|
+
tuple(dilations),
|
|
73
|
+
auto_pad,
|
|
74
|
+
)
|
|
75
|
+
x = F.pad(x, pad_list, value=float("-inf"))
|
|
76
|
+
padding = 0
|
|
77
|
+
elif pads is not None:
|
|
78
|
+
n = len(pads) // 2
|
|
79
|
+
symmetric = all(pads[i] == pads[i + n] for i in range(n))
|
|
80
|
+
|
|
81
|
+
# Check if padding exceeds PyTorch's limit
|
|
82
|
+
# PyTorch: pad should be at most half of effective kernel size
|
|
83
|
+
# effective_kernel = (kernel_size - 1) * dilation + 1
|
|
84
|
+
# max_pad = effective_kernel // 2
|
|
85
|
+
max_allowed_pad = [
|
|
86
|
+
((k - 1) * d + 1) // 2 for k, d in zip(kernel_shape, dilations)
|
|
87
|
+
]
|
|
88
|
+
exceeds_limit = any(
|
|
89
|
+
pads[i] > max_allowed_pad[i] or pads[i + n] > max_allowed_pad[i]
|
|
90
|
+
for i in range(n)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if symmetric and not exceeds_limit:
|
|
94
|
+
padding = tuple(pads[:n])
|
|
95
|
+
else:
|
|
96
|
+
# Use explicit F.pad for asymmetric or large padding
|
|
97
|
+
pad_list = []
|
|
98
|
+
for i in range(n - 1, -1, -1):
|
|
99
|
+
pad_list.extend([pads[i], pads[i + n]])
|
|
100
|
+
x = F.pad(x, pad_list, value=float("-inf"))
|
|
101
|
+
padding = 0
|
|
102
|
+
|
|
103
|
+
kernel = tuple(kernel_shape)
|
|
104
|
+
stride = tuple(strides)
|
|
105
|
+
dilation = tuple(dilations)
|
|
106
|
+
|
|
107
|
+
if ndim == 1:
|
|
108
|
+
result = F.max_pool1d(
|
|
109
|
+
x,
|
|
110
|
+
kernel[0],
|
|
111
|
+
stride=stride[0],
|
|
112
|
+
padding=padding if isinstance(padding, int) else padding[0],
|
|
113
|
+
dilation=dilation[0],
|
|
114
|
+
ceil_mode=bool(ceil_mode),
|
|
115
|
+
return_indices=return_indices,
|
|
116
|
+
)
|
|
117
|
+
elif ndim == 2:
|
|
118
|
+
result = F.max_pool2d(
|
|
119
|
+
x,
|
|
120
|
+
kernel,
|
|
121
|
+
stride=stride,
|
|
122
|
+
padding=padding,
|
|
123
|
+
dilation=dilation,
|
|
124
|
+
ceil_mode=bool(ceil_mode),
|
|
125
|
+
return_indices=return_indices,
|
|
126
|
+
)
|
|
127
|
+
elif ndim == 3:
|
|
128
|
+
result = F.max_pool3d(
|
|
129
|
+
x,
|
|
130
|
+
kernel,
|
|
131
|
+
stride=stride,
|
|
132
|
+
padding=padding,
|
|
133
|
+
dilation=dilation,
|
|
134
|
+
ceil_mode=bool(ceil_mode),
|
|
135
|
+
return_indices=return_indices,
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
raise NotImplementedError(f"MaxPool{ndim}D not supported")
|
|
139
|
+
|
|
140
|
+
if return_indices:
|
|
141
|
+
values, indices = result
|
|
142
|
+
if needs_cast:
|
|
143
|
+
values = values.to(input_dtype)
|
|
144
|
+
|
|
145
|
+
# Handle storage_order for indices
|
|
146
|
+
# PyTorch returns row-major indices (last dim varies fastest)
|
|
147
|
+
# ONNX storage_order=0 means row-major (default)
|
|
148
|
+
# ONNX storage_order=1 means column-major (first spatial dim varies fastest)
|
|
149
|
+
if storage_order == 1:
|
|
150
|
+
# Convert row-major indices to column-major
|
|
151
|
+
# For input shape (N, C, D1, D2, ...), we need to convert indices
|
|
152
|
+
# Row-major: idx = n*C*D1*D2*... + c*D1*D2*... + d1*D2*... + d2*... + ...
|
|
153
|
+
# Column-major: idx = n + c*N + d1*N*C + d2*N*C*D1 + ...
|
|
154
|
+
# Compute the multi-index from row-major flat index
|
|
155
|
+
flat_indices = indices
|
|
156
|
+
# Spatial dims of original input (before any padding)
|
|
157
|
+
spatial_dims = list(input_shape[2:])
|
|
158
|
+
n_batch = input_shape[0]
|
|
159
|
+
n_channel = input_shape[1]
|
|
160
|
+
|
|
161
|
+
# Decompose row-major index to (n, c, d1, d2, ...)
|
|
162
|
+
remaining = flat_indices
|
|
163
|
+
coords = []
|
|
164
|
+
# First extract spatial coords in reverse order (last dim first)
|
|
165
|
+
for dim_size in reversed(spatial_dims):
|
|
166
|
+
coords.append(remaining % dim_size)
|
|
167
|
+
remaining = remaining // dim_size
|
|
168
|
+
# Now remaining = n * C + c
|
|
169
|
+
c_coord = remaining % n_channel
|
|
170
|
+
n_coord = remaining // n_channel
|
|
171
|
+
|
|
172
|
+
# Reverse coords to get (d1, d2, ...) order
|
|
173
|
+
spatial_coords = list(reversed(coords))
|
|
174
|
+
|
|
175
|
+
# Compute column-major index
|
|
176
|
+
# col_idx = n + c*N + d1*N*C + d2*N*C*D1 + ...
|
|
177
|
+
col_idx = n_coord
|
|
178
|
+
stride_factor = n_batch
|
|
179
|
+
col_idx = col_idx + c_coord * stride_factor
|
|
180
|
+
stride_factor = stride_factor * n_channel
|
|
181
|
+
for i, d_coord in enumerate(spatial_coords):
|
|
182
|
+
col_idx = col_idx + d_coord * stride_factor
|
|
183
|
+
stride_factor = stride_factor * spatial_dims[i]
|
|
184
|
+
|
|
185
|
+
indices = col_idx
|
|
186
|
+
|
|
187
|
+
return values, indices
|
|
188
|
+
else:
|
|
189
|
+
if needs_cast:
|
|
190
|
+
result = result.to(input_dtype)
|
|
191
|
+
return result
|
|
192
|
+
|
|
193
|
+
return builder.call_function(
|
|
194
|
+
_max_pool,
|
|
195
|
+
args=(
|
|
196
|
+
x,
|
|
197
|
+
kernel_shape,
|
|
198
|
+
strides,
|
|
199
|
+
pads,
|
|
200
|
+
dilations,
|
|
201
|
+
ceil_mode,
|
|
202
|
+
auto_pad,
|
|
203
|
+
return_indices,
|
|
204
|
+
storage_order,
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@register("MaxUnpool")
|
|
210
|
+
def max_unpool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
211
|
+
"""MaxUnpool - partial inverse of MaxPool.
|
|
212
|
+
|
|
213
|
+
Unpools the input tensor using indices from MaxPool.
|
|
214
|
+
"""
|
|
215
|
+
x = builder.get_value(node.input[0])
|
|
216
|
+
indices = builder.get_value(node.input[1])
|
|
217
|
+
|
|
218
|
+
# Optional output_shape input
|
|
219
|
+
output_shape = get_optional_input(builder, node, 2)
|
|
220
|
+
|
|
221
|
+
kernel_shape = get_attribute(node, "kernel_shape")
|
|
222
|
+
strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
|
|
223
|
+
pads = get_attribute(node, "pads") or [0] * (2 * len(kernel_shape))
|
|
224
|
+
|
|
225
|
+
def _max_unpool(x, indices, kernel_shape, strides, pads, output_shape):
|
|
226
|
+
ndim = len(kernel_shape)
|
|
227
|
+
|
|
228
|
+
# Convert ONNX pads format to PyTorch padding
|
|
229
|
+
# ONNX: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
|
230
|
+
# PyTorch: symmetric padding per dimension
|
|
231
|
+
n = len(pads) // 2
|
|
232
|
+
padding = tuple(pads[:n])
|
|
233
|
+
|
|
234
|
+
kernel = tuple(kernel_shape)
|
|
235
|
+
stride = tuple(strides)
|
|
236
|
+
|
|
237
|
+
# Calculate default output size (without explicit output_shape)
|
|
238
|
+
# Default output: out_i = (in_i - 1) * stride_i + kernel_i - 2 * pad_i
|
|
239
|
+
input_spatial_shape = x.shape[2:]
|
|
240
|
+
default_spatial = []
|
|
241
|
+
for i in range(ndim):
|
|
242
|
+
out_dim = (
|
|
243
|
+
(input_spatial_shape[i] - 1) * stride[i]
|
|
244
|
+
+ kernel[i]
|
|
245
|
+
- pads[i]
|
|
246
|
+
- pads[i + n]
|
|
247
|
+
)
|
|
248
|
+
default_spatial.append(out_dim)
|
|
249
|
+
|
|
250
|
+
# Determine output size
|
|
251
|
+
out_size = None
|
|
252
|
+
if output_shape is not None:
|
|
253
|
+
# output_shape is the full shape including batch and channel dims
|
|
254
|
+
if isinstance(output_shape, torch.Tensor):
|
|
255
|
+
out_size = tuple(int(s) for s in output_shape.tolist())
|
|
256
|
+
else:
|
|
257
|
+
out_size = tuple(int(s) for s in output_shape)
|
|
258
|
+
|
|
259
|
+
# Get spatial dimensions from output_shape
|
|
260
|
+
target_spatial = out_size[2:]
|
|
261
|
+
|
|
262
|
+
# Check if we need to convert indices
|
|
263
|
+
# ONNX indices are computed for the original (default) tensor size
|
|
264
|
+
# PyTorch expects indices relative to the output_size
|
|
265
|
+
if list(target_spatial) != list(default_spatial):
|
|
266
|
+
# Convert indices from default spatial shape to target spatial shape
|
|
267
|
+
# Indices are flattened over (N, C, D1, D2, ...) dimensions
|
|
268
|
+
# We need to extract (d1, d2, ...) coords from default shape
|
|
269
|
+
# and recompute indices for target shape
|
|
270
|
+
|
|
271
|
+
# For efficiency, work with the spatial dimensions only
|
|
272
|
+
# The batch and channel dimensions affect the flat index calculation
|
|
273
|
+
channels = x.shape[1]
|
|
274
|
+
|
|
275
|
+
# Compute the total size for default spatial dimensions
|
|
276
|
+
default_spatial_size = 1
|
|
277
|
+
for d in default_spatial:
|
|
278
|
+
default_spatial_size *= d
|
|
279
|
+
|
|
280
|
+
# Decompose flat indices to (n, c, spatial_coords) in default shape
|
|
281
|
+
remaining = indices
|
|
282
|
+
|
|
283
|
+
# Extract spatial coordinates in reverse order (last spatial dim first)
|
|
284
|
+
spatial_coords = []
|
|
285
|
+
for dim_size in reversed(default_spatial):
|
|
286
|
+
spatial_coords.append(remaining % dim_size)
|
|
287
|
+
remaining = remaining // dim_size
|
|
288
|
+
spatial_coords = list(reversed(spatial_coords))
|
|
289
|
+
|
|
290
|
+
# remaining now contains (n * channels + c)
|
|
291
|
+
c_coord = remaining % channels
|
|
292
|
+
n_coord = remaining // channels
|
|
293
|
+
|
|
294
|
+
# Recompute flat indices for target spatial shape
|
|
295
|
+
# new_idx = n * (C * prod(target_spatial)) + c * prod(target_spatial) + spatial_flat
|
|
296
|
+
target_spatial_size = 1
|
|
297
|
+
for d in target_spatial:
|
|
298
|
+
target_spatial_size *= d
|
|
299
|
+
|
|
300
|
+
# Compute spatial flat index for target shape
|
|
301
|
+
spatial_flat = spatial_coords[0]
|
|
302
|
+
for i in range(1, ndim):
|
|
303
|
+
spatial_flat = spatial_flat * target_spatial[i] + spatial_coords[i]
|
|
304
|
+
|
|
305
|
+
# Compute full flat index
|
|
306
|
+
indices = (
|
|
307
|
+
n_coord * (channels * target_spatial_size)
|
|
308
|
+
+ c_coord * target_spatial_size
|
|
309
|
+
+ spatial_flat
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
if ndim == 1:
|
|
313
|
+
return F.max_unpool1d(
|
|
314
|
+
x,
|
|
315
|
+
indices,
|
|
316
|
+
kernel[0],
|
|
317
|
+
stride=stride[0],
|
|
318
|
+
padding=padding[0],
|
|
319
|
+
output_size=out_size,
|
|
320
|
+
)
|
|
321
|
+
elif ndim == 2:
|
|
322
|
+
return F.max_unpool2d(
|
|
323
|
+
x,
|
|
324
|
+
indices,
|
|
325
|
+
kernel,
|
|
326
|
+
stride=stride,
|
|
327
|
+
padding=padding,
|
|
328
|
+
output_size=out_size,
|
|
329
|
+
)
|
|
330
|
+
elif ndim == 3:
|
|
331
|
+
return F.max_unpool3d(
|
|
332
|
+
x,
|
|
333
|
+
indices,
|
|
334
|
+
kernel,
|
|
335
|
+
stride=stride,
|
|
336
|
+
padding=padding,
|
|
337
|
+
output_size=out_size,
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
raise NotImplementedError(f"MaxUnpool{ndim}D not supported")
|
|
341
|
+
|
|
342
|
+
return builder.call_function(
|
|
343
|
+
_max_unpool,
|
|
344
|
+
args=(x, indices, kernel_shape, strides, pads, output_shape),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
@register("AveragePool")
|
|
349
|
+
def average_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
350
|
+
"""Average pooling."""
|
|
351
|
+
x = builder.get_value(node.input[0])
|
|
352
|
+
|
|
353
|
+
kernel_shape = get_attribute(node, "kernel_shape")
|
|
354
|
+
strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
|
|
355
|
+
pads = get_attribute(node, "pads")
|
|
356
|
+
dilations = get_attribute(node, "dilations") or [1] * len(kernel_shape)
|
|
357
|
+
ceil_mode = get_attribute(node, "ceil_mode", 0)
|
|
358
|
+
count_include_pad = get_attribute(node, "count_include_pad", 0)
|
|
359
|
+
auto_pad = get_attribute(node, "auto_pad", "NOTSET")
|
|
360
|
+
|
|
361
|
+
def _avg_pool_dilated(
|
|
362
|
+
x, kernel_shape, strides, dilations, pads, ceil_mode, count_include_pad
|
|
363
|
+
):
|
|
364
|
+
"""Compute average pooling with dilation support using unfold.
|
|
365
|
+
|
|
366
|
+
PyTorch's avg_pool doesn't support dilation, so we implement it manually.
|
|
367
|
+
"""
|
|
368
|
+
ndim = len(kernel_shape)
|
|
369
|
+
batch_size = x.shape[0]
|
|
370
|
+
channels = x.shape[1]
|
|
371
|
+
spatial_shape = list(x.shape[2:])
|
|
372
|
+
|
|
373
|
+
# Compute effective kernel size with dilation
|
|
374
|
+
# effective_k = (k - 1) * d + 1
|
|
375
|
+
effective_kernel = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
|
|
376
|
+
|
|
377
|
+
# Apply padding if specified
|
|
378
|
+
if pads is not None:
|
|
379
|
+
n = len(pads) // 2
|
|
380
|
+
pads_begin = [pads[i] for i in range(n)]
|
|
381
|
+
pads_end = [pads[i + n] for i in range(n)]
|
|
382
|
+
else:
|
|
383
|
+
n = ndim
|
|
384
|
+
pads_begin = [0] * n
|
|
385
|
+
pads_end = [0] * n
|
|
386
|
+
|
|
387
|
+
# Track original pads (before ceil_mode adjustment) for count_include_pad
|
|
388
|
+
orig_pads_end = pads_end.copy()
|
|
389
|
+
|
|
390
|
+
# For ceil_mode, add extra end padding if needed to get ceil behavior
|
|
391
|
+
# ceil_mode output: ceil((input + pad_begin + pad_end - ek) / stride) + 1
|
|
392
|
+
# floor_mode output: floor((input + pad_begin + pad_end - ek) / stride) + 1
|
|
393
|
+
# To get ceil behavior with floor, add padding: (stride - 1)
|
|
394
|
+
ceil_extra_pad = [0] * ndim
|
|
395
|
+
if ceil_mode:
|
|
396
|
+
for i in range(ndim):
|
|
397
|
+
padded_size = spatial_shape[i] + pads_begin[i] + pads_end[i]
|
|
398
|
+
# Compute output with floor
|
|
399
|
+
out_floor = (padded_size - effective_kernel[i]) // strides[i] + 1
|
|
400
|
+
# Compute output with ceil
|
|
401
|
+
out_ceil = (
|
|
402
|
+
padded_size - effective_kernel[i] + strides[i] - 1
|
|
403
|
+
) // strides[i] + 1
|
|
404
|
+
if out_ceil > out_floor:
|
|
405
|
+
# Need extra padding to get one more output element
|
|
406
|
+
ceil_extra_pad[i] = strides[i] - 1
|
|
407
|
+
pads_end[i] += ceil_extra_pad[i]
|
|
408
|
+
|
|
409
|
+
# Build pad_list for F.pad (reversed order: last dim first)
|
|
410
|
+
pad_list = []
|
|
411
|
+
for i in range(ndim - 1, -1, -1):
|
|
412
|
+
pad_list.extend([pads_begin[i], pads_end[i]])
|
|
413
|
+
|
|
414
|
+
has_padding = any(p > 0 for p in pad_list)
|
|
415
|
+
has_ceil_extra = any(p > 0 for p in ceil_extra_pad)
|
|
416
|
+
|
|
417
|
+
if has_padding:
|
|
418
|
+
x = F.pad(x, pad_list, value=0)
|
|
419
|
+
spatial_shape_padded = list(x.shape[2:])
|
|
420
|
+
|
|
421
|
+
# Create a mask for computing the correct count
|
|
422
|
+
# Case 1: count_include_pad=False -> mask marks original (non-padded) area
|
|
423
|
+
# Case 2: count_include_pad=True with ceil_extra_pad -> mask marks area
|
|
424
|
+
# up to original pads (but not ceil extra pads)
|
|
425
|
+
# Case 3: count_include_pad=True without ceil_extra_pad -> no mask needed
|
|
426
|
+
if not count_include_pad:
|
|
427
|
+
# Original shape before any padding
|
|
428
|
+
orig_shape = [batch_size, channels] + [
|
|
429
|
+
spatial_shape_padded[i] - pads_begin[i] - pads_end[i]
|
|
430
|
+
for i in range(ndim)
|
|
431
|
+
]
|
|
432
|
+
mask = torch.ones(orig_shape, dtype=x.dtype, device=x.device)
|
|
433
|
+
mask = F.pad(mask, pad_list, value=0)
|
|
434
|
+
elif has_ceil_extra:
|
|
435
|
+
# count_include_pad=True but with ceil extra padding
|
|
436
|
+
# Create mask that includes original padding but not ceil extra
|
|
437
|
+
orig_pad_list = []
|
|
438
|
+
for i in range(ndim - 1, -1, -1):
|
|
439
|
+
orig_pad_list.extend([pads_begin[i], orig_pads_end[i]])
|
|
440
|
+
# Shape after original padding only
|
|
441
|
+
orig_padded_shape = [batch_size, channels] + [
|
|
442
|
+
spatial_shape[i] + pads_begin[i] + orig_pads_end[i]
|
|
443
|
+
for i in range(ndim)
|
|
444
|
+
]
|
|
445
|
+
mask = torch.ones(orig_padded_shape, dtype=x.dtype, device=x.device)
|
|
446
|
+
# Pad with ceil extra padding (but these should be 0 in mask)
|
|
447
|
+
ceil_pad_list = []
|
|
448
|
+
for i in range(ndim - 1, -1, -1):
|
|
449
|
+
ceil_pad_list.extend([0, ceil_extra_pad[i]])
|
|
450
|
+
mask = F.pad(mask, ceil_pad_list, value=0)
|
|
451
|
+
else:
|
|
452
|
+
mask = None
|
|
453
|
+
else:
|
|
454
|
+
mask = None
|
|
455
|
+
|
|
456
|
+
# Use unfold to extract patches with dilation
|
|
457
|
+
# For each spatial dimension, unfold with size=kernel and step=stride
|
|
458
|
+
# We need to account for dilation by selecting every d-th element
|
|
459
|
+
|
|
460
|
+
if ndim == 1:
|
|
461
|
+
# Use unfold for 1D
|
|
462
|
+
# unfold(dimension, size, step)
|
|
463
|
+
_, d, s = kernel_shape[0], dilations[0], strides[0]
|
|
464
|
+
ek = effective_kernel[0]
|
|
465
|
+
|
|
466
|
+
# Unfold with effective kernel size and stride
|
|
467
|
+
# Then select every d-th element within each patch
|
|
468
|
+
patches = x.unfold(2, ek, s) # (N, C, out_L, ek)
|
|
469
|
+
# Select dilated elements: indices 0, d, 2d, ..., (k-1)*d
|
|
470
|
+
indices = torch.arange(0, ek, d, device=x.device)
|
|
471
|
+
patches = patches.index_select(-1, indices) # (N, C, out_L, k)
|
|
472
|
+
|
|
473
|
+
if mask is not None:
|
|
474
|
+
mask_patches = mask.unfold(2, ek, s)
|
|
475
|
+
mask_patches = mask_patches.index_select(-1, indices)
|
|
476
|
+
count = mask_patches.sum(dim=-1)
|
|
477
|
+
sum_val = patches.sum(dim=-1)
|
|
478
|
+
return sum_val / count.clamp(min=1)
|
|
479
|
+
else:
|
|
480
|
+
return patches.mean(dim=-1)
|
|
481
|
+
|
|
482
|
+
elif ndim == 2:
|
|
483
|
+
k0, k1 = kernel_shape
|
|
484
|
+
d0, d1 = dilations
|
|
485
|
+
s0, s1 = strides
|
|
486
|
+
ek0, ek1 = effective_kernel
|
|
487
|
+
|
|
488
|
+
# Unfold along height (dim 2), then width (dim 3)
|
|
489
|
+
patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1)
|
|
490
|
+
# patches shape: (N, C, out_H, out_W, ek0, ek1)
|
|
491
|
+
|
|
492
|
+
# Select dilated elements
|
|
493
|
+
indices0 = torch.arange(0, ek0, d0, device=x.device)
|
|
494
|
+
indices1 = torch.arange(0, ek1, d1, device=x.device)
|
|
495
|
+
patches = patches.index_select(-2, indices0).index_select(-1, indices1)
|
|
496
|
+
# patches shape: (N, C, out_H, out_W, k0, k1)
|
|
497
|
+
|
|
498
|
+
if mask is not None:
|
|
499
|
+
mask_patches = mask.unfold(2, ek0, s0).unfold(3, ek1, s1)
|
|
500
|
+
mask_patches = mask_patches.index_select(-2, indices0).index_select(
|
|
501
|
+
-1, indices1
|
|
502
|
+
)
|
|
503
|
+
count = mask_patches.sum(dim=(-2, -1))
|
|
504
|
+
sum_val = patches.sum(dim=(-2, -1))
|
|
505
|
+
return sum_val / count.clamp(min=1)
|
|
506
|
+
else:
|
|
507
|
+
return patches.mean(dim=(-2, -1))
|
|
508
|
+
|
|
509
|
+
elif ndim == 3:
|
|
510
|
+
k0, k1, k2 = kernel_shape
|
|
511
|
+
d0, d1, d2 = dilations
|
|
512
|
+
s0, s1, s2 = strides
|
|
513
|
+
ek0, ek1, ek2 = effective_kernel
|
|
514
|
+
|
|
515
|
+
# Unfold along each spatial dimension
|
|
516
|
+
patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1).unfold(4, ek2, s2)
|
|
517
|
+
# patches shape: (N, C, out_D, out_H, out_W, ek0, ek1, ek2)
|
|
518
|
+
|
|
519
|
+
# Select dilated elements
|
|
520
|
+
indices0 = torch.arange(0, ek0, d0, device=x.device)
|
|
521
|
+
indices1 = torch.arange(0, ek1, d1, device=x.device)
|
|
522
|
+
indices2 = torch.arange(0, ek2, d2, device=x.device)
|
|
523
|
+
patches = (
|
|
524
|
+
patches.index_select(-3, indices0)
|
|
525
|
+
.index_select(-2, indices1)
|
|
526
|
+
.index_select(-1, indices2)
|
|
527
|
+
)
|
|
528
|
+
# patches shape: (N, C, out_D, out_H, out_W, k0, k1, k2)
|
|
529
|
+
|
|
530
|
+
if mask is not None:
|
|
531
|
+
mask_patches = (
|
|
532
|
+
mask.unfold(2, ek0, s0).unfold(3, ek1, s1).unfold(4, ek2, s2)
|
|
533
|
+
)
|
|
534
|
+
mask_patches = (
|
|
535
|
+
mask_patches.index_select(-3, indices0)
|
|
536
|
+
.index_select(-2, indices1)
|
|
537
|
+
.index_select(-1, indices2)
|
|
538
|
+
)
|
|
539
|
+
count = mask_patches.sum(dim=(-3, -2, -1))
|
|
540
|
+
sum_val = patches.sum(dim=(-3, -2, -1))
|
|
541
|
+
return sum_val / count.clamp(min=1)
|
|
542
|
+
else:
|
|
543
|
+
return patches.mean(dim=(-3, -2, -1))
|
|
544
|
+
|
|
545
|
+
else:
|
|
546
|
+
raise NotImplementedError(f"AveragePool{ndim}D not supported")
|
|
547
|
+
|
|
548
|
+
def _avg_pool(
|
|
549
|
+
x,
|
|
550
|
+
kernel_shape,
|
|
551
|
+
strides,
|
|
552
|
+
pads,
|
|
553
|
+
dilations,
|
|
554
|
+
ceil_mode,
|
|
555
|
+
count_include_pad,
|
|
556
|
+
auto_pad,
|
|
557
|
+
):
|
|
558
|
+
ndim = len(kernel_shape)
|
|
559
|
+
|
|
560
|
+
# Check if we have non-trivial dilation
|
|
561
|
+
has_dilation = any(d != 1 for d in dilations)
|
|
562
|
+
|
|
563
|
+
# Handle auto_pad first (before explicit pads)
|
|
564
|
+
if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
|
|
565
|
+
# For SAME padding with count_include_pad=0, we need to compute
|
|
566
|
+
# the average only over valid (non-padded) input positions.
|
|
567
|
+
# We do this by:
|
|
568
|
+
# 1. Sum pooling on padded input (pad with 0s, so they don't affect sum)
|
|
569
|
+
# 2. Count pooling on a mask (to count valid positions per output)
|
|
570
|
+
# 3. Divide sum by count
|
|
571
|
+
input_shape = x.shape[2:]
|
|
572
|
+
pad_list = compute_same_padding(
|
|
573
|
+
tuple(input_shape),
|
|
574
|
+
tuple(kernel_shape),
|
|
575
|
+
tuple(strides),
|
|
576
|
+
tuple(dilations),
|
|
577
|
+
auto_pad,
|
|
578
|
+
use_effective_kernel=True,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Convert pad_list to pads format for dilated implementation
|
|
582
|
+
pads_onnx = pad_list_to_onnx_pads(pad_list, ndim)
|
|
583
|
+
|
|
584
|
+
# Use dilated implementation which handles padding correctly
|
|
585
|
+
return _avg_pool_dilated(
|
|
586
|
+
x, kernel_shape, strides, dilations, pads_onnx, ceil_mode, 0
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# If we have dilation, use the dilated implementation
|
|
590
|
+
if has_dilation:
|
|
591
|
+
return _avg_pool_dilated(
|
|
592
|
+
x, kernel_shape, strides, dilations, pads, ceil_mode, count_include_pad
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
# Check if we need to use manual padding (asymmetric or exceeds limit)
|
|
596
|
+
padding = 0
|
|
597
|
+
use_manual_pad = False
|
|
598
|
+
if pads is not None:
|
|
599
|
+
n = len(pads) // 2
|
|
600
|
+
symmetric = all(pads[i] == pads[i + n] for i in range(n))
|
|
601
|
+
|
|
602
|
+
# Check if padding exceeds PyTorch's limit
|
|
603
|
+
# PyTorch: pad should be at most half of kernel size
|
|
604
|
+
max_allowed_pad = [k // 2 for k in kernel_shape]
|
|
605
|
+
exceeds_limit = any(
|
|
606
|
+
pads[i] > max_allowed_pad[i] or pads[i + n] > max_allowed_pad[i]
|
|
607
|
+
for i in range(n)
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
if symmetric and not exceeds_limit:
|
|
611
|
+
padding = tuple(pads[:n])
|
|
612
|
+
else:
|
|
613
|
+
use_manual_pad = True
|
|
614
|
+
|
|
615
|
+
if use_manual_pad:
|
|
616
|
+
# Use dilated implementation which handles asymmetric/large padding
|
|
617
|
+
return _avg_pool_dilated(
|
|
618
|
+
x, kernel_shape, strides, dilations, pads, ceil_mode, count_include_pad
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
kernel = tuple(kernel_shape)
|
|
622
|
+
stride = tuple(strides)
|
|
623
|
+
|
|
624
|
+
if ndim == 1:
|
|
625
|
+
return F.avg_pool1d(
|
|
626
|
+
x,
|
|
627
|
+
kernel[0],
|
|
628
|
+
stride=stride[0],
|
|
629
|
+
padding=padding if isinstance(padding, int) else padding[0],
|
|
630
|
+
ceil_mode=bool(ceil_mode),
|
|
631
|
+
count_include_pad=bool(count_include_pad),
|
|
632
|
+
)
|
|
633
|
+
elif ndim == 2:
|
|
634
|
+
return F.avg_pool2d(
|
|
635
|
+
x,
|
|
636
|
+
kernel,
|
|
637
|
+
stride=stride,
|
|
638
|
+
padding=padding,
|
|
639
|
+
ceil_mode=bool(ceil_mode),
|
|
640
|
+
count_include_pad=bool(count_include_pad),
|
|
641
|
+
)
|
|
642
|
+
elif ndim == 3:
|
|
643
|
+
return F.avg_pool3d(
|
|
644
|
+
x,
|
|
645
|
+
kernel,
|
|
646
|
+
stride=stride,
|
|
647
|
+
padding=padding,
|
|
648
|
+
ceil_mode=bool(ceil_mode),
|
|
649
|
+
count_include_pad=bool(count_include_pad),
|
|
650
|
+
)
|
|
651
|
+
else:
|
|
652
|
+
raise NotImplementedError(f"AveragePool{ndim}D not supported")
|
|
653
|
+
|
|
654
|
+
return builder.call_function(
|
|
655
|
+
_avg_pool,
|
|
656
|
+
args=(
|
|
657
|
+
x,
|
|
658
|
+
kernel_shape,
|
|
659
|
+
strides,
|
|
660
|
+
pads,
|
|
661
|
+
dilations,
|
|
662
|
+
ceil_mode,
|
|
663
|
+
count_include_pad,
|
|
664
|
+
auto_pad,
|
|
665
|
+
),
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
@register("GlobalAveragePool")
|
|
670
|
+
def global_average_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
671
|
+
"""Global average pooling."""
|
|
672
|
+
x = builder.get_value(node.input[0])
|
|
673
|
+
|
|
674
|
+
def _global_avg_pool(x):
|
|
675
|
+
# Average over all spatial dimensions (keep batch and channel)
|
|
676
|
+
dims = tuple(range(2, x.dim()))
|
|
677
|
+
return x.mean(dim=dims, keepdim=True)
|
|
678
|
+
|
|
679
|
+
return builder.call_function(_global_avg_pool, args=(x,))
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
@register("GlobalMaxPool")
|
|
683
|
+
def global_max_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
684
|
+
"""Global max pooling."""
|
|
685
|
+
x = builder.get_value(node.input[0])
|
|
686
|
+
|
|
687
|
+
def _global_max_pool(x):
|
|
688
|
+
# Max over all spatial dimensions (keep batch and channel)
|
|
689
|
+
result = x
|
|
690
|
+
for dim in range(x.dim() - 1, 1, -1):
|
|
691
|
+
result = result.max(dim=dim, keepdim=True).values
|
|
692
|
+
return result
|
|
693
|
+
|
|
694
|
+
return builder.call_function(_global_max_pool, args=(x,))
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
@register("LpPool")
|
|
698
|
+
def lp_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
699
|
+
"""Lp pooling.
|
|
700
|
+
|
|
701
|
+
Computes the Lp norm over a sliding window:
|
|
702
|
+
output = (sum(|x|^p))^(1/p)
|
|
703
|
+
"""
|
|
704
|
+
x = builder.get_value(node.input[0])
|
|
705
|
+
|
|
706
|
+
kernel_shape = get_attribute(node, "kernel_shape")
|
|
707
|
+
strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
|
|
708
|
+
pads = get_attribute(node, "pads")
|
|
709
|
+
dilations = get_attribute(node, "dilations") or [1] * len(kernel_shape)
|
|
710
|
+
ceil_mode = get_attribute(node, "ceil_mode", 0)
|
|
711
|
+
auto_pad = get_attribute(node, "auto_pad", "NOTSET")
|
|
712
|
+
p = get_attribute(node, "p", 2)
|
|
713
|
+
|
|
714
|
+
def _lp_pool_dilated(x, kernel_shape, strides, dilations, pads, ceil_mode, p):
|
|
715
|
+
"""Compute Lp pooling with dilation support using unfold.
|
|
716
|
+
|
|
717
|
+
PyTorch's lp_pool doesn't support dilation or padding, so we implement
|
|
718
|
+
it manually.
|
|
719
|
+
"""
|
|
720
|
+
ndim = len(kernel_shape)
|
|
721
|
+
spatial_shape = list(x.shape[2:])
|
|
722
|
+
|
|
723
|
+
# Compute effective kernel size with dilation
|
|
724
|
+
# effective_k = (k - 1) * d + 1
|
|
725
|
+
effective_kernel = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
|
|
726
|
+
|
|
727
|
+
# Apply padding if specified
|
|
728
|
+
if pads is not None:
|
|
729
|
+
n = len(pads) // 2
|
|
730
|
+
pads_begin = [pads[i] for i in range(n)]
|
|
731
|
+
pads_end = [pads[i + n] for i in range(n)]
|
|
732
|
+
else:
|
|
733
|
+
n = ndim
|
|
734
|
+
pads_begin = [0] * n
|
|
735
|
+
pads_end = [0] * n
|
|
736
|
+
|
|
737
|
+
# For ceil_mode, add extra end padding if needed to get ceil behavior
|
|
738
|
+
if ceil_mode:
|
|
739
|
+
for i in range(ndim):
|
|
740
|
+
padded_size = spatial_shape[i] + pads_begin[i] + pads_end[i]
|
|
741
|
+
# Compute output with floor
|
|
742
|
+
out_floor = (padded_size - effective_kernel[i]) // strides[i] + 1
|
|
743
|
+
# Compute output with ceil
|
|
744
|
+
out_ceil = (
|
|
745
|
+
padded_size - effective_kernel[i] + strides[i] - 1
|
|
746
|
+
) // strides[i] + 1
|
|
747
|
+
if out_ceil > out_floor:
|
|
748
|
+
# Need extra padding to get one more output element
|
|
749
|
+
pads_end[i] += strides[i] - 1
|
|
750
|
+
|
|
751
|
+
# Build pad_list for F.pad (reversed order: last dim first)
|
|
752
|
+
pad_list = []
|
|
753
|
+
for i in range(ndim - 1, -1, -1):
|
|
754
|
+
pad_list.extend([pads_begin[i], pads_end[i]])
|
|
755
|
+
|
|
756
|
+
has_padding = any(p_val > 0 for p_val in pad_list)
|
|
757
|
+
|
|
758
|
+
if has_padding:
|
|
759
|
+
x = F.pad(x, pad_list, value=0)
|
|
760
|
+
|
|
761
|
+
# Use unfold to extract patches with dilation
|
|
762
|
+
if ndim == 1:
|
|
763
|
+
_, d, s = kernel_shape[0], dilations[0], strides[0]
|
|
764
|
+
ek = effective_kernel[0]
|
|
765
|
+
|
|
766
|
+
# Unfold with effective kernel size and stride
|
|
767
|
+
patches = x.unfold(2, ek, s) # (N, C, out_L, ek)
|
|
768
|
+
# Select dilated elements: indices 0, d, 2d, ..., (k-1)*d
|
|
769
|
+
indices = torch.arange(0, ek, d, device=x.device)
|
|
770
|
+
patches = patches.index_select(-1, indices) # (N, C, out_L, k)
|
|
771
|
+
|
|
772
|
+
# Compute Lp norm: (sum(|x|^p))^(1/p)
|
|
773
|
+
return (patches.abs().pow(p).sum(dim=-1)).pow(1.0 / p)
|
|
774
|
+
|
|
775
|
+
elif ndim == 2:
|
|
776
|
+
k0, k1 = kernel_shape
|
|
777
|
+
d0, d1 = dilations
|
|
778
|
+
s0, s1 = strides
|
|
779
|
+
ek0, ek1 = effective_kernel
|
|
780
|
+
|
|
781
|
+
# Unfold along height (dim 2), then width (dim 3)
|
|
782
|
+
patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1)
|
|
783
|
+
# patches shape: (N, C, out_H, out_W, ek0, ek1)
|
|
784
|
+
|
|
785
|
+
# Select dilated elements
|
|
786
|
+
indices0 = torch.arange(0, ek0, d0, device=x.device)
|
|
787
|
+
indices1 = torch.arange(0, ek1, d1, device=x.device)
|
|
788
|
+
patches = patches.index_select(-2, indices0).index_select(-1, indices1)
|
|
789
|
+
# patches shape: (N, C, out_H, out_W, k0, k1)
|
|
790
|
+
|
|
791
|
+
# Compute Lp norm: (sum(|x|^p))^(1/p)
|
|
792
|
+
return (patches.abs().pow(p).sum(dim=(-2, -1))).pow(1.0 / p)
|
|
793
|
+
|
|
794
|
+
elif ndim == 3:
|
|
795
|
+
k0, k1, k2 = kernel_shape
|
|
796
|
+
d0, d1, d2 = dilations
|
|
797
|
+
s0, s1, s2 = strides
|
|
798
|
+
ek0, ek1, ek2 = effective_kernel
|
|
799
|
+
|
|
800
|
+
# Unfold along each spatial dimension
|
|
801
|
+
patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1).unfold(4, ek2, s2)
|
|
802
|
+
# patches shape: (N, C, out_D, out_H, out_W, ek0, ek1, ek2)
|
|
803
|
+
|
|
804
|
+
# Select dilated elements
|
|
805
|
+
indices0 = torch.arange(0, ek0, d0, device=x.device)
|
|
806
|
+
indices1 = torch.arange(0, ek1, d1, device=x.device)
|
|
807
|
+
indices2 = torch.arange(0, ek2, d2, device=x.device)
|
|
808
|
+
patches = (
|
|
809
|
+
patches.index_select(-3, indices0)
|
|
810
|
+
.index_select(-2, indices1)
|
|
811
|
+
.index_select(-1, indices2)
|
|
812
|
+
)
|
|
813
|
+
# patches shape: (N, C, out_D, out_H, out_W, k0, k1, k2)
|
|
814
|
+
|
|
815
|
+
# Compute Lp norm: (sum(|x|^p))^(1/p)
|
|
816
|
+
return (patches.abs().pow(p).sum(dim=(-3, -2, -1))).pow(1.0 / p)
|
|
817
|
+
|
|
818
|
+
else:
|
|
819
|
+
raise NotImplementedError(f"LpPool{ndim}D not supported")
|
|
820
|
+
|
|
821
|
+
def _lp_pool(x, kernel_shape, strides, pads, dilations, ceil_mode, auto_pad, p):
|
|
822
|
+
ndim = len(kernel_shape)
|
|
823
|
+
|
|
824
|
+
# Check if we have non-trivial dilation
|
|
825
|
+
has_dilation = any(d != 1 for d in dilations)
|
|
826
|
+
|
|
827
|
+
# Handle auto_pad first (before explicit pads)
|
|
828
|
+
if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
|
|
829
|
+
input_shape = x.shape[2:]
|
|
830
|
+
pad_list = compute_same_padding(
|
|
831
|
+
tuple(input_shape),
|
|
832
|
+
tuple(kernel_shape),
|
|
833
|
+
tuple(strides),
|
|
834
|
+
tuple(dilations),
|
|
835
|
+
auto_pad,
|
|
836
|
+
use_effective_kernel=True,
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
# Convert pad_list to pads format for dilated implementation
|
|
840
|
+
pads_onnx = pad_list_to_onnx_pads(pad_list, ndim)
|
|
841
|
+
|
|
842
|
+
# Use dilated implementation which handles padding correctly
|
|
843
|
+
return _lp_pool_dilated(
|
|
844
|
+
x, kernel_shape, strides, dilations, pads_onnx, ceil_mode, p
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
# If we have dilation, use the dilated implementation
|
|
848
|
+
if has_dilation:
|
|
849
|
+
return _lp_pool_dilated(
|
|
850
|
+
x, kernel_shape, strides, dilations, pads, ceil_mode, p
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
# Check if we need to use manual padding (asymmetric or any padding)
|
|
854
|
+
# PyTorch's lp_pool doesn't support padding at all
|
|
855
|
+
if pads is not None and any(pad_val > 0 for pad_val in pads):
|
|
856
|
+
return _lp_pool_dilated(
|
|
857
|
+
x, kernel_shape, strides, dilations, pads, ceil_mode, p
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
# PyTorch's lp_pool functions use sign(f(x)) * |f(x)|^(1/p) where f(x) = sum(x^p),
|
|
861
|
+
# but ONNX's LpPool uses (sum(|x|^p))^(1/p). The difference is that ONNX
|
|
862
|
+
# takes absolute value FIRST before raising to power p. This matters when
|
|
863
|
+
# x contains negative values and p is odd (like p=3), as PyTorch's version
|
|
864
|
+
# can produce NaN while ONNX's version is always well-defined.
|
|
865
|
+
# Therefore, we always use our manual implementation which correctly applies abs() first.
|
|
866
|
+
return _lp_pool_dilated(x, kernel_shape, strides, dilations, pads, ceil_mode, p)
|
|
867
|
+
|
|
868
|
+
return builder.call_function(
|
|
869
|
+
_lp_pool,
|
|
870
|
+
args=(
|
|
871
|
+
x,
|
|
872
|
+
kernel_shape,
|
|
873
|
+
strides,
|
|
874
|
+
pads,
|
|
875
|
+
dilations,
|
|
876
|
+
ceil_mode,
|
|
877
|
+
auto_pad,
|
|
878
|
+
p,
|
|
879
|
+
),
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
@register("GlobalLpPool")
|
|
884
|
+
def global_lp_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
885
|
+
"""Global Lp pooling.
|
|
886
|
+
|
|
887
|
+
Computes the Lp norm over all spatial dimensions.
|
|
888
|
+
"""
|
|
889
|
+
x = builder.get_value(node.input[0])
|
|
890
|
+
p = get_attribute(node, "p", 2)
|
|
891
|
+
|
|
892
|
+
def _global_lp_pool(x, p):
|
|
893
|
+
# Lp norm over all spatial dimensions (keep batch and channel)
|
|
894
|
+
dims = tuple(range(2, x.dim()))
|
|
895
|
+
return (x.abs().pow(p).sum(dim=dims, keepdim=True)).pow(1.0 / p)
|
|
896
|
+
|
|
897
|
+
return builder.call_function(_global_lp_pool, args=(x, p))
|