onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/reduction.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Reduction operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..op_registry import register
|
|
10
|
+
from ..utils.attributes import get_attribute
|
|
11
|
+
from ..utils.op_helpers import get_attribute_or_input
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..graph_builder import GraphBuilder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_reduction_axes(
|
|
18
|
+
node: onnx.NodeProto, builder: "GraphBuilder"
|
|
19
|
+
) -> list[int] | torch.fx.Node | None:
|
|
20
|
+
"""Get axes for reduction, handling both attribute and input formats.
|
|
21
|
+
|
|
22
|
+
In opset < 13, axes is an attribute.
|
|
23
|
+
In opset 13-17, axes can be an attribute or an optional input.
|
|
24
|
+
In opset 18+, axes is an optional input only.
|
|
25
|
+
"""
|
|
26
|
+
return get_attribute_or_input(
|
|
27
|
+
builder,
|
|
28
|
+
node,
|
|
29
|
+
attr_name="axes",
|
|
30
|
+
input_index=1,
|
|
31
|
+
opset_version=builder.opset_version,
|
|
32
|
+
attr_allowed_until=17,
|
|
33
|
+
input_allowed_since=13,
|
|
34
|
+
default=None,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@register("ReduceSum")
|
|
39
|
+
def reduce_sum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
40
|
+
"""Sum reduction."""
|
|
41
|
+
x = builder.get_value(node.input[0])
|
|
42
|
+
axes = _get_reduction_axes(node, builder)
|
|
43
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
44
|
+
noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
|
|
45
|
+
|
|
46
|
+
def _reduce_sum(t, axes, keepdims, noop_with_empty_axes):
|
|
47
|
+
# Handle empty axes list
|
|
48
|
+
if isinstance(axes, (list, tuple)) and len(axes) == 0:
|
|
49
|
+
if noop_with_empty_axes:
|
|
50
|
+
return t
|
|
51
|
+
# Empty axes with noop=False means reduce all dimensions
|
|
52
|
+
axes = None
|
|
53
|
+
if isinstance(axes, torch.Tensor) and axes.numel() == 0:
|
|
54
|
+
if noop_with_empty_axes:
|
|
55
|
+
return t
|
|
56
|
+
axes = None
|
|
57
|
+
if axes is None:
|
|
58
|
+
result = torch.sum(t)
|
|
59
|
+
if keepdims:
|
|
60
|
+
# Reshape to have all dimensions as 1
|
|
61
|
+
result = result.reshape([1] * t.ndim)
|
|
62
|
+
return result
|
|
63
|
+
if isinstance(axes, torch.Tensor):
|
|
64
|
+
axes = tuple(axes.tolist())
|
|
65
|
+
elif isinstance(axes, (list, tuple)):
|
|
66
|
+
axes = tuple(axes)
|
|
67
|
+
return torch.sum(t, dim=axes, keepdim=keepdims)
|
|
68
|
+
|
|
69
|
+
return builder.call_function(
|
|
70
|
+
_reduce_sum, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@register("ReduceMean")
|
|
75
|
+
def reduce_mean(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
76
|
+
"""Mean reduction."""
|
|
77
|
+
x = builder.get_value(node.input[0])
|
|
78
|
+
axes = _get_reduction_axes(node, builder)
|
|
79
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
80
|
+
noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
|
|
81
|
+
|
|
82
|
+
def _reduce_mean(t, axes, keepdims, noop_with_empty_axes):
|
|
83
|
+
# Handle empty axes list
|
|
84
|
+
if isinstance(axes, (list, tuple)) and len(axes) == 0:
|
|
85
|
+
if noop_with_empty_axes:
|
|
86
|
+
return t
|
|
87
|
+
axes = None
|
|
88
|
+
if isinstance(axes, torch.Tensor) and axes.numel() == 0:
|
|
89
|
+
if noop_with_empty_axes:
|
|
90
|
+
return t
|
|
91
|
+
axes = None
|
|
92
|
+
if axes is None:
|
|
93
|
+
result = torch.mean(t)
|
|
94
|
+
if keepdims:
|
|
95
|
+
result = result.reshape([1] * t.ndim)
|
|
96
|
+
return result
|
|
97
|
+
if isinstance(axes, torch.Tensor):
|
|
98
|
+
axes = tuple(axes.tolist())
|
|
99
|
+
elif isinstance(axes, (list, tuple)):
|
|
100
|
+
axes = tuple(axes)
|
|
101
|
+
return torch.mean(t, dim=axes, keepdim=keepdims)
|
|
102
|
+
|
|
103
|
+
return builder.call_function(
|
|
104
|
+
_reduce_mean, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@register("ReduceMax")
|
|
109
|
+
def reduce_max(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
110
|
+
"""Max reduction."""
|
|
111
|
+
x = builder.get_value(node.input[0])
|
|
112
|
+
axes = _get_reduction_axes(node, builder)
|
|
113
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
114
|
+
noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
|
|
115
|
+
|
|
116
|
+
def _reduce_max(t, axes, keepdims, noop_with_empty_axes):
|
|
117
|
+
# Handle empty axes list
|
|
118
|
+
if isinstance(axes, (list, tuple)) and len(axes) == 0:
|
|
119
|
+
if noop_with_empty_axes:
|
|
120
|
+
return t
|
|
121
|
+
axes = None
|
|
122
|
+
if isinstance(axes, torch.Tensor) and axes.numel() == 0:
|
|
123
|
+
if noop_with_empty_axes:
|
|
124
|
+
return t
|
|
125
|
+
axes = None
|
|
126
|
+
|
|
127
|
+
if axes is None:
|
|
128
|
+
# Reduce over all dimensions
|
|
129
|
+
if t.numel() == 0:
|
|
130
|
+
# Empty tensor: return -inf with proper shape
|
|
131
|
+
if keepdims:
|
|
132
|
+
return torch.full(
|
|
133
|
+
[1] * t.ndim, float("-inf"), dtype=t.dtype, device=t.device
|
|
134
|
+
)
|
|
135
|
+
return torch.tensor(float("-inf"), dtype=t.dtype, device=t.device)
|
|
136
|
+
result = t.max()
|
|
137
|
+
if keepdims:
|
|
138
|
+
result = result.reshape([1] * t.ndim)
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
if isinstance(axes, torch.Tensor):
|
|
142
|
+
axes = axes.tolist()
|
|
143
|
+
if isinstance(axes, list) and len(axes) == 1:
|
|
144
|
+
axes = axes[0]
|
|
145
|
+
if isinstance(axes, int):
|
|
146
|
+
# Check for empty dimension
|
|
147
|
+
if t.shape[axes] == 0:
|
|
148
|
+
new_shape = list(t.shape)
|
|
149
|
+
if keepdims:
|
|
150
|
+
new_shape[axes] = 1
|
|
151
|
+
else:
|
|
152
|
+
new_shape.pop(axes)
|
|
153
|
+
return torch.full(
|
|
154
|
+
new_shape, float("-inf"), dtype=t.dtype, device=t.device
|
|
155
|
+
)
|
|
156
|
+
return t.max(dim=axes, keepdim=keepdims).values
|
|
157
|
+
# Multiple axes: reduce sequentially
|
|
158
|
+
result = t
|
|
159
|
+
for axis in sorted(axes, reverse=True):
|
|
160
|
+
if result.shape[axis] == 0:
|
|
161
|
+
new_shape = list(result.shape)
|
|
162
|
+
if keepdims:
|
|
163
|
+
new_shape[axis] = 1
|
|
164
|
+
else:
|
|
165
|
+
new_shape.pop(axis)
|
|
166
|
+
result = torch.full(
|
|
167
|
+
new_shape, float("-inf"), dtype=result.dtype, device=result.device
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
result = result.max(dim=axis, keepdim=keepdims).values
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
return builder.call_function(
|
|
174
|
+
_reduce_max, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@register("ReduceMin")
|
|
179
|
+
def reduce_min(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
180
|
+
"""Min reduction."""
|
|
181
|
+
x = builder.get_value(node.input[0])
|
|
182
|
+
axes = _get_reduction_axes(node, builder)
|
|
183
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
184
|
+
noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
|
|
185
|
+
|
|
186
|
+
def _reduce_min(t, axes, keepdims, noop_with_empty_axes):
|
|
187
|
+
# Handle empty axes list
|
|
188
|
+
if isinstance(axes, (list, tuple)) and len(axes) == 0:
|
|
189
|
+
if noop_with_empty_axes:
|
|
190
|
+
return t
|
|
191
|
+
axes = None
|
|
192
|
+
if isinstance(axes, torch.Tensor) and axes.numel() == 0:
|
|
193
|
+
if noop_with_empty_axes:
|
|
194
|
+
return t
|
|
195
|
+
axes = None
|
|
196
|
+
|
|
197
|
+
if axes is None:
|
|
198
|
+
# Reduce over all dimensions
|
|
199
|
+
if t.numel() == 0:
|
|
200
|
+
# Empty tensor: return inf with proper shape
|
|
201
|
+
if keepdims:
|
|
202
|
+
return torch.full(
|
|
203
|
+
[1] * t.ndim, float("inf"), dtype=t.dtype, device=t.device
|
|
204
|
+
)
|
|
205
|
+
return torch.tensor(float("inf"), dtype=t.dtype, device=t.device)
|
|
206
|
+
result = t.min()
|
|
207
|
+
if keepdims:
|
|
208
|
+
result = result.reshape([1] * t.ndim)
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
if isinstance(axes, torch.Tensor):
|
|
212
|
+
axes = axes.tolist()
|
|
213
|
+
if isinstance(axes, list) and len(axes) == 1:
|
|
214
|
+
axes = axes[0]
|
|
215
|
+
if isinstance(axes, int):
|
|
216
|
+
# Check for empty dimension
|
|
217
|
+
if t.shape[axes] == 0:
|
|
218
|
+
new_shape = list(t.shape)
|
|
219
|
+
if keepdims:
|
|
220
|
+
new_shape[axes] = 1
|
|
221
|
+
else:
|
|
222
|
+
new_shape.pop(axes)
|
|
223
|
+
return torch.full(
|
|
224
|
+
new_shape, float("inf"), dtype=t.dtype, device=t.device
|
|
225
|
+
)
|
|
226
|
+
return t.min(dim=axes, keepdim=keepdims).values
|
|
227
|
+
# Multiple axes: reduce sequentially
|
|
228
|
+
result = t
|
|
229
|
+
for axis in sorted(axes, reverse=True):
|
|
230
|
+
if result.shape[axis] == 0:
|
|
231
|
+
new_shape = list(result.shape)
|
|
232
|
+
if keepdims:
|
|
233
|
+
new_shape[axis] = 1
|
|
234
|
+
else:
|
|
235
|
+
new_shape.pop(axis)
|
|
236
|
+
result = torch.full(
|
|
237
|
+
new_shape, float("inf"), dtype=result.dtype, device=result.device
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
result = result.min(dim=axis, keepdim=keepdims).values
|
|
241
|
+
return result
|
|
242
|
+
|
|
243
|
+
return builder.call_function(
|
|
244
|
+
_reduce_min, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@register("ReduceProd")
|
|
249
|
+
def reduce_prod(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
250
|
+
"""Product reduction."""
|
|
251
|
+
x = builder.get_value(node.input[0])
|
|
252
|
+
axes = _get_reduction_axes(node, builder)
|
|
253
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
254
|
+
noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0)
|
|
255
|
+
|
|
256
|
+
def _reduce_prod(t, axes, keepdims, noop_with_empty_axes):
|
|
257
|
+
# Handle empty axes list
|
|
258
|
+
if isinstance(axes, (list, tuple)) and len(axes) == 0:
|
|
259
|
+
if noop_with_empty_axes:
|
|
260
|
+
return t
|
|
261
|
+
axes = None
|
|
262
|
+
if isinstance(axes, torch.Tensor) and axes.numel() == 0:
|
|
263
|
+
if noop_with_empty_axes:
|
|
264
|
+
return t
|
|
265
|
+
axes = None
|
|
266
|
+
|
|
267
|
+
if axes is None:
|
|
268
|
+
result = torch.prod(t)
|
|
269
|
+
if keepdims:
|
|
270
|
+
result = result.reshape([1] * t.ndim)
|
|
271
|
+
return result
|
|
272
|
+
|
|
273
|
+
if isinstance(axes, torch.Tensor):
|
|
274
|
+
axes = axes.tolist()
|
|
275
|
+
if isinstance(axes, list) and len(axes) == 1:
|
|
276
|
+
axes = axes[0]
|
|
277
|
+
if isinstance(axes, int):
|
|
278
|
+
return torch.prod(t, dim=axes, keepdim=keepdims)
|
|
279
|
+
# Multiple axes: reduce sequentially
|
|
280
|
+
result = t
|
|
281
|
+
for axis in sorted(axes, reverse=True):
|
|
282
|
+
result = torch.prod(result, dim=axis, keepdim=keepdims)
|
|
283
|
+
return result
|
|
284
|
+
|
|
285
|
+
return builder.call_function(
|
|
286
|
+
_reduce_prod, args=(x, axes, bool(keepdims), bool(noop_with_empty_axes))
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@register("ReduceL1")
|
|
291
|
+
def reduce_l1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
292
|
+
"""L1 norm reduction."""
|
|
293
|
+
x = builder.get_value(node.input[0])
|
|
294
|
+
axes = _get_reduction_axes(node, builder)
|
|
295
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
296
|
+
|
|
297
|
+
def _reduce_l1(t, axes, keepdims):
|
|
298
|
+
abs_t = torch.abs(t)
|
|
299
|
+
if axes is None:
|
|
300
|
+
return torch.sum(abs_t)
|
|
301
|
+
if isinstance(axes, torch.Tensor):
|
|
302
|
+
axes = tuple(axes.tolist())
|
|
303
|
+
elif isinstance(axes, (list, tuple)):
|
|
304
|
+
axes = tuple(axes)
|
|
305
|
+
return torch.sum(abs_t, dim=axes, keepdim=keepdims)
|
|
306
|
+
|
|
307
|
+
return builder.call_function(_reduce_l1, args=(x, axes, bool(keepdims)))
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@register("ReduceL2")
|
|
311
|
+
def reduce_l2(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
312
|
+
"""L2 norm reduction."""
|
|
313
|
+
x = builder.get_value(node.input[0])
|
|
314
|
+
axes = _get_reduction_axes(node, builder)
|
|
315
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
316
|
+
|
|
317
|
+
def _reduce_l2(t, axes, keepdims):
|
|
318
|
+
if axes is None:
|
|
319
|
+
return torch.norm(t)
|
|
320
|
+
if isinstance(axes, torch.Tensor):
|
|
321
|
+
axes = tuple(axes.tolist())
|
|
322
|
+
elif isinstance(axes, (list, tuple)):
|
|
323
|
+
axes = tuple(axes)
|
|
324
|
+
return torch.norm(t, dim=axes, keepdim=keepdims)
|
|
325
|
+
|
|
326
|
+
return builder.call_function(_reduce_l2, args=(x, axes, bool(keepdims)))
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
@register("ReduceLogSum")
|
|
330
|
+
def reduce_log_sum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
331
|
+
"""Log of sum reduction."""
|
|
332
|
+
x = builder.get_value(node.input[0])
|
|
333
|
+
axes = _get_reduction_axes(node, builder)
|
|
334
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
335
|
+
|
|
336
|
+
def _reduce_log_sum(t, axes, keepdims):
|
|
337
|
+
if axes is None:
|
|
338
|
+
return torch.log(torch.sum(t))
|
|
339
|
+
if isinstance(axes, torch.Tensor):
|
|
340
|
+
axes = tuple(axes.tolist())
|
|
341
|
+
elif isinstance(axes, (list, tuple)):
|
|
342
|
+
axes = tuple(axes)
|
|
343
|
+
return torch.log(torch.sum(t, dim=axes, keepdim=keepdims))
|
|
344
|
+
|
|
345
|
+
return builder.call_function(_reduce_log_sum, args=(x, axes, bool(keepdims)))
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
@register("ReduceLogSumExp")
|
|
349
|
+
def reduce_log_sum_exp(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
350
|
+
"""LogSumExp reduction."""
|
|
351
|
+
x = builder.get_value(node.input[0])
|
|
352
|
+
axes = _get_reduction_axes(node, builder)
|
|
353
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
354
|
+
|
|
355
|
+
def _reduce_log_sum_exp(t, axes, keepdims):
|
|
356
|
+
if axes is None:
|
|
357
|
+
return torch.logsumexp(t, dim=tuple(range(t.dim())))
|
|
358
|
+
if isinstance(axes, torch.Tensor):
|
|
359
|
+
axes = tuple(axes.tolist())
|
|
360
|
+
elif isinstance(axes, (list, tuple)):
|
|
361
|
+
axes = tuple(axes)
|
|
362
|
+
return torch.logsumexp(t, dim=axes, keepdim=keepdims)
|
|
363
|
+
|
|
364
|
+
return builder.call_function(_reduce_log_sum_exp, args=(x, axes, bool(keepdims)))
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@register("ReduceSumSquare")
|
|
368
|
+
def reduce_sum_square(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
369
|
+
"""Sum of squares reduction."""
|
|
370
|
+
x = builder.get_value(node.input[0])
|
|
371
|
+
axes = _get_reduction_axes(node, builder)
|
|
372
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
373
|
+
|
|
374
|
+
def _reduce_sum_square(t, axes, keepdims):
|
|
375
|
+
sq = torch.square(t)
|
|
376
|
+
if axes is None:
|
|
377
|
+
return torch.sum(sq)
|
|
378
|
+
if isinstance(axes, torch.Tensor):
|
|
379
|
+
axes = tuple(axes.tolist())
|
|
380
|
+
elif isinstance(axes, (list, tuple)):
|
|
381
|
+
axes = tuple(axes)
|
|
382
|
+
return torch.sum(sq, dim=axes, keepdim=keepdims)
|
|
383
|
+
|
|
384
|
+
return builder.call_function(_reduce_sum_square, args=(x, axes, bool(keepdims)))
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# =============================================================================
|
|
388
|
+
# ArgMax/ArgMin operators
|
|
389
|
+
# =============================================================================
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _make_arg_extremum_handler(torch_fn):
|
|
393
|
+
"""Factory for ArgMax/ArgMin operator handlers."""
|
|
394
|
+
|
|
395
|
+
def _arg_extremum(t, axis, keepdims, select_last_index):
|
|
396
|
+
if select_last_index:
|
|
397
|
+
flipped = torch.flip(t, [axis])
|
|
398
|
+
idx = torch_fn(flipped, dim=axis, keepdim=keepdims)
|
|
399
|
+
return t.size(axis) - 1 - idx
|
|
400
|
+
return torch_fn(t, dim=axis, keepdim=keepdims)
|
|
401
|
+
|
|
402
|
+
def handler(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
403
|
+
x = builder.get_value(node.input[0])
|
|
404
|
+
axis = get_attribute(node, "axis", 0)
|
|
405
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
406
|
+
select_last_index = get_attribute(node, "select_last_index", 0)
|
|
407
|
+
return builder.call_function(
|
|
408
|
+
_arg_extremum, args=(x, axis, bool(keepdims), bool(select_last_index))
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
return handler
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@register("ArgMax")
|
|
415
|
+
def argmax(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
416
|
+
"""Index of maximum value."""
|
|
417
|
+
return _make_arg_extremum_handler(torch.argmax)(builder, node)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@register("ArgMin")
|
|
421
|
+
def argmin(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
422
|
+
"""Index of minimum value."""
|
|
423
|
+
return _make_arg_extremum_handler(torch.argmin)(builder, node)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
# =============================================================================
|
|
427
|
+
# Cumulative and TopK operators
|
|
428
|
+
# =============================================================================
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
@register("CumSum")
|
|
432
|
+
def cumsum(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
433
|
+
"""Cumulative sum."""
|
|
434
|
+
x = builder.get_value(node.input[0])
|
|
435
|
+
axis = builder.get_value(node.input[1])
|
|
436
|
+
|
|
437
|
+
exclusive = get_attribute(node, "exclusive", 0)
|
|
438
|
+
reverse = get_attribute(node, "reverse", 0)
|
|
439
|
+
|
|
440
|
+
def _cumsum(x, axis, exclusive, reverse):
|
|
441
|
+
ax = axis.item() if isinstance(axis, torch.Tensor) else axis
|
|
442
|
+
|
|
443
|
+
if reverse:
|
|
444
|
+
x = torch.flip(x, [int(ax)])
|
|
445
|
+
|
|
446
|
+
result = torch.cumsum(x, dim=int(ax))
|
|
447
|
+
|
|
448
|
+
if exclusive:
|
|
449
|
+
# Shift by one and pad with zero
|
|
450
|
+
pad_shape = list(x.shape)
|
|
451
|
+
pad_shape[int(ax)] = 1
|
|
452
|
+
zero_pad = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
|
|
453
|
+
result = torch.cat(
|
|
454
|
+
[zero_pad, result.narrow(int(ax), 0, x.shape[int(ax)] - 1)], dim=int(ax)
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if reverse:
|
|
458
|
+
result = torch.flip(result, [int(ax)])
|
|
459
|
+
|
|
460
|
+
return result
|
|
461
|
+
|
|
462
|
+
return builder.call_function(_cumsum, args=(x, axis, exclusive, reverse))
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
@register("TopK")
|
|
466
|
+
def topk(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
467
|
+
"""Find top K values and indices."""
|
|
468
|
+
x = builder.get_value(node.input[0])
|
|
469
|
+
k = builder.get_value(node.input[1])
|
|
470
|
+
|
|
471
|
+
axis = get_attribute(node, "axis", -1)
|
|
472
|
+
largest = get_attribute(node, "largest", 1)
|
|
473
|
+
sorted_ = get_attribute(node, "sorted", 1)
|
|
474
|
+
|
|
475
|
+
def _topk(x, k, axis, largest, sorted_):
|
|
476
|
+
k_val = k.item() if isinstance(k, torch.Tensor) else k
|
|
477
|
+
k_val = int(k_val)
|
|
478
|
+
|
|
479
|
+
# Handle unsupported dtypes (e.g., uint64) by converting to int64
|
|
480
|
+
original_dtype = x.dtype
|
|
481
|
+
needs_conversion = original_dtype == torch.uint64
|
|
482
|
+
if needs_conversion:
|
|
483
|
+
x = x.to(torch.int64)
|
|
484
|
+
|
|
485
|
+
# ONNX TopK requires stable sorting: for equal values, the element
|
|
486
|
+
# with lower index appears first. PyTorch's topk is not stable.
|
|
487
|
+
# We achieve stability by using argsort on a composite key.
|
|
488
|
+
# Create indices tensor for tie-breaking
|
|
489
|
+
size = x.shape[axis]
|
|
490
|
+
# Create indices [0, 1, 2, ..., size-1] along the specified axis
|
|
491
|
+
indices_shape = [1] * x.ndim
|
|
492
|
+
indices_shape[axis] = size
|
|
493
|
+
idx = torch.arange(size, device=x.device, dtype=x.dtype).view(indices_shape)
|
|
494
|
+
idx = idx.expand_as(x)
|
|
495
|
+
|
|
496
|
+
# Scale values so that the index becomes the tiebreaker
|
|
497
|
+
# For largest=True: negate values, sort ascending, lower index wins
|
|
498
|
+
# For largest=False: use values directly, sort ascending, lower index wins
|
|
499
|
+
if bool(largest):
|
|
500
|
+
# Negate so that larger values become smaller (for ascending sort)
|
|
501
|
+
# Add small offset based on index to break ties (lower index = smaller offset)
|
|
502
|
+
sort_values = -x
|
|
503
|
+
else:
|
|
504
|
+
sort_values = x
|
|
505
|
+
|
|
506
|
+
# Use argsort with stable=True for stable sorting
|
|
507
|
+
sorted_indices = torch.argsort(sort_values, dim=axis, stable=True)
|
|
508
|
+
|
|
509
|
+
# Take top k indices
|
|
510
|
+
# Narrow to first k elements along axis
|
|
511
|
+
top_k_indices = torch.narrow(sorted_indices, axis, 0, k_val)
|
|
512
|
+
|
|
513
|
+
# Gather values using the indices
|
|
514
|
+
values = torch.gather(x, axis, top_k_indices)
|
|
515
|
+
|
|
516
|
+
# If sorted=False, the order is undefined, but we still use stable order
|
|
517
|
+
# The indices should be the original indices
|
|
518
|
+
indices = top_k_indices
|
|
519
|
+
|
|
520
|
+
# Convert values back to original dtype if needed
|
|
521
|
+
if needs_conversion:
|
|
522
|
+
values = values.to(original_dtype)
|
|
523
|
+
|
|
524
|
+
return values, indices
|
|
525
|
+
|
|
526
|
+
result = builder.call_function(_topk, args=(x, k, axis, largest, sorted_))
|
|
527
|
+
|
|
528
|
+
# Handle multiple outputs
|
|
529
|
+
for i, output_name in enumerate(node.output):
|
|
530
|
+
if output_name:
|
|
531
|
+
idx_node = builder.call_function(lambda t, idx: t[idx], args=(result, i))
|
|
532
|
+
builder.env[output_name] = idx_node
|
|
533
|
+
|
|
534
|
+
return result
|