onnx-ir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/_enums.py CHANGED
@@ -114,7 +114,18 @@ class DataType(enum.IntEnum):
114
114
  @property
115
115
  def itemsize(self) -> float:
116
116
  """Returns the size of the data type in bytes."""
117
- return _ITEMSIZE_MAP[self]
117
+ return self.bitwidth / 8
118
+
119
+ @property
120
+ def bitwidth(self) -> int:
121
+ """Returns the bit width of the data type.
122
+
123
+ Raises:
124
+ TypeError: If the data type is not supported.
125
+ """
126
+ if self not in _BITWIDTH_MAP:
127
+ raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
128
+ return _BITWIDTH_MAP[self]
118
129
 
119
130
  def numpy(self) -> np.dtype:
120
131
  """Returns the numpy dtype for the ONNX data type.
@@ -163,30 +174,29 @@ class DataType(enum.IntEnum):
163
174
  return self.__repr__()
164
175
 
165
176
 
166
- _ITEMSIZE_MAP = {
167
- DataType.FLOAT: 4,
168
- DataType.UINT8: 1,
169
- DataType.INT8: 1,
170
- DataType.UINT16: 2,
171
- DataType.INT16: 2,
172
- DataType.INT32: 4,
173
- DataType.INT64: 8,
174
- DataType.STRING: 1,
175
- DataType.BOOL: 1,
176
- DataType.FLOAT16: 2,
177
- DataType.DOUBLE: 8,
178
- DataType.UINT32: 4,
179
- DataType.UINT64: 8,
180
- DataType.COMPLEX64: 8,
181
- DataType.COMPLEX128: 16,
182
- DataType.BFLOAT16: 2,
183
- DataType.FLOAT8E4M3FN: 1,
184
- DataType.FLOAT8E4M3FNUZ: 1,
185
- DataType.FLOAT8E5M2: 1,
186
- DataType.FLOAT8E5M2FNUZ: 1,
187
- DataType.UINT4: 0.5,
188
- DataType.INT4: 0.5,
189
- DataType.FLOAT4E2M1: 0.5,
177
+ _BITWIDTH_MAP = {
178
+ DataType.FLOAT: 32,
179
+ DataType.UINT8: 8,
180
+ DataType.INT8: 8,
181
+ DataType.UINT16: 16,
182
+ DataType.INT16: 16,
183
+ DataType.INT32: 32,
184
+ DataType.INT64: 64,
185
+ DataType.BOOL: 8,
186
+ DataType.FLOAT16: 16,
187
+ DataType.DOUBLE: 64,
188
+ DataType.UINT32: 32,
189
+ DataType.UINT64: 64,
190
+ DataType.COMPLEX64: 64, # 2 * 32
191
+ DataType.COMPLEX128: 128, # 2 * 64
192
+ DataType.BFLOAT16: 16,
193
+ DataType.FLOAT8E4M3FN: 8,
194
+ DataType.FLOAT8E4M3FNUZ: 8,
195
+ DataType.FLOAT8E5M2: 8,
196
+ DataType.FLOAT8E5M2FNUZ: 8,
197
+ DataType.UINT4: 4,
198
+ DataType.INT4: 4,
199
+ DataType.FLOAT4E2M1: 4,
190
200
  }
191
201
 
192
202
 
@@ -12,13 +12,16 @@ __all__ = [
12
12
  ]
13
13
 
14
14
  import collections
15
- from collections.abc import Iterable
16
- from typing import TYPE_CHECKING, SupportsIndex
15
+ import logging
16
+ from collections.abc import Iterable, Sequence
17
+ from typing import SupportsIndex, TypeVar
17
18
 
18
19
  import onnx_ir
20
+ from onnx_ir import _core, _protocols
19
21
 
20
- if TYPE_CHECKING:
21
- from onnx_ir import _core
22
+ T = TypeVar("T")
23
+
24
+ logger = logging.getLogger(__name__)
22
25
 
23
26
 
24
27
  class _GraphIO(collections.UserList["_core.Value"]):
@@ -152,6 +155,10 @@ class GraphInputs(_GraphIO):
152
155
  raise ValueError(
153
156
  f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
154
157
  )
158
+ if value.producer() is not None:
159
+ raise ValueError(
160
+ f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
161
+ )
155
162
  self._ref_counter[value] += 1
156
163
  value._is_graph_input = True
157
164
  value._graph = self._graph
@@ -209,7 +216,7 @@ class GraphOutputs(_GraphIO):
209
216
 
210
217
 
211
218
  class GraphInitializers(collections.UserDict[str, "_core.Value"]):
212
- """The initializers of a Graph."""
219
+ """The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
213
220
 
