onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2fx/ops/image.py ADDED
@@ -0,0 +1,748 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Image and spatial transformation operators.
3
+
4
+ This module implements ONNX operators for image resizing and
5
+ spatial dimension rearrangement.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING
9
+
10
+ import onnx
11
+ import torch
12
+
13
+ from ..op_registry import register
14
+ from ..utils.attributes import get_attribute
15
+ from ..utils.op_helpers import get_optional_input
16
+
17
+ if TYPE_CHECKING:
18
+ from ..graph_builder import GraphBuilder
19
+
20
+
21
+ # =============================================================================
22
+ # Resize operator
23
+ # =============================================================================
24
+
25
+
26
+ @register("Resize")
27
+ def resize(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
28
+ """Resize tensor using interpolation."""
29
+ x = builder.get_value(node.input[0])
30
+
31
+ # roi, scales, sizes are optional inputs
32
+ roi = get_optional_input(builder, node, 1)
33
+ scales = get_optional_input(builder, node, 2)
34
+ sizes = get_optional_input(builder, node, 3)
35
+
36
+ mode = get_attribute(node, "mode", "nearest")
37
+ coordinate_transformation_mode = get_attribute(
38
+ node, "coordinate_transformation_mode", "half_pixel"
39
+ )
40
+
41
+ def _resize(x, roi, scales, sizes, mode, coord_mode):
42
+ import torch.nn.functional as F
43
+
44
+ # Map ONNX mode to PyTorch mode
45
+ mode_map = {
46
+ "nearest": "nearest",
47
+ "linear": "bilinear" if x.dim() == 4 else "linear",
48
+ "cubic": "bicubic",
49
+ }
50
+ torch_mode = mode_map.get(mode, "nearest")
51
+
52
+ # Determine align_corners based on coordinate transformation mode
53
+ align_corners = coord_mode == "align_corners"
54
+ if torch_mode == "nearest":
55
+ align_corners = None
56
+
57
+ if sizes is not None:
58
+ # Use explicit sizes
59
+ size_list = sizes.tolist() if isinstance(sizes, torch.Tensor) else sizes
60
+ # Skip batch and channel dimensions
61
+ output_size = [int(s) for s in size_list[2:]]
62
+ elif scales is not None:
63
+ # Use scales
64
+ scale_list = scales.tolist() if isinstance(scales, torch.Tensor) else scales
65
+ input_shape = x.shape[2:]
66
+ output_size = [int(s * sc) for s, sc in zip(input_shape, scale_list[2:])]
67
+ else:
68
+ return x
69
+
70
+ kwargs = {"size": output_size, "mode": torch_mode}
71
+ if align_corners is not None and torch_mode not in ["nearest", "area"]:
72
+ kwargs["align_corners"] = align_corners
73
+
74
+ return F.interpolate(x, **kwargs)
75
+
76
+ return builder.call_function(
77
+ _resize, args=(x, roi, scales, sizes, mode, coordinate_transformation_mode)
78
+ )
79
+
80
+
81
+ # =============================================================================
82
+ # Upsample operator (deprecated in opset 10, replaced by Resize)
83
+ # =============================================================================
84
+
85
+
86
+ @register("Upsample", since_version=7)
87
+ def upsample(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
88
+ """Upsample tensor using interpolation.
89
+
90
+ Deprecated: This operator is deprecated since opset 10.
91
+ Use Resize operator instead.
92
+
93
+ Opset 7-8: scales is an attribute
94
+ Opset 9: scales is an input
95
+ """
96
+ x = builder.get_value(node.input[0])
97
+
98
+ # In opset 9, scales is an input; in opset 7-8, it's an attribute
99
+ opset = builder.opset_version
100
+ if opset >= 9 and len(node.input) > 1 and node.input[1]:
101
+ scales = builder.get_value(node.input[1])
102
+ else:
103
+ scales = get_attribute(node, "scales")
104
+
105
+ mode = get_attribute(node, "mode", "nearest")
106
+
107
+ def _upsample(x, scales, mode):
108
+ import torch.nn.functional as F
109
+
110
+ # Map ONNX mode to PyTorch mode
111
+ mode_map = {
112
+ "nearest": "nearest",
113
+ "linear": "bilinear" if x.dim() == 4 else "linear",
114
+ "cubic": "bicubic",
115
+ }
116
+ torch_mode = mode_map.get(mode, "nearest")
117
+
118
+ # Use scales to compute output size
119
+ scale_list = scales.tolist() if isinstance(scales, torch.Tensor) else scales
120
+ input_shape = x.shape[2:]
121
+ output_size = [int(s * sc) for s, sc in zip(input_shape, scale_list[2:])]
122
+
123
+ kwargs = {"size": output_size, "mode": torch_mode}
124
+ # align_corners is not used for nearest mode
125
+ if torch_mode not in ["nearest", "area"]:
126
+ kwargs["align_corners"] = False
127
+
128
+ return F.interpolate(x, **kwargs)
129
+
130
+ return builder.call_function(_upsample, args=(x, scales, mode))
131
+
132
+
133
+ # =============================================================================
134
+ # Depth/Space rearrangement operators
135
+ # =============================================================================
136
+
137
+
138
+ @register("DepthToSpace")
139
+ def depth_to_space(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
140
+ """Rearrange depth to spatial dimensions."""
141
+ x = builder.get_value(node.input[0])
142
+ blocksize = get_attribute(node, "blocksize")
143
+ mode = get_attribute(node, "mode", "DCR")
144
+
145
+ def _depth_to_space(x, blocksize, mode):
146
+ b, c, h, w = x.shape
147
+ if mode == "DCR":
148
+ # Depth-Column-Row
149
+ x = x.reshape(b, blocksize, blocksize, c // (blocksize**2), h, w)
150
+ x = x.permute(0, 3, 4, 1, 5, 2)
151
+ x = x.reshape(b, c // (blocksize**2), h * blocksize, w * blocksize)
152
+ else:
153
+ # CRD mode
154
+ x = x.reshape(b, c // (blocksize**2), blocksize, blocksize, h, w)
155
+ x = x.permute(0, 1, 4, 2, 5, 3)
156
+ x = x.reshape(b, c // (blocksize**2), h * blocksize, w * blocksize)
157
+ return x
158
+
159
+ return builder.call_function(_depth_to_space, args=(x, blocksize, mode))
160
+
161
+
162
+ @register("SpaceToDepth")
163
+ def space_to_depth(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
164
+ """Rearrange spatial dimensions to depth."""
165
+ x = builder.get_value(node.input[0])
166
+ blocksize = get_attribute(node, "blocksize")
167
+
168
+ def _space_to_depth(x, blocksize):
169
+ b, c, h, w = x.shape
170
+ x = x.reshape(b, c, h // blocksize, blocksize, w // blocksize, blocksize)
171
+ x = x.permute(0, 3, 5, 1, 2, 4)
172
+ x = x.reshape(b, c * blocksize * blocksize, h // blocksize, w // blocksize)
173
+ return x
174
+
175
+ return builder.call_function(_space_to_depth, args=(x, blocksize))
176
+
177
+
178
+ # =============================================================================
179
+ # Col2Im operator
180
+ # =============================================================================
181
+
182
+
183
+ @register("Col2Im", since_version=18)
184
+ def col2im(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
185
+ """Rearrange column blocks back into a multidimensional image.
186
+
187
+ ONNX Col2Im is the inverse of Im2Col. It combines sliding local blocks
188
+ (columns) back into a larger image tensor.
189
+
190
+ Inputs:
191
+ input: [N, C * prod(block_shape), L] - batched column data
192
+ image_shape: spatial dimensions of the output image
193
+ block_shape: shape of the sliding block
194
+
195
+ Attributes:
196
+ strides: stride along each spatial axis (default: 1)
197
+ pads: padding for each spatial axis in ONNX format (default: 0)
198
+ dilations: dilation for each spatial axis (default: 1)
199
+
200
+ Output:
201
+ [N, C, *image_shape] - reconstructed image
202
+ """
203
+ x = builder.get_value(node.input[0])
204
+ image_shape = builder.get_value(node.input[1])
205
+ block_shape = builder.get_value(node.input[2])
206
+
207
+ strides = get_attribute(node, "strides")
208
+ pads = get_attribute(node, "pads")
209
+ dilations = get_attribute(node, "dilations")
210
+
211
+ def _col2im(x, image_shape, block_shape, strides, pads, dilations):
212
+ import torch.nn.functional as F
213
+ from functools import reduce
214
+ from itertools import product
215
+ from operator import mul
216
+
217
+ # Convert to lists if tensors
218
+ if isinstance(image_shape, torch.Tensor):
219
+ image_shape = image_shape.tolist()
220
+ if isinstance(block_shape, torch.Tensor):
221
+ block_shape = block_shape.tolist()
222
+
223
+ n_dims = len(block_shape)
224
+
225
+ # Default values
226
+ if strides is None:
227
+ strides = [1] * n_dims
228
+ if pads is None:
229
+ pads = [0] * (2 * n_dims)
230
+ if dilations is None:
231
+ dilations = [1] * n_dims
232
+
233
+ # For 2D, use PyTorch's optimized fold
234
+ if n_dims == 2:
235
+ # PyTorch fold uses symmetric padding per dimension
236
+ padding = (pads[0], pads[1])
237
+ return F.fold(
238
+ x,
239
+ output_size=tuple(image_shape),
240
+ kernel_size=tuple(block_shape),
241
+ stride=tuple(strides),
242
+ padding=padding,
243
+ dilation=tuple(dilations),
244
+ )
245
+
246
+ # For N-D, implement manually
247
+ N = x.shape[0]
248
+ L = x.shape[2]
249
+ block_size = reduce(mul, block_shape, 1)
250
+ C = x.shape[1] // block_size
251
+
252
+ # Reshape input: [N, C * prod(block_shape), L] -> [N, C, *block_shape, L]
253
+ input_reshaped = x.reshape(N, C, *block_shape, L)
254
+
255
+ # Initialize output: [N, C, *image_shape]
256
+ output = torch.zeros(N, C, *image_shape, dtype=x.dtype, device=x.device)
257
+
258
+ # Compute effective kernel size after dilation
259
+ effective_block = [(b - 1) * d + 1 for b, d in zip(block_shape, dilations)]
260
+
261
+ # Compute number of blocks in each dimension
262
+ n_blocks = []
263
+ for i, (img_dim, eff_block, s) in enumerate(
264
+ zip(image_shape, effective_block, strides)
265
+ ):
266
+ p_begin = pads[i]
267
+ p_end = pads[n_dims + i]
268
+ n_block = (img_dim + p_begin + p_end - eff_block) // s + 1
269
+ n_blocks.append(n_block)
270
+
271
+ # Iterate over all block positions
272
+ block_indices = list(product(*[range(nb) for nb in n_blocks]))
273
+
274
+ for l_idx, block_idx in enumerate(block_indices):
275
+ # Compute starting position for this block
276
+ starts = [
277
+ bi * s - pads[i] for i, (bi, s) in enumerate(zip(block_idx, strides))
278
+ ]
279
+
280
+ # For each position in the block
281
+ block_positions = list(product(*[range(b) for b in block_shape]))
282
+ for block_pos in block_positions:
283
+ # Compute actual output position with dilation
284
+ output_pos = [
285
+ starts[i] + block_pos[i] * dilations[i] for i in range(n_dims)
286
+ ]
287
+
288
+ # Check bounds
289
+ valid = all(0 <= output_pos[i] < image_shape[i] for i in range(n_dims))
290
+ if valid:
291
+ # Get value from input_reshaped: [N, C, *block_shape, L]
292
+ idx = (slice(None), slice(None)) + tuple(block_pos) + (l_idx,)
293
+ value = input_reshaped[idx]
294
+
295
+ # Add to output
296
+ out_idx = (slice(None), slice(None)) + tuple(output_pos)
297
+ output[out_idx] += value
298
+
299
+ return output
300
+
301
+ return builder.call_function(
302
+ _col2im, args=(x, image_shape, block_shape, strides, pads, dilations)
303
+ )
304
+
305
+
306
+ # =============================================================================
307
+ # CenterCropPad operator
308
+ # =============================================================================
309
+
310
+
311
+ # =============================================================================
312
+ # GridSample operator
313
+ # =============================================================================
314
+
315
+
316
+ @register("GridSample", since_version=16)
317
+ def grid_sample(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
318
+ """Sample input using grid of sampling locations.
319
+
320
+ Given an input X and a flow-field grid, computes the output Y using X values
321
+ and pixel locations from the grid. This is equivalent to PyTorch's
322
+ torch.nn.functional.grid_sample.
323
+
324
+ Inputs:
325
+ X: Input tensor of shape (N, C, H, W) for 4D or (N, C, D, H, W) for 5D
326
+ grid: Grid tensor of shape (N, H_out, W_out, 2) for 4D or
327
+ (N, D_out, H_out, W_out, 3) for 5D
328
+
329
+ Attributes:
330
+ align_corners: If 1, extrema (-1 and 1) refer to center of corner pixels.
331
+ If 0, they refer to corner points. Default: 0
332
+ mode: Interpolation mode - 'linear'/'bilinear' (default), 'nearest',
333
+ 'cubic'/'bicubic'. Opset 16 uses bilinear/bicubic, opset 20+ uses
334
+ linear/cubic.
335
+ padding_mode: Padding mode for outside grid values - 'zeros' (default),
336
+ 'border', 'reflection'
337
+
338
+ Output:
339
+ Y: Output tensor of shape (N, C, H_out, W_out) or (N, C, D_out, H_out, W_out)
340
+ """
341
+ x = builder.get_value(node.input[0])
342
+ grid = builder.get_value(node.input[1])
343
+
344
+ align_corners = get_attribute(node, "align_corners", 0)
345
+ # Handle different mode names across opset versions
346
+ # Opset 16: bilinear (default), nearest, bicubic
347
+ # Opset 20+: linear (default), nearest, cubic
348
+ mode = get_attribute(node, "mode", "linear")
349
+ padding_mode = get_attribute(node, "padding_mode", "zeros")
350
+
351
+ def _grid_sample(x, grid, mode, padding_mode, align_corners):
352
+ import torch.nn.functional as F
353
+
354
+ # Map ONNX mode names to PyTorch mode names
355
+ # PyTorch expects: 'bilinear', 'nearest', 'bicubic' for 4D input
356
+ # PyTorch expects: 'bilinear', 'nearest' for 5D input (no bicubic)
357
+ mode_map = {
358
+ "linear": "bilinear",
359
+ "bilinear": "bilinear",
360
+ "nearest": "nearest",
361
+ "cubic": "bicubic",
362
+ "bicubic": "bicubic",
363
+ }
364
+ torch_mode = mode_map.get(mode, "bilinear")
365
+
366
+ # Convert align_corners from int to bool
367
+ align_corners_bool = bool(align_corners)
368
+
369
+ return F.grid_sample(
370
+ x,
371
+ grid,
372
+ mode=torch_mode,
373
+ padding_mode=padding_mode,
374
+ align_corners=align_corners_bool,
375
+ )
376
+
377
+ return builder.call_function(
378
+ _grid_sample, args=(x, grid, mode, padding_mode, align_corners)
379
+ )
380
+
381
+
382
+ # =============================================================================
383
+ # AffineGrid operator
384
+ # =============================================================================
385
+
386
+
387
+ @register("AffineGrid", since_version=20)
388
+ def affine_grid(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
389
+ """Generate 2D or 3D flow field (sampling grid) from affine matrices.
390
+
391
+ Given a batch of affine matrices theta, generates a grid of sampling
392
+ locations. This is typically used with GridSample to build Spatial
393
+ Transformer Networks.
394
+
395
+ Inputs:
396
+ theta: Input batch of affine matrices with shape (N, 2, 3) for 2D
397
+ or (N, 3, 4) for 3D
398
+ size: Target output image size (N, C, H, W) for 2D or (N, C, D, H, W)
399
+ for 3D, as a 1-D tensor
400
+
401
+ Attributes:
402
+ align_corners: If 1, consider -1 and 1 to refer to the centers of the
403
+ corner pixels. If 0, consider -1 and 1 to refer to the
404
+ outer edge of corner pixels. Default: 0
405
+
406
+ Output:
407
+ grid: Output tensor of shape (N, H, W, 2) for 2D sample coordinates
408
+ or (N, D, H, W, 3) for 3D sample coordinates
409
+ """
410
+ theta = builder.get_value(node.input[0])
411
+ size = builder.get_value(node.input[1])
412
+
413
+ align_corners = get_attribute(node, "align_corners", 0)
414
+
415
+ def _affine_grid(theta, size, align_corners):
416
+ import torch.nn.functional as F
417
+
418
+ # Convert size tensor to a list of integers for torch.Size
419
+ if isinstance(size, torch.Tensor):
420
+ size_list = size.tolist()
421
+ else:
422
+ size_list = list(size)
423
+ size_tuple = torch.Size([int(s) for s in size_list])
424
+
425
+ # Convert align_corners from int to bool
426
+ align_corners_bool = bool(align_corners)
427
+
428
+ return F.affine_grid(theta, size_tuple, align_corners=align_corners_bool)
429
+
430
+ return builder.call_function(_affine_grid, args=(theta, size, align_corners))
431
+
432
+
433
+ # =============================================================================
434
+ # CenterCropPad operator
435
+ # =============================================================================
436
+
437
+
438
+ @register("CenterCropPad", since_version=18)
439
+ def center_crop_pad(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
440
+ """Center crop or pad an input to given dimensions.
441
+
442
+ The crop/pad dimensions can be specified for a subset of the axes.
443
+ Unspecified dimensions will remain unchanged.
444
+
445
+ For cropping (input > target): centered window, start position rounded down.
446
+ For padding (input < target): centered padding, extra pixel on right if odd.
447
+
448
+ Inputs:
449
+ input_data: Input tensor to crop/pad
450
+ shape: 1-D tensor of target dimensions for specified axes
451
+
452
+ Attributes:
453
+ axes: Subset of axes that shape refers to (default: all axes)
454
+
455
+ Output:
456
+ Output tensor with specified dimensions
457
+ """
458
+ x = builder.get_value(node.input[0])
459
+ shape = builder.get_value(node.input[1])
460
+
461
+ axes = get_attribute(node, "axes")
462
+
463
+ def _center_crop_pad(x, shape, axes):
464
+ # Convert shape to list if tensor
465
+ if isinstance(shape, torch.Tensor):
466
+ target_shape = shape.tolist()
467
+ else:
468
+ target_shape = list(shape)
469
+
470
+ ndim = x.dim()
471
+
472
+ # If axes is not provided, use all axes
473
+ if axes is None:
474
+ axes_list = list(range(ndim))
475
+ else:
476
+ axes_list = list(axes)
477
+
478
+ # Normalize negative axes
479
+ axes_list = [(a + ndim) if a < 0 else a for a in axes_list]
480
+
481
+ # Build slices for cropping and padding amounts
482
+ result = x
483
+ for i, axis in enumerate(axes_list):
484
+ current_size = result.shape[axis]
485
+ target_size = int(target_shape[i])
486
+
487
+ if current_size == target_size:
488
+ # No change needed for this axis
489
+ continue
490
+ elif current_size > target_size:
491
+ # Crop: extract centered window
492
+ # Start position is rounded down (floor division)
493
+ diff = current_size - target_size
494
+ start = diff // 2
495
+ end = start + target_size
496
+
497
+ # Build slice for this axis
498
+ slices = [slice(None)] * result.dim()
499
+ slices[axis] = slice(start, end)
500
+ result = result[tuple(slices)]
501
+ else:
502
+ # Pad: add zeros centered
503
+ # Extra pixel goes to the right side
504
+ diff = target_size - current_size
505
+ pad_before = diff // 2
506
+ pad_after = diff - pad_before
507
+
508
+ # torch.nn.functional.pad uses reverse order: last dim first
509
+ # and pairs are (before, after) for each dim from last to first
510
+ # We need to construct padding for just this one axis
511
+
512
+ # Number of dimensions from the end
513
+ dims_from_end = result.dim() - 1 - axis
514
+
515
+ # Build pad tuple: pairs for each dim from last to first
516
+ # We only pad the current axis
517
+ pad = [0] * (2 * result.dim())
518
+ # Index in pad list: dims_from_end * 2 for before, +1 for after
519
+ pad[dims_from_end * 2] = pad_before
520
+ pad[dims_from_end * 2 + 1] = pad_after
521
+
522
+ result = torch.nn.functional.pad(result, pad, mode="constant", value=0)
523
+
524
+ return result
525
+
526
+ return builder.call_function(_center_crop_pad, args=(x, shape, axes))
527
+
528
+
529
+ # =============================================================================
530
+ # RoiAlign operator
531
+ # =============================================================================
532
+
533
+
534
+ @register("RoiAlign")
535
+ def roi_align(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
536
+ """Region of Interest (RoI) Align operator.
537
+
538
+ Performs RoI pooling with bilinear interpolation for sub-pixel accuracy.
539
+ """
540
+ x = builder.get_value(node.input[0])
541
+ rois = builder.get_value(node.input[1])
542
+ batch_indices = builder.get_value(node.input[2])
543
+
544
+ # Get attributes with defaults
545
+ mode = get_attribute(node, "mode", "avg")
546
+ output_height = get_attribute(node, "output_height", 1)
547
+ output_width = get_attribute(node, "output_width", 1)
548
+ sampling_ratio = get_attribute(node, "sampling_ratio", 0)
549
+ spatial_scale = get_attribute(node, "spatial_scale", 1.0)
550
+ coordinate_transformation_mode = get_attribute(
551
+ node, "coordinate_transformation_mode", "half_pixel"
552
+ )
553
+
554
+ # ONNX coordinate_transformation_mode:
555
+ # - "half_pixel": pixel shift by -0.5, corresponds to aligned=True in PyTorch
556
+ # - "output_half_pixel": no pixel shift (legacy), corresponds to aligned=False
557
+ aligned = coordinate_transformation_mode == "half_pixel"
558
+
559
+ def _roi_align(
560
+ x,
561
+ rois,
562
+ batch_indices,
563
+ mode,
564
+ output_height,
565
+ output_width,
566
+ sampling_ratio,
567
+ spatial_scale,
568
+ aligned,
569
+ ):
570
+ from torchvision.ops import roi_align as tv_roi_align
571
+
572
+ # PyTorch expects boxes in format [batch_idx, x1, y1, x2, y2]
573
+ boxes = torch.cat([batch_indices.unsqueeze(1).float(), rois.float()], dim=1)
574
+ output_size = (output_height, output_width)
575
+
576
+ if mode == "avg":
577
+ # Use torchvision's roi_align directly for average mode
578
+ # sampling_ratio: ONNX uses 0 for adaptive, PyTorch uses -1
579
+ torch_sampling = sampling_ratio if sampling_ratio > 0 else -1
580
+ return tv_roi_align(
581
+ x,
582
+ boxes,
583
+ output_size,
584
+ spatial_scale=spatial_scale,
585
+ sampling_ratio=torch_sampling,
586
+ aligned=aligned,
587
+ )
588
+ else:
589
+ # Max mode: ONNX defines max pooling differently from standard
590
+ # bilinear interpolation. For each sample point, it takes the
591
+ # max of the 4 weighted corner values (not the sum).
592
+ return _roi_align_max_mode(
593
+ x,
594
+ rois,
595
+ batch_indices,
596
+ output_height,
597
+ output_width,
598
+ sampling_ratio,
599
+ spatial_scale,
600
+ aligned,
601
+ )
602
+
603
+ return builder.call_function(
604
+ _roi_align,
605
+ args=(
606
+ x,
607
+ rois,
608
+ batch_indices,
609
+ mode,
610
+ output_height,
611
+ output_width,
612
+ sampling_ratio,
613
+ spatial_scale,
614
+ aligned,
615
+ ),
616
+ )
617
+
618
+
619
+ def _roi_align_max_mode(
620
+ x,
621
+ rois,
622
+ batch_indices,
623
+ output_height,
624
+ output_width,
625
+ sampling_ratio,
626
+ spatial_scale,
627
+ half_pixel,
628
+ ):
629
+ """ONNX RoiAlign with max pooling mode.
630
+
631
+ For each output bin, samples at grid points and takes the MAX of the
632
+ weighted corner values at each sampling point, then MAX across all
633
+ sampling points in the bin.
634
+ """
635
+ num_rois = rois.shape[0]
636
+ channels = x.shape[1]
637
+ height = x.shape[2]
638
+ width = x.shape[3]
639
+
640
+ output = torch.zeros(
641
+ num_rois, channels, output_height, output_width, dtype=x.dtype, device=x.device
642
+ )
643
+
644
+ for n in range(num_rois):
645
+ roi_batch_ind = int(batch_indices[n].item())
646
+ roi = rois[n]
647
+
648
+ # Apply spatial scale and offset
649
+ offset = 0.5 if half_pixel else 0.0
650
+ roi_start_w = float(roi[0]) * spatial_scale - offset
651
+ roi_start_h = float(roi[1]) * spatial_scale - offset
652
+ roi_end_w = float(roi[2]) * spatial_scale - offset
653
+ roi_end_h = float(roi[3]) * spatial_scale - offset
654
+
655
+ roi_width = roi_end_w - roi_start_w
656
+ roi_height = roi_end_h - roi_start_h
657
+
658
+ if not half_pixel:
659
+ # Force malformed ROIs to be 1x1
660
+ roi_width = max(roi_width, 1.0)
661
+ roi_height = max(roi_height, 1.0)
662
+
663
+ bin_size_h = roi_height / output_height
664
+ bin_size_w = roi_width / output_width
665
+
666
+ # Determine sampling grid size
667
+ if sampling_ratio > 0:
668
+ roi_bin_grid_h = sampling_ratio
669
+ roi_bin_grid_w = sampling_ratio
670
+ else:
671
+ roi_bin_grid_h = int(
672
+ torch.ceil(torch.tensor(roi_height / output_height)).item()
673
+ )
674
+ roi_bin_grid_w = int(
675
+ torch.ceil(torch.tensor(roi_width / output_width)).item()
676
+ )
677
+ roi_bin_grid_h = max(1, roi_bin_grid_h)
678
+ roi_bin_grid_w = max(1, roi_bin_grid_w)
679
+
680
+ for c in range(channels):
681
+ for ph in range(output_height):
682
+ for pw in range(output_width):
683
+ output_val = None
684
+
685
+ for iy in range(roi_bin_grid_h):
686
+ yy = (
687
+ roi_start_h
688
+ + ph * bin_size_h
689
+ + (iy + 0.5) * bin_size_h / roi_bin_grid_h
690
+ )
691
+ for ix in range(roi_bin_grid_w):
692
+ xx = (
693
+ roi_start_w
694
+ + pw * bin_size_w
695
+ + (ix + 0.5) * bin_size_w / roi_bin_grid_w
696
+ )
697
+
698
+ # Check bounds
699
+ if yy < -1.0 or yy > height or xx < -1.0 or xx > width:
700
+ continue
701
+
702
+ y = max(yy, 0.0)
703
+ xc = max(xx, 0.0)
704
+
705
+ y_low = int(y)
706
+ x_low = int(xc)
707
+
708
+ if y_low >= height - 1:
709
+ y_high = y_low = height - 1
710
+ y = float(y_low)
711
+ else:
712
+ y_high = y_low + 1
713
+
714
+ if x_low >= width - 1:
715
+ x_high = x_low = width - 1
716
+ xc = float(x_low)
717
+ else:
718
+ x_high = x_low + 1
719
+
720
+ ly = y - y_low
721
+ lx = xc - x_low
722
+ hy = 1.0 - ly
723
+ hx = 1.0 - lx
724
+
725
+ # Weights
726
+ w1 = hy * hx
727
+ w2 = hy * lx
728
+ w3 = ly * hx
729
+ w4 = ly * lx
730
+
731
+ # Get corner values
732
+ v1 = x[roi_batch_ind, c, y_low, x_low].item()
733
+ v2 = x[roi_batch_ind, c, y_low, x_high].item()
734
+ v3 = x[roi_batch_ind, c, y_high, x_low].item()
735
+ v4 = x[roi_batch_ind, c, y_high, x_high].item()
736
+
737
+ # ONNX max mode: max of weighted corners
738
+ val = max(w1 * v1, w2 * v2, w3 * v3, w4 * v4)
739
+
740
+ if output_val is None:
741
+ output_val = val
742
+ else:
743
+ output_val = max(output_val, val)
744
+
745
+ if output_val is not None:
746
+ output[n, c, ph, pw] = output_val
747
+
748
+ return output