onnx-diagnostic 0.7.3__py3-none-any.whl → 0.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +82 -12
  3. onnx_diagnostic/export/shape_helper.py +71 -0
  4. onnx_diagnostic/helpers/_log_helper.py +461 -0
  5. onnx_diagnostic/helpers/cache_helper.py +11 -1
  6. onnx_diagnostic/helpers/log_helper.py +404 -315
  7. onnx_diagnostic/reference/ops/op_cast_like.py +12 -8
  8. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  9. onnx_diagnostic/tasks/feature_extraction.py +92 -7
  10. onnx_diagnostic/tasks/fill_mask.py +6 -2
  11. onnx_diagnostic/tasks/image_classification.py +7 -3
  12. onnx_diagnostic/tasks/image_text_to_text.py +6 -2
  13. onnx_diagnostic/tasks/mixture_of_expert.py +1 -1
  14. onnx_diagnostic/tasks/object_detection.py +7 -3
  15. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  16. onnx_diagnostic/tasks/summarization.py +6 -2
  17. onnx_diagnostic/tasks/text2text_generation.py +8 -4
  18. onnx_diagnostic/tasks/text_classification.py +6 -2
  19. onnx_diagnostic/tasks/text_generation.py +5 -3
  20. onnx_diagnostic/tasks/text_to_image.py +6 -2
  21. onnx_diagnostic/tasks/zero_shot_image_classification.py +6 -2
  22. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +63 -7
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +188 -51
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +6 -1
  25. onnx_diagnostic/torch_models/validate.py +49 -10
  26. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/METADATA +1 -1
  27. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/RECORD +30 -29
  28. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/top_level.txt +0 -0
@@ -7,59 +7,107 @@ import torch
7
7
  import transformers
8
8
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
9
9
  from transformers.cache_utils import StaticCache, Cache, DynamicCache
10
+
11
+ try:
12
+ import transformers.masking_utils
13
+
14
+ patch_masking_utils = True
15
+ except ImportError:
16
+ patch_masking_utils = False
17
+
10
18
  from ...ext_test_case import has_transformers
11
19
  from ...helpers.torch_helper import is_torchdynamo_exporting
12
20
 
13
21
 
