onnx-ir 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (52) hide show
  1. {onnx_ir-0.1.2/src/onnx_ir.egg-info → onnx_ir-0.1.4}/PKG-INFO +1 -1
  2. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/__init__.py +1 -1
  3. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_convenience/__init__.py +5 -0
  4. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_core.py +36 -22
  5. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_enums.py +44 -0
  6. onnx_ir-0.1.4/src/onnx_ir/_type_casting.py +50 -0
  7. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/__init__.py +4 -0
  8. onnx_ir-0.1.4/src/onnx_ir/passes/common/common_subexpression_elimination.py +206 -0
  9. onnx_ir-0.1.4/src/onnx_ir/passes/common/initializer_deduplication.py +56 -0
  10. onnx_ir-0.1.4/src/onnx_ir/py.typed +1 -0
  11. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/serde.py +77 -46
  12. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/tensor_adapters.py +62 -7
  13. {onnx_ir-0.1.2 → onnx_ir-0.1.4/src/onnx_ir.egg-info}/PKG-INFO +1 -1
  14. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/SOURCES.txt +2 -0
  15. onnx_ir-0.1.2/src/onnx_ir/_type_casting.py +0 -107
  16. onnx_ir-0.1.2/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -177
  17. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/LICENSE +0 -0
  18. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/MANIFEST.in +0 -0
  19. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/README.md +0 -0
  20. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/pyproject.toml +0 -0
  21. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/setup.cfg +0 -0
  22. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_convenience/_constructors.py +0 -0
  23. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_display.py +0 -0
  24. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_graph_comparison.py +0 -0
  25. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_graph_containers.py +0 -0
  26. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_io.py +0 -0
  27. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_linked_list.py +0 -0
  28. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_metadata.py +0 -0
  29. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_name_authority.py +0 -0
  30. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_polyfill.py +0 -0
  31. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_protocols.py +0 -0
  32. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_tape.py +0 -0
  33. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  34. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_version_utils.py +0 -0
  35. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/convenience.py +0 -0
  36. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/external_data.py +0 -0
  37. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/__init__.py +0 -0
  38. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/_pass_infra.py +0 -0
  39. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  40. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  41. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
  42. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/inliner.py +0 -0
  43. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  44. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  45. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  46. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  47. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/tape.py +0 -0
  48. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/testing.py +0 -0
  49. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/traversal.py +0 -0
  50. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  51. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/requires.txt +0 -0
  52. {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.2"
170
+ __version__ = "0.1.4"
@@ -323,6 +323,9 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
323
323
  and the first value with that name is returned. Values with empty names
324
324
  are excluded from the mapping.
325
325
 
326
+ .. versionchanged:: 0.1.2
327
+ Values from subgraphs are now included in the mapping.
328
+
326
329
  Args:
327
330
  graph: The graph to extract the mapping from.
328
331
 
@@ -410,6 +413,8 @@ def get_const_tensor(
410
413
  it will propagate the shape and type of the constant tensor to the value
411
414
  if `propagate_shape_type` is set to True.
412
415
 
416
+ .. versionadded:: 0.1.2
417
+
413
418
  Args:
414
419
  value: The value to get the constant tensor from.
415
420
  propagate_shape_type: If True, the shape and type of the value will be
@@ -417,6 +417,9 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
417
417
  else:
418
418
  self._shape = shape
419
419
  self._shape.freeze()
420
+ if isinstance(value, np.generic):
421
+ # Turn numpy scalar into a numpy array
422
+ value = np.array(value) # type: ignore[assignment]
420
423
  if dtype is None:
421
424
  if isinstance(value, np.ndarray):
422
425
  self._dtype = _enums.DataType.from_numpy(value.dtype)
@@ -654,15 +657,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
654
657
  self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
655
658
  return
656
659
  # Map the whole file into the memory
657
- # TODO(justinchuby): Verify if this would exhaust the memory address space
658
660
  with open(self.path, "rb") as f:
659
661
  self.raw = mmap.mmap(
660
662
  f.fileno(),
661
663
  0,
662
664
  access=mmap.ACCESS_READ,
663
665
  )
664
- # Handle the byte order correctly by always using little endian
665
- dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
666
+
666
667
  if self.dtype in {
667
668
  _enums.DataType.INT4,
668
669
  _enums.DataType.UINT4,
@@ -672,16 +673,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
672
673
  dt = np.dtype(np.uint8).newbyteorder("<")
673
674
  count = self.size // 2 + self.size % 2
674
675
  else:
676
+ # Handle the byte order correctly by always using little endian
677
+ dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
675
678
  count = self.size
679
+
676
680
  self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
677
681
  shape = self.shape.numpy()
678
- if self.dtype == _enums.DataType.INT4:
679
- # Unpack the int4 arrays
680
- self._array = _type_casting.unpack_int4(self._array, shape)
681
- elif self.dtype == _enums.DataType.UINT4:
682
- self._array = _type_casting.unpack_uint4(self._array, shape)
683
- elif self.dtype == _enums.DataType.FLOAT4E2M1:
684
- self._array = _type_casting.unpack_float4e2m1(self._array, shape)
682
+
683
+ if self.dtype.bitwidth == 4:
684
+ # Unpack the 4bit arrays
685
+ self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
686
+ self.dtype.numpy()
687
+ )
685
688
  else:
686
689
  self._array = self._array.reshape(shape)
687
690
 
@@ -964,7 +967,10 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
964
967
 
965
968
 
966
969
  class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
967
- """A tensor that stores 4bit datatypes in packed format."""
970
+ """A tensor that stores 4bit datatypes in packed format.
971
+
972
+ .. versionadded:: 0.1.2
973
+ """
968
974
 
969
975
  __slots__ = (
970
976
  "_dtype",
@@ -1065,15 +1071,7 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
1065
1071
  """
1066
1072
  array = self.numpy_packed()
1067
1073
  # ONNX IR returns the unpacked arrays
1068
- if self.dtype == _enums.DataType.INT4:
1069
- return _type_casting.unpack_int4(array, self.shape.numpy())
1070
- if self.dtype == _enums.DataType.UINT4:
1071
- return _type_casting.unpack_uint4(array, self.shape.numpy())
1072
- if self.dtype == _enums.DataType.FLOAT4E2M1:
1073
- return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
1074
- raise TypeError(
1075
- f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
1076
- )
1074
+ return _type_casting.unpack_4bitx2(array, self.shape.numpy()).view(self.dtype.numpy())
1077
1075
 
1078
1076
  def numpy_packed(self) -> npt.NDArray[np.uint8]:
1079
1077
  """Return the tensor as a packed array."""
@@ -2335,6 +2333,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2335
2333
  seen as a Sequence of nodes and should be used as such. For example, to obtain
2336
2334
  all nodes as a list, call ``list(graph)``.
2337
2335
 
2336
+ .. versionchanged:: 0.1.1
2337
+ Values with non-none producers will be rejected as graph inputs or initializers.
2338
+
2339
+ .. versionadded:: 0.1.1
2340
+ Added ``add`` method to initializers and attributes.
2341
+
2338
2342
  Attributes:
2339
2343
  name: The name of the graph.
2340
2344
  inputs: The input values of the graph.
@@ -2545,12 +2549,17 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2545
2549
  Consider using
2546
2550
  :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2547
2551
  traversals on nodes.
2552
+
2553
+ .. versionadded:: 0.1.2
2548
2554
  """
2549
2555
  # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2550
2556
  return onnx_ir.traversal.RecursiveGraphIterator(self)
2551
2557
 
2552
2558
  def subgraphs(self) -> Iterator[Graph]:
2553
- """Get all subgraphs in the graph in O(#nodes + #attributes) time."""
2559
+ """Get all subgraphs in the graph in O(#nodes + #attributes) time.
2560
+
2561
+ .. versionadded:: 0.1.2
2562
+ """
2554
2563
  seen_graphs: set[Graph] = set()
2555
2564
  for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2556
2565
  graph = node.graph
@@ -3216,12 +3225,17 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3216
3225
  Consider using
3217
3226
  :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3218
3227
  traversals on nodes.
3228
+
3229
+ .. versionadded:: 0.1.2
3219
3230
  """
3220
3231
  # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3221
3232
  return onnx_ir.traversal.RecursiveGraphIterator(self)
3222
3233
 
3223
3234
  def subgraphs(self) -> Iterator[Graph]:
3224
- """Get all subgraphs in the function in O(#nodes + #attributes) time."""
3235
+ """Get all subgraphs in the function in O(#nodes + #attributes) time.
3236
+
3237
+ .. versionadded:: 0.1.2
3238
+ """
3225
3239
  seen_graphs: set[Graph] = set()
3226
3240
  for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3227
3241
  graph = node.graph
@@ -120,6 +120,8 @@ class DataType(enum.IntEnum):
120
120
  def bitwidth(self) -> int:
121
121
  """Returns the bit width of the data type.
122
122
 
123
+ .. versionadded:: 0.1.2
124
+
123
125
  Raises:
124
126
  TypeError: If the data type is not supported.
125
127
  """
@@ -167,6 +169,48 @@ class DataType(enum.IntEnum):
167
169
  DataType.FLOAT4E2M1,
168
170
  }
169
171
 
172
+ def is_integer(self) -> bool:
173
+ """Returns True if the data type is an integer.
174
+
175
+ .. versionadded:: 0.1.4
176
+ """
177
+ return self in {
178
+ DataType.UINT8,
179
+ DataType.INT8,
180
+ DataType.UINT16,
181
+ DataType.INT16,
182
+ DataType.INT32,
183
+ DataType.INT64,
184
+ DataType.UINT32,
185
+ DataType.UINT64,
186
+ DataType.UINT4,
187
+ DataType.INT4,
188
+ }
189
+
190
+ def is_signed(self) -> bool:
191
+ """Returns True if the data type is a signed type.
192
+
193
+ .. versionadded:: 0.1.4
194
+ """
195
+ return self in {
196
+ DataType.FLOAT,
197
+ DataType.INT8,
198
+ DataType.INT16,
199
+ DataType.INT32,
200
+ DataType.INT64,
201
+ DataType.FLOAT16,
202
+ DataType.DOUBLE,
203
+ DataType.COMPLEX64,
204
+ DataType.COMPLEX128,
205
+ DataType.BFLOAT16,
206
+ DataType.FLOAT8E4M3FN,
207
+ DataType.FLOAT8E4M3FNUZ,
208
+ DataType.FLOAT8E5M2,
209
+ DataType.FLOAT8E5M2FNUZ,
210
+ DataType.INT4,
211
+ DataType.FLOAT4E2M1,
212
+ }
213
+
170
214
  def __repr__(self) -> str:
171
215
  return self.name
172
216
 
@@ -0,0 +1,50 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Numpy utilities for non-native type operation."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import typing
8
+ from collections.abc import Sequence
9
+
10
+ import numpy as np
11
+
12
+ if typing.TYPE_CHECKING:
13
+ import numpy.typing as npt
14
+
15
+
16
+ def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
17
+ """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
18
+ # Create a 1D copy
19
+ array_flat = array.ravel().view(np.uint8).copy()
20
+ size = array.size
21
+ odd_sized = size % 2 == 1
22
+ if odd_sized:
23
+ array_flat.resize([size + 1], refcheck=False)
24
+ array_flat &= 0x0F
25
+ array_flat[1::2] <<= 4
26
+ return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
27
+
28
+
29
+ def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
30
+ """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
31
+
32
+ Args:
33
+ data: A numpy array.
34
+ dims: The dimensions are used to reshape the unpacked buffer.
35
+
36
+ Returns:
37
+ A numpy array of int8/uint8 reshaped to dims.
38
+ """
39
+ assert data.dtype == np.uint8, "Input data must be of type uint8"
40
+ result = np.empty([data.size * 2], dtype=data.dtype)
41
+ array_low = data & np.uint8(0x0F)
42
+ array_high = data & np.uint8(0xF0)
43
+ array_high >>= np.uint8(4)
44
+ result[0::2] = array_low
45
+ result[1::2] = array_high
46
+ if result.size == np.prod(dims) + 1:
47
+ # handle single-element padding due to odd number of elements
48
+ result = result[:-1]
49
+ result.resize(dims, refcheck=False)
50
+ return result
@@ -6,6 +6,7 @@ __all__ = [
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
+ "DeduplicateInitializersPass",
9
10
  "InlinePass",
10
11
  "LiftConstantsToInitializersPass",
11
12
  "LiftSubgraphInitializersToMainGraphPass",
@@ -29,6 +30,9 @@ from onnx_ir.passes.common.constant_manipulation import (
29
30
  LiftSubgraphInitializersToMainGraphPass,
30
31
  RemoveInitializersFromInputsPass,
31
32
  )
33
+ from onnx_ir.passes.common.initializer_deduplication import (
34
+ DeduplicateInitializersPass,
35
+ )
32
36
  from onnx_ir.passes.common.inliner import InlinePass
33
37
  from onnx_ir.passes.common.onnx_checker import CheckerPass
34
38
  from onnx_ir.passes.common.shape_inference import ShapeInferencePass
@@ -0,0 +1,206 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Eliminate common subexpression in ONNX graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CommonSubexpressionEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+ from collections.abc import Sequence
13
+
14
+ import onnx_ir as ir
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
+ """Eliminate common subexpression in ONNX graphs.
21
+
22
+ .. versionadded:: 0.1.1
23
+
24
+ .. versionchanged:: 0.1.3
25
+ Constant nodes with values smaller than ``size_limit`` will be CSE'd.
26
+
27
+ Attributes:
28
+ size_limit: The maximum size of the tensor to be csed. If the tensor contains
29
+ number of elements larger than size_limit, it will not be cse'd. Default is 10.
30
+
31
+ """
32
+
33
+ def __init__(self, size_limit: int = 10):
34
+ """Initialize the CommonSubexpressionEliminationPass."""
35
+ super().__init__()
36
+ self.size_limit = size_limit
37
+
38
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
39
+ """Return the same ir.Model but with CSE applied to the graph."""
40
+ graph = model.graph
41
+ modified = self._eliminate_common_subexpression(graph)
42
+
43
+ return ir.passes.PassResult(
44
+ model,
45
+ modified=modified,
46
+ )
47
+
48
+ def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
49
+ """Eliminate common subexpression in ONNX graphs."""
50
+ modified: bool = False
51
+ # node to node identifier, length of outputs, inputs, and attributes
52
+ existing_node_info_to_the_node: dict[
53
+ tuple[
54
+ ir.OperatorIdentifier,
55
+ int, # len(outputs)
56
+ tuple[int, ...], # input ids
57
+ tuple[tuple[str, object], ...], # attributes
58
+ ],
59
+ ir.Node,
60
+ ] = {}
61
+
62
+ for node in graph:
63
+ # Skip control flow ops like Loop and If.
64
+ control_flow_op: bool = False
65
+ # Skip large tensors to avoid cse weights and bias.
66
+ large_tensor: bool = False
67
+ # Use equality to check if the node is a common subexpression.
68
+ attributes = {}
69
+ for k, v in node.attributes.items():
70
+ # TODO(exporter team): CSE subgraphs.
71
+ # NOTE: control flow ops like Loop and If won't be CSEd
72
+ # because attribute: graph won't match.
73
+ if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
74
+ control_flow_op = True
75
+ break
76
+ # The attribute value could be directly taken from the original
77
+ # protobuf, so we need to make a copy of it.
78
+ value = v.value
79
+ if v.type in (
80
+ ir.AttributeType.INTS,
81
+ ir.AttributeType.FLOATS,
82
+ ir.AttributeType.STRINGS,
83
+ ):
84
+ # For INT, FLOAT and STRING attributes, we convert them to tuples
85
+ # to ensure they are hashable.
86
+ value = tuple(value)
87
+ elif v.type is ir.AttributeType.TENSOR:
88
+ if value.size > self.size_limit:
89
+ # If the tensor is larger than the size limit, we skip it.
90
+ large_tensor = True
91
+ break
92
+ np_value = value.numpy()
93
+
94
+ value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
95
+ attributes[k] = value
96
+
97
+ if control_flow_op:
98
+ # If the node is a control flow op, we skip it.
99
+ logger.debug("Skipping control flow op %s", node)
100
+ continue
101
+
102
+ if large_tensor:
103
+ # If the node has a large tensor, we skip it.
104
+ logger.debug("Skipping large tensor in node %s", node)
105
+ continue
106
+
107
+ if _is_non_deterministic_op(node):
108
+ # If the node is a non-deterministic op, we skip it.
109
+ logger.debug("Skipping non-deterministic op %s", node)
110
+ continue
111
+
112
+ node_info = (
113
+ node.op_identifier(),
114
+ len(node.outputs),
115
+ tuple(id(input) for input in node.inputs),
116
+ tuple(sorted(attributes.items())),
117
+ )
118
+ # Check if the node is a common subexpression.
119
+ if node_info in existing_node_info_to_the_node:
120
+ # If it is, this node has an existing node with the same
121
+ # operator, number of outputs, inputs, and attributes.
122
+ # We replace the node with the existing node.
123
+ modified = True
124
+ existing_node = existing_node_info_to_the_node[node_info]
125
+ _remove_node_and_replace_values(
126
+ graph,
127
+ remove_node=node,
128
+ remove_values=node.outputs,
129
+ new_values=existing_node.outputs,
130
+ )
131
+ logger.debug("Reusing node %s", existing_node)
132
+ else:
133
+ # If it is not, add to the mapping.
134
+ existing_node_info_to_the_node[node_info] = node
135
+ return modified
136
+
137
+
138
+ def _remove_node_and_replace_values(
139
+ graph: ir.Graph,
140
+ /,
141
+ remove_node: ir.Node,
142
+ remove_values: Sequence[ir.Value],
143
+ new_values: Sequence[ir.Value],
144
+ ) -> None:
145
+ """Replaces nodes and values in the graph or function.
146
+
147
+ Args:
148
+ graph: The graph to replace nodes and values in.
149
+ remove_node: The node to remove.
150
+ remove_values: The values to replace.
151
+ new_values: The values to replace with.
152
+ """
153
+ # Reconnect the users of the deleted values to use the new values
154
+ ir.convenience.replace_all_uses_with(remove_values, new_values)
155
+ # Update graph/function outputs if the node generates output
156
+ if any(remove_value.is_graph_output() for remove_value in remove_values):
157
+ replacement_mapping = dict(zip(remove_values, new_values))
158
+ for idx, graph_output in enumerate(graph.outputs):
159
+ if graph_output in replacement_mapping:
160
+ new_value = replacement_mapping[graph_output]
161
+ if new_value.is_graph_output() or new_value.is_graph_input():
162
+ # If the new value is also a graph input/output, we need to
163
+ # create a Identity node to preserve the remove_value and
164
+ # prevent from changing new_value name.
165
+ identity_node = ir.node(
166
+ "Identity",
167
+ inputs=[new_value],
168
+ outputs=[
169
+ ir.Value(
170
+ name=graph_output.name,
171
+ type=graph_output.type,
172
+ shape=graph_output.shape,
173
+ )
174
+ ],
175
+ )
176
+ # reuse the name of the graph output
177
+ graph.outputs[idx] = identity_node.outputs[0]
178
+ graph.insert_before(
179
+ remove_node,
180
+ identity_node,
181
+ )
182
+ else:
183
+ # if new_value is not graph output, we just
184
+ # update it to use old_value name.
185
+ new_value.name = graph_output.name
186
+ graph.outputs[idx] = new_value
187
+
188
+ graph.remove(remove_node, safe=True)
189
+
190
+
191
+ def _is_non_deterministic_op(node: ir.Node) -> bool:
192
+ non_deterministic_ops = frozenset(
193
+ {
194
+ "RandomUniform",
195
+ "RandomNormal",
196
+ "RandomUniformLike",
197
+ "RandomNormalLike",
198
+ "Multinomial",
199
+ }
200
+ )
201
+ return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
202
+
203
+
204
+ def _is_onnx_domain(d: str) -> bool:
205
+ """Check if the domain is the ONNX domain."""
206
+ return d == ""
@@ -0,0 +1,56 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for removing duplicated initializer tensors from a graph."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "DeduplicateInitializersPass",
9
+ ]
10
+
11
+
12
+ import onnx_ir as ir
13
+
14
+
15
+ class DeduplicateInitializersPass(ir.passes.InPlacePass):
16
+ """Remove duplicated initializer tensors from the graph.
17
+
18
+ This pass detects initializers with identical shape, dtype, and content,
19
+ and replaces all duplicate references with a canonical one.
20
+
21
+ To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
22
+ to lift the initializers to the main graph first before running pass.
23
+
24
+ .. versionadded:: 0.1.3
25
+ """
26
+
27
+ def __init__(self, size_limit: int = 1024):
28
+ super().__init__()
29
+ self.size_limit = size_limit
30
+
31
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
32
+ graph = model.graph
33
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
34
+ modified = False
35
+
36
+ for initializer in tuple(graph.initializers.values()):
37
+ # TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
38
+ # out from the main graph before running this pass.
39
+ const_val = initializer.const_value
40
+ if const_val is None:
41
+ # Skip if initializer has no constant value
42
+ continue
43
+
44
+ if const_val.size > self.size_limit:
45
+ continue
46
+
47
+ key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
48
+ if key in initializers:
49
+ modified = True
50
+ ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
51
+ assert initializer.name is not None
52
+ graph.initializers.pop(initializer.name)
53
+ else:
54
+ initializers[key] = initializer # type: ignore[index]
55
+
56
+ return ir.passes.PassResult(model=model, modified=modified)
@@ -0,0 +1 @@
1
+
@@ -74,7 +74,6 @@ from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
74
74
 
75
75
  if typing.TYPE_CHECKING:
76
76
  import google.protobuf.internal.containers as proto_containers
77
- import numpy.typing as npt
78
77
 
79
78
  logger = logging.getLogger(__name__)
80
79
 
@@ -117,13 +116,6 @@ def _little_endian_dtype(dtype) -> np.dtype:
117
116
  return np.dtype(dtype).newbyteorder("<")
118
117
 
119
118
 
120
- def _unflatten_complex(
121
- array: npt.NDArray[np.float32 | np.float64],
122
- ) -> npt.NDArray[np.complex64 | np.complex128]:
123
- """Convert the real representation of a complex dtype to the complex dtype."""
124
- return array[::2] + 1j * array[1::2]
125
-
126
-
127
119
  @typing.overload
128
120
  def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
129
121
  @typing.overload
@@ -200,6 +192,9 @@ def from_onnx_text(
200
192
 
201
193
  Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
202
194
 
195
+ .. versionchanged:: 0.1.2
196
+ Added the ``initializers`` argument.
197
+
203
198
  Args:
204
199
  model_text: The ONNX textual representation of the model.
205
200
  initializers: Tensors to be added as initializers. If provided, these tensors
@@ -237,6 +232,8 @@ def to_onnx_text(
237
232
  ) -> str:
238
233
  """Convert the IR model to the ONNX textual representation.
239
234
 
235
+ .. versionadded:: 0.1.2
236
+
240
237
  Args:
241
238
  model: The IR model to convert.
242
239
  exclude_initializers: If True, the initializers will not be included in the output.
@@ -386,54 +383,88 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
386
383
  "Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
387
384
  )
388
385
 
386
+ shape = self._proto.dims
387
+
389
388
  if self._proto.HasField("raw_data"):
390
- array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<"))
391
- # Cannot return now, because we may need to unpack 4bit tensors
392
- elif dtype == _enums.DataType.STRING:
393
- return np.array(self._proto.string_data).reshape(self._proto.dims)
394
- elif self._proto.int32_data:
395
- array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
396
- if dtype in {_enums.DataType.FLOAT16, _enums.DataType.BFLOAT16}:
397
- # Reinterpret the int32 as float16 or bfloat16
398
- array = array.astype(np.uint16).view(dtype.numpy())
399
- elif dtype in {
389
+ if dtype.bitwidth == 4:
390
+ return _type_casting.unpack_4bitx2(
391
+ np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
392
+ ).view(dtype.numpy())
393
+ return np.frombuffer(
394
+ self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
395
+ ).reshape(shape)
396
+ if dtype == _enums.DataType.STRING:
397
+ return np.array(self._proto.string_data).reshape(shape)
398
+ if self._proto.int32_data:
399
+ assert dtype in {
400
+ _enums.DataType.BFLOAT16,
401
+ _enums.DataType.BOOL,
402
+ _enums.DataType.FLOAT16,
403
+ _enums.DataType.FLOAT4E2M1,
400
404
  _enums.DataType.FLOAT8E4M3FN,
401
405
  _enums.DataType.FLOAT8E4M3FNUZ,
402
406
  _enums.DataType.FLOAT8E5M2,
403
407
  _enums.DataType.FLOAT8E5M2FNUZ,
404
- }:
405
- array = array.astype(np.uint8).view(dtype.numpy())
406
- elif self._proto.int64_data:
407
- array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64))
408
- elif self._proto.uint64_data:
408
+ _enums.DataType.INT16,
409
+ _enums.DataType.INT32,
410
+ _enums.DataType.INT4,
411
+ _enums.DataType.INT8,
412
+ _enums.DataType.UINT16,
413
+ _enums.DataType.UINT4,
414
+ _enums.DataType.UINT8,
415
+ }, f"Unsupported dtype {dtype} for int32_data"
416
+ array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
417
+ if dtype.bitwidth == 32:
418
+ return array.reshape(shape)
419
+ if dtype.bitwidth == 16:
420
+ # Reinterpret the int32 as float16 or bfloat16
421
+ return array.astype(np.uint16).view(dtype.numpy()).reshape(shape)
422
+ if dtype.bitwidth == 8:
423
+ return array.astype(np.uint8).view(dtype.numpy()).reshape(shape)
424
+ if dtype.bitwidth == 4:
425
+ return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
426
+ dtype.numpy()
427
+ )
428
+ raise ValueError(
429
+ f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
430
+ )
431
+ if self._proto.int64_data:
432
+ assert dtype in {
433
+ _enums.DataType.INT64,
434
+ }, f"Unsupported dtype {dtype} for int64_data"
435
+ return np.array(
436
+ self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
437
+ ).reshape(shape)
438
+ if self._proto.uint64_data:
439
+ assert dtype in {
440
+ _enums.DataType.UINT64,
441
+ _enums.DataType.UINT32,
442
+ }, f"Unsupported dtype {dtype} for uint64_data"
409
443
  array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
