onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,947 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Control flow operators.
|
|
3
|
+
|
|
4
|
+
This module implements ONNX control flow operators like Loop and If.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import onnx
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.fx
|
|
13
|
+
from onnx import numpy_helper
|
|
14
|
+
|
|
15
|
+
from ..utils.names import sanitize_name
|
|
16
|
+
from ..op_registry import register
|
|
17
|
+
from ..utils.attributes import get_attribute
|
|
18
|
+
from ..utils.op_helpers import get_optional_input
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from ..graph_builder import GraphBuilder
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _collect_all_subgraph_inputs(node: onnx.NodeProto) -> set:
|
|
25
|
+
"""Recursively collect all inputs from a node including nested subgraph inputs.
|
|
26
|
+
|
|
27
|
+
For control flow nodes like If and Loop, this also collects inputs that
|
|
28
|
+
are referenced by nested subgraphs.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
node : onnx.NodeProto
|
|
33
|
+
The ONNX node to collect inputs from.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
set
|
|
38
|
+
Set of all input names referenced by this node and its nested subgraphs.
|
|
39
|
+
"""
|
|
40
|
+
inputs = set(node.input)
|
|
41
|
+
|
|
42
|
+
# Collect inputs from subgraphs (for If, Loop, etc.)
|
|
43
|
+
for attr in node.attribute:
|
|
44
|
+
if attr.type == onnx.AttributeProto.GRAPH:
|
|
45
|
+
subgraph = attr.g
|
|
46
|
+
# Collect subgraph's own initializers and inputs as local values
|
|
47
|
+
local_values = set()
|
|
48
|
+
for init in subgraph.initializer:
|
|
49
|
+
local_values.add(init.name)
|
|
50
|
+
for inp in subgraph.input:
|
|
51
|
+
local_values.add(inp.name)
|
|
52
|
+
|
|
53
|
+
# Recursively collect inputs from subgraph nodes
|
|
54
|
+
for sub_node in subgraph.node:
|
|
55
|
+
sub_inputs = _collect_all_subgraph_inputs(sub_node)
|
|
56
|
+
# Add outputs of this subgraph node to local values
|
|
57
|
+
for out in sub_node.output:
|
|
58
|
+
if out:
|
|
59
|
+
local_values.add(out)
|
|
60
|
+
# Inputs not satisfied locally are outer references
|
|
61
|
+
for sub_inp in sub_inputs:
|
|
62
|
+
if sub_inp and sub_inp not in local_values:
|
|
63
|
+
inputs.add(sub_inp)
|
|
64
|
+
|
|
65
|
+
return inputs
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _build_subgraph_module(
|
|
69
|
+
body_graph: onnx.GraphProto,
|
|
70
|
+
parent_env: Dict[str, torch.fx.Node],
|
|
71
|
+
parent_opset_versions: Dict[str, int],
|
|
72
|
+
parent_type_info: Optional[Dict[str, bool]] = None,
|
|
73
|
+
tensor_loader: Optional[Callable[[onnx.TensorProto], torch.Tensor]] = None,
|
|
74
|
+
) -> Tuple[torch.fx.GraphModule, List[str], List[str], List[str]]:
|
|
75
|
+
"""Build an FX GraphModule from an ONNX subgraph.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
body_graph : onnx.GraphProto
|
|
80
|
+
The ONNX subgraph to convert.
|
|
81
|
+
parent_env : Dict[str, torch.fx.Node]
|
|
82
|
+
Environment from parent graph (for accessing outer scope values).
|
|
83
|
+
parent_opset_versions : Dict[str, int]
|
|
84
|
+
Opset versions from parent model.
|
|
85
|
+
parent_type_info : Optional[Dict[str, bool]]
|
|
86
|
+
Mapping from value names to whether they are optional types in parent scope.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
Tuple[torch.fx.GraphModule, List[str], List[str], List[str]]
|
|
91
|
+
The FX GraphModule, list of input names, list of output names, and outer refs.
|
|
92
|
+
"""
|
|
93
|
+
from ..op_registry import get_handler
|
|
94
|
+
|
|
95
|
+
graph = torch.fx.Graph()
|
|
96
|
+
env: Dict[str, torch.fx.Node] = {}
|
|
97
|
+
constants: Dict[str, torch.Tensor] = {}
|
|
98
|
+
|
|
99
|
+
# Initialize parent type info
|
|
100
|
+
if parent_type_info is None:
|
|
101
|
+
parent_type_info = {}
|
|
102
|
+
|
|
103
|
+
# Get input and output names
|
|
104
|
+
input_names = [inp.name for inp in body_graph.input]
|
|
105
|
+
output_names = [out.name for out in body_graph.output]
|
|
106
|
+
|
|
107
|
+
# Load initializers from subgraph
|
|
108
|
+
initializer_map: Dict[str, torch.Tensor] = {}
|
|
109
|
+
for initializer in body_graph.initializer:
|
|
110
|
+
if tensor_loader is not None:
|
|
111
|
+
initializer_map[initializer.name] = tensor_loader(initializer)
|
|
112
|
+
else:
|
|
113
|
+
np_array = numpy_helper.to_array(initializer)
|
|
114
|
+
initializer_map[initializer.name] = torch.from_numpy(np_array.copy())
|
|
115
|
+
|
|
116
|
+
# Register initializers as constants
|
|
117
|
+
for name, tensor in initializer_map.items():
|
|
118
|
+
safe_name = sanitize_name(name)
|
|
119
|
+
constants[safe_name] = tensor
|
|
120
|
+
fx_node = graph.get_attr(safe_name)
|
|
121
|
+
env[name] = fx_node
|
|
122
|
+
|
|
123
|
+
# Create placeholders for subgraph inputs
|
|
124
|
+
for inp in body_graph.input:
|
|
125
|
+
if inp.name in env:
|
|
126
|
+
continue # Skip if already loaded as initializer
|
|
127
|
+
safe_name = sanitize_name(inp.name)
|
|
128
|
+
placeholder = graph.placeholder(safe_name)
|
|
129
|
+
env[inp.name] = placeholder
|
|
130
|
+
|
|
131
|
+
# Collect all values that will be produced by nodes in this subgraph
|
|
132
|
+
subgraph_outputs = set()
|
|
133
|
+
for node in body_graph.node:
|
|
134
|
+
for out in node.output:
|
|
135
|
+
if out:
|
|
136
|
+
subgraph_outputs.add(out)
|
|
137
|
+
|
|
138
|
+
# Add references to parent scope values that are used in the subgraph
|
|
139
|
+
# (including nested subgraphs). These will be passed as additional inputs.
|
|
140
|
+
outer_refs: List[str] = []
|
|
141
|
+
for node in body_graph.node:
|
|
142
|
+
# Collect all inputs including from nested subgraphs
|
|
143
|
+
all_inputs = _collect_all_subgraph_inputs(node)
|
|
144
|
+
for inp_name in all_inputs:
|
|
145
|
+
# Skip if empty, already in env, or will be produced by a node in this subgraph
|
|
146
|
+
if not inp_name or inp_name in env or inp_name in subgraph_outputs:
|
|
147
|
+
continue
|
|
148
|
+
if inp_name in parent_env:
|
|
149
|
+
if inp_name not in outer_refs: # Avoid duplicates
|
|
150
|
+
outer_refs.append(inp_name)
|
|
151
|
+
safe_name = sanitize_name(inp_name)
|
|
152
|
+
placeholder = graph.placeholder(f"outer_{safe_name}")
|
|
153
|
+
env[inp_name] = placeholder
|
|
154
|
+
|
|
155
|
+
# Create a minimal builder-like object for handler calls
|
|
156
|
+
class SubgraphBuilder:
|
|
157
|
+
def __init__(self):
|
|
158
|
+
self.graph = graph
|
|
159
|
+
self.env = env
|
|
160
|
+
self._opset_versions = parent_opset_versions
|
|
161
|
+
self._constants = constants
|
|
162
|
+
self._submodules: Dict[str, nn.Module] = {}
|
|
163
|
+
self.initializer_map = initializer_map
|
|
164
|
+
self._body_graph = body_graph
|
|
165
|
+
self._parent_type_info = parent_type_info
|
|
166
|
+
self._tensor_loader = tensor_loader
|
|
167
|
+
# Build type info for this subgraph (to pass to nested subgraphs)
|
|
168
|
+
self._type_info = self._build_type_info()
|
|
169
|
+
|
|
170
|
+
def load_tensor(self, tensor: onnx.TensorProto) -> torch.Tensor:
|
|
171
|
+
if self._tensor_loader is not None:
|
|
172
|
+
return self._tensor_loader(tensor)
|
|
173
|
+
np_array = numpy_helper.to_array(tensor)
|
|
174
|
+
return torch.from_numpy(np_array.copy())
|
|
175
|
+
|
|
176
|
+
def _build_type_info(self) -> Dict[str, bool]:
|
|
177
|
+
"""Build a mapping of value names to whether they are optional types."""
|
|
178
|
+
info: Dict[str, bool] = {}
|
|
179
|
+
# Include parent type info
|
|
180
|
+
info.update(self._parent_type_info)
|
|
181
|
+
# Add types from this subgraph
|
|
182
|
+
for value_info in self._body_graph.input:
|
|
183
|
+
info[value_info.name] = value_info.type.HasField("optional_type")
|
|
184
|
+
for value_info in self._body_graph.value_info:
|
|
185
|
+
info[value_info.name] = value_info.type.HasField("optional_type")
|
|
186
|
+
for value_info in self._body_graph.output:
|
|
187
|
+
info[value_info.name] = value_info.type.HasField("optional_type")
|
|
188
|
+
return info
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def opset_version(self) -> int:
|
|
192
|
+
return self._opset_versions.get("", 1)
|
|
193
|
+
|
|
194
|
+
def is_optional_type(self, name: str) -> bool:
|
|
195
|
+
"""Check if a value has optional type in the subgraph or parent scope."""
|
|
196
|
+
# First check the combined type info (includes parent scope)
|
|
197
|
+
if name in self._type_info:
|
|
198
|
+
return self._type_info[name]
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
def get_opset_version(self, domain: str = "") -> int:
|
|
202
|
+
return self._opset_versions.get(domain, 1)
|
|
203
|
+
|
|
204
|
+
def get_value(self, name: str) -> torch.fx.Node:
|
|
205
|
+
if name not in self.env:
|
|
206
|
+
raise KeyError(f"Value '{name}' not found in subgraph environment")
|
|
207
|
+
return self.env[name]
|
|
208
|
+
|
|
209
|
+
def has_value(self, name: str) -> bool:
|
|
210
|
+
return name in self.env
|
|
211
|
+
|
|
212
|
+
def call_function(
|
|
213
|
+
self,
|
|
214
|
+
func,
|
|
215
|
+
args: tuple = (),
|
|
216
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
217
|
+
) -> torch.fx.Node:
|
|
218
|
+
return self.graph.call_function(func, args=tuple(args), kwargs=kwargs or {})
|
|
219
|
+
|
|
220
|
+
def register_submodule(self, name: str, module: nn.Module) -> str:
|
|
221
|
+
safe_name = sanitize_name(name)
|
|
222
|
+
if safe_name in self._submodules:
|
|
223
|
+
counter = 0
|
|
224
|
+
while f"{safe_name}_{counter}" in self._submodules:
|
|
225
|
+
counter += 1
|
|
226
|
+
safe_name = f"{safe_name}_{counter}"
|
|
227
|
+
self._submodules[safe_name] = module
|
|
228
|
+
return safe_name
|
|
229
|
+
|
|
230
|
+
def call_module(
|
|
231
|
+
self,
|
|
232
|
+
module_name: str,
|
|
233
|
+
args: tuple = (),
|
|
234
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
235
|
+
) -> torch.fx.Node:
|
|
236
|
+
return self.graph.call_module(
|
|
237
|
+
module_name, args=tuple(args), kwargs=kwargs or {}
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
builder = SubgraphBuilder()
|
|
241
|
+
|
|
242
|
+
# Convert nodes
|
|
243
|
+
for node in body_graph.node:
|
|
244
|
+
domain = node.domain if node.domain else ""
|
|
245
|
+
opset = builder.get_opset_version(domain)
|
|
246
|
+
handler = get_handler(node.op_type, domain, opset)
|
|
247
|
+
if handler is None:
|
|
248
|
+
raise ValueError(
|
|
249
|
+
f"Unsupported operator in subgraph: {node.op_type} (domain={domain})"
|
|
250
|
+
)
|
|
251
|
+
fx_node = handler(builder, node)
|
|
252
|
+
|
|
253
|
+
# Handle outputs
|
|
254
|
+
if len(node.output) == 1:
|
|
255
|
+
env[node.output[0]] = fx_node
|
|
256
|
+
else:
|
|
257
|
+
for i, output_name in enumerate(node.output):
|
|
258
|
+
if output_name:
|
|
259
|
+
getitem_node = graph.call_function(
|
|
260
|
+
lambda x, idx=i: x[idx] if isinstance(x, (tuple, list)) else x,
|
|
261
|
+
args=(fx_node, i),
|
|
262
|
+
)
|
|
263
|
+
env[output_name] = getitem_node
|
|
264
|
+
|
|
265
|
+
# Create output - return tuple of output values
|
|
266
|
+
output_nodes = []
|
|
267
|
+
for out_name in output_names:
|
|
268
|
+
if out_name in env:
|
|
269
|
+
output_nodes.append(env[out_name])
|
|
270
|
+
else:
|
|
271
|
+
raise KeyError(f"Output '{out_name}' not found in subgraph environment")
|
|
272
|
+
|
|
273
|
+
if len(output_nodes) == 1:
|
|
274
|
+
graph.output(output_nodes[0])
|
|
275
|
+
else:
|
|
276
|
+
graph.output(tuple(output_nodes))
|
|
277
|
+
|
|
278
|
+
# Create the module
|
|
279
|
+
root_module = nn.Module()
|
|
280
|
+
for name, tensor in constants.items():
|
|
281
|
+
root_module.register_buffer(name, tensor)
|
|
282
|
+
for name, submod in builder._submodules.items():
|
|
283
|
+
root_module.add_module(name, submod)
|
|
284
|
+
|
|
285
|
+
module = torch.fx.GraphModule(root_module, graph)
|
|
286
|
+
module.graph.lint()
|
|
287
|
+
|
|
288
|
+
return module, input_names, output_names, outer_refs
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class LoopModule(nn.Module):
|
|
292
|
+
"""Module that executes an ONNX Loop."""
|
|
293
|
+
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
body_module: torch.fx.GraphModule,
|
|
297
|
+
n_loop_carried: int,
|
|
298
|
+
n_scan_outputs: int,
|
|
299
|
+
n_loop_vars: int,
|
|
300
|
+
n_outer_vars: int,
|
|
301
|
+
):
|
|
302
|
+
super().__init__()
|
|
303
|
+
self.body = body_module
|
|
304
|
+
self.n_loop_carried = n_loop_carried
|
|
305
|
+
self.n_scan_outputs = n_scan_outputs
|
|
306
|
+
self.n_loop_vars = n_loop_vars
|
|
307
|
+
self.n_outer_vars = n_outer_vars
|
|
308
|
+
|
|
309
|
+
def forward(self, max_iters, init_cond, *args) -> Tuple[torch.Tensor, ...]:
|
|
310
|
+
"""Execute the loop.
|
|
311
|
+
|
|
312
|
+
Args are: loop_vars..., outer_vals...
|
|
313
|
+
Returns final loop-carried values followed by concatenated scan outputs.
|
|
314
|
+
"""
|
|
315
|
+
# Split args into loop_vars and outer_vals
|
|
316
|
+
loop_vars = list(args[: self.n_loop_vars])
|
|
317
|
+
outer_vals = list(args[self.n_loop_vars :])
|
|
318
|
+
|
|
319
|
+
# Determine max iterations
|
|
320
|
+
if max_iters is not None:
|
|
321
|
+
max_i = (
|
|
322
|
+
int(max_iters.item()) if hasattr(max_iters, "item") else int(max_iters)
|
|
323
|
+
)
|
|
324
|
+
else:
|
|
325
|
+
max_i = 2**63 - 1 # Very large number
|
|
326
|
+
|
|
327
|
+
# Initial condition
|
|
328
|
+
if init_cond is not None:
|
|
329
|
+
cond = (
|
|
330
|
+
bool(init_cond.item())
|
|
331
|
+
if hasattr(init_cond, "item")
|
|
332
|
+
else bool(init_cond)
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
cond = True
|
|
336
|
+
|
|
337
|
+
# Current loop-carried values
|
|
338
|
+
current_vars = list(loop_vars)
|
|
339
|
+
|
|
340
|
+
# Scan output accumulators
|
|
341
|
+
scan_outputs: List[List[torch.Tensor]] = [
|
|
342
|
+
[] for _ in range(self.n_scan_outputs)
|
|
343
|
+
]
|
|
344
|
+
|
|
345
|
+
i = 0
|
|
346
|
+
while i < max_i and cond:
|
|
347
|
+
# Prepare inputs for body: iteration_num, condition, loop_carried..., outer...
|
|
348
|
+
iter_tensor = torch.tensor(i, dtype=torch.int64)
|
|
349
|
+
cond_tensor = torch.tensor(cond, dtype=torch.bool)
|
|
350
|
+
|
|
351
|
+
# Call body function
|
|
352
|
+
body_inputs = [iter_tensor, cond_tensor] + current_vars + outer_vals
|
|
353
|
+
outputs = self.body(*body_inputs)
|
|
354
|
+
|
|
355
|
+
# Handle single vs multiple outputs
|
|
356
|
+
if not isinstance(outputs, tuple):
|
|
357
|
+
outputs = (outputs,)
|
|
358
|
+
|
|
359
|
+
# First output is new condition
|
|
360
|
+
new_cond = outputs[0]
|
|
361
|
+
cond = (
|
|
362
|
+
bool(new_cond.item()) if hasattr(new_cond, "item") else bool(new_cond)
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Next n_loop_carried outputs are updated loop-carried values
|
|
366
|
+
current_vars = list(outputs[1 : 1 + self.n_loop_carried])
|
|
367
|
+
|
|
368
|
+
# Remaining outputs are scan outputs for this iteration
|
|
369
|
+
for j in range(self.n_scan_outputs):
|
|
370
|
+
scan_outputs[j].append(outputs[1 + self.n_loop_carried + j])
|
|
371
|
+
|
|
372
|
+
i += 1
|
|
373
|
+
|
|
374
|
+
# Prepare final outputs: loop-carried values, then stacked scan outputs
|
|
375
|
+
final_outputs = list(current_vars)
|
|
376
|
+
for scan_list in scan_outputs:
|
|
377
|
+
if scan_list:
|
|
378
|
+
# Stack scan outputs along a new first dimension
|
|
379
|
+
final_outputs.append(torch.stack(scan_list, dim=0))
|
|
380
|
+
else:
|
|
381
|
+
# Empty scan output - create empty tensor
|
|
382
|
+
final_outputs.append(torch.tensor([]))
|
|
383
|
+
|
|
384
|
+
return tuple(final_outputs)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class IfModule(nn.Module):
|
|
388
|
+
"""Module that executes an ONNX If conditional."""
|
|
389
|
+
|
|
390
|
+
def __init__(
|
|
391
|
+
self,
|
|
392
|
+
then_module: torch.fx.GraphModule,
|
|
393
|
+
else_module: torch.fx.GraphModule,
|
|
394
|
+
n_then_outer: int,
|
|
395
|
+
n_else_outer: int,
|
|
396
|
+
):
|
|
397
|
+
super().__init__()
|
|
398
|
+
self.then_branch = then_module
|
|
399
|
+
self.else_branch = else_module
|
|
400
|
+
self.n_then_outer = n_then_outer
|
|
401
|
+
self.n_else_outer = n_else_outer
|
|
402
|
+
|
|
403
|
+
def forward(self, condition, *args) -> Any:
|
|
404
|
+
"""Execute the conditional.
|
|
405
|
+
|
|
406
|
+
Args are: then_outer..., else_outer...
|
|
407
|
+
Returns the outputs of the selected branch.
|
|
408
|
+
"""
|
|
409
|
+
then_outer = list(args[: self.n_then_outer])
|
|
410
|
+
else_outer = list(args[self.n_then_outer :])
|
|
411
|
+
|
|
412
|
+
cond_val = (
|
|
413
|
+
bool(condition.item()) if hasattr(condition, "item") else bool(condition)
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
if cond_val:
|
|
417
|
+
result = self.then_branch(*then_outer)
|
|
418
|
+
else:
|
|
419
|
+
result = self.else_branch(*else_outer)
|
|
420
|
+
|
|
421
|
+
return result
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@register("Loop")
|
|
425
|
+
def loop_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
426
|
+
"""ONNX Loop operator.
|
|
427
|
+
|
|
428
|
+
Loop has inputs: M (max trip count), cond (initial condition), v_initial... (loop-carried deps)
|
|
429
|
+
Loop has attribute: body (GraphProto)
|
|
430
|
+
Body inputs: iteration_num, condition, loop_carried_deps...
|
|
431
|
+
Body outputs: condition, loop_carried_deps..., scan_outputs...
|
|
432
|
+
"""
|
|
433
|
+
# Get body subgraph
|
|
434
|
+
body_graph = get_attribute(node, "body")
|
|
435
|
+
if body_graph is None:
|
|
436
|
+
raise ValueError("Loop operator requires 'body' attribute")
|
|
437
|
+
|
|
438
|
+
# Get inputs
|
|
439
|
+
max_trip_count = builder.get_value(node.input[0]) if node.input[0] else None
|
|
440
|
+
initial_cond = get_optional_input(builder, node, 1)
|
|
441
|
+
loop_carried_inputs = [
|
|
442
|
+
builder.get_value(node.input[i]) for i in range(2, len(node.input))
|
|
443
|
+
]
|
|
444
|
+
|
|
445
|
+
# Get parent type info if available (for nested subgraphs)
|
|
446
|
+
parent_type_info = getattr(builder, "_type_info", None)
|
|
447
|
+
|
|
448
|
+
# Build subgraph module
|
|
449
|
+
body_module, body_input_names, body_output_names, outer_refs = (
|
|
450
|
+
_build_subgraph_module(
|
|
451
|
+
body_graph,
|
|
452
|
+
builder.env,
|
|
453
|
+
builder._opset_versions,
|
|
454
|
+
parent_type_info,
|
|
455
|
+
tensor_loader=builder.load_tensor,
|
|
456
|
+
)
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Get outer scope values that the body references
|
|
460
|
+
outer_values = [builder.get_value(name) for name in outer_refs]
|
|
461
|
+
|
|
462
|
+
# Number of loop-carried dependencies (excluding iteration_num and condition in body inputs)
|
|
463
|
+
n_loop_carried = len(body_input_names) - 2
|
|
464
|
+
# Number of scan outputs
|
|
465
|
+
n_scan_outputs = len(body_output_names) - 1 - n_loop_carried
|
|
466
|
+
|
|
467
|
+
# Create the loop module
|
|
468
|
+
loop_module = LoopModule(
|
|
469
|
+
body_module,
|
|
470
|
+
n_loop_carried,
|
|
471
|
+
n_scan_outputs,
|
|
472
|
+
len(loop_carried_inputs),
|
|
473
|
+
len(outer_values),
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# Register the loop module
|
|
477
|
+
module_name = builder.register_submodule(f"loop_{node.name or 'op'}", loop_module)
|
|
478
|
+
|
|
479
|
+
# Build flat args for call_module: max_iters, init_cond, loop_vars..., outer_vals...
|
|
480
|
+
args = [max_trip_count, initial_cond] + loop_carried_inputs + outer_values
|
|
481
|
+
|
|
482
|
+
result = builder.call_module(module_name, args=tuple(args))
|
|
483
|
+
|
|
484
|
+
return result
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
class ScanModule(nn.Module):
|
|
488
|
+
"""Module that executes an ONNX Scan operation."""
|
|
489
|
+
|
|
490
|
+
def __init__(
|
|
491
|
+
self,
|
|
492
|
+
body_module: torch.fx.GraphModule,
|
|
493
|
+
n_state_vars: int,
|
|
494
|
+
n_scan_inputs: int,
|
|
495
|
+
n_scan_outputs: int,
|
|
496
|
+
n_outer_vars: int,
|
|
497
|
+
scan_input_axes: List[int],
|
|
498
|
+
scan_output_axes: List[int],
|
|
499
|
+
scan_input_directions: List[int],
|
|
500
|
+
scan_output_directions: List[int],
|
|
501
|
+
):
|
|
502
|
+
super().__init__()
|
|
503
|
+
self.body = body_module
|
|
504
|
+
self.n_state_vars = n_state_vars
|
|
505
|
+
self.n_scan_inputs = n_scan_inputs
|
|
506
|
+
self.n_scan_outputs = n_scan_outputs
|
|
507
|
+
self.n_outer_vars = n_outer_vars
|
|
508
|
+
self.scan_input_axes = scan_input_axes
|
|
509
|
+
self.scan_output_axes = scan_output_axes
|
|
510
|
+
self.scan_input_directions = scan_input_directions
|
|
511
|
+
self.scan_output_directions = scan_output_directions
|
|
512
|
+
|
|
513
|
+
def forward(self, *args) -> Tuple[torch.Tensor, ...]:
|
|
514
|
+
"""Execute the scan.
|
|
515
|
+
|
|
516
|
+
Args are: state_vars..., scan_inputs..., outer_vals...
|
|
517
|
+
Returns final state variables followed by scan outputs.
|
|
518
|
+
"""
|
|
519
|
+
# Split args
|
|
520
|
+
state_vars = list(args[: self.n_state_vars])
|
|
521
|
+
scan_inputs = list(
|
|
522
|
+
args[self.n_state_vars : self.n_state_vars + self.n_scan_inputs]
|
|
523
|
+
)
|
|
524
|
+
outer_vals = list(args[self.n_state_vars + self.n_scan_inputs :])
|
|
525
|
+
|
|
526
|
+
# Determine sequence length from first scan input
|
|
527
|
+
if self.n_scan_inputs > 0:
|
|
528
|
+
first_input = scan_inputs[0]
|
|
529
|
+
axis = self.scan_input_axes[0] if self.scan_input_axes else 0
|
|
530
|
+
sequence_length = first_input.shape[axis]
|
|
531
|
+
else:
|
|
532
|
+
sequence_length = 0
|
|
533
|
+
|
|
534
|
+
# Initialize scan output accumulators
|
|
535
|
+
scan_outputs: List[List[torch.Tensor]] = [
|
|
536
|
+
[] for _ in range(self.n_scan_outputs)
|
|
537
|
+
]
|
|
538
|
+
|
|
539
|
+
# Current state
|
|
540
|
+
current_state = list(state_vars)
|
|
541
|
+
|
|
542
|
+
# Execute loop
|
|
543
|
+
for t in range(sequence_length):
|
|
544
|
+
# Extract scan input elements for this iteration
|
|
545
|
+
scan_input_elts = []
|
|
546
|
+
for i, scan_input in enumerate(scan_inputs):
|
|
547
|
+
axis = self.scan_input_axes[i] if i < len(self.scan_input_axes) else 0
|
|
548
|
+
direction = (
|
|
549
|
+
self.scan_input_directions[i]
|
|
550
|
+
if i < len(self.scan_input_directions)
|
|
551
|
+
else 0
|
|
552
|
+
)
|
|
553
|
+
# Reverse direction: 0 = forward, 1 = reverse
|
|
554
|
+
idx = sequence_length - 1 - t if direction == 1 else t
|
|
555
|
+
# Select along the axis
|
|
556
|
+
elt = torch.select(scan_input, axis, idx)
|
|
557
|
+
scan_input_elts.append(elt)
|
|
558
|
+
|
|
559
|
+
# Call body: inputs are state_vars..., scan_input_elts..., outer_vals...
|
|
560
|
+
body_inputs = current_state + scan_input_elts + outer_vals
|
|
561
|
+
outputs = self.body(*body_inputs)
|
|
562
|
+
|
|
563
|
+
# Handle single vs multiple outputs
|
|
564
|
+
if not isinstance(outputs, tuple):
|
|
565
|
+
outputs = (outputs,)
|
|
566
|
+
|
|
567
|
+
# First n_state_vars outputs are updated state
|
|
568
|
+
current_state = list(outputs[: self.n_state_vars])
|
|
569
|
+
|
|
570
|
+
# Remaining outputs are scan output elements
|
|
571
|
+
for j in range(self.n_scan_outputs):
|
|
572
|
+
scan_outputs[j].append(outputs[self.n_state_vars + j])
|
|
573
|
+
|
|
574
|
+
# Prepare final outputs: state variables, then stacked scan outputs
|
|
575
|
+
final_outputs = list(current_state)
|
|
576
|
+
for j, scan_list in enumerate(scan_outputs):
|
|
577
|
+
if scan_list:
|
|
578
|
+
axis = self.scan_output_axes[j] if j < len(self.scan_output_axes) else 0
|
|
579
|
+
direction = (
|
|
580
|
+
self.scan_output_directions[j]
|
|
581
|
+
if j < len(self.scan_output_directions)
|
|
582
|
+
else 0
|
|
583
|
+
)
|
|
584
|
+
# Stack along the specified axis
|
|
585
|
+
stacked = torch.stack(scan_list, dim=axis)
|
|
586
|
+
# Reverse if direction is 1 (prepending = reverse order)
|
|
587
|
+
if direction == 1:
|
|
588
|
+
stacked = torch.flip(stacked, dims=[axis])
|
|
589
|
+
final_outputs.append(stacked)
|
|
590
|
+
else:
|
|
591
|
+
# Empty scan output
|
|
592
|
+
final_outputs.append(torch.tensor([]))
|
|
593
|
+
|
|
594
|
+
return tuple(final_outputs)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
@register("Scan", since_version=9)
|
|
598
|
+
def scan_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
599
|
+
"""ONNX Scan operator (version 9+).
|
|
600
|
+
|
|
601
|
+
Scan iterates over one or more scan_input tensors, constructing scan_output tensors.
|
|
602
|
+
It combines ideas from general recurrences, functional programming constructs
|
|
603
|
+
such as scan, fold, map, and zip.
|
|
604
|
+
|
|
605
|
+
Inputs: initial_state_and_scan_inputs (variadic) - N state vars followed by M scan inputs
|
|
606
|
+
Outputs: final_state_and_scan_outputs (variadic) - N final states followed by K scan outputs
|
|
607
|
+
|
|
608
|
+
Attributes:
|
|
609
|
+
- body: The graph run each iteration
|
|
610
|
+
- num_scan_inputs: Number of scan inputs M
|
|
611
|
+
- scan_input_axes: Axis to scan for each scan input (default: 0)
|
|
612
|
+
- scan_input_directions: Direction for each scan input (0=forward, 1=reverse)
|
|
613
|
+
- scan_output_axes: Axis for each scan output (default: 0)
|
|
614
|
+
- scan_output_directions: Direction for each scan output (0=append, 1=prepend)
|
|
615
|
+
"""
|
|
616
|
+
# Get body subgraph
|
|
617
|
+
body_graph = get_attribute(node, "body")
|
|
618
|
+
if body_graph is None:
|
|
619
|
+
raise ValueError("Scan operator requires 'body' attribute")
|
|
620
|
+
|
|
621
|
+
# Get num_scan_inputs attribute (required)
|
|
622
|
+
num_scan_inputs = get_attribute(node, "num_scan_inputs")
|
|
623
|
+
if num_scan_inputs is None:
|
|
624
|
+
raise ValueError("Scan operator requires 'num_scan_inputs' attribute")
|
|
625
|
+
|
|
626
|
+
# Get optional attributes
|
|
627
|
+
scan_input_axes = get_attribute(node, "scan_input_axes") or []
|
|
628
|
+
scan_input_directions = get_attribute(node, "scan_input_directions") or []
|
|
629
|
+
scan_output_axes = get_attribute(node, "scan_output_axes") or []
|
|
630
|
+
scan_output_directions = get_attribute(node, "scan_output_directions") or []
|
|
631
|
+
|
|
632
|
+
# Parse inputs: first (len(node.input) - num_scan_inputs) are state variables
|
|
633
|
+
n_state_vars = len(node.input) - num_scan_inputs
|
|
634
|
+
|
|
635
|
+
state_inputs = [builder.get_value(node.input[i]) for i in range(n_state_vars)]
|
|
636
|
+
scan_inputs = [
|
|
637
|
+
builder.get_value(node.input[i]) for i in range(n_state_vars, len(node.input))
|
|
638
|
+
]
|
|
639
|
+
|
|
640
|
+
# Get parent type info if available (for nested subgraphs)
|
|
641
|
+
parent_type_info = getattr(builder, "_type_info", None)
|
|
642
|
+
|
|
643
|
+
# Build subgraph module
|
|
644
|
+
body_module, body_input_names, body_output_names, outer_refs = (
|
|
645
|
+
_build_subgraph_module(
|
|
646
|
+
body_graph,
|
|
647
|
+
builder.env,
|
|
648
|
+
builder._opset_versions,
|
|
649
|
+
parent_type_info,
|
|
650
|
+
tensor_loader=builder.load_tensor,
|
|
651
|
+
)
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Get outer scope values that the body references
|
|
655
|
+
outer_values = [builder.get_value(name) for name in outer_refs]
|
|
656
|
+
|
|
657
|
+
# Number of scan outputs
|
|
658
|
+
n_scan_outputs = len(body_output_names) - n_state_vars
|
|
659
|
+
|
|
660
|
+
# Create the scan module
|
|
661
|
+
scan_module = ScanModule(
|
|
662
|
+
body_module,
|
|
663
|
+
n_state_vars,
|
|
664
|
+
num_scan_inputs,
|
|
665
|
+
n_scan_outputs,
|
|
666
|
+
len(outer_values),
|
|
667
|
+
list(scan_input_axes),
|
|
668
|
+
list(scan_output_axes),
|
|
669
|
+
list(scan_input_directions),
|
|
670
|
+
list(scan_output_directions),
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
# Register the scan module
|
|
674
|
+
module_name = builder.register_submodule(f"scan_{node.name or 'op'}", scan_module)
|
|
675
|
+
|
|
676
|
+
# Build args: state_vars..., scan_inputs..., outer_vals...
|
|
677
|
+
args = state_inputs + scan_inputs + outer_values
|
|
678
|
+
|
|
679
|
+
result = builder.call_module(module_name, args=tuple(args))
|
|
680
|
+
|
|
681
|
+
return result
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
@register("Scan", since_version=8)
|
|
685
|
+
def scan_op_v8(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
686
|
+
"""ONNX Scan operator (version 8).
|
|
687
|
+
|
|
688
|
+
Version 8 has batching support and different input format:
|
|
689
|
+
- First input is optional sequence_lens
|
|
690
|
+
- Requires batch dimension (axis 0) and sequence dimension (axis 1)
|
|
691
|
+
|
|
692
|
+
For simplicity, this implementation handles the common case where
|
|
693
|
+
sequence_lens is empty (all sequences have same length).
|
|
694
|
+
"""
|
|
695
|
+
# Get body subgraph
|
|
696
|
+
body_graph = get_attribute(node, "body")
|
|
697
|
+
if body_graph is None:
|
|
698
|
+
raise ValueError("Scan operator requires 'body' attribute")
|
|
699
|
+
|
|
700
|
+
# Get num_scan_inputs attribute (required)
|
|
701
|
+
num_scan_inputs = get_attribute(node, "num_scan_inputs")
|
|
702
|
+
if num_scan_inputs is None:
|
|
703
|
+
raise ValueError("Scan operator requires 'num_scan_inputs' attribute")
|
|
704
|
+
|
|
705
|
+
# Get optional directions attribute (v8 uses 'directions' instead of scan_input_directions)
|
|
706
|
+
directions = get_attribute(node, "directions") or []
|
|
707
|
+
|
|
708
|
+
# In v8, first input is optional sequence_lens, rest are state vars + scan inputs
|
|
709
|
+
# Check if first input is empty (sequence_lens is optional)
|
|
710
|
+
has_sequence_lens = node.input[0] != ""
|
|
711
|
+
|
|
712
|
+
if has_sequence_lens:
|
|
713
|
+
# sequence_lens provided - we ignore it for now (assume fixed length)
|
|
714
|
+
start_idx = 1
|
|
715
|
+
else:
|
|
716
|
+
start_idx = 1 # Skip the empty sequence_lens input
|
|
717
|
+
|
|
718
|
+
# Parse remaining inputs: state variables followed by scan inputs
|
|
719
|
+
remaining_inputs = list(node.input[start_idx:])
|
|
720
|
+
n_state_vars = len(remaining_inputs) - num_scan_inputs
|
|
721
|
+
|
|
722
|
+
state_inputs = [builder.get_value(remaining_inputs[i]) for i in range(n_state_vars)]
|
|
723
|
+
scan_inputs = [
|
|
724
|
+
builder.get_value(remaining_inputs[i])
|
|
725
|
+
for i in range(n_state_vars, len(remaining_inputs))
|
|
726
|
+
]
|
|
727
|
+
|
|
728
|
+
# Get parent type info if available (for nested subgraphs)
|
|
729
|
+
parent_type_info = getattr(builder, "_type_info", None)
|
|
730
|
+
|
|
731
|
+
# Build subgraph module
|
|
732
|
+
body_module, body_input_names, body_output_names, outer_refs = (
|
|
733
|
+
_build_subgraph_module(
|
|
734
|
+
body_graph,
|
|
735
|
+
builder.env,
|
|
736
|
+
builder._opset_versions,
|
|
737
|
+
parent_type_info,
|
|
738
|
+
tensor_loader=builder.load_tensor,
|
|
739
|
+
)
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# Get outer scope values that the body references
|
|
743
|
+
outer_values = [builder.get_value(name) for name in outer_refs]
|
|
744
|
+
|
|
745
|
+
# Number of scan outputs
|
|
746
|
+
n_scan_outputs = len(body_output_names) - n_state_vars
|
|
747
|
+
|
|
748
|
+
# Create the scan module for v8 (handles batching)
|
|
749
|
+
# Note: In v8, batch axis is 0 and sequence axis is 1 (handled in ScanModuleV8)
|
|
750
|
+
scan_module = ScanModuleV8(
|
|
751
|
+
body_module,
|
|
752
|
+
n_state_vars,
|
|
753
|
+
num_scan_inputs,
|
|
754
|
+
n_scan_outputs,
|
|
755
|
+
len(outer_values),
|
|
756
|
+
list(directions),
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
# Register the scan module
|
|
760
|
+
module_name = builder.register_submodule(
|
|
761
|
+
f"scan_v8_{node.name or 'op'}", scan_module
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
# Build args: state_vars..., scan_inputs..., outer_vals...
|
|
765
|
+
args = state_inputs + scan_inputs + outer_values
|
|
766
|
+
|
|
767
|
+
result = builder.call_module(module_name, args=tuple(args))
|
|
768
|
+
|
|
769
|
+
return result
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
class ScanModuleV8(nn.Module):
|
|
773
|
+
"""Module that executes an ONNX Scan operation (version 8 with batching)."""
|
|
774
|
+
|
|
775
|
+
def __init__(
|
|
776
|
+
self,
|
|
777
|
+
body_module: torch.fx.GraphModule,
|
|
778
|
+
n_state_vars: int,
|
|
779
|
+
n_scan_inputs: int,
|
|
780
|
+
n_scan_outputs: int,
|
|
781
|
+
n_outer_vars: int,
|
|
782
|
+
directions: List[int],
|
|
783
|
+
):
|
|
784
|
+
super().__init__()
|
|
785
|
+
self.body = body_module
|
|
786
|
+
self.n_state_vars = n_state_vars
|
|
787
|
+
self.n_scan_inputs = n_scan_inputs
|
|
788
|
+
self.n_scan_outputs = n_scan_outputs
|
|
789
|
+
self.n_outer_vars = n_outer_vars
|
|
790
|
+
self.directions = directions
|
|
791
|
+
|
|
792
|
+
def forward(self, *args) -> Tuple[torch.Tensor, ...]:
|
|
793
|
+
"""Execute the scan with batching.
|
|
794
|
+
|
|
795
|
+
In v8, tensors have shape [batch, sequence, ...].
|
|
796
|
+
State variables have shape [batch, ...].
|
|
797
|
+
|
|
798
|
+
Args are: state_vars..., scan_inputs..., outer_vals...
|
|
799
|
+
Returns final state variables followed by scan outputs.
|
|
800
|
+
"""
|
|
801
|
+
# Split args
|
|
802
|
+
state_vars = list(args[: self.n_state_vars])
|
|
803
|
+
scan_inputs = list(
|
|
804
|
+
args[self.n_state_vars : self.n_state_vars + self.n_scan_inputs]
|
|
805
|
+
)
|
|
806
|
+
outer_vals = list(args[self.n_state_vars + self.n_scan_inputs :])
|
|
807
|
+
|
|
808
|
+
# Get batch size and sequence length from first scan input
|
|
809
|
+
if self.n_scan_inputs > 0:
|
|
810
|
+
first_input = scan_inputs[0]
|
|
811
|
+
batch_size = first_input.shape[0]
|
|
812
|
+
sequence_length = first_input.shape[1]
|
|
813
|
+
else:
|
|
814
|
+
batch_size = state_vars[0].shape[0] if state_vars else 1
|
|
815
|
+
sequence_length = 0
|
|
816
|
+
|
|
817
|
+
# Process each batch
|
|
818
|
+
batch_final_states: List[List[torch.Tensor]] = [
|
|
819
|
+
[] for _ in range(self.n_state_vars)
|
|
820
|
+
]
|
|
821
|
+
batch_scan_outputs: List[List[torch.Tensor]] = [
|
|
822
|
+
[] for _ in range(self.n_scan_outputs)
|
|
823
|
+
]
|
|
824
|
+
|
|
825
|
+
for batch in range(batch_size):
|
|
826
|
+
# Get batch slice of state variables
|
|
827
|
+
current_state = [sv[batch] for sv in state_vars]
|
|
828
|
+
|
|
829
|
+
# Initialize scan output accumulators for this batch
|
|
830
|
+
scan_outputs: List[List[torch.Tensor]] = [
|
|
831
|
+
[] for _ in range(self.n_scan_outputs)
|
|
832
|
+
]
|
|
833
|
+
|
|
834
|
+
# Execute loop over sequence
|
|
835
|
+
for t in range(sequence_length):
|
|
836
|
+
# Extract scan input elements for this batch and time step
|
|
837
|
+
scan_input_elts = []
|
|
838
|
+
for i, scan_input in enumerate(scan_inputs):
|
|
839
|
+
direction = self.directions[i] if i < len(self.directions) else 0
|
|
840
|
+
idx = sequence_length - 1 - t if direction == 1 else t
|
|
841
|
+
# scan_input has shape [batch, sequence, ...]
|
|
842
|
+
elt = scan_input[batch, idx]
|
|
843
|
+
scan_input_elts.append(elt)
|
|
844
|
+
|
|
845
|
+
# Call body
|
|
846
|
+
body_inputs = current_state + scan_input_elts + outer_vals
|
|
847
|
+
outputs = self.body(*body_inputs)
|
|
848
|
+
|
|
849
|
+
if not isinstance(outputs, tuple):
|
|
850
|
+
outputs = (outputs,)
|
|
851
|
+
|
|
852
|
+
# Update state
|
|
853
|
+
current_state = list(outputs[: self.n_state_vars])
|
|
854
|
+
|
|
855
|
+
# Collect scan outputs
|
|
856
|
+
for j in range(self.n_scan_outputs):
|
|
857
|
+
scan_outputs[j].append(outputs[self.n_state_vars + j])
|
|
858
|
+
|
|
859
|
+
# Store final state for this batch
|
|
860
|
+
for i, state in enumerate(current_state):
|
|
861
|
+
batch_final_states[i].append(state)
|
|
862
|
+
|
|
863
|
+
# Stack scan outputs for this batch
|
|
864
|
+
for j, scan_list in enumerate(scan_outputs):
|
|
865
|
+
if scan_list:
|
|
866
|
+
stacked = torch.stack(scan_list, dim=0)
|
|
867
|
+
batch_scan_outputs[j].append(stacked)
|
|
868
|
+
|
|
869
|
+
# Stack across batches
|
|
870
|
+
final_outputs: List[torch.Tensor] = []
|
|
871
|
+
|
|
872
|
+
# Final state variables: stack across batch
|
|
873
|
+
for states in batch_final_states:
|
|
874
|
+
final_outputs.append(torch.stack(states, dim=0))
|
|
875
|
+
|
|
876
|
+
# Scan outputs: stack across batch
|
|
877
|
+
for outputs in batch_scan_outputs:
|
|
878
|
+
if outputs:
|
|
879
|
+
final_outputs.append(torch.stack(outputs, dim=0))
|
|
880
|
+
|
|
881
|
+
return tuple(final_outputs)
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
@register("If")
|
|
885
|
+
def if_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
886
|
+
"""ONNX If operator.
|
|
887
|
+
|
|
888
|
+
If has input: cond (boolean condition)
|
|
889
|
+
If has attributes: then_branch (GraphProto), else_branch (GraphProto)
|
|
890
|
+
Both branches must have the same number and types of outputs.
|
|
891
|
+
"""
|
|
892
|
+
# Get condition
|
|
893
|
+
cond = builder.get_value(node.input[0])
|
|
894
|
+
|
|
895
|
+
# Get branch subgraphs
|
|
896
|
+
then_graph = get_attribute(node, "then_branch")
|
|
897
|
+
else_graph = get_attribute(node, "else_branch")
|
|
898
|
+
|
|
899
|
+
if then_graph is None or else_graph is None:
|
|
900
|
+
raise ValueError(
|
|
901
|
+
"If operator requires 'then_branch' and 'else_branch' attributes"
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
# Get parent type info if available (for nested subgraphs)
|
|
905
|
+
parent_type_info = getattr(builder, "_type_info", None)
|
|
906
|
+
|
|
907
|
+
# Build subgraph modules for both branches
|
|
908
|
+
then_module, then_input_names, then_output_names, then_outer_refs = (
|
|
909
|
+
_build_subgraph_module(
|
|
910
|
+
then_graph,
|
|
911
|
+
builder.env,
|
|
912
|
+
builder._opset_versions,
|
|
913
|
+
parent_type_info,
|
|
914
|
+
tensor_loader=builder.load_tensor,
|
|
915
|
+
)
|
|
916
|
+
)
|
|
917
|
+
else_module, else_input_names, else_output_names, else_outer_refs = (
|
|
918
|
+
_build_subgraph_module(
|
|
919
|
+
else_graph,
|
|
920
|
+
builder.env,
|
|
921
|
+
builder._opset_versions,
|
|
922
|
+
parent_type_info,
|
|
923
|
+
tensor_loader=builder.load_tensor,
|
|
924
|
+
)
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# Get outer scope values for both branches
|
|
928
|
+
then_outer_values = [builder.get_value(name) for name in then_outer_refs]
|
|
929
|
+
else_outer_values = [builder.get_value(name) for name in else_outer_refs]
|
|
930
|
+
|
|
931
|
+
# Create the if module
|
|
932
|
+
if_module = IfModule(
|
|
933
|
+
then_module,
|
|
934
|
+
else_module,
|
|
935
|
+
len(then_outer_values),
|
|
936
|
+
len(else_outer_values),
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# Register the if module
|
|
940
|
+
module_name = builder.register_submodule(f"if_{node.name or 'op'}", if_module)
|
|
941
|
+
|
|
942
|
+
# Build flat args for call_module: condition, then_outer..., else_outer...
|
|
943
|
+
args = [cond] + then_outer_values + else_outer_values
|
|
944
|
+
|
|
945
|
+
result = builder.call_module(module_name, args=tuple(args))
|
|
946
|
+
|
|
947
|
+
return result
|