onnx-ir 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +5 -2
- onnx_ir/_convenience/__init__.py +125 -4
- onnx_ir/_convenience/_constructors.py +6 -2
- onnx_ir/_core.py +261 -39
- onnx_ir/_enums.py +35 -25
- onnx_ir/_graph_containers.py +2 -2
- onnx_ir/_io.py +40 -4
- onnx_ir/_type_casting.py +2 -1
- onnx_ir/_version_utils.py +5 -48
- onnx_ir/convenience.py +3 -1
- onnx_ir/external_data.py +43 -3
- onnx_ir/passes/_pass_infra.py +1 -1
- onnx_ir/passes/common/_c_api_utils.py +1 -1
- onnx_ir/passes/common/onnx_checker.py +1 -1
- onnx_ir/passes/common/shape_inference.py +1 -1
- onnx_ir/passes/common/unused_removal.py +1 -1
- onnx_ir/serde.py +171 -6
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/METADATA +22 -4
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/RECORD +22 -22
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/top_level.txt +0 -0
onnx_ir/_version_utils.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Version utils for testing."""
|
|
4
4
|
|
|
5
|
+
# pylint: disable=import-outside-toplevel
|
|
5
6
|
from __future__ import annotations
|
|
6
7
|
|
|
7
8
|
import packaging.version
|
|
@@ -9,7 +10,7 @@ import packaging.version
|
|
|
9
10
|
|
|
10
11
|
def onnx_older_than(version: str) -> bool:
|
|
11
12
|
"""Returns True if the ONNX version is older than the given version."""
|
|
12
|
-
import onnx #
|
|
13
|
+
import onnx # noqa: TID251
|
|
13
14
|
|
|
14
15
|
return (
|
|
15
16
|
packaging.version.parse(onnx.__version__).release
|
|
@@ -19,7 +20,7 @@ def onnx_older_than(version: str) -> bool:
|
|
|
19
20
|
|
|
20
21
|
def torch_older_than(version: str) -> bool:
|
|
21
22
|
"""Returns True if the torch version is older than the given version."""
|
|
22
|
-
import torch
|
|
23
|
+
import torch
|
|
23
24
|
|
|
24
25
|
return (
|
|
25
26
|
packaging.version.parse(torch.__version__).release
|
|
@@ -27,42 +28,9 @@ def torch_older_than(version: str) -> bool:
|
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def transformers_older_than(version: str) -> bool | None:
|
|
31
|
-
"""Returns True if the transformers version is older than the given version."""
|
|
32
|
-
try:
|
|
33
|
-
import transformers # pylint: disable=import-outside-toplevel
|
|
34
|
-
except ImportError:
|
|
35
|
-
return None
|
|
36
|
-
|
|
37
|
-
return (
|
|
38
|
-
packaging.version.parse(transformers.__version__).release
|
|
39
|
-
< packaging.version.parse(version).release
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def is_onnxruntime_training() -> bool:
|
|
44
|
-
"""Returns True if the onnxruntime is onnxruntime-training."""
|
|
45
|
-
try:
|
|
46
|
-
from onnxruntime import training # pylint: disable=import-outside-toplevel
|
|
47
|
-
|
|
48
|
-
assert training
|
|
49
|
-
except ImportError:
|
|
50
|
-
# onnxruntime not training
|
|
51
|
-
return False
|
|
52
|
-
|
|
53
|
-
try:
|
|
54
|
-
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
|
|
55
|
-
OrtValueVector,
|
|
56
|
-
)
|
|
57
|
-
except ImportError:
|
|
58
|
-
return False
|
|
59
|
-
|
|
60
|
-
return hasattr(OrtValueVector, "push_back_batch")
|
|
61
|
-
|
|
62
|
-
|
|
63
31
|
def onnxruntime_older_than(version: str) -> bool:
|
|
64
32
|
"""Returns True if the onnxruntime version is older than the given version."""
|
|
65
|
-
import onnxruntime
|
|
33
|
+
import onnxruntime
|
|
66
34
|
|
|
67
35
|
return (
|
|
68
36
|
packaging.version.parse(onnxruntime.__version__).release
|
|
@@ -72,20 +40,9 @@ def onnxruntime_older_than(version: str) -> bool:
|
|
|
72
40
|
|
|
73
41
|
def numpy_older_than(version: str) -> bool:
|
|
74
42
|
"""Returns True if the numpy version is older than the given version."""
|
|
75
|
-
import numpy
|
|
43
|
+
import numpy
|
|
76
44
|
|
|
77
45
|
return (
|
|
78
46
|
packaging.version.parse(numpy.__version__).release
|
|
79
47
|
< packaging.version.parse(version).release
|
|
80
48
|
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def has_transformers():
|
|
84
|
-
"""Tells if transformers is installed."""
|
|
85
|
-
try:
|
|
86
|
-
import transformers # pylint: disable=import-outside-toplevel
|
|
87
|
-
|
|
88
|
-
assert transformers
|
|
89
|
-
return True # noqa
|
|
90
|
-
except ImportError:
|
|
91
|
-
return False
|
onnx_ir/convenience.py
CHANGED
|
@@ -7,15 +7,17 @@ from __future__ import annotations
|
|
|
7
7
|
__all__ = [
|
|
8
8
|
"convert_attribute",
|
|
9
9
|
"convert_attributes",
|
|
10
|
+
"create_value_mapping",
|
|
11
|
+
"get_const_tensor",
|
|
10
12
|
"replace_all_uses_with",
|
|
11
13
|
"replace_nodes_and_values",
|
|
12
|
-
"create_value_mapping",
|
|
13
14
|
]
|
|
14
15
|
|
|
15
16
|
from onnx_ir._convenience import (
|
|
16
17
|
convert_attribute,
|
|
17
18
|
convert_attributes,
|
|
18
19
|
create_value_mapping,
|
|
20
|
+
get_const_tensor,
|
|
19
21
|
replace_all_uses_with,
|
|
20
22
|
replace_nodes_and_values,
|
|
21
23
|
)
|
onnx_ir/external_data.py
CHANGED
|
@@ -4,12 +4,15 @@
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
from typing import Callable
|
|
8
|
+
|
|
7
9
|
__all__ = [
|
|
8
10
|
"set_base_dir",
|
|
9
11
|
"unload_from_model",
|
|
10
12
|
"load_to_model",
|
|
11
13
|
"convert_tensors_to_external",
|
|
12
14
|
"convert_tensors_from_external",
|
|
15
|
+
"CallbackInfo",
|
|
13
16
|
]
|
|
14
17
|
|
|
15
18
|
import dataclasses
|
|
@@ -48,6 +51,21 @@ class _ExternalDataInfo:
|
|
|
48
51
|
length: int
|
|
49
52
|
|
|
50
53
|
|
|
54
|
+
@dataclasses.dataclass
|
|
55
|
+
class CallbackInfo:
|
|
56
|
+
"""A class that shares information about a tensor that is to be saved as external data for callback functions.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
total: The total number of tensors to save.
|
|
60
|
+
index: The index of the tensor being saved.
|
|
61
|
+
offset: The offset of the tensor in the external data file.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
total: int
|
|
65
|
+
index: int
|
|
66
|
+
offset: int
|
|
67
|
+
|
|
68
|
+
|
|
51
69
|
def _all_tensors(
|
|
52
70
|
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
|
|
53
71
|
) -> Iterator[_protocols.TensorProtocol]:
|
|
@@ -157,6 +175,7 @@ def _write_external_data(
|
|
|
157
175
|
tensors: Sequence[_protocols.TensorProtocol],
|
|
158
176
|
external_data_infos: Sequence[_ExternalDataInfo],
|
|
159
177
|
file_path: str | os.PathLike,
|
|
178
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
160
179
|
) -> None:
|
|
161
180
|
"""Write tensor data to an external file according to information stored in ExternalDataInfo objects.
|
|
162
181
|
|
|
@@ -164,12 +183,26 @@ def _write_external_data(
|
|
|
164
183
|
tensors: Tensors to be written as external data.
|
|
165
184
|
external_data_infos: External data information stored for each tensor to be written as external data.
|
|
166
185
|
file_path: Location to which external data is to be stored.
|
|
186
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
187
|
+
for debugging or logging purposes.
|
|
167
188
|
"""
|
|
168
|
-
|
|
189
|
+
tensors_count = len(tensors)
|
|
190
|
+
assert tensors_count == len(external_data_infos), (
|
|
169
191
|
"Number of tensors and external data infos should match"
|
|
170
192
|
)
|
|
171
193
|
with open(file_path, "wb") as data_file:
|
|
172
|
-
for tensor, tensor_info in
|
|
194
|
+
for i, (tensor, tensor_info) in enumerate(
|
|
195
|
+
zip(tensors, external_data_infos, strict=True)
|
|
196
|
+
):
|
|
197
|
+
if callback is not None:
|
|
198
|
+
callback(
|
|
199
|
+
tensor,
|
|
200
|
+
CallbackInfo(
|
|
201
|
+
total=tensors_count,
|
|
202
|
+
index=i,
|
|
203
|
+
offset=tensor_info.offset,
|
|
204
|
+
),
|
|
205
|
+
)
|
|
173
206
|
current_offset = tensor_info.offset
|
|
174
207
|
assert tensor is not None
|
|
175
208
|
raw_data = tensor.tobytes()
|
|
@@ -228,6 +261,7 @@ def convert_tensors_to_external(
|
|
|
228
261
|
tensors: Sequence[_protocols.TensorProtocol],
|
|
229
262
|
base_dir: str | os.PathLike,
|
|
230
263
|
relative_path: str | os.PathLike,
|
|
264
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
231
265
|
) -> list[_core.ExternalTensor]:
|
|
232
266
|
"""Convert a sequence of any TensorProtocol tensors to external tensors.
|
|
233
267
|
|
|
@@ -238,6 +272,8 @@ def convert_tensors_to_external(
|
|
|
238
272
|
tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
|
|
239
273
|
base_dir: Path of base directory.
|
|
240
274
|
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
275
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
276
|
+
for debugging or logging purposes.
|
|
241
277
|
|
|
242
278
|
Returns:
|
|
243
279
|
A list of external tensors derived from a list of input tensors. The order
|
|
@@ -285,7 +321,7 @@ def convert_tensors_to_external(
|
|
|
285
321
|
external_info = _compute_external_data_info(tensor, current_offset)
|
|
286
322
|
external_data_infos.append(external_info)
|
|
287
323
|
current_offset = external_info.offset + external_info.length
|
|
288
|
-
_write_external_data(sorted_tensors, external_data_infos, path)
|
|
324
|
+
_write_external_data(sorted_tensors, external_data_infos, path, callback=callback)
|
|
289
325
|
|
|
290
326
|
# Create external tensor objects
|
|
291
327
|
external_tensors: list[_core.ExternalTensor] = [
|
|
@@ -336,6 +372,7 @@ def unload_from_model(
|
|
|
336
372
|
relative_path: str | os.PathLike,
|
|
337
373
|
*,
|
|
338
374
|
size_threshold_bytes: int = 0,
|
|
375
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
339
376
|
) -> _core.Model:
|
|
340
377
|
"""Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file.
|
|
341
378
|
|
|
@@ -356,6 +393,8 @@ def unload_from_model(
|
|
|
356
393
|
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
357
394
|
E.g. "model.data"
|
|
358
395
|
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
396
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
397
|
+
for debugging or logging purposes.
|
|
359
398
|
|
|
360
399
|
Returns:
|
|
361
400
|
An ir.Model with all initializer data equal or above ``size_threshold_bytes``
|
|
@@ -384,6 +423,7 @@ def unload_from_model(
|
|
|
384
423
|
[v.const_value for v in initializers_to_become_external], # type: ignore[misc]
|
|
385
424
|
base_dir=base_dir,
|
|
386
425
|
relative_path=relative_path,
|
|
426
|
+
callback=callback,
|
|
387
427
|
)
|
|
388
428
|
|
|
389
429
|
# Replace the initializer values with external tensors and save the model
|
onnx_ir/passes/_pass_infra.py
CHANGED
onnx_ir/serde.py
CHANGED
|
@@ -37,6 +37,7 @@ __all__ = [
|
|
|
37
37
|
"deserialize_value_info_proto",
|
|
38
38
|
# Serialization
|
|
39
39
|
"to_proto",
|
|
40
|
+
"to_onnx_text",
|
|
40
41
|
"serialize_attribute_into",
|
|
41
42
|
"serialize_attribute",
|
|
42
43
|
"serialize_dimension_into",
|
|
@@ -62,14 +63,14 @@ __all__ = [
|
|
|
62
63
|
import collections
|
|
63
64
|
import logging
|
|
64
65
|
import os
|
|
65
|
-
from collections.abc import Mapping, Sequence
|
|
66
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
66
67
|
from typing import Any, Callable
|
|
67
68
|
|
|
68
69
|
import numpy as np
|
|
69
|
-
import onnx
|
|
70
|
-
import onnx.external_data_helper
|
|
70
|
+
import onnx # noqa: TID251
|
|
71
|
+
import onnx.external_data_helper # noqa: TID251
|
|
71
72
|
|
|
72
|
-
from onnx_ir import _core, _enums, _protocols, _type_casting
|
|
73
|
+
from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
|
|
73
74
|
|
|
74
75
|
if typing.TYPE_CHECKING:
|
|
75
76
|
import google.protobuf.internal.containers as proto_containers
|
|
@@ -190,13 +191,64 @@ def from_proto(proto: object) -> object:
|
|
|
190
191
|
)
|
|
191
192
|
|
|
192
193
|
|
|
193
|
-
def from_onnx_text(
|
|
194
|
+
def from_onnx_text(
|
|
195
|
+
model_text: str,
|
|
196
|
+
/,
|
|
197
|
+
initializers: Iterable[_protocols.TensorProtocol] | None = None,
|
|
198
|
+
) -> _core.Model:
|
|
194
199
|
"""Convert the ONNX textual representation to an IR model.
|
|
195
200
|
|
|
196
201
|
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
model_text: The ONNX textual representation of the model.
|
|
205
|
+
initializers: Tensors to be added as initializers. If provided, these tensors
|
|
206
|
+
will be added to the model as initializers. If a name does not exist in the model,
|
|
207
|
+
a ValueError will be raised.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
The IR model corresponding to the ONNX textual representation.
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
ValueError: If a tensor name in `initializers` does not match any value in the model.
|
|
197
214
|
"""
|
|
198
215
|
proto = onnx.parser.parse_model(model_text)
|
|
199
|
-
|
|
216
|
+
model = deserialize_model(proto)
|
|
217
|
+
values = _convenience.create_value_mapping(model.graph)
|
|
218
|
+
if initializers:
|
|
219
|
+
# Add initializers to the model
|
|
220
|
+
for tensor in initializers:
|
|
221
|
+
name = tensor.name
|
|
222
|
+
if not name:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"Initializer tensor must have a name. "
|
|
225
|
+
f"Please provide a name for the initializer: {tensor}"
|
|
226
|
+
)
|
|
227
|
+
if name not in values:
|
|
228
|
+
raise ValueError(f"Value '{name}' does not exist in model.")
|
|
229
|
+
initializer = values[name]
|
|
230
|
+
initializer.const_value = tensor
|
|
231
|
+
model.graph.register_initializer(initializer)
|
|
232
|
+
return model
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def to_onnx_text(
|
|
236
|
+
model: _protocols.ModelProtocol, /, exclude_initializers: bool = False
|
|
237
|
+
) -> str:
|
|
238
|
+
"""Convert the IR model to the ONNX textual representation.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
model: The IR model to convert.
|
|
242
|
+
exclude_initializers: If True, the initializers will not be included in the output.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
The ONNX textual representation of the model.
|
|
246
|
+
"""
|
|
247
|
+
proto = serialize_model(model)
|
|
248
|
+
if exclude_initializers:
|
|
249
|
+
del proto.graph.initializer[:]
|
|
250
|
+
text = onnx.printer.to_text(proto)
|
|
251
|
+
return text
|
|
200
252
|
|
|
201
253
|
|
|
202
254
|
@typing.overload
|
|
@@ -462,6 +514,14 @@ def _get_field(proto: Any, field: str) -> Any:
|
|
|
462
514
|
def deserialize_opset_import(
|
|
463
515
|
protos: Sequence[onnx.OperatorSetIdProto],
|
|
464
516
|
) -> dict[str, int]:
|
|
517
|
+
"""Deserialize a sequence of OperatorSetIdProto to opset imports mapping.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
protos: The sequence of ONNX OperatorSetIdProto objects.
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
A dictionary mapping domain strings to version integers.
|
|
524
|
+
"""
|
|
465
525
|
return {opset.domain: opset.version for opset in protos}
|
|
466
526
|
|
|
467
527
|
|
|
@@ -495,6 +555,14 @@ def _parse_experimental_function_value_info_name(
|
|
|
495
555
|
|
|
496
556
|
|
|
497
557
|
def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
|
|
558
|
+
"""Deserialize an ONNX ModelProto into an IR Model.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
proto: The ONNX ModelProto to deserialize.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
An IR Model object representing the ONNX model.
|
|
565
|
+
"""
|
|
498
566
|
graph = _deserialize_graph(proto.graph, [])
|
|
499
567
|
graph.opset_imports.update(deserialize_opset_import(proto.opset_import))
|
|
500
568
|
|
|
@@ -699,6 +767,14 @@ def _deserialize_graph(
|
|
|
699
767
|
|
|
700
768
|
@_capture_errors(lambda proto: proto.name)
|
|
701
769
|
def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
770
|
+
"""Deserialize an ONNX FunctionProto into an IR Function.
|
|
771
|
+
|
|
772
|
+
Args:
|
|
773
|
+
proto: The ONNX FunctionProto to deserialize.
|
|
774
|
+
|
|
775
|
+
Returns:
|
|
776
|
+
An IR Function object representing the ONNX function.
|
|
777
|
+
"""
|
|
702
778
|
inputs = [_core.Input(name) for name in proto.input]
|
|
703
779
|
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
|
|
704
780
|
value_info = {info.name: info for info in getattr(proto, "value_info", [])}
|
|
@@ -741,6 +817,15 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
|
741
817
|
def deserialize_value_info_proto(
|
|
742
818
|
proto: onnx.ValueInfoProto, value: _core.Value | None
|
|
743
819
|
) -> _core.Value:
|
|
820
|
+
"""Deserialize an ONNX ValueInfoProto into an IR Value.
|
|
821
|
+
|
|
822
|
+
Args:
|
|
823
|
+
proto: The ONNX ValueInfoProto to deserialize.
|
|
824
|
+
value: An existing Value to update, or None to create a new one.
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
An IR Value object with type and shape information populated from the proto.
|
|
828
|
+
"""
|
|
744
829
|
if value is None:
|
|
745
830
|
value = _core.Value(name=proto.name)
|
|
746
831
|
value.shape = deserialize_type_proto_for_shape(proto.type)
|
|
@@ -767,6 +852,14 @@ def _deserialize_quantization_annotation(
|
|
|
767
852
|
|
|
768
853
|
@_capture_errors(str)
|
|
769
854
|
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
|
|
855
|
+
"""Deserialize an ONNX TensorShapeProto into an IR Shape.
|
|
856
|
+
|
|
857
|
+
Args:
|
|
858
|
+
proto: The ONNX TensorShapeProto to deserialize.
|
|
859
|
+
|
|
860
|
+
Returns:
|
|
861
|
+
An IR Shape object representing the tensor shape.
|
|
862
|
+
"""
|
|
770
863
|
# This logic handles when the shape is [] as well
|
|
771
864
|
dim_protos = proto.dim
|
|
772
865
|
deserialized_dim_denotations = [
|
|
@@ -779,6 +872,14 @@ def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
|
|
|
779
872
|
|
|
780
873
|
@_capture_errors(str)
|
|
781
874
|
def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
|
|
875
|
+
"""Extract and deserialize shape information from an ONNX TypeProto.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
proto: The ONNX TypeProto to extract shape from.
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
An IR Shape object if shape information is present, None otherwise.
|
|
882
|
+
"""
|
|
782
883
|
if proto.HasField("tensor_type"):
|
|
783
884
|
if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
|
|
784
885
|
return None
|
|
@@ -806,6 +907,14 @@ def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | Non
|
|
|
806
907
|
def deserialize_type_proto_for_type(
|
|
807
908
|
proto: onnx.TypeProto,
|
|
808
909
|
) -> _protocols.TypeProtocol | None:
|
|
910
|
+
"""Extract and deserialize type information from an ONNX TypeProto.
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
proto: The ONNX TypeProto to extract type from.
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
An IR type object (TensorType, SequenceType, etc.) if type information is present, None otherwise.
|
|
917
|
+
"""
|
|
809
918
|
denotation = _get_field(proto, "denotation")
|
|
810
919
|
if proto.HasField("tensor_type"):
|
|
811
920
|
if (elem_type := _get_field(proto.tensor_type, "elem_type")) is None:
|
|
@@ -906,6 +1015,14 @@ _deserialize_string_string_maps = deserialize_metadata_props
|
|
|
906
1015
|
|
|
907
1016
|
|
|
908
1017
|
def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr:
|
|
1018
|
+
"""Deserialize an ONNX AttributeProto into an IR Attribute.
|
|
1019
|
+
|
|
1020
|
+
Args:
|
|
1021
|
+
proto: The ONNX AttributeProto to deserialize.
|
|
1022
|
+
|
|
1023
|
+
Returns:
|
|
1024
|
+
An IR Attribute object representing the ONNX attribute.
|
|
1025
|
+
"""
|
|
909
1026
|
return _deserialize_attribute(proto, [])
|
|
910
1027
|
|
|
911
1028
|
|
|
@@ -979,6 +1096,14 @@ def _deserialize_attribute(
|
|
|
979
1096
|
|
|
980
1097
|
|
|
981
1098
|
def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
|
|
1099
|
+
"""Deserialize an ONNX NodeProto into an IR Node.
|
|
1100
|
+
|
|
1101
|
+
Args:
|
|
1102
|
+
proto: The ONNX NodeProto to deserialize.
|
|
1103
|
+
|
|
1104
|
+
Returns:
|
|
1105
|
+
An IR Node object representing the ONNX node.
|
|
1106
|
+
"""
|
|
982
1107
|
return _deserialize_node(
|
|
983
1108
|
proto, scoped_values=[{}], value_info={}, quantization_annotations={}
|
|
984
1109
|
)
|
|
@@ -1097,6 +1222,14 @@ def _deserialize_node(
|
|
|
1097
1222
|
|
|
1098
1223
|
|
|
1099
1224
|
def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto:
|
|
1225
|
+
"""Serialize an IR Model to an ONNX ModelProto.
|
|
1226
|
+
|
|
1227
|
+
Args:
|
|
1228
|
+
model: The IR Model to serialize.
|
|
1229
|
+
|
|
1230
|
+
Returns:
|
|
1231
|
+
The serialized ONNX ModelProto object.
|
|
1232
|
+
"""
|
|
1100
1233
|
return serialize_model_into(onnx.ModelProto(), from_=model)
|
|
1101
1234
|
|
|
1102
1235
|
|
|
@@ -1418,6 +1551,14 @@ def serialize_function_into(
|
|
|
1418
1551
|
|
|
1419
1552
|
|
|
1420
1553
|
def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
|
|
1554
|
+
"""Serialize an IR Node to an ONNX NodeProto.
|
|
1555
|
+
|
|
1556
|
+
Args:
|
|
1557
|
+
node: The IR Node to serialize.
|
|
1558
|
+
|
|
1559
|
+
Returns:
|
|
1560
|
+
The serialized ONNX NodeProto object.
|
|
1561
|
+
"""
|
|
1421
1562
|
node_proto = onnx.NodeProto()
|
|
1422
1563
|
serialize_node_into(node_proto, from_=node)
|
|
1423
1564
|
return node_proto
|
|
@@ -1472,6 +1613,14 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc
|
|
|
1472
1613
|
|
|
1473
1614
|
|
|
1474
1615
|
def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto:
|
|
1616
|
+
"""Serialize an IR Tensor to an ONNX TensorProto.
|
|
1617
|
+
|
|
1618
|
+
Args:
|
|
1619
|
+
tensor: The IR Tensor to serialize.
|
|
1620
|
+
|
|
1621
|
+
Returns:
|
|
1622
|
+
The serialized ONNX TensorProto object.
|
|
1623
|
+
"""
|
|
1475
1624
|
tensor_proto = onnx.TensorProto()
|
|
1476
1625
|
serialize_tensor_into(tensor_proto, from_=tensor)
|
|
1477
1626
|
return tensor_proto
|
|
@@ -1514,6 +1663,14 @@ def serialize_tensor_into(
|
|
|
1514
1663
|
|
|
1515
1664
|
|
|
1516
1665
|
def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.AttributeProto:
|
|
1666
|
+
"""Serialize an IR Attribute to an ONNX AttributeProto.
|
|
1667
|
+
|
|
1668
|
+
Args:
|
|
1669
|
+
attribute: The IR Attribute to serialize.
|
|
1670
|
+
|
|
1671
|
+
Returns:
|
|
1672
|
+
The serialized ONNX AttributeProto object.
|
|
1673
|
+
"""
|
|
1517
1674
|
attribute_proto = onnx.AttributeProto()
|
|
1518
1675
|
serialize_attribute_into(attribute_proto, from_=attribute)
|
|
1519
1676
|
return attribute_proto
|
|
@@ -1678,6 +1835,14 @@ def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtoc
|
|
|
1678
1835
|
|
|
1679
1836
|
|
|
1680
1837
|
def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto:
|
|
1838
|
+
"""Serialize an IR Type to an ONNX TypeProto.
|
|
1839
|
+
|
|
1840
|
+
Args:
|
|
1841
|
+
type_protocol: The IR Type to serialize.
|
|
1842
|
+
|
|
1843
|
+
Returns:
|
|
1844
|
+
The serialized ONNX TypeProto object.
|
|
1845
|
+
"""
|
|
1681
1846
|
type_proto = onnx.TypeProto()
|
|
1682
1847
|
serialize_type_into(type_proto, from_=type_protocol)
|
|
1683
1848
|
return type_proto
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
6
|
License: Apache License v2.0
|
|
7
|
-
Project-URL: Homepage, https://onnx.ai/
|
|
8
|
-
Project-URL: Issues, https://github.com/onnx/
|
|
9
|
-
Project-URL: Repository, https://github.com/onnx/
|
|
7
|
+
Project-URL: Homepage, https://onnx.ai/ir-py
|
|
8
|
+
Project-URL: Issues, https://github.com/onnx/ir-py/issues
|
|
9
|
+
Project-URL: Repository, https://github.com/onnx/ir-py
|
|
10
10
|
Classifier: Development Status :: 4 - Beta
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.9
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -33,6 +33,24 @@ Dynamic: license-file
|
|
|
33
33
|
|
|
34
34
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
35
35
|
|
|
36
|
+
## Getting Started
|
|
37
|
+
|
|
38
|
+
[onnx-ir documentation](https://onnx.ai/ir-py/)
|
|
39
|
+
|
|
40
|
+
### Installation
|
|
41
|
+
|
|
42
|
+
Via pip:
|
|
43
|
+
|
|
44
|
+
```
|
|
45
|
+
pip install onnx-ir
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Or from source:
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
pip install git+https://github.com/onnx/ir-py.git
|
|
52
|
+
```
|
|
53
|
+
|
|
36
54
|
## Features ✨
|
|
37
55
|
|
|
38
56
|
- Full ONNX spec support: all valid models representable by ONNX protobuf, and a subset of invalid models (so you can load and fix them).
|