onnx-ir 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -78,6 +78,7 @@ __all__ = [
78
78
  # Convenience constructors
79
79
  "tensor",
80
80
  "node",
81
+ "val",
81
82
  # Pass infrastructure
82
83
  "passes",
83
84
  # IO
@@ -90,7 +91,7 @@ __all__ = [
90
91
  import types
91
92
 
92
93
  from onnx_ir import convenience, external_data, passes, serde, tape, traversal
93
- from onnx_ir._convenience._constructors import node, tensor
94
+ from onnx_ir._convenience._constructors import node, tensor, val
94
95
  from onnx_ir._core import (
95
96
  Attr,
96
97
  AttrFloat32,
@@ -167,4 +168,4 @@ def __set_module() -> None:
167
168
 
168
169
 
169
170
  __set_module()
170
- __version__ = "0.1.7"
171
+ __version__ = "0.1.9"
@@ -226,7 +226,7 @@ def convert_attributes(
226
226
  ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
227
227
  ... }
228
228
  >>> convert_attributes(attrs)
229
- [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
229
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, (1, 2, 3)), Attr('floats', FLOATS, (1.0, 2.0, 3.0)), Attr('strings', STRINGS, ('hello', 'world')), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
230
230
  name='graph0',
231
231
  inputs=(
232
232
  <BLANKLINE>
@@ -235,7 +235,7 @@ def convert_attributes(
235
235
  <BLANKLINE>
236
236
  ),
237
237
  len()=0
238
- )), Attr('graphs', GRAPHS, [Graph(
238
+ )), Attr('graphs', GRAPHS, (Graph(
239
239
  name='graph1',
240
240
  inputs=(
241
241
  <BLANKLINE>
@@ -253,7 +253,7 @@ def convert_attributes(
253
253
  <BLANKLINE>
254
254
  ),
255
255
  len()=0
256
- )]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
256
+ ))), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, (Tensor(FLOAT), Tensor(FLOAT)))]
257
257
 
258
258
  .. important::
259
259
  An empty sequence should be created with an explicit type by initializing
