onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,736 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
2
|
+
import onnx
|
|
3
|
+
import numpy as np
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
import torch
|
|
6
|
+
from torch._C import _from_dlpack
|
|
7
|
+
import onnxruntime
|
|
8
|
+
from onnxruntime.capi import _pybind_state as ORTC
|
|
9
|
+
from .helper import size_type
|
|
10
|
+
from .onnx_helper import (
|
|
11
|
+
onnx_dtype_to_np_dtype,
|
|
12
|
+
np_dtype_to_tensor_dtype,
|
|
13
|
+
onnx_dtype_name,
|
|
14
|
+
)
|
|
15
|
+
from .torch_helper import torch_dtype_to_onnx_dtype
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _InferenceSession:
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def has_onnxruntime_training(cls):
|
|
25
|
+
"""Tells if onnxruntime_training is installed."""
|
|
26
|
+
try:
|
|
27
|
+
from onnxruntime import training
|
|
28
|
+
except ImportError:
|
|
29
|
+
# onnxruntime not training
|
|
30
|
+
training = None
|
|
31
|
+
if training is None:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
|
|
36
|
+
except ImportError:
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
if not hasattr(OrtValueVector, "push_back_batch"):
|
|
40
|
+
return False
|
|
41
|
+
return True
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
sess: Union[onnx.ModelProto, str, onnxruntime.InferenceSession],
|
|
46
|
+
session_options: Optional[onnxruntime.SessionOptions] = None,
|
|
47
|
+
providers: Optional[Union[str, List[Any]]] = None,
|
|
48
|
+
nvtx: bool = False,
|
|
49
|
+
enable_profiling: bool = False,
|
|
50
|
+
graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
|
|
51
|
+
log_severity_level: Optional[int] = None,
|
|
52
|
+
log_verbosity_level: Optional[int] = None,
|
|
53
|
+
optimized_model_filepath: Optional[str] = None,
|
|
54
|
+
disable_aot_function_inlining: Optional[bool] = None,
|
|
55
|
+
use_training_api: Optional[bool] = None,
|
|
56
|
+
):
|
|
57
|
+
# onnxruntime is importing when needed as it takes a
|
|
58
|
+
# couple of seconds if it contains CUDA EP.
|
|
59
|
+
can_use_training_api = True
|
|
60
|
+
if isinstance(sess, (onnx.ModelProto, str)):
|
|
61
|
+
if isinstance(sess, onnx.ModelProto):
|
|
62
|
+
for i in sess.graph.initializer:
|
|
63
|
+
if i.data_type >= onnx.TensorProto.BFLOAT16:
|
|
64
|
+
# Cannot use training api as it relies too much on numpy.
|
|
65
|
+
can_use_training_api = False
|
|
66
|
+
break
|
|
67
|
+
assert session_options is None or (
|
|
68
|
+
providers is None
|
|
69
|
+
and graph_optimization_level is None
|
|
70
|
+
and log_severity_level is None
|
|
71
|
+
and log_verbosity_level is None
|
|
72
|
+
), "session_options is defined, it is impossible to overwrite any option."
|
|
73
|
+
if session_options is None:
|
|
74
|
+
session_options = onnxruntime.SessionOptions()
|
|
75
|
+
if enable_profiling:
|
|
76
|
+
session_options.enable_profiling = enable_profiling
|
|
77
|
+
if optimized_model_filepath:
|
|
78
|
+
session_options.optimized_model_filepath = optimized_model_filepath
|
|
79
|
+
if log_severity_level is not None:
|
|
80
|
+
session_options.log_severity_level = log_severity_level
|
|
81
|
+
if log_verbosity_level is not None:
|
|
82
|
+
session_options.log_verbosity_level = log_verbosity_level
|
|
83
|
+
if graph_optimization_level is not None:
|
|
84
|
+
if isinstance(graph_optimization_level, bool):
|
|
85
|
+
session_options.graph_optimization_level = (
|
|
86
|
+
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
87
|
+
if graph_optimization_level
|
|
88
|
+
else onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
session_options.graph_optimization_level = graph_optimization_level
|
|
92
|
+
if disable_aot_function_inlining:
|
|
93
|
+
session_options.add_session_config_entry(
|
|
94
|
+
"session.disable_aot_function_inlining", "1"
|
|
95
|
+
)
|
|
96
|
+
if providers is None:
|
|
97
|
+
providers = ["CPUExecutionProvider"]
|
|
98
|
+
if isinstance(providers, str):
|
|
99
|
+
if providers.lower() == "cpu":
|
|
100
|
+
providers = ["CPUExecutionProvider"]
|
|
101
|
+
elif providers.lower() == "cuda":
|
|
102
|
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Unexpected value for providers={providers!r}")
|
|
105
|
+
try:
|
|
106
|
+
sess = onnxruntime.InferenceSession(
|
|
107
|
+
sess if isinstance(sess, str) else sess.SerializeToString(),
|
|
108
|
+
session_options,
|
|
109
|
+
providers=providers,
|
|
110
|
+
)
|
|
111
|
+
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
|
|
112
|
+
if isinstance(sess, onnx.ModelProto):
|
|
113
|
+
debug_path = "_debug_InferenceSession_last_failure.onnx"
|
|
114
|
+
onnx.save(
|
|
115
|
+
sess,
|
|
116
|
+
debug_path,
|
|
117
|
+
save_as_external_data=True,
|
|
118
|
+
all_tensors_to_one_file=True,
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
debug_path = sess
|
|
122
|
+
raise RuntimeError(
|
|
123
|
+
f"Unable to create a session stored in {debug_path!r}), "
|
|
124
|
+
f"providers={providers}"
|
|
125
|
+
) from e
|
|
126
|
+
else:
|
|
127
|
+
assert (
|
|
128
|
+
session_options is None
|
|
129
|
+
and providers is None
|
|
130
|
+
and graph_optimization_level is None
|
|
131
|
+
and log_severity_level is None
|
|
132
|
+
and log_verbosity_level is None
|
|
133
|
+
), f"First input is {type(sess)}, it is impossible to overwrite any option."
|
|
134
|
+
|
|
135
|
+
self.sess = sess
|
|
136
|
+
self.input_names = [i.name for i in sess.get_inputs()]
|
|
137
|
+
self.output_names = [i.name for i in sess.get_outputs()]
|
|
138
|
+
self.input_shapes = [i.shape for i in sess.get_inputs()]
|
|
139
|
+
self.output_shapes = [i.shape for i in sess.get_outputs()]
|
|
140
|
+
self.input_types = [i.type for i in sess.get_inputs()]
|
|
141
|
+
self.output_types = [i.type for i in sess.get_outputs()]
|
|
142
|
+
self.torch = torch
|
|
143
|
+
self.nvtx = nvtx
|
|
144
|
+
self.run_options = onnxruntime.RunOptions()
|
|
145
|
+
|
|
146
|
+
if log_severity_level is not None:
|
|
147
|
+
self.run_options.log_severity_level = log_severity_level
|
|
148
|
+
if log_verbosity_level is not None:
|
|
149
|
+
self.run_options.log_verbosity_level = log_verbosity_level
|
|
150
|
+
|
|
151
|
+
self.use_training_api = can_use_training_api and (
|
|
152
|
+
self.has_onnxruntime_training() if use_training_api is None else use_training_api
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if torch.cuda.device_count() > 0:
|
|
156
|
+
for i in range(torch.cuda.device_count()):
|
|
157
|
+
DEVICES[i] = ORTC.OrtDevice(
|
|
158
|
+
ORTC.OrtDevice.cuda(), ORTC.OrtDevice.default_memory(), i
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
self._torch_from_dlpack = _from_dlpack
|
|
162
|
+
self.sess_bool_outputs = [i.type == "tensor(bool)" for i in sess.get_outputs()]
|
|
163
|
+
|
|
164
|
+
def run(
|
|
165
|
+
self,
|
|
166
|
+
output_names: Optional[List[str]],
|
|
167
|
+
feeds: Union[Dict[str, np.ndarray], Dict[str, ORTC.OrtValue]],
|
|
168
|
+
) -> Union[List[np.ndarray], List[ORTC.OrtValue]]:
|
|
169
|
+
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
|
|
170
|
+
if any(isinstance(t, np.ndarray) for t in feeds.values()):
|
|
171
|
+
return self.sess.run(output_names, feeds)
|
|
172
|
+
ort_outputs = self.sess._sess.run_with_ort_values(
|
|
173
|
+
feeds, output_names or self.output_names, self.run_options
|
|
174
|
+
)
|
|
175
|
+
return self._post_process_inplace(ort_outputs)
|
|
176
|
+
|
|
177
|
+
def _post_process_inplace(self, outputs):
|
|
178
|
+
for i in range(len(outputs)):
|
|
179
|
+
o = outputs[i]
|
|
180
|
+
if self.sess_bool_outputs[i]:
|
|
181
|
+
if isinstance(o, np.ndarray):
|
|
182
|
+
if o.dtype != np.bool_:
|
|
183
|
+
outputs[i] = o.astype(np.bool_)
|
|
184
|
+
else:
|
|
185
|
+
if o.dtype != torch.bool:
|
|
186
|
+
outputs[i] = o.to(torch.bool)
|
|
187
|
+
return outputs
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class InferenceSessionForNumpy(_InferenceSession):
|
|
191
|
+
"""
|
|
192
|
+
Wraps an `onnxruntime.InferenceSession` to overload method `run`
|
|
193
|
+
to support :class:`numpy.ndarray`.
|
|
194
|
+
|
|
195
|
+
:param sess: model or inference session
|
|
196
|
+
:param session_options: options
|
|
197
|
+
:param providers: providers
|
|
198
|
+
:param nvtx: enable nvidia events
|
|
199
|
+
:param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
|
|
200
|
+
:param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
|
|
201
|
+
:param log_severity_level: see :class:`onnxruntime.SessionOptions`
|
|
202
|
+
:param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
|
|
203
|
+
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
|
|
204
|
+
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
|
|
205
|
+
:param use_training_api: use onnxruntime-traning API
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
sess: Union[onnx.ModelProto, str, onnxruntime.InferenceSession],
|
|
211
|
+
session_options: Optional[onnxruntime.SessionOptions] = None,
|
|
212
|
+
providers: Optional[Union[str, List[str]]] = None,
|
|
213
|
+
nvtx: bool = False,
|
|
214
|
+
enable_profiling: bool = False,
|
|
215
|
+
graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
|
|
216
|
+
log_severity_level: Optional[int] = None,
|
|
217
|
+
log_verbosity_level: Optional[int] = None,
|
|
218
|
+
optimized_model_filepath: Optional[str] = None,
|
|
219
|
+
disable_aot_function_inlining: Optional[bool] = None,
|
|
220
|
+
use_training_api: Optional[bool] = None,
|
|
221
|
+
):
|
|
222
|
+
super().__init__(
|
|
223
|
+
sess,
|
|
224
|
+
session_options=session_options,
|
|
225
|
+
providers=providers,
|
|
226
|
+
nvtx=nvtx,
|
|
227
|
+
enable_profiling=enable_profiling,
|
|
228
|
+
graph_optimization_level=graph_optimization_level,
|
|
229
|
+
log_severity_level=log_severity_level,
|
|
230
|
+
log_verbosity_level=log_verbosity_level,
|
|
231
|
+
optimized_model_filepath=optimized_model_filepath,
|
|
232
|
+
disable_aot_function_inlining=disable_aot_function_inlining,
|
|
233
|
+
use_training_api=use_training_api,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def run(
|
|
237
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
|
|
238
|
+
) -> List[Optional[npt.ArrayLike]]:
|
|
239
|
+
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
|
|
240
|
+
# sess.run does not support blfoat16
|
|
241
|
+
# res = self.sess.run(output_names, feeds)
|
|
242
|
+
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
|
|
243
|
+
|
|
244
|
+
def run_dlpack(
|
|
245
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
|
|
246
|
+
) -> Tuple[Optional[npt.ArrayLike], ...]:
|
|
247
|
+
"""
|
|
248
|
+
Same as :meth:`onnxruntime.InferenceSession.run` except that
|
|
249
|
+
feeds is a dictionary of :class:`np.ndarray`.
|
|
250
|
+
The output device is CPU even if the outputs are on CUDA.
|
|
251
|
+
"""
|
|
252
|
+
memory = []
|
|
253
|
+
new_feeds = {}
|
|
254
|
+
for k, v in feeds.items():
|
|
255
|
+
if not k:
|
|
256
|
+
continue
|
|
257
|
+
if isinstance(v, np.ndarray):
|
|
258
|
+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
|
|
259
|
+
v, np_dtype_to_tensor_dtype(v.dtype)
|
|
260
|
+
)
|
|
261
|
+
elif v.dtype == torch.bool:
|
|
262
|
+
vi = v.detach().cpu().numpy()
|
|
263
|
+
memory.append(vi)
|
|
264
|
+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
|
|
265
|
+
vi, onnx.TensorProto.BOOL
|
|
266
|
+
)
|
|
267
|
+
else:
|
|
268
|
+
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
|
|
269
|
+
|
|
270
|
+
if self.nvtx:
|
|
271
|
+
self.torch.cuda.nvtx.range_push("run_with_ort_values")
|
|
272
|
+
ort_outputs = self.sess._sess.run_with_ort_values(
|
|
273
|
+
new_feeds, output_names or self.output_names, self.run_options
|
|
274
|
+
)
|
|
275
|
+
if self.nvtx:
|
|
276
|
+
self.torch.cuda.nvtx.range_pop()
|
|
277
|
+
pth_outputs = self._ortvalues_to_numpy_tensor(ort_outputs)
|
|
278
|
+
return pth_outputs
|
|
279
|
+
|
|
280
|
+
def _ortvalues_to_numpy_tensor(
|
|
281
|
+
self,
|
|
282
|
+
ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
|
|
283
|
+
) -> Tuple[Optional[npt.ArrayLike], ...]:
|
|
284
|
+
if len(ortvalues) == 0:
|
|
285
|
+
return tuple()
|
|
286
|
+
|
|
287
|
+
if self.nvtx:
|
|
288
|
+
self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
|
|
289
|
+
res: List[Optional[npt.ArrayLike]] = [] # noqa: F823
|
|
290
|
+
for i in range(len(ortvalues)):
|
|
291
|
+
if not ortvalues[i].has_value():
|
|
292
|
+
res.append(None)
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
el_type = ortvalues[i].element_type()
|
|
296
|
+
if el_type < onnx.TensorProto.BFLOAT16:
|
|
297
|
+
try:
|
|
298
|
+
a = np.from_dlpack(ortvalues[i])
|
|
299
|
+
except RuntimeError as e:
|
|
300
|
+
assert "ORT only supports contiguous tensor for now." in str(e), (
|
|
301
|
+
f"As it says, non-contiguous OrtValue are not supported "
|
|
302
|
+
f"though DLPack, i={i}, the error is different {e}"
|
|
303
|
+
)
|
|
304
|
+
# We make a copy in that case.
|
|
305
|
+
a = ortvalues[i].numpy()
|
|
306
|
+
res.append(a)
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
# no easy conversion, let's use torch
|
|
310
|
+
tch = torch.from_dlpack(ortvalues[i].to_dlpack())
|
|
311
|
+
size = size_type(el_type)
|
|
312
|
+
assert size == 2, f"Not implemented for type {onnx_dtype_name(el_type)}"
|
|
313
|
+
it = torch.uint16
|
|
314
|
+
itch = tch.view(it)
|
|
315
|
+
npt = itch.numpy()
|
|
316
|
+
|
|
317
|
+
dtype = onnx_dtype_to_np_dtype(el_type)
|
|
318
|
+
res.append(npt.view(dtype))
|
|
319
|
+
|
|
320
|
+
if self.nvtx:
|
|
321
|
+
self.torch.cuda.nvtx.range_pop()
|
|
322
|
+
return tuple(res)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class InferenceSessionForTorch(_InferenceSession):
|
|
326
|
+
"""
|
|
327
|
+
Wraps an `onnxruntime.InferenceSession` to overload method `run`
|
|
328
|
+
to support :class:`torch.Tensor`.
|
|
329
|
+
|
|
330
|
+
:param sess: model or inference session
|
|
331
|
+
:param session_options: options
|
|
332
|
+
:param providers: providers
|
|
333
|
+
:param nvtx: enable nvidia events
|
|
334
|
+
:param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
|
|
335
|
+
:param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
|
|
336
|
+
:param log_severity_level: see :class:`onnxruntime.SessionOptions`
|
|
337
|
+
:param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
|
|
338
|
+
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
|
|
339
|
+
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
|
|
340
|
+
:param use_training_api: use onnxruntime-traning API
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
def __init__(
|
|
344
|
+
self,
|
|
345
|
+
sess: Union[onnx.ModelProto, str, onnxruntime.InferenceSession],
|
|
346
|
+
session_options: Optional[onnxruntime.SessionOptions] = None,
|
|
347
|
+
providers: Optional[Union[str, List[str]]] = None,
|
|
348
|
+
nvtx: bool = False,
|
|
349
|
+
enable_profiling: bool = False,
|
|
350
|
+
graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
|
|
351
|
+
log_severity_level: Optional[int] = None,
|
|
352
|
+
log_verbosity_level: Optional[int] = None,
|
|
353
|
+
optimized_model_filepath: Optional[str] = None,
|
|
354
|
+
disable_aot_function_inlining: Optional[bool] = None,
|
|
355
|
+
use_training_api: Optional[bool] = None,
|
|
356
|
+
):
|
|
357
|
+
super().__init__(
|
|
358
|
+
sess,
|
|
359
|
+
session_options=session_options,
|
|
360
|
+
providers=providers,
|
|
361
|
+
nvtx=nvtx,
|
|
362
|
+
enable_profiling=enable_profiling,
|
|
363
|
+
graph_optimization_level=graph_optimization_level,
|
|
364
|
+
log_severity_level=log_severity_level,
|
|
365
|
+
log_verbosity_level=log_verbosity_level,
|
|
366
|
+
optimized_model_filepath=optimized_model_filepath,
|
|
367
|
+
disable_aot_function_inlining=disable_aot_function_inlining,
|
|
368
|
+
use_training_api=use_training_api,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def _get_ortvalues_from_torch_tensors(
|
|
372
|
+
self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
|
|
373
|
+
) -> Tuple[ORTC.OrtValueVector, List[onnxruntime.OrtDevice]]:
|
|
374
|
+
assert tensors is not None, "tensors cannot be None"
|
|
375
|
+
ortvalues = ORTC.OrtValueVector()
|
|
376
|
+
ortvalues.reserve(len(tensors))
|
|
377
|
+
dtypes = []
|
|
378
|
+
shapes = []
|
|
379
|
+
data_ptrs = []
|
|
380
|
+
devices = []
|
|
381
|
+
|
|
382
|
+
if self.nvtx:
|
|
383
|
+
self.torch.cuda.nvtx.range_push("_get_ortvalues_from_torch_tensors.1")
|
|
384
|
+
max_device = -1
|
|
385
|
+
new_tensors = []
|
|
386
|
+
for tensor in tensors:
|
|
387
|
+
assert isinstance(tensor, self.torch.Tensor), f"Unexpected type {type(tensor)}"
|
|
388
|
+
dtypes.append(onnx_dtype_to_np_dtype(torch_dtype_to_onnx_dtype(tensor.dtype)))
|
|
389
|
+
shapes.append(tensor.size())
|
|
390
|
+
data_ptrs.append(tensor.data_ptr())
|
|
391
|
+
d = tensor.get_device()
|
|
392
|
+
devices.append(DEVICES[d])
|
|
393
|
+
new_tensors.append(tensor)
|
|
394
|
+
max_device = max(max_device, d)
|
|
395
|
+
|
|
396
|
+
if self.nvtx:
|
|
397
|
+
self.torch.cuda.nvtx.range_pop()
|
|
398
|
+
self.torch.cuda.nvtx.range_push("_get_ortvalues_from_torch_tensors.2")
|
|
399
|
+
|
|
400
|
+
assert isinstance(max_device, int), f"unexpected type for device={max_device!r}"
|
|
401
|
+
ortvalues.push_back_batch(new_tensors, data_ptrs, dtypes, shapes, devices)
|
|
402
|
+
output_devices = []
|
|
403
|
+
for _ in range(n_outputs):
|
|
404
|
+
dev = DEVICES[max_device]
|
|
405
|
+
output_devices.append(dev)
|
|
406
|
+
|
|
407
|
+
if self.nvtx:
|
|
408
|
+
self.torch.cuda.nvtx.range_pop()
|
|
409
|
+
return ortvalues, output_devices
|
|
410
|
+
|
|
411
|
+
def _ortvalues_to_torch_tensor(
|
|
412
|
+
self,
|
|
413
|
+
ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
|
|
414
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
415
|
+
if len(ortvalues) == 0:
|
|
416
|
+
return tuple()
|
|
417
|
+
|
|
418
|
+
if all(ortvalues[i].has_value() for i in range(len(ortvalues))):
|
|
419
|
+
if self.nvtx:
|
|
420
|
+
self.torch.cuda.nvtx.range_push("_ortvalues_to_torch_tensor.1")
|
|
421
|
+
res = ortvalues.to_dlpacks(_from_dlpack)
|
|
422
|
+
if self.nvtx:
|
|
423
|
+
self.torch.cuda.nvtx.range_pop()
|
|
424
|
+
else:
|
|
425
|
+
if self.nvtx:
|
|
426
|
+
self.torch.cuda.nvtx.range_push("_ortvalues_to_torch_tensor.2")
|
|
427
|
+
res = []
|
|
428
|
+
for i in range(len(ortvalues)):
|
|
429
|
+
res.append(
|
|
430
|
+
self._torch_from_dlpack(ortvalues[i].to_dlpack())
|
|
431
|
+
if ortvalues[i].has_value()
|
|
432
|
+
else None
|
|
433
|
+
)
|
|
434
|
+
if self.nvtx:
|
|
435
|
+
self.torch.cuda.nvtx.range_pop()
|
|
436
|
+
return tuple(res)
|
|
437
|
+
|
|
438
|
+
def run( # type: ignore
|
|
439
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, torch.Tensor]
|
|
440
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
441
|
+
"""
|
|
442
|
+
Same as :meth:`onnxruntime.InferenceSession.run` except that
|
|
443
|
+
feeds is a dictionary of :class:`torch.Tensor`.
|
|
444
|
+
"""
|
|
445
|
+
if self.use_training_api:
|
|
446
|
+
inputs = [feeds[i] for i in self.input_names]
|
|
447
|
+
return self.run_training_api(*inputs, output_names=output_names)
|
|
448
|
+
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
|
|
449
|
+
|
|
450
|
+
def run_training_api(
|
|
451
|
+
self, *inputs, output_names: Optional[List[str]] = None
|
|
452
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
453
|
+
"""
|
|
454
|
+
Calls the former training API now implemented in onnxruntime as well.
|
|
455
|
+
|
|
456
|
+
:param inputs: list of :class:`torch.Tensor`
|
|
457
|
+
:param output_names: requested outputs or None for all
|
|
458
|
+
:return: tuple of :class:`torch.Tensor`
|
|
459
|
+
"""
|
|
460
|
+
if output_names is None:
|
|
461
|
+
output_names = self.output_names
|
|
462
|
+
ortvalues, output_devices = self._get_ortvalues_from_torch_tensors(
|
|
463
|
+
inputs, len(output_names)
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
if self.nvtx:
|
|
467
|
+
self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
|
|
468
|
+
|
|
469
|
+
ort_outputs = ORTC.OrtValueVector()
|
|
470
|
+
self.sess.run_with_ortvaluevector(
|
|
471
|
+
self.run_options,
|
|
472
|
+
self.input_names,
|
|
473
|
+
ortvalues,
|
|
474
|
+
output_names,
|
|
475
|
+
ort_outputs,
|
|
476
|
+
output_devices,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
if self.nvtx:
|
|
480
|
+
self.torch.cuda.nvtx.range_pop()
|
|
481
|
+
|
|
482
|
+
pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs)
|
|
483
|
+
return pth_outputs
|
|
484
|
+
|
|
485
|
+
def run_dlpack(
|
|
486
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, torch.Tensor]
|
|
487
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
488
|
+
"""
|
|
489
|
+
Same as :meth:`onnxruntime.InferenceSession.run` except that
|
|
490
|
+
feeds is a dictionary of :class:`torch.Tensor`.
|
|
491
|
+
The output device is CPU even if the outputs are on CUDA.
|
|
492
|
+
"""
|
|
493
|
+
new_feeds = {}
|
|
494
|
+
for k, v in feeds.items():
|
|
495
|
+
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
|
|
496
|
+
if not v.is_contiguous():
|
|
497
|
+
v = v.contiguous()
|
|
498
|
+
if v.dtype == torch.bool:
|
|
499
|
+
# It does not work with dlpack
|
|
500
|
+
# unless onnxruntime updates the version it is using.
|
|
501
|
+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
|
|
502
|
+
v.detach().numpy(), onnx.TensorProto.BOOL
|
|
503
|
+
)
|
|
504
|
+
else:
|
|
505
|
+
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
|
|
506
|
+
if self.nvtx:
|
|
507
|
+
self.torch.cuda.nvtx.range_push("run_with_ort_values")
|
|
508
|
+
ort_outputs = self.sess._sess.run_with_ort_values(
|
|
509
|
+
new_feeds, output_names or self.output_names, self.run_options
|
|
510
|
+
)
|
|
511
|
+
if self.nvtx:
|
|
512
|
+
self.torch.cuda.nvtx.range_pop()
|
|
513
|
+
pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs)
|
|
514
|
+
return pth_outputs
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def investigate_onnxruntime_issue(
|
|
518
|
+
proto: Union[onnx.ModelProto, str],
|
|
519
|
+
session_options: Optional[onnxruntime.SessionOptions] = None,
|
|
520
|
+
providers: Optional[Union[str, List[str]]] = None,
|
|
521
|
+
nvtx: bool = False,
|
|
522
|
+
enable_profiling: bool = False,
|
|
523
|
+
graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
|
|
524
|
+
log_severity_level: Optional[int] = None,
|
|
525
|
+
log_verbosity_level: Optional[int] = None,
|
|
526
|
+
optimized_model_filepath: Optional[str] = None,
|
|
527
|
+
disable_aot_function_inlining: Optional[bool] = None,
|
|
528
|
+
use_training_api: Optional[bool] = None,
|
|
529
|
+
onnx_to_session: Optional[
|
|
530
|
+
Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
|
|
531
|
+
] = None,
|
|
532
|
+
# if model needs to be run.
|
|
533
|
+
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None,
|
|
534
|
+
verbose: int = 0,
|
|
535
|
+
dump_filename: Optional[str] = None,
|
|
536
|
+
infer_shapes: bool = True,
|
|
537
|
+
quiet: bool = False,
|
|
538
|
+
):
|
|
539
|
+
"""
|
|
540
|
+
Invgestigates a crashing model. It tries every node until
|
|
541
|
+
it crashes by adding the ones one by one in the model.
|
|
542
|
+
|
|
543
|
+
:param proto: model or inference session
|
|
544
|
+
:param session_options: options
|
|
545
|
+
:param providers: providers
|
|
546
|
+
:param nvtx: enable nvidia events
|
|
547
|
+
:param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
|
|
548
|
+
:param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
|
|
549
|
+
:param log_severity_level: see :class:`onnxruntime.SessionOptions`
|
|
550
|
+
:param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
|
|
551
|
+
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
|
|
552
|
+
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
|
|
553
|
+
:param use_training_api: use onnxruntime-traning API
|
|
554
|
+
:param onnx_to_session: function to load a model into an inference session if
|
|
555
|
+
automated way implemented in this function is not enough,
|
|
556
|
+
if it is equal ``cpu_session``, the callable becomes:
|
|
557
|
+
``lambda model: onnxruntime.InferenceSession(
|
|
558
|
+
model.SerializeToString(), providers=["CPUExecutionProvider"])``
|
|
559
|
+
:param feeds: run onnxruntime as well
|
|
560
|
+
:param verbosity: verbosity level
|
|
561
|
+
:param dump_filename: if not None, the function dumps the last model run
|
|
562
|
+
:param infer_shapes: run shape inference
|
|
563
|
+
:param quiet: if True, raises an exception, False, just stops and
|
|
564
|
+
return the failing node
|
|
565
|
+
|
|
566
|
+
The most simple use:
|
|
567
|
+
|
|
568
|
+
.. code-block:: python
|
|
569
|
+
|
|
570
|
+
investigate_onnxruntime_issue(
|
|
571
|
+
model,
|
|
572
|
+
feeds=feeds,
|
|
573
|
+
verbose=10,
|
|
574
|
+
dump_filename="test_investigate_onnxruntime_issue_callable.onnx",
|
|
575
|
+
onnx_to_session="cpu_session",
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
Full example:
|
|
579
|
+
|
|
580
|
+
.. runpython::
|
|
581
|
+
:showcode:
|
|
582
|
+
|
|
583
|
+
import numpy as np
|
|
584
|
+
import onnx
|
|
585
|
+
import onnx.helper as oh
|
|
586
|
+
from onnx_diagnostic.helpers.ort_session import investigate_onnxruntime_issue
|
|
587
|
+
|
|
588
|
+
TFLOAT = onnx.TensorProto.FLOAT
|
|
589
|
+
model = oh.make_model(
|
|
590
|
+
oh.make_graph(
|
|
591
|
+
[
|
|
592
|
+
oh.make_node("Add", ["x", "y"], ["gggg"]),
|
|
593
|
+
oh.make_node("Add", ["gggg", "z"], ["final"]),
|
|
594
|
+
],
|
|
595
|
+
"dummy",
|
|
596
|
+
[
|
|
597
|
+
oh.make_tensor_value_info("x", TFLOAT, [None, None]),
|
|
598
|
+
oh.make_tensor_value_info("y", TFLOAT, [None, None]),
|
|
599
|
+
oh.make_tensor_value_info("z", TFLOAT, [None, None]),
|
|
600
|
+
],
|
|
601
|
+
[oh.make_tensor_value_info("final", TFLOAT, [None, None])],
|
|
602
|
+
),
|
|
603
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
604
|
+
ir_version=9,
|
|
605
|
+
)
|
|
606
|
+
onnx.checker.check_model(model)
|
|
607
|
+
feeds = {
|
|
608
|
+
"x": np.random.rand(5, 6).astype(np.float32),
|
|
609
|
+
"y": np.random.rand(5, 6).astype(np.float32),
|
|
610
|
+
"z": np.random.rand(5, 6).astype(np.float32),
|
|
611
|
+
}
|
|
612
|
+
investigate_onnxruntime_issue(
|
|
613
|
+
model,
|
|
614
|
+
feeds=feeds,
|
|
615
|
+
verbose=1,
|
|
616
|
+
graph_optimization_level=False,
|
|
617
|
+
dump_filename="last_issue.onnx",
|
|
618
|
+
)
|
|
619
|
+
"""
|
|
620
|
+
onx = (
|
|
621
|
+
proto
|
|
622
|
+
if isinstance(proto, onnx.ModelProto)
|
|
623
|
+
else onnx.load(proto, load_external_data=False)
|
|
624
|
+
)
|
|
625
|
+
input_names = [i.name for i in onx.graph.input]
|
|
626
|
+
if verbose:
|
|
627
|
+
print(
|
|
628
|
+
f"[investigate_onnxruntime_issue] found "
|
|
629
|
+
f"{len(onx.graph.node)} nodes and {len(input_names)} inputs"
|
|
630
|
+
)
|
|
631
|
+
if infer_shapes:
|
|
632
|
+
if verbose:
|
|
633
|
+
print("[investigate_onnxruntime_issue] run shape inference")
|
|
634
|
+
onx = onnx.shape_inference.infer_shapes(onx)
|
|
635
|
+
|
|
636
|
+
if isinstance(onnx_to_session, str):
|
|
637
|
+
if onnx_to_session == "cpu_session":
|
|
638
|
+
import onnxruntime
|
|
639
|
+
|
|
640
|
+
onnx_to_session = lambda model: onnxruntime.InferenceSession( # noqa: E731
|
|
641
|
+
model.SerializeToString(), providers=["CPUExecutionProvider"]
|
|
642
|
+
)
|
|
643
|
+
else:
|
|
644
|
+
raise ValueError(f"Unexpected value onnx_to_session={onnx_to_session!r}")
|
|
645
|
+
else:
|
|
646
|
+
cls = (
|
|
647
|
+
InferenceSessionForNumpy
|
|
648
|
+
if feeds is None or any(isinstance(v, np.ndarray) for v in feeds.values())
|
|
649
|
+
else InferenceSessionForTorch
|
|
650
|
+
)
|
|
651
|
+
if verbose and not onnx_to_session:
|
|
652
|
+
print(f"[investigate_onnxruntime_issue] cls={cls}")
|
|
653
|
+
|
|
654
|
+
for i in range(len(onx.graph.node)):
|
|
655
|
+
node = onx.graph.node[i]
|
|
656
|
+
if verbose:
|
|
657
|
+
print(
|
|
658
|
+
f"[investigate_onnxruntime_issue] + node {i}: "
|
|
659
|
+
f"{node.op_type}({', '.join(node.input)}) -> "
|
|
660
|
+
f"{', '.join(node.output)}"
|
|
661
|
+
)
|
|
662
|
+
ext = onnx.utils.Extractor(onx)
|
|
663
|
+
if quiet:
|
|
664
|
+
try:
|
|
665
|
+
extracted = ext.extract_model(input_names, node.output)
|
|
666
|
+
except Exception as e:
|
|
667
|
+
if verbose > 0:
|
|
668
|
+
print(
|
|
669
|
+
f"[investigate_onnxruntime_issue] cannot extract "
|
|
670
|
+
f"model at node {i} due to {e}"
|
|
671
|
+
)
|
|
672
|
+
return node
|
|
673
|
+
else:
|
|
674
|
+
extracted = ext.extract_model(input_names, node.output)
|
|
675
|
+
|
|
676
|
+
if dump_filename:
|
|
677
|
+
if verbose > 1:
|
|
678
|
+
print(f"[investigate_onnxruntime_issue] save into {dump_filename}")
|
|
679
|
+
onnx.save(extracted, dump_filename)
|
|
680
|
+
|
|
681
|
+
if verbose > 1:
|
|
682
|
+
print("[investigate_onnxruntime_issue] create the session")
|
|
683
|
+
|
|
684
|
+
def _make_session(proto):
|
|
685
|
+
if onnx_to_session:
|
|
686
|
+
return onnx_to_session(proto)
|
|
687
|
+
return cls(
|
|
688
|
+
proto,
|
|
689
|
+
session_options=session_options,
|
|
690
|
+
providers=providers,
|
|
691
|
+
nvtx=nvtx,
|
|
692
|
+
enable_profiling=enable_profiling,
|
|
693
|
+
graph_optimization_level=graph_optimization_level,
|
|
694
|
+
log_severity_level=log_severity_level,
|
|
695
|
+
log_verbosity_level=log_verbosity_level,
|
|
696
|
+
optimized_model_filepath=optimized_model_filepath,
|
|
697
|
+
disable_aot_function_inlining=disable_aot_function_inlining,
|
|
698
|
+
use_training_api=use_training_api,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
if quiet:
|
|
702
|
+
try:
|
|
703
|
+
sess = _make_session(extracted)
|
|
704
|
+
except Exception as e:
|
|
705
|
+
if verbose > 0:
|
|
706
|
+
print(
|
|
707
|
+
f"[investigate_onnxruntime_issue] cannot create session "
|
|
708
|
+
f"at node {i} due to {e}"
|
|
709
|
+
)
|
|
710
|
+
return node
|
|
711
|
+
else:
|
|
712
|
+
sess = _make_session(extracted)
|
|
713
|
+
|
|
714
|
+
if not feeds:
|
|
715
|
+
if verbose > 1:
|
|
716
|
+
print("[investigate_onnxruntime_issue] session created")
|
|
717
|
+
continue
|
|
718
|
+
|
|
719
|
+
if verbose > 1:
|
|
720
|
+
print("[investigate_onnxruntime_issue] running session")
|
|
721
|
+
|
|
722
|
+
if quiet:
|
|
723
|
+
try:
|
|
724
|
+
sess.run(None, feeds)
|
|
725
|
+
except Exception as e:
|
|
726
|
+
if verbose > 0:
|
|
727
|
+
print(
|
|
728
|
+
f"[investigate_onnxruntime_issue] cannot run session "
|
|
729
|
+
f"at node {i} due to {e}"
|
|
730
|
+
)
|
|
731
|
+
return node
|
|
732
|
+
else:
|
|
733
|
+
sess.run(None, feeds)
|
|
734
|
+
|
|
735
|
+
if verbose > 0:
|
|
736
|
+
print("[investigate_onnxruntime_issue] done.")
|