onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,634 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from collections import deque
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence, Union
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.fx
9
+ import onnx
10
+ from onnx import numpy_helper
11
+
12
+ from .exceptions import UnsupportedDTypeError, UnsupportedOpError, ValueNotFoundError
13
+ from .op_registry import get_handler
14
+ from .utils.dtype import DTYPE_MAP
15
+ from .utils.external_data import resolve_external_data
16
+ from .utils.names import sanitize_name
17
+
18
+ # Import ops module to register all operators
19
+ from . import ops # noqa: F401
20
+
21
+
22
+ def _collect_all_inputs(node: onnx.NodeProto) -> set:
23
+ """Recursively collect all inputs from a node including subgraph inputs.
24
+
25
+ For control flow nodes like If and Loop, this also collects inputs that
26
+ are referenced by the subgraphs (then_branch, else_branch, body).
27
+
28
+ Parameters
29
+ ----------
30
+ node : onnx.NodeProto
31
+ The ONNX node to collect inputs from.
32
+
33
+ Returns
34
+ -------
35
+ set
36
+ Set of all input names referenced by this node and its subgraphs.
37
+ """
38
+ inputs = set(node.input)
39
+
40
+ # Collect inputs from subgraphs (for If, Loop, etc.)
41
+ for attr in node.attribute:
42
+ if attr.type == onnx.AttributeProto.GRAPH:
43
+ subgraph = attr.g
44
+ # Collect subgraph's own initializers and inputs as local values
45
+ local_values = set()
46
+ for init in subgraph.initializer:
47
+ local_values.add(init.name)
48
+ for inp in subgraph.input:
49
+ local_values.add(inp.name)
50
+
51
+ # Recursively collect inputs from subgraph nodes
52
+ for sub_node in subgraph.node:
53
+ sub_inputs = _collect_all_inputs(sub_node)
54
+ # Add outputs of this subgraph node to local values
55
+ for out in sub_node.output:
56
+ if out:
57
+ local_values.add(out)
58
+ # Inputs not satisfied locally are outer references
59
+ for sub_inp in sub_inputs:
60
+ if sub_inp and sub_inp not in local_values:
61
+ inputs.add(sub_inp)
62
+
63
+ return inputs
64
+
65
+
66
+ def _topological_sort(
67
+ nodes: List[onnx.NodeProto],
68
+ graph_inputs: set,
69
+ initializers: set,
70
+ ) -> List[onnx.NodeProto]:
71
+ """Topologically sort ONNX graph nodes using Kahn's algorithm.
72
+
73
+ Some ONNX models have nodes in non-topological order (e.g., Cast nodes
74
+ at the end of the graph but their outputs used earlier). This function
75
+ reorders nodes so dependencies are processed before their consumers.
76
+
77
+ This function also considers inputs referenced by subgraphs (for nodes
78
+ like If and Loop) to ensure proper ordering.
79
+
80
+ Parameters
81
+ ----------
82
+ nodes : List[onnx.NodeProto]
83
+ The list of ONNX nodes to sort.
84
+ graph_inputs : set
85
+ Set of graph input names.
86
+ initializers : set
87
+ Set of initializer names.
88
+
89
+ Returns
90
+ -------
91
+ List[onnx.NodeProto]
92
+ Topologically sorted list of nodes.
93
+ """
94
+ if not nodes:
95
+ return []
96
+
97
+ # Pre-compute all inputs for each node (including subgraph inputs)
98
+ node_all_inputs: Dict[int, set] = {}
99
+ for node in nodes:
100
+ node_all_inputs[id(node)] = _collect_all_inputs(node)
101
+
102
+ # Build output->node mapping (which node produces each output)
103
+ output_to_node: Dict[str, onnx.NodeProto] = {}
104
+ for node in nodes:
105
+ for output in node.output:
106
+ if output: # Skip empty outputs
107
+ output_to_node[output] = node
108
+
109
+ # Available values: graph inputs + initializers
110
+ available = graph_inputs | initializers
111
+
112
+ # Compute in-degree for each node (number of unsatisfied dependencies)
113
+ in_degree: Dict[int, int] = {}
114
+ node_id: Dict[int, onnx.NodeProto] = {}
115
+ for i, node in enumerate(nodes):
116
+ node_id[id(node)] = node
117
+ # Count inputs that are neither available nor empty
118
+ deps = 0
119
+ all_inputs = node_all_inputs[id(node)]
120
+ for inp in all_inputs:
121
+ if inp and inp not in available:
122
+ deps += 1
123
+ in_degree[id(node)] = deps
124
+
125
+ # Initialize queue with nodes that have no dependencies
126
+ queue = deque()
127
+ for node in nodes:
128
+ if in_degree[id(node)] == 0:
129
+ queue.append(node)
130
+
131
+ sorted_nodes: List[onnx.NodeProto] = []
132
+ while queue:
133
+ node = queue.popleft()
134
+ sorted_nodes.append(node)
135
+
136
+ # Mark this node's outputs as available
137
+ for output in node.output:
138
+ if output:
139
+ available.add(output)
140
+
141
+ # Reduce in-degree for nodes that depend on this node's outputs
142
+ for candidate in nodes:
143
+ if in_degree[id(candidate)] > 0:
144
+ # Check if any of candidate's inputs are now satisfied
145
+ all_inputs = node_all_inputs[id(candidate)]
146
+ for inp in all_inputs:
147
+ if inp in node.output and inp:
148
+ in_degree[id(candidate)] -= 1
149
+ if in_degree[id(candidate)] == 0:
150
+ queue.append(candidate)
151
+
152
+ # If we couldn't sort all nodes, there's a cycle or missing dependency
153
+ # Fall back to original order
154
+ if len(sorted_nodes) != len(nodes):
155
+ return list(nodes)
156
+
157
+ return sorted_nodes
158
+
159
+
160
+ class GraphBuilder:
161
+ """Builds a PyTorch FX GraphModule from an ONNX model.
162
+
163
+ This class handles the conversion of ONNX graph structure to PyTorch FX,
164
+ including initializer loading, placeholder creation, node conversion,
165
+ and output creation.
166
+
167
+ Parameters
168
+ ----------
169
+ model : onnx.ModelProto
170
+ The ONNX model to convert.
171
+
172
+ Attributes
173
+ ----------
174
+ model : onnx.ModelProto
175
+ The input ONNX model (with shape inference if successful).
176
+ graph : torch.fx.Graph
177
+ The FX graph being constructed.
178
+ env : Dict[str, torch.fx.Node]
179
+ Mapping from ONNX tensor names to FX nodes.
180
+ opset_version : int
181
+ The opset version for the default ONNX domain.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ model: onnx.ModelProto,
187
+ *,
188
+ base_dir: Optional[str] = None,
189
+ memmap_external_data: bool = False,
190
+ ) -> None:
191
+ # Try shape inference but preserve original model if it fails
192
+ # (shape_inference may drop graph contents for large models with external data)
193
+ try:
194
+ inferred_model = onnx.shape_inference.infer_shapes(model)
195
+ # Check if shape inference preserved the model structure
196
+ if len(inferred_model.graph.node) > 0:
197
+ model = inferred_model
198
+ # If nodes were lost, keep original model
199
+ except Exception:
200
+ pass
201
+ self.model: onnx.ModelProto = model
202
+ self.graph: torch.fx.Graph = torch.fx.Graph()
203
+ self._base_dir = base_dir
204
+ self._memmap_external_data = memmap_external_data
205
+ self.value_info_map = self._create_value_info_map()
206
+ self.initializer_map = self._create_initializer_map()
207
+ self.input_names: List[str] = []
208
+ self.env: Dict[str, torch.fx.Node] = {}
209
+ self._constants: Dict[str, torch.Tensor] = {}
210
+ self._submodules: Dict[str, torch.nn.Module] = {}
211
+ self._opset_versions: Dict[str, int] = self._extract_opset_versions()
212
+
213
+ def _extract_opset_versions(self) -> Dict[str, int]:
214
+ """Extract opset versions for all domains from the model.
215
+
216
+ Returns
217
+ -------
218
+ Dict[str, int]
219
+ Dictionary mapping domain names to their opset versions.
220
+ Empty string "" represents the default ONNX domain.
221
+ """
222
+ versions: Dict[str, int] = {}
223
+ for opset in self.model.opset_import:
224
+ domain = opset.domain if opset.domain else ""
225
+ versions[domain] = opset.version
226
+ return versions
227
+
228
+ def _resolve_handler(
229
+ self, node: onnx.NodeProto
230
+ ) -> tuple[Callable[["GraphBuilder", onnx.NodeProto], Any], str, int]:
231
+ """Resolve the handler for an ONNX node and return handler, domain, opset."""
232
+ domain = node.domain if node.domain else ""
233
+ opset = self.get_opset_version(domain)
234
+ handler = get_handler(node.op_type, domain, opset)
235
+ if handler is None:
236
+ raise UnsupportedOpError(node.op_type, domain=domain, opset_version=opset)
237
+ return handler, domain, opset
238
+
239
+ def _tag_operator_node(
240
+ self, node: onnx.NodeProto, fx_node: torch.fx.Node, domain: str
241
+ ) -> None:
242
+ """Attach ONNX metadata to an operator node."""
243
+ if fx_node is not None and hasattr(fx_node, "meta"):
244
+ self._set_onnx_metadata(
245
+ fx_node,
246
+ op_type=node.op_type,
247
+ name=node.name,
248
+ domain=domain,
249
+ )
250
+
251
+ def _register_outputs(
252
+ self, node: onnx.NodeProto, fx_node: torch.fx.Node, domain: str
253
+ ) -> None:
254
+ """Register node outputs in the environment."""
255
+ if len(node.output) == 1:
256
+ self.env[node.output[0]] = fx_node
257
+ return
258
+
259
+ for i, output_name in enumerate(node.output):
260
+ if output_name: # Skip empty output names
261
+ getitem_node = self.graph.call_function(
262
+ lambda x, idx=i: x[idx] if isinstance(x, (tuple, list)) else x,
263
+ args=(fx_node, i),
264
+ )
265
+ self._set_onnx_metadata(
266
+ getitem_node,
267
+ op_type=node.op_type,
268
+ name=node.name,
269
+ domain=domain,
270
+ output_index=i,
271
+ )
272
+ self.env[output_name] = getitem_node
273
+
274
+ @property
275
+ def opset_version(self) -> int:
276
+ """Get the opset version for the default ONNX domain.
277
+
278
+ Returns
279
+ -------
280
+ int
281
+ The opset version number. Defaults to 1 if not specified.
282
+ """
283
+ return self._opset_versions.get("", 1)
284
+
285
+ def get_opset_version(self, domain: str = "") -> int:
286
+ """Get the opset version for a specific domain.
287
+
288
+ Parameters
289
+ ----------
290
+ domain : str, optional
291
+ The ONNX domain. Default is "" (standard ONNX domain).
292
+
293
+ Returns
294
+ -------
295
+ int
296
+ The opset version number. Defaults to 1 if not specified.
297
+ """
298
+ return self._opset_versions.get(domain, 1)
299
+
300
+ def build(self) -> torch.fx.GraphModule:
301
+ self._load_initializers()
302
+ self._create_placeholders()
303
+ self._convert_nodes()
304
+ self._create_outputs()
305
+ root_module = torch.nn.Module()
306
+ # Register constants as buffers
307
+ for name, tensor in self._constants.items():
308
+ root_module.register_buffer(sanitize_name(name), tensor)
309
+ # Register submodules
310
+ for name, submod in self._submodules.items():
311
+ root_module.add_module(name, submod)
312
+ module = torch.fx.GraphModule(root_module, self.graph)
313
+ if self._memmap_external_data:
314
+ module._onnx2fx_inference_only = True
315
+ module.graph.lint()
316
+ return module
317
+
318
+ @staticmethod
319
+ def _set_onnx_metadata(
320
+ node: torch.fx.Node,
321
+ *,
322
+ op_type: str,
323
+ name: Optional[str] = None,
324
+ domain: Optional[str] = None,
325
+ shape: Optional[List[Optional[int]]] = None,
326
+ dtype: Optional[torch.dtype] = None,
327
+ output_index: Optional[int] = None,
328
+ ) -> None:
329
+ """Populate standard ONNX metadata on an FX node."""
330
+ node.meta["onnx_op_type"] = op_type
331
+ if name is not None:
332
+ node.meta["onnx_name"] = name
333
+ if domain is not None:
334
+ node.meta["onnx_domain"] = domain
335
+ if shape is not None:
336
+ node.meta["onnx_shape"] = shape
337
+ if dtype is not None:
338
+ node.meta["onnx_dtype"] = dtype
339
+ if output_index is not None:
340
+ node.meta["onnx_output_index"] = output_index
341
+
342
+ def get_value(self, name: str) -> torch.fx.Node:
343
+ """Get a value (node) by name from the environment.
344
+
345
+ Parameters
346
+ ----------
347
+ name : str
348
+ The name of the value.
349
+
350
+ Returns
351
+ -------
352
+ torch.fx.Node
353
+ The corresponding FX node.
354
+
355
+ Raises
356
+ ------
357
+ KeyError
358
+ If the name is not found in the environment.
359
+ """
360
+ if name not in self.env:
361
+ raise ValueNotFoundError(name, available=list(self.env.keys()))
362
+ return self.env[name]
363
+
364
+ def has_value(self, name: str) -> bool:
365
+ """Check if a value exists in the environment."""
366
+ return name in self.env
367
+
368
+ def call_function(
369
+ self,
370
+ func: Callable[..., Any],
371
+ args: Sequence[Union[torch.fx.Node, Any]] = (),
372
+ kwargs: Optional[Dict[str, Any]] = None,
373
+ ) -> torch.fx.Node:
374
+ """Create a function call node in the FX graph.
375
+
376
+ Parameters
377
+ ----------
378
+ func : Callable[..., Any]
379
+ The function to call. Can be a PyTorch function, lambda, or any callable.
380
+ args : Sequence[Union[torch.fx.Node, Any]], optional
381
+ Positional arguments to the function. Can include FX nodes or constants.
382
+ kwargs : Optional[Dict[str, Any]], optional
383
+ Keyword arguments to the function.
384
+
385
+ Returns
386
+ -------
387
+ torch.fx.Node
388
+ The FX node representing this function call.
389
+ """
390
+ fx_node = self.graph.call_function(func, args=tuple(args), kwargs=kwargs or {})
391
+ return fx_node
392
+
393
+ def register_submodule(self, name: str, module: torch.nn.Module) -> str:
394
+ """Register a submodule for use in the graph.
395
+
396
+ Parameters
397
+ ----------
398
+ name : str
399
+ Base name for the submodule.
400
+ module : torch.nn.Module
401
+ The submodule to register.
402
+
403
+ Returns
404
+ -------
405
+ str
406
+ The actual name used (may be modified to avoid conflicts).
407
+ """
408
+ # Sanitize name
409
+ safe_name = sanitize_name(name)
410
+ # Ensure unique name
411
+ if safe_name in self._submodules:
412
+ counter = 0
413
+ while f"{safe_name}_{counter}" in self._submodules:
414
+ counter += 1
415
+ safe_name = f"{safe_name}_{counter}"
416
+ self._submodules[safe_name] = module
417
+ return safe_name
418
+
419
+ def call_module(
420
+ self,
421
+ module_name: str,
422
+ args: Sequence[Union[torch.fx.Node, Any]] = (),
423
+ kwargs: Optional[Dict[str, Any]] = None,
424
+ ) -> torch.fx.Node:
425
+ """Create a module call node in the FX graph.
426
+
427
+ Parameters
428
+ ----------
429
+ module_name : str
430
+ The name of a registered submodule.
431
+ args : Sequence[Union[torch.fx.Node, Any]], optional
432
+ Positional arguments to the module.
433
+ kwargs : Optional[Dict[str, Any]], optional
434
+ Keyword arguments to the module.
435
+
436
+ Returns
437
+ -------
438
+ torch.fx.Node
439
+ The FX node representing this module call.
440
+ """
441
+ fx_node = self.graph.call_module(
442
+ module_name, args=tuple(args), kwargs=kwargs or {}
443
+ )
444
+ return fx_node
445
+
446
+ def _create_value_info_map(
447
+ self,
448
+ ) -> Dict[str, Tuple[Optional[List[Optional[int]]], Optional[torch.dtype]]]:
449
+ """Build a mapping from value names to their shape and dtype info."""
450
+
451
+ def extract_tensor_shape(
452
+ value: onnx.ValueInfoProto,
453
+ ) -> Optional[List[Optional[int]]]:
454
+ """Extract a list-based representation of a tensor shape from a value info."""
455
+
456
+ tensor_type = value.type.tensor_type
457
+ if not tensor_type.HasField("shape"):
458
+ return None
459
+ dims: List[Optional[int]] = []
460
+ for dim in tensor_type.shape.dim:
461
+ if dim.HasField("dim_value"):
462
+ dims.append(int(dim.dim_value))
463
+ elif dim.HasField("dim_param"):
464
+ dims.append(None)
465
+ else:
466
+ dims.append(None)
467
+ return dims
468
+
469
+ def extract_tensor_dtype(value: onnx.ValueInfoProto) -> Optional[torch.dtype]:
470
+ """Extract the Torch dtype that corresponds to a value info."""
471
+
472
+ onnx_dtype = value.type.tensor_type.elem_type
473
+ if onnx_dtype == 0:
474
+ return None
475
+ torch_dtype = DTYPE_MAP.get(onnx_dtype)
476
+ if torch_dtype is None:
477
+ if onnx_dtype == onnx.TensorProto.STRING:
478
+ return None
479
+ raise UnsupportedDTypeError(
480
+ onnx_dtype=onnx_dtype,
481
+ tensor_name=value.name,
482
+ details="value_info dtype not supported",
483
+ )
484
+ return torch_dtype
485
+
486
+ info_map = {}
487
+ for value_info in (
488
+ list(self.model.graph.input)
489
+ + list(self.model.graph.value_info)
490
+ + list(self.model.graph.output)
491
+ ):
492
+ info_map[value_info.name] = (
493
+ extract_tensor_shape(value_info),
494
+ extract_tensor_dtype(value_info),
495
+ )
496
+ return info_map
497
+
498
+ def is_optional_type(self, name: str) -> bool:
499
+ """Check if a value has optional type in the ONNX model.
500
+
501
+ Parameters
502
+ ----------
503
+ name : str
504
+ The name of the value to check.
505
+
506
+ Returns
507
+ -------
508
+ bool
509
+ True if the value has optional type, False otherwise.
510
+ """
511
+ # Search in graph inputs
512
+ for value_info in self.model.graph.input:
513
+ if value_info.name == name:
514
+ return value_info.type.HasField("optional_type")
515
+ # Search in value_info
516
+ for value_info in self.model.graph.value_info:
517
+ if value_info.name == name:
518
+ return value_info.type.HasField("optional_type")
519
+ # Search in outputs
520
+ for value_info in self.model.graph.output:
521
+ if value_info.name == name:
522
+ return value_info.type.HasField("optional_type")
523
+ return False
524
+
525
+ def _create_initializer_map(self) -> Dict[str, torch.Tensor]:
526
+ """Build a mapping from initializer names to PyTorch tensors."""
527
+ init_map = {}
528
+ for initializer in self.model.graph.initializer:
529
+ init_map[initializer.name] = self.load_tensor(initializer)
530
+ return init_map
531
+
532
+ def load_tensor(self, tensor: onnx.TensorProto) -> torch.Tensor:
533
+ """Load an ONNX TensorProto into a Torch tensor."""
534
+ onnx_dtype = tensor.data_type
535
+ if DTYPE_MAP.get(onnx_dtype) is None:
536
+ raise UnsupportedDTypeError(
537
+ onnx_dtype=onnx_dtype,
538
+ tensor_name=tensor.name or "<unnamed>",
539
+ details="initializer dtype not supported",
540
+ )
541
+
542
+ if self._memmap_external_data and (
543
+ tensor.data_location == onnx.TensorProto.EXTERNAL or tensor.external_data
544
+ ):
545
+ info = resolve_external_data(
546
+ tensor,
547
+ base_dir=self._base_dir,
548
+ strict=True,
549
+ )
550
+ memmap_array = np.memmap(
551
+ info.path,
552
+ dtype=info.numpy_dtype,
553
+ mode="r",
554
+ offset=info.offset,
555
+ shape=info.shape,
556
+ )
557
+ return torch.from_numpy(memmap_array)
558
+
559
+ np_array = numpy_helper.to_array(tensor)
560
+ return torch.from_numpy(np_array.copy())
561
+
562
+ def _load_initializers(self) -> None:
563
+ """Load ONNX initializers as constant nodes in the FX graph."""
564
+ for name, tensor in self.initializer_map.items():
565
+ # Store in constants dict for later registration as buffers
566
+ safe_name = sanitize_name(name)
567
+ self._constants[safe_name] = tensor
568
+
569
+ # Create a get_attr node to access the buffer
570
+ fx_node = self.graph.get_attr(safe_name)
571
+ self._set_onnx_metadata(
572
+ fx_node,
573
+ op_type="Initializer",
574
+ name=name,
575
+ shape=list(tensor.shape),
576
+ dtype=tensor.dtype,
577
+ )
578
+ self.env[name] = fx_node
579
+
580
+ def _create_placeholders(self) -> None:
581
+ """Create FX placeholder nodes for graph inputs.
582
+
583
+ Note: Inputs that are already loaded as initializers are skipped.
584
+ """
585
+ for value in self.model.graph.input:
586
+ # Skip if already loaded as initializer
587
+ if value.name in self.env:
588
+ continue
589
+
590
+ # Sanitize name for valid Python identifier
591
+ safe_name = sanitize_name(value.name)
592
+ placeholder = self.graph.placeholder(safe_name)
593
+ info = self.value_info_map.get(value.name)
594
+ self._set_onnx_metadata(
595
+ placeholder,
596
+ op_type="Input",
597
+ name=value.name,
598
+ shape=info[0] if info else None,
599
+ dtype=info[1] if info else None,
600
+ )
601
+ self.env[value.name] = placeholder
602
+ self.input_names.append(value.name)
603
+
604
+ def _convert_nodes(self) -> None:
605
+ # Get graph inputs and initializers for topological sort
606
+ graph_inputs = {inp.name for inp in self.model.graph.input}
607
+ initializers = set(self.initializer_map.keys())
608
+
609
+ # Topologically sort nodes to handle out-of-order dependencies
610
+ sorted_nodes = _topological_sort(
611
+ list(self.model.graph.node),
612
+ graph_inputs,
613
+ initializers,
614
+ )
615
+
616
+ for node in sorted_nodes:
617
+ # Get handler with domain and opset version support
618
+ handler, domain, _opset = self._resolve_handler(node)
619
+ fx_node = handler(self, node)
620
+
621
+ # Add ONNX metadata to the operator node
622
+ # Some handlers return a list of nodes (e.g., gradient ops)
623
+ self._tag_operator_node(node, fx_node, domain)
624
+ self._register_outputs(node, fx_node, domain)
625
+
626
+ def _create_outputs(self) -> None:
627
+ output_nodes = [self.get_value(value.name) for value in self.model.graph.output]
628
+ if len(output_nodes) == 1:
629
+ self.graph.output(output_nodes[0])
630
+ else:
631
+ self.graph.output(tuple(output_nodes))
632
+
633
+
634
+ __all__ = ["GraphBuilder"]