214
221
  def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
215
222
  # Perform checks first in _set_graph before modifying the data structure with super().__init__()
@@ -244,12 +251,23 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
244
251
 
245
252
  def __setitem__(self, key: str, value: _core.Value) -> None:
246
253
  """Set an initializer for the graph."""
247
- if key != value.name:
254
+ if not isinstance(value, _core.Value):
255
+ raise TypeError(f"value must be a Value object, not {type(value)}")
256
+ if not isinstance(key, str):
257
+ raise TypeError(f"Value name must be a string, not {type(key)}")
258
+ if key == "":
259
+ raise ValueError("Value name cannot be an empty string")
260
+ if not value.name:
261
+ logger.info("Value %s does not have a name, setting it to '%s'", value, key)
262
+ value.name = key
263
+ elif key != value.name:
248
264
  raise ValueError(
249
- f"Key '{key}' does not match the name of the value '{value.name}'"
265
+ f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
266
+ )
267
+ if value.producer() is not None:
268
+ raise ValueError(
269
+ f"Value '{value}' is produced by a node and cannot be a graph initializer"
250
270
  )
251
- if not isinstance(key, str):
252
- raise TypeError(f"Key must be a string, not {type(key)}")
253
271
  if key in self.data:
254
272
  # If the key already exists, unset the old value
255
273
  old_value = self.data[key]
@@ -266,3 +284,90 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
266
284
  # the dictionary is not modified
267
285
  self._maybe_unset_graph(value)
268
286
  super().__delitem__(key)
