onnx-diagnostic 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +411 -0
  4. onnx_diagnostic/doc.py +4 -4
  5. onnx_diagnostic/export/__init__.py +1 -1
  6. onnx_diagnostic/export/dynamic_shapes.py +433 -22
  7. onnx_diagnostic/ext_test_case.py +86 -29
  8. onnx_diagnostic/helpers/__init__.py +1 -0
  9. onnx_diagnostic/helpers/bench_run.py +450 -0
  10. onnx_diagnostic/{cache_helpers.py → helpers/cache_helper.py} +41 -5
  11. onnx_diagnostic/{helpers.py → helpers/helper.py} +136 -659
  12. onnx_diagnostic/helpers/memory_peak.py +249 -0
  13. onnx_diagnostic/helpers/onnx_helper.py +921 -0
  14. onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +42 -3
  15. onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +138 -55
  16. onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
  17. onnx_diagnostic/reference/ort_evaluator.py +7 -2
  18. onnx_diagnostic/torch_export_patches/__init__.py +107 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +137 -33
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +13 -2
  21. onnx_diagnostic/torch_export_patches/patch_inputs.py +174 -0
  22. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -2
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +4 -4
  24. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  25. onnx_diagnostic/torch_models/hghub/hub_api.py +234 -0
  26. onnx_diagnostic/torch_models/hghub/hub_data.py +195 -0
  27. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3259 -0
  28. onnx_diagnostic/torch_models/hghub/model_inputs.py +727 -0
  29. onnx_diagnostic/torch_models/test_helper.py +827 -0
  30. onnx_diagnostic/torch_models/untrained/llm_phi2.py +3 -4
  31. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +3 -4
  32. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  33. onnx_diagnostic/torch_onnx/sbs.py +439 -0
  34. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/METADATA +2 -2
  35. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/RECORD +39 -25
  36. onnx_diagnostic/onnx_tools.py +0 -260
  37. /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
  38. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/WHEEL +0 -0
  39. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/licenses/LICENSE.txt +0 -0
  40. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.2.2"
6
+ __version__ = "0.3.0"
7
7
  __author__ = "Xavier Dupré"
