onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,304 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Sequence operators."""
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import onnx
7
+ import torch
8
+
9
+ from ..op_registry import register
10
+ from ..utils.attributes import get_attribute
11
+ from ..utils.op_helpers import get_optional_input
12
+
13
+ if TYPE_CHECKING:
14
+ from ..graph_builder import GraphBuilder
15
+
16
+
17
+ @register("SequenceConstruct")
18
+ def sequence_construct(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
19
+ """Construct a sequence (list) from input tensors."""
20
+ inputs = [builder.get_value(name) for name in node.input]
21
+ return builder.call_function(lambda *args: list(args), args=tuple(inputs))
22
+
23
+
24
+ @register("SequenceAt")
25
+ def sequence_at(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
26
+ """Get element at position from sequence."""
27
+ seq = builder.get_value(node.input[0])
28
+ position = builder.get_value(node.input[1])
29
+
30
+ def _seq_at(s: list, p: torch.Tensor) -> torch.Tensor:
31
+ idx = int(p.item()) if hasattr(p, "item") else int(p)
32
+ return s[idx]
33
+
34
+ return builder.call_function(_seq_at, args=(seq, position))
35
+
36
+
37
+ @register("SequenceLength")
38
+ def sequence_length(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
39
+ """Get length of sequence."""
40
+ seq = builder.get_value(node.input[0])
41
+
42
+ def _seq_len(s: list) -> torch.Tensor:
43
+ return torch.tensor(len(s), dtype=torch.int64)
44
+
45
+ return builder.call_function(_seq_len, args=(seq,))
46
+
47
+
48
+ @register("SequenceEmpty")
49
+ def sequence_empty(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
50
+ """Create an empty sequence."""
51
+ return builder.call_function(list, args=())
52
+
53
+
54
+ @register("SequenceInsert")
55
+ def sequence_insert(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
56
+ """Insert tensor into sequence at position."""
57
+ seq = builder.get_value(node.input[0])
58
+ tensor = builder.get_value(node.input[1])
59
+ position = get_optional_input(builder, node, 2)
60
+
61
+ if position is not None:
62
+
63
+ def _seq_insert(s: list, t: torch.Tensor, p: torch.Tensor) -> list:
64
+ idx = int(p.item()) if hasattr(p, "item") else int(p)
65
+ return s[:idx] + [t] + s[idx:]
66
+
67
+ return builder.call_function(_seq_insert, args=(seq, tensor, position))
68
+ else:
69
+
70
+ def _seq_append(s: list, t: torch.Tensor) -> list:
71
+ return s + [t]
72
+
73
+ return builder.call_function(_seq_append, args=(seq, tensor))
74
+
75
+
76
+ @register("SequenceErase")
77
+ def sequence_erase(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
78
+ """Remove element from sequence at position."""
79
+ seq = builder.get_value(node.input[0])
80
+ position = get_optional_input(builder, node, 1)
81
+
82
+ if position is not None:
83
+
84
+ def _seq_erase(s: list, p: torch.Tensor) -> list:
85
+ idx = int(p.item()) if hasattr(p, "item") else int(p)
86
+ return s[:idx] + s[idx + 1 :]
87
+
88
+ return builder.call_function(_seq_erase, args=(seq, position))
89
+ else:
90
+
91
+ def _seq_pop(s: list) -> list:
92
+ return s[:-1]
93
+
94
+ return builder.call_function(_seq_pop, args=(seq,))
95
+
96
+
97
+ @register("ConcatFromSequence")
98
+ def concat_from_sequence(
99
+ builder: "GraphBuilder", node: onnx.NodeProto
100
+ ) -> torch.fx.Node:
101
+ """Concatenate tensors from sequence."""
102
+ seq = builder.get_value(node.input[0])
103
+
104
+ axis = get_attribute(node, "axis", 0)
105
+ new_axis = get_attribute(node, "new_axis", 0)
106
+
107
+ if new_axis:
108
+
109
+ def _stack_seq(s: list, ax: int) -> torch.Tensor:
110
+ return torch.stack(s, dim=ax)
111
+
112
+ return builder.call_function(_stack_seq, args=(seq, axis))
113
+ else:
114
+
115
+ def _concat_seq(s: list, ax: int) -> torch.Tensor:
116
+ return torch.cat(s, dim=ax)
117
+
118
+ return builder.call_function(_concat_seq, args=(seq, axis))
119
+
120
+
121
+ @register("SplitToSequence")
122
+ def split_to_sequence(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
123
+ """Split tensor into sequence of tensors."""
124
+ x = builder.get_value(node.input[0])
125
+ split = get_optional_input(builder, node, 1)
126
+
127
+ axis = get_attribute(node, "axis", 0)
128
+ keepdims = get_attribute(node, "keepdims", 1)
129
+
130
+ if split is not None:
131
+
132
+ def _split_seq(t: torch.Tensor, s: torch.Tensor, ax: int, keep: int) -> list:
133
+ sizes = s.tolist() if hasattr(s, "tolist") else [s]
134
+ # Handle scalar split value (equal splits of size s)
135
+ if isinstance(sizes, (int, float)):
136
+ sizes = int(sizes)
137
+ splits = list(torch.split(t, sizes, dim=ax))
138
+ if not keep:
139
+ # Squeeze only if split size is 1 for each chunk
140
+ splits = [
141
+ chunk.squeeze(ax) if chunk.shape[ax] == 1 else chunk
142
+ for chunk in splits
143
+ ]
144
+ return splits
145
+
146
+ return builder.call_function(_split_seq, args=(x, split, axis, keepdims))
147
+ else:
148
+
149
+ def _split_ones(t: torch.Tensor, ax: int, keep: int) -> list:
150
+ splits = list(torch.split(t, 1, dim=ax))
151
+ if not keep:
152
+ splits = [chunk.squeeze(ax) for chunk in splits]
153
+ return splits
154
+
155
+ return builder.call_function(_split_ones, args=(x, axis, keepdims))
156
+
157
+
158
+ @register("ReverseSequence")
159
+ def reverse_sequence(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
160
+ """Reverse sequences in a tensor."""
161
+ x = builder.get_value(node.input[0])
162
+ sequence_lens = builder.get_value(node.input[1])
163
+
164
+ batch_axis = get_attribute(node, "batch_axis", 1)
165
+ time_axis = get_attribute(node, "time_axis", 0)
166
+
167
+ def _reverse_sequence(x, sequence_lens, batch_axis, time_axis):
168
+ result = x.clone()
169
+ for i, seq_len in enumerate(sequence_lens):
170
+ seq_len_val = int(
171
+ seq_len.item() if isinstance(seq_len, torch.Tensor) else seq_len
172
+ )
173
+ if seq_len_val <= 1:
174
+ # Nothing to reverse for length 0 or 1
175
+ continue
176
+
177
+ # Create indices for this batch
178
+ idx = [slice(None)] * x.dim()
179
+ idx[batch_axis] = i
180
+ idx[time_axis] = slice(None, seq_len_val)
181
+
182
+ # Extract the subsequence, reverse it along time_axis, and put it back
183
+ subsequence = x[tuple(idx)]
184
+ # Use torch.flip to reverse along the time_axis dimension
185
+ # Since we've already indexed by batch_axis, the time_axis in the
186
+ # subsequence is shifted if batch_axis < time_axis
187
+ effective_time_axis = time_axis if batch_axis > time_axis else time_axis - 1
188
+ reversed_subseq = torch.flip(subsequence, dims=[effective_time_axis])
189
+ result[tuple(idx)] = reversed_subseq
190
+
191
+ return result
192
+
193
+ return builder.call_function(
194
+ _reverse_sequence, args=(x, sequence_lens, batch_axis, time_axis)
195
+ )
196
+
197
+
198
+ @register("Optional", since_version=15)
199
+ def optional_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
200
+ """Create an optional value.
201
+
202
+ If an input is provided, wraps it in a list to represent "optional with value".
203
+ If no input is provided, creates an empty optional (empty list or None).
204
+
205
+ Representation:
206
+ - optional(tensor): [tensor]
207
+ - optional(sequence): [[tensors...]]
208
+ - empty optional: [] or None
209
+ """
210
+ if len(node.input) > 0 and node.input[0]:
211
+ # Wrap the input in a list to represent optional with value
212
+ value = builder.get_value(node.input[0])
213
+
214
+ def _wrap_optional(v):
215
+ return [v]
216
+
217
+ return builder.call_function(_wrap_optional, args=(value,))
218
+ else:
219
+ # Create an empty optional (empty list)
220
+ return builder.call_function(list, args=())
221
+
222
+
223
+ @register("OptionalHasElement", since_version=15)
224
+ def optional_has_element(
225
+ builder: "GraphBuilder", node: onnx.NodeProto
226
+ ) -> torch.fx.Node:
227
+ """Check if an optional value has an element.
228
+
229
+ An optional is considered empty if:
230
+ - It is None
231
+ - It is an empty list (ONNX representation of empty optional)
232
+ """
233
+ optional_input = get_optional_input(builder, node, 0)
234
+
235
+ if optional_input is not None:
236
+
237
+ def _has_element(opt):
238
+ # Handle list representation of optional (used in ONNX test data)
239
+ if isinstance(opt, list):
240
+ return torch.tensor(len(opt) > 0, dtype=torch.bool)
241
+ # Handle None representation
242
+ return torch.tensor(opt is not None, dtype=torch.bool)
243
+
244
+ return builder.call_function(_has_element, args=(optional_input,))
245
+ else:
246
+ # No input provided means empty optional
247
+ return builder.call_function(lambda: torch.tensor(False, dtype=torch.bool))
248
+
249
+
250
+ @register("OptionalGetElement", since_version=15)
251
+ def optional_get_element(
252
+ builder: "GraphBuilder", node: onnx.NodeProto
253
+ ) -> torch.fx.Node:
254
+ """Get the element from an optional value.
255
+
256
+ Handles both None and list representations of optionals.
257
+ Raises an error if the optional is empty.
258
+
259
+ The behavior depends on the input type:
260
+ - optional(tensor): [tensor] → return tensor
261
+ - optional(sequence): [[t1, t2, ...]] → return [t1, t2, ...]
262
+ - sequence (plain): [t1, t2, ...] → return as-is (sequence IS the value)
263
+ """
264
+ input_name = node.input[0]
265
+ optional_input = builder.get_value(input_name)
266
+
267
+ # Check if the input is declared as optional type in the model
268
+ is_optional = builder.is_optional_type(input_name)
269
+
270
+ if is_optional:
271
+ # Input is optional type - need to unwrap
272
+ # However, in ONNX Loops, the optional type can be "refined" to a non-optional
273
+ # after the first iteration. We need to handle both cases:
274
+ # 1. True optional wrapper: [value] (length 1) -> return value
275
+ # 2. Plain value after refinement: return as-is
276
+
277
+ def _get_element_from_optional(opt):
278
+ if opt is None:
279
+ raise ValueError("Cannot get element from empty optional")
280
+ if isinstance(opt, list):
281
+ if len(opt) == 0:
282
+ raise ValueError("Cannot get element from empty optional")
283
+ if len(opt) == 1:
284
+ # This looks like an optional wrapper [value] - unwrap it
285
+ return opt[0]
286
+ else:
287
+ # Length > 1: this is a plain sequence (after loop refinement)
288
+ # Return as-is since it's already unwrapped
289
+ return opt
290
+ # Not a list - return as-is (tensor, etc.)
291
+ return opt
292
+
293
+ return builder.call_function(_get_element_from_optional, args=(optional_input,))
294
+ else:
295
+ # Input is a sequence type - return as-is
296
+ # (The sequence itself is used as the "element" of an implicit optional)
297
+
298
+ def _get_element_from_sequence(seq):
299
+ if seq is None:
300
+ raise ValueError("Cannot get element from empty optional")
301
+ # Return sequence as-is
302
+ return seq
303
+
304
+ return builder.call_function(_get_element_from_sequence, args=(optional_input,))