onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +154 -3
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +435 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
- onnx_diagnostic/export/api.py +1 -0
- onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/ext_test_case.py +23 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/tasks/image_text_to_text.py +15 -5
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -1508,6 +1508,151 @@ def _cmd_sbs(argv: List[Any]):
|
|
|
1508
1508
|
print("-- done")
|
|
1509
1509
|
|
|
1510
1510
|
|
|
1511
|
+
def get_parser_compare() -> ArgumentParser:
|
|
1512
|
+
parser = ArgumentParser(
|
|
1513
|
+
prog="compare",
|
|
1514
|
+
description=textwrap.dedent(
|
|
1515
|
+
"""
|
|
1516
|
+
Compares two onnx models by aligning the nodes between both models.
|
|
1517
|
+
This is done through an edit distance.
|
|
1518
|
+
"""
|
|
1519
|
+
),
|
|
1520
|
+
epilog=textwrap.dedent(
|
|
1521
|
+
"""
|
|
1522
|
+
Each element (initializer, input, node, output) of the model
|
|
1523
|
+
is converted into an observation. Then it defines a distance between
|
|
1524
|
+
two elements. And finally, it finds the best alignment with
|
|
1525
|
+
an edit distance.
|
|
1526
|
+
"""
|
|
1527
|
+
),
|
|
1528
|
+
)
|
|
1529
|
+
parser.add_argument("model1", type=str, help="first model to compare")
|
|
1530
|
+
parser.add_argument("model2", type=str, help="second model to compare")
|
|
1531
|
+
return parser
|
|
1532
|
+
|
|
1533
|
+
|
|
1534
|
+
def _cmd_compare(argv: List[Any]):
|
|
1535
|
+
import onnx
|
|
1536
|
+
from .torch_onnx.compare import ObsCompare, ObsComparePair
|
|
1537
|
+
|
|
1538
|
+
parser = get_parser_compare()
|
|
1539
|
+
args = parser.parse_args(argv[1:])
|
|
1540
|
+
print(f"-- loading {args.model1!r}")
|
|
1541
|
+
seq1 = ObsCompare.obs_sequence_from_model(onnx.load(args.model1, load_external_data=False))
|
|
1542
|
+
print(f"-- loading {args.model2!r}")
|
|
1543
|
+
seq2 = ObsCompare.obs_sequence_from_model(onnx.load(args.model2, load_external_data=False))
|
|
1544
|
+
print("-- starts comparison")
|
|
1545
|
+
dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
|
|
1546
|
+
print(f"-- done with distance {dist}")
|
|
1547
|
+
print(ObsComparePair.to_str(pair_cmp))
|
|
1548
|
+
|
|
1549
|
+
|
|
1550
|
+
def get_parser_optimize() -> ArgumentParser:
|
|
1551
|
+
parser = ArgumentParser(
|
|
1552
|
+
prog="optimize",
|
|
1553
|
+
formatter_class=RawTextHelpFormatter,
|
|
1554
|
+
description=textwrap.dedent(
|
|
1555
|
+
"""
|
|
1556
|
+
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
|
|
1557
|
+
and replaces them by the corresponding nodes. It also does basic optimization
|
|
1558
|
+
such as removing identity nodes or unused nodes.
|
|
1559
|
+
"""
|
|
1560
|
+
),
|
|
1561
|
+
epilog=textwrap.dedent(
|
|
1562
|
+
"""
|
|
1563
|
+
The goal is to make the model faster.
|
|
1564
|
+
Argument patterns defines the patterns to apply or the set of patterns.
|
|
1565
|
+
It is possible to show statistics or to remove a particular pattern.
|
|
1566
|
+
Here are some environment variables which can be used to trigger
|
|
1567
|
+
these displays.
|
|
1568
|
+
|
|
1569
|
+
Available options algorithms, default and default+runtime:
|
|
1570
|
+
|
|
1571
|
+
- DROPPATTERN=<pattern1,patterns2,...>: do not apply
|
|
1572
|
+
those patterns when optimizing a model
|
|
1573
|
+
- DUMPPATTERNS=<folder>: dumps all matched and applied
|
|
1574
|
+
nodes when a pattern is applied
|
|
1575
|
+
- PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
|
|
1576
|
+
patterns to understand why one pattern was not applied,
|
|
1577
|
+
this shows which line is rejecting a pattern if it seems one pattern was missed
|
|
1578
|
+
"""
|
|
1579
|
+
),
|
|
1580
|
+
)
|
|
1581
|
+
parser.add_argument(
|
|
1582
|
+
"algorithm",
|
|
1583
|
+
choices=["ir", "os_ort", "slim", "default", "default+onnxruntime"],
|
|
1584
|
+
help="algorithm or patterns optimization to apply",
|
|
1585
|
+
)
|
|
1586
|
+
parser.add_argument("input", type=str, help="onnx model to optimize")
|
|
1587
|
+
parser.add_argument(
|
|
1588
|
+
"-o",
|
|
1589
|
+
"--output",
|
|
1590
|
+
type=str,
|
|
1591
|
+
required=False,
|
|
1592
|
+
help="onnx model to output, if empty, if adds .opt-{algorithm}.onnx to the name",
|
|
1593
|
+
)
|
|
1594
|
+
parser.add_argument(
|
|
1595
|
+
"-v",
|
|
1596
|
+
"--verbose",
|
|
1597
|
+
default=0,
|
|
1598
|
+
required=False,
|
|
1599
|
+
type=int,
|
|
1600
|
+
help="verbosity",
|
|
1601
|
+
)
|
|
1602
|
+
parser.add_argument(
|
|
1603
|
+
"--infer-shapes",
|
|
1604
|
+
default=True,
|
|
1605
|
+
action=BooleanOptionalAction,
|
|
1606
|
+
help="infer shapes before optimizing the model",
|
|
1607
|
+
)
|
|
1608
|
+
parser.add_argument(
|
|
1609
|
+
"--processor",
|
|
1610
|
+
default="",
|
|
1611
|
+
help=textwrap.dedent(
|
|
1612
|
+
"""
|
|
1613
|
+
optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
|
|
1614
|
+
some operators are only available in one processor, it might be not used
|
|
1615
|
+
with all
|
|
1616
|
+
"""
|
|
1617
|
+
).strip("\n"),
|
|
1618
|
+
)
|
|
1619
|
+
parser.add_argument(
|
|
1620
|
+
"--remove-shape-info",
|
|
1621
|
+
default=True,
|
|
1622
|
+
action=BooleanOptionalAction,
|
|
1623
|
+
help="remove shape information before outputting the model",
|
|
1624
|
+
)
|
|
1625
|
+
return parser
|
|
1626
|
+
|
|
1627
|
+
|
|
1628
|
+
def _cmd_optimize(argv: List[Any]):
|
|
1629
|
+
parser = get_parser_optimize()
|
|
1630
|
+
args = parser.parse_args(argv[1:])
|
|
1631
|
+
|
|
1632
|
+
from .helpers.optim_helper import optimize_model
|
|
1633
|
+
|
|
1634
|
+
output = (
|
|
1635
|
+
args.output
|
|
1636
|
+
if args.output
|
|
1637
|
+
else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx"
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1640
|
+
optimize_model(
|
|
1641
|
+
args.algorithm,
|
|
1642
|
+
args.input,
|
|
1643
|
+
output=output,
|
|
1644
|
+
verbose=args.verbose,
|
|
1645
|
+
processor=args.processor,
|
|
1646
|
+
infer_shapes=args.infer_shapes,
|
|
1647
|
+
remove_shape_info=args.remove_shape_info,
|
|
1648
|
+
)
|
|
1649
|
+
|
|
1650
|
+
|
|
1651
|
+
#############
|
|
1652
|
+
# main parser
|
|
1653
|
+
#############
|
|
1654
|
+
|
|
1655
|
+
|
|
1511
1656
|
def get_main_parser() -> ArgumentParser:
|
|
1512
1657
|
parser = ArgumentParser(
|
|
1513
1658
|
prog="onnx_diagnostic",
|
|
@@ -1519,16 +1664,17 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1519
1664
|
to get help for a specific command.
|
|
1520
1665
|
|
|
1521
1666
|
agg - aggregates statistics from multiple files
|
|
1522
|
-
config - prints a configuration for a model id
|
|
1667
|
+
config - prints a configuration for a model id (on HuggingFace Hub)
|
|
1523
1668
|
dot - converts an onnx model into dot format
|
|
1524
1669
|
exportsample - produces a code to export a model
|
|
1525
1670
|
find - find node consuming or producing a result
|
|
1526
|
-
lighten - makes an onnx model lighter by removing the weights
|
|
1671
|
+
lighten - makes an onnx model lighter by removing the weights
|
|
1672
|
+
optimize - optimizes an onnx model
|
|
1527
1673
|
print - prints the model on standard output
|
|
1528
1674
|
sbs - compares an exported program and a onnx model
|
|
1529
1675
|
stats - produces statistics on a model
|
|
1530
1676
|
unlighten - restores an onnx model produces by the previous experiment
|
|
1531
|
-
validate - validate a model
|
|
1677
|
+
validate - validate a model (knowing its model id on HuggginFace Hub)
|
|
1532
1678
|
"""
|
|
1533
1679
|
),
|
|
1534
1680
|
)
|
|
@@ -1541,6 +1687,7 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1541
1687
|
"exportsample",
|
|
1542
1688
|
"find",
|
|
1543
1689
|
"lighten",
|
|
1690
|
+
"optimize",
|
|
1544
1691
|
"print",
|
|
1545
1692
|
"sbs",
|
|
1546
1693
|
"stats",
|
|
@@ -1555,11 +1702,13 @@ def get_main_parser() -> ArgumentParser:
|
|
|
1555
1702
|
def main(argv: Optional[List[Any]] = None):
|
|
1556
1703
|
fcts = dict(
|
|
1557
1704
|
agg=_cmd_agg,
|
|
1705
|
+
compare=_cmd_compare,
|
|
1558
1706
|
config=_cmd_config,
|
|
1559
1707
|
dot=_cmd_dot,
|
|
1560
1708
|
exportsample=_cmd_export_sample,
|
|
1561
1709
|
find=_cmd_find,
|
|
1562
1710
|
lighten=_cmd_lighten,
|
|
1711
|
+
optimize=_cmd_optimize,
|
|
1563
1712
|
print=_cmd_print,
|
|
1564
1713
|
sbs=_cmd_sbs,
|
|
1565
1714
|
stats=_cmd_stats,
|
|
@@ -1580,11 +1729,13 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
1580
1729
|
else:
|
|
1581
1730
|
parsers = dict(
|
|
1582
1731
|
agg=get_parser_agg,
|
|
1732
|
+
compare=get_parser_compare,
|
|
1583
1733
|
config=get_parser_config,
|
|
1584
1734
|
dot=get_parser_dot,
|
|
1585
1735
|
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
|
|
1586
1736
|
find=get_parser_find,
|
|
1587
1737
|
lighten=get_parser_lighten,
|
|
1738
|
+
optimize=get_parser_optimize,
|
|
1588
1739
|
print=get_parser_print,
|
|
1589
1740
|
sbs=get_parser_sbs,
|
|
1590
1741
|
stats=get_parser_stats,
|
|
File without changes
|
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
import subprocess
|
|
5
|
+
from argparse import ArgumentParser, BooleanOptionalAction, RawTextHelpFormatter
|
|
6
|
+
from typing import Any, Dict, List, Tuple
|
|
7
|
+
import onnx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_versions():
|
|
11
|
+
"""
|
|
12
|
+
Returns the version of the package currently used.
|
|
13
|
+
The output is a dictionary.
|
|
14
|
+
The function uses delayed import to make to fail fast at startup.
|
|
15
|
+
"""
|
|
16
|
+
import onnx
|
|
17
|
+
import onnx_diagnostic
|
|
18
|
+
import onnxruntime
|
|
19
|
+
import torch
|
|
20
|
+
import transformers
|
|
21
|
+
|
|
22
|
+
return {
|
|
23
|
+
"transformers": transformers.__version__,
|
|
24
|
+
"onnxruntime": onnxruntime.__version__,
|
|
25
|
+
"onnx": onnx.__version__,
|
|
26
|
+
"onnx-diagnostic": onnx_diagnostic.__version__,
|
|
27
|
+
"torch": torch.__version__,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_torch_dtype_from_command_line_args(dtype: str) -> "torch.dtype": # noqa: F821
|
|
32
|
+
"""
|
|
33
|
+
Returns the torch dtype base on the argument provided on the command line.
|
|
34
|
+
|
|
35
|
+
Imports are delayed to be faster when running the help of the command line.
|
|
36
|
+
"""
|
|
37
|
+
import torch
|
|
38
|
+
|
|
39
|
+
torch_dtype = {
|
|
40
|
+
"float16": torch.float16,
|
|
41
|
+
"bfloat16": torch.bfloat16,
|
|
42
|
+
"float32": torch.float32,
|
|
43
|
+
"fp16": torch.float16,
|
|
44
|
+
"bf16": torch.bfloat16,
|
|
45
|
+
"fp32": torch.float32,
|
|
46
|
+
}
|
|
47
|
+
assert (
|
|
48
|
+
dtype in torch_dtype
|
|
49
|
+
), f"Unexpected dtype {dtype!r}, not found in {set(torch_dtype)}."
|
|
50
|
+
return torch_dtype[dtype]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_parser(name: str, epilog: str = "") -> ArgumentParser:
|
|
54
|
+
"""Creates a default parser for many models."""
|
|
55
|
+
parser = ArgumentParser(
|
|
56
|
+
prog=name,
|
|
57
|
+
description=f"""Export command line for model {name!r}.""",
|
|
58
|
+
epilog=epilog,
|
|
59
|
+
formatter_class=RawTextHelpFormatter,
|
|
60
|
+
)
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"-m",
|
|
63
|
+
"--mid",
|
|
64
|
+
type=str,
|
|
65
|
+
default="Qwen/Qwen2.5-VL-7B-Instruct",
|
|
66
|
+
help="model id, default is Qwen/Qwen2.5-VL-7B-Instruct",
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument("-d", "--device", default="cpu", help="Device, cpu (default) or cuda.")
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"-t", "--dtype", default="float32", help="dtype, float32 (default) or float16"
|
|
71
|
+
)
|
|
72
|
+
parser.add_argument(
|
|
73
|
+
"-e", "--exporter", default="onnx-dynamo", help="exporter, default is onnx-dynamo"
|
|
74
|
+
)
|
|
75
|
+
parser.add_argument(
|
|
76
|
+
"--pretrained",
|
|
77
|
+
default=True,
|
|
78
|
+
help="use pretrained model or a random model",
|
|
79
|
+
action=BooleanOptionalAction,
|
|
80
|
+
)
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--second-input",
|
|
83
|
+
default=True,
|
|
84
|
+
help="check discrepancies with other inputs",
|
|
85
|
+
action=BooleanOptionalAction,
|
|
86
|
+
)
|
|
87
|
+
parser.add_argument(
|
|
88
|
+
"--zip",
|
|
89
|
+
default=False,
|
|
90
|
+
help="Creates a file .zip with onnx file and data file.",
|
|
91
|
+
action=BooleanOptionalAction,
|
|
92
|
+
)
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"-o",
|
|
95
|
+
"--output-folder",
|
|
96
|
+
default="dump_models",
|
|
97
|
+
help="Folders where to put the results.",
|
|
98
|
+
action=BooleanOptionalAction,
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"-x",
|
|
102
|
+
"--existing-onnx",
|
|
103
|
+
default="",
|
|
104
|
+
help="If an onnx file exists, only measures the discrepancies.",
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"-p",
|
|
108
|
+
"--part",
|
|
109
|
+
default="visual",
|
|
110
|
+
help="part of the model to export",
|
|
111
|
+
)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"-a",
|
|
114
|
+
"--atol",
|
|
115
|
+
type=float,
|
|
116
|
+
default=2.0,
|
|
117
|
+
help="fails if the maximum discrepancy is above that threshold",
|
|
118
|
+
)
|
|
119
|
+
parser.add_argument(
|
|
120
|
+
"--mismatch01",
|
|
121
|
+
type=float,
|
|
122
|
+
default=0.1,
|
|
123
|
+
help="fails if the ratio of mismatches at level 0.1 is above that threshold",
|
|
124
|
+
)
|
|
125
|
+
parser.add_argument(
|
|
126
|
+
"--profile-exporter",
|
|
127
|
+
default=False,
|
|
128
|
+
help="Profiles the exporter and outputs an html document from pyinstrument",
|
|
129
|
+
action=BooleanOptionalAction,
|
|
130
|
+
)
|
|
131
|
+
return parser
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def remove_inplace_body_last_input_output_type_for_loop_because_they_might_be_sequences(
|
|
135
|
+
filename: str,
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
Modified inplace an onnx file. It wipes out shapes provided
|
|
139
|
+
in ``model.graph.value_info`` because they are wrong when a Loop outputs
|
|
140
|
+
a sequence. It alose removes the types in attribute 'Body'
|
|
141
|
+
of an operator Loop because it may be a tensor when a sequence is expected.
|
|
142
|
+
This should not be needed in the future.
|
|
143
|
+
"""
|
|
144
|
+
model = onnx.load(filename, load_external_data=False)
|
|
145
|
+
for node in model.graph.node:
|
|
146
|
+
if node.op_type == "Loop":
|
|
147
|
+
g = node.attribute[0].g
|
|
148
|
+
g.input[-1].type.CopyFrom(onnx.TypeProto())
|
|
149
|
+
g.output[-1].type.CopyFrom(onnx.TypeProto())
|
|
150
|
+
del model.graph.value_info[:]
|
|
151
|
+
model = onnx.shape_inference.infer_shapes(model)
|
|
152
|
+
onnx.save(model, filename, save_as_external_data=False)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def simplify_model_id_for_a_filename(model_id: str) -> str:
|
|
156
|
+
"""Changes a model id in a way it can be used in a filename."""
|
|
157
|
+
return model_id.lower().replace("/", ".")
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def compute_expected_outputs(
|
|
161
|
+
output_filename: str, model_to_export: "torch.nn.Module", input_filename: str # noqa: F821
|
|
162
|
+
) -> Tuple[Any, List[Any], List[float]]:
|
|
163
|
+
"""
|
|
164
|
+
Computes the expected outputs for a model.
|
|
165
|
+
The function uses delayed import to make to fail fast at startup.
|
|
166
|
+
|
|
167
|
+
It caches the expected outputs in a file. They are restored if the file exists
|
|
168
|
+
or computed and saved if not.
|
|
169
|
+
|
|
170
|
+
Imports are delayed to be faster when running the help of the command line.
|
|
171
|
+
"""
|
|
172
|
+
import tqdm
|
|
173
|
+
import torch
|
|
174
|
+
from ..helpers import string_type
|
|
175
|
+
|
|
176
|
+
inputs = torch.load(input_filename, weights_only=False)
|
|
177
|
+
export_inputs = inputs["export_inputs"]
|
|
178
|
+
other_inputs = inputs["other_inputs"]
|
|
179
|
+
|
|
180
|
+
if os.path.exists(output_filename):
|
|
181
|
+
print(f"-- restore expected outputs from {output_filename!r}")
|
|
182
|
+
expected = torch.load(output_filename, weights_only=False)
|
|
183
|
+
export_expected = expected["export_expected"]
|
|
184
|
+
other_expected = expected["other_expected"]
|
|
185
|
+
durations = expected["durations"]
|
|
186
|
+
else:
|
|
187
|
+
print(
|
|
188
|
+
f"-- compute with inputs: "
|
|
189
|
+
f"{string_type(export_inputs, with_shape=True, with_device=True)}"
|
|
190
|
+
)
|
|
191
|
+
export_expected = model_to_export(**export_inputs)
|
|
192
|
+
print(f"-- got: {string_type(export_expected, with_shape=True)}")
|
|
193
|
+
print(
|
|
194
|
+
f"-- compute with inputs: "
|
|
195
|
+
f"{string_type(other_inputs, with_shape=True, with_device=True)}"
|
|
196
|
+
)
|
|
197
|
+
other_expected = []
|
|
198
|
+
durations = []
|
|
199
|
+
for other in tqdm.tqdm(other_inputs):
|
|
200
|
+
begin = time.perf_counter()
|
|
201
|
+
expected = model_to_export(**other)
|
|
202
|
+
other_expected.append(expected)
|
|
203
|
+
durations.append(time.perf_counter() - begin)
|
|
204
|
+
print(f"-- got: {string_type(other_expected, with_shape=True, with_device=True)}")
|
|
205
|
+
|
|
206
|
+
expected = dict(
|
|
207
|
+
export_expected=export_expected,
|
|
208
|
+
other_expected=other_expected,
|
|
209
|
+
durations=durations,
|
|
210
|
+
)
|
|
211
|
+
print(f"-- dump expected outputs into {output_filename!r}")
|
|
212
|
+
torch.save(expected, output_filename)
|
|
213
|
+
print(f"-- computation took {sum(durations)}")
|
|
214
|
+
print(
|
|
215
|
+
f"-- export_expected={string_type(export_expected, with_shape=True, with_device=True)}"
|
|
216
|
+
)
|
|
217
|
+
print(
|
|
218
|
+
f"-- other_expected={string_type(other_expected, with_shape=True, with_device=True)}"
|
|
219
|
+
)
|
|
220
|
+
return export_expected, other_expected, durations
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def check_for_discrepancies_and_log_everything_into_a_json_file(
|
|
224
|
+
agg_stat_file: str,
|
|
225
|
+
stat_file: str,
|
|
226
|
+
export_duration: float,
|
|
227
|
+
device: str,
|
|
228
|
+
model_file: str,
|
|
229
|
+
cached_inputs: str,
|
|
230
|
+
cached_expected_outputs: str,
|
|
231
|
+
main_info: Dict[str, Any],
|
|
232
|
+
atol: float,
|
|
233
|
+
mismatch01: float,
|
|
234
|
+
):
|
|
235
|
+
"""
|
|
236
|
+
Checks discrepancies for a specific model.
|
|
237
|
+
|
|
238
|
+
Imports are delayed to be faster when running the help of the command line.
|
|
239
|
+
|
|
240
|
+
:param agg_stat_file: a file when the discrepancies are collected, this is used to
|
|
241
|
+
produce a table to make it easier to compare across types, devices, ...
|
|
242
|
+
:param stat_file: discrepancies results dumps into that file
|
|
243
|
+
:param export_duration: export duration
|
|
244
|
+
:param device: targeted device (to select onnxruntime provider)
|
|
245
|
+
:param model_file: onnx model file
|
|
246
|
+
:param cache_inputs: inputs saved with :func:`torch.save` and
|
|
247
|
+
restored with :func:`torch.load`,
|
|
248
|
+
needs to contains `export_inputs` (to check the model is valid),
|
|
249
|
+
and `other_inputs`, other sets of inputs to measure the discrepancies,
|
|
250
|
+
and speed up (rough estimation)
|
|
251
|
+
:param cached_expected_outputs: expected outputs saved with :func:`torch.save`
|
|
252
|
+
and restored with :func:`torch.load`,
|
|
253
|
+
needs to contains `export_expected` (to check the model is valid),
|
|
254
|
+
and `other_expected`, other sets of outputs to measure the discrepancies,
|
|
255
|
+
and speed up (rough estimation)
|
|
256
|
+
:param main_info: a dictionary with values used to tell which version, device, ...
|
|
257
|
+
:param atol: assert if tolerance is above this
|
|
258
|
+
:param mismatch01: assert if the ratio of mismatches is above that threshold
|
|
259
|
+
"""
|
|
260
|
+
import tqdm
|
|
261
|
+
import onnxruntime
|
|
262
|
+
import torch
|
|
263
|
+
from ..helpers import flatten_object, max_diff, string_type, string_diff
|
|
264
|
+
|
|
265
|
+
cached = (
|
|
266
|
+
torch.load(cached_inputs, weights_only=False),
|
|
267
|
+
torch.load(cached_expected_outputs, weights_only=False),
|
|
268
|
+
)
|
|
269
|
+
durations = cached[0].get("durations", [])
|
|
270
|
+
export_inputs = cached[0]["export_inputs"]
|
|
271
|
+
other_inputs = cached[0]["other_inputs"]
|
|
272
|
+
export_expected = cached[1]["export_expected"]
|
|
273
|
+
other_expected = cached[1]["other_expected"]
|
|
274
|
+
|
|
275
|
+
onx = onnx.load(model_file, load_external_data=False)
|
|
276
|
+
opsets = [d for d in onx.opset_import if d.domain == ""]
|
|
277
|
+
assert (
|
|
278
|
+
opsets
|
|
279
|
+
), f"Unable to find standard opset in file {model_file!r}, opsets={onx.opset_import}"
|
|
280
|
+
opset = opsets[0].version
|
|
281
|
+
|
|
282
|
+
with open(stat_file, "w") as f:
|
|
283
|
+
|
|
284
|
+
def fprint(s):
|
|
285
|
+
print(s)
|
|
286
|
+
f.write(f"{s}\n")
|
|
287
|
+
|
|
288
|
+
fprint(f"-- export duration: {export_duration}")
|
|
289
|
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
290
|
+
if device == "cpu":
|
|
291
|
+
providers = providers[1:]
|
|
292
|
+
fprint(f"-- checking discrepancies with providers={providers!r}")
|
|
293
|
+
fprint(f"-- model_file={model_file!r}")
|
|
294
|
+
sess = onnxruntime.InferenceSession(model_file, providers=providers)
|
|
295
|
+
|
|
296
|
+
fprint(
|
|
297
|
+
f"-- export_inputs {string_type(export_inputs, with_shape=True, with_device=True)}"
|
|
298
|
+
)
|
|
299
|
+
fprint(
|
|
300
|
+
f"-- export_expected "
|
|
301
|
+
f"{string_type(export_expected, with_shape=True, with_device=True)}"
|
|
302
|
+
)
|
|
303
|
+
feeds = dict(
|
|
304
|
+
zip(
|
|
305
|
+
[i.name for i in sess.get_inputs()],
|
|
306
|
+
[
|
|
307
|
+
v.detach().cpu().numpy()
|
|
308
|
+
for v in flatten_object(export_inputs, drop_keys=True)
|
|
309
|
+
],
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
small = sess.run(None, feeds)
|
|
313
|
+
flat_export_expected = flatten_object(export_expected, drop_keys=True)
|
|
314
|
+
diff = max_diff(flat_export_expected, small, hist=[0.1, 0.01])
|
|
315
|
+
fprint(f"-- discrepancies={diff}")
|
|
316
|
+
assert diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01, (
|
|
317
|
+
f"absolute error {diff['abs']} is above {atol} or number of "
|
|
318
|
+
f"mismatches ({diff['rep']['>0.1'] / diff['n']}) is above "
|
|
319
|
+
f"{mismatch01}, dicrepancies={string_diff(diff)}"
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
if other_inputs and other_expected:
|
|
323
|
+
feeds = [
|
|
324
|
+
dict(
|
|
325
|
+
zip(
|
|
326
|
+
[i.name for i in sess.get_inputs()],
|
|
327
|
+
[
|
|
328
|
+
v.detach().cpu().numpy()
|
|
329
|
+
for v in flatten_object(inputs, drop_keys=True)
|
|
330
|
+
],
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
for inputs in other_inputs
|
|
334
|
+
]
|
|
335
|
+
fprint("")
|
|
336
|
+
fprint(f"-- inputs {string_type(feeds, with_shape=True, with_device=True)}")
|
|
337
|
+
fprint(
|
|
338
|
+
f"-- expected {string_type(other_expected, with_shape=True, with_device=True)}"
|
|
339
|
+
)
|
|
340
|
+
begin = time.perf_counter()
|
|
341
|
+
gots = []
|
|
342
|
+
for feed in tqdm.tqdm(feeds):
|
|
343
|
+
gots.append(sess.run(None, feed))
|
|
344
|
+
oduration = time.perf_counter() - begin
|
|
345
|
+
fprint(
|
|
346
|
+
f"-- torch duration={sum(durations[:len(gots)])}, onnx duration={oduration}, "
|
|
347
|
+
f"speedup={sum(durations[:len(gots)])/oduration} n={len(gots)}"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
info = {
|
|
351
|
+
**main_info,
|
|
352
|
+
"timestamp": datetime.datetime.now().isoformat(),
|
|
353
|
+
"export_duration": export_duration,
|
|
354
|
+
"latency_torch": sum(durations[: len(gots)]),
|
|
355
|
+
"latency_ort": oduration,
|
|
356
|
+
"speedup": sum(durations[: len(gots)]) / oduration,
|
|
357
|
+
"latency_ort_n": len(gots),
|
|
358
|
+
"opset": opset,
|
|
359
|
+
**get_versions(),
|
|
360
|
+
}
|
|
361
|
+
with open(agg_stat_file, "a") as fs:
|
|
362
|
+
for fe, e, b in zip(feeds, other_expected, gots):
|
|
363
|
+
flat_e = flatten_object(e, drop_keys=True)
|
|
364
|
+
se = string_type(fe, with_shape=True)
|
|
365
|
+
diff = max_diff(flat_e, b, hist=[0.1, 0.01])
|
|
366
|
+
assert (
|
|
367
|
+
diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01
|
|
368
|
+
), (
|
|
369
|
+
f"absolute error {diff['abs']} is above {atol} or number "
|
|
370
|
+
f" of mismatches ({diff['rep']['>0.1'] / diff['n']}) "
|
|
371
|
+
f"is above {mismatch01}, dicrepancies={string_diff(diff)}"
|
|
372
|
+
)
|
|
373
|
+
js = string_diff(diff, js=True, ratio=True, inputs=se, **info)
|
|
374
|
+
fs.write(js)
|
|
375
|
+
fs.write("\n")
|
|
376
|
+
fprint(f"-- inputs={se} -- {js}")
|
|
377
|
+
|
|
378
|
+
if os.path.exists(agg_stat_file):
|
|
379
|
+
print(f"-- statistics from {agg_stat_file!r}")
|
|
380
|
+
import pandas
|
|
381
|
+
|
|
382
|
+
df = pandas.read_json(agg_stat_file, lines=True)
|
|
383
|
+
first = [
|
|
384
|
+
"timestamp",
|
|
385
|
+
"model_id",
|
|
386
|
+
"pretrained",
|
|
387
|
+
"part",
|
|
388
|
+
"device",
|
|
389
|
+
"dtype",
|
|
390
|
+
"attention",
|
|
391
|
+
"opset",
|
|
392
|
+
]
|
|
393
|
+
index = [*first[1:], "exporter"]
|
|
394
|
+
df = df[[*first, *[c for c in df.columns if c not in set(first)]]]
|
|
395
|
+
df.to_excel(agg_stat_file + ".xlsx")
|
|
396
|
+
|
|
397
|
+
values = [
|
|
398
|
+
"abs",
|
|
399
|
+
"%>0.1",
|
|
400
|
+
"%>0.01",
|
|
401
|
+
"export_duration",
|
|
402
|
+
"speedup",
|
|
403
|
+
"latency_torch",
|
|
404
|
+
"latency_ort_n",
|
|
405
|
+
]
|
|
406
|
+
agg = {
|
|
407
|
+
**{c: "max" for c in values if c != "speedup"},
|
|
408
|
+
"speedup": "min",
|
|
409
|
+
}
|
|
410
|
+
stat = df[[*index, *values]].groupby(index, dropna=False).agg(agg)
|
|
411
|
+
stat.to_excel(agg_stat_file + ".agg.xlsx")
|
|
412
|
+
stat = (
|
|
413
|
+
df[df.exporter != "custom"][[*index, *values]]
|
|
414
|
+
.groupby(index, dropna=False)
|
|
415
|
+
.agg(agg)
|
|
416
|
+
)
|
|
417
|
+
stat.to_excel(agg_stat_file + ".agg.onnx-dynamo.xlsx")
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def zip_model_and_data_into_a_single_file(zip_file: str, model_file: str):
|
|
421
|
+
"""
|
|
422
|
+
Zips an onnx model and its data into a zingle file.
|
|
423
|
+
|
|
424
|
+
:param zip_file: zip file to create
|
|
425
|
+
:param model_file: onnx file
|
|
426
|
+
"""
|
|
427
|
+
print()
|
|
428
|
+
print(f"-- make file {zip_file!r}")
|
|
429
|
+
cmd = ["zip", "-v", "-1", zip_file]
|
|
430
|
+
for name in [model_file, f"{model_file}.data"]:
|
|
431
|
+
print(f"-- add {name!r}")
|
|
432
|
+
cmd.append(name)
|
|
433
|
+
print(f"-- cmd: {' '.join(cmd)}")
|
|
434
|
+
subprocess.run(cmd, check=True)
|
|
435
|
+
print("-- done.")
|