onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (45) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +857 -233
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +268 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +36 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/constant_manipulation.py +232 -0
  27. onnx_ir/passes/common/inliner.py +331 -0
  28. onnx_ir/passes/common/onnx_checker.py +57 -0
  29. onnx_ir/passes/common/shape_inference.py +112 -0
  30. onnx_ir/passes/common/topological_sort.py +33 -0
  31. onnx_ir/passes/common/unused_removal.py +196 -0
  32. onnx_ir/serde.py +288 -124
  33. onnx_ir/tape.py +15 -0
  34. onnx_ir/tensor_adapters.py +122 -0
  35. onnx_ir/testing.py +197 -0
  36. onnx_ir/traversal.py +4 -3
  37. onnx_ir-0.1.0.dist-info/METADATA +53 -0
  38. onnx_ir-0.1.0.dist-info/RECORD +41 -0
  39. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
  40. onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
  41. onnx_ir/_external_data.py +0 -323
  42. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  43. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  44. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  45. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/__init__.py CHANGED
@@ -1,14 +1,19 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """In-memory intermediate representation for ONNX graphs."""
4
4
 
5
5
  __all__ = [
6
6
  # Modules
7
7
  "serde",
8
+ "traversal",
9
+ "convenience",
10
+ "external_data",
11
+ "tape",
8
12
  # IR classes
9
13
  "Tensor",
10
14
  "ExternalTensor",
11
15
  "StringTensor",
16
+ "LazyTensor",
12
17
  "SymbolicDim",
13
18
  "Shape",
14
19
  "TensorType",
@@ -66,19 +71,24 @@ __all__ = [
66
71
  "TensorProtoTensor",
67
72
  # Conversion functions
68
73
  "from_proto",
74
+ "from_onnx_text",
69
75
  "to_proto",
70
- # IR Tensor initializer
76
+ # Convenience constructors
71
77
  "tensor",
78
+ "node",
72
79
  # Pass infrastructure
73
80
  "passes",
74
- "traversal",
75
81
  # IO
76
82
  "load",
77
83
  "save",
84
+ # Flags
85
+ "DEBUG",
78
86
  ]
79
87
 
80
- from onnx_ir import passes, serde, traversal
81
- from onnx_ir._convenience import tensor
88
+ import types
89
+
90
+ from onnx_ir import convenience, external_data, passes, serde, tape, traversal
91
+ from onnx_ir._convenience._constructors import node, tensor
82
92
  from onnx_ir._core import (
83
93
  Attr,
84
94
  AttrFloat32,
@@ -100,6 +110,7 @@ from onnx_ir._core import (
100
110
  Graph,
101
111
  GraphView,
102
112
  Input,
113
+ LazyTensor,
103
114
  Model,
104
115
  Node,
105
116
  OptionalType,
@@ -138,17 +149,19 @@ from onnx_ir._protocols import (
138
149
  TypeProtocol,
139
150
  ValueProtocol,
140
151
  )
141
- from onnx_ir.serde import TensorProtoTensor, from_proto, to_proto
142
-
152
+ from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
143
153
 
144
- DEBUG: bool = False
154
+ DEBUG = False
145
155
 
146
156
 
147
157
  def __set_module() -> None:
148
158
  """Set the module of all functions in this module to this public module."""
149
159
  global_dict = globals()
150
160
  for name in __all__:
151
- global_dict[name].__module__ = __name__
161
+ obj = global_dict[name]
162
+ if hasattr(obj, "__module__") and not isinstance(obj, types.GenericAlias):
163
+ obj.__module__ = __name__
152
164
 
153
165
 
154
166
  __set_module()
167
+ __version__ = "0.1.0"
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Convenience methods for constructing and manipulating the IR.
4
4
 
5
5
  This is an internal only module. We should choose to expose some of the methods
@@ -12,19 +12,17 @@ __all__ = [
12
12
  "convert_attribute",
13
13
  "convert_attributes",
14
14
  "replace_all_uses_with",
15
+ "create_value_mapping",
16
+ "replace_nodes_and_values",
15
17
  ]
16
18
 
17
- import typing
18
- from typing import Mapping, Sequence, Union
19
+ from collections.abc import Mapping, Sequence
20
+ from typing import Union
19
21
 
20
- import numpy as np
21
22
  import onnx
22
23
 
23
24
  from onnx_ir import _core, _enums, _protocols, serde
24
25
 
25
- if typing.TYPE_CHECKING:
26
- import numpy.typing as npt
27
-
28
26
  SupportedAttrTypes = Union[
29
27
  str,
30
28
  int,
@@ -35,9 +33,9 @@ SupportedAttrTypes = Union[
35
33
  _protocols.TensorProtocol, # This includes all in-memory tensor types
36
34
  onnx.TensorProto,
37
35
  _core.Attr,
38
- _core.RefAttr,
39
36
  _protocols.GraphProtocol,
40
37
  Sequence[_protocols.GraphProtocol],
38
+ onnx.GraphProto,
41
39
  _protocols.TypeProtocol,
42
40
  Sequence[_protocols.TypeProtocol],
43
41
  None,
@@ -52,7 +50,7 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
52
50
  return _enums.AttributeType.FLOAT
53
51
  if isinstance(attr, str):
54
52
  return _enums.AttributeType.STRING
55
- if isinstance(attr, (_core.Attr, _core.RefAttr)):
53
+ if isinstance(attr, _core.Attr):
56
54
  return attr.type
57
55
  if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
58
56
  return _enums.AttributeType.INTS
@@ -63,10 +61,15 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
63
61
  if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
64
62
  # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
65
63
  return _enums.AttributeType.TENSOR
66
- if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)):
64
+ if isinstance(attr, Sequence) and all(
65
+ isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
66
+ for x in attr
67
+ ):
68
+ return _enums.AttributeType.TENSORS
69
+ if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
67
70
  return _enums.AttributeType.GRAPH
68
71
  if isinstance(attr, Sequence) and all(
69
- isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
72
+ isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
70
73
  ):
71
74
  return _enums.AttributeType.GRAPHS
72
75
  if isinstance(
@@ -94,7 +97,7 @@ def convert_attribute(
94
97
  name: str,
95
98
  attr: SupportedAttrTypes,
96
99
  attr_type: _enums.AttributeType | None = None,
97
- ) -> _core.Attr | _core.RefAttr:
100
+ ) -> _core.Attr:
98
101
  """Convert a Python object to a _core.Attr object.
99
102
 
100
103
  This method is useful when constructing nodes with attributes. It infers the
@@ -110,7 +113,7 @@ def convert_attribute(
110
113
  A ``Attr`` object.
111
114
 
112
115
  Raises:
113
- ValueError: If :param:`attr` is ``None`` and :param:`attr_type` is not provided.
116
+ ValueError: If ``attr`` is ``None`` and ``attr_type`` is not provided.
114
117
  TypeError: If the type of the attribute is not supported.
115
118
  """
116
119
  if attr is None:
@@ -118,7 +121,7 @@ def convert_attribute(
118
121
  raise ValueError("attr_type must be provided when attr is None")
119
122
  return _core.Attr(name, attr_type, None)
120
123
 
121
- if isinstance(attr, (_core.Attr, _core.RefAttr)):
124
+ if isinstance(attr, _core.Attr):
122
125
  if attr.name != name:
123
126
  raise ValueError(
124
127
  f"Attribute name '{attr.name}' does not match provided name '{name}'"
@@ -148,11 +151,27 @@ def convert_attribute(
148
151
  if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
149
152
  return _core.AttrTensor(name, attr)
150
153
  if isinstance(attr, onnx.TensorProto):
151
- return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
154
+ return _core.AttrTensor(name, serde.deserialize_tensor(attr))
155
+ if attr_type == _enums.AttributeType.TENSORS:
156
+ tensors = []
157
+ for t in attr: # type: ignore[union-attr]
158
+ if isinstance(t, onnx.TensorProto):
159
+ tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t)))
160
+ else:
161
+ tensors.append(t) # type: ignore[arg-type]
162
+ return _core.AttrTensors(name, tensors) # type: ignore[arg-type]
152
163
  if attr_type == _enums.AttributeType.GRAPH:
164
+ if isinstance(attr, onnx.GraphProto):
165
+ attr = serde.deserialize_graph(attr)
153
166
  return _core.AttrGraph(name, attr) # type: ignore[arg-type]
154
167
  if attr_type == _enums.AttributeType.GRAPHS:
155
- return _core.AttrGraphs(name, attr) # type: ignore[arg-type]
168
+ graphs = []
169
+ for graph in attr: # type: ignore[union-attr]
170
+ if isinstance(graph, onnx.GraphProto):
171
+ graphs.append(serde.deserialize_graph(graph))
172
+ else:
173
+ graphs.append(graph) # type: ignore[arg-type]
174
+ return _core.AttrGraphs(name, graphs) # type: ignore[arg-type]
156
175
  if attr_type == _enums.AttributeType.TYPE_PROTO:
157
176
  return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
158
177
  if attr_type == _enums.AttributeType.TYPE_PROTOS:
@@ -162,7 +181,7 @@ def convert_attribute(
162
181
 
163
182
  def convert_attributes(
164
183
  attrs: Mapping[str, SupportedAttrTypes],
165
- ) -> list[_core.Attr | _core.RefAttr]:
184
+ ) -> list[_core.Attr]:
166
185
  """Convert a dictionary of attributes to a list of _core.Attr objects.
167
186
 
168
187
  It infers the attribute type based on the type of the value. The supported
@@ -193,7 +212,7 @@ def convert_attributes(
193
212
  ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
194
213
  ... }
195
214
  >>> convert_attributes(attrs)
196
- [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(name='proto')), Attr('graph', INTS, Graph(
215
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph(
197
216
  name='graph0',
198
217
  inputs=(
199
218
  <BLANKLINE>
@@ -228,7 +247,7 @@ def convert_attributes(
228
247
  Returns:
229
248
  A list of _core.Attr objects.
230
249
  """
231
- attributes: list[_core.Attr | _core.RefAttr] = []
250
+ attributes: list[_core.Attr] = []
232
251
  for name, attr in attrs.items():
233
252
  if attr is not None:
234
253
  attributes.append(convert_attribute(name, attr))
@@ -291,86 +310,6 @@ def replace_all_uses_with(
291
310
  user_node.replace_input_with(index, replacement)
292
311
 
293
312
 
294
- def tensor(
295
- value: npt.ArrayLike
296
- | onnx.TensorProto
297
- | _protocols.DLPackCompatible
298
- | _protocols.ArrayCompatible,
299
- dtype: _enums.DataType | None = None,
300
- name: str | None = None,
301
- doc_string: str | None = None,
302
- ) -> _protocols.TensorProtocol:
303
- """Create a tensor value from an ArrayLike object or a TensorProto.
304
-
305
- The dtype must match the value. Reinterpretation of the value is
306
- not supported, unless if the value is a plain Python object, in which case
307
- it is converted to a numpy array with the given dtype.
308
-
309
- :param:`value` can be a numpy array, a plain Python object, or a TensorProto.
310
-
311
- Example::
312
-
313
- >>> import onnx_ir as ir
314
- >>> import numpy as np
315
- >>> import ml_dtypes
316
- >>> import onnx
317
- >>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
318
- Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
319
- >>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
320
- Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
321
- >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
322
- >>> tp_tensor.numpy()
323
- array(0.5, dtype=float32)
324
-
325
- Args:
326
- value: The numpy array to create the tensor from.
327
- dtype: The data type of the tensor.
328
- name: The name of the tensor.
329
- doc_string: The documentation string of the tensor.
330
-
331
- Returns:
332
- A tensor value.
333
-
334
- Raises:
335
- ValueError: If the dtype does not match the value when value is not a plain Python
336
- object like ``list[int]``.
337
- """
338
- if isinstance(value, _protocols.TensorProtocol):
339
- if dtype is not None and dtype != value.dtype:
340
- raise ValueError(
341
- f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
342
- "You do not have to specify the dtype when value is a Tensor."
343
- )
344
- return value
345
- if isinstance(value, onnx.TensorProto):
346
- tensor_ = serde.deserialize_tensor(value)
347
- if name is not None:
348
- tensor_.name = name
349
- if doc_string is not None:
350
- tensor_.doc_string = doc_string
351
- if dtype is not None and dtype != tensor_.dtype:
352
- raise ValueError(
353
- f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
354
- "You do not have to specify the dtype when value is a TensorProto."
355
- )
356
- elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
357
- tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
358
- else:
359
- if dtype is not None:
360
- numpy_dtype = dtype.numpy()
361
- else:
362
- numpy_dtype = None
363
- array = np.array(value, dtype=numpy_dtype)
364
- tensor_ = _core.Tensor(
365
- array,
366
- dtype=dtype,
367
- shape=_core.Shape(array.shape),
368
- name=name,
369
- doc_string=name,
370
- )
371
- return tensor_
372
-
373
-
374
313
  def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
375
314
  """Return a dictionary mapping names to values in the graph.
376
315
 
@@ -382,7 +321,7 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
382
321
  Returns:
383
322
  A dictionary mapping names to values.
384
323
  """
385
- values = {}
324
+ values: dict[str, _core.Value] = {}
386
325
  values.update(graph.initializers)
387
326
  # The names of the values can be None or "", which we need to exclude
388
327
  for input in graph.inputs:
@@ -416,7 +355,6 @@ def replace_nodes_and_values(
416
355
  old_values: The values to replace.
417
356
  new_values: The values to replace with.
418
357
  """
419
-
420
358
  for old_value, new_value in zip(old_values, new_values):
421
359
  # Propagate relevant info from old value to new value
422
360
  # TODO(Rama): Perhaps this should be a separate utility function. Also, consider
@@ -0,0 +1,213 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Convenience constructors for IR objects."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "tensor",
9
+ "node",
10
+ ]
11
+
12
+ import typing
13
+ from collections.abc import Mapping, Sequence
14
+
15
+ import numpy as np
16
+ import onnx
17
+
18
+ from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
19
+
20
+ if typing.TYPE_CHECKING:
21
+ import numpy.typing as npt
22
+
23
+ import onnx_ir as ir
24
+
25
+
26
+ def tensor(
27
+ value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
28
+ dtype: _enums.DataType | None = None,
29
+ name: str | None = None,
30
+ doc_string: str | None = None,
31
+ ) -> _protocols.TensorProtocol:
32
+ """Create a tensor value from an ArrayLike object or a TensorProto.
33
+
34
+ The dtype must match the value. Reinterpretation of the value is
35
+ not supported, unless if the value is a plain Python object, in which case
36
+ it is converted to a numpy array with the given dtype.
37
+
38
+ ``value`` can be a numpy array, a plain Python object, or a TensorProto.
39
+
40
+ Example::
41
+
42
+ >>> import onnx_ir as ir
43
+ >>> import numpy as np
44
+ >>> import ml_dtypes
45
+ >>> import onnx
46
+ >>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
47
+ Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
48
+ >>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
49
+ Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
50
+ >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
51
+ >>> tp_tensor.numpy()
52
+ array(0.5, dtype=float32)
53
+ >>> import torch
54
+ >>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor")
55
+ TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor')
56
+
57
+ Args:
58
+ value: The numpy array to create the tensor from.
59
+ dtype: The data type of the tensor.
60
+ name: The name of the tensor.
61
+ doc_string: The documentation string of the tensor.
62
+
63
+ Returns:
64
+ A tensor value.
65
+
66
+ Raises:
67
+ ValueError: If the dtype does not match the value when value is not a plain Python
68
+ object like ``list[int]``.
69
+ """
70
+ if isinstance(value, _protocols.TensorProtocol):
71
+ if dtype is not None and dtype != value.dtype:
72
+ raise ValueError(
73
+ f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
74
+ "You do not have to specify the dtype when value is a Tensor."
75
+ )
76
+ return value
77
+ if isinstance(value, onnx.TensorProto):
78
+ tensor_ = serde.deserialize_tensor(value)
79
+ if name is not None:
80
+ tensor_.name = name
81
+ if doc_string is not None:
82
+ tensor_.doc_string = doc_string
83
+ if dtype is not None and dtype != tensor_.dtype:
84
+ raise ValueError(
85
+ f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
86
+ "You do not have to specify the dtype when value is a TensorProto."
87
+ )
88
+ return tensor_
89
+ elif str(type(value)) == "<class 'torch.Tensor'>":
90
+ # NOTE: We use str(type(...)) and do not import torch for type checking
91
+ # as it creates overhead during import
92
+ return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
93
+ elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
94
+ return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string)
95
+
96
+ # Plain (numerical) Python object. Determine the numpy dtype and use np.array to construct the tensor
97
+ if dtype is not None:
98
+ if not isinstance(dtype, _enums.DataType):
99
+ raise TypeError(f"dtype must be an instance of DataType. dtype={dtype}")
100
+ numpy_dtype = dtype.numpy()
101
+ elif isinstance(value, Sequence) and not value:
102
+ raise ValueError("dtype must be specified when value is an empty sequence.")
103
+ elif isinstance(value, int) and not isinstance(value, bool):
104
+ # Specify int64 for ints because on Windows this may be int32
105
+ numpy_dtype = np.dtype(np.int64)
106
+ elif isinstance(value, float):
107
+ # If the value is a single float, we use np.float32 as the default dtype
108
+ numpy_dtype = np.dtype(np.float32)
109
+ elif isinstance(value, Sequence) and value:
110
+ if all((isinstance(elem, int) and not isinstance(elem, bool)) for elem in value):
111
+ numpy_dtype = np.dtype(np.int64)
112
+ elif all(isinstance(elem, float) for elem in value):
113
+ # If the value is a sequence of floats, we use np.float32 as the default dtype
114
+ numpy_dtype = np.dtype(np.float32)
115
+ else:
116
+ numpy_dtype = None
117
+ else:
118
+ numpy_dtype = None
119
+
120
+ array = np.array(value, dtype=numpy_dtype)
121
+
122
+ # Handle string tensors by encoding them
123
+ if isinstance(value, str) or (
124
+ isinstance(value, Sequence) and value and all(isinstance(elem, str) for elem in value)
125
+ ):
126
+ array = np.strings.encode(array, encoding="utf-8")
127
+ return _core.StringTensor(
128
+ array,
129
+ shape=_core.Shape(array.shape),
130
+ name=name,
131
+ doc_string=doc_string,
132
+ )
133
+
134
+ return _core.Tensor(
135
+ array,
136
+ dtype=dtype,
137
+ shape=_core.Shape(array.shape),
138
+ name=name,
139
+ doc_string=doc_string,
140
+ )
141
+
142
+
143
+ def node(
144
+ op_type: str,
145
+ inputs: Sequence[ir.Value | None],
146
+ attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
147
+ *,
148
+ domain: str = "",
149
+ overload: str = "",
150
+ num_outputs: int | None = None,
151
+ outputs: Sequence[ir.Value] | None = None,
152
+ version: int | None = None,
153
+ graph: ir.Graph | None = None,
154
+ name: str | None = None,
155
+ doc_string: str | None = None,
156
+ metadata_props: dict[str, str] | None = None,
157
+ ) -> ir.Node:
158
+ """Create an :class:`ir.Node`.
159
+
160
+ This is a convenience constructor for creating a Node that supports Python
161
+ objects as attributes.
162
+
163
+ Example::
164
+
165
+ >>> import onnx_ir as ir
166
+ >>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
167
+ >>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
168
+ >>> node = ir.node(
169
+ ... "SomeOp",
170
+ ... inputs=[input_a, input_b],
171
+ ... attributes={"alpha": 1.0, "some_list": [1, 2, 3]},
172
+ ... domain="some.domain",
173
+ ... name="node_name"
174
+ ... )
175
+ >>> node.op_type
176
+ 'SomeOp'
177
+
178
+ Args:
179
+ op_type: The name of the operator.
180
+ inputs: The input values. When an input is None, it is an empty input.
181
+ attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
182
+ overload: The overload name when the node is invoking a function.
183
+ domain: The domain of the operator. For onnx operators, this is an empty string.
184
+ num_outputs: The number of outputs of the node. If not specified, the number is 1.
185
+ outputs: The output values. If None, the outputs are created during initialization.
186
+ version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
187
+ graph: The graph that the node belongs to. If None, the node is not added to any graph.
188
+ A `Node` must belong to zero or one graph.
189
+ name: The name of the node. If None, the node is anonymous.
190
+ doc_string: The documentation string.
191
+ metadata_props: The metadata properties.
192
+
193
+ Returns:
194
+ A node with the given op_type and inputs.
195
+ """
196
+ if attributes is None:
197
+ attrs: Sequence[ir.Attr] = ()
198
+ else:
199
+ attrs = _convenience.convert_attributes(attributes)
200
+ return _core.Node(
201
+ domain=domain,
202
+ op_type=op_type,
203
+ inputs=inputs,
204
+ attributes=attrs,
205
+ overload=overload,
206
+ num_outputs=num_outputs,
207
+ outputs=outputs,
208
+ version=version,
209
+ graph=graph,
210
+ name=name,
211
+ doc_string=doc_string,
212
+ metadata_props=metadata_props,
213
+ )