onnx-ir 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/_display.py ADDED
@@ -0,0 +1,49 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Internal utilities for displaying the intermediate representation of a model.
4
+
5
+ NOTE: All third-party imports should be scoped and imported only when used to avoid
6
+ importing unnecessary dependencies.
7
+ """
8
+ # pylint: disable=import-outside-toplevel
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any
13
+
14
+
15
+ def require_rich() -> Any:
16
+ """Raise an ImportError if rich is not installed."""
17
+ try:
18
+ import rich
19
+ except ImportError:
20
+ return None
21
+ return rich
22
+
23
+
24
+ class PrettyPrintable:
25
+ def display(self, *, page: bool = False) -> None:
26
+ """Pretty print the object.
27
+
28
+ Args:
29
+ page: Whether to page the output.
30
+ """
31
+ rich = require_rich()
32
+ text = str(self)
33
+
34
+ if rich is None:
35
+ print(text)
36
+ # Color print this message
37
+ print(
38
+ f"\n\n\u001b[36mTip: Install the rich library with 'pip install rich' to pretty print this {self.__class__.__name__}.\u001b[0m"
39
+ )
40
+ return
41
+
42
+ if page:
43
+ import rich.console
44
+
45
+ console = rich.console.Console()
46
+ with console.pager():
47
+ console.print(text)
48
+ else:
49
+ rich.print(text)
onnx_ir/_enums.py ADDED
@@ -0,0 +1,154 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """ONNX IR enums that matches the ONNX spec."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import enum
8
+
9
+ import ml_dtypes
10
+ import numpy as np
11
+
12
+
13
+ class AttributeType(enum.IntEnum):
14
+ """Enum for the types of ONNX attributes."""
15
+
16
+ UNDEFINED = 0
17
+ FLOAT = 1
18
+ INT = 2
19
+ STRING = 3
20
+ TENSOR = 4
21
+ GRAPH = 5
22
+ FLOATS = 6
23
+ INTS = 7
24
+ STRINGS = 8
25
+ TENSORS = 9
26
+ GRAPHS = 10
27
+ SPARSE_TENSOR = 11
28
+ SPARSE_TENSORS = 12
29
+ TYPE_PROTO = 13
30
+ TYPE_PROTOS = 14
31
+
32
+ def __repr__(self) -> str:
33
+ return self.name
34
+
35
+ def __str__(self) -> str:
36
+ return self.__repr__()
37
+
38
+
39
+ class DataType(enum.IntEnum):
40
+ """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
41
+
42
+ # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
43
+ # but we should stick to the names used in the ONNX spec for consistency.
44
+ UNDEFINED = 0
45
+ FLOAT = 1
46
+ UINT8 = 2
47
+ INT8 = 3
48
+ UINT16 = 4
49
+ INT16 = 5
50
+ INT32 = 6
51
+ INT64 = 7
52
+ STRING = 8
53
+ BOOL = 9
54
+ FLOAT16 = 10
55
+ DOUBLE = 11
56
+ UINT32 = 12
57
+ UINT64 = 13
58
+ COMPLEX64 = 14
59
+ COMPLEX128 = 15
60
+ BFLOAT16 = 16
61
+ FLOAT8E4M3FN = 17
62
+ FLOAT8E4M3FNUZ = 18
63
+ FLOAT8E5M2 = 19
64
+ FLOAT8E5M2FNUZ = 20
65
+ UINT4 = 21
66
+ INT4 = 22
67
+
68
+ @classmethod
69
+ def from_numpy(cls, dtype: np.dtype) -> DataType:
70
+ """Returns the ONNX data type for the numpy dtype.
71
+
72
+ Raises:
73
+ TypeError: If the data type is not supported by ONNX.
74
+ """
75
+ if dtype not in _NP_TYPE_TO_DATA_TYPE:
76
+ raise TypeError(f"Unsupported numpy data type: {dtype}")
77
+ return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
78
+
79
+ @property
80
+ def itemsize(self) -> float:
81
+ """Returns the size of the data type in bytes."""
82
+ return _ITEMSIZE_MAP[self]
83
+
84
+ def numpy(self) -> np.dtype:
85
+ """Returns the numpy dtype for the ONNX data type.
86
+
87
+ Raises:
88
+ TypeError: If the data type is not supported by numpy.
89
+ """
90
+ if self not in _DATA_TYPE_TO_NP_TYPE:
91
+ raise TypeError(f"Numpy does not support ONNX data type: {self}")
92
+ return _DATA_TYPE_TO_NP_TYPE[self]
93
+
94
+ def __repr__(self) -> str:
95
+ return self.name
96
+
97
+ def __str__(self) -> str:
98
+ return self.__repr__()
99
+
100
+
101
+ _ITEMSIZE_MAP = {
102
+ DataType.FLOAT: 4,
103
+ DataType.UINT8: 1,
104
+ DataType.INT8: 1,
105
+ DataType.UINT16: 2,
106
+ DataType.INT16: 2,
107
+ DataType.INT32: 4,
108
+ DataType.INT64: 8,
109
+ DataType.STRING: 1,
110
+ DataType.BOOL: 1,
111
+ DataType.FLOAT16: 2,
112
+ DataType.DOUBLE: 8,
113
+ DataType.UINT32: 4,
114
+ DataType.UINT64: 8,
115
+ DataType.COMPLEX64: 8,
116
+ DataType.COMPLEX128: 16,
117
+ DataType.BFLOAT16: 2,
118
+ DataType.FLOAT8E4M3FN: 1,
119
+ DataType.FLOAT8E4M3FNUZ: 1,
120
+ DataType.FLOAT8E5M2: 1,
121
+ DataType.FLOAT8E5M2FNUZ: 1,
122
+ DataType.UINT4: 0.5,
123
+ DataType.INT4: 0.5,
124
+ }
125
+
126
+
127
+ # We use ml_dtypes to support dtypes that are not in numpy.
128
+ _NP_TYPE_TO_DATA_TYPE = {
129
+ np.dtype("bool"): DataType.BOOL,
130
+ np.dtype("complex128"): DataType.COMPLEX128,
131
+ np.dtype("complex64"): DataType.COMPLEX64,
132
+ np.dtype("float16"): DataType.FLOAT16,
133
+ np.dtype("float32"): DataType.FLOAT,
134
+ np.dtype("float64"): DataType.DOUBLE,
135
+ np.dtype("int16"): DataType.INT16,
136
+ np.dtype("int32"): DataType.INT32,
137
+ np.dtype("int64"): DataType.INT64,
138
+ np.dtype("int8"): DataType.INT8,
139
+ np.dtype("object"): DataType.STRING,
140
+ np.dtype("uint16"): DataType.UINT16,
141
+ np.dtype("uint32"): DataType.UINT32,
142
+ np.dtype("uint64"): DataType.UINT64,
143
+ np.dtype("uint8"): DataType.UINT8,
144
+ np.dtype(ml_dtypes.bfloat16): DataType.BFLOAT16,
145
+ np.dtype(ml_dtypes.float8_e4m3fn): DataType.FLOAT8E4M3FN,
146
+ np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
147
+ np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
148
+ np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
149
+ np.dtype(ml_dtypes.int4): DataType.INT4,
150
+ np.dtype(ml_dtypes.uint4): DataType.UINT4,
151
+ }
152
+
153
+ # ONNX DataType to Numpy dtype.
154
+ _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
@@ -0,0 +1,323 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """External data related utilities."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["set_base_dir"]
8
+
9
+ import dataclasses
10
+ import os
11
+ from typing import Iterator, Sequence
12
+
13
+ from onnx_ir import _core, _enums, _protocols, traversal
14
+
15
+ # Note: If needed in future, add these as parameters to the function calls
16
+ # align_offset: Offset will always be page aligned and alloction granularity aligned for mmap support. This is done by padding previous tensor data with zeros keeping same length. Tensor data will be aligned if > align_threshold
17
+ _ALIGN_OFFSET = True
18
+ # align_threshold: Alignment threshold for size of data. Having a low threshold will waste file space for small initializers. Only when tensor's data is > the page_align_threshold it will be force aligned.
19
+ _ALIGN_THRESHOLD = 1048576 # 1MB
20
+ # allocation_granularity: The allocation Granularity for mmap() support. Typically 64KB for Windows & 4KB for other OSes.
21
+ _ALLOCATION_GRANULARITY = 65536 # 64KB
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class _ExternalDataInfo:
26
+ """
27
+ A class that stores information about a tensor that is to be stored as external data.
28
+
29
+ Attributes:
30
+ name: The name of the tensor that is to be stored as external data.
31
+ offset: The offset is used to determine where exactly in the file the external data is written to.
32
+ length: Stores the size of the tensor.
33
+ """
34
+
35
+ name: str | None
36
+ offset: int
37
+ length: int
38
+
39
+
40
+ def _all_tensors(
41
+ graph: _core.Graph | _core.GraphView, include_attributes: bool = False
42
+ ) -> Iterator[_protocols.TensorProtocol]:
43
+ """Iterate over all tensors in the graph.
44
+
45
+ Args:
46
+ graph: The graph to traverse tensors on.
47
+ include_attributes: Whether to include tensors in attributes.
48
+
49
+ Yields:
50
+ Tensors in the graph.
51
+ """
52
+ # Yield all tensors in initializers
53
+ for value in graph.initializers.values():
54
+ if value.const_value is not None:
55
+ yield value.const_value
56
+ if not include_attributes:
57
+ return
58
+ # Look at constant attributes in nodes
59
+ for node in traversal.RecursiveGraphIterator(graph):
60
+ for attr in node.attributes.values():
61
+ if isinstance(attr, _core.RefAttr):
62
+ continue
63
+ if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
64
+ yield attr.value
65
+ elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None:
66
+ yield from attr.value
67
+
68
+
69
+ def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None:
70
+ """Set the base directory for external data in a graph.
71
+
72
+ Args:
73
+ graph: The graph to traverse tensors on.
74
+ base_dir: The base directory. This is the directory where the ONNX file is.
75
+ """
76
+ for tensor in _all_tensors(graph, include_attributes=True):
77
+ if isinstance(tensor, _core.ExternalTensor):
78
+ tensor.base_dir = base_dir
79
+
80
+
81
+ def _load_external_data_file(
82
+ tensors: Sequence[_protocols.TensorProtocol],
83
+ base_path: str | os.PathLike,
84
+ relative_path: str | os.PathLike,
85
+ ) -> list[_protocols.TensorProtocol]:
86
+ """Load all external data that is at relative_path into memory for the provided model.
87
+
88
+ Args:
89
+ tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
90
+ base_path: Path of base directory.
91
+ relative_path: Path to which external data is to be stored, relative to the ONNX file.
92
+
93
+ Returns:
94
+ A list of ir.Tensor values.
95
+ """
96
+ updated_tensors: list[_protocols.TensorProtocol] = []
97
+ for tensor in tensors:
98
+ if isinstance(tensor, _core.ExternalTensor):
99
+ external_tensor = tensor
100
+ if os.path.samefile(tensor.path, os.path.join(base_path, relative_path)):
101
+ # Copy the data as the .numpy() call references data from a file whose data is eventually modified
102
+ tensor_data = external_tensor.numpy().copy()
103
+ external_tensor.release()
104
+ tensor = _core.Tensor(
105
+ tensor_data, name=external_tensor.name, dtype=external_tensor.dtype
106
+ )
107
+ updated_tensors.append(tensor)
108
+ return updated_tensors
109
+
110
+
111
+ def _compute_new_offset(
112
+ current_offset: int,
113
+ tensor_size: int,
114
+ align_offset: bool = _ALIGN_OFFSET,
115
+ align_threshold: int = _ALIGN_THRESHOLD,
116
+ allocation_granularity: int = _ALLOCATION_GRANULARITY,
117
+ ) -> int:
118
+ """Compute the offset to align the tensor data based on the current offset.
119
+
120
+ Args:
121
+ current_offset: Current location in the file at which tensor data will be written to.
122
+ tensor_size: Size of the tensor data to be written to file.
123
+ align_offset: Offset will always be page aligned and alloction granularity aligned for mmap support. This is done by padding previous tensor data with zeros keeping same length. Tensor data will be aligned if > align_threshold
124
+ align_threshold: Alignment threshold for size of data. Having a low threshold will waste file space for small initializers. Only when tensor's data is > the page_align_threshold it will be force aligned.
125
+ allocation_granularity: The allocation Granularity for mmap() support. Typically 64KB for Windows & 4KB for other OSes.
126
+
127
+ Returns:
128
+ The updated offset value.
129
+ """
130
+ if align_offset and tensor_size > align_threshold:
131
+ alignment_factor = max(4096, allocation_granularity)
132
+ # Align to the next page or alloc granularity
133
+ return (current_offset + alignment_factor - 1) // alignment_factor * alignment_factor
134
+ return current_offset
135
+
136
+
137
+ def _compute_external_data_info(
138
+ tensor: _protocols.TensorProtocol,
139
+ current_offset: int,
140
+ ) -> _ExternalDataInfo:
141
+ """Capture information about a tensor that is to be stored as external data."""
142
+ tensor_size = tensor.nbytes
143
+ # Calculate updated offset and align tensors
144
+ current_offset = _compute_new_offset(current_offset, tensor_size)
145
+ # Store offset and tensor size as ExternalDataInfo
146
+ external_data_info = _ExternalDataInfo(
147
+ tensor.name,
148
+ current_offset,
149
+ tensor_size,
150
+ )
151
+ return external_data_info
152
+
153
+
154
+ def _save_external_data(
155
+ external_data_info: list[tuple[_protocols.TensorProtocol, _ExternalDataInfo]],
156
+ file_path: str | os.PathLike,
157
+ ) -> None:
158
+ """Write tensor data to an external file according to information stored in ExternalDataInfo objects.
159
+
160
+ Args:
161
+ external_data_info: A collection of external data information stored for each tensor to be written as external data.
162
+ file_path: Location to which external data is to be stored.
163
+ """
164
+ with open(file_path, "wb") as data_file:
165
+ for tensor, tensor_info in external_data_info:
166
+ current_offset = tensor_info.offset
167
+ assert tensor is not None
168
+ raw_data = tensor.tobytes()
169
+ if isinstance(tensor, _core.ExternalTensor):
170
+ tensor.release()
171
+ # Pad file to required offset if needed
172
+ file_size = data_file.tell()
173
+ if current_offset > file_size:
174
+ data_file.write(b"\0" * (current_offset - file_size))
175
+ data_file.write(raw_data)
176
+
177
+
178
+ def _convert_as_external_tensors(
179
+ external_data_info: list[tuple[_protocols.TensorProtocol, _ExternalDataInfo]],
180
+ base_path: str | os.PathLike,
181
+ relative_path: str | os.PathLike,
182
+ ) -> list[_core.ExternalTensor]:
183
+ """Convert the tensors (stored within the values) written as external data to _core.ExternalTensor types.
184
+
185
+ Args:
186
+ external_data_info: A collection of external data information stored for each tensor to be written as external data.
187
+ base_path: Path of base directory.
188
+ relative_path: Path to which external data is to be stored, relative to the ONNX file.
189
+
190
+ Returns:
191
+ A list of external tensors.
192
+ """
193
+ external_tensors: list[_core.ExternalTensor] = []
194
+ for tensor, tensor_info in external_data_info:
195
+ assert tensor is not None
196
+ external_tensor = _core.ExternalTensor(
197
+ os.path.normpath(relative_path),
198
+ tensor_info.offset,
199
+ tensor_info.length,
200
+ tensor.dtype, # type: ignore[arg-type]
201
+ shape=tensor.shape, # type: ignore[arg-type]
202
+ name=tensor.name, # type: ignore[arg-type]
203
+ base_dir=os.path.normpath(base_path),
204
+ )
205
+ external_tensors.append(external_tensor)
206
+ return external_tensors
207
+
208
+
209
+ def convert_tensors_to_external(
210
+ tensors: Sequence[_protocols.TensorProtocol],
211
+ base_path: str | os.PathLike,
212
+ relative_path: str | os.PathLike,
213
+ load_external_to_memory: bool = False,
214
+ ) -> list[_core.ExternalTensor]:
215
+ """Convert a sequence of any TensorProtocol tensors to external tensors.
216
+
217
+ Args:
218
+ tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
219
+ base_path: Path of base directory.
220
+ relative_path: Path to which external data is to be stored, relative to the ONNX file.
221
+ load_external_to_memory: If set to true, loads external tensors present in the same file path as destination path to memory.
222
+
223
+ Returns:
224
+ A list of external tensors derived from a list of input tensors.
225
+ """
226
+ path = os.path.join(base_path, relative_path)
227
+ # Check if file path is valid, and create subsequent subdirectories within the path if they don't exist
228
+ os.makedirs(os.path.dirname(path), exist_ok=True)
229
+ tmp_file_created = False
230
+ # Check if file exists. Load pre-existing external data if it does.
231
+ if os.path.exists(path):
232
+ # Check if any tensor in the model is using the destination file
233
+ file_used = False
234
+ for tensor in tensors:
235
+ if isinstance(tensor, _core.ExternalTensor) and os.path.samefile(
236
+ path, tensor.path
237
+ ):
238
+ # FIXME(shubhambhokare1): If there is a non-initializer tensor that is referring to this file, that tensor is now invalid. This is a special case we are ok not handling right now.
239
+ file_used = True
240
+ if file_used:
241
+ if load_external_to_memory:
242
+ tensors = _load_external_data_file(tensors, base_path, relative_path)
243
+ else:
244
+ tmp_path = os.path.join(base_path, "tmp")
245
+ os.makedirs(tmp_path, exist_ok=True)
246
+ # If exisiting external tensors are not loaded to memory, copy the external data to a temporary location
247
+ os.rename(path, os.path.join(tmp_path, relative_path))
248
+ tmp_file_created = True
249
+ for tensor in tensors:
250
+ if (
251
+ isinstance(tensor, _core.ExternalTensor)
252
+ and tensor.location == relative_path
253
+ ):
254
+ tensor.base_dir = tmp_path
255
+
256
+ external_data_info: list[tuple[_protocols.TensorProtocol, _ExternalDataInfo]] = []
257
+ # Sort all tensors based on tensor sizes, in order to avoid unneccesarry alignment.
258
+ # All the smaller tensors are written earlier and alignment is performed for the larger tensors.
259
+ sorted_indices = sorted(range(len(tensors)), key=lambda i: tensors[i].nbytes)
260
+ sorted_tensors = [tensors[i] for i in sorted_indices]
261
+
262
+ current_offset = 0
263
+ for tensor in sorted_tensors:
264
+ tensor_info = _compute_external_data_info(tensor, current_offset)
265
+ external_data_info.append((tensor, tensor_info))
266
+ current_offset = tensor_info.offset + tensor_info.length
267
+ _save_external_data(external_data_info, path)
268
+
269
+ # Convert initializers to ExternalTensors
270
+ external_tensors = _convert_as_external_tensors(
271
+ external_data_info, base_path, relative_path
272
+ )
273
+ # Sort external_tensors based on original key order
274
+ external_tensors = [
275
+ external_tensors[i]
276
+ for i in sorted(range(len(external_tensors)), key=lambda i: sorted_indices[i])
277
+ ]
278
+
279
+ # Clean-up temporary file if it is created
280
+ tmp_path = os.path.join(base_path, "tmp", relative_path)
281
+ if os.path.exists(tmp_path) and tmp_file_created:
282
+ os.remove(tmp_path)
283
+
284
+ return external_tensors
285
+
286
+
287
+ def to_external_data(
288
+ model: _core.Model,
289
+ base_path: str | os.PathLike,
290
+ relative_path: str | os.PathLike,
291
+ load_external_to_memory: bool = False,
292
+ ) -> _core.Model:
293
+ """Set all tensors with raw data as external data.
294
+
295
+ Args:
296
+ model: Model to process.
297
+ base_path: Path of base directory.
298
+ relative_path: Path to which external data is to be stored, relative to the ONNX file.
299
+ load_external_to_memory: If set to true, loads external tensors present in the same file path as destination path to memory. Otherwise, the external tensors are appended to file.
300
+
301
+ Returns:
302
+ An ir.Model with all tensors with raw data converted to external tensors.
303
+ """
304
+
305
+ # Get all the tensors in the graph which are to be stored as external data.
306
+ # Iterate through all the tensors, and extract the external data information such as
307
+ # name, offset and length.
308
+ # TODO: Currently attributes not handled, eventually try to use _all_tensors to include attrs
309
+ tensors: list[_protocols.TensorProtocol] = []
310
+ for value in model.graph.initializers.values():
311
+ if value.const_value is not None:
312
+ tensors.append(value.const_value)
313
+
314
+ external_tensors = convert_tensors_to_external(
315
+ tensors,
316
+ base_path,
317
+ relative_path,
318
+ load_external_to_memory=load_external_to_memory,
319
+ )
320
+
321
+ for value, external_tensor in zip(model.graph.initializers.values(), external_tensors):
322
+ value.const_value = external_tensor
323
+ return model
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Utilities for comparing IR graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from onnx_ir import _core
8
+
9
+ # NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs
10
+ # NOTE(justinchuby): A graph may be specified with a set of inputs and outputs
11
+
12
+
13
+ def topologically_equal(graph1: _core.Graph, graph2: _core.Graph) -> bool:
14
+ """Return true if the two graphs are topologically equivalent, without considering initializers.
15
+
16
+ Args:
17
+ graph1: The first graph to compare.
18
+ graph2: The second graph to compare.
19
+
20
+ Returns:
21
+ True if the graphs are equal, False otherwise.
22
+ """
23
+ raise NotImplementedError()
@@ -0,0 +1,118 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Version utils for testing."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+ from typing import Callable, Sequence
9
+
10
+ import packaging.version
11
+
12
+
13
+ def onnx_older_than(version: str) -> bool:
14
+ """Returns True if the ONNX version is older than the given version."""
15
+ import onnx # pylint: disable=import-outside-toplevel
16
+
17
+ return (
18
+ packaging.version.parse(onnx.__version__).release
19
+ < packaging.version.parse(version).release
20
+ )
21
+
22
+
23
+ def torch_older_than(version: str) -> bool:
24
+ """Returns True if the torch version is older than the given version."""
25
+ import torch # pylint: disable=import-outside-toplevel
26
+
27
+ return (
28
+ packaging.version.parse(torch.__version__).release
29
+ < packaging.version.parse(version).release
30
+ )
31
+
32
+
33
+ def transformers_older_than(version: str) -> bool | None:
34
+ """Returns True if the transformers version is older than the given version."""
35
+ try:
36
+ import transformers # pylint: disable=import-outside-toplevel
37
+ except ImportError:
38
+ return None
39
+
40
+ return (
41
+ packaging.version.parse(transformers.__version__).release
42
+ < packaging.version.parse(version).release
43
+ )
44
+
45
+
46
+ def is_onnxruntime_training() -> bool:
47
+ """Returns True if the onnxruntime is onnxruntime-training."""
48
+ try:
49
+ from onnxruntime import training # pylint: disable=import-outside-toplevel
50
+
51
+ assert training
52
+ except ImportError:
53
+ # onnxruntime not training
54
+ return False
55
+
56
+ try:
57
+ from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
58
+ OrtValueVector,
59
+ )
60
+ except ImportError:
61
+ return False
62
+
63
+ return hasattr(OrtValueVector, "push_back_batch")
64
+
65
+
66
+ def onnxruntime_older_than(version: str) -> bool:
67
+ """Returns True if the onnxruntime version is older than the given version."""
68
+ import onnxruntime # pylint: disable=import-outside-toplevel
69
+
70
+ return (
71
+ packaging.version.parse(onnxruntime.__version__).release
72
+ < packaging.version.parse(version).release
73
+ )
74
+
75
+
76
+ def numpy_older_than(version: str) -> bool:
77
+ """Returns True if the numpy version is older than the given version."""
78
+ import numpy # pylint: disable=import-outside-toplevel
79
+
80
+ return (
81
+ packaging.version.parse(numpy.__version__).release
82
+ < packaging.version.parse(version).release
83
+ )
84
+
85
+
86
+ def has_transformers():
87
+ """Tells if transformers is installed."""
88
+ try:
89
+ import transformers # pylint: disable=import-outside-toplevel
90
+
91
+ assert transformers
92
+ return True # noqa
93
+ except ImportError:
94
+ return False
95
+
96
+
97
+ def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
98
+ """Catches warnings.
99
+
100
+ Args:
101
+ warns: warnings to ignore
102
+
103
+ Returns:
104
+ decorated function
105
+ """
106
+
107
+ def wrapper(fct):
108
+ if warns is None:
109
+ raise AssertionError(f"warns cannot be None for '{fct}'.")
110
+
111
+ def call_f(self):
112
+ with warnings.catch_warnings():
113
+ warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
114
+ return fct(self)
115
+
116
+ return call_f
117
+
118
+ return wrapper