@@ -0,0 +1,4 @@
1
+ from ._command_lines_parser import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -0,0 +1,411 @@
1
+ import json
2
+ import sys
3
+ import textwrap
4
+ import onnx
5
+ from typing import Any, List, Optional
6
+ from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
7
+ from textwrap import dedent
8
+
9
+
10
+ def get_parser_lighten() -> ArgumentParser:
11
+ parser = ArgumentParser(
12
+ prog="lighten",
13
+ description=dedent(
14
+ """
15
+ Removes the weights from a heavy model, stores statistics to restore
16
+ random weights.
17
+ """
18
+ ),
19
+ epilog="This is mostly used to write unit tests without adding "
20
+ "a big onnx file to the repository.",
21
+ )
22
+ parser.add_argument(
23
+ "-i",
24
+ "--input",
25
+ type=str,
26
+ required=True,
27
+ help="onnx model to lighten",
28
+ )
29
+ parser.add_argument(
30
+ "-o",
31
+ "--output",
32
+ type=str,
33
+ required=True,
34
+ help="onnx model to output",
35
+ )
36
+ parser.add_argument(
37
+ "-v",
38
+ "--verbose",
39
+ default=0,
40
+ required=False,
41
+ help="verbosity",
42
+ )
43
+ return parser
44
+
45
+
46
+ def _cmd_lighten(argv: List[Any]):
47
+ from .helpers.onnx_helper import onnx_lighten
48
+
49
+ parser = get_parser_lighten()
50
+ args = parser.parse_args(argv[1:])
51
+ onx = onnx.load(args.input)
52
+ new_onx, stats = onnx_lighten(onx, verbose=args.verbose)
53
+ jstats = json.dumps(stats)
54
+ if args.verbose:
55
+ print("save file {args.input!r}")
56
+ if args.verbose:
57
+ print("write file {args.output!r}")
58
+ with open(args.output, "wb") as f:
59
+ f.write(new_onx.SerializeToString())
60
+ name = f"{args.output}.stats"
61
+ with open(name, "w") as f:
62
+ f.write(jstats)
63
+ if args.verbose:
64
+ print("done")
65
+
66
+
67
+ def get_parser_unlighten() -> ArgumentParser:
68
+ parser = ArgumentParser(
69
+ prog="unlighten",
70
+ description=dedent(
71
+ """
72
+ Restores random weights for a model reduces with command lighten,
73
+ the command expects to find a file nearby with extension '.stats'.
74
+ """
75
+ ),
76
+ epilog="This is mostly used to write unit tests without adding "
77
+ "a big onnx file to the repository.",
78
+ )
79
+ parser.add_argument(
80
+ "-i",
81
+ "--input",
82
+ type=str,
83
+ required=True,
84
+ help="onnx model to unlighten",
85
+ )
86
+ parser.add_argument(
87
+ "-o",
88
+ "--output",
89
+ type=str,
90
+ required=True,
91
+ help="onnx model to output",
92
+ )
93
+ parser.add_argument(
94
+ "-v",
95
+ "--verbose",
96
+ default=0,
97
+ required=False,
98
+ help="verbosity",
99
+ )
100
+ return parser
101
+
102
+
103
+ def _cmd_unlighten(argv: List[Any]):
104
+ from .helpers.onnx_helper import onnx_unlighten
105
+
106
+ parser = get_parser_lighten()
107
+ args = parser.parse_args(argv[1:])
108
+ new_onx = onnx_unlighten(args.input, verbose=args.verbose)
109
+ if args.verbose:
110
+ print(f"save file {args.output}")
111
+ with open(args.output, "wb") as f:
112
+ f.write(new_onx.SerializeToString())
113
+ if args.verbose:
114
+ print("done")
115
+
116
+
117
+ def get_parser_print() -> ArgumentParser:
118
+ parser = ArgumentParser(
119
+ prog="print",
120
+ description=dedent(
121
+ """
122
+ Prints the model on the standard output.
123
+ """
124
+ ),
125
+ epilog="To show a model.",
126
+ )
127
+ parser.add_argument(
128
+ "fmt", choices=["pretty", "raw"], help="Format to use.", default="pretty"
129
+ )
130
+ parser.add_argument("input", type=str, help="onnx model to load")
131
+ return parser
132
+
133
+
134
+ def _cmd_print(argv: List[Any]):
135
+ parser = get_parser_print()
136
+ args = parser.parse_args(argv[1:])
137
+ onx = onnx.load(args.input)
138
+ if args.fmt == "raw":
139
+ print(onx)
140
+ elif args.fmt == "pretty":
141
+ from .helpers.onnx_helper import pretty_onnx
142
+
143
+ print(pretty_onnx(onx))
144
+ else:
145
+ raise ValueError(f"Unexpected value fmt={args.fmt!r}")
146
+
147
+
148
+ def get_parser_find() -> ArgumentParser:
149
+ parser = ArgumentParser(
150
+ prog="find",
151
+ description=dedent(
152
+ """
153
+ Look into a model and search for a set of names,
154
+ tells which node is consuming or producing it.
155
+ """
156
+ ),
157
+ epilog="Enables Some quick validation.",
158
+ )
159
+ parser.add_argument(
160
+ "-i",
161
+ "--input",
162
+ type=str,
163
+ required=True,
164
+ help="onnx model to unlighten",
165
+ )
166
+ parser.add_argument(
167
+ "-n",
168
+ "--names",
169
+ type=str,
170
+ required=False,
171
+ help="names to look at comma separated values",
172
+ )
173
+ parser.add_argument(
174
+ "-v",
175
+ "--verbose",
176
+ default=0,
177
+ required=False,
178
+ help="verbosity",
179
+ )
180
+ return parser
181
+
182
+
183
+ def _cmd_find(argv: List[Any]):
184
+ from .helpers.onnx_helper import onnx_find
185
+
186
+ parser = get_parser_find()
187
+ args = parser.parse_args(argv[1:])
188
+ onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
189
+
190
+
191
+ def get_parser_config() -> ArgumentParser:
192
+ parser = ArgumentParser(
193
+ prog="config",
194
+ description=dedent(
195
+ """
196
+ Prints out a configuration for a model id,
197
+ prints the associated task as well.
198
+ """
199
+ ),
200
+ epilog="",
201
+ )
202
+ parser.add_argument(
203
+ "-m",
204
+ "--mid",
205
+ type=str,
206
+ required=True,
207
+ help="model id, usually <author>/<name>",
208
+ )
209
+ parser.add_argument(
210
+ "-t",
211
+ "--task",
212
+ default=False,
213
+ action=BooleanOptionalAction,
214
+ help="displays the task as well",
215
+ )
216
+ return parser
217
+
218
+
219
+ def _cmd_config(argv: List[Any]):
220
+ from .torch_models.hghub.hub_api import get_pretrained_config, task_from_id
221
+
222
+ parser = get_parser_config()
223
+ args = parser.parse_args(argv[1:])
224
+ print(get_pretrained_config(args.mid))
225
+ if args.task:
226
+ print("------")
227
+ print(f"task: {task_from_id(args.mid)}")
228
+
229
+
230
+ def get_parser_validate() -> ArgumentParser:
231
+ parser = ArgumentParser(
232
+ prog="test",
233
+ description=dedent(
234
+ """
235
+ Prints out dummy inputs for a particular task or a model id.
236
+ If both mid and task are empty, the command line displays the list
237
+ of supported tasks.
238
+ """
239
+ ),
240
+ epilog="If the model id is specified, one untrained version of it is instantiated.",
241
+ )
242
+ parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
243
+ parser.add_argument("-t", "--task", default=None, help="force the task to use")
244
+ parser.add_argument("-e", "--export", help="export the model with this exporter")
245
+ parser.add_argument("--opt", help="optimization to apply after the export")
246
+ parser.add_argument(
247
+ "-r",
248
+ "--run",
249
+ default=False,
250
+ action=BooleanOptionalAction,
251
+ help="runs the model to check it runs",
252
+ )
253
+ parser.add_argument(
254
+ "-q",
255
+ "--quiet",
256
+ default=False,
257
+ action=BooleanOptionalAction,
258
+ help="catches exception, report them in the summary",
259
+ )
260
+ parser.add_argument(
261
+ "-p",
262
+ "--patch",
263
+ default=True,
264
+ action=BooleanOptionalAction,
265
+ help="applies patches before exporting",
266
+ )
267
+ parser.add_argument(
268
+ "--stop-if-static",
269
+ default=0,
270
+ type=int,
271
+ help="raises an exception if a dynamic dimension becomes static",
272
+ )
273
+ parser.add_argument(
274
+ "--trained",
275
+ default=False,
276
+ action=BooleanOptionalAction,
277
+ help="validate the trained model (requires downloading)",
278
+ )
279
+ parser.add_argument(
280
+ "-o",
281
+ "--dump-folder",
282
+ help="if not empty, a folder is created to dumps statistics, "
283
+ "exported program, onnx...",
284
+ )
285
+ parser.add_argument(
286
+ "--drop",
287
+ help="drops the following inputs names, it should be a list "
288
+ "with comma separated values",
289
+ )
290
+ parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
291
+ parser.add_argument("--dtype", help="changes dtype if necessary")
292
+ parser.add_argument("--device", help="changes the device if necessary")
293
+ return parser
294
+
295
+
296
+ def _cmd_validate(argv: List[Any]):
297
+ from .helpers import string_type
298
+ from .torch_models.test_helper import get_inputs_for_task, validate_model, _ds_clean
299
+ from .torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
300
+
301
+ parser = get_parser_validate()
302
+ args = parser.parse_args(argv[1:])
303
+ if not args.task and not args.mid:
304
+ print("-- list of supported tasks:")
305
+ print("\n".join(sorted(get_get_inputs_function_for_tasks())))
306
+ elif not args.mid:
307
+ data = get_inputs_for_task(args.task)
308
+ if args.verbose:
309
+ print(f"task: {args.task}")
310
+ max_length = max(len(k) for k in data["inputs"]) + 1
311
+ print("-- inputs")
312
+ for k, v in data["inputs"].items():
313
+ print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
314
+ print("-- dynamic_shapes")
315
+ for k, v in data["dynamic_shapes"].items():
316
+ print(f" + {k.ljust(max_length)}: {_ds_clean(v)}")
317
+ else:
318
+ summary, _data = validate_model(
319
+ model_id=args.mid,
320
+ task=args.task,
321
+ do_run=args.run,
322
+ verbose=args.verbose,
323
+ quiet=args.quiet,
324
+ trained=args.trained,
325
+ dtype=args.dtype,
326
+ device=args.device,
327
+ patch=args.patch,
328
+ stop_if_static=args.stop_if_static,
329
+ optimization=args.opt,
330
+ exporter=args.export,
331
+ dump_folder=args.dump_folder,
332
+ drop_inputs=None if not args.drop else args.drop.split(","),
333
+ )
334
+ print("")
335
+ print("-- summary --")
336
+ for k, v in sorted(summary.items()):
337
+ print(f":{k},{v};")
338
+
339
+
340
+ def get_main_parser() -> ArgumentParser:
341
+ parser = ArgumentParser(
342
+ prog="onnx_diagnostic",
343
+ description="onnx_diagnostic main command line.\n",
344
+ formatter_class=RawTextHelpFormatter,
345
+ epilog=textwrap.dedent(
346
+ """
347
+ Type 'python -m onnx_diagnostic <cmd> --help'
348
+ to get help for a specific command.
349
+
350
+ config - prints a configuration for a model id
351
+ find - find node consuming or producing a result
352
+ lighten - makes an onnx model lighter by removing the weights,
353
+ unlighten - restores an onnx model produces by the previous experiment
354
+ print - prints the model on standard output
355
+ validate - validate a model
356
+ """
357
+ ),
358
+ )
359
+ parser.add_argument(
360
+ "cmd",
361
+ choices=["config", "find", "lighten", "print", "unlighten", "validate"],
362
+ help="Selects a command.",
363
+ )
364
+ return parser
365
+
366
+
367
+ def main(argv: Optional[List[Any]] = None):
368
+ fcts = dict(
369
+ lighten=_cmd_lighten,
370
+ unlighten=_cmd_unlighten,
371
+ print=_cmd_print,
372
+ find=_cmd_find,
373
+ config=_cmd_config,
374
+ validate=_cmd_validate,
375
+ )
376
+
377
+ if argv is None:
378
+ argv = sys.argv[1:]
379
+ if (
380
+ len(argv) == 0
381
+ or (len(argv) <= 1 and argv[0] not in fcts)
382
+ or argv[-1] in ("--help", "-h")
383
+ ):
384
+ if len(argv) < 2:
385
+ parser = get_main_parser()
386
+ parser.parse_args(argv)
387
+ else:
388
+ parsers = dict(
389
+ lighten=get_parser_lighten,
390
+ unlighten=get_parser_unlighten,
391
+ print=get_parser_print,
392
+ find=get_parser_find,
393
+ config=get_parser_config,
394
+ validate=get_parser_validate,
395
+ )
396
+ cmd = argv[0]
397
+ if cmd not in parsers:
398
+ raise ValueError(
399
+ f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
400
+ )
401
+ parser = parsers[cmd]()
402
+ parser.parse_args(argv[1:])
403
+ raise RuntimeError("The programme should have exited before.")
404
+
405
+ cmd = argv[0]
406
+ if cmd in fcts:
407
+ fcts[cmd](argv)
408
+ else:
409
+ raise ValueError(
410
+ f"Unknown command {cmd!r}, use --help to get the list of known command."
411
+ )
onnx_diagnostic/doc.py CHANGED
@@ -8,14 +8,14 @@ def reset_torch_transformers(gallery_conf, fname):
8
8
 
