onnx-ir 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/_type_casting.py CHANGED
@@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
15
15
  import numpy.typing as npt
16
16
 
17
17
 
18
- def pack_int4(array: np.ndarray) -> npt.NDArray[np.uint8]:
18
+ def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
19
19
  """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
20
20
  # Create a 1D copy
21
21
  array_flat = array.ravel().view(np.uint8).copy()
@@ -40,6 +40,7 @@ def _unpack_uint4_as_uint8(
40
40
  Returns:
41
41
  A numpy array of int8/uint8 reshaped to dims.
42
42
  """
43
+ assert data.dtype == np.uint8, "Input data must be of type uint8"
43
44
  result = np.empty([data.size * 2], dtype=data.dtype)
44
45
  array_low = data & np.uint8(0x0F)
45
46
  array_high = data & np.uint8(0xF0)
onnx_ir/_version_utils.py CHANGED
@@ -2,6 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Version utils for testing."""
4
4
 
5
+ # pylint: disable=import-outside-toplevel
5
6
  from __future__ import annotations
6
7
 
7
8
  import packaging.version
@@ -9,7 +10,7 @@ import packaging.version
9
10
 
10
11
  def onnx_older_than(version: str) -> bool:
11
12
  """Returns True if the ONNX version is older than the given version."""
12
- import onnx # pylint: disable=import-outside-toplevel
13
+ import onnx # noqa: TID251
13
14
 
14
15
  return (
15
16
  packaging.version.parse(onnx.__version__).release
@@ -19,7 +20,7 @@ def onnx_older_than(version: str) -> bool:
19
20
 
20
21
  def torch_older_than(version: str) -> bool:
21
22
  """Returns True if the torch version is older than the given version."""
22
- import torch # pylint: disable=import-outside-toplevel
23
+ import torch
23
24
 
24
25
  return (
25
26
  packaging.version.parse(torch.__version__).release
@@ -27,42 +28,9 @@ def torch_older_than(version: str) -> bool:
27
28
  )
28
29
 
29
30
 
30
- def transformers_older_than(version: str) -> bool | None:
31
- """Returns True if the transformers version is older than the given version."""
32
- try:
33
- import transformers # pylint: disable=import-outside-toplevel
34
- except ImportError:
35
- return None
36
-
37
- return (
38
- packaging.version.parse(transformers.__version__).release
39
- < packaging.version.parse(version).release
40
- )
41
-
42
-
43
- def is_onnxruntime_training() -> bool:
44
- """Returns True if the onnxruntime is onnxruntime-training."""
45
- try:
46
- from onnxruntime import training # pylint: disable=import-outside-toplevel
47
-
48
- assert training
49
- except ImportError:
50
- # onnxruntime not training
51
- return False
52
-
53
- try:
54
- from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
55
- OrtValueVector,
56
- )
57
- except ImportError:
58
- return False
59
-
60
- return hasattr(OrtValueVector, "push_back_batch")
61
-
62
-
63
31
  def onnxruntime_older_than(version: str) -> bool:
64
32
  """Returns True if the onnxruntime version is older than the given version."""
65
- import onnxruntime # pylint: disable=import-outside-toplevel
33
+ import onnxruntime
66
34
 
67
35
  return (
68
36
  packaging.version.parse(onnxruntime.__version__).release
@@ -72,20 +40,9 @@ def onnxruntime_older_than(version: str) -> bool:
72
40
 
73
41
  def numpy_older_than(version: str) -> bool:
74
42
  """Returns True if the numpy version is older than the given version."""
75
- import numpy # pylint: disable=import-outside-toplevel
43
+ import numpy
76
44
 
77
45
  return (
78
46
  packaging.version.parse(numpy.__version__).release
79
47
  < packaging.version.parse(version).release
80
48
  )
81
-
82
-
83
- def has_transformers():
84
- """Tells if transformers is installed."""
85
- try:
86
- import transformers # pylint: disable=import-outside-toplevel
87
-
88
- assert transformers
89
- return True # noqa
90
- except ImportError:
91
- return False
onnx_ir/convenience.py CHANGED
@@ -7,15 +7,17 @@ from __future__ import annotations
7
7
  __all__ = [
8
8
  "convert_attribute",
9
9
  "convert_attributes",
10
+ "create_value_mapping",
11
+ "get_const_tensor",
10
12
  "replace_all_uses_with",
11
13
  "replace_nodes_and_values",
12
- "create_value_mapping",
13
14
  ]
14
15
 
15
16
  from onnx_ir._convenience import (
16
17
  convert_attribute,
17
18
  convert_attributes,
18
19
  create_value_mapping,
20
+ get_const_tensor,
19
21
  replace_all_uses_with,
20
22
  replace_nodes_and_values,
21
23
  )
onnx_ir/external_data.py CHANGED
@@ -4,12 +4,15 @@
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ from typing import Callable
8
+
7
9
  __all__ = [
8
10
  "set_base_dir",
9
11
  "unload_from_model",
10
12
  "load_to_model",
11
13
  "convert_tensors_to_external",
12
14
  "convert_tensors_from_external",
15
+ "CallbackInfo",
13
16
  ]
14
17
 
15
18
  import dataclasses
@@ -48,6 +51,21 @@ class _ExternalDataInfo:
48
51
  length: int
49
52
 
50
53
 
54
+ @dataclasses.dataclass
55
+ class CallbackInfo:
56
+ """A class that shares information about a tensor that is to be saved as external data for callback functions.
57
+
58
+ Attributes:
59
+ total: The total number of tensors to save.
60
+ index: The index of the tensor being saved.
61
+ offset: The offset of the tensor in the external data file.
62
+ """
63
+
64
+ total: int
65
+ index: int
66
+ offset: int
67
+
68
+
51
69
  def _all_tensors(
52
70
  graph: _core.Graph | _core.GraphView, include_attributes: bool = False
53
71
  ) -> Iterator[_protocols.TensorProtocol]:
@@ -157,6 +175,7 @@ def _write_external_data(
157
175
  tensors: Sequence[_protocols.TensorProtocol],
158
176
  external_data_infos: Sequence[_ExternalDataInfo],
159
177
  file_path: str | os.PathLike,
178
+ callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
160
179
  ) -> None:
161
180
  """Write tensor data to an external file according to information stored in ExternalDataInfo objects.
