onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +154 -3
  3. onnx_diagnostic/ci_models/__init__.py +0 -0
  4. onnx_diagnostic/ci_models/ci_helpers.py +435 -0
  5. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  6. onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
  7. onnx_diagnostic/export/api.py +1 -0
  8. onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
  9. onnx_diagnostic/export/control_flow_onnx.py +23 -17
  10. onnx_diagnostic/ext_test_case.py +23 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/log_helper.py +1 -3
  13. onnx_diagnostic/helpers/optim_helper.py +116 -0
  14. onnx_diagnostic/tasks/image_text_to_text.py +15 -5
  15. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  16. onnx_diagnostic/tasks/text_generation.py +3 -0
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
  18. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  19. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  20. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
  22. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
  24. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  25. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  26. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
  27. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  28. onnx_diagnostic/torch_onnx/compare.py +357 -0
  29. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
  30. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
  31. onnx_diagnostic/export/control_flow.py +0 -214
  32. onnx_diagnostic/export/control_flow_research.py +0 -140
  33. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
  34. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
  35. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
1
+ import torch
2
+
3
+ try:
4
+ import transformers.models.funnel.modeling_funnel
5
+
6
+ patch_funnel = True
7
+ except ImportError:
8
+ patch_funnel = False
9
+
10
+ if patch_funnel:
11
+ from transformers.models.funnel.modeling_funnel import _relative_shift_gather
12
+
13
+ class patched_FunnelAttentionStructure(torch.nn.Module):
14
+ _PATCHES_ = ["relative_pos"]
15
+ _PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure
16
+
17
+ def relative_pos(
18
+ self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1
19
+ ) -> torch.Tensor:
20
+ if pooled_pos is None:
21
+ pooled_pos = pos
22
+ ref_point = pooled_pos[0] - pos[0]
23
+ # PATCHED
24
+ num_remove = shift * pooled_pos.shape[0]
25
+ max_dist = ref_point + num_remove * stride
26
+ min_dist = pooled_pos[0] - pos[-1]
27
+ return torch.arange(
28
+ max_dist.to(torch.long),
29
+ (min_dist - 1).to(torch.long),
30
+ torch.tensor(-stride, dtype=torch.long),
31
+ dtype=torch.long,
32
+ device=pos.device,
33
+ )
34
+
35
+ class patched_FunnelRelMultiheadAttention(torch.nn.Module):
36
+ _PATCHES_ = ["relative_positional_attention"]
37
+ _PATCHED_CLASS_ = (
38
+ transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention
39
+ )
40
+
41
+ def relative_positional_attention(
42
+ self, position_embeds, q_head, context_len, cls_mask=None
43
+ ):
44
+ """Relative attention score for the positional encodings"""
45
+ # q_head has shape batch_size x sea_len x n_head x d_head
46
+ if self.config.attention_type == "factorized":
47
+ phi, pi, psi, omega = position_embeds
48
+ # Shape n_head x d_head
49
+ u = self.r_r_bias * self.scale
50
+ # Shape d_model x n_head x d_head
51
+ w_r = self.r_kernel
52
+
53
+ # Shape batch_size x sea_len x n_head x d_model
54
+ q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
55
+ q_r_attention_1 = q_r_attention * phi[:, None]
56
+ q_r_attention_2 = q_r_attention * pi[:, None]
57
+
58
+ # Shape batch_size x n_head x seq_len x context_len
59
+ positional_attn = torch.einsum(
60
+ "bind,jd->bnij", q_r_attention_1, psi
61
+ ) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega)
62
+ else:
63
+ shift = 2 if q_head.shape[1] != context_len else 1
64
+ r = position_embeds[self.block_index][shift - 1]
65
+ # Shape n_head x d_head
66
+ v = self.r_r_bias * self.scale
67
+ # Shape d_model x n_head x d_head
68
+ w_r = self.r_kernel
69
+
70
+ # Shape max_rel_len x n_head x d_model
71
+ r_head = torch.einsum("td,dnh->tnh", r, w_r)
72
+ # Shape batch_size x n_head x seq_len x max_rel_len
73
+ positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
74
+ # Shape batch_size x n_head x seq_len x context_len
75
+ positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
76
+
77
+ if cls_mask is not None:
78
+ # PATCHED
79
+ positional_attn = positional_attn * cls_mask
80
+ return positional_attn
@@ -256,8 +256,23 @@ if patch_qwen2_5:
256
256
  return attn_output
