onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,274 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from __future__ import annotations
18
+
19
+ import copy
20
+ from collections.abc import Sequence
21
+
22
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
23
+ from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
24
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor, Variable
25
+ from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
26
+ from onnxslim.third_party.onnx_graphsurgeon.util import misc
27
+
28
+
29
+ class Function(Graph):
30
+ """
31
+ Represents a local function, which is a default implementation of a Custom Op. This default implementation is
32
+ represented as a Graph of other Ops.
33
+
34
+ Functions are used in a model by creating a Node with the same name and domain as the function. This can be done
35
+ using the __call__() method of a Function, which creates this new node and appends it to a Graph. A Function is not
36
+ a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate from the main Graph.
37
+
38
+ Functions can be composed of other functions, but cyclical or recursive definitions are not allowed in ONNX.
39
+ """
40
+
41
+ DEFAULT_DOMAIN = "onnx_graphsurgeon"
42
+
43
+ def __init__(
44
+ self,
45
+ name: str,
46
+ domain: str | None = None,
47
+ nodes: Sequence[Node] | None = None,
48
+ inputs: Sequence[Tensor] | None = None,
49
+ outputs: Sequence[Tensor] | None = None,
50
+ doc_string: str | None = None,
51
+ opset: int | None = None,
52
+ import_domains: Sequence[onnx.OperatorSetIdProto] | None = None,
53
+ functions: Sequence[Function] | None = None,
54
+ attrs: dict | None = None,
55
+ ):
56
+ """
57
+ Args:
58
+ name (str): The name of the function.
59
+ domain (str): The domain/namespace of this function.
60
+ nodes (Sequence[Node]): A list of the nodes in this function.
61
+ inputs (Sequence[Tensor]): A list of graph input Tensors.
62
+ outputs (Sequence[Tensor]): A list of graph output Tensors.
63
+ doc_string (str): A doc_string for the function. Defaults to "".
64
+ opset (int): The ONNX opset used by nodes in this function.
65
+ import_domains (Sequence[onnx.OperatorSetIdProto]): The list of domains used by nodes in this function.
66
+ functions (Sequence[Function]): The list of functions in this model.
67
+ attrs (dict): A mapping of attribute names to their default values.
68
+ Nodes within this function can have attributes which take on the values of the Function attributes.
69
+ When a Function is instantiated into a Node, providing attributes to that Node will override the Function's
70
+ default attribute values. A default value of `None` means that the instantiated Node must provide the value
71
+ of that attribute (in other words, it is a required attribute).
72
+ """
73
+ self.domain = misc.default_value(domain, Function.DEFAULT_DOMAIN)
74
+ self.attrs = misc.default_value(attrs, {})
75
+
76
+ super().__init__(
77
+ nodes,
78
+ inputs,
79
+ outputs,
80
+ name=name,
81
+ doc_string=doc_string,
82
+ opset=opset,
83
+ import_domains=import_domains,
84
+ functions=functions,
85
+ )
86
+
87
+ # Properties of Graph that Function doesn't have.
88
+ del self.producer_name
89
+ del self.producer_version
90
+
91
+ @property
92
+ def unique_id(self):
93
+ """Returns a tuple which uniquely identifies this function."""
94
+ return (self.domain, self.name)
95
+
96
+ def cleanup(
97
+ self,
98
+ remove_unused_node_outputs=False,
99
+ recurse_subgraphs=True,
100
+ remove_unused_graph_inputs=False,
101
+ recurse_functions=False,
102
+ ):
103
+ """See Graph.cleanup() The only difference is that 'recurse_functions' defaults to False, so that only this
104
+ Function is cleaned up.
105
+ """
106
+ if recurse_functions:
107
+ G_LOGGER.warning(
108
+ "Function.cleanup() called with recurse_functions=True, meaning that other functions will also be cleaned up."
109
+ )
110
+ return super().cleanup(
111
+ remove_unused_node_outputs=remove_unused_node_outputs,
112
+ recurse_subgraphs=recurse_subgraphs,
113
+ remove_unused_graph_inputs=remove_unused_graph_inputs,
114
+ recurse_functions=recurse_functions,
115
+ )
116
+
117
+ def fold_constants(self, recurse_functions=False, **kwargs):
118
+ """See Graph.fold_constants() The only difference is that 'recurse_functions' defaults to False, so that only
119
+ this Function's constants are folded.
120
+ """
121
+ if recurse_functions:
122
+ G_LOGGER.warning(
123
+ "Function.fold_constants() called with recurse_functions=True, meaning that other functions will also be const-folded."
124
+ )
125
+ return super().fold_constants(recurse_functions=recurse_functions, **kwargs)
126
+
127
+ def toposort(
128
+ self,
129
+ recurse_subgraphs=True,
130
+ recurse_functions=False,
131
+ mode="nodes",
132
+ ):
133
+ """See Graph.toposort() The only difference is that 'recurse_functions' defaults to False and mode defaults to
134
+ "nodes", so that by default only this function's nodes will be sorted.
135
+ """
136
+ if recurse_functions:
137
+ G_LOGGER.warning(
138
+ "Function.toposort() called with recurse_functions=True, meaning that other functions will be sorted."
139
+ )
140
+ return super().toposort(
141
+ recurse_subgraphs=recurse_subgraphs,
142
+ recurse_functions=recurse_functions,
143
+ mode=mode,
144
+ )
145
+
146
+ def __call__(self, graph, inputs=None, outputs=None, *args, **kwargs) -> list[Tensor]:
147
+ """
148
+ Creates a Node which is an instance of this function. The created node can be used in a Graph or another
149
+ Function.
150
+
151
+ The provided inputs are processed the same way as in Graph.layer().
152
+ If outputs are not provided, they are created based on the Function's outputs.
153
+
154
+ Args:
155
+ graph (Union[Graph, Function]): The Graph of Function to add the new node to.
156
+ inputs (List[Union[Tensor, str, numpy.ndarray]]): The list of inputs.
157
+ outputs (List[Union[Tensor, str, numpy.ndarray]]): The list of outputs.
158
+ attrs (Dict[str, Any]): A list of attributes for the node.
159
+ The attribute names should be a subset of this Function's attribute names.
160
+ args/kwargs: These are passed directly to the constructor of Node.
161
+
162
+ Returns:
163
+ List[Tensor]: The output tensors of the node.
164
+ """
165
+ if inputs is not None and len(inputs) != len(self.inputs):
166
+ msg_template = "Function {} expects {} inputs, but was called with {} inputs."
167
+ G_LOGGER.warning(msg_template.format(self.name, len(self.inputs), len(inputs)))
168
+
169
+ new_output_indices = []
170
+ if outputs is None:
171
+ # Graph.layer() will create Tensors and make sure the names do not conflict.
172
+ outputs = [out.name for out in self.outputs]
173
+ new_output_indices = list(range(len(outputs)))
174
+ elif len(outputs) != len(self.outputs):
175
+ msg_template = "Function {} expects {} outputs, but was called with {} outputs."
176
+ G_LOGGER.warning(msg_template.format(self.name, len(self.outputs), len(outputs)))
177
+ else:
178
+ new_output_indices = [i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor)]
179
+
180
+ attrs = kwargs.get("attrs", None)
181
+ if attrs is not None:
182
+ for attr_name, default_val in self.attrs.items():
183
+ if default_val is None and attr_name not in attrs:
184
+ msg_template = "Function {} called without required attribute: {}"
185
+ G_LOGGER.warning(msg_template.format(self.name, attr_name))
186
+
187
+ inputs = misc.default_value(inputs, [])
188
+ outputs = misc.default_value(outputs, [])
189
+ outputs = graph.layer(
190
+ *args,
191
+ **kwargs,
192
+ op=self.name,
193
+ domain=self.domain,
194
+ inputs=inputs,
195
+ outputs=outputs,
196
+ )
197
+
198
+ # For newly created output tensors, set their shape and dtype to match the Function definition.
199
+ for i in new_output_indices:
200
+ outputs[i].dtype = self.outputs[i].dtype
201
+ outputs[i].shape = self.outputs[i].shape
202
+
203
+ return outputs
204
+
205
+ def copy(self):
206
+ """
207
+ Copy the function.
208
+
209
+ This makes copies of all nodes and tensors in the function, but will not
210
+ do a deep-copy of weights or attributes (with the exception of ``Graph``
211
+ attributes, which will be copied using their ``copy`` method).
212
+
213
+ Returns:
214
+ Function: A copy of the function.
215
+ """
216
+ local_tensor_copies = {n: t.copy() for n, t in self.tensors().items()}
217
+
218
+ def get_tensor(name):
219
+ """Retrieve a tensor by name from a deep-copied dictionary of tensors."""
220
+ return local_tensor_copies[name] if name else Variable.empty()
221
+
222
+ # Next, copy nodes, and update inputs/outputs
223
+ new_nodes = []
224
+ for node in self.nodes:
225
+ new_node = node.copy(
226
+ inputs=[get_tensor(inp.name) for inp in node.inputs],
227
+ outputs=[get_tensor(out.name) for out in node.outputs],
228
+ tensor_map=local_tensor_copies,
229
+ )
230
+ new_nodes.append(new_node)
231
+ new_func_inputs = [get_tensor(inp.name) for inp in self.inputs]
232
+ new_func_outputs = [get_tensor(out.name) for out in self.outputs]
233
+
234
+ new_attrs = {name: copy.copy(val) for name, val in self.attrs.items()}
235
+
236
+ return Function(
237
+ self.name,
238
+ self.domain,
239
+ nodes=new_nodes,
240
+ inputs=new_func_inputs,
241
+ outputs=new_func_outputs,
242
+ doc_string=self.doc_string,
243
+ opset=self.opset,
244
+ import_domains=self.import_domains,
245
+ functions=self.functions,
246
+ attrs=new_attrs,
247
+ )
248
+
249
+ def __eq__(self, other: Function):
250
+ """Checks equality of self with another Function object based on their attributes."""
251
+
252
+ def sequences_equal(seq1, seq2):
253
+ """Checks if two sequences are equal in length and elements."""
254
+ return len(seq1) == len(seq2) and all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
255
+
256
+ return (
257
+ self.unique_id == other.unique_id
258
+ and self.opset == other.opset
259
+ and self.import_domains == other.import_domains
260
+ and sequences_equal(self.inputs, other.inputs)
261
+ and sequences_equal(self.outputs, other.outputs)
262
+ and sequences_equal(self.nodes, other.nodes)
263
+ )
264
+
265
+ def __str__(self):
266
+ """Returns a string representation of the function including its name, domain, opset, inputs, nodes, and
267
+ outputs.
268
+ """
269
+ nodes_str = "\n".join([str(node) for node in self.nodes])
270
+ out = f"Function {self.name}, Domain {self.domain}, Opset {self.opset}"
271
+ out += f"\nInputs: {self.inputs}"
272
+ out += f"\nNodes: {nodes_str}"
273
+ out += f"\nOutputs: {self.outputs}"
274
+ return out