onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -3
  3. onnx_diagnostic/ci_models/ci_helpers.py +12 -7
  4. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  5. onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
  6. onnx_diagnostic/export/api.py +295 -5
  7. onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
  8. onnx_diagnostic/export/dynamic_shapes.py +45 -3
  9. onnx_diagnostic/export/shape_helper.py +1 -0
  10. onnx_diagnostic/ext_test_case.py +9 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/cache_helper.py +0 -8
  13. onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
  14. onnx_diagnostic/helpers/helper.py +30 -1
  15. onnx_diagnostic/helpers/log_helper.py +1 -3
  16. onnx_diagnostic/helpers/optim_helper.py +116 -0
  17. onnx_diagnostic/helpers/ort_session.py +5 -0
  18. onnx_diagnostic/tasks/image_text_to_text.py +19 -9
  19. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  20. onnx_diagnostic/tasks/text_generation.py +3 -0
  21. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
  22. onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
  23. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  24. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
  31. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  33. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  34. onnx_diagnostic/torch_models/validate.py +48 -0
  35. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/METADATA +3 -1
  36. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/RECORD +39 -36
  37. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/WHEEL +0 -0
  38. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.6"
6
+ __version__ = "0.8.8"
7
7
  __author__ = "Xavier Dupré"
@@ -1547,6 +1547,107 @@ def _cmd_compare(argv: List[Any]):
1547
1547
  print(ObsComparePair.to_str(pair_cmp))
1548
1548
 
1549
1549
 
1550
+ def get_parser_optimize() -> ArgumentParser:
1551
+ parser = ArgumentParser(
1552
+ prog="optimize",
1553
+ formatter_class=RawTextHelpFormatter,
1554
+ description=textwrap.dedent(
1555
+ """
1556
+ Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
1557
+ and replaces them by the corresponding nodes. It also does basic optimization
1558
+ such as removing identity nodes or unused nodes.
1559
+ """
1560
+ ),
1561
+ epilog=textwrap.dedent(
1562
+ """
1563
+ The goal is to make the model faster.
1564
+ Argument patterns defines the patterns to apply or the set of patterns.
1565
+ It is possible to show statistics or to remove a particular pattern.
1566
+ Here are some environment variables which can be used to trigger
1567
+ these displays.
1568
+
1569
+ Available options algorithms, default and default+runtime:
1570
+
1571
+ - DROPPATTERN=<pattern1,patterns2,...>: do not apply
1572
+ those patterns when optimizing a model
1573
+ - DUMPPATTERNS=<folder>: dumps all matched and applied
1574
+ nodes when a pattern is applied
1575
+ - PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
1576
+ patterns to understand why one pattern was not applied,
1577
+ this shows which line is rejecting a pattern if it seems one pattern was missed
1578
+ """
1579
+ ),
1580
+ )
1581
+ parser.add_argument(
1582
+ "algorithm",
1583
+ choices=["ir", "os_ort", "slim", "default", "default+onnxruntime"],
1584
+ help="algorithm or patterns optimization to apply",
1585
+ )
1586
+ parser.add_argument("input", type=str, help="onnx model to optimize")
1587
+ parser.add_argument(
1588
+ "-o",
1589
+ "--output",
1590
+ type=str,
1591
+ required=False,
1592
+ help="onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name",
1593
+ )
1594
+ parser.add_argument(
1595
+ "-v",
1596
+ "--verbose",
1597
+ default=0,
1598
+ required=False,
1599
+ type=int,
1600
+ help="verbosity",
1601
+ )
1602
+ parser.add_argument(
1603
+ "--infer-shapes",
1604
+ default=True,
1605
+ action=BooleanOptionalAction,
1606
+ help="infer shapes before optimizing the model",
1607
+ )
1608
+ parser.add_argument(
1609
+ "--processor",
1610
+ default="",
1611
+ help=textwrap.dedent(
1612
+ """
1613
+ optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
1614
+ some operators are only available in one processor, it might be not used
1615
+ with all
1616
+ """
1617
+ ).strip("\n"),
1618
+ )
1619
+ parser.add_argument(
1620
+ "--remove-shape-info",
1621
+ default=True,
1622
+ action=BooleanOptionalAction,
1623
+ help="remove shape information before outputting the model",
1624
+ )
1625
+ return parser
1626
+
1627
+
1628
+ def _cmd_optimize(argv: List[Any]):
1629
+ parser = get_parser_optimize()
1630
+ args = parser.parse_args(argv[1:])
1631
+
1632
+ from .helpers.optim_helper import optimize_model
1633
+
1634
+ output = (
1635
+ args.output
1636
+ if args.output
1637
+ else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx"
1638
+ )
1639
+
1640
+ optimize_model(
1641
+ args.algorithm,
1642
+ args.input,
1643
+ output=output,
1644
+ verbose=args.verbose,
1645
+ processor=args.processor,
1646
+ infer_shapes=args.infer_shapes,
1647
+ remove_shape_info=args.remove_shape_info,
1648
+ )
1649
+
1650
+
1550
1651
  #############
1551
1652
  # main parser
1552
1653
  #############
