onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +67 -9
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +430 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +560 -0
- onnx_diagnostic/export/api.py +15 -4
- onnx_diagnostic/export/cf_simple_loop_for.py +352 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/export/onnx_plug.py +60 -6
- onnx_diagnostic/ext_test_case.py +14 -0
- onnx_diagnostic/helpers/helper.py +26 -27
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +16 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +103 -31
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/RECORD +22 -19
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import ast
|
|
|
2
2
|
import enum
|
|
3
3
|
import inspect
|
|
4
4
|
import itertools
|
|
5
|
+
import json
|
|
5
6
|
from dataclasses import is_dataclass, fields
|
|
6
7
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
7
8
|
import numpy as np
|
|
@@ -1373,11 +1374,7 @@ def max_diff(
|
|
|
1373
1374
|
if hist:
|
|
1374
1375
|
if isinstance(hist, bool):
|
|
1375
1376
|
hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
|
|
1376
|
-
|
|
1377
|
-
cou = np.bincount(ind, minlength=ind.shape[0] + 1)
|
|
1378
|
-
res["rep"] = dict(
|
|
1379
|
-
zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))])
|
|
1380
|
-
)
|
|
1377
|
+
res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
|
|
1381
1378
|
return res # type: ignore
|
|
1382
1379
|
|
|
1383
1380
|
if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
|
|
@@ -1493,27 +1490,11 @@ def max_diff(
|
|
|
1493
1490
|
dev=dev,
|
|
1494
1491
|
)
|
|
1495
1492
|
if hist:
|
|
1496
|
-
if isinstance(hist,
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
res["rep"] = {
|
|
1500
|
-
f">{hist[0]}": (diff > hist[0]).sum().item(),
|
|
1501
|
-
f">{hist[1]}": (diff > hist[1]).sum().item(),
|
|
1502
|
-
}
|
|
1503
|
-
else:
|
|
1504
|
-
if isinstance(hist, bool):
|
|
1505
|
-
hist = torch.tensor(
|
|
1506
|
-
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
|
|
1507
|
-
)
|
|
1508
|
-
hist = torch.tensor(hist).to(diff.device)
|
|
1509
|
-
ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
|
|
1510
|
-
cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
|
|
1511
|
-
res["rep"] = dict(
|
|
1512
|
-
zip(
|
|
1513
|
-
[f">{x}" for x in hist],
|
|
1514
|
-
[int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
|
|
1515
|
-
)
|
|
1493
|
+
if isinstance(hist, bool):
|
|
1494
|
+
hist = torch.tensor(
|
|
1495
|
+
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
|
|
1516
1496
|
)
|
|
1497
|
+
res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
|
|
1517
1498
|
return res # type: ignore
|
|
1518
1499
|
|
|
1519
1500
|
if isinstance(expected, int) and isinstance(got, torch.Tensor):
|
|
@@ -1750,8 +1731,26 @@ def max_diff(
|
|
|
1750
1731
|
)
|
|
1751
1732
|
|
|
1752
1733
|
|
|
1753
|
-
def string_diff(diff: Dict[str, Any]) -> str:
|
|
1754
|
-
"""
|
|
1734
|
+
def string_diff(diff: Dict[str, Any], js: bool = False, ratio: bool = False, **kwargs) -> str:
|
|
1735
|
+
"""
|
|
1736
|
+
Renders discrepancies return by :func:`max_diff` into one string.
|
|
1737
|
+
|
|
1738
|
+
:param diff: differences
|
|
1739
|
+
:param js: json format
|
|
1740
|
+
:param ratio: display mismatch ratio
|
|
1741
|
+
:param kwargs: addition values to add in the json format
|
|
1742
|
+
"""
|
|
1743
|
+
if js:
|
|
1744
|
+
if "rep" in diff:
|
|
1745
|
+
rep = diff["rep"]
|
|
1746
|
+
diff = {**{k: v for k, v in diff.items() if k != "rep"}, **rep}
|
|
1747
|
+
if ratio:
|
|
1748
|
+
for k, v in rep.items():
|
|
1749
|
+
diff[f"%{k}"] = v / diff["n"]
|
|
1750
|
+
diff["mean"] = diff["sum"] / diff["n"]
|
|
1751
|
+
diff.update(kwargs)
|
|
1752
|
+
return json.dumps(diff)
|
|
1753
|
+
|
|
1755
1754
|
# dict(abs=, rel=, sum=, n=n_diff, dnan=)
|
|
1756
1755
|
if "dev" in diff:
|
|
1757
1756
|
ddiff = {k: v for k, v in diff.items() if k != "dev"}
|
|
@@ -818,6 +818,7 @@ def torch_export_patches(
|
|
|
818
818
|
rewrite: Optional[List[Callable]] = None,
|
|
819
819
|
dump_rewriting: Optional[str] = None,
|
|
820
820
|
patch_details: Optional[PatchDetails] = None,
|
|
821
|
+
profile: Optional[str] = None,
|
|
821
822
|
) -> Callable:
|
|
822
823
|
"""
|
|
823
824
|
Tries to bypass some situations :func:`torch.export.export` does not support.
|
|
@@ -850,6 +851,8 @@ def torch_export_patches(
|
|
|
850
851
|
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
|
|
851
852
|
:param patch_details: if specified, this class is used to stored every rewritten done.
|
|
852
853
|
:param verbose: to show which patches is applied
|
|
854
|
+
:param profile: starts profiling whatever is called inside the context manager,
|
|
855
|
+
output the profiling into a text file
|
|
853
856
|
|
|
854
857
|
The list of available patches.
|
|
855
858
|
|
|
@@ -1017,10 +1020,23 @@ def torch_export_patches(
|
|
|
1017
1020
|
if verbose:
|
|
1018
1021
|
print("[torch_export_patches] done patching")
|
|
1019
1022
|
|
|
1023
|
+
if profile:
|
|
1024
|
+
from pyinstrument import Profiler
|
|
1025
|
+
|
|
1026
|
+
profiler = Profiler()
|
|
1027
|
+
profiler.start()
|
|
1028
|
+
else:
|
|
1029
|
+
profiler = None
|
|
1030
|
+
|
|
1020
1031
|
try:
|
|
1021
1032
|
yield fct_callable
|
|
1022
1033
|
finally:
|
|
1023
1034
|
|
|
1035
|
+
if profiler:
|
|
1036
|
+
profiler.stop()
|
|
1037
|
+
with open(profile, "w") as f:
|
|
1038
|
+
f.write(profiler.output_html())
|
|
1039
|
+
|
|
1024
1040
|
# unpatch
|
|
1025
1041
|
|
|
1026
1042
|
if verbose:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
from typing import Callable, List, Optional, Tuple
|
|
2
3
|
import torch
|
|
3
4
|
|
|
@@ -19,6 +20,12 @@ if patch_masking_utils:
|
|
|
19
20
|
prepare_padding_mask,
|
|
20
21
|
)
|
|
21
22
|
|
|
23
|
+
_prepare_padding_mask_kwargs = (
|
|
24
|
+
dict(_slice=False)
|
|
25
|
+
if "_slice" in inspect.signature(prepare_padding_mask).parameters
|
|
26
|
+
else {}
|
|
27
|
+
)
|
|
28
|
+
|
|
22
29
|
try:
|
|
23
30
|
# transformers>=5.0
|
|
24
31
|
from transformers.masking_utils import (
|
|
@@ -132,7 +139,9 @@ if patch_masking_utils:
|
|
|
132
139
|
) -> Optional[torch.Tensor]:
|
|
133
140
|
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
134
141
|
q_length = cache_position.shape[0]
|
|
135
|
-
padding_mask = prepare_padding_mask(
|
|
142
|
+
padding_mask = prepare_padding_mask(
|
|
143
|
+
attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
|
|
144
|
+
)
|
|
136
145
|
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
137
146
|
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
138
147
|
):
|
|
@@ -24,7 +24,7 @@ if patch_qwen2_5:
|
|
|
24
24
|
|
|
25
25
|
onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
|
|
26
26
|
op = onnxscript.opset22
|
|
27
|
-
|
|
27
|
+
op23 = onnxscript.onnx_opset.opset23
|
|
28
28
|
msft_op = onnxscript.values.Opset("com.microsoft", 1)
|
|
29
29
|
STOPAT = (
|
|
30
30
|
int(os.environ.get("STOPAT", None))
|
|
@@ -101,7 +101,7 @@ if patch_qwen2_5:
|
|
|
101
101
|
return attn_output_4d
|
|
102
102
|
|
|
103
103
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
104
|
-
def
|
|
104
|
+
def LoopAttention23(
|
|
105
105
|
query_states,
|
|
106
106
|
key_states,
|
|
107
107
|
value_states,
|
|
@@ -109,26 +109,26 @@ if patch_qwen2_5:
|
|
|
109
109
|
scaling: float = 0.11180339887498948,
|
|
110
110
|
num_heads: int = 16,
|
|
111
111
|
):
|
|
112
|
-
to_3d_shape =
|
|
113
|
-
query_transposed =
|
|
114
|
-
output_shape =
|
|
115
|
-
query_3d =
|
|
116
|
-
value_3d =
|
|
117
|
-
key_3d =
|
|
118
|
-
cu_seqlens =
|
|
119
|
-
num_patches =
|
|
120
|
-
seq_axis =
|
|
121
|
-
seq_axis_int32 =
|
|
122
|
-
seq_attn =
|
|
112
|
+
to_3d_shape = op23.Constant(value_ints=[0, 0, -1])
|
|
113
|
+
query_transposed = op23.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
114
|
+
output_shape = op23.Shape(query_transposed)
|
|
115
|
+
query_3d = op23.Reshape(query_transposed, to_3d_shape)
|
|
116
|
+
value_3d = op23.Reshape(op23.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
117
|
+
key_3d = op23.Reshape(op23.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
118
|
+
cu_seqlens = op23.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
|
|
119
|
+
num_patches = op23.Size(cu_seqlens) - 1
|
|
120
|
+
seq_axis = op23.Constant(value_ints=[1])
|
|
121
|
+
seq_axis_int32 = op23.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
122
|
+
seq_attn = op23.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
123
123
|
for i_patch in range(num_patches):
|
|
124
|
-
i_1d =
|
|
124
|
+
i_1d = op23.Reshape(i_patch, [1])
|
|
125
125
|
i_plus_1_1d = i_1d + 1
|
|
126
|
-
start =
|
|
127
|
-
end =
|
|
128
|
-
query_i =
|
|
129
|
-
key_i =
|
|
130
|
-
value_i =
|
|
131
|
-
mha_output =
|
|
126
|
+
start = op23.Gather(cu_seqlens, i_1d, axis=0)
|
|
127
|
+
end = op23.Gather(cu_seqlens, i_plus_1_1d, axis=0)
|
|
128
|
+
query_i = op23.Slice(query_3d, start, end, seq_axis_int32)
|
|
129
|
+
key_i = op23.Slice(key_3d, start, end, seq_axis_int32)
|
|
130
|
+
value_i = op23.Slice(value_3d, start, end, seq_axis_int32)
|
|
131
|
+
mha_output = op23.Attention(
|
|
132
132
|
query_i,
|
|
133
133
|
key_i,
|
|
134
134
|
value_i,
|
|
@@ -137,9 +137,9 @@ if patch_qwen2_5:
|
|
|
137
137
|
kv_num_heads=num_heads,
|
|
138
138
|
softmax_precision=onnx.TensorProto.FLOAT,
|
|
139
139
|
)
|
|
140
|
-
seq_attn =
|
|
141
|
-
attn_output =
|
|
142
|
-
attn_output_4d =
|
|
140
|
+
seq_attn = op23.SequenceInsert(seq_attn, mha_output)
|
|
141
|
+
attn_output = op23.ConcatFromSequence(seq_attn, axis=1)
|
|
142
|
+
attn_output_4d = op23.Reshape(attn_output, output_shape)
|
|
143
143
|
return attn_output_4d
|
|
144
144
|
|
|
145
145
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
@@ -256,20 +256,24 @@ if patch_qwen2_5:
|
|
|
256
256
|
return attn_output
|
|
257
257
|
|
|
258
258
|
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
259
|
-
|
|
260
|
-
|
|
259
|
+
first_float_tensor = next(
|
|
260
|
+
a
|
|
261
|
+
for a in args
|
|
262
|
+
if a is not None and a.dtype in {torch.float16, torch.float32, torch.bfloat16}
|
|
263
|
+
)
|
|
264
|
+
dtype = first_float_tensor.dtype
|
|
261
265
|
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
|
|
262
266
|
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
263
267
|
if strategy is not None:
|
|
264
268
|
return strategy, itype
|
|
265
269
|
if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
|
|
266
|
-
if opset >=
|
|
267
|
-
return "
|
|
270
|
+
if opset >= 23:
|
|
271
|
+
return "LOOPA23", itype
|
|
268
272
|
return "LOOPMHA", itype
|
|
269
273
|
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
|
|
270
274
|
# first_tensor may be a SymbolicTensor (onnx).
|
|
271
275
|
# is_cuda is not available.
|
|
272
|
-
if hasattr(
|
|
276
|
+
if hasattr(first_float_tensor, "is_cuda") and first_float_tensor.is_cuda:
|
|
273
277
|
return "PACKED", itype
|
|
274
278
|
return "LOOPMHA", itype
|
|
275
279
|
raise AssertionError(
|
|
@@ -288,9 +292,9 @@ if patch_qwen2_5:
|
|
|
288
292
|
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
289
293
|
PackedAttention.to_function_proto()
|
|
290
294
|
),
|
|
291
|
-
("
|
|
292
|
-
("
|
|
293
|
-
onnx.TensorProto.FLOAT16,
|
|
295
|
+
("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
|
|
296
|
+
("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
297
|
+
onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
|
|
294
298
|
),
|
|
295
299
|
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
296
300
|
LoopMHAAttention.to_function_proto()
|
|
@@ -733,3 +737,71 @@ if patch_qwen2_5:
|
|
|
733
737
|
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
|
734
738
|
attn_output = self.proj(attn_output)
|
|
735
739
|
return attn_output
|
|
740
|
+
|
|
741
|
+
class patched_Qwen2_5_VLModel:
|
|
742
|
+
_PATCHES_ = ["get_placeholder_mask"]
|
|
743
|
+
_PATCHED_CLASS_ = transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLModel
|
|
744
|
+
|
|
745
|
+
def get_placeholder_mask(
|
|
746
|
+
self,
|
|
747
|
+
input_ids: torch.LongTensor,
|
|
748
|
+
inputs_embeds: torch.FloatTensor,
|
|
749
|
+
image_features: Optional[torch.FloatTensor] = None,
|
|
750
|
+
video_features: Optional[torch.FloatTensor] = None,
|
|
751
|
+
):
|
|
752
|
+
if input_ids is None:
|
|
753
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
754
|
+
torch.tensor(
|
|
755
|
+
self.config.image_token_id,
|
|
756
|
+
dtype=torch.long,
|
|
757
|
+
device=inputs_embeds.device,
|
|
758
|
+
)
|
|
759
|
+
)
|
|
760
|
+
special_image_mask = special_image_mask.all(-1)
|
|
761
|
+
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
|
762
|
+
torch.tensor(
|
|
763
|
+
self.config.video_token_id,
|
|
764
|
+
dtype=torch.long,
|
|
765
|
+
device=inputs_embeds.device,
|
|
766
|
+
)
|
|
767
|
+
)
|
|
768
|
+
special_video_mask = special_video_mask.all(-1)
|
|
769
|
+
else:
|
|
770
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
771
|
+
special_video_mask = input_ids == self.config.video_token_id
|
|
772
|
+
|
|
773
|
+
special_image_mask = (
|
|
774
|
+
special_image_mask.unsqueeze(-1)
|
|
775
|
+
.expand_as(inputs_embeds)
|
|
776
|
+
.to(inputs_embeds.device)
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# PATCHED: we should use torch._check
|
|
780
|
+
# but this fails for compilation. It cannot be verified with FakeTensors
|
|
781
|
+
# torch._check(
|
|
782
|
+
# image_features is None
|
|
783
|
+
# or inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
784
|
+
# lambda: (
|
|
785
|
+
# f"Image features and image tokens do not match: tokens: "
|
|
786
|
+
# f"{special_image_mask.sum()}, features {image_features.shape[0]}"
|
|
787
|
+
# ),
|
|
788
|
+
# )
|
|
789
|
+
|
|
790
|
+
special_video_mask = (
|
|
791
|
+
special_video_mask.unsqueeze(-1)
|
|
792
|
+
.expand_as(inputs_embeds)
|
|
793
|
+
.to(inputs_embeds.device)
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
# PATCHED: we should use torch._check
|
|
797
|
+
# but this fails for compilation. It cannot be verified with FakeTensors
|
|
798
|
+
# torch._check(
|
|
799
|
+
# video_features is None
|
|
800
|
+
# or inputs_embeds[special_video_mask].numel() == video_features.numel(),
|
|
801
|
+
# lambda: (
|
|
802
|
+
# f"Videos features and video tokens do not match: tokens: "
|
|
803
|
+
# f"{special_video_mask.sum()}, features {video_features.shape[0]}"
|
|
804
|
+
# ),
|
|
805
|
+
# )
|
|
806
|
+
|
|
807
|
+
return special_image_mask, special_video_mask
|
|
@@ -55,6 +55,7 @@ Automatically generated:
|
|
|
55
55
|
import base64
|
|
56
56
|
import json
|
|
57
57
|
import textwrap
|
|
58
|
+
from typing import Any
|
|
58
59
|
import transformers
|
|
59
60
|
|
|
60
61
|
null = None
|
|
@@ -62,6 +63,22 @@ true = True
|
|
|
62
63
|
false = False
|
|
63
64
|
|
|
64
65
|
|
|
66
|
+
def _enforce_default(config_type: type, **kwargs) -> Any:
|
|
67
|
+
config = config_type(**kwargs)
|
|
68
|
+
for name in [
|
|
69
|
+
*[k for k in kwargs if k.endswith("_token_id")],
|
|
70
|
+
"attention_dropout",
|
|
71
|
+
"hidden_size",
|
|
72
|
+
"hidden_act",
|
|
73
|
+
"intermediate_size",
|
|
74
|
+
"max_position_embeddings",
|
|
75
|
+
"vocab_size",
|
|
76
|
+
]:
|
|
77
|
+
if name in kwargs and (not hasattr(config, name) or getattr(config, name) is None):
|
|
78
|
+
setattr(config, name, kwargs[name])
|
|
79
|
+
return config
|
|
80
|
+
|
|
81
|
+
|
|
65
82
|
def _ccached_arnir0_tiny_LLM():
|
|
66
83
|
"arnir0/Tiny-LLM"
|
|
67
84
|
return transformers.LlamaConfig(
|
|
@@ -4691,7 +4708,8 @@ def _ccached_zai_glm_45():
|
|
|
4691
4708
|
|
|
4692
4709
|
def _ccached_microsoft_phi3_mini_128k_instruct():
|
|
4693
4710
|
"microsoft/Phi-3-mini-128k-instruct"
|
|
4694
|
-
return
|
|
4711
|
+
return _enforce_default(
|
|
4712
|
+
transformers.Phi3Config,
|
|
4695
4713
|
**{
|
|
4696
4714
|
"_name_or_path": "Phi-3-mini-128k-instruct",
|
|
4697
4715
|
"architectures": ["Phi3ForCausalLM"],
|
|
@@ -4827,13 +4845,14 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
|
|
|
4827
4845
|
"use_cache": true,
|
|
4828
4846
|
"attention_bias": false,
|
|
4829
4847
|
"vocab_size": 32064,
|
|
4830
|
-
}
|
|
4848
|
+
},
|
|
4831
4849
|
)
|
|
4832
4850
|
|
|
4833
4851
|
|
|
4834
4852
|
def _ccached_google_gemma_3_4b_it_like():
|
|
4835
4853
|
"google/gemma-3-4b-it"
|
|
4836
|
-
return
|
|
4854
|
+
return _enforce_default(
|
|
4855
|
+
transformers.Gemma3Config,
|
|
4837
4856
|
**{
|
|
4838
4857
|
"architectures": ["Gemma3ForConditionalGeneration"],
|
|
4839
4858
|
"boi_token_index": 255999,
|
|
@@ -4863,13 +4882,14 @@ def _ccached_google_gemma_3_4b_it_like():
|
|
|
4863
4882
|
"patch_size": 14,
|
|
4864
4883
|
"vision_use_head": false,
|
|
4865
4884
|
},
|
|
4866
|
-
}
|
|
4885
|
+
},
|
|
4867
4886
|
)
|
|
4868
4887
|
|
|
4869
4888
|
|
|
4870
4889
|
def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
|
|
4871
4890
|
"hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
|
4872
|
-
return
|
|
4891
|
+
return _enforce_default(
|
|
4892
|
+
transformers.Gemma3TextConfig,
|
|
4873
4893
|
**{
|
|
4874
4894
|
"architectures": ["Gemma3ForCausalLM"],
|
|
4875
4895
|
"attention_bias": false,
|
|
@@ -4901,13 +4921,14 @@ def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
|
|
|
4901
4921
|
"transformers_version": "4.52.0.dev0",
|
|
4902
4922
|
"use_cache": true,
|
|
4903
4923
|
"vocab_size": 262144,
|
|
4904
|
-
}
|
|
4924
|
+
},
|
|
4905
4925
|
)
|
|
4906
4926
|
|
|
4907
4927
|
|
|
4908
4928
|
def _ccached_qwen_qwen2_5_vl_7b_instruct():
|
|
4909
4929
|
"Qwen/Qwen2.5-VL-7B-Instruct"
|
|
4910
|
-
return
|
|
4930
|
+
return _enforce_default(
|
|
4931
|
+
transformers.Qwen2_5_VLConfig,
|
|
4911
4932
|
**{
|
|
4912
4933
|
"architectures": ["Qwen2_5_VLForConditionalGeneration"],
|
|
4913
4934
|
"attention_dropout": 0.0,
|
|
@@ -4954,5 +4975,5 @@ def _ccached_qwen_qwen2_5_vl_7b_instruct():
|
|
|
4954
4975
|
},
|
|
4955
4976
|
"rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]},
|
|
4956
4977
|
"vocab_size": 152064,
|
|
4957
|
-
}
|
|
4978
|
+
},
|
|
4958
4979
|
)
|