onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/sequence.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Sequence operators."""
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..op_registry import register
|
|
10
|
+
from ..utils.attributes import get_attribute
|
|
11
|
+
from ..utils.op_helpers import get_optional_input
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ..graph_builder import GraphBuilder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register("SequenceConstruct")
|
|
18
|
+
def sequence_construct(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
19
|
+
"""Construct a sequence (list) from input tensors."""
|
|
20
|
+
inputs = [builder.get_value(name) for name in node.input]
|
|
21
|
+
return builder.call_function(lambda *args: list(args), args=tuple(inputs))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@register("SequenceAt")
|
|
25
|
+
def sequence_at(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
26
|
+
"""Get element at position from sequence."""
|
|
27
|
+
seq = builder.get_value(node.input[0])
|
|
28
|
+
position = builder.get_value(node.input[1])
|
|
29
|
+
|
|
30
|
+
def _seq_at(s: list, p: torch.Tensor) -> torch.Tensor:
|
|
31
|
+
idx = int(p.item()) if hasattr(p, "item") else int(p)
|
|
32
|
+
return s[idx]
|
|
33
|
+
|
|
34
|
+
return builder.call_function(_seq_at, args=(seq, position))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@register("SequenceLength")
|
|
38
|
+
def sequence_length(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
39
|
+
"""Get length of sequence."""
|
|
40
|
+
seq = builder.get_value(node.input[0])
|
|
41
|
+
|
|
42
|
+
def _seq_len(s: list) -> torch.Tensor:
|
|
43
|
+
return torch.tensor(len(s), dtype=torch.int64)
|
|
44
|
+
|
|
45
|
+
return builder.call_function(_seq_len, args=(seq,))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@register("SequenceEmpty")
|
|
49
|
+
def sequence_empty(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
50
|
+
"""Create an empty sequence."""
|
|
51
|
+
return builder.call_function(list, args=())
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@register("SequenceInsert")
|
|
55
|
+
def sequence_insert(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
56
|
+
"""Insert tensor into sequence at position."""
|
|
57
|
+
seq = builder.get_value(node.input[0])
|
|
58
|
+
tensor = builder.get_value(node.input[1])
|
|
59
|
+
position = get_optional_input(builder, node, 2)
|
|
60
|
+
|
|
61
|
+
if position is not None:
|
|
62
|
+
|
|
63
|
+
def _seq_insert(s: list, t: torch.Tensor, p: torch.Tensor) -> list:
|
|
64
|
+
idx = int(p.item()) if hasattr(p, "item") else int(p)
|
|
65
|
+
return s[:idx] + [t] + s[idx:]
|
|
66
|
+
|
|
67
|
+
return builder.call_function(_seq_insert, args=(seq, tensor, position))
|
|
68
|
+
else:
|
|
69
|
+
|
|
70
|
+
def _seq_append(s: list, t: torch.Tensor) -> list:
|
|
71
|
+
return s + [t]
|
|
72
|
+
|
|
73
|
+
return builder.call_function(_seq_append, args=(seq, tensor))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@register("SequenceErase")
|
|
77
|
+
def sequence_erase(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
78
|
+
"""Remove element from sequence at position."""
|
|
79
|
+
seq = builder.get_value(node.input[0])
|
|
80
|
+
position = get_optional_input(builder, node, 1)
|
|
81
|
+
|
|
82
|
+
if position is not None:
|
|
83
|
+
|
|
84
|
+
def _seq_erase(s: list, p: torch.Tensor) -> list:
|
|
85
|
+
idx = int(p.item()) if hasattr(p, "item") else int(p)
|
|
86
|
+
return s[:idx] + s[idx + 1 :]
|
|
87
|
+
|
|
88
|
+
return builder.call_function(_seq_erase, args=(seq, position))
|
|
89
|
+
else:
|
|
90
|
+
|
|
91
|
+
def _seq_pop(s: list) -> list:
|
|
92
|
+
return s[:-1]
|
|
93
|
+
|
|
94
|
+
return builder.call_function(_seq_pop, args=(seq,))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@register("ConcatFromSequence")
|
|
98
|
+
def concat_from_sequence(
|
|
99
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
100
|
+
) -> torch.fx.Node:
|
|
101
|
+
"""Concatenate tensors from sequence."""
|
|
102
|
+
seq = builder.get_value(node.input[0])
|
|
103
|
+
|
|
104
|
+
axis = get_attribute(node, "axis", 0)
|
|
105
|
+
new_axis = get_attribute(node, "new_axis", 0)
|
|
106
|
+
|
|
107
|
+
if new_axis:
|
|
108
|
+
|
|
109
|
+
def _stack_seq(s: list, ax: int) -> torch.Tensor:
|
|
110
|
+
return torch.stack(s, dim=ax)
|
|
111
|
+
|
|
112
|
+
return builder.call_function(_stack_seq, args=(seq, axis))
|
|
113
|
+
else:
|
|
114
|
+
|
|
115
|
+
def _concat_seq(s: list, ax: int) -> torch.Tensor:
|
|
116
|
+
return torch.cat(s, dim=ax)
|
|
117
|
+
|
|
118
|
+
return builder.call_function(_concat_seq, args=(seq, axis))
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@register("SplitToSequence")
|
|
122
|
+
def split_to_sequence(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
123
|
+
"""Split tensor into sequence of tensors."""
|
|
124
|
+
x = builder.get_value(node.input[0])
|
|
125
|
+
split = get_optional_input(builder, node, 1)
|
|
126
|
+
|
|
127
|
+
axis = get_attribute(node, "axis", 0)
|
|
128
|
+
keepdims = get_attribute(node, "keepdims", 1)
|
|
129
|
+
|
|
130
|
+
if split is not None:
|
|
131
|
+
|
|
132
|
+
def _split_seq(t: torch.Tensor, s: torch.Tensor, ax: int, keep: int) -> list:
|
|
133
|
+
sizes = s.tolist() if hasattr(s, "tolist") else [s]
|
|
134
|
+
# Handle scalar split value (equal splits of size s)
|
|
135
|
+
if isinstance(sizes, (int, float)):
|
|
136
|
+
sizes = int(sizes)
|
|
137
|
+
splits = list(torch.split(t, sizes, dim=ax))
|
|
138
|
+
if not keep:
|
|
139
|
+
# Squeeze only if split size is 1 for each chunk
|
|
140
|
+
splits = [
|
|
141
|
+
chunk.squeeze(ax) if chunk.shape[ax] == 1 else chunk
|
|
142
|
+
for chunk in splits
|
|
143
|
+
]
|
|
144
|
+
return splits
|
|
145
|
+
|
|
146
|
+
return builder.call_function(_split_seq, args=(x, split, axis, keepdims))
|
|
147
|
+
else:
|
|
148
|
+
|
|
149
|
+
def _split_ones(t: torch.Tensor, ax: int, keep: int) -> list:
|
|
150
|
+
splits = list(torch.split(t, 1, dim=ax))
|
|
151
|
+
if not keep:
|
|
152
|
+
splits = [chunk.squeeze(ax) for chunk in splits]
|
|
153
|
+
return splits
|
|
154
|
+
|
|
155
|
+
return builder.call_function(_split_ones, args=(x, axis, keepdims))
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@register("ReverseSequence")
|
|
159
|
+
def reverse_sequence(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
160
|
+
"""Reverse sequences in a tensor."""
|
|
161
|
+
x = builder.get_value(node.input[0])
|
|
162
|
+
sequence_lens = builder.get_value(node.input[1])
|
|
163
|
+
|
|
164
|
+
batch_axis = get_attribute(node, "batch_axis", 1)
|
|
165
|
+
time_axis = get_attribute(node, "time_axis", 0)
|
|
166
|
+
|
|
167
|
+
def _reverse_sequence(x, sequence_lens, batch_axis, time_axis):
|
|
168
|
+
result = x.clone()
|
|
169
|
+
for i, seq_len in enumerate(sequence_lens):
|
|
170
|
+
seq_len_val = int(
|
|
171
|
+
seq_len.item() if isinstance(seq_len, torch.Tensor) else seq_len
|
|
172
|
+
)
|
|
173
|
+
if seq_len_val <= 1:
|
|
174
|
+
# Nothing to reverse for length 0 or 1
|
|
175
|
+
continue
|
|
176
|
+
|
|
177
|
+
# Create indices for this batch
|
|
178
|
+
idx = [slice(None)] * x.dim()
|
|
179
|
+
idx[batch_axis] = i
|
|
180
|
+
idx[time_axis] = slice(None, seq_len_val)
|
|
181
|
+
|
|
182
|
+
# Extract the subsequence, reverse it along time_axis, and put it back
|
|
183
|
+
subsequence = x[tuple(idx)]
|
|
184
|
+
# Use torch.flip to reverse along the time_axis dimension
|
|
185
|
+
# Since we've already indexed by batch_axis, the time_axis in the
|
|
186
|
+
# subsequence is shifted if batch_axis < time_axis
|
|
187
|
+
effective_time_axis = time_axis if batch_axis > time_axis else time_axis - 1
|
|
188
|
+
reversed_subseq = torch.flip(subsequence, dims=[effective_time_axis])
|
|
189
|
+
result[tuple(idx)] = reversed_subseq
|
|
190
|
+
|
|
191
|
+
return result
|
|
192
|
+
|
|
193
|
+
return builder.call_function(
|
|
194
|
+
_reverse_sequence, args=(x, sequence_lens, batch_axis, time_axis)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@register("Optional", since_version=15)
|
|
199
|
+
def optional_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
200
|
+
"""Create an optional value.
|
|
201
|
+
|
|
202
|
+
If an input is provided, wraps it in a list to represent "optional with value".
|
|
203
|
+
If no input is provided, creates an empty optional (empty list or None).
|
|
204
|
+
|
|
205
|
+
Representation:
|
|
206
|
+
- optional(tensor): [tensor]
|
|
207
|
+
- optional(sequence): [[tensors...]]
|
|
208
|
+
- empty optional: [] or None
|
|
209
|
+
"""
|
|
210
|
+
if len(node.input) > 0 and node.input[0]:
|
|
211
|
+
# Wrap the input in a list to represent optional with value
|
|
212
|
+
value = builder.get_value(node.input[0])
|
|
213
|
+
|
|
214
|
+
def _wrap_optional(v):
|
|
215
|
+
return [v]
|
|
216
|
+
|
|
217
|
+
return builder.call_function(_wrap_optional, args=(value,))
|
|
218
|
+
else:
|
|
219
|
+
# Create an empty optional (empty list)
|
|
220
|
+
return builder.call_function(list, args=())
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@register("OptionalHasElement", since_version=15)
|
|
224
|
+
def optional_has_element(
|
|
225
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
226
|
+
) -> torch.fx.Node:
|
|
227
|
+
"""Check if an optional value has an element.
|
|
228
|
+
|
|
229
|
+
An optional is considered empty if:
|
|
230
|
+
- It is None
|
|
231
|
+
- It is an empty list (ONNX representation of empty optional)
|
|
232
|
+
"""
|
|
233
|
+
optional_input = get_optional_input(builder, node, 0)
|
|
234
|
+
|
|
235
|
+
if optional_input is not None:
|
|
236
|
+
|
|
237
|
+
def _has_element(opt):
|
|
238
|
+
# Handle list representation of optional (used in ONNX test data)
|
|
239
|
+
if isinstance(opt, list):
|
|
240
|
+
return torch.tensor(len(opt) > 0, dtype=torch.bool)
|
|
241
|
+
# Handle None representation
|
|
242
|
+
return torch.tensor(opt is not None, dtype=torch.bool)
|
|
243
|
+
|
|
244
|
+
return builder.call_function(_has_element, args=(optional_input,))
|
|
245
|
+
else:
|
|
246
|
+
# No input provided means empty optional
|
|
247
|
+
return builder.call_function(lambda: torch.tensor(False, dtype=torch.bool))
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@register("OptionalGetElement", since_version=15)
|
|
251
|
+
def optional_get_element(
|
|
252
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
253
|
+
) -> torch.fx.Node:
|
|
254
|
+
"""Get the element from an optional value.
|
|
255
|
+
|
|
256
|
+
Handles both None and list representations of optionals.
|
|
257
|
+
Raises an error if the optional is empty.
|
|
258
|
+
|
|
259
|
+
The behavior depends on the input type:
|
|
260
|
+
- optional(tensor): [tensor] → return tensor
|
|
261
|
+
- optional(sequence): [[t1, t2, ...]] → return [t1, t2, ...]
|
|
262
|
+
- sequence (plain): [t1, t2, ...] → return as-is (sequence IS the value)
|
|
263
|
+
"""
|
|
264
|
+
input_name = node.input[0]
|
|
265
|
+
optional_input = builder.get_value(input_name)
|
|
266
|
+
|
|
267
|
+
# Check if the input is declared as optional type in the model
|
|
268
|
+
is_optional = builder.is_optional_type(input_name)
|
|
269
|
+
|
|
270
|
+
if is_optional:
|
|
271
|
+
# Input is optional type - need to unwrap
|
|
272
|
+
# However, in ONNX Loops, the optional type can be "refined" to a non-optional
|
|
273
|
+
# after the first iteration. We need to handle both cases:
|
|
274
|
+
# 1. True optional wrapper: [value] (length 1) -> return value
|
|
275
|
+
# 2. Plain value after refinement: return as-is
|
|
276
|
+
|
|
277
|
+
def _get_element_from_optional(opt):
|
|
278
|
+
if opt is None:
|
|
279
|
+
raise ValueError("Cannot get element from empty optional")
|
|
280
|
+
if isinstance(opt, list):
|
|
281
|
+
if len(opt) == 0:
|
|
282
|
+
raise ValueError("Cannot get element from empty optional")
|
|
283
|
+
if len(opt) == 1:
|
|
284
|
+
# This looks like an optional wrapper [value] - unwrap it
|
|
285
|
+
return opt[0]
|
|
286
|
+
else:
|
|
287
|
+
# Length > 1: this is a plain sequence (after loop refinement)
|
|
288
|
+
# Return as-is since it's already unwrapped
|
|
289
|
+
return opt
|
|
290
|
+
# Not a list - return as-is (tensor, etc.)
|
|
291
|
+
return opt
|
|
292
|
+
|
|
293
|
+
return builder.call_function(_get_element_from_optional, args=(optional_input,))
|
|
294
|
+
else:
|
|
295
|
+
# Input is a sequence type - return as-is
|
|
296
|
+
# (The sequence itself is used as the "element" of an implicit optional)
|
|
297
|
+
|
|
298
|
+
def _get_element_from_sequence(seq):
|
|
299
|
+
if seq is None:
|
|
300
|
+
raise ValueError("Cannot get element from empty optional")
|
|
301
|
+
# Return sequence as-is
|
|
302
|
+
return seq
|
|
303
|
+
|
|
304
|
+
return builder.call_function(_get_element_from_sequence, args=(optional_input,))
|