onnx-ir 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -164,4 +164,4 @@ def __set_module() -> None:
164
164
 
165
165
 
166
166
  __set_module()
167
- __version__ = "0.1.0"
167
+ __version__ = "0.1.1"
onnx_ir/_core.py CHANGED
@@ -22,13 +22,12 @@ import os
22
22
  import sys
23
23
  import textwrap
24
24
  import typing
25
- from collections import OrderedDict
26
25
  from collections.abc import (
27
26
  Collection,
28
27
  Hashable,
29
28
  Iterable,
30
29
  Iterator,
31
- MutableMapping,
30
+ Mapping,
32
31
  MutableSequence,
33
32
  Sequence,
34
33
  )
@@ -1325,7 +1324,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1325
1324
  domain: str,
1326
1325
  op_type: str,
1327
1326
  inputs: Iterable[Value | None],
1328
- attributes: Iterable[Attr] = (),
1327
+ attributes: Iterable[Attr] | Mapping[str, Attr] = (),
1329
1328
  *,
1330
1329
  overload: str = "",
1331
1330
  num_outputs: int | None = None,
@@ -1371,15 +1370,10 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1371
1370
  self._inputs: tuple[Value | None, ...] = tuple(inputs)
1372
1371
  # Values belong to their defining nodes. The values list is immutable
1373
1372
  self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
1374
- attributes = tuple(attributes)
1375
- if attributes and not isinstance(attributes[0], Attr):
1376
- raise TypeError(
1377
- f"Expected the attributes to be Attr, got {type(attributes[0])}. "
1378
- "If you are copying the attributes from another node, make sure you call "
1379
- "node.attributes.values() because it is a dictionary."
1380
- )
1381
- self._attributes: OrderedDict[str, Attr] = OrderedDict(
1382
- (attr.name, attr) for attr in attributes
1373
+ if isinstance(attributes, Mapping):
1374
+ attributes = tuple(attributes.values())
1375
+ self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
1376
+ attributes
1383
1377
  )
1384
1378
  self._overload: str = overload
1385
1379
  # TODO(justinchuby): Potentially support a version range
@@ -1637,7 +1631,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1637
1631
  raise AttributeError("outputs is immutable. Please create a new node instead.")
1638
1632
 
1639
1633
  @property
1640
- def attributes(self) -> OrderedDict[str, Attr]:
1634
+ def attributes(self) -> _graph_containers.Attributes:
1641
1635
  """The attributes of the node."""
1642
1636
  return self._attributes
1643
1637
 
@@ -2201,17 +2195,9 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2201
2195
  # Private fields that are not to be accessed by any other classes
2202
2196
  self._inputs = _graph_containers.GraphInputs(self, inputs)
2203
2197
  self._outputs = _graph_containers.GraphOutputs(self, outputs)
2204
- self._initializers = _graph_containers.GraphInitializers(self)
2205
- for initializer in initializers:
2206
- if isinstance(initializer, str):
2207
- raise TypeError(
2208
- "Initializer must be a Value, not a string. "
2209
- "If you are copying the initializers from another graph, "
2210
- "make sure you call graph.initializers.values() because it is a dictionary."
2211
- )
2212
- if initializer.name is None:
2213
- raise ValueError(f"Initializer must have a name: {initializer}")
2214
- self._initializers[initializer.name] = initializer
2198
+ self._initializers = _graph_containers.GraphInitializers(
2199
+ self, {initializer.name: initializer for initializer in initializers}
2200
+ )
2215
2201
  self._doc_string = doc_string
2216
2202
  self._opset_imports = opset_imports or {}
2217
2203
  self._metadata: _metadata.MetadataStore | None = None
@@ -2234,7 +2220,19 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2234
2220
  return self._outputs
2235
2221
 
2236
2222
  @property
2237
- def initializers(self) -> MutableMapping[str, Value]:
2223
+ def initializers(self) -> _graph_containers.GraphInitializers:
2224
+ """The initializers of the graph as a ``MutableMapping[str, Value]``.
2225
+
2226
+ The keys are the names of the initializers. The values are the :class:`Value` objects.
2227
+
2228
+ This property additionally supports the ``add`` method, which takes a :class:`Value`
2229
+ and adds it to the initializers if it is not already present.
2230
+
2231
+ .. note::
2232
+ When setting an initializer with ``graph.initializers[key] = value``,
2233
+ if the value does not have a name, it will be assigned ``key`` as its name.
2234
+
2235
+ """
2238
2236
  return self._initializers
