onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +874 -257
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +40 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +217 -0
- onnx_ir/passes/common/inliner.py +332 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.1.dist-info/METADATA +53 -0
- onnx_ir-0.1.1.dist-info/RECORD +42 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
onnx_ir/_name_authority.py
CHANGED
onnx_ir/_polyfill.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Polyfill for Python builtin functions."""
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
if sys.version_info >= (3, 10):
|
|
10
|
+
zip = zip # pylint: disable=self-assigning-variable
|
|
11
|
+
else:
|
|
12
|
+
# zip(..., strict=True) was added in Python 3.10
|
|
13
|
+
# TODO: Remove this polyfill when we drop support for Python 3.9
|
|
14
|
+
_python_zip = zip
|
|
15
|
+
|
|
16
|
+
def zip(a: Sequence[Any], b: Sequence[Any], strict: bool = False):
|
|
17
|
+
"""Polyfill for Python's zip function.
|
|
18
|
+
|
|
19
|
+
This is a special version which only supports two Sequence inputs.
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
ValueError: If the iterables have different lengths and strict is True.
|
|
23
|
+
"""
|
|
24
|
+
if len(a) != len(b) and strict:
|
|
25
|
+
raise ValueError("zip() argument lengths must be equal")
|
|
26
|
+
return _python_zip(a, b)
|
onnx_ir/_protocols.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Protocols for the ONNX IR.
|
|
4
4
|
|
|
5
5
|
This file defines the interfaces for tools to interact with the IR. The interfaces
|
|
@@ -31,18 +31,20 @@ tools.
|
|
|
31
31
|
from __future__ import annotations
|
|
32
32
|
|
|
33
33
|
import typing
|
|
34
|
-
from
|
|
35
|
-
|
|
34
|
+
from collections import OrderedDict
|
|
35
|
+
from collections.abc import (
|
|
36
36
|
Collection,
|
|
37
37
|
Iterable,
|
|
38
38
|
Iterator,
|
|
39
39
|
Mapping,
|
|
40
40
|
MutableMapping,
|
|
41
41
|
MutableSequence,
|
|
42
|
-
OrderedDict,
|
|
43
|
-
Protocol,
|
|
44
42
|
Sequence,
|
|
45
|
-
|
|
43
|
+
)
|
|
44
|
+
from typing import (
|
|
45
|
+
Any,
|
|
46
|
+
Literal,
|
|
47
|
+
Protocol,
|
|
46
48
|
)
|
|
47
49
|
|
|
48
50
|
from onnx_ir import _enums
|
|
@@ -52,7 +54,7 @@ if typing.TYPE_CHECKING:
|
|
|
52
54
|
from typing_extensions import TypeAlias
|
|
53
55
|
|
|
54
56
|
# An identifier that will uniquely identify an operator. E.g (domain, op_type, overload)
|
|
55
|
-
OperatorIdentifier: TypeAlias =
|
|
57
|
+
OperatorIdentifier: TypeAlias = tuple[str, str, str]
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
@typing.runtime_checkable
|
|
@@ -277,6 +279,11 @@ class GraphProtocol(Protocol):
|
|
|
277
279
|
seen as a Sequence of nodes and should be used as such. For example, to obtain
|
|
278
280
|
all nodes as a list, call ``list(graph)``.
|
|
279
281
|
|
|
282
|
+
.. :note::
|
|
283
|
+
``quantization_annotation`` is deserialized into the Value's ``meta`` field
|
|
284
|
+
under the ``quant_parameter_tensor_names`` key. Values that are stored
|
|
285
|
+
under this key will be serialized as quantization annotations.
|
|
286
|
+
|
|
280
287
|
Attributes:
|
|
281
288
|
name: The name of the graph.
|
|
282
289
|
inputs: The input values of the graph.
|
|
@@ -288,7 +295,6 @@ class GraphProtocol(Protocol):
|
|
|
288
295
|
meta: Metadata store for graph transform passes.
|
|
289
296
|
"""
|
|
290
297
|
|
|
291
|
-
# TODO(justinchuby): Support quantization_annotation
|
|
292
298
|
name: str | None
|
|
293
299
|
inputs: MutableSequence[ValueProtocol]
|
|
294
300
|
outputs: MutableSequence[ValueProtocol]
|
|
@@ -316,11 +322,15 @@ class GraphProtocol(Protocol):
|
|
|
316
322
|
"""Remove a node from the graph."""
|
|
317
323
|
...
|
|
318
324
|
|
|
319
|
-
def insert_after(
|
|
325
|
+
def insert_after(
|
|
326
|
+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
|
|
327
|
+
) -> None:
|
|
320
328
|
"""Insert new nodes after the given node."""
|
|
321
329
|
...
|
|
322
330
|
|
|
323
|
-
def insert_before(
|
|
331
|
+
def insert_before(
|
|
332
|
+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
|
|
333
|
+
) -> None:
|
|
324
334
|
"""Insert new nodes before the given node."""
|
|
325
335
|
...
|
|
326
336
|
|
|
@@ -414,6 +424,8 @@ class AttributeProtocol(Protocol):
|
|
|
414
424
|
value: Any
|
|
415
425
|
doc_string: str | None
|
|
416
426
|
|
|
427
|
+
def is_ref(self) -> Literal[False]: ...
|
|
428
|
+
|
|
417
429
|
|
|
418
430
|
@typing.runtime_checkable
|
|
419
431
|
class ReferenceAttributeProtocol(Protocol):
|
|
@@ -433,6 +445,8 @@ class ReferenceAttributeProtocol(Protocol):
|
|
|
433
445
|
type: _enums.AttributeType
|
|
434
446
|
doc_string: str | None
|
|
435
447
|
|
|
448
|
+
def is_ref(self) -> Literal[True]: ...
|
|
449
|
+
|
|
436
450
|
|
|
437
451
|
@typing.runtime_checkable
|
|
438
452
|
class SparseTensorProtocol(Protocol):
|
|
@@ -585,11 +599,15 @@ class FunctionProtocol(Protocol):
|
|
|
585
599
|
"""Remove a node from the function."""
|
|
586
600
|
...
|
|
587
601
|
|
|
588
|
-
def insert_after(
|
|
602
|
+
def insert_after(
|
|
603
|
+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
|
|
604
|
+
) -> None:
|
|
589
605
|
"""Insert new nodes after the given node."""
|
|
590
606
|
...
|
|
591
607
|
|
|
592
|
-
def insert_before(
|
|
608
|
+
def insert_before(
|
|
609
|
+
self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
|
|
610
|
+
) -> None:
|
|
593
611
|
"""Insert new nodes before the given node."""
|
|
594
612
|
...
|
|
595
613
|
|
onnx_ir/_tape.py
CHANGED
|
@@ -1,77 +1,182 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Convenience methods for constructing the IR."""
|
|
4
4
|
|
|
5
|
-
# NOTE: This is a temporary solution for constructing the IR. It should be replaced
|
|
6
|
-
# with a more permanent solution in the future.
|
|
7
|
-
|
|
8
5
|
from __future__ import annotations
|
|
9
6
|
|
|
10
|
-
from
|
|
7
|
+
from collections.abc import Mapping, Sequence
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Optional,
|
|
11
|
+
)
|
|
11
12
|
|
|
12
13
|
import onnx_ir as ir
|
|
13
14
|
from onnx_ir import _convenience
|
|
14
15
|
|
|
16
|
+
# A type representing the domains/versions used in creating nodes in IR.
|
|
17
|
+
UsedOpsets = set[tuple[str, Optional[int]]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Tape:
|
|
21
|
+
"""Tape class.
|
|
22
|
+
|
|
23
|
+
A tape is a recorder that collects nodes and initializers that are created so
|
|
24
|
+
that they can be used for creating a graph.
|
|
25
|
+
|
|
26
|
+
Example::
|
|
27
|
+
|
|
28
|
+
import onnx_ir as ir
|
|
29
|
+
|
|
30
|
+
tape = ir.tape.Tape()
|
|
31
|
+
a = tape.initializer(ir.tensor([1, 2, 3], name="a"))
|
|
32
|
+
b: ir.Value = ...
|
|
33
|
+
c: ir.Value = ...
|
|
34
|
+
x = tape.op("Add", [a, b], attributes={"alpha": 1.0})
|
|
35
|
+
y = tape.op("Mul", [x, c], attributes={"beta": 2.0})
|
|
36
|
+
model = ir.Model(
|
|
37
|
+
graph := ir.Graph(
|
|
38
|
+
inputs=[b, c],
|
|
39
|
+
outputs=[y],
|
|
40
|
+
nodes=tape.nodes,
|
|
41
|
+
initializers=tape.initializers
|
|
42
|
+
opset_imports={"": 20},
|
|
43
|
+
),
|
|
44
|
+
ir_version=10,
|
|
45
|
+
)
|
|
15
46
|
|
|
16
|
-
|
|
17
|
-
|
|
47
|
+
Attributes:
|
|
48
|
+
graph_like: The graph to append the new nodes and initializers to. When
|
|
49
|
+
it is None, the nodes and initializers are creating without owned by a graph.
|
|
50
|
+
Initializers will not be added to functions because it is not supported by ONNX.
|
|
51
|
+
"""
|
|
18
52
|
|
|
19
|
-
def __init__(self) -> None:
|
|
53
|
+
def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None:
|
|
20
54
|
self._nodes: list[ir.Node] = []
|
|
55
|
+
self._initializers: list[ir.Value] = []
|
|
56
|
+
self._used_opsets: UsedOpsets = set()
|
|
57
|
+
self.graph_like = graph_like
|
|
21
58
|
|
|
22
|
-
def
|
|
23
|
-
return
|
|
59
|
+
def __repr__(self) -> str:
|
|
60
|
+
return f"Tape(nodes={self._nodes}, initializers={self._initializers})"
|
|
24
61
|
|
|
25
62
|
@property
|
|
26
63
|
def nodes(self) -> Sequence[ir.Node]:
|
|
27
64
|
return tuple(self._nodes)
|
|
28
65
|
|
|
66
|
+
@property
|
|
67
|
+
def initializers(self) -> Sequence[ir.Value]:
|
|
68
|
+
return tuple(self._initializers)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def used_opsets(self) -> UsedOpsets:
|
|
72
|
+
return self._used_opsets
|
|
73
|
+
|
|
29
74
|
def op(
|
|
30
75
|
self,
|
|
31
76
|
op_type: str,
|
|
32
77
|
inputs: Sequence[ir.Value | None],
|
|
33
78
|
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
|
|
79
|
+
*,
|
|
34
80
|
domain: str = "",
|
|
81
|
+
overload: str = "",
|
|
82
|
+
version: int | None = None,
|
|
83
|
+
graph: ir.Graph | None = None,
|
|
84
|
+
name: str | None = None,
|
|
85
|
+
doc_string: str | None = None,
|
|
86
|
+
metadata_props: dict[str, str] | None = None,
|
|
87
|
+
output: ir.Value | None = None,
|
|
35
88
|
) -> ir.Value:
|
|
36
89
|
if attributes is None:
|
|
37
|
-
attrs: Sequence[ir.Attr
|
|
90
|
+
attrs: Sequence[ir.Attr] = ()
|
|
38
91
|
else:
|
|
39
92
|
attrs = _convenience.convert_attributes(attributes)
|
|
40
|
-
|
|
93
|
+
output_kwargs: dict[str, Any]
|
|
94
|
+
if output is None:
|
|
95
|
+
output_kwargs = dict(num_outputs=1)
|
|
96
|
+
else:
|
|
97
|
+
output_kwargs = dict(outputs=[output])
|
|
98
|
+
node = ir.Node(
|
|
99
|
+
domain,
|
|
100
|
+
op_type,
|
|
101
|
+
inputs,
|
|
102
|
+
attributes=attrs,
|
|
103
|
+
**output_kwargs,
|
|
104
|
+
overload=overload,
|
|
105
|
+
version=version,
|
|
106
|
+
graph=graph or self.graph_like,
|
|
107
|
+
name=name,
|
|
108
|
+
doc_string=doc_string,
|
|
109
|
+
metadata_props=metadata_props,
|
|
110
|
+
)
|
|
41
111
|
self._nodes.append(node)
|
|
112
|
+
self._used_opsets.add((domain, version))
|
|
42
113
|
|
|
43
114
|
return node.outputs[0]
|
|
44
115
|
|
|
45
|
-
def
|
|
116
|
+
def op_multi_out(
|
|
46
117
|
self,
|
|
47
118
|
op_type: str,
|
|
48
119
|
inputs: Sequence[ir.Value | None],
|
|
49
120
|
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
|
|
50
121
|
*,
|
|
51
|
-
num_outputs: int,
|
|
122
|
+
num_outputs: int | None = None,
|
|
123
|
+
outputs: Sequence[ir.Value] | None = None,
|
|
52
124
|
domain: str = "",
|
|
125
|
+
overload: str = "",
|
|
126
|
+
version: int | None = None,
|
|
127
|
+
graph: ir.Graph | None = None,
|
|
128
|
+
name: str | None = None,
|
|
129
|
+
doc_string: str | None = None,
|
|
130
|
+
metadata_props: dict[str, str] | None = None,
|
|
53
131
|
) -> Sequence[ir.Value]:
|
|
132
|
+
if num_outputs is None and outputs is None:
|
|
133
|
+
raise ValueError("Either num_outputs or outputs must be provided.")
|
|
134
|
+
if num_outputs is not None and outputs is not None:
|
|
135
|
+
raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.")
|
|
136
|
+
output_kwargs: dict[str, Any]
|
|
137
|
+
if outputs is None:
|
|
138
|
+
output_kwargs = dict(num_outputs=num_outputs)
|
|
139
|
+
else:
|
|
140
|
+
output_kwargs = dict(outputs=outputs)
|
|
54
141
|
if attributes is None:
|
|
55
|
-
attrs: Sequence[ir.Attr
|
|
142
|
+
attrs: Sequence[ir.Attr] = ()
|
|
56
143
|
else:
|
|
57
144
|
attrs = _convenience.convert_attributes(attributes)
|
|
58
|
-
node = ir.Node(
|
|
145
|
+
node = ir.Node(
|
|
146
|
+
domain,
|
|
147
|
+
op_type,
|
|
148
|
+
inputs,
|
|
149
|
+
attributes=attrs,
|
|
150
|
+
**output_kwargs,
|
|
151
|
+
overload=overload,
|
|
152
|
+
version=version,
|
|
153
|
+
graph=graph or self.graph_like,
|
|
154
|
+
name=name,
|
|
155
|
+
doc_string=doc_string,
|
|
156
|
+
metadata_props=metadata_props,
|
|
157
|
+
)
|
|
59
158
|
self._nodes.append(node)
|
|
159
|
+
self._used_opsets.add((domain, version))
|
|
60
160
|
|
|
61
161
|
return node.outputs
|
|
62
162
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
163
|
+
def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
|
|
164
|
+
name = name or tensor.name
|
|
165
|
+
if name is None:
|
|
166
|
+
raise ValueError("Name must be provided for initializer.")
|
|
167
|
+
shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
|
|
168
|
+
value = ir.Value(
|
|
169
|
+
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
|
|
170
|
+
)
|
|
171
|
+
self._initializers.append(value)
|
|
172
|
+
if isinstance(self.graph_like, ir.Graph):
|
|
173
|
+
self.graph_like.register_initializer(value)
|
|
174
|
+
return value
|
|
66
175
|
|
|
67
176
|
|
|
68
177
|
class Builder(Tape):
|
|
69
178
|
"""An extension of the tape that provides a more convenient API for constructing the IR."""
|
|
70
179
|
|
|
71
|
-
def __init__(self):
|
|
72
|
-
super().__init__()
|
|
73
|
-
self._used_opsets: UsedOpsets = []
|
|
74
|
-
|
|
75
180
|
def __getattr__(self, op_type: str) -> Any:
|
|
76
181
|
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)
|
|
77
182
|
|
|
@@ -85,20 +190,22 @@ class Builder(Tape):
|
|
|
85
190
|
assert isinstance(outputs, int)
|
|
86
191
|
num_outputs = outputs
|
|
87
192
|
|
|
88
|
-
self._used_opsets.append((domain, version))
|
|
89
193
|
if num_outputs == 1:
|
|
90
|
-
value = super().op(
|
|
194
|
+
value = super().op(
|
|
195
|
+
op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version
|
|
196
|
+
)
|
|
91
197
|
if isinstance(outputs, Sequence):
|
|
92
198
|
value.name = outputs[0]
|
|
93
199
|
return value
|
|
94
|
-
values = super().
|
|
95
|
-
op_type,
|
|
200
|
+
values = super().op_multi_out(
|
|
201
|
+
op_type,
|
|
202
|
+
inputs=inputs,
|
|
203
|
+
attributes=kwargs,
|
|
204
|
+
domain=domain,
|
|
205
|
+
version=version,
|
|
206
|
+
num_outputs=num_outputs,
|
|
96
207
|
)
|
|
97
208
|
if isinstance(outputs, Sequence):
|
|
98
209
|
for value, name in zip(values, outputs):
|
|
99
210
|
value.name = name
|
|
100
211
|
return values
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def used_opsets(self) -> UsedOpsets:
|
|
104
|
-
return self._used_opsets
|
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
# Copyright (c) Microsoft Corporation.
|
|
2
|
-
# Licensed under the MIT License.
|
|
3
|
-
#
|
|
4
1
|
# Copyright © 2016 Igor Kroitor
|
|
5
2
|
#
|
|
6
3
|
# MIT License
|
|
@@ -32,8 +29,8 @@ options to tune the output.
|
|
|
32
29
|
|
|
33
30
|
from __future__ import annotations
|
|
34
31
|
|
|
32
|
+
from collections.abc import Mapping
|
|
35
33
|
from math import ceil, floor, isnan
|
|
36
|
-
from typing import Mapping
|
|
37
34
|
|
|
38
35
|
black = "\033[30m"
|
|
39
36
|
red = "\033[31m"
|
onnx_ir/_type_casting.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Numpy utilities for non-native type operation."""
|
|
4
4
|
# TODO(justinchuby): Upstream the logic to onnx
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
8
|
import typing
|
|
9
|
-
from
|
|
9
|
+
from collections.abc import Sequence
|
|
10
10
|
|
|
11
11
|
import ml_dtypes
|
|
12
12
|
import numpy as np
|
|
@@ -89,3 +89,18 @@ def unpack_int4(
|
|
|
89
89
|
"""
|
|
90
90
|
unpacked = _unpack_uint4_as_uint8(data, dims)
|
|
91
91
|
return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def unpack_float4e2m1(
|
|
95
|
+
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
96
|
+
) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
|
|
97
|
+
"""Convert a packed float4e2m1 array to unpacked float4e2m1 array.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
data: A numpy array.
|
|
101
|
+
dims: The dimensions are used to reshape the unpacked buffer.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
A numpy array of float32 reshaped to dims.
|
|
105
|
+
"""
|
|
106
|
+
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)
|
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Version utils for testing."""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
import warnings
|
|
8
|
-
from typing import Callable, Sequence
|
|
9
|
-
|
|
10
7
|
import packaging.version
|
|
11
8
|
|
|
12
9
|
|
|
@@ -92,27 +89,3 @@ def has_transformers():
|
|
|
92
89
|
return True # noqa
|
|
93
90
|
except ImportError:
|
|
94
91
|
return False
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
|
|
98
|
-
"""Catches warnings.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
warns: warnings to ignore
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
decorated function
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
def wrapper(fct):
|
|
108
|
-
if warns is None:
|
|
109
|
-
raise AssertionError(f"warns cannot be None for '{fct}'.")
|
|
110
|
-
|
|
111
|
-
def call_f(self):
|
|
112
|
-
with warnings.catch_warnings():
|
|
113
|
-
warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
|
|
114
|
-
return fct(self)
|
|
115
|
-
|
|
116
|
-
return call_f
|
|
117
|
-
|
|
118
|
-
return wrapper
|
onnx_ir/convenience.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Convenience methods for constructing and manipulating the IR."""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
@@ -9,11 +9,13 @@ __all__ = [
|
|
|
9
9
|
"convert_attributes",
|
|
10
10
|
"replace_all_uses_with",
|
|
11
11
|
"replace_nodes_and_values",
|
|
12
|
+
"create_value_mapping",
|
|
12
13
|
]
|
|
13
14
|
|
|
14
15
|
from onnx_ir._convenience import (
|
|
15
16
|
convert_attribute,
|
|
16
17
|
convert_attributes,
|
|
18
|
+
create_value_mapping,
|
|
17
19
|
replace_all_uses_with,
|
|
18
20
|
replace_nodes_and_values,
|
|
19
21
|
)
|