onnx-ir 0.1.0__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.0/src/onnx_ir.egg-info → onnx_ir-0.1.1}/PKG-INFO +1 -1
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_core.py +31 -38
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_graph_containers.py +113 -8
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir-0.1.1/src/onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/constant_manipulation.py +10 -25
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/inliner.py +4 -3
- {onnx_ir-0.1.0 → onnx_ir-0.1.1/src/onnx_ir.egg-info}/PKG-INFO +1 -1
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir.egg-info/SOURCES.txt +1 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/LICENSE +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/MANIFEST.in +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/README.md +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/pyproject.toml +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/setup.cfg +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_convenience/__init__.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_enums.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_protocols.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_type_casting.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/external_data.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/serde.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/tensor_adapters.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir/traversal.py +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.0 → onnx_ir-0.1.1}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -22,13 +22,12 @@ import os
|
|
|
22
22
|
import sys
|
|
23
23
|
import textwrap
|
|
24
24
|
import typing
|
|
25
|
-
from collections import OrderedDict
|
|
26
25
|
from collections.abc import (
|
|
27
26
|
Collection,
|
|
28
27
|
Hashable,
|
|
29
28
|
Iterable,
|
|
30
29
|
Iterator,
|
|
31
|
-
|
|
30
|
+
Mapping,
|
|
32
31
|
MutableSequence,
|
|
33
32
|
Sequence,
|
|
34
33
|
)
|
|
@@ -1325,7 +1324,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1325
1324
|
domain: str,
|
|
1326
1325
|
op_type: str,
|
|
1327
1326
|
inputs: Iterable[Value | None],
|
|
1328
|
-
attributes: Iterable[Attr] = (),
|
|
1327
|
+
attributes: Iterable[Attr] | Mapping[str, Attr] = (),
|
|
1329
1328
|
*,
|
|
1330
1329
|
overload: str = "",
|
|
1331
1330
|
num_outputs: int | None = None,
|
|
@@ -1371,15 +1370,10 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1371
1370
|
self._inputs: tuple[Value | None, ...] = tuple(inputs)
|
|
1372
1371
|
# Values belong to their defining nodes. The values list is immutable
|
|
1373
1372
|
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
"If you are copying the attributes from another node, make sure you call "
|
|
1379
|
-
"node.attributes.values() because it is a dictionary."
|
|
1380
|
-
)
|
|
1381
|
-
self._attributes: OrderedDict[str, Attr] = OrderedDict(
|
|
1382
|
-
(attr.name, attr) for attr in attributes
|
|
1373
|
+
if isinstance(attributes, Mapping):
|
|
1374
|
+
attributes = tuple(attributes.values())
|
|
1375
|
+
self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
|
|
1376
|
+
attributes
|
|
1383
1377
|
)
|
|
1384
1378
|
self._overload: str = overload
|
|
1385
1379
|
# TODO(justinchuby): Potentially support a version range
|
|
@@ -1637,7 +1631,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1637
1631
|
raise AttributeError("outputs is immutable. Please create a new node instead.")
|
|
1638
1632
|
|
|
1639
1633
|
@property
|
|
1640
|
-
def attributes(self) ->
|
|
1634
|
+
def attributes(self) -> _graph_containers.Attributes:
|
|
1641
1635
|
"""The attributes of the node."""
|
|
1642
1636
|
return self._attributes
|
|
1643
1637
|
|
|
@@ -2201,17 +2195,9 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2201
2195
|
# Private fields that are not to be accessed by any other classes
|
|
2202
2196
|
self._inputs = _graph_containers.GraphInputs(self, inputs)
|
|
2203
2197
|
self._outputs = _graph_containers.GraphOutputs(self, outputs)
|
|
2204
|
-
self._initializers = _graph_containers.GraphInitializers(
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
raise TypeError(
|
|
2208
|
-
"Initializer must be a Value, not a string. "
|
|
2209
|
-
"If you are copying the initializers from another graph, "
|
|
2210
|
-
"make sure you call graph.initializers.values() because it is a dictionary."
|
|
2211
|
-
)
|
|
2212
|
-
if initializer.name is None:
|
|
2213
|
-
raise ValueError(f"Initializer must have a name: {initializer}")
|
|
2214
|
-
self._initializers[initializer.name] = initializer
|
|
2198
|
+
self._initializers = _graph_containers.GraphInitializers(
|
|
2199
|
+
self, {initializer.name: initializer for initializer in initializers}
|
|
2200
|
+
)
|
|
2215
2201
|
self._doc_string = doc_string
|
|
2216
2202
|
self._opset_imports = opset_imports or {}
|
|
2217
2203
|
self._metadata: _metadata.MetadataStore | None = None
|
|
@@ -2234,7 +2220,19 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2234
2220
|
return self._outputs
|
|
2235
2221
|
|
|
2236
2222
|
@property
|
|
2237
|
-
def initializers(self) ->
|
|
2223
|
+
def initializers(self) -> _graph_containers.GraphInitializers:
|
|
2224
|
+
"""The initializers of the graph as a ``MutableMapping[str, Value]``.
|
|
2225
|
+
|
|
2226
|
+
The keys are the names of the initializers. The values are the :class:`Value` objects.
|
|
2227
|
+
|
|
2228
|
+
This property additionally supports the ``add`` method, which takes a :class:`Value`
|
|
2229
|
+
and adds it to the initializers if it is not already present.
|
|
2230
|
+
|
|
2231
|
+
.. note::
|
|
2232
|
+
When setting an initializer with ``graph.initializers[key] = value``,
|
|
2233
|
+
if the value does not have a name, it will be assigned ``key`` as its name.
|
|
2234
|
+
|
|
2235
|
+
"""
|
|
2238
2236
|
return self._initializers
|
|
2239
2237
|
|
|
2240
2238
|
def register_initializer(self, value: Value) -> None:
|
|
@@ -2263,15 +2261,11 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2263
2261
|
" it is not the same object: existing={self._initializers[value.name]!r},"
|
|
2264
2262
|
f" new={value!r}"
|
|
2265
2263
|
)
|
|
2266
|
-
if value.producer() is not None:
|
|
2267
|
-
raise ValueError(
|
|
2268
|
-
f"Value '{value!r}' is produced by a node and cannot be an initializer."
|
|
2269
|
-
)
|
|
2270
2264
|
if value.const_value is None:
|
|
2271
2265
|
raise ValueError(
|
|
2272
2266
|
f"Value '{value!r}' must have its const_value set to be an initializer."
|
|
2273
2267
|
)
|
|
2274
|
-
self._initializers
|
|
2268
|
+
self._initializers.add(value)
|
|
2275
2269
|
|
|
2276
2270
|
@property
|
|
2277
2271
|
def doc_string(self) -> str | None:
|
|
@@ -2701,7 +2695,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
|
|
|
2701
2695
|
outputs: Sequence[Value],
|
|
2702
2696
|
*,
|
|
2703
2697
|
nodes: Iterable[Node],
|
|
2704
|
-
initializers: Sequence[
|
|
2698
|
+
initializers: Sequence[Value] = (),
|
|
2705
2699
|
doc_string: str | None = None,
|
|
2706
2700
|
opset_imports: dict[str, int] | None = None,
|
|
2707
2701
|
name: str | None = None,
|
|
@@ -2710,10 +2704,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
|
|
|
2710
2704
|
self.name = name
|
|
2711
2705
|
self.inputs = tuple(inputs)
|
|
2712
2706
|
self.outputs = tuple(outputs)
|
|
2713
|
-
for initializer in initializers
|
|
2714
|
-
if initializer.name is None:
|
|
2715
|
-
raise ValueError(f"Initializer must have a name: {initializer}")
|
|
2716
|
-
self.initializers = {tensor.name: tensor for tensor in initializers}
|
|
2707
|
+
self.initializers = {initializer.name: initializer for initializer in initializers}
|
|
2717
2708
|
self.doc_string = doc_string
|
|
2718
2709
|
self.opset_imports = opset_imports or {}
|
|
2719
2710
|
self._metadata: _metadata.MetadataStore | None = None
|
|
@@ -2927,13 +2918,15 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2927
2918
|
# Ensure the inputs and outputs of the function belong to a graph
|
|
2928
2919
|
# and not from an outer scope
|
|
2929
2920
|
graph: Graph,
|
|
2930
|
-
attributes:
|
|
2921
|
+
attributes: Iterable[Attr] | Mapping[str, Attr],
|
|
2931
2922
|
) -> None:
|
|
2932
2923
|
self._domain = domain
|
|
2933
2924
|
self._name = name
|
|
2934
2925
|
self._overload = overload
|
|
2935
2926
|
self._graph = graph
|
|
2936
|
-
|
|
2927
|
+
if isinstance(attributes, Mapping):
|
|
2928
|
+
attributes = tuple(attributes.values())
|
|
2929
|
+
self._attributes = _graph_containers.Attributes(attributes)
|
|
2937
2930
|
|
|
2938
2931
|
def identifier(self) -> _protocols.OperatorIdentifier:
|
|
2939
2932
|
return self.domain, self.name, self.overload
|
|
@@ -2971,7 +2964,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2971
2964
|
return self._graph.outputs
|
|
2972
2965
|
|
|
2973
2966
|
@property
|
|
2974
|
-
def attributes(self) ->
|
|
2967
|
+
def attributes(self) -> _graph_containers.Attributes:
|
|
2975
2968
|
return self._attributes
|
|
2976
2969
|
|
|
2977
2970
|
@typing.overload
|
|
@@ -12,13 +12,16 @@ __all__ = [
|
|
|
12
12
|
]
|
|
13
13
|
|
|
14
14
|
import collections
|
|
15
|
-
|
|
16
|
-
from
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import Iterable, Sequence
|
|
17
|
+
from typing import SupportsIndex, TypeVar
|
|
17
18
|
|
|
18
19
|
import onnx_ir
|
|
20
|
+
from onnx_ir import _core, _protocols
|
|
19
21
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class _GraphIO(collections.UserList["_core.Value"]):
|
|
@@ -152,6 +155,10 @@ class GraphInputs(_GraphIO):
|
|
|
152
155
|
raise ValueError(
|
|
153
156
|
f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
|
|
154
157
|
)
|
|
158
|
+
if value.producer() is not None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
|
|
161
|
+
)
|
|
155
162
|
self._ref_counter[value] += 1
|
|
156
163
|
value._is_graph_input = True
|
|
157
164
|
value._graph = self._graph
|
|
@@ -244,12 +251,23 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
|
244
251
|
|
|
245
252
|
def __setitem__(self, key: str, value: _core.Value) -> None:
|
|
246
253
|
"""Set an initializer for the graph."""
|
|
247
|
-
if
|
|
254
|
+
if not isinstance(value, _core.Value):
|
|
255
|
+
raise TypeError(f"value must be a Value object, not {type(value)}")
|
|
256
|
+
if not isinstance(key, str):
|
|
257
|
+
raise TypeError(f"Value name must be a string, not {type(key)}")
|
|
258
|
+
if key == "":
|
|
259
|
+
raise ValueError("Value name cannot be an empty string")
|
|
260
|
+
if not value.name:
|
|
261
|
+
logger.info("Value %s does not have a name, setting it to '%s'", value, key)
|
|
262
|
+
value.name = key
|
|
263
|
+
elif key != value.name:
|
|
248
264
|
raise ValueError(
|
|
249
|
-
f"Key '{key}' does not match the name of the value '{value.name}'"
|
|
265
|
+
f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
|
|
266
|
+
)
|
|
267
|
+
if value.producer() is not None:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Value '{value}' is produced by a node and cannot be a graph initializer"
|
|
250
270
|
)
|
|
251
|
-
if not isinstance(key, str):
|
|
252
|
-
raise TypeError(f"Key must be a string, not {type(key)}")
|
|
253
271
|
if key in self.data:
|
|
254
272
|
# If the key already exists, unset the old value
|
|
255
273
|
old_value = self.data[key]
|
|
@@ -266,3 +284,90 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
|
266
284
|
# the dictionary is not modified
|
|
267
285
|
self._maybe_unset_graph(value)
|
|
268
286
|
super().__delitem__(key)
|
|
287
|
+
|
|
288
|
+
def add(self, value: _core.Value) -> None:
|
|
289
|
+
"""Add an initializer to the graph."""
|
|
290
|
+
self[value.name] = value # type: ignore[index]
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class Attributes(collections.UserDict[str, "_core.Attr"]):
|
|
294
|
+
"""The attributes of a Node."""
|
|
295
|
+
|
|
296
|
+
def __init__(self, attrs: Iterable[_core.Attr]):
|
|
297
|
+
super().__init__({attr.name: attr for attr in attrs})
|
|
298
|
+
|
|
299
|
+
def __setitem__(self, key: str, value: _core.Attr) -> None:
|
|
300
|
+
"""Set an attribute for the node."""
|
|
301
|
+
if type(key) is not str:
|
|
302
|
+
raise TypeError(f"Key must be a string, not {type(key)}")
|
|
303
|
+
if not isinstance(value, _core.Attr):
|
|
304
|
+
raise TypeError(f"Value must be an Attr, not {type(value)}")
|
|
305
|
+
super().__setitem__(key, value)
|
|
306
|
+
|
|
307
|
+
def add(self, value: _core.Attr) -> None:
|
|
308
|
+
"""Add an attribute to the node."""
|
|
309
|
+
self[value.name] = value
|
|
310
|
+
|
|
311
|
+
def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
|
|
312
|
+
"""Get the integer value of the attribute."""
|
|
313
|
+
if key in self:
|
|
314
|
+
return self[key].as_int()
|
|
315
|
+
return default
|
|
316
|
+
|
|
317
|
+
def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
|
|
318
|
+
"""Get the float value of the attribute."""
|
|
319
|
+
if key in self:
|
|
320
|
+
return self[key].as_float()
|
|
321
|
+
return default
|
|
322
|
+
|
|
323
|
+
def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
|
|
324
|
+
"""Get the string value of the attribute."""
|
|
325
|
+
if key in self:
|
|
326
|
+
return self[key].as_string()
|
|
327
|
+
return default
|
|
328
|
+
|
|
329
|
+
def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
|
|
330
|
+
"""Get the tensor value of the attribute."""
|
|
331
|
+
if key in self:
|
|
332
|
+
return self[key].as_tensor()
|
|
333
|
+
return default
|
|
334
|
+
|
|
335
|
+
def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
|
|
336
|
+
"""Get the graph value of the attribute."""
|
|
337
|
+
if key in self:
|
|
338
|
+
return self[key].as_graph()
|
|
339
|
+
return default
|
|
340
|
+
|
|
341
|
+
def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
|
|
342
|
+
"""Get the Sequence of integers from the attribute."""
|
|
343
|
+
if key in self:
|
|
344
|
+
return self[key].as_ints()
|
|
345
|
+
return default
|
|
346
|
+
|
|
347
|
+
def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
|
|
348
|
+
"""Get the Sequence of floats from the attribute."""
|
|
349
|
+
if key in self:
|
|
350
|
+
return self[key].as_floats()
|
|
351
|
+
return default
|
|
352
|
+
|
|
353
|
+
def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
|
|
354
|
+
"""Get the Sequence of strings from the attribute."""
|
|
355
|
+
if key in self:
|
|
356
|
+
return self[key].as_strings()
|
|
357
|
+
return default
|
|
358
|
+
|
|
359
|
+
def get_tensors(
|
|
360
|
+
self,
|
|
361
|
+
key: str,
|
|
362
|
+
default: T = None, # type: ignore[assignment]
|
|
363
|
+
) -> Sequence[_protocols.TensorProtocol] | T:
|
|
364
|
+
"""Get the Sequence of tensors from the attribute."""
|
|
365
|
+
if key in self:
|
|
366
|
+
return self[key].as_tensors()
|
|
367
|
+
return default
|
|
368
|
+
|
|
369
|
+
def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
|
|
370
|
+
"""Get the Sequence of graphs from the attribute."""
|
|
371
|
+
if key in self:
|
|
372
|
+
return self[key].as_graphs()
|
|
373
|
+
return default
|
|
@@ -5,6 +5,7 @@ __all__ = [
|
|
|
5
5
|
"AddInitializersToInputsPass",
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
|
+
"CommonSubexpressionEliminationPass",
|
|
8
9
|
"InlinePass",
|
|
9
10
|
"LiftConstantsToInitializersPass",
|
|
10
11
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
@@ -19,6 +20,9 @@ __all__ = [
|
|
|
19
20
|
from onnx_ir.passes.common.clear_metadata_and_docstring import (
|
|
20
21
|
ClearMetadataAndDocStringPass,
|
|
21
22
|
)
|
|
23
|
+
from onnx_ir.passes.common.common_subexpression_elimination import (
|
|
24
|
+
CommonSubexpressionEliminationPass,
|
|
25
|
+
)
|
|
22
26
|
from onnx_ir.passes.common.constant_manipulation import (
|
|
23
27
|
AddInitializersToInputsPass,
|
|
24
28
|
LiftConstantsToInitializersPass,
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CommonSubexpressionEliminationPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
21
|
+
|
|
22
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
23
|
+
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
24
|
+
modified = False
|
|
25
|
+
graph = model.graph
|
|
26
|
+
|
|
27
|
+
modified = _eliminate_common_subexpression(graph, modified)
|
|
28
|
+
|
|
29
|
+
return ir.passes.PassResult(
|
|
30
|
+
model,
|
|
31
|
+
modified=modified,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
|
|
36
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
37
|
+
# node to node identifier, length of outputs, inputs, and attributes
|
|
38
|
+
existing_node_info_to_the_node: dict[
|
|
39
|
+
tuple[
|
|
40
|
+
ir.OperatorIdentifier,
|
|
41
|
+
int, # len(outputs)
|
|
42
|
+
tuple[int, ...], # input ids
|
|
43
|
+
tuple[tuple[str, object], ...], # attributes
|
|
44
|
+
],
|
|
45
|
+
ir.Node,
|
|
46
|
+
] = {}
|
|
47
|
+
|
|
48
|
+
for node in graph:
|
|
49
|
+
# Skip control flow ops like Loop and If.
|
|
50
|
+
control_flow_op: bool = False
|
|
51
|
+
# Use equality to check if the node is a common subexpression.
|
|
52
|
+
attributes = {}
|
|
53
|
+
for k, v in node.attributes.items():
|
|
54
|
+
# TODO(exporter team): CSE subgraphs.
|
|
55
|
+
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
56
|
+
# because attribute: graph won't match.
|
|
57
|
+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
58
|
+
control_flow_op = True
|
|
59
|
+
logger.debug("Skipping control flow op %s", node)
|
|
60
|
+
# The attribute value could be directly taken from the original
|
|
61
|
+
# protobuf, so we need to make a copy of it.
|
|
62
|
+
value = v.value
|
|
63
|
+
if v.type in (
|
|
64
|
+
ir.AttributeType.INTS,
|
|
65
|
+
ir.AttributeType.FLOATS,
|
|
66
|
+
ir.AttributeType.STRINGS,
|
|
67
|
+
):
|
|
68
|
+
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
69
|
+
# to ensure they are hashable.
|
|
70
|
+
value = tuple(value)
|
|
71
|
+
attributes[k] = value
|
|
72
|
+
|
|
73
|
+
if control_flow_op:
|
|
74
|
+
# If the node is a control flow op, we skip it.
|
|
75
|
+
logger.debug("Skipping control flow op %s", node)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
if _is_non_deterministic_op(node):
|
|
79
|
+
# If the node is a non-deterministic op, we skip it.
|
|
80
|
+
logger.debug("Skipping non-deterministic op %s", node)
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
node_info = (
|
|
84
|
+
node.op_identifier(),
|
|
85
|
+
len(node.outputs),
|
|
86
|
+
tuple(id(input) for input in node.inputs),
|
|
87
|
+
tuple(sorted(attributes.items())),
|
|
88
|
+
)
|
|
89
|
+
# Check if the node is a common subexpression.
|
|
90
|
+
if node_info in existing_node_info_to_the_node:
|
|
91
|
+
# If it is, this node has an existing node with the same
|
|
92
|
+
# operator, number of outputs, inputs, and attributes.
|
|
93
|
+
# We replace the node with the existing node.
|
|
94
|
+
modified = True
|
|
95
|
+
existing_node = existing_node_info_to_the_node[node_info]
|
|
96
|
+
_remove_node_and_replace_values(
|
|
97
|
+
graph,
|
|
98
|
+
remove_node=node,
|
|
99
|
+
remove_values=node.outputs,
|
|
100
|
+
new_values=existing_node.outputs,
|
|
101
|
+
)
|
|
102
|
+
logger.debug("Reusing node %s", existing_node)
|
|
103
|
+
else:
|
|
104
|
+
# If it is not, add to the mapping.
|
|
105
|
+
existing_node_info_to_the_node[node_info] = node
|
|
106
|
+
return modified
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _remove_node_and_replace_values(
|
|
110
|
+
graph: ir.Graph,
|
|
111
|
+
/,
|
|
112
|
+
remove_node: ir.Node,
|
|
113
|
+
remove_values: Sequence[ir.Value],
|
|
114
|
+
new_values: Sequence[ir.Value],
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Replaces nodes and values in the graph or function.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
graph: The graph to replace nodes and values in.
|
|
120
|
+
remove_node: The node to remove.
|
|
121
|
+
remove_values: The values to replace.
|
|
122
|
+
new_values: The values to replace with.
|
|
123
|
+
"""
|
|
124
|
+
# Reconnect the users of the deleted values to use the new values
|
|
125
|
+
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
126
|
+
# Update graph/function outputs if the node generates output
|
|
127
|
+
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
128
|
+
replacement_mapping = dict(zip(remove_values, new_values))
|
|
129
|
+
for idx, graph_output in enumerate(graph.outputs):
|
|
130
|
+
if graph_output in replacement_mapping:
|
|
131
|
+
new_value = replacement_mapping[graph_output]
|
|
132
|
+
if new_value.is_graph_output() or new_value.is_graph_input():
|
|
133
|
+
# If the new value is also a graph input/output, we need to
|
|
134
|
+
# create a Identity node to preserve the remove_value and
|
|
135
|
+
# prevent from changing new_value name.
|
|
136
|
+
identity_node = ir.node(
|
|
137
|
+
"Identity",
|
|
138
|
+
inputs=[new_value],
|
|
139
|
+
outputs=[
|
|
140
|
+
ir.Value(
|
|
141
|
+
name=graph_output.name,
|
|
142
|
+
type=graph_output.type,
|
|
143
|
+
shape=graph_output.shape,
|
|
144
|
+
)
|
|
145
|
+
],
|
|
146
|
+
)
|
|
147
|
+
# reuse the name of the graph output
|
|
148
|
+
graph.outputs[idx] = identity_node.outputs[0]
|
|
149
|
+
graph.insert_before(
|
|
150
|
+
remove_node,
|
|
151
|
+
identity_node,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
# if new_value is not graph output, we just
|
|
155
|
+
# update it to use old_value name.
|
|
156
|
+
new_value.name = graph_output.name
|
|
157
|
+
graph.outputs[idx] = new_value
|
|
158
|
+
|
|
159
|
+
graph.remove(remove_node, safe=True)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _is_non_deterministic_op(node: ir.Node) -> bool:
|
|
163
|
+
non_deterministic_ops = frozenset(
|
|
164
|
+
{
|
|
165
|
+
"RandomUniform",
|
|
166
|
+
"RandomNormal",
|
|
167
|
+
"RandomUniformLike",
|
|
168
|
+
"RandomNormalLike",
|
|
169
|
+
"Multinomial",
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _is_onnx_domain(d: str) -> bool:
|
|
176
|
+
"""Check if the domain is the ONNX domain."""
|
|
177
|
+
return d == ""
|
|
@@ -139,35 +139,11 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
139
139
|
for further processing or optimization.
|
|
140
140
|
|
|
141
141
|
Initializers that are also graph inputs will not be lifted.
|
|
142
|
-
|
|
143
|
-
Preconditions:
|
|
144
|
-
- All initializers in the model must have unique names across the main graph and subgraphs.
|
|
145
142
|
"""
|
|
146
143
|
|
|
147
|
-
def requires(self, model: ir.Model) -> None:
|
|
148
|
-
"""Ensure all initializer names are unique."""
|
|
149
|
-
registered_initializer_names: set[str] = set()
|
|
150
|
-
duplicated_initializers: list[ir.Value] = []
|
|
151
|
-
for graph in model.graphs():
|
|
152
|
-
for initializer in graph.initializers.values():
|
|
153
|
-
if initializer.name is None:
|
|
154
|
-
raise ir.passes.PreconditionError(
|
|
155
|
-
f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}"
|
|
156
|
-
)
|
|
157
|
-
if initializer.name in registered_initializer_names:
|
|
158
|
-
duplicated_initializers.append(initializer)
|
|
159
|
-
else:
|
|
160
|
-
registered_initializer_names.add(initializer.name)
|
|
161
|
-
if duplicated_initializers:
|
|
162
|
-
raise ir.passes.PreconditionError(
|
|
163
|
-
"Found duplicated initializers in the model. "
|
|
164
|
-
"Initializer name must be unique across the main graph and subgraphs. "
|
|
165
|
-
"Please ensure all initializers have unique names. Duplicated: "
|
|
166
|
-
f"{duplicated_initializers!r}"
|
|
167
|
-
)
|
|
168
|
-
|
|
169
144
|
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
170
145
|
count = 0
|
|
146
|
+
registered_initializer_names: dict[str, int] = {}
|
|
171
147
|
for graph in model.graphs():
|
|
172
148
|
if graph is model.graph:
|
|
173
149
|
continue
|
|
@@ -182,6 +158,15 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
182
158
|
continue
|
|
183
159
|
# Remove the initializer from the subgraph
|
|
184
160
|
graph.initializers.pop(name)
|
|
161
|
+
# To avoid name conflicts, we need to rename the initializer
|
|
162
|
+
# to a unique name in the main graph
|
|
163
|
+
if name in registered_initializer_names:
|
|
164
|
+
name_count = registered_initializer_names[name]
|
|
165
|
+
initializer.name = f"{name}_{name_count}"
|
|
166
|
+
registered_initializer_names[name] = name_count + 1
|
|
167
|
+
else:
|
|
168
|
+
assert initializer.name is not None
|
|
169
|
+
registered_initializer_names[initializer.name] = 1
|
|
185
170
|
model.graph.register_initializer(initializer)
|
|
186
171
|
count += 1
|
|
187
172
|
logger.debug(
|
|
@@ -9,7 +9,7 @@ import dataclasses
|
|
|
9
9
|
__all__ = ["InlinePass", "InlinePassResult"]
|
|
10
10
|
|
|
11
11
|
from collections import defaultdict
|
|
12
|
-
from collections.abc import Iterable, Sequence
|
|
12
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
13
13
|
|
|
14
14
|
import onnx_ir as ir
|
|
15
15
|
import onnx_ir.convenience as _ir_convenience
|
|
@@ -52,7 +52,7 @@ class _CopyReplace:
|
|
|
52
52
|
def __init__(
|
|
53
53
|
self,
|
|
54
54
|
inliner: InlinePass,
|
|
55
|
-
attr_map:
|
|
55
|
+
attr_map: Mapping[str, ir.Attr],
|
|
56
56
|
value_map: dict[ir.Value, ir.Value | None],
|
|
57
57
|
metadata_props: dict[str, str],
|
|
58
58
|
call_stack: CallStack,
|
|
@@ -96,6 +96,7 @@ class _CopyReplace:
|
|
|
96
96
|
return attr
|
|
97
97
|
assert attr.is_ref()
|
|
98
98
|
ref_attr_name = attr.ref_attr_name
|
|
99
|
+
assert ref_attr_name is not None, "Reference attribute must have a name"
|
|
99
100
|
if ref_attr_name in self._attr_map:
|
|
100
101
|
ref_attr = self._attr_map[ref_attr_name]
|
|
101
102
|
if not ref_attr.is_ref():
|
|
@@ -237,7 +238,7 @@ class InlinePass(ir.passes.InPlacePass):
|
|
|
237
238
|
)
|
|
238
239
|
|
|
239
240
|
# Identify substitutions for both inputs and attributes of the function:
|
|
240
|
-
attributes:
|
|
241
|
+
attributes: Mapping[str, ir.Attr] = node.attributes
|
|
241
242
|
default_attr_values = {
|
|
242
243
|
attr.name: attr
|
|
243
244
|
for attr in function.attributes.values()
|
|
@@ -37,6 +37,7 @@ src/onnx_ir/passes/_pass_infra.py
|
|
|
37
37
|
src/onnx_ir/passes/common/__init__.py
|
|
38
38
|
src/onnx_ir/passes/common/_c_api_utils.py
|
|
39
39
|
src/onnx_ir/passes/common/clear_metadata_and_docstring.py
|
|
40
|
+
src/onnx_ir/passes/common/common_subexpression_elimination.py
|
|
40
41
|
src/onnx_ir/passes/common/constant_manipulation.py
|
|
41
42
|
src/onnx_ir/passes/common/inliner.py
|
|
42
43
|
src/onnx_ir/passes/common/onnx_checker.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|