onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +108 -3
- onnx_diagnostic/ci_models/ci_helpers.py +12 -7
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
- onnx_diagnostic/export/api.py +1 -0
- onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
- onnx_diagnostic/ext_test_case.py +9 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/tasks/image_text_to_text.py +15 -5
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +29 -26
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -1547,6 +1547,107 @@ def _cmd_compare(argv: List[Any]):
|
|
|
1547
1547
|
print(ObsComparePair.to_str(pair_cmp))
|
|
1548
1548
|
|
|
1549
1549
|
|
|
1550
|
+
def get_parser_optimize() -> ArgumentParser:
|
|
1551
|
+
parser = ArgumentParser(
|
|
1552
|
+
prog="optimize",
|
|
1553
|
+
formatter_class=RawTextHelpFormatter,
|
|
1554
|
+
description=textwrap.dedent(
|
|
1555
|
+
"""
|
|
1556
|
+
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
|
|
1557
|
+
and replaces them by the corresponding nodes. It also does basic optimization
|
|
1558
|
+
such as removing identity nodes or unused nodes.
|
|
1559
|
+
"""
|
|
1560
|
+
),
|
|
1561
|
+
epilog=textwrap.dedent(
|
|
1562
|
+
"""
|
|
1563
|
+
The goal is to make the model faster.
|
|
1564
|
+
Argument patterns defines the patterns to apply or the set of patterns.
|
|
1565
|
+
It is possible to show statistics or to remove a particular pattern.
|
|
1566
|
+
Here are some environment variables which can be used to trigger
|
|
1567
|
+
these displays.
|
|
1568
|
+
|
|
1569
|
+
Available options algorithms, default and default+runtime:
|
|
1570
|
+
|
|
1571
|
+
- DROPPATTERN=<pattern1,patterns2,...>: do not apply
|
|
1572
|
+
those patterns when optimizing a model
|
|
1573
|
+
- DUMPPATTERNS=<folder>: dumps all matched and applied
|
|
1574
|
+
nodes when a pattern is applied
|
|
1575
|
+
- PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
|
|
1576
|
+
patterns to understand why one pattern was not applied,
|
|
1577
|
+
this shows which line is rejecting a pattern if it seems one pattern was missed
|
|
1578
|
+
"""
|
|
1579
|
+
),
|
|
1580
|
+
)
|
|
1581
|
+
parser.add_argument(
|
|
1582
|
+
"algorithm",
|
|
1583
|
+
choices=["ir", "os_ort", "slim", "default", "default+onnxruntime"],
|
|
1584
|
+
help="algorithm or patterns optimization to apply",
|
|
1585
|
+
)
|
|
1586
|
+
parser.add_argument("input", type=str, help="onnx model to optimize")
|
|
1587
|
+
parser.add_argument(
|
|
1588
|
+
"-o",
|
|
1589
|
+
"--output",
|
|
1590
|
+
type=str,
|
|
1591
|
+
required=False,
|
|
1592
|
+
help="onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name",
|
|
1593
|
+
)
|
|
1594
|
+
parser.add_argument(
|
|
1595
|
+
"-v",
|
|
1596
|
+
"--verbose",
|
|
1597
|
+
default=0,
|
|
1598
|
+
required=False,
|
|
1599
|
+
type=int,
|
|
1600
|
+
help="verbosity",
|
|
1601
|
+
)
|
|
1602
|
+
parser.add_argument(
|
|
1603
|
+
"--infer-shapes",
|
|
1604
|
+
default=True,
|
|
1605
|
+
action=BooleanOptionalAction,
|
|
1606
|
+
help="infer shapes before optimizing the model",
|
|
1607
|
+
)
|
|
1608
|
+
parser.add_argument(
|
|
1609
|
+
"--processor",
|
|
1610
|
+
default="",
|
|
1611
|
+
help=textwrap.dedent(
|
|
1612
|
+
"""
|
|
1613
|
+
optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
|
|
1614
|
+
some operators are only available in one processor, it might be not used
|
|
1615
|
+
with all
|
|
1616
|
+
"""
|
|
1617
|
+
).strip("\n"),
|
|
1618
|
+
)
|
|
1619
|
+
parser.add_argument(
|
|
1620
|
+
"--remove-shape-info",
|
|
1621
|
+
default=True,
|
|
1622
|
+
action=BooleanOptionalAction,
|
|
1623
|
+
help="remove shape information before outputting the model",
|
|
1624
|
+
)
|
|
1625
|
+
return parser
|
|
1626
|
+
|
|
1627
|
+
|
|
1628
|
+
def _cmd_optimize(argv: List[Any]):
|
|
1629
|
+
parser = get_parser_optimize()
|
|
1630
|
+
args = parser.parse_args(argv[1:])
|
|
1631
|
+
|
|
1632
|
+
from .helpers.optim_helper import optimize_model
|
|
1633
|
+
|
|
1634
|
+
output = (
|
|
1635
|
+
args.output
|
|
1636
|
+
if args.output
|
|
1637
|
+
else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx"
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1640
|
+
optimize_model(
|
|
1641
|
+
args.algorithm,
|
|
1642
|
+
args.input,
|
|
1643
|
+
output=output,
|
|
1644
|
+
verbose=args.verbose,
|
|
1645
|
+
processor=args.processor,
|
|
1646
|
+
infer_shapes=args.infer_shapes,
|
|
1647
|
+
remove_shape_info=args.remove_shape_info,
|
|
1648
|
+
)
|
|
1649
|
+
|
|
1650
|
+
|
|
1550
1651
|
#############
|
|
1551
1652
|
# main parser
|
|
1552
1653
|
#############
|
|
@@ -1563,16 +1664,17 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1563
1664
|
to get help for a specific command.
|
|
1564
1665
|
|
|
1565
1666
|
agg - aggregates statistics from multiple files
|
|
1566
|
-
config - prints a configuration for a model id
|
|
1667
|
+
config - prints a configuration for a model id (on HuggingFace Hub)
|
|
1567
1668
|
dot - converts an onnx model into dot format
|
|
1568
1669
|
exportsample - produces a code to export a model
|
|
1569
1670
|
find - find node consuming or producing a result
|
|
1570
|
-
lighten - makes an onnx model lighter by removing the weights
|
|
1671
|
+
lighten - makes an onnx model lighter by removing the weights
|
|
1672
|
+
optimize - optimizes an onnx model
|
|
1571
1673
|
print - prints the model on standard output
|
|
1572
1674
|
sbs - compares an exported program and a onnx model
|
|
1573
1675
|
stats - produces statistics on a model
|
|
1574
1676
|
unlighten - restores an onnx model produces by the previous experiment
|
|
1575
|
-
validate - validate a model
|
|
1677
|
+
validate - validate a model (knowing its model id on HuggginFace Hub)
|
|
1576
1678
|
"""
|
|
1577
1679
|
),
|
|
1578
1680
|
)
|
|
@@ -1585,6 +1687,7 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1585
1687
|
"exportsample",
|
|
1586
1688
|
"find",
|
|
1587
1689
|
"lighten",
|
|
1690
|
+
"optimize",
|
|
1588
1691
|
"print",
|
|
1589
1692
|
"sbs",
|
|
1590
1693
|
"stats",
|
|
@@ -1605,6 +1708,7 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
1605
1708
|
exportsample=_cmd_export_sample,
|
|
1606
1709
|
find=_cmd_find,
|
|
1607
1710
|
lighten=_cmd_lighten,
|
|
1711
|
+
optimize=_cmd_optimize,
|
|
1608
1712
|
print=_cmd_print,
|
|
1609
1713
|
sbs=_cmd_sbs,
|
|
1610
1714
|
stats=_cmd_stats,
|
|
@@ -1631,6 +1735,7 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
1631
1735
|
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
|
|
1632
1736
|
find=get_parser_find,
|
|
1633
1737
|
lighten=get_parser_lighten,
|
|
1738
|
+
optimize=get_parser_optimize,
|
|
1634
1739
|
print=get_parser_print,
|
|
1635
1740
|
sbs=get_parser_sbs,
|
|
1636
1741
|
stats=get_parser_stats,
|
|
@@ -2,7 +2,7 @@ import datetime
|
|
|
2
2
|
import os
|
|
3
3
|
import time
|
|
4
4
|
import subprocess
|
|
5
|
-
from argparse import ArgumentParser, BooleanOptionalAction
|
|
5
|
+
from argparse import ArgumentParser, BooleanOptionalAction, RawTextHelpFormatter
|
|
6
6
|
from typing import Any, Dict, List, Tuple
|
|
7
7
|
import onnx
|
|
8
8
|
|
|
@@ -50,10 +50,13 @@ def get_torch_dtype_from_command_line_args(dtype: str) -> "torch.dtype": # noqa
|
|
|
50
50
|
return torch_dtype[dtype]
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def get_parser(name: str) -> ArgumentParser:
|
|
53
|
+
def get_parser(name: str, epilog: str = "") -> ArgumentParser:
|
|
54
54
|
"""Creates a default parser for many models."""
|
|
55
55
|
parser = ArgumentParser(
|
|
56
|
-
prog=name,
|
|
56
|
+
prog=name,
|
|
57
|
+
description=f"""Export command line for model {name!r}.""",
|
|
58
|
+
epilog=epilog,
|
|
59
|
+
formatter_class=RawTextHelpFormatter,
|
|
57
60
|
)
|
|
58
61
|
parser.add_argument(
|
|
59
62
|
"-m",
|
|
@@ -110,7 +113,7 @@ def get_parser(name: str) -> ArgumentParser:
|
|
|
110
113
|
"-a",
|
|
111
114
|
"--atol",
|
|
112
115
|
type=float,
|
|
113
|
-
default=
|
|
116
|
+
default=2.0,
|
|
114
117
|
help="fails if the maximum discrepancy is above that threshold",
|
|
115
118
|
)
|
|
116
119
|
parser.add_argument(
|
|
@@ -311,7 +314,8 @@ def check_for_discrepancies_and_log_everything_into_a_json_file(
|
|
|
311
314
|
diff = max_diff(flat_export_expected, small, hist=[0.1, 0.01])
|
|
312
315
|
fprint(f"-- discrepancies={diff}")
|
|
313
316
|
assert diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01, (
|
|
314
|
-
f"
|
|
317
|
+
f"absolute error {diff['abs']} is above {atol} or number of "
|
|
318
|
+
f"mismatches ({diff['rep']['>0.1'] / diff['n']}) is above "
|
|
315
319
|
f"{mismatch01}, dicrepancies={string_diff(diff)}"
|
|
316
320
|
)
|
|
317
321
|
|
|
@@ -362,8 +366,9 @@ def check_for_discrepancies_and_log_everything_into_a_json_file(
|
|
|
362
366
|
assert (
|
|
363
367
|
diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01
|
|
364
368
|
), (
|
|
365
|
-
f"
|
|
366
|
-
f"
|
|
369
|
+
f"absolute error {diff['abs']} is above {atol} or number "
|
|
370
|
+
f" of mismatches ({diff['rep']['>0.1'] / diff['n']}) "
|
|
371
|
+
f"is above {mismatch01}, dicrepancies={string_diff(diff)}"
|
|
367
372
|
)
|
|
368
373
|
js = string_diff(diff, js=True, ratio=True, inputs=se, **info)
|
|
369
374
|
fs.write(js)
|