onnx-ir 0.1.8__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (52) hide show
  1. {onnx_ir-0.1.8/src/onnx_ir.egg-info → onnx_ir-0.1.9}/PKG-INFO +2 -2
  2. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/pyproject.toml +1 -1
  3. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/__init__.py +3 -2
  4. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_convenience/__init__.py +4 -4
  5. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_convenience/_constructors.py +75 -4
  6. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_core.py +58 -41
  7. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_enums.py +5 -8
  8. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/serde.py +29 -7
  9. {onnx_ir-0.1.8 → onnx_ir-0.1.9/src/onnx_ir.egg-info}/PKG-INFO +2 -2
  10. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir.egg-info/requires.txt +1 -1
  11. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/LICENSE +0 -0
  12. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/MANIFEST.in +0 -0
  13. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/README.md +0 -0
  14. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/setup.cfg +0 -0
  15. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_display.py +0 -0
  16. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_graph_comparison.py +0 -0
  17. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_graph_containers.py +0 -0
  18. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_io.py +0 -0
  19. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_linked_list.py +0 -0
  20. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_metadata.py +0 -0
  21. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_name_authority.py +0 -0
  22. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_polyfill.py +0 -0
  23. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_protocols.py +0 -0
  24. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_tape.py +0 -0
  25. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  26. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_type_casting.py +0 -0
  27. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/_version_utils.py +0 -0
  28. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/convenience.py +0 -0
  29. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/external_data.py +0 -0
  30. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/__init__.py +0 -0
  31. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/_pass_infra.py +0 -0
  32. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/__init__.py +0 -0
  33. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  34. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  35. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
  36. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
  37. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/identity_elimination.py +0 -0
  38. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/initializer_deduplication.py +0 -0
  39. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/inliner.py +0 -0
  40. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/naming.py +0 -0
  41. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  42. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  43. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  44. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  45. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/py.typed +0 -0
  46. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/tape.py +0 -0
  47. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/tensor_adapters.py +0 -0
  48. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/testing.py +0 -0
  49. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir/traversal.py +0 -0
  50. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir.egg-info/SOURCES.txt +0 -0
  51. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  52. {onnx_ir-0.1.8 → onnx_ir-0.1.9}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -19,7 +19,7 @@ License-File: LICENSE
19
19
  Requires-Dist: numpy
20
20
  Requires-Dist: onnx>=1.16
21
21
  Requires-Dist: typing_extensions>=4.10
22
- Requires-Dist: ml_dtypes
22
+ Requires-Dist: ml_dtypes>=0.5.0
23
23
  Dynamic: license-file
24
24
 
25
25
  # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
@@ -21,7 +21,7 @@ classifiers = [
21
21
  "Programming Language :: Python :: 3.12",
22
22
  "Programming Language :: Python :: 3.13",
23
23
  ]
24
- dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
24
+ dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes>=0.5.0"]
25
25
 
26
26
  [project.urls]
27
27
  Homepage = "https://onnx.ai/ir-py"
@@ -78,6 +78,7 @@ __all__ = [
78
78
  # Convenience constructors
79
79
  "tensor",
80
80
  "node",
81
+ "val",
81
82
  # Pass infrastructure
82
83
  "passes",
83
84
  # IO
@@ -90,7 +91,7 @@ __all__ = [
90
91
  import types
91
92
 
92
93
  from onnx_ir import convenience, external_data, passes, serde, tape, traversal
93
- from onnx_ir._convenience._constructors import node, tensor
94
+ from onnx_ir._convenience._constructors import node, tensor, val
94
95
  from onnx_ir._core import (
95
96
  Attr,
96
97
  AttrFloat32,
@@ -167,4 +168,4 @@ def __set_module() -> None:
167
168
 
168
169
 
169
170
  __set_module()
170
- __version__ = "0.1.8"
171
+ __version__ = "0.1.9"
@@ -226,7 +226,7 @@ def convert_attributes(
226
226
  ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
227
227
  ... }
228
228
  >>> convert_attributes(attrs)
229
- [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
229
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, (1, 2, 3)), Attr('floats', FLOATS, (1.0, 2.0, 3.0)), Attr('strings', STRINGS, ('hello', 'world')), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
230
230
  name='graph0',
231
231
  inputs=(
232
232
  <BLANKLINE>
@@ -235,7 +235,7 @@ def convert_attributes(
235
235
  <BLANKLINE>
236
236
  ),
237
237
  len()=0
238
- )), Attr('graphs', GRAPHS, [Graph(
238
+ )), Attr('graphs', GRAPHS, (Graph(
239
239
  name='graph1',
240
240
  inputs=(
241
241
  <BLANKLINE>
@@ -253,7 +253,7 @@ def convert_attributes(
253
253
  <BLANKLINE>
254
254
  ),
255
255
  len()=0
256
- )]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
256
+ ))), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, (Tensor(FLOAT), Tensor(FLOAT)))]
257
257
 
