onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,534 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Reduction operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..op_registry import register
10
+ from ..utils.attributes import get_attribute
11
+ from ..utils.op_helpers import get_attribute_or_input
12
+
13
+ if TYPE_CHECKING:
14
+ from ..graph_builder import GraphBuilder
15
+
16
+
17
+ def _get_reduction_axes(
18
+ node: onnx.NodeProto, builder: "GraphBuilder"
19
+ ) -> list[int] | torch.fx.Node | None:
20
+ """Get axes for reduction, handling both attribute and input formats.
21
+
22
+ In opset < 13, axes is an attribute.
23
+ In opset 13-17, axes can be an attribute or an optional input.
24
+ In opset 18+, axes is an optional input only.
25
+ """
26
+ return get_attribute_or_input(
27
+ builder,
28
+ node,
29
+ attr_name="axes",
30
+ input_index=1,
31
+ opset_version=builder.opset_version,
32
+ attr_allowed_until=17,
33
+ input_allowed_since=13,
34
+ default=None,
35
+ )
36
+
37
+
38
+ @register("ReduceSum")
39
+ def reduce_sum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
40
+ """Sum reduction."""
41
+ x = builder.get_value(node.input[0])
42
+ axes = _get_reduction_axes(node, builder)
43
+ keepdims = get_attribute(node, "keepdims", 1)
44
+ noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
45
+
46
+ def _reduce_sum(t, axes, keepdims, noop_with_empty_axes):
47
+ # Handle empty axes list
48
+ if isinstance(axes, (list, tuple)) and len(axes) == 0:
49
+ if noop_with_empty_axes:
50
+ return t
51
+ # Empty axes with noop=False means reduce all dimensions
52
+ axes = None
53
+ if isinstance(axes, torch.Tensor) and axes.numel() == 0:
54
+ if noop_with_empty_axes:
55
+ return t
56
+ axes = None
57
+ if axes is None:
58
+ result = torch.sum(t)
59
+ if keepdims:
60
+ # Reshape to have all dimensions as 1
61
+ result = result.reshape([1] * t.ndim)
62
+ return result
63
+ if isinstance(axes, torch.Tensor):
64
+ axes = tuple(axes.tolist())
65
+ elif isinstance(axes, (list, tuple)):
66
+ axes = tuple(axes)
67
+ return torch.sum(t, dim=axes, keepdim=keepdims)
68
+
69
+ return builder.call_function(
70
+ _reduce_sum, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
71
+ )
72
+
73
+
74
+ @register("ReduceMean")
75
+ def reduce_mean(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
76
+ """Mean reduction."""
77
+ x = builder.get_value(node.input[0])
78
+ axes = _get_reduction_axes(node, builder)
79
+ keepdims = get_attribute(node, "keepdims", 1)
80
+ noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
81
+
82
+ def _reduce_mean(t, axes, keepdims, noop_with_empty_axes):
83
+ # Handle empty axes list
84
+ if isinstance(axes, (list, tuple)) and len(axes) == 0:
85
+ if noop_with_empty_axes:
86
+ return t
87
+ axes = None
88
+ if isinstance(axes, torch.Tensor) and axes.numel() == 0:
89
+ if noop_with_empty_axes:
90
+ return t
91
+ axes = None
92
+ if axes is None:
93
+ result = torch.mean(t)
94
+ if keepdims:
95
+ result = result.reshape([1] * t.ndim)
96
+ return result
97
+ if isinstance(axes, torch.Tensor):
98
+ axes = tuple(axes.tolist())
99
+ elif isinstance(axes, (list, tuple)):
100
+ axes = tuple(axes)
101
+ return torch.mean(t, dim=axes, keepdim=keepdims)
102
+
103
+ return builder.call_function(
104
+ _reduce_mean, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
105
+ )
106
+
107
+
108
+ @register("ReduceMax")
109
+ def reduce_max(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
110
+ """Max reduction."""
111
+ x = builder.get_value(node.input[0])
112
+ axes = _get_reduction_axes(node, builder)
113
+ keepdims = get_attribute(node, "keepdims", 1)
114
+ noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
115
+
116
+ def _reduce_max(t, axes, keepdims, noop_with_empty_axes):
117
+ # Handle empty axes list
118
+ if isinstance(axes, (list, tuple)) and len(axes) == 0:
119
+ if noop_with_empty_axes:
120
+ return t
121
+ axes = None
122
+ if isinstance(axes, torch.Tensor) and axes.numel() == 0:
123
+ if noop_with_empty_axes:
124
+ return t
125
+ axes = None
126
+
127
+ if axes is None:
128
+ # Reduce over all dimensions
129
+ if t.numel() == 0:
130
+ # Empty tensor: return -inf with proper shape
131
+ if keepdims:
132
+ return torch.full(
133
+ [1] * t.ndim, float("-inf"), dtype=t.dtype, device=t.device
134
+ )
135
+ return torch.tensor(float("-inf"), dtype=t.dtype, device=t.device)
136
+ result = t.max()
137
+ if keepdims:
138
+ result = result.reshape([1] * t.ndim)
139
+ return result
140
+
141
+ if isinstance(axes, torch.Tensor):
142
+ axes = axes.tolist()
143
+ if isinstance(axes, list) and len(axes) == 1:
144
+ axes = axes[0]
145
+ if isinstance(axes, int):
146
+ # Check for empty dimension
147
+ if t.shape[axes] == 0:
148
+ new_shape = list(t.shape)
149
+ if keepdims:
150
+ new_shape[axes] = 1
151
+ else:
152
+ new_shape.pop(axes)
153
+ return torch.full(
154
+ new_shape, float("-inf"), dtype=t.dtype, device=t.device
155
+ )
156
+ return t.max(dim=axes, keepdim=keepdims).values
157
+ # Multiple axes: reduce sequentially
158
+ result = t
159
+ for axis in sorted(axes, reverse=True):
160
+ if result.shape[axis] == 0:
161
+ new_shape = list(result.shape)
162
+ if keepdims:
163
+ new_shape[axis] = 1
164
+ else:
165
+ new_shape.pop(axis)
166
+ result = torch.full(
167
+ new_shape, float("-inf"), dtype=result.dtype, device=result.device
168
+ )
169
+ else:
170
+ result = result.max(dim=axis, keepdim=keepdims).values
171
+ return result
172
+
173
+ return builder.call_function(
174
+ _reduce_max, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
175
+ )
176
+
177
+
178
+ @register("ReduceMin")
179
+ def reduce_min(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
180
+ """Min reduction."""
181
+ x = builder.get_value(node.input[0])
182
+ axes = _get_reduction_axes(node, builder)
183
+ keepdims = get_attribute(node, "keepdims", 1)
184
+ noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
185
+
186
+ def _reduce_min(t, axes, keepdims, noop_with_empty_axes):
187
+ # Handle empty axes list
188
+ if isinstance(axes, (list, tuple)) and len(axes) == 0:
189
+ if noop_with_empty_axes:
190
+ return t
191
+ axes = None
192
+ if isinstance(axes, torch.Tensor) and axes.numel() == 0:
193
+ if noop_with_empty_axes:
194
+ return t
195
+ axes = None
196
+
197
+ if axes is None:
198
+ # Reduce over all dimensions
199
+ if t.numel() == 0:
200
+ # Empty tensor: return inf with proper shape
201
+ if keepdims:
202
+ return torch.full(
203
+ [1] * t.ndim, float("inf"), dtype=t.dtype, device=t.device
204
+ )
205
+ return torch.tensor(float("inf"), dtype=t.dtype, device=t.device)
206
+ result = t.min()
207
+ if keepdims:
208
+ result = result.reshape([1] * t.ndim)
209
+ return result
210
+
211
+ if isinstance(axes, torch.Tensor):
212
+ axes = axes.tolist()
213
+ if isinstance(axes, list) and len(axes) == 1:
214
+ axes = axes[0]
215
+ if isinstance(axes, int):
216
+ # Check for empty dimension
217
+ if t.shape[axes] == 0:
218
+ new_shape = list(t.shape)
219
+ if keepdims:
220
+ new_shape[axes] = 1
221
+ else:
222
+ new_shape.pop(axes)
223
+ return torch.full(
224
+ new_shape, float("inf"), dtype=t.dtype, device=t.device
225
+ )
226
+ return t.min(dim=axes, keepdim=keepdims).values
227
+ # Multiple axes: reduce sequentially
228
+ result = t
229
+ for axis in sorted(axes, reverse=True):
230
+ if result.shape[axis] == 0:
231
+ new_shape = list(result.shape)
232
+ if keepdims:
233
+ new_shape[axis] = 1
234
+ else:
235
+ new_shape.pop(axis)
236
+ result = torch.full(
237
+ new_shape, float("inf"), dtype=result.dtype, device=result.device
238
+ )
239
+ else:
240
+ result = result.min(dim=axis, keepdim=keepdims).values
241
+ return result
242
+
243
+ return builder.call_function(
244
+ _reduce_min, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
245
+ )
246
+
247
+
248
+ @register("ReduceProd")
249
+ def reduce_prod(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
250
+ """Product reduction."""
251
+ x = builder.get_value(node.input[0])
252
+ axes = _get_reduction_axes(node, builder)
253
+ keepdims = get_attribute(node, "keepdims", 1)
254
+ noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
255
+
256
+ def _reduce_prod(t, axes, keepdims, noop_with_empty_axes):
257
+ # Handle empty axes list
258
+ if isinstance(axes, (list, tuple)) and len(axes) == 0:
259
+ if noop_with_empty_axes:
260
+ return t
261
+ axes = None
262
+ if isinstance(axes, torch.Tensor) and axes.numel() == 0:
263
+ if noop_with_empty_axes:
264
+ return t
265
+ axes = None
266
+
267
+ if axes is None:
268
+ result = torch.prod(t)
269
+ if keepdims:
270
+ result = result.reshape([1] * t.ndim)
271
+ return result
272
+
273
+ if isinstance(axes, torch.Tensor):
274
+ axes = axes.tolist()
275
+ if isinstance(axes, list) and len(axes) == 1:
276
+ axes = axes[0]
277
+ if isinstance(axes, int):
278
+ return torch.prod(t, dim=axes, keepdim=keepdims)
279
+ # Multiple axes: reduce sequentially
280
+ result = t
281
+ for axis in sorted(axes, reverse=True):
282
+ result = torch.prod(result, dim=axis, keepdim=keepdims)
283
+ return result
284
+
285
+ return builder.call_function(
286
+ _reduce_prod, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
287
+ )
288
+
289
+
290
+ @register("ReduceL1")
291
+ def reduce_l1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
292
+ """L1 norm reduction."""
293
+ x = builder.get_value(node.input[0])
294
+ axes = _get_reduction_axes(node, builder)
295
+ keepdims = get_attribute(node, "keepdims", 1)
296
+
297
+ def _reduce_l1(t, axes, keepdims):
298
+ abs_t = torch.abs(t)
299
+ if axes is None:
300
+ return torch.sum(abs_t)
301
+ if isinstance(axes, torch.Tensor):
302
+ axes = tuple(axes.tolist())
303
+ elif isinstance(axes, (list, tuple)):
304
+ axes = tuple(axes)
305
+ return torch.sum(abs_t, dim=axes, keepdim=keepdims)
306
+
307
+ return builder.call_function(_reduce_l1, args=(x, axes, bool(keepdims)))
308
+
309
+
310
+ @register("ReduceL2")
311
+ def reduce_l2(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
312
+ """L2 norm reduction."""
313
+ x = builder.get_value(node.input[0])
314
+ axes = _get_reduction_axes(node, builder)
315
+ keepdims = get_attribute(node, "keepdims", 1)
316
+
317
+ def _reduce_l2(t, axes, keepdims):
318
+ if axes is None:
319
+ return torch.norm(t)
320
+ if isinstance(axes, torch.Tensor):
321
+ axes = tuple(axes.tolist())
322
+ elif isinstance(axes, (list, tuple)):
323
+ axes = tuple(axes)
324
+ return torch.norm(t, dim=axes, keepdim=keepdims)
325
+
326
+ return builder.call_function(_reduce_l2, args=(x, axes, bool(keepdims)))
327
+
328
+
329
+ @register("ReduceLogSum")
330
+ def reduce_log_sum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
331
+ """Log of sum reduction."""
332
+ x = builder.get_value(node.input[0])
333
+ axes = _get_reduction_axes(node, builder)
334
+ keepdims = get_attribute(node, "keepdims", 1)
335
+
336
+ def _reduce_log_sum(t, axes, keepdims):
337
+ if axes is None:
338
+ return torch.log(torch.sum(t))
339
+ if isinstance(axes, torch.Tensor):
340
+ axes = tuple(axes.tolist())
341
+ elif isinstance(axes, (list, tuple)):
342
+ axes = tuple(axes)
343
+ return torch.log(torch.sum(t, dim=axes, keepdim=keepdims))
344
+
345
+ return builder.call_function(_reduce_log_sum, args=(x, axes, bool(keepdims)))
346
+
347
+
348
+ @register("ReduceLogSumExp")
349
+ def reduce_log_sum_exp(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
350
+ """LogSumExp reduction."""
351
+ x = builder.get_value(node.input[0])
352
+ axes = _get_reduction_axes(node, builder)
353
+ keepdims = get_attribute(node, "keepdims", 1)
354
+
355
+ def _reduce_log_sum_exp(t, axes, keepdims):
356
+ if axes is None:
357
+ return torch.logsumexp(t, dim=tuple(range(t.dim())))
358
+ if isinstance(axes, torch.Tensor):
359
+ axes = tuple(axes.tolist())
360
+ elif isinstance(axes, (list, tuple)):
361
+ axes = tuple(axes)
362
+ return torch.logsumexp(t, dim=axes, keepdim=keepdims)
363
+
364
+ return builder.call_function(_reduce_log_sum_exp, args=(x, axes, bool(keepdims)))
365
+
366
+
367
+ @register("ReduceSumSquare")
368
+ def reduce_sum_square(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
369
+ """Sum of squares reduction."""
370
+ x = builder.get_value(node.input[0])
371
+ axes = _get_reduction_axes(node, builder)
372
+ keepdims = get_attribute(node, "keepdims", 1)
373
+
374
+ def _reduce_sum_square(t, axes, keepdims):
375
+ sq = torch.square(t)
376
+ if axes is None:
377
+ return torch.sum(sq)
378
+ if isinstance(axes, torch.Tensor):
379
+ axes = tuple(axes.tolist())
380
+ elif isinstance(axes, (list, tuple)):
381
+ axes = tuple(axes)
382
+ return torch.sum(sq, dim=axes, keepdim=keepdims)
383
+
384
+ return builder.call_function(_reduce_sum_square, args=(x, axes, bool(keepdims)))
385
+
386
+
387
+ # =============================================================================
388
+ # ArgMax/ArgMin operators
389
+ # =============================================================================
390
+
391
+
392
+ def _make_arg_extremum_handler(torch_fn):
393
+ """Factory for ArgMax/ArgMin operator handlers."""
394
+
395
+ def _arg_extremum(t, axis, keepdims, select_last_index):
396
+ if select_last_index:
397
+ flipped = torch.flip(t, [axis])
398
+ idx = torch_fn(flipped, dim=axis, keepdim=keepdims)
399
+ return t.size(axis) - 1 - idx
400
+ return torch_fn(t, dim=axis, keepdim=keepdims)
401
+
402
+ def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
403
+ x = builder.get_value(node.input[0])
404
+ axis = get_attribute(node, "axis", 0)
405
+ keepdims = get_attribute(node, "keepdims", 1)
406
+ select_last_index = get_attribute(node, "select_last_index", 0)
407
+ return builder.call_function(
408
+ _arg_extremum, args=(x, axis, bool(keepdims), bool(select_last_index))
409
+ )
410
+
411
+ return handler
412
+
413
+
414
+ @register("ArgMax")
415
+ def argmax(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
416
+ """Index of maximum value."""
417
+ return _make_arg_extremum_handler(torch.argmax)(builder, node)
418
+
419
+
420
+ @register("ArgMin")
421
+ def argmin(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
422
+ """Index of minimum value."""
423
+ return _make_arg_extremum_handler(torch.argmin)(builder, node)
424
+
425
+
426
+ # =============================================================================
427
+ # Cumulative and TopK operators
428
+ # =============================================================================
429
+
430
+
431
+ @register("CumSum")
432
+ def cumsum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
433
+ """Cumulative sum."""
434
+ x = builder.get_value(node.input[0])
435
+ axis = builder.get_value(node.input[1])
436
+
437
+ exclusive = get_attribute(node, "exclusive", 0)
438
+ reverse = get_attribute(node, "reverse", 0)
439
+
440
+ def _cumsum(x, axis, exclusive, reverse):
441
+ ax = axis.item() if isinstance(axis, torch.Tensor) else axis
442
+
443
+ if reverse:
444
+ x = torch.flip(x, [int(ax)])
445
+
446
+ result = torch.cumsum(x, dim=int(ax))
447
+
448
+ if exclusive:
449
+ # Shift by one and pad with zero
450
+ pad_shape = list(x.shape)
451
+ pad_shape[int(ax)] = 1
452
+ zero_pad = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
453
+ result = torch.cat(
454
+ [zero_pad, result.narrow(int(ax), 0, x.shape[int(ax)] - 1)], dim=int(ax)
455
+ )
456
+
457
+ if reverse:
458
+ result = torch.flip(result, [int(ax)])
459
+
460
+ return result
461
+
462
+ return builder.call_function(_cumsum, args=(x, axis, exclusive, reverse))
463
+
464
+
465
+ @register("TopK")
466
+ def topk(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
467
+ """Find top K values and indices."""
468
+ x = builder.get_value(node.input[0])
469
+ k = builder.get_value(node.input[1])
470
+
471
+ axis = get_attribute(node, "axis", -1)
472
+ largest = get_attribute(node, "largest", 1)
473
+ sorted_ = get_attribute(node, "sorted", 1)
474
+
475
+ def _topk(x, k, axis, largest, sorted_):
476
+ k_val = k.item() if isinstance(k, torch.Tensor) else k
477
+ k_val = int(k_val)
478
+
479
+ # Handle unsupported dtypes (e.g., uint64) by converting to int64
480
+ original_dtype = x.dtype
481
+ needs_conversion = original_dtype == torch.uint64
482
+ if needs_conversion:
483
+ x = x.to(torch.int64)
484
+
485
+ # ONNX TopK requires stable sorting: for equal values, the element
486
+ # with lower index appears first. PyTorch's topk is not stable.
487
+ # We achieve stability by using argsort on a composite key.
488
+ # Create indices tensor for tie-breaking
489
+ size = x.shape[axis]
490
+ # Create indices [0, 1, 2, ..., size-1] along the specified axis
491
+ indices_shape = [1] * x.ndim
492
+ indices_shape[axis] = size
493
+ idx = torch.arange(size, device=x.device, dtype=x.dtype).view(indices_shape)
494
+ idx = idx.expand_as(x)
495
+
496
+ # Scale values so that the index becomes the tiebreaker
497
+ # For largest=True: negate values, sort ascending, lower index wins
498
+ # For largest=False: use values directly, sort ascending, lower index wins
499
+ if bool(largest):
500
+ # Negate so that larger values become smaller (for ascending sort)
501
+ # Add small offset based on index to break ties (lower index = smaller offset)
502
+ sort_values = -x
503
+ else:
504
+ sort_values = x
505
+
506
+ # Use argsort with stable=True for stable sorting
507
+ sorted_indices = torch.argsort(sort_values, dim=axis, stable=True)
508
+
509
+ # Take top k indices
510
+ # Narrow to first k elements along axis
511
+ top_k_indices = torch.narrow(sorted_indices, axis, 0, k_val)
512
+
513
+ # Gather values using the indices
514
+ values = torch.gather(x, axis, top_k_indices)
515
+
516
+ # If sorted=False, the order is undefined, but we still use stable order
517
+ # The indices should be the original indices
518
+ indices = top_k_indices
519
+
520
+ # Convert values back to original dtype if needed
521
+ if needs_conversion:
522
+ values = values.to(original_dtype)
523
+
524
+ return values, indices
525
+
526
+ result = builder.call_function(_topk, args=(x, k, axis, largest, sorted_))
527
+
528
+ # Handle multiple outputs
529
+ for i, output_name in enumerate(node.output):
530
+ if output_name:
531
+ idx_node = builder.call_function(lambda t, idx: t[idx], args=(result, i))
532
+ builder.env[output_name] = idx_node
533
+
534
+ return result