onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +7 -2
- onnx_diagnostic/export/dynamic_shapes.py +11 -2
- onnx_diagnostic/helpers/helper.py +11 -5
- onnx_diagnostic/helpers/log_helper.py +53 -17
- onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
- onnx_diagnostic/helpers/model_builder_helper.py +1 -0
- onnx_diagnostic/helpers/rt_helper.py +2 -1
- onnx_diagnostic/helpers/torch_helper.py +31 -7
- onnx_diagnostic/reference/torch_evaluator.py +2 -2
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/image_text_to_text.py +256 -141
- onnx_diagnostic/tasks/text_generation.py +30 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +184 -151
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +20 -5
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +52 -20
- onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +540 -10
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +55 -5
- onnx_diagnostic/torch_models/validate.py +116 -50
- onnx_diagnostic/torch_onnx/sbs.py +2 -1
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/METADATA +11 -31
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/RECORD +29 -27
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,20 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import math
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from functools import wraps
|
|
5
|
-
from typing import Callable, List, Optional, Tuple
|
|
6
|
+
from typing import Callable, List, Optional, Tuple, Union
|
|
6
7
|
import packaging.version as pv
|
|
7
8
|
import torch
|
|
8
9
|
import transformers
|
|
9
10
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
10
11
|
from transformers.cache_utils import StaticCache, Cache
|
|
12
|
+
from transformers.generation.utils import (
|
|
13
|
+
GenerateNonBeamOutput,
|
|
14
|
+
GenerationConfig,
|
|
15
|
+
StoppingCriteriaList,
|
|
16
|
+
LogitsProcessorList,
|
|
17
|
+
)
|
|
11
18
|
|
|
12
19
|
try:
|
|
13
20
|
from transformers.cache_utils import parse_processor_args # noqa: F401
|
|
@@ -114,6 +121,7 @@ if patch_masking_utils:
|
|
|
114
121
|
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
115
122
|
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
116
123
|
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
124
|
+
# PATCHED: this line called the patched version of sdpa_mask
|
|
117
125
|
mask = patched_sdpa_mask_recent_torch(
|
|
118
126
|
batch_size=batch_size,
|
|
119
127
|
cache_position=cache_position,
|
|
@@ -126,7 +134,7 @@ if patch_masking_utils:
|
|
|
126
134
|
**kwargs,
|
|
127
135
|
)
|
|
128
136
|
min_dtype = torch.finfo(dtype).min
|
|
129
|
-
#
|
|
137
|
+
# PATCHED: the following line
|
|
130
138
|
# we need 0s where the tokens should be taken into account,
|
|
131
139
|
# and -inf otherwise (mask is already of boolean type)
|
|
132
140
|
# mask =
|
|
@@ -158,6 +166,7 @@ if patch_masking_utils:
|
|
|
158
166
|
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
|
159
167
|
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
|
160
168
|
head_arange = torch.arange(1, device=cache_position.device)
|
|
169
|
+
# PATCHED: this line calls the patched version of vmap_for_bhqkv
|
|
161
170
|
causal_mask = patched__vmap_for_bhqkv(mask_function)(
|
|
162
171
|
batch_arange, head_arange, cache_position, kv_arange
|
|
163
172
|
)
|
|
@@ -214,6 +223,7 @@ if patch_DynamicLayer:
|
|
|
214
223
|
self.dtype, self.device = key_states.dtype, key_states.device
|
|
215
224
|
new_shape = list(key_states.shape)
|
|
216
225
|
new_shape[-2] = 0
|
|
226
|
+
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
|
|
217
227
|
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
218
228
|
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
219
229
|
if patch_is_initialized:
|
|
@@ -248,6 +258,8 @@ def _patch_make_causal_mask(
|
|
|
248
258
|
diagonal = past_key_values_length - sliding_window - 1
|
|
249
259
|
|
|
250
260
|
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
|
261
|
+
# PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
|
|
262
|
+
# and used masked_fill instead of masked_fill_
|
|
251
263
|
# In this case, the current implementation of torch fails (17/12/2024).
|
|
252
264
|
# Try model Phi-3.5-Mini-Instruct.
|
|
253
265
|
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
|
|
@@ -455,7 +467,16 @@ class patched_GenerationMixin:
|
|
|
455
467
|
_PATCHES_ = [
|
|
456
468
|
"_cache_dependant_input_preparation",
|
|
457
469
|
"_cache_dependant_input_preparation_exporting",
|
|
458
|
-
|
|
470
|
+
(
|
|
471
|
+
None
|
|
472
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
473
|
+
else "prepare_inputs_for_generation"
|
|
474
|
+
),
|
|
475
|
+
(
|
|
476
|
+
"_sample"
|
|
477
|
+
if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
|
|
478
|
+
else None
|
|
479
|
+
),
|
|
459
480
|
]
|
|
460
481
|
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
|
|
461
482
|
|
|
@@ -588,7 +609,7 @@ class patched_GenerationMixin:
|
|
|
588
609
|
model_inputs = {}
|
|
589
610
|
# - some models don't have `Cache` support
|
|
590
611
|
# (which implies they don't expect `cache_position` in `forward`)
|
|
591
|
-
if self
|
|
612
|
+
if getattr(self, "_supports_cache_class", False):
|
|
592
613
|
model_inputs["cache_position"] = cache_position
|
|
593
614
|
# - `cache_position` was not a mandatory input in
|
|
594
615
|
# `prepare_inputs_for_generation` for those models, and this
|
|
@@ -728,6 +749,192 @@ class patched_GenerationMixin:
|
|
|
728
749
|
model_inputs.pop("labels", None)
|
|
729
750
|
return model_inputs
|
|
730
751
|
|
|
752
|
+
def _sample(
|
|
753
|
+
self,
|
|
754
|
+
input_ids: torch.LongTensor,
|
|
755
|
+
logits_processor: "LogitsProcessorList", # noqa: F821
|
|
756
|
+
stopping_criteria: "StoppingCriteriaList", # noqa: F821
|
|
757
|
+
generation_config: "GenerationConfig", # noqa: F821
|
|
758
|
+
synced_gpus: bool = False,
|
|
759
|
+
streamer: Optional["BaseStreamer"] = None, # noqa: F821
|
|
760
|
+
**model_kwargs,
|
|
761
|
+
) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
|
|
762
|
+
"""
|
|
763
|
+
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
|
|
764
|
+
"""
|
|
765
|
+
# init values
|
|
766
|
+
pad_token_id = generation_config._pad_token_tensor
|
|
767
|
+
output_attentions = generation_config.output_attentions
|
|
768
|
+
output_hidden_states = generation_config.output_hidden_states
|
|
769
|
+
output_scores = generation_config.output_scores
|
|
770
|
+
output_logits = generation_config.output_logits
|
|
771
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
772
|
+
has_eos_stopping_criteria = any(
|
|
773
|
+
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
|
|
774
|
+
)
|
|
775
|
+
do_sample = generation_config.do_sample
|
|
776
|
+
|
|
777
|
+
# init attention / hidden states / scores tuples
|
|
778
|
+
scores = () if (return_dict_in_generate and output_scores) else None
|
|
779
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
780
|
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
781
|
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
782
|
+
decoder_hidden_states = (
|
|
783
|
+
() if (return_dict_in_generate and output_hidden_states) else None
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
787
|
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
788
|
+
encoder_attentions = (
|
|
789
|
+
model_kwargs["encoder_outputs"].get("attentions")
|
|
790
|
+
if output_attentions
|
|
791
|
+
else None
|
|
792
|
+
)
|
|
793
|
+
encoder_hidden_states = (
|
|
794
|
+
model_kwargs["encoder_outputs"].get("hidden_states")
|
|
795
|
+
if output_hidden_states
|
|
796
|
+
else None
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# keep track of which sequences are already finished
|
|
800
|
+
batch_size, cur_len = input_ids.shape[:2]
|
|
801
|
+
this_peer_finished = False
|
|
802
|
+
unfinished_sequences = torch.ones(
|
|
803
|
+
batch_size, dtype=torch.long, device=input_ids.device
|
|
804
|
+
)
|
|
805
|
+
model_kwargs = self._get_initial_cache_position(
|
|
806
|
+
cur_len, input_ids.device, model_kwargs
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
model_forward = self.__call__
|
|
810
|
+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
|
811
|
+
if compile_forward:
|
|
812
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
|
813
|
+
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
|
814
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
815
|
+
# only raise warning if the user passed an explicit compile-config
|
|
816
|
+
if (
|
|
817
|
+
generation_config.compile_config is not None
|
|
818
|
+
and generation_config.compile_config.fullgraph
|
|
819
|
+
):
|
|
820
|
+
generation_config.compile_config.fullgraph = False
|
|
821
|
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
|
822
|
+
|
|
823
|
+
if generation_config.prefill_chunk_size is not None:
|
|
824
|
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
|
825
|
+
is_prefill = False
|
|
826
|
+
else:
|
|
827
|
+
is_prefill = True
|
|
828
|
+
|
|
829
|
+
while self._has_unfinished_sequences(
|
|
830
|
+
this_peer_finished, synced_gpus, device=input_ids.device
|
|
831
|
+
):
|
|
832
|
+
# prepare model inputs
|
|
833
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
834
|
+
|
|
835
|
+
if is_prefill:
|
|
836
|
+
outputs = self(**model_inputs, return_dict=True)
|
|
837
|
+
is_prefill = False
|
|
838
|
+
else:
|
|
839
|
+
outputs = model_forward(**model_inputs, return_dict=True)
|
|
840
|
+
|
|
841
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
842
|
+
outputs,
|
|
843
|
+
model_kwargs,
|
|
844
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
845
|
+
)
|
|
846
|
+
if synced_gpus and this_peer_finished:
|
|
847
|
+
continue
|
|
848
|
+
|
|
849
|
+
next_token_logits = outputs.logits[:, -1, :].to(
|
|
850
|
+
copy=True, dtype=torch.float32, device=input_ids.device
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
# pre-process distribution
|
|
854
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
855
|
+
|
|
856
|
+
# Store scores, attentions and hidden_states when required
|
|
857
|
+
if return_dict_in_generate:
|
|
858
|
+
if output_scores:
|
|
859
|
+
scores += (next_token_scores,)
|
|
860
|
+
if output_logits:
|
|
861
|
+
raw_logits += (next_token_logits,)
|
|
862
|
+
if output_attentions:
|
|
863
|
+
decoder_attentions += (
|
|
864
|
+
(outputs.decoder_attentions,)
|
|
865
|
+
if self.config.is_encoder_decoder
|
|
866
|
+
else (outputs.attentions,)
|
|
867
|
+
)
|
|
868
|
+
if self.config.is_encoder_decoder:
|
|
869
|
+
cross_attentions += (outputs.cross_attentions,)
|
|
870
|
+
|
|
871
|
+
if output_hidden_states:
|
|
872
|
+
decoder_hidden_states += (
|
|
873
|
+
(outputs.decoder_hidden_states,)
|
|
874
|
+
if self.config.is_encoder_decoder
|
|
875
|
+
else (outputs.hidden_states,)
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
# token selection
|
|
879
|
+
if do_sample:
|
|
880
|
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
|
881
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
882
|
+
else:
|
|
883
|
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
884
|
+
|
|
885
|
+
# finished sentences should have their next token be a padding token
|
|
886
|
+
if has_eos_stopping_criteria:
|
|
887
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
888
|
+
1 - unfinished_sequences
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
# update generated ids, model inputs, and length for next step
|
|
892
|
+
# PATCHED: the two following lines, next_tokens can 2D already for this model
|
|
893
|
+
next_tokens_2d = (
|
|
894
|
+
next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
|
|
895
|
+
)
|
|
896
|
+
input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
|
|
897
|
+
if streamer is not None:
|
|
898
|
+
streamer.put(next_tokens.cpu())
|
|
899
|
+
|
|
900
|
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
901
|
+
this_peer_finished = unfinished_sequences.max() == 0
|
|
902
|
+
cur_len += 1
|
|
903
|
+
|
|
904
|
+
# This is needed to properly delete outputs.logits which may be very large
|
|
905
|
+
# for first iteration
|
|
906
|
+
# Otherwise a reference to outputs is kept which keeps
|
|
907
|
+
# the logits alive in the next iteration
|
|
908
|
+
del outputs
|
|
909
|
+
|
|
910
|
+
if streamer is not None:
|
|
911
|
+
streamer.end()
|
|
912
|
+
|
|
913
|
+
if return_dict_in_generate:
|
|
914
|
+
if self.config.is_encoder_decoder:
|
|
915
|
+
return transformers.generation.utils.GenerateEncoderDecoderOutput(
|
|
916
|
+
sequences=input_ids,
|
|
917
|
+
scores=scores,
|
|
918
|
+
logits=raw_logits,
|
|
919
|
+
encoder_attentions=encoder_attentions,
|
|
920
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
921
|
+
decoder_attentions=decoder_attentions,
|
|
922
|
+
cross_attentions=cross_attentions,
|
|
923
|
+
decoder_hidden_states=decoder_hidden_states,
|
|
924
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
925
|
+
)
|
|
926
|
+
else:
|
|
927
|
+
return transformers.generation.utils.GenerateDecoderOnlyOutput(
|
|
928
|
+
sequences=input_ids,
|
|
929
|
+
scores=scores,
|
|
930
|
+
logits=raw_logits,
|
|
931
|
+
attentions=decoder_attentions,
|
|
932
|
+
hidden_states=decoder_hidden_states,
|
|
933
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
934
|
+
)
|
|
935
|
+
else:
|
|
936
|
+
return input_ids
|
|
937
|
+
|
|
731
938
|
|
|
732
939
|
def patched__compute_dynamic_ntk_parameters(
|
|
733
940
|
config: Optional[transformers.PretrainedConfig] = None,
|
|
@@ -791,6 +998,7 @@ def patched__compute_dynamic_ntk_parameters(
|
|
|
791
998
|
if seq_len is None:
|
|
792
999
|
seq_len = max_position_embeddings
|
|
793
1000
|
else:
|
|
1001
|
+
# PATCHED: remove the line using max
|
|
794
1002
|
torch._check(isinstance(seq_len, torch.Tensor))
|
|
795
1003
|
seq_len = torch.maximum(
|
|
796
1004
|
seq_len,
|
|
@@ -896,6 +1104,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
896
1104
|
)
|
|
897
1105
|
original_inv_freq = self.original_inv_freq.to(device)
|
|
898
1106
|
|
|
1107
|
+
# PATCHED: uses torch.cond instead of a test
|
|
899
1108
|
cond = (seq_len > original_max_position_embeddings).item()
|
|
900
1109
|
inv_freq = torch.cond(
|
|
901
1110
|
cond,
|
|
@@ -967,6 +1176,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
967
1176
|
|
|
968
1177
|
original_inv_freq = self.original_inv_freq.to(device)
|
|
969
1178
|
cond = (seq_len >= self.original_max_seq_len).item()
|
|
1179
|
+
# PATCHED: uses torch.cond instead of a test
|
|
970
1180
|
inv_freq = torch.cond(
|
|
971
1181
|
cond,
|
|
972
1182
|
(lambda x, y: x.clone()),
|
|
@@ -1002,6 +1212,7 @@ def common_eager_attention_forward(
|
|
|
1002
1212
|
|
|
1003
1213
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
1004
1214
|
if attention_mask is not None:
|
|
1215
|
+
# PATCHED
|
|
1005
1216
|
# The two following lines were added.
|
|
1006
1217
|
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1007
1218
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
@@ -1074,6 +1285,7 @@ def patched_modeling_marian_eager_attention_forward(
|
|
|
1074
1285
|
class common_RotaryEmbedding(torch.nn.Module):
|
|
1075
1286
|
# This may cause some issues.
|
|
1076
1287
|
# @torch.no_grad()
|
|
1288
|
+
# PATCHED: the decorator
|
|
1077
1289
|
@patched_dynamic_rope_update
|
|
1078
1290
|
def forward(self, x, position_ids):
|
|
1079
1291
|
inv_freq_expanded = (
|
|
@@ -1629,3 +1841,56 @@ if patch_qwen3:
|
|
|
1629
1841
|
batch_size, sequence_length, hidden_dim
|
|
1630
1842
|
)
|
|
1631
1843
|
return final_hidden_states, router_logits
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
try:
|
|
1847
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
|
|
1848
|
+
|
|
1849
|
+
patch_gemma3 = True
|
|
1850
|
+
except ImportError:
|
|
1851
|
+
patch_gemma3 = False
|
|
1852
|
+
|
|
1853
|
+
|
|
1854
|
+
if patch_gemma3:
|
|
1855
|
+
|
|
1856
|
+
class patched_Gemma3Model(torch.nn.Module):
|
|
1857
|
+
_PATCHES_ = ["get_placeholder_mask"]
|
|
1858
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
|
|
1859
|
+
_PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
|
|
1860
|
+
|
|
1861
|
+
def get_placeholder_mask(
|
|
1862
|
+
self,
|
|
1863
|
+
input_ids: torch.LongTensor,
|
|
1864
|
+
inputs_embeds: torch.FloatTensor,
|
|
1865
|
+
image_features: torch.FloatTensor,
|
|
1866
|
+
):
|
|
1867
|
+
if input_ids is None:
|
|
1868
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
1869
|
+
torch.tensor(
|
|
1870
|
+
self.config.image_token_id,
|
|
1871
|
+
dtype=torch.long,
|
|
1872
|
+
device=inputs_embeds.device,
|
|
1873
|
+
)
|
|
1874
|
+
)
|
|
1875
|
+
special_image_mask = special_image_mask.all(-1)
|
|
1876
|
+
else:
|
|
1877
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
1878
|
+
|
|
1879
|
+
n_image_tokens = special_image_mask.sum()
|
|
1880
|
+
special_image_mask = (
|
|
1881
|
+
special_image_mask.unsqueeze(-1)
|
|
1882
|
+
.expand_as(inputs_embeds)
|
|
1883
|
+
.to(inputs_embeds.device)
|
|
1884
|
+
)
|
|
1885
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
1886
|
+
# PATCHED: torch._check
|
|
1887
|
+
# if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
1888
|
+
# raise ValueError( ... )
|
|
1889
|
+
torch._check(
|
|
1890
|
+
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
1891
|
+
lambda: (
|
|
1892
|
+
f"Image features and image tokens do not match: tokens: "
|
|
1893
|
+
f"{n_image_tokens}, features {n_image_features}"
|
|
1894
|
+
),
|
|
1895
|
+
)
|
|
1896
|
+
return special_image_mask
|
|
@@ -4829,3 +4829,39 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
|
|
|
4829
4829
|
"vocab_size": 32064,
|
|
4830
4830
|
}
|
|
4831
4831
|
)
|
|
4832
|
+
|
|
4833
|
+
|
|
4834
|
+
def _ccached_google_gemma_3_4b_it_like():
|
|
4835
|
+
"google/gemma-3-4b-it"
|
|
4836
|
+
return transformers.Gemma3Config(
|
|
4837
|
+
**{
|
|
4838
|
+
"architectures": ["Gemma3ForConditionalGeneration"],
|
|
4839
|
+
"boi_token_index": 255999,
|
|
4840
|
+
"eoi_token_index": 256000,
|
|
4841
|
+
"eos_token_id": [1, 106],
|
|
4842
|
+
"image_token_index": 262144,
|
|
4843
|
+
"initializer_range": 0.02,
|
|
4844
|
+
"mm_tokens_per_image": 256,
|
|
4845
|
+
"model_type": "gemma3",
|
|
4846
|
+
"text_config": {
|
|
4847
|
+
"hidden_size": 2560,
|
|
4848
|
+
"intermediate_size": 10240,
|
|
4849
|
+
"model_type": "gemma3_text",
|
|
4850
|
+
"num_hidden_layers": 34,
|
|
4851
|
+
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
|
|
4852
|
+
"sliding_window": 1024,
|
|
4853
|
+
},
|
|
4854
|
+
"torch_dtype": "bfloat16",
|
|
4855
|
+
"transformers_version": "4.50.0.dev0",
|
|
4856
|
+
"vision_config": {
|
|
4857
|
+
"hidden_size": 1152,
|
|
4858
|
+
"image_size": 896,
|
|
4859
|
+
"intermediate_size": 4304,
|
|
4860
|
+
"model_type": "siglip_vision_model",
|
|
4861
|
+
"num_attention_heads": 16,
|
|
4862
|
+
"num_hidden_layers": 27,
|
|
4863
|
+
"patch_size": 14,
|
|
4864
|
+
"vision_use_head": false,
|
|
4865
|
+
},
|
|
4866
|
+
}
|
|
4867
|
+
)
|
|
@@ -25,6 +25,20 @@ def _code_needing_rewriting(model: Any) -> Any:
|
|
|
25
25
|
return code_needing_rewriting(model)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
def _preprocess_model_id(
|
|
29
|
+
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
|
|
30
|
+
) -> Tuple[str, Optional[str], bool, bool]:
|
|
31
|
+
if subfolder or "//" not in model_id:
|
|
32
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
33
|
+
spl = model_id.split("//")
|
|
34
|
+
if spl[-1] == "pretrained":
|
|
35
|
+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
|
|
36
|
+
if spl[-1] in {"transformer", "vae"}:
|
|
37
|
+
# known subfolder
|
|
38
|
+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
|
|
39
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
40
|
+
|
|
41
|
+
|
|
28
42
|
def get_untrained_model_with_inputs(
|
|
29
43
|
model_id: str,
|
|
30
44
|
config: Optional[Any] = None,
|
|
@@ -57,7 +71,7 @@ def get_untrained_model_with_inputs(
|
|
|
57
71
|
to get a smaller model
|
|
58
72
|
:param use_pretrained: download the pretrained weights as well
|
|
59
73
|
:param use_preinstalled: use preinstalled configurations
|
|
60
|
-
:param add_second_input: provides
|
|
74
|
+
:param add_second_input: provides others inputs to check a model
|
|
61
75
|
supports different shapes
|
|
62
76
|
:param subfolder: subfolder to use for this model id
|
|
63
77
|
:param use_only_preinstalled: use only preinstalled version
|
|
@@ -85,8 +99,16 @@ def get_untrained_model_with_inputs(
|
|
|
85
99
|
f"model_id={model_id!r}, preinstalled model is only available "
|
|
86
100
|
f"if use_only_preinstalled is False."
|
|
87
101
|
)
|
|
102
|
+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
103
|
+
model_id,
|
|
104
|
+
subfolder,
|
|
105
|
+
same_as_pretrained=same_as_pretrained,
|
|
106
|
+
use_pretrained=use_pretrained,
|
|
107
|
+
)
|
|
88
108
|
if verbose:
|
|
89
|
-
print(
|
|
109
|
+
print(
|
|
110
|
+
f"[get_untrained_model_with_inputs] model_id={model_id!r}, subfolder={subfolder!r}"
|
|
111
|
+
)
|
|
90
112
|
if use_preinstalled:
|
|
91
113
|
print(f"[get_untrained_model_with_inputs] use preinstalled {model_id!r}")
|
|
92
114
|
if config is None:
|
|
@@ -178,7 +200,7 @@ def get_untrained_model_with_inputs(
|
|
|
178
200
|
|
|
179
201
|
if verbose:
|
|
180
202
|
print(
|
|
181
|
-
f"[get_untrained_model_with_inputs] package_source={package_source.__name__}
|
|
203
|
+
f"[get_untrained_model_with_inputs] package_source={package_source.__name__} "
|
|
182
204
|
f"from {package_source.__file__}"
|
|
183
205
|
)
|
|
184
206
|
if use_pretrained:
|
|
@@ -193,7 +215,7 @@ def get_untrained_model_with_inputs(
|
|
|
193
215
|
)
|
|
194
216
|
if verbose:
|
|
195
217
|
print(
|
|
196
|
-
f"[get_untrained_model_with_inputs] -- done in "
|
|
218
|
+
f"[get_untrained_model_with_inputs] -- done(1) in "
|
|
197
219
|
f"{time.perf_counter() - begin}s"
|
|
198
220
|
)
|
|
199
221
|
else:
|
|
@@ -250,14 +272,36 @@ def get_untrained_model_with_inputs(
|
|
|
250
272
|
)
|
|
251
273
|
if verbose:
|
|
252
274
|
print(
|
|
253
|
-
f"[get_untrained_model_with_inputs] -- done in "
|
|
275
|
+
f"[get_untrained_model_with_inputs] -- done(2) in "
|
|
254
276
|
f"{time.perf_counter() - begin}s"
|
|
255
277
|
)
|
|
256
278
|
|
|
257
279
|
seed = int(os.environ.get("SEED", "17"))
|
|
258
280
|
torch.manual_seed(seed)
|
|
281
|
+
|
|
282
|
+
if verbose:
|
|
283
|
+
begin = time.perf_counter()
|
|
284
|
+
print(
|
|
285
|
+
f"[get_untrained_model_with_inputs] "
|
|
286
|
+
f"instantiate_specific_model {cls_model}"
|
|
287
|
+
)
|
|
288
|
+
|
|
259
289
|
model = instantiate_specific_model(cls_model, config)
|
|
290
|
+
|
|
291
|
+
if verbose:
|
|
292
|
+
print(
|
|
293
|
+
f"[get_untrained_model_with_inputs] -- done(3) in "
|
|
294
|
+
f"{time.perf_counter() - begin}s (model is {type(model)})"
|
|
295
|
+
)
|
|
296
|
+
|
|
260
297
|
if model is None:
|
|
298
|
+
|
|
299
|
+
if verbose:
|
|
300
|
+
print(
|
|
301
|
+
f"[get_untrained_model_with_inputs] "
|
|
302
|
+
f"instantiate_specific_model(2) {cls_model}"
|
|
303
|
+
)
|
|
304
|
+
|
|
261
305
|
try:
|
|
262
306
|
if type(config) is dict:
|
|
263
307
|
model = cls_model(**config)
|
|
@@ -268,6 +312,12 @@ def get_untrained_model_with_inputs(
|
|
|
268
312
|
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
|
|
269
313
|
) from e
|
|
270
314
|
|
|
315
|
+
if verbose:
|
|
316
|
+
print(
|
|
317
|
+
f"[get_untrained_model_with_inputs] -- done(4) in "
|
|
318
|
+
f"{time.perf_counter() - begin}s (model is {type(model)})"
|
|
319
|
+
)
|
|
320
|
+
|
|
271
321
|
# input kwargs
|
|
272
322
|
seed = int(os.environ.get("SEED", "17")) + 1
|
|
273
323
|
torch.manual_seed(seed)
|