onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,102 @@
1
+ # mypy: allow-untyped-defs
2
+ """
3
+ This file contains canonical definitions for our symbol naming conventions,
4
+ across torch.fx.experimental.symbolic_shapes and torch._inductor. The
5
+ intention is:
6
+
7
+ 1. To make it easily greppable where all the sites we use a prefix are
8
+ 2. Make it possible to easily tell if we can introduce a new prefix without
9
+ introducing a conflict
10
+
11
+ You can occasionally test if prefixes have been hardcoded by renaming prefixes
12
+ in this file and seeing what breaks.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from collections.abc import Iterable
18
+ from enum import Enum, auto
19
+
20
+ import sympy
21
+
22
+
23
+ class SymT(Enum):
24
+ SIZE = auto()
25
+ FLOAT = auto()
26
+ UNBACKED_INT = auto()
27
+ UNBACKED_FLOAT = auto()
28
+ # Inductor: The intermediates in inner_fn tmp0, one generated per ops call.
29
+ # If one of these shows up in an indexing expression, that means an
30
+ # indirect load is happening.
31
+ TMP = auto()
32
+ # Inductor: Placeholder variable that is later replaced with TMP
33
+ INDIRECT = auto()
34
+ # Inductor: Some size expressions are replaced with a precomputed size ps0
35
+ # which is computed host side, and then directly reused in the kernel, so
36
+ # we don't repeatedly recompute it on device.
37
+ PRECOMPUTED_SIZE = auto()
38
+ # Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
39
+ # dim in the loop
40
+ INDEX = auto()
41
+ # Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over
42
+ # reduced dim(s) in the loop
43
+ R0_INDEX = auto()
44
+ R1_INDEX = auto()
45
+ # Inductor: In templated kernels torch._inductor.kernel, we have a hook to
46
+ # store the final output and append epilogue fusions. To do this, we must
47
+ # know what the indexes the outputs range over. NB: These will also
48
+ # advertise as INDEX, this is... probably OK?
49
+ TEMPLATE_INDEX = auto()
50
+ # Inductor: iteration domain for blockIdx.x/blockIdx.y
51
+ XBLOCK = auto()
52
+ YBLOCK = auto()
53
+ ZBLOCK = auto()
54
+ # Inductor: this is used solely for dynamic_reshape_indexer
55
+ VIEW = auto()
56
+ # Alternate (non-modular) indexing used in halide kernels
57
+ HALIDE = auto()
58
+
59
+
60
+ # Invariant: there must not be a prefix which is a prefix of another string,
61
+ # as this introduces ambiguity
62
+ prefix_str = {
63
+ SymT.SIZE: "s", # integer
64
+ SymT.UNBACKED_INT: "u", # integer
65
+ # Prefix z here is chosen to avoid false aliasing in symbol_is_type test
66
+ # DO NOT add a "z" type. You also need to avoid conflicts on these
67
+ # prefixes but this is somewhat easier to manage
68
+ SymT.FLOAT: "zf",
69
+ SymT.UNBACKED_FLOAT: "zuf",
70
+ SymT.TMP: "tmp",
71
+ SymT.PRECOMPUTED_SIZE: "ps",
72
+ SymT.INDEX: "i",
73
+ SymT.R0_INDEX: "r0_",
74
+ SymT.R1_INDEX: "r1_",
75
+ SymT.TEMPLATE_INDEX: "idx",
76
+ SymT.XBLOCK: "x",
77
+ SymT.YBLOCK: "y",
78
+ SymT.ZBLOCK: "z",
79
+ SymT.INDIRECT: "indirect", # false aliasing?
80
+ SymT.VIEW: "view",
81
+ SymT.HALIDE: "h",
82
+ }
83
+
84
+
85
+ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
86
+ # TODO: maybe put the assumptions here directly
87
+ return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)
88
+
89
+
90
+ # This type is a little wider than it should be, because free_symbols says
91
+ # that it contains Basic, rather than Symbol
92
+ def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool:
93
+ assert isinstance(sym, sympy.Symbol)
94
+ name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
95
+ if isinstance(prefix, SymT):
96
+ return name_str.startswith(prefix_str[prefix])
97
+ else:
98
+ return name_str.startswith(tuple(prefix_str[p] for p in prefix))
99
+
100
+
101
+ def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool:
102
+ return any(symbol_is_type(v, prefix) for v in e.free_symbols)
@@ -0,0 +1,15 @@
1
+ from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import export_onnx
2
+ from onnxslim.third_party.onnx_graphsurgeon.graph_pattern import (
3
+ GraphPattern,
4
+ PatternMapping,
5
+ )
6
+ from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import import_onnx
7
+ from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
8
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
9
+ from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
10
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
11
+ from onnxslim.third_party.onnx_graphsurgeon.util.exception import (
12
+ OnnxGraphSurgeonException,
13
+ )
14
+
15
+ __version__ = "0.5.1"
@@ -0,0 +1 @@
1
+ from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter
@@ -0,0 +1,33 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
19
+
20
+
21
+ class BaseExporter:
22
+ @staticmethod
23
+ def export_graph(graph: Graph):
24
+ """
25
+ Export a graph to some destination graph.
26
+
27
+ Args:
28
+ graph (Graph): The source graph to export.
29
+
30
+ Returns:
31
+ object: The exported graph. For example, this might be an onnx.GraphProto
32
+ """
33
+ raise NotImplementedError("BaseExporter is an abstract class")
@@ -0,0 +1,432 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from __future__ import annotations
18
+
19
+ from collections import OrderedDict
20
+ from collections.abc import Sequence
21
+
22
+ import numpy as np
23
+ import onnx
24
+ from onnx import IR_VERSION, ModelProto, defs
25
+
26
+ from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter
27
+ from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
28
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
29
+ from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
30
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import (
31
+ Constant,
32
+ LazyValues,
33
+ SparseValues,
34
+ Tensor,
35
+ Variable,
36
+ )
37
+
38
+ from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
39
+ from onnxslim.third_party.onnx_graphsurgeon.util import misc
40
+
41
+ from ml_dtypes import bfloat16, float8_e4m3fn
42
+
43
+ def dtype_to_onnx(dtype: np.dtype | onnx.TensorProto.DataType) -> int:
44
+ """Converts a numpy dtype or ONNX data type to its integer representation."""
45
+ if isinstance(dtype, int):
46
+ return dtype
47
+ return onnx.helper.np_dtype_to_tensor_dtype(np.dtype(dtype))
48
+
49
+
50
+ def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING):
51
+ """Check if node names are unique and log any duplicates based on the specified severity level."""
52
+ # Note:
53
+ # Empty string or None attribute values are not considered duplicates.
54
+ name_map = {}
55
+ for node in nodes:
56
+ if not node.name:
57
+ continue
58
+ if node.name in name_map:
59
+ msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n"
60
+ G_LOGGER.log(msg, level)
61
+ else:
62
+ name_map[node.name] = node
63
+
64
+
65
+ def update_import_domains(graph):
66
+ """Update the import_domains field of a graph to include its ONNX opset and other used non-ONNX domains."""
67
+ # as well as other non-ONNX domains which are used by this graph's nodes.
68
+ # Returns the updated value of the import_domains field.
69
+
70
+ # Add domain of the standard ONNX opset.
71
+ if graph.import_domains is None:
72
+ graph.import_domains = [onnx.helper.make_opsetid("", graph.opset)]
73
+
74
+ # Crawl over all nodes in this graph and its subgraphs, and add the nodes' domains.
75
+ all_used_domains = {node.domain for node in graph.nodes}
76
+ for subgraph in graph.subgraphs(recursive=True):
77
+ all_used_domains |= {n.domain for n in subgraph.nodes}
78
+ all_used_domains.discard(None)
79
+
80
+ # Update self.import_domains with any missing domains.
81
+ current_domains = {opsetid.domain for opsetid in graph.import_domains}
82
+ DEFAULT_CUSTOM_OPSET_VERSION = 1
83
+ for used_domain in all_used_domains:
84
+ if used_domain not in current_domains:
85
+ graph.import_domains.append(onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION))
86
+ current_domains.add(used_domain)
87
+ return graph.import_domains
88
+
89
+
90
+ def float32_to_bfloat16_uint16(x):
91
+ """Convert a float32 value to bfloat16 represented as uint16."""
92
+ return bfloat16(x).view(np.uint16)
93
+
94
+ def float32_to_float8e4m3(x):
95
+ """Convert a float32 value to float8e4m3 represented as uint8."""
96
+ return float8_e4m3fn(x).view(np.uint8)
97
+
98
+
99
+ class NumpyArrayConverter:
100
+ def __init__(self, container, scalar_converter):
101
+ self.func = np.vectorize(scalar_converter, otypes=[container])
102
+
103
+ def __call__(self, arr):
104
+ return self.func(arr)
105
+
106
+
107
+ _NUMPY_ARRAY_CONVERTERS = {
108
+ onnx.TensorProto.BFLOAT16: NumpyArrayConverter(np.uint16, float32_to_bfloat16_uint16),
109
+ # FP8 in TensorRT supports negative zeros, no infinities
110
+ # See https://onnx.ai/onnx/technical/float8.html#papers
111
+ onnx.TensorProto.FLOAT8E4M3FN: NumpyArrayConverter(np.uint8, float32_to_float8e4m3),
112
+ }
113
+
114
+
115
+ def constant_to_onnx_tensor(tensor: Constant) -> onnx.TensorProto:
116
+ source_dtype = dtype_to_onnx(tensor.dtype)
117
+ target_dtype = dtype_to_onnx(tensor.export_dtype)
118
+
119
+ if source_dtype != target_dtype:
120
+ source_dtype_str = onnx.helper.tensor_dtype_to_string(source_dtype)
121
+ target_dtype_str = onnx.helper.tensor_dtype_to_string(target_dtype)
122
+ assert source_dtype == onnx.TensorProto.FLOAT, (
123
+ f"Cannot convert onnx dtype {source_dtype_str} to {target_dtype_str}. "
124
+ "Source dtype must be float32 to convert to numpy unsupported dtypes."
125
+ )
126
+ assert target_dtype in _NUMPY_ARRAY_CONVERTERS.keys(), (
127
+ f"Cannot convert onnx dtype {source_dtype_str} to {target_dtype_str}. "
128
+ f"Only float32 to {_NUMPY_ARRAY_CONVERTERS.keys()} is supported."
129
+ )
130
+ arr = _NUMPY_ARRAY_CONVERTERS[target_dtype](tensor.values)
131
+ tensor_raw_bytes = arr.tobytes()
132
+ else:
133
+ tensor_raw_bytes = tensor.values.tobytes()
134
+
135
+ return onnx.helper.make_tensor(
136
+ name=tensor.name,
137
+ data_type=target_dtype,
138
+ dims=tensor.shape,
139
+ vals=tensor_raw_bytes,
140
+ raw=True,
141
+ )
142
+
143
+
144
+ class OnnxExporter(BaseExporter):
145
+ @staticmethod
146
+ def export_tensor_proto(tensor: Constant) -> onnx.TensorProto:
147
+ # Do *not* load LazyValues into an intermediate numpy array - instead, use
148
+ # the original onnx.TensorProto directly.
149
+ if isinstance(tensor._values, LazyValues):
150
+ onnx_tensor = tensor._values.tensor
151
+ onnx_tensor.name = tensor.name
152
+ else:
153
+ onnx_tensor = constant_to_onnx_tensor(tensor)
154
+
155
+ if tensor.data_location is not None:
156
+ onnx_tensor.data_location = tensor.data_location
157
+ return onnx_tensor
158
+
159
+ @staticmethod
160
+ def export_sparse_tensor_proto(tensor: Constant) -> onnx.SparseTensorProto:
161
+ """Exports a given Constant tensor as an ONNX SparseTensorProto."""
162
+ return tensor._values.tensor
163
+
164
+ @staticmethod
165
+ def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueInfoProto:
166
+ """Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information."""
167
+ if do_type_check and tensor.dtype is None:
168
+ G_LOGGER.critical(
169
+ f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}"
170
+ )
171
+
172
+ if tensor.dtype is None:
173
+ onnx_tensor = onnx.helper.make_empty_tensor_value_info(tensor.name)
174
+ elif isinstance(tensor, Constant) or tensor.type == "tensor_type":
175
+ onnx_tensor = onnx.helper.make_tensor_value_info(tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape)
176
+ elif tensor.type == "sequence_type":
177
+ onnx_tensor = onnx.helper.make_tensor_sequence_value_info(
178
+ tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape
179
+ )
180
+ elif tensor.type == "sparse_tensor_type":
181
+ onnx_tensor = onnx.helper.make_sparse_tensor_value_info(
182
+ tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape
183
+ )
184
+ return onnx_tensor
185
+
186
+ @staticmethod
187
+ def export_attributes(attrs: dict, subgraph_tensor_map=None) -> list[onnx.AttributeProto]:
188
+ """Convert function attributes to ONNX AttributeProtos for model export."""
189
+ onnx_attrs: list[onnx.AttributeProto] = []
190
+ for key, val in attrs.items():
191
+ if isinstance(val, Tensor):
192
+ val = OnnxExporter.export_tensor_proto(val)
193
+ elif isinstance(val, Graph):
194
+ # Subgraphs don't need to have types specified for their tensors.
195
+ graph = onnx.GraphProto()
196
+ OnnxExporter.export_graph(graph, val, subgraph_tensor_map=subgraph_tensor_map, do_type_check=False)
197
+ val = graph
198
+ elif isinstance(val, Node.AttributeRef):
199
+ onnx_attr = onnx.AttributeProto()
200
+ onnx_attr.name = key
201
+ onnx_attr.type = misc.convert_to_onnx_attr_type(val.type)
202
+
203
+ # Netron has a bug which makes it crash if a Tensor attribute has no tensor data.
204
+ # So provide some meaningless tensor data for Netron to read.
205
+ if val.type == Tensor:
206
+ tensor_proto = OnnxExporter.export_tensor_proto(Constant("", np.array([0], dtype=np.float32)))
207
+ onnx_attr.t.CopyFrom(tensor_proto)
208
+
209
+ onnx_attr.ref_attr_name = val.name
210
+ onnx_attrs.append(onnx_attr)
211
+ continue
212
+ elif isinstance(val, type):
213
+ # May be a numpy type
214
+ try:
215
+ val = dtype_to_onnx(val)
216
+ except TypeError:
217
+ pass
218
+ onnx_attrs.append(onnx.helper.make_attribute(key, val))
219
+ return onnx_attrs
220
+
221
+ @staticmethod
222
+ def export_node(node: Node, subgraph_tensor_map=None) -> onnx.NodeProto:
223
+ # Cannot pass in attrs directly as make_node will change the order
224
+ """Static method to convert an internal node to an ONNX node representation."""
225
+ onnx_node = onnx.helper.make_node(
226
+ node.op,
227
+ inputs=[t.name for t in node.inputs],
228
+ outputs=[t.name for t in node.outputs],
229
+ name=node.name,
230
+ domain=node.domain,
231
+ )
232
+ onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs, subgraph_tensor_map))
233
+ return onnx_node
234
+
235
+ @staticmethod
236
+ def export_function(func: Function) -> onnx.FunctionProto:
237
+ """
238
+ Export an onnx-graphsurgeon Function to an ONNX FunctionProto.
239
+
240
+ Args:
241
+ func (Function): The function to export.
242
+ """
243
+ # Unlike onnx Graphs, onnx Functions don't have an 'initializer' field.
244
+ # So we need to replace all Constant tensors with onnx Constant nodes which produce them.
245
+ # We need to be careful to (a) preserve topological ordering and (b) not make the new nodes visible to the user.
246
+ func_nodes = func.nodes.copy()
247
+ new_const_nodes = [
248
+ Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()])
249
+ for tensor in func.tensors().values()
250
+ if isinstance(tensor, Constant)
251
+ ]
252
+ # Const nodes have no inputs, so this maintains a topological ordering.
253
+ func_nodes = new_const_nodes + func_nodes
254
+
255
+ check_duplicate_node_names(func_nodes, level=G_LOGGER.WARNING)
256
+ nodes = [OnnxExporter.export_node(node) for node in func_nodes]
257
+
258
+ # Update the import_domains field to include all domains used by this function.
259
+ opset_imports = update_import_domains(func)
260
+
261
+ onnx_inputs = [inp.name for inp in func.inputs]
262
+ onnx_outputs = [out.name for out in func.outputs]
263
+
264
+ attributes = []
265
+ attribute_protos = {}
266
+ for attr_name, default_val in func.attrs.items():
267
+ if default_val is None:
268
+ attributes.append(attr_name)
269
+ else:
270
+ attribute_protos[attr_name] = default_val
271
+ attribute_protos = OnnxExporter.export_attributes(attribute_protos)
272
+
273
+ return onnx.helper.make_function(
274
+ func.domain or "",
275
+ func.name,
276
+ onnx_inputs,
277
+ onnx_outputs,
278
+ nodes,
279
+ opset_imports,
280
+ attributes=attributes,
281
+ attribute_protos=attribute_protos,
282
+ doc_string=func.doc_string,
283
+ )
284
+
285
+ @staticmethod
286
+ def export_graph(
287
+ graph_proto: onnx.GraphProto,
288
+ graph: Graph,
289
+ tensor_map: OrderedDict[str, Tensor] | None = None,
290
+ subgraph_tensor_map: OrderedDict[str, Tensor] | None = None,
291
+ do_type_check=True,
292
+ ) -> None:
293
+ """
294
+ Export an onnx-graphsurgeon Graph to an ONNX GraphProto.
295
+
296
+ Args:
297
+ graph (Graph): The graph to export.
298
+
299
+ do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not.
300
+ Defaults to True.
301
+ """
302
+ check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING)
303
+ nodes = [OnnxExporter.export_node(node, subgraph_tensor_map) for node in graph.nodes]
304
+ inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs]
305
+ outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs]
306
+ if tensor_map is None:
307
+ tensor_map = graph.tensors()
308
+ tensor_map = misc.unique_dicts(tensor_map, subgraph_tensor_map)
309
+ else:
310
+ tensor_map = misc.combine_dicts(tensor_map, subgraph_tensor_map)
311
+ initializer = [
312
+ OnnxExporter.export_tensor_proto(tensor)
313
+ for tensor in tensor_map.values()
314
+ if isinstance(tensor, Constant) and not isinstance(tensor._values, SparseValues)
315
+ ]
316
+
317
+ sparse_initializer = [
318
+ OnnxExporter.export_sparse_tensor_proto(tensor)
319
+ for tensor in tensor_map.values()
320
+ if isinstance(tensor, Constant) and isinstance(tensor._values, SparseValues)
321
+ ]
322
+
323
+ # Remove inputs and outputs to export ValueInfoProtos
324
+ for tensor in graph.inputs + graph.outputs:
325
+ if tensor.name in tensor_map:
326
+ del tensor_map[tensor.name]
327
+
328
+ # Omit tensors from value_info if we don't know their shape/dtype
329
+ def has_value_info(tensor):
330
+ """Check if a tensor is a Variable with either a defined dtype or shape."""
331
+ return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None)
332
+
333
+ value_info = [
334
+ OnnxExporter.export_value_info_proto(tensor, do_type_check)
335
+ for tensor in tensor_map.values()
336
+ if has_value_info(tensor)
337
+ ]
338
+
339
+ if initializer is None:
340
+ initializer = []
341
+ if sparse_initializer is None:
342
+ sparse_initializer = []
343
+ if value_info is None:
344
+ value_info = []
345
+
346
+ graph_proto.node.extend(nodes)
347
+ graph_proto.name = graph.name
348
+ graph_proto.input.extend(inputs)
349
+ graph_proto.output.extend(outputs)
350
+ for init in initializer:
351
+ graph_proto.initializer.add().CopyFrom(init)
352
+ for sparse in sparse_initializer:
353
+ graph_proto.sparse_initializer.add().CopyFrom(sparse)
354
+ graph_proto.value_info.extend(value_info)
355
+ if graph.doc_string:
356
+ graph.doc_string = graph.doc_string
357
+
358
+ return graph_proto
359
+
360
+
361
+ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> onnx.ModelProto:
362
+ """
363
+ Exports an onnx-graphsurgeon Graph to an ONNX model.
364
+
365
+ Args:
366
+ graph (Graph): The graph to export
367
+
368
+ do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not.
369
+ Defaults to True.
370
+ kwargs: Additional arguments to onnx.helper.make_model
371
+
372
+ Returns:
373
+ onnx.ModelProto: A corresponding ONNX model.
374
+ """
375
+ sub_graphs = graph.subgraphs(recursive=True)
376
+
377
+ graph_constants_list = [
378
+ {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)}
379
+ for sub_graph in sub_graphs
380
+ ]
381
+
382
+ if not graph_constants_list:
383
+ intersection = None
384
+ else:
385
+ intersection = (
386
+ {
387
+ key: graph_constants_list[0][key]
388
+ for key in graph_constants_list[0]
389
+ if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:])
390
+ }
391
+ if graph_constants_list
392
+ else None
393
+ )
394
+
395
+ model = ModelProto() # create in advance to avoid unnecessary copy
396
+ OnnxExporter.export_graph(
397
+ model.graph, graph, tensor_map=graph.tensors(), subgraph_tensor_map=intersection, do_type_check=do_type_check
398
+ )
399
+ onnx_functions = [OnnxExporter.export_function(func) for func in graph.functions]
400
+ kwargs["functions"] = onnx_functions
401
+
402
+ if "opset_imports" not in kwargs:
403
+ kwargs["opset_imports"] = update_import_domains(graph)
404
+
405
+ if "ir_version" not in kwargs and graph.ir_version is not None:
406
+ kwargs["ir_version"] = graph.ir_version
407
+ else:
408
+ model.ir_version = IR_VERSION
409
+
410
+ opset_imports = None
411
+ opset_imports = kwargs.pop("opset_imports", None) # type: ignore
412
+ if opset_imports is not None:
413
+ model.opset_import.extend(opset_imports)
414
+ else:
415
+ # Default import
416
+ imp = model.opset_import.add()
417
+ imp.version = defs.onnx_opset_version()
418
+
419
+ functions = None
420
+ functions = kwargs.pop("functions", None) # type: ignore
421
+ if functions is not None:
422
+ model.functions.extend(functions)
423
+
424
+ for k, v in kwargs.items():
425
+ # TODO: Does this work with repeated fields?
426
+ setattr(model, k, v)
427
+
428
+ if graph.metadata_props is not None:
429
+ model.metadata_props.extend(graph.metadata_props)
430
+ model.producer_name = graph.producer_name
431
+ model.producer_version = graph.producer_version
432
+ return model
@@ -0,0 +1,4 @@
1
+ from onnxslim.third_party.onnx_graphsurgeon.graph_pattern.graph_pattern import (
2
+ GraphPattern,
3
+ PatternMapping,
4
+ )