onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (45) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +857 -233
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +268 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +36 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/constant_manipulation.py +232 -0
  27. onnx_ir/passes/common/inliner.py +331 -0
  28. onnx_ir/passes/common/onnx_checker.py +57 -0
  29. onnx_ir/passes/common/shape_inference.py +112 -0
  30. onnx_ir/passes/common/topological_sort.py +33 -0
  31. onnx_ir/passes/common/unused_removal.py +196 -0
  32. onnx_ir/serde.py +288 -124
  33. onnx_ir/tape.py +15 -0
  34. onnx_ir/tensor_adapters.py +122 -0
  35. onnx_ir/testing.py +197 -0
  36. onnx_ir/traversal.py +4 -3
  37. onnx_ir-0.1.0.dist-info/METADATA +53 -0
  38. onnx_ir-0.1.0.dist-info/RECORD +41 -0
  39. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
  40. onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
  41. onnx_ir/_external_data.py +0 -323
  42. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  43. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  44. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  45. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/_protocols.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Protocols for the ONNX IR.
4
4
 
5
5
  This file defines the interfaces for tools to interact with the IR. The interfaces
@@ -31,18 +31,20 @@ tools.
31
31
  from __future__ import annotations
32
32
 
33
33
  import typing
34
- from typing import (
35
- Any,
34
+ from collections import OrderedDict
35
+ from collections.abc import (
36
36
  Collection,
37
37
  Iterable,
38
38
  Iterator,
39
39
  Mapping,
40
40
  MutableMapping,
41
41
  MutableSequence,
42
- OrderedDict,
43
- Protocol,
44
42
  Sequence,
45
- Tuple,
43
+ )
44
+ from typing import (
45
+ Any,
46
+ Literal,
47
+ Protocol,
46
48
  )
47
49
 
48
50
  from onnx_ir import _enums
@@ -52,7 +54,7 @@ if typing.TYPE_CHECKING:
52
54
  from typing_extensions import TypeAlias
53
55
 
54
56
  # An identifier that will uniquely identify an operator. E.g (domain, op_type, overload)
55
- OperatorIdentifier: TypeAlias = Tuple[str, str, str]
57
+ OperatorIdentifier: TypeAlias = tuple[str, str, str]
56
58
 
57
59
 
58
60
  @typing.runtime_checkable
@@ -277,6 +279,11 @@ class GraphProtocol(Protocol):
277
279
  seen as a Sequence of nodes and should be used as such. For example, to obtain
278
280
  all nodes as a list, call ``list(graph)``.
279
281
 
282
+ .. :note::
283
+ ``quantization_annotation`` is deserialized into the Value's ``meta`` field
284
+ under the ``quant_parameter_tensor_names`` key. Values that are stored
285
+ under this key will be serialized as quantization annotations.
286
+
280
287
  Attributes:
281
288
  name: The name of the graph.
282
289
  inputs: The input values of the graph.
@@ -288,7 +295,6 @@ class GraphProtocol(Protocol):
288
295
  meta: Metadata store for graph transform passes.
289
296
  """
290
297
 
291
- # TODO(justinchuby): Support quantization_annotation
292
298
  name: str | None
293
299
  inputs: MutableSequence[ValueProtocol]
294
300
  outputs: MutableSequence[ValueProtocol]
@@ -316,11 +322,15 @@ class GraphProtocol(Protocol):
316
322
  """Remove a node from the graph."""
317
323
  ...
318
324
 
319
- def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
325
+ def insert_after(
326
+ self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
327
+ ) -> None:
320
328
  """Insert new nodes after the given node."""
321
329
  ...
322
330
 
323
- def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
331
+ def insert_before(
332
+ self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
333
+ ) -> None:
324
334
  """Insert new nodes before the given node."""
325
335
  ...
326
336
 
@@ -414,6 +424,8 @@ class AttributeProtocol(Protocol):
414
424
  value: Any
415
425
  doc_string: str | None
416
426
 
427
+ def is_ref(self) -> Literal[False]: ...
428
+
417
429
 
418
430
  @typing.runtime_checkable
419
431
  class ReferenceAttributeProtocol(Protocol):
@@ -433,6 +445,8 @@ class ReferenceAttributeProtocol(Protocol):
433
445
  type: _enums.AttributeType
434
446
  doc_string: str | None
435
447
 
448
+ def is_ref(self) -> Literal[True]: ...
449
+
436
450
 
437
451
  @typing.runtime_checkable
438
452
  class SparseTensorProtocol(Protocol):
@@ -585,11 +599,15 @@ class FunctionProtocol(Protocol):
585
599
  """Remove a node from the function."""
586
600
  ...
587
601
 
588
- def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
602
+ def insert_after(
603
+ self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
604
+ ) -> None:
589
605
  """Insert new nodes after the given node."""
590
606
  ...
591
607
 
592
- def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None:
608
+ def insert_before(
609
+ self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, /
610
+ ) -> None:
593
611
  """Insert new nodes before the given node."""
594
612
  ...
595
613
 
onnx_ir/_tape.py CHANGED
@@ -1,77 +1,182 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Convenience methods for constructing the IR."""
4
4
 
5
- # NOTE: This is a temporary solution for constructing the IR. It should be replaced
6
- # with a more permanent solution in the future.
7
-
8
5
  from __future__ import annotations
9
6
 
10
- from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple
7
+ from collections.abc import Mapping, Sequence
8
+ from typing import (
9
+ Any,
10
+ Optional,
11
+ )
11
12
 
12
13
  import onnx_ir as ir
13
14
  from onnx_ir import _convenience
14
15
 
16
+ # A type representing the domains/versions used in creating nodes in IR.
17
+ UsedOpsets = set[tuple[str, Optional[int]]]
18
+
19
+
20
+ class Tape:
21
+ """Tape class.
22
+
23
+ A tape is a recorder that collects nodes and initializers that are created so
24
+ that they can be used for creating a graph.
25
+
26
+ Example::
27
+
28
+ import onnx_ir as ir
29
+
30
+ tape = ir.tape.Tape()
31
+ a = tape.initializer(ir.tensor([1, 2, 3], name="a"))
32
+ b: ir.Value = ...
33
+ c: ir.Value = ...
34
+ x = tape.op("Add", [a, b], attributes={"alpha": 1.0})
35
+ y = tape.op("Mul", [x, c], attributes={"beta": 2.0})
36
+ model = ir.Model(
37
+ graph := ir.Graph(
38
+ inputs=[b, c],
39
+ outputs=[y],
40
+ nodes=tape.nodes,
41
+ initializers=tape.initializers
42
+ opset_imports={"": 20},
43
+ ),
44
+ ir_version=10,
45
+ )
15
46
 
16
- class Tape(Iterable[ir.Node]):
17
- """A tape for recording nodes that are created."""
47
+ Attributes:
48
+ graph_like: The graph to append the new nodes and initializers to. When
49
+ it is None, the nodes and initializers are creating without owned by a graph.
50
+ Initializers will not be added to functions because it is not supported by ONNX.
51
+ """
18
52
 
19
- def __init__(self) -> None:
53
+ def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None:
20
54
  self._nodes: list[ir.Node] = []
55
+ self._initializers: list[ir.Value] = []
56
+ self._used_opsets: UsedOpsets = set()
57
+ self.graph_like = graph_like
21
58
 
22
- def __iter__(self) -> Iterator[ir.Node]:
23
- return iter(self._nodes)
59
+ def __repr__(self) -> str:
60
+ return f"Tape(nodes={self._nodes}, initializers={self._initializers})"
24
61
 
25
62
  @property
26
63
  def nodes(self) -> Sequence[ir.Node]:
27
64
  return tuple(self._nodes)
28
65
 
66
+ @property
67
+ def initializers(self) -> Sequence[ir.Value]:
68
+ return tuple(self._initializers)
69
+
70
+ @property
71
+ def used_opsets(self) -> UsedOpsets:
72
+ return self._used_opsets
73
+
29
74
  def op(
30
75
  self,
31
76
  op_type: str,
32
77
  inputs: Sequence[ir.Value | None],
33
78
  attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
79
+ *,
34
80
  domain: str = "",
81
+ overload: str = "",
82
+ version: int | None = None,
83
+ graph: ir.Graph | None = None,
84
+ name: str | None = None,
85
+ doc_string: str | None = None,
86
+ metadata_props: dict[str, str] | None = None,
87
+ output: ir.Value | None = None,
35
88
  ) -> ir.Value:
36
89
  if attributes is None:
37
- attrs: Sequence[ir.Attr | ir.RefAttr] = ()
90
+ attrs: Sequence[ir.Attr] = ()
38
91
  else:
39
92
  attrs = _convenience.convert_attributes(attributes)
40
- node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=1)
93
+ output_kwargs: dict[str, Any]
94
+ if output is None:
95
+ output_kwargs = dict(num_outputs=1)
96
+ else:
97
+ output_kwargs = dict(outputs=[output])
98
+ node = ir.Node(
99
+ domain,
100
+ op_type,
101
+ inputs,
102
+ attributes=attrs,
103
+ **output_kwargs,
104
+ overload=overload,
105
+ version=version,
106
+ graph=graph or self.graph_like,
107
+ name=name,
108
+ doc_string=doc_string,
109
+ metadata_props=metadata_props,
110
+ )
41
111
  self._nodes.append(node)
