onnx-ir 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -14,6 +14,7 @@ __all__ = [
14
14
  "ExternalTensor",
15
15
  "StringTensor",
16
16
  "LazyTensor",
17
+ "PackedTensor",
17
18
  "SymbolicDim",
18
19
  "Shape",
19
20
  "TensorType",
@@ -73,6 +74,7 @@ __all__ = [
73
74
  "from_proto",
74
75
  "from_onnx_text",
75
76
  "to_proto",
77
+ "to_onnx_text",
76
78
  # Convenience constructors
77
79
  "tensor",
78
80
  "node",
@@ -114,6 +116,7 @@ from onnx_ir._core import (
114
116
  Model,
115
117
  Node,
116
118
  OptionalType,
119
+ PackedTensor,
117
120
  RefAttr,
118
121
  SequenceType,
119
122
  Shape,
@@ -149,7 +152,7 @@ from onnx_ir._protocols import (
149
152
  TypeProtocol,
150
153
  ValueProtocol,
151
154
  )
152
- from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
155
+ from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_onnx_text, to_proto
153
156
 
154
157
  DEBUG = False
155
158
 
@@ -164,4 +167,4 @@ def __set_module() -> None:
164
167
 
165
168
 
166
169
  __set_module()
167
- __version__ = "0.1.1"
170
+ __version__ = "0.1.3"
@@ -14,14 +14,17 @@ __all__ = [
14
14
  "replace_all_uses_with",
15
15
  "create_value_mapping",
16
16
  "replace_nodes_and_values",
17
+ "get_const_tensor",
17
18
  ]
18
19
 
20
+ import logging
19
21
  from collections.abc import Mapping, Sequence
20
22
  from typing import Union
21
23
 
22
- import onnx
24
+ import numpy as np
25
+ import onnx # noqa: TID251
23
26
 
24
- from onnx_ir import _core, _enums, _protocols, serde
27
+ from onnx_ir import _core, _enums, _protocols, serde, traversal
25
28
 
26
29
  SupportedAttrTypes = Union[
27
30
  str,
@@ -42,6 +45,9 @@ SupportedAttrTypes = Union[
42
45
  ]
43
46
 
44
47
 
48
+ logger = logging.getLogger(__name__)
49
+
50
+
45
51
  def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
46
52
  """Infer the attribute type based on the type of the Python object."""
47
53
  if isinstance(attr, int):
@@ -313,7 +319,12 @@ def replace_all_uses_with(
313
319
  def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
314
320
  """Return a dictionary mapping names to values in the graph.
315
321
 
316
- The mapping does not include values from subgraphs.
322
+ The mapping includes values from subgraphs. Duplicated names are omitted,
323
+ and the first value with that name is returned. Values with empty names
324
+ are excluded from the mapping.
325
+
326
+ .. versionchanged:: 0.1.2
327
+ Values from subgraphs are now included in the mapping.
317
328
 
318
329
  Args:
319
330
  graph: The graph to extract the mapping from.
@@ -327,11 +338,23 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
327
338
  for input in graph.inputs:
328
339
  if not input.name:
329
340
  continue
341
+ if input.name in values:
342
+ continue
330
343
  values[input.name] = input
331
- for node in graph:
344
+ for node in traversal.RecursiveGraphIterator(graph):
345
+ for value in node.inputs:
346
+ if not value:
347
+ continue
348
+ if not value.name:
349
+ continue
350
+ if value.name in values:
351
+ continue
352
+ values[value.name] = value
332
353
  for value in node.outputs:
333
354
  if not value.name:
334
355
  continue
356
+ if value.name in values:
357
+ continue
335
358
  values[value.name] = value
336
359
  return values
337
360
 
@@ -375,3 +398,106 @@ def replace_nodes_and_values(
375
398
  # insert new nodes after the index node
376
399
  graph_or_function.insert_after(insertion_point, new_nodes)
377
400
  graph_or_function.remove(old_nodes, safe=True)
401
+
402
+
403
+ def get_const_tensor(
404
+ value: _core.Value, propagate_shape_type: bool = False
405
+ ) -> _protocols.TensorProtocol | None:
406
+ """Get the constant tensor from a value, if it exists.
407
+
408
+ A constant tensor can be obtained if the value has a ``const_value`` set
409
+ (as in the case of an initializer) or if the value is produced by a
410
+ Constant node.
411
+
412
+ This function will not alter the ``const_value`` of the value, but
413
+ it will propagate the shape and type of the constant tensor to the value
414
+ if `propagate_shape_type` is set to True.
415
+
416
+ .. versionadded:: 0.1.2
417
+
418
+ Args:
419
+ value: The value to get the constant tensor from.
420
+ propagate_shape_type: If True, the shape and type of the value will be
421
+ propagated to the Value.
422
+
423
+ Returns:
424
+ The constant tensor if it exists, otherwise None.
425
+
426
+ Raises:
427
+ ValueError: If the Constant node does not have exactly one output or
428
+ one attribute.
429
+ """
430
+ tensor = None
431
+ if value.const_value is not None:
432
+ tensor = value.const_value
433
+ else:
434
+ node = value.producer()
435
+ if node is None:
436
+ # Potentially a graph input
437
+ return None
438
+ if node.op_type != "Constant" or node.domain != "":
439
+ # Not a Constant node or not in the ONNX domain
440
+ return None
441
+ if len(node.outputs) != 1:
442
+ raise ValueError(
443
+ f"Constant node '{node.name}' must have exactly one output, "
444
+ f"but has {len(node.outputs)} outputs."
445
+ )
446
+ if len(node.attributes) != 1:
447
+ raise ValueError(
448
+ f"Constant node '{node.name}' must have exactly one attribute, "
449
+ f"but has {len(node.attributes)} attributes."
450
+ )
451
+
452
+ attr_name, attr_value = next(iter(node.attributes.items()))
453
+
454
+ if attr_value.is_ref():
455
+ # TODO: Make it easier to resolve a reference attribute.
456
+ # For now we just return None
457
+ return None
458
+
459
+ ir_value = node.outputs[0]
460
+ if attr_name in {"value_float", "value_floats"}:
461
+ tensor = _core.Tensor(
462
+ np.array(attr_value.value, dtype=np.float32), name=ir_value.name
463
+ )
464
+ elif attr_name in {"value_int", "value_ints"}:
465
+ tensor = _core.Tensor(
466
+ np.array(attr_value.value, dtype=np.int64), name=ir_value.name
467
+ )
468
+ elif attr_name in {"value_string", "value_strings"}:
469
+ tensor = _core.StringTensor(
470
+ np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
471
+ )
472
+ elif attr_name == "value":
473
+ tensor = attr_value.as_tensor()
474
+ else:
475
+ raise ValueError(
476
+ f"Unsupported attribute '{attr_name}' in Constant node '{node.name}'. "
477
+ "Expected one of 'value_float', 'value_floats', 'value_int', "
478
+ "'value_ints', 'value_string', 'value_strings', or 'value'."
479
+ )
480
+ # Assign the name of the constant value to the tensor
481
+ tensor.name = value.name
482
+ if tensor is not None and propagate_shape_type:
483
+ # Propagate the shape and type of the tensor to the value
484
+ if value.shape is not None and value.shape != tensor.shape:
485
+ logger.warning(
486
+ "Value '%s' has a shape %s that differs from "
487
+ "the constant tensor's shape %s. The value's shape will be updated.",
488
+ value,
489
+ value.shape,
490
+ tensor.shape,
491
+ )
492
+ value.shape = tensor.shape # type: ignore[assignment]
493
+ new_value_type = _core.TensorType(tensor.dtype)
494
+ if value.type is not None and value.type != new_value_type:
495
+ logger.warning(
496
+ "Value '%s' has a type '%s' that differs from "
497
+ "the constant tensor's type '%s'. The value's type will be updated.",
498
+ value,
499
+ value.type,
500
+ new_value_type,
501
+ )
502
+ value.type = new_value_type
503
+ return tensor
@@ -13,7 +13,7 @@ import typing
13
13
  from collections.abc import Mapping, Sequence
14
14
 
15
15
  import numpy as np
16
- import onnx
16
+ import onnx # noqa: TID251
17
17
 
18
18
  from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
19
19
 
@@ -37,6 +37,10 @@ def tensor(
37
37
 
38
38
  ``value`` can be a numpy array, a plain Python object, or a TensorProto.
39
39
 
40
+ .. warning::
41
+ For 4bit dtypes, the value must be unpacked. Use :class:`~onnx_ir.PackedTensor`
42
+ to create a tensor with packed data.
43
+
40
44
  Example::
41
45
 
42
46
  >>> import onnx_ir as ir
@@ -155,7 +159,7 @@ def node(
155
159
  doc_string: str | None = None,
156
160
  metadata_props: dict[str, str] | None = None,
157
161
  ) -> ir.Node:
158
- """Create an :class:`ir.Node`.
162
+ """Create an :class:`~onnx_ir.Node`.
159
163
 
160
164
  This is a convenience constructor for creating a Node that supports Python
161
165
  objects as attributes.