287
+
288
+ def add(self, value: _core.Value) -> None:
289
+ """Add an initializer to the graph."""
290
+ self[value.name] = value # type: ignore[index]
291
+
292
+
293
+ class Attributes(collections.UserDict[str, "_core.Attr"]):
294
+ """The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
295
+
296
+ def __init__(self, attrs: Iterable[_core.Attr]):
297
+ super().__init__({attr.name: attr for attr in attrs})
298
+
299
+ def __setitem__(self, key: str, value: _core.Attr) -> None:
300
+ """Set an attribute for the node."""
301
+ if type(key) is not str:
302
+ raise TypeError(f"Key must be a string, not {type(key)}")
303
+ if not isinstance(value, _core.Attr):
304
+ raise TypeError(f"Value must be an Attr, not {type(value)}")
305
+ super().__setitem__(key, value)
306
+
307
+ def add(self, value: _core.Attr) -> None:
308
+ """Add an attribute to the node."""
309
+ self[value.name] = value
310
+
311
+ def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
312
+ """Get the integer value of the attribute."""
313
+ if key in self:
314
+ return self[key].as_int()
315
+ return default
316
+
317
+ def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
318
+ """Get the float value of the attribute."""
319
+ if key in self:
320
+ return self[key].as_float()
321
+ return default
322
+
323
+ def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
324
+ """Get the string value of the attribute."""
325
+ if key in self:
326
+ return self[key].as_string()
327
+ return default
328
+
329
+ def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
330
+ """Get the tensor value of the attribute."""
331
+ if key in self:
332
+ return self[key].as_tensor()
333
+ return default
334
+
335
+ def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
336
+ """Get the graph value of the attribute."""
337
+ if key in self:
338
+ return self[key].as_graph()
339
+ return default
340
+
341
+ def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
342
+ """Get the Sequence of integers from the attribute."""
343
+ if key in self:
344
+ return self[key].as_ints()
345
+ return default
346
+
347
+ def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
348
+ """Get the Sequence of floats from the attribute."""
349
+ if key in self:
350
+ return self[key].as_floats()
351
+ return default
352
+
353
+ def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
354
+ """Get the Sequence of strings from the attribute."""
355
+ if key in self:
356
+ return self[key].as_strings()
357
+ return default
358
+
359
+ def get_tensors(
360
+ self,
361
+ key: str,
362
+ default: T = None, # type: ignore[assignment]
363
+ ) -> Sequence[_protocols.TensorProtocol] | T:
364
+ """Get the Sequence of tensors from the attribute."""
365
+ if key in self:
366
+ return self[key].as_tensors()
367
+ return default
368
+
369
+ def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
370
+ """Get the Sequence of graphs from the attribute."""
371
+ if key in self:
372
+ return self[key].as_graphs()
373
+ return default
onnx_ir/_io.py CHANGED
@@ -7,10 +7,11 @@ from __future__ import annotations
7
7
  __all__ = ["load", "save"]
8
8
 
9
9
  import os
10
+ from typing import Callable
10
11
 
11
- import onnx
12
+ import onnx # noqa: TID251
12
13
 
13
- from onnx_ir import _core, serde
14
+ from onnx_ir import _core, _protocols, serde
14
15
  from onnx_ir import external_data as _external_data
15
16
  from onnx_ir._polyfill import zip
16
17
 
@@ -43,6 +44,8 @@ def save(
43
44
  format: str | None = None,
44
45
  external_data: str | os.PathLike | None = None,
45
46
  size_threshold_bytes: int = 256,
47
+ callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
48
+ | None = None,
46
49
  ) -> None:
47
50
  """Save an ONNX model to a file.
48
51
 
