onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2124 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import datetime
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
import pprint
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
8
|
+
import time
|
|
9
|
+
import numpy as np
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
from ..export import CoupleInputsDynamicShapes
|
|
13
|
+
from ..helpers import max_diff, string_type, string_diff
|
|
14
|
+
from ..helpers.helper import flatten_object
|
|
15
|
+
from ..helpers.rt_helper import make_feeds
|
|
16
|
+
from ..helpers.torch_helper import to_any, torch_deepcopy
|
|
17
|
+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
18
|
+
from ..tasks import random_input_kwargs
|
|
19
|
+
from ..torch_export_patches import (
|
|
20
|
+
torch_export_patches,
|
|
21
|
+
register_additional_serialization_functions,
|
|
22
|
+
)
|
|
23
|
+
from ..torch_export_patches.patch_inputs import use_dyn_not_str
|
|
24
|
+
from .hghub import get_untrained_model_with_inputs
|
|
25
|
+
from .hghub.model_inputs import _preprocess_model_id
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def empty(value: Any) -> bool:
|
|
29
|
+
"""Tells if the value is empty."""
|
|
30
|
+
if isinstance(value, (str, list, dict, tuple, set)):
|
|
31
|
+
return not bool(value)
|
|
32
|
+
if value is None:
|
|
33
|
+
return True
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, Any]:
|
|
38
|
+
"""
|
|
39
|
+
Returns dummy inputs for a specific task.
|
|
40
|
+
|
|
41
|
+
:param task: requested task
|
|
42
|
+
:param config: returns dummy inputs for a specific config if available
|
|
43
|
+
:return: dummy inputs and dynamic shapes
|
|
44
|
+
"""
|
|
45
|
+
kwargs, f = random_input_kwargs(config, task)
|
|
46
|
+
return f(model=None, config=config, **kwargs)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def split_args_kwargs(inputs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
50
|
+
"""Splits into args, kwargs."""
|
|
51
|
+
if isinstance(inputs, dict):
|
|
52
|
+
return (), inputs
|
|
53
|
+
if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[1], dict):
|
|
54
|
+
return inputs
|
|
55
|
+
assert isinstance(inputs, tuple), f"Unexpected inputs {string_type(inputs)}"
|
|
56
|
+
return inputs, {}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def make_inputs(
|
|
60
|
+
args: Optional[Tuple[Any, ...]], kwargs: Optional[Dict[str, Any]] = None
|
|
61
|
+
) -> Any:
|
|
62
|
+
"""Returns either args, kwargs or both depending on which ones are empty."""
|
|
63
|
+
assert args or kwargs, "No input was given."
|
|
64
|
+
if not args:
|
|
65
|
+
return kwargs
|
|
66
|
+
if not kwargs:
|
|
67
|
+
return args
|
|
68
|
+
return args, kwargs
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def filter_inputs(
|
|
72
|
+
inputs: Any,
|
|
73
|
+
drop_names: List[str],
|
|
74
|
+
model: Optional[Union[torch.nn.Module, List[str]]] = None,
|
|
75
|
+
dynamic_shapes: Optional[Any] = None,
|
|
76
|
+
):
|
|
77
|
+
"""
|
|
78
|
+
Drops some inputs from the given inputs.
|
|
79
|
+
It updates the dynamic shapes as well.
|
|
80
|
+
"""
|
|
81
|
+
args, kwargs = split_args_kwargs(inputs)
|
|
82
|
+
set_drop_names = set(drop_names)
|
|
83
|
+
kwargs = {k: v for k, v in kwargs.items() if k not in set_drop_names}
|
|
84
|
+
dyn = (
|
|
85
|
+
{k: v for k, v in dynamic_shapes.items() if k not in set_drop_names}
|
|
86
|
+
if dynamic_shapes and isinstance(dynamic_shapes, dict)
|
|
87
|
+
else dynamic_shapes
|
|
88
|
+
)
|
|
89
|
+
if not args or all(i in kwargs for i in set_drop_names):
|
|
90
|
+
return make_inputs(args, kwargs), dyn
|
|
91
|
+
assert model, (
|
|
92
|
+
f"we need the model to get the parameter name but model is None, "
|
|
93
|
+
f"input_names={drop_names} and args={string_type(args)}"
|
|
94
|
+
)
|
|
95
|
+
pnames = (
|
|
96
|
+
list(inspect.signature(model.forward).parameters)
|
|
97
|
+
if isinstance(model, torch.nn.Module)
|
|
98
|
+
else model
|
|
99
|
+
)
|
|
100
|
+
new_args = []
|
|
101
|
+
new_ds = []
|
|
102
|
+
for i, a in enumerate(args):
|
|
103
|
+
if isinstance(dynamic_shapes, tuple):
|
|
104
|
+
new_ds.append(None if pnames[i] in set_drop_names else dynamic_shapes[i])
|
|
105
|
+
new_args.append(None if pnames[i] in set_drop_names else a)
|
|
106
|
+
new_inputs = make_inputs(tuple(new_args), kwargs)
|
|
107
|
+
if new_ds:
|
|
108
|
+
return new_inputs, tuple(new_ds)
|
|
109
|
+
return new_inputs, dyn
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _make_folder_name(
|
|
113
|
+
model_id: str,
|
|
114
|
+
exporter: Optional[str],
|
|
115
|
+
optimization: Optional[str] = None,
|
|
116
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
117
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
118
|
+
subfolder: Optional[str] = None,
|
|
119
|
+
opset: Optional[int] = None,
|
|
120
|
+
drop_inputs: Optional[List[str]] = None,
|
|
121
|
+
same_as_pretrained: bool = False,
|
|
122
|
+
use_pretrained: bool = False,
|
|
123
|
+
task: Optional[str] = None,
|
|
124
|
+
) -> str:
|
|
125
|
+
"Creates a filename unique based on the given options."
|
|
126
|
+
els = [model_id.replace("/", "_")]
|
|
127
|
+
if subfolder:
|
|
128
|
+
els.append(subfolder.replace("/", "_"))
|
|
129
|
+
if task:
|
|
130
|
+
els.append(task)
|
|
131
|
+
if drop_inputs:
|
|
132
|
+
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
|
|
133
|
+
els.append(f"I-{ii.upper()}")
|
|
134
|
+
if use_pretrained:
|
|
135
|
+
els.append("TRAINED")
|
|
136
|
+
elif same_as_pretrained:
|
|
137
|
+
els.append("SAMESIZE")
|
|
138
|
+
if exporter:
|
|
139
|
+
els.append(exporter)
|
|
140
|
+
if optimization:
|
|
141
|
+
els.append(optimization)
|
|
142
|
+
if dtype is not None and dtype:
|
|
143
|
+
stype = dtype if isinstance(dtype, str) else str(dtype)
|
|
144
|
+
stype = stype.replace("float", "f").replace("uint", "u").replace("int", "i")
|
|
145
|
+
els.append(stype)
|
|
146
|
+
if device is not None and device:
|
|
147
|
+
sdev = device if isinstance(device, str) else str(device)
|
|
148
|
+
sdev = sdev.lower()
|
|
149
|
+
if "cpu" in sdev:
|
|
150
|
+
sdev = "cpu"
|
|
151
|
+
elif "cuda" in sdev:
|
|
152
|
+
sdev = "cuda"
|
|
153
|
+
else:
|
|
154
|
+
raise AssertionError(f"unexpected value for device={device}, sdev={sdev!r}")
|
|
155
|
+
els.append(sdev)
|
|
156
|
+
if opset is not None:
|
|
157
|
+
els.append(f"op{opset}")
|
|
158
|
+
return "/".join([e for e in els if e])
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def version_summary() -> Dict[str, Union[int, float, str]]:
|
|
162
|
+
"""
|
|
163
|
+
Example:
|
|
164
|
+
|
|
165
|
+
.. runpython::
|
|
166
|
+
:showcode:
|
|
167
|
+
|
|
168
|
+
import pprint
|
|
169
|
+
from onnx_diagnostic.torch_models.validate import version_summary
|
|
170
|
+
|
|
171
|
+
pprint.pprint(version_summary())
|
|
172
|
+
"""
|
|
173
|
+
import numpy
|
|
174
|
+
|
|
175
|
+
summary: Dict[str, Union[int, float, str]] = {
|
|
176
|
+
"version_torch": torch.__version__,
|
|
177
|
+
"version_numpy": numpy.__version__,
|
|
178
|
+
}
|
|
179
|
+
try:
|
|
180
|
+
import scipy
|
|
181
|
+
|
|
182
|
+
summary["version_scipy"] = getattr(scipy, "__version__", "?")
|
|
183
|
+
except ImportError:
|
|
184
|
+
pass
|
|
185
|
+
try:
|
|
186
|
+
import transformers
|
|
187
|
+
|
|
188
|
+
summary["version_transformers"] = getattr(transformers, "__version__", "?")
|
|
189
|
+
except ImportError:
|
|
190
|
+
pass
|
|
191
|
+
try:
|
|
192
|
+
import onnx
|
|
193
|
+
|
|
194
|
+
summary["version_onnx"] = getattr(onnx, "__version__", "?")
|
|
195
|
+
except ImportError:
|
|
196
|
+
pass
|
|
197
|
+
try:
|
|
198
|
+
import onnxscript
|
|
199
|
+
|
|
200
|
+
summary["version_onnxscript"] = getattr(onnxscript, "__version__", "?")
|
|
201
|
+
except ImportError:
|
|
202
|
+
pass
|
|
203
|
+
try:
|
|
204
|
+
import onnxruntime
|
|
205
|
+
|
|
206
|
+
summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?")
|
|
207
|
+
except ImportError:
|
|
208
|
+
pass
|
|
209
|
+
try:
|
|
210
|
+
import onnx_ir
|
|
211
|
+
|
|
212
|
+
summary["version_onnx_ir"] = getattr(onnx_ir, "__version__", "?")
|
|
213
|
+
except ImportError:
|
|
214
|
+
pass
|
|
215
|
+
import onnx_diagnostic
|
|
216
|
+
|
|
217
|
+
summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__
|
|
218
|
+
summary["version_date"] = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
|
219
|
+
return summary
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _quiet_or_not_quiet(
|
|
223
|
+
quiet: bool,
|
|
224
|
+
suffix: str,
|
|
225
|
+
summary: Dict[str, Any],
|
|
226
|
+
data: Optional[Dict[str, Any]],
|
|
227
|
+
fct: Callable,
|
|
228
|
+
repeat: int = 1,
|
|
229
|
+
warmup: int = 0,
|
|
230
|
+
) -> Any:
|
|
231
|
+
begin = time.perf_counter()
|
|
232
|
+
if quiet:
|
|
233
|
+
try:
|
|
234
|
+
res = fct()
|
|
235
|
+
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
236
|
+
if warmup + repeat == 1:
|
|
237
|
+
return res
|
|
238
|
+
except Exception as e:
|
|
239
|
+
summary[f"ERR_{suffix}"] = str(e)
|
|
240
|
+
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
241
|
+
if data is None:
|
|
242
|
+
return {f"ERR_{suffix}": e}
|
|
243
|
+
data[f"ERR_{suffix}"] = e
|
|
244
|
+
return None
|
|
245
|
+
else:
|
|
246
|
+
res = fct()
|
|
247
|
+
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
248
|
+
if warmup + repeat > 1:
|
|
249
|
+
if suffix == "run":
|
|
250
|
+
res = torch_deepcopy(res)
|
|
251
|
+
summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
|
|
252
|
+
summary[f"{suffix}_warmup"] = warmup
|
|
253
|
+
summary[f"{suffix}_repeat"] = repeat
|
|
254
|
+
last_ = None
|
|
255
|
+
end_w = max(0, warmup - 1)
|
|
256
|
+
for _w in range(end_w):
|
|
257
|
+
t = fct()
|
|
258
|
+
_ = string_type(t, with_shape=True, with_min_max=True)
|
|
259
|
+
if _ != last_ or _w == end_w - 1:
|
|
260
|
+
summary[f"io_{suffix}_{_w+1}"] = _
|
|
261
|
+
last_ = _
|
|
262
|
+
summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
|
|
263
|
+
times = []
|
|
264
|
+
for _r in range(repeat):
|
|
265
|
+
begin = time.perf_counter()
|
|
266
|
+
t = fct()
|
|
267
|
+
times.append(time.perf_counter() - begin)
|
|
268
|
+
a = np.array(times, dtype=np.float64)
|
|
269
|
+
a.sort()
|
|
270
|
+
i5 = max(1, a.shape[0] * 5 // 100)
|
|
271
|
+
i2 = max(1, a.shape[0] * 2 // 100)
|
|
272
|
+
summary[f"time_{suffix}_latency"] = a.mean()
|
|
273
|
+
summary[f"time_{suffix}_latency_std"] = a.std()
|
|
274
|
+
summary[f"time_{suffix}_latency_min"] = a.min()
|
|
275
|
+
summary[f"time_{suffix}_latency_max"] = a.max()
|
|
276
|
+
summary[f"time_{suffix}_latency_098"] = a[-i2]
|
|
277
|
+
summary[f"time_{suffix}_latency_095"] = a[-i5]
|
|
278
|
+
summary[f"time_{suffix}_latency_005"] = a[i5]
|
|
279
|
+
summary[f"time_{suffix}_latency_002"] = a[i2]
|
|
280
|
+
summary[f"time_{suffix}_n"] = len(a)
|
|
281
|
+
summary[f"time_{suffix}_latency_m98"] = a[i2:-i2].mean()
|
|
282
|
+
|
|
283
|
+
return res
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
287
|
+
"""Shrinks the configuration before it gets added to the information to log."""
|
|
288
|
+
new_cfg = {}
|
|
289
|
+
for k, v in cfg.items():
|
|
290
|
+
|
|
291
|
+
new_cfg[k] = (
|
|
292
|
+
v
|
|
293
|
+
if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
|
|
294
|
+
else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
|
|
295
|
+
)
|
|
296
|
+
return new_cfg
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def make_patch_kwargs(
|
|
300
|
+
patch: Union[bool, str, Dict[str, bool]] = False,
|
|
301
|
+
rewrite: bool = False,
|
|
302
|
+
) -> Dict[str, Any]:
|
|
303
|
+
"""Creates patch arguments."""
|
|
304
|
+
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
|
|
305
|
+
if isinstance(patch, bool):
|
|
306
|
+
patch_kwargs = default_patch if patch else dict(patch=False)
|
|
307
|
+
elif isinstance(patch, str):
|
|
308
|
+
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
|
|
309
|
+
else:
|
|
310
|
+
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
|
|
311
|
+
patch_kwargs = patch.copy()
|
|
312
|
+
if "patch" not in patch_kwargs:
|
|
313
|
+
if any(patch_kwargs.values()):
|
|
314
|
+
patch_kwargs["patch"] = True
|
|
315
|
+
elif len(patch) == 1 and patch.get("patch", False):
|
|
316
|
+
patch_kwargs.update(default_patch)
|
|
317
|
+
|
|
318
|
+
assert not rewrite or patch_kwargs.get("patch", False), (
|
|
319
|
+
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
|
|
320
|
+
f"patch must be True to enable rewriting, "
|
|
321
|
+
f"if --patch=0 was specified on the command line, rewrites are disabled."
|
|
322
|
+
)
|
|
323
|
+
return patch_kwargs
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def validate_model(
|
|
327
|
+
model_id: str,
|
|
328
|
+
task: Optional[str] = None,
|
|
329
|
+
do_run: bool = False,
|
|
330
|
+
exporter: Optional[str] = None,
|
|
331
|
+
do_same: bool = False,
|
|
332
|
+
verbose: int = 0,
|
|
333
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
334
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
335
|
+
same_as_pretrained: bool = False,
|
|
336
|
+
use_pretrained: bool = False,
|
|
337
|
+
optimization: Optional[str] = None,
|
|
338
|
+
quiet: bool = False,
|
|
339
|
+
patch: Union[bool, str, Dict[str, bool]] = False,
|
|
340
|
+
rewrite: bool = False,
|
|
341
|
+
stop_if_static: int = 1,
|
|
342
|
+
dump_folder: Optional[str] = None,
|
|
343
|
+
drop_inputs: Optional[List[str]] = None,
|
|
344
|
+
ortfusiontype: Optional[str] = None,
|
|
345
|
+
input_options: Optional[Dict[str, Any]] = None,
|
|
346
|
+
model_options: Optional[Dict[str, Any]] = None,
|
|
347
|
+
subfolder: Optional[str] = None,
|
|
348
|
+
opset: Optional[int] = None,
|
|
349
|
+
runtime: str = "onnxruntime",
|
|
350
|
+
repeat: int = 1,
|
|
351
|
+
warmup: int = 0,
|
|
352
|
+
inputs2: int = 1,
|
|
353
|
+
output_names: Optional[List[str]] = None,
|
|
354
|
+
ort_logs: bool = False,
|
|
355
|
+
quiet_input_sets: Optional[Set[str]] = None,
|
|
356
|
+
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
357
|
+
"""
|
|
358
|
+
Validates a model.
|
|
359
|
+
The function can also be called through the command line
|
|
360
|
+
:ref:`l-cmd-validate`.
|
|
361
|
+
|
|
362
|
+
:param model_id: model id to validate
|
|
363
|
+
:param task: task used to generate the necessary inputs,
|
|
364
|
+
can be left empty to use the default task for this model
|
|
365
|
+
if it can be determined
|
|
366
|
+
:param do_run: checks the model works with the defined inputs
|
|
367
|
+
:param exporter: exporter the model using this exporter,
|
|
368
|
+
available list: ``export-strict``, ``export-nostrict``, ...
|
|
369
|
+
see below
|
|
370
|
+
:param do_same: checks the discrepancies of the exported model
|
|
371
|
+
:param verbose: verbosity level
|
|
372
|
+
:param dtype: uses this dtype to check the model
|
|
373
|
+
:param device: do the verification on this device
|
|
374
|
+
:param same_as_pretrained: use a model equivalent to the trained,
|
|
375
|
+
this is not always possible
|
|
376
|
+
:param use_pretrained: use the trained model, not the untrained one
|
|
377
|
+
:param optimization: optimization to apply to the exported model,
|
|
378
|
+
depend on the the exporter
|
|
379
|
+
:param quiet: if quiet, catches exception if any issue
|
|
380
|
+
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
|
|
381
|
+
if True before exporting
|
|
382
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
|
|
383
|
+
a string can be used to specify only one of them
|
|
384
|
+
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
|
|
385
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
|
|
386
|
+
:param stop_if_static: stops if a dynamic dimension becomes static,
|
|
387
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
|
|
388
|
+
:param dump_folder: dumps everything in a subfolder of this one
|
|
389
|
+
:param drop_inputs: drops this list of inputs (given their names)
|
|
390
|
+
:param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
|
|
391
|
+
it accepts multiple values separated by ``|``,
|
|
392
|
+
see :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`
|
|
393
|
+
:param input_options: additional options to define the dummy inputs
|
|
394
|
+
used to export
|
|
395
|
+
:param model_options: additional options when creating the model such as
|
|
396
|
+
``num_hidden_layers`` or ``attn_implementation``
|
|
397
|
+
:param subfolder: version or subfolders to uses when retrieving a model id
|
|
398
|
+
:param opset: onnx opset to use for the conversion
|
|
399
|
+
:param runtime: onnx runtime to use to check about discrepancies,
|
|
400
|
+
possible values ``onnxruntime``, ``torch``, ``orteval``,
|
|
401
|
+
``orteval10``, ``ref`` only if `do_run` is true
|
|
402
|
+
:param repeat: number of time to measure the model
|
|
403
|
+
:param warmup: warmup the model first
|
|
404
|
+
:param inputs2: checks that other sets of inputs are running as well,
|
|
405
|
+
this ensures that the model does support dynamism, the value is used
|
|
406
|
+
as an increment to the first set of values (added to dimensions),
|
|
407
|
+
or an empty cache for example
|
|
408
|
+
:param output_names: output names the onnx exporter should use
|
|
409
|
+
:param ort_logs: increases onnxruntime verbosity when creating the session
|
|
410
|
+
:param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
|
|
411
|
+
even if quiet is False
|
|
412
|
+
:return: two dictionaries, one with some metrics,
|
|
413
|
+
another one with whatever the function produces
|
|
414
|
+
|
|
415
|
+
The following environment variables can be used to print out some
|
|
416
|
+
information:
|
|
417
|
+
|
|
418
|
+
* ``PRINT_CONFIG``: prints the model configuration
|
|
419
|
+
|
|
420
|
+
The following exporters are available:
|
|
421
|
+
|
|
422
|
+
* ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
|
|
423
|
+
* ``onnx-dynamo``: run :func:`torch.onnx.export` (...),
|
|
424
|
+
models can be optimized with ``optimization`` in ``("ir", "os_ort")``
|
|
425
|
+
* ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
|
|
426
|
+
* ``custom``: custom exporter (see :epkg:`experimental-experiment`),
|
|
427
|
+
models can be optimized with ``optimization`` in
|
|
428
|
+
``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
|
|
429
|
+
|
|
430
|
+
The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
|
|
431
|
+
exported model returns the same outputs as the original one, otherwise,
|
|
432
|
+
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
|
|
433
|
+
if ``runtime == 'torch'`` or
|
|
434
|
+
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
|
|
435
|
+
if ``runtime == 'orteval'`` or
|
|
436
|
+
:class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
|
|
437
|
+
if ``runtime == 'ref'``,
|
|
438
|
+
``orteval10`` increases the verbosity.
|
|
439
|
+
|
|
440
|
+
.. versionchanged:: 0.7.13
|
|
441
|
+
*inputs2* not only means a second set of inputs but many
|
|
442
|
+
such as ``input_empty_cache``
|
|
443
|
+
which refers to a set of inputs using an empty cache.
|
|
444
|
+
"""
|
|
445
|
+
main_validation_begin = time.perf_counter()
|
|
446
|
+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
447
|
+
model_id,
|
|
448
|
+
subfolder,
|
|
449
|
+
same_as_pretrained=same_as_pretrained,
|
|
450
|
+
use_pretrained=use_pretrained,
|
|
451
|
+
)
|
|
452
|
+
time_preprocess_model_id = time.perf_counter() - main_validation_begin
|
|
453
|
+
patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
|
|
454
|
+
|
|
455
|
+
summary = version_summary()
|
|
456
|
+
summary.update(
|
|
457
|
+
dict(
|
|
458
|
+
version_model_id=model_id,
|
|
459
|
+
version_do_run=str(do_run),
|
|
460
|
+
version_dtype=str(dtype or ""),
|
|
461
|
+
version_device=str(device or ""),
|
|
462
|
+
version_same_as_pretrained=str(same_as_pretrained),
|
|
463
|
+
version_use_pretrained=str(use_pretrained),
|
|
464
|
+
version_optimization=optimization or "",
|
|
465
|
+
version_quiet=str(quiet),
|
|
466
|
+
version_patch=str(patch),
|
|
467
|
+
version_patch_kwargs=str(patch_kwargs).replace(" ", ""),
|
|
468
|
+
version_rewrite=str(rewrite),
|
|
469
|
+
version_dump_folder=dump_folder or "",
|
|
470
|
+
version_drop_inputs=str(list(drop_inputs or "")),
|
|
471
|
+
version_ortfusiontype=ortfusiontype or "",
|
|
472
|
+
version_stop_if_static=str(stop_if_static),
|
|
473
|
+
version_exporter=exporter or "",
|
|
474
|
+
version_runtime=runtime,
|
|
475
|
+
version_inputs2=inputs2,
|
|
476
|
+
time_preprocess_model_id=time_preprocess_model_id,
|
|
477
|
+
)
|
|
478
|
+
)
|
|
479
|
+
if opset:
|
|
480
|
+
summary["version_opset"] = opset
|
|
481
|
+
|
|
482
|
+
folder_name = None
|
|
483
|
+
if dump_folder:
|
|
484
|
+
folder_name = _make_folder_name(
|
|
485
|
+
model_id,
|
|
486
|
+
exporter,
|
|
487
|
+
optimization,
|
|
488
|
+
dtype=dtype,
|
|
489
|
+
device=device,
|
|
490
|
+
subfolder=subfolder,
|
|
491
|
+
opset=opset,
|
|
492
|
+
drop_inputs=drop_inputs,
|
|
493
|
+
use_pretrained=use_pretrained,
|
|
494
|
+
same_as_pretrained=same_as_pretrained,
|
|
495
|
+
task=task,
|
|
496
|
+
)
|
|
497
|
+
dump_folder = os.path.join(dump_folder, folder_name)
|
|
498
|
+
if not os.path.exists(dump_folder):
|
|
499
|
+
os.makedirs(dump_folder)
|
|
500
|
+
summary["dump_folder"] = dump_folder
|
|
501
|
+
summary["dump_folder_name"] = folder_name
|
|
502
|
+
if verbose:
|
|
503
|
+
print(f"[validate_model] dump into {folder_name!r}")
|
|
504
|
+
|
|
505
|
+
if verbose:
|
|
506
|
+
if subfolder:
|
|
507
|
+
print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
|
|
508
|
+
else:
|
|
509
|
+
print(f"[validate_model] validate model id {model_id!r}")
|
|
510
|
+
if task:
|
|
511
|
+
print(f"[validate_model] with task {task!r}")
|
|
512
|
+
print(f"[validate_model] patch={patch!r}")
|
|
513
|
+
if model_options:
|
|
514
|
+
print(f"[validate_model] model_options={model_options!r}")
|
|
515
|
+
print(f"[validate_model] get dummy inputs with input_options={input_options}...")
|
|
516
|
+
print(
|
|
517
|
+
f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, "
|
|
518
|
+
f"stop_if_static={stop_if_static}"
|
|
519
|
+
)
|
|
520
|
+
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
|
|
521
|
+
print(f"[validate_model] dump_folder={dump_folder!r}")
|
|
522
|
+
print(f"[validate_model] output_names={output_names}")
|
|
523
|
+
summary["model_id"] = model_id
|
|
524
|
+
summary["model_subfolder"] = subfolder or ""
|
|
525
|
+
|
|
526
|
+
iop = input_options or {}
|
|
527
|
+
mop = model_options or {}
|
|
528
|
+
data = _quiet_or_not_quiet(
|
|
529
|
+
quiet,
|
|
530
|
+
"create_torch_model",
|
|
531
|
+
summary,
|
|
532
|
+
None,
|
|
533
|
+
(
|
|
534
|
+
lambda mid=model_id, v=verbose, task=task, uptr=use_pretrained, tr=same_as_pretrained, iop=iop, sub=subfolder, i2=inputs2: ( # noqa: E501
|
|
535
|
+
get_untrained_model_with_inputs(
|
|
536
|
+
mid,
|
|
537
|
+
verbose=v,
|
|
538
|
+
task=task,
|
|
539
|
+
use_pretrained=uptr,
|
|
540
|
+
same_as_pretrained=tr,
|
|
541
|
+
inputs_kwargs=iop,
|
|
542
|
+
model_kwargs=mop,
|
|
543
|
+
subfolder=sub,
|
|
544
|
+
add_second_input=i2,
|
|
545
|
+
)
|
|
546
|
+
)
|
|
547
|
+
),
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
|
|
551
|
+
|
|
552
|
+
if dump_folder:
|
|
553
|
+
with open(os.path.join(dump_folder, "model_config.txt"), "w") as f:
|
|
554
|
+
f.write(f"model_id: {model_id}\n------\n")
|
|
555
|
+
f.write(
|
|
556
|
+
pprint.pformat(
|
|
557
|
+
data["configuration"]
|
|
558
|
+
if type(data["configuration"]) is dict
|
|
559
|
+
else data["configuration"].to_dict()
|
|
560
|
+
)
|
|
561
|
+
)
|
|
562
|
+
dump_info = data.get("dump_info", None)
|
|
563
|
+
if dump_info:
|
|
564
|
+
with open(os.path.join(dump_folder, "model_dump_info.txt"), "w") as f:
|
|
565
|
+
f.write(f"model_id: {model_id}\n------\n")
|
|
566
|
+
f.write(pprint.pformat(dump_info))
|
|
567
|
+
|
|
568
|
+
if exporter == "modelbuilder":
|
|
569
|
+
# Models used with ModelBuilder do not like batch size > 1.
|
|
570
|
+
# Let's change that.
|
|
571
|
+
for k in ["inputs", "inputs2"]:
|
|
572
|
+
if k not in data:
|
|
573
|
+
continue
|
|
574
|
+
if verbose:
|
|
575
|
+
print(f"[validate_model] set batch=1 for data[{k!r}]")
|
|
576
|
+
print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
|
|
577
|
+
cpl = CoupleInputsDynamicShapes(
|
|
578
|
+
tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
|
|
579
|
+
)
|
|
580
|
+
with register_additional_serialization_functions(patch_transformers=True): # type: ignore[arg-type]
|
|
581
|
+
data[k] = cpl.change_dynamic_dimensions(
|
|
582
|
+
desired_values=dict(batch=1), only_desired=True
|
|
583
|
+
)
|
|
584
|
+
if verbose:
|
|
585
|
+
print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
|
|
586
|
+
|
|
587
|
+
# modelbuilder needs different treatments sometimes, so
|
|
588
|
+
# we mark it for later usage.
|
|
589
|
+
# for example, it has different past_kv ordering than
|
|
590
|
+
# flattened CacheObject
|
|
591
|
+
data["exporter"] = exporter
|
|
592
|
+
data["input_options"] = iop
|
|
593
|
+
data["model_options"] = mop
|
|
594
|
+
data["model_dump_folder"] = dump_folder
|
|
595
|
+
if dtype:
|
|
596
|
+
data["model_dtype"] = dtype if isinstance(dtype, str) else str(dtype)
|
|
597
|
+
if device:
|
|
598
|
+
data["model_device"] = str(device)
|
|
599
|
+
if opset:
|
|
600
|
+
data["model_opset"] = opset
|
|
601
|
+
if "rewrite" in data:
|
|
602
|
+
if rewrite:
|
|
603
|
+
summary["model_rewrite"] = str(data["rewrite"])
|
|
604
|
+
if verbose:
|
|
605
|
+
print(f"[validate_model] model_rewrite={summary['model_rewrite']}")
|
|
606
|
+
else:
|
|
607
|
+
del data["rewrite"]
|
|
608
|
+
if verbose:
|
|
609
|
+
print("[validate_model] no rewrite")
|
|
610
|
+
if os.environ.get("PRINT_CONFIG", "0") in (1, "1"):
|
|
611
|
+
print("[validate_model] -- PRINT CONFIG")
|
|
612
|
+
print("-- type(config)", type(data["configuration"]))
|
|
613
|
+
print(data["configuration"])
|
|
614
|
+
print("[validate_model] -- END PRINT CONFIG")
|
|
615
|
+
if iop:
|
|
616
|
+
summary["input_options"] = str(iop)
|
|
617
|
+
if mop:
|
|
618
|
+
summary["model_options"] = str(mop)
|
|
619
|
+
if "ERR_create" in summary:
|
|
620
|
+
return summary, data
|
|
621
|
+
|
|
622
|
+
if drop_inputs:
|
|
623
|
+
if verbose:
|
|
624
|
+
print(f"[validate_model] -- drop inputs: {drop_inputs!r}")
|
|
625
|
+
print(f"[validate_model] current inputs: {string_type(data['inputs'])}")
|
|
626
|
+
print(
|
|
627
|
+
f"[validate_model] current dynnamic_shapes: "
|
|
628
|
+
f"{string_type(data['dynamic_shapes'])}"
|
|
629
|
+
)
|
|
630
|
+
data["inputs"], data["dynamic_shapes"] = filter_inputs(
|
|
631
|
+
data["inputs"],
|
|
632
|
+
drop_names=drop_inputs,
|
|
633
|
+
model=data["model"],
|
|
634
|
+
dynamic_shapes=data["dynamic_shapes"],
|
|
635
|
+
)
|
|
636
|
+
if verbose:
|
|
637
|
+
print(f"[validate_model] new inputs: {string_type(data['inputs'])}")
|
|
638
|
+
print(f"[validate_model] new dynamic_hapes: {string_type(data['dynamic_shapes'])}")
|
|
639
|
+
if second_input_keys:
|
|
640
|
+
for k in second_input_keys:
|
|
641
|
+
data[k], _ = filter_inputs(
|
|
642
|
+
data[k],
|
|
643
|
+
drop_names=drop_inputs,
|
|
644
|
+
model=data["model"],
|
|
645
|
+
dynamic_shapes=data["dynamic_shapes"],
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
if not empty(dtype):
|
|
649
|
+
if isinstance(dtype, str):
|
|
650
|
+
dtype = getattr(torch, dtype)
|
|
651
|
+
if verbose:
|
|
652
|
+
print(f"[validate_model] dtype conversion to {dtype}")
|
|
653
|
+
data["model"] = to_any(data["model"], dtype) # type: ignore
|
|
654
|
+
data["inputs"] = to_any(data["inputs"], dtype) # type: ignore
|
|
655
|
+
summary["model_dtype"] = str(dtype)
|
|
656
|
+
if second_input_keys:
|
|
657
|
+
for k in second_input_keys:
|
|
658
|
+
data[k] = to_any(data[k], dtype) # type: ignore
|
|
659
|
+
|
|
660
|
+
if not empty(device):
|
|
661
|
+
if verbose:
|
|
662
|
+
print(f"[validate_model] device conversion to {device}")
|
|
663
|
+
data["model"] = to_any(data["model"], device) # type: ignore
|
|
664
|
+
data["inputs"] = to_any(data["inputs"], device) # type: ignore
|
|
665
|
+
summary["model_device"] = str(device)
|
|
666
|
+
if second_input_keys:
|
|
667
|
+
for k in second_input_keys:
|
|
668
|
+
data[k] = to_any(data[k], device) # type: ignore
|
|
669
|
+
|
|
670
|
+
for k in ["task", "size", "n_weights"]:
|
|
671
|
+
summary[f"model_{k.replace('_','')}"] = data[k]
|
|
672
|
+
summary["second_input_keys"] = ",".join(second_input_keys)
|
|
673
|
+
summary["model_inputs_options"] = str(input_options or "")
|
|
674
|
+
summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
|
|
675
|
+
summary["model_shapes"] = string_type(data["dynamic_shapes"])
|
|
676
|
+
summary["model_class"] = data["model"].__class__.__name__
|
|
677
|
+
summary["model_module"] = str(data["model"].__class__.__module__)
|
|
678
|
+
if summary["model_module"] in sys.modules:
|
|
679
|
+
summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
|
|
680
|
+
summary["model_config_class"] = data["configuration"].__class__.__name__
|
|
681
|
+
summary["model_config"] = str(
|
|
682
|
+
shrink_config(
|
|
683
|
+
data["configuration"]
|
|
684
|
+
if type(data["configuration"]) is dict
|
|
685
|
+
else data["configuration"].to_dict()
|
|
686
|
+
)
|
|
687
|
+
).replace(" ", "")
|
|
688
|
+
summary["model_id"] = model_id
|
|
689
|
+
|
|
690
|
+
if verbose:
|
|
691
|
+
print("[validate_model] --")
|
|
692
|
+
print(f"[validate_model] task={data['task']}")
|
|
693
|
+
print(f"[validate_model] size={data['size'] / 2**20} Mb")
|
|
694
|
+
print(f"[validate_model] n_weights={data['n_weights'] / 1e6} millions parameters")
|
|
695
|
+
for k, v in data["inputs"].items():
|
|
696
|
+
print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}")
|
|
697
|
+
for k, v in data["dynamic_shapes"].items():
|
|
698
|
+
print(f"[validate_model] +SHAPE {k}={string_type(v)}")
|
|
699
|
+
print(f"[validate_model] second_input_keys={second_input_keys}")
|
|
700
|
+
print("[validate_model] --")
|
|
701
|
+
|
|
702
|
+
if do_run:
|
|
703
|
+
validation_begin = time.perf_counter()
|
|
704
|
+
|
|
705
|
+
_validate_do_run_model(
|
|
706
|
+
data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
|
|
707
|
+
)
|
|
708
|
+
if second_input_keys:
|
|
709
|
+
for k in second_input_keys:
|
|
710
|
+
_validate_do_run_model(
|
|
711
|
+
data,
|
|
712
|
+
summary,
|
|
713
|
+
k,
|
|
714
|
+
f"run2{k[6:]}",
|
|
715
|
+
f"run_expected2{k[6:]}",
|
|
716
|
+
verbose,
|
|
717
|
+
1,
|
|
718
|
+
0,
|
|
719
|
+
quiet,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
|
|
723
|
+
|
|
724
|
+
if exporter:
|
|
725
|
+
print(
|
|
726
|
+
f"[validate_model] -- export the model with {exporter!r}, "
|
|
727
|
+
f"optimization={optimization!r}"
|
|
728
|
+
)
|
|
729
|
+
exporter_begin = time.perf_counter()
|
|
730
|
+
if patch_kwargs:
|
|
731
|
+
if verbose:
|
|
732
|
+
print(
|
|
733
|
+
f"[validate_model] applies patches before exporting "
|
|
734
|
+
f"stop_if_static={stop_if_static}"
|
|
735
|
+
)
|
|
736
|
+
with torch_export_patches( # type: ignore
|
|
737
|
+
stop_if_static=stop_if_static,
|
|
738
|
+
verbose=max(0, verbose - 1),
|
|
739
|
+
rewrite=data.get("rewrite", None),
|
|
740
|
+
dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
|
|
741
|
+
**patch_kwargs, # type: ignore[arg-type]
|
|
742
|
+
) as modificator:
|
|
743
|
+
data["inputs_export"] = modificator(data["inputs"]) # type: ignore
|
|
744
|
+
|
|
745
|
+
if do_run:
|
|
746
|
+
_validate_do_run_exported_program(data, summary, verbose, quiet)
|
|
747
|
+
|
|
748
|
+
# data is modified inplace
|
|
749
|
+
summary_export, data = call_exporter(
|
|
750
|
+
exporter=exporter,
|
|
751
|
+
data=data,
|
|
752
|
+
quiet=quiet,
|
|
753
|
+
verbose=verbose,
|
|
754
|
+
optimization=optimization,
|
|
755
|
+
do_run=do_run,
|
|
756
|
+
dump_folder=dump_folder,
|
|
757
|
+
output_names=output_names,
|
|
758
|
+
)
|
|
759
|
+
else:
|
|
760
|
+
data["inputs_export"] = data["inputs"]
|
|
761
|
+
# data is modified inplace
|
|
762
|
+
summary_export, data = call_exporter(
|
|
763
|
+
exporter=exporter,
|
|
764
|
+
data=data,
|
|
765
|
+
quiet=quiet,
|
|
766
|
+
verbose=verbose,
|
|
767
|
+
optimization=optimization,
|
|
768
|
+
do_run=do_run,
|
|
769
|
+
dump_folder=dump_folder,
|
|
770
|
+
output_names=output_names,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
summary.update(summary_export)
|
|
774
|
+
summary["time_total_exporter"] = time.perf_counter() - exporter_begin
|
|
775
|
+
|
|
776
|
+
dump_stats = None
|
|
777
|
+
if dump_folder:
|
|
778
|
+
if "exported_program" in data:
|
|
779
|
+
ep = data["exported_program"]
|
|
780
|
+
if verbose:
|
|
781
|
+
print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
|
|
782
|
+
assert isinstance(
|
|
783
|
+
folder_name, str
|
|
784
|
+
), f"folder_name={folder_name!r} should be a string"
|
|
785
|
+
folder_name = folder_name.replace("/", "-")
|
|
786
|
+
with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
|
|
787
|
+
f.write(str(ep))
|
|
788
|
+
torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
|
|
789
|
+
with open(os.path.join(dump_folder, f"{folder_name}.graph"), "w") as f:
|
|
790
|
+
f.write(str(ep.graph))
|
|
791
|
+
if verbose:
|
|
792
|
+
print("[validate_model] done (dump ep)")
|
|
793
|
+
if "onnx_program" in data:
|
|
794
|
+
assert isinstance(
|
|
795
|
+
folder_name, str
|
|
796
|
+
), f"folder_name={folder_name!r} should be a string"
|
|
797
|
+
folder_name = folder_name.replace("/", "-")
|
|
798
|
+
epo = data["onnx_program"]
|
|
799
|
+
if verbose:
|
|
800
|
+
print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
|
|
801
|
+
onnx_filename = os.path.join(dump_folder, f"{folder_name}.onnx")
|
|
802
|
+
begin = time.perf_counter()
|
|
803
|
+
if isinstance(epo, onnx.model_container.ModelContainer):
|
|
804
|
+
epo.save(onnx_filename, all_tensors_to_one_file=True)
|
|
805
|
+
elif isinstance(epo, onnx.ModelProto):
|
|
806
|
+
if os.path.exists(f"{onnx_filename}.data"):
|
|
807
|
+
os.remove(f"{onnx_filename}.data")
|
|
808
|
+
onnx.save(
|
|
809
|
+
epo,
|
|
810
|
+
onnx_filename,
|
|
811
|
+
save_as_external_data=True,
|
|
812
|
+
all_tensors_to_one_file=True,
|
|
813
|
+
location=f"{os.path.split(onnx_filename)[-1]}.data",
|
|
814
|
+
)
|
|
815
|
+
else:
|
|
816
|
+
epo.save(onnx_filename, external_data=True)
|
|
817
|
+
duration = time.perf_counter() - begin
|
|
818
|
+
if verbose:
|
|
819
|
+
print(f"[validate_model] done (dump onnx) in {duration}")
|
|
820
|
+
data["onnx_filename"] = onnx_filename
|
|
821
|
+
summary["time_onnx_save"] = duration
|
|
822
|
+
summary.update(compute_statistics(onnx_filename))
|
|
823
|
+
del epo
|
|
824
|
+
|
|
825
|
+
if verbose:
|
|
826
|
+
print(f"[validate_model] dumps statistics in {dump_folder!r}...")
|
|
827
|
+
dump_stats = os.path.join(dump_folder, f"{folder_name}.stats")
|
|
828
|
+
with open(dump_stats, "w") as f:
|
|
829
|
+
for k, v in sorted(summary.items()):
|
|
830
|
+
f.write(f":{k}:{v};\n")
|
|
831
|
+
if verbose:
|
|
832
|
+
print("[validate_model] done (dump)")
|
|
833
|
+
|
|
834
|
+
if not exporter or (
|
|
835
|
+
not exporter.startswith(("onnx-", "custom-"))
|
|
836
|
+
and exporter not in ("custom", "modelbuilder")
|
|
837
|
+
):
|
|
838
|
+
if verbose:
|
|
839
|
+
print("[validate_model] -- done (final)")
|
|
840
|
+
if dump_stats:
|
|
841
|
+
with open(dump_stats, "w") as f:
|
|
842
|
+
for k, v in sorted(summary.items()):
|
|
843
|
+
f.write(f":{k}:{v};\n")
|
|
844
|
+
return summary, data
|
|
845
|
+
|
|
846
|
+
if do_run:
|
|
847
|
+
# Let's move the model to CPU to make sure it frees GPU memory.
|
|
848
|
+
if verbose:
|
|
849
|
+
# It does not really work for the time being and the model
|
|
850
|
+
# gets loaded twice, one by torch, one by onnxruntime
|
|
851
|
+
print("[validation_model] -- delete the model")
|
|
852
|
+
for key in ["model", "onnx_program", "config"]:
|
|
853
|
+
if key in data:
|
|
854
|
+
del data[key]
|
|
855
|
+
if device is not None and "cuda" in str(device).lower():
|
|
856
|
+
torch.cuda.empty_cache()
|
|
857
|
+
gc.collect()
|
|
858
|
+
print("[validation_model] -- done")
|
|
859
|
+
|
|
860
|
+
validation_begin = time.perf_counter()
|
|
861
|
+
summary_valid, data = validate_onnx_model(
|
|
862
|
+
data=data,
|
|
863
|
+
quiet=quiet,
|
|
864
|
+
verbose=verbose,
|
|
865
|
+
runtime=runtime,
|
|
866
|
+
repeat=repeat,
|
|
867
|
+
warmup=warmup,
|
|
868
|
+
second_input_keys=second_input_keys,
|
|
869
|
+
ort_logs=ort_logs,
|
|
870
|
+
quiet_input_sets=quiet_input_sets,
|
|
871
|
+
)
|
|
872
|
+
summary.update(summary_valid)
|
|
873
|
+
summary["time_total_validation_onnx"] = time.perf_counter() - validation_begin
|
|
874
|
+
|
|
875
|
+
if ortfusiontype and "onnx_filename" in data:
|
|
876
|
+
assert (
|
|
877
|
+
"configuration" in data
|
|
878
|
+
), f"missing configuration in data, cannot run ort fusion for model_id={model_id}"
|
|
879
|
+
config = data["configuration"]
|
|
880
|
+
assert hasattr(
|
|
881
|
+
config, "hidden_size"
|
|
882
|
+
), f"Missing attribute hidden_size in configuration {config}"
|
|
883
|
+
hidden_size = config.hidden_size
|
|
884
|
+
assert hasattr(
|
|
885
|
+
config, "num_attention_heads"
|
|
886
|
+
), f"Missing attribute num_attention_heads in configuration {config}"
|
|
887
|
+
num_attention_heads = config.num_attention_heads
|
|
888
|
+
|
|
889
|
+
if ortfusiontype == "ALL":
|
|
890
|
+
from onnxruntime.transformers.optimizer import MODEL_TYPES
|
|
891
|
+
|
|
892
|
+
model_types = sorted(MODEL_TYPES)
|
|
893
|
+
else:
|
|
894
|
+
model_types = ortfusiontype.split("|")
|
|
895
|
+
for model_type in model_types:
|
|
896
|
+
flavour = f"ort{model_type}"
|
|
897
|
+
summary[f"version_{flavour}_hidden_size"] = hidden_size
|
|
898
|
+
summary[f"version_{flavour}_num_attention_heads"] = num_attention_heads
|
|
899
|
+
|
|
900
|
+
begin = time.perf_counter()
|
|
901
|
+
if verbose:
|
|
902
|
+
print(f"[validate_model] run onnxruntime fusion for {model_type!r}")
|
|
903
|
+
input_filename = data["onnx_filename"]
|
|
904
|
+
output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx"
|
|
905
|
+
ort_sum, ort_data = run_ort_fusion(
|
|
906
|
+
input_filename,
|
|
907
|
+
output_path,
|
|
908
|
+
model_type=model_type,
|
|
909
|
+
num_attention_heads=num_attention_heads,
|
|
910
|
+
hidden_size=hidden_size,
|
|
911
|
+
)
|
|
912
|
+
summary.update(ort_sum)
|
|
913
|
+
data.update(ort_data)
|
|
914
|
+
data[f"onnx_filename_{flavour}"] = output_path
|
|
915
|
+
duration = time.perf_counter() - begin
|
|
916
|
+
summary[f"time_ortfusion_{flavour}"] = duration
|
|
917
|
+
if verbose:
|
|
918
|
+
print(
|
|
919
|
+
f"[validate_model] done {model_type!r} in {duration}, "
|
|
920
|
+
f"saved into {output_path!r}"
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
if do_run:
|
|
924
|
+
summary_valid, data = validate_onnx_model(
|
|
925
|
+
data=data,
|
|
926
|
+
quiet=quiet,
|
|
927
|
+
verbose=verbose,
|
|
928
|
+
flavour=flavour,
|
|
929
|
+
runtime=runtime,
|
|
930
|
+
repeat=repeat,
|
|
931
|
+
warmup=warmup,
|
|
932
|
+
second_input_keys=second_input_keys,
|
|
933
|
+
quiet_input_sets=quiet_input_sets,
|
|
934
|
+
)
|
|
935
|
+
summary.update(summary_valid)
|
|
936
|
+
|
|
937
|
+
_compute_final_statistics(summary)
|
|
938
|
+
summary["time_total"] = time.perf_counter() - main_validation_begin
|
|
939
|
+
|
|
940
|
+
if verbose:
|
|
941
|
+
print("[validate_model] -- done (final)")
|
|
942
|
+
if dump_stats:
|
|
943
|
+
# Dumps again the statistics.
|
|
944
|
+
with open(dump_stats, "w") as f:
|
|
945
|
+
for k, v in sorted(summary.items()):
|
|
946
|
+
f.write(f":{k}:{v};\n")
|
|
947
|
+
return summary, data
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
|
|
951
|
+
"""Computes some statistics on the model itself."""
|
|
952
|
+
onx = onnx.load(onnx_filename, load_external_data=False)
|
|
953
|
+
cache_functions = {(f.domain, f.name): f for f in onx.functions}
|
|
954
|
+
local_domains = set(f.domain for f in onx.functions)
|
|
955
|
+
|
|
956
|
+
def node_iter(proto):
|
|
957
|
+
if isinstance(proto, onnx.ModelProto):
|
|
958
|
+
for f in proto.functions:
|
|
959
|
+
yield from node_iter(f)
|
|
960
|
+
yield from node_iter(proto.graph)
|
|
961
|
+
elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
|
|
962
|
+
for node in proto.node:
|
|
963
|
+
yield node
|
|
964
|
+
|
|
965
|
+
# Let's inline the function
|
|
966
|
+
key = node.domain, node.op_type
|
|
967
|
+
if key in cache_functions:
|
|
968
|
+
yield from node_iter(cache_functions[key])
|
|
969
|
+
|
|
970
|
+
# Let's continue
|
|
971
|
+
for att in node.attribute:
|
|
972
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
973
|
+
yield from node_iter(att.g)
|
|
974
|
+
if hasattr(proto, "initializer"):
|
|
975
|
+
yield from proto.initializer
|
|
976
|
+
else:
|
|
977
|
+
raise NotImplementedError(f"Unexpected type={type(proto)}")
|
|
978
|
+
|
|
979
|
+
counts: Dict[str, Union[float, int]] = {}
|
|
980
|
+
n_nodes = 0
|
|
981
|
+
n_nodes_nocst = 0
|
|
982
|
+
for proto in node_iter(onx):
|
|
983
|
+
if isinstance(proto, onnx.NodeProto):
|
|
984
|
+
key = f"n_node_{proto.op_type}"
|
|
985
|
+
n_nodes += 1
|
|
986
|
+
if proto.op_type != "Constant":
|
|
987
|
+
n_nodes_nocst += 1
|
|
988
|
+
if proto.domain in local_domains:
|
|
989
|
+
key = "n_node_local_function"
|
|
990
|
+
if key not in counts:
|
|
991
|
+
counts[key] = 0
|
|
992
|
+
counts[key] += 1
|
|
993
|
+
else:
|
|
994
|
+
key = f"n_node_initializer_{proto.data_type}"
|
|
995
|
+
|
|
996
|
+
if key not in counts:
|
|
997
|
+
counts[key] = 0
|
|
998
|
+
counts[key] += 1
|
|
999
|
+
|
|
1000
|
+
counts["n_node_nodes"] = n_nodes
|
|
1001
|
+
counts["n_node_nodes_nocst"] = n_nodes_nocst
|
|
1002
|
+
counts["n_node_functions"] = len(onx.functions)
|
|
1003
|
+
return counts
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
def _validate_do_run_model(
|
|
1007
|
+
data, summary, key, tag, expected_tag, verbose, repeat, warmup, quiet
|
|
1008
|
+
):
|
|
1009
|
+
if verbose:
|
|
1010
|
+
print(f"[validate_model] -- run the model inputs={key!r}...")
|
|
1011
|
+
print(f"[validate_model] {key}={string_type(data[key], with_shape=True)}")
|
|
1012
|
+
# We make a copy of the input just in case the model modifies them inplace
|
|
1013
|
+
hash_inputs = string_type(data[key], with_shape=True)
|
|
1014
|
+
inputs = torch_deepcopy(data[key])
|
|
1015
|
+
model = data["model"]
|
|
1016
|
+
|
|
1017
|
+
expected = _quiet_or_not_quiet(
|
|
1018
|
+
quiet,
|
|
1019
|
+
tag,
|
|
1020
|
+
summary,
|
|
1021
|
+
data,
|
|
1022
|
+
(lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
|
|
1023
|
+
repeat=repeat,
|
|
1024
|
+
warmup=warmup,
|
|
1025
|
+
)
|
|
1026
|
+
if f"ERR_{tag}" in summary:
|
|
1027
|
+
return summary, data
|
|
1028
|
+
|
|
1029
|
+
summary[expected_tag] = string_type(expected, with_shape=True)
|
|
1030
|
+
if verbose:
|
|
1031
|
+
print(f"[validate_model] done ([{tag}])")
|
|
1032
|
+
data[expected_tag] = expected
|
|
1033
|
+
assert hash_inputs == string_type(data[key], with_shape=True), (
|
|
1034
|
+
f"The model execution did modified the inputs:\n"
|
|
1035
|
+
f"before: {hash_inputs}\n"
|
|
1036
|
+
f" after: {string_type(data[key], with_shape=True)}"
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
def _validate_do_run_exported_program(data, summary, verbose, quiet):
|
|
1041
|
+
|
|
1042
|
+
# We run a second time the model to check the patch did not
|
|
1043
|
+
# introduce any discrepancies
|
|
1044
|
+
if verbose:
|
|
1045
|
+
print("[validate_model] run patched model...")
|
|
1046
|
+
print(
|
|
1047
|
+
f"[validate_model] patched inputs="
|
|
1048
|
+
f"{string_type(data['inputs_export'], with_shape=True)}"
|
|
1049
|
+
)
|
|
1050
|
+
hash_inputs = string_type(data["inputs_export"], with_shape=True)
|
|
1051
|
+
|
|
1052
|
+
# We make a copy of the input just in case the model modifies them inplace
|
|
1053
|
+
inputs = torch_deepcopy(data["inputs_export"])
|
|
1054
|
+
model = data["model"]
|
|
1055
|
+
|
|
1056
|
+
expected = _quiet_or_not_quiet(
|
|
1057
|
+
quiet,
|
|
1058
|
+
"run_patched",
|
|
1059
|
+
summary,
|
|
1060
|
+
data,
|
|
1061
|
+
(lambda m=model, inp=inputs: m(**inp)),
|
|
1062
|
+
)
|
|
1063
|
+
if "ERR_run_patched" in summary:
|
|
1064
|
+
return summary, data
|
|
1065
|
+
|
|
1066
|
+
disc = max_diff(data["run_expected"], expected)
|
|
1067
|
+
for k, v in disc.items():
|
|
1068
|
+
summary[f"disc_patched_{k}"] = str(v)
|
|
1069
|
+
if verbose:
|
|
1070
|
+
print("[validate_model] done (patched run)")
|
|
1071
|
+
print(f"[validate_model] patched discrepancies={string_diff(disc)}")
|
|
1072
|
+
assert hash_inputs == string_type(data["inputs_export"], with_shape=True), (
|
|
1073
|
+
f"The model execution did modified the inputs:\n"
|
|
1074
|
+
f"before: {hash_inputs}\n"
|
|
1075
|
+
f" after: {string_type(data['inputs_export'], with_shape=True)}"
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
_cache_export_times = []
|
|
1080
|
+
_main_export_function = torch.export.export
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
def _torch_export_export(*args, _export=_main_export_function, **kwargs):
|
|
1084
|
+
begin = time.perf_counter()
|
|
1085
|
+
res = _export(*args, **kwargs)
|
|
1086
|
+
duration = time.perf_counter() - begin
|
|
1087
|
+
_cache_export_times.append(duration)
|
|
1088
|
+
return res
|
|
1089
|
+
|
|
1090
|
+
|
|
1091
|
+
def _restore_torch_export_export(summary):
|
|
1092
|
+
torch.export.export = _main_export_function
|
|
1093
|
+
if _cache_export_times:
|
|
1094
|
+
summary["time_torch_export_export"] = sum(_cache_export_times)
|
|
1095
|
+
summary["time_torch_export_export_n"] = len(_cache_export_times)
|
|
1096
|
+
_cache_export_times.clear()
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
def call_exporter(
|
|
1100
|
+
data: Dict[str, Any],
|
|
1101
|
+
exporter: str,
|
|
1102
|
+
quiet: bool = False,
|
|
1103
|
+
verbose: int = 0,
|
|
1104
|
+
optimization: Optional[str] = None,
|
|
1105
|
+
do_run: bool = False,
|
|
1106
|
+
dump_folder: Optional[str] = None,
|
|
1107
|
+
output_names: Optional[List[str]] = None,
|
|
1108
|
+
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
1109
|
+
"""
|
|
1110
|
+
Calls an exporter on a model;
|
|
1111
|
+
If a patch must be applied, it should be before this functions.
|
|
1112
|
+
|
|
1113
|
+
:param data: dictionary with all the necessary inputs
|
|
1114
|
+
:param exporter: exporter to call
|
|
1115
|
+
:param quiet: catch exception or not
|
|
1116
|
+
:param verbose: verbosity
|
|
1117
|
+
:param optimization: optimization to do
|
|
1118
|
+
:param do_run: runs and compute discrepancies
|
|
1119
|
+
:param dump_folder: to dump additional information
|
|
1120
|
+
:param output_names: list of output names to use with the onnx exporter
|
|
1121
|
+
:return: two dictionaries, one with some metrics,
|
|
1122
|
+
another one with whatever the function produces
|
|
1123
|
+
"""
|
|
1124
|
+
_cache_export_times.clear()
|
|
1125
|
+
torch.export.export = _torch_export_export
|
|
1126
|
+
|
|
1127
|
+
if exporter == "export" or exporter.startswith("export-"):
|
|
1128
|
+
# torch export
|
|
1129
|
+
summary, data = call_torch_export_export(
|
|
1130
|
+
exporter=exporter,
|
|
1131
|
+
data=data,
|
|
1132
|
+
quiet=quiet,
|
|
1133
|
+
verbose=verbose,
|
|
1134
|
+
optimization=optimization,
|
|
1135
|
+
do_run=do_run,
|
|
1136
|
+
)
|
|
1137
|
+
_restore_torch_export_export(summary)
|
|
1138
|
+
return summary, data
|
|
1139
|
+
if exporter.startswith("onnx-"):
|
|
1140
|
+
# torch export
|
|
1141
|
+
summary, data = call_torch_export_onnx(
|
|
1142
|
+
exporter=exporter,
|
|
1143
|
+
data=data,
|
|
1144
|
+
quiet=quiet,
|
|
1145
|
+
verbose=verbose,
|
|
1146
|
+
optimization=optimization,
|
|
1147
|
+
output_names=output_names,
|
|
1148
|
+
)
|
|
1149
|
+
_restore_torch_export_export(summary)
|
|
1150
|
+
return summary, data
|
|
1151
|
+
if exporter == "custom" or exporter.startswith("custom"):
|
|
1152
|
+
# torch export
|
|
1153
|
+
summary, data = call_torch_export_custom(
|
|
1154
|
+
exporter=exporter,
|
|
1155
|
+
data=data,
|
|
1156
|
+
quiet=quiet,
|
|
1157
|
+
verbose=verbose,
|
|
1158
|
+
optimization=optimization,
|
|
1159
|
+
dump_folder=dump_folder,
|
|
1160
|
+
output_names=output_names,
|
|
1161
|
+
)
|
|
1162
|
+
_restore_torch_export_export(summary)
|
|
1163
|
+
return summary, data
|
|
1164
|
+
if exporter == "modelbuilder":
|
|
1165
|
+
# torch export
|
|
1166
|
+
summary, data = call_torch_export_model_builder(
|
|
1167
|
+
exporter=exporter,
|
|
1168
|
+
data=data,
|
|
1169
|
+
quiet=quiet,
|
|
1170
|
+
verbose=verbose,
|
|
1171
|
+
optimization=optimization,
|
|
1172
|
+
output_names=output_names,
|
|
1173
|
+
)
|
|
1174
|
+
_restore_torch_export_export(summary)
|
|
1175
|
+
return summary, data
|
|
1176
|
+
raise NotImplementedError(
|
|
1177
|
+
f"export with {exporter!r} and optimization={optimization!r} not implemented yet, "
|
|
1178
|
+
f"exporter must startswith 'onnx-', 'custom', 'export', 'modelbuilder' "
|
|
1179
|
+
f"(onnx-dynamo, custom, export), optimization can 'ir', "
|
|
1180
|
+
f"'default', 'default+onnxruntime', "
|
|
1181
|
+
f"'default+onnxruntime+os_ort', 'ir', 'os_ort'"
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
def call_torch_export_export(
|
|
1186
|
+
data: Dict[str, Any],
|
|
1187
|
+
exporter: str,
|
|
1188
|
+
quiet: bool = False,
|
|
1189
|
+
verbose: int = 0,
|
|
1190
|
+
optimization: Optional[str] = None,
|
|
1191
|
+
do_run: bool = False,
|
|
1192
|
+
):
|
|
1193
|
+
"""
|
|
1194
|
+
Exports a model with :func:`torch.export.export`.
|
|
1195
|
+
If a patch must be applied, it should be before this functions.
|
|
1196
|
+
|
|
1197
|
+
:param data: dictionary with all the necessary inputs, the dictionary must
|
|
1198
|
+
contains keys ``model`` and ``inputs_export``
|
|
1199
|
+
:param exporter: exporter to call
|
|
1200
|
+
:param quiet: catch exception or not
|
|
1201
|
+
:param verbose: verbosity
|
|
1202
|
+
:param optimization: optimization to do
|
|
1203
|
+
:param do_run: runs and compute discrepancies
|
|
1204
|
+
:return: two dictionaries, one with some metrics,
|
|
1205
|
+
another one with whatever the function produces
|
|
1206
|
+
"""
|
|
1207
|
+
assert exporter in {
|
|
1208
|
+
"export",
|
|
1209
|
+
"export-strict",
|
|
1210
|
+
"export-nostrict",
|
|
1211
|
+
}, f"Unexpected value for exporter={exporter!r}"
|
|
1212
|
+
assert not optimization, f"No optimization is implemented for exporter={exporter!r}"
|
|
1213
|
+
assert "model" in data, f"model is missing from data: {sorted(data)}"
|
|
1214
|
+
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
|
|
1215
|
+
summary: Dict[str, Union[str, int, float]] = {}
|
|
1216
|
+
strict = "-strict" in exporter
|
|
1217
|
+
args, kwargs = split_args_kwargs(data["inputs_export"])
|
|
1218
|
+
ds = data.get("dynamic_shapes", None)
|
|
1219
|
+
|
|
1220
|
+
summary["export_exporter"] = exporter
|
|
1221
|
+
summary["export_optimization"] = optimization or ""
|
|
1222
|
+
summary["export_strict"] = strict
|
|
1223
|
+
summary["export_args"] = string_type(args, with_shape=True)
|
|
1224
|
+
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
|
|
1225
|
+
summary["export_dynamic_shapes"] = string_type(ds)
|
|
1226
|
+
|
|
1227
|
+
# There is an issue with DynamicShape [[],[]] becomes []
|
|
1228
|
+
dse = use_dyn_not_str(ds)
|
|
1229
|
+
# dse = CoupleInputsDynamicShapes(args, kwargs, ds).replace_string_by()
|
|
1230
|
+
|
|
1231
|
+
summary["export_dynamic_shapes_export_export"] = string_type(dse)
|
|
1232
|
+
|
|
1233
|
+
if verbose:
|
|
1234
|
+
print(
|
|
1235
|
+
f"[call_torch_export_export] exporter={exporter!r}, "
|
|
1236
|
+
f"strict={strict}, optimization={optimization!r}"
|
|
1237
|
+
)
|
|
1238
|
+
print(f"[call_torch_export_export] args={string_type(args, with_shape=True)}")
|
|
1239
|
+
print(f"[call_torch_export_export] kwargs={string_type(kwargs, with_shape=True)}")
|
|
1240
|
+
print(f"[call_torch_export_export] dynamic_shapes={string_type(ds)}")
|
|
1241
|
+
print(f"[call_torch_export_export] dynamic_shapes_export_export={string_type(dse)}")
|
|
1242
|
+
print("[call_torch_export_export] export...")
|
|
1243
|
+
|
|
1244
|
+
model = data["model"]
|
|
1245
|
+
ep = _quiet_or_not_quiet(
|
|
1246
|
+
quiet,
|
|
1247
|
+
"export_export",
|
|
1248
|
+
summary,
|
|
1249
|
+
data,
|
|
1250
|
+
(
|
|
1251
|
+
lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: (
|
|
1252
|
+
torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s)
|
|
1253
|
+
)
|
|
1254
|
+
),
|
|
1255
|
+
)
|
|
1256
|
+
if "ERR_export_export" in summary:
|
|
1257
|
+
return summary, data
|
|
1258
|
+
|
|
1259
|
+
summary["export_graph_nodes"] = len(ep.graph.nodes)
|
|
1260
|
+
if verbose:
|
|
1261
|
+
print(
|
|
1262
|
+
f"[call_torch_export_export] done (export) "
|
|
1263
|
+
f"with {summary['export_graph_nodes']} nodes"
|
|
1264
|
+
)
|
|
1265
|
+
data["exported_program"] = ep
|
|
1266
|
+
if verbose > 1:
|
|
1267
|
+
print("[call_torch_export_export] -- ExportedProgram")
|
|
1268
|
+
print(ep)
|
|
1269
|
+
print("[call_torch_export_export] -- End of ExportedProgram")
|
|
1270
|
+
|
|
1271
|
+
if do_run:
|
|
1272
|
+
# We check for discrepancies.
|
|
1273
|
+
if verbose:
|
|
1274
|
+
print("[validate_model] run exported model...")
|
|
1275
|
+
print(
|
|
1276
|
+
f"[validate_model] patched inputs="
|
|
1277
|
+
f"{string_type(data['inputs_export'], with_shape=True)}"
|
|
1278
|
+
)
|
|
1279
|
+
hash_inputs = string_type(data["inputs_export"], with_shape=True)
|
|
1280
|
+
|
|
1281
|
+
# We make a copy of the input just in case the model modifies them inplace
|
|
1282
|
+
inputs = torch_deepcopy(data["inputs_export"])
|
|
1283
|
+
model = ep.module()
|
|
1284
|
+
|
|
1285
|
+
expected = _quiet_or_not_quiet(
|
|
1286
|
+
quiet,
|
|
1287
|
+
"run_exported",
|
|
1288
|
+
summary,
|
|
1289
|
+
data,
|
|
1290
|
+
(lambda m=model, inputs=inputs: (model(**inputs))),
|
|
1291
|
+
)
|
|
1292
|
+
if "ERR_export_export" in summary:
|
|
1293
|
+
return summary, data
|
|
1294
|
+
|
|
1295
|
+
disc = max_diff(data["run_expected"], expected)
|
|
1296
|
+
for k, v in disc.items():
|
|
1297
|
+
summary[f"disc_exported_{k}"] = str(v)
|
|
1298
|
+
if verbose:
|
|
1299
|
+
print("[validate_model] done (exported run)")
|
|
1300
|
+
print(f"[validate_model] exported discrepancies={string_diff(disc)}")
|
|
1301
|
+
assert hash_inputs == string_type(data["inputs_export"], with_shape=True), (
|
|
1302
|
+
f"The exported model execution did modified the inputs:\n"
|
|
1303
|
+
f"before: {hash_inputs}\n"
|
|
1304
|
+
f" after: {string_type(data['inputs_export'], with_shape=True)}"
|
|
1305
|
+
)
|
|
1306
|
+
return summary, data
|
|
1307
|
+
|
|
1308
|
+
|
|
1309
|
+
def validate_onnx_model(
|
|
1310
|
+
data: Dict[str, Any],
|
|
1311
|
+
quiet: bool = False,
|
|
1312
|
+
verbose: int = 0,
|
|
1313
|
+
flavour: Optional[str] = None,
|
|
1314
|
+
runtime: str = "onnxruntime",
|
|
1315
|
+
repeat: int = 1,
|
|
1316
|
+
warmup: int = 0,
|
|
1317
|
+
second_input_keys: Optional[List[str]] = None,
|
|
1318
|
+
ort_logs: bool = False,
|
|
1319
|
+
quiet_input_sets: Optional[Set[str]] = None,
|
|
1320
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1321
|
+
"""
|
|
1322
|
+
Verifies that an onnx model produces the same
|
|
1323
|
+
expected outputs. It uses ``data["onnx_filename]`` as the input
|
|
1324
|
+
onnx filename or ``data["onnx_filename_{flavour}]`` if *flavour*
|
|
1325
|
+
is specified.
|
|
1326
|
+
|
|
1327
|
+
:param data: dictionary with all the necessary inputs, the dictionary must
|
|
1328
|
+
contains keys ``model`` and ``inputs_export``
|
|
1329
|
+
:param quiet: catch exception or not
|
|
1330
|
+
:param verbose: verbosity
|
|
1331
|
+
:param flavour: use a different version of the inputs
|
|
1332
|
+
:param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
|
|
1333
|
+
:param repeat: run that number of times the model
|
|
1334
|
+
:param warmup: warmup the model
|
|
1335
|
+
:param second_input_keys: to validate the model on other input sets
|
|
1336
|
+
to make sure the exported model supports dynamism, the value is
|
|
1337
|
+
used as an increment added to the first set of inputs (added to dimensions)
|
|
1338
|
+
:param ort_logs: triggers the logs for onnxruntime
|
|
1339
|
+
:param quiet_input_sets: avoid raising an exception for these sets of inputs
|
|
1340
|
+
:return: two dictionaries, one with some metrics,
|
|
1341
|
+
another one with whatever the function produces
|
|
1342
|
+
"""
|
|
1343
|
+
import onnxruntime
|
|
1344
|
+
|
|
1345
|
+
def _mk(key, flavour=flavour):
|
|
1346
|
+
return f"{key}_{flavour}" if flavour else key
|
|
1347
|
+
|
|
1348
|
+
summary: Dict[str, Any] = {}
|
|
1349
|
+
flat_inputs = flatten_object(data["inputs"], drop_keys=True)
|
|
1350
|
+
d = flat_inputs[0].get_device()
|
|
1351
|
+
providers = (
|
|
1352
|
+
["CPUExecutionProvider"]
|
|
1353
|
+
if d < 0
|
|
1354
|
+
else ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
1355
|
+
)
|
|
1356
|
+
input_data_key = f"onnx_filename_{flavour}" if flavour else "onnx_filename"
|
|
1357
|
+
|
|
1358
|
+
if input_data_key in data:
|
|
1359
|
+
source = data[input_data_key]
|
|
1360
|
+
if not os.path.exists(source):
|
|
1361
|
+
if verbose:
|
|
1362
|
+
print(f"[validate_onnx_model] missing {source!r}")
|
|
1363
|
+
summary[_mk("ERR_onnx_missing")] = f"FileNotFoundError({source!r})"
|
|
1364
|
+
return summary, data
|
|
1365
|
+
summary[input_data_key] = source
|
|
1366
|
+
summary[_mk("onnx_size")] = os.stat(source).st_size
|
|
1367
|
+
else:
|
|
1368
|
+
assert not flavour, f"flavour={flavour!r}, the filename must be saved."
|
|
1369
|
+
assert (
|
|
1370
|
+
"onnx_program" in data
|
|
1371
|
+
), f"onnx_program is missing from data which has {sorted(data)}"
|
|
1372
|
+
source = data["onnx_program"].model_proto.SerializeToString()
|
|
1373
|
+
assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb"
|
|
1374
|
+
summary[_mk("onnx_size")] = len(source)
|
|
1375
|
+
if verbose:
|
|
1376
|
+
print(
|
|
1377
|
+
f"[validate_onnx_model] verify onnx model with providers "
|
|
1378
|
+
f"{providers}..., flavour={flavour!r}"
|
|
1379
|
+
)
|
|
1380
|
+
|
|
1381
|
+
if runtime == "onnxruntime":
|
|
1382
|
+
if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
|
|
1383
|
+
opts = onnxruntime.SessionOptions()
|
|
1384
|
+
opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
|
|
1385
|
+
if verbose:
|
|
1386
|
+
print(
|
|
1387
|
+
f"[validate_onnx_model] saved optimized onnxruntime "
|
|
1388
|
+
f"in {opts.optimized_model_filepath!r}"
|
|
1389
|
+
)
|
|
1390
|
+
onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers)
|
|
1391
|
+
if verbose:
|
|
1392
|
+
print("[validate_onnx_model] -- done")
|
|
1393
|
+
|
|
1394
|
+
if verbose:
|
|
1395
|
+
print("[validate_onnx_model] runtime is onnxruntime")
|
|
1396
|
+
sess_opts = onnxruntime.SessionOptions()
|
|
1397
|
+
if ort_logs:
|
|
1398
|
+
sess_opts.log_severity_level = 0
|
|
1399
|
+
sess_opts.log_verbosity_level = 4
|
|
1400
|
+
cls_runtime = lambda model, providers, _o=sess_opts: onnxruntime.InferenceSession(
|
|
1401
|
+
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
|
|
1402
|
+
_o,
|
|
1403
|
+
providers=providers,
|
|
1404
|
+
)
|
|
1405
|
+
elif runtime == "torch":
|
|
1406
|
+
from ..reference import TorchOnnxEvaluator
|
|
1407
|
+
|
|
1408
|
+
if verbose:
|
|
1409
|
+
print("[validate_onnx_model] runtime is TorchOnnxEvaluator")
|
|
1410
|
+
cls_runtime = (
|
|
1411
|
+
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
|
|
1412
|
+
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1413
|
+
)
|
|
1414
|
+
)
|
|
1415
|
+
elif runtime == "orteval":
|
|
1416
|
+
from ..reference import OnnxruntimeEvaluator
|
|
1417
|
+
|
|
1418
|
+
if verbose:
|
|
1419
|
+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator")
|
|
1420
|
+
cls_runtime = (
|
|
1421
|
+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
|
|
1422
|
+
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1423
|
+
)
|
|
1424
|
+
)
|
|
1425
|
+
elif runtime == "orteval10":
|
|
1426
|
+
from ..reference import OnnxruntimeEvaluator
|
|
1427
|
+
|
|
1428
|
+
if verbose:
|
|
1429
|
+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)")
|
|
1430
|
+
cls_runtime = (
|
|
1431
|
+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
|
|
1432
|
+
model, providers=providers, verbose=10
|
|
1433
|
+
)
|
|
1434
|
+
)
|
|
1435
|
+
elif runtime == "ref":
|
|
1436
|
+
from ..reference import ExtendedReferenceEvaluator
|
|
1437
|
+
|
|
1438
|
+
if verbose:
|
|
1439
|
+
print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator")
|
|
1440
|
+
cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc]
|
|
1441
|
+
model, verbose=max(verbose - 1, 0)
|
|
1442
|
+
)
|
|
1443
|
+
else:
|
|
1444
|
+
raise ValueError(f"Unexpecteed runtime={runtime!r}")
|
|
1445
|
+
|
|
1446
|
+
sess = _quiet_or_not_quiet(
|
|
1447
|
+
quiet,
|
|
1448
|
+
_mk("create_onnx_ort"),
|
|
1449
|
+
summary,
|
|
1450
|
+
data,
|
|
1451
|
+
(lambda source=source, providers=providers: cls_runtime(source, providers)),
|
|
1452
|
+
)
|
|
1453
|
+
if f"ERR_{_mk('onnx_ort_create')}" in summary:
|
|
1454
|
+
return summary, data
|
|
1455
|
+
|
|
1456
|
+
data[_mk("onnx_ort_sess")] = sess
|
|
1457
|
+
if verbose:
|
|
1458
|
+
print(f"[validate_onnx_model] done (ort_session) flavour={flavour!r}")
|
|
1459
|
+
|
|
1460
|
+
keys = [("inputs", "run_expected", "")]
|
|
1461
|
+
if second_input_keys:
|
|
1462
|
+
keys.extend([(k, f"run_expected2{k[6:]}", f"2{k[6:]}") for k in second_input_keys])
|
|
1463
|
+
if verbose:
|
|
1464
|
+
print(f"[validate_onnx_model] -- keys={keys}")
|
|
1465
|
+
for k_input, k_expected, suffix in keys:
|
|
1466
|
+
# make_feeds
|
|
1467
|
+
assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
|
|
1468
|
+
assert k_expected in data, f"Unable to find {k_expected!r} in {sorted(data)}"
|
|
1469
|
+
if verbose:
|
|
1470
|
+
print(f"[validate_onnx_model] -- make_feeds for {k_input!r}...")
|
|
1471
|
+
print(
|
|
1472
|
+
f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
|
|
1473
|
+
)
|
|
1474
|
+
feeds = make_feeds(
|
|
1475
|
+
sess,
|
|
1476
|
+
data[k_input],
|
|
1477
|
+
use_numpy=True,
|
|
1478
|
+
check_flatten=False,
|
|
1479
|
+
is_modelbuilder=data["exporter"] == "modelbuilder", # to remove position_ids
|
|
1480
|
+
)
|
|
1481
|
+
if verbose:
|
|
1482
|
+
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
|
|
1483
|
+
summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
|
|
1484
|
+
if verbose:
|
|
1485
|
+
print("[validate_onnx_model] done (make_feeds)")
|
|
1486
|
+
|
|
1487
|
+
# run ort
|
|
1488
|
+
if verbose:
|
|
1489
|
+
print(f"[validate_onnx_model] run session on inputs 'inputs{suffix}'...")
|
|
1490
|
+
if quiet_input_sets and f"inputs{suffix}" in quiet_input_sets:
|
|
1491
|
+
print(f"[validate_onnx_model] quiet_input_sets={quiet_input_sets}")
|
|
1492
|
+
|
|
1493
|
+
got = _quiet_or_not_quiet(
|
|
1494
|
+
quiet or (quiet_input_sets is not None and f"inputs{suffix}" in quiet_input_sets),
|
|
1495
|
+
_mk(f"run_onnx_ort{suffix}"),
|
|
1496
|
+
summary,
|
|
1497
|
+
data,
|
|
1498
|
+
(lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
|
|
1499
|
+
repeat=repeat,
|
|
1500
|
+
warmup=warmup,
|
|
1501
|
+
)
|
|
1502
|
+
if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
|
|
1503
|
+
return summary, data
|
|
1504
|
+
|
|
1505
|
+
summary[f"run_feeds_{k_input}"] = string_type(feeds, with_shape=True, with_device=True)
|
|
1506
|
+
summary[f"run_output_{k_input}"] = string_type(got, with_shape=True, with_device=True)
|
|
1507
|
+
if verbose:
|
|
1508
|
+
print("[validate_onnx_model] done (run)")
|
|
1509
|
+
print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}")
|
|
1510
|
+
|
|
1511
|
+
# compute discrepancies
|
|
1512
|
+
disc = max_diff(data[k_expected], got, flatten=True)
|
|
1513
|
+
if verbose:
|
|
1514
|
+
print(f"[validate_onnx_model] discrepancies={string_diff(disc)}")
|
|
1515
|
+
for k, v in disc.items():
|
|
1516
|
+
summary[_mk(f"disc_onnx_ort_run{suffix}_{k}")] = v
|
|
1517
|
+
return summary, data
|
|
1518
|
+
|
|
1519
|
+
|
|
1520
|
+
def call_torch_export_onnx(
|
|
1521
|
+
data: Dict[str, Any],
|
|
1522
|
+
exporter: str,
|
|
1523
|
+
quiet: bool = False,
|
|
1524
|
+
verbose: int = 0,
|
|
1525
|
+
optimization: Optional[str] = None,
|
|
1526
|
+
output_names: Optional[List[str]] = None,
|
|
1527
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1528
|
+
"""
|
|
1529
|
+
Exports a model into onnx.
|
|
1530
|
+
If a patch must be applied, it should be before this functions.
|
|
1531
|
+
|
|
1532
|
+
:param data: dictionary with all the necessary inputs, the dictionary must
|
|
1533
|
+
contains keys ``model`` and ``inputs_export``
|
|
1534
|
+
:param exporter: exporter to call
|
|
1535
|
+
:param quiet: catch exception or not
|
|
1536
|
+
:param verbose: verbosity
|
|
1537
|
+
:param optimization: optimization to do
|
|
1538
|
+
:param output_names: output names to use
|
|
1539
|
+
:return: two dictionaries, one with some metrics,
|
|
1540
|
+
another one with whatever the function produces
|
|
1541
|
+
"""
|
|
1542
|
+
available = {None, "", "ir", "os_ort", "ir+default"}
|
|
1543
|
+
assert (
|
|
1544
|
+
optimization in available
|
|
1545
|
+
), f"unexpected value for optimization={optimization}, available={available}"
|
|
1546
|
+
assert exporter in {
|
|
1547
|
+
"onnx-dynamo",
|
|
1548
|
+
"onnx-script",
|
|
1549
|
+
}, f"Unexpected value for exporter={exporter!r}"
|
|
1550
|
+
assert "model" in data, f"model is missing from data: {sorted(data)}"
|
|
1551
|
+
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
|
|
1552
|
+
summary: Dict[str, Union[str, int, float]] = {}
|
|
1553
|
+
dynamo = "dynamo" in exporter
|
|
1554
|
+
args, kwargs = split_args_kwargs(data["inputs_export"])
|
|
1555
|
+
ds = data.get("dynamic_shapes", None)
|
|
1556
|
+
if verbose:
|
|
1557
|
+
print(
|
|
1558
|
+
f"[call_torch_export_onnx] exporter={exporter!r}, "
|
|
1559
|
+
f"optimization={optimization!r}"
|
|
1560
|
+
)
|
|
1561
|
+
print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
|
|
1562
|
+
print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
|
|
1563
|
+
print(f"[call_torch_export_onnx] dynamic_shapes={string_type(ds)}")
|
|
1564
|
+
print("[call_torch_export_onnx] export...")
|
|
1565
|
+
summary["export_exporter"] = exporter
|
|
1566
|
+
summary["export_optimization"] = optimization or ""
|
|
1567
|
+
summary["export_dynamo"] = dynamo
|
|
1568
|
+
summary["export_args"] = string_type(args, with_shape=True)
|
|
1569
|
+
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
|
|
1570
|
+
opset = data.get("model_opset", None)
|
|
1571
|
+
if opset:
|
|
1572
|
+
summary["export_opset"] = opset
|
|
1573
|
+
|
|
1574
|
+
if dynamo:
|
|
1575
|
+
export_export_kwargs = dict(dynamo=True, dynamic_shapes=ds)
|
|
1576
|
+
else:
|
|
1577
|
+
export_export_kwargs = dict(
|
|
1578
|
+
dynamo=False,
|
|
1579
|
+
dynamic_axes={
|
|
1580
|
+
k: v
|
|
1581
|
+
for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) # type: ignore[arg-type]
|
|
1582
|
+
.replace_by_string()
|
|
1583
|
+
.items()
|
|
1584
|
+
if isinstance(v, dict)
|
|
1585
|
+
},
|
|
1586
|
+
)
|
|
1587
|
+
args = tuple(flatten_unflatten_for_dynamic_shapes(a) for a in args)
|
|
1588
|
+
kwargs = {k: flatten_unflatten_for_dynamic_shapes(v) for k, v in kwargs.items()}
|
|
1589
|
+
if verbose:
|
|
1590
|
+
print("[call_torch_export_onnx] dynamo=False so...")
|
|
1591
|
+
print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
|
|
1592
|
+
print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
|
|
1593
|
+
if output_names:
|
|
1594
|
+
export_export_kwargs["output_names"] = output_names
|
|
1595
|
+
if opset:
|
|
1596
|
+
export_export_kwargs["opset_version"] = opset
|
|
1597
|
+
if verbose:
|
|
1598
|
+
print(
|
|
1599
|
+
f"[call_torch_export_onnx] export_export_kwargs="
|
|
1600
|
+
f"{string_type(export_export_kwargs, with_shape=True)}"
|
|
1601
|
+
)
|
|
1602
|
+
model = data["model"]
|
|
1603
|
+
|
|
1604
|
+
epo = _quiet_or_not_quiet(
|
|
1605
|
+
quiet,
|
|
1606
|
+
"export_onnx",
|
|
1607
|
+
summary,
|
|
1608
|
+
data,
|
|
1609
|
+
(
|
|
1610
|
+
lambda m=model, args=args, kws=kwargs, ekws=export_export_kwargs: (
|
|
1611
|
+
torch.onnx.export(
|
|
1612
|
+
m,
|
|
1613
|
+
args,
|
|
1614
|
+
kwargs=kws,
|
|
1615
|
+
**ekws,
|
|
1616
|
+
)
|
|
1617
|
+
)
|
|
1618
|
+
),
|
|
1619
|
+
)
|
|
1620
|
+
if "ERR_export_onnx" in summary:
|
|
1621
|
+
return summary, data
|
|
1622
|
+
|
|
1623
|
+
assert epo is not None, "no onnx export was found"
|
|
1624
|
+
if verbose:
|
|
1625
|
+
print("[call_torch_export_onnx] done (export)")
|
|
1626
|
+
data["onnx_program"] = epo
|
|
1627
|
+
if verbose > 5:
|
|
1628
|
+
print("[call_torch_export_onnx] -- ONNXProgram")
|
|
1629
|
+
print(epo)
|
|
1630
|
+
print("[call_torch_export_onnx] -- End of ONNXProgram")
|
|
1631
|
+
|
|
1632
|
+
if optimization in {"ir", "os_ort", "ir+default"}:
|
|
1633
|
+
if verbose:
|
|
1634
|
+
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
|
|
1635
|
+
if optimization == "ir":
|
|
1636
|
+
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
|
|
1637
|
+
elif optimization == "ir+default":
|
|
1638
|
+
import onnxscript
|
|
1639
|
+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
|
|
1640
|
+
|
|
1641
|
+
def _ir_default_opt(epo):
|
|
1642
|
+
onnxscript.optimizer.optimize_ir(epo.model)
|
|
1643
|
+
onx = epo.model_proto
|
|
1644
|
+
# not very efficient
|
|
1645
|
+
gr = GraphBuilder(
|
|
1646
|
+
onx,
|
|
1647
|
+
infer_shapes_options=True,
|
|
1648
|
+
optimization_options=OptimizationOptions(patterns="default"),
|
|
1649
|
+
)
|
|
1650
|
+
cont = gr.to_onnx(large_model=True)
|
|
1651
|
+
epo.model = cont.to_ir()
|
|
1652
|
+
|
|
1653
|
+
label, f_optim = "export_onnx_opt_ir_default", (
|
|
1654
|
+
lambda epo=epo: _ir_default_opt(epo)
|
|
1655
|
+
)
|
|
1656
|
+
|
|
1657
|
+
else:
|
|
1658
|
+
import onnxscript
|
|
1659
|
+
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
1660
|
+
|
|
1661
|
+
def _os_ort_optim(epo):
|
|
1662
|
+
onnxscript.optimizer.optimize_ir(epo.model)
|
|
1663
|
+
optimized = ort_fusions.optimize_for_ort(epo.model)
|
|
1664
|
+
if isinstance(optimized, tuple):
|
|
1665
|
+
for k, v in optimized[1].items():
|
|
1666
|
+
summary[f"op_opt_fused_{k}"] = v
|
|
1667
|
+
epo.model = optimized[0]
|
|
1668
|
+
else:
|
|
1669
|
+
epo.model = optimized
|
|
1670
|
+
|
|
1671
|
+
label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo))
|
|
1672
|
+
_quiet_or_not_quiet(quiet, label, summary, data, f_optim)
|
|
1673
|
+
if "ERR_export_onnx_opt_ir" in summary:
|
|
1674
|
+
return summary, data
|
|
1675
|
+
if verbose:
|
|
1676
|
+
print("[call_torch_export_onnx] done (optimization)")
|
|
1677
|
+
|
|
1678
|
+
return summary, data
|
|
1679
|
+
|
|
1680
|
+
|
|
1681
|
+
def call_torch_export_model_builder(
|
|
1682
|
+
data: Dict[str, Any],
|
|
1683
|
+
exporter: str,
|
|
1684
|
+
quiet: bool = False,
|
|
1685
|
+
verbose: int = 0,
|
|
1686
|
+
optimization: Optional[str] = None,
|
|
1687
|
+
output_names: Optional[List[str]] = None,
|
|
1688
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1689
|
+
"""
|
|
1690
|
+
Exports a model into onnx with :epkg:`ModelBuilder`.
|
|
1691
|
+
|
|
1692
|
+
:param data: dictionary with all the necessary inputs, the dictionary must
|
|
1693
|
+
contains keys ``model`` and ``inputs_export``
|
|
1694
|
+
:param exporter: exporter to call
|
|
1695
|
+
:param quiet: catch exception or not
|
|
1696
|
+
:param verbose: verbosity
|
|
1697
|
+
:param optimization: optimization to do
|
|
1698
|
+
:param output_names: list of output names to use
|
|
1699
|
+
:return: two dictionaries, one with some metrics,
|
|
1700
|
+
another one with whatever the function produces
|
|
1701
|
+
"""
|
|
1702
|
+
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
|
|
1703
|
+
|
|
1704
|
+
assert optimization in (
|
|
1705
|
+
None,
|
|
1706
|
+
"",
|
|
1707
|
+
), f"unexpected value for optimization={optimization}, none is available"
|
|
1708
|
+
precision = data.get("model_dtype", "fp32")
|
|
1709
|
+
provider = data.get("model_device", "cpu")
|
|
1710
|
+
dump_folder = data.get("model_dump_folder", "")
|
|
1711
|
+
assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
|
|
1712
|
+
assert (
|
|
1713
|
+
not output_names
|
|
1714
|
+
), f"output_names not empty, not supported yet, output_names={output_names}"
|
|
1715
|
+
cache_dir = os.path.join(dump_folder, "cache_mb")
|
|
1716
|
+
if not os.path.exists(cache_dir):
|
|
1717
|
+
os.makedirs(cache_dir)
|
|
1718
|
+
summary: Dict[str, Any] = {}
|
|
1719
|
+
|
|
1720
|
+
epo = _quiet_or_not_quiet(
|
|
1721
|
+
quiet,
|
|
1722
|
+
"export_model_builder",
|
|
1723
|
+
summary,
|
|
1724
|
+
data,
|
|
1725
|
+
(
|
|
1726
|
+
lambda m=data["model"], c=data[
|
|
1727
|
+
"configuration"
|
|
1728
|
+
], p=precision, pr=provider, cd=cache_dir: (
|
|
1729
|
+
save_model_builder(
|
|
1730
|
+
create_model_builder(
|
|
1731
|
+
c, m, precision=p, execution_provider=pr, cache_dir=cd
|
|
1732
|
+
)
|
|
1733
|
+
)
|
|
1734
|
+
)
|
|
1735
|
+
),
|
|
1736
|
+
)
|
|
1737
|
+
if "ERR_export_model_builder" in summary:
|
|
1738
|
+
return summary, data
|
|
1739
|
+
|
|
1740
|
+
assert epo is not None, "no onnx export was found"
|
|
1741
|
+
if verbose:
|
|
1742
|
+
print("[call_torch_export_model_builder] done (export)")
|
|
1743
|
+
data["onnx_program"] = epo
|
|
1744
|
+
return summary, data
|
|
1745
|
+
|
|
1746
|
+
|
|
1747
|
+
def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]:
|
|
1748
|
+
"""
|
|
1749
|
+
Processes statistics coming from the exporters.
|
|
1750
|
+
It takes a sequence of dictionaries (like a data frame)
|
|
1751
|
+
and extracts some metrics.
|
|
1752
|
+
"""
|
|
1753
|
+
|
|
1754
|
+
def _simplify(p):
|
|
1755
|
+
for s in [
|
|
1756
|
+
"remove_unused",
|
|
1757
|
+
"constant_folding",
|
|
1758
|
+
"remove_identity",
|
|
1759
|
+
"remove_duplicated_initializer",
|
|
1760
|
+
"remove_duplicated_shape",
|
|
1761
|
+
"dynamic_dimension_naming",
|
|
1762
|
+
"inline",
|
|
1763
|
+
"check",
|
|
1764
|
+
"build_graph_for_pattern",
|
|
1765
|
+
"pattern_optimization",
|
|
1766
|
+
"topological_sort",
|
|
1767
|
+
]:
|
|
1768
|
+
if s in p or s.replace("_", "-") in p:
|
|
1769
|
+
return s
|
|
1770
|
+
if p.startswith(("apply_", "match_")):
|
|
1771
|
+
return p
|
|
1772
|
+
return "other"
|
|
1773
|
+
|
|
1774
|
+
def _add(d, a, v, use_max=False):
|
|
1775
|
+
if v:
|
|
1776
|
+
if a not in d:
|
|
1777
|
+
d[a] = v
|
|
1778
|
+
elif use_max:
|
|
1779
|
+
d[a] = max(d[a], v)
|
|
1780
|
+
else:
|
|
1781
|
+
d[a] += v
|
|
1782
|
+
|
|
1783
|
+
counts: Dict[str, Any] = {}
|
|
1784
|
+
applied_pattern_time: Dict[str, Any] = {}
|
|
1785
|
+
applied_pattern_n: Dict[str, Any] = {}
|
|
1786
|
+
matching_pattern_time: Dict[str, Any] = {}
|
|
1787
|
+
matching_pattern_n: Dict[str, Any] = {}
|
|
1788
|
+
|
|
1789
|
+
for obs in data:
|
|
1790
|
+
pattern = _simplify(obs["pattern"])
|
|
1791
|
+
_add(counts, "opt_nodes_added", obs.get("added", 0))
|
|
1792
|
+
_add(counts, "opt_nodes_removed", obs.get("removed", 0))
|
|
1793
|
+
_add(counts, "opt_time_steps", obs.get("time_in", 0))
|
|
1794
|
+
_add(counts, "opt_n_steps", 1)
|
|
1795
|
+
_add(
|
|
1796
|
+
counts,
|
|
1797
|
+
"opt_n_iteration",
|
|
1798
|
+
max(counts.get("opt_n_iteration", 0), obs.get("iteration", 0)),
|
|
1799
|
+
use_max=True,
|
|
1800
|
+
)
|
|
1801
|
+
|
|
1802
|
+
if pattern.startswith("apply_"):
|
|
1803
|
+
_add(counts, "opt_n_applied_patterns", 1)
|
|
1804
|
+
_add(counts, "opt_time_applied_patterns", obs.get("time_in", 0))
|
|
1805
|
+
_add(applied_pattern_time, pattern, obs.get("time_in", 0))
|
|
1806
|
+
_add(applied_pattern_n, pattern, 1)
|
|
1807
|
+
elif pattern.startswith("match_"):
|
|
1808
|
+
_add(counts, "opt_n_matching_patterns", 1)
|
|
1809
|
+
_add(counts, "opt_time_matching_patterns", obs.get("time_in", 0))
|
|
1810
|
+
_add(matching_pattern_time, pattern, obs.get("time_in", 0))
|
|
1811
|
+
_add(matching_pattern_n, pattern, 1)
|
|
1812
|
+
else:
|
|
1813
|
+
_add(counts, f"opt_time_{pattern}", obs.get("time_in", 0))
|
|
1814
|
+
_add(counts, f"opt_n_{pattern}", 1)
|
|
1815
|
+
_add(counts, f"opt_nodes_added_{pattern}", obs.get("added", 0))
|
|
1816
|
+
_add(counts, f"opt_nodes_removed_{pattern}", obs.get("removed", 0))
|
|
1817
|
+
|
|
1818
|
+
if applied_pattern_time:
|
|
1819
|
+
longest = max((v, k) for k, v in applied_pattern_time.items())
|
|
1820
|
+
counts["opt_top_time_applied_pattern"], counts["opt_top_time_applied_pattern_arg"] = (
|
|
1821
|
+
longest
|
|
1822
|
+
)
|
|
1823
|
+
longest = max((v, k) for k, v in applied_pattern_n.items())
|
|
1824
|
+
counts["opt_top_n_applied_pattern"], counts["opt_top_n_applied_pattern_arg"] = longest
|
|
1825
|
+
|
|
1826
|
+
if matching_pattern_time:
|
|
1827
|
+
longest = max((v, k) for k, v in matching_pattern_time.items())
|
|
1828
|
+
(
|
|
1829
|
+
counts["opt_top_time_matching_pattern"],
|
|
1830
|
+
counts["opt_top_time_matching_pattern_arg"],
|
|
1831
|
+
) = longest
|
|
1832
|
+
longest = max((v, k) for k, v in matching_pattern_n.items())
|
|
1833
|
+
counts["opt_top_n_matching_pattern"], counts["opt_top_n_matching_pattern_arg"] = (
|
|
1834
|
+
longest
|
|
1835
|
+
)
|
|
1836
|
+
counts["onnx_opt_optimized"] = 1
|
|
1837
|
+
return counts
|
|
1838
|
+
|
|
1839
|
+
|
|
1840
|
+
def call_torch_export_custom(
|
|
1841
|
+
data: Dict[str, Any],
|
|
1842
|
+
exporter: str,
|
|
1843
|
+
quiet: bool = False,
|
|
1844
|
+
verbose: int = 0,
|
|
1845
|
+
optimization: Optional[str] = None,
|
|
1846
|
+
dump_folder: Optional[str] = None,
|
|
1847
|
+
output_names: Optional[List[str]] = None,
|
|
1848
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1849
|
+
"""
|
|
1850
|
+
Exports a model into onnx.
|
|
1851
|
+
If a patch must be applied, it should be before this functions.
|
|
1852
|
+
|
|
1853
|
+
:param data: dictionary with all the necessary inputs, the dictionary must
|
|
1854
|
+
contains keys ``model`` and ``inputs_export``
|
|
1855
|
+
:param exporter: exporter to call
|
|
1856
|
+
:param quiet: catch exception or not
|
|
1857
|
+
:param verbose: verbosity
|
|
1858
|
+
:param optimization: optimization to do
|
|
1859
|
+
:param dump_folder: to store additional information
|
|
1860
|
+
:param output_names: list of output names to use
|
|
1861
|
+
:return: two dictionaries, one with some metrics,
|
|
1862
|
+
another one with whatever the function produces
|
|
1863
|
+
"""
|
|
1864
|
+
available = {
|
|
1865
|
+
"",
|
|
1866
|
+
"default",
|
|
1867
|
+
"default+onnxruntime",
|
|
1868
|
+
"default+os_ort",
|
|
1869
|
+
"default+onnxruntime+os_ort",
|
|
1870
|
+
None,
|
|
1871
|
+
}
|
|
1872
|
+
if optimization == "none":
|
|
1873
|
+
optimization = ""
|
|
1874
|
+
assert (
|
|
1875
|
+
optimization in available
|
|
1876
|
+
), f"unexpected value for optimization={optimization}, available={available}"
|
|
1877
|
+
available = {
|
|
1878
|
+
"custom",
|
|
1879
|
+
"custom-strict",
|
|
1880
|
+
"custom-strict-default",
|
|
1881
|
+
"custom-strict-all",
|
|
1882
|
+
"custom-nostrict",
|
|
1883
|
+
"custom-nostrict-default",
|
|
1884
|
+
"custom-nostrict-all",
|
|
1885
|
+
"custom-noinline",
|
|
1886
|
+
"custom-strict-noinline",
|
|
1887
|
+
"custom-strict-default-noinline",
|
|
1888
|
+
"custom-strict-all-noinline",
|
|
1889
|
+
"custom-nostrict-noinline",
|
|
1890
|
+
"custom-nostrict-default-noinline",
|
|
1891
|
+
"custom-nostrict-all-noinline",
|
|
1892
|
+
"custom-dec",
|
|
1893
|
+
"custom-decall",
|
|
1894
|
+
"custom-fake",
|
|
1895
|
+
}
|
|
1896
|
+
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
|
|
1897
|
+
assert "model" in data, f"model is missing from data: {sorted(data)}"
|
|
1898
|
+
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
|
|
1899
|
+
summary: Dict[str, Union[str, int, float]] = {}
|
|
1900
|
+
strict = "-strict" in exporter
|
|
1901
|
+
args, kwargs = split_args_kwargs(data["inputs_export"])
|
|
1902
|
+
ds = data.get("dynamic_shapes", None)
|
|
1903
|
+
if "-fake" in exporter:
|
|
1904
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
1905
|
+
|
|
1906
|
+
if verbose:
|
|
1907
|
+
print("[call_torch_export_custom] switching to FakeTensor")
|
|
1908
|
+
assert not args, f"Exporter {exporter!r} not implemented with fake tensors."
|
|
1909
|
+
kwargs = torch_deepcopy(kwargs)
|
|
1910
|
+
kwargs, _ = make_fake_with_dynamic_dimensions(kwargs, dynamic_shapes=ds)
|
|
1911
|
+
opset = data.get("model_opset", None)
|
|
1912
|
+
if opset:
|
|
1913
|
+
summary["export_opset"] = opset
|
|
1914
|
+
if verbose:
|
|
1915
|
+
print(
|
|
1916
|
+
f"[call_torch_export_custom] exporter={exporter!r}, "
|
|
1917
|
+
f"optimization={optimization!r}"
|
|
1918
|
+
)
|
|
1919
|
+
print(f"[call_torch_export_custom] args={string_type(args, with_shape=True)}")
|
|
1920
|
+
print(f"[call_torch_export_custom] kwargs={string_type(kwargs, with_shape=True)}")
|
|
1921
|
+
print(f"[call_torch_export_custom] dynamic_shapes={string_type(ds)}")
|
|
1922
|
+
print("[call_torch_export_custom] export...")
|
|
1923
|
+
summary["export_exporter"] = exporter
|
|
1924
|
+
summary["export_optimization"] = optimization or ""
|
|
1925
|
+
summary["export_strict"] = strict
|
|
1926
|
+
summary["export_args"] = string_type(args, with_shape=True)
|
|
1927
|
+
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
|
|
1928
|
+
|
|
1929
|
+
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
|
|
1930
|
+
from experimental_experiment.xbuilder import OptimizationOptions
|
|
1931
|
+
|
|
1932
|
+
spl = optimization.split("+") if optimization else []
|
|
1933
|
+
os_ort = "os_ort" in spl
|
|
1934
|
+
optimization = "+".join(_ for _ in spl if _ != "os_ort")
|
|
1935
|
+
|
|
1936
|
+
export_options = ExportOptions(
|
|
1937
|
+
strict=strict,
|
|
1938
|
+
decomposition_table=(
|
|
1939
|
+
"default"
|
|
1940
|
+
if ("-default" in exporter or "-dec" in exporter)
|
|
1941
|
+
else ("all" if ("-all" in exporter or "-decall" in exporter) else None)
|
|
1942
|
+
),
|
|
1943
|
+
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
|
|
1944
|
+
)
|
|
1945
|
+
inline = "-noinline" not in exporter
|
|
1946
|
+
options = OptimizationOptions(patterns=optimization) if optimization else None
|
|
1947
|
+
model = data["model"]
|
|
1948
|
+
kws = dict(
|
|
1949
|
+
dynamic_shapes=ds,
|
|
1950
|
+
export_options=export_options,
|
|
1951
|
+
options=options,
|
|
1952
|
+
optimize=bool(optimization),
|
|
1953
|
+
large_model=True,
|
|
1954
|
+
return_optimize_report=True,
|
|
1955
|
+
verbose=max(verbose - 2, 0),
|
|
1956
|
+
inline=inline,
|
|
1957
|
+
)
|
|
1958
|
+
if opset:
|
|
1959
|
+
kws["target_opset"] = opset
|
|
1960
|
+
if output_names:
|
|
1961
|
+
kws["output_names"] = output_names
|
|
1962
|
+
|
|
1963
|
+
epo, opt_stats = _quiet_or_not_quiet(
|
|
1964
|
+
quiet,
|
|
1965
|
+
"export_export_onnx_c",
|
|
1966
|
+
summary,
|
|
1967
|
+
data,
|
|
1968
|
+
(
|
|
1969
|
+
lambda m=model, args=args, kwargs=kwargs, kws=kws: (
|
|
1970
|
+
to_onnx(
|
|
1971
|
+
model,
|
|
1972
|
+
args,
|
|
1973
|
+
kwargs=kwargs,
|
|
1974
|
+
**kws,
|
|
1975
|
+
)
|
|
1976
|
+
)
|
|
1977
|
+
),
|
|
1978
|
+
)
|
|
1979
|
+
if "ERR_export_onnx_c" in summary:
|
|
1980
|
+
return summary, data
|
|
1981
|
+
|
|
1982
|
+
new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
|
|
1983
|
+
new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")})
|
|
1984
|
+
if "optimization" in opt_stats:
|
|
1985
|
+
new_stat.update(process_statistics(opt_stats["optimization"]))
|
|
1986
|
+
|
|
1987
|
+
summary.update(new_stat)
|
|
1988
|
+
assert epo is not None, "no onnx export was found"
|
|
1989
|
+
if verbose:
|
|
1990
|
+
print("[call_torch_export_custom] done (export)")
|
|
1991
|
+
|
|
1992
|
+
if os_ort:
|
|
1993
|
+
import onnxscript
|
|
1994
|
+
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
1995
|
+
|
|
1996
|
+
if verbose:
|
|
1997
|
+
print("[call_torch_export_custom] conversion to IR...")
|
|
1998
|
+
begin = time.perf_counter()
|
|
1999
|
+
ir_model = epo.to_ir()
|
|
2000
|
+
duration = time.perf_counter() - begin
|
|
2001
|
+
summary["time_optim_to_ir"] = duration
|
|
2002
|
+
if verbose:
|
|
2003
|
+
print(f"[call_torch_export_custom] done in {duration}")
|
|
2004
|
+
print("[call_torch_export_custom] start optimization...")
|
|
2005
|
+
begin = time.perf_counter()
|
|
2006
|
+
onnxscript.optimizer.optimize_ir(ir_model)
|
|
2007
|
+
ir_optimized = ort_fusions.optimize_for_ort(ir_model)
|
|
2008
|
+
if isinstance(ir_optimized, tuple):
|
|
2009
|
+
report = ir_optimized[1]
|
|
2010
|
+
for k, v in report.items():
|
|
2011
|
+
summary[f"op_opt_fused_{k}"] = v
|
|
2012
|
+
ir_optimized = ir_optimized[0]
|
|
2013
|
+
epo.model = ir_optimized
|
|
2014
|
+
duration = time.perf_counter() - begin
|
|
2015
|
+
summary["time_optim_os_ort"] = duration
|
|
2016
|
+
if verbose:
|
|
2017
|
+
print(f"[call_torch_export_custom] done in {duration}")
|
|
2018
|
+
|
|
2019
|
+
data["onnx_program"] = epo
|
|
2020
|
+
return summary, data
|
|
2021
|
+
|
|
2022
|
+
|
|
2023
|
+
def run_ort_fusion(
|
|
2024
|
+
model_or_path: Union[str, onnx.ModelProto],
|
|
2025
|
+
output_path: str,
|
|
2026
|
+
num_attention_heads: int,
|
|
2027
|
+
hidden_size: int,
|
|
2028
|
+
model_type: str = "bert",
|
|
2029
|
+
verbose: int = 0,
|
|
2030
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
2031
|
+
"""
|
|
2032
|
+
Runs :epkg:`onnxruntime` fusion optimizer.
|
|
2033
|
+
|
|
2034
|
+
:param model_or_path: path to the ModelProto or the ModelProto itself
|
|
2035
|
+
:param output_path: the model to save
|
|
2036
|
+
:param num_attention_heads: number of heads, usually ``config.num_attention_heads``
|
|
2037
|
+
:param hidden_size: hidden size, usually ``config.hidden_size``
|
|
2038
|
+
:param model_type: type of optimization, see below
|
|
2039
|
+
:param verbose: verbosity
|
|
2040
|
+
:return: two dictionaries, summary and data
|
|
2041
|
+
|
|
2042
|
+
Supported values for ``model_type``:
|
|
2043
|
+
|
|
2044
|
+
.. runpython::
|
|
2045
|
+
:showcode:
|
|
2046
|
+
|
|
2047
|
+
import pprint
|
|
2048
|
+
from onnxruntime.transformers.optimizer import MODEL_TYPES
|
|
2049
|
+
|
|
2050
|
+
pprint.pprint(sorted(MODEL_TYPES))
|
|
2051
|
+
"""
|
|
2052
|
+
from onnxruntime.transformers.optimizer import optimize_by_fusion
|
|
2053
|
+
from onnxruntime.transformers.fusion_options import FusionOptions
|
|
2054
|
+
|
|
2055
|
+
opts = FusionOptions(model_type)
|
|
2056
|
+
|
|
2057
|
+
if isinstance(model_or_path, str):
|
|
2058
|
+
if verbose:
|
|
2059
|
+
print(f"[run_ort_fusion] loads {model_or_path!r}")
|
|
2060
|
+
onx = onnx.load(model_or_path)
|
|
2061
|
+
else:
|
|
2062
|
+
onx = model_or_path
|
|
2063
|
+
begin = time.perf_counter()
|
|
2064
|
+
n_nodes = len(onx.graph.node)
|
|
2065
|
+
if verbose:
|
|
2066
|
+
print(
|
|
2067
|
+
f"[run_ort_fusion] starts optimization for "
|
|
2068
|
+
f"model_type={model_type!r} with {n_nodes} nodes"
|
|
2069
|
+
)
|
|
2070
|
+
try:
|
|
2071
|
+
new_onx = optimize_by_fusion(
|
|
2072
|
+
onx,
|
|
2073
|
+
model_type=model_type,
|
|
2074
|
+
num_heads=num_attention_heads,
|
|
2075
|
+
hidden_size=hidden_size,
|
|
2076
|
+
optimization_options=opts,
|
|
2077
|
+
)
|
|
2078
|
+
except Exception as e:
|
|
2079
|
+
duration = time.perf_counter() - begin
|
|
2080
|
+
if verbose:
|
|
2081
|
+
print(f"[run_ort_fusion] failed in {duration} for model_type={model_type!r}")
|
|
2082
|
+
return {
|
|
2083
|
+
f"ERR_opt_ort_{model_type}": str(e),
|
|
2084
|
+
f"opt_ort_{model_type}_duration": duration,
|
|
2085
|
+
}, {}
|
|
2086
|
+
|
|
2087
|
+
duration = time.perf_counter() - begin
|
|
2088
|
+
delta = len(new_onx.model.graph.node)
|
|
2089
|
+
if verbose:
|
|
2090
|
+
print(f"[run_ort_fusion] done in {duration} with {delta} nodes")
|
|
2091
|
+
print(f"[run_ort_fusion] save to {output_path!r}")
|
|
2092
|
+
begin = time.perf_counter()
|
|
2093
|
+
new_onx.save_model_to_file(output_path, use_external_data_format=True)
|
|
2094
|
+
d = time.perf_counter() - begin
|
|
2095
|
+
if verbose:
|
|
2096
|
+
print(f"[run_ort_fusion] done in {d}")
|
|
2097
|
+
return {
|
|
2098
|
+
f"opt_ort_{model_type}_n_nodes1": n_nodes,
|
|
2099
|
+
f"opt_ort_{model_type}_n_nodes2": delta,
|
|
2100
|
+
f"opt_ort_{model_type}_delta_node": delta - n_nodes,
|
|
2101
|
+
f"opt_ort_{model_type}_duration": duration,
|
|
2102
|
+
f"opt_ort_{model_type}_duration_save": d,
|
|
2103
|
+
}, {f"opt_ort_{model_type}": output_path}
|
|
2104
|
+
|
|
2105
|
+
|
|
2106
|
+
def _compute_final_statistics(summary: Dict[str, Any]):
|
|
2107
|
+
"""
|
|
2108
|
+
Updates inline the list of statistics. It adds:
|
|
2109
|
+
|
|
2110
|
+
- speedup
|
|
2111
|
+
"""
|
|
2112
|
+
stats = {}
|
|
2113
|
+
if (
|
|
2114
|
+
"time_run_latency" in summary
|
|
2115
|
+
and "time_run_onnx_ort_latency" in summary
|
|
2116
|
+
and summary["time_run_onnx_ort_latency"] > 0
|
|
2117
|
+
):
|
|
2118
|
+
stats["stat_estimated_speedup_ort"] = (
|
|
2119
|
+
summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
|
|
2120
|
+
)
|
|
2121
|
+
stats["stat_estimated_speedup_ort_m98"] = (
|
|
2122
|
+
summary["time_run_latency_m98"] / summary["time_run_onnx_ort_latency_m98"]
|
|
2123
|
+
)
|
|
2124
|
+
summary.update(stats)
|