14
- def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
15
- """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
16
- from ...helpers import string_type
17
-
18
- dimensions: List[Tuple[Optional[int], ...]] = [
19
- (None, None, None, 0),
20
- (None, None, 0, None),
21
- ]
22
- if bh_indices:
23
- dimensions.extend([(None, 0, None, None), (0, None, None, None)])
24
- # reshape
25
- dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
26
- dimensions = tuple(reversed(dimensions))
27
- indices = tuple(shape.index(-1) for shape in dimensions)
28
-
29
- # unsqueeze
30
- udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
31
-
32
- def vector_mask_function(
33
- *args, mask_function=mask_function, dimensions=dimensions, indices=indices
34
- ):
35
- assert len(args) == len(dimensions) == len(udimensions), (
36
- f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
37
- f"and udimensions={udimensions}."
38
- )
39
- assert len(indices) == len(args), (
40
- f"Mismatch between args={string_type(args)} and indices={indices}, "
41
- f"they should have the same length."
22
+ if patch_masking_utils:
23
+ # Introduced in 4.52
24
+ from transformers.masking_utils import causal_mask_function, sdpa_mask
25
+
26
+ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
27
+ """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
28
+ from ...helpers import string_type
29
+
30
+ dimensions: List[Tuple[Optional[int], ...]] = [
31
+ (None, None, None, 0),
32
+ (None, None, 0, None),
33
+ ]
34
+ if bh_indices:
35
+ dimensions.extend([(None, 0, None, None), (0, None, None, None)])
36
+ # reshape
37
+ dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
38
+ dimensions = tuple(reversed(dimensions))
39
+ indices = tuple(shape.index(-1) for shape in dimensions)
40
+
41
+ # unsqueeze
42
+ udimensions = [
43
+ tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
44
+ ]
45
+
46
+ def vector_mask_function(
47
+ *args, mask_function=mask_function, dimensions=dimensions, indices=indices
48
+ ):
49
+ assert len(args) == len(dimensions) == len(udimensions), (
50
+ f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
51
+ f"and udimensions={udimensions}."
52
+ )
53
+ assert len(indices) == len(args), (
54
+ f"Mismatch between args={string_type(args)} and indices={indices}, "
55
+ f"they should have the same length."
56
+ )
57
+ for a in args:
58
+ assert (
59
+ a.ndim == 1
60
+ ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
61
+ torch._check(a.shape[0] > 0)
62
+
63
+ new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
64
+ # new_args = [
65
+ # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
66
+ # for a, dims in zip(args, udimensions)
67
+ # ]
68
+ max_shape = tuple(args[i].shape[0] for i in indices)
69
+ # if is_torchdynamo_exporting():
70
+ # for a in args:
71
+ # # The exporter should export with a dimension > 1
72
+ # # to make sure it is dynamic.
73
+ # torch._check(a.shape[0] > 1)
74
+ expanded_args = [a.expand(max_shape) for a in new_args]
75
+ return mask_function(*expanded_args)
76
+
77
+ return vector_mask_function
78
+
79
+ def patched_eager_mask(
80
+ batch_size: int,
81
+ cache_position: torch.Tensor,
82
+ kv_length: int,
83
+ kv_offset: int = 0,
84
+ mask_function: Callable = causal_mask_function,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ dtype: torch.dtype = torch.float32,
87
+ **kwargs,
88
+ ) -> torch.Tensor:
89
+ """manual patch for function ``transformers.masking_utils.eager_mask``."""
90
+ # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
91
+ _ = kwargs.pop("allow_is_causal_skip", None)
92
+ mask = sdpa_mask(
93
+ batch_size=batch_size,
94
+ cache_position=cache_position,
95
+ kv_length=kv_length,
96
+ kv_offset=kv_offset,
97
+ mask_function=mask_function,
98
+ attention_mask=attention_mask,
99
+ allow_is_causal_skip=False,
100
+ allow_torch_fix=False,
101
+ **kwargs,
42
102
  )
43
- for a in args:
44
- assert (
45
- a.ndim == 1
46
- ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
47
- torch._check(a.shape[0] > 0)
48
-
49
- new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
50
- # new_args = [
51
- # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
52
- # for a, dims in zip(args, udimensions)
53
- # ]
54
- max_shape = tuple(args[i].shape[0] for i in indices)
55
- # if is_torchdynamo_exporting():
56
- # for a in args:
57
- # # The exporter should export with a dimension > 1 to make sure it is dynamic.
58
- # torch._check(a.shape[0] > 1)
59
- expanded_args = [a.expand(max_shape) for a in new_args]
60
- return mask_function(*expanded_args)
61
-
62
- return vector_mask_function
103
+ min_dtype = torch.finfo(dtype).min
104
+ # The patched line.
105
+ # we need 0s where the tokens should be taken into account,
106
+ # and -inf otherwise (mask is already of boolean type)
107
+ # mask =
108
+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
109
+ mask = (~mask).to(dtype) * min_dtype
110
+ return mask
63
111
 
64
112
 
65
113
  def _patch_make_causal_mask(
@@ -207,7 +255,8 @@ class patched_DynamicCache:
207
255
  """
208
256
  # Update the number of seen tokens
209
257
  if layer_idx == 0:
210
- self._seen_tokens += key_states.shape[-2]
258
+ if hasattr(self, "_seen_tokens"):
259
+ self._seen_tokens += key_states.shape[-2]
211
260
 
212
261
  # Update the cache
213
262
  if key_states is not None:
@@ -246,7 +295,8 @@ class patched_DynamicCache:
246
295
  if self.get_seq_length() <= max_length:
247
296
  return
248
297
 
249
- self._seen_tokens = max_length
298
+ if hasattr(self, "_seen_tokens"):
299
+ self._seen_tokens = max_length
250
300
  for idx in range(len(self.key_cache)):
251
301
  if self.key_cache[idx].numel():
252
302
  self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
@@ -814,6 +864,91 @@ def patched_dynamic_rope_update(rope_forward):
814
864
  return wrapper
815
865
 
816
866
 
867
+ def common_eager_attention_forward(
868
+ module: torch.nn.Module,
869
+ query: torch.Tensor,
870
+ key: torch.Tensor,
871
+ value: torch.Tensor,
872
+ attention_mask: Optional[torch.Tensor],
873
+ scaling: Optional[float] = None,
874
+ dropout: float = 0.0,
875
+ head_mask: Optional[torch.Tensor] = None,
876
+ **kwargs,
877
+ ):
878
+ if scaling is None:
879
+ scaling = query.size(-1) ** -0.5
880
+
881
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
882
+ if attention_mask is not None:
883
+ # The two following lines were added.
884
+ if attention_mask is not None and attention_mask.ndim == 4:
885
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
886
+ attn_weights = attn_weights + attention_mask
887
+
888
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
889
+
890
+ if head_mask is not None:
891
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
892
+
893
+ attn_weights = torch.nn.functional.dropout(
894
+ attn_weights, p=dropout, training=module.training
895
+ )
896
+ attn_output = torch.matmul(attn_weights, value)
897
+ attn_output = attn_output.transpose(1, 2).contiguous()
898
+
899
+ return attn_output, attn_weights
900
+
901
+
902
+ def patched_model_bart_eager_attention_forward(
903
+ module: torch.nn.Module,
904
+ query: torch.Tensor,
905
+ key: torch.Tensor,
906
+ value: torch.Tensor,
907
+ attention_mask: Optional[torch.Tensor],
908
+ scaling: Optional[float] = None,
909
+ dropout: float = 0.0,
910
+ head_mask: Optional[torch.Tensor] = None,
911
+ **kwargs,
912
+ ):
913
+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
914
+ return common_eager_attention_forward(
915
+ module,
916
+ query,
917
+ key,
918
+ value,
919
+ attention_mask=attention_mask,
920
+ scaling=scaling,
921
+ dropout=dropout,
922
+ head_mask=head_mask,
923
+ **kwargs,
924
+ )
925
+
926
+
927
+ def patched_modeling_marian_eager_attention_forward(
928
+ module: torch.nn.Module,
929
+ query: torch.Tensor,
930
+ key: torch.Tensor,
931
+ value: torch.Tensor,
932
+ attention_mask: Optional[torch.Tensor],
933
+ scaling: Optional[float] = None,
934
+ dropout: float = 0.0,
935
+ head_mask: Optional[torch.Tensor] = None,
936
+ **kwargs,
937
+ ):
938
+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
939
+ return common_eager_attention_forward(
940
+ module,
941
+ query,
942
+ key,
943
+ value,
944
+ attention_mask=attention_mask,
945
+ scaling=scaling,
946
+ dropout=dropout,
947
+ head_mask=head_mask,
948
+ **kwargs,
949
+ )
950
+
951
+
817
952
  class common_RotaryEmbedding(torch.nn.Module):
818
953
  @torch.no_grad()
819
954
  @patched_dynamic_rope_update
@@ -1045,4 +1180,6 @@ class patched_IdeficsAttention(torch.nn.Module):
1045
1180
  if output_attentions:
1046
1181
  attn_weights = None
1047
1182
 
1048
- return attn_output, attn_weights, past_key_value
1183
+ if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
1184
+ return attn_output, attn_weights, past_key_value
1185
+ return attn_output, attn_weights
@@ -26,7 +26,7 @@ def get_untrained_model_with_inputs(
26
26
  use_pretrained: bool = False,
27
27
  same_as_pretrained: bool = False,
28
28
  use_preinstalled: bool = True,
29
- add_second_input: bool = False,
29
+ add_second_input: int = 1,
30
30
  subfolder: Optional[str] = None,
31
31
  use_only_preinstalled: bool = False,
32
32
  ) -> Dict[str, Any]:
@@ -144,6 +144,11 @@ def get_untrained_model_with_inputs(
144
144
  f"[get_untrained_model_with_inputs] config._attn_implementation="
145
145
  f"{config._attn_implementation!r}" # type: ignore[union-attr]
146
146
  )
147
+ elif verbose:
148
+ print(
149
+ f"[get_untrained_model_with_inputs] default config._attn_implementation="
150
+ f"{getattr(config, '_attn_implementation', '?')!r}" # type: ignore[union-attr]
151
+ )
147
152
 
148
153
  if type(config) is dict and "_diffusers_version" in config:
149
154
  import diffusers
@@ -18,7 +18,6 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
18
18
  from ..tasks import random_input_kwargs
19
19
  from ..torch_export_patches import torch_export_patches
20
20
  from ..torch_export_patches.patch_inputs import use_dyn_not_str
21
- from ..reference import TorchOnnxEvaluator
22
21
  from .hghub import get_untrained_model_with_inputs
23
22
 
24
23
 
@@ -157,6 +156,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]:
157
156
  "version_torch": torch.__version__,
158
157
  "version_numpy": numpy.__version__,
159
158
  }
159
+ try:
160
+ import scipy
161
+
162
+ summary["version_scipy"] = getattr(scipy, "__version__", "?")
163
+ except ImportError:
164
+ pass
160
165
  try:
161
166
  import transformers
162
167
 
@@ -181,6 +186,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]:
181
186
  summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?")
182
187
  except ImportError:
183
188
  pass
189
+ try:
190
+ import onnx_ir
191
+
192
+ summary["version_onnx_ir"] = getattr(onnx_ir, "__version__", "?")
193
+ except ImportError:
194
+ pass
184
195
  import onnx_diagnostic
185
196
 
186
197
  summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__
@@ -276,7 +287,8 @@ def validate_model(
276
287
  runtime: str = "onnxruntime",
277
288
  repeat: int = 1,
278
289
  warmup: int = 0,
279
- inputs2: bool = True,
290
+ inputs2: int = 1,
291
+ output_names: Optional[List[str]] = None,
280
292
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
281
293
  """
282
294
  Validates a model.
@@ -325,7 +337,9 @@ def validate_model(
325
337
  :param repeat: number of time to measure the model
326
338
  :param warmup: warmup the model first
327
339
  :param inputs2: checks that the second set of inputs is reunning as well,
328
- this ensures that the model does support dynamism
340
+ this ensures that the model does support dynamism, the value is used
341
+ as an increment to the first set of values (added to dimensions)
342
+ :param output_names: output names the onnx exporter should use
329
343
  :return: two dictionaries, one with some metrics,
330
344
  another one with whatever the function produces
331
345
 
@@ -421,6 +435,7 @@ def validate_model(
421
435
  )
422
436
  print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
423
437
  print(f"[validate_model] dump_folder={dump_folder!r}")
438
+ print(f"[validate_model] output_names={output_names}")
424
439
  summary["model_id"] = model_id
425
440
  summary["model_subfolder"] = subfolder or ""
426
441
 
@@ -619,6 +634,7 @@ def validate_model(
619
634
  optimization=optimization,
620
635
  do_run=do_run,
621
636
  dump_folder=dump_folder,
637
+ output_names=output_names,
622
638
  )
623
639
  else:
624
640
  data["inputs_export"] = data["inputs"]
@@ -631,6 +647,7 @@ def validate_model(
631
647
  optimization=optimization,
632
648
  do_run=do_run,
633
649
  dump_folder=dump_folder,
650
+ output_names=output_names,
634
651
  )
635
652
  summary.update(summary_export)
636
653
 
@@ -856,6 +873,7 @@ def call_exporter(
856
873
  optimization: Optional[str] = None,
857
874
  do_run: bool = False,
858
875
  dump_folder: Optional[str] = None,
876
+ output_names: Optional[List[str]] = None,
859
877
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
860
878
  """
861
879
  Calls an exporter on a model;
@@ -868,6 +886,7 @@ def call_exporter(
868
886
  :param optimization: optimization to do
869
887
  :param do_run: runs and compute discrepancies
870
888
  :param dump_folder: to dump additional information
889
+ :param output_names: list of output names to use with the onnx exporter
871
890
  :return: two dictionaries, one with some metrics,
872
891
  another one with whatever the function produces
873
892
  """
@@ -890,6 +909,7 @@ def call_exporter(
890
909
  quiet=quiet,
891
910
  verbose=verbose,
892
911
  optimization=optimization,
912
+ output_names=output_names,
893
913
  )
894
914
  return summary, data
895
915
  if exporter == "custom" or exporter.startswith("custom"):
@@ -901,6 +921,7 @@ def call_exporter(
901
921
  verbose=verbose,
902
922
  optimization=optimization,
903
923
  dump_folder=dump_folder,
924
+ output_names=output_names,
904
925
  )
905
926
  return summary, data
906
927
  if exporter == "modelbuilder":
@@ -911,6 +932,7 @@ def call_exporter(
911
932
  quiet=quiet,
912
933
  verbose=verbose,
913
934
  optimization=optimization,
935
+ output_names=output_names,
914
936
  )
915
937
  return summary, data
916
938
  raise NotImplementedError(
@@ -1054,7 +1076,7 @@ def validate_onnx_model(
1054
1076
  runtime: str = "onnxruntime",
1055
1077
  repeat: int = 1,
1056
1078
  warmup: int = 0,
1057
- inputs2: bool = True,
1079
+ inputs2: int = 1,
1058
1080
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1059
1081
  """
1060
1082
  Verifies that an onnx model produces the same
@@ -1070,14 +1092,15 @@ def validate_onnx_model(
1070
1092
  :param runtime: onnx runtime to use, onnxruntime or torch
1071
1093
  :param repeat: run that number of times the model
1072
1094
  :param warmup: warmup the model
1073
- :param inputs: to validate the model on the second input set
1074
- to make sure the exported model supports dynamism
1095
+ :param inputs2: to validate the model on the second input set
1096
+ to make sure the exported model supports dynamism, the value is
1097
+ used as an increment added to the first set of inputs (added to dimensions)
1075
1098
  :return: two dictionaries, one with some metrics,
1076
1099
  another one with whatever the function produces
1077
1100
  """
1078
1101
  import onnxruntime
1079
1102
 
1080
- def _mk(key):
1103
+ def _mk(key, flavour=flavour):
1081
1104
  return f"{key}_{flavour}" if flavour else key
1082
1105
 
1083
1106
  summary: Dict[str, Any] = {}
@@ -1113,6 +1136,9 @@ def validate_onnx_model(
1113
1136
  f"{providers}..., flavour={flavour!r}"
1114
1137
  )
1115
1138
 
1139
+ if runtime != "onnxruntime":
1140
+ from ..reference import TorchOnnxEvaluator
1141
+
1116
1142
  cls_runtime = (
1117
1143
  (
1118
1144
  lambda model, providers: onnxruntime.InferenceSession(
@@ -1122,14 +1148,14 @@ def validate_onnx_model(
1122
1148
  )
1123
1149
  if runtime == "onnxruntime"
1124
1150
  else (
1125
- lambda model, providers: TorchOnnxEvaluator(
1151
+ lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
1126
1152
  model, providers=providers, verbose=max(verbose - 1, 0)
1127
1153
  )
1128
1154
  )
1129
1155
  )
1130
1156
  sess = _quiet_or_not_quiet(
1131
1157
  quiet,
1132
- _mk("onnx_ort_create"),
1158
+ _mk("create_onnx_ort"),
1133
1159
  summary,
1134
1160
  data,
1135
1161
  (lambda source=source, providers=providers: cls_runtime(source, providers)),
@@ -1164,7 +1190,7 @@ def validate_onnx_model(
1164
1190
 
1165
1191
  got = _quiet_or_not_quiet(
1166
1192
  quiet,
1167
- _mk(f"time_onnx_ort_run{suffix}"),
1193
+ _mk(f"run_onnx_ort{suffix}"),
1168
1194
  summary,
1169
1195
  data,
1170
1196
  (lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
@@ -1195,6 +1221,7 @@ def call_torch_export_onnx(
1195
1221
  quiet: bool = False,
1196
1222
  verbose: int = 0,
1197
1223
  optimization: Optional[str] = None,
1224
+ output_names: Optional[List[str]] = None,
1198
1225
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1199
1226
  """
1200
1227
  Exports a model into onnx.
@@ -1206,6 +1233,7 @@ def call_torch_export_onnx(
1206
1233
  :param quiet: catch exception or not
1207
1234
  :param verbose: verbosity
1208
1235
  :param optimization: optimization to do
1236
+ :param output_names: output names to use
1209
1237
  :return: two dictionaries, one with some metrics,
1210
1238
  another one with whatever the function produces
1211
1239
  """
@@ -1260,6 +1288,8 @@ def call_torch_export_onnx(
1260
1288
  print("[call_torch_export_onnx] dynamo=False so...")
1261
1289
  print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
1262
1290
  print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
1291
+ if output_names:
1292
+ export_export_kwargs["output_names"] = output_names
1263
1293
  if opset:
1264
1294
  export_export_kwargs["opset_version"] = opset
1265
1295
  if verbose:
@@ -1330,6 +1360,7 @@ def call_torch_export_model_builder(
1330
1360
  quiet: bool = False,
1331
1361
  verbose: int = 0,
1332
1362
  optimization: Optional[str] = None,
1363
+ output_names: Optional[List[str]] = None,
1333
1364
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1334
1365
  """
1335
1366
  Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1340,6 +1371,7 @@ def call_torch_export_model_builder(
1340
1371
  :param quiet: catch exception or not
1341
1372
  :param verbose: verbosity
1342
1373
  :param optimization: optimization to do
1374
+ :param output_names: list of output names to use
1343
1375
  :return: two dictionaries, one with some metrics,
1344
1376
  another one with whatever the function produces
1345
1377
  """
@@ -1353,6 +1385,9 @@ def call_torch_export_model_builder(
1353
1385
  provider = data.get("model_device", "cpu")
1354
1386
  dump_folder = data.get("model_dump_folder", "")
1355
1387
  assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
1388
+ assert (
1389
+ not output_names
1390
+ ), f"output_names not empty, not supported yet, output_names={output_names}"
1356
1391
  cache_dir = os.path.join(dump_folder, "cache_mb")
1357
1392
  if not os.path.exists(cache_dir):
1358
1393
  os.makedirs(cache_dir)
@@ -1392,6 +1427,7 @@ def call_torch_export_custom(
1392
1427
  verbose: int = 0,
1393
1428
  optimization: Optional[str] = None,
1394
1429
  dump_folder: Optional[str] = None,
1430
+ output_names: Optional[List[str]] = None,
1395
1431
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1396
1432
  """
1397
1433
  Exports a model into onnx.
@@ -1404,6 +1440,7 @@ def call_torch_export_custom(
1404
1440
  :param verbose: verbosity
1405
1441
  :param optimization: optimization to do
1406
1442
  :param dump_folder: to store additional information
1443
+ :param output_names: list of output names to use
1407
1444
  :return: two dictionaries, one with some metrics,
1408
1445
  another one with whatever the function produces
1409
1446
  """
@@ -1488,6 +1525,8 @@ def call_torch_export_custom(
1488
1525
  )
1489
1526
  if opset:
1490
1527
  kws["target_opset"] = opset
1528
+ if output_names:
1529
+ kws["output_names"] = output_names
1491
1530
 
1492
1531
  epo, opt_stats = _quiet_or_not_quiet(
1493
1532
  quiet,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.7.3
3
+ Version: 0.7.5
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré