onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,621 @@
1
+ import contextlib
2
+ import io
3
+ import itertools
4
+ import re
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+ import numpy as np
7
+ import onnx
8
+
9
+
10
+ def discover():
11
+ """
12
+ Discovers all model cases used to evaluate an exporter.
13
+
14
+ .. runpython::
15
+ :showcode:
16
+
17
+ import pprint
18
+ from onnx_diagnostic.torch_export_patches.eval import discover
19
+
20
+ pprint.pprint(discover())
21
+ """
22
+ from . import model_cases
23
+
24
+ res = {}
25
+ for m in model_cases.__dict__.values():
26
+ if m is None or isinstance(m, str):
27
+ continue
28
+ if not hasattr(m, "forward"):
29
+ continue
30
+ assert m.__name__ not in res, f"Case {m.__name__!r} is duplicated."
31
+ assert hasattr(m, "_inputs"), f"Attribute '_inputs' is missing from class {m}"
32
+ assert hasattr(m, "_dynamic"), f"Attribute '_dynamic' is missing from class {m}"
33
+ res[m.__name__] = m
34
+ return res
35
+
36
+
37
+ def evaluation(
38
+ exporters: Tuple[str] = (
39
+ "export-strict",
40
+ "export-nostrict",
41
+ "export-nostrict-decall",
42
+ ),
43
+ dynamic: Tuple[bool] = (False, True),
44
+ cases: Optional[Union[str, Dict[str, type]]] = None,
45
+ verbose: int = 0,
46
+ quiet: bool = True,
47
+ ) -> List[Dict[str, Any]]:
48
+ """
49
+ Evaluates exporter for a list of cases.
50
+
51
+ :param exporters: exporters to evaluate
52
+ :param dynamic: evaluate static shape and dynamic shapes
53
+ :param cases: model cases to evaluate
54
+ :param verbose: verbosity
55
+ :param quiet: catch exception
56
+ :return: results, list of dictionaries
57
+ """
58
+ if isinstance(exporters, str):
59
+ exporters = (exporters,)
60
+ if isinstance(dynamic, (bool, int)):
61
+ dynamic = (dynamic,)
62
+
63
+ if cases is None:
64
+ cases = discover()
65
+ elif cases in ("three", ["three"]):
66
+ all_cases = discover()
67
+ cases = dict(list(all_cases.items())[:3])
68
+ elif isinstance(cases, str):
69
+ cases = (cases,)
70
+
71
+ if isinstance(cases, (list, tuple)):
72
+ all_cases = discover()
73
+ new_cases = [] # type: ignore[var-annotated]
74
+ for c in cases:
75
+ if "*" in c or "?" in c:
76
+ # regex
77
+ reg = re.compile(c)
78
+ new_cases.extend(k for k in all_cases if reg.match(k))
79
+ else:
80
+ new_cases.append(c)
81
+ cases = {k: v for k, v in all_cases.items() if k in set(new_cases)}
82
+
83
+ sorted_cases = sorted(cases.items())
84
+ loop = list(itertools.product(sorted_cases, dynamic, exporters))
85
+ if verbose:
86
+ try:
87
+ import tqdm
88
+
89
+ loop = tqdm.tqdm(loop)
90
+ except ImportError:
91
+
92
+ def _loop():
93
+ for _ in loop:
94
+ print(f"[evaluation] {_}")
95
+ yield _
96
+
97
+ assert len(loop) > 0, f"No case to test for cases={cases!r}."
98
+ obs = []
99
+ for case, dyn, exporter in loop:
100
+ name, cls_model = case
101
+ res = run_exporter(exporter, cls_model, dyn, quiet=quiet, verbose=max(0, verbose - 1))
102
+ res.update(dict(name=name, dynamic=int(dyn), exporter=exporter))
103
+ obs.append(res)
104
+ return obs
105
+
106
+
107
+ def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
108
+ """
109
+ Flatten inputs.
110
+ """
111
+ if x is None:
112
+ return x
113
+ import torch
114
+
115
+ if isinstance(x, (list, tuple)):
116
+ res = []
117
+ for i in x:
118
+ if i is None or isinstance(
119
+ i,
120
+ (
121
+ torch.Tensor,
122
+ torch.SymInt,
123
+ torch.SymFloat,
124
+ int,
125
+ float,
126
+ ),
127
+ ):
128
+ res.append(i)
129
+ else:
130
+ res.extend(_flatten_inputs(i))
131
+ return tuple(res) if isinstance(x, tuple) else res
132
+ raise AssertionError(f"Unexpected type {type(x)} for x")
133
+
134
+
135
+ def _to_numpy(x):
136
+ if hasattr(x, "numpy"):
137
+ return x.numpy()
138
+ if isinstance(x, int):
139
+ # onnxruntime does not like scalar
140
+ return np.array([x], dtype=np.int64)
141
+ if isinstance(x, float):
142
+ # onnxruntime does not like scalar
143
+ return np.array([x], dtype=np.float32)
144
+ if isinstance(x, list):
145
+ return [_to_numpy(_) for _ in x]
146
+ if isinstance(x, tuple):
147
+ return tuple(_to_numpy(_) for _ in x)
148
+ raise TypeError(f"Unable to convert type {type(x)}, x={x} into numpy")
149
+
150
+
151
+ def _make_feeds(names, args):
152
+ if len(names) == len(args):
153
+ return {k: _to_numpy(v) for k, v in zip(names, args)}
154
+ if len(names) > len(args):
155
+ flats = _flatten_inputs(args)
156
+ return {k: _to_numpy(v) for k, v in zip(names, flats)}
157
+ from ...helpers import string_type
158
+
159
+ raise RuntimeError(
160
+ f"Unable to handle names={names!r} and args={string_type(args, limit=20)}"
161
+ )
162
+
163
+
164
+ def _clone(x):
165
+ if hasattr(x, "clone"):
166
+ return x.clone()
167
+ if isinstance(x, (int, float)):
168
+ return x
169
+ if isinstance(x, list):
170
+ return [_clone(_) for _ in x]
171
+ if isinstance(x, tuple):
172
+ return tuple(_clone(_) for _ in x)
173
+ raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
174
+
175
+
176
+ def _make_exporter_export(
177
+ exporter: str,
178
+ model: "torch.nn.Module", # noqa: F821
179
+ inputs: Tuple[Any, ...],
180
+ dynamic_shapes: Optional[Any] = None,
181
+ verbose: int = 0,
182
+ quiet: bool = True,
183
+ ) -> Union[Dict, Callable]:
184
+ import torch
185
+
186
+ if exporter == "export-strict":
187
+ try:
188
+ exported = torch.export.export(
189
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
190
+ )
191
+ except Exception as e:
192
+ if not quiet:
193
+ raise
194
+ return dict(error=str(e), success=0, error_step="export")
195
+ if verbose >= 9:
196
+ print("-- graph")
197
+ print(exported.graph)
198
+ return exported.module()
199
+ if exporter in ("export-strict-dec", "export-strict-decall"):
200
+ try:
201
+ exported = torch.export.export(
202
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
203
+ )
204
+ if verbose >= 9:
205
+ print("-- graph before decomposition")
206
+ print(exported.graph)
207
+ exported = (
208
+ exported.run_decompositions()
209
+ if "decall" in exporter
210
+ else exported.run_decompositions({})
211
+ )
212
+ except Exception as e:
213
+ if not quiet:
214
+ raise
215
+ return dict(error=str(e), success=0, error_step="export")
216
+ if verbose >= 9:
217
+ print("-- graph after decomposition")
218
+ print(exported.graph)
219
+ return exported.module()
220
+ if exporter == "export-nostrict":
221
+ try:
222
+ exported = torch.export.export(
223
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
224
+ )
225
+ except Exception as e:
226
+ if not quiet:
227
+ raise
228
+ return dict(error=str(e), success=0, error_step="export")
229
+ if verbose >= 9:
230
+ print("-- graph")
231
+ print(exported.graph)
232
+ return exported.module()
233
+ if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
234
+ try:
235
+ exported = torch.export.export(
236
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
237
+ )
238
+ if verbose >= 9:
239
+ print("-- graph before decomposition")
240
+ print(exported.graph)
241
+ exported = (
242
+ exported.run_decompositions()
243
+ if "decall" in exporter
244
+ else exported.run_decompositions({})
245
+ )
246
+ except Exception as e:
247
+ if not quiet:
248
+ raise
249
+ return dict(error=str(e), success=0, error_step="export")
250
+ if verbose >= 9:
251
+ print("-- graph after decomposition")
252
+ print(exported.graph)
253
+ return exported.module()
254
+ if exporter == "export-tracing":
255
+ from experimental_experiment.torch_interpreter.tracing import CustomTracer
256
+
257
+ try:
258
+ graph = CustomTracer().trace(model)
259
+ mod = torch.fx.GraphModule(model, graph)
260
+ except Exception as e:
261
+ if not quiet:
262
+ raise
263
+ return dict(error=str(e), success=0, error_step="export")
264
+ if verbose >= 9:
265
+ print("-- graph")
266
+ print(graph)
267
+ return mod
268
+ raise AssertionError(f"Unexpected exporter={exporter!r}")
269
+
270
+
271
+ def _make_exporter_onnx(
272
+ exporter: str,
273
+ model: "torch.nn.Module", # noqa: F821
274
+ inputs: Tuple[Any, ...],
275
+ dynamic_shapes: Optional[Any] = None,
276
+ verbose: int = 0,
277
+ quiet: bool = True,
278
+ ) -> Union[Dict, Tuple[onnx.ModelProto, Any]]:
279
+ from ...helpers import string_type
280
+
281
+ if exporter.startswith("custom"):
282
+ from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
283
+
284
+ opts = {}
285
+ opts["strict"] = "-nostrict" not in exporter
286
+ opts["fallback"] = "-fallback" in exporter
287
+ opts["tracing"] = "-tracing" in exporter
288
+ opts["jit"] = "-jit" in exporter
289
+ if "-dec" in exporter:
290
+ opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
291
+ try:
292
+ onx, builder = to_onnx(
293
+ model,
294
+ inputs,
295
+ dynamic_shapes=dynamic_shapes,
296
+ export_options=ExportOptions(**opts),
297
+ return_builder=True,
298
+ )
299
+ except Exception as e:
300
+ if not quiet:
301
+ raise RuntimeError(
302
+ f"Unable to convert model={model.__class__.__name__}, "
303
+ f"input={string_type(inputs[0], with_shape=True)}, "
304
+ f"dynamic_shapes={dynamic_shapes}, "
305
+ f"exporter={exporter!r}"
306
+ ) from e
307
+ return dict(error=str(e), success=0, error_step="export")
308
+ return onx, builder
309
+ if exporter == "dynamo":
310
+ import torch
311
+
312
+ try:
313
+ if verbose >= 2:
314
+ onx = torch.onnx.export(
315
+ model,
316
+ inputs,
317
+ dynamic_shapes=dynamic_shapes,
318
+ dynamo=True,
319
+ report=True,
320
+ ).model_proto
321
+ else:
322
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
323
+ io.StringIO()
324
+ ):
325
+ onx = torch.onnx.export(
326
+ model,
327
+ inputs,
328
+ dynamic_shapes=dynamic_shapes,
329
+ dynamo=True,
330
+ ).model_proto
331
+ except Exception as e:
332
+ if not quiet:
333
+ raise RuntimeError(
334
+ f"Unable to convert model={model.__class__.__name__}, "
335
+ f"input={string_type(inputs[0], with_shape=True)}, "
336
+ f"dynamic_shapes={dynamic_shapes}, "
337
+ f"exporter={exporter!r}"
338
+ ) from e
339
+ return dict(error=str(e), success=0, error_step="export")
340
+ return onx, None
341
+ if exporter == "dynamo-ir":
342
+ import torch
343
+
344
+ try:
345
+ if verbose >= 2:
346
+ ep = torch.onnx.export(
347
+ model,
348
+ inputs,
349
+ dynamic_shapes=dynamic_shapes,
350
+ dynamo=True,
351
+ report=True,
352
+ )
353
+ ep.optimize()
354
+ onx = ep.model_proto
355
+ else:
356
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
357
+ io.StringIO()
358
+ ):
359
+ ep = torch.onnx.export(
360
+ model,
361
+ inputs,
362
+ dynamic_shapes=dynamic_shapes,
363
+ dynamo=True,
364
+ )
365
+ ep.optimize()
366
+ onx = ep.model_proto
367
+ except Exception as e:
368
+ if not quiet:
369
+ raise RuntimeError(
370
+ f"Unable to convert model={model.__class__.__name__}, "
371
+ f"input={string_type(inputs[0], with_shape=True)}, "
372
+ f"dynamic_shapes={dynamic_shapes}, "
373
+ f"exporter={exporter!r}"
374
+ ) from e
375
+ return dict(error=str(e), success=0, error_step="export")
376
+ return onx, None
377
+ raise AssertionError(f"Unexpected exporter={exporter!r}")
378
+
379
+
380
+ def run_exporter(
381
+ exporter: str,
382
+ cls_model: type,
383
+ dynamic: bool = False,
384
+ quiet: bool = False,
385
+ verbose: int = 0,
386
+ ) -> Dict[str, Any]:
387
+ """
388
+ Runs an exporter and returns whether it fails or not.
389
+
390
+ :param exporter: exporter
391
+ :param cls_model: model class to create
392
+ :param inputs: list of inputs to try
393
+ :param dynamic: use dynamic shape or not
394
+ :param quiet: raise exception or not
395
+ :param verbose: verbosity
396
+ :return: results
397
+ """
398
+ from onnx_diagnostic.helpers import max_diff, string_type
399
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
400
+
401
+ assert hasattr(
402
+ cls_model, "_inputs"
403
+ ), f"Attribute '_inputs' is missing from class {cls_model}"
404
+
405
+ model = cls_model()
406
+ inputs = cls_model._inputs
407
+ if isinstance(inputs, tuple):
408
+ inputs = [inputs]
409
+ if dynamic:
410
+ assert hasattr(
411
+ cls_model, "_dynamic"
412
+ ), f"Attribute '_inputs' is missing from class {cls_model}"
413
+ dynamic_shapes = cls_model._dynamic
414
+ else:
415
+ dynamic_shapes = None
416
+
417
+ base = dict(inputs=inputs, model=model, dynamic_shapes=dynamic_shapes)
418
+
419
+ if verbose > 0:
420
+ print(
421
+ f"[run_exporter] exporter={exporter}, model={cls_model.__name__}, "
422
+ f"dynamic={dynamic}, inputs={string_type(inputs, with_shape=True)}"
423
+ )
424
+
425
+ builder = None
426
+ onx = None
427
+
428
+ if exporter.startswith("export-"):
429
+ mod = _make_exporter_export(
430
+ exporter,
431
+ model,
432
+ inputs[0],
433
+ dynamic_shapes=dynamic_shapes,
434
+ verbose=verbose,
435
+ quiet=quiet,
436
+ )
437
+ if isinstance(mod, dict):
438
+ # something went wrong
439
+ return mod
440
+ else:
441
+ res = _make_exporter_onnx(
442
+ exporter,
443
+ model,
444
+ inputs[0],
445
+ dynamic_shapes=dynamic_shapes,
446
+ verbose=verbose,
447
+ quiet=quiet,
448
+ )
449
+ if isinstance(res, dict):
450
+ # something went wrong
451
+ return res
452
+
453
+ onx, builder = res
454
+ if verbose >= 9:
455
+ print("[run_exporter] onnx model")
456
+ print(
457
+ builder.pretty_text(add_fx_graph=True)
458
+ if builder is not None
459
+ else pretty_onnx(onx)
460
+ )
461
+ if verbose >= 2:
462
+ onnx.save(onx, f"evaluation-{model.__class__.__name__}-{dynamic}-{exporter}.onnx")
463
+
464
+ names = [i.name for i in onx.graph.input]
465
+ flats = _flatten_inputs(inputs[0]) if len(names) > len(inputs[0]) else inputs[0]
466
+
467
+ assert quiet or len(names) == len(flats), (
468
+ f"Input mismatch, inputs[0]={string_type(inputs[0])} "
469
+ f"inputs but names={names!r}, "
470
+ f"model={cls_model.__name__}, export={exporter!r}"
471
+ )
472
+ if len(names) != len(flats):
473
+ res = dict(
474
+ error=f"Input mismatch, inputs[0]={string_type(inputs[0])} "
475
+ f"but names={names!r}, model={cls_model.__name__}, export={exporter!r}",
476
+ success=0,
477
+ error_step="inputs",
478
+ )
479
+ res.update(base)
480
+ return res
481
+
482
+ import onnxruntime
483
+
484
+ try:
485
+ sess = onnxruntime.InferenceSession(
486
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
487
+ )
488
+ except Exception as e:
489
+ if not quiet:
490
+ raise
491
+ res = dict(error=str(e), success=0, error_step="ort-init")
492
+ res.update(base)
493
+ return res
494
+
495
+ mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
496
+
497
+ # we need to clone for models modifying the inputs
498
+ try:
499
+ expected = model(*_clone(inputs[0]))
500
+ except Exception as e:
501
+ if not quiet:
502
+ raise RuntimeError(
503
+ f"eager mode failed=\n{string_type(inputs[0], with_shape=True)} "
504
+ f"\nmodel=\n{type(model)}"
505
+ ) from e
506
+ res = dict(error=str(e), success=0, error_step="eager")
507
+ res.update(base)
508
+ return res
509
+ try:
510
+ got = mod(*inputs[0])
511
+ except Exception as e:
512
+ if not quiet:
513
+ raise RuntimeError(
514
+ f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} "
515
+ f"\nmodel=\n{pretty_onnx(onx)}"
516
+ ) from e
517
+ res = dict(error=str(e), success=0, error_step="run.0")
518
+ res.update(base)
519
+ return res
520
+
521
+ base["expected"] = expected
522
+ base["obtained"] = got
523
+
524
+ try:
525
+ disc = max_diff(expected, got)
526
+ except Exception as e:
527
+ if not quiet:
528
+ raise
529
+ res = dict(error=str(e), success=0, error_step="discrepancy")
530
+ res.update(base)
531
+ return res
532
+
533
+ if verbose >= 5 and np.isinf(disc["abs"]):
534
+ print("[run_exporter] comparison issues with")
535
+ print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
536
+ print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
537
+ print(f"-- got={string_type(got, with_shape=True, limit=20)}")
538
+ elif verbose >= 9:
539
+ print("[run_exporter] inputs and outputs")
540
+ print(
541
+ f"-- inputs="
542
+ f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
543
+ )
544
+ print(
545
+ f"-- expected="
546
+ f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
547
+ )
548
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
549
+ del disc["n"]
550
+ del disc["sum"]
551
+ disc.update(
552
+ dict(
553
+ success=1 if disc["abs"] < 0.1 else 0,
554
+ model_cls=model.__class__,
555
+ exported=mod, # type: ignore[dict-item]
556
+ onnx=onx, # type: ignore[dict-item]
557
+ )
558
+ )
559
+ if disc["abs"] >= 0.1:
560
+ disc["error"] = "diff.0"
561
+ disc["error_step"] = "diff.0"
562
+ if verbose >= 9:
563
+ max_diff(expected, got, verbose=verbose)
564
+ else:
565
+ disc["success"] = 1
566
+
567
+ if dynamic and onx is not None:
568
+ ds = []
569
+ for i in onx.graph.input:
570
+ if i.type.tensor_type:
571
+ for di, dim in enumerate(i.type.tensor_type.shape.dim):
572
+ if dim.dim_param:
573
+ ds.append((i.name, di, dim.dim_param))
574
+ if verbose >= 2:
575
+ print(f"[run_exporter] dynamic dimension={ds}")
576
+ if not ds:
577
+ return dict(error="no dynamic shape", success=0, error_step="dynamic")
578
+
579
+ if dynamic and len(inputs) > 1:
580
+ for index, i in enumerate(inputs):
581
+ expected = model(*_clone(i))
582
+ try:
583
+ got = mod(*i)
584
+ except Exception as e:
585
+ if not quiet:
586
+ raise RuntimeError(
587
+ f"onnxruntime failed,\n-- feeds=\n{string_type(i, with_shape=True)} "
588
+ f"exporter={exporter!r}, dynamic_shapes={dynamic_shapes}"
589
+ f"\n-- model=\n{pretty_onnx(onx) if onx is not None else type(model)}"
590
+ ) from e
591
+ return dict(error=str(e), success=0, error_step=f"run.{index}")
592
+
593
+ try:
594
+ d = max_diff(expected, got)
595
+ except Exception as e:
596
+ if not quiet:
597
+ raise
598
+ return dict(error=str(e), success=0, error_step=f"discrepancy.{index}")
599
+
600
+ if verbose >= 5 and np.isinf(d["abs"]):
601
+ print(f"[run_exporter] comparison issues iteration {index}")
602
+ print(f"-- inputs={string_type(i, with_shape=True)}")
603
+ print(f"-- expected={string_type(expected, with_shape=True)}")
604
+ print(f"-- got={string_type(got, with_shape=True)}")
605
+ elif verbose >= 9:
606
+ print(f"[run_exporter] inputs and outputs iteration {index}")
607
+ print(f"-- inputs={string_type(i, with_shape=True, with_min_max=True)}")
608
+ print(
609
+ f"-- expected={string_type(expected, with_shape=True, with_min_max=True)}"
610
+ )
611
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True)}")
612
+ del d["n"]
613
+ del d["sum"]
614
+ if d["abs"] >= 0.1:
615
+ d["error"] = f"diff.{index}"
616
+ d["error_step"] = f"diff.{index}"
617
+ d["success"] = 0
618
+ disc.update(d)
619
+
620
+ disc.update(base)
621
+ return disc