onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +2 -2
  2. onnx_diagnostic/_command_lines_parser.py +39 -1
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +14 -5
  5. onnx_diagnostic/ext_test_case.py +15 -1
  6. onnx_diagnostic/helpers/args_helper.py +1 -1
  7. onnx_diagnostic/helpers/graph_helper.py +386 -0
  8. onnx_diagnostic/helpers/helper.py +30 -5
  9. onnx_diagnostic/helpers/model_builder_helper.py +349 -0
  10. onnx_diagnostic/helpers/rt_helper.py +69 -1
  11. onnx_diagnostic/helpers/torch_helper.py +2 -0
  12. onnx_diagnostic/reference/__init__.py +1 -0
  13. onnx_diagnostic/reference/torch_evaluator.py +518 -0
  14. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  15. onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
  16. onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
  17. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  18. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
  19. onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
  20. onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
  21. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  22. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  23. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  24. onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
  25. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  26. onnx_diagnostic/tasks/__init__.py +22 -1
  27. onnx_diagnostic/tasks/image_classification.py +2 -2
  28. onnx_diagnostic/tasks/text_generation.py +3 -3
  29. onnx_diagnostic/torch_export_patches/eval/__init__.py +690 -0
  30. onnx_diagnostic/torch_export_patches/eval/model_cases.py +883 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
  33. onnx_diagnostic/torch_export_patches/patch_module_helper.py +148 -28
  34. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +117 -1
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  37. onnx_diagnostic/torch_models/test_helper.py +225 -22
  38. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  39. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
  42. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,690 @@
