onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,652 @@
1
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
2
+ import numpy as np
3
+ from onnx import (
4
+ AttributeProto,
5
+ GraphProto,
6
+ FunctionProto,
7
+ ModelProto,
8
+ NodeProto,
9
+ TypeProto,
10
+ ValueInfoProto,
11
+ helper as oh,
12
+ load,
13
+ save as onnx_save,
14
+ shape_inference as shi,
15
+ )
16
+ from onnx.defs import onnx_opset_version
17
+ import onnxruntime
18
+ from ..helpers import string_type
19
+ from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
20
+ from ..helpers.ort_session import (
21
+ InferenceSessionForTorch,
22
+ InferenceSessionForNumpy,
23
+ _InferenceSession,
24
+ )
25
+ from ..helpers.torch_helper import to_tensor
26
+ from .report_results_comparison import ReportResultComparison
27
+ from .evaluator import ExtendedReferenceEvaluator
28
+
29
+
30
+ PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
31
+ Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
32
+
33
+
34
+ class OnnxruntimeEvaluator:
35
+ """
36
+ This class loads an onnx model and the executes one by one the nodes
37
+ with onnxruntime. This class is mostly meant for debugging.
38
+
39
+ :param proto: proto or filename
40
+ :param session_options: options
41
+ :param providers: providers
42
+ :param nvtx: enable nvidia events
43
+ :param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
44
+ :param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
45
+ :param log_severity_level: see :class:`onnxruntime.SessionOptions`
46
+ :param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
47
+ :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
48
+ :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
49
+ :param use_training_api: use onnxruntime-traning API
50
+ :param verbose: verbosity
51
+ :param local_functions: additional local function
52
+ :param ir_version: ir version to use when unknown
53
+ :param opsets: opsets to use when unknown
54
+ :param whole: if True, do not split node by node
55
+ :param torch_or_numpy: force the use of one of them, True for torch,
56
+ False for numpy, None to let the class choose
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ proto: Union[str, Proto, "OnnxruntimeEvaluator"],
62
+ session_options: Optional[onnxruntime.SessionOptions] = None,
63
+ providers: Optional[Union[str, List[str]]] = None,
64
+ nvtx: bool = False,
65
+ enable_profiling: bool = False,
66
+ graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
67
+ log_severity_level: Optional[int] = None,
68
+ log_verbosity_level: Optional[int] = None,
69
+ optimized_model_filepath: Optional[str] = None,
70
+ disable_aot_function_inlining: Optional[bool] = None,
71
+ use_training_api: bool = False,
72
+ verbose: int = 0,
73
+ local_functions: Optional[
74
+ Dict[Tuple[str, str], Union[Proto, "OnnxruntimeEvaluator"]]
75
+ ] = None,
76
+ ir_version: int = 10,
77
+ opsets: Optional[Union[int, Dict[str, int]]] = None,
78
+ whole: bool = False,
79
+ torch_or_numpy: Optional[bool] = None,
80
+ ):
81
+ if isinstance(proto, str):
82
+ self.proto: Proto = load(proto)
83
+ elif isinstance(proto, OnnxruntimeEvaluator):
84
+ assert isinstance(
85
+ proto.proto, PROTO
86
+ ), f"Unexpected type for proto.proto {type(proto.proto)}"
87
+ self.proto = proto.proto
88
+ else:
89
+ self.proto = proto
90
+ assert isinstance(
91
+ self.proto, PROTO
92
+ ), f"Unexpected type for self.proto {type(self.proto)}"
93
+
94
+ self._cache: Dict[
95
+ Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
96
+ ] = {}
97
+ self.ir_version = ir_version
98
+ self.opsets = opsets
99
+ self.session_kwargs: Dict[str, Any] = dict(
100
+ session_options=session_options,
101
+ providers=providers,
102
+ nvtx=nvtx,
103
+ enable_profiling=enable_profiling,
104
+ graph_optimization_level=graph_optimization_level,
105
+ log_severity_level=log_severity_level,
106
+ log_verbosity_level=log_verbosity_level,
107
+ optimized_model_filepath=optimized_model_filepath,
108
+ disable_aot_function_inlining=disable_aot_function_inlining,
109
+ use_training_api=use_training_api,
110
+ )
111
+ self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
112
+
113
+ self.verbose = verbose
114
+ self.torch_or_numpy = torch_or_numpy
115
+ self.sess_: Optional[_InferenceSession] = None
116
+ if whole:
117
+ self.nodes: Optional[List[NodeProto]] = None
118
+ self.rt_inits_: Optional[Dict[str, Any]] = None
119
+ self.rt_nodes_: Optional[List[NodeProto]] = None
120
+ else:
121
+ self.nodes = (
122
+ [self.proto]
123
+ if isinstance(self.proto, NodeProto)
124
+ else (
125
+ list(
126
+ self.proto.graph.node
127
+ if hasattr(self.proto, "graph")
128
+ else self.proto.node
129
+ )
130
+ )
131
+ )
132
+ self.rt_inits_ = (
133
+ {
134
+ init.name: self.to_tensor_or_array(init)
135
+ for init in self.proto.graph.initializer
136
+ }
137
+ if hasattr(self.proto, "graph")
138
+ else {}
139
+ )
140
+ self.rt_nodes_ = self.nodes.copy()
141
+
142
+ self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037
143
+ {(f.domain, f.name): self.__class__(f) for f in self.proto.functions}
144
+ if hasattr(self.proto, "functions")
145
+ else {}
146
+ )
147
+ if local_functions:
148
+ self.local_functions.update(local_functions)
149
+ self.garbage_collector = self._build_garbage_collector() if self.rt_nodes_ else {}
150
+
151
+ @property
152
+ def input_names(self) -> List[str]:
153
+ "Returns input names."
154
+ assert self.proto, "self.proto is empty"
155
+ if isinstance(self.proto, NodeProto):
156
+ assert isinstance(
157
+ self.nodes, list
158
+ ), f"Unexpected type {type(self.nodes)} for self.nodes"
159
+ return self.nodes[0].input
160
+ return [
161
+ getattr(o, "name", o)
162
+ for o in (
163
+ self.proto.graph.input if hasattr(self.proto, "graph") else self.proto.input
164
+ )
165
+ ]
166
+
167
+ @property
168
+ def output_names(self) -> List[str]:
169
+ "Returns output names."
170
+ assert self.proto, "self.proto is empty"
171
+ if isinstance(self.proto, NodeProto):
172
+ assert isinstance(
173
+ self.nodes, list
174
+ ), f"Unexpected type {type(self.nodes)} for self.nodes"
175
+ return self.nodes[0].output
176
+ return [
177
+ getattr(o, "name", o)
178
+ for o in (
179
+ self.proto.graph.output if hasattr(self.proto, "graph") else self.proto.output
180
+ )
181
+ ]
182
+
183
+ @property
184
+ def input_types(self) -> List[TypeProto]:
185
+ "Returns input types."
186
+ if not isinstance(self.proto, (ModelProto, GraphProto)):
187
+ raise ValueError(f"Cannot guess input types for type {type(self.proto)}")
188
+ g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
189
+ return [i.type for i in g.input]
190
+
191
+ @property
192
+ def output_types(self) -> List[TypeProto]:
193
+ "Returns output types."
194
+ if not isinstance(self.proto, (ModelProto, GraphProto)):
195
+ raise ValueError(f"Cannot guess output types for type {type(self.proto)}")
196
+ g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
197
+ return [i.type for i in g.output]
198
+
199
+ def _log_arg(self, a: Any) -> Any:
200
+ if isinstance(a, (str, int, float)):
201
+ return a
202
+ device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
203
+ if hasattr(a, "shape"):
204
+ prefix = "A:" if hasattr(a, "astype") else "T:"
205
+ if self.verbose < 4: # noqa: PLR2004
206
+ return f"{prefix}{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
207
+ elements = a.ravel().tolist()
208
+ if len(elements) > 10: # noqa: PLR2004
209
+ elements = elements[:10]
210
+ return f"{prefix}{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
211
+ return f"{prefix}{device}{a.dtype}:{a.shape}:{elements}"
212
+ if hasattr(a, "append"):
213
+ return ", ".join(map(self._log_arg, a))
214
+ return a
215
+
216
+ def _log(self, level: int, pattern: str, *args: Any) -> None:
217
+ if level < self.verbose:
218
+ new_args = [self._log_arg(a) for a in args]
219
+ print(pattern % tuple(new_args))
220
+
221
+ def _is_local_function(self, node: NodeProto) -> bool:
222
+ return (node.domain, node.op_type) in self.local_functions
223
+
224
+ def run(
225
+ self,
226
+ outputs: Optional[List[str]],
227
+ feed_inputs: Dict[str, Any],
228
+ intermediate: bool = False,
229
+ report_cmp: Optional[ReportResultComparison] = None,
230
+ ) -> Union[Dict[str, Any], List[Any]]:
231
+ """
232
+ Runs the model.
233
+ It only works with numpy arrays.
234
+
235
+ :param outputs: required outputs or None for all
236
+ :param feed_inputs: inputs
237
+ :param intermediate: returns all output instead of the last ones
238
+ :param report_cmp: used as a reference,
239
+ every intermediate results is compare to every existing one,
240
+ if not empty, it is an instance of
241
+ :class:`onnx_diagnostic.reference.ReportResultComparison`
242
+ :return: outputs, as a list if return_all is False,
243
+ as a dictionary if return_all is True
244
+ """
245
+ if self.rt_nodes_ is None:
246
+ # runs a whole
247
+ if self.sess_ is None:
248
+ assert self.proto, "self.proto is empty"
249
+ _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
250
+ assert self.sess_, "mypy not happy"
251
+ return self.sess_.run(outputs, feed_inputs)
252
+ if outputs is None:
253
+ outputs = self.output_names
254
+ results: Dict[str, Any] = (self.rt_inits_ or {}).copy()
255
+
256
+ for k, v in results.items():
257
+ self._log(2, " +C %s: %s", k, v)
258
+ for k, v in feed_inputs.items():
259
+ assert not isinstance(v, str), f"Unexpected type str for {k!r}"
260
+ self._log(2, " +I %s: %s", k, v)
261
+ results[k] = v
262
+
263
+ for i_node, node in enumerate(self.rt_nodes_ or []):
264
+ self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
265
+ for i in node.input:
266
+ if i != "" and i not in results:
267
+ raise RuntimeError(
268
+ f"Unable to find input {i!r} in known results {sorted(results)}, "
269
+ f"self.rt_inits_ has {sorted((self.rt_inits_ or {}))}, "
270
+ f"feed_inputs has {sorted(feed_inputs)}."
271
+ )
272
+ inputs = [(results[i] if i != "" else None) for i in node.input]
273
+ if node.op_type == "If" and node.domain == "":
274
+ outputs = self._run_if(node, inputs, results)
275
+ elif node.op_type in {"Scan", "Loop"} and node.domain == "":
276
+ outputs = self._run_scan(node, inputs, results)
277
+ elif self._is_local_function(node):
278
+ outputs = self._run_local(node, inputs, results)
279
+ else:
280
+ outputs = self._run(node, inputs, results)
281
+ for name, value in zip(node.output, outputs):
282
+ if name == "":
283
+ continue
284
+ self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
285
+ assert isinstance(name, str), f"unexpected type for name {type(name)}"
286
+ results[name] = value
287
+ if report_cmp:
288
+ reported = report_cmp.report(dict(zip(node.output, outputs)))
289
+ if self.verbose > 1:
290
+ print(f" -- report {len(reported)} comparisons")
291
+ if not intermediate:
292
+ self._clean_unused_inplace(i_node, node, results)
293
+
294
+ if intermediate:
295
+ return results
296
+ output_names = self.output_names
297
+ for name in output_names:
298
+ if name == "":
299
+ continue
300
+ if name not in results:
301
+ raise RuntimeError(
302
+ f"Unable to find output name {name!r} "
303
+ f"in {sorted(results)}, proto is\n{pretty_onnx(self.proto)}"
304
+ )
305
+ return [results[name] for name in output_names if name != ""]
306
+
307
+ def _build_garbage_collector(self) -> Dict[str, int]:
308
+ """
309
+ Memorizes the results not needed anymore for every node.
310
+ Returns a dictionary with the last node using the results.
311
+ """
312
+ needed = {}
313
+ for i, node in enumerate(self.rt_nodes_ or []):
314
+ for name in node.input:
315
+ needed[name] = i
316
+ if node.op_type in {"Scan", "If", "Loop"}:
317
+ hidden = self._get_hidden_node_inputs(node)
318
+ for name in hidden:
319
+ needed[name] = i
320
+ if isinstance(self.proto, ModelProto):
321
+ for o in self.proto.graph.output:
322
+ needed[o.name] = len(self.rt_nodes_ or [])
323
+ elif isinstance(self.proto, GraphProto):
324
+ for o in self.proto.output:
325
+ needed[o.name] = len(self.rt_nodes_ or [])
326
+ elif isinstance(self.proto, FunctionProto):
327
+ for o in self.proto.output:
328
+ needed[o] = len(self.rt_nodes_ or [])
329
+ return needed
330
+
331
+ def _clean_unused_inplace(self, i_node: int, node: NodeProto, results: Dict[str, Any]):
332
+ """
333
+ Cleans all results not needed anymore. Some models requires to clean the memory
334
+ to be able to run.
335
+ """
336
+ if not self.garbage_collector:
337
+ return
338
+ for name in node.input:
339
+ if self.garbage_collector[name] == i_node and name in results:
340
+ if self.verbose:
341
+ t = results[name]
342
+ print(f" - deletes: {name} - {t.dtype}:{t.shape}")
343
+ del results[name]
344
+ if node.op_type in {"Scan", "If", "Loop"}:
345
+ hidden = self._get_hidden_node_inputs(node)
346
+ for name in hidden:
347
+ if self.garbage_collector[name] == i_node and name in results:
348
+ if self.verbose:
349
+ t = results[name]
350
+ print(f" - deletes: {name} - {t.dtype}:{t.shape}")
351
+ del results[name]
352
+
353
+ def _make_model_proto(
354
+ self,
355
+ nodes: Sequence[NodeProto],
356
+ vinputs: Sequence[ValueInfoProto],
357
+ voutputs: Sequence[ValueInfoProto],
358
+ ) -> ModelProto:
359
+ onx = oh.make_model(
360
+ oh.make_graph(nodes, "-", vinputs, voutputs),
361
+ ir_version=getattr(self.proto, "ir_version", self.ir_version),
362
+ functions=getattr(self.proto, "functions", None),
363
+ )
364
+ del onx.opset_import[:]
365
+ if hasattr(self.proto, "opset_import"):
366
+ onx.opset_import.extend(self.proto.opset_import)
367
+ elif self.opsets:
368
+ if isinstance(self.opsets, int):
369
+ onx.opset_import.append(oh.make_opsetid("", self.opsets))
370
+ else:
371
+ onx.opset_import.extend(
372
+ [oh.make_opsetid(k, v) for k, v in self.opsets.items()]
373
+ )
374
+ else:
375
+ onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
376
+
377
+ # That helps fixing bugs.
378
+ onx = shi.infer_shapes(onx)
379
+ return onx
380
+
381
+ @classmethod
382
+ def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
383
+ """
384
+ Returns the hidden inputs (inputs coming from an upper context)
385
+ used by a subgraph.
386
+ """
387
+ hidden = set()
388
+ memo = set(i.name for i in graph.initializer)
389
+ memo |= set(i.name for i in graph.sparse_initializer)
390
+ for node in graph.node:
391
+ for i in node.input:
392
+ if i not in memo:
393
+ hidden.add(i)
394
+ for att in node.attribute:
395
+ if att.type == AttributeProto.GRAPH and att.g:
396
+ hid = self._get_hidden_inputs(att.g)
397
+ less = set(h for h in hid if h not in memo)
398
+ hidden |= less
399
+ memo |= set(node.output)
400
+ return hidden
401
+
402
+ @classmethod
403
+ def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]:
404
+ """Calls multiple _get_hidden_inputs on every attribute."""
405
+ if node.op_type not in {"Loop", "Scan", "If"}:
406
+ return set()
407
+ hidden = set()
408
+ for att in node.attribute:
409
+ if att.type == AttributeProto.GRAPH:
410
+ hidden |= self._get_hidden_inputs(att.g)
411
+ return hidden - (hidden & set(node.input))
412
+
413
+ def _get_sess(
414
+ self, node: Union[ModelProto, NodeProto], inputs: List[Any]
415
+ ) -> Tuple[ModelProto, _InferenceSession]:
416
+ if isinstance(node, ModelProto):
417
+ onx = node
418
+ else:
419
+ assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
420
+ if node.op_type == "Constant":
421
+ # We force the type to be a boolean.
422
+ ref = ExtendedReferenceEvaluator(node)
423
+ cst = ref.run(None, {})[0]
424
+ vinputs: List[ValueInfoProto] = []
425
+ voutputs = [
426
+ oh.make_tensor_value_info(
427
+ node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
428
+ )
429
+ ]
430
+ else:
431
+ unique_names = set()
432
+ vinputs = []
433
+ for i, it in zip(node.input, inputs):
434
+ if i == "" or i in unique_names:
435
+ continue
436
+ unique_names.add(i)
437
+ value = oh.make_tensor_value_info(
438
+ i, dtype_to_tensor_dtype(it.dtype), it.shape
439
+ )
440
+ vinputs.append(value)
441
+
442
+ # no need to run shape inference
443
+ voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
444
+
445
+ onx = self._make_model_proto([node], vinputs, voutputs)
446
+
447
+ cls = (
448
+ InferenceSessionForNumpy
449
+ if any(isinstance(i, np.ndarray) for i in inputs)
450
+ and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
451
+ else InferenceSessionForTorch
452
+ )
453
+ try:
454
+ sess = cls(onx, **self.session_kwargs)
455
+ except (
456
+ onnxruntime.capi.onnxruntime_pybind11_state.Fail,
457
+ onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
458
+ onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
459
+ ) as e:
460
+ onnx_save(onx, "_debug_OnnxruntimeEvaluator_last_failure.onnx")
461
+ raise RuntimeError(
462
+ f"Unable to infer a session with inputs\n{string_type(inputs)}"
463
+ f"\ndue to {e}\n{pretty_onnx(onx)}"
464
+ ) from e
465
+ return onx, sess
466
+
467
+ def _get_sess_init_subgraph(
468
+ self, node: NodeProto, inputs: List[Any], context: Dict[str, Any], g: GraphProto
469
+ ) -> List[Any]:
470
+ unique_names = set()
471
+ vinputs = []
472
+ for i, it in zip(node.input, inputs):
473
+ if i == "" or i in unique_names:
474
+ continue
475
+ unique_names.add(i)
476
+ value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
477
+ vinputs.append(value)
478
+
479
+ reduced_set = self._get_hidden_inputs(g)
480
+ for i, v in context.items():
481
+ if i in reduced_set and i not in unique_names:
482
+ unique_names.add(i)
483
+ value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
484
+ vinputs.append(value)
485
+ return vinputs
486
+
487
+ def _get_sess_if(
488
+ self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
489
+ ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
490
+ g = None
491
+ for att in node.attribute:
492
+ if att.name == branch:
493
+ g = att.g
494
+ assert g, f"Missing attribute {branch!r}"
495
+ vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
496
+
497
+ voutputs = g.output
498
+
499
+ identities = [
500
+ oh.make_node("Identity", [iname], [ginput.name])
501
+ for iname, ginput in zip(node.input, g.input)
502
+ ]
503
+
504
+ onx = self._make_model_proto([*identities, *g.node], vinputs, voutputs)
505
+ sess = OnnxruntimeEvaluator(
506
+ onx,
507
+ local_functions=self.local_functions,
508
+ verbose=self.verbose,
509
+ ir_version=self.ir_version,
510
+ opsets=self.opsets,
511
+ torch_or_numpy=self.torch_or_numpy,
512
+ **self.session_kwargs,
513
+ )
514
+ return onx, sess
515
+
516
+ def _get_sess_local(
517
+ self, node: NodeProto, inputs: List[Any]
518
+ ) -> Tuple[FunctionProto, "OnnxruntimeEvaluator"]:
519
+ ev = self.local_functions[node.domain, node.op_type]
520
+ sess = OnnxruntimeEvaluator(
521
+ ev,
522
+ local_functions=self.local_functions,
523
+ verbose=self.verbose,
524
+ ir_version=self.ir_version,
525
+ opsets=self.opsets,
526
+ torch_or_numpy=self.torch_or_numpy,
527
+ **self.session_kwargs,
528
+ )
529
+ return ev.proto, sess
530
+
531
+ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
532
+ """Runs a node."""
533
+ types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
534
+ key = (id(node), *types)
535
+ if key in self._cache:
536
+ sess = self._cache[key][1]
537
+ else:
538
+ onx, sess = self._get_sess(node, inputs)
539
+ self._cache[key] = onx, sess
540
+
541
+ feeds = dict(zip(node.input, inputs))
542
+ if "" in feeds:
543
+ feeds[""] = np.array([0], dtype=np.float32)
544
+
545
+ assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
546
+ outputs = list(sess.run(None, feeds))
547
+ assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
548
+ return outputs
549
+
550
+ def _run_if(
551
+ self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
552
+ ) -> List[Any]:
553
+ """Runs a node If."""
554
+ feeds = dict(zip(node.input, inputs))
555
+ feeds.update(results)
556
+ if feeds[node.input[0]]:
557
+ name = "then_branch"
558
+ else:
559
+ name = "else_branch"
560
+
561
+ key = (id(node), name)
562
+ if key in self._cache:
563
+ sess = self._cache[key][1]
564
+ else:
565
+ self._cache[key] = _onx, sess = self._get_sess_if(node, name, inputs, results)
566
+
567
+ assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
568
+ feeds = {name: results[name] for name in sess.input_names}
569
+ outputs = sess.run(None, feeds)
570
+ assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
571
+ return outputs
572
+
573
+ def _get_sess_scan(
574
+ self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
575
+ ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
576
+ g = None
577
+ for att in node.attribute:
578
+ if att.name == branch:
579
+ g = att.g
580
+ assert g, f"Missing attribute {branch!r}"
581
+ vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
582
+
583
+ begin = 0 if node.op_type == "Scan" else 1
584
+ voutputs = []
585
+ for name, _goutput in zip(node.output, g.output[begin:]):
586
+ v = ValueInfoProto()
587
+ # v.ParseFromString(goutput.SerializeToString())
588
+ v.name = name
589
+ voutputs.append(v)
590
+
591
+ # identities = []
592
+ # for iname, ginput in zip(node.input, g.input):
593
+ # identities.append(oh.make_node("Identity", [iname], [ginput.name]))
594
+
595
+ onx = self._make_model_proto([node], vinputs, voutputs)
596
+ sess = OnnxruntimeEvaluator(
597
+ onx,
598
+ local_functions=self.local_functions,
599
+ verbose=self.verbose,
600
+ ir_version=self.ir_version,
601
+ opsets=self.opsets,
602
+ torch_or_numpy=self.torch_or_numpy,
603
+ whole=True,
604
+ **self.session_kwargs,
605
+ )
606
+ return onx, sess
607
+
608
+ def _run_scan(
609
+ self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
610
+ ) -> List[Any]:
611
+ """Runs a node Scan."""
612
+ feeds = dict(zip(node.input, inputs))
613
+ feeds.update(results)
614
+ name = "body"
615
+ key = (id(node), name)
616
+ if key in self._cache:
617
+ sess = self._cache[key][1]
618
+ else:
619
+ self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
620
+
621
+ assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
622
+ feeds = {name: results[name] for name in sess.input_names}
623
+ outputs = sess.run(None, feeds)
624
+ assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
625
+ return outputs
626
+
627
+ def _run_local(
628
+ self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
629
+ ) -> List[Any]:
630
+ """Runs a node."""
631
+ types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
632
+ key = (id(node), *types)
633
+ if key in self._cache:
634
+ sess = self._cache[key][1]
635
+ else:
636
+ onx, sess = self._get_sess_local(node, inputs)
637
+ self._cache[key] = onx, sess
638
+
639
+ replace = dict(zip(node.input, sess.input_names))
640
+ assert len(node.input) == len(sess.input_names), (
641
+ f"Input mismatch: input_names={sess.input_names}, "
642
+ f"replace={replace}, "
643
+ f"type(self.proto)={type(self.proto)}, and node=\n{node}"
644
+ )
645
+ feeds = {replace[i]: v for i, v in zip(node.input, inputs)}
646
+ if "" in feeds:
647
+ feeds[""] = np.array([0], dtype=np.float32)
648
+
649
+ assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
650
+ outputs = sess.run(None, feeds)
651
+ assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
652
+ return outputs
@@ -0,0 +1,46 @@
1
+ import numpy as np
2
+
3
+
4
+ class QuantizedTensor:
5
+ """
6
+ Quantizes a vector in range [0, 255].
7
+
8
+ :param tensor: original tensor
9
+ """
10
+
11
+ def __init__(self, tensor):
12
+ _min = tensor.min()
13
+ _max = tensor.max()
14
+ _min = min(_min, 0)
15
+ _max = max(_max, 0)
16
+ qmin = 0
17
+ qmax = 255
18
+
19
+ self.scale_ = np.array((_max - _min) / (qmax - qmin), dtype=tensor.dtype)
20
+ initial_zero_point = qmin - _min / self.scale_
21
+ self.zero_point_ = np.array(
22
+ int(max(qmin, min(qmax, initial_zero_point))), dtype=np.uint8
23
+ )
24
+ self.quantized_ = np.maximum(
25
+ 0, np.minimum(qmax, (tensor / self.scale_).astype(int) + self.zero_point_)
26
+ ).astype(self.zero_point_.dtype)
27
+
28
+ @property
29
+ def shape(self):
30
+ "accessor"
31
+ return self.quantized_.shape
32
+
33
+ @property
34
+ def scale(self):
35
+ "accessor"
36
+ return self.scale_
37
+
38
+ @property
39
+ def zero_point(self):
40
+ "accessor"
41
+ return self.zero_point_
42
+
43
+ @property
44
+ def qtensor(self):
45
+ "accessor"
46
+ return self.quantized_