onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,558 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Convenience methods for constructing and manipulating the IR.
4
+
5
+ This is an internal only module. We should choose to expose some of the methods
6
+ in convenience.py after they are proven to be useful.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ __all__ = [
12
+ "convert_attribute",
13
+ "convert_attributes",
14
+ "replace_all_uses_with",
15
+ "create_value_mapping",
16
+ "replace_nodes_and_values",
17
+ "get_const_tensor",
18
+ ]
19
+
20
+ import logging
21
+ from collections.abc import Iterable, Mapping, Sequence
22
+ from typing import Union
23
+
24
+ import numpy as np
25
+ import onnx # noqa: TID251
26
+
27
+ from onnx_ir import _core, _enums, _protocols, serde, traversal
28
+
29
+ SupportedAttrTypes = Union[
30
+ str,
31
+ int,
32
+ float,
33
+ Sequence[int],
34
+ Sequence[float],
35
+ Sequence[str],
36
+ _protocols.TensorProtocol, # This includes all in-memory tensor types
37
+ onnx.TensorProto,
38
+ _core.Attr,
39
+ _protocols.GraphProtocol,
40
+ Sequence[_protocols.GraphProtocol],
41
+ onnx.GraphProto,
42
+ _protocols.TypeProtocol,
43
+ Sequence[_protocols.TypeProtocol],
44
+ None,
45
+ ]
46
+
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
52
+ """Infer the attribute type based on the type of the Python object."""
53
+ if isinstance(attr, int):
54
+ return _enums.AttributeType.INT
55
+ if isinstance(attr, float):
56
+ return _enums.AttributeType.FLOAT
57
+ if isinstance(attr, str):
58
+ return _enums.AttributeType.STRING
59
+ if isinstance(attr, _core.Attr):
60
+ return attr.type
61
+ if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
62
+ return _enums.AttributeType.GRAPH
63
+ if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
64
+ # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
65
+ return _enums.AttributeType.TENSOR
66
+ if isinstance(
67
+ attr,
68
+ (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
69
+ ):
70
+ return _enums.AttributeType.TYPE_PROTO
71
+ if isinstance(attr, Sequence):
72
+ if not attr:
73
+ logger.warning(
74
+ "Attribute type is ambiguous because it is an empty sequence. "
75
+ "Please create an Attr with an explicit type. Defaulted to INTS"
76
+ )
77
+ return _enums.AttributeType.INTS
78
+ if all(isinstance(x, int) for x in attr):
79
+ return _enums.AttributeType.INTS
80
+ if all(isinstance(x, float) for x in attr):
81
+ return _enums.AttributeType.FLOATS
82
+ if all(isinstance(x, str) for x in attr):
83
+ return _enums.AttributeType.STRINGS
84
+ if all(
85
+ isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
86
+ for x in attr
87
+ ):
88
+ return _enums.AttributeType.TENSORS
89
+ if all(
90
+ isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
91
+ for x in attr
92
+ ):
93
+ return _enums.AttributeType.GRAPHS
94
+ if all(
95
+ isinstance(
96
+ x,
97
+ (
98
+ _core.TensorType,
99
+ _core.SequenceType,
100
+ _core.OptionalType,
101
+ _protocols.TypeProtocol,
102
+ ),
103
+ )
104
+ for x in attr
105
+ ):
106
+ return _enums.AttributeType.TYPE_PROTOS
107
+ raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
108
+
109
+
110
+ def convert_attribute(
111
+ name: str,
112
+ attr: SupportedAttrTypes,
113
+ attr_type: _enums.AttributeType | None = None,
114
+ ) -> _core.Attr:
115
+ """Convert a Python object to a _core.Attr object.
116
+
117
+ This method is useful when constructing nodes with attributes. It infers the
118
+ attribute type based on the type of the Python value.
119
+
120
+ Args:
121
+ name: The name of the attribute.
122
+ attr: The value of the attribute.
123
+ attr_type: The type of the attribute. This is required when attr is None.
124
+ When provided, it overrides the inferred type.
125
+
126
+ Returns:
127
+ A ``Attr`` object.
128
+
129
+ Raises:
130
+ ValueError: If ``attr`` is ``None`` and ``attr_type`` is not provided.
131
+ TypeError: If the type of the attribute is not supported.
132
+ """
133
+ if attr is None:
134
+ if attr_type is None:
135
+ raise ValueError("attr_type must be provided when attr is None")
136
+ return _core.Attr(name, attr_type, None)
137
+
138
+ if isinstance(attr, _core.Attr):
139
+ if attr.name != name:
140
+ raise ValueError(
141
+ f"Attribute name '{attr.name}' does not match provided name '{name}'"
142
+ )
143
+ if attr_type is not None and attr.type != attr_type:
144
+ raise ValueError(
145
+ f"Attribute type '{attr.type}' does not match provided type '{attr_type}'"
146
+ )
147
+ return attr
148
+
149
+ if attr_type is None:
150
+ attr_type = _infer_attribute_type(attr)
151
+
152
+ if attr_type == _enums.AttributeType.INT:
153
+ return _core.AttrInt64(name, attr) # type: ignore
154
+ if attr_type == _enums.AttributeType.FLOAT:
155
+ return _core.AttrFloat32(name, attr) # type: ignore
156
+ if attr_type == _enums.AttributeType.STRING:
157
+ return _core.AttrString(name, attr) # type: ignore
158
+ if attr_type == _enums.AttributeType.INTS:
159
+ return _core.AttrInt64s(name, attr) # type: ignore
160
+ if attr_type == _enums.AttributeType.FLOATS:
161
+ return _core.AttrFloat32s(name, attr) # type: ignore
162
+ if attr_type == _enums.AttributeType.STRINGS:
163
+ return _core.AttrStrings(name, attr) # type: ignore
164
+ if attr_type == _enums.AttributeType.TENSOR:
165
+ if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
166
+ return _core.AttrTensor(name, attr)
167
+ if isinstance(attr, onnx.TensorProto):
168
+ return _core.AttrTensor(name, serde.deserialize_tensor(attr))
169
+ if attr_type == _enums.AttributeType.TENSORS:
170
+ tensors = []
171
+ for t in attr: # type: ignore[union-attr]
172
+ if isinstance(t, onnx.TensorProto):
173
+ tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t)))
174
+ else:
175
+ tensors.append(t) # type: ignore[arg-type]
176
+ return _core.AttrTensors(name, tensors) # type: ignore[arg-type]
177
+ if attr_type == _enums.AttributeType.GRAPH:
178
+ if isinstance(attr, onnx.GraphProto):
179
+ attr = serde.deserialize_graph(attr)
180
+ return _core.AttrGraph(name, attr) # type: ignore[arg-type]
181
+ if attr_type == _enums.AttributeType.GRAPHS:
182
+ graphs = []
183
+ for graph in attr: # type: ignore[union-attr]
184
+ if isinstance(graph, onnx.GraphProto):
185
+ graphs.append(serde.deserialize_graph(graph))
186
+ else:
187
+ graphs.append(graph) # type: ignore[arg-type]
188
+ return _core.AttrGraphs(name, graphs) # type: ignore[arg-type]
189
+ if attr_type == _enums.AttributeType.TYPE_PROTO:
190
+ return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
191
+ if attr_type == _enums.AttributeType.TYPE_PROTOS:
192
+ return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type]
193
+ raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
194
+
195
+
196
+ def convert_attributes(
197
+ attrs: Mapping[str, SupportedAttrTypes],
198
+ ) -> list[_core.Attr]:
199
+ """Convert a dictionary of attributes to a list of _core.Attr objects.
200
+
201
+ It infers the attribute type based on the type of the value. The supported
202
+ types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
203
+ :class:`_core.Tensor`, and :class:`_core.Attr`::
204
+
205
+ >>> import onnx_ir as ir
206
+ >>> import onnx
207
+ >>> import numpy as np
208
+ >>> attrs = {
209
+ ... "int": 1,
210
+ ... "float": 1.0,
211
+ ... "str": "hello",
212
+ ... "ints": [1, 2, 3],
213
+ ... "floats": [1.0, 2.0, 3.0],
214
+ ... "strings": ["hello", "world"],
215
+ ... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])),
216
+ ... "tensor_proto":
217
+ ... onnx.TensorProto(
218
+ ... dims=[3],
219
+ ... data_type=onnx.TensorProto.FLOAT,
220
+ ... float_data=[1.0, 2.0, 3.0],
221
+ ... name="proto",
222
+ ... ),
223
+ ... "graph": ir.Graph([], [], nodes=[], name="graph0"),
224
+ ... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")],
225
+ ... "type_proto": ir.TensorType(ir.DataType.FLOAT),
226
+ ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
227
+ ... }
228
+ >>> convert_attributes(attrs)
229
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, (1, 2, 3)), Attr('floats', FLOATS, (1.0, 2.0, 3.0)), Attr('strings', STRINGS, ('hello', 'world')), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
230
+ name='graph0',
231
+ inputs=(
232
+ <BLANKLINE>
233
+ ),
234
+ outputs=(
235
+ <BLANKLINE>
236
+ ),
237
+ len()=0
238
+ )), Attr('graphs', GRAPHS, (Graph(
239
+ name='graph1',
240
+ inputs=(
241
+ <BLANKLINE>
242
+ ),
243
+ outputs=(
244
+ <BLANKLINE>
245
+ ),
246
+ len()=0
247
+ ), Graph(
248
+ name='graph2',
249
+ inputs=(
250
+ <BLANKLINE>
251
+ ),
252
+ outputs=(
253
+ <BLANKLINE>
254
+ ),
255
+ len()=0
256
+ ))), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, (Tensor(FLOAT), Tensor(FLOAT)))]
257
+
258
+ .. important::
259
+ An empty sequence should be created with an explicit type by initializing
260
+ an Attr object with an attribute type to avoid type ambiguity. For example::
261
+
262
+ ir.Attr("empty", [], type=ir.AttributeType.INTS)
263
+
264
+ Args:
265
+ attrs: A dictionary of {<attribute name>: <python objects>} to convert.
266
+
267
+ Returns:
268
+ A list of :class:`_core.Attr` objects.
269
+
270
+ Raises:
271
+ TypeError: If an attribute type is not supported.
272
+ """
273
+ attributes: list[_core.Attr] = []
274
+ for name, attr in attrs.items():
275
+ if attr is not None:
276
+ attributes.append(convert_attribute(name, attr))
277
+ return attributes
278
+
279
+
280
+ def replace_all_uses_with(
281
+ values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
282
+ replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
283
+ replace_graph_outputs: bool = False,
284
+ ) -> None:
285
+ """Replace all uses of the given values with the replacements.
286
+
287
+ This is useful when nodes in the graph are replaced with new nodes, where
288
+ the old users need to be updated to use the outputs of the new nodes.
289
+
290
+ For example, suppose we have the following graph::
291
+
292
+ A -> {B, C}
293
+
294
+ We want to replace the node A with a new node D::
295
+
296
+ >>> import onnx_ir as ir
297
+ >>> input = ir.val("input")
298
+ >>> node_a = ir.Node("", "A", [input])
299
+ >>> node_b = ir.Node("", "B", node_a.outputs)
300
+ >>> node_c = ir.Node("", "C", node_a.outputs)
301
+ >>> node_d = ir.Node("", "D", [input])
302
+ >>> replace_all_uses_with(node_a.outputs, node_d.outputs)
303
+ >>> len(node_b.inputs)
304
+ 1
305
+ >>> node_b.inputs[0].producer().op_type
306
+ 'D'
307
+ >>> len(node_c.inputs)
308
+ 1
309
+ >>> node_c.inputs[0].producer().op_type
310
+ 'D'
311
+ >>> len(node_a.outputs[0].uses())
312
+ 0
313
+
314
+ When values and replacements are sequences, they are zipped into pairs. All
315
+ users of the first value is replaced with the first replacement, and so on.
316
+
317
+ .. note::
318
+ Be sure to remove the old nodes from the graph using ``graph.remove()``
319
+ if they are no longer needed, or use :class:`onnx_ir.passes.common.RemoveUnusedNodesPass`
320
+ to remove all unused nodes in the graph.
321
+
322
+ .. tip::
323
+ **Handling graph outputs**
324
+
325
+ To also replace graph outputs that reference the values being replaced, either
326
+ set ``replace_graph_outputs`` to True, or manually update the graph outputs
327
+ before calling this function to avoid an error being raised when ``replace_graph_outputs=False``.
328
+
329
+ Be careful when a value appears multiple times in the graph outputs -
330
+ this is invalid. An identity node will need to be added on each duplicated
331
+ outputs to ensure a valid ONNX graph.
332
+
333
+ You may also want to assign the name of this value to the replacement value
334
+ to maintain the name when it is a graph output.
335
+
336
+ .. versionadded:: 0.1.12
337
+ The ``replace_graph_outputs`` parameter is added.
338
+
339
+ .. versionadded:: 0.1.12
340
+ ValueError is raised when ``replace_graph_outputs`` is False && when the value to
341
+ replace is a graph output.
342
+
343
+ Args:
344
+ values: The value or values to be replaced.
345
+ replacements: The new value or values to use as inputs.
346
+ replace_graph_outputs: If True, graph outputs that reference the values
347
+ being replaced will also be updated to reference the replacements.
348
+
349
+ Raises:
350
+ ValueError: When ``replace_graph_outputs`` is False && when the value to
351
+ replace is a graph output.
352
+ """
353
+ if not isinstance(values, Sequence):
354
+ values = (values,)
355
+ if not isinstance(replacements, Sequence):
356
+ replacements = (replacements,)
357
+ if len(values) != len(replacements):
358
+ raise ValueError("The number of values and replacements must match.")
359
+ for value, replacement in zip(values, replacements):
360
+ value.replace_all_uses_with(replacement, replace_graph_outputs=replace_graph_outputs)
361
+
362
+
363
+ def create_value_mapping(
364
+ graph: _core.Graph | _core.GraphView | _core.Function,
365
+ *,
366
+ include_subgraphs: bool = True,
367
+ ) -> dict[str, _core.Value]:
368
+ """Return a dictionary mapping names to values in the graph.
369
+
370
+ The mapping includes values from subgraphs. Duplicated names are omitted,
371
+ and the first value with that name is returned. Values with empty names
372
+ are excluded from the mapping.
373
+
374
+ .. versionchanged:: 0.1.2
375
+ Values from subgraphs are now included in the mapping.
376
+
377
+ .. versionadded:: 0.1.14
378
+ The ``include_subgraphs`` parameter.
379
+
380
+ Args:
381
+ graph: The graph to extract the mapping from.
382
+ include_subgraphs: If True, values from subgraphs are included in the mapping.
383
+
384
+ Returns:
385
+ A dictionary mapping names to values.
386
+ """
387
+ values: dict[str, _core.Value] = {}
388
+ if not isinstance(graph, _core.Function):
389
+ values.update(graph.initializers)
390
+ # The names of the values can be None or "", which we need to exclude
391
+ for input in graph.inputs:
392
+ if not input.name:
393
+ continue
394
+ if input.name in values:
395
+ continue
396
+ values[input.name] = input
397
+ if include_subgraphs:
398
+ iterator: Iterable[_core.Node] = traversal.RecursiveGraphIterator(graph)
399
+ else:
400
+ iterator = graph
401
+ for node in iterator:
402
+ for value in node.inputs:
403
+ if not value:
404
+ continue
405
+ if not value.name:
406
+ continue
407
+ if value.name in values:
408
+ continue
409
+ values[value.name] = value
410
+ for value in node.outputs:
411
+ if not value.name:
412
+ continue
413
+ if value.name in values:
414
+ continue
415
+ values[value.name] = value
416
+ return values
417
+
418
+
419
+ def replace_nodes_and_values(
420
+ graph_or_function: _core.Graph | _core.Function,
421
+ /,
422
+ insertion_point: _core.Node,
423
+ old_nodes: Sequence[_core.Node],
424
+ new_nodes: Sequence[_core.Node],
425
+ old_values: Sequence[_core.Value],
426
+ new_values: Sequence[_core.Value],
427
+ ) -> None:
428
+ """Replaces nodes and values in the graph or function.
429
+
430
+ Args:
431
+ graph_or_function: The graph or function to replace nodes and values in.
432
+ insertion_point: The node to insert the new nodes after.
433
+ old_nodes: The nodes to replace.
434
+ new_nodes: The nodes to replace with.
435
+ old_values: The values to replace.
436
+ new_values: The values to replace with.
437
+ """
438
+ for old_value, new_value in zip(old_values, new_values):
439
+ # Propagate relevant info from old value to new value
440
+ # TODO(Rama): Perhaps this should be a separate utility function.
441
+ new_value.type = old_value.type if old_value.type is not None else new_value.type
442
+ new_value.shape = old_value.shape if old_value.shape is not None else new_value.shape
443
+ new_value.const_value = (
444
+ old_value.const_value
445
+ if old_value.const_value is not None
446
+ else new_value.const_value
447
+ )
448
+ new_value.name = old_value.name if old_value.name is not None else new_value.name
449
+
450
+ # Reconnect the users of the deleted values to use the new values
451
+ replace_all_uses_with(old_values, new_values, replace_graph_outputs=True)
452
+
453
+ # insert new nodes after the index node
454
+ graph_or_function.insert_after(insertion_point, new_nodes)
455
+ graph_or_function.remove(old_nodes, safe=True)
456
+
457
+
458
+ def get_const_tensor(
459
+ value: _core.Value, propagate_shape_type: bool = False
460
+ ) -> _protocols.TensorProtocol | None:
461
+ """Get the constant tensor from a value, if it exists.
462
+
463
+ A constant tensor can be obtained if the value has a ``const_value`` set
464
+ (as in the case of an initializer) or if the value is produced by a
465
+ Constant node.
466
+
467
+ This function will not alter the ``const_value`` of the value, but
468
+ it will propagate the shape and type of the constant tensor to the value
469
+ if `propagate_shape_type` is set to True.
470
+
471
+ .. versionadded:: 0.1.2
472
+
473
+ Args:
474
+ value: The value to get the constant tensor from.
475
+ propagate_shape_type: If True, the shape and type of the value will be
476
+ propagated to the Value.
477
+
478
+ Returns:
479
+ The constant tensor if it exists, otherwise None.
480
+
481
+ Raises:
482
+ ValueError: If the Constant node does not have exactly one output or
483
+ one attribute.
484
+ """
485
+ tensor = None
486
+ if value.const_value is not None:
487
+ tensor = value.const_value
488
+ else:
489
+ node = value.producer()
490
+ if node is None:
491
+ # Potentially a graph input
492
+ return None
493
+ if node.op_type != "Constant" or node.domain != "":
494
+ # Not a Constant node or not in the ONNX domain
495
+ return None
496
+ if len(node.outputs) != 1:
497
+ raise ValueError(
498
+ f"Constant node '{node.name}' must have exactly one output, "
499
+ f"but has {len(node.outputs)} outputs."
500
+ )
501
+ if len(node.attributes) != 1:
502
+ raise ValueError(
503
+ f"Constant node '{node.name}' must have exactly one attribute, "
504
+ f"but has {len(node.attributes)} attributes."
505
+ )
506
+
507
+ attr_name, attr_value = next(iter(node.attributes.items()))
508
+
509
+ if attr_value.is_ref():
510
+ # TODO: Make it easier to resolve a reference attribute.
511
+ # For now we just return None
512
+ return None
513
+
514
+ ir_value = node.outputs[0]
515
+ if attr_name in {"value_float", "value_floats"}:
516
+ tensor = _core.Tensor(
517
+ np.array(attr_value.value, dtype=np.float32), name=ir_value.name
518
+ )
519
+ elif attr_name in {"value_int", "value_ints"}:
520
+ tensor = _core.Tensor(
521
+ np.array(attr_value.value, dtype=np.int64), name=ir_value.name
522
+ )
523
+ elif attr_name in {"value_string", "value_strings"}:
524
+ tensor = _core.StringTensor(
525
+ np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
526
+ )
527
+ elif attr_name == "value":
528
+ tensor = attr_value.as_tensor()
529
+ else:
530
+ raise ValueError(
531
+ f"Unsupported attribute '{attr_name}' in Constant node '{node.name}'. "
532
+ "Expected one of 'value_float', 'value_floats', 'value_int', "
533
+ "'value_ints', 'value_string', 'value_strings', or 'value'."
534
+ )
535
+ # Assign the name of the constant value to the tensor
536
+ tensor.name = value.name
537
+ if tensor is not None and propagate_shape_type:
538
+ # Propagate the shape and type of the tensor to the value
539
+ if value.shape is not None and value.shape != tensor.shape:
540
+ logger.warning(
541
+ "Value '%s' has a shape %s that differs from "
542
+ "the constant tensor's shape %s. The value's shape will be updated.",
543
+ value,
544
+ value.shape,
545
+ tensor.shape,
546
+ )
547
+ value.shape = tensor.shape # type: ignore[assignment]
548
+ new_value_type = _core.TensorType(tensor.dtype)
549
+ if value.type is not None and value.type != new_value_type:
550
+ logger.warning(
551
+ "Value '%s' has a type '%s' that differs from "
552
+ "the constant tensor's type '%s'. The value's type will be updated.",
553
+ value,
554
+ value.type,
555
+ new_value_type,
556
+ )
557
+ value.type = new_value_type
558
+ return tensor