onnx-diagnostic 0.2.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +452 -0
  4. onnx_diagnostic/doc.py +4 -4
  5. onnx_diagnostic/export/__init__.py +2 -1
  6. onnx_diagnostic/export/dynamic_shapes.py +574 -23
  7. onnx_diagnostic/export/validate.py +170 -0
  8. onnx_diagnostic/ext_test_case.py +151 -31
  9. onnx_diagnostic/helpers/__init__.py +1 -0
  10. onnx_diagnostic/helpers/bench_run.py +450 -0
  11. onnx_diagnostic/helpers/cache_helper.py +216 -0
  12. onnx_diagnostic/helpers/config_helper.py +80 -0
  13. onnx_diagnostic/{helpers.py → helpers/helper.py} +341 -662
  14. onnx_diagnostic/helpers/memory_peak.py +249 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +921 -0
  16. onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +4 -3
  17. onnx_diagnostic/helpers/rt_helper.py +47 -0
  18. onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +149 -55
  19. onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
  20. onnx_diagnostic/reference/ort_evaluator.py +7 -2
  21. onnx_diagnostic/tasks/__init__.py +48 -0
  22. onnx_diagnostic/tasks/automatic_speech_recognition.py +165 -0
  23. onnx_diagnostic/tasks/fill_mask.py +67 -0
  24. onnx_diagnostic/tasks/image_classification.py +96 -0
  25. onnx_diagnostic/tasks/image_text_to_text.py +145 -0
  26. onnx_diagnostic/tasks/sentence_similarity.py +67 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +172 -0
  28. onnx_diagnostic/tasks/text_classification.py +67 -0
  29. onnx_diagnostic/tasks/text_generation.py +248 -0
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +106 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +111 -146
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +346 -57
  33. onnx_diagnostic/torch_export_patches/patch_inputs.py +203 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +41 -2
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +39 -49
  36. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  37. onnx_diagnostic/torch_models/hghub/hub_api.py +254 -0
  38. onnx_diagnostic/torch_models/hghub/hub_data.py +203 -0
  39. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3571 -0
  40. onnx_diagnostic/torch_models/hghub/model_inputs.py +151 -0
  41. onnx_diagnostic/torch_models/test_helper.py +1250 -0
  42. onnx_diagnostic/torch_models/untrained/llm_phi2.py +3 -4
  43. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +3 -4
  44. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  45. onnx_diagnostic/torch_onnx/sbs.py +439 -0
  46. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/METADATA +14 -4
  47. onnx_diagnostic-0.4.0.dist-info/RECORD +86 -0
  48. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/WHEEL +1 -1
  49. onnx_diagnostic/cache_helpers.py +0 -104
  50. onnx_diagnostic/onnx_tools.py +0 -260
  51. onnx_diagnostic-0.2.2.dist-info/RECORD +0 -59
  52. /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
  53. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
  54. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.2.2"
6
+ __version__ = "0.4.0"
7
7
  __author__ = "Xavier Dupré"
