onnx-ir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -14,6 +14,7 @@ __all__ = [
14
14
  "ExternalTensor",
15
15
  "StringTensor",
16
16
  "LazyTensor",
17
+ "PackedTensor",
17
18
  "SymbolicDim",
18
19
  "Shape",
19
20
  "TensorType",
@@ -73,6 +74,7 @@ __all__ = [
73
74
  "from_proto",
74
75
  "from_onnx_text",
75
76
  "to_proto",
77
+ "to_onnx_text",
76
78
  # Convenience constructors
77
79
  "tensor",
78
80
  "node",
@@ -114,6 +116,7 @@ from onnx_ir._core import (
114
116
  Model,
115
117
  Node,
116
118
  OptionalType,
119
+ PackedTensor,
117
120
  RefAttr,
118
121
  SequenceType,
119
122
  Shape,
@@ -149,7 +152,7 @@ from onnx_ir._protocols import (
149
152
  TypeProtocol,
150
153
  ValueProtocol,
151
154
  )
152
- from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
155
+ from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_onnx_text, to_proto
153
156
 
154
157
  DEBUG = False
155
158
 
@@ -164,4 +167,4 @@ def __set_module() -> None:
164
167
 
165
168
 
166
169
  __set_module()
167
- __version__ = "0.1.0"
170
+ __version__ = "0.1.2"
@@ -14,14 +14,17 @@ __all__ = [
14
14
  "replace_all_uses_with",
15
15
  "create_value_mapping",
16
16
  "replace_nodes_and_values",
17
+ "get_const_tensor",
17
18
  ]
18
19
 
20
+ import logging
19
21
  from collections.abc import Mapping, Sequence
20
22
  from typing import Union
21
23
 
22
- import onnx
24
+ import numpy as np
25
+ import onnx # noqa: TID251
23
26
 
24
- from onnx_ir import _core, _enums, _protocols, serde
27
+ from onnx_ir import _core, _enums, _protocols, serde, traversal
25
28
 
26
29
  SupportedAttrTypes = Union[
27
30
  str,
@@ -42,6 +45,9 @@ SupportedAttrTypes = Union[
42
45
  ]
43
46
 
44
47
 
48
+ logger = logging.getLogger(__name__)
49
+
50
+
45
51
  def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
46
52
  """Infer the attribute type based on the type of the Python object."""
47
53
  if isinstance(attr, int):
@@ -313,7 +319,9 @@ def replace_all_uses_with(
313
319
  def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
314
320
  """Return a dictionary mapping names to values in the graph.
315
321
 
316
- The mapping does not include values from subgraphs.
322
+ The mapping includes values from subgraphs. Duplicated names are omitted,
323
+ and the first value with that name is returned. Values with empty names
324
+ are excluded from the mapping.
317
325
 
318
326
  Args:
319
327
  graph: The graph to extract the mapping from.
@@ -327,11 +335,23 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
327
335
  for input in graph.inputs:
328
336
  if not input.name:
329
337
  continue
338
+ if input.name in values:
339
+ continue
330
340
  values[input.name] = input
331
- for node in graph:
341
+ for node in traversal.RecursiveGraphIterator(graph):
342
+ for value in node.inputs:
343
+ if not value:
344
+ continue
345
+ if not value.name:
346
+ continue
347
+ if value.name in values:
348
+ continue
349
+ values[value.name] = value
332
350
  for value in node.outputs:
333
351
  if not value.name:
334
352
  continue
353
+ if value.name in values:
354
+ continue
335
355
  values[value.name] = value
336
356
  return values
337
357
 
@@ -375,3 +395,104 @@ def replace_nodes_and_values(
375
395
  # insert new nodes after the index node
376
396
  graph_or_function.insert_after(insertion_point, new_nodes)
377
397
  graph_or_function.remove(old_nodes, safe=True)
398
+
399
+
400
+ def get_const_tensor(
401
+ value: _core.Value, propagate_shape_type: bool = False
402
+ ) -> _protocols.TensorProtocol | None:
403
+ """Get the constant tensor from a value, if it exists.
404
+
405
+ A constant tensor can be obtained if the value has a ``const_value`` set
406
+ (as in the case of an initializer) or if the value is produced by a
407
+ Constant node.
408
+
409
+ This function will not alter the ``const_value`` of the value, but
410
+ it will propagate the shape and type of the constant tensor to the value
411
+ if `propagate_shape_type` is set to True.
412
+
413
+ Args:
414
+ value: The value to get the constant tensor from.
415
+ propagate_shape_type: If True, the shape and type of the value will be
416
+ propagated to the Value.
417
+
418
+ Returns:
419
+ The constant tensor if it exists, otherwise None.
420
+
421
+ Raises:
422
+ ValueError: If the Constant node does not have exactly one output or
423
+ one attribute.
424
+ """
425
+ tensor = None
426
+ if value.const_value is not None:
427
+ tensor = value.const_value
428
+ else:
429
+ node = value.producer()
430
+ if node is None:
431
+ # Potentially a graph input
432
+ return None
433
+ if node.op_type != "Constant" or node.domain != "":
434
+ # Not a Constant node or not in the ONNX domain
435
+ return None
436
+ if len(node.outputs) != 1:
437
+ raise ValueError(
438
+ f"Constant node '{node.name}' must have exactly one output, "
439
+ f"but has {len(node.outputs)} outputs."
440
+ )
441
+ if len(node.attributes) != 1:
442
+ raise ValueError(
443
+ f"Constant node '{node.name}' must have exactly one attribute, "
444
+ f"but has {len(node.attributes)} attributes."
445
+ )
446
+
447
+ attr_name, attr_value = next(iter(node.attributes.items()))
448
+
449
+ if attr_value.is_ref():
450
+ # TODO: Make it easier to resolve a reference attribute.
451
+ # For now we just return None
452
+ return None
453
+
454
+ ir_value = node.outputs[0]
455
+ if attr_name in {"value_float", "value_floats"}:
456
+ tensor = _core.Tensor(
457
+ np.array(attr_value.value, dtype=np.float32), name=ir_value.name
458
+ )
459
+ elif attr_name in {"value_int", "value_ints"}:
460
+ tensor = _core.Tensor(
461
+ np.array(attr_value.value, dtype=np.int64), name=ir_value.name
462
+ )
463
+ elif attr_name in {"value_string", "value_strings"}:
464
+ tensor = _core.StringTensor(
465
+ np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
466
+ )
467
+ elif attr_name == "value":
468
+ tensor = attr_value.as_tensor()
469
+ else:
470
+ raise ValueError(
471
+ f"Unsupported attribute '{attr_name}' in Constant node '{node.name}'. "
472
+ "Expected one of 'value_float', 'value_floats', 'value_int', "
473
+ "'value_ints', 'value_string', 'value_strings', or 'value'."
474
+ )
475
+ # Assign the name of the constant value to the tensor
476
+ tensor.name = value.name
477
+ if tensor is not None and propagate_shape_type:
478
+ # Propagate the shape and type of the tensor to the value
479
+ if value.shape is not None and value.shape != tensor.shape:
480
+ logger.warning(
481
+ "Value '%s' has a shape %s that differs from "
482
+ "the constant tensor's shape %s. The value's shape will be updated.",
483
+ value,
484
+ value.shape,
485
+ tensor.shape,
486
+ )
487
+ value.shape = tensor.shape # type: ignore[assignment]
488
+ new_value_type = _core.TensorType(tensor.dtype)
489
+ if value.type is not None and value.type != new_value_type:
490
+ logger.warning(
491
+ "Value '%s' has a type '%s' that differs from "
492
+ "the constant tensor's type '%s'. The value's type will be updated.",
493
+ value,
494
+ value.type,
495
+ new_value_type,
496
+ )
497
+ value.type = new_value_type
498
+ return tensor
@@ -13,7 +13,7 @@ import typing
13
13
  from collections.abc import Mapping, Sequence
14
14
 
15
15
  import numpy as np
16
- import onnx
16
+ import onnx # noqa: TID251
17
17
 
18
18
  from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
19
19
 
@@ -37,6 +37,10 @@ def tensor(
37
37
 
38
38
  ``value`` can be a numpy array, a plain Python object, or a TensorProto.
39
39
 
40
+ .. warning::
41
+ For 4bit dtypes, the value must be unpacked. Use :class:`~onnx_ir.PackedTensor`
42
+ to create a tensor with packed data.
43
+
40
44
  Example::
41
45
 
42
46
  >>> import onnx_ir as ir
@@ -155,7 +159,7 @@ def node(
155
159
  doc_string: str | None = None,
156
160
  metadata_props: dict[str, str] | None = None,
157
161
  ) -> ir.Node:
158
- """Create an :class:`ir.Node`.
162
+ """Create an :class:`~onnx_ir.Node`.
159
163
 
160
164
  This is a convenience constructor for creating a Node that supports Python
161
165
  objects as attributes.