112
+ self._used_opsets.add((domain, version))
42
113
 
43
114
  return node.outputs[0]
44
115
 
45
- def op_multi_output(
116
+ def op_multi_out(
46
117
  self,
47
118
  op_type: str,
48
119
  inputs: Sequence[ir.Value | None],
49
120
  attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
50
121
  *,
51
- num_outputs: int,
122
+ num_outputs: int | None = None,
123
+ outputs: Sequence[ir.Value] | None = None,
52
124
  domain: str = "",
125
+ overload: str = "",
126
+ version: int | None = None,
127
+ graph: ir.Graph | None = None,
128
+ name: str | None = None,
129
+ doc_string: str | None = None,
130
+ metadata_props: dict[str, str] | None = None,
53
131
  ) -> Sequence[ir.Value]:
132
+ if num_outputs is None and outputs is None:
133
+ raise ValueError("Either num_outputs or outputs must be provided.")
134
+ if num_outputs is not None and outputs is not None:
135
+ raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.")
136
+ output_kwargs: dict[str, Any]
137
+ if outputs is None:
138
+ output_kwargs = dict(num_outputs=num_outputs)
139
+ else:
140
+ output_kwargs = dict(outputs=outputs)
54
141
  if attributes is None:
55
- attrs: Sequence[ir.Attr | ir.RefAttr] = ()
142
+ attrs: Sequence[ir.Attr] = ()
56
143
  else:
57
144
  attrs = _convenience.convert_attributes(attributes)
58
- node = ir.Node(domain, op_type, inputs, attributes=attrs, num_outputs=num_outputs)
145
+ node = ir.Node(
146
+ domain,
147
+ op_type,
148
+ inputs,
149
+ attributes=attrs,
150
+ **output_kwargs,
151
+ overload=overload,
152
+ version=version,
153
+ graph=graph or self.graph_like,
154
+ name=name,
155
+ doc_string=doc_string,
156
+ metadata_props=metadata_props,
157
+ )
59
158
  self._nodes.append(node)
159
+ self._used_opsets.add((domain, version))
60
160
 
61
161
  return node.outputs
62
162
 
63
-
64
- # A type representing the domains/versions used in creating nodes in IR.
65
- UsedOpsets = List[Tuple[str, Optional[int]]]
163
+ def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
164
+ name = name or tensor.name
165
+ if name is None:
166
+ raise ValueError("Name must be provided for initializer.")
167
+ shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
168
+ value = ir.Value(
169
+ name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
170
+ )
171
+ self._initializers.append(value)
172
+ if isinstance(self.graph_like, ir.Graph):
173
+ self.graph_like.register_initializer(value)
174
+ return value
66
175
 
67
176
 
68
177
  class Builder(Tape):
69
178
  """An extension of the tape that provides a more convenient API for constructing the IR."""
70
179
 
71
- def __init__(self):
72
- super().__init__()
73
- self._used_opsets: UsedOpsets = []
74
-
75
180
  def __getattr__(self, op_type: str) -> Any:
76
181
  return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)
