onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,736 @@
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+ import onnx
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ import torch
6
+ from torch._C import _from_dlpack
7
+ import onnxruntime
8
+ from onnxruntime.capi import _pybind_state as ORTC
9
+ from .helper import size_type
10
+ from .onnx_helper import (
11
+ onnx_dtype_to_np_dtype,
12
+ np_dtype_to_tensor_dtype,
13
+ onnx_dtype_name,
14
+ )
15
+ from .torch_helper import torch_dtype_to_onnx_dtype
16
+
17
+
18
+ DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
19
+
20
+
21
+ class _InferenceSession:
22
+
23
+ @classmethod
24
+ def has_onnxruntime_training(cls):
25
+ """Tells if onnxruntime_training is installed."""
26
+ try:
27
+ from onnxruntime import training
28
+ except ImportError:
29
+ # onnxruntime not training
30
+ training = None
31
+ if training is None:
32
+ return False
33
+
34
+ try:
35
+ from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
36
+ except ImportError:
37
+ return False
38
+
39
+ if not hasattr(OrtValueVector, "push_back_batch"):
40
+ return False
41
+ return True
42
+
43
+ def __init__(
44
+ self,
45
+ sess: Union[onnx.ModelProto, str, onnxruntime.InferenceSession],
46
+ session_options: Optional[onnxruntime.SessionOptions] = None,
47
+ providers: Optional[Union[str, List[Any]]] = None,
48
+ nvtx: bool = False,
49
+ enable_profiling: bool = False,
50
+ graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
51
+ log_severity_level: Optional[int] = None,
52
+ log_verbosity_level: Optional[int] = None,
53
+ optimized_model_filepath: Optional[str] = None,
54
+ disable_aot_function_inlining: Optional[bool] = None,
55
+ use_training_api: Optional[bool] = None,
56
+ ):
57
+ # onnxruntime is importing when needed as it takes a
58
+ # couple of seconds if it contains CUDA EP.
59
+ can_use_training_api = True
60
+ if isinstance(sess, (onnx.ModelProto, str)):
61
+ if isinstance(sess, onnx.ModelProto):
62
+ for i in sess.graph.initializer:
63
+ if i.data_type >= onnx.TensorProto.BFLOAT16:
64
+ # Cannot use training api as it relies too much on numpy.
65
+ can_use_training_api = False
66
+ break
67
+ assert session_options is None or (
68
+ providers is None
69
+ and graph_optimization_level is None
70
+ and log_severity_level is None
71
+ and log_verbosity_level is None
72
+ ), "session_options is defined, it is impossible to overwrite any option."
73
+ if session_options is None:
74
+ session_options = onnxruntime.SessionOptions()
75
+ if enable_profiling:
76
+ session_options.enable_profiling = enable_profiling
77
+ if optimized_model_filepath:
78
+ session_options.optimized_model_filepath = optimized_model_filepath
79
+ if log_severity_level is not None:
80
+ session_options.log_severity_level = log_severity_level
81
+ if log_verbosity_level is not None:
82
+ session_options.log_verbosity_level = log_verbosity_level
83
+ if graph_optimization_level is not None:
84
+ if isinstance(graph_optimization_level, bool):
85
+ session_options.graph_optimization_level = (
86
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
87
+ if graph_optimization_level
88
+ else onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
89
+ )
90
+ else:
91
+ session_options.graph_optimization_level = graph_optimization_level
92
+ if disable_aot_function_inlining:
93
+ session_options.add_session_config_entry(
94
+ "session.disable_aot_function_inlining", "1"
95
+ )
96
+ if providers is None:
97
+ providers = ["CPUExecutionProvider"]
98
+ if isinstance(providers, str):
99
+ if providers.lower() == "cpu":
100
+ providers = ["CPUExecutionProvider"]
101
+ elif providers.lower() == "cuda":
102
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
103
+ else:
104
+ raise ValueError(f"Unexpected value for providers={providers!r}")
105
+ try:
106
+ sess = onnxruntime.InferenceSession(
107
+ sess if isinstance(sess, str) else sess.SerializeToString(),
108
+ session_options,
109
+ providers=providers,
110
+ )
111
+ except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
112
+ if isinstance(sess, onnx.ModelProto):
113
+ debug_path = "_debug_InferenceSession_last_failure.onnx"
114
+ onnx.save(
115
+ sess,
116
+ debug_path,
117
+ save_as_external_data=True,
118
+ all_tensors_to_one_file=True,
119
+ )
120
+ else:
121
+ debug_path = sess
122
+ raise RuntimeError(
123
+ f"Unable to create a session stored in {debug_path!r}), "
124
+ f"providers={providers}"
125
+ ) from e
126
+ else:
127
+ assert (
128
+ session_options is None
129
+ and providers is None
130
+ and graph_optimization_level is None
131
+ and log_severity_level is None
132
+ and log_verbosity_level is None
133
+ ), f"First input is {type(sess)}, it is impossible to overwrite any option."
134
+
135
+ self.sess = sess
136
+ self.input_names = [i.name for i in sess.get_inputs()]
137
+ self.output_names = [i.name for i in sess.get_outputs()]
138
+ self.input_shapes = [i.shape for i in sess.get_inputs()]
139
+ self.output_shapes = [i.shape for i in sess.get_outputs()]
140
+ self.input_types = [i.type for i in sess.get_inputs()]
141
+ self.output_types = [i.type for i in sess.get_outputs()]
142
+ self.torch = torch
143
+ self.nvtx = nvtx
144
+ self.run_options = onnxruntime.RunOptions()
145
+
146
+ if log_severity_level is not None:
147
+ self.run_options.log_severity_level = log_severity_level
148
+ if log_verbosity_level is not None:
149
+ self.run_options.log_verbosity_level = log_verbosity_level
150
+
151
+ self.use_training_api = can_use_training_api and (
152
+ self.has_onnxruntime_training() if use_training_api is None else use_training_api
153
+ )
154
+
155
+ if torch.cuda.device_count() > 0:
156
+ for i in range(torch.cuda.device_count()):
157
+ DEVICES[i] = ORTC.OrtDevice(
158
+ ORTC.OrtDevice.cuda(), ORTC.OrtDevice.default_memory(), i
159
+ )
160
+
161
+ self._torch_from_dlpack = _from_dlpack
162
+ self.sess_bool_outputs = [i.type == "tensor(bool)" for i in sess.get_outputs()]
163
+
164
+ def run(
165
+ self,
166
+ output_names: Optional[List[str]],
167
+ feeds: Union[Dict[str, np.ndarray], Dict[str, ORTC.OrtValue]],
168
+ ) -> Union[List[np.ndarray], List[ORTC.OrtValue]]:
169
+ """Calls :meth:`onnxruntime.InferenceSession.run`."""
170
+ if any(isinstance(t, np.ndarray) for t in feeds.values()):
171
+ return self.sess.run(output_names, feeds)
172
+ ort_outputs = self.sess._sess.run_with_ort_values(
173
+ feeds, output_names or self.output_names, self.run_options
174
+ )
175
+ return self._post_process_inplace(ort_outputs)
176
+
177
+ def _post_process_inplace(self, outputs):
178
+ for i in range(len(outputs)):
179
+ o = outputs[i]
180
+ if self.sess_bool_outputs[i]:
181
+ if isinstance(o, np.ndarray):
182
+ if o.dtype != np.bool_:
183
+ outputs[i] = o.astype(np.bool_)
184
+ else:
185
+ if o.dtype != torch.bool:
186
+ outputs[i] = o.to(torch.bool)
187
+ return outputs
188
+
189
+
190
+ class InferenceSessionForNumpy(_InferenceSession):
191
+ """
192
+ Wraps an `onnxruntime.InferenceSession` to overload method `run`
193
+ to support :class:`numpy.ndarray`.
194
+
195
+ :param sess: model or inference session
196
+ :param session_options: options
197
+ :param providers: providers
198
+ :param nvtx: enable nvidia events
199
+ :param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
200
+ :param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
201
+ :param log_severity_level: see :class:`onnxruntime.SessionOptions`
202
+ :param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
203
+ :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
204
+ :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
205
+ :param use_training_api: use onnxruntime-traning API
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ sess: Union[onnx.ModelProto, str, onnxruntime.InferenceSession],
211
+ session_options: Optional[onnxruntime.SessionOptions] = None,
212
+ providers: Optional[Union[str, List[str]]] = None,
213
+ nvtx: bool = False,
214
+ enable_profiling: bool = False,
215
+ graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
216
+ log_severity_level: Optional[int] = None,
217
+ log_verbosity_level: Optional[int] = None,
218
+ optimized_model_filepath: Optional[str] = None,
219
+ disable_aot_function_inlining: Optional[bool] = None,
220
+ use_training_api: Optional[bool] = None,
221
+ ):
222
+ super().__init__(
223
+ sess,
224
+ session_options=session_options,
225
+ providers=providers,
226
+ nvtx=nvtx,
227
+ enable_profiling=enable_profiling,
228
+ graph_optimization_level=graph_optimization_level,
229
+ log_severity_level=log_severity_level,
230
+ log_verbosity_level=log_verbosity_level,
231
+ optimized_model_filepath=optimized_model_filepath,
232
+ disable_aot_function_inlining=disable_aot_function_inlining,
233
+ use_training_api=use_training_api,
234
+ )
235
+
236
+ def run(
237
+ self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
238
+ ) -> List[Optional[npt.ArrayLike]]:
239
+ """Calls :meth:`onnxruntime.InferenceSession.run`."""
240
+ # sess.run does not support blfoat16
241
+ # res = self.sess.run(output_names, feeds)
242
+ return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
243
+
244
+ def run_dlpack(
245
+ self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
246
+ ) -> Tuple[Optional[npt.ArrayLike], ...]:
247
+ """
248
+ Same as :meth:`onnxruntime.InferenceSession.run` except that
249
+ feeds is a dictionary of :class:`np.ndarray`.
250
+ The output device is CPU even if the outputs are on CUDA.
251
+ """
252
+ memory = []
253
+ new_feeds = {}
254
+ for k, v in feeds.items():
255
+ if not k:
256
+ continue
257
+ if isinstance(v, np.ndarray):
258
+ new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
259
+ v, np_dtype_to_tensor_dtype(v.dtype)
260
+ )
261
+ elif v.dtype == torch.bool:
262
+ vi = v.detach().cpu().numpy()
263
+ memory.append(vi)
264
+ new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
265
+ vi, onnx.TensorProto.BOOL
266
+ )
267
+ else:
268
+ new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
269
+
270
+ if self.nvtx:
271
+ self.torch.cuda.nvtx.range_push("run_with_ort_values")
272
+ ort_outputs = self.sess._sess.run_with_ort_values(
273
+ new_feeds, output_names or self.output_names, self.run_options
274
+ )
275
+ if self.nvtx:
276
+ self.torch.cuda.nvtx.range_pop()
277
+ pth_outputs = self._ortvalues_to_numpy_tensor(ort_outputs)
278
+ return pth_outputs
279
+
280
+ def _ortvalues_to_numpy_tensor(
281
+ self,
282
+ ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
283
+ ) -> Tuple[Optional[npt.ArrayLike], ...]:
284
+ if len(ortvalues) == 0:
285
+ return tuple()
286
+
287
+ if self.nvtx:
288
+ self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
289
+ res: List[Optional[npt.ArrayLike]] = [] # noqa: F823
290
+ for i in range(len(ortvalues)):
291
+ if not ortvalues[i].has_value():
292
+ res.append(None)
293
+ continue
294
+
295
+ el_type = ortvalues[i].element_type()
296
+ if el_type < onnx.TensorProto.BFLOAT16:
297
+ try:
298
+ a = np.from_dlpack(ortvalues[i])
299
+ except RuntimeError as e:
300
+ assert "ORT only supports contiguous tensor for now." in str(e), (
301
+ f"As it says, non-contiguous OrtValue are not supported "
302
+ f"though DLPack, i={i}, the error is different {e}"
303
+ )
304
+ # We make a copy in that case.
305
+ a = ortvalues[i].numpy()
306
+ res.append(a)
307
+ continue
308
+
309
+ # no easy conversion, let's use torch
310
+ tch = torch.from_dlpack(ortvalues[i].to_dlpack())
311
+ size = size_type(el_type)
312
+ assert size == 2, f"Not implemented for type {onnx_dtype_name(el_type)}"
313
+ it = torch.uint16
314
+ itch = tch.view(it)
315
+ npt = itch.numpy()
316
+
317
+ dtype = onnx_dtype_to_np_dtype(el_type)
318
+ res.append(npt.view(dtype))
319
+
320
+ if self.nvtx:
321
+ self.torch.cuda.nvtx.range_pop()
322
+ return tuple(res)
323
+
324
+
325
+ class InferenceSessionForTorch(_InferenceSession):
326
+ """
327
+ Wraps an `onnxruntime.InferenceSession` to overload method `run`
328
+ to support :class:`torch.Tensor`.
329
+
330
+ :param sess: model or inference session
331
+ :param session_options: options
332
+ :param providers: providers
333
+ :param nvtx: enable nvidia events
334
+ :param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
335
+ :param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
336
+ :param log_severity_level: see :class:`onnxruntime.SessionOptions`
337
+ :param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
338
+ :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
339
+ :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
340
+ :param use_training_api: use onnxruntime-traning API
341
+ """
342
+
343
+ def __init__(
344
+ self,
345
+ sess: Union[onnx.ModelProto, str, onnxruntime.InferenceSession],
346
+ session_options: Optional[onnxruntime.SessionOptions] = None,
347
+ providers: Optional[Union[str, List[str]]] = None,
348
+ nvtx: bool = False,
349
+ enable_profiling: bool = False,
350
+ graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
351
+ log_severity_level: Optional[int] = None,
352
+ log_verbosity_level: Optional[int] = None,
353
+ optimized_model_filepath: Optional[str] = None,
354
+ disable_aot_function_inlining: Optional[bool] = None,
355
+ use_training_api: Optional[bool] = None,
356
+ ):
357
+ super().__init__(
358
+ sess,
359
+ session_options=session_options,
360
+ providers=providers,
361
+ nvtx=nvtx,
362
+ enable_profiling=enable_profiling,
363
+ graph_optimization_level=graph_optimization_level,
364
+ log_severity_level=log_severity_level,
365
+ log_verbosity_level=log_verbosity_level,
366
+ optimized_model_filepath=optimized_model_filepath,
367
+ disable_aot_function_inlining=disable_aot_function_inlining,
368
+ use_training_api=use_training_api,
369
+ )
370
+
371
+ def _get_ortvalues_from_torch_tensors(
372
+ self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
373
+ ) -> Tuple[ORTC.OrtValueVector, List[onnxruntime.OrtDevice]]:
374
+ assert tensors is not None, "tensors cannot be None"
375
+ ortvalues = ORTC.OrtValueVector()
376
+ ortvalues.reserve(len(tensors))
377
+ dtypes = []
378
+ shapes = []
379
+ data_ptrs = []
380
+ devices = []
381
+
382
+ if self.nvtx:
383
+ self.torch.cuda.nvtx.range_push("_get_ortvalues_from_torch_tensors.1")
384
+ max_device = -1
385
+ new_tensors = []
386
+ for tensor in tensors:
387
+ assert isinstance(tensor, self.torch.Tensor), f"Unexpected type {type(tensor)}"
388
+ dtypes.append(onnx_dtype_to_np_dtype(torch_dtype_to_onnx_dtype(tensor.dtype)))
389
+ shapes.append(tensor.size())
390
+ data_ptrs.append(tensor.data_ptr())
391
+ d = tensor.get_device()
392
+ devices.append(DEVICES[d])
393
+ new_tensors.append(tensor)
394
+ max_device = max(max_device, d)
395
+
396
+ if self.nvtx:
397
+ self.torch.cuda.nvtx.range_pop()
398
+ self.torch.cuda.nvtx.range_push("_get_ortvalues_from_torch_tensors.2")
399
+
400
+ assert isinstance(max_device, int), f"unexpected type for device={max_device!r}"
401
+ ortvalues.push_back_batch(new_tensors, data_ptrs, dtypes, shapes, devices)
402
+ output_devices = []
403
+ for _ in range(n_outputs):
404
+ dev = DEVICES[max_device]
405
+ output_devices.append(dev)
406
+
407
+ if self.nvtx:
408
+ self.torch.cuda.nvtx.range_pop()
409
+ return ortvalues, output_devices
410
+
411
+ def _ortvalues_to_torch_tensor(
412
+ self,
413
+ ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
414
+ ) -> Tuple[torch.Tensor, ...]:
415
+ if len(ortvalues) == 0:
416
+ return tuple()
417
+
418
+ if all(ortvalues[i].has_value() for i in range(len(ortvalues))):
419
+ if self.nvtx:
420
+ self.torch.cuda.nvtx.range_push("_ortvalues_to_torch_tensor.1")
421
+ res = ortvalues.to_dlpacks(_from_dlpack)
422
+ if self.nvtx:
423
+ self.torch.cuda.nvtx.range_pop()
424
+ else:
425
+ if self.nvtx:
426
+ self.torch.cuda.nvtx.range_push("_ortvalues_to_torch_tensor.2")
427
+ res = []
428
+ for i in range(len(ortvalues)):
429
+ res.append(
430
+ self._torch_from_dlpack(ortvalues[i].to_dlpack())
431
+ if ortvalues[i].has_value()
432
+ else None
433
+ )
434
+ if self.nvtx:
435
+ self.torch.cuda.nvtx.range_pop()
436
+ return tuple(res)
437
+
438
+ def run( # type: ignore
439
+ self, output_names: Optional[List[str]], feeds: Dict[str, torch.Tensor]
440
+ ) -> Tuple[torch.Tensor, ...]:
441
+ """
442
+ Same as :meth:`onnxruntime.InferenceSession.run` except that
443
+ feeds is a dictionary of :class:`torch.Tensor`.
444
+ """
445
+ if self.use_training_api:
446
+ inputs = [feeds[i] for i in self.input_names]
447
+ return self.run_training_api(*inputs, output_names=output_names)
448
+ return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
449
+
450
+ def run_training_api(
451
+ self, *inputs, output_names: Optional[List[str]] = None
452
+ ) -> Tuple[torch.Tensor, ...]:
453
+ """
454
+ Calls the former training API now implemented in onnxruntime as well.
455
+
456
+ :param inputs: list of :class:`torch.Tensor`
457
+ :param output_names: requested outputs or None for all
458
+ :return: tuple of :class:`torch.Tensor`
459
+ """
460
+ if output_names is None:
461
+ output_names = self.output_names
462
+ ortvalues, output_devices = self._get_ortvalues_from_torch_tensors(
463
+ inputs, len(output_names)
464
+ )
465
+
466
+ if self.nvtx:
467
+ self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
468
+
469
+ ort_outputs = ORTC.OrtValueVector()
470
+ self.sess.run_with_ortvaluevector(
471
+ self.run_options,
472
+ self.input_names,
473
+ ortvalues,
474
+ output_names,
475
+ ort_outputs,
476
+ output_devices,
477
+ )
478
+
479
+ if self.nvtx:
480
+ self.torch.cuda.nvtx.range_pop()
481
+
482
+ pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs)
483
+ return pth_outputs
484
+
485
+ def run_dlpack(
486
+ self, output_names: Optional[List[str]], feeds: Dict[str, torch.Tensor]
487
+ ) -> Tuple[torch.Tensor, ...]:
488
+ """
489
+ Same as :meth:`onnxruntime.InferenceSession.run` except that
490
+ feeds is a dictionary of :class:`torch.Tensor`.
491
+ The output device is CPU even if the outputs are on CUDA.
492
+ """
493
+ new_feeds = {}
494
+ for k, v in feeds.items():
495
+ assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
496
+ if not v.is_contiguous():
497
+ v = v.contiguous()
498
+ if v.dtype == torch.bool:
499
+ # It does not work with dlpack
500
+ # unless onnxruntime updates the version it is using.
501
+ new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
502
+ v.detach().numpy(), onnx.TensorProto.BOOL
503
+ )
504
+ else:
505
+ new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
506
+ if self.nvtx:
507
+ self.torch.cuda.nvtx.range_push("run_with_ort_values")
508
+ ort_outputs = self.sess._sess.run_with_ort_values(
509
+ new_feeds, output_names or self.output_names, self.run_options
510
+ )
511
+ if self.nvtx:
512
+ self.torch.cuda.nvtx.range_pop()
513
+ pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs)
514
+ return pth_outputs
515
+
516
+
517
+ def investigate_onnxruntime_issue(
518
+ proto: Union[onnx.ModelProto, str],
519
+ session_options: Optional[onnxruntime.SessionOptions] = None,
520
+ providers: Optional[Union[str, List[str]]] = None,
521
+ nvtx: bool = False,
522
+ enable_profiling: bool = False,
523
+ graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
524
+ log_severity_level: Optional[int] = None,
525
+ log_verbosity_level: Optional[int] = None,
526
+ optimized_model_filepath: Optional[str] = None,
527
+ disable_aot_function_inlining: Optional[bool] = None,
528
+ use_training_api: Optional[bool] = None,
529
+ onnx_to_session: Optional[
530
+ Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
531
+ ] = None,
532
+ # if model needs to be run.
533
+ feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None,
534
+ verbose: int = 0,
535
+ dump_filename: Optional[str] = None,
536
+ infer_shapes: bool = True,
537
+ quiet: bool = False,
538
+ ):
539
+ """
540
+ Invgestigates a crashing model. It tries every node until
541
+ it crashes by adding the ones one by one in the model.
542
+
543
+ :param proto: model or inference session
544
+ :param session_options: options
545
+ :param providers: providers
546
+ :param nvtx: enable nvidia events
547
+ :param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
548
+ :param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
549
+ :param log_severity_level: see :class:`onnxruntime.SessionOptions`
550
+ :param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
551
+ :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
552
+ :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
553
+ :param use_training_api: use onnxruntime-traning API
554
+ :param onnx_to_session: function to load a model into an inference session if
555
+ automated way implemented in this function is not enough,
556
+ if it is equal ``cpu_session``, the callable becomes:
557
+ ``lambda model: onnxruntime.InferenceSession(
558
+ model.SerializeToString(), providers=["CPUExecutionProvider"])``
559
+ :param feeds: run onnxruntime as well
560
+ :param verbosity: verbosity level
561
+ :param dump_filename: if not None, the function dumps the last model run
562
+ :param infer_shapes: run shape inference
563
+ :param quiet: if True, raises an exception, False, just stops and
564
+ return the failing node
565
+
566
+ The most simple use:
567
+
568
+ .. code-block:: python
569
+
570
+ investigate_onnxruntime_issue(
571
+ model,
572
+ feeds=feeds,
573
+ verbose=10,
574
+ dump_filename="test_investigate_onnxruntime_issue_callable.onnx",
575
+ onnx_to_session="cpu_session",
576
+ )
577
+
578
+ Full example:
579
+
580
+ .. runpython::
581
+ :showcode:
582
+
583
+ import numpy as np
584
+ import onnx
585
+ import onnx.helper as oh
586
+ from onnx_diagnostic.helpers.ort_session import investigate_onnxruntime_issue
587
+
588
+ TFLOAT = onnx.TensorProto.FLOAT
589
+ model = oh.make_model(
590
+ oh.make_graph(
591
+ [
592
+ oh.make_node("Add", ["x", "y"], ["gggg"]),
593
+ oh.make_node("Add", ["gggg", "z"], ["final"]),
594
+ ],
595
+ "dummy",
596
+ [
597
+ oh.make_tensor_value_info("x", TFLOAT, [None, None]),
598
+ oh.make_tensor_value_info("y", TFLOAT, [None, None]),
599
+ oh.make_tensor_value_info("z", TFLOAT, [None, None]),
600
+ ],
601
+ [oh.make_tensor_value_info("final", TFLOAT, [None, None])],
602
+ ),
603
+ opset_imports=[oh.make_opsetid("", 18)],
604
+ ir_version=9,
605
+ )
606
+ onnx.checker.check_model(model)
607
+ feeds = {
608
+ "x": np.random.rand(5, 6).astype(np.float32),
609
+ "y": np.random.rand(5, 6).astype(np.float32),
610
+ "z": np.random.rand(5, 6).astype(np.float32),
611
+ }
612
+ investigate_onnxruntime_issue(
613
+ model,
614
+ feeds=feeds,
615
+ verbose=1,
616
+ graph_optimization_level=False,
617
+ dump_filename="last_issue.onnx",
618
+ )
619
+ """
620
+ onx = (
621
+ proto
622
+ if isinstance(proto, onnx.ModelProto)
623
+ else onnx.load(proto, load_external_data=False)
624
+ )
625
+ input_names = [i.name for i in onx.graph.input]
626
+ if verbose:
627
+ print(
628
+ f"[investigate_onnxruntime_issue] found "
629
+ f"{len(onx.graph.node)} nodes and {len(input_names)} inputs"
630
+ )
631
+ if infer_shapes:
632
+ if verbose:
633
+ print("[investigate_onnxruntime_issue] run shape inference")
634
+ onx = onnx.shape_inference.infer_shapes(onx)
635
+
636
+ if isinstance(onnx_to_session, str):
637
+ if onnx_to_session == "cpu_session":
638
+ import onnxruntime
639
+
640
+ onnx_to_session = lambda model: onnxruntime.InferenceSession( # noqa: E731
641
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
642
+ )
643
+ else:
644
+ raise ValueError(f"Unexpected value onnx_to_session={onnx_to_session!r}")
645
+ else:
646
+ cls = (
647
+ InferenceSessionForNumpy
648
+ if feeds is None or any(isinstance(v, np.ndarray) for v in feeds.values())
649
+ else InferenceSessionForTorch
650
+ )
651
+ if verbose and not onnx_to_session:
652
+ print(f"[investigate_onnxruntime_issue] cls={cls}")
653
+
654
+ for i in range(len(onx.graph.node)):
655
+ node = onx.graph.node[i]
656
+ if verbose:
657
+ print(
658
+ f"[investigate_onnxruntime_issue] + node {i}: "
659
+ f"{node.op_type}({', '.join(node.input)}) -> "
660
+ f"{', '.join(node.output)}"
661
+ )
662
+ ext = onnx.utils.Extractor(onx)
663
+ if quiet:
664
+ try:
665
+ extracted = ext.extract_model(input_names, node.output)
666
+ except Exception as e:
667
+ if verbose > 0:
668
+ print(
669
+ f"[investigate_onnxruntime_issue] cannot extract "
670
+ f"model at node {i} due to {e}"
671
+ )
672
+ return node
673
+ else:
674
+ extracted = ext.extract_model(input_names, node.output)
675
+
676
+ if dump_filename:
677
+ if verbose > 1:
678
+ print(f"[investigate_onnxruntime_issue] save into {dump_filename}")
679
+ onnx.save(extracted, dump_filename)
680
+
681
+ if verbose > 1:
682
+ print("[investigate_onnxruntime_issue] create the session")
683
+
684
+ def _make_session(proto):
685
+ if onnx_to_session:
686
+ return onnx_to_session(proto)
687
+ return cls(
688
+ proto,
689
+ session_options=session_options,
690
+ providers=providers,
691
+ nvtx=nvtx,
692
+ enable_profiling=enable_profiling,
693
+ graph_optimization_level=graph_optimization_level,
694
+ log_severity_level=log_severity_level,
695
+ log_verbosity_level=log_verbosity_level,
696
+ optimized_model_filepath=optimized_model_filepath,
697
+ disable_aot_function_inlining=disable_aot_function_inlining,
698
+ use_training_api=use_training_api,
699
+ )
700
+
701
+ if quiet:
702
+ try:
703
+ sess = _make_session(extracted)
704
+ except Exception as e:
705
+ if verbose > 0:
706
+ print(
707
+ f"[investigate_onnxruntime_issue] cannot create session "
708
+ f"at node {i} due to {e}"
709
+ )
710
+ return node
711
+ else:
712
+ sess = _make_session(extracted)
713
+
714
+ if not feeds:
715
+ if verbose > 1:
716
+ print("[investigate_onnxruntime_issue] session created")
717
+ continue
718
+
719
+ if verbose > 1:
720
+ print("[investigate_onnxruntime_issue] running session")
721
+
722
+ if quiet:
723
+ try:
724
+ sess.run(None, feeds)
725
+ except Exception as e:
726
+ if verbose > 0:
727
+ print(
728
+ f"[investigate_onnxruntime_issue] cannot run session "
729
+ f"at node {i} due to {e}"
730
+ )
731
+ return node
732
+ else:
733
+ sess.run(None, feeds)
734
+
735
+ if verbose > 0:
736
+ print("[investigate_onnxruntime_issue] done.")