onnx-ir 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +1 -1
- onnx_ir/passes/common/__init__.py +2 -0
- onnx_ir/passes/common/naming.py +286 -0
- onnx_ir/tensor_adapters.py +14 -2
- onnx_ir/traversal.py +35 -0
- {onnx_ir-0.1.5.dist-info → onnx_ir-0.1.6.dist-info}/METADATA +1 -2
- {onnx_ir-0.1.5.dist-info → onnx_ir-0.1.6.dist-info}/RECORD +10 -9
- {onnx_ir-0.1.5.dist-info → onnx_ir-0.1.6.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.5.dist-info → onnx_ir-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.5.dist-info → onnx_ir-0.1.6.dist-info}/top_level.txt +0 -0
onnx_ir/__init__.py
CHANGED
|
@@ -11,6 +11,7 @@ __all__ = [
|
|
|
11
11
|
"InlinePass",
|
|
12
12
|
"LiftConstantsToInitializersPass",
|
|
13
13
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
14
|
+
"NameFixPass",
|
|
14
15
|
"RemoveInitializersFromInputsPass",
|
|
15
16
|
"RemoveUnusedFunctionsPass",
|
|
16
17
|
"RemoveUnusedNodesPass",
|
|
@@ -38,6 +39,7 @@ from onnx_ir.passes.common.initializer_deduplication import (
|
|
|
38
39
|
DeduplicateInitializersPass,
|
|
39
40
|
)
|
|
40
41
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
42
|
+
from onnx_ir.passes.common.naming import NameFixPass
|
|
41
43
|
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
42
44
|
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
43
45
|
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Name fix pass for ensuring unique names for all values and nodes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"NameFixPass",
|
|
9
|
+
"NameGenerator",
|
|
10
|
+
"SimpleNameGenerator",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
import collections
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Protocol
|
|
16
|
+
|
|
17
|
+
import onnx_ir as ir
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NameGenerator(Protocol):
|
|
23
|
+
def generate_node_name(self, node: ir.Node) -> str:
|
|
24
|
+
"""Generate a preferred name for a node."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
def generate_value_name(self, value: ir.Value) -> str:
|
|
28
|
+
"""Generate a preferred name for a value."""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SimpleNameGenerator(NameGenerator):
|
|
33
|
+
"""Base class for name generation functions."""
|
|
34
|
+
|
|
35
|
+
def generate_node_name(self, node: ir.Node) -> str:
|
|
36
|
+
"""Generate a preferred name for a node."""
|
|
37
|
+
return node.name or "node"
|
|
38
|
+
|
|
39
|
+
def generate_value_name(self, value: ir.Value) -> str:
|
|
40
|
+
"""Generate a preferred name for a value."""
|
|
41
|
+
return value.name or "v"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class NameFixPass(ir.passes.InPlacePass):
|
|
45
|
+
"""Pass for fixing names to ensure all values and nodes have unique names.
|
|
46
|
+
|
|
47
|
+
This pass ensures that:
|
|
48
|
+
1. Graph inputs and outputs have unique names (take precedence)
|
|
49
|
+
2. All intermediate values have unique names (assign names to unnamed values)
|
|
50
|
+
3. All values in subgraphs have unique names within their graph and parent graphs
|
|
51
|
+
4. All nodes have unique names within their graph
|
|
52
|
+
|
|
53
|
+
The pass maintains global uniqueness across the entire model.
|
|
54
|
+
|
|
55
|
+
You can customize the name generation functions for nodes and values by passing
|
|
56
|
+
a subclass of :class:`NameGenerator`.
|
|
57
|
+
|
|
58
|
+
For example, you can use a custom naming scheme like this::
|
|
59
|
+
|
|
60
|
+
class CustomNameGenerator:
|
|
61
|
+
def custom_node_name(node: ir.Node) -> str:
|
|
62
|
+
return f"custom_node_{node.op_type}"
|
|
63
|
+
|
|
64
|
+
def custom_value_name(value: ir.Value) -> str:
|
|
65
|
+
return f"custom_value_{value.type}"
|
|
66
|
+
|
|
67
|
+
name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator())
|
|
68
|
+
|
|
69
|
+
.. versionadded:: 0.1.6
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
name_generator: NameGenerator | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Initialize the NameFixPass with custom name generation functions.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name_generator (NameGenerator, optional): An instance of a subclass of
|
|
80
|
+
:class:`NameGenerator` to customize name generation for nodes and values.
|
|
81
|
+
If not provided, defaults to a basic implementation that uses
|
|
82
|
+
the node's or value's existing name or a generic name like "node" or "v".
|
|
83
|
+
"""
|
|
84
|
+
super().__init__()
|
|
85
|
+
self._name_generator = name_generator or SimpleNameGenerator()
|
|
86
|
+
|
|
87
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
88
|
+
# Process the main graph
|
|
89
|
+
modified = self._fix_graph_names(model.graph)
|
|
90
|
+
|
|
91
|
+
# Process functions
|
|
92
|
+
for function in model.functions.values():
|
|
93
|
+
modified = self._fix_graph_names(function) or modified
|
|
94
|
+
|
|
95
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
96
|
+
|
|
97
|
+
def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
|
|
98
|
+
"""Fix names in a graph and return whether modifications were made."""
|
|
99
|
+
modified = False
|
|
100
|
+
|
|
101
|
+
# Set to track which values have been assigned names
|
|
102
|
+
seen_values: set[ir.Value] = set()
|
|
103
|
+
|
|
104
|
+
# The first set is a dummy placeholder so that there is always a [-1] scope for access
|
|
105
|
+
# (even though we don't write to it)
|
|
106
|
+
scoped_used_value_names: list[set[str]] = [set()]
|
|
107
|
+
scoped_used_node_names: list[set[str]] = [set()]
|
|
108
|
+
|
|
109
|
+
# Counters for generating unique names (using list to pass by reference)
|
|
110
|
+
value_counter = collections.Counter()
|
|
111
|
+
node_counter = collections.Counter()
|
|
112
|
+
|
|
113
|
+
def enter_graph(graph_like) -> None:
|
|
114
|
+
"""Callback for entering a subgraph."""
|
|
115
|
+
# Initialize new scopes with all names from the parent scope
|
|
116
|
+
scoped_used_value_names.append(set(scoped_used_value_names[-1]))
|
|
117
|
+
scoped_used_node_names.append(set())
|
|
118
|
+
|
|
119
|
+
nonlocal modified
|
|
120
|
+
|
|
121
|
+
# Step 1: Fix graph input names first (they have precedence)
|
|
122
|
+
for input_value in graph_like.inputs:
|
|
123
|
+
if self._process_value(
|
|
124
|
+
input_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
125
|
+
):
|
|
126
|
+
modified = True
|
|
127
|
+
|
|
128
|
+
# Step 2: Fix graph output names (they have precedence)
|
|
129
|
+
for output_value in graph_like.outputs:
|
|
130
|
+
if self._process_value(
|
|
131
|
+
output_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
132
|
+
):
|
|
133
|
+
modified = True
|
|
134
|
+
|
|
135
|
+
if isinstance(graph_like, ir.Graph):
|
|
136
|
+
# For graphs, also fix initializers
|
|
137
|
+
for initializer in graph_like.initializers.values():
|
|
138
|
+
if self._process_value(
|
|
139
|
+
initializer, scoped_used_value_names[-1], seen_values, value_counter
|
|
140
|
+
):
|
|
141
|
+
modified = True
|
|
142
|
+
|
|
143
|
+
def exit_graph(_) -> None:
|
|
144
|
+
"""Callback for exiting a subgraph."""
|
|
145
|
+
# Pop the current scope
|
|
146
|
+
scoped_used_value_names.pop()
|
|
147
|
+
scoped_used_node_names.pop()
|
|
148
|
+
|
|
149
|
+
# Step 3: Process all nodes and their values
|
|
150
|
+
for node in ir.traversal.RecursiveGraphIterator(
|
|
151
|
+
graph_like, enter_graph=enter_graph, exit_graph=exit_graph
|
|
152
|
+
):
|
|
153
|
+
# Fix node name
|
|
154
|
+
if not node.name:
|
|
155
|
+
if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
|
|
156
|
+
modified = True
|
|
157
|
+
else:
|
|
158
|
+
if self._fix_duplicate_node_name(
|
|
159
|
+
node, scoped_used_node_names[-1], node_counter
|
|
160
|
+
):
|
|
161
|
+
modified = True
|
|
162
|
+
|
|
163
|
+
# Fix input value names (only if not already processed)
|
|
164
|
+
for input_value in node.inputs:
|
|
165
|
+
if input_value is not None:
|
|
166
|
+
if self._process_value(
|
|
167
|
+
input_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
168
|
+
):
|
|
169
|
+
modified = True
|
|
170
|
+
|
|
171
|
+
# Fix output value names (only if not already processed)
|
|
172
|
+
for output_value in node.outputs:
|
|
173
|
+
if self._process_value(
|
|
174
|
+
output_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
175
|
+
):
|
|
176
|
+
modified = True
|
|
177
|
+
|
|
178
|
+
return modified
|
|
179
|
+
|
|
180
|
+
def _process_value(
|
|
181
|
+
self,
|
|
182
|
+
value: ir.Value,
|
|
183
|
+
used_value_names: set[str],
|
|
184
|
+
seen_values: set[ir.Value],
|
|
185
|
+
value_counter: collections.Counter,
|
|
186
|
+
) -> bool:
|
|
187
|
+
"""Process a value only if it hasn't been processed before."""
|
|
188
|
+
if value in seen_values:
|
|
189
|
+
return False
|
|
190
|
+
|
|
191
|
+
modified = False
|
|
192
|
+
|
|
193
|
+
if not value.name:
|
|
194
|
+
modified = self._assign_value_name(value, used_value_names, value_counter)
|
|
195
|
+
else:
|
|
196
|
+
old_name = value.name
|
|
197
|
+
modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
|
|
198
|
+
if modified:
|
|
199
|
+
assert value.graph is not None
|
|
200
|
+
if value.is_initializer():
|
|
201
|
+
value.graph.initializers.pop(old_name)
|
|
202
|
+
# Add the initializer back with the new name
|
|
203
|
+
value.graph.initializers.add(value)
|
|
204
|
+
|
|
205
|
+
# Record the final name for this value
|
|
206
|
+
assert value.name is not None
|
|
207
|
+
seen_values.add(value)
|
|
208
|
+
return modified
|
|
209
|
+
|
|
210
|
+
def _assign_value_name(
|
|
211
|
+
self, value: ir.Value, used_names: set[str], counter: collections.Counter
|
|
212
|
+
) -> bool:
|
|
213
|
+
"""Assign a name to an unnamed value. Returns True if modified."""
|
|
214
|
+
assert not value.name, (
|
|
215
|
+
"value should not have a name already if function is called correctly"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
preferred_name = self._name_generator.generate_value_name(value)
|
|
219
|
+
value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
|
|
220
|
+
logger.debug("Assigned name %s to unnamed value", value.name)
|
|
221
|
+
return True
|
|
222
|
+
|
|
223
|
+
def _assign_node_name(
|
|
224
|
+
self, node: ir.Node, used_names: set[str], counter: collections.Counter
|
|
225
|
+
) -> bool:
|
|
226
|
+
"""Assign a name to an unnamed node. Returns True if modified."""
|
|
227
|
+
assert not node.name, (
|
|
228
|
+
"node should not have a name already if function is called correctly"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
preferred_name = self._name_generator.generate_node_name(node)
|
|
232
|
+
node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
|
|
233
|
+
logger.debug("Assigned name %s to unnamed node", node.name)
|
|
234
|
+
return True
|
|
235
|
+
|
|
236
|
+
def _fix_duplicate_value_name(
|
|
237
|
+
self, value: ir.Value, used_names: set[str], counter: collections.Counter
|
|
238
|
+
) -> bool:
|
|
239
|
+
"""Fix a value's name if it conflicts with existing names. Returns True if modified."""
|
|
240
|
+
original_name = value.name
|
|
241
|
+
|
|
242
|
+
assert original_name, (
|
|
243
|
+
"value should have a name already if function is called correctly"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if original_name not in used_names:
|
|
247
|
+
# Name is unique, just record it
|
|
248
|
+
used_names.add(original_name)
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
# If name is already used, make it unique
|
|
252
|
+
base_name = self._name_generator.generate_value_name(value)
|
|
253
|
+
value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
|
|
254
|
+
logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
|
|
255
|
+
return True
|
|
256
|
+
|
|
257
|
+
def _fix_duplicate_node_name(
|
|
258
|
+
self, node: ir.Node, used_names: set[str], counter: collections.Counter
|
|
259
|
+
) -> bool:
|
|
260
|
+
"""Fix a node's name if it conflicts with existing names. Returns True if modified."""
|
|
261
|
+
original_name = node.name
|
|
262
|
+
|
|
263
|
+
assert original_name, "node should have a name already if function is called correctly"
|
|
264
|
+
|
|
265
|
+
if original_name not in used_names:
|
|
266
|
+
# Name is unique, just record it
|
|
267
|
+
used_names.add(original_name)
|
|
268
|
+
return False
|
|
269
|
+
|
|
270
|
+
# If name is already used, make it unique
|
|
271
|
+
base_name = self._name_generator.generate_node_name(node)
|
|
272
|
+
node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
|
|
273
|
+
logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
|
|
274
|
+
return True
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _find_and_record_next_unique_name(
|
|
278
|
+
preferred_name: str, used_names: set[str], counter: collections.Counter
|
|
279
|
+
) -> str:
|
|
280
|
+
"""Generate a unique name based on the preferred name and current counter."""
|
|
281
|
+
new_name = preferred_name
|
|
282
|
+
while new_name in used_names:
|
|
283
|
+
counter[preferred_name] += 1
|
|
284
|
+
new_name = f"{preferred_name}_{counter[preferred_name]}"
|
|
285
|
+
used_names.add(new_name)
|
|
286
|
+
return new_name
|
onnx_ir/tensor_adapters.py
CHANGED
|
@@ -68,7 +68,6 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
|
|
|
68
68
|
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
|
|
69
69
|
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
|
|
70
70
|
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
|
|
71
|
-
torch.float8_e8m0fnu: ir.DataType.FLOAT8E8M0,
|
|
72
71
|
torch.int16: ir.DataType.INT16,
|
|
73
72
|
torch.int32: ir.DataType.INT32,
|
|
74
73
|
torch.int64: ir.DataType.INT64,
|
|
@@ -78,6 +77,10 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
|
|
|
78
77
|
torch.uint32: ir.DataType.UINT32,
|
|
79
78
|
torch.uint64: ir.DataType.UINT64,
|
|
80
79
|
}
|
|
80
|
+
if hasattr(torch, "float8_e8m0fnu"):
|
|
81
|
+
# torch.float8_e8m0fnu is available in PyTorch 2.7+
|
|
82
|
+
_TORCH_DTYPE_TO_ONNX[torch.float8_e8m0fnu] = ir.DataType.FLOAT8E8M0
|
|
83
|
+
|
|
81
84
|
if dtype not in _TORCH_DTYPE_TO_ONNX:
|
|
82
85
|
raise TypeError(
|
|
83
86
|
f"Unsupported PyTorch dtype '{dtype}'. "
|
|
@@ -105,7 +108,6 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
|
|
|
105
108
|
ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
|
|
106
109
|
ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
|
|
107
110
|
ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
|
|
108
|
-
ir.DataType.FLOAT8E8M0: torch.float8_e8m0fnu,
|
|
109
111
|
ir.DataType.INT16: torch.int16,
|
|
110
112
|
ir.DataType.INT32: torch.int32,
|
|
111
113
|
ir.DataType.INT64: torch.int64,
|
|
@@ -115,7 +117,17 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
|
|
|
115
117
|
ir.DataType.UINT32: torch.uint32,
|
|
116
118
|
ir.DataType.UINT64: torch.uint64,
|
|
117
119
|
}
|
|
120
|
+
|
|
121
|
+
if hasattr(torch, "float8_e8m0fnu"):
|
|
122
|
+
# torch.float8_e8m0fnu is available in PyTorch 2.7+
|
|
123
|
+
_ONNX_DTYPE_TO_TORCH[ir.DataType.FLOAT8E8M0] = torch.float8_e8m0fnu
|
|
124
|
+
|
|
118
125
|
if dtype not in _ONNX_DTYPE_TO_TORCH:
|
|
126
|
+
if dtype == ir.DataType.FLOAT8E8M0:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"The requested DataType 'FLOAT8E8M0' is only supported in PyTorch 2.7+. "
|
|
129
|
+
"Please upgrade your PyTorch version to use this dtype."
|
|
130
|
+
)
|
|
119
131
|
raise TypeError(
|
|
120
132
|
f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
|
|
121
133
|
"Please use a supported dtype from the list: "
|
onnx_ir/traversal.py
CHANGED
|
@@ -25,19 +25,33 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
|
|
|
25
25
|
*,
|
|
26
26
|
recursive: Callable[[_core.Node], bool] | None = None,
|
|
27
27
|
reverse: bool = False,
|
|
28
|
+
enter_graph: Callable[[GraphLike], None] | None = None,
|
|
29
|
+
exit_graph: Callable[[GraphLike], None] | None = None,
|
|
28
30
|
):
|
|
29
31
|
"""Iterate over the nodes in the graph, recursively visiting subgraphs.
|
|
30
32
|
|
|
33
|
+
This iterator allows for traversing the nodes of a graph and its subgraphs
|
|
34
|
+
in a depth-first manner. It supports optional callbacks for entering and exiting
|
|
35
|
+
subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs
|
|
36
|
+
contained within nodes.
|
|
37
|
+
|
|
38
|
+
.. versionadded:: 0.1.6
|
|
39
|
+
Added the `enter_graph` and `exit_graph` callbacks.
|
|
40
|
+
|
|
31
41
|
Args:
|
|
32
42
|
graph_like: The graph to traverse.
|
|
33
43
|
recursive: A callback that determines whether to recursively visit the subgraphs
|
|
34
44
|
contained in a node. If not provided, all nodes in subgraphs are visited.
|
|
35
45
|
reverse: Whether to iterate in reverse order.
|
|
46
|
+
enter_graph: An optional callback that is called when entering a subgraph.
|
|
47
|
+
exit_graph: An optional callback that is called when exiting a subgraph.
|
|
36
48
|
"""
|
|
37
49
|
self._graph = graph_like
|
|
38
50
|
self._recursive = recursive
|
|
39
51
|
self._reverse = reverse
|
|
40
52
|
self._iterator = self._recursive_node_iter(graph_like)
|
|
53
|
+
self._enter_graph = enter_graph
|
|
54
|
+
self._exit_graph = exit_graph
|
|
41
55
|
|
|
42
56
|
def __iter__(self) -> Self:
|
|
43
57
|
self._iterator = self._recursive_node_iter(self._graph)
|
|
@@ -50,34 +64,55 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
|
|
|
50
64
|
self, graph: _core.Graph | _core.Function | _core.GraphView
|
|
51
65
|
) -> Iterator[_core.Node]:
|
|
52
66
|
iterable = reversed(graph) if self._reverse else graph
|
|
67
|
+
|
|
68
|
+
if self._enter_graph is not None:
|
|
69
|
+
self._enter_graph(graph)
|
|
70
|
+
|
|
53
71
|
for node in iterable: # type: ignore[union-attr]
|
|
54
72
|
yield node
|
|
55
73
|
if self._recursive is not None and not self._recursive(node):
|
|
56
74
|
continue
|
|
57
75
|
yield from self._iterate_subgraphs(node)
|
|
58
76
|
|
|
77
|
+
if self._exit_graph is not None:
|
|
78
|
+
self._exit_graph(graph)
|
|
79
|
+
|
|
59
80
|
def _iterate_subgraphs(self, node: _core.Node):
|
|
60
81
|
for attr in node.attributes.values():
|
|
61
82
|
if not isinstance(attr, _core.Attr):
|
|
62
83
|
continue
|
|
63
84
|
if attr.type == _enums.AttributeType.GRAPH:
|
|
85
|
+
if self._enter_graph is not None:
|
|
86
|
+
self._enter_graph(attr.value)
|
|
64
87
|
yield from RecursiveGraphIterator(
|
|
65
88
|
attr.value,
|
|
66
89
|
recursive=self._recursive,
|
|
67
90
|
reverse=self._reverse,
|
|
91
|
+
enter_graph=self._enter_graph,
|
|
92
|
+
exit_graph=self._exit_graph,
|
|
68
93
|
)
|
|
94
|
+
if self._exit_graph is not None:
|
|
95
|
+
self._exit_graph(attr.value)
|
|
69
96
|
elif attr.type == _enums.AttributeType.GRAPHS:
|
|
70
97
|
graphs = reversed(attr.value) if self._reverse else attr.value
|
|
71
98
|
for graph in graphs:
|
|
99
|
+
if self._enter_graph is not None:
|
|
100
|
+
self._enter_graph(graph)
|
|
72
101
|
yield from RecursiveGraphIterator(
|
|
73
102
|
graph,
|
|
74
103
|
recursive=self._recursive,
|
|
75
104
|
reverse=self._reverse,
|
|
105
|
+
enter_graph=self._enter_graph,
|
|
106
|
+
exit_graph=self._exit_graph,
|
|
76
107
|
)
|
|
108
|
+
if self._exit_graph is not None:
|
|
109
|
+
self._exit_graph(graph)
|
|
77
110
|
|
|
78
111
|
def __reversed__(self) -> Iterator[_core.Node]:
|
|
79
112
|
return RecursiveGraphIterator(
|
|
80
113
|
self._graph,
|
|
81
114
|
recursive=self._recursive,
|
|
82
115
|
reverse=not self._reverse,
|
|
116
|
+
enter_graph=self._enter_graph,
|
|
117
|
+
exit_graph=self._exit_graph,
|
|
83
118
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
6
|
License: Apache License v2.0
|
|
@@ -29,7 +29,6 @@ Dynamic: license-file
|
|
|
29
29
|
[](https://pypi.org/project/onnx-ir)
|
|
30
30
|
[](https://github.com/astral-sh/ruff)
|
|
31
31
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
32
|
-
[](https://deepwiki.com/onnx/ir-py)
|
|
33
32
|
[](https://pepy.tech/projects/onnx-ir)
|
|
34
33
|
|
|
35
34
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
onnx_ir/__init__.py,sha256=
|
|
1
|
+
onnx_ir/__init__.py,sha256=rsm-93uR-9LRHYGjVec4xA1qqUwzrvArfL7SYVdax9E,3424
|
|
2
2
|
onnx_ir/_core.py,sha256=CtRwtDb__hK0MJLWsrNNu5n_xz6TlbJctDLw8UDQAZQ,137454
|
|
3
3
|
onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
|
|
4
4
|
onnx_ir/_enums.py,sha256=SxC-GGgPrmdz6UsMhx7xT9-6VmkZ6j1oVzDqNUHr3Rc,9659
|
|
@@ -18,15 +18,15 @@ onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,1807
|
|
|
18
18
|
onnx_ir/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
19
19
|
onnx_ir/serde.py,sha256=bFQg5XYlDTvZsT_gDO_mPYedkMj_HcUbBvQuxLlRKvc,75980
|
|
20
20
|
onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
|
|
21
|
-
onnx_ir/tensor_adapters.py,sha256=
|
|
21
|
+
onnx_ir/tensor_adapters.py,sha256=YffUeZDZi8thxm-4nF2cL6cNSJSVmLm4A3IbEzwY8QQ,7233
|
|
22
22
|
onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
|
|
23
|
-
onnx_ir/traversal.py,sha256=
|
|
23
|
+
onnx_ir/traversal.py,sha256=Wy4XphwuapAvm94-5iaz6G8LjIoMFpY7qfPfXzYViEE,4488
|
|
24
24
|
onnx_ir/_convenience/__init__.py,sha256=DQ-Bz1wTiZJEARCFxDqZvYexWviGmwvDzE_1hR-vp0Q,19182
|
|
25
25
|
onnx_ir/_convenience/_constructors.py,sha256=5GhlYy_xCE2ng7l_4cNx06WQsNDyvS-0U1HgOpPKJEk,8347
|
|
26
26
|
onnx_ir/_thirdparty/asciichartpy.py,sha256=afQ0fsqko2uYRPAR4TZBrQxvCb4eN8lxZ2yDFbVQq_s,10533
|
|
27
27
|
onnx_ir/passes/__init__.py,sha256=M_Tcl_-qGSNPluFIvOoeDyh0qAwNayaYyXDS5UJUJPQ,764
|
|
28
28
|
onnx_ir/passes/_pass_infra.py,sha256=xIOw_zZIuOqD4Z_wZ4OvsqXfh2IZMoMlDp1xQ_MPQlc,9567
|
|
29
|
-
onnx_ir/passes/common/__init__.py,sha256=
|
|
29
|
+
onnx_ir/passes/common/__init__.py,sha256=vYRzXo4a_c_1Ad7UNCTHsghKIJJngOQNUWlwCaMrXcE,1658
|
|
30
30
|
onnx_ir/passes/common/_c_api_utils.py,sha256=g6riA6xNGVWaO5YjVHZ0krrfslWHmRlryRkwB8X56cg,2907
|
|
31
31
|
onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMOGjdvVwG9yHQTkSphUgDlM0ME,2365
|
|
32
32
|
onnx_ir/passes/common/common_subexpression_elimination.py,sha256=wZ1zEPdCshYB_ifP9fCAVfzQkesE6uhCfzCuL2qO5fA,7948
|
|
@@ -34,12 +34,13 @@ onnx_ir/passes/common/constant_manipulation.py,sha256=_fGDwn0Axl2Q8APfc2m_mLMH28
|
|
|
34
34
|
onnx_ir/passes/common/identity_elimination.py,sha256=FyqnJxFUq9Ga9XyUJ3myjzr36InYSW-oJgDTrUrBORY,3663
|
|
35
35
|
onnx_ir/passes/common/initializer_deduplication.py,sha256=4CIVFYfdXUlmF2sAx560c_pTwYVXtX5hcSwWzUKm5uc,2061
|
|
36
36
|
onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
|
|
37
|
+
onnx_ir/passes/common/naming.py,sha256=kEqIYBVweFvZSJcG8wi8o9_Dmk-NswCp_niuzrq-ubk,10926
|
|
37
38
|
onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6svE0cCyDew,1691
|
|
38
39
|
onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
|
|
39
40
|
onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
|
|
40
41
|
onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
|
|
41
|
-
onnx_ir-0.1.
|
|
42
|
-
onnx_ir-0.1.
|
|
43
|
-
onnx_ir-0.1.
|
|
44
|
-
onnx_ir-0.1.
|
|
45
|
-
onnx_ir-0.1.
|
|
42
|
+
onnx_ir-0.1.6.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
43
|
+
onnx_ir-0.1.6.dist-info/METADATA,sha256=egWNVHaVs8LxpXuBvap5GgGpGGJ9v0Do9ZasSw_x2MM,3523
|
|
44
|
+
onnx_ir-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
45
|
+
onnx_ir-0.1.6.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
|
|
46
|
+
onnx_ir-0.1.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|