1
+ import contextlib
2
+ import io
3
+ import itertools
4
+ import re
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+ import numpy as np
7
+ import onnx
8
+
9
+
10
+ def discover():
11
+ """
12
+ Discovers all model cases used to evaluate an exporter.
13
+
14
+ .. runpython::
15
+ :showcode:
16
+
17
+ import pprint
18
+ from onnx_diagnostic.torch_export_patches.eval import discover
19
+
20
+ pprint.pprint(discover())
21
+ """
22
+ from . import model_cases
23
+
24
+ res = {}
25
+ for m in model_cases.__dict__.values():
26
+ if m is None or isinstance(m, str):
27
+ continue
28
+ if not hasattr(m, "forward"):
29
+ continue
30
+ assert m.__name__ not in res, f"Case {m.__name__!r} is duplicated."
31
+ assert hasattr(m, "_inputs"), f"Attribute '_inputs' is missing from class {m}"
32
+ assert hasattr(m, "_dynamic"), f"Attribute '_dynamic' is missing from class {m}"
33
+ res[m.__name__] = m
34
+ return res
35
+
36
+
37
+ def evaluation(
38
+ exporters: Tuple[str] = (
39
+ "export-strict",
40
+ "export-nostrict",
41
+ "export-nostrict-decall",
42
+ ),
43
+ dynamic: Tuple[bool] = (False, True),
44
+ cases: Optional[Union[str, Dict[str, type]]] = None,
45
+ verbose: int = 0,
46
+ quiet: bool = True,
47
+ ) -> List[Dict[str, Any]]:
48
+ """
49
+ Evaluates exporter for a list of cases.
50
+
51
+ :param exporters: exporters to evaluate
52
+ :param dynamic: evaluate static shape and dynamic shapes
53
+ :param cases: model cases to evaluate
54
+ :param verbose: verbosity
55
+ :param quiet: catch exception
56
+ :return: results, list of dictionaries
57
+ """
58
+ if isinstance(exporters, str):
59
+ exporters = (exporters,)
60
+ if isinstance(dynamic, (bool, int)):
61
+ dynamic = (dynamic,)
62
+
63
+ if cases is None:
64
+ cases = discover()
65
+ elif cases in ("three", ["three"]):
66
+ all_cases = discover()
67
+ cases = dict(list(all_cases.items())[:3])
68
+ elif isinstance(cases, str):
69
+ cases = (cases,)
70
+
71
+ if isinstance(cases, (list, tuple)):
72
+ all_cases = discover()
73
+ new_cases = [] # type: ignore[var-annotated]
74
+ for c in cases:
75
+ if "*" in c or "?" in c:
76
+ # regex
77
+ reg = re.compile(c)
78
+ new_cases.extend(k for k in all_cases if reg.match(k))
79
+ else:
80
+ new_cases.append(c)
81
+ cases = {k: v for k, v in all_cases.items() if k in set(new_cases)}
82
+
83
+ sorted_cases = sorted(cases.items())
84
+ loop = list(itertools.product(sorted_cases, dynamic, exporters))
85
+ if verbose:
86
+ try:
87
+ import tqdm
88
+
89
+ loop = tqdm.tqdm(loop)
90
+ except ImportError:
91
+
92
+ def _loop():
93
+ for _ in loop:
94
+ print(f"[evaluation] {_}")
95
+ yield _
96
+
97
+ assert len(loop) > 0, f"No case to test for cases={cases!r}."
98
+ obs = []
99
+ for case, dyn, exporter in loop:
100
+ name, cls_model = case
101
+ res = run_exporter(exporter, cls_model, dyn, quiet=quiet, verbose=max(0, verbose - 1))
102
+ res.update(dict(name=name, dynamic=int(dyn), exporter=exporter))
103
+ obs.append(res)
104
+ return obs
105
+
106
+
107
+ def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
108
+ """
109
+ Flatten inputs.
110
+ """
111
+ if x is None:
112
+ return x
113
+ import torch
114
+
115
+ if isinstance(x, (list, tuple)):
116
+ res = []
117
+ for i in x:
118
+ if i is None or isinstance(
119
+ i,
120
+ (
121
+ torch.Tensor,
122
+ torch.SymInt,
123
+ torch.SymFloat,
124
+ int,
125
+ float,
126
+ ),
127
+ ):
128
+ res.append(i)
129
+ else:
130
+ res.extend(_flatten_inputs(i))
131
+ return tuple(res) if isinstance(x, tuple) else res
132
+ raise AssertionError(f"Unexpected type {type(x)} for x")
133
+
134
+
135
+ def _to_numpy(x):
136
+ if hasattr(x, "numpy"):
137
+ return x.numpy()
138
+ if isinstance(x, int):
139
+ # onnxruntime does not like scalar
140
+ return np.array([x], dtype=np.int64)
141
+ if isinstance(x, float):
142
+ # onnxruntime does not like scalar
143
+ return np.array([x], dtype=np.float32)
144
+ if isinstance(x, list):
145
+ return [_to_numpy(_) for _ in x]
146
+ if isinstance(x, tuple):
147
+ return tuple(_to_numpy(_) for _ in x)
148
+ raise TypeError(f"Unable to convert type {type(x)}, x={x} into numpy")
149
+
150
+
151
+ def _make_feeds(names, args):
152
+ if len(names) == len(args):
153
+ return {k: _to_numpy(v) for k, v in zip(names, args)}
154
+ if len(names) > len(args):
155
+ flats = _flatten_inputs(args)
156
+ return {k: _to_numpy(v) for k, v in zip(names, flats)}
157
+ from ...helpers import string_type
158
+
159
+ raise RuntimeError(
160
+ f"Unable to handle names={names!r} and args={string_type(args, limit=20)}"
161
+ )
162
+
163
+
164
+ def _clone(x):
165
+ if hasattr(x, "clone"):
166
+ return x.clone()
167
+ if isinstance(x, (int, float)):
168
+ return x
169
+ if isinstance(x, list):
170
+ return [_clone(_) for _ in x]
171
+ if isinstance(x, tuple):
172
+ return tuple(_clone(_) for _ in x)
173
+ raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
174
+
175
+
176
+ def _make_exporter_export(
177
+ exporter: str,
178
+ model: "torch.nn.Module", # noqa: F821
179
+ inputs: Tuple[Any, ...],
180
+ dynamic_shapes: Optional[Any] = None,
181
+ verbose: int = 0,
182
+ quiet: bool = True,
183
+ ) -> Union[Dict, Callable]:
184
+ import torch
185
+
186
+ if exporter == "export-strict":
187
+ try:
188
+ if verbose >= 2:
189
+ exported = torch.export.export(
190
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
191
+ )
192
+ else:
193
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
194
+ io.StringIO()
195
+ ):
196
+ exported = torch.export.export(
197
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
198
+ )
199
+ except Exception as e:
200
+ if not quiet:
201
+ raise
202
+ return dict(error=str(e), success=0, error_step="export")
203
+ if verbose >= 9:
204
+ print("-- graph")
205
+ print(exported.graph)
206
+ return exported.module()
207
+ if exporter in ("export-strict-dec", "export-strict-decall"):
208
+ try:
209
+ if verbose >= 2:
210
+ exported = torch.export.export(
211
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
212
+ )
213
+ if verbose >= 9:
214
+ print("-- graph before decomposition")
215
+ print(exported.graph)
216
+ exported = (
217
+ exported.run_decompositions()
218
+ if "decall" in exporter
219
+ else exported.run_decompositions({})
220
+ )
221
+ else:
222
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
223
+ io.StringIO()
224
+ ):
225
+ exported = torch.export.export(
226
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
227
+ )
228
+ if verbose >= 9:
229
+ print("-- graph before decomposition")
230
+ print(exported.graph)
231
+ exported = (
232
+ exported.run_decompositions()
233
+ if "decall" in exporter
234
+ else exported.run_decompositions({})
235
+ )
236
+ except Exception as e:
237
+ if not quiet:
238
+ raise
239
+ return dict(error=str(e), success=0, error_step="export")
240
+ if verbose >= 9:
241
+ print("-- graph after decomposition")
242
+ print(exported.graph)
243
+ return exported.module()
244
+ if exporter == "export-nostrict":
245
+ try:
246
+ if verbose >= 2:
247
+ exported = torch.export.export(
248
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
249
+ )
250
+ else:
251
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
252
+ io.StringIO()
253
+ ):
254
+ exported = torch.export.export(
255
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
256
+ )
257
+ except Exception as e:
258
+ if not quiet:
259
+ raise
260
+ return dict(error=str(e), success=0, error_step="export")
261
+ if verbose >= 9:
262
+ print("-- graph")
263
+ print(exported.graph)
264
+ return exported.module()
265
+ if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
266
+ try:
267
+ if verbose >= 2:
268
+ exported = torch.export.export(
269
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
270
+ )
271
+ if verbose >= 9:
272
+ print("-- graph before decomposition")
273
+ print(exported.graph)
274
+ exported = (
275
+ exported.run_decompositions()
276
+ if "decall" in exporter
277
+ else exported.run_decompositions({})
278
+ )
279
+ else:
280
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
281
+ io.StringIO()
282
+ ):
283
+ exported = torch.export.export(
284
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
285
+ )
286
+ if verbose >= 9:
287
+ print("-- graph before decomposition")
288
+ print(exported.graph)
289
+ exported = (
290
+ exported.run_decompositions()
291
+ if "decall" in exporter
292
+ else exported.run_decompositions({})
293
+ )
294
+ except Exception as e:
295
+ if not quiet:
296
+ raise
297
+ return dict(error=str(e), success=0, error_step="export")
298
+ if verbose >= 9:
299
+ print("-- graph after decomposition")
300
+ print(exported.graph)
301
+ return exported.module()
302
+ if exporter == "export-tracing":
303
+ from experimental_experiment.torch_interpreter.tracing import CustomTracer
304
+
305
+ try:
306
+ if verbose >= 2:
307
+ graph = CustomTracer().trace(model)
308
+ mod = torch.fx.GraphModule(model, graph)
309
+ else:
310
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
311
+ io.StringIO()
312
+ ):
313
+ graph = CustomTracer().trace(model)
314
+ mod = torch.fx.GraphModule(model, graph)
315
+ except Exception as e:
316
+ if not quiet:
317
+ raise
318
+ return dict(error=str(e), success=0, error_step="export")
319
+ if verbose >= 9:
320
+ print("-- graph")
321
+ print(graph)
322
+ return mod
323
+ raise AssertionError(f"Unexpected exporter={exporter!r}")
324
+
325
+
326
+ def _make_exporter_onnx(
327
+ exporter: str,
328
+ model: "torch.nn.Module", # noqa: F821
329
+ inputs: Tuple[Any, ...],
330
+ dynamic_shapes: Optional[Any] = None,
331
+ verbose: int = 0,
332
+ quiet: bool = True,
333
+ ) -> Union[Dict, Tuple[onnx.ModelProto, Any]]:
334
+ from ...helpers import string_type
335
+
336
+ if exporter.startswith("custom"):
337
+ from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
338
+
339
+ opts = {}
340
+ opts["strict"] = "-nostrict" not in exporter
341
+ opts["fallback"] = "-fallback" in exporter
342
+ opts["tracing"] = "-tracing" in exporter
343
+ opts["jit"] = "-jit" in exporter
344
+ if "-dec" in exporter:
345
+ opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
346
+ try:
347
+ if verbose >= 2:
348
+ onx, builder = to_onnx(
349
+ model,
350
+ inputs,
351
+ dynamic_shapes=dynamic_shapes,
352
+ export_options=ExportOptions(**opts),
353
+ return_builder=True,
354
+ )
355
+ else:
356
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
357
+ io.StringIO()
358
+ ):
359
+ onx, builder = to_onnx(
360
+ model,
361
+ inputs,
362
+ dynamic_shapes=dynamic_shapes,
363
+ export_options=ExportOptions(**opts),
364
+ return_builder=True,
365
+ )
366
+ except Exception as e:
367
+ if not quiet:
368
+ raise RuntimeError(
369
+ f"Unable to convert model={model.__class__.__name__}, "
370
+ f"input={string_type(inputs[0], with_shape=True)}, "
371
+ f"dynamic_shapes={dynamic_shapes}, "
372
+ f"exporter={exporter!r}"
373
+ ) from e
374
+ return dict(error=str(e), success=0, error_step="export")
375
+ return onx, builder
376
+
377
+ if exporter == "dynamo":
378
+ import torch
379
+
380
+ try:
381
+ if verbose >= 2:
382
+ onx = torch.onnx.export(
383
+ model,
384
+ inputs,
385
+ dynamic_shapes=dynamic_shapes,
386
+ dynamo=True,
387
+ report=True,
388
+ ).model_proto
389
+ else:
390
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
391
+ io.StringIO()
392
+ ):
393
+ onx = torch.onnx.export(
394
+ model,
395
+ inputs,
396
+ dynamic_shapes=dynamic_shapes,
397
+ dynamo=True,
398
+ ).model_proto
399
+ except Exception as e:
400
+ if not quiet:
401
+ raise RuntimeError(
402
+ f"Unable to convert model={model.__class__.__name__}, "
403
+ f"input={string_type(inputs[0], with_shape=True)}, "
404
+ f"dynamic_shapes={dynamic_shapes}, "
405
+ f"exporter={exporter!r}"
406
+ ) from e
407
+ return dict(error=str(e), success=0, error_step="export")
408
+ return onx, None
409
+
410
+ if exporter == "dynamo-ir":
411
+ import torch
412
+
413
+ try:
414
+ if verbose >= 2:
415
+ ep = torch.onnx.export(
416
+ model,
417
+ inputs,
418
+ dynamic_shapes=dynamic_shapes,
419
+ dynamo=True,
420
+ report=True,
421
+ )
422
+ ep.optimize()
423
+ onx = ep.model_proto
424
+ else:
425
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
426
+ io.StringIO()
427
+ ):
428
+ ep = torch.onnx.export(
429
+ model,
430
+ inputs,
431
+ dynamic_shapes=dynamic_shapes,
432
+ dynamo=True,
433
+ )
434
+ ep.optimize()
435
+ onx = ep.model_proto
436
+ except Exception as e:
437
+ if not quiet:
438
+ raise RuntimeError(
439
+ f"Unable to convert model={model.__class__.__name__}, "
440
+ f"input={string_type(inputs[0], with_shape=True)}, "
441
+ f"dynamic_shapes={dynamic_shapes}, "
442
+ f"exporter={exporter!r}"
443
+ ) from e
444
+ return dict(error=str(e), success=0, error_step="export")
445
+ return onx, None
446
+ raise AssertionError(f"Unexpected exporter={exporter!r}")
447
+
448
+
449
+ def run_exporter(
450
+ exporter: str,
451
+ cls_model: type,
452
+ dynamic: bool = False,
453
+ quiet: bool = False,
454
+ verbose: int = 0,
455
+ ) -> Dict[str, Any]:
456
+ """
457
+ Runs an exporter and returns whether it fails or not.
458
+
459
+ :param exporter: exporter
460
+ :param cls_model: model class to create
461
+ :param inputs: list of inputs to try
462
+ :param dynamic: use dynamic shape or not
463
+ :param quiet: raise exception or not
464
+ :param verbose: verbosity
465
+ :return: results
466
+ """
467
+ from onnx_diagnostic.helpers import max_diff, string_type
468
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
469
+
470
+ assert hasattr(
471
+ cls_model, "_inputs"
472
+ ), f"Attribute '_inputs' is missing from class {cls_model}"
473
+
474
+ model = cls_model()
475
+ inputs = cls_model._inputs
476
+ if isinstance(inputs, tuple):
477
+ inputs = [inputs]
478
+ if dynamic:
479
+ assert hasattr(
480
+ cls_model, "_dynamic"
481
+ ), f"Attribute '_inputs' is missing from class {cls_model}"
482
+ dynamic_shapes = cls_model._dynamic
483
+ else:
484
+ dynamic_shapes = None
485
+
486
+ base = dict(inputs=inputs, model=model, dynamic_shapes=dynamic_shapes)
487
+
488
+ if verbose > 0:
489
+ print(
490
+ f"[run_exporter] exporter={exporter}, model={cls_model.__name__}, "
491
+ f"dynamic={dynamic}, inputs={string_type(inputs, with_shape=True)}"
492
+ )
493
+
494
+ builder = None
495
+ onx = None
496
+
497
+ if exporter.startswith("export-"):
498
+ mod = _make_exporter_export(
499
+ exporter,
500
+ model,
501
+ inputs[0],
502
+ dynamic_shapes=dynamic_shapes,
503
+ verbose=verbose,
504
+ quiet=quiet,
505
+ )
506
+ if isinstance(mod, dict):
507
+ # something went wrong
508
+ return mod
509
+ else:
510
+ res = _make_exporter_onnx(
511
+ exporter,
512
+ model,
513
+ inputs[0],
514
+ dynamic_shapes=dynamic_shapes,
515
+ verbose=verbose,
516
+ quiet=quiet,
517
+ )
518
+ if isinstance(res, dict):
519
+ # something went wrong
520
+ return res
521
+
522
+ onx, builder = res
523
+ if verbose >= 9:
524
+ print("[run_exporter] onnx model")
525
+ print(
526
+ builder.pretty_text(add_fx_graph=True)
527
+ if builder is not None
528
+ else pretty_onnx(onx)
529
+ )
530
+ if verbose >= 2:
531
+ onnx.save(onx, f"evaluation-{model.__class__.__name__}-{dynamic}-{exporter}.onnx")
532
+
533
+ names = [i.name for i in onx.graph.input]
534
+ flats = _flatten_inputs(inputs[0]) if len(names) > len(inputs[0]) else inputs[0]
535
+
536
+ assert quiet or len(names) == len(flats), (
537
+ f"Input mismatch, inputs[0]={string_type(inputs[0])} "
538
+ f"inputs but names={names!r}, "
539
+ f"model={cls_model.__name__}, export={exporter!r}"
540
+ )
541
+ if len(names) != len(flats):
542
+ res = dict(
543
+ error=f"Input mismatch, inputs[0]={string_type(inputs[0])} "
544
+ f"but names={names!r}, model={cls_model.__name__}, export={exporter!r}",
545
+ success=0,
546
+ error_step="inputs",
547
+ )
548
+ res.update(base)
549
+ return res
550
+
551
+ import onnxruntime
552
+
553
+ try:
554
+ sess = onnxruntime.InferenceSession(
555
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
556
+ )
557
+ except Exception as e:
558
+ if not quiet:
559
+ raise
560
+ res = dict(error=str(e), success=0, error_step="ort-init")
561
+ res.update(base)
562
+ return res
563
+
564
+ mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
565
+
566
+ # we need to clone for models modifying the inputs
567
+ try:
568
+ expected = model(*_clone(inputs[0]))
569
+ except Exception as e:
570
+ if not quiet:
571
+ raise RuntimeError(
572
+ f"eager mode failed=\n{string_type(inputs[0], with_shape=True)} "
573
+ f"\nmodel=\n{type(model)}"
574
+ ) from e
575
+ res = dict(error=str(e), success=0, error_step="eager")
576
+ res.update(base)
577
+ return res
578
+ try:
579
+ got = mod(*inputs[0])
580
+ except Exception as e:
581
+ if not quiet:
582
+ raise RuntimeError(
583
+ f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} "
584
+ f"\nmodel=\n{pretty_onnx(onx)}"
585
+ ) from e
586
+ res = dict(error=str(e), success=0, error_step="run.0")
587
+ res.update(base)
588
+ return res
589
+
590
+ base["expected"] = expected
591
+ base["obtained"] = got
592
+
593
+ try:
594
+ disc = max_diff(expected, got)
595
+ except Exception as e:
596
+ if not quiet:
597
+ raise
598
+ res = dict(error=str(e), success=0, error_step="discrepancy")
599
+ res.update(base)
600
+ return res
601
+
602
+ if verbose >= 5 and np.isinf(disc["abs"]):
603
+ print("[run_exporter] comparison issues with")
604
+ print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
605
+ print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
606
+ print(f"-- got={string_type(got, with_shape=True, limit=20)}")
607
+ elif verbose >= 9:
608
+ print("[run_exporter] inputs and outputs")
609
+ print(
610
+ f"-- inputs="
611
+ f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
612
+ )
613
+ print(
614
+ f"-- expected="
615
+ f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
616
+ )
617
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
618
+ del disc["n"]
619
+ del disc["sum"]
620
+ disc.update(
621
+ dict(
622
+ success=1 if disc["abs"] < 0.1 else 0,
623
+ model_cls=model.__class__,
624
+ exported=mod, # type: ignore[dict-item]
625
+ onnx=onx, # type: ignore[dict-item]
626
+ )
627
+ )
628
+ if disc["abs"] >= 0.1:
629
+ disc["error"] = "diff.0"
630
+ disc["error_step"] = "diff.0"
631
+ if verbose >= 9:
632
+ max_diff(expected, got, verbose=verbose)
633
+ else:
634
+ disc["success"] = 1
635
+
636
+ if dynamic and onx is not None:
637
+ ds = []
638
+ for i in onx.graph.input:
639
+ if i.type.tensor_type:
640
+ for di, dim in enumerate(i.type.tensor_type.shape.dim):
641
+ if dim.dim_param:
642
+ ds.append((i.name, di, dim.dim_param))
643
+ if verbose >= 2:
644
+ print(f"[run_exporter] dynamic dimension={ds}")
645
+ if not ds:
646
+ return dict(error="no dynamic shape", success=0, error_step="dynamic")
647
+
648
+ if dynamic and len(inputs) > 1:
649
+ for index, i in enumerate(inputs):
650
+ expected = model(*_clone(i))
651
+ try:
652
+ got = mod(*i)
653
+ except Exception as e:
654
+ if not quiet:
655
+ raise RuntimeError(
656
+ f"onnxruntime failed,\n-- feeds=\n{string_type(i, with_shape=True)} "
657
+ f"exporter={exporter!r}, dynamic_shapes={dynamic_shapes}"
658
+ f"\n-- model=\n{pretty_onnx(onx) if onx is not None else type(model)}"
659
+ ) from e
660
+ return dict(error=str(e), success=0, error_step=f"run.{index}")
661
+
662
+ try:
663
+ d = max_diff(expected, got)
664
+ except Exception as e:
665
+ if not quiet:
666
+ raise
667
+ return dict(error=str(e), success=0, error_step=f"discrepancy.{index}")
668
+
669
+ if verbose >= 5 and np.isinf(d["abs"]):
670
+ print(f"[run_exporter] comparison issues iteration {index}")
671
+ print(f"-- inputs={string_type(i, with_shape=True)}")
672
+ print(f"-- expected={string_type(expected, with_shape=True)}")
673
+ print(f"-- got={string_type(got, with_shape=True)}")
674
+ elif verbose >= 9:
675
+ print(f"[run_exporter] inputs and outputs iteration {index}")
676
+ print(f"-- inputs={string_type(i, with_shape=True, with_min_max=True)}")
677
+ print(
678
+ f"-- expected={string_type(expected, with_shape=True, with_min_max=True)}"
679
+ )
680
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True)}")
681
+ del d["n"]
682
+ del d["sum"]
683
+ if d["abs"] >= 0.1:
684
+ d["error"] = f"diff.{index}"
685
+ d["error_step"] = f"diff.{index}"
686
+ d["success"] = 0
687
+ disc.update(d)
688
+
689
+ disc.update(base)
690
+ return disc