2239
2237
 
2240
2238
  def register_initializer(self, value: Value) -> None:
@@ -2263,15 +2261,11 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2263
2261
  " it is not the same object: existing={self._initializers[value.name]!r},"
2264
2262
  f" new={value!r}"
2265
2263
  )
2266
- if value.producer() is not None:
2267
- raise ValueError(
2268
- f"Value '{value!r}' is produced by a node and cannot be an initializer."
2269
- )
2270
2264
  if value.const_value is None:
2271
2265
  raise ValueError(
2272
2266
  f"Value '{value!r}' must have its const_value set to be an initializer."
2273
2267
  )
2274
- self._initializers[value.name] = value
2268
+ self._initializers.add(value)
2275
2269
 
2276
2270
  @property
2277
2271
  def doc_string(self) -> str | None:
@@ -2701,7 +2695,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
2701
2695
  outputs: Sequence[Value],
2702
2696
  *,
2703
2697
  nodes: Iterable[Node],
2704
- initializers: Sequence[_protocols.ValueProtocol] = (),
2698
+ initializers: Sequence[Value] = (),
2705
2699
  doc_string: str | None = None,
2706
2700
  opset_imports: dict[str, int] | None = None,
2707
2701
  name: str | None = None,
@@ -2710,10 +2704,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
2710
2704
  self.name = name
2711
2705
  self.inputs = tuple(inputs)
2712
2706
  self.outputs = tuple(outputs)
2713
- for initializer in initializers:
2714
- if initializer.name is None:
2715
- raise ValueError(f"Initializer must have a name: {initializer}")
2716
- self.initializers = {tensor.name: tensor for tensor in initializers}
2707
+ self.initializers = {initializer.name: initializer for initializer in initializers}
2717
2708
  self.doc_string = doc_string
2718
2709
  self.opset_imports = opset_imports or {}
2719
2710
  self._metadata: _metadata.MetadataStore | None = None
@@ -2927,13 +2918,15 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2927
2918
  # Ensure the inputs and outputs of the function belong to a graph
2928
2919
  # and not from an outer scope
2929
2920
  graph: Graph,
2930
- attributes: Sequence[Attr],
2921
+ attributes: Iterable[Attr] | Mapping[str, Attr],
2931
2922
  ) -> None:
2932
2923
  self._domain = domain
2933
2924
  self._name = name
2934
2925
  self._overload = overload
2935
2926
  self._graph = graph
2936
- self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
2927
+ if isinstance(attributes, Mapping):
2928
+ attributes = tuple(attributes.values())
2929
+ self._attributes = _graph_containers.Attributes(attributes)
2937
2930
 
2938
2931
  def identifier(self) -> _protocols.OperatorIdentifier:
2939
2932
  return self.domain, self.name, self.overload
@@ -2971,7 +2964,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2971
2964
  return self._graph.outputs
2972
2965
 
2973
2966
  @property
2974
- def attributes(self) -> OrderedDict[str, Attr]:
2967
+ def attributes(self) -> _graph_containers.Attributes:
2975
2968
  return self._attributes
2976
2969
 
2977
2970
  @typing.overload
@@ -12,13 +12,16 @@ __all__ = [
12
12
  ]
13
13
 
14
14
  import collections
15
- from collections.abc import Iterable
16
- from typing import TYPE_CHECKING, SupportsIndex
15
+ import logging
16
+ from collections.abc import Iterable, Sequence
17
+ from typing import SupportsIndex, TypeVar
17
18
 
18
19
  import onnx_ir
20
+ from onnx_ir import _core, _protocols
19
21
 
20
- if TYPE_CHECKING:
21
- from onnx_ir import _core
22
+ T = TypeVar("T")
23
+
24
+ logger = logging.getLogger(__name__)
22
25
 
23
26
 