410
- elif self._proto.float_data:
444
+ if dtype == _enums.DataType.UINT32:
445
+ return array.astype(np.uint32).reshape(shape)
446
+ return array.reshape(shape)
447
+ if self._proto.float_data:
448
+ assert dtype in {
449
+ _enums.DataType.FLOAT,
450
+ _enums.DataType.COMPLEX64,
451
+ }, f"Unsupported dtype {dtype} for float_data"
411
452
  array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
412
453
  if dtype == _enums.DataType.COMPLEX64:
413
- array = _unflatten_complex(array)
414
- elif self._proto.double_data:
454
+ return array.view(np.complex64).reshape(shape)
455
+ return array.reshape(shape)
456
+ if self._proto.double_data:
457
+ assert dtype in {
458
+ _enums.DataType.DOUBLE,
459
+ _enums.DataType.COMPLEX128,
460
+ }, f"Unsupported dtype {dtype} for double_data"
415
461
  array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
416
462
  if dtype == _enums.DataType.COMPLEX128:
417
- array = _unflatten_complex(array)
418
- else:
419
- # Empty tensor
420
- if not self._proto.dims:
421
- # When dims not precent and there is no data, we return an empty array
422
- return np.array([], dtype=dtype.numpy())
423
- else:
424
- # Otherwise we return a size 0 array with the correct shape
425
- return np.zeros(self._proto.dims, dtype=dtype.numpy())
426
-
427
- if dtype == _enums.DataType.INT4:
428
- return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
429
- elif dtype == _enums.DataType.UINT4:
430
- return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
431
- elif dtype == _enums.DataType.FLOAT4E2M1:
432
- return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
433
- else:
434
- # Otherwise convert to the correct dtype and reshape
435
- # Note we cannot use view() here because the storage dtype may not be the same size as the target
436
- return array.astype(dtype.numpy()).reshape(self._proto.dims)
463
+ return array.view(np.complex128).reshape(shape)
464
+ return array.reshape(shape)
465
+
466
+ # Empty tensor. We return a size 0 array with the correct shape
467
+ return np.zeros(shape, dtype=dtype.numpy())
437
468
 
