onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/graph_builder.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
from collections import deque
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.fx
|
|
9
|
+
import onnx
|
|
10
|
+
from onnx import numpy_helper
|
|
11
|
+
|
|
12
|
+
from .exceptions import UnsupportedDTypeError, UnsupportedOpError, ValueNotFoundError
|
|
13
|
+
from .op_registry import get_handler
|
|
14
|
+
from .utils.dtype import DTYPE_MAP
|
|
15
|
+
from .utils.external_data import resolve_external_data
|
|
16
|
+
from .utils.names import sanitize_name
|
|
17
|
+
|
|
18
|
+
# Import ops module to register all operators
|
|
19
|
+
from . import ops # noqa: F401
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _collect_all_inputs(node: onnx.NodeProto) -> set:
|
|
23
|
+
"""Recursively collect all inputs from a node including subgraph inputs.
|
|
24
|
+
|
|
25
|
+
For control flow nodes like If and Loop, this also collects inputs that
|
|
26
|
+
are referenced by the subgraphs (then_branch, else_branch, body).
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
node : onnx.NodeProto
|
|
31
|
+
The ONNX node to collect inputs from.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
set
|
|
36
|
+
Set of all input names referenced by this node and its subgraphs.
|
|
37
|
+
"""
|
|
38
|
+
inputs = set(node.input)
|
|
39
|
+
|
|
40
|
+
# Collect inputs from subgraphs (for If, Loop, etc.)
|
|
41
|
+
for attr in node.attribute:
|
|
42
|
+
if attr.type == onnx.AttributeProto.GRAPH:
|
|
43
|
+
subgraph = attr.g
|
|
44
|
+
# Collect subgraph's own initializers and inputs as local values
|
|
45
|
+
local_values = set()
|
|
46
|
+
for init in subgraph.initializer:
|
|
47
|
+
local_values.add(init.name)
|
|
48
|
+
for inp in subgraph.input:
|
|
49
|
+
local_values.add(inp.name)
|
|
50
|
+
|
|
51
|
+
# Recursively collect inputs from subgraph nodes
|
|
52
|
+
for sub_node in subgraph.node:
|
|
53
|
+
sub_inputs = _collect_all_inputs(sub_node)
|
|
54
|
+
# Add outputs of this subgraph node to local values
|
|
55
|
+
for out in sub_node.output:
|
|
56
|
+
if out:
|
|
57
|
+
local_values.add(out)
|
|
58
|
+
# Inputs not satisfied locally are outer references
|
|
59
|
+
for sub_inp in sub_inputs:
|
|
60
|
+
if sub_inp and sub_inp not in local_values:
|
|
61
|
+
inputs.add(sub_inp)
|
|
62
|
+
|
|
63
|
+
return inputs
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _topological_sort(
|
|
67
|
+
nodes: List[onnx.NodeProto],
|
|
68
|
+
graph_inputs: set,
|
|
69
|
+
initializers: set,
|
|
70
|
+
) -> List[onnx.NodeProto]:
|
|
71
|
+
"""Topologically sort ONNX graph nodes using Kahn's algorithm.
|
|
72
|
+
|
|
73
|
+
Some ONNX models have nodes in non-topological order (e.g., Cast nodes
|
|
74
|
+
at the end of the graph but their outputs used earlier). This function
|
|
75
|
+
reorders nodes so dependencies are processed before their consumers.
|
|
76
|
+
|
|
77
|
+
This function also considers inputs referenced by subgraphs (for nodes
|
|
78
|
+
like If and Loop) to ensure proper ordering.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
nodes : List[onnx.NodeProto]
|
|
83
|
+
The list of ONNX nodes to sort.
|
|
84
|
+
graph_inputs : set
|
|
85
|
+
Set of graph input names.
|
|
86
|
+
initializers : set
|
|
87
|
+
Set of initializer names.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
List[onnx.NodeProto]
|
|
92
|
+
Topologically sorted list of nodes.
|
|
93
|
+
"""
|
|
94
|
+
if not nodes:
|
|
95
|
+
return []
|
|
96
|
+
|
|
97
|
+
# Pre-compute all inputs for each node (including subgraph inputs)
|
|
98
|
+
node_all_inputs: Dict[int, set] = {}
|
|
99
|
+
for node in nodes:
|
|
100
|
+
node_all_inputs[id(node)] = _collect_all_inputs(node)
|
|
101
|
+
|
|
102
|
+
# Build output->node mapping (which node produces each output)
|
|
103
|
+
output_to_node: Dict[str, onnx.NodeProto] = {}
|
|
104
|
+
for node in nodes:
|
|
105
|
+
for output in node.output:
|
|
106
|
+
if output: # Skip empty outputs
|
|
107
|
+
output_to_node[output] = node
|
|
108
|
+
|
|
109
|
+
# Available values: graph inputs + initializers
|
|
110
|
+
available = graph_inputs | initializers
|
|
111
|
+
|
|
112
|
+
# Compute in-degree for each node (number of unsatisfied dependencies)
|
|
113
|
+
in_degree: Dict[int, int] = {}
|
|
114
|
+
node_id: Dict[int, onnx.NodeProto] = {}
|
|
115
|
+
for i, node in enumerate(nodes):
|
|
116
|
+
node_id[id(node)] = node
|
|
117
|
+
# Count inputs that are neither available nor empty
|
|
118
|
+
deps = 0
|
|
119
|
+
all_inputs = node_all_inputs[id(node)]
|
|
120
|
+
for inp in all_inputs:
|
|
121
|
+
if inp and inp not in available:
|
|
122
|
+
deps += 1
|
|
123
|
+
in_degree[id(node)] = deps
|
|
124
|
+
|
|
125
|
+
# Initialize queue with nodes that have no dependencies
|
|
126
|
+
queue = deque()
|
|
127
|
+
for node in nodes:
|
|
128
|
+
if in_degree[id(node)] == 0:
|
|
129
|
+
queue.append(node)
|
|
130
|
+
|
|
131
|
+
sorted_nodes: List[onnx.NodeProto] = []
|
|
132
|
+
while queue:
|
|
133
|
+
node = queue.popleft()
|
|
134
|
+
sorted_nodes.append(node)
|
|
135
|
+
|
|
136
|
+
# Mark this node's outputs as available
|
|
137
|
+
for output in node.output:
|
|
138
|
+
if output:
|
|
139
|
+
available.add(output)
|
|
140
|
+
|
|
141
|
+
# Reduce in-degree for nodes that depend on this node's outputs
|
|
142
|
+
for candidate in nodes:
|
|
143
|
+
if in_degree[id(candidate)] > 0:
|
|
144
|
+
# Check if any of candidate's inputs are now satisfied
|
|
145
|
+
all_inputs = node_all_inputs[id(candidate)]
|
|
146
|
+
for inp in all_inputs:
|
|
147
|
+
if inp in node.output and inp:
|
|
148
|
+
in_degree[id(candidate)] -= 1
|
|
149
|
+
if in_degree[id(candidate)] == 0:
|
|
150
|
+
queue.append(candidate)
|
|
151
|
+
|
|
152
|
+
# If we couldn't sort all nodes, there's a cycle or missing dependency
|
|
153
|
+
# Fall back to original order
|
|
154
|
+
if len(sorted_nodes) != len(nodes):
|
|
155
|
+
return list(nodes)
|
|
156
|
+
|
|
157
|
+
return sorted_nodes
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class GraphBuilder:
|
|
161
|
+
"""Builds a PyTorch FX GraphModule from an ONNX model.
|
|
162
|
+
|
|
163
|
+
This class handles the conversion of ONNX graph structure to PyTorch FX,
|
|
164
|
+
including initializer loading, placeholder creation, node conversion,
|
|
165
|
+
and output creation.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
model : onnx.ModelProto
|
|
170
|
+
The ONNX model to convert.
|
|
171
|
+
|
|
172
|
+
Attributes
|
|
173
|
+
----------
|
|
174
|
+
model : onnx.ModelProto
|
|
175
|
+
The input ONNX model (with shape inference if successful).
|
|
176
|
+
graph : torch.fx.Graph
|
|
177
|
+
The FX graph being constructed.
|
|
178
|
+
env : Dict[str, torch.fx.Node]
|
|
179
|
+
Mapping from ONNX tensor names to FX nodes.
|
|
180
|
+
opset_version : int
|
|
181
|
+
The opset version for the default ONNX domain.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
model: onnx.ModelProto,
|
|
187
|
+
*,
|
|
188
|
+
base_dir: Optional[str] = None,
|
|
189
|
+
memmap_external_data: bool = False,
|
|
190
|
+
) -> None:
|
|
191
|
+
# Try shape inference but preserve original model if it fails
|
|
192
|
+
# (shape_inference may drop graph contents for large models with external data)
|
|
193
|
+
try:
|
|
194
|
+
inferred_model = onnx.shape_inference.infer_shapes(model)
|
|
195
|
+
# Check if shape inference preserved the model structure
|
|
196
|
+
if len(inferred_model.graph.node) > 0:
|
|
197
|
+
model = inferred_model
|
|
198
|
+
# If nodes were lost, keep original model
|
|
199
|
+
except Exception:
|
|
200
|
+
pass
|
|
201
|
+
self.model: onnx.ModelProto = model
|
|
202
|
+
self.graph: torch.fx.Graph = torch.fx.Graph()
|
|
203
|
+
self._base_dir = base_dir
|
|
204
|
+
self._memmap_external_data = memmap_external_data
|
|
205
|
+
self.value_info_map = self._create_value_info_map()
|
|
206
|
+
self.initializer_map = self._create_initializer_map()
|
|
207
|
+
self.input_names: List[str] = []
|
|
208
|
+
self.env: Dict[str, torch.fx.Node] = {}
|
|
209
|
+
self._constants: Dict[str, torch.Tensor] = {}
|
|
210
|
+
self._submodules: Dict[str, torch.nn.Module] = {}
|
|
211
|
+
self._opset_versions: Dict[str, int] = self._extract_opset_versions()
|
|
212
|
+
|
|
213
|
+
def _extract_opset_versions(self) -> Dict[str, int]:
|
|
214
|
+
"""Extract opset versions for all domains from the model.
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
Dict[str, int]
|
|
219
|
+
Dictionary mapping domain names to their opset versions.
|
|
220
|
+
Empty string "" represents the default ONNX domain.
|
|
221
|
+
"""
|
|
222
|
+
versions: Dict[str, int] = {}
|
|
223
|
+
for opset in self.model.opset_import:
|
|
224
|
+
domain = opset.domain if opset.domain else ""
|
|
225
|
+
versions[domain] = opset.version
|
|
226
|
+
return versions
|
|
227
|
+
|
|
228
|
+
def _resolve_handler(
|
|
229
|
+
self, node: onnx.NodeProto
|
|
230
|
+
) -> tuple[Callable[["GraphBuilder", onnx.NodeProto], Any], str, int]:
|
|
231
|
+
"""Resolve the handler for an ONNX node and return handler, domain, opset."""
|
|
232
|
+
domain = node.domain if node.domain else ""
|
|
233
|
+
opset = self.get_opset_version(domain)
|
|
234
|
+
handler = get_handler(node.op_type, domain, opset)
|
|
235
|
+
if handler is None:
|
|
236
|
+
raise UnsupportedOpError(node.op_type, domain=domain, opset_version=opset)
|
|
237
|
+
return handler, domain, opset
|
|
238
|
+
|
|
239
|
+
def _tag_operator_node(
|
|
240
|
+
self, node: onnx.NodeProto, fx_node: torch.fx.Node, domain: str
|
|
241
|
+
) -> None:
|
|
242
|
+
"""Attach ONNX metadata to an operator node."""
|
|
243
|
+
if fx_node is not None and hasattr(fx_node, "meta"):
|
|
244
|
+
self._set_onnx_metadata(
|
|
245
|
+
fx_node,
|
|
246
|
+
op_type=node.op_type,
|
|
247
|
+
name=node.name,
|
|
248
|
+
domain=domain,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def _register_outputs(
|
|
252
|
+
self, node: onnx.NodeProto, fx_node: torch.fx.Node, domain: str
|
|
253
|
+
) -> None:
|
|
254
|
+
"""Register node outputs in the environment."""
|
|
255
|
+
if len(node.output) == 1:
|
|
256
|
+
self.env[node.output[0]] = fx_node
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
for i, output_name in enumerate(node.output):
|
|
260
|
+
if output_name: # Skip empty output names
|
|
261
|
+
getitem_node = self.graph.call_function(
|
|
262
|
+
lambda x, idx=i: x[idx] if isinstance(x, (tuple, list)) else x,
|
|
263
|
+
args=(fx_node, i),
|
|
264
|
+
)
|
|
265
|
+
self._set_onnx_metadata(
|
|
266
|
+
getitem_node,
|
|
267
|
+
op_type=node.op_type,
|
|
268
|
+
name=node.name,
|
|
269
|
+
domain=domain,
|
|
270
|
+
output_index=i,
|
|
271
|
+
)
|
|
272
|
+
self.env[output_name] = getitem_node
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def opset_version(self) -> int:
|
|
276
|
+
"""Get the opset version for the default ONNX domain.
|
|
277
|
+
|
|
278
|
+
Returns
|
|
279
|
+
-------
|
|
280
|
+
int
|
|
281
|
+
The opset version number. Defaults to 1 if not specified.
|
|
282
|
+
"""
|
|
283
|
+
return self._opset_versions.get("", 1)
|
|
284
|
+
|
|
285
|
+
def get_opset_version(self, domain: str = "") -> int:
|
|
286
|
+
"""Get the opset version for a specific domain.
|
|
287
|
+
|
|
288
|
+
Parameters
|
|
289
|
+
----------
|
|
290
|
+
domain : str, optional
|
|
291
|
+
The ONNX domain. Default is "" (standard ONNX domain).
|
|
292
|
+
|
|
293
|
+
Returns
|
|
294
|
+
-------
|
|
295
|
+
int
|
|
296
|
+
The opset version number. Defaults to 1 if not specified.
|
|
297
|
+
"""
|
|
298
|
+
return self._opset_versions.get(domain, 1)
|
|
299
|
+
|
|
300
|
+
def build(self) -> torch.fx.GraphModule:
|
|
301
|
+
self._load_initializers()
|
|
302
|
+
self._create_placeholders()
|
|
303
|
+
self._convert_nodes()
|
|
304
|
+
self._create_outputs()
|
|
305
|
+
root_module = torch.nn.Module()
|
|
306
|
+
# Register constants as buffers
|
|
307
|
+
for name, tensor in self._constants.items():
|
|
308
|
+
root_module.register_buffer(sanitize_name(name), tensor)
|
|
309
|
+
# Register submodules
|
|
310
|
+
for name, submod in self._submodules.items():
|
|
311
|
+
root_module.add_module(name, submod)
|
|
312
|
+
module = torch.fx.GraphModule(root_module, self.graph)
|
|
313
|
+
if self._memmap_external_data:
|
|
314
|
+
module._onnx2fx_inference_only = True
|
|
315
|
+
module.graph.lint()
|
|
316
|
+
return module
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def _set_onnx_metadata(
|
|
320
|
+
node: torch.fx.Node,
|
|
321
|
+
*,
|
|
322
|
+
op_type: str,
|
|
323
|
+
name: Optional[str] = None,
|
|
324
|
+
domain: Optional[str] = None,
|
|
325
|
+
shape: Optional[List[Optional[int]]] = None,
|
|
326
|
+
dtype: Optional[torch.dtype] = None,
|
|
327
|
+
output_index: Optional[int] = None,
|
|
328
|
+
) -> None:
|
|
329
|
+
"""Populate standard ONNX metadata on an FX node."""
|
|
330
|
+
node.meta["onnx_op_type"] = op_type
|
|
331
|
+
if name is not None:
|
|
332
|
+
node.meta["onnx_name"] = name
|
|
333
|
+
if domain is not None:
|
|
334
|
+
node.meta["onnx_domain"] = domain
|
|
335
|
+
if shape is not None:
|
|
336
|
+
node.meta["onnx_shape"] = shape
|
|
337
|
+
if dtype is not None:
|
|
338
|
+
node.meta["onnx_dtype"] = dtype
|
|
339
|
+
if output_index is not None:
|
|
340
|
+
node.meta["onnx_output_index"] = output_index
|
|
341
|
+
|
|
342
|
+
def get_value(self, name: str) -> torch.fx.Node:
|
|
343
|
+
"""Get a value (node) by name from the environment.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
name : str
|
|
348
|
+
The name of the value.
|
|
349
|
+
|
|
350
|
+
Returns
|
|
351
|
+
-------
|
|
352
|
+
torch.fx.Node
|
|
353
|
+
The corresponding FX node.
|
|
354
|
+
|
|
355
|
+
Raises
|
|
356
|
+
------
|
|
357
|
+
KeyError
|
|
358
|
+
If the name is not found in the environment.
|
|
359
|
+
"""
|
|
360
|
+
if name not in self.env:
|
|
361
|
+
raise ValueNotFoundError(name, available=list(self.env.keys()))
|
|
362
|
+
return self.env[name]
|
|
363
|
+
|
|
364
|
+
def has_value(self, name: str) -> bool:
|
|
365
|
+
"""Check if a value exists in the environment."""
|
|
366
|
+
return name in self.env
|
|
367
|
+
|
|
368
|
+
def call_function(
|
|
369
|
+
self,
|
|
370
|
+
func: Callable[..., Any],
|
|
371
|
+
args: Sequence[Union[torch.fx.Node, Any]] = (),
|
|
372
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
373
|
+
) -> torch.fx.Node:
|
|
374
|
+
"""Create a function call node in the FX graph.
|
|
375
|
+
|
|
376
|
+
Parameters
|
|
377
|
+
----------
|
|
378
|
+
func : Callable[..., Any]
|
|
379
|
+
The function to call. Can be a PyTorch function, lambda, or any callable.
|
|
380
|
+
args : Sequence[Union[torch.fx.Node, Any]], optional
|
|
381
|
+
Positional arguments to the function. Can include FX nodes or constants.
|
|
382
|
+
kwargs : Optional[Dict[str, Any]], optional
|
|
383
|
+
Keyword arguments to the function.
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
torch.fx.Node
|
|
388
|
+
The FX node representing this function call.
|
|
389
|
+
"""
|
|
390
|
+
fx_node = self.graph.call_function(func, args=tuple(args), kwargs=kwargs or {})
|
|
391
|
+
return fx_node
|
|
392
|
+
|
|
393
|
+
def register_submodule(self, name: str, module: torch.nn.Module) -> str:
|
|
394
|
+
"""Register a submodule for use in the graph.
|
|
395
|
+
|
|
396
|
+
Parameters
|
|
397
|
+
----------
|
|
398
|
+
name : str
|
|
399
|
+
Base name for the submodule.
|
|
400
|
+
module : torch.nn.Module
|
|
401
|
+
The submodule to register.
|
|
402
|
+
|
|
403
|
+
Returns
|
|
404
|
+
-------
|
|
405
|
+
str
|
|
406
|
+
The actual name used (may be modified to avoid conflicts).
|
|
407
|
+
"""
|
|
408
|
+
# Sanitize name
|
|
409
|
+
safe_name = sanitize_name(name)
|
|
410
|
+
# Ensure unique name
|
|
411
|
+
if safe_name in self._submodules:
|
|
412
|
+
counter = 0
|
|
413
|
+
while f"{safe_name}_{counter}" in self._submodules:
|
|
414
|
+
counter += 1
|
|
415
|
+
safe_name = f"{safe_name}_{counter}"
|
|
416
|
+
self._submodules[safe_name] = module
|
|
417
|
+
return safe_name
|
|
418
|
+
|
|
419
|
+
def call_module(
|
|
420
|
+
self,
|
|
421
|
+
module_name: str,
|
|
422
|
+
args: Sequence[Union[torch.fx.Node, Any]] = (),
|
|
423
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
424
|
+
) -> torch.fx.Node:
|
|
425
|
+
"""Create a module call node in the FX graph.
|
|
426
|
+
|
|
427
|
+
Parameters
|
|
428
|
+
----------
|
|
429
|
+
module_name : str
|
|
430
|
+
The name of a registered submodule.
|
|
431
|
+
args : Sequence[Union[torch.fx.Node, Any]], optional
|
|
432
|
+
Positional arguments to the module.
|
|
433
|
+
kwargs : Optional[Dict[str, Any]], optional
|
|
434
|
+
Keyword arguments to the module.
|
|
435
|
+
|
|
436
|
+
Returns
|
|
437
|
+
-------
|
|
438
|
+
torch.fx.Node
|
|
439
|
+
The FX node representing this module call.
|
|
440
|
+
"""
|
|
441
|
+
fx_node = self.graph.call_module(
|
|
442
|
+
module_name, args=tuple(args), kwargs=kwargs or {}
|
|
443
|
+
)
|
|
444
|
+
return fx_node
|
|
445
|
+
|
|
446
|
+
def _create_value_info_map(
|
|
447
|
+
self,
|
|
448
|
+
) -> Dict[str, Tuple[Optional[List[Optional[int]]], Optional[torch.dtype]]]:
|
|
449
|
+
"""Build a mapping from value names to their shape and dtype info."""
|
|
450
|
+
|
|
451
|
+
def extract_tensor_shape(
|
|
452
|
+
value: onnx.ValueInfoProto,
|
|
453
|
+
) -> Optional[List[Optional[int]]]:
|
|
454
|
+
"""Extract a list-based representation of a tensor shape from a value info."""
|
|
455
|
+
|
|
456
|
+
tensor_type = value.type.tensor_type
|
|
457
|
+
if not tensor_type.HasField("shape"):
|
|
458
|
+
return None
|
|
459
|
+
dims: List[Optional[int]] = []
|
|
460
|
+
for dim in tensor_type.shape.dim:
|
|
461
|
+
if dim.HasField("dim_value"):
|
|
462
|
+
dims.append(int(dim.dim_value))
|
|
463
|
+
elif dim.HasField("dim_param"):
|
|
464
|
+
dims.append(None)
|
|
465
|
+
else:
|
|
466
|
+
dims.append(None)
|
|
467
|
+
return dims
|
|
468
|
+
|
|
469
|
+
def extract_tensor_dtype(value: onnx.ValueInfoProto) -> Optional[torch.dtype]:
|
|
470
|
+
"""Extract the Torch dtype that corresponds to a value info."""
|
|
471
|
+
|
|
472
|
+
onnx_dtype = value.type.tensor_type.elem_type
|
|
473
|
+
if onnx_dtype == 0:
|
|
474
|
+
return None
|
|
475
|
+
torch_dtype = DTYPE_MAP.get(onnx_dtype)
|
|
476
|
+
if torch_dtype is None:
|
|
477
|
+
if onnx_dtype == onnx.TensorProto.STRING:
|
|
478
|
+
return None
|
|
479
|
+
raise UnsupportedDTypeError(
|
|
480
|
+
onnx_dtype=onnx_dtype,
|
|
481
|
+
tensor_name=value.name,
|
|
482
|
+
details="value_info dtype not supported",
|
|
483
|
+
)
|
|
484
|
+
return torch_dtype
|
|
485
|
+
|
|
486
|
+
info_map = {}
|
|
487
|
+
for value_info in (
|
|
488
|
+
list(self.model.graph.input)
|
|
489
|
+
+ list(self.model.graph.value_info)
|
|
490
|
+
+ list(self.model.graph.output)
|
|
491
|
+
):
|
|
492
|
+
info_map[value_info.name] = (
|
|
493
|
+
extract_tensor_shape(value_info),
|
|
494
|
+
extract_tensor_dtype(value_info),
|
|
495
|
+
)
|
|
496
|
+
return info_map
|
|
497
|
+
|
|
498
|
+
def is_optional_type(self, name: str) -> bool:
|
|
499
|
+
"""Check if a value has optional type in the ONNX model.
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
name : str
|
|
504
|
+
The name of the value to check.
|
|
505
|
+
|
|
506
|
+
Returns
|
|
507
|
+
-------
|
|
508
|
+
bool
|
|
509
|
+
True if the value has optional type, False otherwise.
|
|
510
|
+
"""
|
|
511
|
+
# Search in graph inputs
|
|
512
|
+
for value_info in self.model.graph.input:
|
|
513
|
+
if value_info.name == name:
|
|
514
|
+
return value_info.type.HasField("optional_type")
|
|
515
|
+
# Search in value_info
|
|
516
|
+
for value_info in self.model.graph.value_info:
|
|
517
|
+
if value_info.name == name:
|
|
518
|
+
return value_info.type.HasField("optional_type")
|
|
519
|
+
# Search in outputs
|
|
520
|
+
for value_info in self.model.graph.output:
|
|
521
|
+
if value_info.name == name:
|
|
522
|
+
return value_info.type.HasField("optional_type")
|
|
523
|
+
return False
|
|
524
|
+
|
|
525
|
+
def _create_initializer_map(self) -> Dict[str, torch.Tensor]:
|
|
526
|
+
"""Build a mapping from initializer names to PyTorch tensors."""
|
|
527
|
+
init_map = {}
|
|
528
|
+
for initializer in self.model.graph.initializer:
|
|
529
|
+
init_map[initializer.name] = self.load_tensor(initializer)
|
|
530
|
+
return init_map
|
|
531
|
+
|
|
532
|
+
def load_tensor(self, tensor: onnx.TensorProto) -> torch.Tensor:
|
|
533
|
+
"""Load an ONNX TensorProto into a Torch tensor."""
|
|
534
|
+
onnx_dtype = tensor.data_type
|
|
535
|
+
if DTYPE_MAP.get(onnx_dtype) is None:
|
|
536
|
+
raise UnsupportedDTypeError(
|
|
537
|
+
onnx_dtype=onnx_dtype,
|
|
538
|
+
tensor_name=tensor.name or "<unnamed>",
|
|
539
|
+
details="initializer dtype not supported",
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
if self._memmap_external_data and (
|
|
543
|
+
tensor.data_location == onnx.TensorProto.EXTERNAL or tensor.external_data
|
|
544
|
+
):
|
|
545
|
+
info = resolve_external_data(
|
|
546
|
+
tensor,
|
|
547
|
+
base_dir=self._base_dir,
|
|
548
|
+
strict=True,
|
|
549
|
+
)
|
|
550
|
+
memmap_array = np.memmap(
|
|
551
|
+
info.path,
|
|
552
|
+
dtype=info.numpy_dtype,
|
|
553
|
+
mode="r",
|
|
554
|
+
offset=info.offset,
|
|
555
|
+
shape=info.shape,
|
|
556
|
+
)
|
|
557
|
+
return torch.from_numpy(memmap_array)
|
|
558
|
+
|
|
559
|
+
np_array = numpy_helper.to_array(tensor)
|
|
560
|
+
return torch.from_numpy(np_array.copy())
|
|
561
|
+
|
|
562
|
+
def _load_initializers(self) -> None:
|
|
563
|
+
"""Load ONNX initializers as constant nodes in the FX graph."""
|
|
564
|
+
for name, tensor in self.initializer_map.items():
|
|
565
|
+
# Store in constants dict for later registration as buffers
|
|
566
|
+
safe_name = sanitize_name(name)
|
|
567
|
+
self._constants[safe_name] = tensor
|
|
568
|
+
|
|
569
|
+
# Create a get_attr node to access the buffer
|
|
570
|
+
fx_node = self.graph.get_attr(safe_name)
|
|
571
|
+
self._set_onnx_metadata(
|
|
572
|
+
fx_node,
|
|
573
|
+
op_type="Initializer",
|
|
574
|
+
name=name,
|
|
575
|
+
shape=list(tensor.shape),
|
|
576
|
+
dtype=tensor.dtype,
|
|
577
|
+
)
|
|
578
|
+
self.env[name] = fx_node
|
|
579
|
+
|
|
580
|
+
def _create_placeholders(self) -> None:
|
|
581
|
+
"""Create FX placeholder nodes for graph inputs.
|
|
582
|
+
|
|
583
|
+
Note: Inputs that are already loaded as initializers are skipped.
|
|
584
|
+
"""
|
|
585
|
+
for value in self.model.graph.input:
|
|
586
|
+
# Skip if already loaded as initializer
|
|
587
|
+
if value.name in self.env:
|
|
588
|
+
continue
|
|
589
|
+
|
|
590
|
+
# Sanitize name for valid Python identifier
|
|
591
|
+
safe_name = sanitize_name(value.name)
|
|
592
|
+
placeholder = self.graph.placeholder(safe_name)
|
|
593
|
+
info = self.value_info_map.get(value.name)
|
|
594
|
+
self._set_onnx_metadata(
|
|
595
|
+
placeholder,
|
|
596
|
+
op_type="Input",
|
|
597
|
+
name=value.name,
|
|
598
|
+
shape=info[0] if info else None,
|
|
599
|
+
dtype=info[1] if info else None,
|
|
600
|
+
)
|
|
601
|
+
self.env[value.name] = placeholder
|
|
602
|
+
self.input_names.append(value.name)
|
|
603
|
+
|
|
604
|
+
def _convert_nodes(self) -> None:
|
|
605
|
+
# Get graph inputs and initializers for topological sort
|
|
606
|
+
graph_inputs = {inp.name for inp in self.model.graph.input}
|
|
607
|
+
initializers = set(self.initializer_map.keys())
|
|
608
|
+
|
|
609
|
+
# Topologically sort nodes to handle out-of-order dependencies
|
|
610
|
+
sorted_nodes = _topological_sort(
|
|
611
|
+
list(self.model.graph.node),
|
|
612
|
+
graph_inputs,
|
|
613
|
+
initializers,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
for node in sorted_nodes:
|
|
617
|
+
# Get handler with domain and opset version support
|
|
618
|
+
handler, domain, _opset = self._resolve_handler(node)
|
|
619
|
+
fx_node = handler(self, node)
|
|
620
|
+
|
|
621
|
+
# Add ONNX metadata to the operator node
|
|
622
|
+
# Some handlers return a list of nodes (e.g., gradient ops)
|
|
623
|
+
self._tag_operator_node(node, fx_node, domain)
|
|
624
|
+
self._register_outputs(node, fx_node, domain)
|
|
625
|
+
|
|
626
|
+
def _create_outputs(self) -> None:
|
|
627
|
+
output_nodes = [self.get_value(value.name) for value in self.model.graph.output]
|
|
628
|
+
if len(output_nodes) == 1:
|
|
629
|
+
self.graph.output(output_nodes[0])
|
|
630
|
+
else:
|
|
631
|
+
self.graph.output(tuple(output_nodes))
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
__all__ = ["GraphBuilder"]
|