@@ -1563,16 +1664,17 @@ def get_main_parser() -> ArgumentParser:
1563
1664
  to get help for a specific command.
1564
1665
 
1565
1666
  agg - aggregates statistics from multiple files
1566
- config - prints a configuration for a model id
1667
+ config - prints a configuration for a model id (on HuggingFace Hub)
1567
1668
  dot - converts an onnx model into dot format
1568
1669
  exportsample - produces a code to export a model
1569
1670
  find - find node consuming or producing a result
1570
- lighten - makes an onnx model lighter by removing the weights,
1671
+ lighten - makes an onnx model lighter by removing the weights
1672
+ optimize - optimizes an onnx model
1571
1673
  print - prints the model on standard output
1572
1674
  sbs - compares an exported program and a onnx model
1573
1675
  stats - produces statistics on a model
1574
1676
  unlighten - restores an onnx model produces by the previous experiment
1575
- validate - validate a model
1677
+ validate - validate a model (knowing its model id on HuggginFace Hub)
1576
1678
  """
1577
1679
  ),
1578
1680
  )
@@ -1585,6 +1687,7 @@ def get_main_parser() -> ArgumentParser:
1585
1687
  "exportsample",
1586
1688
  "find",
1587
1689
  "lighten",
1690
+ "optimize",
1588
1691
  "print",
1589
1692
  "sbs",
1590
1693
  "stats",
@@ -1605,6 +1708,7 @@ def main(argv: Optional[List[Any]] = None):
1605
1708
  exportsample=_cmd_export_sample,
1606
1709
  find=_cmd_find,
1607
1710
  lighten=_cmd_lighten,
1711
+ optimize=_cmd_optimize,
1608
1712
  print=_cmd_print,
1609
1713
  sbs=_cmd_sbs,
1610
1714
  stats=_cmd_stats,
@@ -1631,6 +1735,7 @@ def main(argv: Optional[List[Any]] = None):
1631
1735
  exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
1632
1736
  find=get_parser_find,
1633
1737
  lighten=get_parser_lighten,
1738
+ optimize=get_parser_optimize,
1634
1739
  print=get_parser_print,
1635
1740
  sbs=get_parser_sbs,
1636
1741
  stats=get_parser_stats,
@@ -2,7 +2,7 @@ import datetime
2
2
  import os
3
3
  import time
4
4
  import subprocess
5
- from argparse import ArgumentParser, BooleanOptionalAction
5
+ from argparse import ArgumentParser, BooleanOptionalAction, RawTextHelpFormatter
6
6
  from typing import Any, Dict, List, Tuple
7
7
  import onnx
8
8
 
@@ -50,10 +50,13 @@ def get_torch_dtype_from_command_line_args(dtype: str) -> "torch.dtype": # noqa
50
50
  return torch_dtype[dtype]
51
51
 
52
52
 
53
- def get_parser(name: str) -> ArgumentParser:
53
+ def get_parser(name: str, epilog: str = "") -> ArgumentParser:
54
54
  """Creates a default parser for many models."""
55
55
  parser = ArgumentParser(
56
- prog=name, description=f"""Export command line for model {name!r}."""
56
+ prog=name,
57
+ description=f"""Export command line for model {name!r}.""",
58
+ epilog=epilog,
59
+ formatter_class=RawTextHelpFormatter,
57
60
  )
58
61
  parser.add_argument(
59
62
  "-m",
@@ -110,7 +113,7 @@ def get_parser(name: str) -> ArgumentParser:
110
113
  "-a",
111
114
  "--atol",
112
115
  type=float,
113
- default=1.0,
116
+ default=2.0,
114
117
  help="fails if the maximum discrepancy is above that threshold",
115
118
  )
116
119
  parser.add_argument(
@@ -311,7 +314,8 @@ def check_for_discrepancies_and_log_everything_into_a_json_file(
311
314
  diff = max_diff(flat_export_expected, small, hist=[0.1, 0.01])
312
315
  fprint(f"-- discrepancies={diff}")
313
316
  assert diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01, (
314
- f"absolution tolerance is above {atol} or number of mismatches is above "
317
+ f"absolute error {diff['abs']} is above {atol} or number of "
318
+ f"mismatches ({diff['rep']['>0.1'] / diff['n']}) is above "
315
319
  f"{mismatch01}, dicrepancies={string_diff(diff)}"
316
320
  )
317
321
 
@@ -362,8 +366,9 @@ def check_for_discrepancies_and_log_everything_into_a_json_file(
362
366
  assert (
363
367
  diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01
364
368
  ), (
365
- f"absolution tolerance is above {atol} or number of mismatches is "
366
- f"above {mismatch01}, dicrepancies={string_diff(diff)}"
369
+ f"absolute error {diff['abs']} is above {atol} or number "
370
+ f" of mismatches ({diff['rep']['>0.1'] / diff['n']}) "
371
+ f"is above {mismatch01}, dicrepancies={string_diff(diff)}"
367
372
  )
368
373
  js = string_diff(diff, js=True, ratio=True, inputs=se, **info)
369
374
  fs.write(js)