24
27
  class _GraphIO(collections.UserList["_core.Value"]):
@@ -152,6 +155,10 @@ class GraphInputs(_GraphIO):
152
155
  raise ValueError(
153
156
  f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
154
157
  )
158
+ if value.producer() is not None:
159
+ raise ValueError(
160
+ f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
161
+ )
155
162
  self._ref_counter[value] += 1
156
163
  value._is_graph_input = True
157
164
  value._graph = self._graph
@@ -244,12 +251,23 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
244
251
 
245
252
  def __setitem__(self, key: str, value: _core.Value) -> None:
246
253
  """Set an initializer for the graph."""
247
- if key != value.name:
254
+ if not isinstance(value, _core.Value):
255
+ raise TypeError(f"value must be a Value object, not {type(value)}")
256
+ if not isinstance(key, str):
257
+ raise TypeError(f"Value name must be a string, not {type(key)}")
258
+ if key == "":
259
+ raise ValueError("Value name cannot be an empty string")
260
+ if not value.name:
261
+ logger.info("Value %s does not have a name, setting it to '%s'", value, key)
262
+ value.name = key
263
+ elif key != value.name:
248
264
  raise ValueError(
249
- f"Key '{key}' does not match the name of the value '{value.name}'"
265
+ f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
266
+ )
267
+ if value.producer() is not None:
268
+ raise ValueError(
269
+ f"Value '{value}' is produced by a node and cannot be a graph initializer"
250
270
  )
251
- if not isinstance(key, str):
252
- raise TypeError(f"Key must be a string, not {type(key)}")
253
271
  if key in self.data:
254
272
  # If the key already exists, unset the old value
255
273
  old_value = self.data[key]
@@ -266,3 +284,90 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
266
284
  # the dictionary is not modified
267
285
  self._maybe_unset_graph(value)
268
286
  super().__delitem__(key)
