onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +857 -233
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +268 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +36 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/constant_manipulation.py +232 -0
- onnx_ir/passes/common/inliner.py +331 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.0.dist-info/METADATA +53 -0
- onnx_ir-0.1.0.dist-info/RECORD +41 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/__init__.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""In-memory intermediate representation for ONNX graphs."""
|
|
4
4
|
|
|
5
5
|
__all__ = [
|
|
6
6
|
# Modules
|
|
7
7
|
"serde",
|
|
8
|
+
"traversal",
|
|
9
|
+
"convenience",
|
|
10
|
+
"external_data",
|
|
11
|
+
"tape",
|
|
8
12
|
# IR classes
|
|
9
13
|
"Tensor",
|
|
10
14
|
"ExternalTensor",
|
|
11
15
|
"StringTensor",
|
|
16
|
+
"LazyTensor",
|
|
12
17
|
"SymbolicDim",
|
|
13
18
|
"Shape",
|
|
14
19
|
"TensorType",
|
|
@@ -66,19 +71,24 @@ __all__ = [
|
|
|
66
71
|
"TensorProtoTensor",
|
|
67
72
|
# Conversion functions
|
|
68
73
|
"from_proto",
|
|
74
|
+
"from_onnx_text",
|
|
69
75
|
"to_proto",
|
|
70
|
-
#
|
|
76
|
+
# Convenience constructors
|
|
71
77
|
"tensor",
|
|
78
|
+
"node",
|
|
72
79
|
# Pass infrastructure
|
|
73
80
|
"passes",
|
|
74
|
-
"traversal",
|
|
75
81
|
# IO
|
|
76
82
|
"load",
|
|
77
83
|
"save",
|
|
84
|
+
# Flags
|
|
85
|
+
"DEBUG",
|
|
78
86
|
]
|
|
79
87
|
|
|
80
|
-
|
|
81
|
-
|
|
88
|
+
import types
|
|
89
|
+
|
|
90
|
+
from onnx_ir import convenience, external_data, passes, serde, tape, traversal
|
|
91
|
+
from onnx_ir._convenience._constructors import node, tensor
|
|
82
92
|
from onnx_ir._core import (
|
|
83
93
|
Attr,
|
|
84
94
|
AttrFloat32,
|
|
@@ -100,6 +110,7 @@ from onnx_ir._core import (
|
|
|
100
110
|
Graph,
|
|
101
111
|
GraphView,
|
|
102
112
|
Input,
|
|
113
|
+
LazyTensor,
|
|
103
114
|
Model,
|
|
104
115
|
Node,
|
|
105
116
|
OptionalType,
|
|
@@ -138,17 +149,19 @@ from onnx_ir._protocols import (
|
|
|
138
149
|
TypeProtocol,
|
|
139
150
|
ValueProtocol,
|
|
140
151
|
)
|
|
141
|
-
from onnx_ir.serde import TensorProtoTensor, from_proto, to_proto
|
|
142
|
-
|
|
152
|
+
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
|
|
143
153
|
|
|
144
|
-
DEBUG
|
|
154
|
+
DEBUG = False
|
|
145
155
|
|
|
146
156
|
|
|
147
157
|
def __set_module() -> None:
|
|
148
158
|
"""Set the module of all functions in this module to this public module."""
|
|
149
159
|
global_dict = globals()
|
|
150
160
|
for name in __all__:
|
|
151
|
-
global_dict[name]
|
|
161
|
+
obj = global_dict[name]
|
|
162
|
+
if hasattr(obj, "__module__") and not isinstance(obj, types.GenericAlias):
|
|
163
|
+
obj.__module__ = __name__
|
|
152
164
|
|
|
153
165
|
|
|
154
166
|
__set_module()
|
|
167
|
+
__version__ = "0.1.0"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Convenience methods for constructing and manipulating the IR.
|
|
4
4
|
|
|
5
5
|
This is an internal only module. We should choose to expose some of the methods
|
|
@@ -12,19 +12,17 @@ __all__ = [
|
|
|
12
12
|
"convert_attribute",
|
|
13
13
|
"convert_attributes",
|
|
14
14
|
"replace_all_uses_with",
|
|
15
|
+
"create_value_mapping",
|
|
16
|
+
"replace_nodes_and_values",
|
|
15
17
|
]
|
|
16
18
|
|
|
17
|
-
import
|
|
18
|
-
from typing import
|
|
19
|
+
from collections.abc import Mapping, Sequence
|
|
20
|
+
from typing import Union
|
|
19
21
|
|
|
20
|
-
import numpy as np
|
|
21
22
|
import onnx
|
|
22
23
|
|
|
23
24
|
from onnx_ir import _core, _enums, _protocols, serde
|
|
24
25
|
|
|
25
|
-
if typing.TYPE_CHECKING:
|
|
26
|
-
import numpy.typing as npt
|
|
27
|
-
|
|
28
26
|
SupportedAttrTypes = Union[
|
|
29
27
|
str,
|
|
30
28
|
int,
|
|
@@ -35,9 +33,9 @@ SupportedAttrTypes = Union[
|
|
|
35
33
|
_protocols.TensorProtocol, # This includes all in-memory tensor types
|
|
36
34
|
onnx.TensorProto,
|
|
37
35
|
_core.Attr,
|
|
38
|
-
_core.RefAttr,
|
|
39
36
|
_protocols.GraphProtocol,
|
|
40
37
|
Sequence[_protocols.GraphProtocol],
|
|
38
|
+
onnx.GraphProto,
|
|
41
39
|
_protocols.TypeProtocol,
|
|
42
40
|
Sequence[_protocols.TypeProtocol],
|
|
43
41
|
None,
|
|
@@ -52,7 +50,7 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
|
|
|
52
50
|
return _enums.AttributeType.FLOAT
|
|
53
51
|
if isinstance(attr, str):
|
|
54
52
|
return _enums.AttributeType.STRING
|
|
55
|
-
if isinstance(attr,
|
|
53
|
+
if isinstance(attr, _core.Attr):
|
|
56
54
|
return attr.type
|
|
57
55
|
if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
|
|
58
56
|
return _enums.AttributeType.INTS
|
|
@@ -63,10 +61,15 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
|
|
|
63
61
|
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
|
|
64
62
|
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
|
|
65
63
|
return _enums.AttributeType.TENSOR
|
|
66
|
-
if isinstance(attr, (
|
|
64
|
+
if isinstance(attr, Sequence) and all(
|
|
65
|
+
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
|
|
66
|
+
for x in attr
|
|
67
|
+
):
|
|
68
|
+
return _enums.AttributeType.TENSORS
|
|
69
|
+
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
|
|
67
70
|
return _enums.AttributeType.GRAPH
|
|
68
71
|
if isinstance(attr, Sequence) and all(
|
|
69
|
-
isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
|
|
72
|
+
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
|
|
70
73
|
):
|
|
71
74
|
return _enums.AttributeType.GRAPHS
|
|
72
75
|
if isinstance(
|
|
@@ -94,7 +97,7 @@ def convert_attribute(
|
|
|
94
97
|
name: str,
|
|
95
98
|
attr: SupportedAttrTypes,
|
|
96
99
|
attr_type: _enums.AttributeType | None = None,
|
|
97
|
-
) -> _core.Attr
|
|
100
|
+
) -> _core.Attr:
|
|
98
101
|
"""Convert a Python object to a _core.Attr object.
|
|
99
102
|
|
|
100
103
|
This method is useful when constructing nodes with attributes. It infers the
|
|
@@ -110,7 +113,7 @@ def convert_attribute(
|
|
|
110
113
|
A ``Attr`` object.
|
|
111
114
|
|
|
112
115
|
Raises:
|
|
113
|
-
ValueError: If
|
|
116
|
+
ValueError: If ``attr`` is ``None`` and ``attr_type`` is not provided.
|
|
114
117
|
TypeError: If the type of the attribute is not supported.
|
|
115
118
|
"""
|
|
116
119
|
if attr is None:
|
|
@@ -118,7 +121,7 @@ def convert_attribute(
|
|
|
118
121
|
raise ValueError("attr_type must be provided when attr is None")
|
|
119
122
|
return _core.Attr(name, attr_type, None)
|
|
120
123
|
|
|
121
|
-
if isinstance(attr,
|
|
124
|
+
if isinstance(attr, _core.Attr):
|
|
122
125
|
if attr.name != name:
|
|
123
126
|
raise ValueError(
|
|
124
127
|
f"Attribute name '{attr.name}' does not match provided name '{name}'"
|
|
@@ -148,11 +151,27 @@ def convert_attribute(
|
|
|
148
151
|
if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
|
|
149
152
|
return _core.AttrTensor(name, attr)
|
|
150
153
|
if isinstance(attr, onnx.TensorProto):
|
|
151
|
-
return _core.AttrTensor(name, serde.
|
|
154
|
+
return _core.AttrTensor(name, serde.deserialize_tensor(attr))
|
|
155
|
+
if attr_type == _enums.AttributeType.TENSORS:
|
|
156
|
+
tensors = []
|
|
157
|
+
for t in attr: # type: ignore[union-attr]
|
|
158
|
+
if isinstance(t, onnx.TensorProto):
|
|
159
|
+
tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t)))
|
|
160
|
+
else:
|
|
161
|
+
tensors.append(t) # type: ignore[arg-type]
|
|
162
|
+
return _core.AttrTensors(name, tensors) # type: ignore[arg-type]
|
|
152
163
|
if attr_type == _enums.AttributeType.GRAPH:
|
|
164
|
+
if isinstance(attr, onnx.GraphProto):
|
|
165
|
+
attr = serde.deserialize_graph(attr)
|
|
153
166
|
return _core.AttrGraph(name, attr) # type: ignore[arg-type]
|
|
154
167
|
if attr_type == _enums.AttributeType.GRAPHS:
|
|
155
|
-
|
|
168
|
+
graphs = []
|
|
169
|
+
for graph in attr: # type: ignore[union-attr]
|
|
170
|
+
if isinstance(graph, onnx.GraphProto):
|
|
171
|
+
graphs.append(serde.deserialize_graph(graph))
|
|
172
|
+
else:
|
|
173
|
+
graphs.append(graph) # type: ignore[arg-type]
|
|
174
|
+
return _core.AttrGraphs(name, graphs) # type: ignore[arg-type]
|
|
156
175
|
if attr_type == _enums.AttributeType.TYPE_PROTO:
|
|
157
176
|
return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
|
|
158
177
|
if attr_type == _enums.AttributeType.TYPE_PROTOS:
|
|
@@ -162,7 +181,7 @@ def convert_attribute(
|
|
|
162
181
|
|
|
163
182
|
def convert_attributes(
|
|
164
183
|
attrs: Mapping[str, SupportedAttrTypes],
|
|
165
|
-
) -> list[_core.Attr
|
|
184
|
+
) -> list[_core.Attr]:
|
|
166
185
|
"""Convert a dictionary of attributes to a list of _core.Attr objects.
|
|
167
186
|
|
|
168
187
|
It infers the attribute type based on the type of the value. The supported
|
|
@@ -193,7 +212,7 @@ def convert_attributes(
|
|
|
193
212
|
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
|
|
194
213
|
... }
|
|
195
214
|
>>> convert_attributes(attrs)
|
|
196
|
-
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(name='proto')), Attr('graph', INTS, Graph(
|
|
215
|
+
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph(
|
|
197
216
|
name='graph0',
|
|
198
217
|
inputs=(
|
|
199
218
|
<BLANKLINE>
|
|
@@ -228,7 +247,7 @@ def convert_attributes(
|
|
|
228
247
|
Returns:
|
|
229
248
|
A list of _core.Attr objects.
|
|
230
249
|
"""
|
|
231
|
-
attributes: list[_core.Attr
|
|
250
|
+
attributes: list[_core.Attr] = []
|
|
232
251
|
for name, attr in attrs.items():
|
|
233
252
|
if attr is not None:
|
|
234
253
|
attributes.append(convert_attribute(name, attr))
|
|
@@ -291,86 +310,6 @@ def replace_all_uses_with(
|
|
|
291
310
|
user_node.replace_input_with(index, replacement)
|
|
292
311
|
|
|
293
312
|
|
|
294
|
-
def tensor(
|
|
295
|
-
value: npt.ArrayLike
|
|
296
|
-
| onnx.TensorProto
|
|
297
|
-
| _protocols.DLPackCompatible
|
|
298
|
-
| _protocols.ArrayCompatible,
|
|
299
|
-
dtype: _enums.DataType | None = None,
|
|
300
|
-
name: str | None = None,
|
|
301
|
-
doc_string: str | None = None,
|
|
302
|
-
) -> _protocols.TensorProtocol:
|
|
303
|
-
"""Create a tensor value from an ArrayLike object or a TensorProto.
|
|
304
|
-
|
|
305
|
-
The dtype must match the value. Reinterpretation of the value is
|
|
306
|
-
not supported, unless if the value is a plain Python object, in which case
|
|
307
|
-
it is converted to a numpy array with the given dtype.
|
|
308
|
-
|
|
309
|
-
:param:`value` can be a numpy array, a plain Python object, or a TensorProto.
|
|
310
|
-
|
|
311
|
-
Example::
|
|
312
|
-
|
|
313
|
-
>>> import onnx_ir as ir
|
|
314
|
-
>>> import numpy as np
|
|
315
|
-
>>> import ml_dtypes
|
|
316
|
-
>>> import onnx
|
|
317
|
-
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
|
|
318
|
-
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
|
|
319
|
-
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
|
|
320
|
-
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
|
|
321
|
-
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
|
|
322
|
-
>>> tp_tensor.numpy()
|
|
323
|
-
array(0.5, dtype=float32)
|
|
324
|
-
|
|
325
|
-
Args:
|
|
326
|
-
value: The numpy array to create the tensor from.
|
|
327
|
-
dtype: The data type of the tensor.
|
|
328
|
-
name: The name of the tensor.
|
|
329
|
-
doc_string: The documentation string of the tensor.
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
A tensor value.
|
|
333
|
-
|
|
334
|
-
Raises:
|
|
335
|
-
ValueError: If the dtype does not match the value when value is not a plain Python
|
|
336
|
-
object like ``list[int]``.
|
|
337
|
-
"""
|
|
338
|
-
if isinstance(value, _protocols.TensorProtocol):
|
|
339
|
-
if dtype is not None and dtype != value.dtype:
|
|
340
|
-
raise ValueError(
|
|
341
|
-
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
|
|
342
|
-
"You do not have to specify the dtype when value is a Tensor."
|
|
343
|
-
)
|
|
344
|
-
return value
|
|
345
|
-
if isinstance(value, onnx.TensorProto):
|
|
346
|
-
tensor_ = serde.deserialize_tensor(value)
|
|
347
|
-
if name is not None:
|
|
348
|
-
tensor_.name = name
|
|
349
|
-
if doc_string is not None:
|
|
350
|
-
tensor_.doc_string = doc_string
|
|
351
|
-
if dtype is not None and dtype != tensor_.dtype:
|
|
352
|
-
raise ValueError(
|
|
353
|
-
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
|
|
354
|
-
"You do not have to specify the dtype when value is a TensorProto."
|
|
355
|
-
)
|
|
356
|
-
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
|
|
357
|
-
tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
|
|
358
|
-
else:
|
|
359
|
-
if dtype is not None:
|
|
360
|
-
numpy_dtype = dtype.numpy()
|
|
361
|
-
else:
|
|
362
|
-
numpy_dtype = None
|
|
363
|
-
array = np.array(value, dtype=numpy_dtype)
|
|
364
|
-
tensor_ = _core.Tensor(
|
|
365
|
-
array,
|
|
366
|
-
dtype=dtype,
|
|
367
|
-
shape=_core.Shape(array.shape),
|
|
368
|
-
name=name,
|
|
369
|
-
doc_string=name,
|
|
370
|
-
)
|
|
371
|
-
return tensor_
|
|
372
|
-
|
|
373
|
-
|
|
374
313
|
def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
|
|
375
314
|
"""Return a dictionary mapping names to values in the graph.
|
|
376
315
|
|
|
@@ -382,7 +321,7 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
|
|
|
382
321
|
Returns:
|
|
383
322
|
A dictionary mapping names to values.
|
|
384
323
|
"""
|
|
385
|
-
values = {}
|
|
324
|
+
values: dict[str, _core.Value] = {}
|
|
386
325
|
values.update(graph.initializers)
|
|
387
326
|
# The names of the values can be None or "", which we need to exclude
|
|
388
327
|
for input in graph.inputs:
|
|
@@ -416,7 +355,6 @@ def replace_nodes_and_values(
|
|
|
416
355
|
old_values: The values to replace.
|
|
417
356
|
new_values: The values to replace with.
|
|
418
357
|
"""
|
|
419
|
-
|
|
420
358
|
for old_value, new_value in zip(old_values, new_values):
|
|
421
359
|
# Propagate relevant info from old value to new value
|
|
422
360
|
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Convenience constructors for IR objects."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"tensor",
|
|
9
|
+
"node",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
import typing
|
|
13
|
+
from collections.abc import Mapping, Sequence
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import onnx
|
|
17
|
+
|
|
18
|
+
from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
|
|
19
|
+
|
|
20
|
+
if typing.TYPE_CHECKING:
|
|
21
|
+
import numpy.typing as npt
|
|
22
|
+
|
|
23
|
+
import onnx_ir as ir
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def tensor(
|
|
27
|
+
value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
|
|
28
|
+
dtype: _enums.DataType | None = None,
|
|
29
|
+
name: str | None = None,
|
|
30
|
+
doc_string: str | None = None,
|
|
31
|
+
) -> _protocols.TensorProtocol:
|
|
32
|
+
"""Create a tensor value from an ArrayLike object or a TensorProto.
|
|
33
|
+
|
|
34
|
+
The dtype must match the value. Reinterpretation of the value is
|
|
35
|
+
not supported, unless if the value is a plain Python object, in which case
|
|
36
|
+
it is converted to a numpy array with the given dtype.
|
|
37
|
+
|
|
38
|
+
``value`` can be a numpy array, a plain Python object, or a TensorProto.
|
|
39
|
+
|
|
40
|
+
Example::
|
|
41
|
+
|
|
42
|
+
>>> import onnx_ir as ir
|
|
43
|
+
>>> import numpy as np
|
|
44
|
+
>>> import ml_dtypes
|
|
45
|
+
>>> import onnx
|
|
46
|
+
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
|
|
47
|
+
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
|
|
48
|
+
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
|
|
49
|
+
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
|
|
50
|
+
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
|
|
51
|
+
>>> tp_tensor.numpy()
|
|
52
|
+
array(0.5, dtype=float32)
|
|
53
|
+
>>> import torch
|
|
54
|
+
>>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor")
|
|
55
|
+
TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor')
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
value: The numpy array to create the tensor from.
|
|
59
|
+
dtype: The data type of the tensor.
|
|
60
|
+
name: The name of the tensor.
|
|
61
|
+
doc_string: The documentation string of the tensor.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
A tensor value.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
ValueError: If the dtype does not match the value when value is not a plain Python
|
|
68
|
+
object like ``list[int]``.
|
|
69
|
+
"""
|
|
70
|
+
if isinstance(value, _protocols.TensorProtocol):
|
|
71
|
+
if dtype is not None and dtype != value.dtype:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
|
|
74
|
+
"You do not have to specify the dtype when value is a Tensor."
|
|
75
|
+
)
|
|
76
|
+
return value
|
|
77
|
+
if isinstance(value, onnx.TensorProto):
|
|
78
|
+
tensor_ = serde.deserialize_tensor(value)
|
|
79
|
+
if name is not None:
|
|
80
|
+
tensor_.name = name
|
|
81
|
+
if doc_string is not None:
|
|
82
|
+
tensor_.doc_string = doc_string
|
|
83
|
+
if dtype is not None and dtype != tensor_.dtype:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
|
|
86
|
+
"You do not have to specify the dtype when value is a TensorProto."
|
|
87
|
+
)
|
|
88
|
+
return tensor_
|
|
89
|
+
elif str(type(value)) == "<class 'torch.Tensor'>":
|
|
90
|
+
# NOTE: We use str(type(...)) and do not import torch for type checking
|
|
91
|
+
# as it creates overhead during import
|
|
92
|
+
return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
|
|
93
|
+
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
|
|
94
|
+
return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string)
|
|
95
|
+
|
|
96
|
+
# Plain (numerical) Python object. Determine the numpy dtype and use np.array to construct the tensor
|
|
97
|
+
if dtype is not None:
|
|
98
|
+
if not isinstance(dtype, _enums.DataType):
|
|
99
|
+
raise TypeError(f"dtype must be an instance of DataType. dtype={dtype}")
|
|
100
|
+
numpy_dtype = dtype.numpy()
|
|
101
|
+
elif isinstance(value, Sequence) and not value:
|
|
102
|
+
raise ValueError("dtype must be specified when value is an empty sequence.")
|
|
103
|
+
elif isinstance(value, int) and not isinstance(value, bool):
|
|
104
|
+
# Specify int64 for ints because on Windows this may be int32
|
|
105
|
+
numpy_dtype = np.dtype(np.int64)
|
|
106
|
+
elif isinstance(value, float):
|
|
107
|
+
# If the value is a single float, we use np.float32 as the default dtype
|
|
108
|
+
numpy_dtype = np.dtype(np.float32)
|
|
109
|
+
elif isinstance(value, Sequence) and value:
|
|
110
|
+
if all((isinstance(elem, int) and not isinstance(elem, bool)) for elem in value):
|
|
111
|
+
numpy_dtype = np.dtype(np.int64)
|
|
112
|
+
elif all(isinstance(elem, float) for elem in value):
|
|
113
|
+
# If the value is a sequence of floats, we use np.float32 as the default dtype
|
|
114
|
+
numpy_dtype = np.dtype(np.float32)
|
|
115
|
+
else:
|
|
116
|
+
numpy_dtype = None
|
|
117
|
+
else:
|
|
118
|
+
numpy_dtype = None
|
|
119
|
+
|
|
120
|
+
array = np.array(value, dtype=numpy_dtype)
|
|
121
|
+
|
|
122
|
+
# Handle string tensors by encoding them
|
|
123
|
+
if isinstance(value, str) or (
|
|
124
|
+
isinstance(value, Sequence) and value and all(isinstance(elem, str) for elem in value)
|
|
125
|
+
):
|
|
126
|
+
array = np.strings.encode(array, encoding="utf-8")
|
|
127
|
+
return _core.StringTensor(
|
|
128
|
+
array,
|
|
129
|
+
shape=_core.Shape(array.shape),
|
|
130
|
+
name=name,
|
|
131
|
+
doc_string=doc_string,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return _core.Tensor(
|
|
135
|
+
array,
|
|
136
|
+
dtype=dtype,
|
|
137
|
+
shape=_core.Shape(array.shape),
|
|
138
|
+
name=name,
|
|
139
|
+
doc_string=doc_string,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def node(
|
|
144
|
+
op_type: str,
|
|
145
|
+
inputs: Sequence[ir.Value | None],
|
|
146
|
+
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
|
|
147
|
+
*,
|
|
148
|
+
domain: str = "",
|
|
149
|
+
overload: str = "",
|
|
150
|
+
num_outputs: int | None = None,
|
|
151
|
+
outputs: Sequence[ir.Value] | None = None,
|
|
152
|
+
version: int | None = None,
|
|
153
|
+
graph: ir.Graph | None = None,
|
|
154
|
+
name: str | None = None,
|
|
155
|
+
doc_string: str | None = None,
|
|
156
|
+
metadata_props: dict[str, str] | None = None,
|
|
157
|
+
) -> ir.Node:
|
|
158
|
+
"""Create an :class:`ir.Node`.
|
|
159
|
+
|
|
160
|
+
This is a convenience constructor for creating a Node that supports Python
|
|
161
|
+
objects as attributes.
|
|
162
|
+
|
|
163
|
+
Example::
|
|
164
|
+
|
|
165
|
+
>>> import onnx_ir as ir
|
|
166
|
+
>>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
|
|
167
|
+
>>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
|
|
168
|
+
>>> node = ir.node(
|
|
169
|
+
... "SomeOp",
|
|
170
|
+
... inputs=[input_a, input_b],
|
|
171
|
+
... attributes={"alpha": 1.0, "some_list": [1, 2, 3]},
|
|
172
|
+
... domain="some.domain",
|
|
173
|
+
... name="node_name"
|
|
174
|
+
... )
|
|
175
|
+
>>> node.op_type
|
|
176
|
+
'SomeOp'
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
op_type: The name of the operator.
|
|
180
|
+
inputs: The input values. When an input is None, it is an empty input.
|
|
181
|
+
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
|
|
182
|
+
overload: The overload name when the node is invoking a function.
|
|
183
|
+
domain: The domain of the operator. For onnx operators, this is an empty string.
|
|
184
|
+
num_outputs: The number of outputs of the node. If not specified, the number is 1.
|
|
185
|
+
outputs: The output values. If None, the outputs are created during initialization.
|
|
186
|
+
version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
|
|
187
|
+
graph: The graph that the node belongs to. If None, the node is not added to any graph.
|
|
188
|
+
A `Node` must belong to zero or one graph.
|
|
189
|
+
name: The name of the node. If None, the node is anonymous.
|
|
190
|
+
doc_string: The documentation string.
|
|
191
|
+
metadata_props: The metadata properties.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
A node with the given op_type and inputs.
|
|
195
|
+
"""
|
|
196
|
+
if attributes is None:
|
|
197
|
+
attrs: Sequence[ir.Attr] = ()
|
|
198
|
+
else:
|
|
199
|
+
attrs = _convenience.convert_attributes(attributes)
|
|
200
|
+
return _core.Node(
|
|
201
|
+
domain=domain,
|
|
202
|
+
op_type=op_type,
|
|
203
|
+
inputs=inputs,
|
|
204
|
+
attributes=attrs,
|
|
205
|
+
overload=overload,
|
|
206
|
+
num_outputs=num_outputs,
|
|
207
|
+
outputs=outputs,
|
|
208
|
+
version=version,
|
|
209
|
+
graph=graph,
|
|
210
|
+
name=name,
|
|
211
|
+
doc_string=doc_string,
|
|
212
|
+
metadata_props=metadata_props,
|
|
213
|
+
)
|