onnx-ir 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (50) hide show
  1. {onnx_ir-0.1.2/src/onnx_ir.egg-info → onnx_ir-0.1.3}/PKG-INFO +1 -1
  2. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/__init__.py +1 -1
  3. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_convenience/__init__.py +5 -0
  4. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_core.py +25 -3
  5. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_enums.py +2 -0
  6. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/__init__.py +4 -0
  7. onnx_ir-0.1.3/src/onnx_ir/passes/common/common_subexpression_elimination.py +206 -0
  8. onnx_ir-0.1.3/src/onnx_ir/passes/common/initializer_deduplication.py +56 -0
  9. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/serde.py +5 -0
  10. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/tensor_adapters.py +62 -7
  11. {onnx_ir-0.1.2 → onnx_ir-0.1.3/src/onnx_ir.egg-info}/PKG-INFO +1 -1
  12. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/SOURCES.txt +1 -0
  13. onnx_ir-0.1.2/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -177
  14. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/LICENSE +0 -0
  15. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/MANIFEST.in +0 -0
  16. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/README.md +0 -0
  17. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/pyproject.toml +0 -0
  18. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/setup.cfg +0 -0
  19. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_convenience/_constructors.py +0 -0
  20. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_display.py +0 -0
  21. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_graph_comparison.py +0 -0
  22. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_graph_containers.py +0 -0
  23. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_io.py +0 -0
  24. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_linked_list.py +0 -0
  25. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_metadata.py +0 -0
  26. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_name_authority.py +0 -0
  27. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_polyfill.py +0 -0
  28. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_protocols.py +0 -0
  29. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_tape.py +0 -0
  30. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  31. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_type_casting.py +0 -0
  32. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_version_utils.py +0 -0
  33. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/convenience.py +0 -0
  34. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/external_data.py +0 -0
  35. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/__init__.py +0 -0
  36. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/_pass_infra.py +0 -0
  37. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  38. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  39. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
  40. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/inliner.py +0 -0
  41. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  42. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  43. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  44. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  45. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/tape.py +0 -0
  46. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/testing.py +0 -0
  47. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/traversal.py +0 -0
  48. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  49. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/requires.txt +0 -0
  50. {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.2"
170
+ __version__ = "0.1.3"
@@ -323,6 +323,9 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
323
323
  and the first value with that name is returned. Values with empty names
324
324
  are excluded from the mapping.
325
325
 
326
+ .. versionchanged:: 0.1.2
327
+ Values from subgraphs are now included in the mapping.
328
+
326
329
  Args:
327
330
  graph: The graph to extract the mapping from.
328
331
 
@@ -410,6 +413,8 @@ def get_const_tensor(
410
413
  it will propagate the shape and type of the constant tensor to the value
411
414
  if `propagate_shape_type` is set to True.
412
415
 
416
+ .. versionadded:: 0.1.2
417
+
413
418
  Args:
414
419
  value: The value to get the constant tensor from.
415
420
  propagate_shape_type: If True, the shape and type of the value will be
@@ -417,6 +417,9 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
417
417
  else:
418
418
  self._shape = shape
419
419
  self._shape.freeze()
420
+ if isinstance(value, np.generic):
421
+ # Turn numpy scalar into a numpy array
422
+ value = np.array(value) # type: ignore[assignment]
420
423
  if dtype is None:
421
424
  if isinstance(value, np.ndarray):
422
425
  self._dtype = _enums.DataType.from_numpy(value.dtype)
@@ -964,7 +967,10 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
964
967
 
965
968
 
966
969
  class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
967
- """A tensor that stores 4bit datatypes in packed format."""
970
+ """A tensor that stores 4bit datatypes in packed format.
971
+
972
+ .. versionadded:: 0.1.2
973
+ """
968
974
 
969
975
  __slots__ = (
970
976
  "_dtype",
@@ -2335,6 +2341,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2335
2341
  seen as a Sequence of nodes and should be used as such. For example, to obtain
2336
2342
  all nodes as a list, call ``list(graph)``.
2337
2343
 
2344
+ .. versionchanged:: 0.1.1
2345
+ Values with non-none producers will be rejected as graph inputs or initializers.
2346
+
2347
+ .. versionadded:: 0.1.1
2348
+ Added ``add`` method to initializers and attributes.
2349
+
2338
2350
  Attributes:
2339
2351
  name: The name of the graph.
2340
2352
  inputs: The input values of the graph.
@@ -2545,12 +2557,17 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2545
2557
  Consider using
2546
2558
  :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2547
2559
  traversals on nodes.
2560
+
2561
+ .. versionadded:: 0.1.2
2548
2562
  """
2549
2563
  # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2550
2564
  return onnx_ir.traversal.RecursiveGraphIterator(self)
2551
2565
 
2552
2566
  def subgraphs(self) -> Iterator[Graph]:
2553
- """Get all subgraphs in the graph in O(#nodes + #attributes) time."""
2567
+ """Get all subgraphs in the graph in O(#nodes + #attributes) time.
2568
+
2569
+ .. versionadded:: 0.1.2
2570
+ """
2554
2571
  seen_graphs: set[Graph] = set()
2555
2572
  for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2556
2573
  graph = node.graph
@@ -3216,12 +3233,17 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3216
3233
  Consider using
3217
3234
  :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3218
3235
  traversals on nodes.
3236
+
3237
+ .. versionadded:: 0.1.2
3219
3238
  """
3220
3239
  # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3221
3240
  return onnx_ir.traversal.RecursiveGraphIterator(self)
3222
3241
 
3223
3242
  def subgraphs(self) -> Iterator[Graph]:
3224
- """Get all subgraphs in the function in O(#nodes + #attributes) time."""
3243
+ """Get all subgraphs in the function in O(#nodes + #attributes) time.
3244
+
3245
+ .. versionadded:: 0.1.2
3246
+ """
3225
3247
  seen_graphs: set[Graph] = set()
3226
3248
  for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3227
3249
  graph = node.graph
@@ -120,6 +120,8 @@ class DataType(enum.IntEnum):
120
120
  def bitwidth(self) -> int:
121
121
  """Returns the bit width of the data type.
122
122
 
123
+ .. versionadded:: 0.1.2
124
+
123
125
  Raises:
124
126
  TypeError: If the data type is not supported.
125
127
  """
@@ -6,6 +6,7 @@ __all__ = [
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
+ "DeduplicateInitializersPass",
9
10
  "InlinePass",
10
11
  "LiftConstantsToInitializersPass",
11
12
  "LiftSubgraphInitializersToMainGraphPass",
@@ -29,6 +30,9 @@ from onnx_ir.passes.common.constant_manipulation import (
29
30
  LiftSubgraphInitializersToMainGraphPass,
30
31
  RemoveInitializersFromInputsPass,
31
32
  )
33
+ from onnx_ir.passes.common.initializer_deduplication import (
34
+ DeduplicateInitializersPass,
35
+ )
32
36
  from onnx_ir.passes.common.inliner import InlinePass
33
37
  from onnx_ir.passes.common.onnx_checker import CheckerPass
34
38
  from onnx_ir.passes.common.shape_inference import ShapeInferencePass
@@ -0,0 +1,206 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Eliminate common subexpression in ONNX graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CommonSubexpressionEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+ from collections.abc import Sequence
13
+
14
+ import onnx_ir as ir
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
+ """Eliminate common subexpression in ONNX graphs.
21
+
22
+ .. versionadded:: 0.1.1
23
+
24
+ .. versionchanged:: 0.1.3
25
+ Constant nodes with values smaller than ``size_limit`` will be CSE'd.
26
+
27
+ Attributes:
28
+ size_limit: The maximum size of the tensor to be csed. If the tensor contains
29
+ number of elements larger than size_limit, it will not be cse'd. Default is 10.
30
+
31
+ """
32
+
33
+ def __init__(self, size_limit: int = 10):
34
+ """Initialize the CommonSubexpressionEliminationPass."""
35
+ super().__init__()
36
+ self.size_limit = size_limit
37
+
38
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
39
+ """Return the same ir.Model but with CSE applied to the graph."""
40
+ graph = model.graph
41
+ modified = self._eliminate_common_subexpression(graph)
42
+
43
+ return ir.passes.PassResult(
44
+ model,
45
+ modified=modified,
46
+ )
47
+
48
+ def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
49
+ """Eliminate common subexpression in ONNX graphs."""
50
+ modified: bool = False
51
+ # node to node identifier, length of outputs, inputs, and attributes
52
+ existing_node_info_to_the_node: dict[
53
+ tuple[
54
+ ir.OperatorIdentifier,
55
+ int, # len(outputs)
56
+ tuple[int, ...], # input ids
57
+ tuple[tuple[str, object], ...], # attributes
58
+ ],
59
+ ir.Node,
60
+ ] = {}
61
+
62
+ for node in graph:
63
+ # Skip control flow ops like Loop and If.
64
+ control_flow_op: bool = False
65
+ # Skip large tensors to avoid cse weights and bias.
66
+ large_tensor: bool = False
67
+ # Use equality to check if the node is a common subexpression.
68
+ attributes = {}
69
+ for k, v in node.attributes.items():
70
+ # TODO(exporter team): CSE subgraphs.
71
+ # NOTE: control flow ops like Loop and If won't be CSEd
72
+ # because attribute: graph won't match.
73
+ if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
74
+ control_flow_op = True
75
+ break
76
+ # The attribute value could be directly taken from the original
77
+ # protobuf, so we need to make a copy of it.
78
+ value = v.value
79
+ if v.type in (
80
+ ir.AttributeType.INTS,
81
+ ir.AttributeType.FLOATS,
82
+ ir.AttributeType.STRINGS,
83
+ ):
84
+ # For INT, FLOAT and STRING attributes, we convert them to tuples
85
+ # to ensure they are hashable.
86
+ value = tuple(value)
87
+ elif v.type is ir.AttributeType.TENSOR:
88
+ if value.size > self.size_limit:
89
+ # If the tensor is larger than the size limit, we skip it.
90
+ large_tensor = True
91
+ break
92
+ np_value = value.numpy()
93
+
94
+ value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
95
+ attributes[k] = value
96
+
97
+ if control_flow_op:
98
+ # If the node is a control flow op, we skip it.
99
+ logger.debug("Skipping control flow op %s", node)
100
+ continue
101
+
102
+ if large_tensor:
103
+ # If the node has a large tensor, we skip it.
104
+ logger.debug("Skipping large tensor in node %s", node)
105
+ continue
106
+
107
+ if _is_non_deterministic_op(node):
108
+ # If the node is a non-deterministic op, we skip it.
109
+ logger.debug("Skipping non-deterministic op %s", node)
110
+ continue
111
+
112
+ node_info = (
113
+ node.op_identifier(),
114
+ len(node.outputs),
115
+ tuple(id(input) for input in node.inputs),
116
+ tuple(sorted(attributes.items())),
117
+ )
118
+ # Check if the node is a common subexpression.
119
+ if node_info in existing_node_info_to_the_node:
120
+ # If it is, this node has an existing node with the same
121
+ # operator, number of outputs, inputs, and attributes.
122
+ # We replace the node with the existing node.
123
+ modified = True
124
+ existing_node = existing_node_info_to_the_node[node_info]
125
+ _remove_node_and_replace_values(
126
+ graph,
127
+ remove_node=node,
128
+ remove_values=node.outputs,
129
+ new_values=existing_node.outputs,
130
+ )
131
+ logger.debug("Reusing node %s", existing_node)
132
+ else:
133
+ # If it is not, add to the mapping.
134
+ existing_node_info_to_the_node[node_info] = node
135
+ return modified
136
+
137
+
138
+ def _remove_node_and_replace_values(
139
+ graph: ir.Graph,
140
+ /,
141
+ remove_node: ir.Node,
142
+ remove_values: Sequence[ir.Value],
143
+ new_values: Sequence[ir.Value],
144
+ ) -> None:
145
+ """Replaces nodes and values in the graph or function.
146
+
147
+ Args:
148
+ graph: The graph to replace nodes and values in.
149
+ remove_node: The node to remove.
150
+ remove_values: The values to replace.
151
+ new_values: The values to replace with.
152
+ """
153
+ # Reconnect the users of the deleted values to use the new values
154
+ ir.convenience.replace_all_uses_with(remove_values, new_values)
155
+ # Update graph/function outputs if the node generates output
156
+ if any(remove_value.is_graph_output() for remove_value in remove_values):
157
+ replacement_mapping = dict(zip(remove_values, new_values))
158
+ for idx, graph_output in enumerate(graph.outputs):
159
+ if graph_output in replacement_mapping:
160
+ new_value = replacement_mapping[graph_output]
161
+ if new_value.is_graph_output() or new_value.is_graph_input():
162
+ # If the new value is also a graph input/output, we need to
163
+ # create a Identity node to preserve the remove_value and
164
+ # prevent from changing new_value name.
165
+ identity_node = ir.node(
166
+ "Identity",
167
+ inputs=[new_value],
168
+ outputs=[
169
+ ir.Value(
170
+ name=graph_output.name,
171
+ type=graph_output.type,
172
+ shape=graph_output.shape,
173
+ )
174
+ ],
175
+ )
176
+ # reuse the name of the graph output
177
+ graph.outputs[idx] = identity_node.outputs[0]
178
+ graph.insert_before(
179
+ remove_node,
180
+ identity_node,
181
+ )
182
+ else:
183
+ # if new_value is not graph output, we just
184
+ # update it to use old_value name.
185
+ new_value.name = graph_output.name
186
+ graph.outputs[idx] = new_value
187
+
188
+ graph.remove(remove_node, safe=True)
189
+
190
+
191
+ def _is_non_deterministic_op(node: ir.Node) -> bool:
192
+ non_deterministic_ops = frozenset(
193
+ {
194
+ "RandomUniform",
195
+ "RandomNormal",
196
+ "RandomUniformLike",
197
+ "RandomNormalLike",
198
+ "Multinomial",
199
+ }
200
+ )
201
+ return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
202
+
203
+
204
+ def _is_onnx_domain(d: str) -> bool:
205
+ """Check if the domain is the ONNX domain."""
206
+ return d == ""
@@ -0,0 +1,56 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for removing duplicated initializer tensors from a graph."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "DeduplicateInitializersPass",
9
+ ]
10
+
11
+
12
+ import onnx_ir as ir
13
+
14
+
15
+ class DeduplicateInitializersPass(ir.passes.InPlacePass):
16
+ """Remove duplicated initializer tensors from the graph.
17
+
18
+ This pass detects initializers with identical shape, dtype, and content,
19
+ and replaces all duplicate references with a canonical one.
20
+
21
+ To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
22
+ to lift the initializers to the main graph first before running pass.
23
+
24
+ .. versionadded:: 0.1.3
25
+ """
26
+
27
+ def __init__(self, size_limit: int = 1024):
28
+ super().__init__()
29
+ self.size_limit = size_limit
30
+
31
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
32
+ graph = model.graph
33
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
34
+ modified = False
35
+
36
+ for initializer in tuple(graph.initializers.values()):
37
+ # TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
38
+ # out from the main graph before running this pass.
39
+ const_val = initializer.const_value
40
+ if const_val is None:
41
+ # Skip if initializer has no constant value
42
+ continue
43
+
44
+ if const_val.size > self.size_limit:
45
+ continue
46
+
47
+ key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
48
+ if key in initializers:
49
+ modified = True
50
+ ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
51
+ assert initializer.name is not None
52
+ graph.initializers.pop(initializer.name)
53
+ else:
54
+ initializers[key] = initializer # type: ignore[index]
55
+
56
+ return ir.passes.PassResult(model=model, modified=modified)
@@ -200,6 +200,9 @@ def from_onnx_text(
200
200
 
201
201
  Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
202
202
 
203
+ .. versionchanged:: 0.1.2
204
+ Added the ``initializers`` argument.
205
+
203
206
  Args:
204
207
  model_text: The ONNX textual representation of the model.
205
208
  initializers: Tensors to be added as initializers. If provided, these tensors
@@ -237,6 +240,8 @@ def to_onnx_text(
237
240
  ) -> str:
238
241
  """Convert the IR model to the ONNX textual representation.
239
242
 
243
+ .. versionadded:: 0.1.2
244
+
240
245
  Args:
241
246
  model: The IR model to convert.
242
247
  exclude_initializers: If True, the initializers will not be included in the output.
@@ -29,6 +29,8 @@ Example::
29
29
  from __future__ import annotations
30
30
 
31
31
  __all__ = [
32
+ "from_torch_dtype",
33
+ "to_torch_dtype",
32
34
  "TorchTensor",
33
35
  ]
34
36
 
@@ -44,14 +46,17 @@ if TYPE_CHECKING:
44
46
  import torch
45
47
 
46
48
 
47
- class TorchTensor(_core.Tensor):
48
- def __init__(
49
- self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
50
- ):
51
- # Pass the tensor as the raw data to ir.Tensor's constructor
49
+ _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
50
+ _ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None
51
+
52
+
53
+ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
54
+ """Convert a PyTorch dtype to an ONNX IR DataType."""
55
+ global _TORCH_DTYPE_TO_ONNX
56
+ if _TORCH_DTYPE_TO_ONNX is None:
52
57
  import torch
53
58
 
54
- _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
59
+ _TORCH_DTYPE_TO_ONNX = {
55
60
  torch.bfloat16: ir.DataType.BFLOAT16,
56
61
  torch.bool: ir.DataType.BOOL,
57
62
  torch.complex128: ir.DataType.COMPLEX128,
@@ -72,8 +77,58 @@ class TorchTensor(_core.Tensor):
72
77
  torch.uint32: ir.DataType.UINT32,
73
78
  torch.uint64: ir.DataType.UINT64,
74
79
  }
80
+ if dtype not in _TORCH_DTYPE_TO_ONNX:
81
+ raise TypeError(
82
+ f"Unsupported PyTorch dtype '{dtype}'. "
83
+ "Please use a supported dtype from the list: "
84
+ f"{list(_TORCH_DTYPE_TO_ONNX.keys())}"
85
+ )
86
+ return _TORCH_DTYPE_TO_ONNX[dtype]
87
+
88
+
89
+ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
90
+ """Convert an ONNX IR DataType to a PyTorch dtype."""
91
+ global _ONNX_DTYPE_TO_TORCH
92
+ if _ONNX_DTYPE_TO_TORCH is None:
93
+ import torch
94
+
95
+ _ONNX_DTYPE_TO_TORCH = {
96
+ ir.DataType.BFLOAT16: torch.bfloat16,
97
+ ir.DataType.BOOL: torch.bool,
98
+ ir.DataType.COMPLEX128: torch.complex128,
99
+ ir.DataType.COMPLEX64: torch.complex64,
100
+ ir.DataType.FLOAT16: torch.float16,
101
+ ir.DataType.FLOAT: torch.float32,
102
+ ir.DataType.DOUBLE: torch.float64,
103
+ ir.DataType.FLOAT8E4M3FN: torch.float8_e4m3fn,
104
+ ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
105
+ ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
106
+ ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
107
+ ir.DataType.INT16: torch.int16,
108
+ ir.DataType.INT32: torch.int32,
109
+ ir.DataType.INT64: torch.int64,
110
+ ir.DataType.INT8: torch.int8,
111
+ ir.DataType.UINT8: torch.uint8,
112
+ ir.DataType.UINT16: torch.uint16,
113
+ ir.DataType.UINT32: torch.uint32,
114
+ ir.DataType.UINT64: torch.uint64,
115
+ }
116
+ if dtype not in _ONNX_DTYPE_TO_TORCH:
117
+ raise TypeError(
118
+ f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
119
+ "Please use a supported dtype from the list: "
120
+ f"{list(_ONNX_DTYPE_TO_TORCH.keys())}"
121
+ )
122
+ return _ONNX_DTYPE_TO_TORCH[dtype]
123
+
124
+
125
+ class TorchTensor(_core.Tensor):
126
+ def __init__(
127
+ self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
128
+ ):
129
+ # Pass the tensor as the raw data to ir.Tensor's constructor
75
130
  super().__init__(
76
- tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
131
+ tensor, dtype=from_torch_dtype(tensor.dtype), name=name, doc_string=doc_string
77
132
  )
78
133
 
79
134
  def numpy(self) -> npt.NDArray:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -39,6 +39,7 @@ src/onnx_ir/passes/common/_c_api_utils.py
39
39
  src/onnx_ir/passes/common/clear_metadata_and_docstring.py
40
40
  src/onnx_ir/passes/common/common_subexpression_elimination.py
41
41
  src/onnx_ir/passes/common/constant_manipulation.py
42
+ src/onnx_ir/passes/common/initializer_deduplication.py
42
43
  src/onnx_ir/passes/common/inliner.py
43
44
  src/onnx_ir/passes/common/onnx_checker.py
44
45
  src/onnx_ir/passes/common/shape_inference.py
@@ -1,177 +0,0 @@
1
- # Copyright (c) ONNX Project Contributors
2
- # SPDX-License-Identifier: Apache-2.0
3
- """Eliminate common subexpression in ONNX graphs."""
4
-
5
- from __future__ import annotations
6
-
7
- __all__ = [
8
- "CommonSubexpressionEliminationPass",
9
- ]
10
-
11
- import logging
12
- from collections.abc import Sequence
13
-
14
- import onnx_ir as ir
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
- """Eliminate common subexpression in ONNX graphs."""
21
-
22
- def call(self, model: ir.Model) -> ir.passes.PassResult:
23
- """Return the same ir.Model but with CSE applied to the graph."""
24
- modified = False
25
- graph = model.graph
26
-
27
- modified = _eliminate_common_subexpression(graph, modified)
28
-
29
- return ir.passes.PassResult(
30
- model,
31
- modified=modified,
32
- )
33
-
34
-
35
- def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36
- """Eliminate common subexpression in ONNX graphs."""
37
- # node to node identifier, length of outputs, inputs, and attributes
38
- existing_node_info_to_the_node: dict[
39
- tuple[
40
- ir.OperatorIdentifier,
41
- int, # len(outputs)
42
- tuple[int, ...], # input ids
43
- tuple[tuple[str, object], ...], # attributes
44
- ],
45
- ir.Node,
46
- ] = {}
47
-
48
- for node in graph:
49
- # Skip control flow ops like Loop and If.
50
- control_flow_op: bool = False
51
- # Use equality to check if the node is a common subexpression.
52
- attributes = {}
53
- for k, v in node.attributes.items():
54
- # TODO(exporter team): CSE subgraphs.
55
- # NOTE: control flow ops like Loop and If won't be CSEd
56
- # because attribute: graph won't match.
57
- if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
58
- control_flow_op = True
59
- logger.debug("Skipping control flow op %s", node)
60
- # The attribute value could be directly taken from the original
61
- # protobuf, so we need to make a copy of it.
62
- value = v.value
63
- if v.type in (
64
- ir.AttributeType.INTS,
65
- ir.AttributeType.FLOATS,
66
- ir.AttributeType.STRINGS,
67
- ):
68
- # For INT, FLOAT and STRING attributes, we convert them to tuples
69
- # to ensure they are hashable.
70
- value = tuple(value)
71
- attributes[k] = value
72
-
73
- if control_flow_op:
74
- # If the node is a control flow op, we skip it.
75
- logger.debug("Skipping control flow op %s", node)
76
- continue
77
-
78
- if _is_non_deterministic_op(node):
79
- # If the node is a non-deterministic op, we skip it.
80
- logger.debug("Skipping non-deterministic op %s", node)
81
- continue
82
-
83
- node_info = (
84
- node.op_identifier(),
85
- len(node.outputs),
86
- tuple(id(input) for input in node.inputs),
87
- tuple(sorted(attributes.items())),
88
- )
89
- # Check if the node is a common subexpression.
90
- if node_info in existing_node_info_to_the_node:
91
- # If it is, this node has an existing node with the same
92
- # operator, number of outputs, inputs, and attributes.
93
- # We replace the node with the existing node.
94
- modified = True
95
- existing_node = existing_node_info_to_the_node[node_info]
96
- _remove_node_and_replace_values(
97
- graph,
98
- remove_node=node,
99
- remove_values=node.outputs,
100
- new_values=existing_node.outputs,
101
- )
102
- logger.debug("Reusing node %s", existing_node)
103
- else:
104
- # If it is not, add to the mapping.
105
- existing_node_info_to_the_node[node_info] = node
106
- return modified
107
-
108
-
109
- def _remove_node_and_replace_values(
110
- graph: ir.Graph,
111
- /,
112
- remove_node: ir.Node,
113
- remove_values: Sequence[ir.Value],
114
- new_values: Sequence[ir.Value],
115
- ) -> None:
116
- """Replaces nodes and values in the graph or function.
117
-
118
- Args:
119
- graph: The graph to replace nodes and values in.
120
- remove_node: The node to remove.
121
- remove_values: The values to replace.
122
- new_values: The values to replace with.
123
- """
124
- # Reconnect the users of the deleted values to use the new values
125
- ir.convenience.replace_all_uses_with(remove_values, new_values)
126
- # Update graph/function outputs if the node generates output
127
- if any(remove_value.is_graph_output() for remove_value in remove_values):
128
- replacement_mapping = dict(zip(remove_values, new_values))
129
- for idx, graph_output in enumerate(graph.outputs):
130
- if graph_output in replacement_mapping:
131
- new_value = replacement_mapping[graph_output]
132
- if new_value.is_graph_output() or new_value.is_graph_input():
133
- # If the new value is also a graph input/output, we need to
134
- # create a Identity node to preserve the remove_value and
135
- # prevent from changing new_value name.
136
- identity_node = ir.node(
137
- "Identity",
138
- inputs=[new_value],
139
- outputs=[
140
- ir.Value(
141
- name=graph_output.name,
142
- type=graph_output.type,
143
- shape=graph_output.shape,
144
- )
145
- ],
146
- )
147
- # reuse the name of the graph output
148
- graph.outputs[idx] = identity_node.outputs[0]
149
- graph.insert_before(
150
- remove_node,
151
- identity_node,
152
- )
153
- else:
154
- # if new_value is not graph output, we just
155
- # update it to use old_value name.
156
- new_value.name = graph_output.name
157
- graph.outputs[idx] = new_value
158
-
159
- graph.remove(remove_node, safe=True)
160
-
161
-
162
- def _is_non_deterministic_op(node: ir.Node) -> bool:
163
- non_deterministic_ops = frozenset(
164
- {
165
- "RandomUniform",
166
- "RandomNormal",
167
- "RandomUniformLike",
168
- "RandomNormalLike",
169
- "Multinomial",
170
- }
171
- )
172
- return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
173
-
174
-
175
- def _is_onnx_domain(d: str) -> bool:
176
- """Check if the domain is the ONNX domain."""
177
- return d == ""
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes