onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +154 -3
  3. onnx_diagnostic/ci_models/__init__.py +0 -0
  4. onnx_diagnostic/ci_models/ci_helpers.py +435 -0
  5. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  6. onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
  7. onnx_diagnostic/export/api.py +1 -0
  8. onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
  9. onnx_diagnostic/export/control_flow_onnx.py +23 -17
  10. onnx_diagnostic/ext_test_case.py +23 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/log_helper.py +1 -3
  13. onnx_diagnostic/helpers/optim_helper.py +116 -0
  14. onnx_diagnostic/tasks/image_text_to_text.py +15 -5
  15. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  16. onnx_diagnostic/tasks/text_generation.py +3 -0
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
  18. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  19. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  20. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
  22. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
  24. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  25. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  26. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
  27. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  28. onnx_diagnostic/torch_onnx/compare.py +357 -0
  29. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
  30. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
  31. onnx_diagnostic/export/control_flow.py +0 -214
  32. onnx_diagnostic/export/control_flow_research.py +0 -140
  33. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
  34. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
  35. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.5"
6
+ __version__ = "0.8.7"
7
7
  __author__ = "Xavier Dupré"
@@ -1508,6 +1508,151 @@ def _cmd_sbs(argv: List[Any]):
1508
1508
  print("-- done")
1509
1509
 
1510
1510
 
1511
+ def get_parser_compare() -> ArgumentParser:
1512
+ parser = ArgumentParser(
1513
+ prog="compare",
1514
+ description=textwrap.dedent(
1515
+ """
1516
+ Compares two onnx models by aligning the nodes between both models.
1517
+ This is done through an edit distance.
1518
+ """
1519
+ ),
1520
+ epilog=textwrap.dedent(
1521
+ """
1522
+ Each element (initializer, input, node, output) of the model
1523
+ is converted into an observation. Then it defines a distance between
1524
+ two elements. And finally, it finds the best alignment with
1525
+ an edit distance.
1526
+ """
1527
+ ),
1528
+ )
1529
+ parser.add_argument("model1", type=str, help="first model to compare")
1530
+ parser.add_argument("model2", type=str, help="second model to compare")
1531
+ return parser
1532
+
1533
+
1534
+ def _cmd_compare(argv: List[Any]):
1535
+ import onnx
1536
+ from .torch_onnx.compare import ObsCompare, ObsComparePair
1537
+
1538
+ parser = get_parser_compare()
1539
+ args = parser.parse_args(argv[1:])
1540
+ print(f"-- loading {args.model1!r}")
1541
+ seq1 = ObsCompare.obs_sequence_from_model(onnx.load(args.model1, load_external_data=False))
1542
+ print(f"-- loading {args.model2!r}")
1543
+ seq2 = ObsCompare.obs_sequence_from_model(onnx.load(args.model2, load_external_data=False))
1544
+ print("-- starts comparison")
1545
+ dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
1546
+ print(f"-- done with distance {dist}")
1547
+ print(ObsComparePair.to_str(pair_cmp))
1548
+
1549
+
1550
+ def get_parser_optimize() -> ArgumentParser:
1551
+ parser = ArgumentParser(
1552
+ prog="optimize",
1553
+ formatter_class=RawTextHelpFormatter,
1554
+ description=textwrap.dedent(
1555
+ """
1556
+ Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
1557
+ and replaces them by the corresponding nodes. It also does basic optimization
1558
+ such as removing identity nodes or unused nodes.
1559
+ """
1560
+ ),
1561
+ epilog=textwrap.dedent(
1562
+ """
1563
+ The goal is to make the model faster.
1564
+ Argument patterns defines the patterns to apply or the set of patterns.
1565
+ It is possible to show statistics or to remove a particular pattern.
1566
+ Here are some environment variables which can be used to trigger
1567
+ these displays.
1568
+
1569
+ Available options algorithms, default and default+runtime:
1570
+
1571
+ - DROPPATTERN=<pattern1,patterns2,...>: do not apply
1572
+ those patterns when optimizing a model
1573
+ - DUMPPATTERNS=<folder>: dumps all matched and applied
1574
+ nodes when a pattern is applied
1575
+ - PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
1576
+ patterns to understand why one pattern was not applied,
1577
+ this shows which line is rejecting a pattern if it seems one pattern was missed
1578
+ """
1579
+ ),
1580
+ )
1581
+ parser.add_argument(
1582
+ "algorithm",
1583
+ choices=["ir", "os_ort", "slim", "default", "default+onnxruntime"],
1584
+ help="algorithm or patterns optimization to apply",
1585
+ )
1586
+ parser.add_argument("input", type=str, help="onnx model to optimize")
1587
+ parser.add_argument(
1588
+ "-o",
1589
+ "--output",
1590
+ type=str,
1591
+ required=False,
1592
+ help="onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name",
1593
+ )
1594
+ parser.add_argument(
1595
+ "-v",
1596
+ "--verbose",
1597
+ default=0,
1598
+ required=False,
1599
+ type=int,
1600
+ help="verbosity",
1601
+ )
1602
+ parser.add_argument(
1603
+ "--infer-shapes",
1604
+ default=True,
1605
+ action=BooleanOptionalAction,
1606
+ help="infer shapes before optimizing the model",
1607
+ )
1608
+ parser.add_argument(
1609
+ "--processor",
1610
+ default="",
1611
+ help=textwrap.dedent(
1612
+ """
1613
+ optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
1614
+ some operators are only available in one processor, it might be not used
1615
+ with all
1616
+ """
1617
+ ).strip("\n"),
1618
+ )
1619
+ parser.add_argument(
1620
+ "--remove-shape-info",
1621
+ default=True,
1622
+ action=BooleanOptionalAction,
1623
+ help="remove shape information before outputting the model",
1624
+ )
1625
+ return parser
1626
+
1627
+
1628
+ def _cmd_optimize(argv: List[Any]):
1629
+ parser = get_parser_optimize()
1630
+ args = parser.parse_args(argv[1:])
1631
+
1632
+ from .helpers.optim_helper import optimize_model
1633
+
1634
+ output = (
1635
+ args.output
1636
+ if args.output
1637
+ else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx"
1638
+ )
1639
+
1640
+ optimize_model(
1641
+ args.algorithm,
1642
+ args.input,
1643
+ output=output,
1644
+ verbose=args.verbose,
1645
+ processor=args.processor,
1646
+ infer_shapes=args.infer_shapes,
1647
+ remove_shape_info=args.remove_shape_info,
1648
+ )
1649
+
1650
+
1651
+ #############
1652
+ # main parser
1653
+ #############
1654
+
1655
+
1511
1656
  def get_main_parser() -> ArgumentParser:
1512
1657
  parser = ArgumentParser(
1513
1658
  prog="onnx_diagnostic",
@@ -1519,16 +1664,17 @@ def get_main_parser() -> ArgumentParser:
1519
1664
  to get help for a specific command.
1520
1665
 
1521
1666
  agg - aggregates statistics from multiple files
1522
- config - prints a configuration for a model id
1667
+ config - prints a configuration for a model id (on HuggingFace Hub)
1523
1668
  dot - converts an onnx model into dot format
1524
1669
  exportsample - produces a code to export a model
1525
1670
  find - find node consuming or producing a result
1526
- lighten - makes an onnx model lighter by removing the weights,
1671
+ lighten - makes an onnx model lighter by removing the weights
1672
+ optimize - optimizes an onnx model
1527
1673
  print - prints the model on standard output
1528
1674
  sbs - compares an exported program and a onnx model
1529
1675
  stats - produces statistics on a model
1530
1676
  unlighten - restores an onnx model produces by the previous experiment
1531
- validate - validate a model
1677
+ validate - validate a model (knowing its model id on HuggginFace Hub)
1532
1678
  """
1533
1679
  ),
1534
1680
  )
@@ -1541,6 +1687,7 @@ def get_main_parser() -> ArgumentParser:
1541
1687
  "exportsample",
1542
1688
  "find",
1543
1689
  "lighten",
1690
+ "optimize",
1544
1691
  "print",
1545
1692
  "sbs",
1546
1693
  "stats",
@@ -1555,11 +1702,13 @@ def get_main_parser() -> ArgumentParser:
1555
1702
  def main(argv: Optional[List[Any]] = None):
1556
1703
  fcts = dict(
1557
1704
  agg=_cmd_agg,
1705
+ compare=_cmd_compare,
1558
1706
  config=_cmd_config,
1559
1707
  dot=_cmd_dot,
1560
1708
  exportsample=_cmd_export_sample,
1561
1709
  find=_cmd_find,
1562
1710
  lighten=_cmd_lighten,
1711
+ optimize=_cmd_optimize,
1563
1712
  print=_cmd_print,
1564
1713
  sbs=_cmd_sbs,
1565
1714
  stats=_cmd_stats,
@@ -1580,11 +1729,13 @@ def main(argv: Optional[List[Any]] = None):
1580
1729
  else:
1581
1730
  parsers = dict(
1582
1731
  agg=get_parser_agg,
1732
+ compare=get_parser_compare,
1583
1733
  config=get_parser_config,
1584
1734
  dot=get_parser_dot,
1585
1735
  exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
1586
1736
  find=get_parser_find,
1587
1737
  lighten=get_parser_lighten,
1738
+ optimize=get_parser_optimize,
1588
1739
  print=get_parser_print,
1589
1740
  sbs=get_parser_sbs,
1590
1741
  stats=get_parser_stats,
File without changes
@@ -0,0 +1,435 @@
1
+ import datetime
2
+ import os
3
+ import time
4
+ import subprocess
5
+ from argparse import ArgumentParser, BooleanOptionalAction, RawTextHelpFormatter
6
+ from typing import Any, Dict, List, Tuple
7
+ import onnx
8
+
9
+
10
+ def get_versions():
11
+ """
12
+ Returns the version of the package currently used.
13
+ The output is a dictionary.
14
+ The function uses delayed import to make to fail fast at startup.
15
+ """
16
+ import onnx
17
+ import onnx_diagnostic
18
+ import onnxruntime
19
+ import torch
20
+ import transformers
21
+
22
+ return {
23
+ "transformers": transformers.__version__,
24
+ "onnxruntime": onnxruntime.__version__,
25
+ "onnx": onnx.__version__,
26
+ "onnx-diagnostic": onnx_diagnostic.__version__,
27
+ "torch": torch.__version__,
28
+ }
29
+
30
+
31
+ def get_torch_dtype_from_command_line_args(dtype: str) -> "torch.dtype": # noqa: F821
32
+ """
33
+ Returns the torch dtype base on the argument provided on the command line.
34
+
35
+ Imports are delayed to be faster when running the help of the command line.
36
+ """
37
+ import torch
38
+
39
+ torch_dtype = {
40
+ "float16": torch.float16,
41
+ "bfloat16": torch.bfloat16,
42
+ "float32": torch.float32,
43
+ "fp16": torch.float16,
44
+ "bf16": torch.bfloat16,
45
+ "fp32": torch.float32,
46
+ }
47
+ assert (
48
+ dtype in torch_dtype
49
+ ), f"Unexpected dtype {dtype!r}, not found in {set(torch_dtype)}."
50
+ return torch_dtype[dtype]
51
+
52
+
53
+ def get_parser(name: str, epilog: str = "") -> ArgumentParser:
54
+ """Creates a default parser for many models."""
55
+ parser = ArgumentParser(
56
+ prog=name,
57
+ description=f"""Export command line for model {name!r}.""",
58
+ epilog=epilog,
59
+ formatter_class=RawTextHelpFormatter,
60
+ )
61
+ parser.add_argument(
62
+ "-m",
63
+ "--mid",
64
+ type=str,
65
+ default="Qwen/Qwen2.5-VL-7B-Instruct",
66
+ help="model id, default is Qwen/Qwen2.5-VL-7B-Instruct",
67
+ )
68
+ parser.add_argument("-d", "--device", default="cpu", help="Device, cpu (default) or cuda.")
69
+ parser.add_argument(
70
+ "-t", "--dtype", default="float32", help="dtype, float32 (default) or float16"
71
+ )
72
+ parser.add_argument(
73
+ "-e", "--exporter", default="onnx-dynamo", help="exporter, default is onnx-dynamo"
74
+ )
75
+ parser.add_argument(
76
+ "--pretrained",
77
+ default=True,
78
+ help="use pretrained model or a random model",
79
+ action=BooleanOptionalAction,
80
+ )
81
+ parser.add_argument(
82
+ "--second-input",
83
+ default=True,
84
+ help="check discrepancies with other inputs",
85
+ action=BooleanOptionalAction,
86
+ )
87
+ parser.add_argument(
88
+ "--zip",
89
+ default=False,
90
+ help="Creates a file .zip with onnx file and data file.",
91
+ action=BooleanOptionalAction,
92
+ )
93
+ parser.add_argument(
94
+ "-o",
95
+ "--output-folder",
96
+ default="dump_models",
97
+ help="Folders where to put the results.",
98
+ action=BooleanOptionalAction,
99
+ )
100
+ parser.add_argument(
101
+ "-x",
102
+ "--existing-onnx",
103
+ default="",
104
+ help="If an onnx file exists, only measures the discrepancies.",
105
+ )
106
+ parser.add_argument(
107
+ "-p",
108
+ "--part",
109
+ default="visual",
110
+ help="part of the model to export",
111
+ )
112
+ parser.add_argument(
113
+ "-a",
114
+ "--atol",
115
+ type=float,
116
+ default=2.0,
117
+ help="fails if the maximum discrepancy is above that threshold",
118
+ )
119
+ parser.add_argument(
120
+ "--mismatch01",
121
+ type=float,
122
+ default=0.1,
123
+ help="fails if the ratio of mismatches at level 0.1 is above that threshold",
124
+ )
125
+ parser.add_argument(
126
+ "--profile-exporter",
127
+ default=False,
128
+ help="Profiles the exporter and outputs an html document from pyinstrument",
129
+ action=BooleanOptionalAction,
130
+ )
131
+ return parser
132
+
133
+
134
+ def remove_inplace_body_last_input_output_type_for_loop_because_they_might_be_sequences(
135
+ filename: str,
136
+ ):
137
+ """
138
+ Modified inplace an onnx file. It wipes out shapes provided
139
+ in ``model.graph.value_info`` because they are wrong when a Loop outputs
140
+ a sequence. It alose removes the types in attribute 'Body'
141
+ of an operator Loop because it may be a tensor when a sequence is expected.
142
+ This should not be needed in the future.
143
+ """
144
+ model = onnx.load(filename, load_external_data=False)
145
+ for node in model.graph.node:
146
+ if node.op_type == "Loop":
147
+ g = node.attribute[0].g
148
+ g.input[-1].type.CopyFrom(onnx.TypeProto())
149
+ g.output[-1].type.CopyFrom(onnx.TypeProto())
150
+ del model.graph.value_info[:]
151
+ model = onnx.shape_inference.infer_shapes(model)
152
+ onnx.save(model, filename, save_as_external_data=False)
153
+
154
+
155
+ def simplify_model_id_for_a_filename(model_id: str) -> str:
156
+ """Changes a model id in a way it can be used in a filename."""
157
+ return model_id.lower().replace("/", ".")
158
+
159
+
160
+ def compute_expected_outputs(
161
+ output_filename: str, model_to_export: "torch.nn.Module", input_filename: str # noqa: F821
162
+ ) -> Tuple[Any, List[Any], List[float]]:
163
+ """
164
+ Computes the expected outputs for a model.
165
+ The function uses delayed import to make to fail fast at startup.
166
+
167
+ It caches the expected outputs in a file. They are restored if the file exists
168
+ or computed and saved if not.
169
+
170
+ Imports are delayed to be faster when running the help of the command line.
171
+ """
172
+ import tqdm
173
+ import torch
174
+ from ..helpers import string_type
175
+
176
+ inputs = torch.load(input_filename, weights_only=False)
177
+ export_inputs = inputs["export_inputs"]
178
+ other_inputs = inputs["other_inputs"]
179
+
180
+ if os.path.exists(output_filename):
181
+ print(f"-- restore expected outputs from {output_filename!r}")
182
+ expected = torch.load(output_filename, weights_only=False)
183
+ export_expected = expected["export_expected"]
184
+ other_expected = expected["other_expected"]
185
+ durations = expected["durations"]
186
+ else:
187
+ print(
188
+ f"-- compute with inputs: "
189
+ f"{string_type(export_inputs, with_shape=True, with_device=True)}"
190
+ )
191
+ export_expected = model_to_export(**export_inputs)
192
+ print(f"-- got: {string_type(export_expected, with_shape=True)}")
193
+ print(
194
+ f"-- compute with inputs: "
195
+ f"{string_type(other_inputs, with_shape=True, with_device=True)}"
196
+ )
197
+ other_expected = []
198
+ durations = []
199
+ for other in tqdm.tqdm(other_inputs):
200
+ begin = time.perf_counter()
201
+ expected = model_to_export(**other)
202
+ other_expected.append(expected)
203
+ durations.append(time.perf_counter() - begin)
204
+ print(f"-- got: {string_type(other_expected, with_shape=True, with_device=True)}")
205
+
206
+ expected = dict(
207
+ export_expected=export_expected,
208
+ other_expected=other_expected,
209
+ durations=durations,
210
+ )
211
+ print(f"-- dump expected outputs into {output_filename!r}")
212
+ torch.save(expected, output_filename)
213
+ print(f"-- computation took {sum(durations)}")
214
+ print(
215
+ f"-- export_expected={string_type(export_expected, with_shape=True, with_device=True)}"
216
+ )
217
+ print(
218
+ f"-- other_expected={string_type(other_expected, with_shape=True, with_device=True)}"
219
+ )
220
+ return export_expected, other_expected, durations
221
+
222
+
223
+ def check_for_discrepancies_and_log_everything_into_a_json_file(
224
+ agg_stat_file: str,
225
+ stat_file: str,
226
+ export_duration: float,
227
+ device: str,
228
+ model_file: str,
229
+ cached_inputs: str,
230
+ cached_expected_outputs: str,
231
+ main_info: Dict[str, Any],
232
+ atol: float,
233
+ mismatch01: float,
234
+ ):
235
+ """
236
+ Checks discrepancies for a specific model.
237
+
238
+ Imports are delayed to be faster when running the help of the command line.
239
+
240
+ :param agg_stat_file: a file when the discrepancies are collected, this is used to
241
+ produce a table to make it easier to compare across types, devices, ...
242
+ :param stat_file: discrepancies results dumps into that file
243
+ :param export_duration: export duration
244
+ :param device: targeted device (to select onnxruntime provider)
245
+ :param model_file: onnx model file
246
+ :param cache_inputs: inputs saved with :func:`torch.save` and
247
+ restored with :func:`torch.load`,
248
+ needs to contains `export_inputs` (to check the model is valid),
249
+ and `other_inputs`, other sets of inputs to measure the discrepancies,
250
+ and speed up (rough estimation)
251
+ :param cached_expected_outputs: expected outputs saved with :func:`torch.save`
252
+ and restored with :func:`torch.load`,
253
+ needs to contains `export_expected` (to check the model is valid),
254
+ and `other_expected`, other sets of outputs to measure the discrepancies,
255
+ and speed up (rough estimation)
256
+ :param main_info: a dictionary with values used to tell which version, device, ...
257
+ :param atol: assert if tolerance is above this
258
+ :param mismatch01: assert if the ratio of mismatches is above that threshold
259
+ """
260
+ import tqdm
261
+ import onnxruntime
262
+ import torch
263
+ from ..helpers import flatten_object, max_diff, string_type, string_diff
264
+
265
+ cached = (
266
+ torch.load(cached_inputs, weights_only=False),
267
+ torch.load(cached_expected_outputs, weights_only=False),
268
+ )
269
+ durations = cached[0].get("durations", [])
270
+ export_inputs = cached[0]["export_inputs"]
271
+ other_inputs = cached[0]["other_inputs"]
272
+ export_expected = cached[1]["export_expected"]
273
+ other_expected = cached[1]["other_expected"]
274
+
275
+ onx = onnx.load(model_file, load_external_data=False)
276
+ opsets = [d for d in onx.opset_import if d.domain == ""]
277
+ assert (
278
+ opsets
279
+ ), f"Unable to find standard opset in file {model_file!r}, opsets={onx.opset_import}"
280
+ opset = opsets[0].version
281
+
282
+ with open(stat_file, "w") as f:
283
+
284
+ def fprint(s):
285
+ print(s)
286
+ f.write(f"{s}\n")
287
+
288
+ fprint(f"-- export duration: {export_duration}")
289
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
290
+ if device == "cpu":
291
+ providers = providers[1:]
292
+ fprint(f"-- checking discrepancies with providers={providers!r}")
293
+ fprint(f"-- model_file={model_file!r}")
294
+ sess = onnxruntime.InferenceSession(model_file, providers=providers)
295
+
296
+ fprint(
297
+ f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
298
+ )
299
+ fprint(
300
+ f"-- export_expected "
301
+ f"{string_type(export_expected, with_shape=True, with_device=True)}"
302
+ )
303
+ feeds = dict(
304
+ zip(
305
+ [i.name for i in sess.get_inputs()],
306
+ [
307
+ v.detach().cpu().numpy()
308
+ for v in flatten_object(export_inputs, drop_keys=True)
309
+ ],
310
+ )
311
+ )
312
+ small = sess.run(None, feeds)
313
+ flat_export_expected = flatten_object(export_expected, drop_keys=True)
314
+ diff = max_diff(flat_export_expected, small, hist=[0.1, 0.01])
315
+ fprint(f"-- discrepancies={diff}")
316
+ assert diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01, (
317
+ f"absolute error {diff['abs']} is above {atol} or number of "
318
+ f"mismatches ({diff['rep']['>0.1'] / diff['n']}) is above "
319
+ f"{mismatch01}, dicrepancies={string_diff(diff)}"
320
+ )
321
+
322
+ if other_inputs and other_expected:
323
+ feeds = [
324
+ dict(
325
+ zip(
326
+ [i.name for i in sess.get_inputs()],
327
+ [
328
+ v.detach().cpu().numpy()
329
+ for v in flatten_object(inputs, drop_keys=True)
330
+ ],
331
+ )
332
+ )
333
+ for inputs in other_inputs
334
+ ]
335
+ fprint("")
336
+ fprint(f"-- inputs {string_type(feeds, with_shape=True, with_device=True)}")
337
+ fprint(
338
+ f"-- expected {string_type(other_expected, with_shape=True, with_device=True)}"
339
+ )
340
+ begin = time.perf_counter()
341
+ gots = []
342
+ for feed in tqdm.tqdm(feeds):
343
+ gots.append(sess.run(None, feed))
344
+ oduration = time.perf_counter() - begin
345
+ fprint(
346
+ f"-- torch duration={sum(durations[:len(gots)])}, onnx duration={oduration}, "
347
+ f"speedup={sum(durations[:len(gots)])/oduration} n={len(gots)}"
348
+ )
349
+
350
+ info = {
351
+ **main_info,
352
+ "timestamp": datetime.datetime.now().isoformat(),
353
+ "export_duration": export_duration,
354
+ "latency_torch": sum(durations[: len(gots)]),
355
+ "latency_ort": oduration,
356
+ "speedup": sum(durations[: len(gots)]) / oduration,
357
+ "latency_ort_n": len(gots),
358
+ "opset": opset,
359
+ **get_versions(),
360
+ }
361
+ with open(agg_stat_file, "a") as fs:
362
+ for fe, e, b in zip(feeds, other_expected, gots):
363
+ flat_e = flatten_object(e, drop_keys=True)
364
+ se = string_type(fe, with_shape=True)
365
+ diff = max_diff(flat_e, b, hist=[0.1, 0.01])
366
+ assert (
367
+ diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01
368
+ ), (
369
+ f"absolute error {diff['abs']} is above {atol} or number "
370
+ f" of mismatches ({diff['rep']['>0.1'] / diff['n']}) "
371
+ f"is above {mismatch01}, dicrepancies={string_diff(diff)}"
372
+ )
373
+ js = string_diff(diff, js=True, ratio=True, inputs=se, **info)
374
+ fs.write(js)
375
+ fs.write("\n")
376
+ fprint(f"-- inputs={se} -- {js}")
377
+
378
+ if os.path.exists(agg_stat_file):
379
+ print(f"-- statistics from {agg_stat_file!r}")
380
+ import pandas
381
+
382
+ df = pandas.read_json(agg_stat_file, lines=True)
383
+ first = [
384
+ "timestamp",
385
+ "model_id",
386
+ "pretrained",
387
+ "part",
388
+ "device",
389
+ "dtype",
390
+ "attention",
391
+ "opset",
392
+ ]
393
+ index = [*first[1:], "exporter"]
394
+ df = df[[*first, *[c for c in df.columns if c not in set(first)]]]
395
+ df.to_excel(agg_stat_file + ".xlsx")
396
+
397
+ values = [
398
+ "abs",
399
+ "%>0.1",
400
+ "%>0.01",
401
+ "export_duration",
402
+ "speedup",
403
+ "latency_torch",
404
+ "latency_ort_n",
405
+ ]
406
+ agg = {
407
+ **{c: "max" for c in values if c != "speedup"},
408
+ "speedup": "min",
409
+ }
410
+ stat = df[[*index, *values]].groupby(index, dropna=False).agg(agg)
411
+ stat.to_excel(agg_stat_file + ".agg.xlsx")
412
+ stat = (
413
+ df[df.exporter != "custom"][[*index, *values]]
414
+ .groupby(index, dropna=False)
415
+ .agg(agg)
416
+ )
417
+ stat.to_excel(agg_stat_file + ".agg.onnx-dynamo.xlsx")
418
+
419
+
420
+ def zip_model_and_data_into_a_single_file(zip_file: str, model_file: str):
421
+ """
422
+ Zips an onnx model and its data into a zingle file.
423
+
424
+ :param zip_file: zip file to create
425
+ :param model_file: onnx file
426
+ """
427
+ print()
428
+ print(f"-- make file {zip_file!r}")
429
+ cmd = ["zip", "-v", "-1", zip_file]
430
+ for name in [model_file, f"{model_file}.data"]:
431
+ print(f"-- add {name!r}")
432
+ cmd.append(name)
433
+ print(f"-- cmd: {' '.join(cmd)}")
434
+ subprocess.run(cmd, check=True)
435
+ print("-- done.")