onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +154 -3
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +435 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
- onnx_diagnostic/export/api.py +1 -0
- onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/ext_test_case.py +23 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/tasks/image_text_to_text.py +15 -5
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import transformers.models.funnel.modeling_funnel
|
|
5
|
+
|
|
6
|
+
patch_funnel = True
|
|
7
|
+
except ImportError:
|
|
8
|
+
patch_funnel = False
|
|
9
|
+
|
|
10
|
+
if patch_funnel:
|
|
11
|
+
from transformers.models.funnel.modeling_funnel import _relative_shift_gather
|
|
12
|
+
|
|
13
|
+
class patched_FunnelAttentionStructure(torch.nn.Module):
|
|
14
|
+
_PATCHES_ = ["relative_pos"]
|
|
15
|
+
_PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure
|
|
16
|
+
|
|
17
|
+
def relative_pos(
|
|
18
|
+
self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
if pooled_pos is None:
|
|
21
|
+
pooled_pos = pos
|
|
22
|
+
ref_point = pooled_pos[0] - pos[0]
|
|
23
|
+
# PATCHED
|
|
24
|
+
num_remove = shift * pooled_pos.shape[0]
|
|
25
|
+
max_dist = ref_point + num_remove * stride
|
|
26
|
+
min_dist = pooled_pos[0] - pos[-1]
|
|
27
|
+
return torch.arange(
|
|
28
|
+
max_dist.to(torch.long),
|
|
29
|
+
(min_dist - 1).to(torch.long),
|
|
30
|
+
torch.tensor(-stride, dtype=torch.long),
|
|
31
|
+
dtype=torch.long,
|
|
32
|
+
device=pos.device,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
class patched_FunnelRelMultiheadAttention(torch.nn.Module):
|
|
36
|
+
_PATCHES_ = ["relative_positional_attention"]
|
|
37
|
+
_PATCHED_CLASS_ = (
|
|
38
|
+
transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def relative_positional_attention(
|
|
42
|
+
self, position_embeds, q_head, context_len, cls_mask=None
|
|
43
|
+
):
|
|
44
|
+
"""Relative attention score for the positional encodings"""
|
|
45
|
+
# q_head has shape batch_size x sea_len x n_head x d_head
|
|
46
|
+
if self.config.attention_type == "factorized":
|
|
47
|
+
phi, pi, psi, omega = position_embeds
|
|
48
|
+
# Shape n_head x d_head
|
|
49
|
+
u = self.r_r_bias * self.scale
|
|
50
|
+
# Shape d_model x n_head x d_head
|
|
51
|
+
w_r = self.r_kernel
|
|
52
|
+
|
|
53
|
+
# Shape batch_size x sea_len x n_head x d_model
|
|
54
|
+
q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
|
|
55
|
+
q_r_attention_1 = q_r_attention * phi[:, None]
|
|
56
|
+
q_r_attention_2 = q_r_attention * pi[:, None]
|
|
57
|
+
|
|
58
|
+
# Shape batch_size x n_head x seq_len x context_len
|
|
59
|
+
positional_attn = torch.einsum(
|
|
60
|
+
"bind,jd->bnij", q_r_attention_1, psi
|
|
61
|
+
) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega)
|
|
62
|
+
else:
|
|
63
|
+
shift = 2 if q_head.shape[1] != context_len else 1
|
|
64
|
+
r = position_embeds[self.block_index][shift - 1]
|
|
65
|
+
# Shape n_head x d_head
|
|
66
|
+
v = self.r_r_bias * self.scale
|
|
67
|
+
# Shape d_model x n_head x d_head
|
|
68
|
+
w_r = self.r_kernel
|
|
69
|
+
|
|
70
|
+
# Shape max_rel_len x n_head x d_model
|
|
71
|
+
r_head = torch.einsum("td,dnh->tnh", r, w_r)
|
|
72
|
+
# Shape batch_size x n_head x seq_len x max_rel_len
|
|
73
|
+
positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
|
|
74
|
+
# Shape batch_size x n_head x seq_len x context_len
|
|
75
|
+
positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
|
|
76
|
+
|
|
77
|
+
if cls_mask is not None:
|
|
78
|
+
# PATCHED
|
|
79
|
+
positional_attn = positional_attn * cls_mask
|
|
80
|
+
return positional_attn
|
|
@@ -256,8 +256,23 @@ if patch_qwen2_5:
|
|
|
256
256
|
return attn_output
|
|
257
257
|
|
|
258
258
|
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
259
|
-
|
|
260
|
-
|
|
259
|
+
import onnx_ir
|
|
260
|
+
|
|
261
|
+
first_float_tensor = next(
|
|
262
|
+
a
|
|
263
|
+
for a in args
|
|
264
|
+
if a is not None
|
|
265
|
+
and a.dtype
|
|
266
|
+
in {
|
|
267
|
+
torch.float16,
|
|
268
|
+
torch.float32,
|
|
269
|
+
torch.bfloat16,
|
|
270
|
+
onnx_ir.DataType.BFLOAT16,
|
|
271
|
+
onnx_ir.DataType.FLOAT16,
|
|
272
|
+
onnx_ir.DataType.FLOAT,
|
|
273
|
+
}
|
|
274
|
+
)
|
|
275
|
+
dtype = first_float_tensor.dtype
|
|
261
276
|
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
|
|
262
277
|
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
263
278
|
if strategy is not None:
|
|
@@ -269,7 +284,7 @@ if patch_qwen2_5:
|
|
|
269
284
|
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
|
|
270
285
|
# first_tensor may be a SymbolicTensor (onnx).
|
|
271
286
|
# is_cuda is not available.
|
|
272
|
-
if hasattr(
|
|
287
|
+
if hasattr(first_float_tensor, "is_cuda") and first_float_tensor.is_cuda:
|
|
273
288
|
return "PACKED", itype
|
|
274
289
|
return "LOOPMHA", itype
|
|
275
290
|
raise AssertionError(
|
|
@@ -733,3 +748,71 @@ if patch_qwen2_5:
|
|
|
733
748
|
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
|
734
749
|
attn_output = self.proj(attn_output)
|
|
735
750
|
return attn_output
|
|
751
|
+
|
|
752
|
+
class patched_Qwen2_5_VLModel:
|
|
753
|
+
_PATCHES_ = ["get_placeholder_mask"]
|
|
754
|
+
_PATCHED_CLASS_ = transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLModel
|
|
755
|
+
|
|
756
|
+
def get_placeholder_mask(
|
|
757
|
+
self,
|
|
758
|
+
input_ids: torch.LongTensor,
|
|
759
|
+
inputs_embeds: torch.FloatTensor,
|
|
760
|
+
image_features: Optional[torch.FloatTensor] = None,
|
|
761
|
+
video_features: Optional[torch.FloatTensor] = None,
|
|
762
|
+
):
|
|
763
|
+
if input_ids is None:
|
|
764
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
765
|
+
torch.tensor(
|
|
766
|
+
self.config.image_token_id,
|
|
767
|
+
dtype=torch.long,
|
|
768
|
+
device=inputs_embeds.device,
|
|
769
|
+
)
|
|
770
|
+
)
|
|
771
|
+
special_image_mask = special_image_mask.all(-1)
|
|
772
|
+
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
|
773
|
+
torch.tensor(
|
|
774
|
+
self.config.video_token_id,
|
|
775
|
+
dtype=torch.long,
|
|
776
|
+
device=inputs_embeds.device,
|
|
777
|
+
)
|
|
778
|
+
)
|
|
779
|
+
special_video_mask = special_video_mask.all(-1)
|
|
780
|
+
else:
|
|
781
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
782
|
+
special_video_mask = input_ids == self.config.video_token_id
|
|
783
|
+
|
|
784
|
+
special_image_mask = (
|
|
785
|
+
special_image_mask.unsqueeze(-1)
|
|
786
|
+
.expand_as(inputs_embeds)
|
|
787
|
+
.to(inputs_embeds.device)
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# PATCHED: we should use torch._check
|
|
791
|
+
# but this fails for compilation. It cannot be verified with FakeTensors
|
|
792
|
+
# torch._check(
|
|
793
|
+
# image_features is None
|
|
794
|
+
# or inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
795
|
+
# lambda: (
|
|
796
|
+
# f"Image features and image tokens do not match: tokens: "
|
|
797
|
+
# f"{special_image_mask.sum()}, features {image_features.shape[0]}"
|
|
798
|
+
# ),
|
|
799
|
+
# )
|
|
800
|
+
|
|
801
|
+
special_video_mask = (
|
|
802
|
+
special_video_mask.unsqueeze(-1)
|
|
803
|
+
.expand_as(inputs_embeds)
|
|
804
|
+
.to(inputs_embeds.device)
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
# PATCHED: we should use torch._check
|
|
808
|
+
# but this fails for compilation. It cannot be verified with FakeTensors
|
|
809
|
+
# torch._check(
|
|
810
|
+
# video_features is None
|
|
811
|
+
# or inputs_embeds[special_video_mask].numel() == video_features.numel(),
|
|
812
|
+
# lambda: (
|
|
813
|
+
# f"Videos features and video tokens do not match: tokens: "
|
|
814
|
+
# f"{special_video_mask.sum()}, features {video_features.shape[0]}"
|
|
815
|
+
# ),
|
|
816
|
+
# )
|
|
817
|
+
|
|
818
|
+
return special_image_mask, special_video_mask
|
|
@@ -5,6 +5,7 @@ import os
|
|
|
5
5
|
import traceback
|
|
6
6
|
from functools import reduce
|
|
7
7
|
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
|
|
8
|
+
import sympy
|
|
8
9
|
import torch
|
|
9
10
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
10
11
|
|
|
@@ -1091,3 +1092,17 @@ def patched__broadcast_in_dim_meta_level_2(
|
|
|
1091
1092
|
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1092
1093
|
|
|
1093
1094
|
return a.as_strided(shape, new_strides, a.storage_offset())
|
|
1095
|
+
|
|
1096
|
+
|
|
1097
|
+
class patched_DynamicDimConstraintPrinter:
|
|
1098
|
+
"""
|
|
1099
|
+
Patches
|
|
1100
|
+
``torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol``.
|
|
1101
|
+
Valid for ``torch>=2.10``.
|
|
1102
|
+
"""
|
|
1103
|
+
|
|
1104
|
+
def _print_Symbol(self, expr: sympy.Symbol) -> str:
|
|
1105
|
+
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
|
1106
|
+
if self.symbol_to_source.get(expr):
|
|
1107
|
+
return self.symbol_to_source[expr][0].name
|
|
1108
|
+
return str(expr)
|
|
@@ -1,29 +1,37 @@
|
|
|
1
1
|
# transformers
|
|
2
2
|
from typing import List
|
|
3
3
|
from .patch_helper import _has_transformers
|
|
4
|
-
|
|
5
4
|
from ._patch_transformers_attention import (
|
|
6
5
|
patched_sdpa_attention_forward,
|
|
7
6
|
patched_model_bart_eager_attention_forward,
|
|
8
7
|
patched_modeling_marian_eager_attention_forward,
|
|
9
8
|
)
|
|
9
|
+
from ._patch_transformers_generation_mixin import patched_GenerationMixin
|
|
10
|
+
from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
|
|
11
|
+
from ._patch_transformers_rotary_embedding import (
|
|
12
|
+
patched__compute_dynamic_ntk_parameters,
|
|
13
|
+
patched_dynamic_rope_update,
|
|
14
|
+
patched_GemmaRotaryEmbedding,
|
|
15
|
+
patched_LlamaRotaryEmbedding,
|
|
16
|
+
patched_MistralRotaryEmbedding,
|
|
17
|
+
patched_MixtralRotaryEmbedding,
|
|
18
|
+
patched_PhiRotaryEmbedding,
|
|
19
|
+
)
|
|
20
|
+
from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
|
|
21
|
+
from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
|
|
22
|
+
|
|
23
|
+
# transformers dependent patches
|
|
10
24
|
|
|
11
25
|
from ._patch_transformers_cache_utils import patch_parse_processor_args
|
|
12
26
|
|
|
13
27
|
if patch_parse_processor_args:
|
|
14
28
|
from ._patch_transformers_cache_utils import patched_parse_processor_args
|
|
15
|
-
|
|
16
|
-
from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
|
|
17
|
-
|
|
18
29
|
from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache
|
|
19
30
|
|
|
20
31
|
if patch_DynamicLayer:
|
|
21
32
|
from ._patch_transformers_dynamic_cache import patched_DynamicLayer
|
|
22
33
|
if patch_DynamicCache:
|
|
23
34
|
from ._patch_transformers_dynamic_cache import patched_DynamicCache
|
|
24
|
-
|
|
25
|
-
from ._patch_transformers_generation_mixin import patched_GenerationMixin
|
|
26
|
-
|
|
27
35
|
from ._patch_transformers_masking_utils import patch_masking_utils
|
|
28
36
|
|
|
29
37
|
if patch_masking_utils:
|
|
@@ -33,15 +41,7 @@ if patch_masking_utils:
|
|
|
33
41
|
patched_sdpa_mask_recent_torch,
|
|
34
42
|
)
|
|
35
43
|
|
|
36
|
-
|
|
37
|
-
patched__compute_dynamic_ntk_parameters,
|
|
38
|
-
patched_dynamic_rope_update,
|
|
39
|
-
patched_GemmaRotaryEmbedding,
|
|
40
|
-
patched_LlamaRotaryEmbedding,
|
|
41
|
-
patched_MistralRotaryEmbedding,
|
|
42
|
-
patched_MixtralRotaryEmbedding,
|
|
43
|
-
patched_PhiRotaryEmbedding,
|
|
44
|
-
)
|
|
44
|
+
# transformers models dependent patches
|
|
45
45
|
|
|
46
46
|
if _has_transformers("4.51"):
|
|
47
47
|
from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding
|
|
@@ -54,16 +54,11 @@ if _has_transformers("4.52"):
|
|
|
54
54
|
if _has_transformers("4.53"):
|
|
55
55
|
from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding
|
|
56
56
|
|
|
57
|
-
# Models
|
|
58
|
-
|
|
59
57
|
from ._patch_transformers_gemma3 import patch_gemma3
|
|
60
58
|
|
|
61
59
|
if patch_gemma3:
|
|
62
60
|
from ._patch_transformers_gemma3 import patched_Gemma3Model
|
|
63
61
|
|
|
64
|
-
from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
|
|
65
|
-
|
|
66
|
-
|
|
67
62
|
from ._patch_transformers_qwen2 import patch_qwen2
|
|
68
63
|
|
|
69
64
|
if patch_qwen2:
|
|
@@ -77,16 +72,20 @@ if patch_qwen2_5:
|
|
|
77
72
|
patched_Qwen2_5_VisionTransformerPretrainedModel,
|
|
78
73
|
patched_Qwen2_5_VLVisionAttentionOneIteration,
|
|
79
74
|
patched_Qwen2_5_VLVisionAttention,
|
|
75
|
+
patched_Qwen2_5_VLModel,
|
|
80
76
|
PLUGS as PLUGS_Qwen25,
|
|
81
77
|
)
|
|
82
|
-
|
|
83
78
|
from ._patch_transformers_qwen3 import patch_qwen3
|
|
84
79
|
|
|
85
80
|
if patch_qwen3:
|
|
86
81
|
from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock
|
|
82
|
+
from ._patch_transformers_funnel import patch_funnel
|
|
87
83
|
|
|
88
|
-
|
|
89
|
-
from .
|
|
84
|
+
if patch_funnel:
|
|
85
|
+
from ._patch_transformers_funnel import (
|
|
86
|
+
patched_FunnelAttentionStructure,
|
|
87
|
+
patched_FunnelRelMultiheadAttention,
|
|
88
|
+
)
|
|
90
89
|
|
|
91
90
|
|
|
92
91
|
def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821
|
|
@@ -184,7 +184,18 @@ def _trygetattr(config, attname):
|
|
|
184
184
|
return None
|
|
185
185
|
|
|
186
186
|
|
|
187
|
+
def rewrite_architecture_name(name: Optional[str]) -> Optional[str]:
|
|
188
|
+
if name == "ConditionalDETRForObjectDetection":
|
|
189
|
+
return "ConditionalDetrForObjectDetection"
|
|
190
|
+
return name
|
|
191
|
+
|
|
192
|
+
|
|
187
193
|
def architecture_from_config(config) -> Optional[str]:
|
|
194
|
+
"""Guesses the architecture (class) of the model described by this config."""
|
|
195
|
+
return rewrite_architecture_name(_architecture_from_config(config))
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _architecture_from_config(config) -> Optional[str]:
|
|
188
199
|
"""Guesses the architecture (class) of the model described by this config."""
|
|
189
200
|
if isinstance(config, dict):
|
|
190
201
|
if "_class_name" in config:
|
|
@@ -5,7 +5,10 @@ from typing import Dict, List
|
|
|
5
5
|
|
|
6
6
|
__date__ = "2025-06-21"
|
|
7
7
|
|
|
8
|
-
__data_arch_values__ = {
|
|
8
|
+
__data_arch_values__ = {
|
|
9
|
+
"ConditionalDETRForObjectDetection": dict(image_size=224),
|
|
10
|
+
"ResNetForImageClassification": dict(image_size=224),
|
|
11
|
+
}
|
|
9
12
|
|
|
10
13
|
__data_arch__ = textwrap.dedent(
|
|
11
14
|
"""
|
|
@@ -32,6 +35,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
32
35
|
ConvNextV2Model,image-feature-extraction
|
|
33
36
|
CosmosTransformer3DModel,image-to-video
|
|
34
37
|
CvtModel,feature-extraction
|
|
38
|
+
ClvpModelForConditionalGeneration,audio-feature-extraction
|
|
35
39
|
DPTModel,image-feature-extraction
|
|
36
40
|
Data2VecAudioModel,feature-extraction
|
|
37
41
|
Data2VecTextModel,feature-extraction
|
|
@@ -49,6 +53,8 @@ __data_arch__ = textwrap.dedent(
|
|
|
49
53
|
ElectraModel,feature-extraction
|
|
50
54
|
EsmModel,feature-extraction
|
|
51
55
|
FalconMambaForCausalLM,text-generation
|
|
56
|
+
FunnelBaseModel,feature-extraction
|
|
57
|
+
FuyuForCausalLM,image-text-to-text
|
|
52
58
|
GLPNModel,image-feature-extraction
|
|
53
59
|
GPT2LMHeadModel,text-generation
|
|
54
60
|
GPTBigCodeModel,feature-extraction
|
|
@@ -63,6 +69,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
63
69
|
Glm4vMoeForConditionalGeneration,image-text-to-text
|
|
64
70
|
GraniteForCausalLM,text-generation
|
|
65
71
|
GroupViTModel,feature-extraction
|
|
72
|
+
HeliumForCausalLM,text-generation
|
|
66
73
|
HieraForImageClassification,image-classification
|
|
67
74
|
HubertModel,feature-extraction
|
|
68
75
|
IBertModel,feature-extraction
|
|
@@ -136,6 +143,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
136
143
|
SwinModel,image-feature-extraction
|
|
137
144
|
Swinv2Model,image-feature-extraction
|
|
138
145
|
T5ForConditionalGeneration,text2text-generation
|
|
146
|
+
T5GemmaForConditionalGeneration,text2text-generation
|
|
139
147
|
TableTransformerModel,image-feature-extraction
|
|
140
148
|
TableTransformerForObjectDetection,object-detection
|
|
141
149
|
UNet2DConditionModel,text-to-image
|
|
@@ -55,6 +55,7 @@ Automatically generated:
|
|
|
55
55
|
import base64
|
|
56
56
|
import json
|
|
57
57
|
import textwrap
|
|
58
|
+
from typing import Any
|
|
58
59
|
import transformers
|
|
59
60
|
|
|
60
61
|
null = None
|
|
@@ -62,6 +63,22 @@ true = True
|
|
|
62
63
|
false = False
|
|
63
64
|
|
|
64
65
|
|
|
66
|
+
def _enforce_default(config_type: type, **kwargs) -> Any:
|
|
67
|
+
config = config_type(**kwargs)
|
|
68
|
+
for name in [
|
|
69
|
+
*[k for k in kwargs if k.endswith("_token_id")],
|
|
70
|
+
"attention_dropout",
|
|
71
|
+
"hidden_size",
|
|
72
|
+
"hidden_act",
|
|
73
|
+
"intermediate_size",
|
|
74
|
+
"max_position_embeddings",
|
|
75
|
+
"vocab_size",
|
|
76
|
+
]:
|
|
77
|
+
if name in kwargs and (not hasattr(config, name) or getattr(config, name) is None):
|
|
78
|
+
setattr(config, name, kwargs[name])
|
|
79
|
+
return config
|
|
80
|
+
|
|
81
|
+
|
|
65
82
|
def _ccached_arnir0_tiny_LLM():
|
|
66
83
|
"arnir0/Tiny-LLM"
|
|
67
84
|
return transformers.LlamaConfig(
|
|
@@ -4691,7 +4708,8 @@ def _ccached_zai_glm_45():
|
|
|
4691
4708
|
|
|
4692
4709
|
def _ccached_microsoft_phi3_mini_128k_instruct():
|
|
4693
4710
|
"microsoft/Phi-3-mini-128k-instruct"
|
|
4694
|
-
return
|
|
4711
|
+
return _enforce_default(
|
|
4712
|
+
transformers.Phi3Config,
|
|
4695
4713
|
**{
|
|
4696
4714
|
"_name_or_path": "Phi-3-mini-128k-instruct",
|
|
4697
4715
|
"architectures": ["Phi3ForCausalLM"],
|
|
@@ -4827,13 +4845,14 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
|
|
|
4827
4845
|
"use_cache": true,
|
|
4828
4846
|
"attention_bias": false,
|
|
4829
4847
|
"vocab_size": 32064,
|
|
4830
|
-
}
|
|
4848
|
+
},
|
|
4831
4849
|
)
|
|
4832
4850
|
|
|
4833
4851
|
|
|
4834
4852
|
def _ccached_google_gemma_3_4b_it_like():
|
|
4835
4853
|
"google/gemma-3-4b-it"
|
|
4836
|
-
return
|
|
4854
|
+
return _enforce_default(
|
|
4855
|
+
transformers.Gemma3Config,
|
|
4837
4856
|
**{
|
|
4838
4857
|
"architectures": ["Gemma3ForConditionalGeneration"],
|
|
4839
4858
|
"boi_token_index": 255999,
|
|
@@ -4863,13 +4882,14 @@ def _ccached_google_gemma_3_4b_it_like():
|
|
|
4863
4882
|
"patch_size": 14,
|
|
4864
4883
|
"vision_use_head": false,
|
|
4865
4884
|
},
|
|
4866
|
-
}
|
|
4885
|
+
},
|
|
4867
4886
|
)
|
|
4868
4887
|
|
|
4869
4888
|
|
|
4870
4889
|
def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
|
|
4871
4890
|
"hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
|
4872
|
-
return
|
|
4891
|
+
return _enforce_default(
|
|
4892
|
+
transformers.Gemma3TextConfig,
|
|
4873
4893
|
**{
|
|
4874
4894
|
"architectures": ["Gemma3ForCausalLM"],
|
|
4875
4895
|
"attention_bias": false,
|
|
@@ -4901,13 +4921,14 @@ def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
|
|
|
4901
4921
|
"transformers_version": "4.52.0.dev0",
|
|
4902
4922
|
"use_cache": true,
|
|
4903
4923
|
"vocab_size": 262144,
|
|
4904
|
-
}
|
|
4924
|
+
},
|
|
4905
4925
|
)
|
|
4906
4926
|
|
|
4907
4927
|
|
|
4908
4928
|
def _ccached_qwen_qwen2_5_vl_7b_instruct():
|
|
4909
4929
|
"Qwen/Qwen2.5-VL-7B-Instruct"
|
|
4910
|
-
return
|
|
4930
|
+
return _enforce_default(
|
|
4931
|
+
transformers.Qwen2_5_VLConfig,
|
|
4911
4932
|
**{
|
|
4912
4933
|
"architectures": ["Qwen2_5_VLForConditionalGeneration"],
|
|
4913
4934
|
"attention_dropout": 0.0,
|
|
@@ -4954,5 +4975,5 @@ def _ccached_qwen_qwen2_5_vl_7b_instruct():
|
|
|
4954
4975
|
},
|
|
4955
4976
|
"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]},
|
|
4956
4977
|
"vocab_size": 152064,
|
|
4957
|
-
}
|
|
4978
|
+
},
|
|
4958
4979
|
)
|
|
@@ -64,6 +64,7 @@ def get_untrained_model_with_inputs(
|
|
|
64
64
|
use_only_preinstalled: bool = False,
|
|
65
65
|
config_reduction: Optional[Callable[[Any, str], Dict]] = None,
|
|
66
66
|
submodule: Optional[str] = None,
|
|
67
|
+
skip_inputs: bool = False,
|
|
67
68
|
) -> Dict[str, Any]:
|
|
68
69
|
"""
|
|
69
70
|
Gets a non initialized model similar to the original model
|
|
@@ -93,6 +94,7 @@ def get_untrained_model_with_inputs(
|
|
|
93
94
|
this function takes a configuration and a task (string)
|
|
94
95
|
as arguments
|
|
95
96
|
:param submodule: use a submodule instead of the main model
|
|
97
|
+
:param skip_inputs: do not generate the inputs
|
|
96
98
|
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
|
|
97
99
|
some necessary rewriting as well
|
|
98
100
|
|
|
@@ -332,13 +334,12 @@ def get_untrained_model_with_inputs(
|
|
|
332
334
|
f"[get_untrained_model_with_inputs] "
|
|
333
335
|
f"instantiate_specific_model(2) {cls_model}"
|
|
334
336
|
)
|
|
335
|
-
|
|
336
337
|
try:
|
|
337
338
|
if type(config) is dict:
|
|
338
339
|
model = cls_model(**config)
|
|
339
340
|
else:
|
|
340
341
|
model = cls_model(config)
|
|
341
|
-
except RuntimeError as e:
|
|
342
|
+
except (RuntimeError, AttributeError, ValueError) as e:
|
|
342
343
|
raise RuntimeError(
|
|
343
344
|
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
|
|
344
345
|
) from e
|
|
@@ -350,23 +351,27 @@ def get_untrained_model_with_inputs(
|
|
|
350
351
|
)
|
|
351
352
|
|
|
352
353
|
# input kwargs
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
354
|
+
if not skip_inputs:
|
|
355
|
+
seed = int(os.environ.get("SEED", "17")) + 1
|
|
356
|
+
torch.manual_seed(seed)
|
|
357
|
+
kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
|
|
358
|
+
if verbose:
|
|
359
|
+
print(f"[get_untrained_model_with_inputs] use fct={fct}")
|
|
360
|
+
if os.environ.get("PRINT_CONFIG") in (1, "1"):
|
|
361
|
+
print(f"-- input kwargs for task {task!r}")
|
|
362
|
+
pprint.pprint(kwargs)
|
|
363
|
+
if inputs_kwargs:
|
|
364
|
+
kwargs.update(inputs_kwargs)
|
|
365
|
+
|
|
366
|
+
# This line is important. Some models may produce different
|
|
367
|
+
# outputs even with the same inputs in training mode.
|
|
368
|
+
model.eval() # type: ignore[union-attr]
|
|
369
|
+
res = fct(model, config, add_second_input=add_second_input, **kwargs)
|
|
370
|
+
|
|
371
|
+
res["input_kwargs"] = kwargs
|
|
372
|
+
else:
|
|
373
|
+
res = {}
|
|
374
|
+
|
|
370
375
|
res["model_kwargs"] = mkwargs
|
|
371
376
|
if diff_config is not None:
|
|
372
377
|
res["dump_info"] = dict(config_diff=diff_config)
|