onnx-ir 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +154 -0
- onnx_ir/_convenience.py +439 -0
- onnx_ir/_core.py +2875 -0
- onnx_ir/_display.py +49 -0
- onnx_ir/_enums.py +154 -0
- onnx_ir/_external_data.py +323 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_internal/version_utils.py +118 -0
- onnx_ir/_io.py +50 -0
- onnx_ir/_linked_list.py +276 -0
- onnx_ir/_metadata.py +44 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_protocols.py +598 -0
- onnx_ir/_tape.py +104 -0
- onnx_ir/_thirdparty/asciichartpy.py +313 -0
- onnx_ir/_type_casting.py +91 -0
- onnx_ir/convenience.py +32 -0
- onnx_ir/passes/__init__.py +33 -0
- onnx_ir/passes/_pass_infra.py +172 -0
- onnx_ir/serde.py +1551 -0
- onnx_ir/traversal.py +82 -0
- onnx_ir-0.0.1.dist-info/LICENSE +22 -0
- onnx_ir-0.0.1.dist-info/METADATA +73 -0
- onnx_ir-0.0.1.dist-info/RECORD +26 -0
- onnx_ir-0.0.1.dist-info/WHEEL +5 -0
- onnx_ir-0.0.1.dist-info/top_level.txt +1 -0
onnx_ir/_display.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""Internal utilities for displaying the intermediate representation of a model.
|
|
4
|
+
|
|
5
|
+
NOTE: All third-party imports should be scoped and imported only when used to avoid
|
|
6
|
+
importing unnecessary dependencies.
|
|
7
|
+
"""
|
|
8
|
+
# pylint: disable=import-outside-toplevel
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def require_rich() -> Any:
|
|
16
|
+
"""Raise an ImportError if rich is not installed."""
|
|
17
|
+
try:
|
|
18
|
+
import rich
|
|
19
|
+
except ImportError:
|
|
20
|
+
return None
|
|
21
|
+
return rich
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PrettyPrintable:
|
|
25
|
+
def display(self, *, page: bool = False) -> None:
|
|
26
|
+
"""Pretty print the object.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
page: Whether to page the output.
|
|
30
|
+
"""
|
|
31
|
+
rich = require_rich()
|
|
32
|
+
text = str(self)
|
|
33
|
+
|
|
34
|
+
if rich is None:
|
|
35
|
+
print(text)
|
|
36
|
+
# Color print this message
|
|
37
|
+
print(
|
|
38
|
+
f"\n\n\u001b[36mTip: Install the rich library with 'pip install rich' to pretty print this {self.__class__.__name__}.\u001b[0m"
|
|
39
|
+
)
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
if page:
|
|
43
|
+
import rich.console
|
|
44
|
+
|
|
45
|
+
console = rich.console.Console()
|
|
46
|
+
with console.pager():
|
|
47
|
+
console.print(text)
|
|
48
|
+
else:
|
|
49
|
+
rich.print(text)
|
onnx_ir/_enums.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""ONNX IR enums that matches the ONNX spec."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import enum
|
|
8
|
+
|
|
9
|
+
import ml_dtypes
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AttributeType(enum.IntEnum):
|
|
14
|
+
"""Enum for the types of ONNX attributes."""
|
|
15
|
+
|
|
16
|
+
UNDEFINED = 0
|
|
17
|
+
FLOAT = 1
|
|
18
|
+
INT = 2
|
|
19
|
+
STRING = 3
|
|
20
|
+
TENSOR = 4
|
|
21
|
+
GRAPH = 5
|
|
22
|
+
FLOATS = 6
|
|
23
|
+
INTS = 7
|
|
24
|
+
STRINGS = 8
|
|
25
|
+
TENSORS = 9
|
|
26
|
+
GRAPHS = 10
|
|
27
|
+
SPARSE_TENSOR = 11
|
|
28
|
+
SPARSE_TENSORS = 12
|
|
29
|
+
TYPE_PROTO = 13
|
|
30
|
+
TYPE_PROTOS = 14
|
|
31
|
+
|
|
32
|
+
def __repr__(self) -> str:
|
|
33
|
+
return self.name
|
|
34
|
+
|
|
35
|
+
def __str__(self) -> str:
|
|
36
|
+
return self.__repr__()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DataType(enum.IntEnum):
|
|
40
|
+
"""Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
|
|
41
|
+
|
|
42
|
+
# NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
|
|
43
|
+
# but we should stick to the names used in the ONNX spec for consistency.
|
|
44
|
+
UNDEFINED = 0
|
|
45
|
+
FLOAT = 1
|
|
46
|
+
UINT8 = 2
|
|
47
|
+
INT8 = 3
|
|
48
|
+
UINT16 = 4
|
|
49
|
+
INT16 = 5
|
|
50
|
+
INT32 = 6
|
|
51
|
+
INT64 = 7
|
|
52
|
+
STRING = 8
|
|
53
|
+
BOOL = 9
|
|
54
|
+
FLOAT16 = 10
|
|
55
|
+
DOUBLE = 11
|
|
56
|
+
UINT32 = 12
|
|
57
|
+
UINT64 = 13
|
|
58
|
+
COMPLEX64 = 14
|
|
59
|
+
COMPLEX128 = 15
|
|
60
|
+
BFLOAT16 = 16
|
|
61
|
+
FLOAT8E4M3FN = 17
|
|
62
|
+
FLOAT8E4M3FNUZ = 18
|
|
63
|
+
FLOAT8E5M2 = 19
|
|
64
|
+
FLOAT8E5M2FNUZ = 20
|
|
65
|
+
UINT4 = 21
|
|
66
|
+
INT4 = 22
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def from_numpy(cls, dtype: np.dtype) -> DataType:
|
|
70
|
+
"""Returns the ONNX data type for the numpy dtype.
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
TypeError: If the data type is not supported by ONNX.
|
|
74
|
+
"""
|
|
75
|
+
if dtype not in _NP_TYPE_TO_DATA_TYPE:
|
|
76
|
+
raise TypeError(f"Unsupported numpy data type: {dtype}")
|
|
77
|
+
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def itemsize(self) -> float:
|
|
81
|
+
"""Returns the size of the data type in bytes."""
|
|
82
|
+
return _ITEMSIZE_MAP[self]
|
|
83
|
+
|
|
84
|
+
def numpy(self) -> np.dtype:
|
|
85
|
+
"""Returns the numpy dtype for the ONNX data type.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
TypeError: If the data type is not supported by numpy.
|
|
89
|
+
"""
|
|
90
|
+
if self not in _DATA_TYPE_TO_NP_TYPE:
|
|
91
|
+
raise TypeError(f"Numpy does not support ONNX data type: {self}")
|
|
92
|
+
return _DATA_TYPE_TO_NP_TYPE[self]
|
|
93
|
+
|
|
94
|
+
def __repr__(self) -> str:
|
|
95
|
+
return self.name
|
|
96
|
+
|
|
97
|
+
def __str__(self) -> str:
|
|
98
|
+
return self.__repr__()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
_ITEMSIZE_MAP = {
|
|
102
|
+
DataType.FLOAT: 4,
|
|
103
|
+
DataType.UINT8: 1,
|
|
104
|
+
DataType.INT8: 1,
|
|
105
|
+
DataType.UINT16: 2,
|
|
106
|
+
DataType.INT16: 2,
|
|
107
|
+
DataType.INT32: 4,
|
|
108
|
+
DataType.INT64: 8,
|
|
109
|
+
DataType.STRING: 1,
|
|
110
|
+
DataType.BOOL: 1,
|
|
111
|
+
DataType.FLOAT16: 2,
|
|
112
|
+
DataType.DOUBLE: 8,
|
|
113
|
+
DataType.UINT32: 4,
|
|
114
|
+
DataType.UINT64: 8,
|
|
115
|
+
DataType.COMPLEX64: 8,
|
|
116
|
+
DataType.COMPLEX128: 16,
|
|
117
|
+
DataType.BFLOAT16: 2,
|
|
118
|
+
DataType.FLOAT8E4M3FN: 1,
|
|
119
|
+
DataType.FLOAT8E4M3FNUZ: 1,
|
|
120
|
+
DataType.FLOAT8E5M2: 1,
|
|
121
|
+
DataType.FLOAT8E5M2FNUZ: 1,
|
|
122
|
+
DataType.UINT4: 0.5,
|
|
123
|
+
DataType.INT4: 0.5,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# We use ml_dtypes to support dtypes that are not in numpy.
|
|
128
|
+
_NP_TYPE_TO_DATA_TYPE = {
|
|
129
|
+
np.dtype("bool"): DataType.BOOL,
|
|
130
|
+
np.dtype("complex128"): DataType.COMPLEX128,
|
|
131
|
+
np.dtype("complex64"): DataType.COMPLEX64,
|
|
132
|
+
np.dtype("float16"): DataType.FLOAT16,
|
|
133
|
+
np.dtype("float32"): DataType.FLOAT,
|
|
134
|
+
np.dtype("float64"): DataType.DOUBLE,
|
|
135
|
+
np.dtype("int16"): DataType.INT16,
|
|
136
|
+
np.dtype("int32"): DataType.INT32,
|
|
137
|
+
np.dtype("int64"): DataType.INT64,
|
|
138
|
+
np.dtype("int8"): DataType.INT8,
|
|
139
|
+
np.dtype("object"): DataType.STRING,
|
|
140
|
+
np.dtype("uint16"): DataType.UINT16,
|
|
141
|
+
np.dtype("uint32"): DataType.UINT32,
|
|
142
|
+
np.dtype("uint64"): DataType.UINT64,
|
|
143
|
+
np.dtype("uint8"): DataType.UINT8,
|
|
144
|
+
np.dtype(ml_dtypes.bfloat16): DataType.BFLOAT16,
|
|
145
|
+
np.dtype(ml_dtypes.float8_e4m3fn): DataType.FLOAT8E4M3FN,
|
|
146
|
+
np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
|
|
147
|
+
np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
|
|
148
|
+
np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
|
|
149
|
+
np.dtype(ml_dtypes.int4): DataType.INT4,
|
|
150
|
+
np.dtype(ml_dtypes.uint4): DataType.UINT4,
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# ONNX DataType to Numpy dtype.
|
|
154
|
+
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""External data related utilities."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = ["set_base_dir"]
|
|
8
|
+
|
|
9
|
+
import dataclasses
|
|
10
|
+
import os
|
|
11
|
+
from typing import Iterator, Sequence
|
|
12
|
+
|
|
13
|
+
from onnx_ir import _core, _enums, _protocols, traversal
|
|
14
|
+
|
|
15
|
+
# Note: If needed in future, add these as parameters to the function calls
|
|
16
|
+
# align_offset: Offset will always be page aligned and alloction granularity aligned for mmap support. This is done by padding previous tensor data with zeros keeping same length. Tensor data will be aligned if > align_threshold
|
|
17
|
+
_ALIGN_OFFSET = True
|
|
18
|
+
# align_threshold: Alignment threshold for size of data. Having a low threshold will waste file space for small initializers. Only when tensor's data is > the page_align_threshold it will be force aligned.
|
|
19
|
+
_ALIGN_THRESHOLD = 1048576 # 1MB
|
|
20
|
+
# allocation_granularity: The allocation Granularity for mmap() support. Typically 64KB for Windows & 4KB for other OSes.
|
|
21
|
+
_ALLOCATION_GRANULARITY = 65536 # 64KB
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclasses.dataclass
|
|
25
|
+
class _ExternalDataInfo:
|
|
26
|
+
"""
|
|
27
|
+
A class that stores information about a tensor that is to be stored as external data.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
name: The name of the tensor that is to be stored as external data.
|
|
31
|
+
offset: The offset is used to determine where exactly in the file the external data is written to.
|
|
32
|
+
length: Stores the size of the tensor.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
name: str | None
|
|
36
|
+
offset: int
|
|
37
|
+
length: int
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _all_tensors(
|
|
41
|
+
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
|
|
42
|
+
) -> Iterator[_protocols.TensorProtocol]:
|
|
43
|
+
"""Iterate over all tensors in the graph.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
graph: The graph to traverse tensors on.
|
|
47
|
+
include_attributes: Whether to include tensors in attributes.
|
|
48
|
+
|
|
49
|
+
Yields:
|
|
50
|
+
Tensors in the graph.
|
|
51
|
+
"""
|
|
52
|
+
# Yield all tensors in initializers
|
|
53
|
+
for value in graph.initializers.values():
|
|
54
|
+
if value.const_value is not None:
|
|
55
|
+
yield value.const_value
|
|
56
|
+
if not include_attributes:
|
|
57
|
+
return
|
|
58
|
+
# Look at constant attributes in nodes
|
|
59
|
+
for node in traversal.RecursiveGraphIterator(graph):
|
|
60
|
+
for attr in node.attributes.values():
|
|
61
|
+
if isinstance(attr, _core.RefAttr):
|
|
62
|
+
continue
|
|
63
|
+
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
|
|
64
|
+
yield attr.value
|
|
65
|
+
elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None:
|
|
66
|
+
yield from attr.value
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None:
|
|
70
|
+
"""Set the base directory for external data in a graph.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
graph: The graph to traverse tensors on.
|
|
74
|
+
base_dir: The base directory. This is the directory where the ONNX file is.
|
|
75
|
+
"""
|
|
76
|
+
for tensor in _all_tensors(graph, include_attributes=True):
|
|
77
|
+
if isinstance(tensor, _core.ExternalTensor):
|
|
78
|
+
tensor.base_dir = base_dir
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _load_external_data_file(
|
|
82
|
+
tensors: Sequence[_protocols.TensorProtocol],
|
|
83
|
+
base_path: str | os.PathLike,
|
|
84
|
+
relative_path: str | os.PathLike,
|
|
85
|
+
) -> list[_protocols.TensorProtocol]:
|
|
86
|
+
"""Load all external data that is at relative_path into memory for the provided model.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
|
|
90
|
+
base_path: Path of base directory.
|
|
91
|
+
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
A list of ir.Tensor values.
|
|
95
|
+
"""
|
|
96
|
+
updated_tensors: list[_protocols.TensorProtocol] = []
|
|
97
|
+
for tensor in tensors:
|
|
98
|
+
if isinstance(tensor, _core.ExternalTensor):
|
|
99
|
+
external_tensor = tensor
|
|
100
|
+
if os.path.samefile(tensor.path, os.path.join(base_path, relative_path)):
|
|
101
|
+
# Copy the data as the .numpy() call references data from a file whose data is eventually modified
|
|
102
|
+
tensor_data = external_tensor.numpy().copy()
|
|
103
|
+
external_tensor.release()
|
|
104
|
+
tensor = _core.Tensor(
|
|
105
|
+
tensor_data, name=external_tensor.name, dtype=external_tensor.dtype
|
|
106
|
+
)
|
|
107
|
+
updated_tensors.append(tensor)
|
|
108
|
+
return updated_tensors
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _compute_new_offset(
|
|
112
|
+
current_offset: int,
|
|
113
|
+
tensor_size: int,
|
|
114
|
+
align_offset: bool = _ALIGN_OFFSET,
|
|
115
|
+
align_threshold: int = _ALIGN_THRESHOLD,
|
|
116
|
+
allocation_granularity: int = _ALLOCATION_GRANULARITY,
|
|
117
|
+
) -> int:
|
|
118
|
+
"""Compute the offset to align the tensor data based on the current offset.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
current_offset: Current location in the file at which tensor data will be written to.
|
|
122
|
+
tensor_size: Size of the tensor data to be written to file.
|
|
123
|
+
align_offset: Offset will always be page aligned and alloction granularity aligned for mmap support. This is done by padding previous tensor data with zeros keeping same length. Tensor data will be aligned if > align_threshold
|
|
124
|
+
align_threshold: Alignment threshold for size of data. Having a low threshold will waste file space for small initializers. Only when tensor's data is > the page_align_threshold it will be force aligned.
|
|
125
|
+
allocation_granularity: The allocation Granularity for mmap() support. Typically 64KB for Windows & 4KB for other OSes.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
The updated offset value.
|
|
129
|
+
"""
|
|
130
|
+
if align_offset and tensor_size > align_threshold:
|
|
131
|
+
alignment_factor = max(4096, allocation_granularity)
|
|
132
|
+
# Align to the next page or alloc granularity
|
|
133
|
+
return (current_offset + alignment_factor - 1) // alignment_factor * alignment_factor
|
|
134
|
+
return current_offset
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _compute_external_data_info(
|
|
138
|
+
tensor: _protocols.TensorProtocol,
|
|
139
|
+
current_offset: int,
|
|
140
|
+
) -> _ExternalDataInfo:
|
|
141
|
+
"""Capture information about a tensor that is to be stored as external data."""
|
|
142
|
+
tensor_size = tensor.nbytes
|
|
143
|
+
# Calculate updated offset and align tensors
|
|
144
|
+
current_offset = _compute_new_offset(current_offset, tensor_size)
|
|
145
|
+
# Store offset and tensor size as ExternalDataInfo
|
|
146
|
+
external_data_info = _ExternalDataInfo(
|
|
147
|
+
tensor.name,
|
|
148
|
+
current_offset,
|
|
149
|
+
tensor_size,
|
|
150
|
+
)
|
|
151
|
+
return external_data_info
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _save_external_data(
|
|
155
|
+
external_data_info: list[tuple[_protocols.TensorProtocol, _ExternalDataInfo]],
|
|
156
|
+
file_path: str | os.PathLike,
|
|
157
|
+
) -> None:
|
|
158
|
+
"""Write tensor data to an external file according to information stored in ExternalDataInfo objects.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
external_data_info: A collection of external data information stored for each tensor to be written as external data.
|
|
162
|
+
file_path: Location to which external data is to be stored.
|
|
163
|
+
"""
|
|
164
|
+
with open(file_path, "wb") as data_file:
|
|
165
|
+
for tensor, tensor_info in external_data_info:
|
|
166
|
+
current_offset = tensor_info.offset
|
|
167
|
+
assert tensor is not None
|
|
168
|
+
raw_data = tensor.tobytes()
|
|
169
|
+
if isinstance(tensor, _core.ExternalTensor):
|
|
170
|
+
tensor.release()
|
|
171
|
+
# Pad file to required offset if needed
|
|
172
|
+
file_size = data_file.tell()
|
|
173
|
+
if current_offset > file_size:
|
|
174
|
+
data_file.write(b"\0" * (current_offset - file_size))
|
|
175
|
+
data_file.write(raw_data)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _convert_as_external_tensors(
|
|
179
|
+
external_data_info: list[tuple[_protocols.TensorProtocol, _ExternalDataInfo]],
|
|
180
|
+
base_path: str | os.PathLike,
|
|
181
|
+
relative_path: str | os.PathLike,
|
|
182
|
+
) -> list[_core.ExternalTensor]:
|
|
183
|
+
"""Convert the tensors (stored within the values) written as external data to _core.ExternalTensor types.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
external_data_info: A collection of external data information stored for each tensor to be written as external data.
|
|
187
|
+
base_path: Path of base directory.
|
|
188
|
+
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
A list of external tensors.
|
|
192
|
+
"""
|
|
193
|
+
external_tensors: list[_core.ExternalTensor] = []
|
|
194
|
+
for tensor, tensor_info in external_data_info:
|
|
195
|
+
assert tensor is not None
|
|
196
|
+
external_tensor = _core.ExternalTensor(
|
|
197
|
+
os.path.normpath(relative_path),
|
|
198
|
+
tensor_info.offset,
|
|
199
|
+
tensor_info.length,
|
|
200
|
+
tensor.dtype, # type: ignore[arg-type]
|
|
201
|
+
shape=tensor.shape, # type: ignore[arg-type]
|
|
202
|
+
name=tensor.name, # type: ignore[arg-type]
|
|
203
|
+
base_dir=os.path.normpath(base_path),
|
|
204
|
+
)
|
|
205
|
+
external_tensors.append(external_tensor)
|
|
206
|
+
return external_tensors
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def convert_tensors_to_external(
|
|
210
|
+
tensors: Sequence[_protocols.TensorProtocol],
|
|
211
|
+
base_path: str | os.PathLike,
|
|
212
|
+
relative_path: str | os.PathLike,
|
|
213
|
+
load_external_to_memory: bool = False,
|
|
214
|
+
) -> list[_core.ExternalTensor]:
|
|
215
|
+
"""Convert a sequence of any TensorProtocol tensors to external tensors.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
|
|
219
|
+
base_path: Path of base directory.
|
|
220
|
+
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
221
|
+
load_external_to_memory: If set to true, loads external tensors present in the same file path as destination path to memory.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
A list of external tensors derived from a list of input tensors.
|
|
225
|
+
"""
|
|
226
|
+
path = os.path.join(base_path, relative_path)
|
|
227
|
+
# Check if file path is valid, and create subsequent subdirectories within the path if they don't exist
|
|
228
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
229
|
+
tmp_file_created = False
|
|
230
|
+
# Check if file exists. Load pre-existing external data if it does.
|
|
231
|
+
if os.path.exists(path):
|
|
232
|
+
# Check if any tensor in the model is using the destination file
|
|
233
|
+
file_used = False
|
|
234
|
+
for tensor in tensors:
|
|
235
|
+
if isinstance(tensor, _core.ExternalTensor) and os.path.samefile(
|
|
236
|
+
path, tensor.path
|
|
237
|
+
):
|
|
238
|
+
# FIXME(shubhambhokare1): If there is a non-initializer tensor that is referring to this file, that tensor is now invalid. This is a special case we are ok not handling right now.
|
|
239
|
+
file_used = True
|
|
240
|
+
if file_used:
|
|
241
|
+
if load_external_to_memory:
|
|
242
|
+
tensors = _load_external_data_file(tensors, base_path, relative_path)
|
|
243
|
+
else:
|
|
244
|
+
tmp_path = os.path.join(base_path, "tmp")
|
|
245
|
+
os.makedirs(tmp_path, exist_ok=True)
|
|
246
|
+
# If exisiting external tensors are not loaded to memory, copy the external data to a temporary location
|
|
247
|
+
os.rename(path, os.path.join(tmp_path, relative_path))
|
|
248
|
+
tmp_file_created = True
|
|
249
|
+
for tensor in tensors:
|
|
250
|
+
if (
|
|
251
|
+
isinstance(tensor, _core.ExternalTensor)
|
|
252
|
+
and tensor.location == relative_path
|
|
253
|
+
):
|
|
254
|
+
tensor.base_dir = tmp_path
|
|
255
|
+
|
|
256
|
+
external_data_info: list[tuple[_protocols.TensorProtocol, _ExternalDataInfo]] = []
|
|
257
|
+
# Sort all tensors based on tensor sizes, in order to avoid unneccesarry alignment.
|
|
258
|
+
# All the smaller tensors are written earlier and alignment is performed for the larger tensors.
|
|
259
|
+
sorted_indices = sorted(range(len(tensors)), key=lambda i: tensors[i].nbytes)
|
|
260
|
+
sorted_tensors = [tensors[i] for i in sorted_indices]
|
|
261
|
+
|
|
262
|
+
current_offset = 0
|
|
263
|
+
for tensor in sorted_tensors:
|
|
264
|
+
tensor_info = _compute_external_data_info(tensor, current_offset)
|
|
265
|
+
external_data_info.append((tensor, tensor_info))
|
|
266
|
+
current_offset = tensor_info.offset + tensor_info.length
|
|
267
|
+
_save_external_data(external_data_info, path)
|
|
268
|
+
|
|
269
|
+
# Convert initializers to ExternalTensors
|
|
270
|
+
external_tensors = _convert_as_external_tensors(
|
|
271
|
+
external_data_info, base_path, relative_path
|
|
272
|
+
)
|
|
273
|
+
# Sort external_tensors based on original key order
|
|
274
|
+
external_tensors = [
|
|
275
|
+
external_tensors[i]
|
|
276
|
+
for i in sorted(range(len(external_tensors)), key=lambda i: sorted_indices[i])
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
# Clean-up temporary file if it is created
|
|
280
|
+
tmp_path = os.path.join(base_path, "tmp", relative_path)
|
|
281
|
+
if os.path.exists(tmp_path) and tmp_file_created:
|
|
282
|
+
os.remove(tmp_path)
|
|
283
|
+
|
|
284
|
+
return external_tensors
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def to_external_data(
|
|
288
|
+
model: _core.Model,
|
|
289
|
+
base_path: str | os.PathLike,
|
|
290
|
+
relative_path: str | os.PathLike,
|
|
291
|
+
load_external_to_memory: bool = False,
|
|
292
|
+
) -> _core.Model:
|
|
293
|
+
"""Set all tensors with raw data as external data.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
model: Model to process.
|
|
297
|
+
base_path: Path of base directory.
|
|
298
|
+
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
299
|
+
load_external_to_memory: If set to true, loads external tensors present in the same file path as destination path to memory. Otherwise, the external tensors are appended to file.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
An ir.Model with all tensors with raw data converted to external tensors.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
# Get all the tensors in the graph which are to be stored as external data.
|
|
306
|
+
# Iterate through all the tensors, and extract the external data information such as
|
|
307
|
+
# name, offset and length.
|
|
308
|
+
# TODO: Currently attributes not handled, eventually try to use _all_tensors to include attrs
|
|
309
|
+
tensors: list[_protocols.TensorProtocol] = []
|
|
310
|
+
for value in model.graph.initializers.values():
|
|
311
|
+
if value.const_value is not None:
|
|
312
|
+
tensors.append(value.const_value)
|
|
313
|
+
|
|
314
|
+
external_tensors = convert_tensors_to_external(
|
|
315
|
+
tensors,
|
|
316
|
+
base_path,
|
|
317
|
+
relative_path,
|
|
318
|
+
load_external_to_memory=load_external_to_memory,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
for value, external_tensor in zip(model.graph.initializers.values(), external_tensors):
|
|
322
|
+
value.const_value = external_tensor
|
|
323
|
+
return model
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""Utilities for comparing IR graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from onnx_ir import _core
|
|
8
|
+
|
|
9
|
+
# NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs
|
|
10
|
+
# NOTE(justinchuby): A graph may be specified with a set of inputs and outputs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def topologically_equal(graph1: _core.Graph, graph2: _core.Graph) -> bool:
|
|
14
|
+
"""Return true if the two graphs are topologically equivalent, without considering initializers.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
graph1: The first graph to compare.
|
|
18
|
+
graph2: The second graph to compare.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
True if the graphs are equal, False otherwise.
|
|
22
|
+
"""
|
|
23
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""Version utils for testing."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
from typing import Callable, Sequence
|
|
9
|
+
|
|
10
|
+
import packaging.version
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def onnx_older_than(version: str) -> bool:
|
|
14
|
+
"""Returns True if the ONNX version is older than the given version."""
|
|
15
|
+
import onnx # pylint: disable=import-outside-toplevel
|
|
16
|
+
|
|
17
|
+
return (
|
|
18
|
+
packaging.version.parse(onnx.__version__).release
|
|
19
|
+
< packaging.version.parse(version).release
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def torch_older_than(version: str) -> bool:
|
|
24
|
+
"""Returns True if the torch version is older than the given version."""
|
|
25
|
+
import torch # pylint: disable=import-outside-toplevel
|
|
26
|
+
|
|
27
|
+
return (
|
|
28
|
+
packaging.version.parse(torch.__version__).release
|
|
29
|
+
< packaging.version.parse(version).release
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def transformers_older_than(version: str) -> bool | None:
|
|
34
|
+
"""Returns True if the transformers version is older than the given version."""
|
|
35
|
+
try:
|
|
36
|
+
import transformers # pylint: disable=import-outside-toplevel
|
|
37
|
+
except ImportError:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
return (
|
|
41
|
+
packaging.version.parse(transformers.__version__).release
|
|
42
|
+
< packaging.version.parse(version).release
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def is_onnxruntime_training() -> bool:
|
|
47
|
+
"""Returns True if the onnxruntime is onnxruntime-training."""
|
|
48
|
+
try:
|
|
49
|
+
from onnxruntime import training # pylint: disable=import-outside-toplevel
|
|
50
|
+
|
|
51
|
+
assert training
|
|
52
|
+
except ImportError:
|
|
53
|
+
# onnxruntime not training
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
|
|
58
|
+
OrtValueVector,
|
|
59
|
+
)
|
|
60
|
+
except ImportError:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
return hasattr(OrtValueVector, "push_back_batch")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def onnxruntime_older_than(version: str) -> bool:
|
|
67
|
+
"""Returns True if the onnxruntime version is older than the given version."""
|
|
68
|
+
import onnxruntime # pylint: disable=import-outside-toplevel
|
|
69
|
+
|
|
70
|
+
return (
|
|
71
|
+
packaging.version.parse(onnxruntime.__version__).release
|
|
72
|
+
< packaging.version.parse(version).release
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def numpy_older_than(version: str) -> bool:
|
|
77
|
+
"""Returns True if the numpy version is older than the given version."""
|
|
78
|
+
import numpy # pylint: disable=import-outside-toplevel
|
|
79
|
+
|
|
80
|
+
return (
|
|
81
|
+
packaging.version.parse(numpy.__version__).release
|
|
82
|
+
< packaging.version.parse(version).release
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def has_transformers():
|
|
87
|
+
"""Tells if transformers is installed."""
|
|
88
|
+
try:
|
|
89
|
+
import transformers # pylint: disable=import-outside-toplevel
|
|
90
|
+
|
|
91
|
+
assert transformers
|
|
92
|
+
return True # noqa
|
|
93
|
+
except ImportError:
|
|
94
|
+
return False
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
|
|
98
|
+
"""Catches warnings.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
warns: warnings to ignore
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
decorated function
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def wrapper(fct):
|
|
108
|
+
if warns is None:
|
|
109
|
+
raise AssertionError(f"warns cannot be None for '{fct}'.")
|
|
110
|
+
|
|
111
|
+
def call_f(self):
|
|
112
|
+
with warnings.catch_warnings():
|
|
113
|
+
warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
|
|
114
|
+
return fct(self)
|
|
115
|
+
|
|
116
|
+
return call_f
|
|
117
|
+
|
|
118
|
+
return wrapper
|