onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,266 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from __future__ import annotations
18
+
19
+ from collections import OrderedDict
20
+ from dataclasses import dataclass
21
+
22
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
23
+ from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
24
+ from onnxslim.third_party.onnx_graphsurgeon.util import misc
25
+
26
+
27
+ class Node:
28
+ @dataclass
29
+ class AttributeRef:
30
+ """
31
+ An AttributeRef is an attribute value which references an attribute in the parent function. A node's attribute
32
+ can only be an AttributeRef if the node lives inside a Function.
33
+
34
+ Args:
35
+ name (str): The name of the referenced attribute in the parent Function.
36
+ type (type): The attribute's type.
37
+ """
38
+
39
+ name: str
40
+ type: type
41
+
42
+ def __init__(
43
+ self,
44
+ op: str,
45
+ name: str | None = None,
46
+ attrs: dict[str, object] | None = None,
47
+ inputs: list[Tensor] | None = None,
48
+ outputs: list[Tensor] | None = None,
49
+ domain: str | None = None,
50
+ ):
51
+ """
52
+ A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors.
53
+
54
+ Args:
55
+ op (str): The operation this node performs.
56
+
57
+ name (str): The name of this node.
58
+ attrs (Dict[str, object]): A dictionary that maps attribute names to their values.
59
+ inputs (List[Tensor]): A list of zero or more input Tensors.
60
+ outputs (List[Tensor]): A list of zero or more output Tensors.
61
+ domain (str): The domain of this node,
62
+ """
63
+ self.op = op
64
+ self.name = misc.default_value(name, "")
65
+ self.attrs = misc.default_value(attrs, OrderedDict())
66
+ self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=misc.default_value(inputs, []))
67
+ self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=misc.default_value(outputs, []))
68
+ self.domain = domain
69
+
70
+ def i(self, tensor_idx=0, producer_idx=0):
71
+ """
72
+ Convenience function to get a producer node of one of this node's input tensors. Note that the parameters are
73
+ swapped compared to the o() function; this is because tensors are likely to have only a single producer.
74
+
75
+ For example:
76
+ ::
77
+
78
+ assert node.i() == node.inputs[0].inputs[0]
79
+ assert node.i(1, 2) == node.inputs[1].inputs[2]
80
+
81
+ Args:
82
+ tensor_idx (int): The index of the input tensor of this node. Defaults to 0.
83
+ producer_idx (int): The index of the producer of the input tensor, if the tensor has multiple producers. Defaults to 0
84
+
85
+ Returns:
86
+ Node: The specified producer (input) node.
87
+ """
88
+ return self.inputs[tensor_idx].inputs[producer_idx]
89
+
90
+ def o(self, consumer_idx=0, tensor_idx=0):
91
+ """
92
+ Convenience function to get a consumer node of one of this node's output tensors.
93
+
94
+ For example:
95
+ ::
96
+
97
+ assert node.o() == node.outputs[0].outputs[0]
98
+ assert node.o(2, 1) == node.outputs[1].outputs[2]
99
+
100
+ Args:
101
+ consumer_idx (int): The index of the consumer of the input tensor. Defaults to 0.
102
+ tensor_idx (int): The index of the output tensor of this node, if the node has multiple outputs. Defaults to 0.
103
+
104
+ Returns:
105
+ Node: The specified consumer (output) node
106
+ """
107
+ return self.outputs[tensor_idx].outputs[consumer_idx]
108
+
109
+ def subgraphs(self, recursive=False):
110
+ """
111
+ Convenience function to iterate over all subgraphs which are contained in this node. Node subgraphs are found in
112
+ attributes of ONNX control flow nodes such as 'If' and 'Loop'.
113
+
114
+ Args:
115
+ recursive (bool): Whether to recurse into the subgraph nodes when looking for subgraphs. Defaults to False.
116
+
117
+ Returns:
118
+ A generator which iterates over this node's subgraphs.
119
+ """
120
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
121
+
122
+ visit_queue = [self]
123
+
124
+ # This prevents infinite recursion in the (illegal) case of cyclical graphs.
125
+ visited = set()
126
+
127
+ while visit_queue:
128
+ node = visit_queue.pop()
129
+ for attr in node.attrs.values():
130
+ if isinstance(attr, Graph) and id(attr) not in visited:
131
+ visited.add(id(attr))
132
+ if recursive:
133
+ visit_queue.extend(attr.nodes)
134
+ yield attr
135
+
136
+ def __setattr__(self, name, value):
137
+ """Sets the attribute 'name' to 'value', handling special cases for 'inputs' and 'outputs' attributes."""
138
+ if name in {"inputs", "outputs"}:
139
+ try:
140
+ attr = getattr(self, name)
141
+ if value is attr:
142
+ # This can happen when using things like +=
143
+ # The __iadd__ is executed followed by an assignment
144
+ return
145
+
146
+ attr.clear()
147
+ attr.extend(value)
148
+ except AttributeError:
149
+ super().__setattr__(name, value)
150
+ else:
151
+ super().__setattr__(name, value)
152
+
153
+ def copy(
154
+ self,
155
+ inputs: list[Tensor] | None = None,
156
+ outputs: list[Tensor] | None = None,
157
+ tensor_map=None,
158
+ ):
159
+ """
160
+ Makes a shallow copy of this node, overriding input and output information.
161
+
162
+ Note: Generally, you should only ever make a copy of a Graph.
163
+ """
164
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
165
+
166
+ new_attrs = OrderedDict()
167
+ for name, attr in self.attrs.items():
168
+ new_attrs[name] = attr.copy(tensor_map) if isinstance(attr, Graph) else attr
169
+ return Node(
170
+ self.op,
171
+ self.name,
172
+ new_attrs,
173
+ inputs=inputs,
174
+ outputs=outputs,
175
+ domain=self.domain,
176
+ )
177
+
178
+ def __str__(self):
179
+ """Return a string representation of the object showing its name and operation."""
180
+ ret = f"{self.name} ({self.op})"
181
+
182
+ def add_io(name, io):
183
+ """Add the input or output operations and their names to the string representation of the object."""
184
+ nonlocal ret
185
+ ret += f"\n\t{name}: ["
186
+ for elem in io:
187
+ ret += f"\n\t\t{elem}"
188
+ ret += "\n\t]"
189
+
190
+ add_io("Inputs", self.inputs)
191
+ add_io("Outputs", self.outputs)
192
+
193
+ if self.attrs:
194
+ ret += f"\nAttributes: {self.attrs}"
195
+
196
+ if self.domain:
197
+ ret += f"\nDomain: {self.domain}"
198
+
199
+ return ret
200
+
201
+ def __repr__(self):
202
+ """Return the string representation of the Ultralytics object."""
203
+ return self.__str__()
204
+
205
+ def __eq__(self, other):
206
+ """Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs."""
207
+ G_LOGGER.verbose(f"Comparing node: {self.name} with {other.name}")
208
+ attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs
209
+ if not attrs_match:
210
+ return False
211
+
212
+ inputs_match = misc.sequences_equal(self.inputs, other.inputs)
213
+ if not inputs_match:
214
+ return False
215
+
216
+ outputs_match = misc.sequences_equal(self.outputs, other.outputs)
217
+ return self.domain == other.domain if outputs_match else False
218
+
219
+ @property
220
+ def users(self):
221
+ users = []
222
+ for output in self.outputs: # output is a Variable
223
+ if output.is_output:
224
+ output.op = "output"
225
+ users.append(output)
226
+ users.extend(iter(output.outputs))
227
+ return users
228
+
229
+ @property
230
+ def feeds(self):
231
+ """Retrieve the list of nodes that provide inputs to the given node."""
232
+ feeds = []
233
+ for input in self.inputs:
234
+ if len(input.inputs) == 0 and not isinstance(input, Constant):
235
+ feeds.append(input)
236
+ elif isinstance(input, Constant):
237
+ feeds.append(input)
238
+ else:
239
+ feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs)
240
+ return feeds
241
+
242
+ def erase(self, input_var_idx=0, output_var_idx=0):
243
+ if isinstance(self.inputs[input_var_idx], Variable):
244
+ if self.inputs[input_var_idx].is_input:
245
+ self.outputs[output_var_idx].replace_all_uses_with(self.inputs[input_var_idx])
246
+ self.inputs.clear()
247
+ self.outputs.clear()
248
+ else:
249
+ self.inputs[input_var_idx].replace_all_uses_with(self.outputs[output_var_idx])
250
+ self.inputs.clear()
251
+ self.outputs.clear()
252
+
253
+ def replace_all_uses_with(self, node: Node):
254
+ """Replace all uses of this node with the given node."""
255
+ for user in self.users:
256
+ for inp in user.inputs:
257
+ if inp in self.outputs:
258
+ for i, input in enumerate(user.inputs):
259
+ if input == inp:
260
+ user.inputs[i] = node.outputs[self.outputs.index(inp)]
261
+
262
+ if isinstance(self.outputs[0], Variable) and self.outputs[0].is_output:
263
+ node.outputs[0] = self.outputs[0]
264
+
265
+ self.inputs.clear()
266
+ self.outputs.clear()