257
257
 
258
258
  def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259
- first_tensor = next(a for a in args if a is not None)
260
- dtype = first_tensor.dtype
259
+ import onnx_ir
260
+
261
+ first_float_tensor = next(
262
+ a
263
+ for a in args
264
+ if a is not None
265
+ and a.dtype
266
+ in {
267
+ torch.float16,
268
+ torch.float32,
269
+ torch.bfloat16,
270
+ onnx_ir.DataType.BFLOAT16,
271
+ onnx_ir.DataType.FLOAT16,
272
+ onnx_ir.DataType.FLOAT,
273
+ }
274
+ )
275
+ dtype = first_float_tensor.dtype
261
276
  strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
262
277
  itype = torch_dtype_to_onnx_dtype(dtype)
263
278
  if strategy is not None:
@@ -269,7 +284,7 @@ if patch_qwen2_5:
269
284
  if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270
285
  # first_tensor may be a SymbolicTensor (onnx).
271
286
  # is_cuda is not available.
272
- if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
287
+ if hasattr(first_float_tensor, "is_cuda") and first_float_tensor.is_cuda:
273
288
  return "PACKED", itype
274
289
  return "LOOPMHA", itype
275
290
  raise AssertionError(
@@ -733,3 +748,71 @@ if patch_qwen2_5:
733
748
  attn_output = attn_output.reshape(seq_length, -1).contiguous()
734
749
  attn_output = self.proj(attn_output)
735
750
  return attn_output
751
+
752
+ class patched_Qwen2_5_VLModel:
753
+ _PATCHES_ = ["get_placeholder_mask"]
754
+ _PATCHED_CLASS_ = transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLModel
755
+
756
+ def get_placeholder_mask(
757
+ self,
758
+ input_ids: torch.LongTensor,
759
+ inputs_embeds: torch.FloatTensor,
760
+ image_features: Optional[torch.FloatTensor] = None,
761
+ video_features: Optional[torch.FloatTensor] = None,
762
+ ):
763
+ if input_ids is None:
764
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
765
+ torch.tensor(
766
+ self.config.image_token_id,
767
+ dtype=torch.long,
768
+ device=inputs_embeds.device,
769
+ )
770
+ )
771
+ special_image_mask = special_image_mask.all(-1)
772
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
773
+ torch.tensor(
774
+ self.config.video_token_id,
775
+ dtype=torch.long,
776
+ device=inputs_embeds.device,
777
+ )
778
+ )
779
+ special_video_mask = special_video_mask.all(-1)
780
+ else:
781
+ special_image_mask = input_ids == self.config.image_token_id
782
+ special_video_mask = input_ids == self.config.video_token_id
783
+
784
+ special_image_mask = (
785
+ special_image_mask.unsqueeze(-1)
786
+ .expand_as(inputs_embeds)
787
+ .to(inputs_embeds.device)
788
+ )
789
+
790
+ # PATCHED: we should use torch._check
791
+ # but this fails for compilation. It cannot be verified with FakeTensors
792
+ # torch._check(
793
+ # image_features is None
794
+ # or inputs_embeds[special_image_mask].numel() == image_features.numel(),
795
+ # lambda: (
796
+ # f"Image features and image tokens do not match: tokens: "
797
+ # f"{special_image_mask.sum()}, features {image_features.shape[0]}"
798
+ # ),
799
+ # )
800
+
801
+ special_video_mask = (
802
+ special_video_mask.unsqueeze(-1)
803
+ .expand_as(inputs_embeds)
804
+ .to(inputs_embeds.device)
805
+ )
806
+
807
+ # PATCHED: we should use torch._check
808
+ # but this fails for compilation. It cannot be verified with FakeTensors
809
+ # torch._check(
810
+ # video_features is None
811
+ # or inputs_embeds[special_video_mask].numel() == video_features.numel(),
812
+ # lambda: (
813
+ # f"Videos features and video tokens do not match: tokens: "
814
+ # f"{special_video_mask.sum()}, features {video_features.shape[0]}"
815
+ # ),
816
+ # )
817
+
818
+ return special_image_mask, special_video_mask
@@ -5,6 +5,7 @@ import os
5
5
  import traceback
6
6
  from functools import reduce
7
7
  from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
8
+ import sympy
8
9
  import torch
9
10
  from torch._subclasses.fake_tensor import FakeTensorMode
10
11
 
@@ -1091,3 +1092,17 @@ def patched__broadcast_in_dim_meta_level_2(
1091
1092
  new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1092
1093
 
1093
1094
  return a.as_strided(shape, new_strides, a.storage_offset())
1095
+
1096
+
1097
+ class patched_DynamicDimConstraintPrinter:
1098
+ """
1099
+ Patches
1100
+ ``torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol``.
1101
+ Valid for ``torch>=2.10``.
1102
+ """
1103
+
1104
+ def _print_Symbol(self, expr: sympy.Symbol) -> str:
1105
+ assert isinstance(expr, sympy.Symbol), str(type(expr))
1106
+ if self.symbol_to_source.get(expr):
1107
+ return self.symbol_to_source[expr][0].name
1108
+ return str(expr)
@@ -1,29 +1,37 @@
1
1
  # transformers
2
2
  from typing import List
3
3
  from .patch_helper import _has_transformers
4
-
5
4
  from ._patch_transformers_attention import (
6
5
  patched_sdpa_attention_forward,
7
6
  patched_model_bart_eager_attention_forward,
8
7
  patched_modeling_marian_eager_attention_forward,
9
8
  )
9
+ from ._patch_transformers_generation_mixin import patched_GenerationMixin
10
+ from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
11
+ from ._patch_transformers_rotary_embedding import (
12
+ patched__compute_dynamic_ntk_parameters,
13
+ patched_dynamic_rope_update,
14
+ patched_GemmaRotaryEmbedding,
15
+ patched_LlamaRotaryEmbedding,
16
+ patched_MistralRotaryEmbedding,
17
+ patched_MixtralRotaryEmbedding,
18
+ patched_PhiRotaryEmbedding,
19
+ )
20
+ from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
21
+ from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
22
+
23
+ # transformers dependent patches
10
24
 
11
25
  from ._patch_transformers_cache_utils import patch_parse_processor_args
12
26
 
13
27
  if patch_parse_processor_args:
14
28
  from ._patch_transformers_cache_utils import patched_parse_processor_args
15
-
16
- from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
17
-
18
29
  from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache
19
30
 
20
31
  if patch_DynamicLayer:
21
32
  from ._patch_transformers_dynamic_cache import patched_DynamicLayer
22
33
  if patch_DynamicCache:
23
34
  from ._patch_transformers_dynamic_cache import patched_DynamicCache
24
-
25
- from ._patch_transformers_generation_mixin import patched_GenerationMixin
26
-
27
35
  from ._patch_transformers_masking_utils import patch_masking_utils
28
36
 
29
37
  if patch_masking_utils:
@@ -33,15 +41,7 @@ if patch_masking_utils:
33
41
  patched_sdpa_mask_recent_torch,
34
42
  )
35
43
 
36
- from ._patch_transformers_rotary_embedding import (
37
- patched__compute_dynamic_ntk_parameters,
38
- patched_dynamic_rope_update,
39
- patched_GemmaRotaryEmbedding,
40
- patched_LlamaRotaryEmbedding,
41
- patched_MistralRotaryEmbedding,
42
- patched_MixtralRotaryEmbedding,
43
- patched_PhiRotaryEmbedding,
44
- )
44
+ # transformers models dependent patches
45
45
 
46
46
  if _has_transformers("4.51"):
47
47
  from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding
@@ -54,16 +54,11 @@ if _has_transformers("4.52"):
54
54
  if _has_transformers("4.53"):
