onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,725 @@
1
+ import contextlib
2
+ import io
3
+ import itertools
4
+ import re
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+ import numpy as np
7
+ import onnx
8
+
9
+
10
+ def discover():
11
+ """
12
+ Discovers all model cases used to evaluate an exporter.
13
+
14
+ .. runpython::
15
+ :showcode:
16
+
17
+ import pprint
18
+ from onnx_diagnostic.torch_export_patches.eval import discover
19
+
20
+ pprint.pprint(discover())
21
+ """
22
+ from . import model_cases
23
+
24
+ res = {}
25
+ for m in model_cases.__dict__.values():
26
+ if m is None or isinstance(m, str):
27
+ continue
28
+ if not hasattr(m, "forward"):
29
+ continue
30
+ assert m.__name__ not in res, f"Case {m.__name__!r} is duplicated."
31
+ assert hasattr(m, "_inputs"), f"Attribute '_inputs' is missing from class {m}"
32
+ assert hasattr(m, "_dynamic"), f"Attribute '_dynamic' is missing from class {m}"
33
+ res[m.__name__] = m
34
+ return res
35
+
36
+
37
+ def evaluation(
38
+ exporters: Tuple[str] = (
39
+ "export-strict",
40
+ "export-nostrict",
41
+ "export-nostrict-decall",
42
+ "export-strict-oblivious",
43
+ "export-nostrict-oblivious",
44
+ "export-nostrict-decall-oblivious",
45
+ ),
46
+ dynamic: Tuple[bool] = (False, True),
47
+ cases: Optional[Union[str, Dict[str, type]]] = None,
48
+ verbose: int = 0,
49
+ quiet: bool = True,
50
+ ) -> List[Dict[str, Any]]:
51
+ """
52
+ Evaluates exporter for a list of cases.
53
+
54
+ :param exporters: exporters to evaluate
55
+ :param dynamic: evaluate static shape and dynamic shapes
56
+ :param cases: model cases to evaluate
57
+ :param verbose: verbosity
58
+ :param quiet: catch exception
59
+ :return: results, list of dictionaries
60
+ """
61
+ if isinstance(exporters, str):
62
+ exporters = (exporters,)
63
+ if isinstance(dynamic, (bool, int)):
64
+ dynamic = (dynamic,)
65
+
66
+ if cases is None:
67
+ cases = discover()
68
+ elif cases in ("three", ["three"]):
69
+ all_cases = discover()
70
+ cases = dict(list(all_cases.items())[:3])
71
+ elif isinstance(cases, str):
72
+ cases = (cases,)
73
+
74
+ if isinstance(cases, (list, tuple)):
75
+ all_cases = discover()
76
+ new_cases = [] # type: ignore[var-annotated]
77
+ for c in cases:
78
+ if "*" in c or "?" in c:
79
+ # regex
80
+ reg = re.compile(c)
81
+ new_cases.extend(k for k in all_cases if reg.match(k))
82
+ else:
83
+ new_cases.append(c)
84
+ cases = {k: v for k, v in all_cases.items() if k in set(new_cases)}
85
+
86
+ sorted_cases = sorted(cases.items())
87
+ loop = list(itertools.product(sorted_cases, dynamic, exporters))
88
+ if verbose:
89
+ try:
90
+ import tqdm
91
+
92
+ loop = tqdm.tqdm(loop)
93
+ except ImportError:
94
+
95
+ def _loop():
96
+ for _ in loop:
97
+ print(f"[evaluation] {_}")
98
+ yield _
99
+
100
+ assert len(loop) > 0, f"No case to test for cases={cases!r}."
101
+ obs = []
102
+ for case, dyn, exporter in loop:
103
+ name, cls_model = case
104
+ res = run_exporter(exporter, cls_model, dyn, quiet=quiet, verbose=max(0, verbose - 1))
105
+ res.update(dict(name=name, dynamic=int(dyn), exporter=exporter))
106
+ obs.append(res)
107
+ return obs
108
+
109
+
110
+ def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
111
+ """Flatten inputs."""
112
+ if x is None:
113
+ return x
114
+ import torch
115
+
116
+ if isinstance(x, (list, tuple)):
117
+ res = []
118
+ for i in x:
119
+ if i is None or isinstance(
120
+ i,
121
+ (
122
+ torch.Tensor,
123
+ torch.SymInt,
124
+ torch.SymFloat,
125
+ int,
126
+ float,
127
+ ),
128
+ ):
129
+ res.append(i)
130
+ else:
131
+ res.extend(_flatten_inputs(i))
132
+ return tuple(res) if isinstance(x, tuple) else res
133
+ raise AssertionError(f"Unexpected type {type(x)} for x")
134
+
135
+
136
+ def _to_numpy(x):
137
+ if hasattr(x, "numpy"):
138
+ return x.numpy()
139
+ if isinstance(x, int):
140
+ # onnxruntime does not like scalar
141
+ return np.array([x], dtype=np.int64)
142
+ if isinstance(x, float):
143
+ # onnxruntime does not like scalar
144
+ return np.array([x], dtype=np.float32)
145
+ if isinstance(x, list):
146
+ return [_to_numpy(_) for _ in x]
147
+ if isinstance(x, tuple):
148
+ return tuple(_to_numpy(_) for _ in x)
149
+ raise TypeError(f"Unable to convert type {type(x)}, x={x} into numpy")
150
+
151
+
152
+ def _make_feeds(names, args):
153
+ if len(names) == len(args):
154
+ return {k: _to_numpy(v) for k, v in zip(names, args)}
155
+ if len(names) > len(args):
156
+ flats = _flatten_inputs(args)
157
+ return {k: _to_numpy(v) for k, v in zip(names, flats)}
158
+ from ...helpers import string_type
159
+
160
+ raise RuntimeError(
161
+ f"Unable to handle names={names!r} and args={string_type(args, limit=20)}"
162
+ )
163
+
164
+
165
+ def _clone(x):
166
+ if hasattr(x, "clone"):
167
+ return x.clone()
168
+ if isinstance(x, (int, float)):
169
+ return x
170
+ if isinstance(x, list):
171
+ return [_clone(_) for _ in x]
172
+ if isinstance(x, tuple):
173
+ return tuple(_clone(_) for _ in x)
174
+ raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
175
+
176
+
177
+ def _wrap_torch_export(*args, backed_size_oblivious=False, **kwargs):
178
+ import torch
179
+
180
+ if backed_size_oblivious:
181
+ with torch.fx.experimental._config.patch(backed_size_oblivious=True):
182
+ return torch.export.export(*args, **kwargs)
183
+ return torch.export.export(*args, **kwargs)
184
+
185
+
186
+ def _make_exporter_export(
187
+ exporter: str,
188
+ model: "torch.nn.Module", # noqa: F821
189
+ inputs: Tuple[Any, ...],
190
+ dynamic_shapes: Optional[Any] = None,
191
+ verbose: int = 0,
192
+ quiet: bool = True,
193
+ ) -> Union[Dict, Callable]:
194
+ import torch
195
+
196
+ backed_size_oblivious = "-oblivious" in exporter
197
+ strict = "-nostrict" not in exporter
198
+
199
+ if exporter in (
200
+ "export-strict",
201
+ "export-strict-oblivious",
202
+ "export-nostrict",
203
+ "export-nostrict-oblivious",
204
+ "export-oblivious",
205
+ ):
206
+ try:
207
+ if verbose >= 2:
208
+ exported = _wrap_torch_export(
209
+ model,
210
+ inputs,
211
+ dynamic_shapes=dynamic_shapes,
212
+ strict=strict,
213
+ backed_size_oblivious=backed_size_oblivious,
214
+ )
215
+ else:
216
+ with (
217
+ contextlib.redirect_stdout(io.StringIO()),
218
+ contextlib.redirect_stderr(io.StringIO()),
219
+ ):
220
+ exported = _wrap_torch_export(
221
+ model,
222
+ inputs,
223
+ dynamic_shapes=dynamic_shapes,
224
+ strict=strict,
225
+ backed_size_oblivious=backed_size_oblivious,
226
+ )
227
+ except Exception as e:
228
+ if not quiet:
229
+ raise
230
+ return dict(error=str(e), success=0, error_step="export")
231
+ if verbose >= 9:
232
+ print("-- graph")
233
+ print(exported.graph)
234
+ return exported.module()
235
+
236
+ if exporter in (
237
+ "export-strict-dec",
238
+ "export-strict-decall",
239
+ "export-strict-dec-oblivious",
240
+ "export-strict-decall-oblivious",
241
+ "export-nostrict-dec",
242
+ "export-nostrict-decall",
243
+ "export-nostrict-dec-oblivious",
244
+ "export-nostrict-decall-oblivious",
245
+ ):
246
+ try:
247
+ if verbose >= 2:
248
+ exported = _wrap_torch_export(
249
+ model,
250
+ inputs,
251
+ dynamic_shapes=dynamic_shapes,
252
+ strict=strict,
253
+ backed_size_oblivious=backed_size_oblivious,
254
+ )
255
+ if verbose >= 9:
256
+ print("-- graph before decomposition")
257
+ print(exported.graph)
258
+ exported = (
259
+ exported.run_decompositions()
260
+ if "decall" in exporter
261
+ else exported.run_decompositions({})
262
+ )
263
+ else:
264
+ with (
265
+ contextlib.redirect_stdout(io.StringIO()),
266
+ contextlib.redirect_stderr(io.StringIO()),
267
+ ):
268
+ exported = _wrap_torch_export(
269
+ model,
270
+ inputs,
271
+ dynamic_shapes=dynamic_shapes,
272
+ strict=strict,
273
+ backed_size_oblivious=backed_size_oblivious,
274
+ )
275
+ if verbose >= 9:
276
+ print("-- graph before decomposition")
277
+ print(exported.graph)
278
+ exported = (
279
+ exported.run_decompositions()
280
+ if "decall" in exporter
281
+ else exported.run_decompositions({})
282
+ )
283
+ except Exception as e:
284
+ if not quiet:
285
+ raise
286
+ return dict(error=str(e), success=0, error_step="export")
287
+ if verbose >= 9:
288
+ print("-- graph after decomposition")
289
+ print(exported.graph)
290
+ return exported.module()
291
+
292
+ if exporter == "export-tracing":
293
+ from experimental_experiment.torch_interpreter.tracing import CustomTracer
294
+
295
+ try:
296
+ if verbose >= 2:
297
+ graph = CustomTracer().trace(model)
298
+ mod = torch.fx.GraphModule(model, graph)
299
+ else:
300
+ with (
301
+ contextlib.redirect_stdout(io.StringIO()),
302
+ contextlib.redirect_stderr(io.StringIO()),
303
+ ):
304
+ graph = CustomTracer().trace(model)
305
+ mod = torch.fx.GraphModule(model, graph)
306
+ except Exception as e:
307
+ if not quiet:
308
+ raise
309
+ return dict(error=str(e), success=0, error_step="export")
310
+ if verbose >= 9:
311
+ print("-- graph")
312
+ print(graph)
313
+ return mod
314
+ raise AssertionError(f"Unexpected exporter={exporter!r}")
315
+
316
+
317
+ def _make_exporter_onnx(
318
+ exporter: str,
319
+ model: "torch.nn.Module", # noqa: F821
320
+ inputs: Tuple[Any, ...],
321
+ dynamic_shapes: Optional[Any] = None,
322
+ verbose: int = 0,
323
+ quiet: bool = True,
324
+ ) -> Union[Dict, Tuple[onnx.ModelProto, Any]]:
325
+ from ...helpers import string_type
326
+
327
+ if exporter.startswith("custom"):
328
+ from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
329
+
330
+ opts = {}
331
+ opts["strict"] = "-strict" in exporter
332
+ opts["fallback"] = "-fallback" in exporter
333
+ opts["tracing"] = "-tracing" in exporter
334
+ opts["jit"] = "-jit" in exporter
335
+ if "-dec" in exporter:
336
+ opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
337
+ try:
338
+ if verbose >= 2:
339
+ onx, builder = to_onnx(
340
+ model,
341
+ inputs,
342
+ dynamic_shapes=dynamic_shapes,
343
+ export_options=ExportOptions(**opts),
344
+ return_builder=True,
345
+ )
346
+ else:
347
+ with (
348
+ contextlib.redirect_stdout(io.StringIO()),
349
+ contextlib.redirect_stderr(io.StringIO()),
350
+ ):
351
+ onx, builder = to_onnx(
352
+ model,
353
+ inputs,
354
+ dynamic_shapes=dynamic_shapes,
355
+ export_options=ExportOptions(**opts),
356
+ return_builder=True,
357
+ )
358
+ except Exception as e:
359
+ if not quiet:
360
+ raise RuntimeError(
361
+ f"Unable to convert model={model.__class__.__name__}, "
362
+ f"input={string_type(inputs[0], with_shape=True)}, "
363
+ f"dynamic_shapes={dynamic_shapes}, "
364
+ f"exporter={exporter!r}"
365
+ ) from e
366
+ return dict(error=str(e), success=0, error_step="export")
367
+ return onx, builder
368
+
369
+ if exporter == "dynamo":
370
+ import torch
371
+
372
+ try:
373
+ if verbose >= 2:
374
+ onx = torch.onnx.export(
375
+ model,
376
+ inputs,
377
+ dynamic_shapes=dynamic_shapes,
378
+ dynamo=True,
379
+ report=True,
380
+ ).model_proto
381
+ else:
382
+ with (
383
+ contextlib.redirect_stdout(io.StringIO()),
384
+ contextlib.redirect_stderr(io.StringIO()),
385
+ ):
386
+ onx = torch.onnx.export(
387
+ model,
388
+ inputs,
389
+ dynamic_shapes=dynamic_shapes,
390
+ dynamo=True,
391
+ ).model_proto
392
+ except Exception as e:
393
+ if not quiet:
394
+ raise RuntimeError(
395
+ f"Unable to convert model={model.__class__.__name__}, "
396
+ f"input={string_type(inputs[0], with_shape=True)}, "
397
+ f"dynamic_shapes={dynamic_shapes}, "
398
+ f"exporter={exporter!r}"
399
+ ) from e
400
+ return dict(error=str(e), success=0, error_step="export")
401
+ return onx, None
402
+
403
+ if exporter == "dynamo-ir":
404
+ import torch
405
+
406
+ try:
407
+ if verbose >= 2:
408
+ ep = torch.onnx.export(
409
+ model,
410
+ inputs,
411
+ dynamic_shapes=dynamic_shapes,
412
+ dynamo=True,
413
+ report=True,
414
+ )
415
+ ep.optimize()
416
+ onx = ep.model_proto
417
+ else:
418
+ with (
419
+ contextlib.redirect_stdout(io.StringIO()),
420
+ contextlib.redirect_stderr(io.StringIO()),
421
+ ):
422
+ ep = torch.onnx.export(
423
+ model,
424
+ inputs,
425
+ dynamic_shapes=dynamic_shapes,
426
+ dynamo=True,
427
+ )
428
+ ep.optimize()
429
+ onx = ep.model_proto
430
+ except Exception as e:
431
+ if not quiet:
432
+ raise RuntimeError(
433
+ f"Unable to convert model={model.__class__.__name__}, "
434
+ f"input={string_type(inputs[0], with_shape=True)}, "
435
+ f"dynamic_shapes={dynamic_shapes}, "
436
+ f"exporter={exporter!r}"
437
+ ) from e
438
+ return dict(error=str(e), success=0, error_step="export")
439
+ return onx, None
440
+ raise AssertionError(f"Unexpected exporter={exporter!r}")
441
+
442
+
443
+ def _compares_on_one_example(
444
+ model: Callable, inputs: Tuple[Any, ...], mod: Callable, verbose: int, quiet: bool
445
+ ) -> Tuple[Any, Any, Dict]:
446
+ from onnx_diagnostic.helpers import max_diff, string_type
447
+
448
+ try:
449
+ expected = model(*_clone(inputs))
450
+ except Exception as e:
451
+ if not quiet:
452
+ raise RuntimeError(
453
+ f"eager mode failed=\n{string_type(inputs, with_shape=True)} "
454
+ f"\nmodel=\n{type(model)}"
455
+ ) from e
456
+ res = dict(error=str(e), success=0, error_step="eager")
457
+ return None, None, res
458
+ try:
459
+ got = mod(*inputs)
460
+ except Exception as e:
461
+ if not quiet:
462
+ raise RuntimeError(
463
+ f"onnxruntime failed, feeds=\n{string_type(inputs, with_shape=True)}"
464
+ ) from e
465
+ res = dict(error=str(e), success=0, error_step="run.0")
466
+ return expected, None, res
467
+
468
+ try:
469
+ disc = max_diff(expected, got)
470
+ except Exception as e:
471
+ if not quiet:
472
+ raise
473
+ res = dict(error=str(e), success=0, error_step="discrepancy")
474
+ return expected, got, res
475
+
476
+ if verbose >= 5 and np.isinf(disc["abs"]):
477
+ print("[run_exporter] comparison issues with")
478
+ print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
479
+ print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
480
+ print(f"-- got={string_type(got, with_shape=True, limit=20)}")
481
+ elif verbose >= 9:
482
+ print("[run_exporter] inputs and outputs")
483
+ print(
484
+ f"-- inputs="
485
+ f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
486
+ )
487
+ print(
488
+ f"-- expected="
489
+ f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
490
+ )
491
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
492
+ del disc["n"]
493
+ del disc["sum"]
494
+ disc.update(
495
+ dict(
496
+ success=1 if disc["abs"] < 0.1 else 0,
497
+ model_cls=model.__class__, # type: ignore[dict-item]
498
+ exported=mod, # type: ignore[dict-item]
499
+ )
500
+ )
501
+ if disc["abs"] >= 0.1:
502
+ disc["error"] = "diff.0"
503
+ disc["error_step"] = "diff.0"
504
+ if verbose >= 9:
505
+ max_diff(expected, got, verbose=verbose)
506
+ else:
507
+ disc["success"] = 1
508
+ return expected, got, disc
509
+
510
+
511
+ def run_exporter(
512
+ exporter: str,
513
+ cls_model: type,
514
+ dynamic: bool = False,
515
+ quiet: bool = False,
516
+ verbose: int = 0,
517
+ ) -> Dict[str, Any]:
518
+ """
519
+ Runs an exporter and returns whether it fails or not.
520
+
521
+ :param exporter: exporter
522
+ :param cls_model: model class to create
523
+ :param inputs: list of inputs to try
524
+ :param dynamic: use dynamic shape or not
525
+ :param quiet: raise exception or not
526
+ :param verbose: verbosity
527
+ :return: results
528
+ """
529
+ from onnx_diagnostic.helpers import max_diff, string_type
530
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
531
+
532
+ assert hasattr(
533
+ cls_model, "_inputs"
534
+ ), f"Attribute '_inputs' is missing from class {cls_model}"
535
+
536
+ model = cls_model()
537
+ inputs = cls_model._inputs
538
+ valid = getattr(cls_model, "_valid", None)
539
+ if isinstance(inputs, tuple):
540
+ inputs = [inputs]
541
+ if dynamic:
542
+ assert hasattr(
543
+ cls_model, "_dynamic"
544
+ ), f"Attribute '_inputs' is missing from class {cls_model}"
545
+ dynamic_shapes = cls_model._dynamic
546
+ else:
547
+ dynamic_shapes = None
548
+
549
+ base = dict(inputs=inputs, model=model, dynamic_shapes=dynamic_shapes)
550
+
551
+ if verbose > 0:
552
+ print(
553
+ f"[run_exporter] exporter={exporter}, model={cls_model.__name__}, "
554
+ f"dynamic={dynamic}, inputs={string_type(inputs, with_shape=True)}"
555
+ )
556
+
557
+ builder = None
558
+ onx = None
559
+
560
+ if exporter.startswith("export-"):
561
+ mod = _make_exporter_export(
562
+ exporter,
563
+ model,
564
+ inputs[0],
565
+ dynamic_shapes=dynamic_shapes,
566
+ verbose=verbose,
567
+ quiet=quiet,
568
+ )
569
+ if isinstance(mod, dict):
570
+ # something went wrong
571
+ return mod
572
+ else:
573
+ res = _make_exporter_onnx(
574
+ exporter,
575
+ model,
576
+ inputs[0],
577
+ dynamic_shapes=dynamic_shapes,
578
+ verbose=verbose,
579
+ quiet=quiet,
580
+ )
581
+ if isinstance(res, dict):
582
+ # something went wrong
583
+ return res
584
+
585
+ onx, builder = res
586
+ base["onx"] = onx
587
+ base["builder"] = builder
588
+ if verbose >= 9:
589
+ print("[run_exporter] onnx model")
590
+ print(
591
+ builder.pretty_text(add_fx_graph=True)
592
+ if builder is not None
593
+ else pretty_onnx(onx)
594
+ )
595
+ if verbose >= 2:
596
+ onnx.save(onx, f"evaluation-{model.__class__.__name__}-{dynamic}-{exporter}.onnx")
597
+
598
+ names = [i.name for i in onx.graph.input]
599
+ flats = _flatten_inputs(inputs[0]) if len(names) > len(inputs[0]) else inputs[0]
600
+
601
+ assert quiet or len(names) == len(flats), (
602
+ f"Input mismatch, inputs[0]={string_type(inputs[0])} "
603
+ f"inputs but names={names!r}, "
604
+ f"model={cls_model.__name__}, export={exporter!r}"
605
+ )
606
+ if len(names) != len(flats):
607
+ res = dict(
608
+ error=f"Input mismatch, inputs[0]={string_type(inputs[0])} "
609
+ f"but names={names!r}, model={cls_model.__name__}, export={exporter!r}",
610
+ success=0,
611
+ error_step="inputs",
612
+ )
613
+ res.update(base)
614
+ return res
615
+
616
+ import onnxruntime
617
+
618
+ try:
619
+ sess = onnxruntime.InferenceSession(
620
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
621
+ )
622
+ except Exception as e:
623
+ if not quiet:
624
+ raise
625
+ res = dict(error=str(e), success=0, error_step="ort-init")
626
+ res.update(base)
627
+ return res
628
+
629
+ mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
630
+
631
+ # we need to clone for models modifying the inputs
632
+ expected, got, disc = _compares_on_one_example(model, inputs[0], mod, verbose, quiet)
633
+ if expected is not None:
634
+ base["expected"] = expected
635
+ if got is not None:
636
+ base["obtained"] = got
637
+ disc.update(base)
638
+ disc["onnx"] = onx # type: ignore[dict-item]
639
+
640
+ if valid is not None:
641
+ for valid_inputs in valid:
642
+ expected, got, _disc = _compares_on_one_example(
643
+ model, valid_inputs, mod, verbose, quiet
644
+ )
645
+ if "abs" not in disc and (np.isnan(disc["abs"]) or disc["abs"] > 1e-3):
646
+ _disc["issue-abs"] = disc["abs"]
647
+ _disc["issue-rel"] = disc["rel"]
648
+ _disc["issue-inputs"] = string_type(
649
+ valid_inputs, with_shape=True, with_min_max=True
650
+ )
651
+ _disc["issue-expected"] = string_type(
652
+ expected, with_shape=True, with_min_max=True
653
+ )
654
+ _disc["issue-obtained"] = string_type(got, with_shape=True, with_min_max=True)
655
+ if not quiet:
656
+ raise RuntimeError(
657
+ f"validation failed,"
658
+ f"\n-- inputs=\n{string_type(_disc['issue-inputs'])} "
659
+ f"\n-- exporter={exporter!r}\n-- dynamic_shapes={dynamic_shapes}, "
660
+ f"\n-- expected={_disc['issue-expected']}"
661
+ f"\n-- obtained={_disc['issue-obtained']}"
662
+ )
663
+ break
664
+
665
+ if dynamic and onx is not None:
666
+ ds = []
667
+ for i in onx.graph.input:
668
+ if i.type.tensor_type:
669
+ for di, dim in enumerate(i.type.tensor_type.shape.dim):
670
+ if dim.dim_param:
671
+ ds.append((i.name, di, dim.dim_param))
672
+ if verbose >= 2:
673
+ print(f"[run_exporter] dynamic dimension={ds}")
674
+ if not ds:
675
+ return dict(error="no dynamic shape", success=0, error_step="dynamic")
676
+
677
+ if dynamic and len(inputs) > 1:
678
+ for index, i in enumerate(inputs):
679
+ if quiet:
680
+ try:
681
+ expected = model(*_clone(i))
682
+ except Exception as e:
683
+ return dict(error=str(e), success=0, error_step=f"run0.{index}")
684
+ else:
685
+ expected = model(*_clone(i))
686
+ try:
687
+ got = mod(*i)
688
+ except Exception as e:
689
+ if not quiet:
690
+ raise RuntimeError(
691
+ f"onnxruntime failed,\n-- feeds=\n{string_type(i, with_shape=True)} "
692
+ f"exporter={exporter!r}, dynamic_shapes={dynamic_shapes}"
693
+ f"\n-- model=\n{pretty_onnx(onx) if onx is not None else type(model)}"
694
+ ) from e
695
+ return dict(error=str(e), success=0, error_step=f"run.{index}")
696
+
697
+ try:
698
+ d = max_diff(expected, got)
699
+ except Exception as e:
700
+ if not quiet:
701
+ raise
702
+ return dict(error=str(e), success=0, error_step=f"discrepancy.{index}")
703
+
704
+ if verbose >= 5 and np.isinf(d["abs"]):
705
+ print(f"[run_exporter] comparison issues iteration {index}")
706
+ print(f"-- inputs={string_type(i, with_shape=True)}")
707
+ print(f"-- expected={string_type(expected, with_shape=True)}")
708
+ print(f"-- got={string_type(got, with_shape=True)}")
709
+ elif verbose >= 9:
710
+ print(f"[run_exporter] inputs and outputs iteration {index}")
711
+ print(f"-- inputs={string_type(i, with_shape=True, with_min_max=True)}")
712
+ print(
713
+ f"-- expected={string_type(expected, with_shape=True, with_min_max=True)}"
714
+ )
715
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True)}")
716
+ del d["n"]
717
+ del d["sum"]
718
+ if d["abs"] >= 0.1:
719
+ d["error"] = f"diff.{index}"
720
+ d["error_step"] = f"diff.{index}"
721
+ d["success"] = 0
722
+ disc.update(d)
723
+
724
+ disc.update(base)
725
+ return disc