438
469
  def tobytes(self) -> bytes:
439
470
  """Return the tensor as a byte string conformed to the ONNX specification, in little endian.
@@ -29,6 +29,8 @@ Example::
29
29
  from __future__ import annotations
30
30
 
31
31
  __all__ = [
32
+ "from_torch_dtype",
33
+ "to_torch_dtype",
32
34
  "TorchTensor",
33
35
  ]
34
36
 
@@ -44,14 +46,17 @@ if TYPE_CHECKING:
44
46
  import torch
45
47
 
46
48
 
47
- class TorchTensor(_core.Tensor):
48
- def __init__(
49
- self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
50
- ):
51
- # Pass the tensor as the raw data to ir.Tensor's constructor
49
+ _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
50
+ _ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None
51
+
52
+
53
+ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
54
+ """Convert a PyTorch dtype to an ONNX IR DataType."""
55
+ global _TORCH_DTYPE_TO_ONNX
56
+ if _TORCH_DTYPE_TO_ONNX is None:
52
57
  import torch
53
58
 
54
- _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
59
+ _TORCH_DTYPE_TO_ONNX = {
55
60
  torch.bfloat16: ir.DataType.BFLOAT16,
56
61
  torch.bool: ir.DataType.BOOL,
57
62
  torch.complex128: ir.DataType.COMPLEX128,
@@ -72,8 +77,58 @@ class TorchTensor(_core.Tensor):
72
77
  torch.uint32: ir.DataType.UINT32,
73
78
  torch.uint64: ir.DataType.UINT64,
74
79
  }
80
+ if dtype not in _TORCH_DTYPE_TO_ONNX:
81
+ raise TypeError(
82
+ f"Unsupported PyTorch dtype '{dtype}'. "
83
+ "Please use a supported dtype from the list: "
84
+ f"{list(_TORCH_DTYPE_TO_ONNX.keys())}"
85
+ )
86
+ return _TORCH_DTYPE_TO_ONNX[dtype]
87
+
88
+
89
+ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
90
+ """Convert an ONNX IR DataType to a PyTorch dtype."""
91
+ global _ONNX_DTYPE_TO_TORCH
92
+ if _ONNX_DTYPE_TO_TORCH is None:
93
+ import torch
94
+
95
+ _ONNX_DTYPE_TO_TORCH = {
96
+ ir.DataType.BFLOAT16: torch.bfloat16,
97
+ ir.DataType.BOOL: torch.bool,
98
+ ir.DataType.COMPLEX128: torch.complex128,
99
+ ir.DataType.COMPLEX64: torch.complex64,
100
+ ir.DataType.FLOAT16: torch.float16,
101
+ ir.DataType.FLOAT: torch.float32,
102
+ ir.DataType.DOUBLE: torch.float64,
103
+ ir.DataType.FLOAT8E4M3FN: torch.float8_e4m3fn,
104
+ ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
105
+ ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
106
+ ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
107
+ ir.DataType.INT16: torch.int16,
108
+ ir.DataType.INT32: torch.int32,
109
+ ir.DataType.INT64: torch.int64,
110
+ ir.DataType.INT8: torch.int8,
111
+ ir.DataType.UINT8: torch.uint8,
112
+ ir.DataType.UINT16: torch.uint16,
113
+ ir.DataType.UINT32: torch.uint32,
114
+ ir.DataType.UINT64: torch.uint64,
115
+ }
116
+ if dtype not in _ONNX_DTYPE_TO_TORCH:
117
+ raise TypeError(
118
+ f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
119
+ "Please use a supported dtype from the list: "
120
+ f"{list(_ONNX_DTYPE_TO_TORCH.keys())}"
121
+ )
122
+ return _ONNX_DTYPE_TO_TORCH[dtype]
123
+
124
+
125
+ class TorchTensor(_core.Tensor):
126
+ def __init__(
127
+ self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
128
+ ):
129
+ # Pass the tensor as the raw data to ir.Tensor's constructor
75
130
  super().__init__(
76
- tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
131
+ tensor, dtype=from_torch_dtype(tensor.dtype), name=name, doc_string=doc_string
77
132
  )
78
133
 
79
134
  def numpy(self) -> npt.NDArray:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -19,6 +19,7 @@ src/onnx_ir/_type_casting.py
19
19
  src/onnx_ir/_version_utils.py
20
20
  src/onnx_ir/convenience.py
21
21
  src/onnx_ir/external_data.py
22
+ src/onnx_ir/py.typed
22
23
  src/onnx_ir/serde.py
23
24
  src/onnx_ir/tape.py
24
25
  src/onnx_ir/tensor_adapters.py
@@ -39,6 +40,7 @@ src/onnx_ir/passes/common/_c_api_utils.py
39
40
  src/onnx_ir/passes/common/clear_metadata_and_docstring.py
40
41
  src/onnx_ir/passes/common/common_subexpression_elimination.py
41
42
  src/onnx_ir/passes/common/constant_manipulation.py
43
+ src/onnx_ir/passes/common/initializer_deduplication.py
42
44
  src/onnx_ir/passes/common/inliner.py
43
45
  src/onnx_ir/passes/common/onnx_checker.py
44
46
  src/onnx_ir/passes/common/shape_inference.py
@@ -1,107 +0,0 @@
1
- # Copyright (c) ONNX Project Contributors
2
- # SPDX-License-Identifier: Apache-2.0
3
- """Numpy utilities for non-native type operation."""
4
- # TODO(justinchuby): Upstream the logic to onnx
5
-
6
- from __future__ import annotations
7
-
8
- import typing
9
- from collections.abc import Sequence
10
-
11
- import ml_dtypes
12
- import numpy as np
13
-
14
- if typing.TYPE_CHECKING:
15
- import numpy.typing as npt
16
-
17
-
18
- def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
19
- """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
20
- # Create a 1D copy
21
- array_flat = array.ravel().view(np.uint8).copy()
22
- size = array.size
23
- odd_sized = size % 2 == 1
24
- if odd_sized:
25
- array_flat.resize([size + 1], refcheck=False)
26
- array_flat &= 0x0F
27
- array_flat[1::2] <<= 4
28
- return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
29
-
30
-
31
- def _unpack_uint4_as_uint8(
32
- data: npt.NDArray[np.uint8], dims: Sequence[int]
33
- ) -> npt.NDArray[np.uint8]:
34
- """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
35
-
36
- Args:
37
- data: A numpy array.
38
- dims: The dimensions are used to reshape the unpacked buffer.
39
-
40
- Returns:
41
- A numpy array of int8/uint8 reshaped to dims.
42
- """
43
- assert data.dtype == np.uint8, "Input data must be of type uint8"
44
- result = np.empty([data.size * 2], dtype=data.dtype)
45
- array_low = data & np.uint8(0x0F)
46
- array_high = data & np.uint8(0xF0)
47
- array_high >>= np.uint8(4)
48
- result[0::2] = array_low
49
- result[1::2] = array_high
50
- if result.size == np.prod(dims) + 1:
51
- # handle single-element padding due to odd number of elements
52
- result = result[:-1]
53
- result.resize(dims, refcheck=False)
54
- return result
55
-
56
-
57
- def unpack_uint4(
58
- data: npt.NDArray[np.uint8], dims: Sequence[int]
59
- ) -> npt.NDArray[ml_dtypes.uint4]:
60
- """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
61
-
62
- Args:
63
- data: A numpy array.
64
- dims: The dimensions are used to reshape the unpacked buffer.
65
-
66
- Returns:
67
- A numpy array of int8/uint8 reshaped to dims.
68
- """
69
- return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.uint4)
70
-
71
-
72
- def _extend_int4_sign_bits(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]:
73
- """Extend 4-bit signed integer to 8-bit signed integer."""
74
- return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8)
75
-
76
-
77
- def unpack_int4(
78
- data: npt.NDArray[np.uint8], dims: Sequence[int]
79
- ) -> npt.NDArray[ml_dtypes.int4]:
80
- """Convert a packed (signed) int4 array to unpacked int4 array represented as int8.
81
-
82
- The sign bit is extended to the most significant bit of the int8.
83
-
84
- Args:
85
- data: A numpy array.
86
- dims: The dimensions are used to reshape the unpacked buffer.
87
-
88
- Returns:
89
- A numpy array of int8 reshaped to dims.
90
- """
91
- unpacked = _unpack_uint4_as_uint8(data, dims)
92
- return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
93
-
94
-
95
- def unpack_float4e2m1(
96
- data: npt.NDArray[np.uint8], dims: Sequence[int]
97
- ) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
98
- """Convert a packed float4e2m1 array to unpacked float4e2m1 array.
99
-
100
- Args:
101
- data: A numpy array.
102
- dims: The dimensions are used to reshape the unpacked buffer.
103
-
104
- Returns:
105
- A numpy array of float32 reshaped to dims.
106
- """
107
- return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)
@@ -1,177 +0,0 @@
1
- # Copyright (c) ONNX Project Contributors
2
- # SPDX-License-Identifier: Apache-2.0
3
- """Eliminate common subexpression in ONNX graphs."""
4
-
5
- from __future__ import annotations
6
-
7
- __all__ = [
8
- "CommonSubexpressionEliminationPass",
9
- ]
10
-
11
- import logging
12
- from collections.abc import Sequence
13
-
14
- import onnx_ir as ir
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
- """Eliminate common subexpression in ONNX graphs."""
21
-
22
- def call(self, model: ir.Model) -> ir.passes.PassResult:
23
- """Return the same ir.Model but with CSE applied to the graph."""
24
- modified = False
25
- graph = model.graph
26
-
27
- modified = _eliminate_common_subexpression(graph, modified)
28
-
29
- return ir.passes.PassResult(
30
- model,
31
- modified=modified,
32
- )
33
-
34
-
35
- def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36
- """Eliminate common subexpression in ONNX graphs."""
37
- # node to node identifier, length of outputs, inputs, and attributes
38
- existing_node_info_to_the_node: dict[
39
- tuple[
40
- ir.OperatorIdentifier,
41
- int, # len(outputs)
42
- tuple[int, ...], # input ids
43
- tuple[tuple[str, object], ...], # attributes
44
- ],
45
- ir.Node,
46
- ] = {}
47
-
48
- for node in graph:
49
- # Skip control flow ops like Loop and If.
50
- control_flow_op: bool = False
51
- # Use equality to check if the node is a common subexpression.
52
- attributes = {}
53
- for k, v in node.attributes.items():
54
- # TODO(exporter team): CSE subgraphs.
55
- # NOTE: control flow ops like Loop and If won't be CSEd
56
- # because attribute: graph won't match.
57
- if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
58
- control_flow_op = True
59
- logger.debug("Skipping control flow op %s", node)
60
- # The attribute value could be directly taken from the original
61
- # protobuf, so we need to make a copy of it.
62
- value = v.value
63
- if v.type in (
64
- ir.AttributeType.INTS,
65
- ir.AttributeType.FLOATS,
66
- ir.AttributeType.STRINGS,
67
- ):
68
- # For INT, FLOAT and STRING attributes, we convert them to tuples
69
- # to ensure they are hashable.
70
- value = tuple(value)
71
- attributes[k] = value
72
-
73
- if control_flow_op:
74
- # If the node is a control flow op, we skip it.
75
- logger.debug("Skipping control flow op %s", node)
76
- continue
77
-
78
- if _is_non_deterministic_op(node):
79
- # If the node is a non-deterministic op, we skip it.
80
- logger.debug("Skipping non-deterministic op %s", node)
81
- continue
82
-
83
- node_info = (
84
- node.op_identifier(),
85
- len(node.outputs),
86
- tuple(id(input) for input in node.inputs),
87
- tuple(sorted(attributes.items())),
88
- )
89
- # Check if the node is a common subexpression.
90
- if node_info in existing_node_info_to_the_node:
91
- # If it is, this node has an existing node with the same
92
- # operator, number of outputs, inputs, and attributes.
93
- # We replace the node with the existing node.
94
- modified = True
95
- existing_node = existing_node_info_to_the_node[node_info]
96
- _remove_node_and_replace_values(
97
- graph,
98
- remove_node=node,
99
- remove_values=node.outputs,
100
- new_values=existing_node.outputs,
101
- )
102
- logger.debug("Reusing node %s", existing_node)
103
- else:
104
- # If it is not, add to the mapping.
105
- existing_node_info_to_the_node[node_info] = node
106
- return modified
107
-
108
-
109
- def _remove_node_and_replace_values(
110
- graph: ir.Graph,
111
- /,
112
- remove_node: ir.Node,
113
- remove_values: Sequence[ir.Value],
114
- new_values: Sequence[ir.Value],
115
- ) -> None:
116
- """Replaces nodes and values in the graph or function.
117
-
118
- Args:
119
- graph: The graph to replace nodes and values in.
120
- remove_node: The node to remove.
121
- remove_values: The values to replace.
122
- new_values: The values to replace with.
123
- """
124
- # Reconnect the users of the deleted values to use the new values
125
- ir.convenience.replace_all_uses_with(remove_values, new_values)
126
- # Update graph/function outputs if the node generates output
127
- if any(remove_value.is_graph_output() for remove_value in remove_values):
128
- replacement_mapping = dict(zip(remove_values, new_values))
129
- for idx, graph_output in enumerate(graph.outputs):
130
- if graph_output in replacement_mapping:
131
- new_value = replacement_mapping[graph_output]
132
- if new_value.is_graph_output() or new_value.is_graph_input():
133
- # If the new value is also a graph input/output, we need to
134
- # create a Identity node to preserve the remove_value and
135
- # prevent from changing new_value name.
136
- identity_node = ir.node(
137
- "Identity",
138
- inputs=[new_value],
139
- outputs=[
140
- ir.Value(
141
- name=graph_output.name,
142
- type=graph_output.type,
143
- shape=graph_output.shape,
144
- )
145
- ],
146
- )
147
- # reuse the name of the graph output
148
- graph.outputs[idx] = identity_node.outputs[0]
149
- graph.insert_before(
150
- remove_node,
151
- identity_node,
152
- )
153
- else:
154
- # if new_value is not graph output, we just
155
- # update it to use old_value name.
156
- new_value.name = graph_output.name
157
- graph.outputs[idx] = new_value
158
-
159
- graph.remove(remove_node, safe=True)
160
-
161
-
162
- def _is_non_deterministic_op(node: ir.Node) -> bool:
163
- non_deterministic_ops = frozenset(
164
- {
165
- "RandomUniform",
166
- "RandomNormal",
167
- "RandomUniformLike",
168
- "RandomNormalLike",
169
- "Multinomial",
170
- }
171
- )
172
- return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
173
-
174
-
175
- def _is_onnx_domain(d: str) -> bool:
176
- """Check if the domain is the ONNX domain."""
177
- return d == ""
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes