onnx-ir 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.3"
170
+ __version__ = "0.1.5"
onnx_ir/_core.py CHANGED
@@ -78,6 +78,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
78
78
  _enums.DataType.FLOAT8E4M3FNUZ,
79
79
  _enums.DataType.FLOAT8E5M2,
80
80
  _enums.DataType.FLOAT8E5M2FNUZ,
81
+ _enums.DataType.FLOAT8E8M0,
81
82
  _enums.DataType.INT4,
82
83
  _enums.DataType.UINT4,
83
84
  _enums.DataType.FLOAT4E2M1,
@@ -261,6 +262,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
261
262
  ml_dtypes.float8_e4m3fn,
262
263
  ml_dtypes.float8_e5m2fnuz,
263
264
  ml_dtypes.float8_e5m2,
265
+ ml_dtypes.float8_e8m0fnu,
264
266
  ):
265
267
  raise TypeError(
266
268
  f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
@@ -319,6 +321,8 @@ def _maybe_view_np_array_with_ml_dtypes(
319
321
  return array.view(ml_dtypes.float8_e5m2)
320
322
  if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
321
323
  return array.view(ml_dtypes.float8_e5m2fnuz)
324
+ if dtype == _enums.DataType.FLOAT8E8M0:
325
+ return array.view(ml_dtypes.float8_e8m0fnu)
322
326
  if dtype == _enums.DataType.INT4:
323
327
  return array.view(ml_dtypes.int4)
324
328
  if dtype == _enums.DataType.UINT4:
@@ -657,15 +661,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
657
661
  self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
658
662
  return
659
663
  # Map the whole file into the memory
660
- # TODO(justinchuby): Verify if this would exhaust the memory address space
661
664
  with open(self.path, "rb") as f:
662
665
  self.raw = mmap.mmap(
663
666
  f.fileno(),
664
667
  0,
665
668
  access=mmap.ACCESS_READ,
666
669
  )
667
- # Handle the byte order correctly by always using little endian
668
- dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
670
+
669
671
  if self.dtype in {
670
672
  _enums.DataType.INT4,
671
673
  _enums.DataType.UINT4,
@@ -675,16 +677,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
675
677
  dt = np.dtype(np.uint8).newbyteorder("<")
676
678
  count = self.size // 2 + self.size % 2
677
679
  else:
680
+ # Handle the byte order correctly by always using little endian
681
+ dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
678
682
  count = self.size
683
+
679
684
  self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
680
685
  shape = self.shape.numpy()
681
- if self.dtype == _enums.DataType.INT4:
682
- # Unpack the int4 arrays
683
- self._array = _type_casting.unpack_int4(self._array, shape)
684
- elif self.dtype == _enums.DataType.UINT4:
685
- self._array = _type_casting.unpack_uint4(self._array, shape)
686
- elif self.dtype == _enums.DataType.FLOAT4E2M1:
687
- self._array = _type_casting.unpack_float4e2m1(self._array, shape)
686
+
687
+ if self.dtype.bitwidth == 4:
688
+ # Unpack the 4bit arrays
689
+ self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
690
+ self.dtype.numpy()
691
+ )
688
692
  else:
689
693
  self._array = self._array.reshape(shape)
690
694
 
@@ -1071,15 +1075,7 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
1071
1075
  """
1072
1076
  array = self.numpy_packed()
1073
1077
  # ONNX IR returns the unpacked arrays
1074
- if self.dtype == _enums.DataType.INT4:
1075
- return _type_casting.unpack_int4(array, self.shape.numpy())
1076
- if self.dtype == _enums.DataType.UINT4:
1077
- return _type_casting.unpack_uint4(array, self.shape.numpy())
1078
- if self.dtype == _enums.DataType.FLOAT4E2M1:
1079
- return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
1080
- raise TypeError(
1081
- f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
1082
- )
1078
+ return _type_casting.unpack_4bitx2(array, self.shape.numpy()).view(self.dtype.numpy())
1083
1079
 
1084
1080
  def numpy_packed(self) -> npt.NDArray[np.uint8]:
1085
1081
  """Return the tensor as a packed array."""
onnx_ir/_enums.py CHANGED
@@ -65,6 +65,7 @@ class DataType(enum.IntEnum):
65
65
  UINT4 = 21
66
66
  INT4 = 22
67
67
  FLOAT4E2M1 = 23
68
+ FLOAT8E8M0 = 24
68
69
 
69
70
  @classmethod
70
71
  def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -81,6 +82,7 @@ class DataType(enum.IntEnum):
81
82
 
82
83
  # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
83
84
  # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
85
+ # TODO(#137): Remove this when ONNX 1.19 is the minimum requirement
84
86
  if hasattr(dtype, "names"):
85
87
  if dtype.names == ("bfloat16",):
86
88
  return DataType.BFLOAT16
@@ -167,6 +169,50 @@ class DataType(enum.IntEnum):
167
169
  DataType.FLOAT8E5M2,
168
170
  DataType.FLOAT8E5M2FNUZ,
169
171
  DataType.FLOAT4E2M1,
172
+ DataType.FLOAT8E8M0,
173
+ }
174
+
175
+ def is_integer(self) -> bool:
176
+ """Returns True if the data type is an integer.
177
+
178
+ .. versionadded:: 0.1.4
179
+ """
180
+ return self in {
181
+ DataType.UINT8,
182
+ DataType.INT8,
183
+ DataType.UINT16,
184
+ DataType.INT16,
185
+ DataType.INT32,
186
+ DataType.INT64,
187
+ DataType.UINT32,
188
+ DataType.UINT64,
189
+ DataType.UINT4,
190
+ DataType.INT4,
191
+ }
192
+
193
+ def is_signed(self) -> bool:
194
+ """Returns True if the data type is a signed type.
195
+
196
+ .. versionadded:: 0.1.4
197
+ """
198
+ return self in {
199
+ DataType.FLOAT,
200
+ DataType.INT8,
201
+ DataType.INT16,
202
+ DataType.INT32,
203
+ DataType.INT64,
204
+ DataType.FLOAT16,
205
+ DataType.DOUBLE,
206
+ DataType.COMPLEX64,
207
+ DataType.COMPLEX128,
208
+ DataType.BFLOAT16,
209
+ DataType.FLOAT8E4M3FN,
210
+ DataType.FLOAT8E4M3FNUZ,
211
+ DataType.FLOAT8E5M2,
212
+ DataType.FLOAT8E5M2FNUZ,
213
+ DataType.INT4,
214
+ DataType.FLOAT4E2M1,
215
+ DataType.FLOAT8E8M0,
170
216
  }
171
217
 
172
218
  def __repr__(self) -> str:
@@ -199,6 +245,7 @@ _BITWIDTH_MAP = {
199
245
  DataType.UINT4: 4,
200
246
  DataType.INT4: 4,
201
247
  DataType.FLOAT4E2M1: 4,
248
+ DataType.FLOAT8E8M0: 8,
202
249
  }
203
250
 
204
251
 
@@ -224,6 +271,7 @@ _NP_TYPE_TO_DATA_TYPE = {
224
271
  np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
225
272
  np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
226
273
  np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
274
+ np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
227
275
  np.dtype(ml_dtypes.int4): DataType.INT4,
228
276
  np.dtype(ml_dtypes.uint4): DataType.UINT4,
229
277
  }
@@ -248,6 +296,7 @@ _DATA_TYPE_TO_SHORT_NAME = {
248
296
  DataType.FLOAT8E5M2: "f8e5m2",
249
297
  DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
250
298
  DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
299
+ DataType.FLOAT8E8M0: "f8e8m0",
251
300
  DataType.FLOAT4E2M1: "f4e2m1",
252
301
  DataType.COMPLEX64: "c64",
253
302
  DataType.COMPLEX128: "c128",
onnx_ir/_type_casting.py CHANGED
@@ -1,14 +1,12 @@
1
1
  # Copyright (c) ONNX Project Contributors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Numpy utilities for non-native type operation."""
4
- # TODO(justinchuby): Upstream the logic to onnx
5
4
 
6
5
  from __future__ import annotations
7
6
 
8
7
  import typing
9
8
  from collections.abc import Sequence
10
9
 
11
- import ml_dtypes
12
10
  import numpy as np
13
11
 
14
12
  if typing.TYPE_CHECKING:
@@ -28,9 +26,7 @@ def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
28
26
  return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
29
27
 
30
28
 
31
- def _unpack_uint4_as_uint8(
32
- data: npt.NDArray[np.uint8], dims: Sequence[int]
33
- ) -> npt.NDArray[np.uint8]:
29
+ def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
34
30
  """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
35
31
 
36
32
  Args:
@@ -52,56 +48,3 @@ def _unpack_uint4_as_uint8(
52
48
  result = result[:-1]
53
49
  result.resize(dims, refcheck=False)
54
50
  return result
55
-
56
-
57
- def unpack_uint4(
58
- data: npt.NDArray[np.uint8], dims: Sequence[int]
59
- ) -> npt.NDArray[ml_dtypes.uint4]:
60
- """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
61
-
62
- Args:
63
- data: A numpy array.
64
- dims: The dimensions are used to reshape the unpacked buffer.
65
-
66
- Returns:
67
- A numpy array of int8/uint8 reshaped to dims.
68
- """
69
- return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.uint4)
70
-
71
-
72
- def _extend_int4_sign_bits(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]:
73
- """Extend 4-bit signed integer to 8-bit signed integer."""
74
- return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8)
75
-
76
-
77
- def unpack_int4(
78
- data: npt.NDArray[np.uint8], dims: Sequence[int]
79
- ) -> npt.NDArray[ml_dtypes.int4]:
80
- """Convert a packed (signed) int4 array to unpacked int4 array represented as int8.
81
-
82
- The sign bit is extended to the most significant bit of the int8.
83
-
84
- Args:
85
- data: A numpy array.
86
- dims: The dimensions are used to reshape the unpacked buffer.
87
-
88
- Returns:
89
- A numpy array of int8 reshaped to dims.
90
- """
91
- unpacked = _unpack_uint4_as_uint8(data, dims)
92
- return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
93
-
94
-
95
- def unpack_float4e2m1(
96
- data: npt.NDArray[np.uint8], dims: Sequence[int]
97
- ) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
98
- """Convert a packed float4e2m1 array to unpacked float4e2m1 array.
99
-
100
- Args:
101
- data: A numpy array.
102
- dims: The dimensions are used to reshape the unpacked buffer.
103
-
104
- Returns:
105
- A numpy array of float32 reshaped to dims.
106
- """
107
- return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)
@@ -7,6 +7,7 @@ __all__ = [
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
9
  "DeduplicateInitializersPass",
10
+ "IdentityEliminationPass",
10
11
  "InlinePass",
11
12
  "LiftConstantsToInitializersPass",
12
13
  "LiftSubgraphInitializersToMainGraphPass",
@@ -30,6 +31,9 @@ from onnx_ir.passes.common.constant_manipulation import (
30
31
  LiftSubgraphInitializersToMainGraphPass,
31
32
  RemoveInitializersFromInputsPass,
32
33
  )
34
+ from onnx_ir.passes.common.identity_elimination import (
35
+ IdentityEliminationPass,
36
+ )
33
37
  from onnx_ir.passes.common.initializer_deduplication import (
34
38
  DeduplicateInitializersPass,
35
39
  )
@@ -0,0 +1,97 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Identity elimination pass for removing redundant Identity nodes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "IdentityEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx_ir as ir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class IdentityEliminationPass(ir.passes.InPlacePass):
19
+ """Pass for eliminating redundant Identity nodes.
20
+
21
+ This pass removes Identity nodes according to the following rules:
22
+ 1. For any node of the form `y = Identity(x)`, where `y` is not an output
23
+ of any graph, replace all uses of `y` with a use of `x`, and remove the node.
24
+ 2. If `y` is an output of a graph, and `x` is not an input of any graph,
25
+ we can still do the elimination, but the value `x` should be renamed to be `y`.
26
+ 3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate
27
+ the node. It should be retained.
28
+ """
29
+
30
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
31
+ """Main entry point for the identity elimination pass."""
32
+ modified = False
33
+
34
+ # Use RecursiveGraphIterator to process all nodes in the model graph and subgraphs
35
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
36
+ if self._try_eliminate_identity_node(node):
37
+ modified = True
38
+
39
+ # Process nodes in functions
40
+ for function in model.functions.values():
41
+ for node in ir.traversal.RecursiveGraphIterator(function):
42
+ if self._try_eliminate_identity_node(node):
43
+ modified = True
44
+
45
+ if modified:
46
+ logger.info("Identity elimination pass modified the model")
47
+
48
+ return ir.passes.PassResult(model, modified=modified)
49
+
50
+ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
51
+ """Try to eliminate a single identity node. Returns True if modified."""
52
+ if node.op_type != "Identity" or node.domain != "":
53
+ return False
54
+
55
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
56
+ # Invalid Identity node, skip
57
+ return False
58
+
59
+ input_value = node.inputs[0]
60
+ output_value = node.outputs[0]
61
+
62
+ if input_value is None:
63
+ # Cannot eliminate if input is None
64
+ return False
65
+
66
+ # Get the graph that contains this node
67
+ graph_like = node.graph
68
+ assert graph_like is not None, "Node must be in a graph"
69
+
70
+ output_is_graph_output = output_value.is_graph_output()
71
+ input_is_graph_input = input_value.is_graph_input()
72
+
73
+ # Case 3: Both output is graph output and input is graph input - keep the node
74
+ if output_is_graph_output and input_is_graph_input:
75
+ return False
76
+
77
+ # Case 1 & 2 (merged): Eliminate the identity node
78
+ # Replace all uses of output with input
79
+ ir.convenience.replace_all_uses_with(output_value, input_value)
80
+
81
+ # If output is a graph output, we need to rename input and update graph outputs
82
+ if output_is_graph_output:
83
+ # Store the original output name
84
+ original_output_name = output_value.name
85
+
86
+ # Update the input value to have the output's name
87
+ input_value.name = original_output_name
88
+
89
+ # Update graph outputs to point to the input value
90
+ for idx, graph_output in enumerate(graph_like.outputs):
91
+ if graph_output is output_value:
92
+ graph_like.outputs[idx] = input_value
93
+
94
+ # Remove the identity node
95
+ graph_like.remove(node, safe=True)
96
+ logger.debug("Eliminated identity node: %s", node)
97
+ return True
onnx_ir/py.typed ADDED
@@ -0,0 +1 @@
1
+
onnx_ir/serde.py CHANGED
@@ -74,7 +74,6 @@ from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
74
74
 
75
75
  if typing.TYPE_CHECKING:
76
76
  import google.protobuf.internal.containers as proto_containers
77
- import numpy.typing as npt
78
77
 
79
78
  logger = logging.getLogger(__name__)
80
79
 
@@ -117,13 +116,6 @@ def _little_endian_dtype(dtype) -> np.dtype:
117
116
  return np.dtype(dtype).newbyteorder("<")
118
117
 
119
118
 
120
- def _unflatten_complex(
121
- array: npt.NDArray[np.float32 | np.float64],
122
- ) -> npt.NDArray[np.complex64 | np.complex128]:
123
- """Convert the real representation of a complex dtype to the complex dtype."""
124
- return array[::2] + 1j * array[1::2]
125
-
126
-
127
119
  @typing.overload
128
120
  def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
129
121
  @typing.overload
@@ -391,54 +383,89 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
391
383
  "Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
392
384
  )
393
385
 
386
+ shape = self._proto.dims
387
+
394
388
  if self._proto.HasField("raw_data"):
395
- array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<"))
396
- # Cannot return now, because we may need to unpack 4bit tensors
397
- elif dtype == _enums.DataType.STRING:
398
- return np.array(self._proto.string_data).reshape(self._proto.dims)
399
- elif self._proto.int32_data:
400
- array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
401
- if dtype in {_enums.DataType.FLOAT16, _enums.DataType.BFLOAT16}:
402
- # Reinterpret the int32 as float16 or bfloat16
403
- array = array.astype(np.uint16).view(dtype.numpy())
404
- elif dtype in {
389
+ if dtype.bitwidth == 4:
390
+ return _type_casting.unpack_4bitx2(
391
+ np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
392
+ ).view(dtype.numpy())
393
+ return np.frombuffer(
394
+ self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
395
+ ).reshape(shape)
396
+ if dtype == _enums.DataType.STRING:
397
+ return np.array(self._proto.string_data).reshape(shape)
398
+ if self._proto.int32_data:
399
+ assert dtype in {
400
+ _enums.DataType.BFLOAT16,
401
+ _enums.DataType.BOOL,
402
+ _enums.DataType.FLOAT16,
403
+ _enums.DataType.FLOAT4E2M1,
405
404
  _enums.DataType.FLOAT8E4M3FN,
406
405
  _enums.DataType.FLOAT8E4M3FNUZ,
407
406
  _enums.DataType.FLOAT8E5M2,
408
407
  _enums.DataType.FLOAT8E5M2FNUZ,
409
- }:
410
- array = array.astype(np.uint8).view(dtype.numpy())
411
- elif self._proto.int64_data:
412
- array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64))
413
- elif self._proto.uint64_data:
408
+ _enums.DataType.FLOAT8E8M0,
409
+ _enums.DataType.INT16,
410
+ _enums.DataType.INT32,
411
+ _enums.DataType.INT4,
412
+ _enums.DataType.INT8,
413
+ _enums.DataType.UINT16,
414
+ _enums.DataType.UINT4,
415
+ _enums.DataType.UINT8,
416
+ }, f"Unsupported dtype {dtype} for int32_data"
417
+ array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
418
+ if dtype.bitwidth == 32:
419
+ return array.reshape(shape)
420
+ if dtype.bitwidth == 16:
421
+ # Reinterpret the int32 as float16 or bfloat16
422
+ return array.astype(np.uint16).view(dtype.numpy()).reshape(shape)
423
+ if dtype.bitwidth == 8:
424
+ return array.astype(np.uint8).view(dtype.numpy()).reshape(shape)
425
+ if dtype.bitwidth == 4:
426
+ return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
427
+ dtype.numpy()
428
+ )
429
+ raise ValueError(
430
+ f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
431
+ )
432
+ if self._proto.int64_data:
433
+ assert dtype in {
434
+ _enums.DataType.INT64,
435
+ }, f"Unsupported dtype {dtype} for int64_data"
436
+ return np.array(
437
+ self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
438
+ ).reshape(shape)
439
+ if self._proto.uint64_data:
440
+ assert dtype in {
441
+ _enums.DataType.UINT64,
442
+ _enums.DataType.UINT32,
443
+ }, f"Unsupported dtype {dtype} for uint64_data"
414
444
  array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
415
- elif self._proto.float_data:
445
+ if dtype == _enums.DataType.UINT32:
446
+ return array.astype(np.uint32).reshape(shape)
447
+ return array.reshape(shape)
448
+ if self._proto.float_data:
449
+ assert dtype in {
450
+ _enums.DataType.FLOAT,
451
+ _enums.DataType.COMPLEX64,
452
+ }, f"Unsupported dtype {dtype} for float_data"
416
453
  array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
417
454
  if dtype == _enums.DataType.COMPLEX64:
418
- array = _unflatten_complex(array)
419
- elif self._proto.double_data:
455
+ return array.view(np.complex64).reshape(shape)
456
+ return array.reshape(shape)
457
+ if self._proto.double_data:
458
+ assert dtype in {
459
+ _enums.DataType.DOUBLE,
460
+ _enums.DataType.COMPLEX128,
461
+ }, f"Unsupported dtype {dtype} for double_data"
420
462
  array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
421
463
  if dtype == _enums.DataType.COMPLEX128:
422
- array = _unflatten_complex(array)
423
- else:
424
- # Empty tensor
425
- if not self._proto.dims:
426
- # When dims not precent and there is no data, we return an empty array
427
- return np.array([], dtype=dtype.numpy())
428
- else:
429
- # Otherwise we return a size 0 array with the correct shape
430
- return np.zeros(self._proto.dims, dtype=dtype.numpy())
431
-
432
- if dtype == _enums.DataType.INT4:
433
- return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
434
- elif dtype == _enums.DataType.UINT4:
435
- return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
436
- elif dtype == _enums.DataType.FLOAT4E2M1:
437
- return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
438
- else:
439
- # Otherwise convert to the correct dtype and reshape
440
- # Note we cannot use view() here because the storage dtype may not be the same size as the target
441
- return array.astype(dtype.numpy()).reshape(self._proto.dims)
464
+ return array.view(np.complex128).reshape(shape)
465
+ return array.reshape(shape)
466
+
467
+ # Empty tensor. We return a size 0 array with the correct shape
468
+ return np.zeros(shape, dtype=dtype.numpy())
442
469
 
443
470
  def tobytes(self) -> bytes:
444
471
  """Return the tensor as a byte string conformed to the ONNX specification, in little endian.
@@ -479,6 +506,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
479
506
  _enums.DataType.FLOAT8E4M3FNUZ,
480
507
  _enums.DataType.FLOAT8E5M2,
481
508
  _enums.DataType.FLOAT8E5M2FNUZ,
509
+ _enums.DataType.FLOAT8E8M0,
482
510
  _enums.DataType.INT4,
483
511
  _enums.DataType.UINT4,
484
512
  _enums.DataType.FLOAT4E2M1,
@@ -68,6 +68,7 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
68
68
  torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
69
69
  torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
70
70
  torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
71
+ torch.float8_e8m0fnu: ir.DataType.FLOAT8E8M0,
71
72
  torch.int16: ir.DataType.INT16,
72
73
  torch.int32: ir.DataType.INT32,
73
74
  torch.int64: ir.DataType.INT64,
@@ -104,6 +105,7 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
104
105
  ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
105
106
  ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
106
107
  ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
108
+ ir.DataType.FLOAT8E8M0: torch.float8_e8m0fnu,
107
109
  ir.DataType.INT16: torch.int16,
108
110
  ir.DataType.INT32: torch.int32,
109
111
  ir.DataType.INT64: torch.int64,
@@ -142,6 +144,7 @@ class TorchTensor(_core.Tensor):
142
144
  ir.DataType.FLOAT8E4M3FNUZ,
143
145
  ir.DataType.FLOAT8E5M2,
144
146
  ir.DataType.FLOAT8E5M2FNUZ,
147
+ ir.DataType.FLOAT8E8M0,
145
148
  }:
146
149
  return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
147
150
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -30,6 +30,7 @@ Dynamic: license-file
30
30
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
31
31
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
32
32
  [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
33
+ [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
33
34
 
34
35
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
35
36
 
@@ -1,7 +1,7 @@
1
- onnx_ir/__init__.py,sha256=5KP1Ngl2qyWiqb5S0Ol5owYsbU0geo4LFwGwN8EXTIk,3424
2
- onnx_ir/_core.py,sha256=-9BpVTZHuHQ9jsms33wqu4NjMEaDF_M57sIuVxYcM1I,137964
1
+ onnx_ir/__init__.py,sha256=_995K-JXuL0upLulUJxCXziF1gMcehH3gzea2eukCyM,3424
2
+ onnx_ir/_core.py,sha256=CtRwtDb__hK0MJLWsrNNu5n_xz6TlbJctDLw8UDQAZQ,137454
3
3
  onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
4
- onnx_ir/_enums.py,sha256=4lmm_DFKEtz6PqNw6gt6GcqrBYHisctgKMsUbQCm5N8,8252
4
+ onnx_ir/_enums.py,sha256=SxC-GGgPrmdz6UsMhx7xT9-6VmkZ6j1oVzDqNUHr3Rc,9659
5
5
  onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
6
6
  onnx_ir/_graph_containers.py,sha256=PRKrshRZ5rzWCgRs1TefzJq9n8wyo7OqeKy3XxMhyys,14265
7
7
  onnx_ir/_io.py,sha256=GWwA4XOZ-ZX1cgibgaYD0K0O5d9LX21ZwcBN02Wrh04,5205
@@ -11,13 +11,14 @@ onnx_ir/_name_authority.py,sha256=PnoV9TRgMLussZNufWavJXosDWx5avPfldVjMWEEz18,30
11
11
  onnx_ir/_polyfill.py,sha256=LzAGBKQbVDlURC0tgQgaxgkYU4rESgCYnqVs-u-Vsx8,887
12
12
  onnx_ir/_protocols.py,sha256=M29sIOAvtdlis3QtBvCQPH4pnvSwhJCQNCvs3IrN9FY,21276
13
13
  onnx_ir/_tape.py,sha256=nEGY6VZVKuB8FDyXeYr0MTq8j7E4HKOE2yN8qpz4ia0,7007
14
- onnx_ir/_type_casting.py,sha256=8iZDVrNAx_FwRVt48G4tkzIOFu3I6AsETpH3fdxcyEI,3387
14
+ onnx_ir/_type_casting.py,sha256=hbikTmgFEu0SEfnbgv2R1LbpuPQ2MCfqto3-oLWhcBc,1645
15
15
  onnx_ir/_version_utils.py,sha256=bZThuE7meVHFOY1DLsmss9WshVIp9iig7udGfDbVaK4,1333
16
16
  onnx_ir/convenience.py,sha256=0B1epuXZCSmY4FbW2vaYfR-t5ubxBZ1UruiytHs-zFw,917
17
17
  onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,18079
18
- onnx_ir/serde.py,sha256=YkbYfQMwn0YAzTd3tVDSWJ-NBiSVsG-74T6xk3e5iTU,75073
18
+ onnx_ir/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
19
+ onnx_ir/serde.py,sha256=bFQg5XYlDTvZsT_gDO_mPYedkMj_HcUbBvQuxLlRKvc,75980
19
20
  onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
20
- onnx_ir/tensor_adapters.py,sha256=dXuapwfFcpLhjKC6AOqCXbtY3WvDaEHoCNPwjnUK7_o,6565
21
+ onnx_ir/tensor_adapters.py,sha256=Pl2eLXa1VQh0nZy6NFMBr_9BRY_OPoKQX1oa4K7ecUo,6717
21
22
  onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
22
23
  onnx_ir/traversal.py,sha256=Z69wzYBNljn1S7PhVTYgwMftrfsdEBLoa0JYteOhLL0,2863
23
24
  onnx_ir/_convenience/__init__.py,sha256=DQ-Bz1wTiZJEARCFxDqZvYexWviGmwvDzE_1hR-vp0Q,19182
@@ -25,19 +26,20 @@ onnx_ir/_convenience/_constructors.py,sha256=5GhlYy_xCE2ng7l_4cNx06WQsNDyvS-0U1H
25
26
  onnx_ir/_thirdparty/asciichartpy.py,sha256=afQ0fsqko2uYRPAR4TZBrQxvCb4eN8lxZ2yDFbVQq_s,10533
26
27
  onnx_ir/passes/__init__.py,sha256=M_Tcl_-qGSNPluFIvOoeDyh0qAwNayaYyXDS5UJUJPQ,764
27
28
  onnx_ir/passes/_pass_infra.py,sha256=xIOw_zZIuOqD4Z_wZ4OvsqXfh2IZMoMlDp1xQ_MPQlc,9567
28
- onnx_ir/passes/common/__init__.py,sha256=GrrscfBekrIjxrYusgvTgP80OrgY1GMJwZMInRQmcL4,1467
29
+ onnx_ir/passes/common/__init__.py,sha256=LWkH39XATj1lQz82cVrxtle6YiZZ8RkT1fVZNthiTLI,1586
29
30
  onnx_ir/passes/common/_c_api_utils.py,sha256=g6riA6xNGVWaO5YjVHZ0krrfslWHmRlryRkwB8X56cg,2907
30
31
  onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMOGjdvVwG9yHQTkSphUgDlM0ME,2365
31
32
  onnx_ir/passes/common/common_subexpression_elimination.py,sha256=wZ1zEPdCshYB_ifP9fCAVfzQkesE6uhCfzCuL2qO5fA,7948
32
33
  onnx_ir/passes/common/constant_manipulation.py,sha256=_fGDwn0Axl2Q8APfc2m_mLMH28T-Mc9kIlpzBXoe3q4,8779
34
+ onnx_ir/passes/common/identity_elimination.py,sha256=FyqnJxFUq9Ga9XyUJ3myjzr36InYSW-oJgDTrUrBORY,3663
33
35
  onnx_ir/passes/common/initializer_deduplication.py,sha256=4CIVFYfdXUlmF2sAx560c_pTwYVXtX5hcSwWzUKm5uc,2061
34
36
  onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
35
37
  onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6svE0cCyDew,1691
36
38
  onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
37
39
  onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
38
40
  onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
39
- onnx_ir-0.1.3.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
40
- onnx_ir-0.1.3.dist-info/METADATA,sha256=vKG8o_nAUJfjM05rahv0g-FCeHkHXIwCAcuYzSY6PH8,4782
41
- onnx_ir-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
- onnx_ir-0.1.3.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
43
- onnx_ir-0.1.3.dist-info/RECORD,,
41
+ onnx_ir-0.1.5.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
42
+ onnx_ir-0.1.5.dist-info/METADATA,sha256=SHH7BxuFCKIsWyRKQyOKbXRtZX8n0ryietlWDPPLBvA,4884
43
+ onnx_ir-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
44
+ onnx_ir-0.1.5.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
45
+ onnx_ir-0.1.5.dist-info/RECORD,,