onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2124 @@
1
+ import gc
2
+ import datetime
3
+ import inspect
4
+ import os
5
+ import pprint
6
+ import sys
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
8
+ import time
9
+ import numpy as np
10
+ import onnx
11
+ import torch
12
+ from ..export import CoupleInputsDynamicShapes
13
+ from ..helpers import max_diff, string_type, string_diff
14
+ from ..helpers.helper import flatten_object
15
+ from ..helpers.rt_helper import make_feeds
16
+ from ..helpers.torch_helper import to_any, torch_deepcopy
17
+ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
18
+ from ..tasks import random_input_kwargs
19
+ from ..torch_export_patches import (
20
+ torch_export_patches,
21
+ register_additional_serialization_functions,
22
+ )
23
+ from ..torch_export_patches.patch_inputs import use_dyn_not_str
24
+ from .hghub import get_untrained_model_with_inputs
25
+ from .hghub.model_inputs import _preprocess_model_id
26
+
27
+
28
+ def empty(value: Any) -> bool:
29
+ """Tells if the value is empty."""
30
+ if isinstance(value, (str, list, dict, tuple, set)):
31
+ return not bool(value)
32
+ if value is None:
33
+ return True
34
+ return False
35
+
36
+
37
+ def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, Any]:
38
+ """
39
+ Returns dummy inputs for a specific task.
40
+
41
+ :param task: requested task
42
+ :param config: returns dummy inputs for a specific config if available
43
+ :return: dummy inputs and dynamic shapes
44
+ """
45
+ kwargs, f = random_input_kwargs(config, task)
46
+ return f(model=None, config=config, **kwargs)
47
+
48
+
49
+ def split_args_kwargs(inputs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
50
+ """Splits into args, kwargs."""
51
+ if isinstance(inputs, dict):
52
+ return (), inputs
53
+ if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[1], dict):
54
+ return inputs
55
+ assert isinstance(inputs, tuple), f"Unexpected inputs {string_type(inputs)}"
56
+ return inputs, {}
57
+
58
+
59
+ def make_inputs(
60
+ args: Optional[Tuple[Any, ...]], kwargs: Optional[Dict[str, Any]] = None
61
+ ) -> Any:
62
+ """Returns either args, kwargs or both depending on which ones are empty."""
63
+ assert args or kwargs, "No input was given."
64
+ if not args:
65
+ return kwargs
66
+ if not kwargs:
67
+ return args
68
+ return args, kwargs
69
+
70
+
71
+ def filter_inputs(
72
+ inputs: Any,
73
+ drop_names: List[str],
74
+ model: Optional[Union[torch.nn.Module, List[str]]] = None,
75
+ dynamic_shapes: Optional[Any] = None,
76
+ ):
77
+ """
78
+ Drops some inputs from the given inputs.
79
+ It updates the dynamic shapes as well.
80
+ """
81
+ args, kwargs = split_args_kwargs(inputs)
82
+ set_drop_names = set(drop_names)
83
+ kwargs = {k: v for k, v in kwargs.items() if k not in set_drop_names}
84
+ dyn = (
85
+ {k: v for k, v in dynamic_shapes.items() if k not in set_drop_names}
86
+ if dynamic_shapes and isinstance(dynamic_shapes, dict)
87
+ else dynamic_shapes
88
+ )
89
+ if not args or all(i in kwargs for i in set_drop_names):
90
+ return make_inputs(args, kwargs), dyn
91
+ assert model, (
92
+ f"we need the model to get the parameter name but model is None, "
93
+ f"input_names={drop_names} and args={string_type(args)}"
94
+ )
95
+ pnames = (
96
+ list(inspect.signature(model.forward).parameters)
97
+ if isinstance(model, torch.nn.Module)
98
+ else model
99
+ )
100
+ new_args = []
101
+ new_ds = []
102
+ for i, a in enumerate(args):
103
+ if isinstance(dynamic_shapes, tuple):
104
+ new_ds.append(None if pnames[i] in set_drop_names else dynamic_shapes[i])
105
+ new_args.append(None if pnames[i] in set_drop_names else a)
106
+ new_inputs = make_inputs(tuple(new_args), kwargs)
107
+ if new_ds:
108
+ return new_inputs, tuple(new_ds)
109
+ return new_inputs, dyn
110
+
111
+
112
+ def _make_folder_name(
113
+ model_id: str,
114
+ exporter: Optional[str],
115
+ optimization: Optional[str] = None,
116
+ dtype: Optional[Union[str, torch.dtype]] = None,
117
+ device: Optional[Union[str, torch.device]] = None,
118
+ subfolder: Optional[str] = None,
119
+ opset: Optional[int] = None,
120
+ drop_inputs: Optional[List[str]] = None,
121
+ same_as_pretrained: bool = False,
122
+ use_pretrained: bool = False,
123
+ task: Optional[str] = None,
124
+ ) -> str:
125
+ "Creates a filename unique based on the given options."
126
+ els = [model_id.replace("/", "_")]
127
+ if subfolder:
128
+ els.append(subfolder.replace("/", "_"))
129
+ if task:
130
+ els.append(task)
131
+ if drop_inputs:
132
+ ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
133
+ els.append(f"I-{ii.upper()}")
134
+ if use_pretrained:
135
+ els.append("TRAINED")
136
+ elif same_as_pretrained:
137
+ els.append("SAMESIZE")
138
+ if exporter:
139
+ els.append(exporter)
140
+ if optimization:
141
+ els.append(optimization)
142
+ if dtype is not None and dtype:
143
+ stype = dtype if isinstance(dtype, str) else str(dtype)
144
+ stype = stype.replace("float", "f").replace("uint", "u").replace("int", "i")
145
+ els.append(stype)
146
+ if device is not None and device:
147
+ sdev = device if isinstance(device, str) else str(device)
148
+ sdev = sdev.lower()
149
+ if "cpu" in sdev:
150
+ sdev = "cpu"
151
+ elif "cuda" in sdev:
152
+ sdev = "cuda"
153
+ else:
154
+ raise AssertionError(f"unexpected value for device={device}, sdev={sdev!r}")
155
+ els.append(sdev)
156
+ if opset is not None:
157
+ els.append(f"op{opset}")
158
+ return "/".join([e for e in els if e])
159
+
160
+
161
+ def version_summary() -> Dict[str, Union[int, float, str]]:
162
+ """
163
+ Example:
164
+
165
+ .. runpython::
166
+ :showcode:
167
+
168
+ import pprint
169
+ from onnx_diagnostic.torch_models.validate import version_summary
170
+
171
+ pprint.pprint(version_summary())
172
+ """
173
+ import numpy
174
+
175
+ summary: Dict[str, Union[int, float, str]] = {
176
+ "version_torch": torch.__version__,
177
+ "version_numpy": numpy.__version__,
178
+ }
179
+ try:
180
+ import scipy
181
+
182
+ summary["version_scipy"] = getattr(scipy, "__version__", "?")
183
+ except ImportError:
184
+ pass
185
+ try:
186
+ import transformers
187
+
188
+ summary["version_transformers"] = getattr(transformers, "__version__", "?")
189
+ except ImportError:
190
+ pass
191
+ try:
192
+ import onnx
193
+
194
+ summary["version_onnx"] = getattr(onnx, "__version__", "?")
195
+ except ImportError:
196
+ pass
197
+ try:
198
+ import onnxscript
199
+
200
+ summary["version_onnxscript"] = getattr(onnxscript, "__version__", "?")
201
+ except ImportError:
202
+ pass
203
+ try:
204
+ import onnxruntime
205
+
206
+ summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?")
207
+ except ImportError:
208
+ pass
209
+ try:
210
+ import onnx_ir
211
+
212
+ summary["version_onnx_ir"] = getattr(onnx_ir, "__version__", "?")
213
+ except ImportError:
214
+ pass
215
+ import onnx_diagnostic
216
+
217
+ summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__
218
+ summary["version_date"] = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
219
+ return summary
220
+
221
+
222
+ def _quiet_or_not_quiet(
223
+ quiet: bool,
224
+ suffix: str,
225
+ summary: Dict[str, Any],
226
+ data: Optional[Dict[str, Any]],
227
+ fct: Callable,
228
+ repeat: int = 1,
229
+ warmup: int = 0,
230
+ ) -> Any:
231
+ begin = time.perf_counter()
232
+ if quiet:
233
+ try:
234
+ res = fct()
235
+ summary[f"time_{suffix}"] = time.perf_counter() - begin
236
+ if warmup + repeat == 1:
237
+ return res
238
+ except Exception as e:
239
+ summary[f"ERR_{suffix}"] = str(e)
240
+ summary[f"time_{suffix}"] = time.perf_counter() - begin
241
+ if data is None:
242
+ return {f"ERR_{suffix}": e}
243
+ data[f"ERR_{suffix}"] = e
244
+ return None
245
+ else:
246
+ res = fct()
247
+ summary[f"time_{suffix}"] = time.perf_counter() - begin
248
+ if warmup + repeat > 1:
249
+ if suffix == "run":
250
+ res = torch_deepcopy(res)
251
+ summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
252
+ summary[f"{suffix}_warmup"] = warmup
253
+ summary[f"{suffix}_repeat"] = repeat
254
+ last_ = None
255
+ end_w = max(0, warmup - 1)
256
+ for _w in range(end_w):
257
+ t = fct()
258
+ _ = string_type(t, with_shape=True, with_min_max=True)
259
+ if _ != last_ or _w == end_w - 1:
260
+ summary[f"io_{suffix}_{_w+1}"] = _
261
+ last_ = _
262
+ summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
263
+ times = []
264
+ for _r in range(repeat):
265
+ begin = time.perf_counter()
266
+ t = fct()
267
+ times.append(time.perf_counter() - begin)
268
+ a = np.array(times, dtype=np.float64)
269
+ a.sort()
270
+ i5 = max(1, a.shape[0] * 5 // 100)
271
+ i2 = max(1, a.shape[0] * 2 // 100)
272
+ summary[f"time_{suffix}_latency"] = a.mean()
273
+ summary[f"time_{suffix}_latency_std"] = a.std()
274
+ summary[f"time_{suffix}_latency_min"] = a.min()
275
+ summary[f"time_{suffix}_latency_max"] = a.max()
276
+ summary[f"time_{suffix}_latency_098"] = a[-i2]
277
+ summary[f"time_{suffix}_latency_095"] = a[-i5]
278
+ summary[f"time_{suffix}_latency_005"] = a[i5]
279
+ summary[f"time_{suffix}_latency_002"] = a[i2]
280
+ summary[f"time_{suffix}_n"] = len(a)
281
+ summary[f"time_{suffix}_latency_m98"] = a[i2:-i2].mean()
282
+
283
+ return res
284
+
285
+
286
+ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
287
+ """Shrinks the configuration before it gets added to the information to log."""
288
+ new_cfg = {}
289
+ for k, v in cfg.items():
290
+
291
+ new_cfg[k] = (
292
+ v
293
+ if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
294
+ else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
295
+ )
296
+ return new_cfg
297
+
298
+
299
+ def make_patch_kwargs(
300
+ patch: Union[bool, str, Dict[str, bool]] = False,
301
+ rewrite: bool = False,
302
+ ) -> Dict[str, Any]:
303
+ """Creates patch arguments."""
304
+ default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
305
+ if isinstance(patch, bool):
306
+ patch_kwargs = default_patch if patch else dict(patch=False)
307
+ elif isinstance(patch, str):
308
+ patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
309
+ else:
310
+ assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
311
+ patch_kwargs = patch.copy()
312
+ if "patch" not in patch_kwargs:
313
+ if any(patch_kwargs.values()):
314
+ patch_kwargs["patch"] = True
315
+ elif len(patch) == 1 and patch.get("patch", False):
316
+ patch_kwargs.update(default_patch)
317
+
318
+ assert not rewrite or patch_kwargs.get("patch", False), (
319
+ f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
320
+ f"patch must be True to enable rewriting, "
321
+ f"if --patch=0 was specified on the command line, rewrites are disabled."
322
+ )
323
+ return patch_kwargs
324
+
325
+
326
+ def validate_model(
327
+ model_id: str,
328
+ task: Optional[str] = None,
329
+ do_run: bool = False,
330
+ exporter: Optional[str] = None,
331
+ do_same: bool = False,
332
+ verbose: int = 0,
333
+ dtype: Optional[Union[str, torch.dtype]] = None,
334
+ device: Optional[Union[str, torch.device]] = None,
335
+ same_as_pretrained: bool = False,
336
+ use_pretrained: bool = False,
337
+ optimization: Optional[str] = None,
338
+ quiet: bool = False,
339
+ patch: Union[bool, str, Dict[str, bool]] = False,
340
+ rewrite: bool = False,
341
+ stop_if_static: int = 1,
342
+ dump_folder: Optional[str] = None,
343
+ drop_inputs: Optional[List[str]] = None,
344
+ ortfusiontype: Optional[str] = None,
345
+ input_options: Optional[Dict[str, Any]] = None,
346
+ model_options: Optional[Dict[str, Any]] = None,
347
+ subfolder: Optional[str] = None,
348
+ opset: Optional[int] = None,
349
+ runtime: str = "onnxruntime",
350
+ repeat: int = 1,
351
+ warmup: int = 0,
352
+ inputs2: int = 1,
353
+ output_names: Optional[List[str]] = None,
354
+ ort_logs: bool = False,
355
+ quiet_input_sets: Optional[Set[str]] = None,
356
+ ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
357
+ """
358
+ Validates a model.
359
+ The function can also be called through the command line
360
+ :ref:`l-cmd-validate`.
361
+
362
+ :param model_id: model id to validate
363
+ :param task: task used to generate the necessary inputs,
364
+ can be left empty to use the default task for this model
365
+ if it can be determined
366
+ :param do_run: checks the model works with the defined inputs
367
+ :param exporter: exporter the model using this exporter,
368
+ available list: ``export-strict``, ``export-nostrict``, ...
369
+ see below
370
+ :param do_same: checks the discrepancies of the exported model
371
+ :param verbose: verbosity level
372
+ :param dtype: uses this dtype to check the model
373
+ :param device: do the verification on this device
374
+ :param same_as_pretrained: use a model equivalent to the trained,
375
+ this is not always possible
376
+ :param use_pretrained: use the trained model, not the untrained one
377
+ :param optimization: optimization to apply to the exported model,
378
+ depend on the the exporter
379
+ :param quiet: if quiet, catches exception if any issue
380
+ :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
381
+ if True before exporting
382
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
383
+ a string can be used to specify only one of them
384
+ :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
385
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
386
+ :param stop_if_static: stops if a dynamic dimension becomes static,
387
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
388
+ :param dump_folder: dumps everything in a subfolder of this one
389
+ :param drop_inputs: drops this list of inputs (given their names)
390
+ :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
391
+ it accepts multiple values separated by ``|``,
392
+ see :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`
393
+ :param input_options: additional options to define the dummy inputs
394
+ used to export
395
+ :param model_options: additional options when creating the model such as
396
+ ``num_hidden_layers`` or ``attn_implementation``
397
+ :param subfolder: version or subfolders to uses when retrieving a model id
398
+ :param opset: onnx opset to use for the conversion
399
+ :param runtime: onnx runtime to use to check about discrepancies,
400
+ possible values ``onnxruntime``, ``torch``, ``orteval``,
401
+ ``orteval10``, ``ref`` only if `do_run` is true
402
+ :param repeat: number of time to measure the model
403
+ :param warmup: warmup the model first
404
+ :param inputs2: checks that other sets of inputs are running as well,
405
+ this ensures that the model does support dynamism, the value is used
406
+ as an increment to the first set of values (added to dimensions),
407
+ or an empty cache for example
408
+ :param output_names: output names the onnx exporter should use
409
+ :param ort_logs: increases onnxruntime verbosity when creating the session
410
+ :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
411
+ even if quiet is False
412
+ :return: two dictionaries, one with some metrics,
413
+ another one with whatever the function produces
414
+
415
+ The following environment variables can be used to print out some
416
+ information:
417
+
418
+ * ``PRINT_CONFIG``: prints the model configuration
419
+
420
+ The following exporters are available:
421
+
422
+ * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
423
+ * ``onnx-dynamo``: run :func:`torch.onnx.export` (...),
424
+ models can be optimized with ``optimization`` in ``("ir", "os_ort")``
425
+ * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
426
+ * ``custom``: custom exporter (see :epkg:`experimental-experiment`),
427
+ models can be optimized with ``optimization`` in
428
+ ``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
429
+
430
+ The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
431
+ exported model returns the same outputs as the original one, otherwise,
432
+ :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
433
+ if ``runtime == 'torch'`` or
434
+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
435
+ if ``runtime == 'orteval'`` or
436
+ :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
437
+ if ``runtime == 'ref'``,
438
+ ``orteval10`` increases the verbosity.
439
+
440
+ .. versionchanged:: 0.7.13
441
+ *inputs2* not only means a second set of inputs but many
442
+ such as ``input_empty_cache``
443
+ which refers to a set of inputs using an empty cache.
444
+ """
445
+ main_validation_begin = time.perf_counter()
446
+ model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
447
+ model_id,
448
+ subfolder,
449
+ same_as_pretrained=same_as_pretrained,
450
+ use_pretrained=use_pretrained,
451
+ )
452
+ time_preprocess_model_id = time.perf_counter() - main_validation_begin
453
+ patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
454
+
455
+ summary = version_summary()
456
+ summary.update(
457
+ dict(
458
+ version_model_id=model_id,
459
+ version_do_run=str(do_run),
460
+ version_dtype=str(dtype or ""),
461
+ version_device=str(device or ""),
462
+ version_same_as_pretrained=str(same_as_pretrained),
463
+ version_use_pretrained=str(use_pretrained),
464
+ version_optimization=optimization or "",
465
+ version_quiet=str(quiet),
466
+ version_patch=str(patch),
467
+ version_patch_kwargs=str(patch_kwargs).replace(" ", ""),
468
+ version_rewrite=str(rewrite),
469
+ version_dump_folder=dump_folder or "",
470
+ version_drop_inputs=str(list(drop_inputs or "")),
471
+ version_ortfusiontype=ortfusiontype or "",
472
+ version_stop_if_static=str(stop_if_static),
473
+ version_exporter=exporter or "",
474
+ version_runtime=runtime,
475
+ version_inputs2=inputs2,
476
+ time_preprocess_model_id=time_preprocess_model_id,
477
+ )
478
+ )
479
+ if opset:
480
+ summary["version_opset"] = opset
481
+
482
+ folder_name = None
483
+ if dump_folder:
484
+ folder_name = _make_folder_name(
485
+ model_id,
486
+ exporter,
487
+ optimization,
488
+ dtype=dtype,
489
+ device=device,
490
+ subfolder=subfolder,
491
+ opset=opset,
492
+ drop_inputs=drop_inputs,
493
+ use_pretrained=use_pretrained,
494
+ same_as_pretrained=same_as_pretrained,
495
+ task=task,
496
+ )
497
+ dump_folder = os.path.join(dump_folder, folder_name)
498
+ if not os.path.exists(dump_folder):
499
+ os.makedirs(dump_folder)
500
+ summary["dump_folder"] = dump_folder
501
+ summary["dump_folder_name"] = folder_name
502
+ if verbose:
503
+ print(f"[validate_model] dump into {folder_name!r}")
504
+
505
+ if verbose:
506
+ if subfolder:
507
+ print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
508
+ else:
509
+ print(f"[validate_model] validate model id {model_id!r}")
510
+ if task:
511
+ print(f"[validate_model] with task {task!r}")
512
+ print(f"[validate_model] patch={patch!r}")
513
+ if model_options:
514
+ print(f"[validate_model] model_options={model_options!r}")
515
+ print(f"[validate_model] get dummy inputs with input_options={input_options}...")
516
+ print(
517
+ f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, "
518
+ f"stop_if_static={stop_if_static}"
519
+ )
520
+ print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
521
+ print(f"[validate_model] dump_folder={dump_folder!r}")
522
+ print(f"[validate_model] output_names={output_names}")
523
+ summary["model_id"] = model_id
524
+ summary["model_subfolder"] = subfolder or ""
525
+
526
+ iop = input_options or {}
527
+ mop = model_options or {}
528
+ data = _quiet_or_not_quiet(
529
+ quiet,
530
+ "create_torch_model",
531
+ summary,
532
+ None,
533
+ (
534
+ lambda mid=model_id, v=verbose, task=task, uptr=use_pretrained, tr=same_as_pretrained, iop=iop, sub=subfolder, i2=inputs2: ( # noqa: E501
535
+ get_untrained_model_with_inputs(
536
+ mid,
537
+ verbose=v,
538
+ task=task,
539
+ use_pretrained=uptr,
540
+ same_as_pretrained=tr,
541
+ inputs_kwargs=iop,
542
+ model_kwargs=mop,
543
+ subfolder=sub,
544
+ add_second_input=i2,
545
+ )
546
+ )
547
+ ),
548
+ )
549
+
550
+ second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
551
+
552
+ if dump_folder:
553
+ with open(os.path.join(dump_folder, "model_config.txt"), "w") as f:
554
+ f.write(f"model_id: {model_id}\n------\n")
555
+ f.write(
556
+ pprint.pformat(
557
+ data["configuration"]
558
+ if type(data["configuration"]) is dict
559
+ else data["configuration"].to_dict()
560
+ )
561
+ )
562
+ dump_info = data.get("dump_info", None)
563
+ if dump_info:
564
+ with open(os.path.join(dump_folder, "model_dump_info.txt"), "w") as f:
565
+ f.write(f"model_id: {model_id}\n------\n")
566
+ f.write(pprint.pformat(dump_info))
567
+
568
+ if exporter == "modelbuilder":
569
+ # Models used with ModelBuilder do not like batch size > 1.
570
+ # Let's change that.
571
+ for k in ["inputs", "inputs2"]:
572
+ if k not in data:
573
+ continue
574
+ if verbose:
575
+ print(f"[validate_model] set batch=1 for data[{k!r}]")
576
+ print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
577
+ cpl = CoupleInputsDynamicShapes(
578
+ tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
579
+ )
580
+ with register_additional_serialization_functions(patch_transformers=True): # type: ignore[arg-type]
581
+ data[k] = cpl.change_dynamic_dimensions(
582
+ desired_values=dict(batch=1), only_desired=True
583
+ )
584
+ if verbose:
585
+ print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
586
+
587
+ # modelbuilder needs different treatments sometimes, so
588
+ # we mark it for later usage.
589
+ # for example, it has different past_kv ordering than
590
+ # flattened CacheObject
591
+ data["exporter"] = exporter
592
+ data["input_options"] = iop
593
+ data["model_options"] = mop
594
+ data["model_dump_folder"] = dump_folder
595
+ if dtype:
596
+ data["model_dtype"] = dtype if isinstance(dtype, str) else str(dtype)
597
+ if device:
598
+ data["model_device"] = str(device)
599
+ if opset:
600
+ data["model_opset"] = opset
601
+ if "rewrite" in data:
602
+ if rewrite:
603
+ summary["model_rewrite"] = str(data["rewrite"])
604
+ if verbose:
605
+ print(f"[validate_model] model_rewrite={summary['model_rewrite']}")
606
+ else:
607
+ del data["rewrite"]
608
+ if verbose:
609
+ print("[validate_model] no rewrite")
610
+ if os.environ.get("PRINT_CONFIG", "0") in (1, "1"):
611
+ print("[validate_model] -- PRINT CONFIG")
612
+ print("-- type(config)", type(data["configuration"]))
613
+ print(data["configuration"])
614
+ print("[validate_model] -- END PRINT CONFIG")
615
+ if iop:
616
+ summary["input_options"] = str(iop)
617
+ if mop:
618
+ summary["model_options"] = str(mop)
619
+ if "ERR_create" in summary:
620
+ return summary, data
621
+
622
+ if drop_inputs:
623
+ if verbose:
624
+ print(f"[validate_model] -- drop inputs: {drop_inputs!r}")
625
+ print(f"[validate_model] current inputs: {string_type(data['inputs'])}")
626
+ print(
627
+ f"[validate_model] current dynnamic_shapes: "
628
+ f"{string_type(data['dynamic_shapes'])}"
629
+ )
630
+ data["inputs"], data["dynamic_shapes"] = filter_inputs(
631
+ data["inputs"],
632
+ drop_names=drop_inputs,
633
+ model=data["model"],
634
+ dynamic_shapes=data["dynamic_shapes"],
635
+ )
636
+ if verbose:
637
+ print(f"[validate_model] new inputs: {string_type(data['inputs'])}")
638
+ print(f"[validate_model] new dynamic_hapes: {string_type(data['dynamic_shapes'])}")
639
+ if second_input_keys:
640
+ for k in second_input_keys:
641
+ data[k], _ = filter_inputs(
642
+ data[k],
643
+ drop_names=drop_inputs,
644
+ model=data["model"],
645
+ dynamic_shapes=data["dynamic_shapes"],
646
+ )
647
+
648
+ if not empty(dtype):
649
+ if isinstance(dtype, str):
650
+ dtype = getattr(torch, dtype)
651
+ if verbose:
652
+ print(f"[validate_model] dtype conversion to {dtype}")
653
+ data["model"] = to_any(data["model"], dtype) # type: ignore
654
+ data["inputs"] = to_any(data["inputs"], dtype) # type: ignore
655
+ summary["model_dtype"] = str(dtype)
656
+ if second_input_keys:
657
+ for k in second_input_keys:
658
+ data[k] = to_any(data[k], dtype) # type: ignore
659
+
660
+ if not empty(device):
661
+ if verbose:
662
+ print(f"[validate_model] device conversion to {device}")
663
+ data["model"] = to_any(data["model"], device) # type: ignore
664
+ data["inputs"] = to_any(data["inputs"], device) # type: ignore
665
+ summary["model_device"] = str(device)
666
+ if second_input_keys:
667
+ for k in second_input_keys:
668
+ data[k] = to_any(data[k], device) # type: ignore
669
+
670
+ for k in ["task", "size", "n_weights"]:
671
+ summary[f"model_{k.replace('_','')}"] = data[k]
672
+ summary["second_input_keys"] = ",".join(second_input_keys)
673
+ summary["model_inputs_options"] = str(input_options or "")
674
+ summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
675
+ summary["model_shapes"] = string_type(data["dynamic_shapes"])
676
+ summary["model_class"] = data["model"].__class__.__name__
677
+ summary["model_module"] = str(data["model"].__class__.__module__)
678
+ if summary["model_module"] in sys.modules:
679
+ summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
680
+ summary["model_config_class"] = data["configuration"].__class__.__name__
681
+ summary["model_config"] = str(
682
+ shrink_config(
683
+ data["configuration"]
684
+ if type(data["configuration"]) is dict
685
+ else data["configuration"].to_dict()
686
+ )
687
+ ).replace(" ", "")
688
+ summary["model_id"] = model_id
689
+
690
+ if verbose:
691
+ print("[validate_model] --")
692
+ print(f"[validate_model] task={data['task']}")
693
+ print(f"[validate_model] size={data['size'] / 2**20} Mb")
694
+ print(f"[validate_model] n_weights={data['n_weights'] / 1e6} millions parameters")
695
+ for k, v in data["inputs"].items():
696
+ print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}")
697
+ for k, v in data["dynamic_shapes"].items():
698
+ print(f"[validate_model] +SHAPE {k}={string_type(v)}")
699
+ print(f"[validate_model] second_input_keys={second_input_keys}")
700
+ print("[validate_model] --")
701
+
702
+ if do_run:
703
+ validation_begin = time.perf_counter()
704
+
705
+ _validate_do_run_model(
706
+ data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
707
+ )
708
+ if second_input_keys:
709
+ for k in second_input_keys:
710
+ _validate_do_run_model(
711
+ data,
712
+ summary,
713
+ k,
714
+ f"run2{k[6:]}",
715
+ f"run_expected2{k[6:]}",
716
+ verbose,
717
+ 1,
718
+ 0,
719
+ quiet,
720
+ )
721
+
722
+ summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
723
+
724
+ if exporter:
725
+ print(
726
+ f"[validate_model] -- export the model with {exporter!r}, "
727
+ f"optimization={optimization!r}"
728
+ )
729
+ exporter_begin = time.perf_counter()
730
+ if patch_kwargs:
731
+ if verbose:
732
+ print(
733
+ f"[validate_model] applies patches before exporting "
734
+ f"stop_if_static={stop_if_static}"
735
+ )
736
+ with torch_export_patches( # type: ignore
737
+ stop_if_static=stop_if_static,
738
+ verbose=max(0, verbose - 1),
739
+ rewrite=data.get("rewrite", None),
740
+ dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
741
+ **patch_kwargs, # type: ignore[arg-type]
742
+ ) as modificator:
743
+ data["inputs_export"] = modificator(data["inputs"]) # type: ignore
744
+
745
+ if do_run:
746
+ _validate_do_run_exported_program(data, summary, verbose, quiet)
747
+
748
+ # data is modified inplace
749
+ summary_export, data = call_exporter(
750
+ exporter=exporter,
751
+ data=data,
752
+ quiet=quiet,
753
+ verbose=verbose,
754
+ optimization=optimization,
755
+ do_run=do_run,
756
+ dump_folder=dump_folder,
757
+ output_names=output_names,
758
+ )
759
+ else:
760
+ data["inputs_export"] = data["inputs"]
761
+ # data is modified inplace
762
+ summary_export, data = call_exporter(
763
+ exporter=exporter,
764
+ data=data,
765
+ quiet=quiet,
766
+ verbose=verbose,
767
+ optimization=optimization,
768
+ do_run=do_run,
769
+ dump_folder=dump_folder,
770
+ output_names=output_names,
771
+ )
772
+
773
+ summary.update(summary_export)
774
+ summary["time_total_exporter"] = time.perf_counter() - exporter_begin
775
+
776
+ dump_stats = None
777
+ if dump_folder:
778
+ if "exported_program" in data:
779
+ ep = data["exported_program"]
780
+ if verbose:
781
+ print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
782
+ assert isinstance(
783
+ folder_name, str
784
+ ), f"folder_name={folder_name!r} should be a string"
785
+ folder_name = folder_name.replace("/", "-")
786
+ with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
787
+ f.write(str(ep))
788
+ torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
789
+ with open(os.path.join(dump_folder, f"{folder_name}.graph"), "w") as f:
790
+ f.write(str(ep.graph))
791
+ if verbose:
792
+ print("[validate_model] done (dump ep)")
793
+ if "onnx_program" in data:
794
+ assert isinstance(
795
+ folder_name, str
796
+ ), f"folder_name={folder_name!r} should be a string"
797
+ folder_name = folder_name.replace("/", "-")
798
+ epo = data["onnx_program"]
799
+ if verbose:
800
+ print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
801
+ onnx_filename = os.path.join(dump_folder, f"{folder_name}.onnx")
802
+ begin = time.perf_counter()
803
+ if isinstance(epo, onnx.model_container.ModelContainer):
804
+ epo.save(onnx_filename, all_tensors_to_one_file=True)
805
+ elif isinstance(epo, onnx.ModelProto):
806
+ if os.path.exists(f"{onnx_filename}.data"):
807
+ os.remove(f"{onnx_filename}.data")
808
+ onnx.save(
809
+ epo,
810
+ onnx_filename,
811
+ save_as_external_data=True,
812
+ all_tensors_to_one_file=True,
813
+ location=f"{os.path.split(onnx_filename)[-1]}.data",
814
+ )
815
+ else:
816
+ epo.save(onnx_filename, external_data=True)
817
+ duration = time.perf_counter() - begin
818
+ if verbose:
819
+ print(f"[validate_model] done (dump onnx) in {duration}")
820
+ data["onnx_filename"] = onnx_filename
821
+ summary["time_onnx_save"] = duration
822
+ summary.update(compute_statistics(onnx_filename))
823
+ del epo
824
+
825
+ if verbose:
826
+ print(f"[validate_model] dumps statistics in {dump_folder!r}...")
827
+ dump_stats = os.path.join(dump_folder, f"{folder_name}.stats")
828
+ with open(dump_stats, "w") as f:
829
+ for k, v in sorted(summary.items()):
830
+ f.write(f":{k}:{v};\n")
831
+ if verbose:
832
+ print("[validate_model] done (dump)")
833
+
834
+ if not exporter or (
835
+ not exporter.startswith(("onnx-", "custom-"))
836
+ and exporter not in ("custom", "modelbuilder")
837
+ ):
838
+ if verbose:
839
+ print("[validate_model] -- done (final)")
840
+ if dump_stats:
841
+ with open(dump_stats, "w") as f:
842
+ for k, v in sorted(summary.items()):
843
+ f.write(f":{k}:{v};\n")
844
+ return summary, data
845
+
846
+ if do_run:
847
+ # Let's move the model to CPU to make sure it frees GPU memory.
848
+ if verbose:
849
+ # It does not really work for the time being and the model
850
+ # gets loaded twice, one by torch, one by onnxruntime
851
+ print("[validation_model] -- delete the model")
852
+ for key in ["model", "onnx_program", "config"]:
853
+ if key in data:
854
+ del data[key]
855
+ if device is not None and "cuda" in str(device).lower():
856
+ torch.cuda.empty_cache()
857
+ gc.collect()
858
+ print("[validation_model] -- done")
859
+
860
+ validation_begin = time.perf_counter()
861
+ summary_valid, data = validate_onnx_model(
862
+ data=data,
863
+ quiet=quiet,
864
+ verbose=verbose,
865
+ runtime=runtime,
866
+ repeat=repeat,
867
+ warmup=warmup,
868
+ second_input_keys=second_input_keys,
869
+ ort_logs=ort_logs,
870
+ quiet_input_sets=quiet_input_sets,
871
+ )
872
+ summary.update(summary_valid)
873
+ summary["time_total_validation_onnx"] = time.perf_counter() - validation_begin
874
+
875
+ if ortfusiontype and "onnx_filename" in data:
876
+ assert (
877
+ "configuration" in data
878
+ ), f"missing configuration in data, cannot run ort fusion for model_id={model_id}"
879
+ config = data["configuration"]
880
+ assert hasattr(
881
+ config, "hidden_size"
882
+ ), f"Missing attribute hidden_size in configuration {config}"
883
+ hidden_size = config.hidden_size
884
+ assert hasattr(
885
+ config, "num_attention_heads"
886
+ ), f"Missing attribute num_attention_heads in configuration {config}"
887
+ num_attention_heads = config.num_attention_heads
888
+
889
+ if ortfusiontype == "ALL":
890
+ from onnxruntime.transformers.optimizer import MODEL_TYPES
891
+
892
+ model_types = sorted(MODEL_TYPES)
893
+ else:
894
+ model_types = ortfusiontype.split("|")
895
+ for model_type in model_types:
896
+ flavour = f"ort{model_type}"
897
+ summary[f"version_{flavour}_hidden_size"] = hidden_size
898
+ summary[f"version_{flavour}_num_attention_heads"] = num_attention_heads
899
+
900
+ begin = time.perf_counter()
901
+ if verbose:
902
+ print(f"[validate_model] run onnxruntime fusion for {model_type!r}")
903
+ input_filename = data["onnx_filename"]
904
+ output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx"
905
+ ort_sum, ort_data = run_ort_fusion(
906
+ input_filename,
907
+ output_path,
908
+ model_type=model_type,
909
+ num_attention_heads=num_attention_heads,
910
+ hidden_size=hidden_size,
911
+ )
912
+ summary.update(ort_sum)
913
+ data.update(ort_data)
914
+ data[f"onnx_filename_{flavour}"] = output_path
915
+ duration = time.perf_counter() - begin
916
+ summary[f"time_ortfusion_{flavour}"] = duration
917
+ if verbose:
918
+ print(
919
+ f"[validate_model] done {model_type!r} in {duration}, "
920
+ f"saved into {output_path!r}"
921
+ )
922
+
923
+ if do_run:
924
+ summary_valid, data = validate_onnx_model(
925
+ data=data,
926
+ quiet=quiet,
927
+ verbose=verbose,
928
+ flavour=flavour,
929
+ runtime=runtime,
930
+ repeat=repeat,
931
+ warmup=warmup,
932
+ second_input_keys=second_input_keys,
933
+ quiet_input_sets=quiet_input_sets,
934
+ )
935
+ summary.update(summary_valid)
936
+
937
+ _compute_final_statistics(summary)
938
+ summary["time_total"] = time.perf_counter() - main_validation_begin
939
+
940
+ if verbose:
941
+ print("[validate_model] -- done (final)")
942
+ if dump_stats:
943
+ # Dumps again the statistics.
944
+ with open(dump_stats, "w") as f:
945
+ for k, v in sorted(summary.items()):
946
+ f.write(f":{k}:{v};\n")
947
+ return summary, data
948
+
949
+
950
+ def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
951
+ """Computes some statistics on the model itself."""
952
+ onx = onnx.load(onnx_filename, load_external_data=False)
953
+ cache_functions = {(f.domain, f.name): f for f in onx.functions}
954
+ local_domains = set(f.domain for f in onx.functions)
955
+
956
+ def node_iter(proto):
957
+ if isinstance(proto, onnx.ModelProto):
958
+ for f in proto.functions:
959
+ yield from node_iter(f)
960
+ yield from node_iter(proto.graph)
961
+ elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
962
+ for node in proto.node:
963
+ yield node
964
+
965
+ # Let's inline the function
966
+ key = node.domain, node.op_type
967
+ if key in cache_functions:
968
+ yield from node_iter(cache_functions[key])
969
+
970
+ # Let's continue
971
+ for att in node.attribute:
972
+ if att.type == onnx.AttributeProto.GRAPH:
973
+ yield from node_iter(att.g)
974
+ if hasattr(proto, "initializer"):
975
+ yield from proto.initializer
976
+ else:
977
+ raise NotImplementedError(f"Unexpected type={type(proto)}")
978
+
979
+ counts: Dict[str, Union[float, int]] = {}
980
+ n_nodes = 0
981
+ n_nodes_nocst = 0
982
+ for proto in node_iter(onx):
983
+ if isinstance(proto, onnx.NodeProto):
984
+ key = f"n_node_{proto.op_type}"
985
+ n_nodes += 1
986
+ if proto.op_type != "Constant":
987
+ n_nodes_nocst += 1
988
+ if proto.domain in local_domains:
989
+ key = "n_node_local_function"
990
+ if key not in counts:
991
+ counts[key] = 0
992
+ counts[key] += 1
993
+ else:
994
+ key = f"n_node_initializer_{proto.data_type}"
995
+
996
+ if key not in counts:
997
+ counts[key] = 0
998
+ counts[key] += 1
999
+
1000
+ counts["n_node_nodes"] = n_nodes
1001
+ counts["n_node_nodes_nocst"] = n_nodes_nocst
1002
+ counts["n_node_functions"] = len(onx.functions)
1003
+ return counts
1004
+
1005
+
1006
+ def _validate_do_run_model(
1007
+ data, summary, key, tag, expected_tag, verbose, repeat, warmup, quiet
1008
+ ):
1009
+ if verbose:
1010
+ print(f"[validate_model] -- run the model inputs={key!r}...")
1011
+ print(f"[validate_model] {key}={string_type(data[key], with_shape=True)}")
1012
+ # We make a copy of the input just in case the model modifies them inplace
1013
+ hash_inputs = string_type(data[key], with_shape=True)
1014
+ inputs = torch_deepcopy(data[key])
1015
+ model = data["model"]
1016
+
1017
+ expected = _quiet_or_not_quiet(
1018
+ quiet,
1019
+ tag,
1020
+ summary,
1021
+ data,
1022
+ (lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
1023
+ repeat=repeat,
1024
+ warmup=warmup,
1025
+ )
1026
+ if f"ERR_{tag}" in summary:
1027
+ return summary, data
1028
+
1029
+ summary[expected_tag] = string_type(expected, with_shape=True)
1030
+ if verbose:
1031
+ print(f"[validate_model] done ([{tag}])")
1032
+ data[expected_tag] = expected
1033
+ assert hash_inputs == string_type(data[key], with_shape=True), (
1034
+ f"The model execution did modified the inputs:\n"
1035
+ f"before: {hash_inputs}\n"
1036
+ f" after: {string_type(data[key], with_shape=True)}"
1037
+ )
1038
+
1039
+
1040
+ def _validate_do_run_exported_program(data, summary, verbose, quiet):
1041
+
1042
+ # We run a second time the model to check the patch did not
1043
+ # introduce any discrepancies
1044
+ if verbose:
1045
+ print("[validate_model] run patched model...")
1046
+ print(
1047
+ f"[validate_model] patched inputs="
1048
+ f"{string_type(data['inputs_export'], with_shape=True)}"
1049
+ )
1050
+ hash_inputs = string_type(data["inputs_export"], with_shape=True)
1051
+
1052
+ # We make a copy of the input just in case the model modifies them inplace
1053
+ inputs = torch_deepcopy(data["inputs_export"])
1054
+ model = data["model"]
1055
+
1056
+ expected = _quiet_or_not_quiet(
1057
+ quiet,
1058
+ "run_patched",
1059
+ summary,
1060
+ data,
1061
+ (lambda m=model, inp=inputs: m(**inp)),
1062
+ )
1063
+ if "ERR_run_patched" in summary:
1064
+ return summary, data
1065
+
1066
+ disc = max_diff(data["run_expected"], expected)
1067
+ for k, v in disc.items():
1068
+ summary[f"disc_patched_{k}"] = str(v)
1069
+ if verbose:
1070
+ print("[validate_model] done (patched run)")
1071
+ print(f"[validate_model] patched discrepancies={string_diff(disc)}")
1072
+ assert hash_inputs == string_type(data["inputs_export"], with_shape=True), (
1073
+ f"The model execution did modified the inputs:\n"
1074
+ f"before: {hash_inputs}\n"
1075
+ f" after: {string_type(data['inputs_export'], with_shape=True)}"
1076
+ )
1077
+
1078
+
1079
+ _cache_export_times = []
1080
+ _main_export_function = torch.export.export
1081
+
1082
+
1083
+ def _torch_export_export(*args, _export=_main_export_function, **kwargs):
1084
+ begin = time.perf_counter()
1085
+ res = _export(*args, **kwargs)
1086
+ duration = time.perf_counter() - begin
1087
+ _cache_export_times.append(duration)
1088
+ return res
1089
+
1090
+
1091
+ def _restore_torch_export_export(summary):
1092
+ torch.export.export = _main_export_function
1093
+ if _cache_export_times:
1094
+ summary["time_torch_export_export"] = sum(_cache_export_times)
1095
+ summary["time_torch_export_export_n"] = len(_cache_export_times)
1096
+ _cache_export_times.clear()
1097
+
1098
+
1099
+ def call_exporter(
1100
+ data: Dict[str, Any],
1101
+ exporter: str,
1102
+ quiet: bool = False,
1103
+ verbose: int = 0,
1104
+ optimization: Optional[str] = None,
1105
+ do_run: bool = False,
1106
+ dump_folder: Optional[str] = None,
1107
+ output_names: Optional[List[str]] = None,
1108
+ ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
1109
+ """
1110
+ Calls an exporter on a model;
1111
+ If a patch must be applied, it should be before this functions.
1112
+
1113
+ :param data: dictionary with all the necessary inputs
1114
+ :param exporter: exporter to call
1115
+ :param quiet: catch exception or not
1116
+ :param verbose: verbosity
1117
+ :param optimization: optimization to do
1118
+ :param do_run: runs and compute discrepancies
1119
+ :param dump_folder: to dump additional information
1120
+ :param output_names: list of output names to use with the onnx exporter
1121
+ :return: two dictionaries, one with some metrics,
1122
+ another one with whatever the function produces
1123
+ """
1124
+ _cache_export_times.clear()
1125
+ torch.export.export = _torch_export_export
1126
+
1127
+ if exporter == "export" or exporter.startswith("export-"):
1128
+ # torch export
1129
+ summary, data = call_torch_export_export(
1130
+ exporter=exporter,
1131
+ data=data,
1132
+ quiet=quiet,
1133
+ verbose=verbose,
1134
+ optimization=optimization,
1135
+ do_run=do_run,
1136
+ )
1137
+ _restore_torch_export_export(summary)
1138
+ return summary, data
1139
+ if exporter.startswith("onnx-"):
1140
+ # torch export
1141
+ summary, data = call_torch_export_onnx(
1142
+ exporter=exporter,
1143
+ data=data,
1144
+ quiet=quiet,
1145
+ verbose=verbose,
1146
+ optimization=optimization,
1147
+ output_names=output_names,
1148
+ )
1149
+ _restore_torch_export_export(summary)
1150
+ return summary, data
1151
+ if exporter == "custom" or exporter.startswith("custom"):
1152
+ # torch export
1153
+ summary, data = call_torch_export_custom(
1154
+ exporter=exporter,
1155
+ data=data,
1156
+ quiet=quiet,
1157
+ verbose=verbose,
1158
+ optimization=optimization,
1159
+ dump_folder=dump_folder,
1160
+ output_names=output_names,
1161
+ )
1162
+ _restore_torch_export_export(summary)
1163
+ return summary, data
1164
+ if exporter == "modelbuilder":
1165
+ # torch export
1166
+ summary, data = call_torch_export_model_builder(
1167
+ exporter=exporter,
1168
+ data=data,
1169
+ quiet=quiet,
1170
+ verbose=verbose,
1171
+ optimization=optimization,
1172
+ output_names=output_names,
1173
+ )
1174
+ _restore_torch_export_export(summary)
1175
+ return summary, data
1176
+ raise NotImplementedError(
1177
+ f"export with {exporter!r} and optimization={optimization!r} not implemented yet, "
1178
+ f"exporter must startswith 'onnx-', 'custom', 'export', 'modelbuilder' "
1179
+ f"(onnx-dynamo, custom, export), optimization can 'ir', "
1180
+ f"'default', 'default+onnxruntime', "
1181
+ f"'default+onnxruntime+os_ort', 'ir', 'os_ort'"
1182
+ )
1183
+
1184
+
1185
+ def call_torch_export_export(
1186
+ data: Dict[str, Any],
1187
+ exporter: str,
1188
+ quiet: bool = False,
1189
+ verbose: int = 0,
1190
+ optimization: Optional[str] = None,
1191
+ do_run: bool = False,
1192
+ ):
1193
+ """
1194
+ Exports a model with :func:`torch.export.export`.
1195
+ If a patch must be applied, it should be before this functions.
1196
+
1197
+ :param data: dictionary with all the necessary inputs, the dictionary must
1198
+ contains keys ``model`` and ``inputs_export``
1199
+ :param exporter: exporter to call
1200
+ :param quiet: catch exception or not
1201
+ :param verbose: verbosity
1202
+ :param optimization: optimization to do
1203
+ :param do_run: runs and compute discrepancies
1204
+ :return: two dictionaries, one with some metrics,
1205
+ another one with whatever the function produces
1206
+ """
1207
+ assert exporter in {
1208
+ "export",
1209
+ "export-strict",
1210
+ "export-nostrict",
1211
+ }, f"Unexpected value for exporter={exporter!r}"
1212
+ assert not optimization, f"No optimization is implemented for exporter={exporter!r}"
1213
+ assert "model" in data, f"model is missing from data: {sorted(data)}"
1214
+ assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
1215
+ summary: Dict[str, Union[str, int, float]] = {}
1216
+ strict = "-strict" in exporter
1217
+ args, kwargs = split_args_kwargs(data["inputs_export"])
1218
+ ds = data.get("dynamic_shapes", None)
1219
+
1220
+ summary["export_exporter"] = exporter
1221
+ summary["export_optimization"] = optimization or ""
1222
+ summary["export_strict"] = strict
1223
+ summary["export_args"] = string_type(args, with_shape=True)
1224
+ summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1225
+ summary["export_dynamic_shapes"] = string_type(ds)
1226
+
1227
+ # There is an issue with DynamicShape [[],[]] becomes []
1228
+ dse = use_dyn_not_str(ds)
1229
+ # dse = CoupleInputsDynamicShapes(args, kwargs, ds).replace_string_by()
1230
+
1231
+ summary["export_dynamic_shapes_export_export"] = string_type(dse)
1232
+
1233
+ if verbose:
1234
+ print(
1235
+ f"[call_torch_export_export] exporter={exporter!r}, "
1236
+ f"strict={strict}, optimization={optimization!r}"
1237
+ )
1238
+ print(f"[call_torch_export_export] args={string_type(args, with_shape=True)}")
1239
+ print(f"[call_torch_export_export] kwargs={string_type(kwargs, with_shape=True)}")
1240
+ print(f"[call_torch_export_export] dynamic_shapes={string_type(ds)}")
1241
+ print(f"[call_torch_export_export] dynamic_shapes_export_export={string_type(dse)}")
1242
+ print("[call_torch_export_export] export...")
1243
+
1244
+ model = data["model"]
1245
+ ep = _quiet_or_not_quiet(
1246
+ quiet,
1247
+ "export_export",
1248
+ summary,
1249
+ data,
1250
+ (
1251
+ lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: (
1252
+ torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s)
1253
+ )
1254
+ ),
1255
+ )
1256
+ if "ERR_export_export" in summary:
1257
+ return summary, data
1258
+
1259
+ summary["export_graph_nodes"] = len(ep.graph.nodes)
1260
+ if verbose:
1261
+ print(
1262
+ f"[call_torch_export_export] done (export) "
1263
+ f"with {summary['export_graph_nodes']} nodes"
1264
+ )
1265
+ data["exported_program"] = ep
1266
+ if verbose > 1:
1267
+ print("[call_torch_export_export] -- ExportedProgram")
1268
+ print(ep)
1269
+ print("[call_torch_export_export] -- End of ExportedProgram")
1270
+
1271
+ if do_run:
1272
+ # We check for discrepancies.
1273
+ if verbose:
1274
+ print("[validate_model] run exported model...")
1275
+ print(
1276
+ f"[validate_model] patched inputs="
1277
+ f"{string_type(data['inputs_export'], with_shape=True)}"
1278
+ )
1279
+ hash_inputs = string_type(data["inputs_export"], with_shape=True)
1280
+
1281
+ # We make a copy of the input just in case the model modifies them inplace
1282
+ inputs = torch_deepcopy(data["inputs_export"])
1283
+ model = ep.module()
1284
+
1285
+ expected = _quiet_or_not_quiet(
1286
+ quiet,
1287
+ "run_exported",
1288
+ summary,
1289
+ data,
1290
+ (lambda m=model, inputs=inputs: (model(**inputs))),
1291
+ )
1292
+ if "ERR_export_export" in summary:
1293
+ return summary, data
1294
+
1295
+ disc = max_diff(data["run_expected"], expected)
1296
+ for k, v in disc.items():
1297
+ summary[f"disc_exported_{k}"] = str(v)
1298
+ if verbose:
1299
+ print("[validate_model] done (exported run)")
1300
+ print(f"[validate_model] exported discrepancies={string_diff(disc)}")
1301
+ assert hash_inputs == string_type(data["inputs_export"], with_shape=True), (
1302
+ f"The exported model execution did modified the inputs:\n"
1303
+ f"before: {hash_inputs}\n"
1304
+ f" after: {string_type(data['inputs_export'], with_shape=True)}"
1305
+ )
1306
+ return summary, data
1307
+
1308
+
1309
+ def validate_onnx_model(
1310
+ data: Dict[str, Any],
1311
+ quiet: bool = False,
1312
+ verbose: int = 0,
1313
+ flavour: Optional[str] = None,
1314
+ runtime: str = "onnxruntime",
1315
+ repeat: int = 1,
1316
+ warmup: int = 0,
1317
+ second_input_keys: Optional[List[str]] = None,
1318
+ ort_logs: bool = False,
1319
+ quiet_input_sets: Optional[Set[str]] = None,
1320
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1321
+ """
1322
+ Verifies that an onnx model produces the same
1323
+ expected outputs. It uses ``data["onnx_filename]`` as the input
1324
+ onnx filename or ``data["onnx_filename_{flavour}]`` if *flavour*
1325
+ is specified.
1326
+
1327
+ :param data: dictionary with all the necessary inputs, the dictionary must
1328
+ contains keys ``model`` and ``inputs_export``
1329
+ :param quiet: catch exception or not
1330
+ :param verbose: verbosity
1331
+ :param flavour: use a different version of the inputs
1332
+ :param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
1333
+ :param repeat: run that number of times the model
1334
+ :param warmup: warmup the model
1335
+ :param second_input_keys: to validate the model on other input sets
1336
+ to make sure the exported model supports dynamism, the value is
1337
+ used as an increment added to the first set of inputs (added to dimensions)
1338
+ :param ort_logs: triggers the logs for onnxruntime
1339
+ :param quiet_input_sets: avoid raising an exception for these sets of inputs
1340
+ :return: two dictionaries, one with some metrics,
1341
+ another one with whatever the function produces
1342
+ """
1343
+ import onnxruntime
1344
+
1345
+ def _mk(key, flavour=flavour):
1346
+ return f"{key}_{flavour}" if flavour else key
1347
+
1348
+ summary: Dict[str, Any] = {}
1349
+ flat_inputs = flatten_object(data["inputs"], drop_keys=True)
1350
+ d = flat_inputs[0].get_device()
1351
+ providers = (
1352
+ ["CPUExecutionProvider"]
1353
+ if d < 0
1354
+ else ["CUDAExecutionProvider", "CPUExecutionProvider"]
1355
+ )
1356
+ input_data_key = f"onnx_filename_{flavour}" if flavour else "onnx_filename"
1357
+
1358
+ if input_data_key in data:
1359
+ source = data[input_data_key]
1360
+ if not os.path.exists(source):
1361
+ if verbose:
1362
+ print(f"[validate_onnx_model] missing {source!r}")
1363
+ summary[_mk("ERR_onnx_missing")] = f"FileNotFoundError({source!r})"
1364
+ return summary, data
1365
+ summary[input_data_key] = source
1366
+ summary[_mk("onnx_size")] = os.stat(source).st_size
1367
+ else:
1368
+ assert not flavour, f"flavour={flavour!r}, the filename must be saved."
1369
+ assert (
1370
+ "onnx_program" in data
1371
+ ), f"onnx_program is missing from data which has {sorted(data)}"
1372
+ source = data["onnx_program"].model_proto.SerializeToString()
1373
+ assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb"
1374
+ summary[_mk("onnx_size")] = len(source)
1375
+ if verbose:
1376
+ print(
1377
+ f"[validate_onnx_model] verify onnx model with providers "
1378
+ f"{providers}..., flavour={flavour!r}"
1379
+ )
1380
+
1381
+ if runtime == "onnxruntime":
1382
+ if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
1383
+ opts = onnxruntime.SessionOptions()
1384
+ opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
1385
+ if verbose:
1386
+ print(
1387
+ f"[validate_onnx_model] saved optimized onnxruntime "
1388
+ f"in {opts.optimized_model_filepath!r}"
1389
+ )
1390
+ onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers)
1391
+ if verbose:
1392
+ print("[validate_onnx_model] -- done")
1393
+
1394
+ if verbose:
1395
+ print("[validate_onnx_model] runtime is onnxruntime")
1396
+ sess_opts = onnxruntime.SessionOptions()
1397
+ if ort_logs:
1398
+ sess_opts.log_severity_level = 0
1399
+ sess_opts.log_verbosity_level = 4
1400
+ cls_runtime = lambda model, providers, _o=sess_opts: onnxruntime.InferenceSession(
1401
+ (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1402
+ _o,
1403
+ providers=providers,
1404
+ )
1405
+ elif runtime == "torch":
1406
+ from ..reference import TorchOnnxEvaluator
1407
+
1408
+ if verbose:
1409
+ print("[validate_onnx_model] runtime is TorchOnnxEvaluator")
1410
+ cls_runtime = (
1411
+ lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1412
+ model, providers=providers, verbose=max(verbose - 1, 0)
1413
+ )
1414
+ )
1415
+ elif runtime == "orteval":
1416
+ from ..reference import OnnxruntimeEvaluator
1417
+
1418
+ if verbose:
1419
+ print("[validate_onnx_model] runtime is OnnxruntimeEvaluator")
1420
+ cls_runtime = (
1421
+ lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
1422
+ model, providers=providers, verbose=max(verbose - 1, 0)
1423
+ )
1424
+ )
1425
+ elif runtime == "orteval10":
1426
+ from ..reference import OnnxruntimeEvaluator
1427
+
1428
+ if verbose:
1429
+ print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)")
1430
+ cls_runtime = (
1431
+ lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
1432
+ model, providers=providers, verbose=10
1433
+ )
1434
+ )
1435
+ elif runtime == "ref":
1436
+ from ..reference import ExtendedReferenceEvaluator
1437
+
1438
+ if verbose:
1439
+ print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator")
1440
+ cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc]
1441
+ model, verbose=max(verbose - 1, 0)
1442
+ )
1443
+ else:
1444
+ raise ValueError(f"Unexpecteed runtime={runtime!r}")
1445
+
1446
+ sess = _quiet_or_not_quiet(
1447
+ quiet,
1448
+ _mk("create_onnx_ort"),
1449
+ summary,
1450
+ data,
1451
+ (lambda source=source, providers=providers: cls_runtime(source, providers)),
1452
+ )
1453
+ if f"ERR_{_mk('onnx_ort_create')}" in summary:
1454
+ return summary, data
1455
+
1456
+ data[_mk("onnx_ort_sess")] = sess
1457
+ if verbose:
1458
+ print(f"[validate_onnx_model] done (ort_session) flavour={flavour!r}")
1459
+
1460
+ keys = [("inputs", "run_expected", "")]
1461
+ if second_input_keys:
1462
+ keys.extend([(k, f"run_expected2{k[6:]}", f"2{k[6:]}") for k in second_input_keys])
1463
+ if verbose:
1464
+ print(f"[validate_onnx_model] -- keys={keys}")
1465
+ for k_input, k_expected, suffix in keys:
1466
+ # make_feeds
1467
+ assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
1468
+ assert k_expected in data, f"Unable to find {k_expected!r} in {sorted(data)}"
1469
+ if verbose:
1470
+ print(f"[validate_onnx_model] -- make_feeds for {k_input!r}...")
1471
+ print(
1472
+ f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
1473
+ )
1474
+ feeds = make_feeds(
1475
+ sess,
1476
+ data[k_input],
1477
+ use_numpy=True,
1478
+ check_flatten=False,
1479
+ is_modelbuilder=data["exporter"] == "modelbuilder", # to remove position_ids
1480
+ )
1481
+ if verbose:
1482
+ print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
1483
+ summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
1484
+ if verbose:
1485
+ print("[validate_onnx_model] done (make_feeds)")
1486
+
1487
+ # run ort
1488
+ if verbose:
1489
+ print(f"[validate_onnx_model] run session on inputs 'inputs{suffix}'...")
1490
+ if quiet_input_sets and f"inputs{suffix}" in quiet_input_sets:
1491
+ print(f"[validate_onnx_model] quiet_input_sets={quiet_input_sets}")
1492
+
1493
+ got = _quiet_or_not_quiet(
1494
+ quiet or (quiet_input_sets is not None and f"inputs{suffix}" in quiet_input_sets),
1495
+ _mk(f"run_onnx_ort{suffix}"),
1496
+ summary,
1497
+ data,
1498
+ (lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
1499
+ repeat=repeat,
1500
+ warmup=warmup,
1501
+ )
1502
+ if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
1503
+ return summary, data
1504
+
1505
+ summary[f"run_feeds_{k_input}"] = string_type(feeds, with_shape=True, with_device=True)
1506
+ summary[f"run_output_{k_input}"] = string_type(got, with_shape=True, with_device=True)
1507
+ if verbose:
1508
+ print("[validate_onnx_model] done (run)")
1509
+ print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}")
1510
+
1511
+ # compute discrepancies
1512
+ disc = max_diff(data[k_expected], got, flatten=True)
1513
+ if verbose:
1514
+ print(f"[validate_onnx_model] discrepancies={string_diff(disc)}")
1515
+ for k, v in disc.items():
1516
+ summary[_mk(f"disc_onnx_ort_run{suffix}_{k}")] = v
1517
+ return summary, data
1518
+
1519
+
1520
+ def call_torch_export_onnx(
1521
+ data: Dict[str, Any],
1522
+ exporter: str,
1523
+ quiet: bool = False,
1524
+ verbose: int = 0,
1525
+ optimization: Optional[str] = None,
1526
+ output_names: Optional[List[str]] = None,
1527
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1528
+ """
1529
+ Exports a model into onnx.
1530
+ If a patch must be applied, it should be before this functions.
1531
+
1532
+ :param data: dictionary with all the necessary inputs, the dictionary must
1533
+ contains keys ``model`` and ``inputs_export``
1534
+ :param exporter: exporter to call
1535
+ :param quiet: catch exception or not
1536
+ :param verbose: verbosity
1537
+ :param optimization: optimization to do
1538
+ :param output_names: output names to use
1539
+ :return: two dictionaries, one with some metrics,
1540
+ another one with whatever the function produces
1541
+ """
1542
+ available = {None, "", "ir", "os_ort", "ir+default"}
1543
+ assert (
1544
+ optimization in available
1545
+ ), f"unexpected value for optimization={optimization}, available={available}"
1546
+ assert exporter in {
1547
+ "onnx-dynamo",
1548
+ "onnx-script",
1549
+ }, f"Unexpected value for exporter={exporter!r}"
1550
+ assert "model" in data, f"model is missing from data: {sorted(data)}"
1551
+ assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
1552
+ summary: Dict[str, Union[str, int, float]] = {}
1553
+ dynamo = "dynamo" in exporter
1554
+ args, kwargs = split_args_kwargs(data["inputs_export"])
1555
+ ds = data.get("dynamic_shapes", None)
1556
+ if verbose:
1557
+ print(
1558
+ f"[call_torch_export_onnx] exporter={exporter!r}, "
1559
+ f"optimization={optimization!r}"
1560
+ )
1561
+ print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
1562
+ print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
1563
+ print(f"[call_torch_export_onnx] dynamic_shapes={string_type(ds)}")
1564
+ print("[call_torch_export_onnx] export...")
1565
+ summary["export_exporter"] = exporter
1566
+ summary["export_optimization"] = optimization or ""
1567
+ summary["export_dynamo"] = dynamo
1568
+ summary["export_args"] = string_type(args, with_shape=True)
1569
+ summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1570
+ opset = data.get("model_opset", None)
1571
+ if opset:
1572
+ summary["export_opset"] = opset
1573
+
1574
+ if dynamo:
1575
+ export_export_kwargs = dict(dynamo=True, dynamic_shapes=ds)
1576
+ else:
1577
+ export_export_kwargs = dict(
1578
+ dynamo=False,
1579
+ dynamic_axes={
1580
+ k: v
1581
+ for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) # type: ignore[arg-type]
1582
+ .replace_by_string()
1583
+ .items()
1584
+ if isinstance(v, dict)
1585
+ },
1586
+ )
1587
+ args = tuple(flatten_unflatten_for_dynamic_shapes(a) for a in args)
1588
+ kwargs = {k: flatten_unflatten_for_dynamic_shapes(v) for k, v in kwargs.items()}
1589
+ if verbose:
1590
+ print("[call_torch_export_onnx] dynamo=False so...")
1591
+ print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
1592
+ print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
1593
+ if output_names:
1594
+ export_export_kwargs["output_names"] = output_names
1595
+ if opset:
1596
+ export_export_kwargs["opset_version"] = opset
1597
+ if verbose:
1598
+ print(
1599
+ f"[call_torch_export_onnx] export_export_kwargs="
1600
+ f"{string_type(export_export_kwargs, with_shape=True)}"
1601
+ )
1602
+ model = data["model"]
1603
+
1604
+ epo = _quiet_or_not_quiet(
1605
+ quiet,
1606
+ "export_onnx",
1607
+ summary,
1608
+ data,
1609
+ (
1610
+ lambda m=model, args=args, kws=kwargs, ekws=export_export_kwargs: (
1611
+ torch.onnx.export(
1612
+ m,
1613
+ args,
1614
+ kwargs=kws,
1615
+ **ekws,
1616
+ )
1617
+ )
1618
+ ),
1619
+ )
1620
+ if "ERR_export_onnx" in summary:
1621
+ return summary, data
1622
+
1623
+ assert epo is not None, "no onnx export was found"
1624
+ if verbose:
1625
+ print("[call_torch_export_onnx] done (export)")
1626
+ data["onnx_program"] = epo
1627
+ if verbose > 5:
1628
+ print("[call_torch_export_onnx] -- ONNXProgram")
1629
+ print(epo)
1630
+ print("[call_torch_export_onnx] -- End of ONNXProgram")
1631
+
1632
+ if optimization in {"ir", "os_ort", "ir+default"}:
1633
+ if verbose:
1634
+ print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
1635
+ if optimization == "ir":
1636
+ label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
1637
+ elif optimization == "ir+default":
1638
+ import onnxscript
1639
+ from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
1640
+
1641
+ def _ir_default_opt(epo):
1642
+ onnxscript.optimizer.optimize_ir(epo.model)
1643
+ onx = epo.model_proto
1644
+ # not very efficient
1645
+ gr = GraphBuilder(
1646
+ onx,
1647
+ infer_shapes_options=True,
1648
+ optimization_options=OptimizationOptions(patterns="default"),
1649
+ )
1650
+ cont = gr.to_onnx(large_model=True)
1651
+ epo.model = cont.to_ir()
1652
+
1653
+ label, f_optim = "export_onnx_opt_ir_default", (
1654
+ lambda epo=epo: _ir_default_opt(epo)
1655
+ )
1656
+
1657
+ else:
1658
+ import onnxscript
1659
+ import onnxscript.rewriter.ort_fusions as ort_fusions
1660
+
1661
+ def _os_ort_optim(epo):
1662
+ onnxscript.optimizer.optimize_ir(epo.model)
1663
+ optimized = ort_fusions.optimize_for_ort(epo.model)
1664
+ if isinstance(optimized, tuple):
1665
+ for k, v in optimized[1].items():
1666
+ summary[f"op_opt_fused_{k}"] = v
1667
+ epo.model = optimized[0]
1668
+ else:
1669
+ epo.model = optimized
1670
+
1671
+ label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo))
1672
+ _quiet_or_not_quiet(quiet, label, summary, data, f_optim)
1673
+ if "ERR_export_onnx_opt_ir" in summary:
1674
+ return summary, data
1675
+ if verbose:
1676
+ print("[call_torch_export_onnx] done (optimization)")
1677
+
1678
+ return summary, data
1679
+
1680
+
1681
+ def call_torch_export_model_builder(
1682
+ data: Dict[str, Any],
1683
+ exporter: str,
1684
+ quiet: bool = False,
1685
+ verbose: int = 0,
1686
+ optimization: Optional[str] = None,
1687
+ output_names: Optional[List[str]] = None,
1688
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1689
+ """
1690
+ Exports a model into onnx with :epkg:`ModelBuilder`.
1691
+
1692
+ :param data: dictionary with all the necessary inputs, the dictionary must
1693
+ contains keys ``model`` and ``inputs_export``
1694
+ :param exporter: exporter to call
1695
+ :param quiet: catch exception or not
1696
+ :param verbose: verbosity
1697
+ :param optimization: optimization to do
1698
+ :param output_names: list of output names to use
1699
+ :return: two dictionaries, one with some metrics,
1700
+ another one with whatever the function produces
1701
+ """
1702
+ from ..helpers.model_builder_helper import create_model_builder, save_model_builder
1703
+
1704
+ assert optimization in (
1705
+ None,
1706
+ "",
1707
+ ), f"unexpected value for optimization={optimization}, none is available"
1708
+ precision = data.get("model_dtype", "fp32")
1709
+ provider = data.get("model_device", "cpu")
1710
+ dump_folder = data.get("model_dump_folder", "")
1711
+ assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
1712
+ assert (
1713
+ not output_names
1714
+ ), f"output_names not empty, not supported yet, output_names={output_names}"
1715
+ cache_dir = os.path.join(dump_folder, "cache_mb")
1716
+ if not os.path.exists(cache_dir):
1717
+ os.makedirs(cache_dir)
1718
+ summary: Dict[str, Any] = {}
1719
+
1720
+ epo = _quiet_or_not_quiet(
1721
+ quiet,
1722
+ "export_model_builder",
1723
+ summary,
1724
+ data,
1725
+ (
1726
+ lambda m=data["model"], c=data[
1727
+ "configuration"
1728
+ ], p=precision, pr=provider, cd=cache_dir: (
1729
+ save_model_builder(
1730
+ create_model_builder(
1731
+ c, m, precision=p, execution_provider=pr, cache_dir=cd
1732
+ )
1733
+ )
1734
+ )
1735
+ ),
1736
+ )
1737
+ if "ERR_export_model_builder" in summary:
1738
+ return summary, data
1739
+
1740
+ assert epo is not None, "no onnx export was found"
1741
+ if verbose:
1742
+ print("[call_torch_export_model_builder] done (export)")
1743
+ data["onnx_program"] = epo
1744
+ return summary, data
1745
+
1746
+
1747
+ def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]:
1748
+ """
1749
+ Processes statistics coming from the exporters.
1750
+ It takes a sequence of dictionaries (like a data frame)
1751
+ and extracts some metrics.
1752
+ """
1753
+
1754
+ def _simplify(p):
1755
+ for s in [
1756
+ "remove_unused",
1757
+ "constant_folding",
1758
+ "remove_identity",
1759
+ "remove_duplicated_initializer",
1760
+ "remove_duplicated_shape",
1761
+ "dynamic_dimension_naming",
1762
+ "inline",
1763
+ "check",
1764
+ "build_graph_for_pattern",
1765
+ "pattern_optimization",
1766
+ "topological_sort",
1767
+ ]:
1768
+ if s in p or s.replace("_", "-") in p:
1769
+ return s
1770
+ if p.startswith(("apply_", "match_")):
1771
+ return p
1772
+ return "other"
1773
+
1774
+ def _add(d, a, v, use_max=False):
1775
+ if v:
1776
+ if a not in d:
1777
+ d[a] = v
1778
+ elif use_max:
1779
+ d[a] = max(d[a], v)
1780
+ else:
1781
+ d[a] += v
1782
+
1783
+ counts: Dict[str, Any] = {}
1784
+ applied_pattern_time: Dict[str, Any] = {}
1785
+ applied_pattern_n: Dict[str, Any] = {}
1786
+ matching_pattern_time: Dict[str, Any] = {}
1787
+ matching_pattern_n: Dict[str, Any] = {}
1788
+
1789
+ for obs in data:
1790
+ pattern = _simplify(obs["pattern"])
1791
+ _add(counts, "opt_nodes_added", obs.get("added", 0))
1792
+ _add(counts, "opt_nodes_removed", obs.get("removed", 0))
1793
+ _add(counts, "opt_time_steps", obs.get("time_in", 0))
1794
+ _add(counts, "opt_n_steps", 1)
1795
+ _add(
1796
+ counts,
1797
+ "opt_n_iteration",
1798
+ max(counts.get("opt_n_iteration", 0), obs.get("iteration", 0)),
1799
+ use_max=True,
1800
+ )
1801
+
1802
+ if pattern.startswith("apply_"):
1803
+ _add(counts, "opt_n_applied_patterns", 1)
1804
+ _add(counts, "opt_time_applied_patterns", obs.get("time_in", 0))
1805
+ _add(applied_pattern_time, pattern, obs.get("time_in", 0))
1806
+ _add(applied_pattern_n, pattern, 1)
1807
+ elif pattern.startswith("match_"):
1808
+ _add(counts, "opt_n_matching_patterns", 1)
1809
+ _add(counts, "opt_time_matching_patterns", obs.get("time_in", 0))
1810
+ _add(matching_pattern_time, pattern, obs.get("time_in", 0))
1811
+ _add(matching_pattern_n, pattern, 1)
1812
+ else:
1813
+ _add(counts, f"opt_time_{pattern}", obs.get("time_in", 0))
1814
+ _add(counts, f"opt_n_{pattern}", 1)
1815
+ _add(counts, f"opt_nodes_added_{pattern}", obs.get("added", 0))
1816
+ _add(counts, f"opt_nodes_removed_{pattern}", obs.get("removed", 0))
1817
+
1818
+ if applied_pattern_time:
1819
+ longest = max((v, k) for k, v in applied_pattern_time.items())
1820
+ counts["opt_top_time_applied_pattern"], counts["opt_top_time_applied_pattern_arg"] = (
1821
+ longest
1822
+ )
1823
+ longest = max((v, k) for k, v in applied_pattern_n.items())
1824
+ counts["opt_top_n_applied_pattern"], counts["opt_top_n_applied_pattern_arg"] = longest
1825
+
1826
+ if matching_pattern_time:
1827
+ longest = max((v, k) for k, v in matching_pattern_time.items())
1828
+ (
1829
+ counts["opt_top_time_matching_pattern"],
1830
+ counts["opt_top_time_matching_pattern_arg"],
1831
+ ) = longest
1832
+ longest = max((v, k) for k, v in matching_pattern_n.items())
1833
+ counts["opt_top_n_matching_pattern"], counts["opt_top_n_matching_pattern_arg"] = (
1834
+ longest
1835
+ )
1836
+ counts["onnx_opt_optimized"] = 1
1837
+ return counts
1838
+
1839
+
1840
+ def call_torch_export_custom(
1841
+ data: Dict[str, Any],
1842
+ exporter: str,
1843
+ quiet: bool = False,
1844
+ verbose: int = 0,
1845
+ optimization: Optional[str] = None,
1846
+ dump_folder: Optional[str] = None,
1847
+ output_names: Optional[List[str]] = None,
1848
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1849
+ """
1850
+ Exports a model into onnx.
1851
+ If a patch must be applied, it should be before this functions.
1852
+
1853
+ :param data: dictionary with all the necessary inputs, the dictionary must
1854
+ contains keys ``model`` and ``inputs_export``
1855
+ :param exporter: exporter to call
1856
+ :param quiet: catch exception or not
1857
+ :param verbose: verbosity
1858
+ :param optimization: optimization to do
1859
+ :param dump_folder: to store additional information
1860
+ :param output_names: list of output names to use
1861
+ :return: two dictionaries, one with some metrics,
1862
+ another one with whatever the function produces
1863
+ """
1864
+ available = {
1865
+ "",
1866
+ "default",
1867
+ "default+onnxruntime",
1868
+ "default+os_ort",
1869
+ "default+onnxruntime+os_ort",
1870
+ None,
1871
+ }
1872
+ if optimization == "none":
1873
+ optimization = ""
1874
+ assert (
1875
+ optimization in available
1876
+ ), f"unexpected value for optimization={optimization}, available={available}"
1877
+ available = {
1878
+ "custom",
1879
+ "custom-strict",
1880
+ "custom-strict-default",
1881
+ "custom-strict-all",
1882
+ "custom-nostrict",
1883
+ "custom-nostrict-default",
1884
+ "custom-nostrict-all",
1885
+ "custom-noinline",
1886
+ "custom-strict-noinline",
1887
+ "custom-strict-default-noinline",
1888
+ "custom-strict-all-noinline",
1889
+ "custom-nostrict-noinline",
1890
+ "custom-nostrict-default-noinline",
1891
+ "custom-nostrict-all-noinline",
1892
+ "custom-dec",
1893
+ "custom-decall",
1894
+ "custom-fake",
1895
+ }
1896
+ assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1897
+ assert "model" in data, f"model is missing from data: {sorted(data)}"
1898
+ assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
1899
+ summary: Dict[str, Union[str, int, float]] = {}
1900
+ strict = "-strict" in exporter
1901
+ args, kwargs = split_args_kwargs(data["inputs_export"])
1902
+ ds = data.get("dynamic_shapes", None)
1903
+ if "-fake" in exporter:
1904
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1905
+
1906
+ if verbose:
1907
+ print("[call_torch_export_custom] switching to FakeTensor")
1908
+ assert not args, f"Exporter {exporter!r} not implemented with fake tensors."
1909
+ kwargs = torch_deepcopy(kwargs)
1910
+ kwargs, _ = make_fake_with_dynamic_dimensions(kwargs, dynamic_shapes=ds)
1911
+ opset = data.get("model_opset", None)
1912
+ if opset:
1913
+ summary["export_opset"] = opset
1914
+ if verbose:
1915
+ print(
1916
+ f"[call_torch_export_custom] exporter={exporter!r}, "
1917
+ f"optimization={optimization!r}"
1918
+ )
1919
+ print(f"[call_torch_export_custom] args={string_type(args, with_shape=True)}")
1920
+ print(f"[call_torch_export_custom] kwargs={string_type(kwargs, with_shape=True)}")
1921
+ print(f"[call_torch_export_custom] dynamic_shapes={string_type(ds)}")
1922
+ print("[call_torch_export_custom] export...")
1923
+ summary["export_exporter"] = exporter
1924
+ summary["export_optimization"] = optimization or ""
1925
+ summary["export_strict"] = strict
1926
+ summary["export_args"] = string_type(args, with_shape=True)
1927
+ summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1928
+
1929
+ from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
1930
+ from experimental_experiment.xbuilder import OptimizationOptions
1931
+
1932
+ spl = optimization.split("+") if optimization else []
1933
+ os_ort = "os_ort" in spl
1934
+ optimization = "+".join(_ for _ in spl if _ != "os_ort")
1935
+
1936
+ export_options = ExportOptions(
1937
+ strict=strict,
1938
+ decomposition_table=(
1939
+ "default"
1940
+ if ("-default" in exporter or "-dec" in exporter)
1941
+ else ("all" if ("-all" in exporter or "-decall" in exporter) else None)
1942
+ ),
1943
+ save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
1944
+ )
1945
+ inline = "-noinline" not in exporter
1946
+ options = OptimizationOptions(patterns=optimization) if optimization else None
1947
+ model = data["model"]
1948
+ kws = dict(
1949
+ dynamic_shapes=ds,
1950
+ export_options=export_options,
1951
+ options=options,
1952
+ optimize=bool(optimization),
1953
+ large_model=True,
1954
+ return_optimize_report=True,
1955
+ verbose=max(verbose - 2, 0),
1956
+ inline=inline,
1957
+ )
1958
+ if opset:
1959
+ kws["target_opset"] = opset
1960
+ if output_names:
1961
+ kws["output_names"] = output_names
1962
+
1963
+ epo, opt_stats = _quiet_or_not_quiet(
1964
+ quiet,
1965
+ "export_export_onnx_c",
1966
+ summary,
1967
+ data,
1968
+ (
1969
+ lambda m=model, args=args, kwargs=kwargs, kws=kws: (
1970
+ to_onnx(
1971
+ model,
1972
+ args,
1973
+ kwargs=kwargs,
1974
+ **kws,
1975
+ )
1976
+ )
1977
+ ),
1978
+ )
1979
+ if "ERR_export_onnx_c" in summary:
1980
+ return summary, data
1981
+
1982
+ new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
1983
+ new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")})
1984
+ if "optimization" in opt_stats:
1985
+ new_stat.update(process_statistics(opt_stats["optimization"]))
1986
+
1987
+ summary.update(new_stat)
1988
+ assert epo is not None, "no onnx export was found"
1989
+ if verbose:
1990
+ print("[call_torch_export_custom] done (export)")
1991
+
1992
+ if os_ort:
1993
+ import onnxscript
1994
+ import onnxscript.rewriter.ort_fusions as ort_fusions
1995
+
1996
+ if verbose:
1997
+ print("[call_torch_export_custom] conversion to IR...")
1998
+ begin = time.perf_counter()
1999
+ ir_model = epo.to_ir()
2000
+ duration = time.perf_counter() - begin
2001
+ summary["time_optim_to_ir"] = duration
2002
+ if verbose:
2003
+ print(f"[call_torch_export_custom] done in {duration}")
2004
+ print("[call_torch_export_custom] start optimization...")
2005
+ begin = time.perf_counter()
2006
+ onnxscript.optimizer.optimize_ir(ir_model)
2007
+ ir_optimized = ort_fusions.optimize_for_ort(ir_model)
2008
+ if isinstance(ir_optimized, tuple):
2009
+ report = ir_optimized[1]
2010
+ for k, v in report.items():
2011
+ summary[f"op_opt_fused_{k}"] = v
2012
+ ir_optimized = ir_optimized[0]
2013
+ epo.model = ir_optimized
2014
+ duration = time.perf_counter() - begin
2015
+ summary["time_optim_os_ort"] = duration
2016
+ if verbose:
2017
+ print(f"[call_torch_export_custom] done in {duration}")
2018
+
2019
+ data["onnx_program"] = epo
2020
+ return summary, data
2021
+
2022
+
2023
+ def run_ort_fusion(
2024
+ model_or_path: Union[str, onnx.ModelProto],
2025
+ output_path: str,
2026
+ num_attention_heads: int,
2027
+ hidden_size: int,
2028
+ model_type: str = "bert",
2029
+ verbose: int = 0,
2030
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
2031
+ """
2032
+ Runs :epkg:`onnxruntime` fusion optimizer.
2033
+
2034
+ :param model_or_path: path to the ModelProto or the ModelProto itself
2035
+ :param output_path: the model to save
2036
+ :param num_attention_heads: number of heads, usually ``config.num_attention_heads``
2037
+ :param hidden_size: hidden size, usually ``config.hidden_size``
2038
+ :param model_type: type of optimization, see below
2039
+ :param verbose: verbosity
2040
+ :return: two dictionaries, summary and data
2041
+
2042
+ Supported values for ``model_type``:
2043
+
2044
+ .. runpython::
2045
+ :showcode:
2046
+
2047
+ import pprint
2048
+ from onnxruntime.transformers.optimizer import MODEL_TYPES
2049
+
2050
+ pprint.pprint(sorted(MODEL_TYPES))
2051
+ """
2052
+ from onnxruntime.transformers.optimizer import optimize_by_fusion
2053
+ from onnxruntime.transformers.fusion_options import FusionOptions
2054
+
2055
+ opts = FusionOptions(model_type)
2056
+
2057
+ if isinstance(model_or_path, str):
2058
+ if verbose:
2059
+ print(f"[run_ort_fusion] loads {model_or_path!r}")
2060
+ onx = onnx.load(model_or_path)
2061
+ else:
2062
+ onx = model_or_path
2063
+ begin = time.perf_counter()
2064
+ n_nodes = len(onx.graph.node)
2065
+ if verbose:
2066
+ print(
2067
+ f"[run_ort_fusion] starts optimization for "
2068
+ f"model_type={model_type!r} with {n_nodes} nodes"
2069
+ )
2070
+ try:
2071
+ new_onx = optimize_by_fusion(
2072
+ onx,
2073
+ model_type=model_type,
2074
+ num_heads=num_attention_heads,
2075
+ hidden_size=hidden_size,
2076
+ optimization_options=opts,
2077
+ )
2078
+ except Exception as e:
2079
+ duration = time.perf_counter() - begin
2080
+ if verbose:
2081
+ print(f"[run_ort_fusion] failed in {duration} for model_type={model_type!r}")
2082
+ return {
2083
+ f"ERR_opt_ort_{model_type}": str(e),
2084
+ f"opt_ort_{model_type}_duration": duration,
2085
+ }, {}
2086
+
2087
+ duration = time.perf_counter() - begin
2088
+ delta = len(new_onx.model.graph.node)
2089
+ if verbose:
2090
+ print(f"[run_ort_fusion] done in {duration} with {delta} nodes")
2091
+ print(f"[run_ort_fusion] save to {output_path!r}")
2092
+ begin = time.perf_counter()
2093
+ new_onx.save_model_to_file(output_path, use_external_data_format=True)
2094
+ d = time.perf_counter() - begin
2095
+ if verbose:
2096
+ print(f"[run_ort_fusion] done in {d}")
2097
+ return {
2098
+ f"opt_ort_{model_type}_n_nodes1": n_nodes,
2099
+ f"opt_ort_{model_type}_n_nodes2": delta,
2100
+ f"opt_ort_{model_type}_delta_node": delta - n_nodes,
2101
+ f"opt_ort_{model_type}_duration": duration,
2102
+ f"opt_ort_{model_type}_duration_save": d,
2103
+ }, {f"opt_ort_{model_type}": output_path}
2104
+
2105
+
2106
+ def _compute_final_statistics(summary: Dict[str, Any]):
2107
+ """
2108
+ Updates inline the list of statistics. It adds:
2109
+
2110
+ - speedup
2111
+ """
2112
+ stats = {}
2113
+ if (
2114
+ "time_run_latency" in summary
2115
+ and "time_run_onnx_ort_latency" in summary
2116
+ and summary["time_run_onnx_ort_latency"] > 0
2117
+ ):
2118
+ stats["stat_estimated_speedup_ort"] = (
2119
+ summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
2120
+ )
2121
+ stats["stat_estimated_speedup_ort_m98"] = (
2122
+ summary["time_run_latency_m98"] / summary["time_run_onnx_ort_latency_m98"]
2123
+ )
2124
+ summary.update(stats)