287
+
288
+ def add(self, value: _core.Value) -> None:
289
+ """Add an initializer to the graph."""
290
+ self[value.name] = value # type: ignore[index]
291
+
292
+
293
+ class Attributes(collections.UserDict[str, "_core.Attr"]):
294
+ """The attributes of a Node."""
295
+
296
+ def __init__(self, attrs: Iterable[_core.Attr]):
297
+ super().__init__({attr.name: attr for attr in attrs})
298
+
299
+ def __setitem__(self, key: str, value: _core.Attr) -> None:
300
+ """Set an attribute for the node."""
301
+ if type(key) is not str:
302
+ raise TypeError(f"Key must be a string, not {type(key)}")
303
+ if not isinstance(value, _core.Attr):
304
+ raise TypeError(f"Value must be an Attr, not {type(value)}")
305
+ super().__setitem__(key, value)
306
+
307
+ def add(self, value: _core.Attr) -> None:
308
+ """Add an attribute to the node."""
309
+ self[value.name] = value
310
+
311
+ def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
312
+ """Get the integer value of the attribute."""
313
+ if key in self:
314
+ return self[key].as_int()
315
+ return default
316
+
317
+ def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
318
+ """Get the float value of the attribute."""
319
+ if key in self:
320
+ return self[key].as_float()
321
+ return default
322
+
323
+ def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
324
+ """Get the string value of the attribute."""
325
+ if key in self:
326
+ return self[key].as_string()
327
+ return default
328
+
329
+ def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
330
+ """Get the tensor value of the attribute."""
331
+ if key in self:
332
+ return self[key].as_tensor()
333
+ return default
334
+
335
+ def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
336
+ """Get the graph value of the attribute."""
337
+ if key in self:
338
+ return self[key].as_graph()
339
+ return default
340
+
341
+ def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
342
+ """Get the Sequence of integers from the attribute."""
343
+ if key in self:
344
+ return self[key].as_ints()
345
+ return default
346
+
347
+ def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
348
+ """Get the Sequence of floats from the attribute."""
349
+ if key in self:
350
+ return self[key].as_floats()
351
+ return default
352
+
353
+ def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
354
+ """Get the Sequence of strings from the attribute."""
355
+ if key in self:
356
+ return self[key].as_strings()
357
+ return default
358
+
359
+ def get_tensors(
360
+ self,
361
+ key: str,
362
+ default: T = None, # type: ignore[assignment]
363
+ ) -> Sequence[_protocols.TensorProtocol] | T:
364
+ """Get the Sequence of tensors from the attribute."""
365
+ if key in self:
366
+ return self[key].as_tensors()
367
+ return default
368
+
369
+ def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
370
+ """Get the Sequence of graphs from the attribute."""
371
+ if key in self:
372
+ return self[key].as_graphs()
373
+ return default
@@ -5,6 +5,7 @@ __all__ = [
5
5
  "AddInitializersToInputsPass",
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
+ "CommonSubexpressionEliminationPass",
8
9
  "InlinePass",
9
10
  "LiftConstantsToInitializersPass",
10
11
  "LiftSubgraphInitializersToMainGraphPass",
@@ -19,6 +20,9 @@ __all__ = [
19
20
  from onnx_ir.passes.common.clear_metadata_and_docstring import (
20
21
  ClearMetadataAndDocStringPass,
21
22
  )
23
+ from onnx_ir.passes.common.common_subexpression_elimination import (
24
+ CommonSubexpressionEliminationPass,
25
+ )
22
26
  from onnx_ir.passes.common.constant_manipulation import (
23
27
  AddInitializersToInputsPass,
24
28
  LiftConstantsToInitializersPass,
@@ -0,0 +1,177 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Eliminate common subexpression in ONNX graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CommonSubexpressionEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+ from collections.abc import Sequence
13
+
14
+ import onnx_ir as ir
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
+ """Eliminate common subexpression in ONNX graphs."""
21
+
22
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
23
+ """Return the same ir.Model but with CSE applied to the graph."""
24
+ modified = False
25
+ graph = model.graph
26
+
27
+ modified = _eliminate_common_subexpression(graph, modified)
28
+
29
+ return ir.passes.PassResult(
30
+ model,
31
+ modified=modified,
32
+ )
33
+
34
+
35
+ def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36
+ """Eliminate common subexpression in ONNX graphs."""
37
+ # node to node identifier, length of outputs, inputs, and attributes
38
+ existing_node_info_to_the_node: dict[
39
+ tuple[
40
+ ir.OperatorIdentifier,
41
+ int, # len(outputs)
42
+ tuple[int, ...], # input ids
43
+ tuple[tuple[str, object], ...], # attributes
44
+ ],
45
+ ir.Node,
46
+ ] = {}
47
+
48
+ for node in graph:
49
+ # Skip control flow ops like Loop and If.
50
+ control_flow_op: bool = False
51
+ # Use equality to check if the node is a common subexpression.
52
+ attributes = {}
53
+ for k, v in node.attributes.items():
54
+ # TODO(exporter team): CSE subgraphs.
55
+ # NOTE: control flow ops like Loop and If won't be CSEd
56
+ # because attribute: graph won't match.
57
+ if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
58
+ control_flow_op = True
59
+ logger.debug("Skipping control flow op %s", node)
60
+ # The attribute value could be directly taken from the original
61
+ # protobuf, so we need to make a copy of it.
62
+ value = v.value
63
+ if v.type in (
64
+ ir.AttributeType.INTS,
65
+ ir.AttributeType.FLOATS,
66
+ ir.AttributeType.STRINGS,
67
+ ):
68
+ # For INT, FLOAT and STRING attributes, we convert them to tuples
69
+ # to ensure they are hashable.
70
+ value = tuple(value)
71
+ attributes[k] = value
72
+
73
+ if control_flow_op:
74
+ # If the node is a control flow op, we skip it.
75
+ logger.debug("Skipping control flow op %s", node)
76
+ continue
77
+
78
+ if _is_non_deterministic_op(node):
79
+ # If the node is a non-deterministic op, we skip it.
80
+ logger.debug("Skipping non-deterministic op %s", node)
81
+ continue
82
+
83
+ node_info = (
84
+ node.op_identifier(),
85
+ len(node.outputs),
86
+ tuple(id(input) for input in node.inputs),
87
+ tuple(sorted(attributes.items())),
88
+ )
89
+ # Check if the node is a common subexpression.
90
+ if node_info in existing_node_info_to_the_node:
91
+ # If it is, this node has an existing node with the same
92
+ # operator, number of outputs, inputs, and attributes.
93
+ # We replace the node with the existing node.
94
+ modified = True
95
+ existing_node = existing_node_info_to_the_node[node_info]
96
+ _remove_node_and_replace_values(
97
+ graph,
98
+ remove_node=node,
99
+ remove_values=node.outputs,
100
+ new_values=existing_node.outputs,
101
+ )
102
+ logger.debug("Reusing node %s", existing_node)
103
+ else:
104
+ # If it is not, add to the mapping.
105
+ existing_node_info_to_the_node[node_info] = node
106
+ return modified
107
+
108
+
109
+ def _remove_node_and_replace_values(
110
+ graph: ir.Graph,
111
+ /,
112
+ remove_node: ir.Node,
113
+ remove_values: Sequence[ir.Value],
114
+ new_values: Sequence[ir.Value],
115
+ ) -> None:
116
+ """Replaces nodes and values in the graph or function.
117
+
118
+ Args:
119
+ graph: The graph to replace nodes and values in.
120
+ remove_node: The node to remove.
121
+ remove_values: The values to replace.
122
+ new_values: The values to replace with.
123
+ """
124
+ # Reconnect the users of the deleted values to use the new values
125
+ ir.convenience.replace_all_uses_with(remove_values, new_values)
126
+ # Update graph/function outputs if the node generates output
127
+ if any(remove_value.is_graph_output() for remove_value in remove_values):
128
+ replacement_mapping = dict(zip(remove_values, new_values))
129
+ for idx, graph_output in enumerate(graph.outputs):
130
+ if graph_output in replacement_mapping:
131
+ new_value = replacement_mapping[graph_output]
132
+ if new_value.is_graph_output() or new_value.is_graph_input():
133
+ # If the new value is also a graph input/output, we need to
134
+ # create a Identity node to preserve the remove_value and
135
+ # prevent from changing new_value name.
136
+ identity_node = ir.node(
137
+ "Identity",
138
+ inputs=[new_value],
139
+ outputs=[
140
+ ir.Value(
141
+ name=graph_output.name,
142
+ type=graph_output.type,
143
+ shape=graph_output.shape,
144
+ )
145
+ ],
146
+ )
147
+ # reuse the name of the graph output
148
+ graph.outputs[idx] = identity_node.outputs[0]
149
+ graph.insert_before(
150
+ remove_node,
151
+ identity_node,
152
+ )
153
+ else:
154
+ # if new_value is not graph output, we just
155
+ # update it to use old_value name.
156
+ new_value.name = graph_output.name
157
+ graph.outputs[idx] = new_value
158
+
159
+ graph.remove(remove_node, safe=True)
160
+
161
+
162
+ def _is_non_deterministic_op(node: ir.Node) -> bool:
163
+ non_deterministic_ops = frozenset(
164
+ {
165
+ "RandomUniform",
166
+ "RandomNormal",
167
+ "RandomUniformLike",
168
+ "RandomNormalLike",
169
+ "Multinomial",
170
+ }
171
+ )
172
+ return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
173
+
174
+
175
+ def _is_onnx_domain(d: str) -> bool:
176
+ """Check if the domain is the ONNX domain."""
177
+ return d == ""
@@ -139,35 +139,11 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
139
139
  for further processing or optimization.
140
140
 
141
141
  Initializers that are also graph inputs will not be lifted.
142
-
143
- Preconditions:
144
- - All initializers in the model must have unique names across the main graph and subgraphs.
145
142
  """
146
143
 
147
- def requires(self, model: ir.Model) -> None:
148
- """Ensure all initializer names are unique."""
149
- registered_initializer_names: set[str] = set()
150
- duplicated_initializers: list[ir.Value] = []
151
- for graph in model.graphs():
152
- for initializer in graph.initializers.values():
153
- if initializer.name is None:
154
- raise ir.passes.PreconditionError(
155
- f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}"
156
- )
157
- if initializer.name in registered_initializer_names:
158
- duplicated_initializers.append(initializer)
159
- else:
160
- registered_initializer_names.add(initializer.name)
161
- if duplicated_initializers:
162
- raise ir.passes.PreconditionError(
163
- "Found duplicated initializers in the model. "
164
- "Initializer name must be unique across the main graph and subgraphs. "
165
- "Please ensure all initializers have unique names. Duplicated: "
166
- f"{duplicated_initializers!r}"
167
- )
168
-
169
144
  def call(self, model: ir.Model) -> ir.passes.PassResult:
170
145
  count = 0
146
+ registered_initializer_names: dict[str, int] = {}
171
147
  for graph in model.graphs():
172
148
  if graph is model.graph:
173
149
  continue
@@ -182,6 +158,15 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
182
158
  continue
183
159
  # Remove the initializer from the subgraph
184
160
  graph.initializers.pop(name)
161
+ # To avoid name conflicts, we need to rename the initializer
162
+ # to a unique name in the main graph
163
+ if name in registered_initializer_names:
164
+ name_count = registered_initializer_names[name]
165
+ initializer.name = f"{name}_{name_count}"
166
+ registered_initializer_names[name] = name_count + 1
167
+ else:
168
+ assert initializer.name is not None
169
+ registered_initializer_names[initializer.name] = 1
185
170
  model.graph.register_initializer(initializer)
186
171
  count += 1
187
172
  logger.debug(
@@ -9,7 +9,7 @@ import dataclasses
9
9
  __all__ = ["InlinePass", "InlinePassResult"]
10
10
 
11
11
  from collections import defaultdict
12
- from collections.abc import Iterable, Sequence
12
+ from collections.abc import Iterable, Mapping, Sequence
13
13
 
14
14
  import onnx_ir as ir
15
15
  import onnx_ir.convenience as _ir_convenience
@@ -52,7 +52,7 @@ class _CopyReplace:
52
52
  def __init__(
53
53
  self,
54
54
  inliner: InlinePass,
55
- attr_map: dict[str, ir.Attr],
55
+ attr_map: Mapping[str, ir.Attr],
56
56
  value_map: dict[ir.Value, ir.Value | None],
57
57
  metadata_props: dict[str, str],
58
58
  call_stack: CallStack,
@@ -96,6 +96,7 @@ class _CopyReplace:
96
96
  return attr
97
97
  assert attr.is_ref()
98
98
  ref_attr_name = attr.ref_attr_name
99
+ assert ref_attr_name is not None, "Reference attribute must have a name"
99
100
  if ref_attr_name in self._attr_map:
100
101
  ref_attr = self._attr_map[ref_attr_name]
101
102
  if not ref_attr.is_ref():
@@ -237,7 +238,7 @@ class InlinePass(ir.passes.InPlacePass):
237
238
  )
238
239
 
239
240
  # Identify substitutions for both inputs and attributes of the function:
240
- attributes: dict[str, ir.Attr] = node.attributes
241
+ attributes: Mapping[str, ir.Attr] = node.attributes
241
242
  default_attr_values = {
242
243
  attr.name: attr
243
244
  for attr in function.attributes.values()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -1,9 +1,9 @@
1
- onnx_ir/__init__.py,sha256=igWUy_HVjkfd34fuqT-33X-iEu5QONDkj9Fq5Pa_wJg,3352
2
- onnx_ir/_core.py,sha256=sw3UcXLFQoFE0kFeEP3Rvn0i3rcdk6W-iOkyAaC5Lyg,128801
1
+ onnx_ir/__init__.py,sha256=0fD02tkU7-bC9BfPS68TP2500619oJ8NZyGx3CdGmVk,3352
2
+ onnx_ir/_core.py,sha256=7nufz-9r8J3d6R4BzmRKq0DwmWosOZp3ICNr9MfMG0E,128316
3
3
  onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
4
4
  onnx_ir/_enums.py,sha256=zMvRvYyxOg0Rf3DCQ5Sn1TyZ5znj4NuGO-OAOKZCiDM,7880
5
5
  onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
6
- onnx_ir/_graph_containers.py,sha256=3EQLwxhFxc4b1DR4wwrW6F6sC1yr2hhUrZ2io4gmZRE,9884
6
+ onnx_ir/_graph_containers.py,sha256=hK3R3OrQTMXF8_z9Kx1DBtJriq_NQx8MUAFy7GpTZ2U,14154
7
7
  onnx_ir/_io.py,sha256=XmVqvM2lyX7QtXGr0KcV4bboRGTOPJ8BP4YtQ-jh4dg,3886
8
8
  onnx_ir/_linked_list.py,sha256=PXVcbHLMXHLZ6DxZnElnJLWfhBPvYcXUxM8Y3K4J7lM,10592
9
9
  onnx_ir/_metadata.py,sha256=lzmCaYy4kAJrPW-PSGKF4a78LisxF0hiofySNQ3Mwhg,1544
@@ -25,17 +25,18 @@ onnx_ir/_convenience/_constructors.py,sha256=nA0tytizoFhQeN6gpxVx3khJQXq-tRtIh0U
25
25
  onnx_ir/_thirdparty/asciichartpy.py,sha256=afQ0fsqko2uYRPAR4TZBrQxvCb4eN8lxZ2yDFbVQq_s,10533
26
26
  onnx_ir/passes/__init__.py,sha256=M_Tcl_-qGSNPluFIvOoeDyh0qAwNayaYyXDS5UJUJPQ,764
27
27
  onnx_ir/passes/_pass_infra.py,sha256=HEzxDbXjIUPVubv4pxsPTFXiCDPoiM_tPEoEH1mHO70,9560
28
- onnx_ir/passes/common/__init__.py,sha256=lLBRSVPh90_aGWkUBmN-D0hcDgBFjQf-Juli7Cs6zv8,1182
28
+ onnx_ir/passes/common/__init__.py,sha256=aHjx2y7L7LJChixmKsSUCdiaTP1u-zSmcmEISduqeG4,1335
29
29
  onnx_ir/passes/common/_c_api_utils.py,sha256=cr0vOhnZ-0lOcZV_mOS3Gn-cUK73CPzjAjfbYA-PJuQ,2891
30
30
  onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMOGjdvVwG9yHQTkSphUgDlM0ME,2365
31
- onnx_ir/passes/common/constant_manipulation.py,sha256=RfcQa-KijcTzNTbixOz_YDs2ssfgxdzsic4T9rgzAfI,9456
32
- onnx_ir/passes/common/inliner.py,sha256=8yg79ae764qaJPnDFnbdBE8sR0Y4MbpZQNfjBo72Gkg,13606
31
+ onnx_ir/passes/common/common_subexpression_elimination.py,sha256=WMsTAI-12A3iVqptmWw0tiBmGwVKsls5VNxZEbjvp2A,6527
32
+ onnx_ir/passes/common/constant_manipulation.py,sha256=_fGDwn0Axl2Q8APfc2m_mLMH28T-Mc9kIlpzBXoe3q4,8779
33
+ onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
33
34
  onnx_ir/passes/common/onnx_checker.py,sha256=4RdWgleYHs36pRRiUCbojkBrw80b1LX88xmj5NLclMg,1675
34
35
  onnx_ir/passes/common/shape_inference.py,sha256=J5VWsLbx9dPwV1JTuaRBObliiVHEb978AxHq_9dOGII,3976
35
36
  onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
36
37
  onnx_ir/passes/common/unused_removal.py,sha256=n1Vr8kSv3HGZyxFin_Kyx79GasfmhlQRVdJ0hGeZnv0,7597
37
- onnx_ir-0.1.0.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
38
- onnx_ir-0.1.0.dist-info/METADATA,sha256=Ujo2fno9o-4TNFx62wsebgHkpb0PdtZoS_lia0BVmt4,4586
39
- onnx_ir-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- onnx_ir-0.1.0.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
41
- onnx_ir-0.1.0.dist-info/RECORD,,
38
+ onnx_ir-0.1.1.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
39
+ onnx_ir-0.1.1.dist-info/METADATA,sha256=W3i284mv7QuWNNkjRy7x_zHEsMwgUpXvmoux6VE0vZQ,4586
40
+ onnx_ir-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
41
+ onnx_ir-0.1.1.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
42
+ onnx_ir-0.1.1.dist-info/RECORD,,