onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1200 @@
1
+ import functools
2
+ import json
3
+ import os
4
+ import sys
5
+ import warnings
6
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import onnx
10
+ import onnx.helper as oh
11
+ import onnx.numpy_helper as onh
12
+ from onnx import (
13
+ AttributeProto,
14
+ FunctionProto,
15
+ GraphProto,
16
+ ModelProto,
17
+ NodeProto,
18
+ TensorProto,
19
+ ValueInfoProto,
20
+ load as onnx_load,
21
+ )
22
+
23
+
24
+ def _make_stat(init: TensorProto) -> Dict[str, float]:
25
+ """
26
+ Produces statistics.
27
+
28
+ :param init: tensor
29
+ :return statistics
30
+ """
31
+ ar = onh.to_array(init)
32
+ return dict(
33
+ mean=float(ar.mean()),
34
+ std=float(ar.std()),
35
+ shape=ar.shape,
36
+ itype=np_dtype_to_tensor_dtype(ar.dtype),
37
+ min=float(ar.min()),
38
+ max=float(ar.max()),
39
+ )
40
+
41
+
42
+ def onnx_lighten(
43
+ onx: Union[str, ModelProto],
44
+ verbose: int = 0,
45
+ ) -> Tuple[ModelProto, Dict[str, Dict[str, float]]]:
46
+ """
47
+ Creates a model without big initializers but stores statistics
48
+ into dictionaries. The function can be reversed with
49
+ :func:`onnx_diagnostic.helpers.onnx_helper.onnx_unlighten`.
50
+ The model is modified inplace.
51
+
52
+ :param onx: model
53
+ :param verbose: verbosity
54
+ :return: new model, statistics
55
+ """
56
+ if isinstance(onx, str):
57
+ if verbose:
58
+ print(f"[onnx_lighten] load {onx!r}")
59
+ model = onnx.load(onx)
60
+ else:
61
+ assert isinstance(onx, ModelProto), f"Unexpected type {type(onx)}"
62
+ model = onx
63
+
64
+ keep = []
65
+ stats = []
66
+ for init in model.graph.initializer:
67
+ shape = init.dims
68
+ size = np.prod(shape)
69
+ if size > 2**12:
70
+ stat = _make_stat(init)
71
+ stats.append((init.name, stat))
72
+ if verbose:
73
+ print(f"[onnx_lighten] remove initializer {init.name!r} stat={stat}")
74
+ else:
75
+ keep.append(init)
76
+
77
+ del model.graph.initializer[:]
78
+ model.graph.initializer.extend(keep)
79
+ return model, dict(stats)
80
+
81
+
82
+ def _get_tensor(min=None, max=None, mean=None, std=None, shape=None, itype=None):
83
+ assert itype is not None, "itype must be specified."
84
+ assert shape is not None, "shape must be specified."
85
+ dtype = tensor_dtype_to_np_dtype(itype)
86
+ if (mean is None or std is None) or (
87
+ min is not None and max is not None and abs(max - min - 1) < 0.01
88
+ ):
89
+ if min is None:
90
+ min = 0
91
+ if max is None:
92
+ max = 0
93
+ return (np.random.random(shape) * (max - min) + min).astype(dtype)
94
+ assert std is not None and mean is not None, f"mean={mean} or std={std} is None"
95
+ t = np.random.randn(*shape).astype(dtype)
96
+ return t
97
+
98
+
99
+ def onnx_unlighten(
100
+ onx: Union[str, ModelProto],
101
+ stats: Optional[Dict[str, Dict[str, float]]] = None,
102
+ verbose: int = 0,
103
+ ) -> ModelProto:
104
+ """
105
+ Function fixing the model produced by function
106
+ :func:`onnx_diagnostic.helpers.onnx_helper.onnx_lighten`.
107
+ The model is modified inplace.
108
+
109
+ :param onx: model
110
+ :param stats: statistics, can be None if onx is a file,
111
+ then it loads the file ``<filename>.stats``,
112
+ it assumes it is json format
113
+ :param verbose: verbosity
114
+ :return: new model, statistics
115
+ """
116
+ if isinstance(onx, str):
117
+ if stats is None:
118
+ fstats = f"{onx}.stats"
119
+ assert os.path.exists(fstats), f"File {fstats!r} is missing."
120
+ if verbose:
121
+ print(f"[onnx_unlighten] load {fstats!r}")
122
+ with open(fstats, "r") as f:
123
+ stats = json.load(f)
124
+ if verbose:
125
+ print(f"[onnx_unlighten] load {onx!r}")
126
+ model = onnx.load(onx)
127
+ else:
128
+ assert isinstance(onx, ModelProto), f"Unexpected type {type(onx)}"
129
+ model = onx
130
+ assert stats is not None, "stats is missing"
131
+
132
+ keep = []
133
+ for name, stat in stats.items():
134
+ t = _get_tensor(**stat)
135
+ init = from_array_extended(t, name=name)
136
+ keep.append(init)
137
+
138
+ model.graph.initializer.extend(keep)
139
+ return model
140
+
141
+
142
+ def _validate_graph(
143
+ g: GraphProto,
144
+ existing: Set[str],
145
+ verbose: int = 0,
146
+ watch: Optional[Set[str]] = None,
147
+ path: Optional[Sequence[str]] = None,
148
+ ):
149
+ found = []
150
+ path = path or ["root"]
151
+ set_init = set(i.name for i in g.initializer)
152
+ set_input = set(i.name for i in g.input)
153
+ existing |= set_init | set_input
154
+ if watch and set_init & watch:
155
+ if verbose:
156
+ print(f"-- found init {set_init & watch} in {path}")
157
+ found.extend([i for i in g.initializer if i.name in set_init & watch])
158
+ if watch and set_input & watch:
159
+ if verbose:
160
+ print(f"-- found input {set_input & watch} in {path}")
161
+ found.extend([i for i in g.input if i.name in set_input & watch])
162
+ try:
163
+ import tqdm
164
+
165
+ loop = tqdm.tqdm(g.node) if verbose else g.node
166
+ except ImportError:
167
+ loop = g.node
168
+
169
+ for node in loop:
170
+ ins = set(node.input) & existing
171
+ if ins != set(node.input):
172
+ raise AssertionError(
173
+ f"One input is missing from node.input={node.input}, "
174
+ f"existing={ins}, path={'/'.join(path)}, "
175
+ f"node: {node.op_type}[{node.name}]"
176
+ )
177
+ if watch and ins & watch:
178
+ if verbose:
179
+ print(
180
+ f"-- found input {ins & watch} in "
181
+ f"{'/'.join(path)}/{node.op_type}[{node.name}]"
182
+ )
183
+ found.append(node)
184
+ for att in node.attribute:
185
+ if att.type == AttributeProto.GRAPH:
186
+ found.extend(
187
+ _validate_graph(
188
+ att.g,
189
+ existing.copy(),
190
+ watch=watch,
191
+ path=[*path, f"{node.op_type}[{node.name}]"],
192
+ verbose=verbose,
193
+ )
194
+ )
195
+ existing |= set(node.output)
196
+ if watch and set(node.output) & watch:
197
+ if verbose:
198
+ print(
199
+ f"-- found output {set(node.output) & watch} "
200
+ f"in {'/'.join(path)}/{node.op_type}[{node.name}]"
201
+ )
202
+ found.append(node)
203
+ out = set(o.name for o in g.output)
204
+ ins = out & existing
205
+ if ins != out:
206
+ raise AssertionError(
207
+ f"One output is missing, out={node.input}, existing={ins}, path={path}"
208
+ )
209
+ return found
210
+
211
+
212
+ def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[str]] = None):
213
+ existing = set(g.input)
214
+ found = []
215
+ for node in g.node:
216
+ ins = set(node.input) & existing
217
+ if ins != set(node.input):
218
+ raise AssertionError(
219
+ f"One input is missing from node.input={node.input}, existing={ins}"
220
+ )
221
+ if watch and ins & watch:
222
+ if verbose:
223
+ print(f"-- found input {ins & watch} in {node.op_type}[{node.name}]")
224
+ found.append(node)
225
+ for att in node.attribute:
226
+ if att.type == AttributeProto.GRAPH:
227
+ found.extend(
228
+ _validate_graph(g, existing.copy(), path=[g.name], verbose=verbose)
229
+ )
230
+ existing |= set(node.output)
231
+ if watch and set(node.output) & watch:
232
+ if verbose:
233
+ print(
234
+ f"-- found output {set(node.output) & watch} "
235
+ f"in {node.op_type}[{node.name}]"
236
+ )
237
+ out = set(g.output)
238
+ ins = out & existing
239
+ if ins != out:
240
+ raise AssertionError(
241
+ f"One output is missing, out={node.input}, existing={ins}, path={g.name}"
242
+ )
243
+ return found
244
+
245
+
246
+ def onnx_find(
247
+ onx: Union[str, ModelProto], verbose: int = 0, watch: Optional[Set[str]] = None
248
+ ) -> List[Union[NodeProto, TensorProto]]:
249
+ """
250
+ Looks for node producing or consuming some results.
251
+
252
+ :param onx: model
253
+ :param verbose: verbosity
254
+ :param watch: names to search for
255
+ :return: list of nodes
256
+ """
257
+
258
+ if isinstance(onx, str):
259
+ onx = onnx.load(onx, load_external_data=False)
260
+ found = []
261
+ found.extend(_validate_graph(onx.graph, set(), verbose=verbose, watch=watch))
262
+ for f in onx.functions:
263
+ found.extend(_validate_function(f, watch=watch, verbose=verbose))
264
+ if verbose and found:
265
+ print(f"-- found {len(found)} nodes")
266
+ return found
267
+
268
+
269
+ def check_model_ort(
270
+ onx: ModelProto,
271
+ providers: Optional[Union[str, List[Any]]] = None,
272
+ dump_file: Optional[str] = None,
273
+ ) -> "onnxruntime.InferenceSession": # noqa: F821
274
+ """
275
+ Loads a model with onnxruntime.
276
+
277
+ :param onx: ModelProto
278
+ :param providers: list of providers, None fur CPU, cpu for CPU, cuda for CUDA
279
+ :param dump_file: if not empty, dumps the model into this file if
280
+ an error happened
281
+ :return: InferenceSession
282
+ """
283
+ from onnxruntime import InferenceSession
284
+
285
+ if providers is None or providers == "cpu":
286
+ providers = ["CPUExecutionProvider"]
287
+ elif not isinstance(providers, list) and providers.startswith("cuda"):
288
+ device_id = 0 if ":" not in providers else int(providers.split(":")[1])
289
+ providers = [
290
+ ("CUDAExecutionProvider", {"device_id": device_id}),
291
+ ("CPUExecutionProvider", {}),
292
+ ]
293
+
294
+ if isinstance(onx, str):
295
+ try:
296
+ return InferenceSession(onx, providers=providers)
297
+ except Exception as e:
298
+ import onnx
299
+
300
+ if dump_file:
301
+ onnx.save(onx, dump_file)
302
+
303
+ raise AssertionError( # noqa: B904
304
+ f"onnxruntime cannot load the model "
305
+ f"due to {e}\n{pretty_onnx(onnx.load(onx))}"
306
+ )
307
+ return
308
+ try:
309
+ return InferenceSession(onx.SerializeToString(), providers=providers)
310
+ except Exception as e:
311
+ if dump_file:
312
+ onnx.save(onx, dump_file)
313
+ raise AssertionError( # noqa: B904
314
+ f"onnxruntime cannot load the modeldue to {e}\n{pretty_onnx(onx)}"
315
+ )
316
+
317
+
318
+ @functools.cache
319
+ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
320
+ """
321
+ Returns the ONNX name for a specific element type.
322
+
323
+ .. runpython::
324
+ :showcode:
325
+
326
+ import onnx
327
+ from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name
328
+
329
+ itype = onnx.TensorProto.BFLOAT16
330
+ print(onnx_dtype_name(itype))
331
+ print(onnx_dtype_name(7))
332
+ """
333
+ for k in dir(TensorProto):
334
+ if "FLOAT" in k or "INT" in k or "TEXT" in k or "BOOL" in k:
335
+ v = getattr(TensorProto, k)
336
+ if v == itype:
337
+ return k
338
+ if exc:
339
+ raise ValueError(f"Unexpected value itype: {itype}")
340
+ if itype == 0:
341
+ return "UNDEFINED"
342
+ return "UNEXPECTED"
343
+
344
+
345
+ def pretty_onnx(
346
+ onx: Union[FunctionProto, GraphProto, ModelProto, ValueInfoProto, str],
347
+ with_attributes: bool = False,
348
+ highlight: Optional[Set[str]] = None,
349
+ shape_inference: bool = False,
350
+ ) -> str:
351
+ """
352
+ Displays an onnx prot in a better way.
353
+
354
+ :param with_attributes: displays attributes as well, if only a node is printed
355
+ :param highlight: to highlight some names
356
+ :param shape_inference: run shape inference before printing the model
357
+ :return: text
358
+ """
359
+ assert onx is not None, "onx cannot be None"
360
+ if isinstance(onx, str):
361
+ onx = onnx_load(onx, load_external_data=False)
362
+ assert onx is not None, "onx cannot be None"
363
+
364
+ if shape_inference:
365
+ onx = onnx.shape_inference.infer_shapes(onx)
366
+
367
+ if isinstance(onx, ValueInfoProto):
368
+ name = onx.name
369
+ itype = onx.type.tensor_type.elem_type
370
+ shape = tuple((d.dim_param or d.dim_value) for d in onx.type.tensor_type.shape.dim)
371
+ shape_str = ",".join(map(str, shape))
372
+ return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
373
+
374
+ if isinstance(onx, AttributeProto):
375
+ att = onx
376
+ if att.type == AttributeProto.INT:
377
+ return f"{att.name}={att.i}"
378
+ if att.type == AttributeProto.INTS:
379
+ return f"{att.name}={att.ints}"
380
+ if att.type == AttributeProto.FLOAT:
381
+ return f"{att.name}={att.f}"
382
+ if att.type == AttributeProto.FLOATS:
383
+ return f"{att.name}={att.floats}"
384
+ if att.type == AttributeProto.STRING:
385
+ return f"{att.name}={att.s!r}"
386
+ if att.type == AttributeProto.TENSOR:
387
+ v = to_array_extended(att.t)
388
+ assert hasattr(v, "reshape"), f"not a tensor {type(v)}"
389
+ assert hasattr(v, "shape"), f"not a tensor {type(v)}"
390
+ vf = v.reshape((-1,))
391
+ if vf.size < 10:
392
+ tt = f"[{', '.join(map(str, vf))}]"
393
+ else:
394
+ tt = f"[{', '.join(map(str, vf[:10]))}, ...]"
395
+ if len(v.shape) != 1:
396
+ return f"{att.name}=tensor({tt}, dtype={v.dtype}).reshape({v.shape})"
397
+ return f"{att.name}=tensor({tt}, dtype={v.dtype})"
398
+ raise NotImplementedError(
399
+ f"pretty_onnx not implemented yet for AttributeProto={att!r}"
400
+ )
401
+
402
+ if isinstance(onx, NodeProto):
403
+
404
+ def _high(n):
405
+ if highlight and n in highlight:
406
+ return f"**{n}**"
407
+ return n
408
+
409
+ text = (
410
+ f"{onx.op_type}({', '.join(map(_high, onx.input))})"
411
+ f" -> {', '.join(map(_high, onx.output))}"
412
+ )
413
+ if onx.domain:
414
+ text = f"{onx.domain}.{text}"
415
+ if not with_attributes or not onx.attribute:
416
+ return text
417
+ rows = []
418
+ for att in onx.attribute:
419
+ rows.append(pretty_onnx(att))
420
+ if len(rows) > 1:
421
+ suffix = "\n".join(f" {s}" for s in rows)
422
+ return f"{text}\n{suffix}"
423
+ return f"{text} --- {rows[0]}"
424
+
425
+ if isinstance(onx, TensorProto):
426
+ shape = "x".join(map(str, onx.dims))
427
+ return f"TensorProto:{onx.data_type}:{shape}:{onx.name}"
428
+
429
+ try:
430
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
431
+
432
+ if isinstance(onx, FunctionProto):
433
+ return (
434
+ f"function: {onx.name}[{onx.domain}]\n"
435
+ f"{onnx_simple_text_plot(onx, recursive=True)}"
436
+ )
437
+ return onnx_simple_text_plot(onx, recursive=True)
438
+ except ImportError:
439
+ from onnx.printer import to_text
440
+
441
+ return to_text(onx)
442
+
443
+
444
+ def get_onnx_signature(model: ModelProto) -> Tuple[Tuple[str, Any], ...]:
445
+ """
446
+ Produces a tuple of tuples corresponding to the signatures.
447
+
448
+ :param model: model
449
+ :return: signature
450
+ """
451
+ sig: List[Any] = []
452
+ for i in model.graph.input:
453
+ dt = i.type
454
+ if dt.HasField("sequence_type"):
455
+ dst = dt.sequence_type.elem_type
456
+ tdt = dst.tensor_type
457
+ el = tdt.elem_type
458
+ shape = tuple(d.dim_param or d.dim_value for d in tdt.shape.dim)
459
+ sig.append((i.name, [(i.name, el, shape)]))
460
+ elif dt.HasField("tensor_type"):
461
+ el = dt.tensor_type.elem_type
462
+ shape = tuple(d.dim_param or d.dim_value for d in dt.tensor_type.shape.dim)
463
+ sig.append((i.name, el, shape))
464
+ else:
465
+ raise AssertionError(f"Unable to interpret dt={dt!r} in {i!r}")
466
+ return tuple(sig)
467
+
468
+
469
+ def convert_endian(tensor: TensorProto) -> None:
470
+ """Call to convert endianness of raw data in tensor.
471
+
472
+ Args:
473
+ tensor: TensorProto to be converted.
474
+ """
475
+ tensor_dtype = tensor.data_type
476
+ np_dtype = tensor_dtype_to_np_dtype(tensor_dtype)
477
+ tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
478
+
479
+
480
+ def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
481
+ """
482
+ Converts a numpy array to a tensor def assuming the dtype
483
+ is defined in ml_dtypes.
484
+
485
+ Args:
486
+ arr: a numpy array.
487
+ name: (optional) the name of the tensor.
488
+
489
+ Returns:
490
+ TensorProto: the converted tensor def.
491
+ """
492
+ import ml_dtypes
493
+
494
+ assert isinstance(arr, np.ndarray), f"arr must be of type numpy.ndarray, got {type(arr)}"
495
+
496
+ tensor = TensorProto()
497
+ tensor.dims.extend(arr.shape)
498
+ if name:
499
+ tensor.name = name
500
+
501
+ if arr.dtype == ml_dtypes.bfloat16:
502
+ dtype = TensorProto.BFLOAT16
503
+ elif arr.dtype == ml_dtypes.float8_e4m3fn:
504
+ dtype = TensorProto.FLOAT8E4M3FN
505
+ elif arr.dtype == ml_dtypes.float8_e4m3fnuz:
506
+ dtype = TensorProto.FLOAT8E4M3FNUZ
507
+ elif arr.dtype == ml_dtypes.float8_e5m2:
508
+ dtype = TensorProto.FLOAT8E5M2
509
+ elif arr.dtype == ml_dtypes.float8_e5m2fnuz:
510
+ dtype = TensorProto.FLOAT8E5M2FNUZ
511
+ else:
512
+ raise NotImplementedError(f"No conversion from {arr.dtype}")
513
+ tensor.data_type = dtype
514
+ tensor.raw_data = arr.tobytes() # note: tobytes() is only after 1.9.
515
+ if sys.byteorder == "big":
516
+ convert_endian(tensor)
517
+ return tensor
518
+
519
+
520
+ _STORAGE_TYPE = {
521
+ TensorProto.FLOAT16: np.int16,
522
+ TensorProto.BFLOAT16: np.int16,
523
+ }
524
+
525
+
526
+ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
527
+ """
528
+ Converts an array into a :class:`onnx.TensorProto`.
529
+
530
+ :param tensor: numpy array or torch tensor
531
+ :param name: name
532
+ :return: TensorProto
533
+ """
534
+ if not isinstance(tensor, np.ndarray):
535
+ import torch
536
+ from .torch_helper import proto_from_tensor
537
+
538
+ assert isinstance(
539
+ tensor, torch.Tensor
540
+ ), f"Unable to convert type {type(tensor)} into TensorProto."
541
+ return proto_from_tensor(tensor, name=name)
542
+
543
+ try:
544
+ from onnx.reference.ops.op_cast import (
545
+ bfloat16,
546
+ float8e4m3fn,
547
+ float8e4m3fnuz,
548
+ float8e5m2,
549
+ float8e5m2fnuz,
550
+ )
551
+ except ImportError:
552
+ bfloat16 = None
553
+
554
+ if bfloat16 is None:
555
+ return onh.from_array(tensor, name)
556
+
557
+ dt = tensor.dtype
558
+ if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
559
+ to = TensorProto.FLOAT8E4M3FN
560
+ dt_to = np.uint8
561
+ elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
562
+ to = TensorProto.FLOAT8E4M3FNUZ
563
+ dt_to = np.uint8
564
+ elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
565
+ to = TensorProto.FLOAT8E5M2
566
+ dt_to = np.uint8
567
+ elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
568
+ to = TensorProto.FLOAT8E5M2FNUZ
569
+ dt_to = np.uint8
570
+ elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
571
+ to = TensorProto.BFLOAT16
572
+ dt_to = np.uint16
573
+ else:
574
+ try:
575
+ import ml_dtypes
576
+ except ImportError:
577
+ ml_dtypes = None
578
+ if ml_dtypes is not None and (
579
+ tensor.dtype == ml_dtypes.bfloat16
580
+ or tensor.dtype == ml_dtypes.float8_e4m3fn
581
+ or tensor.dtype == ml_dtypes.float8_e4m3fnuz
582
+ or tensor.dtype == ml_dtypes.float8_e5m2
583
+ or tensor.dtype == ml_dtypes.float8_e5m2fnuz
584
+ ):
585
+ return from_array_ml_dtypes(tensor, name)
586
+ return onh.from_array(tensor, name)
587
+
588
+ t = onh.from_array(tensor.astype(dt_to), name)
589
+ t.data_type = to
590
+ return t
591
+
592
+
593
+ def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
594
+ """Converts :class:`onnx.TensorProto` into a numpy array."""
595
+ arr = onh.to_array(proto)
596
+ if proto.data_type >= onnx.TensorProto.BFLOAT16:
597
+ # Types not supported by numpy
598
+ ml_dtypes = onnx_dtype_to_np_dtype(proto.data_type)
599
+ return arr.view(ml_dtypes)
600
+ return arr
601
+
602
+
603
+ def onnx_dtype_to_np_dtype(itype: int) -> Any:
604
+ """
605
+ Converts an onnx type into a to numpy dtype.
606
+ That includes :epkg:`ml_dtypes` dtypes.
607
+
608
+ :param to: onnx dtype
609
+ :return: numpy dtype
610
+ """
611
+ if itype == TensorProto.FLOAT:
612
+ return np.float32
613
+ if itype == TensorProto.FLOAT16:
614
+ return np.float16
615
+ if itype == TensorProto.BFLOAT16:
616
+ import ml_dtypes
617
+
618
+ return ml_dtypes.bfloat16
619
+ if itype == TensorProto.DOUBLE:
620
+ return np.float64
621
+ if itype == TensorProto.INT32:
622
+ return np.int32
623
+ if itype == TensorProto.INT64:
624
+ return np.int64
625
+ if itype == TensorProto.UINT32:
626
+ return np.uint32
627
+ if itype == TensorProto.UINT64:
628
+ return np.uint64
629
+ if itype == TensorProto.BOOL:
630
+ return np.bool
631
+ if itype == TensorProto.INT16:
632
+ return np.int16
633
+ if itype == TensorProto.UINT16:
634
+ return np.uint16
635
+ if itype == TensorProto.INT8:
636
+ return np.int16
637
+ if itype == TensorProto.UINT8:
638
+ return np.uint16
639
+ if itype == TensorProto.COMPLEX64:
640
+ return np.complex64
641
+ if itype == TensorProto.COMPLEX128:
642
+ return np.complex128
643
+ raise NotImplementedError(
644
+ f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
645
+ )
646
+
647
+
648
+ def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F821
649
+ """
650
+ Converts a torch dtype or numpy dtype into a onnx element type.
651
+
652
+ :param to: dtype
653
+ :return: onnx type
654
+ """
655
+ try:
656
+ return np_dtype_to_tensor_dtype(dt)
657
+ except (KeyError, TypeError, ValueError):
658
+ pass
659
+ from .torch_helper import torch_dtype_to_onnx_dtype
660
+
661
+ return torch_dtype_to_onnx_dtype(dt)
662
+
663
+
664
+ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
665
+ """
666
+ Converts a numpy dtype into a onnx element type.
667
+
668
+ :param to: dtype
669
+ :return: onnx type
670
+ """
671
+ try:
672
+ return oh.np_dtype_to_tensor_dtype(dt)
673
+ except ValueError:
674
+ try:
675
+ import ml_dtypes
676
+ except ImportError:
677
+ ml_dtypes = None # type: ignore
678
+ if ml_dtypes is not None:
679
+ if dt == ml_dtypes.bfloat16:
680
+ return TensorProto.BFLOAT16
681
+ if dt == ml_dtypes.float8_e4m3fn:
682
+ return TensorProto.FLOAT8E4M3FN
683
+ if dt == ml_dtypes.float8_e4m3fnuz:
684
+ return TensorProto.FLOAT8E4M3FNUZ
685
+ if dt == ml_dtypes.float8_e5m2:
686
+ return TensorProto.FLOAT8E5M2
687
+ if dt == ml_dtypes.float8_e5m2fnuz:
688
+ return TensorProto.FLOAT8E5M2FNUZ
689
+ if dt == np.float32:
690
+ return TensorProto.FLOAT
691
+ if dt == np.float16:
692
+ return TensorProto.FLOAT16
693
+ if dt == np.float64:
694
+ return TensorProto.DOUBLE
695
+ if dt == np.int64:
696
+ return TensorProto.INT64
697
+ if dt == np.uint64:
698
+ return TensorProto.UINT64
699
+ if dt == np.int16:
700
+ return TensorProto.INT16
701
+ if dt == np.uint16:
702
+ return TensorProto.UINT16
703
+ if dt == np.int32:
704
+ return TensorProto.INT32
705
+ if dt == np.int8:
706
+ return TensorProto.INT8
707
+ if dt == np.uint8:
708
+ return TensorProto.UINT8
709
+ if dt == np.uint32:
710
+ return TensorProto.UINT32
711
+ if dt == np.bool:
712
+ return TensorProto.BOOL
713
+ if dt == np.complex64:
714
+ return TensorProto.COMPLEX64
715
+ if dt == np.complex128:
716
+ return TensorProto.COMPLEX128
717
+ raise ValueError(f"Unable to convert type {dt}")
718
+
719
+
720
+ def type_info(itype: int, att: str):
721
+ """
722
+ Returns the minimum or maximum value for a type.
723
+
724
+ :param itype: onnx type
725
+ :param att: 'min' or 'max'
726
+ :return: value
727
+ """
728
+ if itype in {TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE}:
729
+ dtype = tensor_dtype_to_np_dtype(itype)
730
+ fi = np.finfo(dtype)
731
+ elif itype == TensorProto.BFLOAT16:
732
+ import ml_dtypes
733
+
734
+ dtype = tensor_dtype_to_np_dtype(itype)
735
+ fi = ml_dtypes.finfo(dtype) # type: ignore
736
+ else:
737
+ dtype = tensor_dtype_to_np_dtype(itype)
738
+ fi = np.iinfo(dtype) # type: ignore
739
+ if att == "min":
740
+ return fi.min
741
+ if att == "max":
742
+ return fi.max
743
+ raise ValueError(f"Unexpected value {att!r}")
744
+
745
+
746
+ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
747
+ """
748
+ Converts a TensorProto's data_type to corresponding numpy dtype.
749
+ It can be used while making tensor.
750
+
751
+ :param tensor_dtype: TensorProto's data_type
752
+ :return: numpy's data_type
753
+ """
754
+ if tensor_dtype >= 16:
755
+ try:
756
+ import ml_dtypes # noqa: F401
757
+ except ImportError as e:
758
+ raise ValueError(
759
+ f"Unsupported value for tensor_dtype, "
760
+ f"numpy does not support onnx type {tensor_dtype}. "
761
+ f"ml_dtypes can be used."
762
+ ) from e
763
+
764
+ mapping: Dict[int, np.dtype] = {
765
+ TensorProto.BFLOAT16: ml_dtypes.bfloat16,
766
+ TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn,
767
+ TensorProto.FLOAT8E4M3FNUZ: ml_dtypes.float8_e4m3fnuz,
768
+ TensorProto.FLOAT8E5M2: ml_dtypes.float8_e5m2,
769
+ TensorProto.FLOAT8E5M2FNUZ: ml_dtypes.float8_e5m2fnuz,
770
+ }
771
+ assert (
772
+ tensor_dtype in mapping
773
+ ), f"Unable to find tensor_dtype={tensor_dtype!r} in mapping={mapping}"
774
+ return mapping[tensor_dtype]
775
+
776
+ return oh.tensor_dtype_to_np_dtype(tensor_dtype)
777
+
778
+
779
+ def iterator_initializer_constant(
780
+ model: Union[FunctionProto, GraphProto, ModelProto],
781
+ use_numpy: bool = True,
782
+ prefix: str = "",
783
+ ) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821
784
+ """
785
+ Iterates on iniatialiers and constant in an onnx model.
786
+
787
+ :param model: model
788
+ :param use_numpy: use numpy or pytorch
789
+ :param prefix: for subgraph
790
+ :return: iterator
791
+ """
792
+ if not isinstance(model, FunctionProto):
793
+ graph = model if isinstance(model, GraphProto) else model.graph
794
+ if not use_numpy:
795
+ from .torch_helper import to_tensor
796
+ if prefix:
797
+ prefix += "."
798
+ for init in graph.initializer:
799
+ yield f"{prefix}{init.name}", (
800
+ to_array_extended(init) if use_numpy else to_tensor(init)
801
+ )
802
+ nodes = graph.node
803
+ name = graph.name
804
+ if isinstance(model, ModelProto):
805
+ for f in model.functions:
806
+ yield from iterator_initializer_constant(
807
+ f, use_numpy=use_numpy, prefix=f"{prefix}{f.name}"
808
+ )
809
+ else:
810
+ nodes = model.node
811
+ name = model.name
812
+ for node in nodes:
813
+ if node.op_type == "Constant" and node.domain == "":
814
+ from ..reference import ExtendedReferenceEvaluator as Inference
815
+
816
+ if not use_numpy:
817
+ import torch
818
+ sess = Inference(node)
819
+ value = sess.run(None, {})[0]
820
+ yield f"{prefix}{node.output[0]}", (
821
+ value if use_numpy else torch.from_numpy(value)
822
+ )
823
+
824
+ if node.op_type in {"Loop", "Body", "Scan"}:
825
+ for att in node.attribute:
826
+ assert (
827
+ att.type != onnx.AttributeProto.GRAPHS
828
+ ), "Not implemented for type AttributeProto.GRAPHS."
829
+ if att.type == onnx.AttributeProto.GRAPH:
830
+ yield from iterator_initializer_constant(
831
+ att.g, use_numpy=use_numpy, prefix=f"{prefix}{name}"
832
+ )
833
+
834
+
835
+ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union[float, str]]:
836
+ """
837
+ Produces statistics on a tensor.
838
+
839
+ :param tensor: tensor
840
+ :return: statistics
841
+
842
+ .. runpython::
843
+ :showcode:
844
+
845
+ import pprint
846
+ import numpy as np
847
+ from onnx_diagnostic.helpers.onnx_helper import tensor_statistics
848
+
849
+ t = np.random.rand(40, 50).astype(np.float16)
850
+ pprint.pprint(tensor_statistics(t))
851
+ """
852
+ from .helper import size_type
853
+
854
+ if isinstance(tensor, TensorProto):
855
+ tensor = to_array_extended(tensor)
856
+ itype = np_dtype_to_tensor_dtype(tensor.dtype)
857
+ stat = dict(
858
+ mean=float(tensor.mean()),
859
+ std=float(tensor.std()),
860
+ shape="x".join(map(str, tensor.shape)),
861
+ numel=tensor.size,
862
+ size=tensor.size * size_type(tensor.dtype),
863
+ itype=itype,
864
+ stype=onnx_dtype_name(itype),
865
+ min=float(tensor.min()),
866
+ max=float(tensor.max()),
867
+ nnan=float(np.isnan(tensor).sum()),
868
+ )
869
+
870
+ if tensor.size < 8:
871
+ return stat
872
+
873
+ with warnings.catch_warnings():
874
+ warnings.simplefilter("ignore")
875
+ try:
876
+ hist = np.array(
877
+ [
878
+ 0,
879
+ 1e-10,
880
+ 1e-8,
881
+ 1e-7,
882
+ 1e-6,
883
+ 1e-5,
884
+ 0.0001,
885
+ 0.001,
886
+ 0.01,
887
+ 0.1,
888
+ 0.5,
889
+ 1,
890
+ 1.96,
891
+ 10,
892
+ 1e2,
893
+ 1e3,
894
+ 1e4,
895
+ 1e5,
896
+ 1e6,
897
+ 1e7,
898
+ 1e8,
899
+ 1e10,
900
+ 1e50,
901
+ ],
902
+ dtype=tensor.dtype,
903
+ )
904
+ except OverflowError as e:
905
+ from .helper import string_type
906
+
907
+ raise ValueError(
908
+ f"Unable to convert one value into {tensor.dtype}, "
909
+ f"tensor={string_type(tensor, with_shape=True)}"
910
+ ) from e
911
+ hist = np.array(sorted(set(hist[~np.isinf(hist)])), dtype=tensor.dtype)
912
+ ind = np.digitize(np.abs(tensor).reshape((-1,)), hist, right=True)
913
+ cou = np.bincount(ind, minlength=ind.shape[0] + 1)
914
+ stat.update(
915
+ dict(zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))]))
916
+ )
917
+ ii = (np.arange(9) + 1) / 10
918
+ qu = np.quantile(tensor, ii)
919
+ stat.update({f"q{i}": float(q) for i, q in zip(ii, qu)})
920
+ return stat
921
+
922
+
923
+ class NodeCoordinates:
924
+ """
925
+ A way to localize a node,
926
+ path is a tuple of three information, node index, node type, node name.
927
+ """
928
+
929
+ __slots__ = ("node", "path")
930
+
931
+ def __init__(
932
+ self,
933
+ node: Union[onnx.TensorProto, NodeProto, str],
934
+ path: Tuple[Tuple[int, str, str], ...],
935
+ ):
936
+ assert isinstance(path, tuple), f"Unexpected type {type(path)} for path"
937
+ assert all(isinstance(t, tuple) for t in path), f"Unexpected type in path={path}"
938
+ self.node = node
939
+ self.path = path
940
+
941
+ def __str__(self) -> str:
942
+ "usual"
943
+ if isinstance(self.node, str):
944
+ return f"{self.path_to_str()} :: {self.node!r}"
945
+ return f"{self.path_to_str()} :: {pretty_onnx(self.node)}"
946
+
947
+ def path_to_str(self) -> str:
948
+ "Strings representing coordinates."
949
+ return "x".join(f"({':'.join(map(str, t))})" for t in self.path)
950
+
951
+
952
+ class ResultFound:
953
+ """
954
+ Class returned by :func:`enumerate_results`.
955
+ """
956
+
957
+ __slots__ = ("consumer", "name", "producer")
958
+
959
+ def __init__(
960
+ self,
961
+ name: str,
962
+ producer: Optional[NodeCoordinates],
963
+ consumer: Optional[NodeCoordinates],
964
+ ):
965
+ assert isinstance(name, str), f"unexpected type {type(name)} for name"
966
+ self.name = name
967
+ self.producer = producer
968
+ self.consumer = consumer
969
+
970
+ def __str__(self) -> str:
971
+ "usuals"
972
+ return (
973
+ f"<< {self.name} - {self.consumer}"
974
+ if self.producer is None
975
+ else f">> {self.name} - {self.producer}"
976
+ )
977
+
978
+
979
+ def enumerate_results(
980
+ proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
981
+ name: Union[Set[str], str],
982
+ verbose: int = 0,
983
+ coordinates: Optional[List[Tuple[int, str, str]]] = None,
984
+ ) -> Iterator[ResultFound]:
985
+ """
986
+ Iterates on all nodes, attributes to find where a name is used.
987
+
988
+ :param proto: a proto
989
+ :param name: name or names to find
990
+ :param verbose: verbosity
991
+ :param coordinates: coordinates of a node
992
+ :return: iterator on :class:`ResultFound`
993
+ """
994
+ if not isinstance(name, set):
995
+ name = {name}
996
+ coordinates = coordinates or []
997
+ assert all(
998
+ isinstance(c, tuple) for c in coordinates
999
+ ), f"Unexpected type in coordinates={coordinates}"
1000
+ indent = " " * len(coordinates)
1001
+ if isinstance(proto, ModelProto):
1002
+ if verbose:
1003
+ print(f"[enumerate_results] {indent}searching for {name!r} into ModelProto...")
1004
+ yield from enumerate_results(proto.graph, name, verbose=verbose)
1005
+ elif isinstance(proto, FunctionProto):
1006
+ if verbose:
1007
+ print(f"[enumerate_results] {indent}searching for {name!r} into FunctionProto...")
1008
+ for i in proto.input:
1009
+ if i in name:
1010
+ r = ResultFound(
1011
+ i,
1012
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
1013
+ None,
1014
+ )
1015
+ if verbose > 1:
1016
+ print(f"[enumerate_results] {indent}-- {r}")
1017
+ yield r
1018
+ yield from enumerate_results(proto.node, name, verbose=verbose)
1019
+ for i in proto.output:
1020
+ if i in name:
1021
+ r = ResultFound(
1022
+ i,
1023
+ None,
1024
+ NodeCoordinates(
1025
+ i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
1026
+ ),
1027
+ )
1028
+ if verbose > 1:
1029
+ print(f"[enumerate_results] {indent}-- {r}")
1030
+ yield r
1031
+ elif isinstance(proto, GraphProto):
1032
+ if verbose:
1033
+ print(f"[enumerate_results] {indent}searching for {name!r} into GraphProto...")
1034
+ for i in proto.initializer:
1035
+ if i.name in name:
1036
+ r = ResultFound(
1037
+ i.name,
1038
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
1039
+ None,
1040
+ )
1041
+ if verbose > 1:
1042
+ print(f"[enumerate_results] {indent}-- {r}")
1043
+ yield r
1044
+ for i in proto.sparse_initializer:
1045
+ if i.name in name:
1046
+ r = ResultFound(
1047
+ i.name,
1048
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
1049
+ None,
1050
+ )
1051
+ if verbose > 1:
1052
+ print(f"[enumerate_results] {indent}-- {r}")
1053
+ yield r
1054
+ for i in proto.input:
1055
+ if i.name in name:
1056
+ r = ResultFound(
1057
+ i.name,
1058
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
1059
+ None,
1060
+ )
1061
+ if verbose > 1:
1062
+ print(f"[enumerate_results] {indent}-- {r}")
1063
+ yield r
1064
+ yield from enumerate_results(
1065
+ proto.node, name, verbose=verbose, coordinates=coordinates
1066
+ )
1067
+ for i in proto.output:
1068
+ if i.name in name:
1069
+ r = ResultFound(
1070
+ i.name,
1071
+ None,
1072
+ NodeCoordinates(
1073
+ i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
1074
+ ),
1075
+ )
1076
+ if verbose > 1:
1077
+ print(f"[enumerate_results] {indent}-- {r}")
1078
+ yield r
1079
+ else:
1080
+ if verbose:
1081
+ print(
1082
+ f"[enumerate_results] {indent}searching for {name!r} into List[NodeProto]..."
1083
+ )
1084
+ for node_i, node in enumerate(proto):
1085
+ if set(node.input) & name:
1086
+ for n in node.input:
1087
+ if n in name:
1088
+ r = ResultFound(
1089
+ n,
1090
+ NodeCoordinates(
1091
+ node,
1092
+ tuple( # noqa: C409
1093
+ [*coordinates, (node_i, node.op_type, node.name)]
1094
+ ),
1095
+ ),
1096
+ None,
1097
+ )
1098
+ if verbose > 1:
1099
+ print(f"[enumerate_results] {indent}-- {r}")
1100
+ yield r
1101
+ if node.op_type in {"If", "Scan", "Loop", "SequenceMap"}:
1102
+ for att in node.attribute:
1103
+ if att.type == onnx.AttributeProto.GRAPH:
1104
+ yield from enumerate_results(
1105
+ att.g,
1106
+ name,
1107
+ verbose=verbose,
1108
+ coordinates=[*coordinates, (node_i, node.op_type, node.name)],
1109
+ )
1110
+ if set(node.output) & name:
1111
+ for n in node.output:
1112
+ if n in name:
1113
+ r = ResultFound(
1114
+ n,
1115
+ None,
1116
+ NodeCoordinates(
1117
+ node,
1118
+ tuple( # noqa: C409
1119
+ [*coordinates, (node_i, node.op_type, node.name)]
1120
+ ),
1121
+ ),
1122
+ )
1123
+ if verbose > 1:
1124
+ print(f"[enumerate_results] {indent}-- {r}")
1125
+ yield r
1126
+ if verbose:
1127
+ print(f"[enumerate_results] {indent}done")
1128
+
1129
+
1130
+ def shadowing_names(
1131
+ proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
1132
+ verbose: int = 0,
1133
+ existing: Optional[Set[str]] = None,
1134
+ shadow_context: Optional[Set[str]] = None,
1135
+ post_shadow_context: Optional[Set[str]] = None,
1136
+ ) -> Tuple[Set[str], Set[str], Set[str]]:
1137
+ """
1138
+ Returns the shadowing names, the names created in the main graph
1139
+ after they were created in a subgraphs and the names created by the nodes.
1140
+ """
1141
+ if isinstance(proto, ModelProto):
1142
+ return shadowing_names(proto.graph)
1143
+ if isinstance(proto, GraphProto):
1144
+ assert (
1145
+ existing is None and shadow_context is None
1146
+ ), "existing must be None if nodes is None"
1147
+ return shadowing_names(
1148
+ proto.node,
1149
+ verbose=verbose,
1150
+ existing=set(i.name for i in proto.initializer)
1151
+ | set(i.name for i in proto.sparse_initializer)
1152
+ | set(i.name for i in proto.input if i.name),
1153
+ shadow_context=set(),
1154
+ post_shadow_context=set(),
1155
+ )
1156
+ if isinstance(proto, FunctionProto):
1157
+ assert (
1158
+ existing is None and shadow_context is None
1159
+ ), "existing must be None if nodes is None"
1160
+ return shadowing_names(
1161
+ proto.node,
1162
+ verbose=verbose,
1163
+ existing=set(i for i in proto.input if i),
1164
+ shadow_context=set(),
1165
+ post_shadow_context=set(),
1166
+ )
1167
+
1168
+ assert (
1169
+ existing is not None and shadow_context is not None
1170
+ ), "existing must not be None if nodes is not None"
1171
+ shadow = set()
1172
+ shadow_context = shadow_context.copy()
1173
+ existing = existing.copy()
1174
+ created = set()
1175
+ post_shadow = set()
1176
+ for node in proto:
1177
+ not_empty = set(n for n in node.input if n)
1178
+ intersection = not_empty & existing
1179
+ assert len(intersection) == len(not_empty), (
1180
+ f"One input in {not_empty}, node={pretty_onnx(node)} "
1181
+ f"was not found in {existing}"
1182
+ )
1183
+ for att in node.attribute:
1184
+ if att.type == AttributeProto.GRAPH:
1185
+ g = att.g
1186
+ shadow |= set(i.name for i in g.input) & shadow_context
1187
+ shadow |= set(i.name for i in g.initializer) & shadow_context
1188
+ shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1189
+ s, _ps, c = shadowing_names(
1190
+ g.node, verbose=verbose, existing=existing, shadow_context=existing
1191
+ )
1192
+ shadow |= s
1193
+ created |= c
1194
+
1195
+ not_empty = set(n for n in node.output if n)
1196
+ post_shadow |= not_empty & created
1197
+ shadow |= not_empty & shadow_context
1198
+ existing |= not_empty
1199
+ created |= not_empty
1200
+ return shadow, post_shadow, created