@@ -52,6 +55,30 @@ def save(
52
55
  to load the newly saved model, or provide a different external data path that
53
56
  is not currently referenced by any tensors in the model.
54
57
 
58
+ .. tip::
59
+
60
+ A simple progress bar can be implemented by passing a callback function as the following::
61
+
62
+ import onnx_ir as ir
63
+ import tqdm
64
+
65
+ with tqdm.tqdm() as pbar:
66
+ total_set = False
67
+
68
+ def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
69
+ nonlocal total_set
70
+ if not total_set:
71
+ pbar.total = metadata.total
72
+ total_set = True
73
+
74
+ pbar.update()
75
+ pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
76
+
77
+ ir.save(
78
+ ...,
79
+ callback=callback,
80
+ )
81
+
55
82
  Args:
56
83
  model: The model to save.
57
84
  path: The path to save the model to. E.g. "model.onnx".
@@ -65,6 +92,8 @@ def save(
65
92
  it will be serialized in the ONNX Proto message.
66
93
  size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
67
94
  Effective only when ``external_data`` is set.
95
+ callback: A callback function that is called for each tensor that is saved to external data
96
+ for debugging or logging purposes.
68
97
 
69
98
  Raises:
70
99
  ValueError: If the external data path is an absolute path.
@@ -77,12 +106,19 @@ def save(
77
106
  base_dir = os.path.dirname(path)
78
107
 
79
108
  # Store the original initializer values so they can be restored if modify_model=False
80
- initializer_values = tuple(model.graph.initializers.values())
109
+ initializer_values: list[_core.Value] = []
110
+ for graph in model.graphs():
111
+ # Collect from all subgraphs as well
112
+ initializer_values.extend(graph.initializers.values())
81
113
  tensors = [v.const_value for v in initializer_values]
82
114
 
83
115
  try:
84
116
  model = _external_data.unload_from_model(
85
- model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes
117
+ model,
118
+ base_dir,
119
+ external_data,
120
+ size_threshold_bytes=size_threshold_bytes,
121
+ callback=callback,
86
122
  )
87
123
  proto = serde.serialize_model(model)
88
124
  onnx.save(proto, path, format=format)
onnx_ir/_type_casting.py CHANGED
@@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
15
15
  import numpy.typing as npt
16
16
 
17
17
 
18
- def pack_int4(array: np.ndarray) -> npt.NDArray[np.uint8]:
18
+ def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
19
19
  """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
20
20
  # Create a 1D copy
21
21
  array_flat = array.ravel().view(np.uint8).copy()
@@ -40,6 +40,7 @@ def _unpack_uint4_as_uint8(
40
40
  Returns:
41
41
  A numpy array of int8/uint8 reshaped to dims.
42
42
  """
43
+ assert data.dtype == np.uint8, "Input data must be of type uint8"
43
44
  result = np.empty([data.size * 2], dtype=data.dtype)
44
45
  array_low = data & np.uint8(0x0F)
45
46
  array_high = data & np.uint8(0xF0)
onnx_ir/_version_utils.py CHANGED
@@ -2,6 +2,7 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Version utils for testing."""
4
4
 
5
+ # pylint: disable=import-outside-toplevel
5
6
  from __future__ import annotations
6
7
 
7
8
  import packaging.version
@@ -9,7 +10,7 @@ import packaging.version
9
10
 
10
11
  def onnx_older_than(version: str) -> bool:
11
12
  """Returns True if the ONNX version is older than the given version."""
12
- import onnx # pylint: disable=import-outside-toplevel
13
+ import onnx # noqa: TID251
13
14
 
14
15
  return (
15
16
  packaging.version.parse(onnx.__version__).release
@@ -19,7 +20,7 @@ def onnx_older_than(version: str) -> bool:
19
20
 
20
21
  def torch_older_than(version: str) -> bool:
21
22
  """Returns True if the torch version is older than the given version."""
22
- import torch # pylint: disable=import-outside-toplevel
23
+ import torch
23
24
 
24
25
  return (
25
26
  packaging.version.parse(torch.__version__).release
@@ -27,42 +28,9 @@ def torch_older_than(version: str) -> bool:
27
28
  )
28
29
 
29
30
 
30
- def transformers_older_than(version: str) -> bool | None:
31
- """Returns True if the transformers version is older than the given version."""
32
- try:
33
- import transformers # pylint: disable=import-outside-toplevel
34
- except ImportError:
35
- return None
36
-
37
- return (
38
- packaging.version.parse(transformers.__version__).release
39
- < packaging.version.parse(version).release
40
- )
41
-
42
-
43
- def is_onnxruntime_training() -> bool:
44
- """Returns True if the onnxruntime is onnxruntime-training."""
45
- try:
46
- from onnxruntime import training # pylint: disable=import-outside-toplevel
47
-
48
- assert training
49
- except ImportError:
50
- # onnxruntime not training
51
- return False
52
-
53
- try:
54
- from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
55
- OrtValueVector,
56
- )
57
- except ImportError:
58
- return False
59
-
60
- return hasattr(OrtValueVector, "push_back_batch")
61
-
62
-
63
31
  def onnxruntime_older_than(version: str) -> bool:
64
32
  """Returns True if the onnxruntime version is older than the given version."""
65
- import onnxruntime # pylint: disable=import-outside-toplevel
33
+ import onnxruntime
66
34
 
67
35
  return (
68
36
  packaging.version.parse(onnxruntime.__version__).release
@@ -72,20 +40,9 @@ def onnxruntime_older_than(version: str) -> bool:
72
40
 
73
41
  def numpy_older_than(version: str) -> bool:
74
42
  """Returns True if the numpy version is older than the given version."""
75
- import numpy # pylint: disable=import-outside-toplevel
43
+ import numpy
76
44
 
77
45
  return (
78
46
  packaging.version.parse(numpy.__version__).release
79
47
  < packaging.version.parse(version).release
80
48
  )
81
-
82
-
83
- def has_transformers():
84
- """Tells if transformers is installed."""
85
- try:
86
- import transformers # pylint: disable=import-outside-toplevel
87
-
88
- assert transformers
89
- return True # noqa
90
- except ImportError:
91
- return False
onnx_ir/convenience.py CHANGED
@@ -7,15 +7,17 @@ from __future__ import annotations
7
7
  __all__ = [
8
8
  "convert_attribute",
9
9
  "convert_attributes",
10
+ "create_value_mapping",
11
+ "get_const_tensor",
10
12
  "replace_all_uses_with",
11
13
  "replace_nodes_and_values",
12
- "create_value_mapping",
13
14
  ]
14
15
 
15
16
  from onnx_ir._convenience import (
16
17
  convert_attribute,
17
18
  convert_attributes,
18
19
  create_value_mapping,
20
+ get_const_tensor,
19
21
  replace_all_uses_with,
20
22
  replace_nodes_and_values,
21
23
  )
onnx_ir/external_data.py CHANGED
@@ -4,12 +4,15 @@
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ from typing import Callable
8
+
7
9
  __all__ = [
8
10
  "set_base_dir",
9
11
  "unload_from_model",
10
12
  "load_to_model",
11
13
  "convert_tensors_to_external",
12
14
  "convert_tensors_from_external",
15
+ "CallbackInfo",
13
16
  ]
14
17
 
15
18
  import dataclasses
@@ -48,6 +51,21 @@ class _ExternalDataInfo:
48
51
  length: int
49
52
 
50
53
 
54
+ @dataclasses.dataclass
55
+ class CallbackInfo:
56
+ """A class that shares information about a tensor that is to be saved as external data for callback functions.
57
+
58
+ Attributes:
59
+ total: The total number of tensors to save.
60
+ index: The index of the tensor being saved.
61
+ offset: The offset of the tensor in the external data file.
62
+ """
63
+
64
+ total: int
65
+ index: int
66
+ offset: int
67
+
68
+
51
69
  def _all_tensors(
52
70
  graph: _core.Graph | _core.GraphView, include_attributes: bool = False
53
71
  ) -> Iterator[_protocols.TensorProtocol]:
@@ -157,6 +175,7 @@ def _write_external_data(
157
175
  tensors: Sequence[_protocols.TensorProtocol],
158
176
  external_data_infos: Sequence[_ExternalDataInfo],
159
177
  file_path: str | os.PathLike,
178
+ callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
160
179
  ) -> None:
161
180
  """Write tensor data to an external file according to information stored in ExternalDataInfo objects.
162
181
 
@@ -164,12 +183,26 @@ def _write_external_data(
164
183
  tensors: Tensors to be written as external data.
165
184
  external_data_infos: External data information stored for each tensor to be written as external data.
166
185
  file_path: Location to which external data is to be stored.
186
+ callback: A callback function that is called for each tensor that is saved to external data
187
+ for debugging or logging purposes.
167
188
  """
168
- assert len(tensors) == len(external_data_infos), (
189
+ tensors_count = len(tensors)
190
+ assert tensors_count == len(external_data_infos), (
169
191
  "Number of tensors and external data infos should match"
170
192
  )
171
193
  with open(file_path, "wb") as data_file:
172
- for tensor, tensor_info in zip(tensors, external_data_infos, strict=True):
194
+ for i, (tensor, tensor_info) in enumerate(
195
+ zip(tensors, external_data_infos, strict=True)
196
+ ):
197
+ if callback is not None:
198
+ callback(
199
+ tensor,
200
+ CallbackInfo(
201
+ total=tensors_count,
202
+ index=i,
203
+ offset=tensor_info.offset,
204
+ ),
205
+ )
173
206
  current_offset = tensor_info.offset
174
207
  assert tensor is not None
175
208
  raw_data = tensor.tobytes()
@@ -228,6 +261,7 @@ def convert_tensors_to_external(
228
261
  tensors: Sequence[_protocols.TensorProtocol],
229
262
  base_dir: str | os.PathLike,
230
263
  relative_path: str | os.PathLike,
264
+ callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
231
265
  ) -> list[_core.ExternalTensor]:
232
266
  """Convert a sequence of any TensorProtocol tensors to external tensors.
233
267
 
@@ -238,6 +272,8 @@ def convert_tensors_to_external(
238
272
  tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
239
273
  base_dir: Path of base directory.
240
274
  relative_path: Path to which external data is to be stored, relative to the ONNX file.
275
+ callback: A callback function that is called for each tensor that is saved to external data
276
+ for debugging or logging purposes.
241
277
 
242
278
  Returns:
243
279
  A list of external tensors derived from a list of input tensors. The order
@@ -285,7 +321,7 @@ def convert_tensors_to_external(
285
321
  external_info = _compute_external_data_info(tensor, current_offset)
286
322
  external_data_infos.append(external_info)
287
323
  current_offset = external_info.offset + external_info.length
288
- _write_external_data(sorted_tensors, external_data_infos, path)
324
+ _write_external_data(sorted_tensors, external_data_infos, path, callback=callback)
289
325
 
290
326
  # Create external tensor objects
291
327
  external_tensors: list[_core.ExternalTensor] = [
@@ -336,6 +372,7 @@ def unload_from_model(
336
372
  relative_path: str | os.PathLike,
337
373
  *,
338
374
  size_threshold_bytes: int = 0,
375
+ callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
339
376
  ) -> _core.Model:
340
377
  """Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file.
341
378
 
@@ -356,6 +393,8 @@ def unload_from_model(
356
393
  relative_path: Path to which external data is to be stored, relative to the ONNX file.
357
394
  E.g. "model.data"
358
395
  size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
396
+ callback: A callback function that is called for each tensor that is saved to external data
397
+ for debugging or logging purposes.
359
398
 
360
399
  Returns:
361
400
  An ir.Model with all initializer data equal or above ``size_threshold_bytes``
@@ -384,6 +423,7 @@ def unload_from_model(
384
423
  [v.const_value for v in initializers_to_become_external], # type: ignore[misc]
385
424
  base_dir=base_dir,
386
425
  relative_path=relative_path,
426
+ callback=callback,
387
427
  )
388
428
 
389
429
  # Replace the initializer values with external tensors and save the model
@@ -127,7 +127,7 @@ class PassBase(abc.ABC):
127
127
 
128
128
  # Check postconditions
129
129
  try:
130
- self.ensures(model)
130
+ self.ensures(result.model)
131
131
  except PostconditionError:
132
132
  raise
133
133
  except Exception as e:
@@ -5,6 +5,7 @@ __all__ = [
5
5
  "AddInitializersToInputsPass",
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
+ "CommonSubexpressionEliminationPass",
8
9
  "InlinePass",
9
10
  "LiftConstantsToInitializersPass",
10
11
  "LiftSubgraphInitializersToMainGraphPass",
@@ -19,6 +20,9 @@ __all__ = [
19
20
  from onnx_ir.passes.common.clear_metadata_and_docstring import (
20
21
  ClearMetadataAndDocStringPass,
21
22
  )
23
+ from onnx_ir.passes.common.common_subexpression_elimination import (
24
+ CommonSubexpressionEliminationPass,
25
+ )
22
26
  from onnx_ir.passes.common.constant_manipulation import (
23
27
  AddInitializersToInputsPass,
24
28
  LiftConstantsToInitializersPass,
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Callable, TypeVar
10
10
  import onnx_ir as ir
11
11
 
12
12
  if TYPE_CHECKING:
13
- import onnx
13
+ import onnx # noqa: TID251
14
14
 
15
15
 
16
16
  logger = logging.getLogger(__name__)