@@ -293,7 +293,7 @@ def replace_all_uses_with(
293
293
  We want to replace the node A with a new node D::
294
294
 
295
295
  >>> import onnx_ir as ir
296
- >>> input = ir.Input("input")
296
+ >>> input = ir.val("input")
297
297
  >>> node_a = ir.Node("", "A", [input])
298
298
  >>> node_b = ir.Node("", "B", node_a.outputs)
299
299
  >>> node_c = ir.Node("", "C", node_a.outputs)
@@ -25,7 +25,7 @@ if typing.TYPE_CHECKING:
25
25
 
26
26
  def tensor(
27
27
  value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
28
- dtype: _enums.DataType | None = None,
28
+ dtype: ir.DataType | None = None,
29
29
  name: str | None = None,
30
30
  doc_string: str | None = None,
31
31
  ) -> _protocols.TensorProtocol:
@@ -159,7 +159,7 @@ def node(
159
159
  doc_string: str | None = None,
160
160
  metadata_props: dict[str, str] | None = None,
161
161
  ) -> ir.Node:
162
- """Create an :class:`~onnx_ir.Node`.
162
+ """Create a :class:`~onnx_ir.Node`.
163
163
 
164
164
  This is a convenience constructor for creating a Node that supports Python
165
165
  objects as attributes.
@@ -167,8 +167,8 @@ def node(
167
167
  Example::
168
168
 
169
169
  >>> import onnx_ir as ir
170
- >>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
171
- >>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
170
+ >>> input_a = ir.val("A", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
171
+ >>> input_b = ir.val("B", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
172
172
  >>> node = ir.node(
173
173
  ... "SomeOp",
174
174
  ... inputs=[input_a, input_b],
@@ -215,3 +215,74 @@ def node(
215
215
  doc_string=doc_string,
216
216
  metadata_props=metadata_props,
217
217
  )
218
+
219
+
220
+ def val(
221
+ name: str | None,
222
+ dtype: ir.DataType | None = None,
223
+ shape: ir.Shape | Sequence[int | str | None] | None = None,
224
+ *,
225
+ type: ir.TypeProtocol | None = None,
226
+ const_value: ir.TensorProtocol | None = None,
227
+ ) -> ir.Value:
228
+ """Create a :class:`~onnx_ir.Value` with the given name and type.
229
+
230
+ This is a convenience constructor for creating a Value that allows you to specify
231
+ dtype and shape in a more relaxed manner. Whereas to create a Value directly, you
232
+ need to create a :class:`~onnx_ir.TypeProtocol` and :class:`~onnx_ir.Shape` object
233
+ first, this function allows you to specify dtype as a :class:`~onnx_ir.DataType`
234
+ and shape as a sequence of integers or symbolic dimensions.
235
+
236
+ Example::
237
+
238
+ >>> import onnx_ir as ir
239
+ >>> t = ir.val("x", ir.DataType.FLOAT, ["N", 42, 3])
240
+ >>> t.name
241
+ 'x'
242
+ >>> t.type
243
+ Tensor(FLOAT)
244
+ >>> t.shape
245
+ Shape([SymbolicDim(N), 42, 3])
246
+
247
+ .. versionadded:: 0.1.9
248
+
249
+ Args:
250
+ name: The name of the value.
251
+ dtype: The data type of the TensorType of the value. This is used only when type is None.
252
+ shape: The shape of the value.
253
+ type: The type of the value. Only one of dtype and type can be specified.
254
+ const_value: The constant tensor that initializes the value. Supply this argument
255
+ when you want to create an initializer. The type and shape can be obtained from the tensor.
256
+
257
+ Returns:
258
+ A Value object.
259
+ """
260
+ if const_value is not None:
261
+ const_tensor_type = _core.TensorType(const_value.dtype)
262
+ if type is not None and type != const_tensor_type:
263
+ raise ValueError(
264
+ f"The type does not match the const_value. type={type} but const_value has type {const_tensor_type}. "
265
+ "You do not have to specify the type when const_value is provided."
266
+ )
267
+ if dtype is not None and dtype != const_value.dtype:
268
+ raise ValueError(
269
+ f"The dtype does not match the const_value. dtype={dtype} but const_value has dtype {const_value.dtype}. "
270
+ "You do not have to specify the dtype when const_value is provided."
271
+ )
272
+ if shape is not None and _core.Shape(shape) != const_value.shape:
273
+ raise ValueError(
274
+ f"The shape does not match the const_value. shape={shape} but const_value has shape {const_value.shape}. "
275
+ "You do not have to specify the shape when const_value is provided."
276
+ )
277
+ return _core.Value(
278
+ name=name,
279
+ type=const_tensor_type,
280
+ shape=_core.Shape(const_value.shape), # type: ignore
281
+ const_value=const_value,
282
+ )
283
+
284
+ if type is None and dtype is not None:
285
+ type = _core.TensorType(dtype)
286
+ if shape is not None and not isinstance(shape, _core.Shape):
287
+ shape = _core.Shape(shape)
288
+ return _core.Value(name=name, type=type, shape=shape)
onnx_ir/_core.py CHANGED
@@ -45,7 +45,7 @@ from typing import (
45
45
 
46
46
  import ml_dtypes
47
47
  import numpy as np
48
- from typing_extensions import TypeIs
48
+ from typing_extensions import TypeIs, deprecated
49
49
 
50
50
  import onnx_ir
51
51
  from onnx_ir import (
@@ -836,6 +836,11 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
836
836
  """The shape of the tensor. Immutable."""
837
837
  return self._shape
838
838
 
839
+ @property
840
+ def nbytes(self) -> int:
841
+ """The number of bytes in the tensor."""
842
+ return sum(len(string) for string in self.string_data())
843
+
839
844
  @property
840
845
  def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
841
846
  """Backing data of the tensor. Immutable."""
@@ -2270,6 +2275,7 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2270
2275
  return self._is_initializer
2271
2276
 
2272
2277
 
2278
+ @deprecated("Input is deprecated since 0.1.9. Use ir.val(...) instead.")
2273
2279
  def Input(
2274
2280
  name: str | None = None,
2275
2281
  shape: Shape | None = None,
@@ -3392,6 +3398,31 @@ class Attr(
3392
3398
  *,
3393
3399
  doc_string: str | None = None,
3394
3400
  ) -> None:
3401
+ # Quick checks to ensure that INT and FLOAT attributes are stored as int and float,
3402
+ # not np.int32, np.float32, bool, etc.
3403
+ # This also allows errors to be raised at the time of construction instead of later
3404
+ # during serialization.
3405
+ # TODO(justinchuby): Use case matching when we drop support for Python 3.9
3406
+ if value is None:
3407
+ # Value can be None for reference attributes or when it is used as a
3408
+ # placeholder for schemas
3409
+ pass
3410
+ elif type == _enums.AttributeType.INT:
3411
+ value = int(value)
3412
+ elif type == _enums.AttributeType.FLOAT:
3413
+ value = float(value)
3414
+ elif type == _enums.AttributeType.INTS:
3415
+ value = tuple(int(v) for v in value)
3416
+ elif type == _enums.AttributeType.FLOATS:
3417
+ value = tuple(float(v) for v in value)
3418
+ elif type in {
3419
+ _enums.AttributeType.STRINGS,
3420
+ _enums.AttributeType.TENSORS,
3421
+ _enums.AttributeType.GRAPHS,
3422
+ _enums.AttributeType.TYPE_PROTOS,
3423
+ }:
3424
+ value = tuple(value)
3425
+
3395
3426
  self._name = name
3396
3427
  self._type = type
3397
3428
  self._value = value
@@ -3453,7 +3484,7 @@ class Attr(
3453
3484
  return f"@{self.ref_attr_name}"
3454
3485
  if self.type == _enums.AttributeType.GRAPH:
3455
3486
  return textwrap.indent("\n" + str(self.value), " " * 4)
3456
- return str(self.value)
3487
+ return repr(self.value)
3457
3488
 
3458
3489
  def __repr__(self) -> str:
3459
3490
  if self.is_ref():
@@ -3467,8 +3498,8 @@ class Attr(
3467
3498
  raise TypeError(
3468
3499
  f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
3469
3500
  )
3470
- # Do not use isinstance check because it may prevent np.float32 etc. from being used
3471
- return float(self.value)
3501
+ # value is guaranteed to be a float in the constructor
3502
+ return self.value
3472
3503
 
3473
3504
  def as_int(self) -> int:
3474
3505
  """Get the attribute value as an int."""
@@ -3476,8 +3507,8 @@ class Attr(
3476
3507
  raise TypeError(
3477
3508
  f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
3478
3509
  )
3479
- # Do not use isinstance check because it may prevent np.int32 etc. from being used
3480
- return int(self.value)
3510
+ # value is guaranteed to be an int in the constructor
3511
+ return self.value
3481
3512
 
3482
3513
  def as_string(self) -> str:
3483
3514
  """Get the attribute value as a string."""
@@ -3485,9 +3516,10 @@ class Attr(
3485
3516
  raise TypeError(
3486
3517
  f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
3487
3518
  )
3488
- if not isinstance(self.value, str):
3519
+ value = self.value
3520
+ if not isinstance(value, str):
3489
3521
  raise TypeError(f"Value of attribute '{self!r}' is not a string.")
3490
- return self.value
3522
+ return value
3491
3523
 
3492
3524
  def as_tensor(self) -> _protocols.TensorProtocol:
3493
3525
  """Get the attribute value as a tensor."""
@@ -3495,9 +3527,10 @@ class Attr(
3495
3527
  raise TypeError(
3496
3528
  f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
3497
3529
  )
3498
- if not isinstance(self.value, _protocols.TensorProtocol):
3530
+ value = self.value
3531
+ if not isinstance(value, _protocols.TensorProtocol):
3499
3532
  raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
3500
- return self.value
3533
+ return value
3501
3534
 
3502
3535
  def as_graph(self) -> Graph:
3503
3536
  """Get the attribute value as a graph."""
@@ -3505,75 +3538,64 @@ class Attr(
3505
3538
  raise TypeError(
3506
3539
  f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
3507
3540
  )
3508
- if not isinstance(self.value, Graph):
3541
+ value = self.value
3542
+ if not isinstance(value, Graph):
3509
3543
  raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
3510
- return self.value
3544
+ return value
3511
3545
 
3512
- def as_floats(self) -> Sequence[float]:
3546
+ def as_floats(self) -> tuple[float, ...]:
3513
3547
  """Get the attribute value as a sequence of floats."""
3514
3548
  if self.type != _enums.AttributeType.FLOATS:
3515
3549
  raise TypeError(
3516
3550
  f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
3517
3551
  )
3518
- if not isinstance(self.value, Sequence):
3519
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3520
- # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3521
- # Create a copy of the list to prevent mutation
3522
- return [float(v) for v in self.value]
3552
+ # value is guaranteed to be a sequence of float in the constructor
3553
+ return self.value
3523
3554
 
3524
- def as_ints(self) -> Sequence[int]:
3555
+ def as_ints(self) -> tuple[int, ...]:
3525
3556
  """Get the attribute value as a sequence of ints."""
3526
3557
  if self.type != _enums.AttributeType.INTS:
3527
3558
  raise TypeError(
3528
3559
  f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
3529
3560
  )
3530
- if not isinstance(self.value, Sequence):
3531
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3532
- # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3533
- # Create a copy of the list to prevent mutation
3534
- return list(self.value)
3561
+ # value is guaranteed to be a sequence of int in the constructor
3562
+ return self.value
3535
3563
 
3536
- def as_strings(self) -> Sequence[str]:
3564
+ def as_strings(self) -> tuple[str, ...]:
3537
3565
  """Get the attribute value as a sequence of strings."""
3538
3566
  if self.type != _enums.AttributeType.STRINGS:
3539
3567
  raise TypeError(
3540
3568
  f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
3541
3569
  )
3542
- if not isinstance(self.value, Sequence):
3543
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3544
3570
  if onnx_ir.DEBUG:
3545
3571
  if not all(isinstance(x, str) for x in self.value):
3546
3572
  raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
3547
- # Create a copy of the list to prevent mutation
3548
- return list(self.value)
3573
+ # value is guaranteed to be a sequence in the constructor
3574
+ return self.value
3549
3575
 
3550
- def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
3576
+ def as_tensors(self) -> tuple[_protocols.TensorProtocol, ...]:
3551
3577
  """Get the attribute value as a sequence of tensors."""
3552
3578
  if self.type != _enums.AttributeType.TENSORS:
3553
3579
  raise TypeError(
3554
3580
  f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
3555
3581
  )
3556
- if not isinstance(self.value, Sequence):
3557
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3558
3582
  if onnx_ir.DEBUG:
3559
3583
  if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
3560
3584
  raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
3561
- # Create a copy of the list to prevent mutation
3562
- return list(self.value)
3585
+ # value is guaranteed to be a sequence in the constructor
3586
+ return tuple(self.value)
3563
3587
 
3564
- def as_graphs(self) -> Sequence[Graph]:
3588
+ def as_graphs(self) -> tuple[Graph, ...]:
3565
3589
  """Get the attribute value as a sequence of graphs."""
3566
3590
  if self.type != _enums.AttributeType.GRAPHS:
3567
3591
  raise TypeError(
3568
3592
  f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
3569
3593
  )
3570
- if not isinstance(self.value, Sequence):
3571
- raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3572
3594
  if onnx_ir.DEBUG:
3573
3595
  if not all(isinstance(x, Graph) for x in self.value):
3574
3596
  raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
3575
- # Create a copy of the list to prevent mutation
3576
- return list(self.value)
3597
+ # value is guaranteed to be a sequence in the constructor
3598
+ return tuple(self.value)
3577
3599
 
3578
3600
 
3579
3601
  # NOTE: The following functions are just for convenience
@@ -3600,7 +3622,7 @@ def RefAttr(
3600
3622
  return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
3601
3623
 
3602
3624
 
3603
- def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
3625
+ def AttrFloat32(name: str, value: float | np.floating, doc_string: str | None = None) -> Attr:
3604
3626
  """Create a float attribute."""
3605
3627
  # NOTE: The function name is capitalized to maintain API backward compatibility.
3606
3628
  return Attr(
@@ -3611,7 +3633,7 @@ def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
3611
3633
  )
3612
3634
 
3613
3635
 
3614
- def AttrInt64(name: str, value: int, doc_string: str | None = None) -> Attr:
3636
+ def AttrInt64(name: str, value: int | np.integer, doc_string: str | None = None) -> Attr:
3615
3637
  """Create an int attribute."""
3616
3638
  # NOTE: The function name is capitalized to maintain API backward compatibility.
3617
3639
  return Attr(
onnx_ir/_enums.py CHANGED
@@ -5,6 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import enum
8
+ from typing import Any
8
9
 
9
10
  import ml_dtypes
10
11
  import numpy as np
@@ -77,7 +78,7 @@ class DataType(enum.IntEnum):
77
78
  if dtype in _NP_TYPE_TO_DATA_TYPE:
78
79
  return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
79
80
 
80
- if np.issubdtype(dtype, np.str_):
81
+ if np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_):
81
82
  return DataType.STRING
82
83
 
83
84
  # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
@@ -131,6 +132,146 @@ class DataType(enum.IntEnum):
131
132
  raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
132
133
  return _BITWIDTH_MAP[self]
133
134
 
135
+ @property
136
+ def exponent_bitwidth(self) -> int:
137
+ """Returns the bit width of the exponent for floating-point types.
138
+
139
+ .. versionadded:: 0.1.8
140
+
141
+ Raises:
142
+ TypeError: If the data type is not supported.
143
+ """
144
+ if self.is_floating_point():
145
+ return ml_dtypes.finfo(self.numpy()).nexp
146
+
147
+ raise TypeError(f"Exponent not available for ONNX data type: {self}")
148
+
149
+ @property
150
+ def mantissa_bitwidth(self) -> int:
151
+ """Returns the bit width of the mantissa for floating-point types.
152
+
153
+ .. versionadded:: 0.1.8
154
+
155
+ Raises:
156
+ TypeError: If the data type is not supported.
157
+ """
158
+ if self.is_floating_point():
159
+ return ml_dtypes.finfo(self.numpy()).nmant
160
+
161
+ raise TypeError(f"Mantissa not available for ONNX data type: {self}")
162
+
163
+ @property
164
+ def eps(self) -> int | np.floating[Any]:
165
+ """Returns the difference between 1.0 and the next smallest representable float larger than 1.0 for the ONNX data type.
166
+
167
+ Returns 1 for integers.
168
+
169
+ .. versionadded:: 0.1.8
170
+
171
+ Raises:
172
+ TypeError: If the data type is not a numeric data type.
173
+ """
174
+ if self.is_integer():
175
+ return 1
176
+
177
+ if self.is_floating_point():
178
+ return ml_dtypes.finfo(self.numpy()).eps
179
+
180
+ raise TypeError(f"Eps not available for ONNX data type: {self}")
181
+
182
+ @property
183
+ def tiny(self) -> int | np.floating[Any]:
184
+ """Returns the smallest positive non-zero value for the ONNX data type.
185
+
186
+ Returns 1 for integers.
187
+
188
+ .. versionadded:: 0.1.8
189
+
190
+ Raises:
191
+ TypeError: If the data type is not a numeric data type.
192
+ """
193
+ if self.is_integer():
194
+ return 1
195
+
196
+ if self.is_floating_point():
197
+ return ml_dtypes.finfo(self.numpy()).tiny
198
+
199
+ raise TypeError(f"Tiny not available for ONNX data type: {self}")
200
+
201
+ @property
202
+ def min(self) -> int | np.floating[Any]:
203
+ """Returns the minimum representable value for the ONNX data type.
204
+
205
+ .. versionadded:: 0.1.8
206
+
207
+ Raises:
208
+ TypeError: If the data type is not a numeric data type.
209
+ """
210
+ if self.is_integer():
211
+ return ml_dtypes.iinfo(self.numpy()).min
212
+
213
+ if self.is_floating_point():
214
+ return ml_dtypes.finfo(self.numpy()).min
215
+
216
+ raise TypeError(f"Minimum not available for ONNX data type: {self}")
217
+
218
+ @property
219
+ def max(self) -> int | np.floating[Any]:
220
+ """Returns the maximum representable value for the ONNX data type.
221
+
222
+ .. versionadded:: 0.1.8
223
+
224
+ Raises:
225
+ TypeError: If the data type is not a numeric data type.
226
+ """
227
+ if self.is_integer():
228
+ return ml_dtypes.iinfo(self.numpy()).max
229
+
230
+ if self.is_floating_point():
231
+ return ml_dtypes.finfo(self.numpy()).max
232
+
233
+ raise TypeError(f"Maximum not available for ONNX data type: {self}")
234
+
235
+ @property
236
+ def precision(self) -> int:
237
+ """Returns the precision for the ONNX dtype if supported.
238
+
239
+ For floats returns the approximate number of decimal digits to which
240
+ this kind of float is precise. Returns 0 for integers.
241
+
242
+ .. versionadded:: 0.1.8
243
+
244
+ Raises:
245
+ TypeError: If the data type is not a numeric data type.
246
+ """
247
+ if self.is_integer():
248
+ return 0
249
+
250
+ if self.is_floating_point():
251
+ return ml_dtypes.finfo(self.numpy()).precision
252
+
253
+ raise TypeError(f"Precision not available for ONNX data type: {self}")
254
+
255
+ @property
256
+ def resolution(self) -> int | np.floating[Any]:
257
+ """Returns the resolution for the ONNX dtype if supported.
258
+
259
+ Returns the approximate decimal resolution of this type, i.e.,
260
+ 10**-precision. Returns 1 for integers.
261
+
262
+ .. versionadded:: 0.1.8
263
+
264
+ Raises:
265
+ TypeError: If the data type is not a numeric data type.
266
+ """
267
+ if self.is_integer():
268
+ return 1
269
+
270
+ if self.is_floating_point():
271
+ return ml_dtypes.finfo(self.numpy()).resolution
272
+
273
+ raise TypeError(f"Resolution not available for ONNX data type: {self}")
274
+
134
275
  def numpy(self) -> np.dtype:
135
276
  """Returns the numpy dtype for the ONNX data type.
136
277
 
@@ -215,6 +356,13 @@ class DataType(enum.IntEnum):
215
356
  DataType.FLOAT8E8M0,
216
357
  }
217
358
 
359
+ def is_string(self) -> bool:
360
+ """Returns True if the data type is a string type.
361
+
362
+ .. versionadded:: 0.1.8
363
+ """
364
+ return self == DataType.STRING
365
+
218
366
  def __repr__(self) -> str:
219
367
  return self.name
220
368
 
@@ -274,15 +422,9 @@ _NP_TYPE_TO_DATA_TYPE = {
274
422
  np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
275
423
  np.dtype(ml_dtypes.int4): DataType.INT4,
276
424
  np.dtype(ml_dtypes.uint4): DataType.UINT4,
425
+ np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1,
277
426
  }
278
427
 
279
- # TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
280
- _NP_TYPE_TO_DATA_TYPE.update(
281
- {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
282
- if hasattr(ml_dtypes, "float4_e2m1fn")
283
- else {}
284
- )
285
-
286
428
  # ONNX DataType to Numpy dtype.
287
429
  _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
288
430
 
@@ -10,6 +10,8 @@ __all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
10
10
  import hashlib
11
11
  import logging
12
12
 
13
+ import numpy as np
14
+
13
15
  import onnx_ir as ir
14
16
 
15
17
  logger = logging.getLogger(__name__)
@@ -42,17 +44,27 @@ def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
42
44
  size_limit,
43
45
  )
44
46
  return True
45
-
46
- if const_val.dtype == ir.DataType.STRING:
47
- # Skip string initializers as they don't have a bytes representation
48
- logger.warning(
49
- "Skipped deduplication of string initializer '%s' (unsupported yet)",
50
- initializer.name,
51
- )
52
- return True
53
47
  return False
54
48
 
55
49
 
50
+ def _tobytes(val):
51
+ """StringTensor does not support tobytes. Use 'string_data' instead.
52
+
53
+ However, 'string_data' yields a list of bytes which cannot be hashed, i.e.,
54
+ cannot be used to index into a dict. To generate keys for identifying
55
+ tensors in initializer deduplication the following converts the list of
56
+ bytes to an array of fixed-length strings which can be flattened into a
57
+ bytes-string. This, together with the tensor shape, is sufficient for
58
+ identifying tensors for deduplication, but it differs from the
59
+ representation used for serializing tensors (that is string_data) by adding
60
+ padding bytes so that each string occupies the same number of consecutive
61
+ bytes in the flattened .tobytes representation.
62
+ """
63
+ if val.dtype.is_string():
64
+ return np.array(val.string_data()).tobytes()
65
+ return val.tobytes()
66
+
67
+
56
68
  class DeduplicateInitializersPass(ir.passes.InPlacePass):
57
69
  """Remove duplicated initializer tensors from the main graph and all subgraphs.
58
70
 
@@ -84,7 +96,7 @@ class DeduplicateInitializersPass(ir.passes.InPlacePass):
84
96
  const_val = initializer.const_value
85
97
  assert const_val is not None
86
98
 
87
- key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
99
+ key = (const_val.dtype, tuple(const_val.shape), _tobytes(const_val))
88
100
  if key in initializers:
89
101
  modified = True
90
102
  initializer_to_keep = initializers[key] # type: ignore[index]
@@ -143,7 +155,7 @@ class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
143
155
  key = (const_val.dtype, tensor_dims, tensor_digest)
144
156
 
145
157
  if key in initializers:
146
- if initializers[key].const_value.tobytes() != const_val.tobytes():
158
+ if _tobytes(initializers[key].const_value) != _tobytes(const_val):
147
159
  logger.warning(
148
160
  "Initializer deduplication failed: "
149
161
  "hashes match but values differ with values %s and %s",
onnx_ir/serde.py CHANGED
@@ -711,7 +711,7 @@ def _deserialize_graph(
711
711
 
712
712
  # Create values for initializers and inputs
713
713
  initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
714
- inputs = [_core.Input(info.name) for info in proto.input]
714
+ inputs = [_core.Value(name=info.name) for info in proto.input]
715
715
  for info, value in zip(proto.input, inputs):
716
716
  deserialize_value_info_proto(info, value)
717
717
 
@@ -869,7 +869,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
869
869
  Returns:
870
870
  An IR Function object representing the ONNX function.
871
871
  """
872
- inputs = [_core.Input(name) for name in proto.input]
872
+ inputs = [_core.Value(name=name) for name in proto.input]
873
873
  values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
874
874
  value_info = {info.name: info for info in getattr(proto, "value_info", [])}
875
875
 
@@ -1143,7 +1143,19 @@ def _deserialize_attribute(
1143
1143
  if type_ == _enums.AttributeType.FLOAT:
1144
1144
  return _core.AttrFloat32(name, proto.f, doc_string=doc_string)
1145
1145
  if type_ == _enums.AttributeType.STRING:
1146
- return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string)
1146
+ try:
1147
+ return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string)
1148
+ except UnicodeDecodeError:
1149
+ # Even though onnx.ai/onnx/repo-docs/IR.html#attributes requires the attribute
1150
+ # for strings to be utf-8 encoded bytes, custom ops may still store arbitrary data there
1151
+ logger.warning(
1152
+ "Attribute %r contains invalid UTF-8 bytes. ONNX spec requires string attributes "
1153
+ "to be UTF-8 encoded so the model is invalid. We will skip decoding the attribute and "
1154
+ "use the bytes as attribute value",
1155
+ name,
1156
+ )
1157
+ return _core.Attr(name, type_, proto.s, doc_string=doc_string)
1158
+
1147
1159
  if type_ == _enums.AttributeType.INTS:
1148
1160
  return _core.AttrInt64s(name, proto.ints, doc_string=doc_string)
1149
1161
  if type_ == _enums.AttributeType.FLOATS:
@@ -1792,7 +1804,18 @@ def _fill_in_value_for_attribute(
1792
1804
  attribute_proto.type = onnx.AttributeProto.FLOAT
1793
1805
  elif type_ == _enums.AttributeType.STRING:
1794
1806
  # value: str
1795
- attribute_proto.s = value.encode("utf-8")
1807
+ if type(value) is bytes:
1808
+ # Even though onnx.ai/onnx/repo-docs/IR.html#attributes requires the attribute
1809
+ # for strings to be utf-8 encoded bytes, custom ops may still store arbitrary data there
1810
+ logger.warning(
1811
+ "Value in attribute %r should be a string but is instead bytes. ONNX "
1812
+ "spec requires string attributes to be UTF-8 encoded so the model is invalid. "
1813
+ "We will skip encoding the attribute and use the bytes as attribute value",
1814
+ attribute_proto.name,
1815
+ )
1816
+ attribute_proto.s = value
1817
+ else:
1818
+ attribute_proto.s = value.encode("utf-8")
1796
1819
  attribute_proto.type = onnx.AttributeProto.STRING
1797
1820
  elif type_ == _enums.AttributeType.INTS:
1798
1821
  # value: Sequence[int]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -19,10 +19,10 @@ License-File: LICENSE
19
19
  Requires-Dist: numpy
20
20
  Requires-Dist: onnx>=1.16
21
21
  Requires-Dist: typing_extensions>=4.10
22
- Requires-Dist: ml_dtypes
22
+ Requires-Dist: ml_dtypes>=0.5.0
23
23
  Dynamic: license-file
24
24
 
25
- # ONNX IR
25
+ # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
26
26
 
27
27
  [![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
28
28
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
@@ -60,6 +60,10 @@ pip install git+https://github.com/onnx/ir-py.git
60
60
  - Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
61
61
  - No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
62
62
 
63
+ ## Concept Diagram
64
+
65
+ ![Concept Diagram](docs/resource/onnx-ir-entities.svg)
66
+
63
67
  ## Code Organization 🗺️
64
68
 
65
69
  - [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR.
@@ -1,7 +1,7 @@
1
- onnx_ir/__init__.py,sha256=GkXeM2FSKjT0TUO8ezCJdT1yHZKdtQ6keZKx2a3BluI,3424
2
- onnx_ir/_core.py,sha256=XQRd43VQj72qBGLa_4x9NEjjfhM0rxJ7qT6sLKA_rGA,139032
1
+ onnx_ir/__init__.py,sha256=GONmwgFPw_4lRywnqZUQz_oOG8p-JP-PwaUAiYKls8Q,3440
2
+ onnx_ir/_core.py,sha256=ALDyEiVvZP6bsAmnBSYKPgCeKBHqcYVv5_wAHwRhf20,139578
3
3
  onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
4
- onnx_ir/_enums.py,sha256=SxC-GGgPrmdz6UsMhx7xT9-6VmkZ6j1oVzDqNUHr3Rc,9659
4
+ onnx_ir/_enums.py,sha256=E7WQ7yQzulBeimamc9q_k4fEUoyH_2PWtaOMpwck_W0,13915
5
5
  onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
6
6
  onnx_ir/_graph_containers.py,sha256=PRKrshRZ5rzWCgRs1TefzJq9n8wyo7OqeKy3XxMhyys,14265
7
7
  onnx_ir/_io.py,sha256=GWwA4XOZ-ZX1cgibgaYD0K0O5d9LX21ZwcBN02Wrh04,5205
@@ -16,13 +16,13 @@ onnx_ir/_version_utils.py,sha256=bZThuE7meVHFOY1DLsmss9WshVIp9iig7udGfDbVaK4,133
16
16
  onnx_ir/convenience.py,sha256=0B1epuXZCSmY4FbW2vaYfR-t5ubxBZ1UruiytHs-zFw,917
17
17
  onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,18079
18
18
  onnx_ir/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
19
- onnx_ir/serde.py,sha256=Ld00k4L_TJ50T8FA0myV0C1hLr7EqwujZk6bBr_nGLQ,78174
19
+ onnx_ir/serde.py,sha256=2us-I2h3_BXxM9aYyXFOdq5v_cLvLIwyHGpIZbhL2W0,79459
20
20
  onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
21
21
  onnx_ir/tensor_adapters.py,sha256=YffUeZDZi8thxm-4nF2cL6cNSJSVmLm4A3IbEzwY8QQ,7233
22
22
  onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
23
23
  onnx_ir/traversal.py,sha256=Wy4XphwuapAvm94-5iaz6G8LjIoMFpY7qfPfXzYViEE,4488
24
- onnx_ir/_convenience/__init__.py,sha256=bXUxjZ_91idQJ33zWtByQ0J4VsWCUdvAy9iIflpLtW8,19754
25
- onnx_ir/_convenience/_constructors.py,sha256=5GhlYy_xCE2ng7l_4cNx06WQsNDyvS-0U1HgOpPKJEk,8347
24
+ onnx_ir/_convenience/__init__.py,sha256=SO7kc8RXVKEUODGh0q2Y7WgmbUsOjYSixmKFx_A0DAQ,19752
25
+ onnx_ir/_convenience/_constructors.py,sha256=ETYrhJ5eg4ozf4K9C-5mT1vw1lxrdLCcWE4CJEGbl-k,11304
26
26
  onnx_ir/_thirdparty/asciichartpy.py,sha256=afQ0fsqko2uYRPAR4TZBrQxvCb4eN8lxZ2yDFbVQq_s,10533
27
27
  onnx_ir/passes/__init__.py,sha256=M_Tcl_-qGSNPluFIvOoeDyh0qAwNayaYyXDS5UJUJPQ,764
28
28
  onnx_ir/passes/_pass_infra.py,sha256=xIOw_zZIuOqD4Z_wZ4OvsqXfh2IZMoMlDp1xQ_MPQlc,9567
@@ -32,15 +32,15 @@ onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMO
32
32
  onnx_ir/passes/common/common_subexpression_elimination.py,sha256=wZ1zEPdCshYB_ifP9fCAVfzQkesE6uhCfzCuL2qO5fA,7948
33
33
  onnx_ir/passes/common/constant_manipulation.py,sha256=dFzzqbpRecJJrYf6edvR_sdr4F0gV-1wEtDXsQ7fStM,9101
34
34
  onnx_ir/passes/common/identity_elimination.py,sha256=wN8g8uPGn6IIQ6Jf1lo6nGTXvpWyiSQtT_CfmtvZpwA,3664
35
- onnx_ir/passes/common/initializer_deduplication.py,sha256=k6IZdXrjANbVhTQCQAPIePUjqF83NG3YGwEYThYJJ7o,6655
35
+ onnx_ir/passes/common/initializer_deduplication.py,sha256=gKrXTMFAtCkMmiIm8zWzwPnwSbRdZxunJeAt_jFU-vY,7253
36
36
  onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
37
37
  onnx_ir/passes/common/naming.py,sha256=NNKc9IPrmzm3J0zGQILfooayVzfdXDYHY9DHex1hFgs,10927
38
38
  onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6svE0cCyDew,1691
39
39
  onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
40
40
  onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
41
41
  onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
42
- onnx_ir-0.1.7.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
43
- onnx_ir-0.1.7.dist-info/METADATA,sha256=M4-BdpNXpv18P_tALf6KdUdXeCO2JrVxbxtzs4HCmJI,3462
44
- onnx_ir-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- onnx_ir-0.1.7.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
46
- onnx_ir-0.1.7.dist-info/RECORD,,
42
+ onnx_ir-0.1.9.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
43
+ onnx_ir-0.1.9.dist-info/METADATA,sha256=UnIaOip9p965JE-B8Kb3cuUrDTuoGoQurFmLAgsWdAA,3604
44
+ onnx_ir-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ onnx_ir-0.1.9.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
46
+ onnx_ir-0.1.9.dist-info/RECORD,,