162
181
 
@@ -164,12 +183,26 @@ def _write_external_data(
164
183
  tensors: Tensors to be written as external data.
165
184
  external_data_infos: External data information stored for each tensor to be written as external data.
166
185
  file_path: Location to which external data is to be stored.
186
+ callback: A callback function that is called for each tensor that is saved to external data
187
+ for debugging or logging purposes.
167
188
  """
168
- assert len(tensors) == len(external_data_infos), (
189
+ tensors_count = len(tensors)
190
+ assert tensors_count == len(external_data_infos), (
169
191
  "Number of tensors and external data infos should match"
170
192
  )
171
193
  with open(file_path, "wb") as data_file:
172
- for tensor, tensor_info in zip(tensors, external_data_infos, strict=True):
194
+ for i, (tensor, tensor_info) in enumerate(
195
+ zip(tensors, external_data_infos, strict=True)
196
+ ):
197
+ if callback is not None:
198
+ callback(
199
+ tensor,
200
+ CallbackInfo(
201
+ total=tensors_count,
202
+ index=i,
203
+ offset=tensor_info.offset,
204
+ ),
205
+ )
173
206
  current_offset = tensor_info.offset
174
207
  assert tensor is not None
175
208
  raw_data = tensor.tobytes()
@@ -228,6 +261,7 @@ def convert_tensors_to_external(
228
261
  tensors: Sequence[_protocols.TensorProtocol],
229
262
  base_dir: str | os.PathLike,
230
263
  relative_path: str | os.PathLike,
264
+ callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
231
265
  ) -> list[_core.ExternalTensor]:
232
266
  """Convert a sequence of any TensorProtocol tensors to external tensors.
233
267
 
@@ -238,6 +272,8 @@ def convert_tensors_to_external(
238
272
  tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
239
273
  base_dir: Path of base directory.
240
274
  relative_path: Path to which external data is to be stored, relative to the ONNX file.
275
+ callback: A callback function that is called for each tensor that is saved to external data
276
+ for debugging or logging purposes.
241
277
 
242
278
  Returns:
243
279
  A list of external tensors derived from a list of input tensors. The order
@@ -285,7 +321,7 @@ def convert_tensors_to_external(
285
321
  external_info = _compute_external_data_info(tensor, current_offset)
286
322
  external_data_infos.append(external_info)
287
323
  current_offset = external_info.offset + external_info.length
288
- _write_external_data(sorted_tensors, external_data_infos, path)
324
+ _write_external_data(sorted_tensors, external_data_infos, path, callback=callback)
289
325
 
290
326
  # Create external tensor objects
291
327
  external_tensors: list[_core.ExternalTensor] = [
@@ -336,6 +372,7 @@ def unload_from_model(
336
372
  relative_path: str | os.PathLike,
337
373
  *,
338
374
  size_threshold_bytes: int = 0,
375
+ callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
339
376
  ) -> _core.Model:
340
377
  """Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file.
341
378
 
@@ -356,6 +393,8 @@ def unload_from_model(
356
393
  relative_path: Path to which external data is to be stored, relative to the ONNX file.
357
394
  E.g. "model.data"
358
395
  size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
396
+ callback: A callback function that is called for each tensor that is saved to external data
397
+ for debugging or logging purposes.
359
398
 
360
399
  Returns:
361
400
  An ir.Model with all initializer data equal or above ``size_threshold_bytes``
@@ -384,6 +423,7 @@ def unload_from_model(
384
423
  [v.const_value for v in initializers_to_become_external], # type: ignore[misc]
385
424
  base_dir=base_dir,
386
425
  relative_path=relative_path,
426
+ callback=callback,
387
427
  )
388
428
 
389
429
  # Replace the initializer values with external tensors and save the model
@@ -127,7 +127,7 @@ class PassBase(abc.ABC):
127
127
 
128
128
  # Check postconditions
129
129
  try:
130
- self.ensures(model)
130
+ self.ensures(result.model)
131
131
  except PostconditionError:
132
132
  raise
133
133
  except Exception as e:
@@ -6,6 +6,7 @@ __all__ = [
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
+ "DeduplicateInitializersPass",
9
10
  "InlinePass",
10
11
  "LiftConstantsToInitializersPass",
11
12
  "LiftSubgraphInitializersToMainGraphPass",
@@ -29,6 +30,9 @@ from onnx_ir.passes.common.constant_manipulation import (
29
30
  LiftSubgraphInitializersToMainGraphPass,
30
31
  RemoveInitializersFromInputsPass,
31
32
  )
33
+ from onnx_ir.passes.common.initializer_deduplication import (
34
+ DeduplicateInitializersPass,
35
+ )
32
36
  from onnx_ir.passes.common.inliner import InlinePass
33
37
  from onnx_ir.passes.common.onnx_checker import CheckerPass
34
38
  from onnx_ir.passes.common.shape_inference import ShapeInferencePass
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Callable, TypeVar
10
10
  import onnx_ir as ir
11
11
 
12
12
  if TYPE_CHECKING:
13
- import onnx
13
+ import onnx # noqa: TID251
14
14
 
15
15
 
16
16
  logger = logging.getLogger(__name__)
@@ -17,93 +17,122 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
19
  class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
- """Eliminate common subexpression in ONNX graphs."""
20
+ """Eliminate common subexpression in ONNX graphs.
21
+
22
+ .. versionadded:: 0.1.1
23
+
24
+ .. versionchanged:: 0.1.3
25
+ Constant nodes with values smaller than ``size_limit`` will be CSE'd.
26
+
27
+ Attributes:
28
+ size_limit: The maximum size of the tensor to be csed. If the tensor contains
29
+ number of elements larger than size_limit, it will not be cse'd. Default is 10.
30
+
31
+ """
32
+
33
+ def __init__(self, size_limit: int = 10):
34
+ """Initialize the CommonSubexpressionEliminationPass."""
35
+ super().__init__()
36
+ self.size_limit = size_limit
21
37
 
22
38
  def call(self, model: ir.Model) -> ir.passes.PassResult:
23
39
  """Return the same ir.Model but with CSE applied to the graph."""
24
- modified = False
25
40
  graph = model.graph
26
-
27
- modified = _eliminate_common_subexpression(graph, modified)
41
+ modified = self._eliminate_common_subexpression(graph)
28
42
 
29
43
  return ir.passes.PassResult(
30
44
  model,
31
45
  modified=modified,
32
46
  )
33
47
 
34
-
35
- def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36
- """Eliminate common subexpression in ONNX graphs."""
37
- # node to node identifier, length of outputs, inputs, and attributes
38
- existing_node_info_to_the_node: dict[
39
- tuple[
40
- ir.OperatorIdentifier,
41
- int, # len(outputs)
42
- tuple[int, ...], # input ids
43
- tuple[tuple[str, object], ...], # attributes
44
- ],
45
- ir.Node,
46
- ] = {}
47
-
48
- for node in graph:
49
- # Skip control flow ops like Loop and If.
50
- control_flow_op: bool = False
51
- # Use equality to check if the node is a common subexpression.
52
- attributes = {}
53
- for k, v in node.attributes.items():
54
- # TODO(exporter team): CSE subgraphs.
55
- # NOTE: control flow ops like Loop and If won't be CSEd
56
- # because attribute: graph won't match.
57
- if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
58
- control_flow_op = True
48
+ def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
49
+ """Eliminate common subexpression in ONNX graphs."""
50
+ modified: bool = False
51
+ # node to node identifier, length of outputs, inputs, and attributes
52
+ existing_node_info_to_the_node: dict[
53
+ tuple[
54
+ ir.OperatorIdentifier,
55
+ int, # len(outputs)
56
+ tuple[int, ...], # input ids
57
+ tuple[tuple[str, object], ...], # attributes
58
+ ],
59
+ ir.Node,
60
+ ] = {}
61
+
62
+ for node in graph:
63
+ # Skip control flow ops like Loop and If.
64
+ control_flow_op: bool = False
65
+ # Skip large tensors to avoid cse weights and bias.
66
+ large_tensor: bool = False
67
+ # Use equality to check if the node is a common subexpression.
68
+ attributes = {}
69
+ for k, v in node.attributes.items():
70
+ # TODO(exporter team): CSE subgraphs.
71
+ # NOTE: control flow ops like Loop and If won't be CSEd
72
+ # because attribute: graph won't match.
73
+ if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
74
+ control_flow_op = True
75
+ break
76
+ # The attribute value could be directly taken from the original
77
+ # protobuf, so we need to make a copy of it.
78
+ value = v.value
79
+ if v.type in (
80
+ ir.AttributeType.INTS,
81
+ ir.AttributeType.FLOATS,
82
+ ir.AttributeType.STRINGS,
83
+ ):
84
+ # For INT, FLOAT and STRING attributes, we convert them to tuples
85
+ # to ensure they are hashable.
86
+ value = tuple(value)
87
+ elif v.type is ir.AttributeType.TENSOR:
88
+ if value.size > self.size_limit:
89
+ # If the tensor is larger than the size limit, we skip it.
90
+ large_tensor = True
91
+ break
92
+ np_value = value.numpy()
93
+
94
+ value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
95
+ attributes[k] = value
96
+
97
+ if control_flow_op:
98
+ # If the node is a control flow op, we skip it.
59
99
  logger.debug("Skipping control flow op %s", node)
60
- # The attribute value could be directly taken from the original
61
- # protobuf, so we need to make a copy of it.
62
- value = v.value
63
- if v.type in (
64
- ir.AttributeType.INTS,
65
- ir.AttributeType.FLOATS,
66
- ir.AttributeType.STRINGS,
67
- ):
68
- # For INT, FLOAT and STRING attributes, we convert them to tuples
69
- # to ensure they are hashable.
70
- value = tuple(value)
71
- attributes[k] = value
72
-
73
- if control_flow_op:
74
- # If the node is a control flow op, we skip it.
75
- logger.debug("Skipping control flow op %s", node)
76
- continue
77
-
78
- if _is_non_deterministic_op(node):
79
- # If the node is a non-deterministic op, we skip it.
80
- logger.debug("Skipping non-deterministic op %s", node)
81
- continue
82
-
83
- node_info = (
84
- node.op_identifier(),
85
- len(node.outputs),
86
- tuple(id(input) for input in node.inputs),
87
- tuple(sorted(attributes.items())),
88
- )
89
- # Check if the node is a common subexpression.
90
- if node_info in existing_node_info_to_the_node:
91
- # If it is, this node has an existing node with the same
92
- # operator, number of outputs, inputs, and attributes.
93
- # We replace the node with the existing node.
94
- modified = True
95
- existing_node = existing_node_info_to_the_node[node_info]
96
- _remove_node_and_replace_values(
97
- graph,
98
- remove_node=node,
99
- remove_values=node.outputs,
100
- new_values=existing_node.outputs,
100
+ continue
101
+
102
+ if large_tensor:
103
+ # If the node has a large tensor, we skip it.
104
+ logger.debug("Skipping large tensor in node %s", node)
105
+ continue
106
+
107
+ if _is_non_deterministic_op(node):
108
+ # If the node is a non-deterministic op, we skip it.
109
+ logger.debug("Skipping non-deterministic op %s", node)
110
+ continue
111
+
112
+ node_info = (
113
+ node.op_identifier(),
114
+ len(node.outputs),
115
+ tuple(id(input) for input in node.inputs),
116
+ tuple(sorted(attributes.items())),
101
117
  )
102
- logger.debug("Reusing node %s", existing_node)
103
- else:
104
- # If it is not, add to the mapping.
105
- existing_node_info_to_the_node[node_info] = node
106
- return modified
118
+ # Check if the node is a common subexpression.
119
+ if node_info in existing_node_info_to_the_node:
120
+ # If it is, this node has an existing node with the same
121
+ # operator, number of outputs, inputs, and attributes.
122
+ # We replace the node with the existing node.
123
+ modified = True
124
+ existing_node = existing_node_info_to_the_node[node_info]
125
+ _remove_node_and_replace_values(
126
+ graph,
127
+ remove_node=node,
128
+ remove_values=node.outputs,
129
+ new_values=existing_node.outputs,
130
+ )
131
+ logger.debug("Reusing node %s", existing_node)
132
+ else:
133
+ # If it is not, add to the mapping.
134
+ existing_node_info_to_the_node[node_info] = node
135
+ return modified
107
136
 
108
137
 
109
138
  def _remove_node_and_replace_values(
@@ -0,0 +1,56 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for removing duplicated initializer tensors from a graph."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "DeduplicateInitializersPass",
9
+ ]
10
+
11
+
12
+ import onnx_ir as ir
13
+
14
+
15
+ class DeduplicateInitializersPass(ir.passes.InPlacePass):
16
+ """Remove duplicated initializer tensors from the graph.
17
+
18
+ This pass detects initializers with identical shape, dtype, and content,
19
+ and replaces all duplicate references with a canonical one.
20
+
21
+ To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
22
+ to lift the initializers to the main graph first before running pass.
23
+
24
+ .. versionadded:: 0.1.3
25
+ """
26
+
27
+ def __init__(self, size_limit: int = 1024):
28
+ super().__init__()
29
+ self.size_limit = size_limit
30
+
31
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
32
+ graph = model.graph
33
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
34
+ modified = False
35
+
36
+ for initializer in tuple(graph.initializers.values()):
37
+ # TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
38
+ # out from the main graph before running this pass.
39
+ const_val = initializer.const_value
40
+ if const_val is None:
41
+ # Skip if initializer has no constant value
42
+ continue
43
+
44
+ if const_val.size > self.size_limit:
45
+ continue
46
+
47
+ key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
48
+ if key in initializers:
49
+ modified = True
50
+ ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
51
+ assert initializer.name is not None
52
+ graph.initializers.pop(initializer.name)
53
+ else:
54
+ initializers[key] = initializer # type: ignore[index]
55
+
56
+ return ir.passes.PassResult(model=model, modified=modified)
@@ -10,7 +10,7 @@ __all__ = [
10
10
 
11
11
  from typing import Literal
12
12
 
13
- import onnx
13
+ import onnx # noqa: TID251
14
14
 
15
15
  import onnx_ir as ir
16
16
  from onnx_ir.passes.common import _c_api_utils
@@ -11,7 +11,7 @@ __all__ = [
11
11
 
12
12
  import logging
13
13
 
14
- import onnx
14
+ import onnx # noqa: TID251
15
15
 
16
16
  import onnx_ir as ir
17
17
  from onnx_ir.passes.common import _c_api_utils
@@ -10,7 +10,7 @@ __all__ = [
10
10
 
11
11
  import logging
12
12
 
13
- import onnx
13
+ import onnx # noqa: TID251
14
14
 
15
15
  import onnx_ir as ir
16
16