9
9
 
10
10
  def plot_legend(
11
- text: str, text_bottom: str = "", color: str = "green", fontsize: int = 35
11
+ text: str, text_bottom: str = "", color: str = "green", fontsize: int = 15
12
12
  ) -> "matplotlib.axes.Axes": # noqa: F821
13
13
  import matplotlib.pyplot as plt
14
14
 
15
- fig = plt.figure()
15
+ fig = plt.figure(figsize=(2, 2))
16
16
  ax = fig.add_subplot()
17
17
  ax.axis([0, 5, 0, 5])
18
- ax.text(2.5, 4, "END", fontsize=50, horizontalalignment="center")
18
+ ax.text(2.5, 4, "END", fontsize=10, horizontalalignment="center")
19
19
  ax.text(
20
20
  2.5,
21
21
  2.5,
@@ -26,7 +26,7 @@ def plot_legend(
26
26
  verticalalignment="center",
27
27
  )
28
28
  if text_bottom:
29
- ax.text(4.5, 0.5, text_bottom, fontsize=20, horizontalalignment="right")
29
+ ax.text(4.5, 0.5, text_bottom, fontsize=7, horizontalalignment="right")
30
30
  ax.grid(False)
31
31
  ax.set_axis_off()
32
32
  return ax
@@ -1 +1 @@
1
- from .dynamic_shapes import ModelInputs
1
+ from .dynamic_shapes import CoupleInputsDynamicShapes, ModelInputs