@@ -0,0 +1,4 @@
1
+ from ._command_lines_parser import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -0,0 +1,452 @@
1
+ import argparse
2
+ import json
3
+ import sys
4
+ import textwrap
5
+ import onnx
6
+ from typing import Any, List, Optional
7
+ from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
8
+ from textwrap import dedent
9
+
10
+
11
+ def get_parser_lighten() -> ArgumentParser:
12
+ parser = ArgumentParser(
13
+ prog="lighten",
14
+ description=dedent(
15
+ """
16
+ Removes the weights from a heavy model, stores statistics to restore
17
+ random weights.
18
+ """
19
+ ),
20
+ epilog="This is mostly used to write unit tests without adding "
21
+ "a big onnx file to the repository.",
22
+ )
23
+ parser.add_argument(
24
+ "-i",
25
+ "--input",
26
+ type=str,
27
+ required=True,
28
+ help="onnx model to lighten",
29
+ )
30
+ parser.add_argument(
31
+ "-o",
32
+ "--output",
33
+ type=str,
34
+ required=True,
35
+ help="onnx model to output",
36
+ )
37
+ parser.add_argument(
38
+ "-v",
39
+ "--verbose",
40
+ default=0,
41
+ required=False,
42
+ help="verbosity",
43
+ )
44
+ return parser
45
+
46
+
47
+ def _cmd_lighten(argv: List[Any]):
48
+ from .helpers.onnx_helper import onnx_lighten
49
+
50
+ parser = get_parser_lighten()
51
+ args = parser.parse_args(argv[1:])
52
+ onx = onnx.load(args.input)
53
+ new_onx, stats = onnx_lighten(onx, verbose=args.verbose)
54
+ jstats = json.dumps(stats)
55
+ if args.verbose:
56
+ print("save file {args.input!r}")
57
+ if args.verbose:
58
+ print("write file {args.output!r}")
59
+ with open(args.output, "wb") as f:
60
+ f.write(new_onx.SerializeToString())
61
+ name = f"{args.output}.stats"
62
+ with open(name, "w") as f:
63
+ f.write(jstats)
64
+ if args.verbose:
65
+ print("done")
66
+
67
+
68
+ def get_parser_unlighten() -> ArgumentParser:
69
+ parser = ArgumentParser(
70
+ prog="unlighten",
71
+ description=dedent(
72
+ """
73
+ Restores random weights for a model reduces with command lighten,
74
+ the command expects to find a file nearby with extension '.stats'.
75
+ """
76
+ ),
77
+ epilog="This is mostly used to write unit tests without adding "
78
+ "a big onnx file to the repository.",
79
+ )
80
+ parser.add_argument(
81
+ "-i",
82
+ "--input",
83
+ type=str,
84
+ required=True,
85
+ help="onnx model to unlighten",
86
+ )
87
+ parser.add_argument(
88
+ "-o",
89
+ "--output",
90
+ type=str,
91
+ required=True,
92
+ help="onnx model to output",
93
+ )
94
+ parser.add_argument(
95
+ "-v",
96
+ "--verbose",
97
+ default=0,
98
+ required=False,
99
+ help="verbosity",
100
+ )
101
+ return parser
102
+
103
+
104
+ def _cmd_unlighten(argv: List[Any]):
105
+ from .helpers.onnx_helper import onnx_unlighten
106
+
107
+ parser = get_parser_lighten()
108
+ args = parser.parse_args(argv[1:])
109
+ new_onx = onnx_unlighten(args.input, verbose=args.verbose)
110
+ if args.verbose:
111
+ print(f"save file {args.output}")
112
+ with open(args.output, "wb") as f:
113
+ f.write(new_onx.SerializeToString())
114
+ if args.verbose:
115
+ print("done")
116
+
117
+
118
+ def get_parser_print() -> ArgumentParser:
119
+ parser = ArgumentParser(
120
+ prog="print",
121
+ description=dedent(
122
+ """
123
+ Prints the model on the standard output.
124
+ """
125
+ ),
126
+ epilog="To show a model.",
127
+ )
128
+ parser.add_argument(
129
+ "fmt", choices=["pretty", "raw"], help="Format to use.", default="pretty"
130
+ )
131
+ parser.add_argument("input", type=str, help="onnx model to load")
132
+ return parser
133
+
134
+
135
+ def _cmd_print(argv: List[Any]):
136
+ parser = get_parser_print()
137
+ args = parser.parse_args(argv[1:])
138
+ onx = onnx.load(args.input)
139
+ if args.fmt == "raw":
140
+ print(onx)
141
+ elif args.fmt == "pretty":
142
+ from .helpers.onnx_helper import pretty_onnx
143
+
144
+ print(pretty_onnx(onx))
145
+ else:
146
+ raise ValueError(f"Unexpected value fmt={args.fmt!r}")
147
+
148
+
149
+ def get_parser_find() -> ArgumentParser:
150
+ parser = ArgumentParser(
151
+ prog="find",
152
+ description=dedent(
153
+ """
154
+ Look into a model and search for a set of names,
155
+ tells which node is consuming or producing it.
156
+ """
157
+ ),
158
+ epilog="Enables Some quick validation.",
159
+ )
160
+ parser.add_argument(
161
+ "-i",
162
+ "--input",
163
+ type=str,
164
+ required=True,
165
+ help="onnx model to unlighten",
166
+ )
167
+ parser.add_argument(
168
+ "-n",
169
+ "--names",
170
+ type=str,
171
+ required=False,
172
+ help="names to look at comma separated values",
173
+ )
174
+ parser.add_argument(
175
+ "-v",
176
+ "--verbose",
177
+ default=0,
178
+ required=False,
179
+ help="verbosity",
180
+ )
181
+ return parser
182
+
183
+
184
+ def _cmd_find(argv: List[Any]):
185
+ from .helpers.onnx_helper import onnx_find
186
+
187
+ parser = get_parser_find()
188
+ args = parser.parse_args(argv[1:])
189
+ onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
190
+
191
+
192
+ def get_parser_config() -> ArgumentParser:
193
+ parser = ArgumentParser(
194
+ prog="config",
195
+ description=dedent(
196
+ """
197
+ Prints out a configuration for a model id,
198
+ prints the associated task as well.
199
+ """
200
+ ),
201
+ epilog="",
202
+ )
203
+ parser.add_argument(
204
+ "-m",
205
+ "--mid",
206
+ type=str,
207
+ required=True,
208
+ help="model id, usually <author>/<name>",
209
+ )
210
+ parser.add_argument(
211
+ "-t",
212
+ "--task",
213
+ default=False,
214
+ action=BooleanOptionalAction,
215
+ help="displays the task as well",
216
+ )
217
+ return parser
218
+
219
+
220
+ def _cmd_config(argv: List[Any]):
221
+ from .torch_models.hghub.hub_api import get_pretrained_config, task_from_id
222
+
223
+ parser = get_parser_config()
224
+ args = parser.parse_args(argv[1:])
225
+ print(get_pretrained_config(args.mid))
226
+ if args.task:
227
+ print("------")
228
+ print(f"task: {task_from_id(args.mid)}")
229
+
230
+
231
+ class _ParseDict(argparse.Action):
232
+ def __call__(self, parser, namespace, values, option_string=None):
233
+ d = getattr(namespace, self.dest) or {}
234
+
235
+ if values:
236
+ for item in values:
237
+ split_items = item.split("=", 1)
238
+ key = split_items[0].strip() # we remove blanks around keys, as is logical
239
+ value = split_items[1]
240
+
241
+ d[key] = value
242
+
243
+ setattr(namespace, self.dest, d)
244
+
245
+
246
+ def get_parser_validate() -> ArgumentParser:
247
+ parser = ArgumentParser(
248
+ prog="test",
249
+ description=dedent(
250
+ """
251
+ Prints out dummy inputs for a particular task or a model id.
252
+ If both mid and task are empty, the command line displays the list
253
+ of supported tasks.
254
+ """
255
+ ),
256
+ epilog="If the model id is specified, one untrained version of it is instantiated.",
257
+ )
258
+ parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
259
+ parser.add_argument("-t", "--task", default=None, help="force the task to use")
260
+ parser.add_argument("-e", "--export", help="export the model with this exporter")
261
+ parser.add_argument("--opt", help="optimization to apply after the export")
262
+ parser.add_argument(
263
+ "-r",
264
+ "--run",
265
+ default=False,
266
+ action=BooleanOptionalAction,
267
+ help="runs the model to check it runs",
268
+ )
269
+ parser.add_argument(
270
+ "-q",
271
+ "--quiet",
272
+ default=False,
273
+ action=BooleanOptionalAction,
274
+ help="catches exception, report them in the summary",
275
+ )
276
+ parser.add_argument(
277
+ "-p",
278
+ "--patch",
279
+ default=True,
280
+ action=BooleanOptionalAction,
281
+ help="applies patches before exporting",
282
+ )
283
+ parser.add_argument(
284
+ "--stop-if-static",
285
+ default=0,
286
+ type=int,
287
+ help="raises an exception if a dynamic dimension becomes static",
288
+ )
289
+ parser.add_argument(
290
+ "--trained",
291
+ default=False,
292
+ action=BooleanOptionalAction,
293
+ help="validate the trained model (requires downloading)",
294
+ )
295
+ parser.add_argument(
296
+ "-o",
297
+ "--dump-folder",
298
+ help="if not empty, a folder is created to dumps statistics, "
299
+ "exported program, onnx...",
300
+ )
301
+ parser.add_argument(
302
+ "--drop",
303
+ help="drops the following inputs names, it should be a list "
304
+ "with comma separated values",
305
+ )
306
+ parser.add_argument(
307
+ "--ortfusiontype",
308
+ required=False,
309
+ help="applies onnxruntime fusion, this parameter should contain the "
310
+ "model type or multiple values separated by `|`. `ALL` can be used "
311
+ "to run them all",
312
+ )
313
+ parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
314
+ parser.add_argument("--dtype", help="changes dtype if necessary")
315
+ parser.add_argument("--device", help="changes the device if necessary")
316
+ parser.add_argument(
317
+ "--iop",
318
+ metavar="KEY=VALUE",
319
+ nargs="*",
320
+ help="Additional input options, use to change the default "
321
+ "inputs use to export, example: --iop cls_cache=SlidingWindowCache",
322
+ action=_ParseDict,
323
+ )
324
+ return parser
325
+
326
+
327
+ def _cmd_validate(argv: List[Any]):
328
+ from .helpers import string_type
329
+ from .torch_models.test_helper import get_inputs_for_task, validate_model
330
+ from .tasks import supported_tasks
331
+
332
+ parser = get_parser_validate()
333
+ args = parser.parse_args(argv[1:])
334
+ if not args.task and not args.mid:
335
+ print("-- list of supported tasks:")
336
+ print("\n".join(supported_tasks()))
337
+ elif not args.mid:
338
+ data = get_inputs_for_task(args.task)
339
+ if args.verbose:
340
+ print(f"task: {args.task}")
341
+ max_length = max(len(k) for k in data["inputs"]) + 1
342
+ print("-- inputs")
343
+ for k, v in data["inputs"].items():
344
+ print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
345
+ print("-- dynamic_shapes")
346
+ for k, v in data["dynamic_shapes"].items():
347
+ print(f" + {k.ljust(max_length)}: {string_type(v)}")
348
+ else:
349
+ # Let's skip any invalid combination if known to be unsupported
350
+ if (
351
+ "onnx" not in (args.export or "")
352
+ and "custom" not in (args.export or "")
353
+ and (args.opt or "")
354
+ ):
355
+ print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
356
+ return
357
+ summary, _data = validate_model(
358
+ model_id=args.mid,
359
+ task=args.task,
360
+ do_run=args.run,
361
+ verbose=args.verbose,
362
+ quiet=args.quiet,
363
+ trained=args.trained,
364
+ dtype=args.dtype,
365
+ device=args.device,
366
+ patch=args.patch,
367
+ stop_if_static=args.stop_if_static,
368
+ optimization=args.opt,
369
+ exporter=args.export,
370
+ dump_folder=args.dump_folder,
371
+ drop_inputs=None if not args.drop else args.drop.split(","),
372
+ ortfusiontype=args.ortfusiontype,
373
+ input_options=args.iop,
374
+ )
375
+ print("")
376
+ print("-- summary --")
377
+ for k, v in sorted(summary.items()):
378
+ print(f":{k},{v};")
379
+
380
+
381
+ def get_main_parser() -> ArgumentParser:
382
+ parser = ArgumentParser(
383
+ prog="onnx_diagnostic",
384
+ description="onnx_diagnostic main command line.\n",
385
+ formatter_class=RawTextHelpFormatter,
386
+ epilog=textwrap.dedent(
387
+ """
388
+ Type 'python -m onnx_diagnostic <cmd> --help'
389
+ to get help for a specific command.
390
+
391
+ config - prints a configuration for a model id
392
+ find - find node consuming or producing a result
393
+ lighten - makes an onnx model lighter by removing the weights,
394
+ unlighten - restores an onnx model produces by the previous experiment
395
+ print - prints the model on standard output
396
+ validate - validate a model
397
+ """
398
+ ),
399
+ )
400
+ parser.add_argument(
401
+ "cmd",
402
+ choices=["config", "find", "lighten", "print", "unlighten", "validate"],
403
+ help="Selects a command.",
404
+ )
405
+ return parser
406
+
407
+
408
+ def main(argv: Optional[List[Any]] = None):
409
+ fcts = dict(
410
+ lighten=_cmd_lighten,
411
+ unlighten=_cmd_unlighten,
412
+ print=_cmd_print,
413
+ find=_cmd_find,
414
+ config=_cmd_config,
415
+ validate=_cmd_validate,
416
+ )
417
+
418
+ if argv is None:
419
+ argv = sys.argv[1:]
420
+ if (
421
+ len(argv) == 0
422
+ or (len(argv) <= 1 and argv[0] not in fcts)
423
+ or argv[-1] in ("--help", "-h")
424
+ ):
425
+ if len(argv) < 2:
426
+ parser = get_main_parser()
427
+ parser.parse_args(argv)
428
+ else:
429
+ parsers = dict(
430
+ lighten=get_parser_lighten,
431
+ unlighten=get_parser_unlighten,
432
+ print=get_parser_print,
433
+ find=get_parser_find,
434
+ config=get_parser_config,
435
+ validate=get_parser_validate,
436
+ )
437
+ cmd = argv[0]
438
+ if cmd not in parsers:
439
+ raise ValueError(
440
+ f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
441
+ )
442
+ parser = parsers[cmd]()
443
+ parser.parse_args(argv[1:])
444
+ raise RuntimeError("The programme should have exited before.")
445
+
446
+ cmd = argv[0]
447
+ if cmd in fcts:
448
+ fcts[cmd](argv)
449
+ else:
450
+ raise ValueError(
451
+ f"Unknown command {cmd!r}, use --help to get the list of known command."
452
+ )
onnx_diagnostic/doc.py CHANGED
@@ -8,14 +8,14 @@ def reset_torch_transformers(gallery_conf, fname):
8
8
 
