onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,1575 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from __future__ import annotations
18
+
19
+ import copy
20
+ import numbers
21
+ from collections import OrderedDict, defaultdict
22
+ from collections.abc import Sequence
23
+
24
+ import numpy as np
25
+
26
+ from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
27
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
28
+ from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER, LogMode
29
+ from onnxslim.third_party.onnx_graphsurgeon.util import misc
30
+
31
+
32
+ class NodeIDAdder:
33
+ def __init__(self, graph):
34
+ """Initializes NodeIDAdder with a specified graph."""
35
+ self.graph = graph
36
+
37
+ def __enter__(self):
38
+ """Assigns unique `id` attributes to each node in the graph upon entering the context."""
39
+ # Using the index in the node list allows the same object to count as different nodes.
40
+ for index, node in enumerate(self.graph.nodes):
41
+ node.id = index
42
+
43
+ def __exit__(self, exc_type, exc_value, traceback):
44
+ """Removes `id` attributes from each node in the graph upon exiting the context."""
45
+ for node in self.graph.nodes:
46
+ del node.id
47
+
48
+
49
+ class Graph:
50
+ """Represents a graph containing nodes and tensors."""
51
+
52
+ DEFAULT_OPSET = 11
53
+ OPSET_FUNC_MAP = defaultdict(dict) # Ops registered for specific opsets.
54
+ GLOBAL_FUNC_MAP = {} # Ops registered for ALL opsets.
55
+
56
+ @staticmethod
57
+ def register(opsets=None):
58
+ """
59
+ Registers a function with the Graph class for the specified group of opsets. After registering the function, it
60
+ can be accessed like a normal member function.
61
+
62
+ For example:
63
+ ::
64
+
65
+ @Graph.register()
66
+ def add(self, a, b):
67
+ '''Registers a function with the Graph class for the specified group of opsets for dynamic access as a member function.'''
68
+ return self.layer(op="Add", inputs=[a, b], outputs=["add_out_gs"])
69
+
70
+ graph.add(a, b)
71
+
72
+ Args:
73
+ opsets (Sequence[int]):
74
+ A group of opsets for which to register the function. Multiple functions with the same
75
+ name may be registered simultaneously if they are registered for different opsets.
76
+ Registering a function with a duplicate name for the same opsets will overwrite any
77
+ function previously registered for those opsets. By default, the function is
78
+ registered for all opsets.
79
+ """
80
+
81
+ def register_func(func):
82
+ """Registers a function for different opsets, overwriting any previously registered function with the same
83
+ name.
84
+ """
85
+ if hasattr(Graph, func.__name__):
86
+ G_LOGGER.warning(
87
+ f"Registered function: {func.__name__} is hidden by a Graph attribute or function with the same name. "
88
+ "This function will never be called!"
89
+ )
90
+
91
+ # Default behavior is to register functions for all opsets.
92
+ if opsets is None:
93
+ Graph.GLOBAL_FUNC_MAP[func.__name__] = func
94
+ else:
95
+ for opset in opsets:
96
+ Graph.OPSET_FUNC_MAP[opset][func.__name__] = func
97
+ return func
98
+
99
+ return register_func
100
+
101
+ def __init__(
102
+ self,
103
+ nodes: Sequence[Node] | None = None,
104
+ inputs: Sequence[Tensor] | None = None,
105
+ outputs: Sequence[Tensor] | None = None,
106
+ name=None,
107
+ doc_string=None,
108
+ opset=None,
109
+ import_domains=None,
110
+ ir_version=None,
111
+ producer_name: str | None = None,
112
+ producer_version: str | None = None,
113
+ functions: Sequence[Function] | None = None,
114
+ metadata_props=None,
115
+ ):
116
+ """
117
+ Args:
118
+ nodes (Sequence[Node]): A list of the nodes in this graph.
119
+ inputs (Sequence[Tensor]): A list of graph input Tensors.
120
+ outputs (Sequence[Tensor]): A list of graph output Tensors.
121
+ name (str): The name of the graph. Defaults to "onnx_graphsurgeon_graph".
122
+ doc_string (str): A doc_string for the graph. Defaults to "".
123
+ opset (int): The ONNX opset to use when exporting this graph.
124
+ producer_name (str): The name of the tool used to generate the model. Defaults to "".
125
+ producer_version (str): The version of the generating tool. Defaults to "".
126
+ """
127
+ self.nodes = misc.default_value(nodes, [])
128
+ self.inputs = list(misc.default_value(inputs, []))
129
+ self.outputs = list(misc.default_value(outputs, []))
130
+
131
+ self.name = misc.default_value(name, "onnx_graphsurgeon_graph")
132
+ self.__name__ = self.name
133
+
134
+ self.doc_string = misc.default_value(doc_string, "")
135
+ self.opset = misc.default_value(opset, Graph.DEFAULT_OPSET)
136
+ self.producer_name = misc.default_value(producer_name, "")
137
+ self.producer_version = misc.default_value(producer_version, "")
138
+ self.metadata_props = metadata_props
139
+ self.import_domains = import_domains
140
+ self.ir_version = ir_version
141
+ # For layer() function
142
+ self.name_idx = 0
143
+
144
+ # In ONNX, the same list of Functions is shared between all Graphs & Functions in a model.
145
+ # Protect the list object with an underscore as self._functions
146
+ # Users should access/modify/set this list via graph.functions
147
+ self._functions = list(misc.default_value(functions, []))
148
+ self._merge_subgraph_functions()
149
+
150
+ # Printing graphs can be very expensive
151
+ G_LOGGER.ultra_verbose(lambda: f"Created Graph: {self}")
152
+
153
+ def __getattr__(self, name):
154
+ """Dynamically handles attribute access, falling back to superclass attribute retrieval if not found."""
155
+ try:
156
+ return super().__getattribute__(name)
157
+ except AttributeError as err:
158
+ # Warn user if the name matches multiple registered functions.
159
+ methods = []
160
+ method_descs = []
161
+
162
+ # Opset specific ops always take priority over global ops.
163
+ if self.opset in Graph.OPSET_FUNC_MAP and name in Graph.OPSET_FUNC_MAP[self.opset]:
164
+ methods.append(Graph.OPSET_FUNC_MAP[self.opset][name])
165
+ method_descs.append(f'GraphSurgeon-registered function "{name}" with opset {self.opset}')
166
+
167
+ # Registered ops take priority over Local Functions.
168
+ if name in Graph.GLOBAL_FUNC_MAP:
169
+ methods.append(Graph.GLOBAL_FUNC_MAP[name])
170
+ method_descs.append(f'GraphSurgeon-registered function "{name}"')
171
+
172
+ for func in self.functions:
173
+ if func.name == name:
174
+ methods.append(func.__call__)
175
+ method_descs.append(f'Local Function "{func.name}" with domain "{func.domain}"')
176
+
177
+ if methods:
178
+ if len(methods) > 1:
179
+ msg_template = (
180
+ "Method name {} is overloaded with the following candidates: {}. " + "Choosing candidate {}"
181
+ )
182
+ G_LOGGER.warning(
183
+ message=msg_template.format(name, method_descs, method_descs[0]),
184
+ mode=LogMode.ONCE,
185
+ )
186
+ return lambda *args, **kwargs: methods[0](self, *args, **kwargs)
187
+
188
+ found_in_other_opsets = {opset for opset, opset_map in Graph.OPSET_FUNC_MAP.items() if name in opset_map}
189
+
190
+ G_LOGGER.error(
191
+ f"Function: '{name}' was not registered for opset {self.opset}. "
192
+ + (
193
+ f"Note: '{name}' was registered for opsets: {found_in_other_opsets}."
194
+ if found_in_other_opsets
195
+ else ""
196
+ )
197
+ )
198
+ raise err
199
+
200
+ def __setattr__(self, name, value):
201
+ """Sets an attribute to the given value, converting 'inputs' and 'outputs' to lists if applicable."""
202
+ if name in {"inputs", "outputs"}:
203
+ value = list(value)
204
+ return super().__setattr__(name, value)
205
+
206
+ @property
207
+ def functions(self) -> list[Function]:
208
+ """Returns the list of subgraph functions associated with this graph."""
209
+ return self._functions
210
+
211
+ @functions.setter
212
+ def functions(self, new_fns: Sequence[Function]):
213
+ """Get or set the list of functions, ensuring changes propagate to all associated subgraphs and functions."""
214
+ # this graph, its subgraphs, and its functions.
215
+ # If the user sets a new value for self.functions,
216
+ # all subgraphs and functions should also see this new value.
217
+ self._functions.clear()
218
+ self._functions += list(new_fns)
219
+
220
+ def __eq__(self, other: Graph):
221
+ """Check for equality between two Graph objects by comparing their nodes, inputs, and outputs."""
222
+ nodes_match = misc.sequences_equal(self.nodes, other.nodes)
223
+ if not nodes_match:
224
+ return False
225
+ inputs_match = misc.sequences_equal(self.inputs, other.inputs)
226
+ if not inputs_match:
227
+ return False
228
+ outputs_match = misc.sequences_equal(self.outputs, other.outputs)
229
+ if not outputs_match:
230
+ return False
231
+
232
+ return self.opset == other.opset and self.import_domains == other.import_domains
233
+
234
+ def node_ids(self):
235
+ """
236
+ Returns a context manager that supplies unique integer IDs for Nodes in the Graph.
237
+
238
+ For example:
239
+ ::
240
+
241
+ with graph.node_ids():
242
+ assert graph.nodes[0].id != graph.nodes[1].id
243
+
244
+ Returns:
245
+ NodeIDAdder: A context manager that supplies unique integer IDs for Nodes.
246
+ """
247
+ return NodeIDAdder(self)
248
+
249
+ # Gets the node ID for a node. All internal code should use this instead of accessing `node.id` directly.
250
+ def _get_node_id(self, node):
251
+ """Gets the node ID for a node, ensuring all internal code uses this instead of directly accessing `node.id`."""
252
+ try:
253
+ return node.id
254
+ except AttributeError:
255
+ G_LOGGER.critical(
256
+ f"Encountered a node not in the graph:\n{node}.\n\n"
257
+ "To fix this, please append the node to this graph's `nodes` attribute."
258
+ )
259
+
260
+ # A tensor is local if it is produced in this graph, or is explicitly a graph input.
261
+ def _local_tensors(self):
262
+ """Return a dictionary of tensors that are local to the graph, including nodes' outputs, graph inputs, and
263
+ constants.
264
+ """
265
+ local_tensors = {t.name: t for node in self.nodes for t in node.outputs if not t.is_empty()}
266
+ local_tensors.update({t.name: t for t in self.inputs})
267
+ local_tensors.update({t.name: t for t in self.tensors().values() if isinstance(t, Constant)})
268
+ return local_tensors
269
+
270
+ # Returns tensors used by this graph which are not present in the graph.
271
+ # These may come from an outer graph for example.
272
+ def _foreign_tensors(self):
273
+ """Returns tensors used by this graph which are not present in the graph, potentially from an outer graph."""
274
+ local_tensors = self._local_tensors()
275
+ foreign_tensors = {}
276
+
277
+ def is_foreign_tensor(tensor):
278
+ """Check if a tensor is foreign by verifying its absence in local tensors."""
279
+ return tensor.name not in local_tensors
280
+
281
+ for node in self.nodes:
282
+ foreign_tensors.update({t.name: t for t in node.inputs if is_foreign_tensor(t)})
283
+
284
+ for subgraph in node.subgraphs():
285
+ subgraph_foreign_tensors = subgraph._foreign_tensors()
286
+ # Some of the foreign tensors from a subgraph may come from this graph.
287
+ subgraph_foreign_tensors = {
288
+ t.name: t for t in subgraph_foreign_tensors.values() if is_foreign_tensor(t)
289
+ }
290
+ foreign_tensors.update(subgraph_foreign_tensors)
291
+
292
+ return foreign_tensors
293
+
294
+ def _get_used_node_ids(self):
295
+ """Returns a dictionary of tensors that are used by node IDs in the current subgraph."""
296
+ local_tensors = self._local_tensors()
297
+
298
+ class IgnoreDupAndForeign:
299
+ def __init__(self, initial_tensors=None):
300
+ """Initialize IgnoreDupAndForeign with an optional list of initial tensors."""
301
+ tensors = misc.default_value(initial_tensors, [])
302
+ self.seen_tensors = {tensor.name for tensor in tensors}
303
+
304
+ def __call__(self, tensor):
305
+ """Check if a tensor should be included based on its name and whether it has been seen before."""
306
+ # False if it should be filtered out.
307
+ if tensor.is_empty():
308
+ return True
309
+ elif tensor.name not in local_tensors:
310
+ return False
311
+ elif tensor.name not in self.seen_tensors:
312
+ self.seen_tensors.add(tensor.name)
313
+ return True
314
+ return False
315
+
316
+ # Traverse backwards from outputs to find all used nodes.
317
+ ignore_tensors = IgnoreDupAndForeign()
318
+ used_tensors = list(filter(ignore_tensors, self.outputs))
319
+ used_node_ids = set()
320
+
321
+ index = 0
322
+ while index < len(used_tensors):
323
+ used_tensor = used_tensors[index]
324
+ index += 1
325
+ for node in used_tensor.inputs:
326
+ # Must cast to list here, otherwise node_used_tensors will be SynchronizedList!
327
+ node_used_tensors = list(node.inputs)
328
+
329
+ # If a node includes a subgraph, get any tensors that it uses from the outer graph.
330
+ for subgraph in node.subgraphs():
331
+ node_used_tensors += list(subgraph._foreign_tensors().values())
332
+
333
+ used_node_ids.add(self._get_node_id(node))
334
+ used_tensors.extend(filter(ignore_tensors, node_used_tensors))
335
+ return used_node_ids, used_tensors
336
+
337
+ def _merge_subgraph_functions(self):
338
+ """Merge function lists of subgraphs into the parent graph's function list."""
339
+ # function list than the parent graph. This function merges those lists.
340
+ func_ids = {func.unique_id for func in self.functions}
341
+
342
+ def absorb_function_list(func_list):
343
+ """Absorb and merge unique functions from a provided function list into the parent graph's function list."""
344
+ for func in func_list:
345
+ if func.unique_id not in func_ids:
346
+ self.functions.append(func)
347
+ func_ids.add(func.unique_id)
348
+ return self.functions
349
+
350
+ for graph in [*self.functions, self]:
351
+ for subgraph in graph.subgraphs(recursive=True):
352
+ new_list = absorb_function_list(subgraph.functions)
353
+ subgraph._functions = new_list
354
+
355
+ for func in self.functions:
356
+ func._functions = absorb_function_list(func.functions)
357
+
358
+ def subgraphs(self, recursive=False):
359
+ """
360
+ Convenience function to iterate over all subgraphs which are contained in this graph. Subgraphs are found in the
361
+ attributes of ONNX control flow nodes such as 'If' and 'Loop'.
362
+
363
+ Args:
364
+ recursive (bool): Whether to recursively search this graph's subgraphs for more subgraphs. Defaults to False.
365
+
366
+ Returns:
367
+ A generator which iterates over the subgraphs contained in this graph.
368
+ """
369
+ for node in self.nodes:
370
+ yield from node.subgraphs(recursive=recursive)
371
+
372
+ def cleanup(
373
+ self,
374
+ remove_unused_node_outputs=False,
375
+ recurse_subgraphs=True,
376
+ remove_unused_graph_inputs=False,
377
+ recurse_functions=True,
378
+ ):
379
+ """
380
+ Removes unused nodes and tensors from the graph. A node or tensor is considered unused if it does not contribute
381
+ to any of the graph outputs.
382
+
383
+ Additionally, any producer nodes of graph input tensors, as well as consumer nodes of graph output
384
+ tensors that are not in the graph, are removed from the graph.
385
+
386
+ *Note: This function will never modify graph output tensors.*
387
+
388
+ Args:
389
+ remove_unused_node_outputs (bool): Whether to remove unused output tensors of nodes. This will never remove
390
+ empty-tensor (i.e. optional, but omitted) outputs. Defaults to False.
391
+ recurse_subgraphs (bool):
392
+ Whether to recursively cleanup subgraphs.
393
+ Defaults to True.
394
+ remove_unused_graph_inputs (bool):
395
+ Whether to remove unused graph inputs.
396
+ Defaults to False.
397
+ recurse_functions (bool):
398
+ Whether to also clean up this graph's local functions.
399
+ Defaults to True.
400
+
401
+ Returns:
402
+ self
403
+ """
404
+
405
+ def cleanup_subgraphs():
406
+ """Clean up subgraphs by removing unused node outputs and graph inputs, optionally recursing into subgraphs
407
+ and local functions.
408
+ """
409
+ for subgraph in self.subgraphs():
410
+ subgraph.cleanup(
411
+ remove_unused_node_outputs=remove_unused_node_outputs,
412
+ recurse_subgraphs=recurse_subgraphs,
413
+ remove_unused_graph_inputs=False,
414
+ recurse_functions=False, # Only cleanup functions once
415
+ )
416
+
417
+ if recurse_subgraphs:
418
+ cleanup_subgraphs()
419
+
420
+ if recurse_functions:
421
+ for func in self.functions:
422
+ func.cleanup(
423
+ remove_unused_node_outputs=remove_unused_node_outputs,
424
+ recurse_subgraphs=recurse_subgraphs,
425
+ remove_unused_graph_inputs=remove_unused_graph_inputs,
426
+ recurse_functions=False, # No infinite recursion
427
+ )
428
+
429
+ G_LOGGER.verbose(f"Cleaning up {self.name}")
430
+
431
+ with self.node_ids():
432
+ # Graph input producers must be removed first so used_node_ids is correct.
433
+ for inp in self.inputs:
434
+ inp.inputs.clear()
435
+
436
+ used_node_ids, used_tensors = self._get_used_node_ids()
437
+
438
+ inputs = []
439
+ for inp in self.inputs:
440
+ if inp in used_tensors or not remove_unused_graph_inputs:
441
+ inputs.append(inp)
442
+ else:
443
+ G_LOGGER.debug(f"Removing unused input: {inp}")
444
+ self.inputs = inputs
445
+
446
+ nodes = []
447
+
448
+ for node in self.nodes:
449
+ if self._get_node_id(node) in used_node_ids:
450
+ nodes.append(node)
451
+ else:
452
+ node.inputs.clear()
453
+ node.outputs.clear()
454
+ G_LOGGER.ultra_verbose(f"Removing unused node: {node}")
455
+
456
+ # Remove any hanging tensors - tensors without outputs
457
+ if remove_unused_node_outputs:
458
+ graph_output_names = {tensor.name for tensor in self.outputs}
459
+ for node in nodes:
460
+
461
+ def is_hanging_tensor(tensor):
462
+ """Checks if a tensor is hanging by verifying it is non-empty, has no outputs, and is not a
463
+ graph output.
464
+ """
465
+ return (
466
+ not tensor.is_empty() and len(tensor.outputs) == 0 and tensor.name not in graph_output_names
467
+ )
468
+
469
+ to_remove = [out for out in node.outputs if is_hanging_tensor(out)]
470
+ for out in to_remove:
471
+ if out in node.outputs:
472
+ node.outputs.remove(out)
473
+
474
+ self.nodes = nodes
475
+
476
+ return self
477
+
478
+ def toposort(
479
+ self,
480
+ recurse_subgraphs=True,
481
+ recurse_functions=True,
482
+ mode="full",
483
+ ):
484
+ """
485
+ Topologically sort the graph in place.
486
+
487
+ Args:
488
+ recurse_subgraphs (bool):
489
+ Whether to recursively topologically sort subgraphs.
490
+ Only applicable when mode="full" or mode="nodes".
491
+ Defaults to True.
492
+ recurse_functions (bool):
493
+ Whether to topologically sort the nodes of this graph's functions.
494
+ Only applicable when mode="full" or mode="nodes".
495
+ Defaults to True.
496
+ mode (str):
497
+ Whether to reorder this graph's list of nodes, list of functions, or both.
498
+ Possible values:
499
+ - "full": Topologically sort the list of nodes and the list of functions.
500
+ - "nodes": Only sort the list of nodes.
501
+ - "functions": Only sort the list of functions.
502
+ Defaults to "full".
503
+
504
+ Returns:
505
+ self
506
+ """
507
+ ALLOWED_MODES = ["full", "nodes", "functions"]
508
+ if mode not in ALLOWED_MODES:
509
+ G_LOGGER.critical(f'Mode "{mode}" not in {ALLOWED_MODES}')
510
+
511
+ sort_nodes = mode in {"full", "nodes"}
512
+ sort_functions = mode in {"full", "functions"}
513
+
514
+ if sort_nodes and recurse_functions:
515
+ for func in self.functions:
516
+ func.toposort(recurse_subgraphs=recurse_subgraphs, mode="nodes")
517
+
518
+ if sort_nodes and recurse_subgraphs:
519
+ for subgraph in self.subgraphs():
520
+ subgraph.toposort(recurse_subgraphs=True, recurse_functions=False, mode="nodes")
521
+
522
+ G_LOGGER.debug(f"Topologically sorting {self.name}")
523
+
524
+ # Keeps track of a node and its level in the graph hierarchy.
525
+ # 0 corresponds to an input node, N corresponds to a node with N layers of inputs.
526
+ class HierarchyDescriptor:
527
+ def __init__(self, node_or_func, level=None):
528
+ """Initializes a HierarchyDescriptor with a node or function and an optional level in the graph
529
+ hierarchy.
530
+ """
531
+ self.node_or_func = node_or_func
532
+ self.level = level
533
+
534
+ def __lt__(self, other):
535
+ """Defines less-than comparison behavior based on hierarchy levels."""
536
+ return self.level < other.level
537
+
538
+ hierarchy_levels = {} # Dict[int, HierarchyDescriptor]
539
+
540
+ local_tensors = self._local_tensors()
541
+ func_id_to_func = {}
542
+
543
+ def get_id(node_or_func):
544
+ """Returns the unique ID for a Node object or a function."""
545
+ if isinstance(node_or_func, Node):
546
+ return self._get_node_id(node_or_func)
547
+ return node_or_func.unique_id
548
+
549
+ def get_hierarchy_level(node_or_func, visited=None):
550
+ """Returns the hierarchy level of a node or function, with optional tracking of visited elements."""
551
+ visited = misc.default_value(visited, set())
552
+ visited.add(get_id(node_or_func))
553
+
554
+ def get_inputs(node_or_func):
555
+ """Find all nodes used by a given node or function."""
556
+
557
+ def get_used_nodes(node):
558
+ """Find all nodes that are used as inputs by a given node."""
559
+ inputs = {}
560
+
561
+ def add_local_producers(tensor):
562
+ """Add local tensors and their producer nodes to the inputs dictionary."""
563
+ nonlocal inputs
564
+ if tensor.name in local_tensors:
565
+ for inp_node in tensor.inputs:
566
+ inputs[self._get_node_id(inp_node)] = inp_node
567
+
568
+ for tensor in node.inputs:
569
+ add_local_producers(tensor)
570
+
571
+ # If a node includes a subgraph, get any tensors that it uses from the outer graph.
572
+ for subgraph in node.subgraphs():
573
+ for tensor in subgraph._foreign_tensors().values():
574
+ add_local_producers(tensor)
575
+
576
+ return inputs.values()
577
+
578
+ # Find all functions used in this list of nodes.
579
+ def get_used_funcs(nodes):
580
+ """Return a dictionary of functions used in the provided list of nodes."""
581
+ inputs = {}
582
+ for subgraph in self.subgraphs():
583
+ inputs.update(get_used_funcs(subgraph.nodes))
584
+ for node in nodes:
585
+ func_id = (node.domain, node.op)
586
+ if func_id in func_id_to_func:
587
+ inputs[func_id] = func_id_to_func[func_id]
588
+ return inputs
589
+
590
+ if isinstance(node_or_func, Node):
591
+ inputs = get_used_nodes(node_or_func)
592
+ else:
593
+ inputs = get_used_funcs(node_or_func.nodes).values()
594
+ return inputs
595
+
596
+ if get_id(node_or_func) in hierarchy_levels:
597
+ return hierarchy_levels[get_id(node_or_func)].level
598
+
599
+ # The level of a node is the level of its highest input + 1.
600
+ max_input_level = max(
601
+ [get_hierarchy_level(inp, visited=visited) for inp in get_inputs(node_or_func)] + [-1]
602
+ )
603
+ visited.remove(get_id(node_or_func))
604
+
605
+ hierarchy_levels[get_id(node_or_func)] = HierarchyDescriptor(node_or_func, level=max_input_level + 1)
606
+ return max_input_level + 1
607
+
608
+ if sort_nodes:
609
+ with self.node_ids():
610
+ for node in self.nodes:
611
+ hierarchy_levels[get_id(node)] = HierarchyDescriptor(node, level=get_hierarchy_level(node))
612
+ self.nodes = [hd.node_or_func for hd in sorted(hierarchy_levels.values())]
613
+
614
+ if sort_functions:
615
+ self._merge_subgraph_functions()
616
+ func_id_to_func.update({func.unique_id: func for func in self.functions})
617
+ hierarchy_levels.clear()
618
+ for func in self.functions:
619
+ hierarchy_levels[func.unique_id] = HierarchyDescriptor(func, level=get_hierarchy_level(func))
620
+ self.functions = [hd.node_or_func for hd in sorted(hierarchy_levels.values())]
621
+
622
+ return self
623
+
624
+ def tensors(self, check_duplicates=False):
625
+ """
626
+ Creates a tensor map of all the tensors used by this graph by walking over all nodes. Empty tensors are omitted
627
+ from this map.
628
+
629
+ Tensors are guaranteed to be in order of the nodes in the graph. Hence, if the graph is topologically sorted, the tensor map will be too.
630
+
631
+ Args:
632
+ check_duplicates (bool): Whether to fail if multiple tensors with the same name are encountered.
633
+
634
+ Raises:
635
+ OnnxGraphSurgeonException: If check_duplicates is True and multiple distinct tensors in the graph share the same name.
636
+
637
+ Returns:
638
+ OrderedDict[str, Tensor]: A mapping of tensor names to tensors.
639
+ """
640
+ tensor_map = OrderedDict()
641
+
642
+ def add_to_tensor_map(tensor):
643
+ """Add a tensor to the tensor_map if it is not empty and ensure no duplicate tensor names exist."""
644
+ if not tensor.is_empty():
645
+ if tensor.name in tensor_map and tensor_map[tensor.name] is not tensor:
646
+ msg = f"Found distinct tensors that share the same name:\n[id: {id(tensor_map[tensor.name])}] {tensor_map[tensor.name]}\n[id: {id(tensor)}] {tensor}\n"
647
+ msg += f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}"
648
+
649
+ if check_duplicates:
650
+ G_LOGGER.critical(msg)
651
+ # G_LOGGER.warning(msg)
652
+
653
+ tensor_map[tensor.name] = tensor
654
+
655
+ # I/O tensors may not be attached to nodes.
656
+ for io_tensor in self.inputs:
657
+ add_to_tensor_map(io_tensor)
658
+
659
+ for node in self.nodes:
660
+ for tensor in node.inputs + node.outputs:
661
+ add_to_tensor_map(tensor)
662
+
663
+ for io_tensor in self.outputs:
664
+ add_to_tensor_map(io_tensor)
665
+
666
+ return tensor_map
667
+
668
+ def fold_constants(
669
+ self,
670
+ fold_shapes=True,
671
+ recurse_subgraphs=True,
672
+ partitioning=None,
673
+ error_ok=True,
674
+ flatten_subgraphs=True,
675
+ size_threshold=None,
676
+ should_exclude_node=None,
677
+ recurse_functions=True,
678
+ ):
679
+ """
680
+ Folds constants in-place in the graph. The graph's nodes and functions must be topologically sorted prior to
681
+ calling this function (see `toposort()`).
682
+
683
+ This function will not remove constants after folding them. In order to get rid of
684
+ these hanging nodes, you can run the `cleanup()` function.
685
+
686
+ *Note: Due to how this function is implemented, the graph must be exportable to ONNX,
687
+ and evaluable in ONNX-Runtime. Additionally, ONNX-Runtime must be installed.*
688
+
689
+ Args:
690
+ fold_shapes (bool):
691
+ Whether to fold `Shape` nodes in the graph.
692
+ This requires shapes to be inferred in the graph, and can only fold
693
+ static shapes.
694
+ Defaults to True.
695
+ recurse_subgraphs (bool):
696
+ Whether to recursively fold constants in subgraphs.
697
+ Defaults to True.
698
+ partitioning (Union[str, None]):
699
+ Whether/How to partition the graph so that errors in folding one
700
+ part of a model do not affect other parts. Available modes are:
701
+
702
+ - None: Do not partition the graph. If inference fails, no constants are folded.
703
+ - "basic": Partition the graph. If inference fails in one partition, other partitions will
704
+ remain unaffected.
705
+ - "recursive": Partition the graph recursively. If inference fails in a partition, the partition
706
+ will be further partitioned.
707
+
708
+ Defaults to None.
709
+ error_ok (bool):
710
+ Whether inference errors should be suppressed.
711
+ When this is False, any errors encountered during inference will be re-raised.
712
+ Defaults to True.
713
+ flatten_subgraphs (bool):
714
+ Whether to flatten subgraphs where possible. For example, `If` nodes with a constant condition
715
+ can be flattened into the parent graph.
716
+ size_threshold (int):
717
+ The maximum size threshold, in bytes, for which to fold constants.
718
+ Any tensors larger than this value will not be folded.
719
+ Set to ``None`` to disable the size threshold and always fold constants.
720
+ For example, some models may apply ops like `Tile` or `Expand` to constants, which can
721
+ result in very large tensors. Rather than pre-computing those constants and bloating
722
+ the model size, it may be desirable to skip folding them and allow them to be computed
723
+ at runtime.
724
+ Defaults to None.
725
+ should_exclude_node (Callable[[gs.Node], bool]):
726
+ A callable that accepts an onnx-graphsurgeon node from the graph and reports whether it should
727
+ be excluded from folding. This is only called for nodes which are otherwise foldable.
728
+ Note that preventing a node from being folded also prevents its consumers from being folded.
729
+ Defaults to a callable that always returns False.
730
+ recurse_functions (bool):
731
+ Whether to fold constants in this graph's Functions.
732
+ Defaults to True.
733
+
734
+ Returns:
735
+ self
736
+ """
737
+ from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import (
738
+ dtype_to_onnx,
739
+ export_onnx,
740
+ )
741
+
742
+ custom_should_exclude_node = misc.default_value(should_exclude_node, lambda node: False)
743
+
744
+ # Don't fold nodes with attribute values which are variable.
745
+ def should_exclude_node(node):
746
+ """Determine if an ONNX graph node should be excluded based on its attributes."""
747
+ for attr_val in node.attrs.values():
748
+ if isinstance(attr_val, Node.AttributeRef):
749
+ return True
750
+ return custom_should_exclude_node(node)
751
+
752
+ PARTITIONING_MODES = [None, "basic", "recursive"]
753
+ if partitioning not in PARTITIONING_MODES:
754
+ G_LOGGER.critical(f"Argument for parameter 'partitioning' must be one of: {PARTITIONING_MODES}")
755
+ ORT_PROVIDERS = ["CPUExecutionProvider"]
756
+
757
+ G_LOGGER.debug(f"Folding constants in {self.name}")
758
+
759
+ # We apply constant folding in 5 passes:
760
+ # Pass 1 lowers 'Constant' nodes into Constant tensors.
761
+ # Pass 2 elides casts applied to shape tensors. This is done separately from other shape folding
762
+ # since it operates on the original graph rather than a clone.
763
+ # Pass 3 finds all Constant tensors in the graph, then finds all descendants which are dependent
764
+ # only on constants.
765
+ # Pass 4 searches for Shape nodes that have variable inputs (i.e. not marked const in pass 1)
766
+ # and turns them into Constants iff the input has a statically known shape.
767
+ # Pass 5 computes the descendants determined in Pass 3 using ONNX-Runtime and replaces them in the graph.
768
+
769
+ # Pass 1: Lower constant nodes
770
+ for tensor in self.tensors().values():
771
+ if len(tensor.inputs) == 1:
772
+ node = tensor.inputs[0]
773
+ if node.op == "Constant" and tensor.outputs:
774
+ if len(node.attrs) != 1:
775
+ G_LOGGER.warning("Constant node must contain exactly one attribute")
776
+ continue
777
+ attr_name, attr_val = next(iter(node.attrs.items()))
778
+ allowed_attrs = {
779
+ "value",
780
+ "value_float",
781
+ "value_floats",
782
+ "value_int",
783
+ "value_ints",
784
+ }
785
+ if attr_name not in allowed_attrs:
786
+ G_LOGGER.warning(f"Unsupported attribute for Constant node: {attr_name}")
787
+ continue
788
+ if isinstance(attr_val, Node.AttributeRef):
789
+ continue
790
+ elif isinstance(attr_val, Constant):
791
+ arr = attr_val._values # Using ._values avoids copying
792
+ else:
793
+ arr = np.array(attr_val, dtype=tensor.dtype)
794
+ tensor.to_constant(arr)
795
+ tensor.inputs.clear()
796
+
797
+ # Pass 2: Run shape-tensor cast elision
798
+ def run_cast_elision(node):
799
+ """Perform cast elision optimization on an ONNX node to eliminate unnecessary cast operations."""
800
+ import onnx
801
+
802
+ # Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int)
803
+ # This pattern is problematic for TensorRT since these operations may be performed on Shape Tensors, which
804
+ # are not allowed to be floating point type. Attempt to fold the pattern here
805
+ VALID_CAST_ELISION_OPS = {
806
+ "Add",
807
+ "Sub",
808
+ "Mul",
809
+ "Div",
810
+ "Max",
811
+ "Min",
812
+ "Equal",
813
+ "Greater",
814
+ "Less",
815
+ "Concat",
816
+ }
817
+
818
+ if node.op not in VALID_CAST_ELISION_OPS:
819
+ return
820
+
821
+ # If the uncasted outputs of this node have any consumers other than "Cast" nodes,
822
+ # then we cannot elide the cast.
823
+ for out_tensor in node.outputs:
824
+ if out_tensor in self.outputs:
825
+ return
826
+
827
+ if any(out_node.op != "Cast" for out_node in out_tensor.outputs):
828
+ return
829
+
830
+ # Get list of input nodes that cast to float32
831
+ inp_casts = [
832
+ inp_node
833
+ for inp_tensor in node.inputs
834
+ for inp_node in inp_tensor.inputs
835
+ if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT
836
+ ]
837
+
838
+ # No cast nodes found, return early
839
+ if not inp_casts:
840
+ return
841
+
842
+ # Ensure that all input cast nodes are casting from the same type
843
+ inp_dtypes = [dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts]
844
+ if len(set(inp_dtypes)) != 1:
845
+ return
846
+
847
+ final_type = inp_dtypes[0]
848
+
849
+ # Get list of output nodes that cast to int32 or int64
850
+ out_casts = [
851
+ out_node
852
+ for out_tensor in node.outputs
853
+ for out_node in out_tensor.outputs
854
+ if out_node.op == "Cast"
855
+ and out_node.attrs["to"] in {onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64}
856
+ ]
857
+
858
+ # No cast node found on outputs, return early
859
+ if not out_casts:
860
+ return
861
+
862
+ # Ensure that all output cast nodes are casting to the same type and that this
863
+ # matches the original type before the inputs were casted.
864
+ out_dtypes = [out_cast.attrs["to"] for out_cast in out_casts]
865
+ if len(set(out_dtypes)) != 1 or out_dtypes[0] != final_type:
866
+ return
867
+
868
+ # If all checks passed, reconnect inputs/outputs to the consumers/producers
869
+ # of the Cast nodes.
870
+ # Note that we need to be careful in how we rebind tensors since they may
871
+ # be used by multiple nodes. Thus, it is not necessarily safe to assume that
872
+ # `cast_node.inputs[0].outputs[0] == cast_node`.
873
+ for index, inp in enumerate(node.inputs):
874
+ if isinstance(inp, Constant):
875
+ inp.values = inp.values.astype(onnx.helper.tensor_dtype_to_np_dtype(final_type))
876
+
877
+ for cast in inp_casts:
878
+ if cast.outputs[0] == inp:
879
+ node.inputs[index] = cast.inputs[0]
880
+
881
+ for index, out in enumerate(node.outputs):
882
+ for cast in out_casts:
883
+ if cast.inputs[0] == out:
884
+ out_tensor = cast.outputs[0]
885
+ out_tensor.inputs.clear() # Disconnect from Cast
886
+ node.outputs[index] = out_tensor
887
+
888
+ if fold_shapes:
889
+ # Perform shape tensor cast elision prior to most other folding
890
+ G_LOGGER.debug(f"Performing shape tensor cast elision in {self.name}")
891
+ try:
892
+ with self.node_ids():
893
+ for node in self.nodes:
894
+ run_cast_elision(node)
895
+ except Exception as err:
896
+ if not error_ok:
897
+ raise err
898
+ G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err))
899
+
900
+ # Note that most of the remaining passes operate on a clone of the original graph.
901
+ # Pass 3: Find all descendants of constant tensors
902
+
903
+ graph_clone = self.copy()
904
+ clone_tensors = graph_clone.tensors()
905
+
906
+ # If 'self' is a Function, then these fields need to be set so it can be exported as an ONNX Graph.
907
+ graph_clone.producer_name = ""
908
+ graph_clone.producer_version = ""
909
+
910
+ def update_foldable_outputs(graph_constants):
911
+ """Updates the graph's outputs to ensure certain operations remain foldable."""
912
+
913
+ def is_foldable(node):
914
+ """Determines if a given node operation is foldable based on its type."""
915
+ NO_FOLD_OPS = {
916
+ "QuantizeLinear",
917
+ "DequantizeLinear",
918
+ "DynamicQuantizeLinear",
919
+ "SequenceEmpty",
920
+ }
921
+ if node.op in NO_FOLD_OPS:
922
+ return False
923
+
924
+ def all_tensors_const(tensors):
925
+ """Check if all tensors in a given list are constants in the graph."""
926
+ return all(t.name in graph_constants for t in tensors if not t.is_empty())
927
+
928
+ if not all_tensors_const(node.inputs):
929
+ return False
930
+
931
+ all_subgraph_foreign_tensors_const = True
932
+ for subgraph in node.subgraphs():
933
+ foreign_tensors = subgraph._foreign_tensors().values()
934
+ all_subgraph_foreign_tensors_const &= all_tensors_const(foreign_tensors)
935
+
936
+ return all_subgraph_foreign_tensors_const and not should_exclude_node(node)
937
+
938
+ # Walks along the outputs of graph_constants to see if they can also be computed statically.
939
+ # Since the graph is topologically sorted, this should find all constant nodes in the graph.
940
+ for node in graph_clone.nodes:
941
+ if is_foldable(node):
942
+ graph_constants.update({out.name: out for out in node.outputs})
943
+ return graph_constants
944
+
945
+ graph_constants = {}
946
+ for name, tensor in clone_tensors.items():
947
+ if isinstance(tensor, Constant):
948
+ if any((t.op == "Gather" and t.inputs.index(tensor) == 0) for t in tensor.outputs):
949
+ if len(tensor.outputs) <= 1:
950
+ graph_constants[name] = tensor
951
+ else:
952
+ graph_constants[name] = tensor
953
+
954
+ graph_constants = update_foldable_outputs(graph_constants)
955
+
956
+ # Pass 4: Shape Folding
957
+
958
+ def get_producer(tensor, op):
959
+ """Get the producer of the specified tensor iff it matches op."""
960
+ if len(tensor.inputs) != 1:
961
+ return None
962
+
963
+ node = tensor.inputs[0]
964
+ return None if node.op != op else node
965
+
966
+ def get_input(node, index=0):
967
+ """Get the input tensor of a node iff the input tensor is not already marked a graph constant."""
968
+ if node is None:
969
+ return None
970
+
971
+ inp = node.inputs[index]
972
+
973
+ # If the input was already found to be a constant, it will be folded anyway.
974
+ return None if inp.name in graph_constants else inp
975
+
976
+ def get_scalar_value(tensor):
977
+ """Gets the scalar value of a constant tensor with a single item."""
978
+ return next(iter(tensor.values)) if tensor.shape else tensor.values
979
+
980
+ def fold_shape(tensor):
981
+ """Returns the input tensor shape if available, otherwise returns None.
982
+ Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
983
+ """
984
+ shape_node = get_producer(tensor, "Shape")
985
+ inp = get_input(shape_node)
986
+ if inp is None:
987
+ return None
988
+
989
+ if inp.shape is None or misc.is_dynamic_shape(inp.shape):
990
+ return None
991
+
992
+ full_shape = inp.shape
993
+ num_dims = len(full_shape)
994
+
995
+ # Get start and end attributes (default: start=0, end=None means full shape)
996
+ start = shape_node.attrs.get("start", 0)
997
+ end = shape_node.attrs.get("end", None)
998
+
999
+ # Handle negative indices
1000
+ if start < 0:
1001
+ start = num_dims + start
1002
+ if end is None:
1003
+ end = num_dims
1004
+ elif end < 0:
1005
+ end = num_dims + end
1006
+
1007
+ # Clamp to valid range
1008
+ start = max(0, min(start, num_dims))
1009
+ end = max(0, min(end, num_dims))
1010
+
1011
+ if start > end:
1012
+ return None
1013
+
1014
+ target_shape = full_shape[start:end]
1015
+ return np.array(target_shape, dtype=np.int64)
1016
+
1017
+ def fold_shape_gather(tensor):
1018
+ """Retrieves and returns the shape of the input tensor as a NumPy array, otherwise returns None.
1019
+ Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
1020
+ """
1021
+ gather = get_producer(tensor, "Gather")
1022
+ if gather is None:
1023
+ return None
1024
+
1025
+ data = gather.inputs[0]
1026
+ indices_tensor = gather.inputs[1]
1027
+
1028
+ shape_node = get_producer(data, "Shape")
1029
+ inp = get_input(shape_node)
1030
+ if inp is None or inp.shape is None:
1031
+ return None
1032
+
1033
+ if not isinstance(indices_tensor, Constant):
1034
+ return None
1035
+
1036
+ # Get the shape slice from Shape node (considering start/end attributes)
1037
+ full_shape = inp.shape
1038
+ num_dims = len(full_shape)
1039
+
1040
+ start = shape_node.attrs.get("start", 0)
1041
+ end = shape_node.attrs.get("end", None)
1042
+
1043
+ if start < 0:
1044
+ start = num_dims + start
1045
+ if end is None:
1046
+ end = num_dims
1047
+ elif end < 0:
1048
+ end = num_dims + end
1049
+
1050
+ start = max(0, min(start, num_dims))
1051
+ end = max(0, min(end, num_dims))
1052
+
1053
+ if start > end:
1054
+ return None
1055
+
1056
+ shape_slice = full_shape[start:end]
1057
+
1058
+ indices = indices_tensor.values
1059
+ if not indices.shape: # Scalar-case
1060
+ idx = int(indices)
1061
+ # Handle negative indices relative to shape_slice
1062
+ if idx < 0:
1063
+ idx = len(shape_slice) + idx
1064
+ if idx < 0 or idx >= len(shape_slice):
1065
+ return None
1066
+ shape = shape_slice[idx]
1067
+ if misc.is_dynamic_dimension(shape):
1068
+ return None
1069
+ else:
1070
+ shape = []
1071
+ for index in indices:
1072
+ idx = int(index)
1073
+ # Handle negative indices relative to shape_slice
1074
+ if idx < 0:
1075
+ idx = len(shape_slice) + idx
1076
+ if idx < 0 or idx >= len(shape_slice):
1077
+ return None
1078
+ shape.append(shape_slice[idx])
1079
+ if misc.is_dynamic_shape(shape):
1080
+ return None
1081
+
1082
+ return np.array(shape, dtype=np.int64)
1083
+
1084
+ def fold_shape_slice(tensor):
1085
+ """Fold tensor shape slice information into a NumPy array of int64 type.
1086
+ Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
1087
+ """
1088
+ slice_node = get_producer(tensor, "Slice")
1089
+ if slice_node is None:
1090
+ return None
1091
+
1092
+ data = slice_node.inputs[0]
1093
+
1094
+ if len(slice_node.inputs) >= 3:
1095
+ starts, ends = slice_node.inputs[1:3]
1096
+ if any(not isinstance(t, Constant) for t in [starts, ends]):
1097
+ return None
1098
+ starts, ends = get_scalar_value(starts), get_scalar_value(ends)
1099
+ elif "starts" in slice_node.attrs and "ends" in slice_node.attrs:
1100
+ starts, ends = slice_node.attrs["starts"][0], slice_node.attrs["ends"][0]
1101
+ else:
1102
+ return None
1103
+
1104
+ shape_node = get_producer(data, "Shape")
1105
+ inp = get_input(shape_node)
1106
+ if inp is None or inp.shape is None:
1107
+ return None
1108
+
1109
+ # For shape tensors, we can only slice on the 0th dimension.
1110
+ if len(slice_node.inputs) > 3:
1111
+ axes = slice_node.inputs[3]
1112
+ if not isinstance(axes, Constant):
1113
+ return None
1114
+
1115
+ if get_scalar_value(axes) != 0:
1116
+ return None
1117
+ elif "axes" in slice_node.attrs:
1118
+ if slice_node.attrs["axes"][0] != 0:
1119
+ return None
1120
+
1121
+ steps = 1
1122
+ if len(slice_node.inputs) > 4:
1123
+ steps = slice_node.inputs[4]
1124
+ if not isinstance(steps, Constant):
1125
+ return None
1126
+ steps = get_scalar_value(steps)
1127
+ elif "steps" in slice_node.attrs:
1128
+ steps = slice_node.attrs["steps"][0]
1129
+
1130
+ # Get the shape slice from Shape node (considering start/end attributes)
1131
+ full_shape = inp.shape
1132
+ num_dims = len(full_shape)
1133
+
1134
+ shape_start = shape_node.attrs.get("start", 0)
1135
+ shape_end = shape_node.attrs.get("end", None)
1136
+
1137
+ if shape_start < 0:
1138
+ shape_start = num_dims + shape_start
1139
+ if shape_end is None:
1140
+ shape_end = num_dims
1141
+ elif shape_end < 0:
1142
+ shape_end = num_dims + shape_end
1143
+
1144
+ shape_start = max(0, min(shape_start, num_dims))
1145
+ shape_end = max(0, min(shape_end, num_dims))
1146
+
1147
+ if shape_start > shape_end:
1148
+ return None
1149
+
1150
+ shape_slice = full_shape[shape_start:shape_end]
1151
+
1152
+ # Apply the Slice operation on the shape_slice
1153
+ shape = shape_slice[starts:ends:steps]
1154
+ if misc.is_dynamic_shape(shape):
1155
+ return None
1156
+
1157
+ return np.array(shape, dtype=np.int64)
1158
+
1159
+ if fold_shapes:
1160
+ # NOTE: The order of shape folding passes is important to maximize how much we fold (phase-ordering problem).
1161
+ SHAPE_FOLD_FUNCS = {fold_shape_gather, fold_shape_slice, fold_shape}
1162
+ for shape_fold_func in SHAPE_FOLD_FUNCS:
1163
+ try:
1164
+ for tensor in clone_tensors.values():
1165
+ shape_of = shape_fold_func(tensor)
1166
+
1167
+ if shape_of is not None:
1168
+ G_LOGGER.ultra_verbose(f"Folding shape tensor: {tensor.name} to: {shape_of}")
1169
+ graph_constants[tensor.name] = tensor.to_constant(shape_of)
1170
+ graph_constants[tensor.name].inputs.clear()
1171
+ except Exception as err:
1172
+ if not error_ok:
1173
+ raise err
1174
+ G_LOGGER.warning(f"'{shape_fold_func.__name__}' routine failed with:\n{err}")
1175
+ else:
1176
+ graph_constants = update_foldable_outputs(graph_constants)
1177
+
1178
+ # Pass 5: Evaluate all tensors descended from constants with ONNX-Runtime and replace them with constant values.
1179
+
1180
+ def partition_and_infer(subgraph):
1181
+ """Evaluates and partitions the subgraph to infer constant values using ONNX-Runtime."""
1182
+
1183
+ def get_out_node_ids():
1184
+ """Gets the final output nodes, identifying producer nodes of graph output tensors with no other
1185
+ outputs.
1186
+ """
1187
+ with subgraph.node_ids():
1188
+ out_node_ids = set()
1189
+ for out in subgraph.outputs:
1190
+ if not out.outputs and not isinstance(out, Constant):
1191
+ for n_inp in out.inputs:
1192
+ out_node_ids.add(subgraph._get_node_id(n_inp))
1193
+ return out_node_ids
1194
+
1195
+ # Compute each output node in a separate subgraph.
1196
+ out_node_ids = get_out_node_ids()
1197
+ constant_values = {}
1198
+
1199
+ for index in out_node_ids: # Have to use index since 'node' is not in part
1200
+ part = subgraph.copy()
1201
+ out_node = part.nodes[index]
1202
+ part.outputs = out_node.outputs
1203
+ part.name = f"Folding: {[out.name for out in part.outputs]}"
1204
+ part.cleanup(remove_unused_graph_inputs=True)
1205
+ names = [out.name for out in part.outputs]
1206
+
1207
+ try:
1208
+ # Determining types is not trivial, and ONNX-RT does its own type inference.
1209
+ import onnxruntime as onnxrt
1210
+
1211
+ sess = onnxrt.InferenceSession(
1212
+ export_onnx(part, do_type_check=False).SerializeToString(),
1213
+ providers=ORT_PROVIDERS,
1214
+ )
1215
+ values = sess.run(names, {})
1216
+ except Exception as err:
1217
+ G_LOGGER.warning(f"Inference failed for subgraph: {part.name}. Note: Error was:\n{err}")
1218
+ if partitioning == "recursive":
1219
+ G_LOGGER.verbose("Attempting to recursively partition subgraph")
1220
+ # Partition failed, peel off last node.
1221
+ # We only need to remove one node, so avoid doing an expensive call to cleanup()
1222
+ part.outputs = out_node.inputs
1223
+ del part.nodes[part.nodes.index(out_node)]
1224
+ out_node.outputs.clear()
1225
+ out_node.inputs.clear()
1226
+ else:
1227
+ G_LOGGER.info("You may see better results if you set partitioning='recursive'")
1228
+ if not error_ok:
1229
+ raise err
1230
+
1231
+ constant_values.update(partition_and_infer(part))
1232
+ else:
1233
+ constant_values.update(dict(zip(names, values)))
1234
+
1235
+ return constant_values
1236
+
1237
+ # Only evaluate foldable values that have non-foldable outputs or are graph outputs.
1238
+ # Otherwise, if all the outputs are foldable, then we can just evaluate the outputs directly.
1239
+ # Additionally, if we can determine tensor size, do not evaluate tensors whose sizes exceed the size threshold.
1240
+ def should_eval_foldable(tensor):
1241
+ """Determine if foldable values should be evaluated based on output nature and tensor size constraints."""
1242
+ from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import (
1243
+ get_itemsize,
1244
+ )
1245
+
1246
+ non_const = not isinstance(tensor, Constant)
1247
+ is_graph_output = not tensor.outputs
1248
+ has_non_foldable_outputs = any(out.name not in graph_constants for out in tensor.outputs)
1249
+ exceeds_size_threshold = (
1250
+ tensor.shape is not None
1251
+ and not misc.is_dynamic_shape(tensor.shape)
1252
+ and tensor.dtype is not None
1253
+ and size_threshold is not None
1254
+ ) and (misc.volume(tensor.shape) * get_itemsize(tensor.dtype) > size_threshold)
1255
+
1256
+ return non_const and (is_graph_output or has_non_foldable_outputs) and not exceeds_size_threshold
1257
+
1258
+ graph_clone.outputs = [t for t in graph_constants.values() if should_eval_foldable(t)]
1259
+ G_LOGGER.debug(f"Folding tensors: {graph_clone.outputs}")
1260
+ graph_clone.cleanup(remove_unused_graph_inputs=True, recurse_functions=False)
1261
+
1262
+ # Using ._values avoids a deep copy of the values.
1263
+ constant_values = {
1264
+ name: tensor._values for name, tensor in graph_constants.items() if isinstance(tensor, Constant)
1265
+ }
1266
+ if graph_clone.outputs:
1267
+ if partitioning:
1268
+ constant_values.update(partition_and_infer(graph_clone))
1269
+ else:
1270
+ names = [t.name for t in graph_clone.outputs]
1271
+ try:
1272
+ import os
1273
+ import tempfile
1274
+
1275
+ import onnx
1276
+ import onnxruntime as onnxrt
1277
+
1278
+ onnx_model = export_onnx(graph_clone, do_type_check=False)
1279
+ if onnx_model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
1280
+ tmp_dir = tempfile.TemporaryDirectory()
1281
+ tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
1282
+ location = f"{os.path.basename(tmp_path)}.data"
1283
+ if os.path.exists(location):
1284
+ os.remove(location)
1285
+ onnx.save(
1286
+ onnx_model,
1287
+ tmp_path,
1288
+ save_as_external_data=True,
1289
+ all_tensors_to_one_file=True,
1290
+ location=location,
1291
+ )
1292
+ onnx_model = tmp_path
1293
+ else:
1294
+ onnx_model = onnx_model.SerializeToString()
1295
+ sess = onnxrt.InferenceSession(
1296
+ onnx_model,
1297
+ providers=ORT_PROVIDERS,
1298
+ )
1299
+ values = sess.run(names, {})
1300
+ constant_values.update(dict(zip(names, values)))
1301
+ except Exception as err:
1302
+ G_LOGGER.warning(
1303
+ "Inference failed. You may want to try enabling partitioning to see better results. "
1304
+ f"Note: Error was:\n{err}"
1305
+ )
1306
+ G_LOGGER.verbose(f"Note: Graph was:\n{graph_clone}")
1307
+ if not error_ok:
1308
+ raise
1309
+ elif not constant_values:
1310
+ G_LOGGER.debug(
1311
+ f"Could not find any nodes in this graph ({self.name}) that can be folded. "
1312
+ "This could mean that constant folding has already been run on this graph. "
1313
+ "Skipping."
1314
+ )
1315
+
1316
+ # Finally, replace the Variables in the original graph with constants.
1317
+ large_tensors = {}
1318
+ if constant_values:
1319
+ graph_tensors = self.tensors()
1320
+ for name, values in constant_values.items():
1321
+ tensor = graph_tensors[name]
1322
+ if isinstance(tensor, Constant) or not tensor.outputs:
1323
+ # No need to fold tensors that are already constant.
1324
+ continue
1325
+
1326
+ if size_threshold is not None and values.nbytes > size_threshold:
1327
+ G_LOGGER.debug(
1328
+ f"Will not fold: '{name}' since its size in bytes ({values.nbytes}) exceeds the size threshold ({size_threshold})"
1329
+ )
1330
+ continue
1331
+ elif size_threshold is None and values.nbytes > (1 << 20):
1332
+ large_tensors[name] = values.nbytes
1333
+
1334
+ tensor.to_constant(values)
1335
+ tensor.inputs.clear() # Constants do not need inputs
1336
+
1337
+ if large_tensors:
1338
+ large_tensors_mib = {
1339
+ tensor_name: f"{value // (1 << 20)} MiB" for tensor_name, value in large_tensors.items()
1340
+ }
1341
+ G_LOGGER.warning(
1342
+ "It looks like this model contains foldable nodes that produce large outputs.\n"
1343
+ "In order to avoid bloating the model, you may want to set a constant-folding size threshold.\n"
1344
+ f"Note: Large tensors and their corresponding sizes were: {large_tensors_mib}",
1345
+ mode=LogMode.ONCE,
1346
+ )
1347
+
1348
+ # Folding subgraphs after the outer graph can lead to better folding.
1349
+ def fold_subgraphs():
1350
+ """Folds constants within subgraphs of the outer computational graph for optimization."""
1351
+ for subgraph in self.subgraphs():
1352
+ subgraph.fold_constants(
1353
+ fold_shapes=fold_shapes,
1354
+ recurse_subgraphs=recurse_subgraphs,
1355
+ partitioning=partitioning,
1356
+ error_ok=error_ok,
1357
+ flatten_subgraphs=flatten_subgraphs,
1358
+ size_threshold=size_threshold,
1359
+ recurse_functions=False, # Functions are folded later
1360
+ )
1361
+
1362
+ if recurse_subgraphs:
1363
+ fold_subgraphs()
1364
+
1365
+ if flatten_subgraphs:
1366
+ # Flatten conditional subgraphs
1367
+ index = 0
1368
+ while index < len(self.nodes):
1369
+ node = self.nodes[index]
1370
+ if node.op == "If" and isinstance(node.inputs[0], Constant):
1371
+ G_LOGGER.debug(f"Flattening conditional: {node.name}")
1372
+ cond = get_scalar_value(node.inputs[0])
1373
+ subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"]
1374
+ # Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors
1375
+ for tensor in subgraph._local_tensors().values():
1376
+ tensor.name += f"_subg_{index}_{subgraph.name}"
1377
+
1378
+ # The subgraph outputs correspond to the If node outputs. Only the latter are visible
1379
+ # in the parent graph, so we rebind the producer nodes of the subgraph outputs to point
1380
+ # to the output tensors of the If instead.
1381
+ node_outputs = list(node.outputs)
1382
+ for node_out, subgraph_out in zip(node_outputs, subgraph.outputs):
1383
+ node_out.inputs.clear()
1384
+ for producer in subgraph_out.inputs:
1385
+ for tensor_idx, out_tensor in enumerate(producer.outputs):
1386
+ if out_tensor == subgraph_out:
1387
+ producer.outputs[tensor_idx] = node_out
1388
+
1389
+ # Copy subgraph nodes into parent graph at the index of the If.
1390
+ del self.nodes[index]
1391
+ self.nodes[index:index] = subgraph.nodes
1392
+ index += len(subgraph.nodes) - 1
1393
+
1394
+ index += 1
1395
+
1396
+ if recurse_functions:
1397
+ # Nodes which are constant-folded but not cleaned up can result in errors during inference,
1398
+ # so process functions in reverse topological order.
1399
+ for func in reversed(self.functions):
1400
+ func.fold_constants(
1401
+ fold_shapes=fold_shapes,
1402
+ recurse_subgraphs=recurse_subgraphs,
1403
+ partitioning=partitioning,
1404
+ error_ok=error_ok,
1405
+ flatten_subgraphs=flatten_subgraphs,
1406
+ size_threshold=size_threshold,
1407
+ should_exclude_node=should_exclude_node,
1408
+ recurse_functions=False, # No infinite recursion
1409
+ )
1410
+
1411
+ return self
1412
+
1413
+ def _generate_name(self, prefix: str, existing_names: set):
1414
+ """Generate a unique name by appending an index to the given prefix, ensuring it does not clash with existing
1415
+ names.
1416
+ """
1417
+ # Generation is done by appending an index to the prefix.
1418
+ while True:
1419
+ name = f"{prefix}_{self.name_idx}"
1420
+ self.name_idx += 1
1421
+ if name not in existing_names: # Ensure generated name is unique
1422
+ break
1423
+ return name
1424
+
1425
+ def layer(self, inputs=None, outputs=None, *args, **kwargs):
1426
+ """
1427
+ Creates a node, adds it to this graph, and optionally creates its input and output tensors.
1428
+
1429
+ The input and output lists can include various different types:
1430
+
1431
+ - ``Tensor``:
1432
+ Any Tensors provided will be used as-is in the inputs/outputs of the node created.
1433
+ Therefore, you must ensure that the provided Tensors have unique names.
1434
+ - ``str``:
1435
+ If a string is provided, this function will generate a new tensor using
1436
+ the string to generate a name. It will append an index to the end of the provided string
1437
+ to guarantee unique names.
1438
+ - ``numpy.ndarray``:
1439
+ If a NumPy array is provided, this function will generate a Constant tensor
1440
+ using the name prefix: "onnx_graphsurgeon_constant", and append an index to the end
1441
+ of the prefix to guarantee unique names.
1442
+ - ``Union[List[Number], Tuple[Number]]``:
1443
+ If a list or tuple of numbers (int or float) is provided, this function will
1444
+ generate a Constant tensor using the name prefix: "onnx_graphsurgeon_lst_constant",
1445
+ and append an index to the end of the prefix to guarantee unique names.
1446
+ The values of the tensor will be a 1D array containing the specified values.
1447
+ The datatype will be either `np.float32` or `np.int64`.
1448
+
1449
+ Args:
1450
+ inputs (List[Union[Tensor, str, numpy.ndarray]]): The list of inputs
1451
+ outputs (List[Union[Tensor, str, numpy.ndarray]]): The list of outputs
1452
+ args/kwargs: These are passed directly to the constructor of Node
1453
+
1454
+ Returns:
1455
+ List[Tensor]: The output tensors of the node
1456
+ """
1457
+ inputs = misc.default_value(inputs, [])
1458
+ outputs = misc.default_value(outputs, [])
1459
+
1460
+ def process_io(io, existing_names):
1461
+ """Processes input/output elements, converting them to Tensor, Variable, or Constant, and ensuring unique
1462
+ names.
1463
+ """
1464
+ new_io = []
1465
+ for elem in io:
1466
+ if isinstance(elem, Tensor):
1467
+ new_io.append(elem)
1468
+ elif isinstance(elem, str):
1469
+ name = self._generate_name(elem, existing_names)
1470
+ tensor = Variable(name=name)
1471
+ new_io.append(tensor)
1472
+ elif isinstance(elem, np.ndarray):
1473
+ name = self._generate_name("onnx_graphsurgeon_constant", existing_names)
1474
+ new_io.append(Constant(name=name, values=elem))
1475
+ elif isinstance(elem, (list, tuple, numbers.Number)):
1476
+ if isinstance(elem, (list, tuple)):
1477
+ dtype = np.float32 if any(isinstance(x, float) for x in elem) else np.int64
1478
+ else:
1479
+ dtype = np.float32 if isinstance(elem, float) else np.int64
1480
+ arr = np.array(elem, dtype=dtype)
1481
+ name = self._generate_name("onnx_graphsurgeon_lst_constant", existing_names)
1482
+ new_io.append(Constant(name=name, values=arr))
1483
+ else:
1484
+ G_LOGGER.critical(
1485
+ f"Unrecognized type passed to Graph.layer: {elem}.\n"
1486
+ "\tHint: Did you forget to unpack a list with `*`?\n"
1487
+ "\tPlease use Tensors, strings, or NumPy arrays."
1488
+ )
1489
+ if new_io[-1].name:
1490
+ existing_names.add(new_io[-1].name)
1491
+ return new_io
1492
+
1493
+ existing_names = set(self.tensors().keys()) # set for fast lookup
1494
+ inputs = process_io(inputs, existing_names)
1495
+ outputs = process_io(outputs, existing_names)
1496
+
1497
+ if "name" not in kwargs:
1498
+ kwargs["name"] = self._generate_name("onnx_graphsurgeon_node", {node.name for node in self.nodes})
1499
+
1500
+ node = Node(*args, **kwargs, inputs=inputs, outputs=outputs)
1501
+ self.nodes.append(node)
1502
+ return node.outputs
1503
+
1504
+ def copy(self, tensor_map: OrderedDict[str, Tensor] | None = None):
1505
+ """
1506
+ Copy the graph.
1507
+
1508
+ This makes copies of all nodes and tensors in the graph, but will not
1509
+ do a deep-copy of weights or attributes (with the exception of ``Graph``
1510
+ attributes, which will be copied using their ``copy`` method).
1511
+
1512
+ Args:
1513
+ tensor_map (OrderedDict[str, Tensor]):
1514
+ A mapping of tensor names to tensors from the outer graph.
1515
+ This should be ``None`` if this is the outer-most graph.
1516
+
1517
+ Returns:
1518
+ Graph: A copy of the graph.
1519
+ """
1520
+ # First, reconstruct each tensor in the graph, but with no inputs or outputs
1521
+ tensor_map = copy.copy(misc.default_value(tensor_map, {}))
1522
+
1523
+ local_tensor_copies = {}
1524
+ # When we're cloning a subgraph by itself, we need to use `tensors()` to get all
1525
+ # required tensors - even those produced by outer graphs.
1526
+ local_tensor_copies.update({n: t.copy() for n, t in self.tensors().items()})
1527
+ # However, we should prioritize copies already made by the outer graph.
1528
+ local_tensor_copies.update(tensor_map)
1529
+ # And locally produced tensors should take precedence over everything else.
1530
+ local_tensor_copies.update({n: t.copy() for n, t in self._local_tensors().items()})
1531
+
1532
+ def get_tensor(name):
1533
+ """Retrieve a tensor by its name from local copies, or return an empty variable if no name is provided."""
1534
+ return local_tensor_copies[name] if name else Variable.empty()
1535
+
1536
+ # Next, copy nodes, and update inputs/outputs
1537
+ new_nodes = []
1538
+ for node in self.nodes:
1539
+ new_node = node.copy(
1540
+ inputs=[get_tensor(inp.name) for inp in node.inputs],
1541
+ outputs=[get_tensor(out.name) for out in node.outputs],
1542
+ tensor_map=local_tensor_copies,
1543
+ )
1544
+ new_nodes.append(new_node)
1545
+
1546
+ new_graph_inputs = [get_tensor(inp.name) for inp in self.inputs]
1547
+ new_graph_outputs = [get_tensor(out.name) for out in self.outputs]
1548
+ return Graph(
1549
+ nodes=new_nodes,
1550
+ inputs=new_graph_inputs,
1551
+ outputs=new_graph_outputs,
1552
+ name=copy.copy(self.name),
1553
+ doc_string=copy.copy(self.doc_string),
1554
+ opset=copy.copy(self.opset),
1555
+ import_domains=self.import_domains,
1556
+ ir_version=self.ir_version,
1557
+ functions=copy.copy(self.functions),
1558
+ )
1559
+
1560
+ def __str__(self):
1561
+ """Return a string representation of the graph including its name, opset, local functions, inputs, nodes, and
1562
+ outputs.
1563
+ """
1564
+ nodes_str = "\n".join([str(node) for node in self.nodes])
1565
+ functions_str = ",".join([str(func.name) for func in self.functions])
1566
+ out = f"Graph {self.name} (Opset {self.opset})"
1567
+ out += f"\nLocal Functions: [{functions_str}]"
1568
+ out += f"\nInputs: {self.inputs}"
1569
+ out += f"\nNodes: {nodes_str}"
1570
+ out += f"\nOutputs: {self.outputs}"
1571
+ return out
1572
+
1573
+ def __repr__(self):
1574
+ """Returns a string representation of the object."""
1575
+ return self.__str__()