258
258
  .. important::
259
259
  An empty sequence should be created with an explicit type by initializing
@@ -293,7 +293,7 @@ def replace_all_uses_with(
293
293
  We want to replace the node A with a new node D::
294
294
 
295
295
  >>> import onnx_ir as ir
296
- >>> input = ir.Input("input")
296
+ >>> input = ir.val("input")
297
297
  >>> node_a = ir.Node("", "A", [input])
298
298
  >>> node_b = ir.Node("", "B", node_a.outputs)
299
299
  >>> node_c = ir.Node("", "C", node_a.outputs)
@@ -25,7 +25,7 @@ if typing.TYPE_CHECKING:
25
25
 
26
26
  def tensor(
27
27
  value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
28
- dtype: _enums.DataType | None = None,
28
+ dtype: ir.DataType | None = None,
29
29
  name: str | None = None,
30
30
  doc_string: str | None = None,
31
31
  ) -> _protocols.TensorProtocol:
@@ -159,7 +159,7 @@ def node(
159
159
  doc_string: str | None = None,
160
160
  metadata_props: dict[str, str] | None = None,
161
161
  ) -> ir.Node:
162
- """Create an :class:`~onnx_ir.Node`.
162
+ """Create a :class:`~onnx_ir.Node`.
163
163
 
164
164
  This is a convenience constructor for creating a Node that supports Python
165
165
  objects as attributes.
@@ -167,8 +167,8 @@ def node(
167
167
  Example::
168
168
 
169
169
  >>> import onnx_ir as ir
170
- >>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
171
- >>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
170
+ >>> input_a = ir.val("A", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
171
+ >>> input_b = ir.val("B", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
172
172
  >>> node = ir.node(
173
173
  ... "SomeOp",
174
174
  ... inputs=[input_a, input_b],
@@ -215,3 +215,74 @@ def node(
215
215
  doc_string=doc_string,
216
216
  metadata_props=metadata_props,
217
217
  )
218
+
219
+
220
+ def val(
221
+ name: str | None,
222
+ dtype: ir.DataType | None = None,
223
+ shape: ir.Shape | Sequence[int | str | None] | None = None,
224
+ *,
225
+ type: ir.TypeProtocol | None = None,
226
+ const_value: ir.TensorProtocol | None = None,
227
+ ) -> ir.Value:
228
+ """Create a :class:`~onnx_ir.Value` with the given name and type.
229
+
230
+ This is a convenience constructor for creating a Value that allows you to specify
231
+ dtype and shape in a more relaxed manner. Whereas to create a Value directly, you
232
+ need to create a :class:`~onnx_ir.TypeProtocol` and :class:`~onnx_ir.Shape` object
233
+ first, this function allows you to specify dtype as a :class:`~onnx_ir.DataType`
234
+ and shape as a sequence of integers or symbolic dimensions.
235
+
236
+ Example::
237
+
238
+ >>> import onnx_ir as ir
239
+ >>> t = ir.val("x", ir.DataType.FLOAT, ["N", 42, 3])
240
+ >>> t.name
241
+ 'x'
242
+ >>> t.type
243
+ Tensor(FLOAT)
244
+ >>> t.shape
245
+ Shape([SymbolicDim(N), 42, 3])
246
+
247
+ .. versionadded:: 0.1.9
248
+
249
+ Args:
250
+ name: The name of the value.
251
+ dtype: The data type of the TensorType of the value. This is used only when type is None.
252
+ shape: The shape of the value.
253
+ type: The type of the value. Only one of dtype and type can be specified.
254
+ const_value: The constant tensor that initializes the value. Supply this argument
255
+ when you want to create an initializer. The type and shape can be obtained from the tensor.
256
+
257
+ Returns:
258
+ A Value object.
259
+ """
260
+ if const_value is not None:
261
+ const_tensor_type = _core.TensorType(const_value.dtype)
262
+ if type is not None and type != const_tensor_type:
263
+ raise ValueError(
264
+ f"The type does not match the const_value. type={type} but const_value has type {const_tensor_type}. "
265
+ "You do not have to specify the type when const_value is provided."
266
+ )
267
+ if dtype is not None and dtype != const_value.dtype:
268
+ raise ValueError(
269
+ f"The dtype does not match the const_value. dtype={dtype} but const_value has dtype {const_value.dtype}. "
270
+ "You do not have to specify the dtype when const_value is provided."
271
+ )
272
+ if shape is not None and _core.Shape(shape) != const_value.shape:
273
+ raise ValueError(
274
+ f"The shape does not match the const_value. shape={shape} but const_value has shape {const_value.shape}. "
275
+ "You do not have to specify the shape when const_value is provided."
276
+ )
277
+ return _core.Value(
278
+ name=name,
279
+ type=const_tensor_type,
280
+ shape=_core.Shape(const_value.shape), # type: ignore
281
+ const_value=const_value,
282
+ )
283
+
284
+ if type is None and dtype is not None:
285
+ type = _core.TensorType(dtype)
286
+ if shape is not None and not isinstance(shape, _core.Shape):
287
+ shape = _core.Shape(shape)
288
+ return _core.Value(name=name, type=type, shape=shape)
@@ -45,7 +45,7 @@ from typing import (
45
45
 
46
46
  import ml_dtypes
47
47
  import numpy as np
48
- from typing_extensions import TypeIs
48
+ from typing_extensions import TypeIs, deprecated
49
49
 
50
50
  import onnx_ir
51
51
  from onnx_ir import (
@@ -2275,6 +2275,7 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2275
2275
  return self._is_initializer
2276
2276
 
2277
2277
 
2278
+ @deprecated("Input is deprecated since 0.1.9. Use ir.val(...) instead.")
2278
2279
  def Input(
2279
2280
  name: str | None = None,
2280
2281
  shape: Shape | None = None,
@@ -3397,6 +3398,31 @@ class Attr(
3397
3398
  *,
3398
3399
  doc_string: str | None = None,
3399
3400
  ) -> None:
3401
+ # Quick checks to ensure that INT and FLOAT attributes are stored as int and float,
3402
+ # not np.int32, np.float32, bool, etc.
3403
+ # This also allows errors to be raised at the time of construction instead of later
3404
+ # during serialization.
3405
+ # TODO(justinchuby): Use case matching when we drop support for Python 3.9
3406
+ if value is None:
3407
+ # Value can be None for reference attributes or when it is used as a
3408
+ # placeholder for schemas
3409
+ pass
3410
+ elif type == _enums.AttributeType.INT:
3411
+ value = int(value)
3412
+ elif type == _enums.AttributeType.FLOAT:
3413
+ value = float(value)
3414
+ elif type == _enums.AttributeType.INTS:
3415
+ value = tuple(int(v) for v in value)
3416
+ elif type == _enums.AttributeType.FLOATS:
3417
+ value = tuple(float(v) for v in value)
3418
+ elif type in {
3419
+ _enums.AttributeType.STRINGS,
3420
+ _enums.AttributeType.TENSORS,
3421
+ _enums.AttributeType.GRAPHS,
3422
+ _enums.AttributeType.TYPE_PROTOS,
3423
+ }:
3424
+ value = tuple(value)
3425
+
3400
3426
  self._name = name
3401
3427
  self._type = type
3402
3428
  self._value = value
@@ -3458,7 +3484,7 @@ class Attr(
3458
3484
  return f"@{self.ref_attr_name}"
3459
3485
  if self.type == _enums.AttributeType.GRAPH:
3460
3486
  return textwrap.indent("\n" + str(self.value), " " * 4)
3461
- return str(self.value)
3487
+ return repr(self.value)
3462
3488
 
3463
3489
  def __repr__(self) -> str:
3464
3490
  if self.is_ref():
@@ -3472,8 +3498,8 @@ class Attr(
3472
3498
  raise TypeError(
3473
3499
  f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
3474
3500
  )
3475
- # Do not use isinstance check because it may prevent np.float32 etc. from being used
3476
- return float(self.value)
3501
+ # value is guaranteed to be a float in the constructor
3502
+ return self.value
3477
3503
 
3478
3504
  def as_int(self) -> int:
3479
3505
  """Get the attribute value as an int."""
@@ -3481,8 +3507,8 @@ class Attr(
3481
3507
  raise TypeError(
3482
3508
  f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
3483
3509
  )
3484
- # Do not use isinstance check because it may prevent np.int32 etc. from being used
3485
- return int(self.value)
3510
+ # value is guaranteed to be an int in the constructor
3511
+ return self.value
3486
3512
 
3487
3513
  def as_string(self) -> str:
3488
3514
  """Get the attribute value as a string."""
@@ -3490,9 +3516,10 @@ class Attr(
3490
3516
  raise TypeError(
3491
3517
  f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
3492
3518
  )
3493
- if not isinstance(self.value, str):
3519
+ value = self.value
3520
+ if not isinstance(value, str):
3494
3521
  raise TypeError(f"Value of attribute '{self!r}' is not a string.")
3495
- return self.value
3522
+ return value
3496
3523
 
3497
3524
  def as_tensor(self) -> _protocols.TensorProtocol:
3498
3525
  """Get the attribute value as a tensor."""
@@ -3500,9 +3527,10 @@ class Attr(
3500
3527
  raise TypeError(
3501
3528
  f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
3502
3529
  )
3503
- if not isinstance(self.value, _protocols.TensorProtocol):
3530
+ value = self.value
3531
+ if not isinstance(value, _protocols.TensorProtocol):
3504
3532
  raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
3505
- return self.value
3533
+ return value
3506
3534
 
3507
3535
  def as_graph(self) -> Graph:
3508
3536
  """Get the attribute value as a graph."""
@@ -3510,75 +3538,64 @@ class Attr(
3510
3538
  raise TypeError(
3511
3539
  f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
3512
3540
  )
3513
- if not isinstance(self.value, Graph):
3541
+ value = self.value
3542
+ if not isinstance(value, Graph):
3514
3543
  raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
3515
- return self.value
3544
+ return value
3516
3545
 
3517
- def as_floats(self) -> Sequence[float]:
3546
+ def as_floats(self) -> tuple[float, ...]:
3518
3547
  """Get the attribute value as a sequence of floats."""
3519
3548
  if self.type != _enums.AttributeType.FLOATS:
3520
3549
  raise TypeError(
3521
3550
  f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
3522
3551
  )
3523
- if not isinstance(self.value, Sequence):
3524
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3525
- # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3526
- # Create a copy of the list to prevent mutation
3527
- return [float(v) for v in self.value]
3552
+ # value is guaranteed to be a sequence of float in the constructor
3553
+ return self.value
3528
3554
 
3529
- def as_ints(self) -> Sequence[int]:
3555
+ def as_ints(self) -> tuple[int, ...]:
3530
3556
  """Get the attribute value as a sequence of ints."""
3531
3557
  if self.type != _enums.AttributeType.INTS:
3532
3558
  raise TypeError(
3533
3559
  f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
3534
3560
  )
3535
- if not isinstance(self.value, Sequence):
3536
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3537
- # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3538
- # Create a copy of the list to prevent mutation
3539
- return list(self.value)
3561
+ # value is guaranteed to be a sequence of int in the constructor
3562
+ return self.value
3540
3563
 
3541
- def as_strings(self) -> Sequence[str]:
3564
+ def as_strings(self) -> tuple[str, ...]:
3542
3565
  """Get the attribute value as a sequence of strings."""
3543
3566
  if self.type != _enums.AttributeType.STRINGS:
3544
3567
  raise TypeError(
3545
3568
  f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
3546
3569
  )
3547
- if not isinstance(self.value, Sequence):
3548
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3549
3570
  if onnx_ir.DEBUG:
3550
3571
  if not all(isinstance(x, str) for x in self.value):
3551
3572
  raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
3552
- # Create a copy of the list to prevent mutation
3553
- return list(self.value)
3573
+ # value is guaranteed to be a sequence in the constructor
3574
+ return self.value
3554
3575
 
3555
- def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
3576
+ def as_tensors(self) -> tuple[_protocols.TensorProtocol, ...]:
3556
3577
  """Get the attribute value as a sequence of tensors."""
3557
3578
  if self.type != _enums.AttributeType.TENSORS:
3558
3579
  raise TypeError(
3559
3580
  f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
3560
3581
  )
3561
- if not isinstance(self.value, Sequence):
3562
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3563
3582
  if onnx_ir.DEBUG:
3564
3583
  if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
3565
3584
  raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
3566
- # Create a copy of the list to prevent mutation
3567
- return list(self.value)
3585
+ # value is guaranteed to be a sequence in the constructor
3586
+ return tuple(self.value)
3568
3587
 
3569
- def as_graphs(self) -> Sequence[Graph]:
3588
+ def as_graphs(self) -> tuple[Graph, ...]:
3570
3589
  """Get the attribute value as a sequence of graphs."""
3571
3590
  if self.type != _enums.AttributeType.GRAPHS:
3572
3591
  raise TypeError(
3573
3592
  f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
3574
3593
  )
3575
- if not isinstance(self.value, Sequence):
3576
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3577
3594
  if onnx_ir.DEBUG:
3578
3595
  if not all(isinstance(x, Graph) for x in self.value):
3579
3596
  raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
3580
- # Create a copy of the list to prevent mutation
3581
- return list(self.value)
3597
+ # value is guaranteed to be a sequence in the constructor
3598
+ return tuple(self.value)
3582
3599
 
3583
3600
 
3584
3601
  # NOTE: The following functions are just for convenience
@@ -3605,7 +3622,7 @@ def RefAttr(
3605
3622
  return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
3606
3623
 
3607
3624
 
3608
- def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
3625
+ def AttrFloat32(name: str, value: float | np.floating, doc_string: str | None = None) -> Attr:
3609
3626
  """Create a float attribute."""
3610
3627
  # NOTE: The function name is capitalized to maintain API backward compatibility.
3611
3628
  return Attr(
@@ -3616,7 +3633,7 @@ def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
3616
3633
  )
3617
3634
 
3618
3635
 
3619
- def AttrInt64(name: str, value: int, doc_string: str | None = None) -> Attr:
3636
+ def AttrInt64(name: str, value: int | np.integer, doc_string: str | None = None) -> Attr:
3620
3637
  """Create an int attribute."""
3621
3638
  # NOTE: The function name is capitalized to maintain API backward compatibility.
3622
3639
  return Attr(
@@ -357,7 +357,10 @@ class DataType(enum.IntEnum):
357
357
  }
358
358
 
359
359
  def is_string(self) -> bool:
360
- """Returns True if the data type is a string type."""
360
+ """Returns True if the data type is a string type.
361
+
362
+ .. versionadded:: 0.1.8
363
+ """
361
364
  return self == DataType.STRING
362
365
 
363
366
  def __repr__(self) -> str:
@@ -419,15 +422,9 @@ _NP_TYPE_TO_DATA_TYPE = {
419
422
  np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
420
423
  np.dtype(ml_dtypes.int4): DataType.INT4,
421
424
  np.dtype(ml_dtypes.uint4): DataType.UINT4,
425
+ np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1,
422
426
  }
423
427
 
424
- # TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
425
- _NP_TYPE_TO_DATA_TYPE.update(
426
- {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
427
- if hasattr(ml_dtypes, "float4_e2m1fn")
428
- else {}
429
- )
430
-
431
428
  # ONNX DataType to Numpy dtype.
432
429
  _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
433
430
 
@@ -711,7 +711,7 @@ def _deserialize_graph(
711
711
 
712
712
  # Create values for initializers and inputs
713
713
  initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
714
- inputs = [_core.Input(info.name) for info in proto.input]
714
+ inputs = [_core.Value(name=info.name) for info in proto.input]
715
715
  for info, value in zip(proto.input, inputs):
716
716
  deserialize_value_info_proto(info, value)
717
717
 
@@ -869,7 +869,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
869
869
  Returns:
870
870
  An IR Function object representing the ONNX function.
871
871
  """
872
- inputs = [_core.Input(name) for name in proto.input]
872
+ inputs = [_core.Value(name=name) for name in proto.input]
873
873
  values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
874
874
  value_info = {info.name: info for info in getattr(proto, "value_info", [])}
875
875
 
@@ -1143,7 +1143,19 @@ def _deserialize_attribute(
1143
1143
  if type_ == _enums.AttributeType.FLOAT:
1144
1144
  return _core.AttrFloat32(name, proto.f, doc_string=doc_string)
1145
1145
  if type_ == _enums.AttributeType.STRING:
1146
- return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string)
1146
+ try:
1147
+ return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string)
1148
+ except UnicodeDecodeError:
1149
+ # Even though onnx.ai/onnx/repo-docs/IR.html#attributes requires the attribute
1150
+ # for strings to be utf-8 encoded bytes, custom ops may still store arbitrary data there
1151
+ logger.warning(
1152
+ "Attribute %r contains invalid UTF-8 bytes. ONNX spec requires string attributes "
1153
+ "to be UTF-8 encoded so the model is invalid. We will skip decoding the attribute and "
1154
+ "use the bytes as attribute value",
1155
+ name,
1156
+ )
1157
+ return _core.Attr(name, type_, proto.s, doc_string=doc_string)
1158
+
1147
1159
  if type_ == _enums.AttributeType.INTS:
1148
1160
  return _core.AttrInt64s(name, proto.ints, doc_string=doc_string)
1149
1161
  if type_ == _enums.AttributeType.FLOATS:
@@ -1784,16 +1796,26 @@ def _fill_in_value_for_attribute(
1784
1796
  ) -> None:
1785
1797
  if type_ == _enums.AttributeType.INT:
1786
1798
  # value: int
1787
- # Cast bool to int, for example
1788
- attribute_proto.i = int(value)
1799
+ attribute_proto.i = value
1789
1800
  attribute_proto.type = onnx.AttributeProto.INT
1790
1801
  elif type_ == _enums.AttributeType.FLOAT:
1791
1802
  # value: float
1792
- attribute_proto.f = float(value)
1803
+ attribute_proto.f = value
1793
1804
  attribute_proto.type = onnx.AttributeProto.FLOAT
1794
1805
  elif type_ == _enums.AttributeType.STRING:
1795
1806
  # value: str
1796
- attribute_proto.s = value.encode("utf-8")
1807
+ if type(value) is bytes:
1808
+ # Even though onnx.ai/onnx/repo-docs/IR.html#attributes requires the attribute
1809
+ # for strings to be utf-8 encoded bytes, custom ops may still store arbitrary data there
1810
+ logger.warning(
1811
+ "Value in attribute %r should be a string but is instead bytes. ONNX "
1812
+ "spec requires string attributes to be UTF-8 encoded so the model is invalid. "
1813
+ "We will skip encoding the attribute and use the bytes as attribute value",
1814
+ attribute_proto.name,
1815
+ )
1816
+ attribute_proto.s = value
1817
+ else:
1818
+ attribute_proto.s = value.encode("utf-8")
1797
1819
  attribute_proto.type = onnx.AttributeProto.STRING
1798
1820
  elif type_ == _enums.AttributeType.INTS:
1799
1821
  # value: Sequence[int]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -19,7 +19,7 @@ License-File: LICENSE
19
19
  Requires-Dist: numpy
20
20
  Requires-Dist: onnx>=1.16
21
21
  Requires-Dist: typing_extensions>=4.10
22
- Requires-Dist: ml_dtypes
22
+ Requires-Dist: ml_dtypes>=0.5.0
23
23
  Dynamic: license-file
24
24
 
25
25
  # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
@@ -1,4 +1,4 @@
1
1
  numpy
2
2
  onnx>=1.16
3
3
  typing_extensions>=4.10
4
- ml_dtypes
4
+ ml_dtypes>=0.5.0
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes