onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -236,7 +236,7 @@ def code_sample(
|
|
|
236
236
|
)
|
|
237
237
|
)
|
|
238
238
|
"""
|
|
239
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
239
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
240
240
|
model_id,
|
|
241
241
|
subfolder,
|
|
242
242
|
same_as_pretrained=same_as_pretrained,
|
|
@@ -256,6 +256,7 @@ def code_sample(
|
|
|
256
256
|
model_kwargs=mop,
|
|
257
257
|
subfolder=subfolder,
|
|
258
258
|
add_second_input=False,
|
|
259
|
+
submodule=submodule,
|
|
259
260
|
)
|
|
260
261
|
if drop_inputs:
|
|
261
262
|
update = {}
|
|
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _preprocess_model_id(
|
|
29
|
-
model_id: str,
|
|
30
|
-
|
|
29
|
+
model_id: str,
|
|
30
|
+
subfolder: Optional[str],
|
|
31
|
+
same_as_pretrained: bool,
|
|
32
|
+
use_pretrained: bool,
|
|
33
|
+
submodule: Optional[str] = None,
|
|
34
|
+
) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
|
|
35
|
+
if "::" in model_id:
|
|
36
|
+
assert (
|
|
37
|
+
not submodule
|
|
38
|
+
), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
|
|
39
|
+
model_id, submodule = model_id.split("::", maxsplit=1)
|
|
31
40
|
if subfolder or "//" not in model_id:
|
|
32
|
-
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
41
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
|
|
33
42
|
spl = model_id.split("//")
|
|
34
43
|
if spl[-1] == "pretrained":
|
|
35
|
-
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
|
|
44
|
+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
|
|
36
45
|
if spl[-1] in {"transformer", "vae"}:
|
|
37
46
|
# known subfolder
|
|
38
|
-
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
|
|
39
|
-
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
47
|
+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained, submodule
|
|
48
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
|
|
40
49
|
|
|
41
50
|
|
|
42
51
|
def get_untrained_model_with_inputs(
|
|
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
|
|
|
54
63
|
subfolder: Optional[str] = None,
|
|
55
64
|
use_only_preinstalled: bool = False,
|
|
56
65
|
config_reduction: Optional[Callable[[Any, str], Dict]] = None,
|
|
66
|
+
submodule: Optional[str] = None,
|
|
57
67
|
) -> Dict[str, Any]:
|
|
58
68
|
"""
|
|
59
69
|
Gets a non initialized model similar to the original model
|
|
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
|
|
|
82
92
|
<onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
|
|
83
93
|
this function takes a configuration and a task (string)
|
|
84
94
|
as arguments
|
|
95
|
+
:param submodule: use a submodule instead of the main model
|
|
85
96
|
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
|
|
86
97
|
some necessary rewriting as well
|
|
87
98
|
|
|
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
|
|
|
108
119
|
f"model_id={model_id!r}, preinstalled model is only available "
|
|
109
120
|
f"if use_only_preinstalled is False."
|
|
110
121
|
)
|
|
111
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
122
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
112
123
|
model_id,
|
|
113
124
|
subfolder,
|
|
114
125
|
same_as_pretrained=same_as_pretrained,
|
|
115
126
|
use_pretrained=use_pretrained,
|
|
127
|
+
submodule=submodule,
|
|
116
128
|
)
|
|
117
129
|
if verbose:
|
|
118
130
|
print(
|
|
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
|
|
|
147
159
|
if verbose:
|
|
148
160
|
print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
|
|
149
161
|
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
|
|
162
|
+
if submodule:
|
|
163
|
+
print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
|
|
150
164
|
if task is None:
|
|
151
165
|
task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
|
|
152
166
|
if verbose:
|
|
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
|
|
|
357
371
|
if diff_config is not None:
|
|
358
372
|
res["dump_info"] = dict(config_diff=diff_config)
|
|
359
373
|
|
|
374
|
+
if submodule:
|
|
375
|
+
path = submodule.split("::") if "::" in submodule else [submodule]
|
|
376
|
+
for p in path:
|
|
377
|
+
assert hasattr(model, p), (
|
|
378
|
+
f"Unable to find submodule {p!r} in in class {type(model)}, "
|
|
379
|
+
f"submodule={submodule!r}, possible candidates: "
|
|
380
|
+
f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
|
|
381
|
+
)
|
|
382
|
+
model = getattr(model, p)
|
|
383
|
+
|
|
384
|
+
if verbose:
|
|
385
|
+
print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
|
|
386
|
+
|
|
360
387
|
sizes = compute_model_size(model)
|
|
361
388
|
res["model"] = model
|
|
362
389
|
res["configuration"] = config
|
|
@@ -349,13 +349,15 @@ def _prepare_validation(
|
|
|
349
349
|
verbose,
|
|
350
350
|
output_names,
|
|
351
351
|
dump_folder,
|
|
352
|
+
submodule,
|
|
352
353
|
):
|
|
353
354
|
main_validation_begin = time.perf_counter()
|
|
354
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
355
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
355
356
|
model_id,
|
|
356
357
|
subfolder,
|
|
357
358
|
same_as_pretrained=same_as_pretrained,
|
|
358
359
|
use_pretrained=use_pretrained,
|
|
360
|
+
submodule=submodule,
|
|
359
361
|
)
|
|
360
362
|
time_preprocess_model_id = time.perf_counter() - main_validation_begin
|
|
361
363
|
patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
|
|
@@ -364,6 +366,7 @@ def _prepare_validation(
|
|
|
364
366
|
summary.update(
|
|
365
367
|
dict(
|
|
366
368
|
version_model_id=model_id,
|
|
369
|
+
version_submodule=submodule,
|
|
367
370
|
version_do_run=str(do_run),
|
|
368
371
|
version_dtype=str(dtype or ""),
|
|
369
372
|
version_device=str(device or ""),
|
|
@@ -444,6 +447,7 @@ def _prepare_validation(
|
|
|
444
447
|
dump_folder,
|
|
445
448
|
folder_name,
|
|
446
449
|
patch_kwargs,
|
|
450
|
+
submodule,
|
|
447
451
|
)
|
|
448
452
|
|
|
449
453
|
|
|
@@ -460,6 +464,7 @@ def _get_untrained_model_with_inputs(
|
|
|
460
464
|
inputs2,
|
|
461
465
|
quiet,
|
|
462
466
|
dump_folder,
|
|
467
|
+
submodule,
|
|
463
468
|
):
|
|
464
469
|
iop = input_options or {}
|
|
465
470
|
mop = model_options or {}
|
|
@@ -480,6 +485,7 @@ def _get_untrained_model_with_inputs(
|
|
|
480
485
|
model_kwargs=mop,
|
|
481
486
|
subfolder=sub,
|
|
482
487
|
add_second_input=i2,
|
|
488
|
+
submodule=submodule,
|
|
483
489
|
)
|
|
484
490
|
)
|
|
485
491
|
),
|
|
@@ -671,7 +677,16 @@ def _call_exporter(
|
|
|
671
677
|
do_run,
|
|
672
678
|
output_names,
|
|
673
679
|
exporter_options,
|
|
680
|
+
save_ep,
|
|
674
681
|
):
|
|
682
|
+
if save_ep and dump_folder:
|
|
683
|
+
for name in data:
|
|
684
|
+
if name.startswith("inputs"):
|
|
685
|
+
if verbose:
|
|
686
|
+
print(f"[validate_model] -- dump {name!r}")
|
|
687
|
+
filename = os.path.join(dump_folder, f"{save_ep}.{name}.pt")
|
|
688
|
+
torch.save(data[name], filename)
|
|
689
|
+
|
|
675
690
|
if exporter:
|
|
676
691
|
expop = exporter_options or {}
|
|
677
692
|
if verbose:
|
|
@@ -711,6 +726,7 @@ def _call_exporter(
|
|
|
711
726
|
dump_folder=dump_folder,
|
|
712
727
|
output_names=output_names,
|
|
713
728
|
exporter_options=expop,
|
|
729
|
+
save_ep=save_ep,
|
|
714
730
|
)
|
|
715
731
|
else:
|
|
716
732
|
data["inputs_export"] = data["inputs"]
|
|
@@ -831,6 +847,8 @@ def validate_model(
|
|
|
831
847
|
output_names: Optional[List[str]] = None,
|
|
832
848
|
ort_logs: bool = False,
|
|
833
849
|
quiet_input_sets: Optional[Set[str]] = None,
|
|
850
|
+
save_ep: Optional[str] = None,
|
|
851
|
+
submodule: Optional[str] = None,
|
|
834
852
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
835
853
|
"""
|
|
836
854
|
Validates a model.
|
|
@@ -889,6 +907,9 @@ def validate_model(
|
|
|
889
907
|
:param ort_logs: increases onnxruntime verbosity when creating the session
|
|
890
908
|
:param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
|
|
891
909
|
even if quiet is False
|
|
910
|
+
:param save_ep: if not empty, this can be used to save the input sets and
|
|
911
|
+
the exported program
|
|
912
|
+
:param submodule: to test not the model but a submodule of this model
|
|
892
913
|
:return: two dictionaries, one with some metrics,
|
|
893
914
|
another one with whatever the function produces
|
|
894
915
|
|
|
@@ -952,6 +973,8 @@ def validate_model(
|
|
|
952
973
|
subfolder=subfolder,
|
|
953
974
|
use_pretrained=use_pretrained,
|
|
954
975
|
same_as_pretrained=same_as_pretrained,
|
|
976
|
+
save_ep=save_ep,
|
|
977
|
+
submodule=submodule,
|
|
955
978
|
)
|
|
956
979
|
if dump_folder:
|
|
957
980
|
with open(dump_stats, "w") as f:
|
|
@@ -1038,6 +1061,8 @@ def _validate_model_step1(
|
|
|
1038
1061
|
subfolder,
|
|
1039
1062
|
use_pretrained,
|
|
1040
1063
|
same_as_pretrained,
|
|
1064
|
+
save_ep,
|
|
1065
|
+
submodule,
|
|
1041
1066
|
):
|
|
1042
1067
|
assert not do_same or do_run, (
|
|
1043
1068
|
f"Discrepancies cannot be measured if the model is not run, "
|
|
@@ -1052,6 +1077,7 @@ def _validate_model_step1(
|
|
|
1052
1077
|
dump_folder,
|
|
1053
1078
|
folder_name,
|
|
1054
1079
|
patch_kwargs,
|
|
1080
|
+
submodule,
|
|
1055
1081
|
) = _prepare_validation(
|
|
1056
1082
|
model_id=model_id,
|
|
1057
1083
|
subfolder=subfolder,
|
|
@@ -1078,6 +1104,7 @@ def _validate_model_step1(
|
|
|
1078
1104
|
verbose=verbose,
|
|
1079
1105
|
output_names=output_names,
|
|
1080
1106
|
dump_folder=dump_folder,
|
|
1107
|
+
submodule=submodule,
|
|
1081
1108
|
)
|
|
1082
1109
|
|
|
1083
1110
|
data, iop, mop = _get_untrained_model_with_inputs(
|
|
@@ -1093,6 +1120,7 @@ def _validate_model_step1(
|
|
|
1093
1120
|
inputs2=inputs2,
|
|
1094
1121
|
quiet=quiet,
|
|
1095
1122
|
dump_folder=dump_folder,
|
|
1123
|
+
submodule=submodule,
|
|
1096
1124
|
)
|
|
1097
1125
|
|
|
1098
1126
|
second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
|
|
@@ -1153,6 +1181,7 @@ def _validate_model_step1(
|
|
|
1153
1181
|
do_run=do_run,
|
|
1154
1182
|
output_names=output_names,
|
|
1155
1183
|
exporter_options=exporter_options,
|
|
1184
|
+
save_ep=save_ep,
|
|
1156
1185
|
)
|
|
1157
1186
|
|
|
1158
1187
|
cont, dump_stats = _dump_onnx_model(
|
|
@@ -1426,6 +1455,7 @@ def call_exporter(
|
|
|
1426
1455
|
dump_folder: Optional[str] = None,
|
|
1427
1456
|
output_names: Optional[List[str]] = None,
|
|
1428
1457
|
exporter_options: Optional[Dict[str, Any]] = None,
|
|
1458
|
+
save_ep: Optional[str] = None,
|
|
1429
1459
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
1430
1460
|
"""
|
|
1431
1461
|
Calls an exporter on a model;
|
|
@@ -1440,6 +1470,7 @@ def call_exporter(
|
|
|
1440
1470
|
:param dump_folder: to dump additional information
|
|
1441
1471
|
:param output_names: list of output names to use with the onnx exporter
|
|
1442
1472
|
:param exporter_options: exporter options
|
|
1473
|
+
:param save_ep: saves the exported program
|
|
1443
1474
|
:return: two dictionaries, one with some metrics,
|
|
1444
1475
|
another one with whatever the function produces
|
|
1445
1476
|
"""
|
|
@@ -1456,6 +1487,8 @@ def call_exporter(
|
|
|
1456
1487
|
optimization=optimization,
|
|
1457
1488
|
do_run=do_run,
|
|
1458
1489
|
exporter_options=exporter_options,
|
|
1490
|
+
save_ep=save_ep,
|
|
1491
|
+
dump_folder=dump_folder,
|
|
1459
1492
|
)
|
|
1460
1493
|
_restore_torch_export_export(summary)
|
|
1461
1494
|
return summary, data
|
|
@@ -1469,6 +1502,8 @@ def call_exporter(
|
|
|
1469
1502
|
optimization=optimization,
|
|
1470
1503
|
output_names=output_names,
|
|
1471
1504
|
exporter_options=exporter_options,
|
|
1505
|
+
dump_folder=dump_folder,
|
|
1506
|
+
save_ep=save_ep,
|
|
1472
1507
|
)
|
|
1473
1508
|
_restore_torch_export_export(summary)
|
|
1474
1509
|
return summary, data
|
|
@@ -1483,6 +1518,7 @@ def call_exporter(
|
|
|
1483
1518
|
dump_folder=dump_folder,
|
|
1484
1519
|
output_names=output_names,
|
|
1485
1520
|
exporter_options=exporter_options,
|
|
1521
|
+
save_ep=save_ep,
|
|
1486
1522
|
)
|
|
1487
1523
|
_restore_torch_export_export(summary)
|
|
1488
1524
|
return summary, data
|
|
@@ -1516,6 +1552,8 @@ def call_torch_export_export(
|
|
|
1516
1552
|
optimization: Optional[str] = None,
|
|
1517
1553
|
do_run: bool = False,
|
|
1518
1554
|
exporter_options: Optional[Dict[str, Any]] = None,
|
|
1555
|
+
dump_folder: Optional[str] = None,
|
|
1556
|
+
save_ep: Optional[str] = None,
|
|
1519
1557
|
):
|
|
1520
1558
|
"""
|
|
1521
1559
|
Exports a model with :func:`torch.export.export`.
|
|
@@ -1529,6 +1567,8 @@ def call_torch_export_export(
|
|
|
1529
1567
|
:param optimization: optimization to do
|
|
1530
1568
|
:param do_run: runs and compute discrepancies
|
|
1531
1569
|
:param exporter_options: additional options given to the exporter
|
|
1570
|
+
:param dump_folder: folder where to dump the exported program
|
|
1571
|
+
:param save_ep: to save the exported program
|
|
1532
1572
|
:return: two dictionaries, one with some metrics,
|
|
1533
1573
|
another one with whatever the function produces
|
|
1534
1574
|
"""
|
|
@@ -1604,6 +1644,12 @@ def call_torch_export_export(
|
|
|
1604
1644
|
print(ep)
|
|
1605
1645
|
print("[call_torch_export_export] -- End of ExportedProgram")
|
|
1606
1646
|
|
|
1647
|
+
if dump_folder and save_ep:
|
|
1648
|
+
fname = f"{save_ep}.pt2"
|
|
1649
|
+
if verbose:
|
|
1650
|
+
print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
|
|
1651
|
+
torch.export.save(ep, os.path.join(dump_folder, fname))
|
|
1652
|
+
|
|
1607
1653
|
if do_run:
|
|
1608
1654
|
# We check for discrepancies.
|
|
1609
1655
|
if verbose:
|
|
@@ -1880,6 +1926,8 @@ def call_torch_export_onnx(
|
|
|
1880
1926
|
optimization: Optional[str] = None,
|
|
1881
1927
|
output_names: Optional[List[str]] = None,
|
|
1882
1928
|
exporter_options: Optional[Dict[str, Any]] = None,
|
|
1929
|
+
dump_folder: Optional[str] = None,
|
|
1930
|
+
save_ep: Optional[str] = None,
|
|
1883
1931
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1884
1932
|
"""
|
|
1885
1933
|
Exports a model into onnx.
|
|
@@ -1893,6 +1941,8 @@ def call_torch_export_onnx(
|
|
|
1893
1941
|
:param optimization: optimization to do
|
|
1894
1942
|
:param output_names: output names to use
|
|
1895
1943
|
:param exporter_options: additional options to give the exporter
|
|
1944
|
+
:param dump_folder: to know where to dump the exported program
|
|
1945
|
+
:param save_ep: to save the exported program
|
|
1896
1946
|
:return: two dictionaries, one with some metrics,
|
|
1897
1947
|
another one with whatever the function produces
|
|
1898
1948
|
"""
|
|
@@ -1986,6 +2036,12 @@ def call_torch_export_onnx(
|
|
|
1986
2036
|
return summary, data
|
|
1987
2037
|
|
|
1988
2038
|
assert epo is not None, "no onnx export was found"
|
|
2039
|
+
if dump_folder and save_ep:
|
|
2040
|
+
fname = f"{save_ep}.pt2"
|
|
2041
|
+
if verbose:
|
|
2042
|
+
print(f"[call_torch_export_export] -- save the exported program in {fname!r}")
|
|
2043
|
+
torch.export.save(epo.exported_program, os.path.join(dump_folder, fname))
|
|
2044
|
+
|
|
1989
2045
|
if verbose:
|
|
1990
2046
|
print("[call_torch_export_onnx] done (export)")
|
|
1991
2047
|
data["onnx_program"] = epo
|
|
@@ -2219,6 +2275,7 @@ def call_torch_export_custom(
|
|
|
2219
2275
|
dump_folder: Optional[str] = None,
|
|
2220
2276
|
output_names: Optional[List[str]] = None,
|
|
2221
2277
|
exporter_options: Optional[Dict[str, Any]] = None,
|
|
2278
|
+
save_ep: Optional[str] = None,
|
|
2222
2279
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
2223
2280
|
"""
|
|
2224
2281
|
Exports a model into onnx.
|
|
@@ -2233,6 +2290,7 @@ def call_torch_export_custom(
|
|
|
2233
2290
|
:param dump_folder: to store additional information
|
|
2234
2291
|
:param output_names: list of output names to use
|
|
2235
2292
|
:param exporter_options: additional exporter options
|
|
2293
|
+
:param save_ep: to save the exported program
|
|
2236
2294
|
:return: two dictionaries, one with some metrics,
|
|
2237
2295
|
another one with whatever the function produces
|
|
2238
2296
|
"""
|
|
@@ -2345,7 +2403,11 @@ def call_torch_export_custom(
|
|
|
2345
2403
|
export_options = ExportOptions(
|
|
2346
2404
|
strict=strict,
|
|
2347
2405
|
decomposition_table=decomposition_table,
|
|
2348
|
-
save_ep=(
|
|
2406
|
+
save_ep=(
|
|
2407
|
+
(os.path.join(dump_folder, f"{exporter}.ep"), 2**35 if save_ep else 2**18)
|
|
2408
|
+
if dump_folder
|
|
2409
|
+
else None
|
|
2410
|
+
),
|
|
2349
2411
|
**exporter_options,
|
|
2350
2412
|
)
|
|
2351
2413
|
options = OptimizationOptions(patterns=optimization) if optimization else None
|
|
@@ -4,6 +4,7 @@ import onnx
|
|
|
4
4
|
import torch
|
|
5
5
|
from ..api import TensorLike
|
|
6
6
|
from ..helpers import string_type
|
|
7
|
+
from ..helpers.onnx_helper import get_hidden_inputs
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class RuntimeValueKind(enum.IntEnum):
|
|
@@ -151,30 +152,6 @@ class RuntimeValue:
|
|
|
151
152
|
return self.kind == RuntimeValueKind.INITIALIZER
|
|
152
153
|
|
|
153
154
|
|
|
154
|
-
def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
|
|
155
|
-
"""
|
|
156
|
-
Returns the hidden inputs (inputs coming from an upper context)
|
|
157
|
-
used by a subgraph.
|
|
158
|
-
"""
|
|
159
|
-
hidden = set()
|
|
160
|
-
memo = (
|
|
161
|
-
set(i.name for i in graph.initializer)
|
|
162
|
-
| set(i.name for i in graph.sparse_initializer)
|
|
163
|
-
| set(i.name for i in graph.input)
|
|
164
|
-
)
|
|
165
|
-
for node in graph.node:
|
|
166
|
-
for i in node.input:
|
|
167
|
-
if i not in memo:
|
|
168
|
-
hidden.add(i)
|
|
169
|
-
for att in node.attribute:
|
|
170
|
-
if att.type == onnx.AttributeProto.GRAPH and att.g:
|
|
171
|
-
hid = get_hidden_inputs(att.g)
|
|
172
|
-
less = set(h for h in hid if h not in memo)
|
|
173
|
-
hidden |= less
|
|
174
|
-
memo |= set(node.output)
|
|
175
|
-
return hidden
|
|
176
|
-
|
|
177
|
-
|
|
178
155
|
def set_is_shape(
|
|
179
156
|
node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None
|
|
180
157
|
) -> List[str]:
|