77
182
 
@@ -85,20 +190,22 @@ class Builder(Tape):
85
190
  assert isinstance(outputs, int)
86
191
  num_outputs = outputs
87
192
 
88
- self._used_opsets.append((domain, version))
89
193
  if num_outputs == 1:
90
- value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
194
+ value = super().op(
195
+ op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version
196
+ )
91
197
  if isinstance(outputs, Sequence):
92
198
  value.name = outputs[0]
93
199
  return value
94
- values = super().op_multi_output(
95
- op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
200
+ values = super().op_multi_out(
201
+ op_type,
202
+ inputs=inputs,
203
+ attributes=kwargs,
204
+ domain=domain,
205
+ version=version,
206
+ num_outputs=num_outputs,
96
207
  )
97
208
  if isinstance(outputs, Sequence):
98
209
  for value, name in zip(values, outputs):
99
210
  value.name = name
100
211
  return values
101
-
102
- @property
103
- def used_opsets(self) -> UsedOpsets:
104
- return self._used_opsets
@@ -1,6 +1,3 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
3
- #
4
1
  # Copyright © 2016 Igor Kroitor
5
2
  #
6
3
  # MIT License
@@ -32,8 +29,8 @@ options to tune the output.
32
29
 
33
30
  from __future__ import annotations
34
31
 
32
+ from collections.abc import Mapping
35
33
  from math import ceil, floor, isnan
36
- from typing import Mapping
37
34
 
38
35
  black = "\033[30m"
39
36
  red = "\033[31m"
onnx_ir/_type_casting.py CHANGED
@@ -1,12 +1,12 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Numpy utilities for non-native type operation."""
4
4
  # TODO(justinchuby): Upstream the logic to onnx
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
8
  import typing
9
- from typing import Sequence
9
+ from collections.abc import Sequence
10
10
 
11
11
  import ml_dtypes
12
12
  import numpy as np
@@ -89,3 +89,18 @@ def unpack_int4(
89
89
  """
90
90
  unpacked = _unpack_uint4_as_uint8(data, dims)
91
91
  return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
92
+
93
+
94
+ def unpack_float4e2m1(
95
+ data: npt.NDArray[np.uint8], dims: Sequence[int]
96
+ ) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
97
+ """Convert a packed float4e2m1 array to unpacked float4e2m1 array.
98
+
99
+ Args:
100
+ data: A numpy array.
101
+ dims: The dimensions are used to reshape the unpacked buffer.
102
+
103
+ Returns:
104
+ A numpy array of float32 reshaped to dims.
105
+ """
106
+ return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)
@@ -1,12 +1,9 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Version utils for testing."""
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- import warnings
8
- from typing import Callable, Sequence
9
-
10
7
  import packaging.version
11
8
 
12
9
 
@@ -92,27 +89,3 @@ def has_transformers():
92
89
  return True # noqa
93
90
  except ImportError:
94
91
  return False
95
-
96
-
97
- def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
98
- """Catches warnings.
99
-
100
- Args:
101
- warns: warnings to ignore
102
-
103
- Returns:
104
- decorated function
105
- """
106
-
107
- def wrapper(fct):
108
- if warns is None:
109
- raise AssertionError(f"warns cannot be None for '{fct}'.")
110
-
111
- def call_f(self):
112
- with warnings.catch_warnings():
113
- warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
114
- return fct(self)
115
-
116
- return call_f
117
-
118
- return wrapper
onnx_ir/convenience.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Convenience methods for constructing and manipulating the IR."""
4
4
 
5
5
  from __future__ import annotations
@@ -9,11 +9,13 @@ __all__ = [
9
9
  "convert_attributes",
10
10
  "replace_all_uses_with",
11
11
  "replace_nodes_and_values",
12
+ "create_value_mapping",
12
13
  ]
13
14
 
14
15
  from onnx_ir._convenience import (
15
16
  convert_attribute,
16
17
  convert_attributes,
18
+ create_value_mapping,
17
19
  replace_all_uses_with,
18
20
  replace_nodes_and_values,
19
21
  )