onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,947 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Control flow operators.
3
+
4
+ This module implements ONNX control flow operators like Loop and If.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
8
+
9
+ import onnx
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.fx
13
+ from onnx import numpy_helper
14
+
15
+ from ..utils.names import sanitize_name
16
+ from ..op_registry import register
17
+ from ..utils.attributes import get_attribute
18
+ from ..utils.op_helpers import get_optional_input
19
+
20
+ if TYPE_CHECKING:
21
+ from ..graph_builder import GraphBuilder
22
+
23
+
24
+ def _collect_all_subgraph_inputs(node: onnx.NodeProto) -> set:
25
+ """Recursively collect all inputs from a node including nested subgraph inputs.
26
+
27
+ For control flow nodes like If and Loop, this also collects inputs that
28
+ are referenced by nested subgraphs.
29
+
30
+ Parameters
31
+ ----------
32
+ node : onnx.NodeProto
33
+ The ONNX node to collect inputs from.
34
+
35
+ Returns
36
+ -------
37
+ set
38
+ Set of all input names referenced by this node and its nested subgraphs.
39
+ """
40
+ inputs = set(node.input)
41
+
42
+ # Collect inputs from subgraphs (for If, Loop, etc.)
43
+ for attr in node.attribute:
44
+ if attr.type == onnx.AttributeProto.GRAPH:
45
+ subgraph = attr.g
46
+ # Collect subgraph's own initializers and inputs as local values
47
+ local_values = set()
48
+ for init in subgraph.initializer:
49
+ local_values.add(init.name)
50
+ for inp in subgraph.input:
51
+ local_values.add(inp.name)
52
+
53
+ # Recursively collect inputs from subgraph nodes
54
+ for sub_node in subgraph.node:
55
+ sub_inputs = _collect_all_subgraph_inputs(sub_node)
56
+ # Add outputs of this subgraph node to local values
57
+ for out in sub_node.output:
58
+ if out:
59
+ local_values.add(out)
60
+ # Inputs not satisfied locally are outer references
61
+ for sub_inp in sub_inputs:
62
+ if sub_inp and sub_inp not in local_values:
63
+ inputs.add(sub_inp)
64
+
65
+ return inputs
66
+
67
+
68
+ def _build_subgraph_module(
69
+ body_graph: onnx.GraphProto,
70
+ parent_env: Dict[str, torch.fx.Node],
71
+ parent_opset_versions: Dict[str, int],
72
+ parent_type_info: Optional[Dict[str, bool]] = None,
73
+ tensor_loader: Optional[Callable[[onnx.TensorProto], torch.Tensor]] = None,
74
+ ) -> Tuple[torch.fx.GraphModule, List[str], List[str], List[str]]:
75
+ """Build an FX GraphModule from an ONNX subgraph.
76
+
77
+ Parameters
78
+ ----------
79
+ body_graph : onnx.GraphProto
80
+ The ONNX subgraph to convert.
81
+ parent_env : Dict[str, torch.fx.Node]
82
+ Environment from parent graph (for accessing outer scope values).
83
+ parent_opset_versions : Dict[str, int]
84
+ Opset versions from parent model.
85
+ parent_type_info : Optional[Dict[str, bool]]
86
+ Mapping from value names to whether they are optional types in parent scope.
87
+
88
+ Returns
89
+ -------
90
+ Tuple[torch.fx.GraphModule, List[str], List[str], List[str]]
91
+ The FX GraphModule, list of input names, list of output names, and outer refs.
92
+ """
93
+ from ..op_registry import get_handler
94
+
95
+ graph = torch.fx.Graph()
96
+ env: Dict[str, torch.fx.Node] = {}
97
+ constants: Dict[str, torch.Tensor] = {}
98
+
99
+ # Initialize parent type info
100
+ if parent_type_info is None:
101
+ parent_type_info = {}
102
+
103
+ # Get input and output names
104
+ input_names = [inp.name for inp in body_graph.input]
105
+ output_names = [out.name for out in body_graph.output]
106
+
107
+ # Load initializers from subgraph
108
+ initializer_map: Dict[str, torch.Tensor] = {}
109
+ for initializer in body_graph.initializer:
110
+ if tensor_loader is not None:
111
+ initializer_map[initializer.name] = tensor_loader(initializer)
112
+ else:
113
+ np_array = numpy_helper.to_array(initializer)
114
+ initializer_map[initializer.name] = torch.from_numpy(np_array.copy())
115
+
116
+ # Register initializers as constants
117
+ for name, tensor in initializer_map.items():
118
+ safe_name = sanitize_name(name)
119
+ constants[safe_name] = tensor
120
+ fx_node = graph.get_attr(safe_name)
121
+ env[name] = fx_node
122
+
123
+ # Create placeholders for subgraph inputs
124
+ for inp in body_graph.input:
125
+ if inp.name in env:
126
+ continue # Skip if already loaded as initializer
127
+ safe_name = sanitize_name(inp.name)
128
+ placeholder = graph.placeholder(safe_name)
129
+ env[inp.name] = placeholder
130
+
131
+ # Collect all values that will be produced by nodes in this subgraph
132
+ subgraph_outputs = set()
133
+ for node in body_graph.node:
134
+ for out in node.output:
135
+ if out:
136
+ subgraph_outputs.add(out)
137
+
138
+ # Add references to parent scope values that are used in the subgraph
139
+ # (including nested subgraphs). These will be passed as additional inputs.
140
+ outer_refs: List[str] = []
141
+ for node in body_graph.node:
142
+ # Collect all inputs including from nested subgraphs
143
+ all_inputs = _collect_all_subgraph_inputs(node)
144
+ for inp_name in all_inputs:
145
+ # Skip if empty, already in env, or will be produced by a node in this subgraph
146
+ if not inp_name or inp_name in env or inp_name in subgraph_outputs:
147
+ continue
148
+ if inp_name in parent_env:
149
+ if inp_name not in outer_refs: # Avoid duplicates
150
+ outer_refs.append(inp_name)
151
+ safe_name = sanitize_name(inp_name)
152
+ placeholder = graph.placeholder(f"outer_{safe_name}")
153
+ env[inp_name] = placeholder
154
+
155
+ # Create a minimal builder-like object for handler calls
156
+ class SubgraphBuilder:
157
+ def __init__(self):
158
+ self.graph = graph
159
+ self.env = env
160
+ self._opset_versions = parent_opset_versions
161
+ self._constants = constants
162
+ self._submodules: Dict[str, nn.Module] = {}
163
+ self.initializer_map = initializer_map
164
+ self._body_graph = body_graph
165
+ self._parent_type_info = parent_type_info
166
+ self._tensor_loader = tensor_loader
167
+ # Build type info for this subgraph (to pass to nested subgraphs)
168
+ self._type_info = self._build_type_info()
169
+
170
+ def load_tensor(self, tensor: onnx.TensorProto) -> torch.Tensor:
171
+ if self._tensor_loader is not None:
172
+ return self._tensor_loader(tensor)
173
+ np_array = numpy_helper.to_array(tensor)
174
+ return torch.from_numpy(np_array.copy())
175
+
176
+ def _build_type_info(self) -> Dict[str, bool]:
177
+ """Build a mapping of value names to whether they are optional types."""
178
+ info: Dict[str, bool] = {}
179
+ # Include parent type info
180
+ info.update(self._parent_type_info)
181
+ # Add types from this subgraph
182
+ for value_info in self._body_graph.input:
183
+ info[value_info.name] = value_info.type.HasField("optional_type")
184
+ for value_info in self._body_graph.value_info:
185
+ info[value_info.name] = value_info.type.HasField("optional_type")
186
+ for value_info in self._body_graph.output:
187
+ info[value_info.name] = value_info.type.HasField("optional_type")
188
+ return info
189
+
190
+ @property
191
+ def opset_version(self) -> int:
192
+ return self._opset_versions.get("", 1)
193
+
194
+ def is_optional_type(self, name: str) -> bool:
195
+ """Check if a value has optional type in the subgraph or parent scope."""
196
+ # First check the combined type info (includes parent scope)
197
+ if name in self._type_info:
198
+ return self._type_info[name]
199
+ return False
200
+
201
+ def get_opset_version(self, domain: str = "") -> int:
202
+ return self._opset_versions.get(domain, 1)
203
+
204
+ def get_value(self, name: str) -> torch.fx.Node:
205
+ if name not in self.env:
206
+ raise KeyError(f"Value '{name}' not found in subgraph environment")
207
+ return self.env[name]
208
+
209
+ def has_value(self, name: str) -> bool:
210
+ return name in self.env
211
+
212
+ def call_function(
213
+ self,
214
+ func,
215
+ args: tuple = (),
216
+ kwargs: Optional[Dict[str, Any]] = None,
217
+ ) -> torch.fx.Node:
218
+ return self.graph.call_function(func, args=tuple(args), kwargs=kwargs or {})
219
+
220
+ def register_submodule(self, name: str, module: nn.Module) -> str:
221
+ safe_name = sanitize_name(name)
222
+ if safe_name in self._submodules:
223
+ counter = 0
224
+ while f"{safe_name}_{counter}" in self._submodules:
225
+ counter += 1
226
+ safe_name = f"{safe_name}_{counter}"
227
+ self._submodules[safe_name] = module
228
+ return safe_name
229
+
230
+ def call_module(
231
+ self,
232
+ module_name: str,
233
+ args: tuple = (),
234
+ kwargs: Optional[Dict[str, Any]] = None,
235
+ ) -> torch.fx.Node:
236
+ return self.graph.call_module(
237
+ module_name, args=tuple(args), kwargs=kwargs or {}
238
+ )
239
+
240
+ builder = SubgraphBuilder()
241
+
242
+ # Convert nodes
243
+ for node in body_graph.node:
244
+ domain = node.domain if node.domain else ""
245
+ opset = builder.get_opset_version(domain)
246
+ handler = get_handler(node.op_type, domain, opset)
247
+ if handler is None:
248
+ raise ValueError(
249
+ f"Unsupported operator in subgraph: {node.op_type} (domain={domain})"
250
+ )
251
+ fx_node = handler(builder, node)
252
+
253
+ # Handle outputs
254
+ if len(node.output) == 1:
255
+ env[node.output[0]] = fx_node
256
+ else:
257
+ for i, output_name in enumerate(node.output):
258
+ if output_name:
259
+ getitem_node = graph.call_function(
260
+ lambda x, idx=i: x[idx] if isinstance(x, (tuple, list)) else x,
261
+ args=(fx_node, i),
262
+ )
263
+ env[output_name] = getitem_node
264
+
265
+ # Create output - return tuple of output values
266
+ output_nodes = []
267
+ for out_name in output_names:
268
+ if out_name in env:
269
+ output_nodes.append(env[out_name])
270
+ else:
271
+ raise KeyError(f"Output '{out_name}' not found in subgraph environment")
272
+
273
+ if len(output_nodes) == 1:
274
+ graph.output(output_nodes[0])
275
+ else:
276
+ graph.output(tuple(output_nodes))
277
+
278
+ # Create the module
279
+ root_module = nn.Module()
280
+ for name, tensor in constants.items():
281
+ root_module.register_buffer(name, tensor)
282
+ for name, submod in builder._submodules.items():
283
+ root_module.add_module(name, submod)
284
+
285
+ module = torch.fx.GraphModule(root_module, graph)
286
+ module.graph.lint()
287
+
288
+ return module, input_names, output_names, outer_refs
289
+
290
+
291
+ class LoopModule(nn.Module):
292
+ """Module that executes an ONNX Loop."""
293
+
294
+ def __init__(
295
+ self,
296
+ body_module: torch.fx.GraphModule,
297
+ n_loop_carried: int,
298
+ n_scan_outputs: int,
299
+ n_loop_vars: int,
300
+ n_outer_vars: int,
301
+ ):
302
+ super().__init__()
303
+ self.body = body_module
304
+ self.n_loop_carried = n_loop_carried
305
+ self.n_scan_outputs = n_scan_outputs
306
+ self.n_loop_vars = n_loop_vars
307
+ self.n_outer_vars = n_outer_vars
308
+
309
+ def forward(self, max_iters, init_cond, *args) -> Tuple[torch.Tensor, ...]:
310
+ """Execute the loop.
311
+
312
+ Args are: loop_vars..., outer_vals...
313
+ Returns final loop-carried values followed by concatenated scan outputs.
314
+ """
315
+ # Split args into loop_vars and outer_vals
316
+ loop_vars = list(args[: self.n_loop_vars])
317
+ outer_vals = list(args[self.n_loop_vars :])
318
+
319
+ # Determine max iterations
320
+ if max_iters is not None:
321
+ max_i = (
322
+ int(max_iters.item()) if hasattr(max_iters, "item") else int(max_iters)
323
+ )
324
+ else:
325
+ max_i = 2**63 - 1 # Very large number
326
+
327
+ # Initial condition
328
+ if init_cond is not None:
329
+ cond = (
330
+ bool(init_cond.item())
331
+ if hasattr(init_cond, "item")
332
+ else bool(init_cond)
333
+ )
334
+ else:
335
+ cond = True
336
+
337
+ # Current loop-carried values
338
+ current_vars = list(loop_vars)
339
+
340
+ # Scan output accumulators
341
+ scan_outputs: List[List[torch.Tensor]] = [
342
+ [] for _ in range(self.n_scan_outputs)
343
+ ]
344
+
345
+ i = 0
346
+ while i < max_i and cond:
347
+ # Prepare inputs for body: iteration_num, condition, loop_carried..., outer...
348
+ iter_tensor = torch.tensor(i, dtype=torch.int64)
349
+ cond_tensor = torch.tensor(cond, dtype=torch.bool)
350
+
351
+ # Call body function
352
+ body_inputs = [iter_tensor, cond_tensor] + current_vars + outer_vals
353
+ outputs = self.body(*body_inputs)
354
+
355
+ # Handle single vs multiple outputs
356
+ if not isinstance(outputs, tuple):
357
+ outputs = (outputs,)
358
+
359
+ # First output is new condition
360
+ new_cond = outputs[0]
361
+ cond = (
362
+ bool(new_cond.item()) if hasattr(new_cond, "item") else bool(new_cond)
363
+ )
364
+
365
+ # Next n_loop_carried outputs are updated loop-carried values
366
+ current_vars = list(outputs[1 : 1 + self.n_loop_carried])
367
+
368
+ # Remaining outputs are scan outputs for this iteration
369
+ for j in range(self.n_scan_outputs):
370
+ scan_outputs[j].append(outputs[1 + self.n_loop_carried + j])
371
+
372
+ i += 1
373
+
374
+ # Prepare final outputs: loop-carried values, then stacked scan outputs
375
+ final_outputs = list(current_vars)
376
+ for scan_list in scan_outputs:
377
+ if scan_list:
378
+ # Stack scan outputs along a new first dimension
379
+ final_outputs.append(torch.stack(scan_list, dim=0))
380
+ else:
381
+ # Empty scan output - create empty tensor
382
+ final_outputs.append(torch.tensor([]))
383
+
384
+ return tuple(final_outputs)
385
+
386
+
387
+ class IfModule(nn.Module):
388
+ """Module that executes an ONNX If conditional."""
389
+
390
+ def __init__(
391
+ self,
392
+ then_module: torch.fx.GraphModule,
393
+ else_module: torch.fx.GraphModule,
394
+ n_then_outer: int,
395
+ n_else_outer: int,
396
+ ):
397
+ super().__init__()
398
+ self.then_branch = then_module
399
+ self.else_branch = else_module
400
+ self.n_then_outer = n_then_outer
401
+ self.n_else_outer = n_else_outer
402
+
403
+ def forward(self, condition, *args) -> Any:
404
+ """Execute the conditional.
405
+
406
+ Args are: then_outer..., else_outer...
407
+ Returns the outputs of the selected branch.
408
+ """
409
+ then_outer = list(args[: self.n_then_outer])
410
+ else_outer = list(args[self.n_then_outer :])
411
+
412
+ cond_val = (
413
+ bool(condition.item()) if hasattr(condition, "item") else bool(condition)
414
+ )
415
+
416
+ if cond_val:
417
+ result = self.then_branch(*then_outer)
418
+ else:
419
+ result = self.else_branch(*else_outer)
420
+
421
+ return result
422
+
423
+
424
+ @register("Loop")
425
+ def loop_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
426
+ """ONNX Loop operator.
427
+
428
+ Loop has inputs: M (max trip count), cond (initial condition), v_initial... (loop-carried deps)
429
+ Loop has attribute: body (GraphProto)
430
+ Body inputs: iteration_num, condition, loop_carried_deps...
431
+ Body outputs: condition, loop_carried_deps..., scan_outputs...
432
+ """
433
+ # Get body subgraph
434
+ body_graph = get_attribute(node, "body")
435
+ if body_graph is None:
436
+ raise ValueError("Loop operator requires 'body' attribute")
437
+
438
+ # Get inputs
439
+ max_trip_count = builder.get_value(node.input[0]) if node.input[0] else None
440
+ initial_cond = get_optional_input(builder, node, 1)
441
+ loop_carried_inputs = [
442
+ builder.get_value(node.input[i]) for i in range(2, len(node.input))
443
+ ]
444
+
445
+ # Get parent type info if available (for nested subgraphs)
446
+ parent_type_info = getattr(builder, "_type_info", None)
447
+
448
+ # Build subgraph module
449
+ body_module, body_input_names, body_output_names, outer_refs = (
450
+ _build_subgraph_module(
451
+ body_graph,
452
+ builder.env,
453
+ builder._opset_versions,
454
+ parent_type_info,
455
+ tensor_loader=builder.load_tensor,
456
+ )
457
+ )
458
+
459
+ # Get outer scope values that the body references
460
+ outer_values = [builder.get_value(name) for name in outer_refs]
461
+
462
+ # Number of loop-carried dependencies (excluding iteration_num and condition in body inputs)
463
+ n_loop_carried = len(body_input_names) - 2
464
+ # Number of scan outputs
465
+ n_scan_outputs = len(body_output_names) - 1 - n_loop_carried
466
+
467
+ # Create the loop module
468
+ loop_module = LoopModule(
469
+ body_module,
470
+ n_loop_carried,
471
+ n_scan_outputs,
472
+ len(loop_carried_inputs),
473
+ len(outer_values),
474
+ )
475
+
476
+ # Register the loop module
477
+ module_name = builder.register_submodule(f"loop_{node.name or 'op'}", loop_module)
478
+
479
+ # Build flat args for call_module: max_iters, init_cond, loop_vars..., outer_vals...
480
+ args = [max_trip_count, initial_cond] + loop_carried_inputs + outer_values
481
+
482
+ result = builder.call_module(module_name, args=tuple(args))
483
+
484
+ return result
485
+
486
+
487
+ class ScanModule(nn.Module):
488
+ """Module that executes an ONNX Scan operation."""
489
+
490
+ def __init__(
491
+ self,
492
+ body_module: torch.fx.GraphModule,
493
+ n_state_vars: int,
494
+ n_scan_inputs: int,
495
+ n_scan_outputs: int,
496
+ n_outer_vars: int,
497
+ scan_input_axes: List[int],
498
+ scan_output_axes: List[int],
499
+ scan_input_directions: List[int],
500
+ scan_output_directions: List[int],
501
+ ):
502
+ super().__init__()
503
+ self.body = body_module
504
+ self.n_state_vars = n_state_vars
505
+ self.n_scan_inputs = n_scan_inputs
506
+ self.n_scan_outputs = n_scan_outputs
507
+ self.n_outer_vars = n_outer_vars
508
+ self.scan_input_axes = scan_input_axes
509
+ self.scan_output_axes = scan_output_axes
510
+ self.scan_input_directions = scan_input_directions
511
+ self.scan_output_directions = scan_output_directions
512
+
513
+ def forward(self, *args) -> Tuple[torch.Tensor, ...]:
514
+ """Execute the scan.
515
+
516
+ Args are: state_vars..., scan_inputs..., outer_vals...
517
+ Returns final state variables followed by scan outputs.
518
+ """
519
+ # Split args
520
+ state_vars = list(args[: self.n_state_vars])
521
+ scan_inputs = list(
522
+ args[self.n_state_vars : self.n_state_vars + self.n_scan_inputs]
523
+ )
524
+ outer_vals = list(args[self.n_state_vars + self.n_scan_inputs :])
525
+
526
+ # Determine sequence length from first scan input
527
+ if self.n_scan_inputs > 0:
528
+ first_input = scan_inputs[0]
529
+ axis = self.scan_input_axes[0] if self.scan_input_axes else 0
530
+ sequence_length = first_input.shape[axis]
531
+ else:
532
+ sequence_length = 0
533
+
534
+ # Initialize scan output accumulators
535
+ scan_outputs: List[List[torch.Tensor]] = [
536
+ [] for _ in range(self.n_scan_outputs)
537
+ ]
538
+
539
+ # Current state
540
+ current_state = list(state_vars)
541
+
542
+ # Execute loop
543
+ for t in range(sequence_length):
544
+ # Extract scan input elements for this iteration
545
+ scan_input_elts = []
546
+ for i, scan_input in enumerate(scan_inputs):
547
+ axis = self.scan_input_axes[i] if i < len(self.scan_input_axes) else 0
548
+ direction = (
549
+ self.scan_input_directions[i]
550
+ if i < len(self.scan_input_directions)
551
+ else 0
552
+ )
553
+ # Reverse direction: 0 = forward, 1 = reverse
554
+ idx = sequence_length - 1 - t if direction == 1 else t
555
+ # Select along the axis
556
+ elt = torch.select(scan_input, axis, idx)
557
+ scan_input_elts.append(elt)
558
+
559
+ # Call body: inputs are state_vars..., scan_input_elts..., outer_vals...
560
+ body_inputs = current_state + scan_input_elts + outer_vals
561
+ outputs = self.body(*body_inputs)
562
+
563
+ # Handle single vs multiple outputs
564
+ if not isinstance(outputs, tuple):
565
+ outputs = (outputs,)
566
+
567
+ # First n_state_vars outputs are updated state
568
+ current_state = list(outputs[: self.n_state_vars])
569
+
570
+ # Remaining outputs are scan output elements
571
+ for j in range(self.n_scan_outputs):
572
+ scan_outputs[j].append(outputs[self.n_state_vars + j])
573
+
574
+ # Prepare final outputs: state variables, then stacked scan outputs
575
+ final_outputs = list(current_state)
576
+ for j, scan_list in enumerate(scan_outputs):
577
+ if scan_list:
578
+ axis = self.scan_output_axes[j] if j < len(self.scan_output_axes) else 0
579
+ direction = (
580
+ self.scan_output_directions[j]
581
+ if j < len(self.scan_output_directions)
582
+ else 0
583
+ )
584
+ # Stack along the specified axis
585
+ stacked = torch.stack(scan_list, dim=axis)
586
+ # Reverse if direction is 1 (prepending = reverse order)
587
+ if direction == 1:
588
+ stacked = torch.flip(stacked, dims=[axis])
589
+ final_outputs.append(stacked)
590
+ else:
591
+ # Empty scan output
592
+ final_outputs.append(torch.tensor([]))
593
+
594
+ return tuple(final_outputs)
595
+
596
+
597
+ @register("Scan", since_version=9)
598
+ def scan_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
599
+ """ONNX Scan operator (version 9+).
600
+
601
+ Scan iterates over one or more scan_input tensors, constructing scan_output tensors.
602
+ It combines ideas from general recurrences, functional programming constructs
603
+ such as scan, fold, map, and zip.
604
+
605
+ Inputs: initial_state_and_scan_inputs (variadic) - N state vars followed by M scan inputs
606
+ Outputs: final_state_and_scan_outputs (variadic) - N final states followed by K scan outputs
607
+
608
+ Attributes:
609
+ - body: The graph run each iteration
610
+ - num_scan_inputs: Number of scan inputs M
611
+ - scan_input_axes: Axis to scan for each scan input (default: 0)
612
+ - scan_input_directions: Direction for each scan input (0=forward, 1=reverse)
613
+ - scan_output_axes: Axis for each scan output (default: 0)
614
+ - scan_output_directions: Direction for each scan output (0=append, 1=prepend)
615
+ """
616
+ # Get body subgraph
617
+ body_graph = get_attribute(node, "body")
618
+ if body_graph is None:
619
+ raise ValueError("Scan operator requires 'body' attribute")
620
+
621
+ # Get num_scan_inputs attribute (required)
622
+ num_scan_inputs = get_attribute(node, "num_scan_inputs")
623
+ if num_scan_inputs is None:
624
+ raise ValueError("Scan operator requires 'num_scan_inputs' attribute")
625
+
626
+ # Get optional attributes
627
+ scan_input_axes = get_attribute(node, "scan_input_axes") or []
628
+ scan_input_directions = get_attribute(node, "scan_input_directions") or []
629
+ scan_output_axes = get_attribute(node, "scan_output_axes") or []
630
+ scan_output_directions = get_attribute(node, "scan_output_directions") or []
631
+
632
+ # Parse inputs: first (len(node.input) - num_scan_inputs) are state variables
633
+ n_state_vars = len(node.input) - num_scan_inputs
634
+
635
+ state_inputs = [builder.get_value(node.input[i]) for i in range(n_state_vars)]
636
+ scan_inputs = [
637
+ builder.get_value(node.input[i]) for i in range(n_state_vars, len(node.input))
638
+ ]
639
+
640
+ # Get parent type info if available (for nested subgraphs)
641
+ parent_type_info = getattr(builder, "_type_info", None)
642
+
643
+ # Build subgraph module
644
+ body_module, body_input_names, body_output_names, outer_refs = (
645
+ _build_subgraph_module(
646
+ body_graph,
647
+ builder.env,
648
+ builder._opset_versions,
649
+ parent_type_info,
650
+ tensor_loader=builder.load_tensor,
651
+ )
652
+ )
653
+
654
+ # Get outer scope values that the body references
655
+ outer_values = [builder.get_value(name) for name in outer_refs]
656
+
657
+ # Number of scan outputs
658
+ n_scan_outputs = len(body_output_names) - n_state_vars
659
+
660
+ # Create the scan module
661
+ scan_module = ScanModule(
662
+ body_module,
663
+ n_state_vars,
664
+ num_scan_inputs,
665
+ n_scan_outputs,
666
+ len(outer_values),
667
+ list(scan_input_axes),
668
+ list(scan_output_axes),
669
+ list(scan_input_directions),
670
+ list(scan_output_directions),
671
+ )
672
+
673
+ # Register the scan module
674
+ module_name = builder.register_submodule(f"scan_{node.name or 'op'}", scan_module)
675
+
676
+ # Build args: state_vars..., scan_inputs..., outer_vals...
677
+ args = state_inputs + scan_inputs + outer_values
678
+
679
+ result = builder.call_module(module_name, args=tuple(args))
680
+
681
+ return result
682
+
683
+
684
+ @register("Scan", since_version=8)
685
+ def scan_op_v8(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
686
+ """ONNX Scan operator (version 8).
687
+
688
+ Version 8 has batching support and different input format:
689
+ - First input is optional sequence_lens
690
+ - Requires batch dimension (axis 0) and sequence dimension (axis 1)
691
+
692
+ For simplicity, this implementation handles the common case where
693
+ sequence_lens is empty (all sequences have same length).
694
+ """
695
+ # Get body subgraph
696
+ body_graph = get_attribute(node, "body")
697
+ if body_graph is None:
698
+ raise ValueError("Scan operator requires 'body' attribute")
699
+
700
+ # Get num_scan_inputs attribute (required)
701
+ num_scan_inputs = get_attribute(node, "num_scan_inputs")
702
+ if num_scan_inputs is None:
703
+ raise ValueError("Scan operator requires 'num_scan_inputs' attribute")
704
+
705
+ # Get optional directions attribute (v8 uses 'directions' instead of scan_input_directions)
706
+ directions = get_attribute(node, "directions") or []
707
+
708
+ # In v8, first input is optional sequence_lens, rest are state vars + scan inputs
709
+ # Check if first input is empty (sequence_lens is optional)
710
+ has_sequence_lens = node.input[0] != ""
711
+
712
+ if has_sequence_lens:
713
+ # sequence_lens provided - we ignore it for now (assume fixed length)
714
+ start_idx = 1
715
+ else:
716
+ start_idx = 1 # Skip the empty sequence_lens input
717
+
718
+ # Parse remaining inputs: state variables followed by scan inputs
719
+ remaining_inputs = list(node.input[start_idx:])
720
+ n_state_vars = len(remaining_inputs) - num_scan_inputs
721
+
722
+ state_inputs = [builder.get_value(remaining_inputs[i]) for i in range(n_state_vars)]
723
+ scan_inputs = [
724
+ builder.get_value(remaining_inputs[i])
725
+ for i in range(n_state_vars, len(remaining_inputs))
726
+ ]
727
+
728
+ # Get parent type info if available (for nested subgraphs)
729
+ parent_type_info = getattr(builder, "_type_info", None)
730
+
731
+ # Build subgraph module
732
+ body_module, body_input_names, body_output_names, outer_refs = (
733
+ _build_subgraph_module(
734
+ body_graph,
735
+ builder.env,
736
+ builder._opset_versions,
737
+ parent_type_info,
738
+ tensor_loader=builder.load_tensor,
739
+ )
740
+ )
741
+
742
+ # Get outer scope values that the body references
743
+ outer_values = [builder.get_value(name) for name in outer_refs]
744
+
745
+ # Number of scan outputs
746
+ n_scan_outputs = len(body_output_names) - n_state_vars
747
+
748
+ # Create the scan module for v8 (handles batching)
749
+ # Note: In v8, batch axis is 0 and sequence axis is 1 (handled in ScanModuleV8)
750
+ scan_module = ScanModuleV8(
751
+ body_module,
752
+ n_state_vars,
753
+ num_scan_inputs,
754
+ n_scan_outputs,
755
+ len(outer_values),
756
+ list(directions),
757
+ )
758
+
759
+ # Register the scan module
760
+ module_name = builder.register_submodule(
761
+ f"scan_v8_{node.name or 'op'}", scan_module
762
+ )
763
+
764
+ # Build args: state_vars..., scan_inputs..., outer_vals...
765
+ args = state_inputs + scan_inputs + outer_values
766
+
767
+ result = builder.call_module(module_name, args=tuple(args))
768
+
769
+ return result
770
+
771
+
772
+ class ScanModuleV8(nn.Module):
773
+ """Module that executes an ONNX Scan operation (version 8 with batching)."""
774
+
775
+ def __init__(
776
+ self,
777
+ body_module: torch.fx.GraphModule,
778
+ n_state_vars: int,
779
+ n_scan_inputs: int,
780
+ n_scan_outputs: int,
781
+ n_outer_vars: int,
782
+ directions: List[int],
783
+ ):
784
+ super().__init__()
785
+ self.body = body_module
786
+ self.n_state_vars = n_state_vars
787
+ self.n_scan_inputs = n_scan_inputs
788
+ self.n_scan_outputs = n_scan_outputs
789
+ self.n_outer_vars = n_outer_vars
790
+ self.directions = directions
791
+
792
+ def forward(self, *args) -> Tuple[torch.Tensor, ...]:
793
+ """Execute the scan with batching.
794
+
795
+ In v8, tensors have shape [batch, sequence, ...].
796
+ State variables have shape [batch, ...].
797
+
798
+ Args are: state_vars..., scan_inputs..., outer_vals...
799
+ Returns final state variables followed by scan outputs.
800
+ """
801
+ # Split args
802
+ state_vars = list(args[: self.n_state_vars])
803
+ scan_inputs = list(
804
+ args[self.n_state_vars : self.n_state_vars + self.n_scan_inputs]
805
+ )
806
+ outer_vals = list(args[self.n_state_vars + self.n_scan_inputs :])
807
+
808
+ # Get batch size and sequence length from first scan input
809
+ if self.n_scan_inputs > 0:
810
+ first_input = scan_inputs[0]
811
+ batch_size = first_input.shape[0]
812
+ sequence_length = first_input.shape[1]
813
+ else:
814
+ batch_size = state_vars[0].shape[0] if state_vars else 1
815
+ sequence_length = 0
816
+
817
+ # Process each batch
818
+ batch_final_states: List[List[torch.Tensor]] = [
819
+ [] for _ in range(self.n_state_vars)
820
+ ]
821
+ batch_scan_outputs: List[List[torch.Tensor]] = [
822
+ [] for _ in range(self.n_scan_outputs)
823
+ ]
824
+
825
+ for batch in range(batch_size):
826
+ # Get batch slice of state variables
827
+ current_state = [sv[batch] for sv in state_vars]
828
+
829
+ # Initialize scan output accumulators for this batch
830
+ scan_outputs: List[List[torch.Tensor]] = [
831
+ [] for _ in range(self.n_scan_outputs)
832
+ ]
833
+
834
+ # Execute loop over sequence
835
+ for t in range(sequence_length):
836
+ # Extract scan input elements for this batch and time step
837
+ scan_input_elts = []
838
+ for i, scan_input in enumerate(scan_inputs):
839
+ direction = self.directions[i] if i < len(self.directions) else 0
840
+ idx = sequence_length - 1 - t if direction == 1 else t
841
+ # scan_input has shape [batch, sequence, ...]
842
+ elt = scan_input[batch, idx]
843
+ scan_input_elts.append(elt)
844
+
845
+ # Call body
846
+ body_inputs = current_state + scan_input_elts + outer_vals
847
+ outputs = self.body(*body_inputs)
848
+
849
+ if not isinstance(outputs, tuple):
850
+ outputs = (outputs,)
851
+
852
+ # Update state
853
+ current_state = list(outputs[: self.n_state_vars])
854
+
855
+ # Collect scan outputs
856
+ for j in range(self.n_scan_outputs):
857
+ scan_outputs[j].append(outputs[self.n_state_vars + j])
858
+
859
+ # Store final state for this batch
860
+ for i, state in enumerate(current_state):
861
+ batch_final_states[i].append(state)
862
+
863
+ # Stack scan outputs for this batch
864
+ for j, scan_list in enumerate(scan_outputs):
865
+ if scan_list:
866
+ stacked = torch.stack(scan_list, dim=0)
867
+ batch_scan_outputs[j].append(stacked)
868
+
869
+ # Stack across batches
870
+ final_outputs: List[torch.Tensor] = []
871
+
872
+ # Final state variables: stack across batch
873
+ for states in batch_final_states:
874
+ final_outputs.append(torch.stack(states, dim=0))
875
+
876
+ # Scan outputs: stack across batch
877
+ for outputs in batch_scan_outputs:
878
+ if outputs:
879
+ final_outputs.append(torch.stack(outputs, dim=0))
880
+
881
+ return tuple(final_outputs)
882
+
883
+
884
+ @register("If")
885
+ def if_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
886
+ """ONNX If operator.
887
+
888
+ If has input: cond (boolean condition)
889
+ If has attributes: then_branch (GraphProto), else_branch (GraphProto)
890
+ Both branches must have the same number and types of outputs.
891
+ """
892
+ # Get condition
893
+ cond = builder.get_value(node.input[0])
894
+
895
+ # Get branch subgraphs
896
+ then_graph = get_attribute(node, "then_branch")
897
+ else_graph = get_attribute(node, "else_branch")
898
+
899
+ if then_graph is None or else_graph is None:
900
+ raise ValueError(
901
+ "If operator requires 'then_branch' and 'else_branch' attributes"
902
+ )
903
+
904
+ # Get parent type info if available (for nested subgraphs)
905
+ parent_type_info = getattr(builder, "_type_info", None)
906
+
907
+ # Build subgraph modules for both branches
908
+ then_module, then_input_names, then_output_names, then_outer_refs = (
909
+ _build_subgraph_module(
910
+ then_graph,
911
+ builder.env,
912
+ builder._opset_versions,
913
+ parent_type_info,
914
+ tensor_loader=builder.load_tensor,
915
+ )
916
+ )
917
+ else_module, else_input_names, else_output_names, else_outer_refs = (
918
+ _build_subgraph_module(
919
+ else_graph,
920
+ builder.env,
921
+ builder._opset_versions,
922
+ parent_type_info,
923
+ tensor_loader=builder.load_tensor,
924
+ )
925
+ )
926
+
927
+ # Get outer scope values for both branches
928
+ then_outer_values = [builder.get_value(name) for name in then_outer_refs]
929
+ else_outer_values = [builder.get_value(name) for name in else_outer_refs]
930
+
931
+ # Create the if module
932
+ if_module = IfModule(
933
+ then_module,
934
+ else_module,
935
+ len(then_outer_values),
936
+ len(else_outer_values),
937
+ )
938
+
939
+ # Register the if module
940
+ module_name = builder.register_submodule(f"if_{node.name or 'op'}", if_module)
941
+
942
+ # Build flat args for call_module: condition, then_outer..., else_outer...
943
+ args = [cond] + then_outer_values + else_outer_values
944
+
945
+ result = builder.call_module(module_name, args=tuple(args))
946
+
947
+ return result