onnx-ir 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py ADDED
@@ -0,0 +1,154 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """In-memory intermediate representation for ONNX graphs."""
4
+
5
+ __all__ = [
6
+ # Modules
7
+ "serde",
8
+ # IR classes
9
+ "Tensor",
10
+ "ExternalTensor",
11
+ "StringTensor",
12
+ "SymbolicDim",
13
+ "Shape",
14
+ "TensorType",
15
+ "OptionalType",
16
+ "SequenceType",
17
+ "SparseTensorType",
18
+ "TypeAndShape",
19
+ "Value",
20
+ "Attr",
21
+ "RefAttr",
22
+ "Node",
23
+ "Function",
24
+ "Graph",
25
+ "GraphView",
26
+ "Model",
27
+ # Constructors
28
+ "AttrFloat32",
29
+ "AttrFloat32s",
30
+ "AttrGraph",
31
+ "AttrGraphs",
32
+ "AttrInt64",
33
+ "AttrInt64s",
34
+ "AttrSparseTensor",
35
+ "AttrSparseTensors",
36
+ "AttrString",
37
+ "AttrStrings",
38
+ "AttrTensor",
39
+ "AttrTensors",
40
+ "AttrTypeProto",
41
+ "AttrTypeProtos",
42
+ "Input",
43
+ # Protocols
44
+ "ArrayCompatible",
45
+ "DLPackCompatible",
46
+ "TensorProtocol",
47
+ "ValueProtocol",
48
+ "ModelProtocol",
49
+ "NodeProtocol",
50
+ "GraphProtocol",
51
+ "GraphViewProtocol",
52
+ "AttributeProtocol",
53
+ "ReferenceAttributeProtocol",
54
+ "SparseTensorProtocol",
55
+ "SymbolicDimProtocol",
56
+ "ShapeProtocol",
57
+ "TypeProtocol",
58
+ "MapTypeProtocol",
59
+ "FunctionProtocol",
60
+ # Enums
61
+ "AttributeType",
62
+ "DataType",
63
+ # Types
64
+ "OperatorIdentifier",
65
+ # Protobuf compatible types
66
+ "TensorProtoTensor",
67
+ # Conversion functions
68
+ "from_proto",
69
+ "to_proto",
70
+ # IR Tensor initializer
71
+ "tensor",
72
+ # Pass infrastructure
73
+ "passes",
74
+ "traversal",
75
+ # IO
76
+ "load",
77
+ "save",
78
+ ]
79
+
80
+ from onnx_ir import passes, serde, traversal
81
+ from onnx_ir._convenience import tensor
82
+ from onnx_ir._core import (
83
+ Attr,
84
+ AttrFloat32,
85
+ AttrFloat32s,
86
+ AttrGraph,
87
+ AttrGraphs,
88
+ AttrInt64,
89
+ AttrInt64s,
90
+ AttrSparseTensor,
91
+ AttrSparseTensors,
92
+ AttrString,
93
+ AttrStrings,
94
+ AttrTensor,
95
+ AttrTensors,
96
+ AttrTypeProto,
97
+ AttrTypeProtos,
98
+ ExternalTensor,
99
+ Function,
100
+ Graph,
101
+ GraphView,
102
+ Input,
103
+ Model,
104
+ Node,
105
+ OptionalType,
106
+ RefAttr,
107
+ SequenceType,
108
+ Shape,
109
+ SparseTensorType,
110
+ StringTensor,
111
+ SymbolicDim,
112
+ Tensor,
113
+ TensorType,
114
+ TypeAndShape,
115
+ Value,
116
+ )
117
+ from onnx_ir._enums import (
118
+ AttributeType,
119
+ DataType,
120
+ )
121
+ from onnx_ir._io import load, save
122
+ from onnx_ir._protocols import (
123
+ ArrayCompatible,
124
+ AttributeProtocol,
125
+ DLPackCompatible,
126
+ FunctionProtocol,
127
+ GraphProtocol,
128
+ GraphViewProtocol,
129
+ MapTypeProtocol,
130
+ ModelProtocol,
131
+ NodeProtocol,
132
+ OperatorIdentifier,
133
+ ReferenceAttributeProtocol,
134
+ ShapeProtocol,
135
+ SparseTensorProtocol,
136
+ SymbolicDimProtocol,
137
+ TensorProtocol,
138
+ TypeProtocol,
139
+ ValueProtocol,
140
+ )
141
+ from onnx_ir.serde import TensorProtoTensor, from_proto, to_proto
142
+
143
+
144
+ DEBUG: bool = False
145
+
146
+
147
+ def __set_module() -> None:
148
+ """Set the module of all functions in this module to this public module."""
149
+ global_dict = globals()
150
+ for name in __all__:
151
+ global_dict[name].__module__ = __name__
152
+
153
+
154
+ __set_module()
@@ -0,0 +1,439 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Convenience methods for constructing and manipulating the IR.
4
+
5
+ This is an internal only module. We should choose to expose some of the methods
6
+ in convenience.py after they are proven to be useful.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ __all__ = [
12
+ "convert_attribute",
13
+ "convert_attributes",
14
+ "replace_all_uses_with",
15
+ ]
16
+
17
+ import typing
18
+ from typing import Mapping, Sequence, Union
19
+
20
+ import numpy as np
21
+ import onnx
22
+
23
+ from onnx_ir import _core, _enums, _protocols, serde
24
+
25
+ if typing.TYPE_CHECKING:
26
+ import numpy.typing as npt
27
+
28
+ SupportedAttrTypes = Union[
29
+ str,
30
+ int,
31
+ float,
32
+ Sequence[int],
33
+ Sequence[float],
34
+ Sequence[str],
35
+ _protocols.TensorProtocol, # This includes all in-memory tensor types
36
+ onnx.TensorProto,
37
+ _core.Attr,
38
+ _core.RefAttr,
39
+ _protocols.GraphProtocol,
40
+ Sequence[_protocols.GraphProtocol],
41
+ _protocols.TypeProtocol,
42
+ Sequence[_protocols.TypeProtocol],
43
+ None,
44
+ ]
45
+
46
+
47
+ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
48
+ """Infer the attribute type based on the type of the Python object."""
49
+ if isinstance(attr, int):
50
+ return _enums.AttributeType.INT
51
+ if isinstance(attr, float):
52
+ return _enums.AttributeType.FLOAT
53
+ if isinstance(attr, str):
54
+ return _enums.AttributeType.STRING
55
+ if isinstance(attr, (_core.Attr, _core.RefAttr)):
56
+ return attr.type
57
+ if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
58
+ return _enums.AttributeType.INTS
59
+ if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
60
+ return _enums.AttributeType.FLOATS
61
+ if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
62
+ return _enums.AttributeType.STRINGS
63
+ if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
64
+ # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
65
+ return _enums.AttributeType.TENSOR
66
+ if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)):
67
+ return _enums.AttributeType.GRAPH
68
+ if isinstance(attr, Sequence) and all(
69
+ isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
70
+ ):
71
+ return _enums.AttributeType.GRAPHS
72
+ if isinstance(
73
+ attr,
74
+ (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
75
+ ):
76
+ return _enums.AttributeType.TYPE_PROTO
77
+ if isinstance(attr, Sequence) and all(
78
+ isinstance(
79
+ x,
80
+ (
81
+ _core.TensorType,
82
+ _core.SequenceType,
83
+ _core.OptionalType,
84
+ _protocols.TypeProtocol,
85
+ ),
86
+ )
87
+ for x in attr
88
+ ):
89
+ return _enums.AttributeType.TYPE_PROTOS
90
+ raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
91
+
92
+
93
+ def convert_attribute(
94
+ name: str,
95
+ attr: SupportedAttrTypes,
96
+ attr_type: _enums.AttributeType | None = None,
97
+ ) -> _core.Attr | _core.RefAttr:
98
+ """Convert a Python object to a _core.Attr object.
99
+
100
+ This method is useful when constructing nodes with attributes. It infers the
101
+ attribute type based on the type of the Python value.
102
+
103
+ Args:
104
+ name: The name of the attribute.
105
+ attr: The value of the attribute.
106
+ attr_type: The type of the attribute. This is required when attr is None.
107
+ When provided, it overrides the inferred type.
108
+
109
+ Returns:
110
+ A ``Attr`` object.
111
+
112
+ Raises:
113
+ ValueError: If :param:`attr` is ``None`` and :param:`attr_type` is not provided.
114
+ TypeError: If the type of the attribute is not supported.
115
+ """
116
+ if attr is None:
117
+ if attr_type is None:
118
+ raise ValueError("attr_type must be provided when attr is None")
119
+ return _core.Attr(name, attr_type, None)
120
+
121
+ if isinstance(attr, (_core.Attr, _core.RefAttr)):
122
+ if attr.name != name:
123
+ raise ValueError(
124
+ f"Attribute name '{attr.name}' does not match provided name '{name}'"
125
+ )
126
+ if attr_type is not None and attr.type != attr_type:
127
+ raise ValueError(
128
+ f"Attribute type '{attr.type}' does not match provided type '{attr_type}'"
129
+ )
130
+ return attr
131
+
132
+ if attr_type is None:
133
+ attr_type = _infer_attribute_type(attr)
134
+
135
+ if attr_type == _enums.AttributeType.INT:
136
+ return _core.AttrInt64(name, attr) # type: ignore
137
+ if attr_type == _enums.AttributeType.FLOAT:
138
+ return _core.AttrFloat32(name, attr) # type: ignore
139
+ if attr_type == _enums.AttributeType.STRING:
140
+ return _core.AttrString(name, attr) # type: ignore
141
+ if attr_type == _enums.AttributeType.INTS:
142
+ return _core.AttrInt64s(name, attr) # type: ignore
143
+ if attr_type == _enums.AttributeType.FLOATS:
144
+ return _core.AttrFloat32s(name, attr) # type: ignore
145
+ if attr_type == _enums.AttributeType.STRINGS:
146
+ return _core.AttrStrings(name, attr) # type: ignore
147
+ if attr_type == _enums.AttributeType.TENSOR:
148
+ if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
149
+ return _core.AttrTensor(name, attr)
150
+ if isinstance(attr, onnx.TensorProto):
151
+ return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
152
+ if attr_type == _enums.AttributeType.GRAPH:
153
+ return _core.AttrGraph(name, attr) # type: ignore[arg-type]
154
+ if attr_type == _enums.AttributeType.GRAPHS:
155
+ return _core.AttrGraphs(name, attr) # type: ignore[arg-type]
156
+ if attr_type == _enums.AttributeType.TYPE_PROTO:
157
+ return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
158
+ if attr_type == _enums.AttributeType.TYPE_PROTOS:
159
+ return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type]
160
+ raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
161
+
162
+
163
+ def convert_attributes(
164
+ attrs: Mapping[str, SupportedAttrTypes],
165
+ ) -> list[_core.Attr | _core.RefAttr]:
166
+ """Convert a dictionary of attributes to a list of _core.Attr objects.
167
+
168
+ It infers the attribute type based on the type of the value. The supported
169
+ types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
170
+ :class:`_core.Tensor`, and :class:`_core.Attr`::
171
+
172
+ >>> import onnx_ir as ir
173
+ >>> import onnx
174
+ >>> import numpy as np
175
+ >>> attrs = {
176
+ ... "int": 1,
177
+ ... "float": 1.0,
178
+ ... "str": "hello",
179
+ ... "ints": [1, 2, 3],
180
+ ... "floats": [1.0, 2.0, 3.0],
181
+ ... "strings": ["hello", "world"],
182
+ ... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])),
183
+ ... "tensor_proto":
184
+ ... onnx.TensorProto(
185
+ ... dims=[3],
186
+ ... data_type=onnx.TensorProto.FLOAT,
187
+ ... float_data=[1.0, 2.0, 3.0],
188
+ ... name="proto",
189
+ ... ),
190
+ ... "graph": ir.Graph([], [], nodes=[], name="graph0"),
191
+ ... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")],
192
+ ... "type_proto": ir.TensorType(ir.DataType.FLOAT),
193
+ ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
194
+ ... }
195
+ >>> convert_attributes(attrs)
196
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(name='proto')), Attr('graph', INTS, Graph(
197
+ name='graph0',
198
+ inputs=(
199
+ <BLANKLINE>
200
+ ),
201
+ outputs=(
202
+ <BLANKLINE>
203
+ ),
204
+ len()=0
205
+ )), Attr('graphs', GRAPHS, [Graph(
206
+ name='graph1',
207
+ inputs=(
208
+ <BLANKLINE>
209
+ ),
210
+ outputs=(
211
+ <BLANKLINE>
212
+ ),
213
+ len()=0
214
+ ), Graph(
215
+ name='graph2',
216
+ inputs=(
217
+ <BLANKLINE>
218
+ ),
219
+ outputs=(
220
+ <BLANKLINE>
221
+ ),
222
+ len()=0
223
+ )]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
224
+
225
+ Args:
226
+ attrs: A dictionary of {<attribute name>: <python objects>} to convert.
227
+
228
+ Returns:
229
+ A list of _core.Attr objects.
230
+ """
231
+ attributes: list[_core.Attr | _core.RefAttr] = []
232
+ for name, attr in attrs.items():
233
+ if attr is not None:
234
+ attributes.append(convert_attribute(name, attr))
235
+ return attributes
236
+
237
+
238
+ def replace_all_uses_with(
239
+ values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
240
+ replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
241
+ ) -> None:
242
+ """Replace all uses of the given values with the replacements.
243
+
244
+ This is useful when nodes in the graph are replaced with new nodes, where
245
+ the old users need to be updated to use the outputs of the new nodes.
246
+
247
+ For example, suppose we have the following graph::
248
+
249
+ A -> {B, C}
250
+
251
+ We want to replace the node A with a new node D::
252
+
253
+ >>> import onnx_ir as ir
254
+ >>> input = ir.Input("input")
255
+ >>> node_a = ir.Node("", "A", [input])
256
+ >>> node_b = ir.Node("", "B", node_a.outputs)
257
+ >>> node_c = ir.Node("", "C", node_a.outputs)
258
+ >>> node_d = ir.Node("", "D", [input])
259
+ >>> replace_all_uses_with(node_a.outputs, node_d.outputs)
260
+ >>> len(node_b.inputs)
261
+ 1
262
+ >>> node_b.inputs[0].producer().op_type
263
+ 'D'
264
+ >>> len(node_c.inputs)
265
+ 1
266
+ >>> node_c.inputs[0].producer().op_type
267
+ 'D'
268
+ >>> len(node_a.outputs[0].uses())
269
+ 0
270
+
271
+ When values and replacements are sequences, they are zipped into pairs. All
272
+ users of the first value is replaced with the first replacement, and so on.
273
+
274
+ .. note::
275
+ You still need to update the graph outputs if any of the values being
276
+ replaced are part of the graph outputs. Be sure to remove the old nodes
277
+ from the graph using ``graph.remove()`` if they are no longer needed.
278
+
279
+ Args:
280
+ values: The value or values to be replaced.
281
+ replacements: The new value or values to use as inputs.
282
+ """
283
+ if not isinstance(values, Sequence):
284
+ values = (values,)
285
+ if not isinstance(replacements, Sequence):
286
+ replacements = (replacements,)
287
+ if len(values) != len(replacements):
288
+ raise ValueError("The number of values and replacements must match.")
289
+ for value, replacement in zip(values, replacements):
290
+ for user_node, index in tuple(value.uses()):
291
+ user_node.replace_input_with(index, replacement)
292
+
293
+
294
+ def tensor(
295
+ value: npt.ArrayLike
296
+ | onnx.TensorProto
297
+ | _protocols.DLPackCompatible
298
+ | _protocols.ArrayCompatible,
299
+ dtype: _enums.DataType | None = None,
300
+ name: str | None = None,
301
+ doc_string: str | None = None,
302
+ ) -> _protocols.TensorProtocol:
303
+ """Create a tensor value from an ArrayLike object or a TensorProto.
304
+
305
+ The dtype must match the value. Reinterpretation of the value is
306
+ not supported, unless if the value is a plain Python object, in which case
307
+ it is converted to a numpy array with the given dtype.
308
+
309
+ :param:`value` can be a numpy array, a plain Python object, or a TensorProto.
310
+
311
+ Example::
312
+
313
+ >>> import onnx_ir as ir
314
+ >>> import numpy as np
315
+ >>> import ml_dtypes
316
+ >>> import onnx
317
+ >>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
318
+ Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
319
+ >>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
320
+ Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
321
+ >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
322
+ >>> tp_tensor.numpy()
323
+ array(0.5, dtype=float32)
324
+
325
+ Args:
326
+ value: The numpy array to create the tensor from.
327
+ dtype: The data type of the tensor.
328
+ name: The name of the tensor.
329
+ doc_string: The documentation string of the tensor.
330
+
331
+ Returns:
332
+ A tensor value.
333
+
334
+ Raises:
335
+ ValueError: If the dtype does not match the value when value is not a plain Python
336
+ object like ``list[int]``.
337
+ """
338
+ if isinstance(value, _protocols.TensorProtocol):
339
+ if dtype is not None and dtype != value.dtype:
340
+ raise ValueError(
341
+ f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
342
+ "You do not have to specify the dtype when value is a Tensor."
343
+ )
344
+ return value
345
+ if isinstance(value, onnx.TensorProto):
346
+ tensor_ = serde.deserialize_tensor(value)
347
+ if name is not None:
348
+ tensor_.name = name
349
+ if doc_string is not None:
350
+ tensor_.doc_string = doc_string
351
+ if dtype is not None and dtype != tensor_.dtype:
352
+ raise ValueError(
353
+ f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
354
+ "You do not have to specify the dtype when value is a TensorProto."
355
+ )
356
+ elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
357
+ tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
358
+ else:
359
+ if dtype is not None:
360
+ numpy_dtype = dtype.numpy()
361
+ else:
362
+ numpy_dtype = None
363
+ array = np.array(value, dtype=numpy_dtype)
364
+ tensor_ = _core.Tensor(
365
+ array,
366
+ dtype=dtype,
367
+ shape=_core.Shape(array.shape),
368
+ name=name,
369
+ doc_string=name,
370
+ )
371
+ return tensor_
372
+
373
+
374
+ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
375
+ """Return a dictionary mapping names to values in the graph.
376
+
377
+ The mapping does not include values from subgraphs.
378
+
379
+ Args:
380
+ graph: The graph to extract the mapping from.
381
+
382
+ Returns:
383
+ A dictionary mapping names to values.
384
+ """
385
+ values = {}
386
+ values.update(graph.initializers)
387
+ # The names of the values can be None or "", which we need to exclude
388
+ for input in graph.inputs:
389
+ if not input.name:
390
+ continue
391
+ values[input.name] = input
392
+ for node in graph:
393
+ for value in node.outputs:
394
+ if not value.name:
395
+ continue
396
+ values[value.name] = value
397
+ return values
398
+
399
+
400
+ def replace_nodes_and_values(
401
+ graph_or_function: _core.Graph | _core.Function,
402
+ /,
403
+ insertion_point: _core.Node,
404
+ old_nodes: Sequence[_core.Node],
405
+ new_nodes: Sequence[_core.Node],
406
+ old_values: Sequence[_core.Value],
407
+ new_values: Sequence[_core.Value],
408
+ ) -> None:
409
+ """Replaces nodes and values in the graph or function.
410
+
411
+ Args:
412
+ graph_or_function: The graph or function to replace nodes and values in.
413
+ insertion_point: The node to insert the new nodes after.
414
+ old_nodes: The nodes to replace.
415
+ new_nodes: The nodes to replace with.
416
+ old_values: The values to replace.
417
+ new_values: The values to replace with.
418
+ """
419
+
420
+ for old_value, new_value in zip(old_values, new_values):
421
+ # Propagate relevant info from old value to new value
422
+ # TODO(Rama): Perhaps this should be a separate utility function. Also, consider
423
+ # merging old and new type/shape info.
424
+ new_value.type = old_value.type
425
+ new_value.shape = old_value.shape
426
+ new_value.const_value = old_value.const_value
427
+ new_value.name = old_value.name
428
+
429
+ # Reconnect the users of the deleted values to use the new values
430
+ replace_all_uses_with(old_values, new_values)
431
+ # Update graph/function outputs if the node generates output
432
+ replacement_mapping = dict(zip(old_values, new_values))
433
+ for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
434
+ if graph_or_function_output in replacement_mapping:
435
+ graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
436
+
437
+ # insert new nodes after the index node
438
+ graph_or_function.insert_after(insertion_point, new_nodes)
439
+ graph_or_function.remove(old_nodes, safe=True)