onnx-ir 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +1 -1
- onnx_ir/_convenience/__init__.py +5 -0
- onnx_ir/_core.py +25 -3
- onnx_ir/_enums.py +2 -0
- onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +104 -75
- onnx_ir/passes/common/initializer_deduplication.py +56 -0
- onnx_ir/serde.py +5 -0
- onnx_ir/tensor_adapters.py +62 -7
- {onnx_ir-0.1.2.dist-info → onnx_ir-0.1.3.dist-info}/METADATA +1 -1
- {onnx_ir-0.1.2.dist-info → onnx_ir-0.1.3.dist-info}/RECORD +14 -13
- {onnx_ir-0.1.2.dist-info → onnx_ir-0.1.3.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.2.dist-info → onnx_ir-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.2.dist-info → onnx_ir-0.1.3.dist-info}/top_level.txt +0 -0
onnx_ir/__init__.py
CHANGED
onnx_ir/_convenience/__init__.py
CHANGED
|
@@ -323,6 +323,9 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
|
|
|
323
323
|
and the first value with that name is returned. Values with empty names
|
|
324
324
|
are excluded from the mapping.
|
|
325
325
|
|
|
326
|
+
.. versionchanged:: 0.1.2
|
|
327
|
+
Values from subgraphs are now included in the mapping.
|
|
328
|
+
|
|
326
329
|
Args:
|
|
327
330
|
graph: The graph to extract the mapping from.
|
|
328
331
|
|
|
@@ -410,6 +413,8 @@ def get_const_tensor(
|
|
|
410
413
|
it will propagate the shape and type of the constant tensor to the value
|
|
411
414
|
if `propagate_shape_type` is set to True.
|
|
412
415
|
|
|
416
|
+
.. versionadded:: 0.1.2
|
|
417
|
+
|
|
413
418
|
Args:
|
|
414
419
|
value: The value to get the constant tensor from.
|
|
415
420
|
propagate_shape_type: If True, the shape and type of the value will be
|
onnx_ir/_core.py
CHANGED
|
@@ -417,6 +417,9 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
417
417
|
else:
|
|
418
418
|
self._shape = shape
|
|
419
419
|
self._shape.freeze()
|
|
420
|
+
if isinstance(value, np.generic):
|
|
421
|
+
# Turn numpy scalar into a numpy array
|
|
422
|
+
value = np.array(value) # type: ignore[assignment]
|
|
420
423
|
if dtype is None:
|
|
421
424
|
if isinstance(value, np.ndarray):
|
|
422
425
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
@@ -964,7 +967,10 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
|
|
|
964
967
|
|
|
965
968
|
|
|
966
969
|
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
967
|
-
"""A tensor that stores 4bit datatypes in packed format.
|
|
970
|
+
"""A tensor that stores 4bit datatypes in packed format.
|
|
971
|
+
|
|
972
|
+
.. versionadded:: 0.1.2
|
|
973
|
+
"""
|
|
968
974
|
|
|
969
975
|
__slots__ = (
|
|
970
976
|
"_dtype",
|
|
@@ -2335,6 +2341,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2335
2341
|
seen as a Sequence of nodes and should be used as such. For example, to obtain
|
|
2336
2342
|
all nodes as a list, call ``list(graph)``.
|
|
2337
2343
|
|
|
2344
|
+
.. versionchanged:: 0.1.1
|
|
2345
|
+
Values with non-none producers will be rejected as graph inputs or initializers.
|
|
2346
|
+
|
|
2347
|
+
.. versionadded:: 0.1.1
|
|
2348
|
+
Added ``add`` method to initializers and attributes.
|
|
2349
|
+
|
|
2338
2350
|
Attributes:
|
|
2339
2351
|
name: The name of the graph.
|
|
2340
2352
|
inputs: The input values of the graph.
|
|
@@ -2545,12 +2557,17 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2545
2557
|
Consider using
|
|
2546
2558
|
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2547
2559
|
traversals on nodes.
|
|
2560
|
+
|
|
2561
|
+
.. versionadded:: 0.1.2
|
|
2548
2562
|
"""
|
|
2549
2563
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2550
2564
|
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
2551
2565
|
|
|
2552
2566
|
def subgraphs(self) -> Iterator[Graph]:
|
|
2553
|
-
"""Get all subgraphs in the graph in O(#nodes + #attributes) time.
|
|
2567
|
+
"""Get all subgraphs in the graph in O(#nodes + #attributes) time.
|
|
2568
|
+
|
|
2569
|
+
.. versionadded:: 0.1.2
|
|
2570
|
+
"""
|
|
2554
2571
|
seen_graphs: set[Graph] = set()
|
|
2555
2572
|
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
2556
2573
|
graph = node.graph
|
|
@@ -3216,12 +3233,17 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3216
3233
|
Consider using
|
|
3217
3234
|
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
3218
3235
|
traversals on nodes.
|
|
3236
|
+
|
|
3237
|
+
.. versionadded:: 0.1.2
|
|
3219
3238
|
"""
|
|
3220
3239
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
3221
3240
|
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
3222
3241
|
|
|
3223
3242
|
def subgraphs(self) -> Iterator[Graph]:
|
|
3224
|
-
"""Get all subgraphs in the function in O(#nodes + #attributes) time.
|
|
3243
|
+
"""Get all subgraphs in the function in O(#nodes + #attributes) time.
|
|
3244
|
+
|
|
3245
|
+
.. versionadded:: 0.1.2
|
|
3246
|
+
"""
|
|
3225
3247
|
seen_graphs: set[Graph] = set()
|
|
3226
3248
|
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
3227
3249
|
graph = node.graph
|
onnx_ir/_enums.py
CHANGED
|
@@ -6,6 +6,7 @@ __all__ = [
|
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
8
|
"CommonSubexpressionEliminationPass",
|
|
9
|
+
"DeduplicateInitializersPass",
|
|
9
10
|
"InlinePass",
|
|
10
11
|
"LiftConstantsToInitializersPass",
|
|
11
12
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
@@ -29,6 +30,9 @@ from onnx_ir.passes.common.constant_manipulation import (
|
|
|
29
30
|
LiftSubgraphInitializersToMainGraphPass,
|
|
30
31
|
RemoveInitializersFromInputsPass,
|
|
31
32
|
)
|
|
33
|
+
from onnx_ir.passes.common.initializer_deduplication import (
|
|
34
|
+
DeduplicateInitializersPass,
|
|
35
|
+
)
|
|
32
36
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
33
37
|
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
34
38
|
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
@@ -17,93 +17,122 @@ logger = logging.getLogger(__name__)
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
-
"""Eliminate common subexpression in ONNX graphs.
|
|
20
|
+
"""Eliminate common subexpression in ONNX graphs.
|
|
21
|
+
|
|
22
|
+
.. versionadded:: 0.1.1
|
|
23
|
+
|
|
24
|
+
.. versionchanged:: 0.1.3
|
|
25
|
+
Constant nodes with values smaller than ``size_limit`` will be CSE'd.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
size_limit: The maximum size of the tensor to be csed. If the tensor contains
|
|
29
|
+
number of elements larger than size_limit, it will not be cse'd. Default is 10.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, size_limit: int = 10):
|
|
34
|
+
"""Initialize the CommonSubexpressionEliminationPass."""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.size_limit = size_limit
|
|
21
37
|
|
|
22
38
|
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
23
39
|
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
24
|
-
modified = False
|
|
25
40
|
graph = model.graph
|
|
26
|
-
|
|
27
|
-
modified = _eliminate_common_subexpression(graph, modified)
|
|
41
|
+
modified = self._eliminate_common_subexpression(graph)
|
|
28
42
|
|
|
29
43
|
return ir.passes.PassResult(
|
|
30
44
|
model,
|
|
31
45
|
modified=modified,
|
|
32
46
|
)
|
|
33
47
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
48
|
+
def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
|
|
49
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
50
|
+
modified: bool = False
|
|
51
|
+
# node to node identifier, length of outputs, inputs, and attributes
|
|
52
|
+
existing_node_info_to_the_node: dict[
|
|
53
|
+
tuple[
|
|
54
|
+
ir.OperatorIdentifier,
|
|
55
|
+
int, # len(outputs)
|
|
56
|
+
tuple[int, ...], # input ids
|
|
57
|
+
tuple[tuple[str, object], ...], # attributes
|
|
58
|
+
],
|
|
59
|
+
ir.Node,
|
|
60
|
+
] = {}
|
|
61
|
+
|
|
62
|
+
for node in graph:
|
|
63
|
+
# Skip control flow ops like Loop and If.
|
|
64
|
+
control_flow_op: bool = False
|
|
65
|
+
# Skip large tensors to avoid cse weights and bias.
|
|
66
|
+
large_tensor: bool = False
|
|
67
|
+
# Use equality to check if the node is a common subexpression.
|
|
68
|
+
attributes = {}
|
|
69
|
+
for k, v in node.attributes.items():
|
|
70
|
+
# TODO(exporter team): CSE subgraphs.
|
|
71
|
+
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
72
|
+
# because attribute: graph won't match.
|
|
73
|
+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
74
|
+
control_flow_op = True
|
|
75
|
+
break
|
|
76
|
+
# The attribute value could be directly taken from the original
|
|
77
|
+
# protobuf, so we need to make a copy of it.
|
|
78
|
+
value = v.value
|
|
79
|
+
if v.type in (
|
|
80
|
+
ir.AttributeType.INTS,
|
|
81
|
+
ir.AttributeType.FLOATS,
|
|
82
|
+
ir.AttributeType.STRINGS,
|
|
83
|
+
):
|
|
84
|
+
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
85
|
+
# to ensure they are hashable.
|
|
86
|
+
value = tuple(value)
|
|
87
|
+
elif v.type is ir.AttributeType.TENSOR:
|
|
88
|
+
if value.size > self.size_limit:
|
|
89
|
+
# If the tensor is larger than the size limit, we skip it.
|
|
90
|
+
large_tensor = True
|
|
91
|
+
break
|
|
92
|
+
np_value = value.numpy()
|
|
93
|
+
|
|
94
|
+
value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
|
|
95
|
+
attributes[k] = value
|
|
96
|
+
|
|
97
|
+
if control_flow_op:
|
|
98
|
+
# If the node is a control flow op, we skip it.
|
|
59
99
|
logger.debug("Skipping control flow op %s", node)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
):
|
|
68
|
-
#
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if _is_non_deterministic_op(node):
|
|
79
|
-
# If the node is a non-deterministic op, we skip it.
|
|
80
|
-
logger.debug("Skipping non-deterministic op %s", node)
|
|
81
|
-
continue
|
|
82
|
-
|
|
83
|
-
node_info = (
|
|
84
|
-
node.op_identifier(),
|
|
85
|
-
len(node.outputs),
|
|
86
|
-
tuple(id(input) for input in node.inputs),
|
|
87
|
-
tuple(sorted(attributes.items())),
|
|
88
|
-
)
|
|
89
|
-
# Check if the node is a common subexpression.
|
|
90
|
-
if node_info in existing_node_info_to_the_node:
|
|
91
|
-
# If it is, this node has an existing node with the same
|
|
92
|
-
# operator, number of outputs, inputs, and attributes.
|
|
93
|
-
# We replace the node with the existing node.
|
|
94
|
-
modified = True
|
|
95
|
-
existing_node = existing_node_info_to_the_node[node_info]
|
|
96
|
-
_remove_node_and_replace_values(
|
|
97
|
-
graph,
|
|
98
|
-
remove_node=node,
|
|
99
|
-
remove_values=node.outputs,
|
|
100
|
-
new_values=existing_node.outputs,
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
if large_tensor:
|
|
103
|
+
# If the node has a large tensor, we skip it.
|
|
104
|
+
logger.debug("Skipping large tensor in node %s", node)
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
if _is_non_deterministic_op(node):
|
|
108
|
+
# If the node is a non-deterministic op, we skip it.
|
|
109
|
+
logger.debug("Skipping non-deterministic op %s", node)
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
node_info = (
|
|
113
|
+
node.op_identifier(),
|
|
114
|
+
len(node.outputs),
|
|
115
|
+
tuple(id(input) for input in node.inputs),
|
|
116
|
+
tuple(sorted(attributes.items())),
|
|
101
117
|
)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
118
|
+
# Check if the node is a common subexpression.
|
|
119
|
+
if node_info in existing_node_info_to_the_node:
|
|
120
|
+
# If it is, this node has an existing node with the same
|
|
121
|
+
# operator, number of outputs, inputs, and attributes.
|
|
122
|
+
# We replace the node with the existing node.
|
|
123
|
+
modified = True
|
|
124
|
+
existing_node = existing_node_info_to_the_node[node_info]
|
|
125
|
+
_remove_node_and_replace_values(
|
|
126
|
+
graph,
|
|
127
|
+
remove_node=node,
|
|
128
|
+
remove_values=node.outputs,
|
|
129
|
+
new_values=existing_node.outputs,
|
|
130
|
+
)
|
|
131
|
+
logger.debug("Reusing node %s", existing_node)
|
|
132
|
+
else:
|
|
133
|
+
# If it is not, add to the mapping.
|
|
134
|
+
existing_node_info_to_the_node[node_info] = node
|
|
135
|
+
return modified
|
|
107
136
|
|
|
108
137
|
|
|
109
138
|
def _remove_node_and_replace_values(
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DeduplicateInitializersPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import onnx_ir as ir
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
16
|
+
"""Remove duplicated initializer tensors from the graph.
|
|
17
|
+
|
|
18
|
+
This pass detects initializers with identical shape, dtype, and content,
|
|
19
|
+
and replaces all duplicate references with a canonical one.
|
|
20
|
+
|
|
21
|
+
To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
22
|
+
to lift the initializers to the main graph first before running pass.
|
|
23
|
+
|
|
24
|
+
.. versionadded:: 0.1.3
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, size_limit: int = 1024):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.size_limit = size_limit
|
|
30
|
+
|
|
31
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
32
|
+
graph = model.graph
|
|
33
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
34
|
+
modified = False
|
|
35
|
+
|
|
36
|
+
for initializer in tuple(graph.initializers.values()):
|
|
37
|
+
# TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
|
|
38
|
+
# out from the main graph before running this pass.
|
|
39
|
+
const_val = initializer.const_value
|
|
40
|
+
if const_val is None:
|
|
41
|
+
# Skip if initializer has no constant value
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
if const_val.size > self.size_limit:
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
|
|
48
|
+
if key in initializers:
|
|
49
|
+
modified = True
|
|
50
|
+
ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
|
|
51
|
+
assert initializer.name is not None
|
|
52
|
+
graph.initializers.pop(initializer.name)
|
|
53
|
+
else:
|
|
54
|
+
initializers[key] = initializer # type: ignore[index]
|
|
55
|
+
|
|
56
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
onnx_ir/serde.py
CHANGED
|
@@ -200,6 +200,9 @@ def from_onnx_text(
|
|
|
200
200
|
|
|
201
201
|
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
|
|
202
202
|
|
|
203
|
+
.. versionchanged:: 0.1.2
|
|
204
|
+
Added the ``initializers`` argument.
|
|
205
|
+
|
|
203
206
|
Args:
|
|
204
207
|
model_text: The ONNX textual representation of the model.
|
|
205
208
|
initializers: Tensors to be added as initializers. If provided, these tensors
|
|
@@ -237,6 +240,8 @@ def to_onnx_text(
|
|
|
237
240
|
) -> str:
|
|
238
241
|
"""Convert the IR model to the ONNX textual representation.
|
|
239
242
|
|
|
243
|
+
.. versionadded:: 0.1.2
|
|
244
|
+
|
|
240
245
|
Args:
|
|
241
246
|
model: The IR model to convert.
|
|
242
247
|
exclude_initializers: If True, the initializers will not be included in the output.
|
onnx_ir/tensor_adapters.py
CHANGED
|
@@ -29,6 +29,8 @@ Example::
|
|
|
29
29
|
from __future__ import annotations
|
|
30
30
|
|
|
31
31
|
__all__ = [
|
|
32
|
+
"from_torch_dtype",
|
|
33
|
+
"to_torch_dtype",
|
|
32
34
|
"TorchTensor",
|
|
33
35
|
]
|
|
34
36
|
|
|
@@ -44,14 +46,17 @@ if TYPE_CHECKING:
|
|
|
44
46
|
import torch
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
|
|
50
|
+
_ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
|
|
54
|
+
"""Convert a PyTorch dtype to an ONNX IR DataType."""
|
|
55
|
+
global _TORCH_DTYPE_TO_ONNX
|
|
56
|
+
if _TORCH_DTYPE_TO_ONNX is None:
|
|
52
57
|
import torch
|
|
53
58
|
|
|
54
|
-
_TORCH_DTYPE_TO_ONNX
|
|
59
|
+
_TORCH_DTYPE_TO_ONNX = {
|
|
55
60
|
torch.bfloat16: ir.DataType.BFLOAT16,
|
|
56
61
|
torch.bool: ir.DataType.BOOL,
|
|
57
62
|
torch.complex128: ir.DataType.COMPLEX128,
|
|
@@ -72,8 +77,58 @@ class TorchTensor(_core.Tensor):
|
|
|
72
77
|
torch.uint32: ir.DataType.UINT32,
|
|
73
78
|
torch.uint64: ir.DataType.UINT64,
|
|
74
79
|
}
|
|
80
|
+
if dtype not in _TORCH_DTYPE_TO_ONNX:
|
|
81
|
+
raise TypeError(
|
|
82
|
+
f"Unsupported PyTorch dtype '{dtype}'. "
|
|
83
|
+
"Please use a supported dtype from the list: "
|
|
84
|
+
f"{list(_TORCH_DTYPE_TO_ONNX.keys())}"
|
|
85
|
+
)
|
|
86
|
+
return _TORCH_DTYPE_TO_ONNX[dtype]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
|
|
90
|
+
"""Convert an ONNX IR DataType to a PyTorch dtype."""
|
|
91
|
+
global _ONNX_DTYPE_TO_TORCH
|
|
92
|
+
if _ONNX_DTYPE_TO_TORCH is None:
|
|
93
|
+
import torch
|
|
94
|
+
|
|
95
|
+
_ONNX_DTYPE_TO_TORCH = {
|
|
96
|
+
ir.DataType.BFLOAT16: torch.bfloat16,
|
|
97
|
+
ir.DataType.BOOL: torch.bool,
|
|
98
|
+
ir.DataType.COMPLEX128: torch.complex128,
|
|
99
|
+
ir.DataType.COMPLEX64: torch.complex64,
|
|
100
|
+
ir.DataType.FLOAT16: torch.float16,
|
|
101
|
+
ir.DataType.FLOAT: torch.float32,
|
|
102
|
+
ir.DataType.DOUBLE: torch.float64,
|
|
103
|
+
ir.DataType.FLOAT8E4M3FN: torch.float8_e4m3fn,
|
|
104
|
+
ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
|
|
105
|
+
ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
|
|
106
|
+
ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
|
|
107
|
+
ir.DataType.INT16: torch.int16,
|
|
108
|
+
ir.DataType.INT32: torch.int32,
|
|
109
|
+
ir.DataType.INT64: torch.int64,
|
|
110
|
+
ir.DataType.INT8: torch.int8,
|
|
111
|
+
ir.DataType.UINT8: torch.uint8,
|
|
112
|
+
ir.DataType.UINT16: torch.uint16,
|
|
113
|
+
ir.DataType.UINT32: torch.uint32,
|
|
114
|
+
ir.DataType.UINT64: torch.uint64,
|
|
115
|
+
}
|
|
116
|
+
if dtype not in _ONNX_DTYPE_TO_TORCH:
|
|
117
|
+
raise TypeError(
|
|
118
|
+
f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
|
|
119
|
+
"Please use a supported dtype from the list: "
|
|
120
|
+
f"{list(_ONNX_DTYPE_TO_TORCH.keys())}"
|
|
121
|
+
)
|
|
122
|
+
return _ONNX_DTYPE_TO_TORCH[dtype]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class TorchTensor(_core.Tensor):
|
|
126
|
+
def __init__(
|
|
127
|
+
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
|
|
128
|
+
):
|
|
129
|
+
# Pass the tensor as the raw data to ir.Tensor's constructor
|
|
75
130
|
super().__init__(
|
|
76
|
-
tensor, dtype=
|
|
131
|
+
tensor, dtype=from_torch_dtype(tensor.dtype), name=name, doc_string=doc_string
|
|
77
132
|
)
|
|
78
133
|
|
|
79
134
|
def numpy(self) -> npt.NDArray:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
onnx_ir/__init__.py,sha256=
|
|
2
|
-
onnx_ir/_core.py,sha256
|
|
1
|
+
onnx_ir/__init__.py,sha256=5KP1Ngl2qyWiqb5S0Ol5owYsbU0geo4LFwGwN8EXTIk,3424
|
|
2
|
+
onnx_ir/_core.py,sha256=-9BpVTZHuHQ9jsms33wqu4NjMEaDF_M57sIuVxYcM1I,137964
|
|
3
3
|
onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
|
|
4
|
-
onnx_ir/_enums.py,sha256=
|
|
4
|
+
onnx_ir/_enums.py,sha256=4lmm_DFKEtz6PqNw6gt6GcqrBYHisctgKMsUbQCm5N8,8252
|
|
5
5
|
onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
|
|
6
6
|
onnx_ir/_graph_containers.py,sha256=PRKrshRZ5rzWCgRs1TefzJq9n8wyo7OqeKy3XxMhyys,14265
|
|
7
7
|
onnx_ir/_io.py,sha256=GWwA4XOZ-ZX1cgibgaYD0K0O5d9LX21ZwcBN02Wrh04,5205
|
|
@@ -15,28 +15,29 @@ onnx_ir/_type_casting.py,sha256=8iZDVrNAx_FwRVt48G4tkzIOFu3I6AsETpH3fdxcyEI,3387
|
|
|
15
15
|
onnx_ir/_version_utils.py,sha256=bZThuE7meVHFOY1DLsmss9WshVIp9iig7udGfDbVaK4,1333
|
|
16
16
|
onnx_ir/convenience.py,sha256=0B1epuXZCSmY4FbW2vaYfR-t5ubxBZ1UruiytHs-zFw,917
|
|
17
17
|
onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,18079
|
|
18
|
-
onnx_ir/serde.py,sha256=
|
|
18
|
+
onnx_ir/serde.py,sha256=YkbYfQMwn0YAzTd3tVDSWJ-NBiSVsG-74T6xk3e5iTU,75073
|
|
19
19
|
onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
|
|
20
|
-
onnx_ir/tensor_adapters.py,sha256=
|
|
20
|
+
onnx_ir/tensor_adapters.py,sha256=dXuapwfFcpLhjKC6AOqCXbtY3WvDaEHoCNPwjnUK7_o,6565
|
|
21
21
|
onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
|
|
22
22
|
onnx_ir/traversal.py,sha256=Z69wzYBNljn1S7PhVTYgwMftrfsdEBLoa0JYteOhLL0,2863
|
|
23
|
-
onnx_ir/_convenience/__init__.py,sha256=
|
|
23
|
+
onnx_ir/_convenience/__init__.py,sha256=DQ-Bz1wTiZJEARCFxDqZvYexWviGmwvDzE_1hR-vp0Q,19182
|
|
24
24
|
onnx_ir/_convenience/_constructors.py,sha256=5GhlYy_xCE2ng7l_4cNx06WQsNDyvS-0U1HgOpPKJEk,8347
|
|
25
25
|
onnx_ir/_thirdparty/asciichartpy.py,sha256=afQ0fsqko2uYRPAR4TZBrQxvCb4eN8lxZ2yDFbVQq_s,10533
|
|
26
26
|
onnx_ir/passes/__init__.py,sha256=M_Tcl_-qGSNPluFIvOoeDyh0qAwNayaYyXDS5UJUJPQ,764
|
|
27
27
|
onnx_ir/passes/_pass_infra.py,sha256=xIOw_zZIuOqD4Z_wZ4OvsqXfh2IZMoMlDp1xQ_MPQlc,9567
|
|
28
|
-
onnx_ir/passes/common/__init__.py,sha256=
|
|
28
|
+
onnx_ir/passes/common/__init__.py,sha256=GrrscfBekrIjxrYusgvTgP80OrgY1GMJwZMInRQmcL4,1467
|
|
29
29
|
onnx_ir/passes/common/_c_api_utils.py,sha256=g6riA6xNGVWaO5YjVHZ0krrfslWHmRlryRkwB8X56cg,2907
|
|
30
30
|
onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMOGjdvVwG9yHQTkSphUgDlM0ME,2365
|
|
31
|
-
onnx_ir/passes/common/common_subexpression_elimination.py,sha256=
|
|
31
|
+
onnx_ir/passes/common/common_subexpression_elimination.py,sha256=wZ1zEPdCshYB_ifP9fCAVfzQkesE6uhCfzCuL2qO5fA,7948
|
|
32
32
|
onnx_ir/passes/common/constant_manipulation.py,sha256=_fGDwn0Axl2Q8APfc2m_mLMH28T-Mc9kIlpzBXoe3q4,8779
|
|
33
|
+
onnx_ir/passes/common/initializer_deduplication.py,sha256=4CIVFYfdXUlmF2sAx560c_pTwYVXtX5hcSwWzUKm5uc,2061
|
|
33
34
|
onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
|
|
34
35
|
onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6svE0cCyDew,1691
|
|
35
36
|
onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
|
|
36
37
|
onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
|
|
37
38
|
onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
|
|
38
|
-
onnx_ir-0.1.
|
|
39
|
-
onnx_ir-0.1.
|
|
40
|
-
onnx_ir-0.1.
|
|
41
|
-
onnx_ir-0.1.
|
|
42
|
-
onnx_ir-0.1.
|
|
39
|
+
onnx_ir-0.1.3.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
40
|
+
onnx_ir-0.1.3.dist-info/METADATA,sha256=vKG8o_nAUJfjM05rahv0g-FCeHkHXIwCAcuYzSY6PH8,4782
|
|
41
|
+
onnx_ir-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
42
|
+
onnx_ir-0.1.3.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
|
|
43
|
+
onnx_ir-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|