9
9
 
10
10
  def plot_legend(
11
- text: str, text_bottom: str = "", color: str = "green", fontsize: int = 35
11
+ text: str, text_bottom: str = "", color: str = "green", fontsize: int = 15
12
12
  ) -> "matplotlib.axes.Axes": # noqa: F821
13
13
  import matplotlib.pyplot as plt
14
14
 
15
- fig = plt.figure()
15
+ fig = plt.figure(figsize=(2, 2))
16
16
  ax = fig.add_subplot()
17
17
  ax.axis([0, 5, 0, 5])
18
- ax.text(2.5, 4, "END", fontsize=50, horizontalalignment="center")
18
+ ax.text(2.5, 4, "END", fontsize=10, horizontalalignment="center")
19
19
  ax.text(
20
20
  2.5,
21
21
  2.5,
@@ -26,7 +26,7 @@ def plot_legend(
26
26
  verticalalignment="center",
27
27
  )
28
28
  if text_bottom:
29
- ax.text(4.5, 0.5, text_bottom, fontsize=20, horizontalalignment="right")
29
+ ax.text(4.5, 0.5, text_bottom, fontsize=7, horizontalalignment="right")
30
30
  ax.grid(False)
31
31
  ax.set_axis_off()
32
32
  return ax
@@ -1 +1,2 @@
1
- from .dynamic_shapes import ModelInputs
1
+ from .dynamic_shapes import CoupleInputsDynamicShapes, ModelInputs
2
+ from .validate import validate_ep