onnx-ir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +5 -2
- onnx_ir/_convenience/__init__.py +125 -4
- onnx_ir/_convenience/_constructors.py +6 -2
- onnx_ir/_core.py +291 -76
- onnx_ir/_enums.py +35 -25
- onnx_ir/_graph_containers.py +114 -9
- onnx_ir/_io.py +40 -4
- onnx_ir/_type_casting.py +2 -1
- onnx_ir/_version_utils.py +5 -48
- onnx_ir/convenience.py +3 -1
- onnx_ir/external_data.py +43 -3
- onnx_ir/passes/_pass_infra.py +1 -1
- onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir/passes/common/_c_api_utils.py +1 -1
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +10 -25
- onnx_ir/passes/common/inliner.py +4 -3
- onnx_ir/passes/common/onnx_checker.py +1 -1
- onnx_ir/passes/common/shape_inference.py +1 -1
- onnx_ir/passes/common/unused_removal.py +1 -1
- onnx_ir/serde.py +171 -6
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/METADATA +22 -4
- onnx_ir-0.1.2.dist-info/RECORD +42 -0
- onnx_ir-0.1.0.dist-info/RECORD +0 -41
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/top_level.txt +0 -0
onnx_ir/_enums.py
CHANGED
|
@@ -114,7 +114,18 @@ class DataType(enum.IntEnum):
|
|
|
114
114
|
@property
|
|
115
115
|
def itemsize(self) -> float:
|
|
116
116
|
"""Returns the size of the data type in bytes."""
|
|
117
|
-
return
|
|
117
|
+
return self.bitwidth / 8
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def bitwidth(self) -> int:
|
|
121
|
+
"""Returns the bit width of the data type.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
TypeError: If the data type is not supported.
|
|
125
|
+
"""
|
|
126
|
+
if self not in _BITWIDTH_MAP:
|
|
127
|
+
raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
|
|
128
|
+
return _BITWIDTH_MAP[self]
|
|
118
129
|
|
|
119
130
|
def numpy(self) -> np.dtype:
|
|
120
131
|
"""Returns the numpy dtype for the ONNX data type.
|
|
@@ -163,30 +174,29 @@ class DataType(enum.IntEnum):
|
|
|
163
174
|
return self.__repr__()
|
|
164
175
|
|
|
165
176
|
|
|
166
|
-
|
|
167
|
-
DataType.FLOAT:
|
|
168
|
-
DataType.UINT8:
|
|
169
|
-
DataType.INT8:
|
|
170
|
-
DataType.UINT16:
|
|
171
|
-
DataType.INT16:
|
|
172
|
-
DataType.INT32:
|
|
173
|
-
DataType.INT64:
|
|
174
|
-
DataType.
|
|
175
|
-
DataType.
|
|
176
|
-
DataType.
|
|
177
|
-
DataType.
|
|
178
|
-
DataType.
|
|
179
|
-
DataType.
|
|
180
|
-
DataType.
|
|
181
|
-
DataType.
|
|
182
|
-
DataType.
|
|
183
|
-
DataType.
|
|
184
|
-
DataType.
|
|
185
|
-
DataType.
|
|
186
|
-
DataType.
|
|
187
|
-
DataType.
|
|
188
|
-
DataType.
|
|
189
|
-
DataType.FLOAT4E2M1: 0.5,
|
|
177
|
+
_BITWIDTH_MAP = {
|
|
178
|
+
DataType.FLOAT: 32,
|
|
179
|
+
DataType.UINT8: 8,
|
|
180
|
+
DataType.INT8: 8,
|
|
181
|
+
DataType.UINT16: 16,
|
|
182
|
+
DataType.INT16: 16,
|
|
183
|
+
DataType.INT32: 32,
|
|
184
|
+
DataType.INT64: 64,
|
|
185
|
+
DataType.BOOL: 8,
|
|
186
|
+
DataType.FLOAT16: 16,
|
|
187
|
+
DataType.DOUBLE: 64,
|
|
188
|
+
DataType.UINT32: 32,
|
|
189
|
+
DataType.UINT64: 64,
|
|
190
|
+
DataType.COMPLEX64: 64, # 2 * 32
|
|
191
|
+
DataType.COMPLEX128: 128, # 2 * 64
|
|
192
|
+
DataType.BFLOAT16: 16,
|
|
193
|
+
DataType.FLOAT8E4M3FN: 8,
|
|
194
|
+
DataType.FLOAT8E4M3FNUZ: 8,
|
|
195
|
+
DataType.FLOAT8E5M2: 8,
|
|
196
|
+
DataType.FLOAT8E5M2FNUZ: 8,
|
|
197
|
+
DataType.UINT4: 4,
|
|
198
|
+
DataType.INT4: 4,
|
|
199
|
+
DataType.FLOAT4E2M1: 4,
|
|
190
200
|
}
|
|
191
201
|
|
|
192
202
|
|
onnx_ir/_graph_containers.py
CHANGED
|
@@ -12,13 +12,16 @@ __all__ = [
|
|
|
12
12
|
]
|
|
13
13
|
|
|
14
14
|
import collections
|
|
15
|
-
|
|
16
|
-
from
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import Iterable, Sequence
|
|
17
|
+
from typing import SupportsIndex, TypeVar
|
|
17
18
|
|
|
18
19
|
import onnx_ir
|
|
20
|
+
from onnx_ir import _core, _protocols
|
|
19
21
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class _GraphIO(collections.UserList["_core.Value"]):
|
|
@@ -152,6 +155,10 @@ class GraphInputs(_GraphIO):
|
|
|
152
155
|
raise ValueError(
|
|
153
156
|
f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
|
|
154
157
|
)
|
|
158
|
+
if value.producer() is not None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
|
|
161
|
+
)
|
|
155
162
|
self._ref_counter[value] += 1
|
|
156
163
|
value._is_graph_input = True
|
|
157
164
|
value._graph = self._graph
|
|
@@ -209,7 +216,7 @@ class GraphOutputs(_GraphIO):
|
|
|
209
216
|
|
|
210
217
|
|
|
211
218
|
class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
212
|
-
"""The initializers of a Graph."""
|
|
219
|
+
"""The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
|
|
213
220
|
|
|
214
221
|
def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
|
|
215
222
|
# Perform checks first in _set_graph before modifying the data structure with super().__init__()
|
|
@@ -244,12 +251,23 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
|
244
251
|
|
|
245
252
|
def __setitem__(self, key: str, value: _core.Value) -> None:
|
|
246
253
|
"""Set an initializer for the graph."""
|
|
247
|
-
if
|
|
254
|
+
if not isinstance(value, _core.Value):
|
|
255
|
+
raise TypeError(f"value must be a Value object, not {type(value)}")
|
|
256
|
+
if not isinstance(key, str):
|
|
257
|
+
raise TypeError(f"Value name must be a string, not {type(key)}")
|
|
258
|
+
if key == "":
|
|
259
|
+
raise ValueError("Value name cannot be an empty string")
|
|
260
|
+
if not value.name:
|
|
261
|
+
logger.info("Value %s does not have a name, setting it to '%s'", value, key)
|
|
262
|
+
value.name = key
|
|
263
|
+
elif key != value.name:
|
|
248
264
|
raise ValueError(
|
|
249
|
-
f"Key '{key}' does not match the name of the value '{value.name}'"
|
|
265
|
+
f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
|
|
266
|
+
)
|
|
267
|
+
if value.producer() is not None:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Value '{value}' is produced by a node and cannot be a graph initializer"
|
|
250
270
|
)
|
|
251
|
-
if not isinstance(key, str):
|
|
252
|
-
raise TypeError(f"Key must be a string, not {type(key)}")
|
|
253
271
|
if key in self.data:
|
|
254
272
|
# If the key already exists, unset the old value
|
|
255
273
|
old_value = self.data[key]
|
|
@@ -266,3 +284,90 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
|
266
284
|
# the dictionary is not modified
|
|
267
285
|
self._maybe_unset_graph(value)
|
|
268
286
|
super().__delitem__(key)
|
|
287
|
+
|
|
288
|
+
def add(self, value: _core.Value) -> None:
|
|
289
|
+
"""Add an initializer to the graph."""
|
|
290
|
+
self[value.name] = value # type: ignore[index]
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class Attributes(collections.UserDict[str, "_core.Attr"]):
|
|
294
|
+
"""The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
|
|
295
|
+
|
|
296
|
+
def __init__(self, attrs: Iterable[_core.Attr]):
|
|
297
|
+
super().__init__({attr.name: attr for attr in attrs})
|
|
298
|
+
|
|
299
|
+
def __setitem__(self, key: str, value: _core.Attr) -> None:
|
|
300
|
+
"""Set an attribute for the node."""
|
|
301
|
+
if type(key) is not str:
|
|
302
|
+
raise TypeError(f"Key must be a string, not {type(key)}")
|
|
303
|
+
if not isinstance(value, _core.Attr):
|
|
304
|
+
raise TypeError(f"Value must be an Attr, not {type(value)}")
|
|
305
|
+
super().__setitem__(key, value)
|
|
306
|
+
|
|
307
|
+
def add(self, value: _core.Attr) -> None:
|
|
308
|
+
"""Add an attribute to the node."""
|
|
309
|
+
self[value.name] = value
|
|
310
|
+
|
|
311
|
+
def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
|
|
312
|
+
"""Get the integer value of the attribute."""
|
|
313
|
+
if key in self:
|
|
314
|
+
return self[key].as_int()
|
|
315
|
+
return default
|
|
316
|
+
|
|
317
|
+
def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
|
|
318
|
+
"""Get the float value of the attribute."""
|
|
319
|
+
if key in self:
|
|
320
|
+
return self[key].as_float()
|
|
321
|
+
return default
|
|
322
|
+
|
|
323
|
+
def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
|
|
324
|
+
"""Get the string value of the attribute."""
|
|
325
|
+
if key in self:
|
|
326
|
+
return self[key].as_string()
|
|
327
|
+
return default
|
|
328
|
+
|
|
329
|
+
def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
|
|
330
|
+
"""Get the tensor value of the attribute."""
|
|
331
|
+
if key in self:
|
|
332
|
+
return self[key].as_tensor()
|
|
333
|
+
return default
|
|
334
|
+
|
|
335
|
+
def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
|
|
336
|
+
"""Get the graph value of the attribute."""
|
|
337
|
+
if key in self:
|
|
338
|
+
return self[key].as_graph()
|
|
339
|
+
return default
|
|
340
|
+
|
|
341
|
+
def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
|
|
342
|
+
"""Get the Sequence of integers from the attribute."""
|
|
343
|
+
if key in self:
|
|
344
|
+
return self[key].as_ints()
|
|
345
|
+
return default
|
|
346
|
+
|
|
347
|
+
def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
|
|
348
|
+
"""Get the Sequence of floats from the attribute."""
|
|
349
|
+
if key in self:
|
|
350
|
+
return self[key].as_floats()
|
|
351
|
+
return default
|
|
352
|
+
|
|
353
|
+
def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
|
|
354
|
+
"""Get the Sequence of strings from the attribute."""
|
|
355
|
+
if key in self:
|
|
356
|
+
return self[key].as_strings()
|
|
357
|
+
return default
|
|
358
|
+
|
|
359
|
+
def get_tensors(
|
|
360
|
+
self,
|
|
361
|
+
key: str,
|
|
362
|
+
default: T = None, # type: ignore[assignment]
|
|
363
|
+
) -> Sequence[_protocols.TensorProtocol] | T:
|
|
364
|
+
"""Get the Sequence of tensors from the attribute."""
|
|
365
|
+
if key in self:
|
|
366
|
+
return self[key].as_tensors()
|
|
367
|
+
return default
|
|
368
|
+
|
|
369
|
+
def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
|
|
370
|
+
"""Get the Sequence of graphs from the attribute."""
|
|
371
|
+
if key in self:
|
|
372
|
+
return self[key].as_graphs()
|
|
373
|
+
return default
|
onnx_ir/_io.py
CHANGED
|
@@ -7,10 +7,11 @@ from __future__ import annotations
|
|
|
7
7
|
__all__ = ["load", "save"]
|
|
8
8
|
|
|
9
9
|
import os
|
|
10
|
+
from typing import Callable
|
|
10
11
|
|
|
11
|
-
import onnx
|
|
12
|
+
import onnx # noqa: TID251
|
|
12
13
|
|
|
13
|
-
from onnx_ir import _core, serde
|
|
14
|
+
from onnx_ir import _core, _protocols, serde
|
|
14
15
|
from onnx_ir import external_data as _external_data
|
|
15
16
|
from onnx_ir._polyfill import zip
|
|
16
17
|
|
|
@@ -43,6 +44,8 @@ def save(
|
|
|
43
44
|
format: str | None = None,
|
|
44
45
|
external_data: str | os.PathLike | None = None,
|
|
45
46
|
size_threshold_bytes: int = 256,
|
|
47
|
+
callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
|
|
48
|
+
| None = None,
|
|
46
49
|
) -> None:
|
|
47
50
|
"""Save an ONNX model to a file.
|
|
48
51
|
|
|
@@ -52,6 +55,30 @@ def save(
|
|
|
52
55
|
to load the newly saved model, or provide a different external data path that
|
|
53
56
|
is not currently referenced by any tensors in the model.
|
|
54
57
|
|
|
58
|
+
.. tip::
|
|
59
|
+
|
|
60
|
+
A simple progress bar can be implemented by passing a callback function as the following::
|
|
61
|
+
|
|
62
|
+
import onnx_ir as ir
|
|
63
|
+
import tqdm
|
|
64
|
+
|
|
65
|
+
with tqdm.tqdm() as pbar:
|
|
66
|
+
total_set = False
|
|
67
|
+
|
|
68
|
+
def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
|
|
69
|
+
nonlocal total_set
|
|
70
|
+
if not total_set:
|
|
71
|
+
pbar.total = metadata.total
|
|
72
|
+
total_set = True
|
|
73
|
+
|
|
74
|
+
pbar.update()
|
|
75
|
+
pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
|
|
76
|
+
|
|
77
|
+
ir.save(
|
|
78
|
+
...,
|
|
79
|
+
callback=callback,
|
|
80
|
+
)
|
|
81
|
+
|
|
55
82
|
Args:
|
|
56
83
|
model: The model to save.
|
|
57
84
|
path: The path to save the model to. E.g. "model.onnx".
|
|
@@ -65,6 +92,8 @@ def save(
|
|
|
65
92
|
it will be serialized in the ONNX Proto message.
|
|
66
93
|
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
67
94
|
Effective only when ``external_data`` is set.
|
|
95
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
96
|
+
for debugging or logging purposes.
|
|
68
97
|
|
|
69
98
|
Raises:
|
|
70
99
|
ValueError: If the external data path is an absolute path.
|
|
@@ -77,12 +106,19 @@ def save(
|
|
|
77
106
|
base_dir = os.path.dirname(path)
|
|
78
107
|
|
|
79
108
|
# Store the original initializer values so they can be restored if modify_model=False
|
|
80
|
-
initializer_values =
|
|
109
|
+
initializer_values: list[_core.Value] = []
|
|
110
|
+
for graph in model.graphs():
|
|
111
|
+
# Collect from all subgraphs as well
|
|
112
|
+
initializer_values.extend(graph.initializers.values())
|
|
81
113
|
tensors = [v.const_value for v in initializer_values]
|
|
82
114
|
|
|
83
115
|
try:
|
|
84
116
|
model = _external_data.unload_from_model(
|
|
85
|
-
model,
|
|
117
|
+
model,
|
|
118
|
+
base_dir,
|
|
119
|
+
external_data,
|
|
120
|
+
size_threshold_bytes=size_threshold_bytes,
|
|
121
|
+
callback=callback,
|
|
86
122
|
)
|
|
87
123
|
proto = serde.serialize_model(model)
|
|
88
124
|
onnx.save(proto, path, format=format)
|
onnx_ir/_type_casting.py
CHANGED
|
@@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
|
|
|
15
15
|
import numpy.typing as npt
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def
|
|
18
|
+
def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
|
|
19
19
|
"""Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
|
|
20
20
|
# Create a 1D copy
|
|
21
21
|
array_flat = array.ravel().view(np.uint8).copy()
|
|
@@ -40,6 +40,7 @@ def _unpack_uint4_as_uint8(
|
|
|
40
40
|
Returns:
|
|
41
41
|
A numpy array of int8/uint8 reshaped to dims.
|
|
42
42
|
"""
|
|
43
|
+
assert data.dtype == np.uint8, "Input data must be of type uint8"
|
|
43
44
|
result = np.empty([data.size * 2], dtype=data.dtype)
|
|
44
45
|
array_low = data & np.uint8(0x0F)
|
|
45
46
|
array_high = data & np.uint8(0xF0)
|
onnx_ir/_version_utils.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Version utils for testing."""
|
|
4
4
|
|
|
5
|
+
# pylint: disable=import-outside-toplevel
|
|
5
6
|
from __future__ import annotations
|
|
6
7
|
|
|
7
8
|
import packaging.version
|
|
@@ -9,7 +10,7 @@ import packaging.version
|
|
|
9
10
|
|
|
10
11
|
def onnx_older_than(version: str) -> bool:
|
|
11
12
|
"""Returns True if the ONNX version is older than the given version."""
|
|
12
|
-
import onnx #
|
|
13
|
+
import onnx # noqa: TID251
|
|
13
14
|
|
|
14
15
|
return (
|
|
15
16
|
packaging.version.parse(onnx.__version__).release
|
|
@@ -19,7 +20,7 @@ def onnx_older_than(version: str) -> bool:
|
|
|
19
20
|
|
|
20
21
|
def torch_older_than(version: str) -> bool:
|
|
21
22
|
"""Returns True if the torch version is older than the given version."""
|
|
22
|
-
import torch
|
|
23
|
+
import torch
|
|
23
24
|
|
|
24
25
|
return (
|
|
25
26
|
packaging.version.parse(torch.__version__).release
|
|
@@ -27,42 +28,9 @@ def torch_older_than(version: str) -> bool:
|
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def transformers_older_than(version: str) -> bool | None:
|
|
31
|
-
"""Returns True if the transformers version is older than the given version."""
|
|
32
|
-
try:
|
|
33
|
-
import transformers # pylint: disable=import-outside-toplevel
|
|
34
|
-
except ImportError:
|
|
35
|
-
return None
|
|
36
|
-
|
|
37
|
-
return (
|
|
38
|
-
packaging.version.parse(transformers.__version__).release
|
|
39
|
-
< packaging.version.parse(version).release
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def is_onnxruntime_training() -> bool:
|
|
44
|
-
"""Returns True if the onnxruntime is onnxruntime-training."""
|
|
45
|
-
try:
|
|
46
|
-
from onnxruntime import training # pylint: disable=import-outside-toplevel
|
|
47
|
-
|
|
48
|
-
assert training
|
|
49
|
-
except ImportError:
|
|
50
|
-
# onnxruntime not training
|
|
51
|
-
return False
|
|
52
|
-
|
|
53
|
-
try:
|
|
54
|
-
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
|
|
55
|
-
OrtValueVector,
|
|
56
|
-
)
|
|
57
|
-
except ImportError:
|
|
58
|
-
return False
|
|
59
|
-
|
|
60
|
-
return hasattr(OrtValueVector, "push_back_batch")
|
|
61
|
-
|
|
62
|
-
|
|
63
31
|
def onnxruntime_older_than(version: str) -> bool:
|
|
64
32
|
"""Returns True if the onnxruntime version is older than the given version."""
|
|
65
|
-
import onnxruntime
|
|
33
|
+
import onnxruntime
|
|
66
34
|
|
|
67
35
|
return (
|
|
68
36
|
packaging.version.parse(onnxruntime.__version__).release
|
|
@@ -72,20 +40,9 @@ def onnxruntime_older_than(version: str) -> bool:
|
|
|
72
40
|
|
|
73
41
|
def numpy_older_than(version: str) -> bool:
|
|
74
42
|
"""Returns True if the numpy version is older than the given version."""
|
|
75
|
-
import numpy
|
|
43
|
+
import numpy
|
|
76
44
|
|
|
77
45
|
return (
|
|
78
46
|
packaging.version.parse(numpy.__version__).release
|
|
79
47
|
< packaging.version.parse(version).release
|
|
80
48
|
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def has_transformers():
|
|
84
|
-
"""Tells if transformers is installed."""
|
|
85
|
-
try:
|
|
86
|
-
import transformers # pylint: disable=import-outside-toplevel
|
|
87
|
-
|
|
88
|
-
assert transformers
|
|
89
|
-
return True # noqa
|
|
90
|
-
except ImportError:
|
|
91
|
-
return False
|
onnx_ir/convenience.py
CHANGED
|
@@ -7,15 +7,17 @@ from __future__ import annotations
|
|
|
7
7
|
__all__ = [
|
|
8
8
|
"convert_attribute",
|
|
9
9
|
"convert_attributes",
|
|
10
|
+
"create_value_mapping",
|
|
11
|
+
"get_const_tensor",
|
|
10
12
|
"replace_all_uses_with",
|
|
11
13
|
"replace_nodes_and_values",
|
|
12
|
-
"create_value_mapping",
|
|
13
14
|
]
|
|
14
15
|
|
|
15
16
|
from onnx_ir._convenience import (
|
|
16
17
|
convert_attribute,
|
|
17
18
|
convert_attributes,
|
|
18
19
|
create_value_mapping,
|
|
20
|
+
get_const_tensor,
|
|
19
21
|
replace_all_uses_with,
|
|
20
22
|
replace_nodes_and_values,
|
|
21
23
|
)
|
onnx_ir/external_data.py
CHANGED
|
@@ -4,12 +4,15 @@
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
from typing import Callable
|
|
8
|
+
|
|
7
9
|
__all__ = [
|
|
8
10
|
"set_base_dir",
|
|
9
11
|
"unload_from_model",
|
|
10
12
|
"load_to_model",
|
|
11
13
|
"convert_tensors_to_external",
|
|
12
14
|
"convert_tensors_from_external",
|
|
15
|
+
"CallbackInfo",
|
|
13
16
|
]
|
|
14
17
|
|
|
15
18
|
import dataclasses
|
|
@@ -48,6 +51,21 @@ class _ExternalDataInfo:
|
|
|
48
51
|
length: int
|
|
49
52
|
|
|
50
53
|
|
|
54
|
+
@dataclasses.dataclass
|
|
55
|
+
class CallbackInfo:
|
|
56
|
+
"""A class that shares information about a tensor that is to be saved as external data for callback functions.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
total: The total number of tensors to save.
|
|
60
|
+
index: The index of the tensor being saved.
|
|
61
|
+
offset: The offset of the tensor in the external data file.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
total: int
|
|
65
|
+
index: int
|
|
66
|
+
offset: int
|
|
67
|
+
|
|
68
|
+
|
|
51
69
|
def _all_tensors(
|
|
52
70
|
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
|
|
53
71
|
) -> Iterator[_protocols.TensorProtocol]:
|
|
@@ -157,6 +175,7 @@ def _write_external_data(
|
|
|
157
175
|
tensors: Sequence[_protocols.TensorProtocol],
|
|
158
176
|
external_data_infos: Sequence[_ExternalDataInfo],
|
|
159
177
|
file_path: str | os.PathLike,
|
|
178
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
160
179
|
) -> None:
|
|
161
180
|
"""Write tensor data to an external file according to information stored in ExternalDataInfo objects.
|
|
162
181
|
|
|
@@ -164,12 +183,26 @@ def _write_external_data(
|
|
|
164
183
|
tensors: Tensors to be written as external data.
|
|
165
184
|
external_data_infos: External data information stored for each tensor to be written as external data.
|
|
166
185
|
file_path: Location to which external data is to be stored.
|
|
186
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
187
|
+
for debugging or logging purposes.
|
|
167
188
|
"""
|
|
168
|
-
|
|
189
|
+
tensors_count = len(tensors)
|
|
190
|
+
assert tensors_count == len(external_data_infos), (
|
|
169
191
|
"Number of tensors and external data infos should match"
|
|
170
192
|
)
|
|
171
193
|
with open(file_path, "wb") as data_file:
|
|
172
|
-
for tensor, tensor_info in
|
|
194
|
+
for i, (tensor, tensor_info) in enumerate(
|
|
195
|
+
zip(tensors, external_data_infos, strict=True)
|
|
196
|
+
):
|
|
197
|
+
if callback is not None:
|
|
198
|
+
callback(
|
|
199
|
+
tensor,
|
|
200
|
+
CallbackInfo(
|
|
201
|
+
total=tensors_count,
|
|
202
|
+
index=i,
|
|
203
|
+
offset=tensor_info.offset,
|
|
204
|
+
),
|
|
205
|
+
)
|
|
173
206
|
current_offset = tensor_info.offset
|
|
174
207
|
assert tensor is not None
|
|
175
208
|
raw_data = tensor.tobytes()
|
|
@@ -228,6 +261,7 @@ def convert_tensors_to_external(
|
|
|
228
261
|
tensors: Sequence[_protocols.TensorProtocol],
|
|
229
262
|
base_dir: str | os.PathLike,
|
|
230
263
|
relative_path: str | os.PathLike,
|
|
264
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
231
265
|
) -> list[_core.ExternalTensor]:
|
|
232
266
|
"""Convert a sequence of any TensorProtocol tensors to external tensors.
|
|
233
267
|
|
|
@@ -238,6 +272,8 @@ def convert_tensors_to_external(
|
|
|
238
272
|
tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
|
|
239
273
|
base_dir: Path of base directory.
|
|
240
274
|
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
275
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
276
|
+
for debugging or logging purposes.
|
|
241
277
|
|
|
242
278
|
Returns:
|
|
243
279
|
A list of external tensors derived from a list of input tensors. The order
|
|
@@ -285,7 +321,7 @@ def convert_tensors_to_external(
|
|
|
285
321
|
external_info = _compute_external_data_info(tensor, current_offset)
|
|
286
322
|
external_data_infos.append(external_info)
|
|
287
323
|
current_offset = external_info.offset + external_info.length
|
|
288
|
-
_write_external_data(sorted_tensors, external_data_infos, path)
|
|
324
|
+
_write_external_data(sorted_tensors, external_data_infos, path, callback=callback)
|
|
289
325
|
|
|
290
326
|
# Create external tensor objects
|
|
291
327
|
external_tensors: list[_core.ExternalTensor] = [
|
|
@@ -336,6 +372,7 @@ def unload_from_model(
|
|
|
336
372
|
relative_path: str | os.PathLike,
|
|
337
373
|
*,
|
|
338
374
|
size_threshold_bytes: int = 0,
|
|
375
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
339
376
|
) -> _core.Model:
|
|
340
377
|
"""Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file.
|
|
341
378
|
|
|
@@ -356,6 +393,8 @@ def unload_from_model(
|
|
|
356
393
|
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
357
394
|
E.g. "model.data"
|
|
358
395
|
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
396
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
397
|
+
for debugging or logging purposes.
|
|
359
398
|
|
|
360
399
|
Returns:
|
|
361
400
|
An ir.Model with all initializer data equal or above ``size_threshold_bytes``
|
|
@@ -384,6 +423,7 @@ def unload_from_model(
|
|
|
384
423
|
[v.const_value for v in initializers_to_become_external], # type: ignore[misc]
|
|
385
424
|
base_dir=base_dir,
|
|
386
425
|
relative_path=relative_path,
|
|
426
|
+
callback=callback,
|
|
387
427
|
)
|
|
388
428
|
|
|
389
429
|
# Replace the initializer values with external tensors and save the model
|
onnx_ir/passes/_pass_infra.py
CHANGED
|
@@ -5,6 +5,7 @@ __all__ = [
|
|
|
5
5
|
"AddInitializersToInputsPass",
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
|
+
"CommonSubexpressionEliminationPass",
|
|
8
9
|
"InlinePass",
|
|
9
10
|
"LiftConstantsToInitializersPass",
|
|
10
11
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
@@ -19,6 +20,9 @@ __all__ = [
|
|
|
19
20
|
from onnx_ir.passes.common.clear_metadata_and_docstring import (
|
|
20
21
|
ClearMetadataAndDocStringPass,
|
|
21
22
|
)
|
|
23
|
+
from onnx_ir.passes.common.common_subexpression_elimination import (
|
|
24
|
+
CommonSubexpressionEliminationPass,
|
|
25
|
+
)
|
|
22
26
|
from onnx_ir.passes.common.constant_manipulation import (
|
|
23
27
|
AddInitializersToInputsPass,
|
|
24
28
|
LiftConstantsToInitializersPass,
|