onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (46) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +874 -257
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +373 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +40 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
  27. onnx_ir/passes/common/constant_manipulation.py +217 -0
  28. onnx_ir/passes/common/inliner.py +332 -0
  29. onnx_ir/passes/common/onnx_checker.py +57 -0
  30. onnx_ir/passes/common/shape_inference.py +112 -0
  31. onnx_ir/passes/common/topological_sort.py +33 -0
  32. onnx_ir/passes/common/unused_removal.py +196 -0
  33. onnx_ir/serde.py +288 -124
  34. onnx_ir/tape.py +15 -0
  35. onnx_ir/tensor_adapters.py +122 -0
  36. onnx_ir/testing.py +197 -0
  37. onnx_ir/traversal.py +4 -3
  38. onnx_ir-0.1.1.dist-info/METADATA +53 -0
  39. onnx_ir-0.1.1.dist-info/RECORD +42 -0
  40. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
  41. onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
  42. onnx_ir/_external_data.py +0 -323
  43. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  44. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  45. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  46. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # This module implements some APIs described in
5
5
  # https://pytorch.org/executorch/stable/compiler-custom-compiler-passes.html
@@ -16,10 +16,14 @@ from __future__ import annotations
16
16
 
17
17
  import dataclasses
18
18
  import logging
19
- from typing import Sequence
19
+ from collections.abc import Sequence
20
+ from typing import Literal, final
20
21
 
