onnx-diagnostic 0.7.10__py3-none-any.whl → 0.7.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +13 -3
- onnx_diagnostic/helpers/cache_helper.py +8 -6
- onnx_diagnostic/helpers/log_helper.py +65 -12
- onnx_diagnostic/helpers/rt_helper.py +53 -36
- onnx_diagnostic/tasks/__init__.py +4 -2
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +11 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +5 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +73 -32
- onnx_diagnostic/torch_models/hghub/hub_data.py +3 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +70 -38
- onnx_diagnostic/torch_models/hghub/model_specific.py +27 -0
- onnx_diagnostic/torch_models/validate.py +329 -88
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/RECORD +19 -18
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.10.dist-info → onnx_diagnostic-0.7.12.dist-info}/top_level.txt +0 -0
|
@@ -3,17 +3,15 @@ import inspect
|
|
|
3
3
|
import os
|
|
4
4
|
import pprint
|
|
5
5
|
import sys
|
|
6
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
7
7
|
import time
|
|
8
8
|
import numpy as np
|
|
9
9
|
import onnx
|
|
10
|
-
import onnxscript
|
|
11
|
-
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
12
10
|
import torch
|
|
13
11
|
from ..export import CoupleInputsDynamicShapes
|
|
14
12
|
from ..helpers import max_diff, string_type, string_diff
|
|
15
13
|
from ..helpers.helper import flatten_object
|
|
16
|
-
from ..helpers.rt_helper import make_feeds
|
|
14
|
+
from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch
|
|
17
15
|
from ..helpers.torch_helper import to_any, torch_deepcopy
|
|
18
16
|
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
19
17
|
from ..tasks import random_input_kwargs
|
|
@@ -113,6 +111,8 @@ def _make_folder_name(
|
|
|
113
111
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
114
112
|
device: Optional[Union[str, torch.device]] = None,
|
|
115
113
|
subfolder: Optional[str] = None,
|
|
114
|
+
opset: Optional[int] = None,
|
|
115
|
+
drop_inputs: Optional[List[str]] = None,
|
|
116
116
|
) -> str:
|
|
117
117
|
"Creates a filename unique based on the given options."
|
|
118
118
|
els = [model_id.replace("/", "_")]
|
|
@@ -136,6 +136,11 @@ def _make_folder_name(
|
|
|
136
136
|
else:
|
|
137
137
|
raise AssertionError(f"unexpected value for device={device}, sdev={sdev!r}")
|
|
138
138
|
els.append(sdev)
|
|
139
|
+
if opset is not None:
|
|
140
|
+
els.append(f"op{opset}")
|
|
141
|
+
if drop_inputs:
|
|
142
|
+
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
|
|
143
|
+
els.append(f"I-{ii.upper()}")
|
|
139
144
|
return "-".join(els)
|
|
140
145
|
|
|
141
146
|
|
|
@@ -246,6 +251,7 @@ def _quiet_or_not_quiet(
|
|
|
246
251
|
summary[f"time_{suffix}_latency_std"] = a.std()
|
|
247
252
|
summary[f"time_{suffix}_latency_min"] = a.min()
|
|
248
253
|
summary[f"time_{suffix}_latency_min"] = a.max()
|
|
254
|
+
summary[f"time_{suffix}_n"] = len(a)
|
|
249
255
|
return res
|
|
250
256
|
|
|
251
257
|
|
|
@@ -262,6 +268,20 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
262
268
|
return new_cfg
|
|
263
269
|
|
|
264
270
|
|
|
271
|
+
def _preprocess_model_id(
|
|
272
|
+
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
|
|
273
|
+
) -> Tuple[str, Optional[str], bool, bool]:
|
|
274
|
+
if subfolder or "//" not in model_id:
|
|
275
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
276
|
+
spl = model_id.split("//")
|
|
277
|
+
if spl[-1] == "pretrained":
|
|
278
|
+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
|
|
279
|
+
if spl[-1] in {"transformer", "vae"}:
|
|
280
|
+
# known subfolder
|
|
281
|
+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
|
|
282
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
283
|
+
|
|
284
|
+
|
|
265
285
|
def validate_model(
|
|
266
286
|
model_id: str,
|
|
267
287
|
task: Optional[str] = None,
|
|
@@ -290,6 +310,7 @@ def validate_model(
|
|
|
290
310
|
warmup: int = 0,
|
|
291
311
|
inputs2: int = 1,
|
|
292
312
|
output_names: Optional[List[str]] = None,
|
|
313
|
+
ort_logs: bool = False,
|
|
293
314
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
294
315
|
"""
|
|
295
316
|
Validates a model.
|
|
@@ -334,13 +355,15 @@ def validate_model(
|
|
|
334
355
|
:param subfolder: version or subfolders to uses when retrieving a model id
|
|
335
356
|
:param opset: onnx opset to use for the conversion
|
|
336
357
|
:param runtime: onnx runtime to use to check about discrepancies,
|
|
337
|
-
|
|
358
|
+
possible values ``onnxruntime``, ``torch``, ``orteval``,
|
|
359
|
+
``orteval10``, ``ref`` only if `do_run` is true
|
|
338
360
|
:param repeat: number of time to measure the model
|
|
339
361
|
:param warmup: warmup the model first
|
|
340
362
|
:param inputs2: checks that the second set of inputs is reunning as well,
|
|
341
363
|
this ensures that the model does support dynamism, the value is used
|
|
342
364
|
as an increment to the first set of values (added to dimensions)
|
|
343
365
|
:param output_names: output names the onnx exporter should use
|
|
366
|
+
:param ort_logs: increases onnxruntime verbosity when creating the session
|
|
344
367
|
:return: two dictionaries, one with some metrics,
|
|
345
368
|
another one with whatever the function produces
|
|
346
369
|
|
|
@@ -361,14 +384,23 @@ def validate_model(
|
|
|
361
384
|
|
|
362
385
|
The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
|
|
363
386
|
exported model returns the same outputs as the original one, otherwise,
|
|
364
|
-
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
|
|
387
|
+
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
|
|
388
|
+
if ``runtime == 'torch'`` or
|
|
389
|
+
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
|
|
390
|
+
if ``runtime == 'orteval'`` or
|
|
391
|
+
:class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
|
|
392
|
+
if ``runtime == 'ref'``,
|
|
393
|
+
``orteval10`` increases the verbosity.
|
|
365
394
|
"""
|
|
395
|
+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
396
|
+
model_id,
|
|
397
|
+
subfolder,
|
|
398
|
+
same_as_pretrained=same_as_pretrained,
|
|
399
|
+
use_pretrained=use_pretrained,
|
|
400
|
+
)
|
|
401
|
+
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
|
|
366
402
|
if isinstance(patch, bool):
|
|
367
|
-
patch_kwargs = (
|
|
368
|
-
dict(patch_transformers=True, patch_diffusers=True, patch=True)
|
|
369
|
-
if patch
|
|
370
|
-
else dict(patch=False)
|
|
371
|
-
)
|
|
403
|
+
patch_kwargs = default_patch if patch else dict(patch=False)
|
|
372
404
|
elif isinstance(patch, str):
|
|
373
405
|
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
|
|
374
406
|
else:
|
|
@@ -377,11 +409,13 @@ def validate_model(
|
|
|
377
409
|
if "patch" not in patch_kwargs:
|
|
378
410
|
if any(patch_kwargs.values()):
|
|
379
411
|
patch_kwargs["patch"] = True
|
|
412
|
+
elif len(patch) == 1 and patch.get("patch", False):
|
|
413
|
+
patch_kwargs.update(default_patch)
|
|
380
414
|
|
|
381
415
|
assert not rewrite or patch_kwargs.get("patch", False), (
|
|
382
416
|
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
|
|
383
417
|
f"patch must be True to enable rewriting, "
|
|
384
|
-
f"if --
|
|
418
|
+
f"if --patch=0 was specified on the command line, rewrites are disabled."
|
|
385
419
|
)
|
|
386
420
|
summary = version_summary()
|
|
387
421
|
summary.update(
|
|
@@ -412,7 +446,14 @@ def validate_model(
|
|
|
412
446
|
folder_name = None
|
|
413
447
|
if dump_folder:
|
|
414
448
|
folder_name = _make_folder_name(
|
|
415
|
-
model_id,
|
|
449
|
+
model_id,
|
|
450
|
+
exporter,
|
|
451
|
+
optimization,
|
|
452
|
+
dtype=dtype,
|
|
453
|
+
device=device,
|
|
454
|
+
subfolder=subfolder,
|
|
455
|
+
opset=opset,
|
|
456
|
+
drop_inputs=drop_inputs,
|
|
416
457
|
)
|
|
417
458
|
dump_folder = os.path.join(dump_folder, folder_name)
|
|
418
459
|
if not os.path.exists(dump_folder):
|
|
@@ -508,6 +549,11 @@ def validate_model(
|
|
|
508
549
|
if verbose:
|
|
509
550
|
print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
|
|
510
551
|
|
|
552
|
+
# modelbuilder needs different treatments sometimes, so
|
|
553
|
+
# we mark it for later usage.
|
|
554
|
+
# for example, it has different past_kv ordering than
|
|
555
|
+
# flattened CacheObject
|
|
556
|
+
data["exporter"] = exporter
|
|
511
557
|
data["input_options"] = iop
|
|
512
558
|
data["model_options"] = mop
|
|
513
559
|
data["model_dump_folder"] = dump_folder
|
|
@@ -743,6 +789,7 @@ def validate_model(
|
|
|
743
789
|
repeat=repeat,
|
|
744
790
|
warmup=warmup,
|
|
745
791
|
inputs2=inputs2,
|
|
792
|
+
ort_logs=ort_logs,
|
|
746
793
|
)
|
|
747
794
|
summary.update(summary_valid)
|
|
748
795
|
|
|
@@ -807,6 +854,8 @@ def validate_model(
|
|
|
807
854
|
)
|
|
808
855
|
summary.update(summary_valid)
|
|
809
856
|
|
|
857
|
+
_compute_final_statistics(summary)
|
|
858
|
+
|
|
810
859
|
if verbose:
|
|
811
860
|
print("[validate_model] -- done (final)")
|
|
812
861
|
if dump_stats:
|
|
@@ -819,15 +868,24 @@ def validate_model(
|
|
|
819
868
|
def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
|
|
820
869
|
"""Computes some statistics on the model itself."""
|
|
821
870
|
onx = onnx.load(onnx_filename, load_external_data=False)
|
|
871
|
+
cache_functions = {(f.domain, f.name): f for f in onx.functions}
|
|
872
|
+
local_domains = set(f.domain for f in onx.functions)
|
|
822
873
|
|
|
823
874
|
def node_iter(proto):
|
|
824
875
|
if isinstance(proto, onnx.ModelProto):
|
|
825
|
-
yield from node_iter(proto.graph)
|
|
826
876
|
for f in proto.functions:
|
|
827
877
|
yield from node_iter(f)
|
|
878
|
+
yield from node_iter(proto.graph)
|
|
828
879
|
elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
|
|
829
880
|
for node in proto.node:
|
|
830
881
|
yield node
|
|
882
|
+
|
|
883
|
+
# Let's inline the function
|
|
884
|
+
key = node.domain, node.op_type
|
|
885
|
+
if key in cache_functions:
|
|
886
|
+
yield from node_iter(cache_functions[key])
|
|
887
|
+
|
|
888
|
+
# Let's continue
|
|
831
889
|
for att in node.attribute:
|
|
832
890
|
if att.type == onnx.AttributeProto.GRAPH:
|
|
833
891
|
yield from node_iter(att.g)
|
|
@@ -837,15 +895,29 @@ def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
|
|
|
837
895
|
raise NotImplementedError(f"Unexpected type={type(proto)}")
|
|
838
896
|
|
|
839
897
|
counts: Dict[str, Union[float, int]] = {}
|
|
898
|
+
n_nodes = 0
|
|
899
|
+
n_nodes_nocst = 0
|
|
840
900
|
for proto in node_iter(onx):
|
|
841
901
|
if isinstance(proto, onnx.NodeProto):
|
|
842
902
|
key = f"n_node_{proto.op_type}"
|
|
903
|
+
n_nodes += 1
|
|
904
|
+
if proto.op_type != "Constant":
|
|
905
|
+
n_nodes_nocst += 1
|
|
906
|
+
if proto.domain in local_domains:
|
|
907
|
+
key = "n_node_local_function"
|
|
908
|
+
if key not in counts:
|
|
909
|
+
counts[key] = 0
|
|
910
|
+
counts[key] += 1
|
|
843
911
|
else:
|
|
844
912
|
key = f"n_node_initializer_{proto.data_type}"
|
|
845
913
|
|
|
846
914
|
if key not in counts:
|
|
847
915
|
counts[key] = 0
|
|
848
916
|
counts[key] += 1
|
|
917
|
+
|
|
918
|
+
counts["n_node_nodes"] = n_nodes
|
|
919
|
+
counts["n_node_nodes_nocst"] = n_nodes_nocst
|
|
920
|
+
counts["n_node_functions"] = len(onx.functions)
|
|
849
921
|
return counts
|
|
850
922
|
|
|
851
923
|
|
|
@@ -922,6 +994,26 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet):
|
|
|
922
994
|
)
|
|
923
995
|
|
|
924
996
|
|
|
997
|
+
_cache_export_times = []
|
|
998
|
+
_main_export_function = torch.export.export
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
def _torch_export_export(*args, _export=_main_export_function, **kwargs):
|
|
1002
|
+
begin = time.perf_counter()
|
|
1003
|
+
res = _export(*args, **kwargs)
|
|
1004
|
+
duration = time.perf_counter() - begin
|
|
1005
|
+
_cache_export_times.append(duration)
|
|
1006
|
+
return res
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
def _restore_torch_export_export(summary):
|
|
1010
|
+
torch.export.export = _main_export_function
|
|
1011
|
+
if _cache_export_times:
|
|
1012
|
+
summary["time_torch_export_export"] = sum(_cache_export_times)
|
|
1013
|
+
summary["time_torch_export_export_n"] = len(_cache_export_times)
|
|
1014
|
+
_cache_export_times.clear()
|
|
1015
|
+
|
|
1016
|
+
|
|
925
1017
|
def call_exporter(
|
|
926
1018
|
data: Dict[str, Any],
|
|
927
1019
|
exporter: str,
|
|
@@ -947,6 +1039,9 @@ def call_exporter(
|
|
|
947
1039
|
:return: two dictionaries, one with some metrics,
|
|
948
1040
|
another one with whatever the function produces
|
|
949
1041
|
"""
|
|
1042
|
+
_cache_export_times.clear()
|
|
1043
|
+
torch.export.export = _torch_export_export
|
|
1044
|
+
|
|
950
1045
|
if exporter == "export" or exporter.startswith("export-"):
|
|
951
1046
|
# torch export
|
|
952
1047
|
summary, data = call_torch_export_export(
|
|
@@ -957,6 +1052,7 @@ def call_exporter(
|
|
|
957
1052
|
optimization=optimization,
|
|
958
1053
|
do_run=do_run,
|
|
959
1054
|
)
|
|
1055
|
+
_restore_torch_export_export(summary)
|
|
960
1056
|
return summary, data
|
|
961
1057
|
if exporter.startswith("onnx-"):
|
|
962
1058
|
# torch export
|
|
@@ -968,6 +1064,7 @@ def call_exporter(
|
|
|
968
1064
|
optimization=optimization,
|
|
969
1065
|
output_names=output_names,
|
|
970
1066
|
)
|
|
1067
|
+
_restore_torch_export_export(summary)
|
|
971
1068
|
return summary, data
|
|
972
1069
|
if exporter == "custom" or exporter.startswith("custom"):
|
|
973
1070
|
# torch export
|
|
@@ -980,6 +1077,7 @@ def call_exporter(
|
|
|
980
1077
|
dump_folder=dump_folder,
|
|
981
1078
|
output_names=output_names,
|
|
982
1079
|
)
|
|
1080
|
+
_restore_torch_export_export(summary)
|
|
983
1081
|
return summary, data
|
|
984
1082
|
if exporter == "modelbuilder":
|
|
985
1083
|
# torch export
|
|
@@ -991,6 +1089,7 @@ def call_exporter(
|
|
|
991
1089
|
optimization=optimization,
|
|
992
1090
|
output_names=output_names,
|
|
993
1091
|
)
|
|
1092
|
+
_restore_torch_export_export(summary)
|
|
994
1093
|
return summary, data
|
|
995
1094
|
raise NotImplementedError(
|
|
996
1095
|
f"export with {exporter!r} and optimization={optimization!r} not implemented yet, "
|
|
@@ -1134,6 +1233,7 @@ def validate_onnx_model(
|
|
|
1134
1233
|
repeat: int = 1,
|
|
1135
1234
|
warmup: int = 0,
|
|
1136
1235
|
inputs2: int = 1,
|
|
1236
|
+
ort_logs: bool = False,
|
|
1137
1237
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1138
1238
|
"""
|
|
1139
1239
|
Verifies that an onnx model produces the same
|
|
@@ -1146,12 +1246,13 @@ def validate_onnx_model(
|
|
|
1146
1246
|
:param quiet: catch exception or not
|
|
1147
1247
|
:param verbose: verbosity
|
|
1148
1248
|
:param flavour: use a different version of the inputs
|
|
1149
|
-
:param runtime: onnx runtime to use, onnxruntime
|
|
1249
|
+
:param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
|
|
1150
1250
|
:param repeat: run that number of times the model
|
|
1151
1251
|
:param warmup: warmup the model
|
|
1152
1252
|
:param inputs2: to validate the model on the second input set
|
|
1153
1253
|
to make sure the exported model supports dynamism, the value is
|
|
1154
1254
|
used as an increment added to the first set of inputs (added to dimensions)
|
|
1255
|
+
:param ort_logs: triggers the logs for onnxruntime
|
|
1155
1256
|
:return: two dictionaries, one with some metrics,
|
|
1156
1257
|
another one with whatever the function produces
|
|
1157
1258
|
"""
|
|
@@ -1193,23 +1294,71 @@ def validate_onnx_model(
|
|
|
1193
1294
|
f"{providers}..., flavour={flavour!r}"
|
|
1194
1295
|
)
|
|
1195
1296
|
|
|
1196
|
-
if runtime
|
|
1297
|
+
if runtime == "onnxruntime":
|
|
1298
|
+
if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
|
|
1299
|
+
opts = onnxruntime.SessionOptions()
|
|
1300
|
+
opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
|
|
1301
|
+
if verbose:
|
|
1302
|
+
print(
|
|
1303
|
+
f"[validate_onnx_model] saved optimized onnxruntime "
|
|
1304
|
+
f"in {opts.optimized_model_filepath!r}"
|
|
1305
|
+
)
|
|
1306
|
+
onnxruntime.InferenceSession(data["onnx_filename"], opts, providers=providers)
|
|
1307
|
+
if verbose:
|
|
1308
|
+
print("[validate_onnx_model] -- done")
|
|
1309
|
+
|
|
1310
|
+
if verbose:
|
|
1311
|
+
print("[validate_onnx_model] runtime is onnxruntime")
|
|
1312
|
+
sess_opts = onnxruntime.SessionOptions()
|
|
1313
|
+
if ort_logs:
|
|
1314
|
+
sess_opts.log_severity_level = 0
|
|
1315
|
+
sess_opts.log_verbosity_level = 4
|
|
1316
|
+
cls_runtime = lambda model, providers, _o=sess_opts: onnxruntime.InferenceSession(
|
|
1317
|
+
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
|
|
1318
|
+
_o,
|
|
1319
|
+
providers=providers,
|
|
1320
|
+
)
|
|
1321
|
+
elif runtime == "torch":
|
|
1197
1322
|
from ..reference import TorchOnnxEvaluator
|
|
1198
1323
|
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
providers=providers,
|
|
1324
|
+
if verbose:
|
|
1325
|
+
print("[validate_onnx_model] runtime is TorchOnnxEvaluator")
|
|
1326
|
+
cls_runtime = (
|
|
1327
|
+
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
|
|
1328
|
+
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1204
1329
|
)
|
|
1205
1330
|
)
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1331
|
+
elif runtime == "orteval":
|
|
1332
|
+
from ..reference import OnnxruntimeEvaluator
|
|
1333
|
+
|
|
1334
|
+
if verbose:
|
|
1335
|
+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator")
|
|
1336
|
+
cls_runtime = (
|
|
1337
|
+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
|
|
1209
1338
|
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1210
1339
|
)
|
|
1211
1340
|
)
|
|
1212
|
-
|
|
1341
|
+
elif runtime == "orteval10":
|
|
1342
|
+
from ..reference import OnnxruntimeEvaluator
|
|
1343
|
+
|
|
1344
|
+
if verbose:
|
|
1345
|
+
print("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)")
|
|
1346
|
+
cls_runtime = (
|
|
1347
|
+
lambda model, providers, _cls_=OnnxruntimeEvaluator: _cls_( # type: ignore[misc]
|
|
1348
|
+
model, providers=providers, verbose=10
|
|
1349
|
+
)
|
|
1350
|
+
)
|
|
1351
|
+
elif runtime == "ref":
|
|
1352
|
+
from ..reference import ExtendedReferenceEvaluator
|
|
1353
|
+
|
|
1354
|
+
if verbose:
|
|
1355
|
+
print("[validate_onnx_model] runtime is ExtendedReferenceEvaluator")
|
|
1356
|
+
cls_runtime = lambda model, providers, _cls_=ExtendedReferenceEvaluator: _cls_( # type: ignore[misc]
|
|
1357
|
+
model, verbose=max(verbose - 1, 0)
|
|
1358
|
+
)
|
|
1359
|
+
else:
|
|
1360
|
+
raise ValueError(f"Unexpecteed runtime={runtime!r}")
|
|
1361
|
+
|
|
1213
1362
|
sess = _quiet_or_not_quiet(
|
|
1214
1363
|
quiet,
|
|
1215
1364
|
_mk("create_onnx_ort"),
|
|
@@ -1234,7 +1383,13 @@ def validate_onnx_model(
|
|
|
1234
1383
|
print(
|
|
1235
1384
|
f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
|
|
1236
1385
|
)
|
|
1237
|
-
feeds = make_feeds(
|
|
1386
|
+
feeds = make_feeds(
|
|
1387
|
+
sess,
|
|
1388
|
+
data[k_input],
|
|
1389
|
+
use_numpy=True,
|
|
1390
|
+
check_flatten=False,
|
|
1391
|
+
is_modelbuilder=data["exporter"] == "modelbuilder",
|
|
1392
|
+
)
|
|
1238
1393
|
if verbose:
|
|
1239
1394
|
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
|
|
1240
1395
|
summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
|
|
@@ -1254,6 +1409,13 @@ def validate_onnx_model(
|
|
|
1254
1409
|
repeat=repeat,
|
|
1255
1410
|
warmup=warmup,
|
|
1256
1411
|
)
|
|
1412
|
+
# NOTE: modelbuilder has different order on past_kv outputs
|
|
1413
|
+
if data["exporter"] == "modelbuilder":
|
|
1414
|
+
logits = got[:1]
|
|
1415
|
+
past_key_values = got[1:]
|
|
1416
|
+
reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values)
|
|
1417
|
+
got = logits + reorder_past_key_values
|
|
1418
|
+
|
|
1257
1419
|
if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
|
|
1258
1420
|
return summary, data
|
|
1259
1421
|
|
|
@@ -1294,7 +1456,7 @@ def call_torch_export_onnx(
|
|
|
1294
1456
|
:return: two dictionaries, one with some metrics,
|
|
1295
1457
|
another one with whatever the function produces
|
|
1296
1458
|
"""
|
|
1297
|
-
available = {None, "", "ir", "os_ort"}
|
|
1459
|
+
available = {None, "", "ir", "os_ort", "ir+default"}
|
|
1298
1460
|
assert (
|
|
1299
1461
|
optimization in available
|
|
1300
1462
|
), f"unexpected value for optimization={optimization}, available={available}"
|
|
@@ -1384,12 +1546,34 @@ def call_torch_export_onnx(
|
|
|
1384
1546
|
print(epo)
|
|
1385
1547
|
print("[call_torch_export_onnx] -- End of ONNXProgram")
|
|
1386
1548
|
|
|
1387
|
-
if optimization in {"ir", "os_ort"}:
|
|
1549
|
+
if optimization in {"ir", "os_ort", "ir+default"}:
|
|
1388
1550
|
if verbose:
|
|
1389
1551
|
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
|
|
1390
1552
|
if optimization == "ir":
|
|
1391
1553
|
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
|
|
1554
|
+
elif optimization == "ir+default":
|
|
1555
|
+
import onnxscript
|
|
1556
|
+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
|
|
1557
|
+
|
|
1558
|
+
def _ir_default_opt(epo):
|
|
1559
|
+
onnxscript.optimizer.optimize_ir(epo.model)
|
|
1560
|
+
onx = epo.model_proto
|
|
1561
|
+
# not very efficient
|
|
1562
|
+
gr = GraphBuilder(
|
|
1563
|
+
onx,
|
|
1564
|
+
infer_shapes_options=True,
|
|
1565
|
+
optimization_options=OptimizationOptions(patterns="default"),
|
|
1566
|
+
)
|
|
1567
|
+
cont = gr.to_onnx(large_model=True)
|
|
1568
|
+
epo.model = cont.to_ir()
|
|
1569
|
+
|
|
1570
|
+
label, f_optim = "export_onnx_opt_ir_default", (
|
|
1571
|
+
lambda epo=epo: _ir_default_opt(epo)
|
|
1572
|
+
)
|
|
1573
|
+
|
|
1392
1574
|
else:
|
|
1575
|
+
import onnxscript
|
|
1576
|
+
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
1393
1577
|
|
|
1394
1578
|
def _os_ort_optim(epo):
|
|
1395
1579
|
onnxscript.optimizer.optimize_ir(epo.model)
|
|
@@ -1477,6 +1661,97 @@ def call_torch_export_model_builder(
|
|
|
1477
1661
|
return summary, data
|
|
1478
1662
|
|
|
1479
1663
|
|
|
1664
|
+
def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]:
|
|
1665
|
+
"""
|
|
1666
|
+
Processes statistics coming from the exporters.
|
|
1667
|
+
It takes a sequence of dictionaries (like a data frame)
|
|
1668
|
+
and extracts some metrics.
|
|
1669
|
+
"""
|
|
1670
|
+
|
|
1671
|
+
def _simplify(p):
|
|
1672
|
+
for s in [
|
|
1673
|
+
"remove_unused",
|
|
1674
|
+
"constant_folding",
|
|
1675
|
+
"remove_identity",
|
|
1676
|
+
"remove_duplicated_initializer",
|
|
1677
|
+
"dynamic_dimension_naming",
|
|
1678
|
+
"inline",
|
|
1679
|
+
"check",
|
|
1680
|
+
"build_graph_for_pattern",
|
|
1681
|
+
"pattern_optimization",
|
|
1682
|
+
]:
|
|
1683
|
+
if s in p or s.replace("_", "-") in p:
|
|
1684
|
+
return s
|
|
1685
|
+
if p.startswith(("apply_", "match_")):
|
|
1686
|
+
return p
|
|
1687
|
+
return "other"
|
|
1688
|
+
|
|
1689
|
+
def _add(d, a, v, use_max=False):
|
|
1690
|
+
if v:
|
|
1691
|
+
if a not in d:
|
|
1692
|
+
d[a] = v
|
|
1693
|
+
elif use_max:
|
|
1694
|
+
d[a] = max(d[a], v)
|
|
1695
|
+
else:
|
|
1696
|
+
d[a] += v
|
|
1697
|
+
|
|
1698
|
+
counts: Dict[str, Any] = {}
|
|
1699
|
+
applied_pattern_time: Dict[str, Any] = {}
|
|
1700
|
+
applied_pattern_n: Dict[str, Any] = {}
|
|
1701
|
+
matching_pattern_time: Dict[str, Any] = {}
|
|
1702
|
+
matching_pattern_n: Dict[str, Any] = {}
|
|
1703
|
+
|
|
1704
|
+
for obs in data:
|
|
1705
|
+
pattern = _simplify(obs["pattern"])
|
|
1706
|
+
_add(counts, "opt_nodes_added", obs.get("added", 0))
|
|
1707
|
+
_add(counts, "opt_nodes_removed", obs.get("removed", 0))
|
|
1708
|
+
_add(counts, "opt_time_steps", obs.get("time_in", 0))
|
|
1709
|
+
_add(counts, "opt_n_steps", 1)
|
|
1710
|
+
_add(
|
|
1711
|
+
counts,
|
|
1712
|
+
"opt_n_iteration",
|
|
1713
|
+
max(counts.get("opt_n_iteration", 0), obs.get("iteration", 0)),
|
|
1714
|
+
use_max=True,
|
|
1715
|
+
)
|
|
1716
|
+
|
|
1717
|
+
if pattern.startswith("apply_"):
|
|
1718
|
+
_add(counts, "opt_n_applied_patterns", 1)
|
|
1719
|
+
_add(counts, "opt_time_applied_patterns", obs.get("time_in", 0))
|
|
1720
|
+
_add(applied_pattern_time, pattern, obs.get("time_in", 0))
|
|
1721
|
+
_add(applied_pattern_n, pattern, 1)
|
|
1722
|
+
elif pattern.startswith("match_"):
|
|
1723
|
+
_add(counts, "opt_n_matching_patterns", 1)
|
|
1724
|
+
_add(counts, "opt_time_matching_patterns", obs.get("time_in", 0))
|
|
1725
|
+
_add(matching_pattern_time, pattern, obs.get("time_in", 0))
|
|
1726
|
+
_add(matching_pattern_n, pattern, 1)
|
|
1727
|
+
else:
|
|
1728
|
+
_add(counts, f"opt_time_{pattern}", obs.get("time_in", 0))
|
|
1729
|
+
_add(counts, f"opt_n_{pattern}", 1)
|
|
1730
|
+
_add(counts, f"opt_nodes_added_{pattern}", obs.get("added", 0))
|
|
1731
|
+
_add(counts, f"opt_nodes_removed_{pattern}", obs.get("removed", 0))
|
|
1732
|
+
|
|
1733
|
+
if applied_pattern_time:
|
|
1734
|
+
longest = max((v, k) for k, v in applied_pattern_time.items())
|
|
1735
|
+
counts["opt_top_time_applied_pattern"], counts["opt_top_time_applied_pattern_arg"] = (
|
|
1736
|
+
longest
|
|
1737
|
+
)
|
|
1738
|
+
longest = max((v, k) for k, v in applied_pattern_n.items())
|
|
1739
|
+
counts["opt_top_n_applied_pattern"], counts["opt_top_n_applied_pattern_arg"] = longest
|
|
1740
|
+
|
|
1741
|
+
if matching_pattern_time:
|
|
1742
|
+
longest = max((v, k) for k, v in matching_pattern_time.items())
|
|
1743
|
+
(
|
|
1744
|
+
counts["opt_top_time_matching_pattern"],
|
|
1745
|
+
counts["opt_top_time_matching_pattern_arg"],
|
|
1746
|
+
) = longest
|
|
1747
|
+
longest = max((v, k) for k, v in matching_pattern_n.items())
|
|
1748
|
+
counts["opt_top_n_matching_pattern"], counts["opt_top_n_matching_pattern_arg"] = (
|
|
1749
|
+
longest
|
|
1750
|
+
)
|
|
1751
|
+
counts["onnx_opt_optimized"] = 1
|
|
1752
|
+
return counts
|
|
1753
|
+
|
|
1754
|
+
|
|
1480
1755
|
def call_torch_export_custom(
|
|
1481
1756
|
data: Dict[str, Any],
|
|
1482
1757
|
exporter: str,
|
|
@@ -1509,6 +1784,8 @@ def call_torch_export_custom(
|
|
|
1509
1784
|
"default+onnxruntime+os_ort",
|
|
1510
1785
|
None,
|
|
1511
1786
|
}
|
|
1787
|
+
if optimization == "none":
|
|
1788
|
+
optimization = ""
|
|
1512
1789
|
assert (
|
|
1513
1790
|
optimization in available
|
|
1514
1791
|
), f"unexpected value for optimization={optimization}, available={available}"
|
|
@@ -1604,67 +1881,10 @@ def call_torch_export_custom(
|
|
|
1604
1881
|
if "ERR_export_onnx_c" in summary:
|
|
1605
1882
|
return summary, data
|
|
1606
1883
|
|
|
1607
|
-
new_stat = {}
|
|
1884
|
+
new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
|
|
1885
|
+
new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")})
|
|
1608
1886
|
if "optimization" in opt_stats:
|
|
1609
|
-
|
|
1610
|
-
max_iter = 0
|
|
1611
|
-
applied = {}
|
|
1612
|
-
matched = set()
|
|
1613
|
-
n_applied = 0
|
|
1614
|
-
by_pattern = {}
|
|
1615
|
-
by_pattern_n = {}
|
|
1616
|
-
by_iter = {}
|
|
1617
|
-
cst_added, cst_removed, cst_time_in = 0, 0, 0.0
|
|
1618
|
-
|
|
1619
|
-
for obs in opt_stats["optimization"]:
|
|
1620
|
-
pattern = obs["pattern"]
|
|
1621
|
-
if pattern == "constant_folding":
|
|
1622
|
-
cst_added += obs.get("added", 0)
|
|
1623
|
-
cst_removed += obs.get("removed", 0)
|
|
1624
|
-
cst_time_in += obs.get("time_in", 0)
|
|
1625
|
-
if pattern not in by_pattern:
|
|
1626
|
-
by_pattern[pattern] = 0
|
|
1627
|
-
by_pattern_n[pattern] = 0
|
|
1628
|
-
by_iter[pattern] = 0
|
|
1629
|
-
time_in += obs.get("time_in", 0)
|
|
1630
|
-
added += obs.get("added", 0)
|
|
1631
|
-
removed += obs.get("removed", 0)
|
|
1632
|
-
max_iter = max(max_iter, obs.get("iteration", 0))
|
|
1633
|
-
by_pattern[pattern] += obs.get("time_in", 0)
|
|
1634
|
-
by_pattern_n[pattern] += obs.get("added", 0) - obs.get("removed", 0)
|
|
1635
|
-
if not pattern.startswith("match"):
|
|
1636
|
-
by_iter[pattern] = max(by_iter[pattern], obs.get("iteration", 0))
|
|
1637
|
-
p = obs["pattern"]
|
|
1638
|
-
if p.startswith("match_"):
|
|
1639
|
-
matched.add(p)
|
|
1640
|
-
elif p.startswith("apply_"):
|
|
1641
|
-
key = f"op_opt_{p}"
|
|
1642
|
-
key2 = f"op_opt_maxiter_{p}"
|
|
1643
|
-
if key not in applied:
|
|
1644
|
-
applied[key] = 1
|
|
1645
|
-
applied[key2] = obs["iteration"]
|
|
1646
|
-
else:
|
|
1647
|
-
applied[key] += 1
|
|
1648
|
-
applied[key2] = max(obs["iteration"], applied[key2])
|
|
1649
|
-
n_applied += 1
|
|
1650
|
-
|
|
1651
|
-
new_stat.update(
|
|
1652
|
-
dict(
|
|
1653
|
-
onnx_opt_optimized=1,
|
|
1654
|
-
op_opt_all_time_in=time_in,
|
|
1655
|
-
op_opt_all_added=added,
|
|
1656
|
-
op_opt_all_removed=removed,
|
|
1657
|
-
op_opt_max_iter=max_iter,
|
|
1658
|
-
op_opt_unique_matched=len(matched),
|
|
1659
|
-
op_opt_unique_applied=len(applied),
|
|
1660
|
-
op_opt_n_applied=n_applied,
|
|
1661
|
-
time_export_optimization=time_in,
|
|
1662
|
-
op_opt_export_optimization=time_in,
|
|
1663
|
-
op_opt_cst_time_in=cst_time_in,
|
|
1664
|
-
op_opt_cst_added=cst_added,
|
|
1665
|
-
op_opt_cst_removed=cst_removed,
|
|
1666
|
-
)
|
|
1667
|
-
)
|
|
1887
|
+
new_stat.update(process_statistics(opt_stats["optimization"]))
|
|
1668
1888
|
|
|
1669
1889
|
summary.update(new_stat)
|
|
1670
1890
|
assert epo is not None, "no onnx export was found"
|
|
@@ -1672,6 +1892,9 @@ def call_torch_export_custom(
|
|
|
1672
1892
|
print("[call_torch_export_custom] done (export)")
|
|
1673
1893
|
|
|
1674
1894
|
if os_ort:
|
|
1895
|
+
import onnxscript
|
|
1896
|
+
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
1897
|
+
|
|
1675
1898
|
if verbose:
|
|
1676
1899
|
print("[call_torch_export_custom] conversion to IR...")
|
|
1677
1900
|
begin = time.perf_counter()
|
|
@@ -1780,3 +2003,21 @@ def run_ort_fusion(
|
|
|
1780
2003
|
f"opt_ort_{model_type}_duration": duration,
|
|
1781
2004
|
f"opt_ort_{model_type}_duration_save": d,
|
|
1782
2005
|
}, {f"opt_ort_{model_type}": output_path}
|
|
2006
|
+
|
|
2007
|
+
|
|
2008
|
+
def _compute_final_statistics(summary: Dict[str, Any]):
|
|
2009
|
+
"""
|
|
2010
|
+
Updates inline the list of statistics. It adds:
|
|
2011
|
+
|
|
2012
|
+
- speedup
|
|
2013
|
+
"""
|
|
2014
|
+
stats = {}
|
|
2015
|
+
if (
|
|
2016
|
+
"time_run_latency" in summary
|
|
2017
|
+
and "time_run_onnx_ort_latency" in summary
|
|
2018
|
+
and summary["time_run_onnx_ort_latency"] > 0
|
|
2019
|
+
):
|
|
2020
|
+
stats["stat_estimated_speedup_ort"] = (
|
|
2021
|
+
summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
|
|
2022
|
+
)
|
|
2023
|
+
summary.update(stats)
|