onnx-ir 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.5"
170
+ __version__ = "0.1.7"
@@ -58,44 +58,52 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
58
58
  return _enums.AttributeType.STRING
59
59
  if isinstance(attr, _core.Attr):
60
60
  return attr.type
61
- if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
62
- return _enums.AttributeType.INTS
63
- if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
64
- return _enums.AttributeType.FLOATS
65
- if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
66
- return _enums.AttributeType.STRINGS
61
+ if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
62
+ return _enums.AttributeType.GRAPH
67
63
  if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
68
64
  # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
69
65
  return _enums.AttributeType.TENSOR
70
- if isinstance(attr, Sequence) and all(
71
- isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
72
- for x in attr
73
- ):
74
- return _enums.AttributeType.TENSORS
75
- if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
76
- return _enums.AttributeType.GRAPH
77
- if isinstance(attr, Sequence) and all(
78
- isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
79
- ):
80
- return _enums.AttributeType.GRAPHS
81
66
  if isinstance(
82
67
  attr,
83
68
  (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
84
69
  ):
85
70
  return _enums.AttributeType.TYPE_PROTO
86
- if isinstance(attr, Sequence) and all(
87
- isinstance(
88
- x,
89
- (
90
- _core.TensorType,
91
- _core.SequenceType,
92
- _core.OptionalType,
93
- _protocols.TypeProtocol,
94
- ),
95
- )
96
- for x in attr
97
- ):
98
- return _enums.AttributeType.TYPE_PROTOS
71
+ if isinstance(attr, Sequence):
72
+ if not attr:
73
+ logger.warning(
74
+ "Attribute type is ambiguous because it is an empty sequence. "
75
+ "Please create an Attr with an explicit type. Defaulted to INTS"
76
+ )
77
+ return _enums.AttributeType.INTS
78
+ if all(isinstance(x, int) for x in attr):
79
+ return _enums.AttributeType.INTS
80
+ if all(isinstance(x, float) for x in attr):
81
+ return _enums.AttributeType.FLOATS
82
+ if all(isinstance(x, str) for x in attr):
83
+ return _enums.AttributeType.STRINGS
84
+ if all(
85
+ isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
86
+ for x in attr
87
+ ):
88
+ return _enums.AttributeType.TENSORS
89
+ if all(
90
+ isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
91
+ for x in attr
92
+ ):
93
+ return _enums.AttributeType.GRAPHS
94
+ if all(
95
+ isinstance(
96
+ x,
97
+ (
98
+ _core.TensorType,
99
+ _core.SequenceType,
100
+ _core.OptionalType,
101
+ _protocols.TypeProtocol,
102
+ ),
103
+ )
104
+ for x in attr
105
+ ):
106
+ return _enums.AttributeType.TYPE_PROTOS
99
107
  raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
100
108
 
101
109
 
@@ -218,7 +226,7 @@ def convert_attributes(
218
226
  ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
219
227
  ... }
220
228
  >>> convert_attributes(attrs)
221
- [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph(
229
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
222
230
  name='graph0',
223
231
  inputs=(
224
232
  <BLANKLINE>
@@ -247,11 +255,20 @@ def convert_attributes(
247
255
  len()=0
248
256
  )]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
249
257
 
258
+ .. important::
259
+ An empty sequence should be created with an explicit type by initializing
260
+ an Attr object with an attribute type to avoid type ambiguity. For example::
261
+
262
+ ir.Attr("empty", [], type=ir.AttributeType.INTS)
263
+
250
264
  Args:
251
265
  attrs: A dictionary of {<attribute name>: <python objects>} to convert.
252
266
 
253
267
  Returns:
254
- A list of _core.Attr objects.
268
+ A list of :class:`_core.Attr` objects.
269
+
270
+ Raises:
271
+ TypeError: If an attribute type is not supported.
255
272
  """
256
273
  attributes: list[_core.Attr] = []
257
274
  for name, attr in attrs.items():
onnx_ir/_core.py CHANGED
@@ -2564,14 +2564,23 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2564
2564
 
2565
2565
  .. versionadded:: 0.1.2
2566
2566
  """
2567
- seen_graphs: set[Graph] = set()
2568
- for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2569
- graph = node.graph
2567
+ # Use a dict to preserve order
2568
+ seen_graphs: dict[Graph, None] = {}
2569
+
2570
+ # Need to use the enter_graph callback so that empty subgraphs are collected
2571
+ def enter_subgraph(graph) -> None:
2570
2572
  if graph is self:
2571
- continue
2572
- if graph is not None and graph not in seen_graphs:
2573
- seen_graphs.add(graph)
2574
- yield graph
2573
+ return
2574
+ if not isinstance(graph, Graph):
2575
+ raise TypeError(
2576
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
2577
+ )
2578
+ if graph not in seen_graphs:
2579
+ seen_graphs[graph] = None
2580
+
2581
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
2582
+ pass
2583
+ yield from seen_graphs.keys()
2575
2584
 
2576
2585
  # Mutation methods
2577
2586
  def append(self, node: Node, /) -> None:
@@ -3180,6 +3189,21 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3180
3189
  def attributes(self) -> _graph_containers.Attributes:
3181
3190
  return self._attributes
3182
3191
 
3192
+ @property
3193
+ def graph(self) -> Graph:
3194
+ """The underlying Graph object that contains the nodes of this function.
3195
+
3196
+ Only use this graph for identity comparison::
3197
+
3198
+ if value.graph is function.graph:
3199
+ # Do something with the value that belongs to this function
3200
+
3201
+ Otherwise use the Function object directly to access the nodes and other properties.
3202
+
3203
+ .. versionadded:: 0.1.7
3204
+ """
3205
+ return self._graph
3206
+
3183
3207
  @typing.overload
3184
3208
  def __getitem__(self, index: int) -> Node: ...
3185
3209
  @typing.overload
@@ -3240,14 +3264,22 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3240
3264
 
3241
3265
  .. versionadded:: 0.1.2
3242
3266
  """
3243
- seen_graphs: set[Graph] = set()
3244
- for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3245
- graph = node.graph
3246
- if graph is self._graph:
3247
- continue
3248
- if graph is not None and graph not in seen_graphs:
3249
- seen_graphs.add(graph)
3250
- yield graph
3267
+ seen_graphs: dict[Graph, None] = {}
3268
+
3269
+ # Need to use the enter_graph callback so that empty subgraphs are collected
3270
+ def enter_subgraph(graph) -> None:
3271
+ if graph is self:
3272
+ return
3273
+ if not isinstance(graph, Graph):
3274
+ raise TypeError(
3275
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
3276
+ )
3277
+ if graph not in seen_graphs:
3278
+ seen_graphs[graph] = None
3279
+
3280
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
3281
+ pass
3282
+ yield from seen_graphs.keys()
3251
3283
 
3252
3284
  # Mutation methods
3253
3285
  def append(self, node: Node, /) -> None:
@@ -3349,7 +3381,7 @@ class Attr(
3349
3381
  ):
3350
3382
  """Base class for ONNX attributes or references."""
3351
3383
 
3352
- __slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
3384
+ __slots__ = ("_metadata", "_name", "_ref_attr_name", "_type", "_value", "doc_string")
3353
3385
 
3354
3386
  def __init__(
3355
3387
  self,
@@ -3365,6 +3397,7 @@ class Attr(
3365
3397
  self._value = value
3366
3398
  self._ref_attr_name = ref_attr_name
3367
3399
  self.doc_string = doc_string
3400
+ self._metadata: _metadata.MetadataStore | None = None
3368
3401
 
3369
3402
  @property
3370
3403
  def name(self) -> str:
@@ -3386,6 +3419,17 @@ class Attr(
3386
3419
  def ref_attr_name(self) -> str | None:
3387
3420
  return self._ref_attr_name
3388
3421
 
3422
+ @property
3423
+ def meta(self) -> _metadata.MetadataStore:
3424
+ """The metadata store for intermediate analysis.
3425
+
3426
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
3427
+ to the ONNX proto.
3428
+ """
3429
+ if self._metadata is None:
3430
+ self._metadata = _metadata.MetadataStore()
3431
+ return self._metadata
3432
+
3389
3433
  def is_ref(self) -> bool:
3390
3434
  """Check if this attribute is a reference attribute."""
3391
3435
  return self.ref_attr_name is not None
@@ -6,11 +6,13 @@ __all__ = [
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
+ "DeduplicateHashedInitializersPass",
9
10
  "DeduplicateInitializersPass",
10
11
  "IdentityEliminationPass",
11
12
  "InlinePass",
12
13
  "LiftConstantsToInitializersPass",
13
14
  "LiftSubgraphInitializersToMainGraphPass",
15
+ "NameFixPass",
14
16
  "RemoveInitializersFromInputsPass",
15
17
  "RemoveUnusedFunctionsPass",
16
18
  "RemoveUnusedNodesPass",
@@ -35,9 +37,11 @@ from onnx_ir.passes.common.identity_elimination import (
35
37
  IdentityEliminationPass,
36
38
  )
37
39
  from onnx_ir.passes.common.initializer_deduplication import (
40
+ DeduplicateHashedInitializersPass,
38
41
  DeduplicateInitializersPass,
39
42
  )
40
43
  from onnx_ir.passes.common.inliner import InlinePass
44
+ from onnx_ir.passes.common.naming import NameFixPass
41
45
  from onnx_ir.passes.common.onnx_checker import CheckerPass
42
46
  from onnx_ir.passes.common.shape_inference import ShapeInferencePass
43
47
  from onnx_ir.passes.common.topological_sort import TopologicalSortPass
@@ -148,6 +148,7 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
148
148
  if graph is model.graph:
149
149
  continue
150
150
  for name in tuple(graph.initializers):
151
+ assert name is not None
151
152
  initializer = graph.initializers[name]
152
153
  if initializer.is_graph_input():
153
154
  # Skip the ones that are also graph inputs
@@ -156,17 +157,24 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
156
157
  initializer.name,
157
158
  )
158
159
  continue
160
+ if initializer.is_graph_output():
161
+ logger.debug(
162
+ "Initializer '%s' is used as output, so it can't be lifted",
163
+ initializer.name,
164
+ )
165
+ continue
159
166
  # Remove the initializer from the subgraph
160
167
  graph.initializers.pop(name)
161
168
  # To avoid name conflicts, we need to rename the initializer
162
169
  # to a unique name in the main graph
163
- if name in registered_initializer_names:
164
- name_count = registered_initializer_names[name]
165
- initializer.name = f"{name}_{name_count}"
166
- registered_initializer_names[name] = name_count + 1
167
- else:
168
- assert initializer.name is not None
169
- registered_initializer_names[initializer.name] = 1
170
+ new_name = name
171
+ while new_name in model.graph.initializers:
172
+ if name in registered_initializer_names:
173
+ registered_initializer_names[name] += 1
174
+ else:
175
+ registered_initializer_names[name] = 1
176
+ new_name = f"{name}_{registered_initializer_names[name]}"
177
+ initializer.name = new_name
170
178
  model.graph.register_initializer(initializer)
171
179
  count += 1
172
180
  logger.debug(
@@ -19,6 +19,7 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
19
19
  """Pass for eliminating redundant Identity nodes.
20
20
 
21
21
  This pass removes Identity nodes according to the following rules:
22
+
22
23
  1. For any node of the form `y = Identity(x)`, where `y` is not an output
23
24
  of any graph, replace all uses of `y` with a use of `x`, and remove the node.
24
25
  2. If `y` is an output of a graph, and `x` is not an input of any graph,
@@ -4,24 +4,68 @@
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- __all__ = [
8
- "DeduplicateInitializersPass",
9
- ]
7
+ __all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
10
8
 
11
9
 
10
+ import hashlib
11
+ import logging
12
+
12
13
  import onnx_ir as ir
13
14
 
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
19
+ """Check if the initializer should be skipped for deduplication."""
20
+ if initializer.is_graph_input() or initializer.is_graph_output():
21
+ # Skip graph inputs and outputs
22
+ logger.warning(
23
+ "Skipped deduplication of initializer '%s' as it is a graph input or output",
24
+ initializer.name,
25
+ )
26
+ return True
27
+
28
+ const_val = initializer.const_value
29
+ if const_val is None:
30
+ # Skip if initializer has no constant value
31
+ logger.warning(
32
+ "Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
33
+ initializer.name,
34
+ )
35
+ return True
36
+
37
+ if const_val.size > size_limit:
38
+ # Skip if the initializer is larger than the size limit
39
+ logger.debug(
40
+ "Skipped initializer '%s' as it exceeds the size limit of %d elements",
41
+ initializer.name,
42
+ size_limit,
43
+ )
44
+ return True
45
+
46
+ if const_val.dtype == ir.DataType.STRING:
47
+ # Skip string initializers as they don't have a bytes representation
48
+ logger.warning(
49
+ "Skipped deduplication of string initializer '%s' (unsupported yet)",
50
+ initializer.name,
51
+ )
52
+ return True
53
+ return False
54
+
14
55
 
15
56
  class DeduplicateInitializersPass(ir.passes.InPlacePass):
16
- """Remove duplicated initializer tensors from the graph.
57
+ """Remove duplicated initializer tensors from the main graph and all subgraphs.
17
58
 
18
59
  This pass detects initializers with identical shape, dtype, and content,
19
60
  and replaces all duplicate references with a canonical one.
20
61
 
21
- To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
62
+ Initializers are deduplicated within each graph. To deduplicate initializers
63
+ in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
22
64
  to lift the initializers to the main graph first before running pass.
23
65
 
24
66
  .. versionadded:: 0.1.3
67
+ .. versionchanged:: 0.1.7
68
+ This pass now deduplicates initializers in subgraphs as well.
25
69
  """
26
70
 
27
71
  def __init__(self, size_limit: int = 1024):
@@ -29,28 +73,95 @@ class DeduplicateInitializersPass(ir.passes.InPlacePass):
29
73
  self.size_limit = size_limit
30
74
 
31
75
  def call(self, model: ir.Model) -> ir.passes.PassResult:
32
- graph = model.graph
33
- initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
34
76
  modified = False
35
77
 
36
- for initializer in tuple(graph.initializers.values()):
37
- # TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
38
- # out from the main graph before running this pass.
39
- const_val = initializer.const_value
40
- if const_val is None:
41
- # Skip if initializer has no constant value
42
- continue
43
-
44
- if const_val.size > self.size_limit:
45
- continue
46
-
47
- key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
48
- if key in initializers:
49
- modified = True
50
- ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
51
- assert initializer.name is not None
52
- graph.initializers.pop(initializer.name)
53
- else:
54
- initializers[key] = initializer # type: ignore[index]
78
+ for graph in model.graphs():
79
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
80
+ for initializer in tuple(graph.initializers.values()):
81
+ if _should_skip_initializer(initializer, self.size_limit):
82
+ continue
83
+
84
+ const_val = initializer.const_value
85
+ assert const_val is not None
86
+
87
+ key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
88
+ if key in initializers:
89
+ modified = True
90
+ initializer_to_keep = initializers[key] # type: ignore[index]
91
+ ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
92
+ assert initializer.name is not None
93
+ graph.initializers.pop(initializer.name)
94
+ logger.info(
95
+ "Replaced initializer '%s' with existing initializer '%s'",
96
+ initializer.name,
97
+ initializer_to_keep.name,
98
+ )
99
+ else:
100
+ initializers[key] = initializer # type: ignore[index]
101
+
102
+ return ir.passes.PassResult(model=model, modified=modified)
103
+
104
+
105
+ class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
106
+ """Remove duplicated initializer tensors (using a hashed method) from the graph.
107
+
108
+ This pass detects initializers with identical shape, dtype, and hashed content,
109
+ and replaces all duplicate references with a canonical one.
110
+
111
+ This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass`
112
+ as it does not store the full tensor data in memory, but instead uses a hash of the tensor data.
113
+
114
+ .. versionadded:: 0.1.7
115
+ """
116
+
117
+ def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024):
118
+ super().__init__()
119
+ # 4 GB default size limit for deduplication
120
+ self.size_limit = size_limit
121
+
122
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
123
+ modified = False
124
+
125
+ for graph in model.graphs():
126
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {}
127
+
128
+ for initializer in tuple(graph.initializers.values()):
129
+ if _should_skip_initializer(initializer, self.size_limit):
130
+ continue
131
+
132
+ const_val = initializer.const_value
133
+ assert const_val is not None
134
+
135
+ # Hash tensor data to avoid storing large amounts of data in memory
136
+ hashed = hashlib.sha512()
137
+ tensor_data = const_val.numpy()
138
+ hashed.update(tensor_data)
139
+ tensor_digest = hashed.hexdigest()
140
+
141
+ tensor_dims = tuple(const_val.shape.numpy())
142
+
143
+ key = (const_val.dtype, tensor_dims, tensor_digest)
144
+
145
+ if key in initializers:
146
+ if initializers[key].const_value.tobytes() != const_val.tobytes():
147
+ logger.warning(
148
+ "Initializer deduplication failed: "
149
+ "hashes match but values differ with values %s and %s",
150
+ initializers[key],
151
+ initializer,
152
+ )
153
+ continue
154
+ modified = True
155
+ initializer_to_keep = initializers[key] # type: ignore[index]
156
+ ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
157
+ assert initializer.name is not None
158
+ graph.initializers.pop(initializer.name)
159
+ logger.info(
160
+ "Replaced initializer '%s' with existing initializer '%s'",
161
+ initializer.name,
162
+ initializer_to_keep.name,
163
+ )
164
+ else:
165
+ initializers[key] = initializer # type: ignore[index]
55
166
 
56
167
  return ir.passes.PassResult(model=model, modified=modified)
@@ -0,0 +1,286 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Name fix pass for ensuring unique names for all values and nodes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "NameFixPass",
9
+ "NameGenerator",
10
+ "SimpleNameGenerator",
11
+ ]
12
+
13
+ import collections
14
+ import logging
15
+ from typing import Protocol
16
+
17
+ import onnx_ir as ir
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class NameGenerator(Protocol):
23
+ def generate_node_name(self, node: ir.Node) -> str:
24
+ """Generate a preferred name for a node."""
25
+ ...
26
+
27
+ def generate_value_name(self, value: ir.Value) -> str:
28
+ """Generate a preferred name for a value."""
29
+ ...
30
+
31
+
32
+ class SimpleNameGenerator(NameGenerator):
33
+ """Base class for name generation functions."""
34
+
35
+ def generate_node_name(self, node: ir.Node) -> str:
36
+ """Generate a preferred name for a node."""
37
+ return node.name or "node"
38
+
39
+ def generate_value_name(self, value: ir.Value) -> str:
40
+ """Generate a preferred name for a value."""
41
+ return value.name or "v"
42
+
43
+
44
+ class NameFixPass(ir.passes.InPlacePass):
45
+ """Pass for fixing names to ensure all values and nodes have unique names.
46
+
47
+ This pass ensures that:
48
+ 1. Graph inputs and outputs have unique names (take precedence)
49
+ 2. All intermediate values have unique names (assign names to unnamed values)
50
+ 3. All values in subgraphs have unique names within their graph and parent graphs
51
+ 4. All nodes have unique names within their graph
52
+
53
+ The pass maintains global uniqueness across the entire model.
54
+
55
+ You can customize the name generation functions for nodes and values by passing
56
+ a subclass of :class:`NameGenerator`.
57
+
58
+ For example, you can use a custom naming scheme like this::
59
+
60
+ class CustomNameGenerator:
61
+ def custom_node_name(node: ir.Node) -> str:
62
+ return f"custom_node_{node.op_type}"
63
+
64
+ def custom_value_name(value: ir.Value) -> str:
65
+ return f"custom_value_{value.type}"
66
+
67
+ name_fix_pass = NameFixPass(name_generator=CustomNameGenerator())
68
+
69
+ .. versionadded:: 0.1.6
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ name_generator: NameGenerator | None = None,
75
+ ) -> None:
76
+ """Initialize the NameFixPass with custom name generation functions.
77
+
78
+ Args:
79
+ name_generator (NameGenerator, optional): An instance of a subclass of
80
+ :class:`NameGenerator` to customize name generation for nodes and values.
81
+ If not provided, defaults to a basic implementation that uses
82
+ the node's or value's existing name or a generic name like "node" or "v".
83
+ """
84
+ super().__init__()
85
+ self._name_generator = name_generator or SimpleNameGenerator()
86
+
87
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
88
+ # Process the main graph
89
+ modified = self._fix_graph_names(model.graph)
90
+
91
+ # Process functions
92
+ for function in model.functions.values():
93
+ modified = self._fix_graph_names(function) or modified
94
+
95
+ return ir.passes.PassResult(model, modified=modified)
96
+
97
+ def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
98
+ """Fix names in a graph and return whether modifications were made."""
99
+ modified = False
100
+
101
+ # Set to track which values have been assigned names
102
+ seen_values: set[ir.Value] = set()
103
+
104
+ # The first set is a dummy placeholder so that there is always a [-1] scope for access
105
+ # (even though we don't write to it)
106
+ scoped_used_value_names: list[set[str]] = [set()]
107
+ scoped_used_node_names: list[set[str]] = [set()]
108
+
109
+ # Counters for generating unique names (using list to pass by reference)
110
+ value_counter = collections.Counter()
111
+ node_counter = collections.Counter()
112
+
113
+ def enter_graph(graph_like) -> None:
114
+ """Callback for entering a subgraph."""
115
+ # Initialize new scopes with all names from the parent scope
116
+ scoped_used_value_names.append(set(scoped_used_value_names[-1]))
117
+ scoped_used_node_names.append(set())
118
+
119
+ nonlocal modified
120
+
121
+ # Step 1: Fix graph input names first (they have precedence)
122
+ for input_value in graph_like.inputs:
123
+ if self._process_value(
124
+ input_value, scoped_used_value_names[-1], seen_values, value_counter
125
+ ):
126
+ modified = True
127
+
128
+ # Step 2: Fix graph output names (they have precedence)
129
+ for output_value in graph_like.outputs:
130
+ if self._process_value(
131
+ output_value, scoped_used_value_names[-1], seen_values, value_counter
132
+ ):
133
+ modified = True
134
+
135
+ if isinstance(graph_like, ir.Graph):
136
+ # For graphs, also fix initializers
137
+ for initializer in graph_like.initializers.values():
138
+ if self._process_value(
139
+ initializer, scoped_used_value_names[-1], seen_values, value_counter
140
+ ):
141
+ modified = True
142
+
143
+ def exit_graph(_) -> None:
144
+ """Callback for exiting a subgraph."""
145
+ # Pop the current scope
146
+ scoped_used_value_names.pop()
147
+ scoped_used_node_names.pop()
148
+
149
+ # Step 3: Process all nodes and their values
150
+ for node in ir.traversal.RecursiveGraphIterator(
151
+ graph_like, enter_graph=enter_graph, exit_graph=exit_graph
152
+ ):
153
+ # Fix node name
154
+ if not node.name:
155
+ if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
156
+ modified = True
157
+ else:
158
+ if self._fix_duplicate_node_name(
159
+ node, scoped_used_node_names[-1], node_counter
160
+ ):
161
+ modified = True
162
+
163
+ # Fix input value names (only if not already processed)
164
+ for input_value in node.inputs:
165
+ if input_value is not None:
166
+ if self._process_value(
167
+ input_value, scoped_used_value_names[-1], seen_values, value_counter
168
+ ):
169
+ modified = True
170
+
171
+ # Fix output value names (only if not already processed)
172
+ for output_value in node.outputs:
173
+ if self._process_value(
174
+ output_value, scoped_used_value_names[-1], seen_values, value_counter
175
+ ):
176
+ modified = True
177
+
178
+ return modified
179
+
180
+ def _process_value(
181
+ self,
182
+ value: ir.Value,
183
+ used_value_names: set[str],
184
+ seen_values: set[ir.Value],
185
+ value_counter: collections.Counter,
186
+ ) -> bool:
187
+ """Process a value only if it hasn't been processed before."""
188
+ if value in seen_values:
189
+ return False
190
+
191
+ modified = False
192
+
193
+ if not value.name:
194
+ modified = self._assign_value_name(value, used_value_names, value_counter)
195
+ else:
196
+ old_name = value.name
197
+ modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
198
+ if modified:
199
+ assert value.graph is not None
200
+ if value.is_initializer():
201
+ value.graph.initializers.pop(old_name)
202
+ # Add the initializer back with the new name
203
+ value.graph.initializers.add(value)
204
+
205
+ # Record the final name for this value
206
+ assert value.name is not None
207
+ seen_values.add(value)
208
+ return modified
209
+
210
+ def _assign_value_name(
211
+ self, value: ir.Value, used_names: set[str], counter: collections.Counter
212
+ ) -> bool:
213
+ """Assign a name to an unnamed value. Returns True if modified."""
214
+ assert not value.name, (
215
+ "value should not have a name already if function is called correctly"
216
+ )
217
+
218
+ preferred_name = self._name_generator.generate_value_name(value)
219
+ value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
220
+ logger.debug("Assigned name %s to unnamed value", value.name)
221
+ return True
222
+
223
+ def _assign_node_name(
224
+ self, node: ir.Node, used_names: set[str], counter: collections.Counter
225
+ ) -> bool:
226
+ """Assign a name to an unnamed node. Returns True if modified."""
227
+ assert not node.name, (
228
+ "node should not have a name already if function is called correctly"
229
+ )
230
+
231
+ preferred_name = self._name_generator.generate_node_name(node)
232
+ node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
233
+ logger.debug("Assigned name %s to unnamed node", node.name)
234
+ return True
235
+
236
+ def _fix_duplicate_value_name(
237
+ self, value: ir.Value, used_names: set[str], counter: collections.Counter
238
+ ) -> bool:
239
+ """Fix a value's name if it conflicts with existing names. Returns True if modified."""
240
+ original_name = value.name
241
+
242
+ assert original_name, (
243
+ "value should have a name already if function is called correctly"
244
+ )
245
+
246
+ if original_name not in used_names:
247
+ # Name is unique, just record it
248
+ used_names.add(original_name)
249
+ return False
250
+
251
+ # If name is already used, make it unique
252
+ base_name = self._name_generator.generate_value_name(value)
253
+ value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
254
+ logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
255
+ return True
256
+
257
+ def _fix_duplicate_node_name(
258
+ self, node: ir.Node, used_names: set[str], counter: collections.Counter
259
+ ) -> bool:
260
+ """Fix a node's name if it conflicts with existing names. Returns True if modified."""
261
+ original_name = node.name
262
+
263
+ assert original_name, "node should have a name already if function is called correctly"
264
+
265
+ if original_name not in used_names:
266
+ # Name is unique, just record it
267
+ used_names.add(original_name)
268
+ return False
269
+
270
+ # If name is already used, make it unique
271
+ base_name = self._name_generator.generate_node_name(node)
272
+ node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
273
+ logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
274
+ return True
275
+
276
+
277
+ def _find_and_record_next_unique_name(
278
+ preferred_name: str, used_names: set[str], counter: collections.Counter
279
+ ) -> str:
280
+ """Generate a unique name based on the preferred name and current counter."""
281
+ new_name = preferred_name
282
+ while new_name in used_names:
283
+ counter[preferred_name] += 1
284
+ new_name = f"{preferred_name}_{counter[preferred_name]}"
285
+ used_names.add(new_name)
286
+ return new_name
onnx_ir/serde.py CHANGED
@@ -682,8 +682,8 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
682
682
  Returns:
683
683
  IR Graph.
684
684
 
685
- .. versionadded:: 0.3
686
- Support for *quantization_annotation* is added.
685
+ .. versionadded:: 0.1.3
686
+ Support for `quantization_annotation` is added.
687
687
  """
688
688
  return _deserialize_graph(proto, [])
689
689
 
@@ -760,6 +760,18 @@ def _deserialize_graph(
760
760
  # Build the value info dictionary to allow for quick lookup for this graph scope
761
761
  value_info = {info.name: info for info in proto.value_info}
762
762
 
763
+ # Declare values for all node outputs from this graph scope. This is necessary
764
+ # to handle the case where a node in a subgraph uses a value that is declared out
765
+ # of order in the outer graph. Declaring the values first allows us to find the
766
+ # values later when deserializing the nodes in subgraphs.
767
+ for node in proto.node:
768
+ _declare_node_outputs(
769
+ node,
770
+ values,
771
+ value_info=value_info,
772
+ quantization_annotations=quantization_annotations,
773
+ )
774
+
763
775
  # Deserialize nodes with all known values
764
776
  nodes = [
765
777
  _deserialize_node(node, scoped_values, value_info, quantization_annotations)
@@ -798,6 +810,55 @@ def _deserialize_graph(
798
810
  )
799
811
 
800
812
 
813
+ def _declare_node_outputs(
814
+ proto: onnx.NodeProto,
815
+ current_value_scope: dict[str, _core.Value],
816
+ value_info: dict[str, onnx.ValueInfoProto],
817
+ quantization_annotations: dict[str, onnx.TensorAnnotation],
818
+ ) -> None:
819
+ """Declare outputs for a node in the current graph scope.
820
+
821
+ This is necessary to handle the case where a node in a subgraph uses a value that is declared
822
+ out of order in the outer graph. Declaring the values first allows us to find the values later
823
+ when deserializing the nodes in subgraphs.
824
+
825
+ Args:
826
+ proto: The ONNX NodeProto to declare outputs for.
827
+ current_value_scope: The current scope of values, mapping value names to their corresponding Value objects.
828
+ value_info: A dictionary mapping value names to their corresponding ValueInfoProto.
829
+ quantization_annotations: A dictionary mapping tensor names to their corresponding TensorAnnotation.
830
+
831
+ Raises:
832
+ ValueError: If an output name is redeclared in the current graph scope.
833
+ """
834
+ for output_name in proto.output:
835
+ if output_name == "":
836
+ continue
837
+ if output_name in current_value_scope:
838
+ raise ValueError(
839
+ f"Output '{output_name}' is redeclared in the current graph scope. "
840
+ f"Original declaration {current_value_scope[output_name]}. "
841
+ f"New declaration: by operator '{proto.op_type}' of node '{proto.name}'. "
842
+ "The model is invalid"
843
+ )
844
+
845
+ # Create the value and add it to the current scope.
846
+ value = _core.Value(name=output_name)
847
+ current_value_scope[output_name] = value
848
+ # Fill in shape/type information if they exist
849
+ if output_name in value_info:
850
+ deserialize_value_info_proto(value_info[output_name], value)
851
+ else:
852
+ logger.debug(
853
+ "ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
854
+ output_name,
855
+ proto.name,
856
+ proto.op_type,
857
+ )
858
+ if output_name in quantization_annotations:
859
+ _deserialize_quantization_annotation(quantization_annotations[output_name], value)
860
+
861
+
801
862
  @_capture_errors(lambda proto: proto.name)
802
863
  def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
803
864
  """Deserialize an ONNX FunctionProto into an IR Function.
@@ -812,7 +873,14 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
812
873
  values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
813
874
  value_info = {info.name: info for info in getattr(proto, "value_info", [])}
814
875
 
815
- # TODO(justinchuby): Handle unsorted nodes
876
+ for node in proto.node:
877
+ _declare_node_outputs(
878
+ node,
879
+ values,
880
+ value_info=value_info,
881
+ quantization_annotations={},
882
+ )
883
+
816
884
  nodes = [
817
885
  _deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
818
886
  for node in proto.node
@@ -1137,8 +1205,15 @@ def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
1137
1205
  Returns:
1138
1206
  An IR Node object representing the ONNX node.
1139
1207
  """
1208
+ value_scope: dict[str, _core.Value] = {}
1209
+ _declare_node_outputs(
1210
+ proto,
1211
+ value_scope,
1212
+ value_info={},
1213
+ quantization_annotations={},
1214
+ )
1140
1215
  return _deserialize_node(
1141
- proto, scoped_values=[{}], value_info={}, quantization_annotations={}
1216
+ proto, scoped_values=[value_scope], value_info={}, quantization_annotations={}
1142
1217
  )
1143
1218
 
1144
1219
 
@@ -1161,18 +1236,18 @@ def _deserialize_node(
1161
1236
  for values in reversed(scoped_values):
1162
1237
  if input_name not in values:
1163
1238
  continue
1239
+
1164
1240
  node_inputs.append(values[input_name])
1165
1241
  found = True
1166
1242
  del values # Remove the reference so it is not used by mistake
1167
1243
  break
1168
1244
  if not found:
1169
- # If the input is not found, we know the graph may be unsorted and
1170
- # the input may be a supposed-to-be initializer or an output of a node that comes later.
1171
- # Here we create the value with the name and add it to the current scope.
1172
- # Nodes need to check the value pool for potentially initialized outputs
1245
+ # If the input is not found, we know the graph is invalid because the value
1246
+ # is not declared. We will still create a new input for the node so that
1247
+ # it can be fixed later.
1173
1248
  logger.warning(
1174
- "Input '%s' of node '%s(%s::%s:%s)' not found in any scope. "
1175
- "The graph may be unsorted. Creating a new input (current depth: %s) .",
1249
+ "Input '%s' of node '%s' (%s::%s:%s) cannot be found in any scope. "
1250
+ "The model is invalid but we will still create a new input for the node (current depth: %s)",
1176
1251
  input_name,
1177
1252
  proto.name,
1178
1253
  proto.domain,
@@ -1208,35 +1283,22 @@ def _deserialize_node(
1208
1283
  node_outputs.append(_core.Value(name=""))
1209
1284
  continue
1210
1285
 
1211
- # 1. When the graph is unsorted, we may be able to find the output already created
1286
+ # The outputs should already be declared in the current scope by _declare_node_outputs.
1287
+ #
1288
+ # When the graph is unsorted, we may be able to find the output already created
1212
1289
  # as an input to some other nodes in the current scope.
1213
1290
  # Note that a value is always owned by the producing node. Even though a value
1214
1291
  # can be created when parsing inputs of other nodes, the new node created here
1215
1292
  # that produces the value will assume ownership. It is then impossible to transfer
1216
1293
  # the ownership to any other node.
1217
-
1294
+ #
1218
1295
  # The output can only be found in the current scope. It is impossible for
1219
1296
  # a node to produce an output that is not in its own scope.
1220
1297
  current_scope = scoped_values[-1]
1221
- if output_name in current_scope:
1222
- value = current_scope[output_name]
1223
- else:
1224
- # 2. Common scenario: the graph is sorted and this is the first time we see the output.
1225
- # Create the value and add it to the current scope.
1226
- value = _core.Value(name=output_name)
1227
- current_scope[output_name] = value
1228
- # Fill in shape/type information if they exist
1229
- if output_name in value_info:
1230
- deserialize_value_info_proto(value_info[output_name], value)
1231
- else:
1232
- logger.debug(
1233
- "ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
1234
- output_name,
1235
- proto.name,
1236
- proto.op_type,
1237
- )
1238
- if output_name in quantization_annotations:
1239
- _deserialize_quantization_annotation(quantization_annotations[output_name], value)
1298
+ assert output_name in current_scope, (
1299
+ f"Output '{output_name}' not found in the current scope. This is unexpected"
1300
+ )
1301
+ value = current_scope[output_name]
1240
1302
  node_outputs.append(value)
1241
1303
  return _core.Node(
1242
1304
  proto.domain,
@@ -1469,8 +1531,6 @@ def serialize_graph_into(
1469
1531
  serialize_value_into(graph_proto.input.add(), input_)
1470
1532
  if input_.name not in from_.initializers:
1471
1533
  # Annotations for initializers will be added below to avoid double adding
1472
- # TODO(justinchuby): We should add a method is_initializer() on Value when
1473
- # the initializer list is tracked
1474
1534
  _maybe_add_quantization_annotation(graph_proto, input_)
1475
1535
  input_names = {input_.name for input_ in from_.inputs}
1476
1536
  # TODO(justinchuby): Support sparse_initializer
@@ -1818,7 +1878,7 @@ def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.
1818
1878
  return value_info_proto
1819
1879
 
1820
1880
 
1821
- @_capture_errors(lambda value_info_proto, from_: repr(from_))
1881
+ @_capture_errors(lambda value_info_proto, from_, name="": repr(from_))
1822
1882
  def serialize_value_into(
1823
1883
  value_info_proto: onnx.ValueInfoProto,
1824
1884
  from_: _protocols.ValueProtocol,
@@ -68,7 +68,6 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
68
68
  torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
69
69
  torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
70
70
  torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
71
- torch.float8_e8m0fnu: ir.DataType.FLOAT8E8M0,
72
71
  torch.int16: ir.DataType.INT16,
73
72
  torch.int32: ir.DataType.INT32,
74
73
  torch.int64: ir.DataType.INT64,
@@ -78,6 +77,10 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
78
77
  torch.uint32: ir.DataType.UINT32,
79
78
  torch.uint64: ir.DataType.UINT64,
80
79
  }
80
+ if hasattr(torch, "float8_e8m0fnu"):
81
+ # torch.float8_e8m0fnu is available in PyTorch 2.7+
82
+ _TORCH_DTYPE_TO_ONNX[torch.float8_e8m0fnu] = ir.DataType.FLOAT8E8M0
83
+
81
84
  if dtype not in _TORCH_DTYPE_TO_ONNX:
82
85
  raise TypeError(
83
86
  f"Unsupported PyTorch dtype '{dtype}'. "
@@ -105,7 +108,6 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
105
108
  ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
106
109
  ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
107
110
  ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
108
- ir.DataType.FLOAT8E8M0: torch.float8_e8m0fnu,
109
111
  ir.DataType.INT16: torch.int16,
110
112
  ir.DataType.INT32: torch.int32,
111
113
  ir.DataType.INT64: torch.int64,
@@ -115,7 +117,17 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
115
117
  ir.DataType.UINT32: torch.uint32,
116
118
  ir.DataType.UINT64: torch.uint64,
117
119
  }
120
+
121
+ if hasattr(torch, "float8_e8m0fnu"):
122
+ # torch.float8_e8m0fnu is available in PyTorch 2.7+
123
+ _ONNX_DTYPE_TO_TORCH[ir.DataType.FLOAT8E8M0] = torch.float8_e8m0fnu
124
+
118
125
  if dtype not in _ONNX_DTYPE_TO_TORCH:
126
+ if dtype == ir.DataType.FLOAT8E8M0:
127
+ raise ValueError(
128
+ "The requested DataType 'FLOAT8E8M0' is only supported in PyTorch 2.7+. "
129
+ "Please upgrade your PyTorch version to use this dtype."
130
+ )
119
131
  raise TypeError(
120
132
  f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
121
133
  "Please use a supported dtype from the list: "
onnx_ir/traversal.py CHANGED
@@ -25,19 +25,33 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
25
25
  *,
26
26
  recursive: Callable[[_core.Node], bool] | None = None,
27
27
  reverse: bool = False,
28
+ enter_graph: Callable[[GraphLike], None] | None = None,
29
+ exit_graph: Callable[[GraphLike], None] | None = None,
28
30
  ):
29
31
  """Iterate over the nodes in the graph, recursively visiting subgraphs.
30
32
 
33
+ This iterator allows for traversing the nodes of a graph and its subgraphs
34
+ in a depth-first manner. It supports optional callbacks for entering and exiting
35
+ subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs
36
+ contained within nodes.
37
+
38
+ .. versionadded:: 0.1.6
39
+ Added the `enter_graph` and `exit_graph` callbacks.
40
+
31
41
  Args:
32
42
  graph_like: The graph to traverse.
33
43
  recursive: A callback that determines whether to recursively visit the subgraphs
34
44
  contained in a node. If not provided, all nodes in subgraphs are visited.
35
45
  reverse: Whether to iterate in reverse order.
46
+ enter_graph: An optional callback that is called when entering a subgraph.
47
+ exit_graph: An optional callback that is called when exiting a subgraph.
36
48
  """
37
49
  self._graph = graph_like
38
50
  self._recursive = recursive
39
51
  self._reverse = reverse
40
52
  self._iterator = self._recursive_node_iter(graph_like)
53
+ self._enter_graph = enter_graph
54
+ self._exit_graph = exit_graph
41
55
 
42
56
  def __iter__(self) -> Self:
43
57
  self._iterator = self._recursive_node_iter(self._graph)
@@ -50,34 +64,55 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
50
64
  self, graph: _core.Graph | _core.Function | _core.GraphView
51
65
  ) -> Iterator[_core.Node]:
52
66
  iterable = reversed(graph) if self._reverse else graph
67
+
68
+ if self._enter_graph is not None:
69
+ self._enter_graph(graph)
70
+
53
71
  for node in iterable: # type: ignore[union-attr]
54
72
  yield node
55
73
  if self._recursive is not None and not self._recursive(node):
56
74
  continue
57
75
  yield from self._iterate_subgraphs(node)
58
76
 
77
+ if self._exit_graph is not None:
78
+ self._exit_graph(graph)
79
+
59
80
  def _iterate_subgraphs(self, node: _core.Node):
60
81
  for attr in node.attributes.values():
61
82
  if not isinstance(attr, _core.Attr):
62
83
  continue
63
84
  if attr.type == _enums.AttributeType.GRAPH:
85
+ if self._enter_graph is not None:
86
+ self._enter_graph(attr.value)
64
87
  yield from RecursiveGraphIterator(
65
88
  attr.value,
66
89
  recursive=self._recursive,
67
90
  reverse=self._reverse,
91
+ enter_graph=self._enter_graph,
92
+ exit_graph=self._exit_graph,
68
93
  )
94
+ if self._exit_graph is not None:
95
+ self._exit_graph(attr.value)
69
96
  elif attr.type == _enums.AttributeType.GRAPHS:
70
97
  graphs = reversed(attr.value) if self._reverse else attr.value
71
98
  for graph in graphs:
99
+ if self._enter_graph is not None:
100
+ self._enter_graph(graph)
72
101
  yield from RecursiveGraphIterator(
73
102
  graph,
74
103
  recursive=self._recursive,
75
104
  reverse=self._reverse,
105
+ enter_graph=self._enter_graph,
106
+ exit_graph=self._exit_graph,
76
107
  )
108
+ if self._exit_graph is not None:
109
+ self._exit_graph(graph)
77
110
 
78
111
  def __reversed__(self) -> Iterator[_core.Node]:
79
112
  return RecursiveGraphIterator(
80
113
  self._graph,
81
114
  recursive=self._recursive,
82
115
  reverse=not self._reverse,
116
+ enter_graph=self._enter_graph,
117
+ exit_graph=self._exit_graph,
83
118
  )
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
- License: Apache License v2.0
6
+ License-Expression: Apache-2.0
7
7
  Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
- Classifier: License :: OSI Approved :: Apache Software License
17
16
  Requires-Python: >=3.9
18
17
  Description-Content-Type: text/markdown
19
18
  License-File: LICENSE
@@ -29,7 +28,6 @@ Dynamic: license-file
29
28
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
30
29
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
31
30
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
32
- [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
33
31
  [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
34
32
 
35
33
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
@@ -1,5 +1,5 @@
1
- onnx_ir/__init__.py,sha256=_995K-JXuL0upLulUJxCXziF1gMcehH3gzea2eukCyM,3424
2
- onnx_ir/_core.py,sha256=CtRwtDb__hK0MJLWsrNNu5n_xz6TlbJctDLw8UDQAZQ,137454
1
+ onnx_ir/__init__.py,sha256=GkXeM2FSKjT0TUO8ezCJdT1yHZKdtQ6keZKx2a3BluI,3424
2
+ onnx_ir/_core.py,sha256=XQRd43VQj72qBGLa_4x9NEjjfhM0rxJ7qT6sLKA_rGA,139032
3
3
  onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
4
4
  onnx_ir/_enums.py,sha256=SxC-GGgPrmdz6UsMhx7xT9-6VmkZ6j1oVzDqNUHr3Rc,9659
5
5
  onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
@@ -16,30 +16,31 @@ onnx_ir/_version_utils.py,sha256=bZThuE7meVHFOY1DLsmss9WshVIp9iig7udGfDbVaK4,133
16
16
  onnx_ir/convenience.py,sha256=0B1epuXZCSmY4FbW2vaYfR-t5ubxBZ1UruiytHs-zFw,917
17
17
  onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,18079
18
18
  onnx_ir/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
19
- onnx_ir/serde.py,sha256=bFQg5XYlDTvZsT_gDO_mPYedkMj_HcUbBvQuxLlRKvc,75980
19
+ onnx_ir/serde.py,sha256=Ld00k4L_TJ50T8FA0myV0C1hLr7EqwujZk6bBr_nGLQ,78174
20
20
  onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
21
- onnx_ir/tensor_adapters.py,sha256=Pl2eLXa1VQh0nZy6NFMBr_9BRY_OPoKQX1oa4K7ecUo,6717
21
+ onnx_ir/tensor_adapters.py,sha256=YffUeZDZi8thxm-4nF2cL6cNSJSVmLm4A3IbEzwY8QQ,7233
22
22
  onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
23
- onnx_ir/traversal.py,sha256=Z69wzYBNljn1S7PhVTYgwMftrfsdEBLoa0JYteOhLL0,2863
24
- onnx_ir/_convenience/__init__.py,sha256=DQ-Bz1wTiZJEARCFxDqZvYexWviGmwvDzE_1hR-vp0Q,19182
23
+ onnx_ir/traversal.py,sha256=Wy4XphwuapAvm94-5iaz6G8LjIoMFpY7qfPfXzYViEE,4488
24
+ onnx_ir/_convenience/__init__.py,sha256=bXUxjZ_91idQJ33zWtByQ0J4VsWCUdvAy9iIflpLtW8,19754
25
25
  onnx_ir/_convenience/_constructors.py,sha256=5GhlYy_xCE2ng7l_4cNx06WQsNDyvS-0U1HgOpPKJEk,8347
26
26
  onnx_ir/_thirdparty/asciichartpy.py,sha256=afQ0fsqko2uYRPAR4TZBrQxvCb4eN8lxZ2yDFbVQq_s,10533
27
27
  onnx_ir/passes/__init__.py,sha256=M_Tcl_-qGSNPluFIvOoeDyh0qAwNayaYyXDS5UJUJPQ,764
28
28
  onnx_ir/passes/_pass_infra.py,sha256=xIOw_zZIuOqD4Z_wZ4OvsqXfh2IZMoMlDp1xQ_MPQlc,9567
29
- onnx_ir/passes/common/__init__.py,sha256=LWkH39XATj1lQz82cVrxtle6YiZZ8RkT1fVZNthiTLI,1586
29
+ onnx_ir/passes/common/__init__.py,sha256=NYBrXvkq_CbWRwcZptSsF4u_-1zfN_BClLhVQY0pwYc,1738
30
30
  onnx_ir/passes/common/_c_api_utils.py,sha256=g6riA6xNGVWaO5YjVHZ0krrfslWHmRlryRkwB8X56cg,2907
31
31
  onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMOGjdvVwG9yHQTkSphUgDlM0ME,2365
32
32
  onnx_ir/passes/common/common_subexpression_elimination.py,sha256=wZ1zEPdCshYB_ifP9fCAVfzQkesE6uhCfzCuL2qO5fA,7948
33
- onnx_ir/passes/common/constant_manipulation.py,sha256=_fGDwn0Axl2Q8APfc2m_mLMH28T-Mc9kIlpzBXoe3q4,8779
34
- onnx_ir/passes/common/identity_elimination.py,sha256=FyqnJxFUq9Ga9XyUJ3myjzr36InYSW-oJgDTrUrBORY,3663
35
- onnx_ir/passes/common/initializer_deduplication.py,sha256=4CIVFYfdXUlmF2sAx560c_pTwYVXtX5hcSwWzUKm5uc,2061
33
+ onnx_ir/passes/common/constant_manipulation.py,sha256=dFzzqbpRecJJrYf6edvR_sdr4F0gV-1wEtDXsQ7fStM,9101
34
+ onnx_ir/passes/common/identity_elimination.py,sha256=wN8g8uPGn6IIQ6Jf1lo6nGTXvpWyiSQtT_CfmtvZpwA,3664
35
+ onnx_ir/passes/common/initializer_deduplication.py,sha256=k6IZdXrjANbVhTQCQAPIePUjqF83NG3YGwEYThYJJ7o,6655
36
36
  onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
37
+ onnx_ir/passes/common/naming.py,sha256=NNKc9IPrmzm3J0zGQILfooayVzfdXDYHY9DHex1hFgs,10927
37
38
  onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6svE0cCyDew,1691
38
39
  onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
39
40
  onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
40
41
  onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
41
- onnx_ir-0.1.5.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
42
- onnx_ir-0.1.5.dist-info/METADATA,sha256=SHH7BxuFCKIsWyRKQyOKbXRtZX8n0ryietlWDPPLBvA,4884
43
- onnx_ir-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
44
- onnx_ir-0.1.5.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
45
- onnx_ir-0.1.5.dist-info/RECORD,,
42
+ onnx_ir-0.1.7.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
43
+ onnx_ir-0.1.7.dist-info/METADATA,sha256=M4-BdpNXpv18P_tALf6KdUdXeCO2JrVxbxtzs4HCmJI,3462
44
+ onnx_ir-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ onnx_ir-0.1.7.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
46
+ onnx_ir-0.1.7.dist-info/RECORD,,