onnx-diagnostic 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +43 -1
  3. onnx_diagnostic/export/dynamic_shapes.py +7 -3
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +11 -1
  6. onnx_diagnostic/helpers/config_helper.py +7 -2
  7. onnx_diagnostic/helpers/helper.py +31 -0
  8. onnx_diagnostic/helpers/torch_test_helper.py +6 -0
  9. onnx_diagnostic/tasks/__init__.py +6 -2
  10. onnx_diagnostic/tasks/automatic_speech_recognition.py +22 -4
  11. onnx_diagnostic/tasks/feature_extraction.py +76 -0
  12. onnx_diagnostic/tasks/fill_mask.py +14 -3
  13. onnx_diagnostic/tasks/image_classification.py +16 -3
  14. onnx_diagnostic/tasks/image_text_to_text.py +24 -4
  15. onnx_diagnostic/tasks/mixture_of_expert.py +76 -0
  16. onnx_diagnostic/tasks/sentence_similarity.py +14 -3
  17. onnx_diagnostic/tasks/text2text_generation.py +19 -3
  18. onnx_diagnostic/tasks/text_classification.py +14 -3
  19. onnx_diagnostic/tasks/text_generation.py +69 -48
  20. onnx_diagnostic/tasks/zero_shot_image_classification.py +18 -3
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -2
  22. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +6 -1
  23. onnx_diagnostic/torch_models/hghub/hub_api.py +12 -5
  24. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -0
  25. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +71 -0
  26. onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
  27. onnx_diagnostic/torch_models/test_helper.py +23 -5
  28. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/METADATA +1 -1
  29. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/RECORD +32 -30
  30. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/WHEEL +1 -1
  31. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/licenses/LICENSE.txt +0 -0
  32. {onnx_diagnostic-0.4.0.dist-info → onnx_diagnostic-0.4.2.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from ..helpers.config_helper import update_config, check_hasattr, _pick
6
6
  __TASK__ = "text2text-generation"
7
7
 
8
8
 
9
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
9
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
10
10
  """Reduces a model size."""
11
11
  kwargs: Dict[str, Any] = {}
12
12
  if hasattr(config, "num_decoder_layers"):
@@ -28,6 +28,7 @@ def get_inputs(
28
28
  batch_size: int = 2,
29
29
  sequence_length: int = 30,
30
30
  sequence_length2: int = 3,
31
+ add_second_input: bool = False,
31
32
  **kwargs, # unused
32
33
  ):
33
34
  """
@@ -125,10 +126,25 @@ def get_inputs(
125
126
  # encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
126
127
  # encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
127
128
  )
128
- return dict(inputs=inputs, dynamic_shapes=shapes)
129
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
130
+ if add_second_input:
131
+ res["inputs2"] = get_inputs(
132
+ model=model,
133
+ config=config,
134
+ dummy_max_token_id=dummy_max_token_id,
135
+ num_key_value_heads=num_key_value_heads,
136
+ num_hidden_layers=num_hidden_layers,
137
+ head_dim=head_dim,
138
+ encoder_dim=encoder_dim,
139
+ batch_size=batch_size + 1,
140
+ sequence_length=sequence_length + 1,
141
+ sequence_length2=sequence_length2 + 1,
142
+ **kwargs,
143
+ )["inputs"]
144
+ return res
129
145
 
130
146
 
131
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
147
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
132
148
  """
133
149
  Inputs kwargs.
134
150
 
@@ -5,7 +5,7 @@ from ..helpers.config_helper import update_config, check_hasattr
5
5
  __TASK__ = "text-classification"
6
6
 
7
7
 
8
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
9
  """Reduces a model size."""
10
10
  check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
11
  kwargs = dict(
@@ -22,6 +22,7 @@ def get_inputs(
22
22
  batch_size: int,
23
23
  sequence_length: int,
24
24
  dummy_max_token_id: int,
25
+ add_second_input: bool = False,
25
26
  **kwargs, # unused
26
27
  ):
27
28
  """
@@ -48,10 +49,20 @@ def get_inputs(
48
49
  token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
49
50
  attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
50
51
  )
51
- return dict(inputs=inputs, dynamic_shapes=shapes)
52
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
53
+ if add_second_input:
54
+ res["inputs2"] = get_inputs(
55
+ model=model,
56
+ config=config,
57
+ batch_size=batch_size + 1,
58
+ sequence_length=sequence_length + 1,
59
+ dummy_max_token_id=dummy_max_token_id,
60
+ **kwargs,
61
+ )["inputs"]
62
+ return res
52
63
 
53
64
 
54
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
65
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
55
66
  """
56
67
  Inputs kwargs.
57
68
 
@@ -11,7 +11,7 @@ from ..helpers.config_helper import update_config, check_hasattr, _pick
11
11
  __TASK__ = "text-generation"
12
12
 
13
13
 
14
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
14
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
15
15
  """Reduces a model size."""
16
16
  # FalconMambaConfig: use_mambapy
17
17
  check_hasattr(
@@ -71,6 +71,7 @@ def get_inputs(
71
71
  num_key_value_heads: Optional[int] = None,
72
72
  head_dim: Optional[int] = None,
73
73
  cls_cache: Optional[Union[type, str]] = None,
74
+ add_second_input: bool = False,
74
75
  **kwargs, # unused
75
76
  ):
76
77
  """
@@ -144,58 +145,78 @@ def get_inputs(
144
145
  ]
145
146
  ),
146
147
  )
147
- return dict(inputs=inputs, dynamic_shapes=shapes)
148
-
149
- if head_dim is None:
150
- assert config, "head_dim is None, the value cannot be set without a configuration"
151
- head_dim = config.hidden_size // config.num_attention_heads
148
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
149
+ else:
150
+ if head_dim is None:
151
+ assert config, "head_dim is None, the value cannot be set without a configuration"
152
+ head_dim = config.hidden_size // config.num_attention_heads
152
153
 
153
- shapes = {
154
- "input_ids": {0: batch, 1: seq_length},
155
- "attention_mask": {
156
- 0: batch,
157
- 1: "cache+seq", # cache_length + seq_length
158
- },
159
- "position_ids": {
160
- 0: batch,
161
- 1: "cache+seq", # cache_length + seq_length
162
- },
163
- "past_key_values": [
164
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
165
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
166
- ],
167
- }
154
+ shapes = {
155
+ "input_ids": {0: batch, 1: seq_length},
156
+ "attention_mask": {
157
+ 0: batch,
158
+ 1: "cache+seq", # cache_length + seq_length
159
+ },
160
+ "position_ids": {
161
+ 0: batch,
162
+ 1: "cache+seq", # cache_length + seq_length
163
+ },
164
+ "past_key_values": [
165
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
166
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
167
+ ],
168
+ }
168
169
 
169
- make_cache = (
170
- make_sliding_window_cache
171
- if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache)
172
- else make_dynamic_cache
173
- )
170
+ make_cache = (
171
+ make_sliding_window_cache
172
+ if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache)
173
+ else make_dynamic_cache
174
+ )
174
175
 
175
- inputs = dict(
176
- input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
177
- torch.int64
178
- ),
179
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
180
- torch.int64
181
- ),
182
- position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
183
- .to(torch.int64)
184
- .expand((batch_size, -1)),
185
- past_key_values=make_cache(
186
- [
187
- (
188
- torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
189
- torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
190
- )
191
- for i in range(num_hidden_layers)
192
- ]
193
- ),
194
- )
195
- return dict(inputs=inputs, dynamic_shapes=shapes)
176
+ inputs = dict(
177
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
178
+ torch.int64
179
+ ),
180
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
181
+ torch.int64
182
+ ),
183
+ position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
184
+ .to(torch.int64)
185
+ .expand((batch_size, -1)),
186
+ past_key_values=make_cache(
187
+ [
188
+ (
189
+ torch.randn(
190
+ batch_size, num_key_value_heads, sequence_length, head_dim
191
+ ),
192
+ torch.randn(
193
+ batch_size, num_key_value_heads, sequence_length, head_dim
194
+ ),
195
+ )
196
+ for i in range(num_hidden_layers)
197
+ ]
198
+ ),
199
+ )
200
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
201
+ if add_second_input:
202
+ res["inputs2"] = get_inputs(
203
+ model=model,
204
+ config=config,
205
+ dummy_max_token_id=dummy_max_token_id,
206
+ num_hidden_layers=num_hidden_layers,
207
+ batch_size=batch_size + 1,
208
+ sequence_length=sequence_length + 1,
209
+ sequence_length2=sequence_length2 + 1,
210
+ dynamic_rope=dynamic_rope,
211
+ num_key_value_heads=num_key_value_heads,
212
+ head_dim=head_dim,
213
+ cls_cache=cls_cache,
214
+ **kwargs,
215
+ )["inputs"]
216
+ return res
196
217
 
197
218
 
198
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
219
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
199
220
  """
200
221
  Inputs kwargs.
201
222
 
@@ -5,7 +5,7 @@ from ..helpers.config_helper import update_config, check_hasattr
5
5
  __TASK__ = "zero-shot-image-classification"
6
6
 
7
7
 
8
- def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
9
  """Reduces a model size."""
10
10
  check_hasattr(config, "vision_config", "text_config")
11
11
  check_hasattr(config.vision_config, "num_hidden_layers", "num_attention_heads")
@@ -34,6 +34,7 @@ def get_inputs(
34
34
  input_height: int = 224,
35
35
  input_channels: int = 3,
36
36
  batch_size_image=3,
37
+ add_second_input: bool = False,
37
38
  **kwargs, # unused
38
39
  ):
39
40
  """
@@ -81,10 +82,24 @@ def get_inputs(
81
82
  batch_size_image, input_channels, input_width, input_height
82
83
  ).clamp(-1, 1),
83
84
  )
84
- return dict(inputs=inputs, dynamic_shapes=shapes)
85
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
86
+ if add_second_input:
87
+ res["inputs2"] = get_inputs(
88
+ model=model,
89
+ config=config,
90
+ dummy_max_token_id=dummy_max_token_id,
91
+ batch_size=batch_size + 1,
92
+ sequence_length=sequence_length + 1,
93
+ input_width=input_width,
94
+ input_height=input_height,
95
+ input_channels=input_channels,
96
+ batch_size_image=batch_size_image + 1,
97
+ **kwargs,
98
+ )["inputs"]
99
+ return res
85
100
 
86
101
 
87
- def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
102
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
88
103
  """
89
104
  Inputs kwargs.
90
105
 
@@ -44,7 +44,7 @@ def _catch_produce_guards_and_solve_constraints(
44
44
  raise
45
45
  if verbose:
46
46
  print(
47
- f"[_catch_produce_guards_and_solve_constraints] ERROR"
47
+ f"[_catch_produce_guards_and_solve_constraints] ERROR: "
48
48
  f"produce_guards_and_solve_constraints failed, "
49
49
  f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
50
50
  f"fake_mode={fake_mode}\n"
@@ -54,6 +54,7 @@ def _catch_produce_guards_and_solve_constraints(
54
54
  f"_is_torch_jit_trace={_is_torch_jit_trace}\n"
55
55
  f"exc={e}\ngm={gm}"
56
56
  )
57
+ torch._dynamo.reset()
57
58
 
58
59
 
59
60
  def patch__check_input_constraints_for_graph(
@@ -70,13 +71,14 @@ def patch__check_input_constraints_for_graph(
70
71
  raise
71
72
  if verbose:
72
73
  print(
73
- f"[_check_input_constraints_for_graph] ERROR"
74
+ f"[_check_input_constraints_for_graph] ERROR: "
74
75
  f"_check_input_constraints_for_graph failed, "
75
76
  f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
76
77
  f"input_placeholders={input_placeholders}\n"
77
78
  f"range_constraints={range_constraints}\n"
78
79
  f"exc={e}"
79
80
  )
81
+ torch._dynamo.reset()
80
82
 
81
83
 
82
84
  def patched_infer_size(a, b):
@@ -5,6 +5,7 @@ import torch
5
5
  import transformers
6
6
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
7
7
  from transformers.cache_utils import StaticCache, Cache, DynamicCache
8
+ from ...ext_test_case import has_transformers
8
9
  from ...helpers.torch_test_helper import is_torchdynamo_exporting
9
10
 
10
11
 
@@ -50,7 +51,8 @@ class patched_AttentionMaskConverter:
50
51
  ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
51
52
  """
52
53
 
53
- _PATCHES_ = ["_make_causal_mask"]
54
+ # This method was fixed in 4.51 at least.
55
+ _PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
54
56
  _PATCHED_CLASS_ = AttentionMaskConverter
55
57
 
56
58
  @staticmethod
@@ -69,6 +71,9 @@ class patched_AttentionMaskConverter:
69
71
  This static method may be called with ``AttentionMaskConverter._make_causal_mask``
70
72
  or ``self._make_causal_mask``. That changes this argument is receives.
71
73
  That should not matter but...
74
+ The patch should be implemented in another way. static methods do not play well
75
+ with a simple replacement.
76
+ Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
72
77
  """
73
78
  if args:
74
79
  index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
@@ -1,8 +1,10 @@
1
+ import copy
1
2
  import functools
2
3
  import os
3
4
  from typing import Any, Dict, List, Optional, Union
4
5
  import transformers
5
6
  from huggingface_hub import HfApi, model_info
7
+ from ...helpers.config_helper import update_config
6
8
  from . import hub_data_cached_configs
7
9
  from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
8
10
 
@@ -30,7 +32,7 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig
30
32
  return res
31
33
 
32
34
 
33
- def get_cached_configuration(name: str) -> Optional[transformers.PretrainedConfig]:
35
+ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.PretrainedConfig]:
34
36
  """
35
37
  Returns cached configuration to avoid having to many accesses to internet.
36
38
  It returns None if not Cache. The list of cached models follows.
@@ -46,14 +48,18 @@ def get_cached_configuration(name: str) -> Optional[transformers.PretrainedConfi
46
48
  cached = _retrieve_cached_configurations()
47
49
  assert cached, "no cached configuration, which is weird"
48
50
  if name in cached:
49
- return cached[name]()
51
+ conf = cached[name]()
52
+ if kwargs:
53
+ conf = copy.deepcopy(conf)
54
+ update_config(conf, kwargs)
55
+ return conf
50
56
  if os.environ.get("NOHTTP", ""):
51
57
  raise AssertionError(f"Unable to find {name!r} in {sorted(cached)}")
52
58
  return None
53
59
 
54
60
 
55
61
  def get_pretrained_config(
56
- model_id: str, trust_remote_code: bool = True, use_preinstalled: bool = True
62
+ model_id: str, trust_remote_code: bool = True, use_preinstalled: bool = True, **kwargs
57
63
  ) -> Any:
58
64
  """
59
65
  Returns the config for a model_id.
@@ -65,14 +71,15 @@ def get_pretrained_config(
65
71
  accessing the network, if available, it is returned by
66
72
  :func:`get_cached_configuration`, the cached list is mostly for
67
73
  unit tests
74
+ :param kwargs: additional kwargs
68
75
  :return: a configuration
69
76
  """
70
77
  if use_preinstalled:
71
- conf = get_cached_configuration(model_id)
78
+ conf = get_cached_configuration(model_id, **kwargs)
72
79
  if conf is not None:
73
80
  return conf
74
81
  return transformers.AutoConfig.from_pretrained(
75
- model_id, trust_remote_code=trust_remote_code
82
+ model_id, trust_remote_code=trust_remote_code, **kwargs
76
83
  )
77
84
 
78
85
 
@@ -13,6 +13,7 @@ __data_arch__ = textwrap.dedent(
13
13
  ASTModel,feature-extraction
14
14
  AlbertModel,feature-extraction
15
15
  BeitForImageClassification,image-classification
16
+ BartModel,feature-extraction
16
17
  BertForMaskedLM,fill-mask
17
18
  BertForSequenceClassification,text-classification
18
19
  BertModel,sentence-similarity
@@ -76,6 +77,7 @@ __data_arch__ = textwrap.dedent(
76
77
  MobileNetV2Model,image-feature-extraction
77
78
  MobileViTForImageClassification,image-classification
78
79
  ModernBertForMaskedLM,fill-mask
80
+ Phi4MMForCausalLM,MoE
79
81
  MoonshineForConditionalGeneration,automatic-speech-recognition
80
82
  MptForCausalLM,text-generation
81
83
  MusicgenForConditionalGeneration,text-to-audio
@@ -3569,3 +3569,74 @@ def _ccached_tiiuae_falcon_mamba_tiny_dev():
3569
3569
  "vocab_size": 65024,
3570
3570
  }
3571
3571
  )
3572
+
3573
+
3574
+ def _ccached_facebook_bart_base():
3575
+ "facebook/bart-base"
3576
+ return transformers.BartConfig(
3577
+ **{
3578
+ "_name_or_path": "bart-base",
3579
+ "activation_dropout": 0.1,
3580
+ "activation_function": "gelu",
3581
+ "add_bias_logits": false,
3582
+ "add_final_layer_norm": false,
3583
+ "architectures": ["BartModel"],
3584
+ "attention_dropout": 0.1,
3585
+ "bos_token_id": 0,
3586
+ "classif_dropout": 0.1,
3587
+ "classifier_dropout": 0.0,
3588
+ "d_model": 768,
3589
+ "decoder_attention_heads": 12,
3590
+ "decoder_ffn_dim": 3072,
3591
+ "decoder_layerdrop": 0.0,
3592
+ "decoder_layers": 6,
3593
+ "decoder_start_token_id": 2,
3594
+ "dropout": 0.1,
3595
+ "early_stopping": true,
3596
+ "encoder_attention_heads": 12,
3597
+ "encoder_ffn_dim": 3072,
3598
+ "encoder_layerdrop": 0.0,
3599
+ "encoder_layers": 6,
3600
+ "eos_token_id": 2,
3601
+ "forced_eos_token_id": 2,
3602
+ "forced_bos_token_id": 0,
3603
+ "gradient_checkpointing": false,
3604
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1", "2": "LABEL_2"},
3605
+ "init_std": 0.02,
3606
+ "is_encoder_decoder": true,
3607
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
3608
+ "max_position_embeddings": 1024,
3609
+ "model_type": "bart",
3610
+ "no_repeat_ngram_size": 3,
3611
+ "normalize_before": false,
3612
+ "normalize_embedding": true,
3613
+ "num_beams": 4,
3614
+ "num_hidden_layers": 6,
3615
+ "pad_token_id": 1,
3616
+ "scale_embedding": false,
3617
+ "task_specific_params": {
3618
+ "summarization": {
3619
+ "length_penalty": 1.0,
3620
+ "max_length": 128,
3621
+ "min_length": 12,
3622
+ "num_beams": 4,
3623
+ },
3624
+ "summarization_cnn": {
3625
+ "length_penalty": 2.0,
3626
+ "max_length": 142,
3627
+ "min_length": 56,
3628
+ "num_beams": 4,
3629
+ },
3630
+ "summarization_xsum": {
3631
+ "length_penalty": 1.0,
3632
+ "max_length": 62,
3633
+ "min_length": 11,
3634
+ "num_beams": 6,
3635
+ },
3636
+ },
3637
+ "torch_dtype": "float32",
3638
+ "transformers_version": "4.12.0.dev0",
3639
+ "use_cache": true,
3640
+ "vocab_size": 50265,
3641
+ }
3642
+ )
@@ -17,6 +17,7 @@ def get_untrained_model_with_inputs(
17
17
  dynamic_rope: Optional[bool] = None,
18
18
  same_as_pretrained: bool = False,
19
19
  use_preinstalled: bool = True,
20
+ add_second_input: bool = False,
20
21
  ) -> Dict[str, Any]:
21
22
  """
22
23
  Gets a non initialized model similar to the original model
@@ -34,6 +35,8 @@ def get_untrained_model_with_inputs(
34
35
  :param same_as_pretrained: if True, do not change the default values
35
36
  to get a smaller model
36
37
  :param use_preinstalled: use preinstalled configurations
38
+ :param add_second_input: provides a second inputs to check a model
39
+ supports different shapes
37
40
  :return: dictionary with a model, inputs, dynamic shapes, and the configuration
38
41
 
39
42
  Example:
@@ -58,7 +61,9 @@ def get_untrained_model_with_inputs(
58
61
  if use_preinstalled:
59
62
  print(f"[get_untrained_model_with_inputs] use preinstalled {model_id!r}")
60
63
  if config is None:
61
- config = get_pretrained_config(model_id, use_preinstalled=use_preinstalled)
64
+ config = get_pretrained_config(
65
+ model_id, use_preinstalled=use_preinstalled, **(model_kwargs or {})
66
+ )
62
67
  archs = config.architectures # type: ignore
63
68
  assert archs is not None and len(archs) == 1, (
64
69
  f"Unable to determine the architecture for model {model_id!r}, "
@@ -67,7 +72,6 @@ def get_untrained_model_with_inputs(
67
72
  arch = archs[0]
68
73
  if verbose:
69
74
  print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
70
- config = get_pretrained_config(model_id)
71
75
  if verbose:
72
76
  print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
73
77
  task = task_from_arch(arch)
@@ -106,7 +110,7 @@ def get_untrained_model_with_inputs(
106
110
  # This line is important. Some models may produce different
107
111
  # outputs even with the same inputs in training mode.
108
112
  model.eval()
109
- res = fct(model, config, **kwargs)
113
+ res = fct(model, config, add_second_input=add_second_input, **kwargs)
110
114
 
111
115
  res["input_kwargs"] = kwargs
112
116
  res["model_kwargs"] = mkwargs
@@ -11,10 +11,10 @@ from ..helpers.helper import flatten_object
11
11
  from ..helpers.rt_helper import make_feeds
12
12
  from ..helpers.torch_test_helper import to_any, torch_deepcopy
13
13
  from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
14
+ from ..tasks import random_input_kwargs
14
15
  from ..torch_export_patches import bypass_export_some_errors
15
16
  from ..torch_export_patches.patch_inputs import use_dyn_not_str
16
17
  from .hghub import get_untrained_model_with_inputs
17
- from .hghub.model_inputs import random_input_kwargs
18
18
 
19
19
 
20
20
  def empty(value: Any) -> bool:
@@ -221,6 +221,7 @@ def validate_model(
221
221
  drop_inputs: Optional[List[str]] = None,
222
222
  ortfusiontype: Optional[str] = None,
223
223
  input_options: Optional[Dict[str, Any]] = None,
224
+ model_options: Optional[Dict[str, Any]] = None,
224
225
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
225
226
  """
226
227
  Validates a model.
@@ -251,6 +252,8 @@ def validate_model(
251
252
  see :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion`
252
253
  :param input_options: additional options to define the dummy inputs
253
254
  used to export
255
+ :param model_options: additional options when creating the model such as
256
+ ``num_hidden_layers`` or ``attn_implementation``
254
257
  :return: two dictionaries, one with some metrics,
255
258
  another one with whatever the function produces
256
259
  """
@@ -289,10 +292,13 @@ def validate_model(
289
292
 
290
293
  if verbose:
291
294
  print(f"[validate_model] validate model id {model_id!r}")
295
+ if model_options:
296
+ print(f"[validate_model] model_options={model_options!r}")
292
297
  print(f"[validate_model] get dummy inputs with input_options={input_options}...")
293
298
  summary["model_id"] = model_id
294
299
 
295
300
  iop = input_options or {}
301
+ mop = model_options or {}
296
302
  data = _quiet_or_not_quiet(
297
303
  quiet,
298
304
  "create",
@@ -301,18 +307,28 @@ def validate_model(
301
307
  (
302
308
  lambda mid=model_id, v=verbose, task=task, tr=trained, iop=iop: (
303
309
  get_untrained_model_with_inputs(
304
- mid, verbose=v, task=task, same_as_pretrained=tr, inputs_kwargs=iop
310
+ mid,
311
+ verbose=v,
312
+ task=task,
313
+ same_as_pretrained=tr,
314
+ inputs_kwargs=iop,
315
+ model_kwargs=mop,
305
316
  )
306
317
  )
307
318
  ),
308
319
  )
309
- data["input_options"] = input_options
320
+ data["input_options"] = iop
321
+ data["model_options"] = mop
322
+ if iop:
323
+ summary["input_options"] = str(iop)
324
+ if mop:
325
+ summary["model_options"] = str(mop)
310
326
  if "ERR_create" in summary:
311
327
  return summary, data
312
328
 
313
329
  if drop_inputs:
314
330
  if verbose:
315
- print(f"[validate_model] -- drop inputs {drop_inputs!r}")
331
+ print(f"[validate_model] -- drop inputs: {drop_inputs!r}")
316
332
  print(f"[validate_model] current inputs: {string_type(data['inputs'])}")
317
333
  print(
318
334
  f"[validate_model] current dynnamic_shapes: "
@@ -505,7 +521,9 @@ def validate_model(
505
521
  if verbose:
506
522
  print("[validate_model] done (dump)")
507
523
 
508
- if not exporter or not exporter.startswith(("onnx-", "custom-")):
524
+ if not exporter or (
525
+ not exporter.startswith(("onnx-", "custom-")) and exporter != "custom"
526
+ ):
509
527
  if verbose:
510
528
  print("[validate_model] -- done (final)")
511
529
  if dump_stats:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré