onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.4"
6
+ __version__ = "0.8.6"
7
7
  __author__ = "Xavier Dupré"
@@ -198,15 +198,19 @@ def get_parser_print() -> ArgumentParser:
198
198
  )
199
199
  parser.add_argument(
200
200
  "fmt",
201
- choices=["pretty", "raw", "text", "printer"],
201
+ choices=["dot", "pretty", "printer", "raw", "shape", "text"],
202
202
  default="pretty",
203
203
  help=textwrap.dedent(
204
204
  """
205
205
  Prints out a model on the standard output.
206
- raw - just prints the model with print(...)
207
- printer - onnx.printer.to_text(...)
206
+
207
+ dot - converts the graph into dot
208
208
  pretty - an improved rendering
209
+ printer - onnx.printer.to_text(...)
210
+ raw - just prints the model with print(...)
211
+ shape - prints every node node with input and output shapes
209
212
  text - uses GraphRendering
213
+
210
214
  """.strip(
211
215
  "\n"
212
216
  )
@@ -232,6 +236,14 @@ def _cmd_print(argv: List[Any]):
232
236
  from .helpers.graph_helper import GraphRendering
233
237
 
234
238
  print(GraphRendering(onx).text_rendering())
239
+ elif args.fmt == "shape":
240
+ from experimental_experiment.xbuilder import GraphBuilder
241
+
242
+ print(GraphBuilder(onx).pretty_text())
243
+ elif args.fmt == "dot":
244
+ from .helpers.dot_helper import to_dot
245
+
246
+ print(to_dot(onx))
235
247
  else:
236
248
  raise ValueError(f"Unexpected value fmt={args.fmt!r}")
237
249
 
@@ -517,12 +529,12 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
517
529
  nargs="*",
518
530
  help=textwrap.dedent(
519
531
  """
520
- Applies patches before exporting, it can be a boolean
521
- to enable to disable the patches or be more finetuned
522
- (default is True). It is possible to disable patch for torch
523
- by adding:
524
- --patch "patch_sympy=False" --patch "patch_torch=False"
525
- """.strip(
532
+ Applies patches before exporting, it can be a boolean
533
+ to enable to disable the patches or be more finetuned
534
+ (default is True). It is possible to disable patch for torch
535
+ by adding:
536
+ --patch "patch_sympy=False" --patch "patch_torch=False"
537
+ """.strip(
526
538
  "\n"
527
539
  )
528
540
  ),
@@ -1496,6 +1508,50 @@ def _cmd_sbs(argv: List[Any]):
1496
1508
  print("-- done")
1497
1509
 
1498
1510
 
1511
+ def get_parser_compare() -> ArgumentParser:
1512
+ parser = ArgumentParser(
1513
+ prog="compare",
1514
+ description=textwrap.dedent(
1515
+ """
1516
+ Compares two onnx models by aligning the nodes between both models.
1517
+ This is done through an edit distance.
1518
+ """
1519
+ ),
1520
+ epilog=textwrap.dedent(
1521
+ """
1522
+ Each element (initializer, input, node, output) of the model
1523
+ is converted into an observation. Then it defines a distance between
1524
+ two elements. And finally, it finds the best alignment with
1525
+ an edit distance.
1526
+ """
1527
+ ),
1528
+ )
1529
+ parser.add_argument("model1", type=str, help="first model to compare")
1530
+ parser.add_argument("model2", type=str, help="second model to compare")
1531
+ return parser
1532
+
1533
+
1534
+ def _cmd_compare(argv: List[Any]):
1535
+ import onnx
1536
+ from .torch_onnx.compare import ObsCompare, ObsComparePair
1537
+
1538
+ parser = get_parser_compare()
1539
+ args = parser.parse_args(argv[1:])
1540
+ print(f"-- loading {args.model1!r}")
1541
+ seq1 = ObsCompare.obs_sequence_from_model(onnx.load(args.model1, load_external_data=False))
1542
+ print(f"-- loading {args.model2!r}")
1543
+ seq2 = ObsCompare.obs_sequence_from_model(onnx.load(args.model2, load_external_data=False))
1544
+ print("-- starts comparison")
1545
+ dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
1546
+ print(f"-- done with distance {dist}")
1547
+ print(ObsComparePair.to_str(pair_cmp))
1548
+
1549
+
1550
+ #############
1551
+ # main parser
1552
+ #############
1553
+
1554
+
1499
1555
  def get_main_parser() -> ArgumentParser:
1500
1556
  parser = ArgumentParser(
1501
1557
  prog="onnx_diagnostic",
@@ -1543,6 +1599,7 @@ def get_main_parser() -> ArgumentParser:
1543
1599
  def main(argv: Optional[List[Any]] = None):
1544
1600
  fcts = dict(
1545
1601
  agg=_cmd_agg,
1602
+ compare=_cmd_compare,
1546
1603
  config=_cmd_config,
1547
1604
  dot=_cmd_dot,
1548
1605
  exportsample=_cmd_export_sample,
@@ -1568,6 +1625,7 @@ def main(argv: Optional[List[Any]] = None):
1568
1625
  else:
1569
1626
  parsers = dict(
1570
1627
  agg=get_parser_agg,
1628
+ compare=get_parser_compare,
1571
1629
  config=get_parser_config,
1572
1630
  dot=get_parser_dot,
1573
1631
  exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
File without changes
@@ -0,0 +1,430 @@
1
+ import datetime
2
+ import os
3
+ import time
4
+ import subprocess
5
+ from argparse import ArgumentParser, BooleanOptionalAction
6
+ from typing import Any, Dict, List, Tuple
7
+ import onnx
8
+
9
+
10
+ def get_versions():
11
+ """
12
+ Returns the version of the package currently used.
13
+ The output is a dictionary.
14
+ The function uses delayed import to make to fail fast at startup.
15
+ """
16
+ import onnx
17
+ import onnx_diagnostic
18
+ import onnxruntime
19
+ import torch
20
+ import transformers
21
+
22
+ return {
23
+ "transformers": transformers.__version__,
24
+ "onnxruntime": onnxruntime.__version__,
25
+ "onnx": onnx.__version__,
26
+ "onnx-diagnostic": onnx_diagnostic.__version__,
27
+ "torch": torch.__version__,
28
+ }
29
+
30
+
31
+ def get_torch_dtype_from_command_line_args(dtype: str) -> "torch.dtype": # noqa: F821
32
+ """
33
+ Returns the torch dtype base on the argument provided on the command line.
34
+
35
+ Imports are delayed to be faster when running the help of the command line.
36
+ """
37
+ import torch
38
+
39
+ torch_dtype = {
40
+ "float16": torch.float16,
41
+ "bfloat16": torch.bfloat16,
42
+ "float32": torch.float32,
43
+ "fp16": torch.float16,
44
+ "bf16": torch.bfloat16,
45
+ "fp32": torch.float32,
46
+ }
47
+ assert (
48
+ dtype in torch_dtype
49
+ ), f"Unexpected dtype {dtype!r}, not found in {set(torch_dtype)}."
50
+ return torch_dtype[dtype]
51
+
52
+
53
+ def get_parser(name: str) -> ArgumentParser:
54
+ """Creates a default parser for many models."""
55
+ parser = ArgumentParser(
56
+ prog=name, description=f"""Export command line for model {name!r}."""
57
+ )
58
+ parser.add_argument(
59
+ "-m",
60
+ "--mid",
61
+ type=str,
62
+ default="Qwen/Qwen2.5-VL-7B-Instruct",
63
+ help="model id, default is Qwen/Qwen2.5-VL-7B-Instruct",
64
+ )
65
+ parser.add_argument("-d", "--device", default="cpu", help="Device, cpu (default) or cuda.")
66
+ parser.add_argument(
67
+ "-t", "--dtype", default="float32", help="dtype, float32 (default) or float16"
68
+ )
69
+ parser.add_argument(
70
+ "-e", "--exporter", default="onnx-dynamo", help="exporter, default is onnx-dynamo"
71
+ )
72
+ parser.add_argument(
73
+ "--pretrained",
74
+ default=True,
75
+ help="use pretrained model or a random model",
76
+ action=BooleanOptionalAction,
77
+ )
78
+ parser.add_argument(
79
+ "--second-input",
80
+ default=True,
81
+ help="check discrepancies with other inputs",
82
+ action=BooleanOptionalAction,
83
+ )
84
+ parser.add_argument(
85
+ "--zip",
86
+ default=False,
87
+ help="Creates a file .zip with onnx file and data file.",
88
+ action=BooleanOptionalAction,
89
+ )
90
+ parser.add_argument(
91
+ "-o",
92
+ "--output-folder",
93
+ default="dump_models",
94
+ help="Folders where to put the results.",
95
+ action=BooleanOptionalAction,
96
+ )
97
+ parser.add_argument(
98
+ "-x",
99
+ "--existing-onnx",
100
+ default="",
101
+ help="If an onnx file exists, only measures the discrepancies.",
102
+ )
103
+ parser.add_argument(
104
+ "-p",
105
+ "--part",
106
+ default="visual",
107
+ help="part of the model to export",
108
+ )
109
+ parser.add_argument(
110
+ "-a",
111
+ "--atol",
112
+ type=float,
113
+ default=1.0,
114
+ help="fails if the maximum discrepancy is above that threshold",
115
+ )
116
+ parser.add_argument(
117
+ "--mismatch01",
118
+ type=float,
119
+ default=0.1,
120
+ help="fails if the ratio of mismatches at level 0.1 is above that threshold",
121
+ )
122
+ parser.add_argument(
123
+ "--profile-exporter",
124
+ default=False,
125
+ help="Profiles the exporter and outputs an html document from pyinstrument",
126
+ action=BooleanOptionalAction,
127
+ )
128
+ return parser
129
+
130
+
131
+ def remove_inplace_body_last_input_output_type_for_loop_because_they_might_be_sequences(
132
+ filename: str,
133
+ ):
134
+ """
135
+ Modified inplace an onnx file. It wipes out shapes provided
136
+ in ``model.graph.value_info`` because they are wrong when a Loop outputs
137
+ a sequence. It alose removes the types in attribute 'Body'
138
+ of an operator Loop because it may be a tensor when a sequence is expected.
139
+ This should not be needed in the future.
140
+ """
141
+ model = onnx.load(filename, load_external_data=False)
142
+ for node in model.graph.node:
143
+ if node.op_type == "Loop":
144
+ g = node.attribute[0].g
145
+ g.input[-1].type.CopyFrom(onnx.TypeProto())
146
+ g.output[-1].type.CopyFrom(onnx.TypeProto())
147
+ del model.graph.value_info[:]
148
+ model = onnx.shape_inference.infer_shapes(model)
149
+ onnx.save(model, filename, save_as_external_data=False)
150
+
151
+
152
+ def simplify_model_id_for_a_filename(model_id: str) -> str:
153
+ """Changes a model id in a way it can be used in a filename."""
154
+ return model_id.lower().replace("/", ".")
155
+
156
+
157
+ def compute_expected_outputs(
158
+ output_filename: str, model_to_export: "torch.nn.Module", input_filename: str # noqa: F821
159
+ ) -> Tuple[Any, List[Any], List[float]]:
160
+ """
161
+ Computes the expected outputs for a model.
162
+ The function uses delayed import to make to fail fast at startup.
163
+
164
+ It caches the expected outputs in a file. They are restored if the file exists
165
+ or computed and saved if not.
166
+
167
+ Imports are delayed to be faster when running the help of the command line.
168
+ """
169
+ import tqdm
170
+ import torch
171
+ from ..helpers import string_type
172
+
173
+ inputs = torch.load(input_filename, weights_only=False)
174
+ export_inputs = inputs["export_inputs"]
175
+ other_inputs = inputs["other_inputs"]
176
+
177
+ if os.path.exists(output_filename):
178
+ print(f"-- restore expected outputs from {output_filename!r}")
179
+ expected = torch.load(output_filename, weights_only=False)
180
+ export_expected = expected["export_expected"]
181
+ other_expected = expected["other_expected"]
182
+ durations = expected["durations"]
183
+ else:
184
+ print(
185
+ f"-- compute with inputs: "
186
+ f"{string_type(export_inputs, with_shape=True, with_device=True)}"
187
+ )
188
+ export_expected = model_to_export(**export_inputs)
189
+ print(f"-- got: {string_type(export_expected, with_shape=True)}")
190
+ print(
191
+ f"-- compute with inputs: "
192
+ f"{string_type(other_inputs, with_shape=True, with_device=True)}"
193
+ )
194
+ other_expected = []
195
+ durations = []
196
+ for other in tqdm.tqdm(other_inputs):
197
+ begin = time.perf_counter()
198
+ expected = model_to_export(**other)
199
+ other_expected.append(expected)
200
+ durations.append(time.perf_counter() - begin)
201
+ print(f"-- got: {string_type(other_expected, with_shape=True, with_device=True)}")
202
+
203
+ expected = dict(
204
+ export_expected=export_expected,
205
+ other_expected=other_expected,
206
+ durations=durations,
207
+ )
208
+ print(f"-- dump expected outputs into {output_filename!r}")
209
+ torch.save(expected, output_filename)
210
+ print(f"-- computation took {sum(durations)}")
211
+ print(
212
+ f"-- export_expected={string_type(export_expected, with_shape=True, with_device=True)}"
213
+ )
214
+ print(
215
+ f"-- other_expected={string_type(other_expected, with_shape=True, with_device=True)}"
216
+ )
217
+ return export_expected, other_expected, durations
218
+
219
+
220
+ def check_for_discrepancies_and_log_everything_into_a_json_file(
221
+ agg_stat_file: str,
222
+ stat_file: str,
223
+ export_duration: float,
224
+ device: str,
225
+ model_file: str,
226
+ cached_inputs: str,
227
+ cached_expected_outputs: str,
228
+ main_info: Dict[str, Any],
229
+ atol: float,
230
+ mismatch01: float,
231
+ ):
232
+ """
233
+ Checks discrepancies for a specific model.
234
+
235
+ Imports are delayed to be faster when running the help of the command line.
236
+
237
+ :param agg_stat_file: a file when the discrepancies are collected, this is used to
238
+ produce a table to make it easier to compare across types, devices, ...
239
+ :param stat_file: discrepancies results dumps into that file
240
+ :param export_duration: export duration
241
+ :param device: targeted device (to select onnxruntime provider)
242
+ :param model_file: onnx model file
243
+ :param cache_inputs: inputs saved with :func:`torch.save` and
244
+ restored with :func:`torch.load`,
245
+ needs to contains `export_inputs` (to check the model is valid),
246
+ and `other_inputs`, other sets of inputs to measure the discrepancies,
247
+ and speed up (rough estimation)
248
+ :param cached_expected_outputs: expected outputs saved with :func:`torch.save`
249
+ and restored with :func:`torch.load`,
250
+ needs to contains `export_expected` (to check the model is valid),
251
+ and `other_expected`, other sets of outputs to measure the discrepancies,
252
+ and speed up (rough estimation)
253
+ :param main_info: a dictionary with values used to tell which version, device, ...
254
+ :param atol: assert if tolerance is above this
255
+ :param mismatch01: assert if the ratio of mismatches is above that threshold
256
+ """
257
+ import tqdm
258
+ import onnxruntime
259
+ import torch
260
+ from ..helpers import flatten_object, max_diff, string_type, string_diff
261
+
262
+ cached = (
263
+ torch.load(cached_inputs, weights_only=False),
264
+ torch.load(cached_expected_outputs, weights_only=False),
265
+ )
266
+ durations = cached[0].get("durations", [])
267
+ export_inputs = cached[0]["export_inputs"]
268
+ other_inputs = cached[0]["other_inputs"]
269
+ export_expected = cached[1]["export_expected"]
270
+ other_expected = cached[1]["other_expected"]
271
+
272
+ onx = onnx.load(model_file, load_external_data=False)
273
+ opsets = [d for d in onx.opset_import if d.domain == ""]
274
+ assert (
275
+ opsets
276
+ ), f"Unable to find standard opset in file {model_file!r}, opsets={onx.opset_import}"
277
+ opset = opsets[0].version
278
+
279
+ with open(stat_file, "w") as f:
280
+
281
+ def fprint(s):
282
+ print(s)
283
+ f.write(f"{s}\n")
284
+
285
+ fprint(f"-- export duration: {export_duration}")
286
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
287
+ if device == "cpu":
288
+ providers = providers[1:]
289
+ fprint(f"-- checking discrepancies with providers={providers!r}")
290
+ fprint(f"-- model_file={model_file!r}")
291
+ sess = onnxruntime.InferenceSession(model_file, providers=providers)
292
+
293
+ fprint(
294
+ f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
295
+ )
296
+ fprint(
297
+ f"-- export_expected "
298
+ f"{string_type(export_expected, with_shape=True, with_device=True)}"
299
+ )
300
+ feeds = dict(
301
+ zip(
302
+ [i.name for i in sess.get_inputs()],
303
+ [
304
+ v.detach().cpu().numpy()
305
+ for v in flatten_object(export_inputs, drop_keys=True)
306
+ ],
307
+ )
308
+ )
309
+ small = sess.run(None, feeds)
310
+ flat_export_expected = flatten_object(export_expected, drop_keys=True)
311
+ diff = max_diff(flat_export_expected, small, hist=[0.1, 0.01])
312
+ fprint(f"-- discrepancies={diff}")
313
+ assert diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01, (
314
+ f"absolution tolerance is above {atol} or number of mismatches is above "
315
+ f"{mismatch01}, dicrepancies={string_diff(diff)}"
316
+ )
317
+
318
+ if other_inputs and other_expected:
319
+ feeds = [
320
+ dict(
321
+ zip(
322
+ [i.name for i in sess.get_inputs()],
323
+ [
324
+ v.detach().cpu().numpy()
325
+ for v in flatten_object(inputs, drop_keys=True)
326
+ ],
327
+ )
328
+ )
329
+ for inputs in other_inputs
330
+ ]
331
+ fprint("")
332
+ fprint(f"-- inputs {string_type(feeds, with_shape=True, with_device=True)}")
333
+ fprint(
334
+ f"-- expected {string_type(other_expected, with_shape=True, with_device=True)}"
335
+ )
336
+ begin = time.perf_counter()
337
+ gots = []
338
+ for feed in tqdm.tqdm(feeds):
339
+ gots.append(sess.run(None, feed))
340
+ oduration = time.perf_counter() - begin
341
+ fprint(
342
+ f"-- torch duration={sum(durations[:len(gots)])}, onnx duration={oduration}, "
343
+ f"speedup={sum(durations[:len(gots)])/oduration} n={len(gots)}"
344
+ )
345
+
346
+ info = {
347
+ **main_info,
348
+ "timestamp": datetime.datetime.now().isoformat(),
349
+ "export_duration": export_duration,
350
+ "latency_torch": sum(durations[: len(gots)]),
351
+ "latency_ort": oduration,
352
+ "speedup": sum(durations[: len(gots)]) / oduration,
353
+ "latency_ort_n": len(gots),
354
+ "opset": opset,
355
+ **get_versions(),
356
+ }
357
+ with open(agg_stat_file, "a") as fs:
358
+ for fe, e, b in zip(feeds, other_expected, gots):
359
+ flat_e = flatten_object(e, drop_keys=True)
360
+ se = string_type(fe, with_shape=True)
361
+ diff = max_diff(flat_e, b, hist=[0.1, 0.01])
362
+ assert (
363
+ diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01
364
+ ), (
365
+ f"absolution tolerance is above {atol} or number of mismatches is "
366
+ f"above {mismatch01}, dicrepancies={string_diff(diff)}"
367
+ )
368
+ js = string_diff(diff, js=True, ratio=True, inputs=se, **info)
369
+ fs.write(js)
370
+ fs.write("\n")
371
+ fprint(f"-- inputs={se} -- {js}")
372
+
373
+ if os.path.exists(agg_stat_file):
374
+ print(f"-- statistics from {agg_stat_file!r}")
375
+ import pandas
376
+
377
+ df = pandas.read_json(agg_stat_file, lines=True)
378
+ first = [
379
+ "timestamp",
380
+ "model_id",
381
+ "pretrained",
382
+ "part",
383
+ "device",
384
+ "dtype",
385
+ "attention",
386
+ "opset",
387
+ ]
388
+ index = [*first[1:], "exporter"]
389
+ df = df[[*first, *[c for c in df.columns if c not in set(first)]]]
390
+ df.to_excel(agg_stat_file + ".xlsx")
391
+
392
+ values = [
393
+ "abs",
394
+ "%>0.1",
395
+ "%>0.01",
396
+ "export_duration",
397
+ "speedup",
398
+ "latency_torch",
399
+ "latency_ort_n",
400
+ ]
401
+ agg = {
402
+ **{c: "max" for c in values if c != "speedup"},
403
+ "speedup": "min",
404
+ }
405
+ stat = df[[*index, *values]].groupby(index, dropna=False).agg(agg)
406
+ stat.to_excel(agg_stat_file + ".agg.xlsx")
407
+ stat = (
408
+ df[df.exporter != "custom"][[*index, *values]]
409
+ .groupby(index, dropna=False)
410
+ .agg(agg)
411
+ )
412
+ stat.to_excel(agg_stat_file + ".agg.onnx-dynamo.xlsx")
413
+
414
+
415
+ def zip_model_and_data_into_a_single_file(zip_file: str, model_file: str):
416
+ """
417
+ Zips an onnx model and its data into a zingle file.
418
+
419
+ :param zip_file: zip file to create
420
+ :param model_file: onnx file
421
+ """
422
+ print()
423
+ print(f"-- make file {zip_file!r}")
424
+ cmd = ["zip", "-v", "-1", zip_file]
425
+ for name in [model_file, f"{model_file}.data"]:
426
+ print(f"-- add {name!r}")
427
+ cmd.append(name)
428
+ print(f"-- cmd: {' '.join(cmd)}")
429
+ subprocess.run(cmd, check=True)
430
+ print("-- done.")