onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2fx/ops/pooling.py ADDED
@@ -0,0 +1,897 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Pooling operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from ..op_registry import register
11
+ from ..utils.attributes import get_attribute
12
+ from ..utils.op_helpers import (
13
+ compute_same_padding,
14
+ get_optional_input,
15
+ pad_list_to_onnx_pads,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from ..graph_builder import GraphBuilder
20
+
21
+
22
+ # =============================================================================
23
+ # Pooling operators
24
+ # =============================================================================
25
+
26
+
27
+ @register("MaxPool")
28
+ def max_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
29
+ """Max pooling."""
30
+ x = builder.get_value(node.input[0])
31
+
32
+ kernel_shape = get_attribute(node, "kernel_shape")
33
+ # ONNX spec: strides defaults to 1 along each spatial axis (not kernel_shape)
34
+ strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
35
+ pads = get_attribute(node, "pads")
36
+ dilations = get_attribute(node, "dilations") or [1] * len(kernel_shape)
37
+ ceil_mode = get_attribute(node, "ceil_mode", 0)
38
+ auto_pad = get_attribute(node, "auto_pad", "NOTSET")
39
+ storage_order = get_attribute(node, "storage_order", 0)
40
+
41
+ # Check if we need to return indices (second output requested)
42
+ return_indices = len(node.output) > 1 and node.output[1] != ""
43
+
44
+ def _max_pool(
45
+ x,
46
+ kernel_shape,
47
+ strides,
48
+ pads,
49
+ dilations,
50
+ ceil_mode,
51
+ auto_pad,
52
+ return_indices,
53
+ storage_order,
54
+ ):
55
+ ndim = len(kernel_shape)
56
+ input_dtype = x.dtype
57
+ input_shape = x.shape # (N, C, D1, D2, ...)
58
+
59
+ # PyTorch max_pool doesn't support int8/uint8, need to convert
60
+ needs_cast = input_dtype in (torch.int8, torch.uint8)
61
+ if needs_cast:
62
+ x = x.float()
63
+
64
+ padding = 0
65
+ # Handle auto_pad first (before explicit pads)
66
+ if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
67
+ spatial_shape = x.shape[2:]
68
+ pad_list = compute_same_padding(
69
+ tuple(spatial_shape),
70
+ tuple(kernel_shape),
71
+ tuple(strides),
72
+ tuple(dilations),
73
+ auto_pad,
74
+ )
75
+ x = F.pad(x, pad_list, value=float("-inf"))
76
+ padding = 0
77
+ elif pads is not None:
78
+ n = len(pads) // 2
79
+ symmetric = all(pads[i] == pads[i + n] for i in range(n))
80
+
81
+ # Check if padding exceeds PyTorch's limit
82
+ # PyTorch: pad should be at most half of effective kernel size
83
+ # effective_kernel = (kernel_size - 1) * dilation + 1
84
+ # max_pad = effective_kernel // 2
85
+ max_allowed_pad = [
86
+ ((k - 1) * d + 1) // 2 for k, d in zip(kernel_shape, dilations)
87
+ ]
88
+ exceeds_limit = any(
89
+ pads[i] > max_allowed_pad[i] or pads[i + n] > max_allowed_pad[i]
90
+ for i in range(n)
91
+ )
92
+
93
+ if symmetric and not exceeds_limit:
94
+ padding = tuple(pads[:n])
95
+ else:
96
+ # Use explicit F.pad for asymmetric or large padding
97
+ pad_list = []
98
+ for i in range(n - 1, -1, -1):
99
+ pad_list.extend([pads[i], pads[i + n]])
100
+ x = F.pad(x, pad_list, value=float("-inf"))
101
+ padding = 0
102
+
103
+ kernel = tuple(kernel_shape)
104
+ stride = tuple(strides)
105
+ dilation = tuple(dilations)
106
+
107
+ if ndim == 1:
108
+ result = F.max_pool1d(
109
+ x,
110
+ kernel[0],
111
+ stride=stride[0],
112
+ padding=padding if isinstance(padding, int) else padding[0],
113
+ dilation=dilation[0],
114
+ ceil_mode=bool(ceil_mode),
115
+ return_indices=return_indices,
116
+ )
117
+ elif ndim == 2:
118
+ result = F.max_pool2d(
119
+ x,
120
+ kernel,
121
+ stride=stride,
122
+ padding=padding,
123
+ dilation=dilation,
124
+ ceil_mode=bool(ceil_mode),
125
+ return_indices=return_indices,
126
+ )
127
+ elif ndim == 3:
128
+ result = F.max_pool3d(
129
+ x,
130
+ kernel,
131
+ stride=stride,
132
+ padding=padding,
133
+ dilation=dilation,
134
+ ceil_mode=bool(ceil_mode),
135
+ return_indices=return_indices,
136
+ )
137
+ else:
138
+ raise NotImplementedError(f"MaxPool{ndim}D not supported")
139
+
140
+ if return_indices:
141
+ values, indices = result
142
+ if needs_cast:
143
+ values = values.to(input_dtype)
144
+
145
+ # Handle storage_order for indices
146
+ # PyTorch returns row-major indices (last dim varies fastest)
147
+ # ONNX storage_order=0 means row-major (default)
148
+ # ONNX storage_order=1 means column-major (first spatial dim varies fastest)
149
+ if storage_order == 1:
150
+ # Convert row-major indices to column-major
151
+ # For input shape (N, C, D1, D2, ...), we need to convert indices
152
+ # Row-major: idx = n*C*D1*D2*... + c*D1*D2*... + d1*D2*... + d2*... + ...
153
+ # Column-major: idx = n + c*N + d1*N*C + d2*N*C*D1 + ...
154
+ # Compute the multi-index from row-major flat index
155
+ flat_indices = indices
156
+ # Spatial dims of original input (before any padding)
157
+ spatial_dims = list(input_shape[2:])
158
+ n_batch = input_shape[0]
159
+ n_channel = input_shape[1]
160
+
161
+ # Decompose row-major index to (n, c, d1, d2, ...)
162
+ remaining = flat_indices
163
+ coords = []
164
+ # First extract spatial coords in reverse order (last dim first)
165
+ for dim_size in reversed(spatial_dims):
166
+ coords.append(remaining % dim_size)
167
+ remaining = remaining // dim_size
168
+ # Now remaining = n * C + c
169
+ c_coord = remaining % n_channel
170
+ n_coord = remaining // n_channel
171
+
172
+ # Reverse coords to get (d1, d2, ...) order
173
+ spatial_coords = list(reversed(coords))
174
+
175
+ # Compute column-major index
176
+ # col_idx = n + c*N + d1*N*C + d2*N*C*D1 + ...
177
+ col_idx = n_coord
178
+ stride_factor = n_batch
179
+ col_idx = col_idx + c_coord * stride_factor
180
+ stride_factor = stride_factor * n_channel
181
+ for i, d_coord in enumerate(spatial_coords):
182
+ col_idx = col_idx + d_coord * stride_factor
183
+ stride_factor = stride_factor * spatial_dims[i]
184
+
185
+ indices = col_idx
186
+
187
+ return values, indices
188
+ else:
189
+ if needs_cast:
190
+ result = result.to(input_dtype)
191
+ return result
192
+
193
+ return builder.call_function(
194
+ _max_pool,
195
+ args=(
196
+ x,
197
+ kernel_shape,
198
+ strides,
199
+ pads,
200
+ dilations,
201
+ ceil_mode,
202
+ auto_pad,
203
+ return_indices,
204
+ storage_order,
205
+ ),
206
+ )
207
+
208
+
209
+ @register("MaxUnpool")
210
+ def max_unpool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
211
+ """MaxUnpool - partial inverse of MaxPool.
212
+
213
+ Unpools the input tensor using indices from MaxPool.
214
+ """
215
+ x = builder.get_value(node.input[0])
216
+ indices = builder.get_value(node.input[1])
217
+
218
+ # Optional output_shape input
219
+ output_shape = get_optional_input(builder, node, 2)
220
+
221
+ kernel_shape = get_attribute(node, "kernel_shape")
222
+ strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
223
+ pads = get_attribute(node, "pads") or [0] * (2 * len(kernel_shape))
224
+
225
+ def _max_unpool(x, indices, kernel_shape, strides, pads, output_shape):
226
+ ndim = len(kernel_shape)
227
+
228
+ # Convert ONNX pads format to PyTorch padding
229
+ # ONNX: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
230
+ # PyTorch: symmetric padding per dimension
231
+ n = len(pads) // 2
232
+ padding = tuple(pads[:n])
233
+
234
+ kernel = tuple(kernel_shape)
235
+ stride = tuple(strides)
236
+
237
+ # Calculate default output size (without explicit output_shape)
238
+ # Default output: out_i = (in_i - 1) * stride_i + kernel_i - 2 * pad_i
239
+ input_spatial_shape = x.shape[2:]
240
+ default_spatial = []
241
+ for i in range(ndim):
242
+ out_dim = (
243
+ (input_spatial_shape[i] - 1) * stride[i]
244
+ + kernel[i]
245
+ - pads[i]
246
+ - pads[i + n]
247
+ )
248
+ default_spatial.append(out_dim)
249
+
250
+ # Determine output size
251
+ out_size = None
252
+ if output_shape is not None:
253
+ # output_shape is the full shape including batch and channel dims
254
+ if isinstance(output_shape, torch.Tensor):
255
+ out_size = tuple(int(s) for s in output_shape.tolist())
256
+ else:
257
+ out_size = tuple(int(s) for s in output_shape)
258
+
259
+ # Get spatial dimensions from output_shape
260
+ target_spatial = out_size[2:]
261
+
262
+ # Check if we need to convert indices
263
+ # ONNX indices are computed for the original (default) tensor size
264
+ # PyTorch expects indices relative to the output_size
265
+ if list(target_spatial) != list(default_spatial):
266
+ # Convert indices from default spatial shape to target spatial shape
267
+ # Indices are flattened over (N, C, D1, D2, ...) dimensions
268
+ # We need to extract (d1, d2, ...) coords from default shape
269
+ # and recompute indices for target shape
270
+
271
+ # For efficiency, work with the spatial dimensions only
272
+ # The batch and channel dimensions affect the flat index calculation
273
+ channels = x.shape[1]
274
+
275
+ # Compute the total size for default spatial dimensions
276
+ default_spatial_size = 1
277
+ for d in default_spatial:
278
+ default_spatial_size *= d
279
+
280
+ # Decompose flat indices to (n, c, spatial_coords) in default shape
281
+ remaining = indices
282
+
283
+ # Extract spatial coordinates in reverse order (last spatial dim first)
284
+ spatial_coords = []
285
+ for dim_size in reversed(default_spatial):
286
+ spatial_coords.append(remaining % dim_size)
287
+ remaining = remaining // dim_size
288
+ spatial_coords = list(reversed(spatial_coords))
289
+
290
+ # remaining now contains (n * channels + c)
291
+ c_coord = remaining % channels
292
+ n_coord = remaining // channels
293
+
294
+ # Recompute flat indices for target spatial shape
295
+ # new_idx = n * (C * prod(target_spatial)) + c * prod(target_spatial) + spatial_flat
296
+ target_spatial_size = 1
297
+ for d in target_spatial:
298
+ target_spatial_size *= d
299
+
300
+ # Compute spatial flat index for target shape
301
+ spatial_flat = spatial_coords[0]
302
+ for i in range(1, ndim):
303
+ spatial_flat = spatial_flat * target_spatial[i] + spatial_coords[i]
304
+
305
+ # Compute full flat index
306
+ indices = (
307
+ n_coord * (channels * target_spatial_size)
308
+ + c_coord * target_spatial_size
309
+ + spatial_flat
310
+ )
311
+
312
+ if ndim == 1:
313
+ return F.max_unpool1d(
314
+ x,
315
+ indices,
316
+ kernel[0],
317
+ stride=stride[0],
318
+ padding=padding[0],
319
+ output_size=out_size,
320
+ )
321
+ elif ndim == 2:
322
+ return F.max_unpool2d(
323
+ x,
324
+ indices,
325
+ kernel,
326
+ stride=stride,
327
+ padding=padding,
328
+ output_size=out_size,
329
+ )
330
+ elif ndim == 3:
331
+ return F.max_unpool3d(
332
+ x,
333
+ indices,
334
+ kernel,
335
+ stride=stride,
336
+ padding=padding,
337
+ output_size=out_size,
338
+ )
339
+ else:
340
+ raise NotImplementedError(f"MaxUnpool{ndim}D not supported")
341
+
342
+ return builder.call_function(
343
+ _max_unpool,
344
+ args=(x, indices, kernel_shape, strides, pads, output_shape),
345
+ )
346
+
347
+
348
+ @register("AveragePool")
349
+ def average_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
350
+ """Average pooling."""
351
+ x = builder.get_value(node.input[0])
352
+
353
+ kernel_shape = get_attribute(node, "kernel_shape")
354
+ strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
355
+ pads = get_attribute(node, "pads")
356
+ dilations = get_attribute(node, "dilations") or [1] * len(kernel_shape)
357
+ ceil_mode = get_attribute(node, "ceil_mode", 0)
358
+ count_include_pad = get_attribute(node, "count_include_pad", 0)
359
+ auto_pad = get_attribute(node, "auto_pad", "NOTSET")
360
+
361
+ def _avg_pool_dilated(
362
+ x, kernel_shape, strides, dilations, pads, ceil_mode, count_include_pad
363
+ ):
364
+ """Compute average pooling with dilation support using unfold.
365
+
366
+ PyTorch's avg_pool doesn't support dilation, so we implement it manually.
367
+ """
368
+ ndim = len(kernel_shape)
369
+ batch_size = x.shape[0]
370
+ channels = x.shape[1]
371
+ spatial_shape = list(x.shape[2:])
372
+
373
+ # Compute effective kernel size with dilation
374
+ # effective_k = (k - 1) * d + 1
375
+ effective_kernel = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
376
+
377
+ # Apply padding if specified
378
+ if pads is not None:
379
+ n = len(pads) // 2
380
+ pads_begin = [pads[i] for i in range(n)]
381
+ pads_end = [pads[i + n] for i in range(n)]
382
+ else:
383
+ n = ndim
384
+ pads_begin = [0] * n
385
+ pads_end = [0] * n
386
+
387
+ # Track original pads (before ceil_mode adjustment) for count_include_pad
388
+ orig_pads_end = pads_end.copy()
389
+
390
+ # For ceil_mode, add extra end padding if needed to get ceil behavior
391
+ # ceil_mode output: ceil((input + pad_begin + pad_end - ek) / stride) + 1
392
+ # floor_mode output: floor((input + pad_begin + pad_end - ek) / stride) + 1
393
+ # To get ceil behavior with floor, add padding: (stride - 1)
394
+ ceil_extra_pad = [0] * ndim
395
+ if ceil_mode:
396
+ for i in range(ndim):
397
+ padded_size = spatial_shape[i] + pads_begin[i] + pads_end[i]
398
+ # Compute output with floor
399
+ out_floor = (padded_size - effective_kernel[i]) // strides[i] + 1
400
+ # Compute output with ceil
401
+ out_ceil = (
402
+ padded_size - effective_kernel[i] + strides[i] - 1
403
+ ) // strides[i] + 1
404
+ if out_ceil > out_floor:
405
+ # Need extra padding to get one more output element
406
+ ceil_extra_pad[i] = strides[i] - 1
407
+ pads_end[i] += ceil_extra_pad[i]
408
+
409
+ # Build pad_list for F.pad (reversed order: last dim first)
410
+ pad_list = []
411
+ for i in range(ndim - 1, -1, -1):
412
+ pad_list.extend([pads_begin[i], pads_end[i]])
413
+
414
+ has_padding = any(p > 0 for p in pad_list)
415
+ has_ceil_extra = any(p > 0 for p in ceil_extra_pad)
416
+
417
+ if has_padding:
418
+ x = F.pad(x, pad_list, value=0)
419
+ spatial_shape_padded = list(x.shape[2:])
420
+
421
+ # Create a mask for computing the correct count
422
+ # Case 1: count_include_pad=False -> mask marks original (non-padded) area
423
+ # Case 2: count_include_pad=True with ceil_extra_pad -> mask marks area
424
+ # up to original pads (but not ceil extra pads)
425
+ # Case 3: count_include_pad=True without ceil_extra_pad -> no mask needed
426
+ if not count_include_pad:
427
+ # Original shape before any padding
428
+ orig_shape = [batch_size, channels] + [
429
+ spatial_shape_padded[i] - pads_begin[i] - pads_end[i]
430
+ for i in range(ndim)
431
+ ]
432
+ mask = torch.ones(orig_shape, dtype=x.dtype, device=x.device)
433
+ mask = F.pad(mask, pad_list, value=0)
434
+ elif has_ceil_extra:
435
+ # count_include_pad=True but with ceil extra padding
436
+ # Create mask that includes original padding but not ceil extra
437
+ orig_pad_list = []
438
+ for i in range(ndim - 1, -1, -1):
439
+ orig_pad_list.extend([pads_begin[i], orig_pads_end[i]])
440
+ # Shape after original padding only
441
+ orig_padded_shape = [batch_size, channels] + [
442
+ spatial_shape[i] + pads_begin[i] + orig_pads_end[i]
443
+ for i in range(ndim)
444
+ ]
445
+ mask = torch.ones(orig_padded_shape, dtype=x.dtype, device=x.device)
446
+ # Pad with ceil extra padding (but these should be 0 in mask)
447
+ ceil_pad_list = []
448
+ for i in range(ndim - 1, -1, -1):
449
+ ceil_pad_list.extend([0, ceil_extra_pad[i]])
450
+ mask = F.pad(mask, ceil_pad_list, value=0)
451
+ else:
452
+ mask = None
453
+ else:
454
+ mask = None
455
+
456
+ # Use unfold to extract patches with dilation
457
+ # For each spatial dimension, unfold with size=kernel and step=stride
458
+ # We need to account for dilation by selecting every d-th element
459
+
460
+ if ndim == 1:
461
+ # Use unfold for 1D
462
+ # unfold(dimension, size, step)
463
+ _, d, s = kernel_shape[0], dilations[0], strides[0]
464
+ ek = effective_kernel[0]
465
+
466
+ # Unfold with effective kernel size and stride
467
+ # Then select every d-th element within each patch
468
+ patches = x.unfold(2, ek, s) # (N, C, out_L, ek)
469
+ # Select dilated elements: indices 0, d, 2d, ..., (k-1)*d
470
+ indices = torch.arange(0, ek, d, device=x.device)
471
+ patches = patches.index_select(-1, indices) # (N, C, out_L, k)
472
+
473
+ if mask is not None:
474
+ mask_patches = mask.unfold(2, ek, s)
475
+ mask_patches = mask_patches.index_select(-1, indices)
476
+ count = mask_patches.sum(dim=-1)
477
+ sum_val = patches.sum(dim=-1)
478
+ return sum_val / count.clamp(min=1)
479
+ else:
480
+ return patches.mean(dim=-1)
481
+
482
+ elif ndim == 2:
483
+ k0, k1 = kernel_shape
484
+ d0, d1 = dilations
485
+ s0, s1 = strides
486
+ ek0, ek1 = effective_kernel
487
+
488
+ # Unfold along height (dim 2), then width (dim 3)
489
+ patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1)
490
+ # patches shape: (N, C, out_H, out_W, ek0, ek1)
491
+
492
+ # Select dilated elements
493
+ indices0 = torch.arange(0, ek0, d0, device=x.device)
494
+ indices1 = torch.arange(0, ek1, d1, device=x.device)
495
+ patches = patches.index_select(-2, indices0).index_select(-1, indices1)
496
+ # patches shape: (N, C, out_H, out_W, k0, k1)
497
+
498
+ if mask is not None:
499
+ mask_patches = mask.unfold(2, ek0, s0).unfold(3, ek1, s1)
500
+ mask_patches = mask_patches.index_select(-2, indices0).index_select(
501
+ -1, indices1
502
+ )
503
+ count = mask_patches.sum(dim=(-2, -1))
504
+ sum_val = patches.sum(dim=(-2, -1))
505
+ return sum_val / count.clamp(min=1)
506
+ else:
507
+ return patches.mean(dim=(-2, -1))
508
+
509
+ elif ndim == 3:
510
+ k0, k1, k2 = kernel_shape
511
+ d0, d1, d2 = dilations
512
+ s0, s1, s2 = strides
513
+ ek0, ek1, ek2 = effective_kernel
514
+
515
+ # Unfold along each spatial dimension
516
+ patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1).unfold(4, ek2, s2)
517
+ # patches shape: (N, C, out_D, out_H, out_W, ek0, ek1, ek2)
518
+
519
+ # Select dilated elements
520
+ indices0 = torch.arange(0, ek0, d0, device=x.device)
521
+ indices1 = torch.arange(0, ek1, d1, device=x.device)
522
+ indices2 = torch.arange(0, ek2, d2, device=x.device)
523
+ patches = (
524
+ patches.index_select(-3, indices0)
525
+ .index_select(-2, indices1)
526
+ .index_select(-1, indices2)
527
+ )
528
+ # patches shape: (N, C, out_D, out_H, out_W, k0, k1, k2)
529
+
530
+ if mask is not None:
531
+ mask_patches = (
532
+ mask.unfold(2, ek0, s0).unfold(3, ek1, s1).unfold(4, ek2, s2)
533
+ )
534
+ mask_patches = (
535
+ mask_patches.index_select(-3, indices0)
536
+ .index_select(-2, indices1)
537
+ .index_select(-1, indices2)
538
+ )
539
+ count = mask_patches.sum(dim=(-3, -2, -1))
540
+ sum_val = patches.sum(dim=(-3, -2, -1))
541
+ return sum_val / count.clamp(min=1)
542
+ else:
543
+ return patches.mean(dim=(-3, -2, -1))
544
+
545
+ else:
546
+ raise NotImplementedError(f"AveragePool{ndim}D not supported")
547
+
548
+ def _avg_pool(
549
+ x,
550
+ kernel_shape,
551
+ strides,
552
+ pads,
553
+ dilations,
554
+ ceil_mode,
555
+ count_include_pad,
556
+ auto_pad,
557
+ ):
558
+ ndim = len(kernel_shape)
559
+
560
+ # Check if we have non-trivial dilation
561
+ has_dilation = any(d != 1 for d in dilations)
562
+
563
+ # Handle auto_pad first (before explicit pads)
564
+ if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
565
+ # For SAME padding with count_include_pad=0, we need to compute
566
+ # the average only over valid (non-padded) input positions.
567
+ # We do this by:
568
+ # 1. Sum pooling on padded input (pad with 0s, so they don't affect sum)
569
+ # 2. Count pooling on a mask (to count valid positions per output)
570
+ # 3. Divide sum by count
571
+ input_shape = x.shape[2:]
572
+ pad_list = compute_same_padding(
573
+ tuple(input_shape),
574
+ tuple(kernel_shape),
575
+ tuple(strides),
576
+ tuple(dilations),
577
+ auto_pad,
578
+ use_effective_kernel=True,
579
+ )
580
+
581
+ # Convert pad_list to pads format for dilated implementation
582
+ pads_onnx = pad_list_to_onnx_pads(pad_list, ndim)
583
+
584
+ # Use dilated implementation which handles padding correctly
585
+ return _avg_pool_dilated(
586
+ x, kernel_shape, strides, dilations, pads_onnx, ceil_mode, 0
587
+ )
588
+
589
+ # If we have dilation, use the dilated implementation
590
+ if has_dilation:
591
+ return _avg_pool_dilated(
592
+ x, kernel_shape, strides, dilations, pads, ceil_mode, count_include_pad
593
+ )
594
+
595
+ # Check if we need to use manual padding (asymmetric or exceeds limit)
596
+ padding = 0
597
+ use_manual_pad = False
598
+ if pads is not None:
599
+ n = len(pads) // 2
600
+ symmetric = all(pads[i] == pads[i + n] for i in range(n))
601
+
602
+ # Check if padding exceeds PyTorch's limit
603
+ # PyTorch: pad should be at most half of kernel size
604
+ max_allowed_pad = [k // 2 for k in kernel_shape]
605
+ exceeds_limit = any(
606
+ pads[i] > max_allowed_pad[i] or pads[i + n] > max_allowed_pad[i]
607
+ for i in range(n)
608
+ )
609
+
610
+ if symmetric and not exceeds_limit:
611
+ padding = tuple(pads[:n])
612
+ else:
613
+ use_manual_pad = True
614
+
615
+ if use_manual_pad:
616
+ # Use dilated implementation which handles asymmetric/large padding
617
+ return _avg_pool_dilated(
618
+ x, kernel_shape, strides, dilations, pads, ceil_mode, count_include_pad
619
+ )
620
+
621
+ kernel = tuple(kernel_shape)
622
+ stride = tuple(strides)
623
+
624
+ if ndim == 1:
625
+ return F.avg_pool1d(
626
+ x,
627
+ kernel[0],
628
+ stride=stride[0],
629
+ padding=padding if isinstance(padding, int) else padding[0],
630
+ ceil_mode=bool(ceil_mode),
631
+ count_include_pad=bool(count_include_pad),
632
+ )
633
+ elif ndim == 2:
634
+ return F.avg_pool2d(
635
+ x,
636
+ kernel,
637
+ stride=stride,
638
+ padding=padding,
639
+ ceil_mode=bool(ceil_mode),
640
+ count_include_pad=bool(count_include_pad),
641
+ )
642
+ elif ndim == 3:
643
+ return F.avg_pool3d(
644
+ x,
645
+ kernel,
646
+ stride=stride,
647
+ padding=padding,
648
+ ceil_mode=bool(ceil_mode),
649
+ count_include_pad=bool(count_include_pad),
650
+ )
651
+ else:
652
+ raise NotImplementedError(f"AveragePool{ndim}D not supported")
653
+
654
+ return builder.call_function(
655
+ _avg_pool,
656
+ args=(
657
+ x,
658
+ kernel_shape,
659
+ strides,
660
+ pads,
661
+ dilations,
662
+ ceil_mode,
663
+ count_include_pad,
664
+ auto_pad,
665
+ ),
666
+ )
667
+
668
+
669
+ @register("GlobalAveragePool")
670
+ def global_average_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
671
+ """Global average pooling."""
672
+ x = builder.get_value(node.input[0])
673
+
674
+ def _global_avg_pool(x):
675
+ # Average over all spatial dimensions (keep batch and channel)
676
+ dims = tuple(range(2, x.dim()))
677
+ return x.mean(dim=dims, keepdim=True)
678
+
679
+ return builder.call_function(_global_avg_pool, args=(x,))
680
+
681
+
682
+ @register("GlobalMaxPool")
683
+ def global_max_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
684
+ """Global max pooling."""
685
+ x = builder.get_value(node.input[0])
686
+
687
+ def _global_max_pool(x):
688
+ # Max over all spatial dimensions (keep batch and channel)
689
+ result = x
690
+ for dim in range(x.dim() - 1, 1, -1):
691
+ result = result.max(dim=dim, keepdim=True).values
692
+ return result
693
+
694
+ return builder.call_function(_global_max_pool, args=(x,))
695
+
696
+
697
+ @register("LpPool")
698
+ def lp_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
699
+ """Lp pooling.
700
+
701
+ Computes the Lp norm over a sliding window:
702
+ output = (sum(|x|^p))^(1/p)
703
+ """
704
+ x = builder.get_value(node.input[0])
705
+
706
+ kernel_shape = get_attribute(node, "kernel_shape")
707
+ strides = get_attribute(node, "strides") or [1] * len(kernel_shape)
708
+ pads = get_attribute(node, "pads")
709
+ dilations = get_attribute(node, "dilations") or [1] * len(kernel_shape)
710
+ ceil_mode = get_attribute(node, "ceil_mode", 0)
711
+ auto_pad = get_attribute(node, "auto_pad", "NOTSET")
712
+ p = get_attribute(node, "p", 2)
713
+
714
+ def _lp_pool_dilated(x, kernel_shape, strides, dilations, pads, ceil_mode, p):
715
+ """Compute Lp pooling with dilation support using unfold.
716
+
717
+ PyTorch's lp_pool doesn't support dilation or padding, so we implement
718
+ it manually.
719
+ """
720
+ ndim = len(kernel_shape)
721
+ spatial_shape = list(x.shape[2:])
722
+
723
+ # Compute effective kernel size with dilation
724
+ # effective_k = (k - 1) * d + 1
725
+ effective_kernel = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
726
+
727
+ # Apply padding if specified
728
+ if pads is not None:
729
+ n = len(pads) // 2
730
+ pads_begin = [pads[i] for i in range(n)]
731
+ pads_end = [pads[i + n] for i in range(n)]
732
+ else:
733
+ n = ndim
734
+ pads_begin = [0] * n
735
+ pads_end = [0] * n
736
+
737
+ # For ceil_mode, add extra end padding if needed to get ceil behavior
738
+ if ceil_mode:
739
+ for i in range(ndim):
740
+ padded_size = spatial_shape[i] + pads_begin[i] + pads_end[i]
741
+ # Compute output with floor
742
+ out_floor = (padded_size - effective_kernel[i]) // strides[i] + 1
743
+ # Compute output with ceil
744
+ out_ceil = (
745
+ padded_size - effective_kernel[i] + strides[i] - 1
746
+ ) // strides[i] + 1
747
+ if out_ceil > out_floor:
748
+ # Need extra padding to get one more output element
749
+ pads_end[i] += strides[i] - 1
750
+
751
+ # Build pad_list for F.pad (reversed order: last dim first)
752
+ pad_list = []
753
+ for i in range(ndim - 1, -1, -1):
754
+ pad_list.extend([pads_begin[i], pads_end[i]])
755
+
756
+ has_padding = any(p_val > 0 for p_val in pad_list)
757
+
758
+ if has_padding:
759
+ x = F.pad(x, pad_list, value=0)
760
+
761
+ # Use unfold to extract patches with dilation
762
+ if ndim == 1:
763
+ _, d, s = kernel_shape[0], dilations[0], strides[0]
764
+ ek = effective_kernel[0]
765
+
766
+ # Unfold with effective kernel size and stride
767
+ patches = x.unfold(2, ek, s) # (N, C, out_L, ek)
768
+ # Select dilated elements: indices 0, d, 2d, ..., (k-1)*d
769
+ indices = torch.arange(0, ek, d, device=x.device)
770
+ patches = patches.index_select(-1, indices) # (N, C, out_L, k)
771
+
772
+ # Compute Lp norm: (sum(|x|^p))^(1/p)
773
+ return (patches.abs().pow(p).sum(dim=-1)).pow(1.0 / p)
774
+
775
+ elif ndim == 2:
776
+ k0, k1 = kernel_shape
777
+ d0, d1 = dilations
778
+ s0, s1 = strides
779
+ ek0, ek1 = effective_kernel
780
+
781
+ # Unfold along height (dim 2), then width (dim 3)
782
+ patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1)
783
+ # patches shape: (N, C, out_H, out_W, ek0, ek1)
784
+
785
+ # Select dilated elements
786
+ indices0 = torch.arange(0, ek0, d0, device=x.device)
787
+ indices1 = torch.arange(0, ek1, d1, device=x.device)
788
+ patches = patches.index_select(-2, indices0).index_select(-1, indices1)
789
+ # patches shape: (N, C, out_H, out_W, k0, k1)
790
+
791
+ # Compute Lp norm: (sum(|x|^p))^(1/p)
792
+ return (patches.abs().pow(p).sum(dim=(-2, -1))).pow(1.0 / p)
793
+
794
+ elif ndim == 3:
795
+ k0, k1, k2 = kernel_shape
796
+ d0, d1, d2 = dilations
797
+ s0, s1, s2 = strides
798
+ ek0, ek1, ek2 = effective_kernel
799
+
800
+ # Unfold along each spatial dimension
801
+ patches = x.unfold(2, ek0, s0).unfold(3, ek1, s1).unfold(4, ek2, s2)
802
+ # patches shape: (N, C, out_D, out_H, out_W, ek0, ek1, ek2)
803
+
804
+ # Select dilated elements
805
+ indices0 = torch.arange(0, ek0, d0, device=x.device)
806
+ indices1 = torch.arange(0, ek1, d1, device=x.device)
807
+ indices2 = torch.arange(0, ek2, d2, device=x.device)
808
+ patches = (
809
+ patches.index_select(-3, indices0)
810
+ .index_select(-2, indices1)
811
+ .index_select(-1, indices2)
812
+ )
813
+ # patches shape: (N, C, out_D, out_H, out_W, k0, k1, k2)
814
+
815
+ # Compute Lp norm: (sum(|x|^p))^(1/p)
816
+ return (patches.abs().pow(p).sum(dim=(-3, -2, -1))).pow(1.0 / p)
817
+
818
+ else:
819
+ raise NotImplementedError(f"LpPool{ndim}D not supported")
820
+
821
+ def _lp_pool(x, kernel_shape, strides, pads, dilations, ceil_mode, auto_pad, p):
822
+ ndim = len(kernel_shape)
823
+
824
+ # Check if we have non-trivial dilation
825
+ has_dilation = any(d != 1 for d in dilations)
826
+
827
+ # Handle auto_pad first (before explicit pads)
828
+ if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
829
+ input_shape = x.shape[2:]
830
+ pad_list = compute_same_padding(
831
+ tuple(input_shape),
832
+ tuple(kernel_shape),
833
+ tuple(strides),
834
+ tuple(dilations),
835
+ auto_pad,
836
+ use_effective_kernel=True,
837
+ )
838
+
839
+ # Convert pad_list to pads format for dilated implementation
840
+ pads_onnx = pad_list_to_onnx_pads(pad_list, ndim)
841
+
842
+ # Use dilated implementation which handles padding correctly
843
+ return _lp_pool_dilated(
844
+ x, kernel_shape, strides, dilations, pads_onnx, ceil_mode, p
845
+ )
846
+
847
+ # If we have dilation, use the dilated implementation
848
+ if has_dilation:
849
+ return _lp_pool_dilated(
850
+ x, kernel_shape, strides, dilations, pads, ceil_mode, p
851
+ )
852
+
853
+ # Check if we need to use manual padding (asymmetric or any padding)
854
+ # PyTorch's lp_pool doesn't support padding at all
855
+ if pads is not None and any(pad_val > 0 for pad_val in pads):
856
+ return _lp_pool_dilated(
857
+ x, kernel_shape, strides, dilations, pads, ceil_mode, p
858
+ )
859
+
860
+ # PyTorch's lp_pool functions use sign(f(x)) * |f(x)|^(1/p) where f(x) = sum(x^p),
861
+ # but ONNX's LpPool uses (sum(|x|^p))^(1/p). The difference is that ONNX
862
+ # takes absolute value FIRST before raising to power p. This matters when
863
+ # x contains negative values and p is odd (like p=3), as PyTorch's version
864
+ # can produce NaN while ONNX's version is always well-defined.
865
+ # Therefore, we always use our manual implementation which correctly applies abs() first.
866
+ return _lp_pool_dilated(x, kernel_shape, strides, dilations, pads, ceil_mode, p)
867
+
868
+ return builder.call_function(
869
+ _lp_pool,
870
+ args=(
871
+ x,
872
+ kernel_shape,
873
+ strides,
874
+ pads,
875
+ dilations,
876
+ ceil_mode,
877
+ auto_pad,
878
+ p,
879
+ ),
880
+ )
881
+
882
+
883
+ @register("GlobalLpPool")
884
+ def global_lp_pool(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
885
+ """Global Lp pooling.
886
+
887
+ Computes the Lp norm over all spatial dimensions.
888
+ """
889
+ x = builder.get_value(node.input[0])
890
+ p = get_attribute(node, "p", 2)
891
+
892
+ def _global_lp_pool(x, p):
893
+ # Lp norm over all spatial dimensions (keep batch and channel)
894
+ dims = tuple(range(2, x.dim()))
895
+ return (x.abs().pow(p).sum(dim=dims, keepdim=True)).pow(1.0 / p)
896
+
897
+ return builder.call_function(_global_lp_pool, args=(x, p))