onnx-diagnostic 0.7.3__py3-none-any.whl → 0.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +82 -12
- onnx_diagnostic/export/shape_helper.py +71 -0
- onnx_diagnostic/helpers/_log_helper.py +461 -0
- onnx_diagnostic/helpers/cache_helper.py +11 -1
- onnx_diagnostic/helpers/log_helper.py +404 -315
- onnx_diagnostic/reference/ops/op_cast_like.py +12 -8
- onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
- onnx_diagnostic/tasks/feature_extraction.py +92 -7
- onnx_diagnostic/tasks/fill_mask.py +6 -2
- onnx_diagnostic/tasks/image_classification.py +7 -3
- onnx_diagnostic/tasks/image_text_to_text.py +6 -2
- onnx_diagnostic/tasks/mixture_of_expert.py +1 -1
- onnx_diagnostic/tasks/object_detection.py +7 -3
- onnx_diagnostic/tasks/sentence_similarity.py +6 -2
- onnx_diagnostic/tasks/summarization.py +6 -2
- onnx_diagnostic/tasks/text2text_generation.py +8 -4
- onnx_diagnostic/tasks/text_classification.py +6 -2
- onnx_diagnostic/tasks/text_generation.py +5 -3
- onnx_diagnostic/tasks/text_to_image.py +6 -2
- onnx_diagnostic/tasks/zero_shot_image_classification.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +63 -7
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +188 -51
- onnx_diagnostic/torch_models/hghub/model_inputs.py +6 -1
- onnx_diagnostic/torch_models/validate.py +49 -10
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/RECORD +30 -29
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/top_level.txt +0 -0
|
@@ -7,59 +7,107 @@ import torch
|
|
|
7
7
|
import transformers
|
|
8
8
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
9
9
|
from transformers.cache_utils import StaticCache, Cache, DynamicCache
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import transformers.masking_utils
|
|
13
|
+
|
|
14
|
+
patch_masking_utils = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
patch_masking_utils = False
|
|
17
|
+
|
|
10
18
|
from ...ext_test_case import has_transformers
|
|
11
19
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
12
20
|
|
|
13
21
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
22
|
+
if patch_masking_utils:
|
|
23
|
+
# Introduced in 4.52
|
|
24
|
+
from transformers.masking_utils import causal_mask_function, sdpa_mask
|
|
25
|
+
|
|
26
|
+
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
27
|
+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
28
|
+
from ...helpers import string_type
|
|
29
|
+
|
|
30
|
+
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
31
|
+
(None, None, None, 0),
|
|
32
|
+
(None, None, 0, None),
|
|
33
|
+
]
|
|
34
|
+
if bh_indices:
|
|
35
|
+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
|
36
|
+
# reshape
|
|
37
|
+
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
|
|
38
|
+
dimensions = tuple(reversed(dimensions))
|
|
39
|
+
indices = tuple(shape.index(-1) for shape in dimensions)
|
|
40
|
+
|
|
41
|
+
# unsqueeze
|
|
42
|
+
udimensions = [
|
|
43
|
+
tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
def vector_mask_function(
|
|
47
|
+
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
|
|
48
|
+
):
|
|
49
|
+
assert len(args) == len(dimensions) == len(udimensions), (
|
|
50
|
+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
|
|
51
|
+
f"and udimensions={udimensions}."
|
|
52
|
+
)
|
|
53
|
+
assert len(indices) == len(args), (
|
|
54
|
+
f"Mismatch between args={string_type(args)} and indices={indices}, "
|
|
55
|
+
f"they should have the same length."
|
|
56
|
+
)
|
|
57
|
+
for a in args:
|
|
58
|
+
assert (
|
|
59
|
+
a.ndim == 1
|
|
60
|
+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
|
|
61
|
+
torch._check(a.shape[0] > 0)
|
|
62
|
+
|
|
63
|
+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
|
|
64
|
+
# new_args = [
|
|
65
|
+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
66
|
+
# for a, dims in zip(args, udimensions)
|
|
67
|
+
# ]
|
|
68
|
+
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
69
|
+
# if is_torchdynamo_exporting():
|
|
70
|
+
# for a in args:
|
|
71
|
+
# # The exporter should export with a dimension > 1
|
|
72
|
+
# # to make sure it is dynamic.
|
|
73
|
+
# torch._check(a.shape[0] > 1)
|
|
74
|
+
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
75
|
+
return mask_function(*expanded_args)
|
|
76
|
+
|
|
77
|
+
return vector_mask_function
|
|
78
|
+
|
|
79
|
+
def patched_eager_mask(
|
|
80
|
+
batch_size: int,
|
|
81
|
+
cache_position: torch.Tensor,
|
|
82
|
+
kv_length: int,
|
|
83
|
+
kv_offset: int = 0,
|
|
84
|
+
mask_function: Callable = causal_mask_function,
|
|
85
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
86
|
+
dtype: torch.dtype = torch.float32,
|
|
87
|
+
**kwargs,
|
|
88
|
+
) -> torch.Tensor:
|
|
89
|
+
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
90
|
+
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
91
|
+
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
92
|
+
mask = sdpa_mask(
|
|
93
|
+
batch_size=batch_size,
|
|
94
|
+
cache_position=cache_position,
|
|
95
|
+
kv_length=kv_length,
|
|
96
|
+
kv_offset=kv_offset,
|
|
97
|
+
mask_function=mask_function,
|
|
98
|
+
attention_mask=attention_mask,
|
|
99
|
+
allow_is_causal_skip=False,
|
|
100
|
+
allow_torch_fix=False,
|
|
101
|
+
**kwargs,
|
|
42
102
|
)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
52
|
-
# for a, dims in zip(args, udimensions)
|
|
53
|
-
# ]
|
|
54
|
-
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
55
|
-
# if is_torchdynamo_exporting():
|
|
56
|
-
# for a in args:
|
|
57
|
-
# # The exporter should export with a dimension > 1 to make sure it is dynamic.
|
|
58
|
-
# torch._check(a.shape[0] > 1)
|
|
59
|
-
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
60
|
-
return mask_function(*expanded_args)
|
|
61
|
-
|
|
62
|
-
return vector_mask_function
|
|
103
|
+
min_dtype = torch.finfo(dtype).min
|
|
104
|
+
# The patched line.
|
|
105
|
+
# we need 0s where the tokens should be taken into account,
|
|
106
|
+
# and -inf otherwise (mask is already of boolean type)
|
|
107
|
+
# mask =
|
|
108
|
+
# torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
|
|
109
|
+
mask = (~mask).to(dtype) * min_dtype
|
|
110
|
+
return mask
|
|
63
111
|
|
|
64
112
|
|
|
65
113
|
def _patch_make_causal_mask(
|
|
@@ -207,7 +255,8 @@ class patched_DynamicCache:
|
|
|
207
255
|
"""
|
|
208
256
|
# Update the number of seen tokens
|
|
209
257
|
if layer_idx == 0:
|
|
210
|
-
self
|
|
258
|
+
if hasattr(self, "_seen_tokens"):
|
|
259
|
+
self._seen_tokens += key_states.shape[-2]
|
|
211
260
|
|
|
212
261
|
# Update the cache
|
|
213
262
|
if key_states is not None:
|
|
@@ -246,7 +295,8 @@ class patched_DynamicCache:
|
|
|
246
295
|
if self.get_seq_length() <= max_length:
|
|
247
296
|
return
|
|
248
297
|
|
|
249
|
-
self
|
|
298
|
+
if hasattr(self, "_seen_tokens"):
|
|
299
|
+
self._seen_tokens = max_length
|
|
250
300
|
for idx in range(len(self.key_cache)):
|
|
251
301
|
if self.key_cache[idx].numel():
|
|
252
302
|
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
|
@@ -814,6 +864,91 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
814
864
|
return wrapper
|
|
815
865
|
|
|
816
866
|
|
|
867
|
+
def common_eager_attention_forward(
|
|
868
|
+
module: torch.nn.Module,
|
|
869
|
+
query: torch.Tensor,
|
|
870
|
+
key: torch.Tensor,
|
|
871
|
+
value: torch.Tensor,
|
|
872
|
+
attention_mask: Optional[torch.Tensor],
|
|
873
|
+
scaling: Optional[float] = None,
|
|
874
|
+
dropout: float = 0.0,
|
|
875
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
876
|
+
**kwargs,
|
|
877
|
+
):
|
|
878
|
+
if scaling is None:
|
|
879
|
+
scaling = query.size(-1) ** -0.5
|
|
880
|
+
|
|
881
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
882
|
+
if attention_mask is not None:
|
|
883
|
+
# The two following lines were added.
|
|
884
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
885
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
886
|
+
attn_weights = attn_weights + attention_mask
|
|
887
|
+
|
|
888
|
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
|
889
|
+
|
|
890
|
+
if head_mask is not None:
|
|
891
|
+
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
|
892
|
+
|
|
893
|
+
attn_weights = torch.nn.functional.dropout(
|
|
894
|
+
attn_weights, p=dropout, training=module.training
|
|
895
|
+
)
|
|
896
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
897
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
898
|
+
|
|
899
|
+
return attn_output, attn_weights
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
def patched_model_bart_eager_attention_forward(
|
|
903
|
+
module: torch.nn.Module,
|
|
904
|
+
query: torch.Tensor,
|
|
905
|
+
key: torch.Tensor,
|
|
906
|
+
value: torch.Tensor,
|
|
907
|
+
attention_mask: Optional[torch.Tensor],
|
|
908
|
+
scaling: Optional[float] = None,
|
|
909
|
+
dropout: float = 0.0,
|
|
910
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
911
|
+
**kwargs,
|
|
912
|
+
):
|
|
913
|
+
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
|
|
914
|
+
return common_eager_attention_forward(
|
|
915
|
+
module,
|
|
916
|
+
query,
|
|
917
|
+
key,
|
|
918
|
+
value,
|
|
919
|
+
attention_mask=attention_mask,
|
|
920
|
+
scaling=scaling,
|
|
921
|
+
dropout=dropout,
|
|
922
|
+
head_mask=head_mask,
|
|
923
|
+
**kwargs,
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
|
|
927
|
+
def patched_modeling_marian_eager_attention_forward(
|
|
928
|
+
module: torch.nn.Module,
|
|
929
|
+
query: torch.Tensor,
|
|
930
|
+
key: torch.Tensor,
|
|
931
|
+
value: torch.Tensor,
|
|
932
|
+
attention_mask: Optional[torch.Tensor],
|
|
933
|
+
scaling: Optional[float] = None,
|
|
934
|
+
dropout: float = 0.0,
|
|
935
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
936
|
+
**kwargs,
|
|
937
|
+
):
|
|
938
|
+
"""[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
|
|
939
|
+
return common_eager_attention_forward(
|
|
940
|
+
module,
|
|
941
|
+
query,
|
|
942
|
+
key,
|
|
943
|
+
value,
|
|
944
|
+
attention_mask=attention_mask,
|
|
945
|
+
scaling=scaling,
|
|
946
|
+
dropout=dropout,
|
|
947
|
+
head_mask=head_mask,
|
|
948
|
+
**kwargs,
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
|
|
817
952
|
class common_RotaryEmbedding(torch.nn.Module):
|
|
818
953
|
@torch.no_grad()
|
|
819
954
|
@patched_dynamic_rope_update
|
|
@@ -1045,4 +1180,6 @@ class patched_IdeficsAttention(torch.nn.Module):
|
|
|
1045
1180
|
if output_attentions:
|
|
1046
1181
|
attn_weights = None
|
|
1047
1182
|
|
|
1048
|
-
|
|
1183
|
+
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
|
|
1184
|
+
return attn_output, attn_weights, past_key_value
|
|
1185
|
+
return attn_output, attn_weights
|
|
@@ -26,7 +26,7 @@ def get_untrained_model_with_inputs(
|
|
|
26
26
|
use_pretrained: bool = False,
|
|
27
27
|
same_as_pretrained: bool = False,
|
|
28
28
|
use_preinstalled: bool = True,
|
|
29
|
-
add_second_input:
|
|
29
|
+
add_second_input: int = 1,
|
|
30
30
|
subfolder: Optional[str] = None,
|
|
31
31
|
use_only_preinstalled: bool = False,
|
|
32
32
|
) -> Dict[str, Any]:
|
|
@@ -144,6 +144,11 @@ def get_untrained_model_with_inputs(
|
|
|
144
144
|
f"[get_untrained_model_with_inputs] config._attn_implementation="
|
|
145
145
|
f"{config._attn_implementation!r}" # type: ignore[union-attr]
|
|
146
146
|
)
|
|
147
|
+
elif verbose:
|
|
148
|
+
print(
|
|
149
|
+
f"[get_untrained_model_with_inputs] default config._attn_implementation="
|
|
150
|
+
f"{getattr(config, '_attn_implementation', '?')!r}" # type: ignore[union-attr]
|
|
151
|
+
)
|
|
147
152
|
|
|
148
153
|
if type(config) is dict and "_diffusers_version" in config:
|
|
149
154
|
import diffusers
|
|
@@ -18,7 +18,6 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
|
18
18
|
from ..tasks import random_input_kwargs
|
|
19
19
|
from ..torch_export_patches import torch_export_patches
|
|
20
20
|
from ..torch_export_patches.patch_inputs import use_dyn_not_str
|
|
21
|
-
from ..reference import TorchOnnxEvaluator
|
|
22
21
|
from .hghub import get_untrained_model_with_inputs
|
|
23
22
|
|
|
24
23
|
|
|
@@ -157,6 +156,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]:
|
|
|
157
156
|
"version_torch": torch.__version__,
|
|
158
157
|
"version_numpy": numpy.__version__,
|
|
159
158
|
}
|
|
159
|
+
try:
|
|
160
|
+
import scipy
|
|
161
|
+
|
|
162
|
+
summary["version_scipy"] = getattr(scipy, "__version__", "?")
|
|
163
|
+
except ImportError:
|
|
164
|
+
pass
|
|
160
165
|
try:
|
|
161
166
|
import transformers
|
|
162
167
|
|
|
@@ -181,6 +186,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]:
|
|
|
181
186
|
summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?")
|
|
182
187
|
except ImportError:
|
|
183
188
|
pass
|
|
189
|
+
try:
|
|
190
|
+
import onnx_ir
|
|
191
|
+
|
|
192
|
+
summary["version_onnx_ir"] = getattr(onnx_ir, "__version__", "?")
|
|
193
|
+
except ImportError:
|
|
194
|
+
pass
|
|
184
195
|
import onnx_diagnostic
|
|
185
196
|
|
|
186
197
|
summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__
|
|
@@ -276,7 +287,8 @@ def validate_model(
|
|
|
276
287
|
runtime: str = "onnxruntime",
|
|
277
288
|
repeat: int = 1,
|
|
278
289
|
warmup: int = 0,
|
|
279
|
-
inputs2:
|
|
290
|
+
inputs2: int = 1,
|
|
291
|
+
output_names: Optional[List[str]] = None,
|
|
280
292
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
281
293
|
"""
|
|
282
294
|
Validates a model.
|
|
@@ -325,7 +337,9 @@ def validate_model(
|
|
|
325
337
|
:param repeat: number of time to measure the model
|
|
326
338
|
:param warmup: warmup the model first
|
|
327
339
|
:param inputs2: checks that the second set of inputs is reunning as well,
|
|
328
|
-
this ensures that the model does support dynamism
|
|
340
|
+
this ensures that the model does support dynamism, the value is used
|
|
341
|
+
as an increment to the first set of values (added to dimensions)
|
|
342
|
+
:param output_names: output names the onnx exporter should use
|
|
329
343
|
:return: two dictionaries, one with some metrics,
|
|
330
344
|
another one with whatever the function produces
|
|
331
345
|
|
|
@@ -421,6 +435,7 @@ def validate_model(
|
|
|
421
435
|
)
|
|
422
436
|
print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
|
|
423
437
|
print(f"[validate_model] dump_folder={dump_folder!r}")
|
|
438
|
+
print(f"[validate_model] output_names={output_names}")
|
|
424
439
|
summary["model_id"] = model_id
|
|
425
440
|
summary["model_subfolder"] = subfolder or ""
|
|
426
441
|
|
|
@@ -619,6 +634,7 @@ def validate_model(
|
|
|
619
634
|
optimization=optimization,
|
|
620
635
|
do_run=do_run,
|
|
621
636
|
dump_folder=dump_folder,
|
|
637
|
+
output_names=output_names,
|
|
622
638
|
)
|
|
623
639
|
else:
|
|
624
640
|
data["inputs_export"] = data["inputs"]
|
|
@@ -631,6 +647,7 @@ def validate_model(
|
|
|
631
647
|
optimization=optimization,
|
|
632
648
|
do_run=do_run,
|
|
633
649
|
dump_folder=dump_folder,
|
|
650
|
+
output_names=output_names,
|
|
634
651
|
)
|
|
635
652
|
summary.update(summary_export)
|
|
636
653
|
|
|
@@ -856,6 +873,7 @@ def call_exporter(
|
|
|
856
873
|
optimization: Optional[str] = None,
|
|
857
874
|
do_run: bool = False,
|
|
858
875
|
dump_folder: Optional[str] = None,
|
|
876
|
+
output_names: Optional[List[str]] = None,
|
|
859
877
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
860
878
|
"""
|
|
861
879
|
Calls an exporter on a model;
|
|
@@ -868,6 +886,7 @@ def call_exporter(
|
|
|
868
886
|
:param optimization: optimization to do
|
|
869
887
|
:param do_run: runs and compute discrepancies
|
|
870
888
|
:param dump_folder: to dump additional information
|
|
889
|
+
:param output_names: list of output names to use with the onnx exporter
|
|
871
890
|
:return: two dictionaries, one with some metrics,
|
|
872
891
|
another one with whatever the function produces
|
|
873
892
|
"""
|
|
@@ -890,6 +909,7 @@ def call_exporter(
|
|
|
890
909
|
quiet=quiet,
|
|
891
910
|
verbose=verbose,
|
|
892
911
|
optimization=optimization,
|
|
912
|
+
output_names=output_names,
|
|
893
913
|
)
|
|
894
914
|
return summary, data
|
|
895
915
|
if exporter == "custom" or exporter.startswith("custom"):
|
|
@@ -901,6 +921,7 @@ def call_exporter(
|
|
|
901
921
|
verbose=verbose,
|
|
902
922
|
optimization=optimization,
|
|
903
923
|
dump_folder=dump_folder,
|
|
924
|
+
output_names=output_names,
|
|
904
925
|
)
|
|
905
926
|
return summary, data
|
|
906
927
|
if exporter == "modelbuilder":
|
|
@@ -911,6 +932,7 @@ def call_exporter(
|
|
|
911
932
|
quiet=quiet,
|
|
912
933
|
verbose=verbose,
|
|
913
934
|
optimization=optimization,
|
|
935
|
+
output_names=output_names,
|
|
914
936
|
)
|
|
915
937
|
return summary, data
|
|
916
938
|
raise NotImplementedError(
|
|
@@ -1054,7 +1076,7 @@ def validate_onnx_model(
|
|
|
1054
1076
|
runtime: str = "onnxruntime",
|
|
1055
1077
|
repeat: int = 1,
|
|
1056
1078
|
warmup: int = 0,
|
|
1057
|
-
inputs2:
|
|
1079
|
+
inputs2: int = 1,
|
|
1058
1080
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1059
1081
|
"""
|
|
1060
1082
|
Verifies that an onnx model produces the same
|
|
@@ -1070,14 +1092,15 @@ def validate_onnx_model(
|
|
|
1070
1092
|
:param runtime: onnx runtime to use, onnxruntime or torch
|
|
1071
1093
|
:param repeat: run that number of times the model
|
|
1072
1094
|
:param warmup: warmup the model
|
|
1073
|
-
:param
|
|
1074
|
-
to make sure the exported model supports dynamism
|
|
1095
|
+
:param inputs2: to validate the model on the second input set
|
|
1096
|
+
to make sure the exported model supports dynamism, the value is
|
|
1097
|
+
used as an increment added to the first set of inputs (added to dimensions)
|
|
1075
1098
|
:return: two dictionaries, one with some metrics,
|
|
1076
1099
|
another one with whatever the function produces
|
|
1077
1100
|
"""
|
|
1078
1101
|
import onnxruntime
|
|
1079
1102
|
|
|
1080
|
-
def _mk(key):
|
|
1103
|
+
def _mk(key, flavour=flavour):
|
|
1081
1104
|
return f"{key}_{flavour}" if flavour else key
|
|
1082
1105
|
|
|
1083
1106
|
summary: Dict[str, Any] = {}
|
|
@@ -1113,6 +1136,9 @@ def validate_onnx_model(
|
|
|
1113
1136
|
f"{providers}..., flavour={flavour!r}"
|
|
1114
1137
|
)
|
|
1115
1138
|
|
|
1139
|
+
if runtime != "onnxruntime":
|
|
1140
|
+
from ..reference import TorchOnnxEvaluator
|
|
1141
|
+
|
|
1116
1142
|
cls_runtime = (
|
|
1117
1143
|
(
|
|
1118
1144
|
lambda model, providers: onnxruntime.InferenceSession(
|
|
@@ -1122,14 +1148,14 @@ def validate_onnx_model(
|
|
|
1122
1148
|
)
|
|
1123
1149
|
if runtime == "onnxruntime"
|
|
1124
1150
|
else (
|
|
1125
|
-
lambda model, providers:
|
|
1151
|
+
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
|
|
1126
1152
|
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1127
1153
|
)
|
|
1128
1154
|
)
|
|
1129
1155
|
)
|
|
1130
1156
|
sess = _quiet_or_not_quiet(
|
|
1131
1157
|
quiet,
|
|
1132
|
-
_mk("
|
|
1158
|
+
_mk("create_onnx_ort"),
|
|
1133
1159
|
summary,
|
|
1134
1160
|
data,
|
|
1135
1161
|
(lambda source=source, providers=providers: cls_runtime(source, providers)),
|
|
@@ -1164,7 +1190,7 @@ def validate_onnx_model(
|
|
|
1164
1190
|
|
|
1165
1191
|
got = _quiet_or_not_quiet(
|
|
1166
1192
|
quiet,
|
|
1167
|
-
_mk(f"
|
|
1193
|
+
_mk(f"run_onnx_ort{suffix}"),
|
|
1168
1194
|
summary,
|
|
1169
1195
|
data,
|
|
1170
1196
|
(lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
|
|
@@ -1195,6 +1221,7 @@ def call_torch_export_onnx(
|
|
|
1195
1221
|
quiet: bool = False,
|
|
1196
1222
|
verbose: int = 0,
|
|
1197
1223
|
optimization: Optional[str] = None,
|
|
1224
|
+
output_names: Optional[List[str]] = None,
|
|
1198
1225
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1199
1226
|
"""
|
|
1200
1227
|
Exports a model into onnx.
|
|
@@ -1206,6 +1233,7 @@ def call_torch_export_onnx(
|
|
|
1206
1233
|
:param quiet: catch exception or not
|
|
1207
1234
|
:param verbose: verbosity
|
|
1208
1235
|
:param optimization: optimization to do
|
|
1236
|
+
:param output_names: output names to use
|
|
1209
1237
|
:return: two dictionaries, one with some metrics,
|
|
1210
1238
|
another one with whatever the function produces
|
|
1211
1239
|
"""
|
|
@@ -1260,6 +1288,8 @@ def call_torch_export_onnx(
|
|
|
1260
1288
|
print("[call_torch_export_onnx] dynamo=False so...")
|
|
1261
1289
|
print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
|
|
1262
1290
|
print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
|
|
1291
|
+
if output_names:
|
|
1292
|
+
export_export_kwargs["output_names"] = output_names
|
|
1263
1293
|
if opset:
|
|
1264
1294
|
export_export_kwargs["opset_version"] = opset
|
|
1265
1295
|
if verbose:
|
|
@@ -1330,6 +1360,7 @@ def call_torch_export_model_builder(
|
|
|
1330
1360
|
quiet: bool = False,
|
|
1331
1361
|
verbose: int = 0,
|
|
1332
1362
|
optimization: Optional[str] = None,
|
|
1363
|
+
output_names: Optional[List[str]] = None,
|
|
1333
1364
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1334
1365
|
"""
|
|
1335
1366
|
Exports a model into onnx with :epkg:`ModelBuilder`.
|
|
@@ -1340,6 +1371,7 @@ def call_torch_export_model_builder(
|
|
|
1340
1371
|
:param quiet: catch exception or not
|
|
1341
1372
|
:param verbose: verbosity
|
|
1342
1373
|
:param optimization: optimization to do
|
|
1374
|
+
:param output_names: list of output names to use
|
|
1343
1375
|
:return: two dictionaries, one with some metrics,
|
|
1344
1376
|
another one with whatever the function produces
|
|
1345
1377
|
"""
|
|
@@ -1353,6 +1385,9 @@ def call_torch_export_model_builder(
|
|
|
1353
1385
|
provider = data.get("model_device", "cpu")
|
|
1354
1386
|
dump_folder = data.get("model_dump_folder", "")
|
|
1355
1387
|
assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
|
|
1388
|
+
assert (
|
|
1389
|
+
not output_names
|
|
1390
|
+
), f"output_names not empty, not supported yet, output_names={output_names}"
|
|
1356
1391
|
cache_dir = os.path.join(dump_folder, "cache_mb")
|
|
1357
1392
|
if not os.path.exists(cache_dir):
|
|
1358
1393
|
os.makedirs(cache_dir)
|
|
@@ -1392,6 +1427,7 @@ def call_torch_export_custom(
|
|
|
1392
1427
|
verbose: int = 0,
|
|
1393
1428
|
optimization: Optional[str] = None,
|
|
1394
1429
|
dump_folder: Optional[str] = None,
|
|
1430
|
+
output_names: Optional[List[str]] = None,
|
|
1395
1431
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1396
1432
|
"""
|
|
1397
1433
|
Exports a model into onnx.
|
|
@@ -1404,6 +1440,7 @@ def call_torch_export_custom(
|
|
|
1404
1440
|
:param verbose: verbosity
|
|
1405
1441
|
:param optimization: optimization to do
|
|
1406
1442
|
:param dump_folder: to store additional information
|
|
1443
|
+
:param output_names: list of output names to use
|
|
1407
1444
|
:return: two dictionaries, one with some metrics,
|
|
1408
1445
|
another one with whatever the function produces
|
|
1409
1446
|
"""
|
|
@@ -1488,6 +1525,8 @@ def call_torch_export_custom(
|
|
|
1488
1525
|
)
|
|
1489
1526
|
if opset:
|
|
1490
1527
|
kws["target_opset"] = opset
|
|
1528
|
+
if output_names:
|
|
1529
|
+
kws["output_names"] = output_names
|
|
1491
1530
|
|
|
1492
1531
|
epo, opt_stats = _quiet_or_not_quiet(
|
|
1493
1532
|
quiet,
|