21
22
  __all__ = [
22
23
  "PassBase",
24
+ "Sequential",
25
+ "InPlacePass",
26
+ "FunctionalPass",
23
27
  "PassManager",
24
28
  "PassResult",
25
29
  # Errors
@@ -58,7 +62,7 @@ class PassResult:
58
62
 
59
63
  Attributes:
60
64
  model: The transformed model.
61
- modified: Whether the model was modified.
65
+ modified: Whether the resulting model is different from the input model.
62
66
  """
63
67
 
64
68
  model: ir.Model
@@ -68,14 +72,89 @@ class PassResult:
68
72
  class PassBase(abc.ABC):
69
73
  """Base class for all passes.
70
74
 
71
- Class attributes:
72
- in_place: Whether the pass modifies the model in place.
75
+ ``in_place`` and ``changes_input`` properties and what they mean:
76
+
77
+ +------------+------------------+----------------------------+
78
+ | | changes_inputs | not changes_inputs |
79
+ +------------+------------------+----------------------------+
80
+ | in_place | in place | Side-effect-only pass |
81
+ +------------+------------------+----------------------------+
82
+ | not | destructive | functional |
83
+ | in_place | | |
84
+ +------------+------------------+----------------------------+
73
85
  """
74
86
 
75
- in_place: bool = True
87
+ @property
88
+ @abc.abstractmethod
89
+ def in_place(self) -> bool:
90
+ """Whether the pass modifies the model in place and returns it.
91
+
92
+ If True, the pass will return the same model object that was passed in.
93
+ If False, the pass will return a new model object.
94
+ """
95
+ raise NotImplementedError
96
+
97
+ @property
98
+ @abc.abstractmethod
99
+ def changes_input(self) -> bool:
100
+ """Whether the pass modifies input model."""
101
+ raise NotImplementedError
102
+
103
+ @property
104
+ def destructive(self) -> bool:
105
+ """Whether the pass will destroy the input model when ``in_place=False``.
76
106
 
77
- def __call__(self, model: ir.Model) -> PassResult:
78
- return self.call(model)
107
+ A pass is destructive if it is not in place and it modifies the input model.
108
+ """
109
+ return not self.in_place and self.changes_input
110
+
111
+ def __call__(self, model_or_result: ir.Model | PassResult, /) -> PassResult:
112
+ if isinstance(model_or_result, PassResult):
113
+ model = model_or_result.model
114
+ else:
115
+ model = model_or_result
116
+ # Check preconditions
117
+ try:
118
+ self.requires(model)
119
+ except PreconditionError:
120
+ raise
121
+ except Exception as e:
122
+ raise PreconditionError(
123
+ f"Pre-condition for pass '{self.__class__.__name__}' failed"
124
+ ) from e
125
+
126
+ result = self.call(model)
127
+
128
+ # Check postconditions
129
+ try:
130
+ self.ensures(model)
131
+ except PostconditionError:
132
+ raise
133
+ except Exception as e:
134
+ raise PostconditionError(
135
+ f"Post-condition for pass '{self.__class__.__name__}' failed"
136
+ ) from e
137
+
138
+ if not isinstance(result, PassResult):
139
+ raise TypeError(
140
+ f"The result of the pass '{self.__class__.__name__}' should be type PassResult. "
141
+ "Please create one with ir.passes.PassResult()."
142
+ )
143
+
144
+ # Checks that the declared in-place property is respected
145
+ if self.in_place and result.model is not model:
146
+ raise PassError(
147
+ f"The pass '{self.__class__.__name__}' is declared in-place, "
148
+ "but the model returned is *not* the same object as the input model. "
149
+ "Pass developer: Pass should return the same model object or the in_place property should return False."
150
+ )
151
+ if not self.in_place and result.model is model:
152
+ raise PassError(
153
+ f"The pass '{self.__class__.__name__}' is declared not in-place, "
154
+ "but the model returned *is* the same object as the input model. "
155
+ "Pass developer: Pass should return a new model object or the in_place property should return True."
156
+ )
157
+ return result
79
158
 
80
159
  @abc.abstractmethod
81
160
  def call(self, model: ir.Model) -> PassResult:
@@ -97,76 +176,114 @@ class PassBase(abc.ABC):
97
176
  del model # Unused
98
177
 
99
178
 
100
- class PassManager:
179
+ class InPlacePass(PassBase):
180
+ """A pass that modifies the input model in place and returns it."""
181
+
182
+ @property
183
+ @final
184
+ def in_place(self) -> Literal[True]:
185
+ """An in-place pass is in place."""
186
+ return True
187
+
188
+ @property
189
+ @final
190
+ def changes_input(self) -> Literal[True]:
191
+ """An in-place pass changes the input model."""
192
+ return True
193
+
194
+
195
+ class FunctionalPass(PassBase):
196
+ """A pass that returns a new model but does not modify the input model."""
197
+
198
+ @property
199
+ @final
200
+ def in_place(self) -> Literal[False]:
201
+ """A functional pass is not in place."""
202
+ return False
203
+
204
+ @property
205
+ @final
206
+ def changes_input(self) -> Literal[False]:
207
+ """A functional pass does not change the input model."""
208
+ return False
209
+
210
+
211
+ class Sequential(PassBase):
212
+ """Run a sequence of passes in order."""
213
+
214
+ def __init__(self, *passes: PassBase):
215
+ if not passes:
216
+ raise ValueError("Sequential must take at least one pass")
217
+ self.passes = passes
218
+ self._in_place = all(pass_.in_place for pass_ in passes)
219
+ # The reason changes_inputs is decided by the first pass is that if the first pass is either in-place,
220
+ # or if it is not designed to be in-place but somehow changes the input (destructive),
221
+ # this pass sequence will change inputs.
222
+ self._changes_input = self.passes[0].changes_input or self.passes[0].in_place
223
+
224
+ @property
225
+ def in_place(self) -> bool:
226
+ return self._in_place
227
+
228
+ @property
229
+ def changes_input(self) -> bool:
230
+ return self._changes_input
231
+
232
+ def call(self, model: ir.Model) -> PassResult:
233
+ modified = False
234
+ for i, pass_ in enumerate(self.passes):
235
+ logger.debug("Running the %s-th pass '%s'", i, pass_)
236
+ try:
237
+ pass_result = pass_(model)
238
+ except Exception as e:
239
+ prev_pass_names = [str(p) for p in self.passes[:i]]
240
+ raise PassError(
241
+ f"An error occurred when running the '{pass_}' pass after the "
242
+ f"following passes: {prev_pass_names}"
243
+ ) from e
244
+
245
+ model = pass_result.model
246
+ modified = modified or pass_result.modified
247
+
248
+ return PassResult(model, modified)
249
+
250
+
251
+ class PassManager(Sequential):
101
252
  """Pass manager for the IR.
102
253
 
103
- The PassManager is a callable that runs a sequence of passes on a model.
254
+ The PassManager is a Pass that runs a sequence of passes on a model.
104
255
 
105
256
  Attributes:
106
257
  passes: The passes to run.
107
- check_invariants: Whether to check invariants before and after each pass.
108
258
  steps: The number of times to run the passes.
259
+ early_stop: Whether to stop running the passes if the graph stops changing.
109
260
  """
110
261
 
111
262
  def __init__(
112
263
  self,
113
264
  passes: Sequence[PassBase],
114
- check_invariants: bool = False,
115
265
  steps: int = 1,
266
+ early_stop: bool = True,
116
267
  ):
117
268
  # TODO(justinchuby): Implement constraints
118
- self.passes = list(passes)
119
- self.check_invariants = check_invariants
269
+ super().__init__(*passes)
120
270
  self.steps = steps
271
+ self.early_stop = early_stop
121
272
 
122
- def __call__(self, model: ir.Model) -> PassResult:
273
+ def call(self, model: ir.Model) -> PassResult:
123
274
  """Run the set of passes `steps` number of times or until the graph stops changing."""
124
275
  overall_modified = False
125
276
  for step in range(self.steps):
126
- step_result = self._run_one_step(model, step)
277
+ try:
278
+ # Call the call method of Sequential
279
+ step_result = super().call(model)
280
+ except Exception as e:
281
+ raise PassError(f"An error occurred at step {step}") from e
127
282
  model = step_result.model
128
283
  modified = step_result.modified
129
284
  overall_modified = overall_modified or modified
130
285
  # If the graph no longer changes, then we can stop running these passes
131
- if not modified:
286
+ if not modified and self.early_stop:
132
287
  logger.info("PassManager: No more graph changes detected after step %s", step)
133
288
  break
134
289
  return PassResult(model, overall_modified)
135
-
136
- def _run_one_step(self, model: ir.Model, step: int) -> PassResult:
137
- modified = False
138
- for i, pass_ in enumerate(self.passes):
139
- logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step)
140
-
141
- # 1. Check preconditions
142
- if self.check_invariants:
143
- try:
144
- pass_.requires(model)
145
- except Exception as e:
146
- raise PreconditionError(f"Pre-condition failed for {pass_}") from e
147
-
148
- # 2. Run the pass
149
- try:
150
- pass_result = pass_(model)
151
- except Exception as e:
152
- prev_pass_names = [str(p) for p in self.passes[:i]]
153
- raise PassError(
154
- f"An error occurred when running the '{pass_}' pass after the "
155
- f"following passes: {prev_pass_names} during step {step}"
156
- ) from e
157
- if not isinstance(pass_result, PassResult):
158
- raise TypeError(
159
- f"The result of the pass {pass_} should be type PassResult."
160
- "Please create one with ir.passes.PassResult()."
161
- )
162
-
163
- model = pass_result.model
164
- modified = modified or pass_result.modified
165
-
166
- # 3. Check postconditions
167
- if self.check_invariants:
168
- try:
169
- pass_.ensures(model)
170
- except Exception as e:
171
- raise PostconditionError(f"Post-condition failed for {pass_}") from e
172
- return PassResult(model, modified)
@@ -0,0 +1,40 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ __all__ = [
5
+ "AddInitializersToInputsPass",
6
+ "CheckerPass",
7
+ "ClearMetadataAndDocStringPass",
8
+ "CommonSubexpressionEliminationPass",
9
+ "InlinePass",
10
+ "LiftConstantsToInitializersPass",
11
+ "LiftSubgraphInitializersToMainGraphPass",
12
+ "RemoveInitializersFromInputsPass",
13
+ "RemoveUnusedFunctionsPass",
14
+ "RemoveUnusedNodesPass",
15
+ "RemoveUnusedOpsetsPass",
16
+ "ShapeInferencePass",
17
+ "TopologicalSortPass",
18
+ ]
19
+
20
+ from onnx_ir.passes.common.clear_metadata_and_docstring import (
21
+ ClearMetadataAndDocStringPass,
22
+ )
23
+ from onnx_ir.passes.common.common_subexpression_elimination import (
24
+ CommonSubexpressionEliminationPass,
25
+ )
26
+ from onnx_ir.passes.common.constant_manipulation import (
27
+ AddInitializersToInputsPass,
28
+ LiftConstantsToInitializersPass,
29
+ LiftSubgraphInitializersToMainGraphPass,
30
+ RemoveInitializersFromInputsPass,
31
+ )
32
+ from onnx_ir.passes.common.inliner import InlinePass
33
+ from onnx_ir.passes.common.onnx_checker import CheckerPass
34
+ from onnx_ir.passes.common.shape_inference import ShapeInferencePass
35
+ from onnx_ir.passes.common.topological_sort import TopologicalSortPass
36
+ from onnx_ir.passes.common.unused_removal import (
37
+ RemoveUnusedFunctionsPass,
38
+ RemoveUnusedNodesPass,
39
+ RemoveUnusedOpsetsPass,
40
+ )
@@ -0,0 +1,76 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Utilities for interfacing with onnx C APIs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from typing import TYPE_CHECKING, Callable, TypeVar
9
+
10
+ import onnx_ir as ir
11
+
12
+ if TYPE_CHECKING:
13
+ import onnx
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+ # Temporarily remove initializers larger than this size to keep model size down
18
+ # for the onnx.shape_inference call because it needs to serialize the model
19
+ _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
20
+ _R = TypeVar("_R")
21
+
22
+
23
+ def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
24
+ """Call an ONNX C API function by temporarily removing initializers.
25
+
26
+ This is necessary because the ONNX C API does not support large models
27
+ with initializers that have large tensor values. The input model is left
28
+ unchanged no matter the call succeeds or not.
29
+
30
+ Args:
31
+ func: Partially applied function that takes a model proto and returns anything.
32
+ model: The IR model to pass to the API function.
33
+
34
+ Returns:
35
+ The resulting ModelProto that contains the result of the API call.
36
+ """
37
+ # Store the original initializer values so they can be restored
38
+ initializer_values = tuple(model.graph.initializers.values())
39
+ tensors = {v.name: v.const_value for v in initializer_values}
40
+ original_inputs_len = len(model.graph.inputs)
41
+
42
+ # Turn the initializers into inputs and clear the initializers
43
+ # to limit the model size
44
+ for initializer in initializer_values:
45
+ # Make sure the initializer has its shape/type set
46
+ assert initializer.const_value is not None
47
+ if initializer.shape is None:
48
+ initializer.shape = initializer.const_value.shape # type: ignore[assignment]
49
+ if initializer.dtype is None:
50
+ initializer.dtype = initializer.const_value.dtype
51
+ if initializer not in model.graph.inputs:
52
+ model.graph.inputs.append(initializer)
53
+ if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
54
+ # Temporarily remove the initializer value to reduce model size
55
+ # for onnx.shape_inference
56
+ initializer.const_value = None
57
+ assert initializer.name is not None
58
+ model.graph.initializers.pop(initializer.name)
59
+
60
+ proto = ir.serde.serialize_model(model)
61
+
62
+ try:
63
+ # Call the ONNX C API function
64
+ result = func(proto)
65
+ finally:
66
+ # Restore the original initializer values so the model is unchanged
67
+ for initializer in initializer_values:
68
+ initializer.const_value = tensors[initializer.name]
69
+ model.graph.register_initializer(initializer)
70
+
71
+ # Restore the original inputs
72
+ inputs = model.graph.inputs[:original_inputs_len]
73
+ model.graph.inputs.clear()
74
+ model.graph.inputs.extend(inputs)
75
+
76
+ return result
@@ -0,0 +1,60 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Clear all metadata and docstring from the model, graphs, nodes, and functions."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "ClearMetadataAndDocStringPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx_ir as ir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ClearMetadataAndDocStringPass(ir.passes.InPlacePass):
19
+ """Clear all metadata and docstring from the model, graphs, nodes, and functions."""
20
+
21
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
22
+ # 0. TODO: Should we clean model metadata and docstring?
23
+
24
+ # 1. Clean up the graph and the belonged nodes metadata properties
25
+ modified = self._clear_graph_or_function_metadata_and_docstring(model.graph)
26
+
27
+ # 2. Clean up all of the functions metadata properties
28
+ for function in model.functions.values():
29
+ modified = (
30
+ self._clear_graph_or_function_metadata_and_docstring(function) or modified
31
+ )
32
+ return ir.passes.PassResult(model, modified=modified)
33
+
34
+ def _clear_graph_or_function_metadata_and_docstring(
35
+ self,
36
+ graph_or_function: ir.Graph | ir.Function,
37
+ ) -> bool:
38
+ """Clear metadata and docstring from the graph or function."""
39
+ checked_graphs_or_functions: set[ir.Graph | ir.Function] = set()
40
+ modified = False
41
+ # Clean up all of the nodes metadata properties
42
+ for node in ir.traversal.RecursiveGraphIterator(graph_or_function):
43
+ if node.metadata_props:
44
+ modified = True
45
+ logger.debug("Removed metadata from %s nodes", node.name)
46
+ node.metadata_props.clear()
47
+ node.doc_string = None
48
+
49
+ # Clean up the owning graph/function metadata properties
50
+ # and doc_string if the graph/function is not already checked
51
+ assert node.graph is not None
52
+ if node.graph not in checked_graphs_or_functions and (
53
+ node.graph.metadata_props or node.graph.doc_string
54
+ ):
55
+ modified = True
56
+ logger.debug("Removed metadata from %s graph/function", node.graph.name)
57
+ node.graph.metadata_props.clear()
58
+ node.graph.doc_string = None
59
+ checked_graphs_or_functions.add(node.graph)
60
+ return modified
@@ -0,0 +1,177 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Eliminate common subexpression in ONNX graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CommonSubexpressionEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+ from collections.abc import Sequence
13
+
14
+ import onnx_ir as ir
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
+ """Eliminate common subexpression in ONNX graphs."""
21
+
22
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
23
+ """Return the same ir.Model but with CSE applied to the graph."""
24
+ modified = False
25
+ graph = model.graph
26
+
27
+ modified = _eliminate_common_subexpression(graph, modified)
28
+
29
+ return ir.passes.PassResult(
30
+ model,
31
+ modified=modified,
32
+ )
33
+
34
+
35
+ def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36
+ """Eliminate common subexpression in ONNX graphs."""
37
+ # node to node identifier, length of outputs, inputs, and attributes
38
+ existing_node_info_to_the_node: dict[
39
+ tuple[
40
+ ir.OperatorIdentifier,
41
+ int, # len(outputs)
42
+ tuple[int, ...], # input ids
43
+ tuple[tuple[str, object], ...], # attributes
44
+ ],
45
+ ir.Node,
46
+ ] = {}
47
+
48
+ for node in graph:
49
+ # Skip control flow ops like Loop and If.
50
+ control_flow_op: bool = False
51
+ # Use equality to check if the node is a common subexpression.
52
+ attributes = {}
53
+ for k, v in node.attributes.items():
54
+ # TODO(exporter team): CSE subgraphs.
55
+ # NOTE: control flow ops like Loop and If won't be CSEd
56
+ # because attribute: graph won't match.
57
+ if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
58
+ control_flow_op = True
59
+ logger.debug("Skipping control flow op %s", node)
60
+ # The attribute value could be directly taken from the original
61
+ # protobuf, so we need to make a copy of it.
62
+ value = v.value
63
+ if v.type in (
64
+ ir.AttributeType.INTS,
65
+ ir.AttributeType.FLOATS,
66
+ ir.AttributeType.STRINGS,
67
+ ):
68
+ # For INT, FLOAT and STRING attributes, we convert them to tuples
69
+ # to ensure they are hashable.
70
+ value = tuple(value)
71
+ attributes[k] = value
72
+
73
+ if control_flow_op:
74
+ # If the node is a control flow op, we skip it.
75
+ logger.debug("Skipping control flow op %s", node)
76
+ continue
77
+
78
+ if _is_non_deterministic_op(node):
79
+ # If the node is a non-deterministic op, we skip it.
80
+ logger.debug("Skipping non-deterministic op %s", node)
81
+ continue
82
+
83
+ node_info = (
84
+ node.op_identifier(),
85
+ len(node.outputs),
86
+ tuple(id(input) for input in node.inputs),
87
+ tuple(sorted(attributes.items())),
88
+ )
89
+ # Check if the node is a common subexpression.
90
+ if node_info in existing_node_info_to_the_node:
91
+ # If it is, this node has an existing node with the same
92
+ # operator, number of outputs, inputs, and attributes.
93
+ # We replace the node with the existing node.
94
+ modified = True
95
+ existing_node = existing_node_info_to_the_node[node_info]
96
+ _remove_node_and_replace_values(
97
+ graph,
98
+ remove_node=node,
99
+ remove_values=node.outputs,
100
+ new_values=existing_node.outputs,
101
+ )
102
+ logger.debug("Reusing node %s", existing_node)
103
+ else:
104
+ # If it is not, add to the mapping.
105
+ existing_node_info_to_the_node[node_info] = node
106
+ return modified
107
+
108
+
109
+ def _remove_node_and_replace_values(
110
+ graph: ir.Graph,
111
+ /,
112
+ remove_node: ir.Node,
113
+ remove_values: Sequence[ir.Value],
114
+ new_values: Sequence[ir.Value],
115
+ ) -> None:
116
+ """Replaces nodes and values in the graph or function.
117
+
118
+ Args:
119
+ graph: The graph to replace nodes and values in.
120
+ remove_node: The node to remove.
121
+ remove_values: The values to replace.
122
+ new_values: The values to replace with.
123
+ """
124
+ # Reconnect the users of the deleted values to use the new values
125
+ ir.convenience.replace_all_uses_with(remove_values, new_values)
126
+ # Update graph/function outputs if the node generates output
127
+ if any(remove_value.is_graph_output() for remove_value in remove_values):
128
+ replacement_mapping = dict(zip(remove_values, new_values))
129
+ for idx, graph_output in enumerate(graph.outputs):
130
+ if graph_output in replacement_mapping:
131
+ new_value = replacement_mapping[graph_output]
132
+ if new_value.is_graph_output() or new_value.is_graph_input():
133
+ # If the new value is also a graph input/output, we need to
134
+ # create a Identity node to preserve the remove_value and
135
+ # prevent from changing new_value name.
136
+ identity_node = ir.node(
137
+ "Identity",
138
+ inputs=[new_value],
139
+ outputs=[
140
+ ir.Value(
141
+ name=graph_output.name,
142
+ type=graph_output.type,
143
+ shape=graph_output.shape,
144
+ )
145
+ ],
146
+ )
147
+ # reuse the name of the graph output
148
+ graph.outputs[idx] = identity_node.outputs[0]
149
+ graph.insert_before(
150
+ remove_node,
151
+ identity_node,
152
+ )
153
+ else:
154
+ # if new_value is not graph output, we just
155
+ # update it to use old_value name.
156
+ new_value.name = graph_output.name
157
+ graph.outputs[idx] = new_value
158
+
159
+ graph.remove(remove_node, safe=True)
160
+
161
+
162
+ def _is_non_deterministic_op(node: ir.Node) -> bool:
163
+ non_deterministic_ops = frozenset(
164
+ {
165
+ "RandomUniform",
166
+ "RandomNormal",
167
+ "RandomUniformLike",
168
+ "RandomNormalLike",
169
+ "Multinomial",
170
+ }
171
+ )
172
+ return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
173
+
174
+
175
+ def _is_onnx_domain(d: str) -> bool:
176
+ """Check if the domain is the ONNX domain."""
177
+ return d == ""