onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1141 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import sys
|
|
6
|
+
import textwrap
|
|
7
|
+
import onnx
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
|
+
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_parser_lighten() -> ArgumentParser:
|
|
13
|
+
parser = ArgumentParser(
|
|
14
|
+
prog="lighten",
|
|
15
|
+
description=textwrap.dedent(
|
|
16
|
+
"""
|
|
17
|
+
Removes the weights from a heavy model, stores statistics to restore
|
|
18
|
+
random weights.
|
|
19
|
+
"""
|
|
20
|
+
),
|
|
21
|
+
epilog="This is mostly used to write unit tests without adding "
|
|
22
|
+
"a big onnx file to the repository.",
|
|
23
|
+
)
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"-i",
|
|
26
|
+
"--input",
|
|
27
|
+
type=str,
|
|
28
|
+
required=True,
|
|
29
|
+
help="onnx model to lighten",
|
|
30
|
+
)
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
"-o",
|
|
33
|
+
"--output",
|
|
34
|
+
type=str,
|
|
35
|
+
required=True,
|
|
36
|
+
help="onnx model to output",
|
|
37
|
+
)
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"-v",
|
|
40
|
+
"--verbose",
|
|
41
|
+
default=0,
|
|
42
|
+
required=False,
|
|
43
|
+
help="verbosity",
|
|
44
|
+
)
|
|
45
|
+
return parser
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _cmd_lighten(argv: List[Any]):
|
|
49
|
+
from .helpers.onnx_helper import onnx_lighten
|
|
50
|
+
|
|
51
|
+
parser = get_parser_lighten()
|
|
52
|
+
args = parser.parse_args(argv[1:])
|
|
53
|
+
onx = onnx.load(args.input)
|
|
54
|
+
new_onx, stats = onnx_lighten(onx, verbose=args.verbose)
|
|
55
|
+
jstats = json.dumps(stats)
|
|
56
|
+
if args.verbose:
|
|
57
|
+
print("save file {args.input!r}")
|
|
58
|
+
if args.verbose:
|
|
59
|
+
print("write file {args.output!r}")
|
|
60
|
+
with open(args.output, "wb") as f:
|
|
61
|
+
f.write(new_onx.SerializeToString())
|
|
62
|
+
name = f"{args.output}.stats"
|
|
63
|
+
with open(name, "w") as f:
|
|
64
|
+
f.write(jstats)
|
|
65
|
+
if args.verbose:
|
|
66
|
+
print("done")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_parser_unlighten() -> ArgumentParser:
|
|
70
|
+
parser = ArgumentParser(
|
|
71
|
+
prog="unlighten",
|
|
72
|
+
description=textwrap.dedent(
|
|
73
|
+
"""
|
|
74
|
+
Restores random weights for a model reduces with command lighten,
|
|
75
|
+
the command expects to find a file nearby with extension '.stats'.
|
|
76
|
+
"""
|
|
77
|
+
),
|
|
78
|
+
epilog="This is mostly used to write unit tests without adding "
|
|
79
|
+
"a big onnx file to the repository.",
|
|
80
|
+
)
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"-i",
|
|
83
|
+
"--input",
|
|
84
|
+
type=str,
|
|
85
|
+
required=True,
|
|
86
|
+
help="onnx model to unlighten",
|
|
87
|
+
)
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"-o",
|
|
90
|
+
"--output",
|
|
91
|
+
type=str,
|
|
92
|
+
required=True,
|
|
93
|
+
help="onnx model to output",
|
|
94
|
+
)
|
|
95
|
+
parser.add_argument(
|
|
96
|
+
"-v",
|
|
97
|
+
"--verbose",
|
|
98
|
+
default=0,
|
|
99
|
+
required=False,
|
|
100
|
+
help="verbosity",
|
|
101
|
+
)
|
|
102
|
+
return parser
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _cmd_unlighten(argv: List[Any]):
|
|
106
|
+
from .helpers.onnx_helper import onnx_unlighten
|
|
107
|
+
|
|
108
|
+
parser = get_parser_lighten()
|
|
109
|
+
args = parser.parse_args(argv[1:])
|
|
110
|
+
new_onx = onnx_unlighten(args.input, verbose=args.verbose)
|
|
111
|
+
if args.verbose:
|
|
112
|
+
print(f"save file {args.output}")
|
|
113
|
+
with open(args.output, "wb") as f:
|
|
114
|
+
f.write(new_onx.SerializeToString())
|
|
115
|
+
if args.verbose:
|
|
116
|
+
print("done")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_parser_print() -> ArgumentParser:
|
|
120
|
+
parser = ArgumentParser(
|
|
121
|
+
prog="print",
|
|
122
|
+
description="Prints the model on the standard output.",
|
|
123
|
+
epilog="To show a model.",
|
|
124
|
+
formatter_class=RawTextHelpFormatter,
|
|
125
|
+
)
|
|
126
|
+
parser.add_argument(
|
|
127
|
+
"fmt",
|
|
128
|
+
choices=["pretty", "raw", "text", "printer"],
|
|
129
|
+
default="pretty",
|
|
130
|
+
help=textwrap.dedent(
|
|
131
|
+
"""
|
|
132
|
+
Prints out a model on the standard output.
|
|
133
|
+
raw - just prints the model with print(...)
|
|
134
|
+
printer - onnx.printer.to_text(...)
|
|
135
|
+
pretty - an improved rendering
|
|
136
|
+
text - uses GraphRendering
|
|
137
|
+
""".strip(
|
|
138
|
+
"\n"
|
|
139
|
+
)
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
parser.add_argument("input", type=str, help="onnx model to load")
|
|
143
|
+
return parser
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _cmd_print(argv: List[Any]):
|
|
147
|
+
parser = get_parser_print()
|
|
148
|
+
args = parser.parse_args(argv[1:])
|
|
149
|
+
onx = onnx.load(args.input)
|
|
150
|
+
if args.fmt == "raw":
|
|
151
|
+
print(onx)
|
|
152
|
+
elif args.fmt == "pretty":
|
|
153
|
+
from .helpers.onnx_helper import pretty_onnx
|
|
154
|
+
|
|
155
|
+
print(pretty_onnx(onx))
|
|
156
|
+
elif args.fmt == "printer":
|
|
157
|
+
print(onnx.printer.to_text(onx))
|
|
158
|
+
elif args.fmt == "text":
|
|
159
|
+
from .helpers.graph_helper import GraphRendering
|
|
160
|
+
|
|
161
|
+
print(GraphRendering(onx).text_rendering())
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(f"Unexpected value fmt={args.fmt!r}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def get_parser_find() -> ArgumentParser:
|
|
167
|
+
parser = ArgumentParser(
|
|
168
|
+
prog="find",
|
|
169
|
+
description=textwrap.dedent(
|
|
170
|
+
"""
|
|
171
|
+
Look into a model and search for a set of names,
|
|
172
|
+
tells which node is consuming or producing it.
|
|
173
|
+
"""
|
|
174
|
+
),
|
|
175
|
+
epilog="Enables Some quick validation.",
|
|
176
|
+
)
|
|
177
|
+
parser.add_argument(
|
|
178
|
+
"-i",
|
|
179
|
+
"--input",
|
|
180
|
+
type=str,
|
|
181
|
+
required=True,
|
|
182
|
+
help="onnx model to unlighten",
|
|
183
|
+
)
|
|
184
|
+
parser.add_argument(
|
|
185
|
+
"-n",
|
|
186
|
+
"--names",
|
|
187
|
+
type=str,
|
|
188
|
+
required=False,
|
|
189
|
+
help="Names to look at comma separated values, if 'SHADOW', "
|
|
190
|
+
"search for shadowing names.",
|
|
191
|
+
)
|
|
192
|
+
parser.add_argument(
|
|
193
|
+
"-v",
|
|
194
|
+
"--verbose",
|
|
195
|
+
default=0,
|
|
196
|
+
type=int,
|
|
197
|
+
required=False,
|
|
198
|
+
help="verbosity",
|
|
199
|
+
)
|
|
200
|
+
parser.add_argument(
|
|
201
|
+
"--v2",
|
|
202
|
+
default=False,
|
|
203
|
+
action=BooleanOptionalAction,
|
|
204
|
+
help="Uses enumerate_results instead of onnx_find.",
|
|
205
|
+
)
|
|
206
|
+
return parser
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _cmd_find(argv: List[Any]):
|
|
210
|
+
from .helpers.onnx_helper import onnx_find, enumerate_results, shadowing_names
|
|
211
|
+
|
|
212
|
+
parser = get_parser_find()
|
|
213
|
+
args = parser.parse_args(argv[1:])
|
|
214
|
+
if args.names == "SHADOW":
|
|
215
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
216
|
+
s, ps = shadowing_names(onx)[:2]
|
|
217
|
+
print(f"shadowing names: {s}")
|
|
218
|
+
print(f"post-shadowing names: {ps}")
|
|
219
|
+
elif args.v2:
|
|
220
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
221
|
+
res = list(
|
|
222
|
+
enumerate_results(onx, name=set(args.names.split(",")), verbose=args.verbose)
|
|
223
|
+
)
|
|
224
|
+
if not args.verbose:
|
|
225
|
+
print("\n".join(map(str, res)))
|
|
226
|
+
else:
|
|
227
|
+
onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def get_parser_config() -> ArgumentParser:
|
|
231
|
+
parser = ArgumentParser(
|
|
232
|
+
prog="config",
|
|
233
|
+
description=textwrap.dedent(
|
|
234
|
+
"""
|
|
235
|
+
Prints out a configuration for a model id,
|
|
236
|
+
prints the associated task as well.
|
|
237
|
+
"""
|
|
238
|
+
),
|
|
239
|
+
formatter_class=RawTextHelpFormatter,
|
|
240
|
+
epilog="",
|
|
241
|
+
)
|
|
242
|
+
parser.add_argument(
|
|
243
|
+
"-m",
|
|
244
|
+
"--mid",
|
|
245
|
+
type=str,
|
|
246
|
+
required=True,
|
|
247
|
+
help="model id, usually `<author>/<name>`",
|
|
248
|
+
)
|
|
249
|
+
parser.add_argument(
|
|
250
|
+
"-t",
|
|
251
|
+
"--task",
|
|
252
|
+
default=False,
|
|
253
|
+
action=BooleanOptionalAction,
|
|
254
|
+
help="Displays the task as well.",
|
|
255
|
+
)
|
|
256
|
+
parser.add_argument(
|
|
257
|
+
"-c",
|
|
258
|
+
"--cached",
|
|
259
|
+
default=True,
|
|
260
|
+
action=BooleanOptionalAction,
|
|
261
|
+
help="Uses cached configuration, only available for some of them,\n"
|
|
262
|
+
"mostly for unit test purposes.",
|
|
263
|
+
)
|
|
264
|
+
parser.add_argument(
|
|
265
|
+
"--mop",
|
|
266
|
+
metavar="KEY=VALUE",
|
|
267
|
+
nargs="*",
|
|
268
|
+
help="Additional model options, use to change some parameters of the model, "
|
|
269
|
+
"example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
|
|
270
|
+
action=_ParseDict,
|
|
271
|
+
)
|
|
272
|
+
return parser
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _cmd_config(argv: List[Any]):
|
|
276
|
+
from .torch_models.hghub.hub_api import get_pretrained_config, task_from_id
|
|
277
|
+
|
|
278
|
+
parser = get_parser_config()
|
|
279
|
+
args = parser.parse_args(argv[1:])
|
|
280
|
+
conf = get_pretrained_config(args.mid, **(args.mop or {}))
|
|
281
|
+
print(conf)
|
|
282
|
+
for k, v in sorted(conf.__dict__.items()):
|
|
283
|
+
if "_implementation" in k:
|
|
284
|
+
print(f"config.{k}={v!r}")
|
|
285
|
+
if args.task:
|
|
286
|
+
print("------")
|
|
287
|
+
print(f"task: {task_from_id(args.mid)}")
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
|
|
291
|
+
assert isinstance(value, str), f"value should be string but value={value!r}"
|
|
292
|
+
if value and value[0] == "{" and value[-1] == "}":
|
|
293
|
+
# a dictionary
|
|
294
|
+
return json.loads(value.replace("'", '"'))
|
|
295
|
+
return value
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class _ParseDict(argparse.Action):
|
|
299
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
300
|
+
d = getattr(namespace, self.dest) or {}
|
|
301
|
+
|
|
302
|
+
if values:
|
|
303
|
+
for item in values:
|
|
304
|
+
split_items = item.split("=", 1)
|
|
305
|
+
key = split_items[0].strip() # we remove blanks around keys, as is logical
|
|
306
|
+
value = split_items[1]
|
|
307
|
+
|
|
308
|
+
if value in ("True", "true", "False", "false"):
|
|
309
|
+
d[key] = value in ("True", "true")
|
|
310
|
+
continue
|
|
311
|
+
try:
|
|
312
|
+
d[key] = int(value)
|
|
313
|
+
continue
|
|
314
|
+
except (TypeError, ValueError):
|
|
315
|
+
pass
|
|
316
|
+
try:
|
|
317
|
+
d[key] = float(value)
|
|
318
|
+
continue
|
|
319
|
+
except (TypeError, ValueError):
|
|
320
|
+
pass
|
|
321
|
+
d[key] = _parse_json(value)
|
|
322
|
+
|
|
323
|
+
setattr(namespace, self.dest, d)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class _BoolOrParseDictPatch(argparse.Action):
|
|
327
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
328
|
+
|
|
329
|
+
if not values:
|
|
330
|
+
return
|
|
331
|
+
if len(values) == 1 and values[0] in (
|
|
332
|
+
"True",
|
|
333
|
+
"False",
|
|
334
|
+
"true",
|
|
335
|
+
"false",
|
|
336
|
+
"0",
|
|
337
|
+
"1",
|
|
338
|
+
0,
|
|
339
|
+
1,
|
|
340
|
+
):
|
|
341
|
+
setattr(namespace, self.dest, values[0] in ("True", "true", 1, "1"))
|
|
342
|
+
return
|
|
343
|
+
d = getattr(namespace, self.dest) or {}
|
|
344
|
+
if not isinstance(d, dict):
|
|
345
|
+
d = {
|
|
346
|
+
"patch_sympy": d,
|
|
347
|
+
"patch_torch": d,
|
|
348
|
+
"patch_transformers": d,
|
|
349
|
+
"patch_diffusers": d,
|
|
350
|
+
}
|
|
351
|
+
for item in values:
|
|
352
|
+
split_items = item.split("=", 1)
|
|
353
|
+
key = split_items[0].strip() # we remove blanks around keys, as is logical
|
|
354
|
+
value = split_items[1]
|
|
355
|
+
|
|
356
|
+
if value in ("True", "true", "False", "false"):
|
|
357
|
+
d[key] = value in ("True", "true")
|
|
358
|
+
continue
|
|
359
|
+
try:
|
|
360
|
+
d[key] = int(value)
|
|
361
|
+
continue
|
|
362
|
+
except (TypeError, ValueError):
|
|
363
|
+
pass
|
|
364
|
+
try:
|
|
365
|
+
d[key] = float(value)
|
|
366
|
+
continue
|
|
367
|
+
except (TypeError, ValueError):
|
|
368
|
+
pass
|
|
369
|
+
d[key] = _parse_json(value)
|
|
370
|
+
|
|
371
|
+
setattr(namespace, self.dest, d)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
375
|
+
parser = ArgumentParser(
|
|
376
|
+
prog=name,
|
|
377
|
+
description=textwrap.dedent(
|
|
378
|
+
"""
|
|
379
|
+
Validates a model for a particular task given the model id.
|
|
380
|
+
It exports the model and then validates it by computing the discrepancies
|
|
381
|
+
on different input sets.
|
|
382
|
+
"""
|
|
383
|
+
if name == "validate"
|
|
384
|
+
else """
|
|
385
|
+
Creates a script to export a model for a particular task given the model id.
|
|
386
|
+
"""
|
|
387
|
+
),
|
|
388
|
+
epilog=textwrap.dedent(
|
|
389
|
+
f"""
|
|
390
|
+
If the model id is specified, one untrained version of it is instantiated.
|
|
391
|
+
Examples:
|
|
392
|
+
|
|
393
|
+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
|
|
394
|
+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
395
|
+
--dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
|
|
396
|
+
|
|
397
|
+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
|
|
398
|
+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
399
|
+
--dtype float16 --device cuda --patch --export custom --opt default
|
|
400
|
+
|
|
401
|
+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
|
|
402
|
+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
403
|
+
--dtype float16 --device cuda --export modelbuilder
|
|
404
|
+
|
|
405
|
+
position_ids is usually not needed, they can be removed by adding:
|
|
406
|
+
|
|
407
|
+
--drop position_ids
|
|
408
|
+
|
|
409
|
+
The behaviour may be modified compare the original configuration,
|
|
410
|
+
the following argument can be rope_scaling to dynamic:
|
|
411
|
+
|
|
412
|
+
--mop \"rope_scaling={{'rope_type': 'dynamic', 'factor': 10.0}}\""
|
|
413
|
+
|
|
414
|
+
You can profile the command line by running:
|
|
415
|
+
|
|
416
|
+
pyinstrument -m onnx_diagnostic {name} ...
|
|
417
|
+
pyinstrument -r html -o profile.html -m onnx_diagnostic {name} ...
|
|
418
|
+
"""
|
|
419
|
+
),
|
|
420
|
+
formatter_class=RawTextHelpFormatter,
|
|
421
|
+
)
|
|
422
|
+
parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
|
|
423
|
+
parser.add_argument("-t", "--task", default=None, help="force the task to use")
|
|
424
|
+
parser.add_argument("-e", "--export", help="export the model with this exporter")
|
|
425
|
+
parser.add_argument("--opt", help="optimization to apply after the export")
|
|
426
|
+
parser.add_argument(
|
|
427
|
+
"-r",
|
|
428
|
+
"--run",
|
|
429
|
+
default=False,
|
|
430
|
+
action=BooleanOptionalAction,
|
|
431
|
+
help="Runs the model to check it runs.",
|
|
432
|
+
)
|
|
433
|
+
parser.add_argument(
|
|
434
|
+
"-q",
|
|
435
|
+
"--quiet",
|
|
436
|
+
default=False,
|
|
437
|
+
action=BooleanOptionalAction,
|
|
438
|
+
help="Catches exception, reports them in the summary.",
|
|
439
|
+
)
|
|
440
|
+
parser.add_argument(
|
|
441
|
+
"--patch",
|
|
442
|
+
default=True,
|
|
443
|
+
action=_BoolOrParseDictPatch,
|
|
444
|
+
nargs="*",
|
|
445
|
+
help="Applies patches before exporting, it can be a boolean "
|
|
446
|
+
"to enable to disable the patches or be more finetuned. It is possible to "
|
|
447
|
+
"disable patch for torch by adding "
|
|
448
|
+
'--patch "patch_sympy=False" --patch "patch_torch=False", '
|
|
449
|
+
"default is True.",
|
|
450
|
+
)
|
|
451
|
+
parser.add_argument(
|
|
452
|
+
"--rewrite",
|
|
453
|
+
default=True,
|
|
454
|
+
action=BooleanOptionalAction,
|
|
455
|
+
help="Applies rewrite before exporting.",
|
|
456
|
+
)
|
|
457
|
+
parser.add_argument(
|
|
458
|
+
"--stop-if-static",
|
|
459
|
+
default=0,
|
|
460
|
+
type=int,
|
|
461
|
+
help="Raises an exception if a dynamic dimension becomes static.",
|
|
462
|
+
)
|
|
463
|
+
parser.add_argument(
|
|
464
|
+
"--same-as-trained",
|
|
465
|
+
default=False,
|
|
466
|
+
action=BooleanOptionalAction,
|
|
467
|
+
help="Validates or exports a model identical to the trained model but not trained.",
|
|
468
|
+
)
|
|
469
|
+
parser.add_argument(
|
|
470
|
+
"--trained",
|
|
471
|
+
default=False,
|
|
472
|
+
action=BooleanOptionalAction,
|
|
473
|
+
help="Validates or exports the trained model (requires downloading).",
|
|
474
|
+
)
|
|
475
|
+
parser.add_argument(
|
|
476
|
+
"--inputs2",
|
|
477
|
+
default=1,
|
|
478
|
+
type=int,
|
|
479
|
+
help="Validates or exports the model on a second set of inputs\n"
|
|
480
|
+
"to check the exported model supports dynamism. The values is used "
|
|
481
|
+
"as an increment to the first set of inputs. A high value may trick "
|
|
482
|
+
"a different behavior in the model and missed by the exporter.",
|
|
483
|
+
)
|
|
484
|
+
parser.add_argument(
|
|
485
|
+
"--runtime",
|
|
486
|
+
choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
|
|
487
|
+
default="onnxruntime",
|
|
488
|
+
help="onnx runtime to use, `onnxruntime` by default",
|
|
489
|
+
)
|
|
490
|
+
parser.add_argument(
|
|
491
|
+
"-o",
|
|
492
|
+
"--dump-folder",
|
|
493
|
+
help="A folder is created to dumps statistics,\nexported program, onnx...",
|
|
494
|
+
)
|
|
495
|
+
parser.add_argument(
|
|
496
|
+
"--drop",
|
|
497
|
+
help="Drops the following inputs names, it should be a list\n"
|
|
498
|
+
"with comma separated values, example:\n"
|
|
499
|
+
"--drop position_ids",
|
|
500
|
+
)
|
|
501
|
+
parser.add_argument(
|
|
502
|
+
"--opset",
|
|
503
|
+
type=int,
|
|
504
|
+
default=18,
|
|
505
|
+
help="onnx opset to use, 18 by default",
|
|
506
|
+
)
|
|
507
|
+
parser.add_argument(
|
|
508
|
+
"--subfolder",
|
|
509
|
+
help="Subfolder where to find the model and the configuration.",
|
|
510
|
+
)
|
|
511
|
+
if name == "validate":
|
|
512
|
+
parser.add_argument(
|
|
513
|
+
"--ortfusiontype",
|
|
514
|
+
required=False,
|
|
515
|
+
help="Applies onnxruntime fusion, this parameter should contain the\n"
|
|
516
|
+
"model type or multiple values separated by `|`. `ALL` can be used\n"
|
|
517
|
+
"to run them all.",
|
|
518
|
+
)
|
|
519
|
+
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
|
|
520
|
+
parser.add_argument("--dtype", help="Changes dtype if necessary.")
|
|
521
|
+
parser.add_argument("--device", help="Changes the device if necessary.")
|
|
522
|
+
parser.add_argument(
|
|
523
|
+
"--iop",
|
|
524
|
+
metavar="KEY=VALUE",
|
|
525
|
+
nargs="*",
|
|
526
|
+
help="Additional input options, use to change the default"
|
|
527
|
+
"inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
|
|
528
|
+
"\n --iop cls_cache=StaticCache",
|
|
529
|
+
action=_ParseDict,
|
|
530
|
+
)
|
|
531
|
+
parser.add_argument(
|
|
532
|
+
"--mop",
|
|
533
|
+
metavar="KEY=VALUE",
|
|
534
|
+
nargs="*",
|
|
535
|
+
help="Additional model options, use to change some parameters of the model, "
|
|
536
|
+
"example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
|
|
537
|
+
"--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
|
|
538
|
+
action=_ParseDict,
|
|
539
|
+
)
|
|
540
|
+
if name == "validate":
|
|
541
|
+
parser.add_argument(
|
|
542
|
+
"--repeat",
|
|
543
|
+
default=1,
|
|
544
|
+
type=int,
|
|
545
|
+
help="number of times to run the model to measures inference time",
|
|
546
|
+
)
|
|
547
|
+
parser.add_argument(
|
|
548
|
+
"--warmup",
|
|
549
|
+
default=0,
|
|
550
|
+
type=int,
|
|
551
|
+
help="number of times to run the model to do warmup",
|
|
552
|
+
)
|
|
553
|
+
parser.add_argument(
|
|
554
|
+
"--outnames",
|
|
555
|
+
help="This comma separated list defines the output names "
|
|
556
|
+
"the onnx exporter should use.",
|
|
557
|
+
default="",
|
|
558
|
+
)
|
|
559
|
+
if name == "validate":
|
|
560
|
+
parser.add_argument(
|
|
561
|
+
"--ort-logs",
|
|
562
|
+
default=False,
|
|
563
|
+
action=BooleanOptionalAction,
|
|
564
|
+
help="Enables onnxruntime logging when the session is created",
|
|
565
|
+
)
|
|
566
|
+
parser.add_argument(
|
|
567
|
+
"--quiet-input-sets",
|
|
568
|
+
default="",
|
|
569
|
+
help="Avoids raising an exception when an input sets does not work with "
|
|
570
|
+
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
|
|
571
|
+
)
|
|
572
|
+
return parser
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def _cmd_validate(argv: List[Any]):
|
|
576
|
+
from .helpers import string_type
|
|
577
|
+
from .torch_models.validate import get_inputs_for_task, validate_model
|
|
578
|
+
from .tasks import supported_tasks
|
|
579
|
+
|
|
580
|
+
parser = get_parser_validate()
|
|
581
|
+
args = parser.parse_args(argv[1:])
|
|
582
|
+
if not args.task and not args.mid:
|
|
583
|
+
print("-- list of supported tasks:")
|
|
584
|
+
print("\n".join(supported_tasks()))
|
|
585
|
+
elif not args.mid:
|
|
586
|
+
data = get_inputs_for_task(args.task)
|
|
587
|
+
if args.verbose:
|
|
588
|
+
print(f"task: {args.task}")
|
|
589
|
+
max_length = max(len(k) for k in data["inputs"]) + 1
|
|
590
|
+
print("-- inputs")
|
|
591
|
+
for k, v in data["inputs"].items():
|
|
592
|
+
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
|
|
593
|
+
print("-- dynamic_shapes")
|
|
594
|
+
for k, v in data["dynamic_shapes"].items():
|
|
595
|
+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
|
|
596
|
+
else:
|
|
597
|
+
# Let's skip any invalid combination if known to be unsupported
|
|
598
|
+
if (
|
|
599
|
+
"onnx" not in (args.export or "")
|
|
600
|
+
and "custom" not in (args.export or "")
|
|
601
|
+
and (args.opt or "")
|
|
602
|
+
):
|
|
603
|
+
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
|
|
604
|
+
return
|
|
605
|
+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
|
|
606
|
+
summary, _data = validate_model(
|
|
607
|
+
model_id=args.mid,
|
|
608
|
+
task=args.task,
|
|
609
|
+
do_run=args.run,
|
|
610
|
+
verbose=args.verbose,
|
|
611
|
+
quiet=args.quiet,
|
|
612
|
+
same_as_pretrained=args.same_as_trained,
|
|
613
|
+
use_pretrained=args.trained,
|
|
614
|
+
dtype=args.dtype,
|
|
615
|
+
device=args.device,
|
|
616
|
+
patch=patch_dict,
|
|
617
|
+
rewrite=args.rewrite and patch_dict.get("patch", True),
|
|
618
|
+
stop_if_static=args.stop_if_static,
|
|
619
|
+
optimization=args.opt,
|
|
620
|
+
exporter=args.export,
|
|
621
|
+
dump_folder=args.dump_folder,
|
|
622
|
+
drop_inputs=None if not args.drop else args.drop.split(","),
|
|
623
|
+
ortfusiontype=args.ortfusiontype,
|
|
624
|
+
input_options=args.iop,
|
|
625
|
+
model_options=args.mop,
|
|
626
|
+
subfolder=args.subfolder,
|
|
627
|
+
opset=args.opset,
|
|
628
|
+
runtime=args.runtime,
|
|
629
|
+
repeat=args.repeat,
|
|
630
|
+
warmup=args.warmup,
|
|
631
|
+
inputs2=args.inputs2,
|
|
632
|
+
ort_logs=args.ort_logs,
|
|
633
|
+
quiet_input_sets=set(args.quiet_input_sets.split(",")),
|
|
634
|
+
output_names=(
|
|
635
|
+
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
636
|
+
),
|
|
637
|
+
)
|
|
638
|
+
print("")
|
|
639
|
+
print("-- summary --")
|
|
640
|
+
for k, v in sorted(summary.items()):
|
|
641
|
+
print(f":{k},{v};")
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def _cmd_export_sample(argv: List[Any]):
|
|
645
|
+
from .helpers import string_type
|
|
646
|
+
from .torch_models.validate import get_inputs_for_task, _make_folder_name
|
|
647
|
+
from .torch_models.code_sample import code_sample
|
|
648
|
+
from .tasks import supported_tasks
|
|
649
|
+
|
|
650
|
+
parser = get_parser_validate("exportsample")
|
|
651
|
+
args = parser.parse_args(argv[1:])
|
|
652
|
+
if not args.task and not args.mid:
|
|
653
|
+
print("-- list of supported tasks:")
|
|
654
|
+
print("\n".join(supported_tasks()))
|
|
655
|
+
elif not args.mid:
|
|
656
|
+
data = get_inputs_for_task(args.task)
|
|
657
|
+
if args.verbose:
|
|
658
|
+
print(f"task: {args.task}")
|
|
659
|
+
max_length = max(len(k) for k in data["inputs"]) + 1
|
|
660
|
+
print("-- inputs")
|
|
661
|
+
for k, v in data["inputs"].items():
|
|
662
|
+
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
|
|
663
|
+
print("-- dynamic_shapes")
|
|
664
|
+
for k, v in data["dynamic_shapes"].items():
|
|
665
|
+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
|
|
666
|
+
else:
|
|
667
|
+
# Let's skip any invalid combination if known to be unsupported
|
|
668
|
+
if (
|
|
669
|
+
"onnx" not in (args.export or "")
|
|
670
|
+
and "custom" not in (args.export or "")
|
|
671
|
+
and (args.opt or "")
|
|
672
|
+
):
|
|
673
|
+
print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}")
|
|
674
|
+
return
|
|
675
|
+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
|
|
676
|
+
code = code_sample(
|
|
677
|
+
model_id=args.mid,
|
|
678
|
+
task=args.task,
|
|
679
|
+
do_run=args.run,
|
|
680
|
+
verbose=args.verbose,
|
|
681
|
+
quiet=args.quiet,
|
|
682
|
+
same_as_pretrained=args.same_as_trained,
|
|
683
|
+
use_pretrained=args.trained,
|
|
684
|
+
dtype=args.dtype,
|
|
685
|
+
device=args.device,
|
|
686
|
+
patch=patch_dict,
|
|
687
|
+
rewrite=args.rewrite and patch_dict.get("patch", True),
|
|
688
|
+
stop_if_static=args.stop_if_static,
|
|
689
|
+
optimization=args.opt,
|
|
690
|
+
exporter=args.export,
|
|
691
|
+
dump_folder=args.dump_folder,
|
|
692
|
+
drop_inputs=None if not args.drop else args.drop.split(","),
|
|
693
|
+
input_options=args.iop,
|
|
694
|
+
model_options=args.mop,
|
|
695
|
+
subfolder=args.subfolder,
|
|
696
|
+
opset=args.opset,
|
|
697
|
+
runtime=args.runtime,
|
|
698
|
+
output_names=(
|
|
699
|
+
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
700
|
+
),
|
|
701
|
+
)
|
|
702
|
+
if args.dump_folder:
|
|
703
|
+
os.makedirs(args.dump_folder, exist_ok=True)
|
|
704
|
+
name = (
|
|
705
|
+
_make_folder_name(
|
|
706
|
+
model_id=args.mid,
|
|
707
|
+
exporter=args.export,
|
|
708
|
+
optimization=args.opt,
|
|
709
|
+
dtype=args.dtype,
|
|
710
|
+
device=args.device,
|
|
711
|
+
subfolder=args.subfolder,
|
|
712
|
+
opset=args.opset,
|
|
713
|
+
drop_inputs=None if not args.drop else args.drop.split(","),
|
|
714
|
+
same_as_pretrained=args.same_as_trained,
|
|
715
|
+
use_pretrained=args.trained,
|
|
716
|
+
task=args.task,
|
|
717
|
+
).replace("/", "-")
|
|
718
|
+
+ ".py"
|
|
719
|
+
)
|
|
720
|
+
fullname = os.path.join(args.dump_folder, name)
|
|
721
|
+
if args.verbose:
|
|
722
|
+
print(f"-- prints code in {fullname!r}")
|
|
723
|
+
print("--")
|
|
724
|
+
with open(fullname, "w") as f:
|
|
725
|
+
f.write(code)
|
|
726
|
+
if args.verbose:
|
|
727
|
+
print("-- done")
|
|
728
|
+
else:
|
|
729
|
+
print(code)
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def get_parser_stats() -> ArgumentParser:
|
|
733
|
+
parser = ArgumentParser(
|
|
734
|
+
prog="stats",
|
|
735
|
+
description="Prints out statistics on an ONNX model.",
|
|
736
|
+
epilog="",
|
|
737
|
+
)
|
|
738
|
+
parser.add_argument(
|
|
739
|
+
"-i",
|
|
740
|
+
"--input",
|
|
741
|
+
type=str,
|
|
742
|
+
required=True,
|
|
743
|
+
help="ONNX file",
|
|
744
|
+
)
|
|
745
|
+
parser.add_argument(
|
|
746
|
+
"-o",
|
|
747
|
+
"--output",
|
|
748
|
+
required=False,
|
|
749
|
+
default="",
|
|
750
|
+
help="outputs the statistics in a file",
|
|
751
|
+
)
|
|
752
|
+
parser.add_argument(
|
|
753
|
+
"-v",
|
|
754
|
+
"--verbose",
|
|
755
|
+
required=False,
|
|
756
|
+
default=1,
|
|
757
|
+
type=int,
|
|
758
|
+
help="verbosity",
|
|
759
|
+
)
|
|
760
|
+
parser.add_argument(
|
|
761
|
+
"-e",
|
|
762
|
+
"--end",
|
|
763
|
+
required=False,
|
|
764
|
+
default=-1,
|
|
765
|
+
type=int,
|
|
766
|
+
help="ends after this many tensors",
|
|
767
|
+
)
|
|
768
|
+
parser.add_argument(
|
|
769
|
+
"-b",
|
|
770
|
+
"--begin",
|
|
771
|
+
required=False,
|
|
772
|
+
default=0,
|
|
773
|
+
type=int,
|
|
774
|
+
help="starts after this many tensors",
|
|
775
|
+
)
|
|
776
|
+
parser.add_argument(
|
|
777
|
+
"-r",
|
|
778
|
+
"--regex",
|
|
779
|
+
required=False,
|
|
780
|
+
default="",
|
|
781
|
+
type=str,
|
|
782
|
+
help="Keeps only tensors whose name verifies "
|
|
783
|
+
"this regular expression, empty = no filter.",
|
|
784
|
+
)
|
|
785
|
+
return parser
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def _cmd_stats(argv: List[Any]):
|
|
789
|
+
from .helpers.onnx_helper import iterator_initializer_constant, tensor_statistics
|
|
790
|
+
|
|
791
|
+
parser = get_parser_stats()
|
|
792
|
+
args = parser.parse_args(argv[1:])
|
|
793
|
+
assert os.path.exists(args.input), f"Missing filename {args.input!r}"
|
|
794
|
+
if args.verbose:
|
|
795
|
+
print(f"Loading {args.input}")
|
|
796
|
+
onx = onnx.load(args.input)
|
|
797
|
+
reg = re.compile(args.regex) if args.regex else None
|
|
798
|
+
data = []
|
|
799
|
+
for index, (name, init) in enumerate(iterator_initializer_constant(onx)):
|
|
800
|
+
if reg and not reg.search(name):
|
|
801
|
+
continue
|
|
802
|
+
if index < args.begin:
|
|
803
|
+
continue
|
|
804
|
+
if args.end > 0 and index >= args.end:
|
|
805
|
+
break
|
|
806
|
+
if args.verbose:
|
|
807
|
+
print(f"processing {index + 1}: {name!r}")
|
|
808
|
+
stats = tensor_statistics(init)
|
|
809
|
+
if not args.output:
|
|
810
|
+
print(f"{name}: {stats}")
|
|
811
|
+
stats["name"] = name
|
|
812
|
+
data.append(stats)
|
|
813
|
+
if args.output:
|
|
814
|
+
if args.verbose:
|
|
815
|
+
print(f"saving into {args.output!r}")
|
|
816
|
+
import pandas
|
|
817
|
+
|
|
818
|
+
df = pandas.DataFrame(data)
|
|
819
|
+
ext = os.path.splitext(args.output)
|
|
820
|
+
if ext[-1] == ".xlsx":
|
|
821
|
+
df.to_excel(args.output, index=False)
|
|
822
|
+
else:
|
|
823
|
+
df.to_csv(args.output, index=False)
|
|
824
|
+
if args.verbose:
|
|
825
|
+
print("done.")
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
class _ParseNamedDict(argparse.Action):
|
|
829
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
830
|
+
assert ":" in values, f"':' missing from {values!r}"
|
|
831
|
+
namespace_key, rest = values.split(":", 1)
|
|
832
|
+
pairs = rest.split(",")
|
|
833
|
+
inner_dict = {}
|
|
834
|
+
|
|
835
|
+
for pair in pairs:
|
|
836
|
+
if "=" not in pair:
|
|
837
|
+
raise argparse.ArgumentError(self, f"Expected '=' in pair '{pair}'")
|
|
838
|
+
key, value = pair.split("=", 1)
|
|
839
|
+
inner_dict[key] = value
|
|
840
|
+
assert inner_dict, f"Unable to parse {rest!r} into a dictionary"
|
|
841
|
+
if not hasattr(namespace, self.dest) or getattr(namespace, self.dest) is None:
|
|
842
|
+
setattr(namespace, self.dest, {})
|
|
843
|
+
assert isinstance(
|
|
844
|
+
getattr(namespace, self.dest), dict
|
|
845
|
+
), f"Unexpected type for namespace.{self.dest}={getattr(namespace, self.dest)}"
|
|
846
|
+
getattr(namespace, self.dest).update({namespace_key: inner_dict})
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def get_parser_agg() -> ArgumentParser:
|
|
850
|
+
parser = ArgumentParser(
|
|
851
|
+
prog="agg",
|
|
852
|
+
description=textwrap.dedent(
|
|
853
|
+
"""
|
|
854
|
+
Aggregates statistics coming from benchmarks.
|
|
855
|
+
Every run is a row. Every row is indexed by some keys,
|
|
856
|
+
and produces values. Every row has a date.
|
|
857
|
+
The data can come any csv files produces by benchmarks,
|
|
858
|
+
it can concatenates many csv files, or csv files inside zip files.
|
|
859
|
+
It produces an excel file with many tabs, one per view.
|
|
860
|
+
"""
|
|
861
|
+
),
|
|
862
|
+
epilog=textwrap.dedent(
|
|
863
|
+
"""
|
|
864
|
+
examples:
|
|
865
|
+
|
|
866
|
+
python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
|
|
867
|
+
python -m onnx_diagnostic agg agg.xlsx raw/*.zip raw/*.csv -v 1 \\
|
|
868
|
+
--no-raw --keep-last-date --filter-out "exporter:test-exporter"
|
|
869
|
+
|
|
870
|
+
Another to create timeseries:
|
|
871
|
+
|
|
872
|
+
python -m onnx_diagnostic agg history.xlsx raw/*.csv -v 1 --no-raw \\
|
|
873
|
+
--no-recent
|
|
874
|
+
"""
|
|
875
|
+
),
|
|
876
|
+
formatter_class=RawTextHelpFormatter,
|
|
877
|
+
)
|
|
878
|
+
parser.add_argument("output", help="output excel file")
|
|
879
|
+
parser.add_argument(
|
|
880
|
+
"inputs",
|
|
881
|
+
nargs="+",
|
|
882
|
+
help="input csv or zip files, at least 1, it can be a name, or search path",
|
|
883
|
+
)
|
|
884
|
+
parser.add_argument(
|
|
885
|
+
"--filter", default="rawdata_.*.csv", help="filter for input files inside zip files"
|
|
886
|
+
)
|
|
887
|
+
parser.add_argument(
|
|
888
|
+
"--recent",
|
|
889
|
+
default=True,
|
|
890
|
+
action=BooleanOptionalAction,
|
|
891
|
+
help="Keeps only the most recent experiment for the same of keys.",
|
|
892
|
+
)
|
|
893
|
+
parser.add_argument(
|
|
894
|
+
"--keep-last-date",
|
|
895
|
+
default=False,
|
|
896
|
+
action=BooleanOptionalAction,
|
|
897
|
+
help="Rewrite all dates to the last one to simplifies the analysis, "
|
|
898
|
+
"this assume changing the date does not add ambiguity, if any, option "
|
|
899
|
+
"--recent should be added.",
|
|
900
|
+
)
|
|
901
|
+
parser.add_argument(
|
|
902
|
+
"--raw",
|
|
903
|
+
default=True,
|
|
904
|
+
action=BooleanOptionalAction,
|
|
905
|
+
help="Keeps the raw data in a sheet.",
|
|
906
|
+
)
|
|
907
|
+
parser.add_argument("-t", "--time", default="DATE", help="Date or time column")
|
|
908
|
+
parser.add_argument(
|
|
909
|
+
"-k",
|
|
910
|
+
"--keys",
|
|
911
|
+
default="^version_.*,^model_.*,device,opt_patterns,suite,memory_peak,"
|
|
912
|
+
"machine,exporter,dynamic,rtopt,dtype,device,architecture",
|
|
913
|
+
help="List of columns to consider as keys, "
|
|
914
|
+
"multiple values are separated by `,`\n"
|
|
915
|
+
"regular expressions are allowed",
|
|
916
|
+
)
|
|
917
|
+
parser.add_argument(
|
|
918
|
+
"--drop-keys",
|
|
919
|
+
default="",
|
|
920
|
+
help="Drops keys from the given list. Something it is faster "
|
|
921
|
+
"to remove one than to select all the remaining ones.",
|
|
922
|
+
)
|
|
923
|
+
parser.add_argument(
|
|
924
|
+
"-w",
|
|
925
|
+
"--values",
|
|
926
|
+
default="^time_.*,^disc.*,^ERR_.*,CMD,^ITER.*,^onnx_.*,^op_onnx_.*,^peak_gpu_.*",
|
|
927
|
+
help="List of columns to consider as values, "
|
|
928
|
+
"multiple values are separated by `,`\n"
|
|
929
|
+
"regular expressions are allowed",
|
|
930
|
+
)
|
|
931
|
+
parser.add_argument(
|
|
932
|
+
"-i", "--ignored", default="^version_.*", help="List of columns to ignore"
|
|
933
|
+
)
|
|
934
|
+
parser.add_argument(
|
|
935
|
+
"-f",
|
|
936
|
+
"--formula",
|
|
937
|
+
default="speedup,bucket[speedup],ERR1,n_models,n_model_eager,"
|
|
938
|
+
"n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
|
|
939
|
+
"n_model_pass,n_model_faster,"
|
|
940
|
+
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
941
|
+
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
|
|
942
|
+
"n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
|
|
943
|
+
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
944
|
+
"n_node_constant,n_node_shape,n_node_expand,"
|
|
945
|
+
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
946
|
+
"time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
|
|
947
|
+
help="Columns to compute after the aggregation was done.",
|
|
948
|
+
)
|
|
949
|
+
parser.add_argument(
|
|
950
|
+
"--views",
|
|
951
|
+
default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
|
|
952
|
+
"bucket-speedup,raw-short,counts,peak-gpu,onnx",
|
|
953
|
+
help=textwrap.dedent(
|
|
954
|
+
"""
|
|
955
|
+
Views to add to the output files. Each view becomes a tab.
|
|
956
|
+
A view is defined by its name, among
|
|
957
|
+
agg-suite, agg-all, disc, speedup, time, time_export, err,
|
|
958
|
+
cmd, bucket-speedup, raw-short, counts, peak-gpu, onnx.
|
|
959
|
+
Their definition is part of class CubeLogsPerformance.
|
|
960
|
+
"""
|
|
961
|
+
),
|
|
962
|
+
)
|
|
963
|
+
parser.add_argument(
|
|
964
|
+
"--csv",
|
|
965
|
+
default="raw-short",
|
|
966
|
+
help="Views to dump as csv files.",
|
|
967
|
+
)
|
|
968
|
+
parser.add_argument("-v", "--verbose", type=int, default=0, help="verbosity")
|
|
969
|
+
parser.add_argument(
|
|
970
|
+
"--filter-in",
|
|
971
|
+
default="",
|
|
972
|
+
help="adds a filter to filter in data, syntax is\n"
|
|
973
|
+
'``"<column1>:<value1>;<value2>//<column2>:<value3>"`` ...',
|
|
974
|
+
)
|
|
975
|
+
parser.add_argument(
|
|
976
|
+
"--filter-out",
|
|
977
|
+
default="",
|
|
978
|
+
help="adds a filter to filter out data, syntax is\n"
|
|
979
|
+
'``"<column1>:<value1>;<value2>//<column2>:<value3>"`` ...',
|
|
980
|
+
)
|
|
981
|
+
parser.add_argument(
|
|
982
|
+
"--sbs",
|
|
983
|
+
help=textwrap.dedent(
|
|
984
|
+
"""
|
|
985
|
+
Defines an exporter to compare to another, there must be at least
|
|
986
|
+
two arguments defined with --sbs. Example:
|
|
987
|
+
--sbs dynamo:exporter=onnx-dynamo,opt=ir,attn_impl=eager
|
|
988
|
+
--sbs custom:exporter=custom,opt=default,attn_impl=eager
|
|
989
|
+
"""
|
|
990
|
+
),
|
|
991
|
+
action=_ParseNamedDict,
|
|
992
|
+
)
|
|
993
|
+
return parser
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def _cmd_agg(argv: List[Any]):
|
|
997
|
+
from .helpers._log_helper import open_dataframe, enumerate_csv_files, filter_data
|
|
998
|
+
from .helpers.log_helper import CubeLogsPerformance
|
|
999
|
+
|
|
1000
|
+
parser = get_parser_agg()
|
|
1001
|
+
args = parser.parse_args(argv[1:])
|
|
1002
|
+
reg = re.compile(args.filter)
|
|
1003
|
+
|
|
1004
|
+
csv = list(
|
|
1005
|
+
enumerate_csv_files(
|
|
1006
|
+
args.inputs, verbose=args.verbose, filtering=lambda name: bool(reg.search(name))
|
|
1007
|
+
)
|
|
1008
|
+
)
|
|
1009
|
+
assert csv, f"No csv files in {args.inputs}, args.filter={args.filter!r}, csv={csv}"
|
|
1010
|
+
if args.verbose:
|
|
1011
|
+
from tqdm import tqdm
|
|
1012
|
+
|
|
1013
|
+
loop = tqdm(csv)
|
|
1014
|
+
else:
|
|
1015
|
+
loop = csv
|
|
1016
|
+
dfs = []
|
|
1017
|
+
for c in loop:
|
|
1018
|
+
df = open_dataframe(c)
|
|
1019
|
+
assert (
|
|
1020
|
+
args.time in df.columns
|
|
1021
|
+
), f"Missing time column {args.time!r} in {c!r}\n{df.head()}\n{sorted(df.columns)}"
|
|
1022
|
+
dfs.append(filter_data(df, filter_in=args.filter_in, filter_out=args.filter_out))
|
|
1023
|
+
|
|
1024
|
+
drop_keys = set(args.drop_keys.split(","))
|
|
1025
|
+
cube = CubeLogsPerformance(
|
|
1026
|
+
dfs,
|
|
1027
|
+
time=args.time,
|
|
1028
|
+
keys=[a for a in args.keys.split(",") if a and a not in drop_keys],
|
|
1029
|
+
values=[a for a in args.values.split(",") if a],
|
|
1030
|
+
ignored=[a for a in args.ignored.split(",") if a],
|
|
1031
|
+
recent=args.recent,
|
|
1032
|
+
formulas={k: k for k in args.formula.split(",")},
|
|
1033
|
+
keep_last_date=args.keep_last_date,
|
|
1034
|
+
)
|
|
1035
|
+
cube.load(verbose=max(args.verbose - 1, 0))
|
|
1036
|
+
if args.verbose:
|
|
1037
|
+
print(f"Dumps final file into {args.output!r}")
|
|
1038
|
+
cube.to_excel(
|
|
1039
|
+
args.output,
|
|
1040
|
+
{k: k for k in args.views.split(",")},
|
|
1041
|
+
verbose=args.verbose,
|
|
1042
|
+
csv=args.csv.split(","),
|
|
1043
|
+
raw=args.raw,
|
|
1044
|
+
time_mask=True,
|
|
1045
|
+
sbs=args.sbs,
|
|
1046
|
+
)
|
|
1047
|
+
if args.verbose:
|
|
1048
|
+
print(f"Wrote {args.output!r}")
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
def get_main_parser() -> ArgumentParser:
|
|
1052
|
+
parser = ArgumentParser(
|
|
1053
|
+
prog="onnx_diagnostic",
|
|
1054
|
+
description="onnx_diagnostic main command line.\n",
|
|
1055
|
+
formatter_class=RawTextHelpFormatter,
|
|
1056
|
+
epilog=textwrap.dedent(
|
|
1057
|
+
"""
|
|
1058
|
+
Type 'python -m onnx_diagnostic <cmd> --help'
|
|
1059
|
+
to get help for a specific command.
|
|
1060
|
+
|
|
1061
|
+
agg - aggregates statistics from multiple files
|
|
1062
|
+
config - prints a configuration for a model id
|
|
1063
|
+
exportsample - produces a code to export a model
|
|
1064
|
+
find - find node consuming or producing a result
|
|
1065
|
+
lighten - makes an onnx model lighter by removing the weights,
|
|
1066
|
+
print - prints the model on standard output
|
|
1067
|
+
stats - produces statistics on a model
|
|
1068
|
+
unlighten - restores an onnx model produces by the previous experiment
|
|
1069
|
+
validate - validate a model
|
|
1070
|
+
"""
|
|
1071
|
+
),
|
|
1072
|
+
)
|
|
1073
|
+
parser.add_argument(
|
|
1074
|
+
"cmd",
|
|
1075
|
+
choices=[
|
|
1076
|
+
"agg",
|
|
1077
|
+
"config",
|
|
1078
|
+
"exportsample",
|
|
1079
|
+
"find",
|
|
1080
|
+
"lighten",
|
|
1081
|
+
"print",
|
|
1082
|
+
"stats",
|
|
1083
|
+
"unlighten",
|
|
1084
|
+
"validate",
|
|
1085
|
+
],
|
|
1086
|
+
help="Selects a command.",
|
|
1087
|
+
)
|
|
1088
|
+
return parser
|
|
1089
|
+
|
|
1090
|
+
|
|
1091
|
+
def main(argv: Optional[List[Any]] = None):
|
|
1092
|
+
fcts = dict(
|
|
1093
|
+
lighten=_cmd_lighten,
|
|
1094
|
+
unlighten=_cmd_unlighten,
|
|
1095
|
+
print=_cmd_print,
|
|
1096
|
+
find=_cmd_find,
|
|
1097
|
+
config=_cmd_config,
|
|
1098
|
+
validate=_cmd_validate,
|
|
1099
|
+
stats=_cmd_stats,
|
|
1100
|
+
agg=_cmd_agg,
|
|
1101
|
+
exportsample=_cmd_export_sample,
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
if argv is None:
|
|
1105
|
+
argv = sys.argv[1:]
|
|
1106
|
+
if (
|
|
1107
|
+
len(argv) == 0
|
|
1108
|
+
or (len(argv) <= 1 and argv[0] not in fcts)
|
|
1109
|
+
or argv[-1] in ("--help", "-h")
|
|
1110
|
+
):
|
|
1111
|
+
if len(argv) < 2:
|
|
1112
|
+
parser = get_main_parser()
|
|
1113
|
+
parser.parse_args(argv)
|
|
1114
|
+
else:
|
|
1115
|
+
parsers = dict(
|
|
1116
|
+
lighten=get_parser_lighten,
|
|
1117
|
+
unlighten=get_parser_unlighten,
|
|
1118
|
+
print=get_parser_print,
|
|
1119
|
+
find=get_parser_find,
|
|
1120
|
+
config=get_parser_config,
|
|
1121
|
+
validate=get_parser_validate,
|
|
1122
|
+
stats=get_parser_stats,
|
|
1123
|
+
agg=get_parser_agg,
|
|
1124
|
+
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
|
|
1125
|
+
)
|
|
1126
|
+
cmd = argv[0]
|
|
1127
|
+
if cmd not in parsers:
|
|
1128
|
+
raise ValueError(
|
|
1129
|
+
f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
|
|
1130
|
+
)
|
|
1131
|
+
parser = parsers[cmd]() # type: ignore[operator]
|
|
1132
|
+
parser.parse_args(argv[1:])
|
|
1133
|
+
raise RuntimeError("The programme should have exited before.")
|
|
1134
|
+
|
|
1135
|
+
cmd = argv[0]
|
|
1136
|
+
if cmd in fcts:
|
|
1137
|
+
fcts[cmd](argv)
|
|
1138
|
+
else:
|
|
1139
|
+
raise ValueError(
|
|
1140
|
+
f"Unknown command {cmd!r}, use --help to get the list of known command."
|
|
1141
|
+
)
|