onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/tensor.py
ADDED
|
@@ -0,0 +1,1161 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Tensor manipulation operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..exceptions import ConversionError
|
|
10
|
+
from ..op_registry import register
|
|
11
|
+
from ..utils.attributes import get_attribute
|
|
12
|
+
from ..utils.names import sanitize_name
|
|
13
|
+
from ..utils.op_helpers import get_optional_input
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ..graph_builder import GraphBuilder
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# =============================================================================
|
|
20
|
+
# Constant and Identity operators
|
|
21
|
+
# =============================================================================
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@register("Constant")
|
|
25
|
+
def constant(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
26
|
+
"""Create a constant tensor."""
|
|
27
|
+
value = get_attribute(node, "value", tensor_loader=builder.load_tensor)
|
|
28
|
+
if value is None:
|
|
29
|
+
value_float = get_attribute(node, "value_float")
|
|
30
|
+
if value_float is not None:
|
|
31
|
+
value = torch.tensor(value_float, dtype=torch.float32)
|
|
32
|
+
value_int = get_attribute(node, "value_int")
|
|
33
|
+
if value_int is not None:
|
|
34
|
+
value = torch.tensor(value_int, dtype=torch.int64)
|
|
35
|
+
value_floats = get_attribute(node, "value_floats")
|
|
36
|
+
if value_floats is not None:
|
|
37
|
+
value = torch.tensor(value_floats, dtype=torch.float32)
|
|
38
|
+
value_ints = get_attribute(node, "value_ints")
|
|
39
|
+
if value_ints is not None:
|
|
40
|
+
value = torch.tensor(value_ints, dtype=torch.int64)
|
|
41
|
+
|
|
42
|
+
if value is None:
|
|
43
|
+
raise ConversionError(
|
|
44
|
+
"Constant node has no value attribute",
|
|
45
|
+
node_name=node.name,
|
|
46
|
+
op_type="Constant",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
output_name = node.output[0]
|
|
50
|
+
safe_name = sanitize_name(output_name)
|
|
51
|
+
builder._constants[safe_name] = value
|
|
52
|
+
|
|
53
|
+
fx_node = builder.graph.get_attr(safe_name)
|
|
54
|
+
fx_node.meta["onnx_op_type"] = "Constant"
|
|
55
|
+
fx_node.meta["onnx_name"] = output_name
|
|
56
|
+
fx_node.meta["onnx_shape"] = list(value.shape) if hasattr(value, "shape") else []
|
|
57
|
+
fx_node.meta["onnx_dtype"] = value.dtype if hasattr(value, "dtype") else None
|
|
58
|
+
return fx_node
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register("Identity")
|
|
62
|
+
def identity(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
63
|
+
"""Identity operator - returns input unchanged."""
|
|
64
|
+
return builder.get_value(node.input[0])
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@register("Cast")
|
|
68
|
+
def cast(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
69
|
+
"""Cast tensor to a different data type."""
|
|
70
|
+
from ..utils.dtype import onnx_dtype_to_torch
|
|
71
|
+
|
|
72
|
+
x = builder.get_value(node.input[0])
|
|
73
|
+
to_dtype = get_attribute(node, "to")
|
|
74
|
+
torch_dtype = onnx_dtype_to_torch(to_dtype)
|
|
75
|
+
|
|
76
|
+
if torch_dtype is None:
|
|
77
|
+
raise ConversionError(
|
|
78
|
+
f"Unsupported cast target dtype: {to_dtype}",
|
|
79
|
+
node_name=node.name,
|
|
80
|
+
op_type="Cast",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return builder.call_function(lambda t, dtype: t.to(dtype), args=(x, torch_dtype))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@register("CastLike")
|
|
87
|
+
def cast_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
88
|
+
"""Cast tensor to the same data type as the target tensor."""
|
|
89
|
+
x = builder.get_value(node.input[0])
|
|
90
|
+
target = builder.get_value(node.input[1])
|
|
91
|
+
|
|
92
|
+
def _cast_like(t, target):
|
|
93
|
+
return t.to(target.dtype)
|
|
94
|
+
|
|
95
|
+
return builder.call_function(_cast_like, args=(x, target))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# =============================================================================
|
|
99
|
+
# Shape manipulation operators
|
|
100
|
+
# =============================================================================
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register("Reshape")
|
|
104
|
+
def reshape(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
105
|
+
"""Reshape tensor to a new shape.
|
|
106
|
+
|
|
107
|
+
ONNX Reshape semantics:
|
|
108
|
+
- A value of 0 means the dimension is unchanged from the input shape
|
|
109
|
+
- A value of -1 means the dimension is inferred from the remaining elements
|
|
110
|
+
"""
|
|
111
|
+
x = builder.get_value(node.input[0])
|
|
112
|
+
shape = builder.get_value(node.input[1])
|
|
113
|
+
|
|
114
|
+
# Check allowzero attribute (default is 0, meaning 0 copies from input)
|
|
115
|
+
allowzero = get_attribute(node, "allowzero", 0)
|
|
116
|
+
|
|
117
|
+
def _reshape(t, shape, allowzero):
|
|
118
|
+
if isinstance(shape, torch.Tensor):
|
|
119
|
+
shape = shape.tolist()
|
|
120
|
+
else:
|
|
121
|
+
shape = list(shape)
|
|
122
|
+
|
|
123
|
+
# Convert to integers (shape may contain floats from tensor operations)
|
|
124
|
+
shape = [int(d) for d in shape]
|
|
125
|
+
|
|
126
|
+
# ONNX: if allowzero=0, a value of 0 in shape means copy from input
|
|
127
|
+
if not allowzero:
|
|
128
|
+
for i, dim in enumerate(shape):
|
|
129
|
+
if dim == 0:
|
|
130
|
+
if i < t.dim():
|
|
131
|
+
shape[i] = t.shape[i]
|
|
132
|
+
|
|
133
|
+
return torch.reshape(t, tuple(shape))
|
|
134
|
+
|
|
135
|
+
return builder.call_function(_reshape, args=(x, shape, allowzero))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@register("Transpose")
|
|
139
|
+
def transpose(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
140
|
+
"""Transpose tensor dimensions."""
|
|
141
|
+
x = builder.get_value(node.input[0])
|
|
142
|
+
perm = get_attribute(node, "perm")
|
|
143
|
+
if perm is None:
|
|
144
|
+
# Default: reverse all dimensions
|
|
145
|
+
return builder.call_function(lambda t: t.T, args=(x,))
|
|
146
|
+
return builder.call_function(torch.permute, args=(x, perm))
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@register("Squeeze", since_version=1)
|
|
150
|
+
def squeeze_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
151
|
+
"""Remove dimensions of size 1 for opset 1-12.
|
|
152
|
+
|
|
153
|
+
In opset < 13, axes is an attribute.
|
|
154
|
+
"""
|
|
155
|
+
x = builder.get_value(node.input[0])
|
|
156
|
+
|
|
157
|
+
axes = get_attribute(node, "axes")
|
|
158
|
+
if axes is not None:
|
|
159
|
+
# Squeeze specific dimensions
|
|
160
|
+
result = x
|
|
161
|
+
# Sort in reverse to maintain correct indices after each squeeze
|
|
162
|
+
for axis in sorted(axes, reverse=True):
|
|
163
|
+
result = builder.call_function(
|
|
164
|
+
torch.squeeze, args=(result,), kwargs={"dim": axis}
|
|
165
|
+
)
|
|
166
|
+
return result
|
|
167
|
+
return builder.call_function(torch.squeeze, args=(x,))
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@register("Squeeze", since_version=13)
|
|
171
|
+
def squeeze_v13(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
172
|
+
"""Remove dimensions of size 1 for opset 13+.
|
|
173
|
+
|
|
174
|
+
In opset 13+, axes is an optional input (not attribute).
|
|
175
|
+
"""
|
|
176
|
+
x = builder.get_value(node.input[0])
|
|
177
|
+
|
|
178
|
+
# axes is an optional input in opset 13+
|
|
179
|
+
axes = get_optional_input(builder, node, 1)
|
|
180
|
+
if axes is not None:
|
|
181
|
+
|
|
182
|
+
def _squeeze_dynamic(t, axes):
|
|
183
|
+
if isinstance(axes, torch.Tensor):
|
|
184
|
+
axes = axes.tolist()
|
|
185
|
+
if isinstance(axes, list):
|
|
186
|
+
if len(axes) == 1:
|
|
187
|
+
return torch.squeeze(t, dim=axes[0])
|
|
188
|
+
# Multiple axes - squeeze in reverse order
|
|
189
|
+
result = t
|
|
190
|
+
for axis in sorted(axes, reverse=True):
|
|
191
|
+
result = torch.squeeze(result, dim=int(axis))
|
|
192
|
+
return result
|
|
193
|
+
return torch.squeeze(t, dim=int(axes))
|
|
194
|
+
|
|
195
|
+
return builder.call_function(_squeeze_dynamic, args=(x, axes))
|
|
196
|
+
|
|
197
|
+
# No axes input - squeeze all dimensions of size 1
|
|
198
|
+
return builder.call_function(torch.squeeze, args=(x,))
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@register("Unsqueeze", since_version=1)
|
|
202
|
+
def unsqueeze_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
203
|
+
"""Insert dimensions of size 1 for opset 1-12.
|
|
204
|
+
|
|
205
|
+
In opset < 13, axes is a required attribute.
|
|
206
|
+
"""
|
|
207
|
+
x = builder.get_value(node.input[0])
|
|
208
|
+
|
|
209
|
+
axes = get_attribute(node, "axes")
|
|
210
|
+
if axes is None:
|
|
211
|
+
raise ConversionError(
|
|
212
|
+
"Unsqueeze requires axes attribute in opset < 13",
|
|
213
|
+
node_name=node.name,
|
|
214
|
+
op_type="Unsqueeze",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Handle single axis
|
|
218
|
+
if isinstance(axes, int):
|
|
219
|
+
return builder.call_function(torch.unsqueeze, args=(x, axes))
|
|
220
|
+
|
|
221
|
+
# Handle multiple axes - unsqueeze in sorted order
|
|
222
|
+
result = x
|
|
223
|
+
for axis in sorted(axes):
|
|
224
|
+
result = builder.call_function(torch.unsqueeze, args=(result, axis))
|
|
225
|
+
return result
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@register("Unsqueeze", since_version=13)
|
|
229
|
+
def unsqueeze_v13(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
230
|
+
"""Insert dimensions of size 1 for opset 13+.
|
|
231
|
+
|
|
232
|
+
In opset 13+, axes is a required input (not attribute).
|
|
233
|
+
"""
|
|
234
|
+
x = builder.get_value(node.input[0])
|
|
235
|
+
|
|
236
|
+
if len(node.input) < 2 or not node.input[1]:
|
|
237
|
+
raise ConversionError(
|
|
238
|
+
"Unsqueeze requires axes input in opset 13+",
|
|
239
|
+
node_name=node.name,
|
|
240
|
+
op_type="Unsqueeze",
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
axes = builder.get_value(node.input[1])
|
|
244
|
+
|
|
245
|
+
def _unsqueeze_dynamic(t, axes):
|
|
246
|
+
if isinstance(axes, torch.Tensor):
|
|
247
|
+
axes = axes.tolist()
|
|
248
|
+
if isinstance(axes, int):
|
|
249
|
+
return torch.unsqueeze(t, axes)
|
|
250
|
+
# Handle multiple axes - unsqueeze in sorted order
|
|
251
|
+
result = t
|
|
252
|
+
for axis in sorted(axes):
|
|
253
|
+
result = torch.unsqueeze(result, int(axis))
|
|
254
|
+
return result
|
|
255
|
+
|
|
256
|
+
return builder.call_function(_unsqueeze_dynamic, args=(x, axes))
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@register("Flatten")
|
|
260
|
+
def flatten(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
261
|
+
"""Flatten tensor to 2D.
|
|
262
|
+
|
|
263
|
+
ONNX Flatten reshapes the input tensor to a 2D tensor:
|
|
264
|
+
- First dimension = product of dimensions from 0 to axis-1
|
|
265
|
+
- Second dimension = product of dimensions from axis to end
|
|
266
|
+
"""
|
|
267
|
+
x = builder.get_value(node.input[0])
|
|
268
|
+
axis = get_attribute(node, "axis", 1)
|
|
269
|
+
|
|
270
|
+
def _flatten_to_2d(t, axis):
|
|
271
|
+
shape = t.shape
|
|
272
|
+
# Handle negative axis
|
|
273
|
+
if axis < 0:
|
|
274
|
+
axis = len(shape) + axis
|
|
275
|
+
# Compute dimensions
|
|
276
|
+
dim0 = 1
|
|
277
|
+
for i in range(axis):
|
|
278
|
+
dim0 *= shape[i]
|
|
279
|
+
dim1 = 1
|
|
280
|
+
for i in range(axis, len(shape)):
|
|
281
|
+
dim1 *= shape[i]
|
|
282
|
+
return t.reshape(dim0, dim1)
|
|
283
|
+
|
|
284
|
+
return builder.call_function(_flatten_to_2d, args=(x, axis))
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
@register("Expand")
|
|
288
|
+
def expand(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
289
|
+
"""Broadcast tensor to a new shape.
|
|
290
|
+
|
|
291
|
+
ONNX Expand uses bidirectional broadcasting, which means:
|
|
292
|
+
- If target dim is 1, keep the original dimension
|
|
293
|
+
- The output shape is max(input_dim, target_dim) for each dimension
|
|
294
|
+
"""
|
|
295
|
+
x = builder.get_value(node.input[0])
|
|
296
|
+
shape = builder.get_value(node.input[1])
|
|
297
|
+
|
|
298
|
+
def _expand(t, shape):
|
|
299
|
+
if isinstance(shape, torch.Tensor):
|
|
300
|
+
shape = tuple(int(s) for s in shape.tolist())
|
|
301
|
+
# Use broadcast_shapes to compute the actual broadcast shape
|
|
302
|
+
# This handles cases where target_dim=1 should preserve input_dim
|
|
303
|
+
broadcast_shape = torch.broadcast_shapes(t.shape, shape)
|
|
304
|
+
return t.expand(broadcast_shape)
|
|
305
|
+
|
|
306
|
+
return builder.call_function(_expand, args=(x, shape))
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
# =============================================================================
|
|
310
|
+
# Concatenation and splitting operators
|
|
311
|
+
# =============================================================================
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@register("Concat")
|
|
315
|
+
def concat(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
316
|
+
"""Concatenate tensors along an axis."""
|
|
317
|
+
inputs = [builder.get_value(name) for name in node.input]
|
|
318
|
+
axis = get_attribute(node, "axis", 0)
|
|
319
|
+
return builder.call_function(torch.cat, args=(inputs,), kwargs={"dim": axis})
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@register("Split", since_version=1)
|
|
323
|
+
def split_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
324
|
+
"""Split tensor into chunks for opset 1-12.
|
|
325
|
+
|
|
326
|
+
In opset < 13, split sizes is an optional attribute.
|
|
327
|
+
"""
|
|
328
|
+
x = builder.get_value(node.input[0])
|
|
329
|
+
axis = get_attribute(node, "axis", 0)
|
|
330
|
+
|
|
331
|
+
split_attr = get_attribute(node, "split")
|
|
332
|
+
if split_attr is not None:
|
|
333
|
+
result = builder.call_function(torch.split, args=(x, list(split_attr), axis))
|
|
334
|
+
else:
|
|
335
|
+
# Default: split into equal parts based on number of outputs
|
|
336
|
+
result = builder.call_function(torch.chunk, args=(x, len(node.output), axis))
|
|
337
|
+
|
|
338
|
+
# Handle multiple outputs
|
|
339
|
+
for i, output_name in enumerate(node.output):
|
|
340
|
+
if output_name:
|
|
341
|
+
idx_node = builder.call_function(lambda t, idx: t[idx], args=(result, i))
|
|
342
|
+
builder.env[output_name] = idx_node
|
|
343
|
+
|
|
344
|
+
return result
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@register("Split", since_version=13)
|
|
348
|
+
def split_v13(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
349
|
+
"""Split tensor into chunks for opset 13+.
|
|
350
|
+
|
|
351
|
+
In opset 13+, split sizes is an optional input.
|
|
352
|
+
In opset 18+, num_outputs attribute was added.
|
|
353
|
+
"""
|
|
354
|
+
x = builder.get_value(node.input[0])
|
|
355
|
+
axis = get_attribute(node, "axis", 0)
|
|
356
|
+
num_outputs = get_attribute(node, "num_outputs") # Added in opset 18
|
|
357
|
+
|
|
358
|
+
# split sizes is an optional input in opset 13+
|
|
359
|
+
split_sizes = get_optional_input(builder, node, 1)
|
|
360
|
+
if split_sizes is not None:
|
|
361
|
+
|
|
362
|
+
def _split_with_sizes(t, sizes, dim):
|
|
363
|
+
if hasattr(sizes, "tolist"):
|
|
364
|
+
sizes = sizes.tolist()
|
|
365
|
+
return torch.split(t, sizes, dim)
|
|
366
|
+
|
|
367
|
+
result = builder.call_function(_split_with_sizes, args=(x, split_sizes, axis))
|
|
368
|
+
elif num_outputs is not None:
|
|
369
|
+
# Split into equal parts using num_outputs (opset 18+)
|
|
370
|
+
result = builder.call_function(torch.chunk, args=(x, num_outputs, axis))
|
|
371
|
+
else:
|
|
372
|
+
# Default: split into equal parts based on number of outputs
|
|
373
|
+
result = builder.call_function(torch.chunk, args=(x, len(node.output), axis))
|
|
374
|
+
|
|
375
|
+
# Handle multiple outputs
|
|
376
|
+
for i, output_name in enumerate(node.output):
|
|
377
|
+
if output_name:
|
|
378
|
+
idx_node = builder.call_function(lambda t, idx: t[idx], args=(result, i))
|
|
379
|
+
builder.env[output_name] = idx_node
|
|
380
|
+
|
|
381
|
+
return result
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# =============================================================================
|
|
385
|
+
# Slicing and indexing operators
|
|
386
|
+
# =============================================================================
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
@register("Slice", since_version=1)
|
|
390
|
+
def slice_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
391
|
+
"""Slice tensor along axes (opset 1-9).
|
|
392
|
+
|
|
393
|
+
In opset < 10, starts, ends, and axes are attributes.
|
|
394
|
+
"""
|
|
395
|
+
x = builder.get_value(node.input[0])
|
|
396
|
+
starts = get_attribute(node, "starts")
|
|
397
|
+
ends = get_attribute(node, "ends")
|
|
398
|
+
axes = get_attribute(node, "axes")
|
|
399
|
+
# Note: steps attribute doesn't exist in opset < 10
|
|
400
|
+
|
|
401
|
+
return builder.call_function(
|
|
402
|
+
_dynamic_slice,
|
|
403
|
+
args=(x, list(starts), list(ends), list(axes) if axes else None, None),
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@register("Slice", since_version=10)
|
|
408
|
+
def slice_v10(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
409
|
+
"""Slice tensor along axes (opset 10+).
|
|
410
|
+
|
|
411
|
+
In opset 10+, starts, ends, axes, and steps are inputs.
|
|
412
|
+
"""
|
|
413
|
+
x = builder.get_value(node.input[0])
|
|
414
|
+
starts = builder.get_value(node.input[1])
|
|
415
|
+
ends = builder.get_value(node.input[2])
|
|
416
|
+
|
|
417
|
+
axes = get_optional_input(builder, node, 3)
|
|
418
|
+
steps = get_optional_input(builder, node, 4)
|
|
419
|
+
|
|
420
|
+
# Use torch.narrow for simple cases, or dynamic slicing
|
|
421
|
+
return builder.call_function(
|
|
422
|
+
_dynamic_slice,
|
|
423
|
+
args=(x, starts, ends, axes, steps),
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _dynamic_slice(x, starts, ends, axes=None, steps=None):
|
|
428
|
+
"""Helper function for dynamic slicing with support for negative steps."""
|
|
429
|
+
import torch
|
|
430
|
+
|
|
431
|
+
# Convert to lists if tensors
|
|
432
|
+
if isinstance(starts, torch.Tensor):
|
|
433
|
+
starts = starts.tolist()
|
|
434
|
+
if isinstance(ends, torch.Tensor):
|
|
435
|
+
ends = ends.tolist()
|
|
436
|
+
if axes is not None and isinstance(axes, torch.Tensor):
|
|
437
|
+
axes = axes.tolist()
|
|
438
|
+
if steps is not None and isinstance(steps, torch.Tensor):
|
|
439
|
+
steps = steps.tolist()
|
|
440
|
+
|
|
441
|
+
if axes is None:
|
|
442
|
+
axes = list(range(len(starts)))
|
|
443
|
+
if steps is None:
|
|
444
|
+
steps = [1] * len(starts)
|
|
445
|
+
|
|
446
|
+
# Handle negative steps by flipping, slicing with positive step, then flipping back
|
|
447
|
+
# We process each axis separately to handle this correctly
|
|
448
|
+
result = x
|
|
449
|
+
for start, end, axis, step in zip(starts, ends, axes, steps):
|
|
450
|
+
dim_size = result.size(axis)
|
|
451
|
+
|
|
452
|
+
if step < 0:
|
|
453
|
+
# For negative steps, ONNX semantics:
|
|
454
|
+
# start defaults to dim_size - 1, end defaults to -dim_size - 1
|
|
455
|
+
# We iterate from start down to end (exclusive) with abs(step)
|
|
456
|
+
|
|
457
|
+
# Handle special ONNX sentinel values and negative indices
|
|
458
|
+
if start >= dim_size:
|
|
459
|
+
start = dim_size - 1
|
|
460
|
+
elif start < 0:
|
|
461
|
+
start = max(-1, dim_size + start)
|
|
462
|
+
|
|
463
|
+
if end < -dim_size:
|
|
464
|
+
end = -1 # Sentinel for "before the beginning"
|
|
465
|
+
elif end < 0:
|
|
466
|
+
end = dim_size + end
|
|
467
|
+
|
|
468
|
+
# For negative step: we go from start down to end (exclusive)
|
|
469
|
+
# Example: start=20, end=0, step=-1 means indices [20, 19, ..., 1]
|
|
470
|
+
# Flip the axis, compute equivalent positive slice, then flip back
|
|
471
|
+
|
|
472
|
+
# Compute the actual range of elements we want
|
|
473
|
+
# start > end for negative step, so we want indices from end+1 to start (inclusive)
|
|
474
|
+
actual_start = end + 1 if end >= 0 else 0
|
|
475
|
+
actual_end = start + 1 if start >= 0 else dim_size
|
|
476
|
+
|
|
477
|
+
# Clamp to valid range
|
|
478
|
+
actual_start = max(0, min(actual_start, dim_size))
|
|
479
|
+
actual_end = max(0, min(actual_end, dim_size))
|
|
480
|
+
|
|
481
|
+
if actual_start >= actual_end:
|
|
482
|
+
# Empty slice
|
|
483
|
+
slices = [slice(None)] * result.dim()
|
|
484
|
+
slices[axis] = slice(0, 0)
|
|
485
|
+
result = result[tuple(slices)]
|
|
486
|
+
else:
|
|
487
|
+
# First slice to get the range
|
|
488
|
+
slices = [slice(None)] * result.dim()
|
|
489
|
+
slices[axis] = slice(int(actual_start), int(actual_end))
|
|
490
|
+
result = result[tuple(slices)]
|
|
491
|
+
|
|
492
|
+
# Then flip to reverse the order
|
|
493
|
+
result = torch.flip(result, dims=[axis])
|
|
494
|
+
|
|
495
|
+
# Apply striding if step < -1
|
|
496
|
+
if step < -1:
|
|
497
|
+
abs_step = -step
|
|
498
|
+
slices = [slice(None)] * result.dim()
|
|
499
|
+
slices[axis] = slice(None, None, int(abs_step))
|
|
500
|
+
result = result[tuple(slices)]
|
|
501
|
+
else:
|
|
502
|
+
# Positive step - original logic
|
|
503
|
+
if start < 0:
|
|
504
|
+
start = max(0, dim_size + start)
|
|
505
|
+
if end < 0:
|
|
506
|
+
end = max(0, dim_size + end)
|
|
507
|
+
# Clamp to valid range
|
|
508
|
+
start = min(start, dim_size)
|
|
509
|
+
end = min(end, dim_size)
|
|
510
|
+
slices = [slice(None)] * result.dim()
|
|
511
|
+
slices[axis] = slice(int(start), int(end), int(step))
|
|
512
|
+
result = result[tuple(slices)]
|
|
513
|
+
|
|
514
|
+
return result
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
@register("Gather")
|
|
518
|
+
def gather(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
519
|
+
"""Gather elements along an axis.
|
|
520
|
+
|
|
521
|
+
ONNX Gather behavior:
|
|
522
|
+
- output shape = data.shape[:axis] + indices.shape + data.shape[axis+1:]
|
|
523
|
+
- If indices is a scalar, the axis dimension is removed from the output
|
|
524
|
+
- If indices is a multi-dimensional tensor, indices.shape replaces the axis dimension
|
|
525
|
+
"""
|
|
526
|
+
x = builder.get_value(node.input[0])
|
|
527
|
+
indices = builder.get_value(node.input[1])
|
|
528
|
+
axis = get_attribute(node, "axis", 0)
|
|
529
|
+
|
|
530
|
+
def _gather(data, indices, axis):
|
|
531
|
+
indices = indices.long()
|
|
532
|
+
|
|
533
|
+
if axis < 0:
|
|
534
|
+
axis = data.dim() + axis
|
|
535
|
+
|
|
536
|
+
# Handle scalar indices - need to squeeze the dimension after gather
|
|
537
|
+
if indices.ndim == 0:
|
|
538
|
+
# Scalar index: select single element along axis, removing that dimension
|
|
539
|
+
return torch.index_select(data, axis, indices.unsqueeze(0)).squeeze(axis)
|
|
540
|
+
|
|
541
|
+
# For multi-dimensional indices, we need proper ONNX Gather semantics
|
|
542
|
+
# Move the gather axis to position 0
|
|
543
|
+
if axis != 0:
|
|
544
|
+
data = data.movedim(axis, 0)
|
|
545
|
+
|
|
546
|
+
# Flatten indices for indexing
|
|
547
|
+
indices_flat = indices.flatten()
|
|
548
|
+
gathered = data[indices_flat] # [num_indices, ...]
|
|
549
|
+
|
|
550
|
+
# Reshape to restore indices dimensions
|
|
551
|
+
new_shape = list(indices.shape) + list(data.shape[1:])
|
|
552
|
+
gathered = gathered.view(new_shape)
|
|
553
|
+
|
|
554
|
+
# Move the original leading dimensions back
|
|
555
|
+
if axis != 0:
|
|
556
|
+
# Permute dimensions to restore original order
|
|
557
|
+
# Current: [idx..., prefix..., suffix...]
|
|
558
|
+
# Target: [prefix..., idx..., suffix...]
|
|
559
|
+
num_idx_dims = indices.ndim
|
|
560
|
+
num_prefix_dims = axis
|
|
561
|
+
|
|
562
|
+
perm = (
|
|
563
|
+
list(range(num_idx_dims, num_idx_dims + num_prefix_dims))
|
|
564
|
+
+ list(range(num_idx_dims))
|
|
565
|
+
+ list(range(num_idx_dims + num_prefix_dims, gathered.ndim))
|
|
566
|
+
)
|
|
567
|
+
gathered = gathered.permute(perm)
|
|
568
|
+
|
|
569
|
+
return gathered
|
|
570
|
+
|
|
571
|
+
return builder.call_function(_gather, args=(x, indices, axis))
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
@register("GatherElements")
|
|
575
|
+
def gather_elements(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
576
|
+
"""Gather elements using indices with same rank as input."""
|
|
577
|
+
x = builder.get_value(node.input[0])
|
|
578
|
+
indices = builder.get_value(node.input[1])
|
|
579
|
+
axis = get_attribute(node, "axis", 0)
|
|
580
|
+
return builder.call_function(torch.gather, args=(x, axis, indices))
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
@register("GatherND")
|
|
584
|
+
def gather_nd(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
585
|
+
"""Gather slices using n-dimensional indices."""
|
|
586
|
+
x = builder.get_value(node.input[0])
|
|
587
|
+
indices = builder.get_value(node.input[1])
|
|
588
|
+
batch_dims = get_attribute(node, "batch_dims", 0)
|
|
589
|
+
|
|
590
|
+
def _gather_nd(data, indices, batch_dims=0):
|
|
591
|
+
# Simplified GatherND implementation
|
|
592
|
+
indices = indices.long()
|
|
593
|
+
if batch_dims == 0:
|
|
594
|
+
# Flatten indices to list of coordinate tuples
|
|
595
|
+
idx_shape = indices.shape
|
|
596
|
+
indices_flat = indices.reshape(-1, idx_shape[-1])
|
|
597
|
+
result = torch.stack([data[tuple(idx)] for idx in indices_flat])
|
|
598
|
+
return result.reshape(idx_shape[:-1] + data.shape[indices.shape[-1] :])
|
|
599
|
+
else:
|
|
600
|
+
raise NotImplementedError("batch_dims > 0 not yet supported for GatherND")
|
|
601
|
+
|
|
602
|
+
return builder.call_function(_gather_nd, args=(x, indices, batch_dims))
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
@register("ScatterElements")
|
|
606
|
+
def scatter_elements(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
607
|
+
"""Scatter elements using indices."""
|
|
608
|
+
x = builder.get_value(node.input[0])
|
|
609
|
+
indices = builder.get_value(node.input[1])
|
|
610
|
+
updates = builder.get_value(node.input[2])
|
|
611
|
+
axis = get_attribute(node, "axis", 0)
|
|
612
|
+
reduction = get_attribute(node, "reduction", "none")
|
|
613
|
+
|
|
614
|
+
def _scatter_elements(data, axis, idx, upd, reduction):
|
|
615
|
+
# Handle negative axis
|
|
616
|
+
if axis < 0:
|
|
617
|
+
axis = data.ndim + axis
|
|
618
|
+
|
|
619
|
+
# Handle negative indices by converting to positive
|
|
620
|
+
dim_size = data.shape[axis]
|
|
621
|
+
idx = torch.where(idx < 0, idx + dim_size, idx)
|
|
622
|
+
|
|
623
|
+
# Map ONNX reduction to PyTorch reduce argument
|
|
624
|
+
if reduction == "none":
|
|
625
|
+
return data.scatter(axis, idx, upd)
|
|
626
|
+
elif reduction == "add":
|
|
627
|
+
return data.scatter_add(axis, idx, upd)
|
|
628
|
+
elif reduction == "mul":
|
|
629
|
+
return data.scatter_reduce(axis, idx, upd, reduce="prod")
|
|
630
|
+
elif reduction == "max":
|
|
631
|
+
return data.scatter_reduce(axis, idx, upd, reduce="amax")
|
|
632
|
+
elif reduction == "min":
|
|
633
|
+
return data.scatter_reduce(axis, idx, upd, reduce="amin")
|
|
634
|
+
else:
|
|
635
|
+
raise ValueError(f"Unsupported reduction: {reduction}")
|
|
636
|
+
|
|
637
|
+
return builder.call_function(
|
|
638
|
+
_scatter_elements, args=(x, axis, indices, updates, reduction)
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
@register("Scatter")
|
|
643
|
+
def scatter(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
644
|
+
"""Scatter (deprecated, replaced by ScatterElements in opset 11)."""
|
|
645
|
+
x = builder.get_value(node.input[0])
|
|
646
|
+
indices = builder.get_value(node.input[1])
|
|
647
|
+
updates = builder.get_value(node.input[2])
|
|
648
|
+
axis = get_attribute(node, "axis", 0)
|
|
649
|
+
|
|
650
|
+
def _scatter(data, axis, idx, upd):
|
|
651
|
+
# Handle negative axis
|
|
652
|
+
if axis < 0:
|
|
653
|
+
axis = data.ndim + axis
|
|
654
|
+
|
|
655
|
+
# Handle negative indices by converting to positive
|
|
656
|
+
dim_size = data.shape[axis]
|
|
657
|
+
idx = torch.where(idx < 0, idx + dim_size, idx)
|
|
658
|
+
|
|
659
|
+
return data.scatter(axis, idx, upd)
|
|
660
|
+
|
|
661
|
+
return builder.call_function(_scatter, args=(x, axis, indices, updates))
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
# =============================================================================
|
|
665
|
+
# Tiling and padding operators
|
|
666
|
+
# =============================================================================
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
@register("Tile")
|
|
670
|
+
def tile(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
671
|
+
"""Tile tensor by repeating."""
|
|
672
|
+
x = builder.get_value(node.input[0])
|
|
673
|
+
repeats = builder.get_value(node.input[1])
|
|
674
|
+
|
|
675
|
+
def _tile(t, reps):
|
|
676
|
+
if isinstance(reps, torch.Tensor):
|
|
677
|
+
reps = tuple(int(r) for r in reps.tolist())
|
|
678
|
+
return torch.tile(t, reps)
|
|
679
|
+
|
|
680
|
+
return builder.call_function(_tile, args=(x, repeats))
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def _pad_impl(x, pads, mode, constant_value):
|
|
684
|
+
"""Helper function for Pad operator.
|
|
685
|
+
|
|
686
|
+
Converts ONNX pad format to PyTorch format and applies padding.
|
|
687
|
+
ONNX: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
|
688
|
+
PyTorch: [xn_begin, xn_end, ..., x1_begin, x1_end]
|
|
689
|
+
"""
|
|
690
|
+
import torch
|
|
691
|
+
import torch.nn.functional as F
|
|
692
|
+
|
|
693
|
+
if isinstance(pads, torch.Tensor):
|
|
694
|
+
pads = pads.tolist()
|
|
695
|
+
|
|
696
|
+
n = len(pads) // 2
|
|
697
|
+
# Reverse and interleave
|
|
698
|
+
torch_pads = []
|
|
699
|
+
for i in range(n - 1, -1, -1):
|
|
700
|
+
torch_pads.extend([int(pads[i]), int(pads[i + n])])
|
|
701
|
+
|
|
702
|
+
mode_map = {"constant": "constant", "reflect": "reflect", "edge": "replicate"}
|
|
703
|
+
torch_mode = mode_map.get(mode, "constant")
|
|
704
|
+
|
|
705
|
+
if torch_mode == "constant":
|
|
706
|
+
return F.pad(x, torch_pads, mode=torch_mode, value=float(constant_value))
|
|
707
|
+
|
|
708
|
+
# For non-constant modes (reflect, replicate), PyTorch only supports padding
|
|
709
|
+
# the last N dimensions. Trim leading zero-padding pairs.
|
|
710
|
+
# torch_pads is ordered as [last_dim_begin, last_dim_end, ..., first_dim_begin, first_dim_end]
|
|
711
|
+
# We need to trim trailing zero pairs (which correspond to first dimensions).
|
|
712
|
+
while len(torch_pads) > 2 and torch_pads[-1] == 0 and torch_pads[-2] == 0:
|
|
713
|
+
torch_pads = torch_pads[:-2]
|
|
714
|
+
|
|
715
|
+
return F.pad(x, torch_pads, mode=torch_mode)
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
@register("Pad", since_version=1)
|
|
719
|
+
def pad_v1(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
720
|
+
"""Pad tensor (opset 1-10).
|
|
721
|
+
|
|
722
|
+
In opset < 11, pads and value are attributes.
|
|
723
|
+
"""
|
|
724
|
+
x = builder.get_value(node.input[0])
|
|
725
|
+
pads = list(get_attribute(node, "pads"))
|
|
726
|
+
mode = get_attribute(node, "mode", "constant")
|
|
727
|
+
constant_value = get_attribute(node, "value", 0.0)
|
|
728
|
+
|
|
729
|
+
return builder.call_function(_pad_impl, args=(x, pads, mode, constant_value))
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
@register("Pad", since_version=11)
|
|
733
|
+
def pad_v11(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
734
|
+
"""Pad tensor (opset 11+).
|
|
735
|
+
|
|
736
|
+
In opset 11+, pads, constant_value, and axes are inputs.
|
|
737
|
+
"""
|
|
738
|
+
x = builder.get_value(node.input[0])
|
|
739
|
+
pads = builder.get_value(node.input[1])
|
|
740
|
+
mode = get_attribute(node, "mode", "constant")
|
|
741
|
+
|
|
742
|
+
constant_value = get_optional_input(builder, node, 2, default=0.0)
|
|
743
|
+
|
|
744
|
+
# Note: axes input (opset 18+) is not yet supported
|
|
745
|
+
# If needed, would require reordering pads based on axes
|
|
746
|
+
|
|
747
|
+
return builder.call_function(_pad_impl, args=(x, pads, mode, constant_value))
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
# =============================================================================
|
|
751
|
+
# Shape operators
|
|
752
|
+
# =============================================================================
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
@register("Shape")
|
|
756
|
+
def shape(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
757
|
+
"""Get tensor shape."""
|
|
758
|
+
x = builder.get_value(node.input[0])
|
|
759
|
+
start = get_attribute(node, "start", 0)
|
|
760
|
+
end = get_attribute(node, "end")
|
|
761
|
+
|
|
762
|
+
def _get_shape(t, start, end):
|
|
763
|
+
shape = torch.tensor(t.shape, dtype=torch.int64)
|
|
764
|
+
if end is None:
|
|
765
|
+
return shape[start:]
|
|
766
|
+
return shape[start:end]
|
|
767
|
+
|
|
768
|
+
return builder.call_function(_get_shape, args=(x, start, end))
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
@register("Size")
|
|
772
|
+
def size(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
773
|
+
"""Get total number of elements."""
|
|
774
|
+
x = builder.get_value(node.input[0])
|
|
775
|
+
return builder.call_function(
|
|
776
|
+
lambda t: torch.tensor(t.numel(), dtype=torch.int64), args=(x,)
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
@register("ConstantOfShape")
|
|
781
|
+
def constant_of_shape(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
782
|
+
"""Create tensor filled with constant value."""
|
|
783
|
+
shape = builder.get_value(node.input[0])
|
|
784
|
+
value = get_attribute(node, "value", tensor_loader=builder.load_tensor)
|
|
785
|
+
|
|
786
|
+
if value is not None:
|
|
787
|
+
fill_value = (
|
|
788
|
+
value.item() if hasattr(value, "item") else float(value.flatten()[0])
|
|
789
|
+
)
|
|
790
|
+
dtype = value.dtype
|
|
791
|
+
else:
|
|
792
|
+
fill_value = 0.0
|
|
793
|
+
dtype = torch.float32
|
|
794
|
+
|
|
795
|
+
def _constant_of_shape(shape, fill_value, dtype):
|
|
796
|
+
if isinstance(shape, torch.Tensor):
|
|
797
|
+
shape = shape.tolist()
|
|
798
|
+
return torch.full(shape, fill_value, dtype=dtype)
|
|
799
|
+
|
|
800
|
+
return builder.call_function(_constant_of_shape, args=(shape, fill_value, dtype))
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
# =============================================================================
|
|
804
|
+
# Tensor generation operators
|
|
805
|
+
# =============================================================================
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
@register("Range")
|
|
809
|
+
def range_(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
810
|
+
"""Generate a range of values."""
|
|
811
|
+
start = builder.get_value(node.input[0])
|
|
812
|
+
limit = builder.get_value(node.input[1])
|
|
813
|
+
delta = builder.get_value(node.input[2])
|
|
814
|
+
|
|
815
|
+
def _range(start, limit, delta):
|
|
816
|
+
# Extract scalar values
|
|
817
|
+
st = start.item() if isinstance(start, torch.Tensor) else start
|
|
818
|
+
lim = limit.item() if isinstance(limit, torch.Tensor) else limit
|
|
819
|
+
dlt = delta.item() if isinstance(delta, torch.Tensor) else delta
|
|
820
|
+
dtype = start.dtype if isinstance(start, torch.Tensor) else torch.float32
|
|
821
|
+
return torch.arange(st, lim, dlt, dtype=dtype)
|
|
822
|
+
|
|
823
|
+
return builder.call_function(_range, args=(start, limit, delta))
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
@register("OneHot")
|
|
827
|
+
def one_hot(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
828
|
+
"""One-hot encoding."""
|
|
829
|
+
indices = builder.get_value(node.input[0])
|
|
830
|
+
depth = builder.get_value(node.input[1])
|
|
831
|
+
values = builder.get_value(node.input[2])
|
|
832
|
+
|
|
833
|
+
axis = get_attribute(node, "axis", -1)
|
|
834
|
+
|
|
835
|
+
def _one_hot(indices, depth, values, axis):
|
|
836
|
+
d = depth.item() if isinstance(depth, torch.Tensor) else depth
|
|
837
|
+
off_value = values[0]
|
|
838
|
+
on_value = values[1]
|
|
839
|
+
|
|
840
|
+
# Create one-hot tensor
|
|
841
|
+
result = torch.nn.functional.one_hot(indices.long(), int(d))
|
|
842
|
+
result = result.to(values.dtype)
|
|
843
|
+
|
|
844
|
+
# Apply on/off values
|
|
845
|
+
result = result * (on_value - off_value) + off_value
|
|
846
|
+
|
|
847
|
+
# Move axis if needed
|
|
848
|
+
if axis != -1 and axis != indices.dim():
|
|
849
|
+
# Permute to move the one-hot dimension to the correct axis
|
|
850
|
+
ndim = result.dim()
|
|
851
|
+
if axis < 0:
|
|
852
|
+
axis = ndim + axis
|
|
853
|
+
perm = list(range(ndim - 1))
|
|
854
|
+
perm.insert(axis, ndim - 1)
|
|
855
|
+
result = result.permute(perm)
|
|
856
|
+
|
|
857
|
+
return result
|
|
858
|
+
|
|
859
|
+
return builder.call_function(_one_hot, args=(indices, depth, values, axis))
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
@register("NonZero")
|
|
863
|
+
def non_zero(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
864
|
+
"""Find indices of non-zero elements."""
|
|
865
|
+
x = builder.get_value(node.input[0])
|
|
866
|
+
|
|
867
|
+
def _non_zero(x):
|
|
868
|
+
# ONNX returns shape (rank, num_nonzero), PyTorch returns tuple
|
|
869
|
+
result = torch.nonzero(x, as_tuple=False).T
|
|
870
|
+
return result.to(torch.int64)
|
|
871
|
+
|
|
872
|
+
return builder.call_function(_non_zero, args=(x,))
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
@register("Unique")
|
|
876
|
+
def unique(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
877
|
+
"""Find unique elements."""
|
|
878
|
+
x = builder.get_value(node.input[0])
|
|
879
|
+
|
|
880
|
+
axis = get_attribute(node, "axis")
|
|
881
|
+
sorted_ = get_attribute(node, "sorted", 1)
|
|
882
|
+
|
|
883
|
+
def _unique(x, axis, sorted_):
|
|
884
|
+
if axis is not None:
|
|
885
|
+
return torch.unique(
|
|
886
|
+
x,
|
|
887
|
+
sorted=bool(sorted_),
|
|
888
|
+
return_inverse=True,
|
|
889
|
+
return_counts=True,
|
|
890
|
+
dim=axis,
|
|
891
|
+
)
|
|
892
|
+
return torch.unique(
|
|
893
|
+
x, sorted=bool(sorted_), return_inverse=True, return_counts=True
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
return builder.call_function(_unique, args=(x, axis, sorted_))
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
@register("Trilu")
|
|
900
|
+
def trilu(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
901
|
+
"""Triangular part of matrix."""
|
|
902
|
+
x = builder.get_value(node.input[0])
|
|
903
|
+
|
|
904
|
+
k = get_optional_input(builder, node, 1, default=0)
|
|
905
|
+
|
|
906
|
+
upper = get_attribute(node, "upper", 1)
|
|
907
|
+
|
|
908
|
+
def _trilu(x, k, upper):
|
|
909
|
+
k_val = k.item() if isinstance(k, torch.Tensor) else k
|
|
910
|
+
if upper:
|
|
911
|
+
return torch.triu(x, diagonal=int(k_val))
|
|
912
|
+
return torch.tril(x, diagonal=int(k_val))
|
|
913
|
+
|
|
914
|
+
return builder.call_function(_trilu, args=(x, k, upper))
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
@register("EyeLike")
|
|
918
|
+
def eye_like(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
919
|
+
"""Create an identity matrix with the same shape as input.
|
|
920
|
+
|
|
921
|
+
Note: The dtype attribute is ignored; output uses input tensor's dtype.
|
|
922
|
+
"""
|
|
923
|
+
x = builder.get_value(node.input[0])
|
|
924
|
+
k = get_attribute(node, "k", 0)
|
|
925
|
+
|
|
926
|
+
def _eye_like(t: torch.Tensor, diag: int) -> torch.Tensor:
|
|
927
|
+
n, m = t.shape[-2], t.shape[-1]
|
|
928
|
+
eye = torch.eye(n, m, dtype=t.dtype, device=t.device)
|
|
929
|
+
if diag != 0:
|
|
930
|
+
eye = torch.diagonal(eye, offset=diag)
|
|
931
|
+
return eye
|
|
932
|
+
|
|
933
|
+
return builder.call_function(_eye_like, args=(x, k))
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
# =============================================================================
|
|
937
|
+
# Scatter ND operators
|
|
938
|
+
# =============================================================================
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
@register("ScatterND")
|
|
942
|
+
def scatter_nd(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
943
|
+
"""Scatter updates into data at indices."""
|
|
944
|
+
data = builder.get_value(node.input[0])
|
|
945
|
+
indices = builder.get_value(node.input[1])
|
|
946
|
+
updates = builder.get_value(node.input[2])
|
|
947
|
+
|
|
948
|
+
reduction = get_attribute(node, "reduction", "none")
|
|
949
|
+
|
|
950
|
+
def _scatter_nd_none(
|
|
951
|
+
d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
|
|
952
|
+
) -> torch.Tensor:
|
|
953
|
+
output = d.clone()
|
|
954
|
+
idx = idx.long()
|
|
955
|
+
|
|
956
|
+
idx_shape = idx.shape[:-1]
|
|
957
|
+
last_dim = idx.shape[-1]
|
|
958
|
+
|
|
959
|
+
flat_idx = idx.reshape(-1, last_dim)
|
|
960
|
+
flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
|
|
961
|
+
|
|
962
|
+
for i in range(flat_idx.shape[0]):
|
|
963
|
+
data_idx = tuple(flat_idx[i].tolist())
|
|
964
|
+
output[data_idx] = flat_upd[i]
|
|
965
|
+
|
|
966
|
+
return output
|
|
967
|
+
|
|
968
|
+
def _scatter_nd_add(
|
|
969
|
+
d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
|
|
970
|
+
) -> torch.Tensor:
|
|
971
|
+
output = d.clone()
|
|
972
|
+
idx = idx.long()
|
|
973
|
+
|
|
974
|
+
idx_shape = idx.shape[:-1]
|
|
975
|
+
last_dim = idx.shape[-1]
|
|
976
|
+
|
|
977
|
+
flat_idx = idx.reshape(-1, last_dim)
|
|
978
|
+
flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
|
|
979
|
+
|
|
980
|
+
for i in range(flat_idx.shape[0]):
|
|
981
|
+
data_idx = tuple(flat_idx[i].tolist())
|
|
982
|
+
output[data_idx] = output[data_idx] + flat_upd[i]
|
|
983
|
+
|
|
984
|
+
return output
|
|
985
|
+
|
|
986
|
+
def _scatter_nd_mul(
|
|
987
|
+
d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
|
|
988
|
+
) -> torch.Tensor:
|
|
989
|
+
output = d.clone()
|
|
990
|
+
idx = idx.long()
|
|
991
|
+
|
|
992
|
+
idx_shape = idx.shape[:-1]
|
|
993
|
+
last_dim = idx.shape[-1]
|
|
994
|
+
|
|
995
|
+
flat_idx = idx.reshape(-1, last_dim)
|
|
996
|
+
flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
|
|
997
|
+
|
|
998
|
+
for i in range(flat_idx.shape[0]):
|
|
999
|
+
data_idx = tuple(flat_idx[i].tolist())
|
|
1000
|
+
output[data_idx] = output[data_idx] * flat_upd[i]
|
|
1001
|
+
|
|
1002
|
+
return output
|
|
1003
|
+
|
|
1004
|
+
def _scatter_nd_max(
|
|
1005
|
+
d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
|
|
1006
|
+
) -> torch.Tensor:
|
|
1007
|
+
output = d.clone()
|
|
1008
|
+
idx = idx.long()
|
|
1009
|
+
|
|
1010
|
+
idx_shape = idx.shape[:-1]
|
|
1011
|
+
last_dim = idx.shape[-1]
|
|
1012
|
+
|
|
1013
|
+
flat_idx = idx.reshape(-1, last_dim)
|
|
1014
|
+
flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
|
|
1015
|
+
|
|
1016
|
+
for i in range(flat_idx.shape[0]):
|
|
1017
|
+
data_idx = tuple(flat_idx[i].tolist())
|
|
1018
|
+
output[data_idx] = torch.maximum(output[data_idx], flat_upd[i])
|
|
1019
|
+
|
|
1020
|
+
return output
|
|
1021
|
+
|
|
1022
|
+
def _scatter_nd_min(
|
|
1023
|
+
d: torch.Tensor, idx: torch.Tensor, upd: torch.Tensor
|
|
1024
|
+
) -> torch.Tensor:
|
|
1025
|
+
output = d.clone()
|
|
1026
|
+
idx = idx.long()
|
|
1027
|
+
|
|
1028
|
+
idx_shape = idx.shape[:-1]
|
|
1029
|
+
last_dim = idx.shape[-1]
|
|
1030
|
+
|
|
1031
|
+
flat_idx = idx.reshape(-1, last_dim)
|
|
1032
|
+
flat_upd = upd.reshape(-1, *upd.shape[len(idx_shape) :])
|
|
1033
|
+
|
|
1034
|
+
for i in range(flat_idx.shape[0]):
|
|
1035
|
+
data_idx = tuple(flat_idx[i].tolist())
|
|
1036
|
+
output[data_idx] = torch.minimum(output[data_idx], flat_upd[i])
|
|
1037
|
+
|
|
1038
|
+
return output
|
|
1039
|
+
|
|
1040
|
+
if reduction == "add":
|
|
1041
|
+
return builder.call_function(_scatter_nd_add, args=(data, indices, updates))
|
|
1042
|
+
elif reduction == "mul":
|
|
1043
|
+
return builder.call_function(_scatter_nd_mul, args=(data, indices, updates))
|
|
1044
|
+
elif reduction == "max":
|
|
1045
|
+
return builder.call_function(_scatter_nd_max, args=(data, indices, updates))
|
|
1046
|
+
elif reduction == "min":
|
|
1047
|
+
return builder.call_function(_scatter_nd_min, args=(data, indices, updates))
|
|
1048
|
+
else:
|
|
1049
|
+
return builder.call_function(_scatter_nd_none, args=(data, indices, updates))
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
# =============================================================================
|
|
1053
|
+
# Select and Compress operators
|
|
1054
|
+
# =============================================================================
|
|
1055
|
+
|
|
1056
|
+
|
|
1057
|
+
@register("Select")
|
|
1058
|
+
def select_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
1059
|
+
"""Select elements based on indices (like advanced indexing)."""
|
|
1060
|
+
data = builder.get_value(node.input[0])
|
|
1061
|
+
indices = builder.get_value(node.input[1])
|
|
1062
|
+
|
|
1063
|
+
return builder.call_function(torch.index_select, args=(data, 0, indices))
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
@register("Compress")
|
|
1067
|
+
def compress_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
1068
|
+
"""Select elements based on a boolean condition tensor."""
|
|
1069
|
+
data = builder.get_value(node.input[0])
|
|
1070
|
+
condition = builder.get_value(node.input[1])
|
|
1071
|
+
|
|
1072
|
+
axis = get_attribute(node, "axis", None)
|
|
1073
|
+
|
|
1074
|
+
if axis is not None:
|
|
1075
|
+
|
|
1076
|
+
def _compress_axis(d: torch.Tensor, c: torch.Tensor, ax: int) -> torch.Tensor:
|
|
1077
|
+
# Get indices where condition is True
|
|
1078
|
+
indices = torch.nonzero(c, as_tuple=True)[0]
|
|
1079
|
+
return torch.index_select(d, ax, indices)
|
|
1080
|
+
|
|
1081
|
+
return builder.call_function(_compress_axis, args=(data, condition, axis))
|
|
1082
|
+
else:
|
|
1083
|
+
# Flatten and compress
|
|
1084
|
+
def _compress_flat(d: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
|
1085
|
+
return d.flatten()[c.flatten().bool()]
|
|
1086
|
+
|
|
1087
|
+
return builder.call_function(_compress_flat, args=(data, condition))
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
# =============================================================================
|
|
1091
|
+
# TensorScatter operator (for KV cache updates in LLMs)
|
|
1092
|
+
# =============================================================================
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
@register("TensorScatter", since_version=24)
|
|
1096
|
+
def tensor_scatter(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
1097
|
+
"""TensorScatter for KV cache updates.
|
|
1098
|
+
|
|
1099
|
+
Updates a cache tensor at specified indices along a given axis.
|
|
1100
|
+
Commonly used for key/value cache updates in LLM attention.
|
|
1101
|
+
|
|
1102
|
+
Inputs:
|
|
1103
|
+
past_cache: Cache tensor (batch_size, D1, ..., max_sequence_length, ..., Dn)
|
|
1104
|
+
update: Update tensor (batch_size, D1, ..., sequence_length, ..., Dn)
|
|
1105
|
+
write_indices (optional): Start indices per batch sample (batch_size,)
|
|
1106
|
+
|
|
1107
|
+
Attributes:
|
|
1108
|
+
axis: Sequence dimension (default -2)
|
|
1109
|
+
mode: 'linear' or 'circular' (default 'linear')
|
|
1110
|
+
"""
|
|
1111
|
+
past_cache = builder.get_value(node.input[0])
|
|
1112
|
+
update = builder.get_value(node.input[1])
|
|
1113
|
+
write_indices = get_optional_input(builder, node, 2)
|
|
1114
|
+
|
|
1115
|
+
axis = get_attribute(node, "axis", -2)
|
|
1116
|
+
mode = get_attribute(node, "mode", "linear")
|
|
1117
|
+
|
|
1118
|
+
def _tensor_scatter(
|
|
1119
|
+
cache: torch.Tensor,
|
|
1120
|
+
upd: torch.Tensor,
|
|
1121
|
+
write_idx: torch.Tensor | None,
|
|
1122
|
+
ax: int,
|
|
1123
|
+
scatter_mode: str,
|
|
1124
|
+
) -> torch.Tensor:
|
|
1125
|
+
output = cache.clone()
|
|
1126
|
+
|
|
1127
|
+
# Handle negative axis
|
|
1128
|
+
if ax < 0:
|
|
1129
|
+
ax = cache.ndim + ax
|
|
1130
|
+
|
|
1131
|
+
batch_size = cache.shape[0]
|
|
1132
|
+
max_seq_len = cache.shape[ax]
|
|
1133
|
+
seq_len = upd.shape[ax]
|
|
1134
|
+
|
|
1135
|
+
# Default write_indices to zeros if not provided
|
|
1136
|
+
if write_idx is None:
|
|
1137
|
+
write_idx = torch.zeros(batch_size, dtype=torch.int64, device=cache.device)
|
|
1138
|
+
|
|
1139
|
+
# For each batch element, copy the update into the cache at the specified position
|
|
1140
|
+
for b in range(batch_size):
|
|
1141
|
+
start_idx = int(write_idx[b].item())
|
|
1142
|
+
|
|
1143
|
+
for s in range(seq_len):
|
|
1144
|
+
if scatter_mode == "circular":
|
|
1145
|
+
cache_idx = (start_idx + s) % max_seq_len
|
|
1146
|
+
else:
|
|
1147
|
+
cache_idx = start_idx + s
|
|
1148
|
+
|
|
1149
|
+
# Build the index tuple for the cache and update tensors
|
|
1150
|
+
# For cache: (b, D1, ..., cache_idx, ..., Dn)
|
|
1151
|
+
# For update: (b, D1, ..., s, ..., Dn)
|
|
1152
|
+
cache_slices = [b] + [slice(None)] * (ax - 1) + [cache_idx]
|
|
1153
|
+
update_slices = [b] + [slice(None)] * (ax - 1) + [s]
|
|
1154
|
+
|
|
1155
|
+
output[tuple(cache_slices)] = upd[tuple(update_slices)]
|
|
1156
|
+
|
|
1157
|
+
return output
|
|
1158
|
+
|
|
1159
|
+
return builder.call_function(
|
|
1160
|
+
_tensor_scatter, args=(past_cache, update, write_indices, axis, mode)
|
|
1161
|
+
)
|