55
55
  from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding
56
56
 
57
- # Models
58
-
59
57
  from ._patch_transformers_gemma3 import patch_gemma3
60
58
 
61
59
  if patch_gemma3:
62
60
  from ._patch_transformers_gemma3 import patched_Gemma3Model
63
61
 
64
- from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
65
-
66
-
67
62
  from ._patch_transformers_qwen2 import patch_qwen2
68
63
 
69
64
  if patch_qwen2:
@@ -77,16 +72,20 @@ if patch_qwen2_5:
77
72
  patched_Qwen2_5_VisionTransformerPretrainedModel,
78
73
  patched_Qwen2_5_VLVisionAttentionOneIteration,
79
74
  patched_Qwen2_5_VLVisionAttention,
75
+ patched_Qwen2_5_VLModel,
80
76
  PLUGS as PLUGS_Qwen25,
81
77
  )
82
-
83
78
  from ._patch_transformers_qwen3 import patch_qwen3
84
79
 
85
80
  if patch_qwen3:
86
81
  from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock
82
+ from ._patch_transformers_funnel import patch_funnel
87
83
 
88
-
89
- from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
84
+ if patch_funnel:
85
+ from ._patch_transformers_funnel import (
86
+ patched_FunnelAttentionStructure,
87
+ patched_FunnelRelMultiheadAttention,
88
+ )
90
89
 
91
90
 
92
91
  def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821
@@ -184,7 +184,18 @@ def _trygetattr(config, attname):
184
184
  return None
185
185
 
186
186
 
187
+ def rewrite_architecture_name(name: Optional[str]) -> Optional[str]:
188
+ if name == "ConditionalDETRForObjectDetection":
189
+ return "ConditionalDetrForObjectDetection"
190
+ return name
191
+
192
+
187
193
  def architecture_from_config(config) -> Optional[str]:
194
+ """Guesses the architecture (class) of the model described by this config."""
195
+ return rewrite_architecture_name(_architecture_from_config(config))
196
+
197
+
198
+ def _architecture_from_config(config) -> Optional[str]:
188
199
  """Guesses the architecture (class) of the model described by this config."""
189
200
  if isinstance(config, dict):
190
201
  if "_class_name" in config:
@@ -5,7 +5,10 @@ from typing import Dict, List
5
5
 
6
6
  __date__ = "2025-06-21"
7
7
 
8
- __data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
8
+ __data_arch_values__ = {
9
+ "ConditionalDETRForObjectDetection": dict(image_size=224),
10
+ "ResNetForImageClassification": dict(image_size=224),
11
+ }
9
12
 
10
13
  __data_arch__ = textwrap.dedent(
11
14
  """
@@ -32,6 +35,7 @@ __data_arch__ = textwrap.dedent(
32
35
  ConvNextV2Model,image-feature-extraction
33
36
  CosmosTransformer3DModel,image-to-video
34
37
  CvtModel,feature-extraction
38
+ ClvpModelForConditionalGeneration,audio-feature-extraction
35
39
  DPTModel,image-feature-extraction
36
40
  Data2VecAudioModel,feature-extraction
37
41
  Data2VecTextModel,feature-extraction
@@ -49,6 +53,8 @@ __data_arch__ = textwrap.dedent(
49
53
  ElectraModel,feature-extraction
50
54
  EsmModel,feature-extraction
51
55
  FalconMambaForCausalLM,text-generation
56
+ FunnelBaseModel,feature-extraction
57
+ FuyuForCausalLM,image-text-to-text
52
58
  GLPNModel,image-feature-extraction
53
59
  GPT2LMHeadModel,text-generation
54
60
  GPTBigCodeModel,feature-extraction
@@ -63,6 +69,7 @@ __data_arch__ = textwrap.dedent(
63
69
  Glm4vMoeForConditionalGeneration,image-text-to-text
64
70
  GraniteForCausalLM,text-generation
65
71
  GroupViTModel,feature-extraction
72
+ HeliumForCausalLM,text-generation
66
73
  HieraForImageClassification,image-classification
67
74
  HubertModel,feature-extraction
68
75
  IBertModel,feature-extraction
@@ -136,6 +143,7 @@ __data_arch__ = textwrap.dedent(
136
143
  SwinModel,image-feature-extraction
137
144
  Swinv2Model,image-feature-extraction
138
145
  T5ForConditionalGeneration,text2text-generation
146
+ T5GemmaForConditionalGeneration,text2text-generation
139
147
  TableTransformerModel,image-feature-extraction
140
148
  TableTransformerForObjectDetection,object-detection
141
149
  UNet2DConditionModel,text-to-image
@@ -55,6 +55,7 @@ Automatically generated:
55
55
  import base64
56
56
  import json
57
57
  import textwrap
58
+ from typing import Any
58
59
  import transformers
59
60
 
60
61
  null = None
@@ -62,6 +63,22 @@ true = True
62
63
  false = False
63
64
 
64
65
 
66
+ def _enforce_default(config_type: type, **kwargs) -> Any:
67
+ config = config_type(**kwargs)
68
+ for name in [
69
+ *[k for k in kwargs if k.endswith("_token_id")],
70
+ "attention_dropout",
71
+ "hidden_size",
72
+ "hidden_act",
73
+ "intermediate_size",
74
+ "max_position_embeddings",
75
+ "vocab_size",
76
+ ]:
77
+ if name in kwargs and (not hasattr(config, name) or getattr(config, name) is None):
78
+ setattr(config, name, kwargs[name])
79
+ return config
80
+
81
+
65
82
  def _ccached_arnir0_tiny_LLM():
66
83
  "arnir0/Tiny-LLM"
67
84
  return transformers.LlamaConfig(
@@ -4691,7 +4708,8 @@ def _ccached_zai_glm_45():
4691
4708
 
4692
4709
  def _ccached_microsoft_phi3_mini_128k_instruct():
4693
4710
  "microsoft/Phi-3-mini-128k-instruct"
4694
- return transformers.Phi3Config(
4711
+ return _enforce_default(
4712
+ transformers.Phi3Config,
4695
4713
  **{
4696
4714
  "_name_or_path": "Phi-3-mini-128k-instruct",
4697
4715
  "architectures": ["Phi3ForCausalLM"],
@@ -4827,13 +4845,14 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
4827
4845
  "use_cache": true,
4828
4846
  "attention_bias": false,
4829
4847
  "vocab_size": 32064,
4830
- }
4848
+ },
4831
4849
  )
4832
4850
 
4833
4851
 
4834
4852
  def _ccached_google_gemma_3_4b_it_like():
4835
4853
  "google/gemma-3-4b-it"
4836
- return transformers.Gemma3Config(
4854
+ return _enforce_default(
4855
+ transformers.Gemma3Config,
4837
4856
  **{
4838
4857
  "architectures": ["Gemma3ForConditionalGeneration"],
4839
4858
  "boi_token_index": 255999,
@@ -4863,13 +4882,14 @@ def _ccached_google_gemma_3_4b_it_like():
4863
4882
  "patch_size": 14,
4864
4883
  "vision_use_head": false,
4865
4884
  },
4866
- }
4885
+ },
4867
4886
  )
4868
4887
 
4869
4888
 
4870
4889
  def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
4871
4890
  "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
4872
- return transformers.Gemma3TextConfig(
4891
+ return _enforce_default(
4892
+ transformers.Gemma3TextConfig,
4873
4893
  **{
4874
4894
  "architectures": ["Gemma3ForCausalLM"],
4875
4895
  "attention_bias": false,
@@ -4901,13 +4921,14 @@ def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
4901
4921
  "transformers_version": "4.52.0.dev0",
4902
4922
  "use_cache": true,
4903
4923
  "vocab_size": 262144,
4904
- }
4924
+ },
4905
4925
  )
4906
4926
 
4907
4927
 
4908
4928
  def _ccached_qwen_qwen2_5_vl_7b_instruct():
4909
4929
  "Qwen/Qwen2.5-VL-7B-Instruct"
4910
- return transformers.Qwen2_5_VLConfig(
4930
+ return _enforce_default(
4931
+ transformers.Qwen2_5_VLConfig,
4911
4932
  **{
4912
4933
  "architectures": ["Qwen2_5_VLForConditionalGeneration"],
4913
4934
  "attention_dropout": 0.0,
@@ -4954,5 +4975,5 @@ def _ccached_qwen_qwen2_5_vl_7b_instruct():
4954
4975
  },
4955
4976
  "rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]},
4956
4977
  "vocab_size": 152064,
4957
- }
4978
+ },
4958
4979
  )
@@ -64,6 +64,7 @@ def get_untrained_model_with_inputs(
64
64
  use_only_preinstalled: bool = False,
65
65
  config_reduction: Optional[Callable[[Any, str], Dict]] = None,
66
66
  submodule: Optional[str] = None,
67
+ skip_inputs: bool = False,
67
68
  ) -> Dict[str, Any]:
68
69
  """
69
70
  Gets a non initialized model similar to the original model
@@ -93,6 +94,7 @@ def get_untrained_model_with_inputs(
93
94
  this function takes a configuration and a task (string)
94
95
  as arguments
95
96
  :param submodule: use a submodule instead of the main model
97
+ :param skip_inputs: do not generate the inputs
96
98
  :return: dictionary with a model, inputs, dynamic shapes, and the configuration,
97
99
  some necessary rewriting as well
98
100
 
@@ -332,13 +334,12 @@ def get_untrained_model_with_inputs(
332
334
  f"[get_untrained_model_with_inputs] "
333
335
  f"instantiate_specific_model(2) {cls_model}"
334
336
  )
335
-
336
337
  try:
337
338
  if type(config) is dict:
338
339
  model = cls_model(**config)
339
340
  else:
340
341
  model = cls_model(config)
341
- except RuntimeError as e:
342
+ except (RuntimeError, AttributeError, ValueError) as e:
342
343
  raise RuntimeError(
343
344
  f"Unable to instantiate class {cls_model.__name__} with\n{config}"
344
345
  ) from e
@@ -350,23 +351,27 @@ def get_untrained_model_with_inputs(
350
351
  )
351
352
 
352
353
  # input kwargs
353
- seed = int(os.environ.get("SEED", "17")) + 1
354
- torch.manual_seed(seed)
355
- kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
356
- if verbose:
357
- print(f"[get_untrained_model_with_inputs] use fct={fct}")
358
- if os.environ.get("PRINT_CONFIG") in (1, "1"):
359
- print(f"-- input kwargs for task {task!r}")
360
- pprint.pprint(kwargs)
361
- if inputs_kwargs:
362
- kwargs.update(inputs_kwargs)
363
-
364
- # This line is important. Some models may produce different
365
- # outputs even with the same inputs in training mode.
366
- model.eval() # type: ignore[union-attr]
367
- res = fct(model, config, add_second_input=add_second_input, **kwargs)
368
-
369
- res["input_kwargs"] = kwargs
354
+ if not skip_inputs:
355
+ seed = int(os.environ.get("SEED", "17")) + 1
356
+ torch.manual_seed(seed)
357
+ kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
358
+ if verbose:
359
+ print(f"[get_untrained_model_with_inputs] use fct={fct}")
360
+ if os.environ.get("PRINT_CONFIG") in (1, "1"):
361
+ print(f"-- input kwargs for task {task!r}")
362
+ pprint.pprint(kwargs)
363
+ if inputs_kwargs:
364
+ kwargs.update(inputs_kwargs)
365
+
366
+ # This line is important. Some models may produce different
367
+ # outputs even with the same inputs in training mode.
368
+ model.eval() # type: ignore[union-attr]
369
+ res = fct(model, config, add_second_input=add_second_input, **kwargs)
370
+
371
+ res["input_kwargs"] = kwargs
372
+ else:
373
+ res = {}
374
+
370
375
  res["model_kwargs"] = mkwargs
371
376
  if diff_config is not None:
372
377
  res["dump_info"] = dict(config_diff=diff_config)