onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +87 -77
- onnx_diagnostic/doc.py +22 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +59 -0
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +585 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +54 -73
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +21 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +3 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +158 -103
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +41 -39
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -11,7 +11,7 @@ from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
14
|
-
"""
|
|
14
|
+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
15
15
|
from ...helpers import string_type
|
|
16
16
|
|
|
17
17
|
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
@@ -534,19 +534,169 @@ class patched_GenerationMixin:
|
|
|
534
534
|
return model_inputs
|
|
535
535
|
|
|
536
536
|
|
|
537
|
-
def
|
|
537
|
+
def patched__compute_dynamic_ntk_parameters(
|
|
538
|
+
config: Optional[transformers.PretrainedConfig] = None,
|
|
539
|
+
device: Optional["torch.device"] = None,
|
|
540
|
+
seq_len: Optional[int] = None,
|
|
541
|
+
**rope_kwargs,
|
|
542
|
+
) -> Tuple["torch.Tensor", float]:
|
|
543
|
+
"""
|
|
544
|
+
manual patch:
|
|
545
|
+
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
546
|
+
|
|
547
|
+
Computes the inverse frequencies with NTK scaling.
|
|
548
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
552
|
+
The model configuration.
|
|
553
|
+
device (`torch.device`):
|
|
554
|
+
The device to use for initialization of the inverse frequencies.
|
|
555
|
+
seq_len (`int`, *optional*):
|
|
556
|
+
The current sequence length,
|
|
557
|
+
used to update the dynamic RoPE at inference time.
|
|
558
|
+
rope_kwargs (`Dict`, *optional*):
|
|
559
|
+
BC compatibility with the previous
|
|
560
|
+
RoPE class instantiation, will be removed in v4.45.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
Tuple of (`torch.Tensor`, `float`),
|
|
564
|
+
containing the inverse frequencies for the RoPE embeddings and the
|
|
565
|
+
post-processing scaling factor applied to the
|
|
566
|
+
omputed cos/sin (unused in this type of RoPE).
|
|
538
567
|
"""
|
|
539
|
-
|
|
568
|
+
if config is not None and len(rope_kwargs) > 0:
|
|
569
|
+
raise ValueError(
|
|
570
|
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
|
571
|
+
f"`_compute_dynamic_ntk_parameters`, got "
|
|
572
|
+
f"`rope_kwargs`={rope_kwargs} and `config`={config}"
|
|
573
|
+
)
|
|
574
|
+
if len(rope_kwargs) > 0:
|
|
575
|
+
base = rope_kwargs["base"]
|
|
576
|
+
dim = rope_kwargs["dim"]
|
|
577
|
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
|
578
|
+
factor = rope_kwargs["factor"]
|
|
579
|
+
elif config is not None:
|
|
580
|
+
base = config.rope_theta
|
|
581
|
+
partial_rotary_factor = (
|
|
582
|
+
config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
583
|
+
)
|
|
584
|
+
head_dim = getattr(
|
|
585
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
586
|
+
)
|
|
587
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
588
|
+
max_position_embeddings = config.max_position_embeddings
|
|
589
|
+
factor = config.rope_scaling["factor"]
|
|
590
|
+
|
|
591
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
592
|
+
|
|
593
|
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
|
594
|
+
# seq_len = seq_len if seq_len is not None and
|
|
595
|
+
# seq_len > max_position_embeddings else max_position_embeddings
|
|
596
|
+
if seq_len is None:
|
|
597
|
+
seq_len = max_position_embeddings
|
|
598
|
+
else:
|
|
599
|
+
torch._check(isinstance(seq_len, torch.Tensor))
|
|
600
|
+
seq_len = torch.maximum(
|
|
601
|
+
seq_len,
|
|
602
|
+
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Compute the inverse frequencies
|
|
606
|
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
|
|
607
|
+
dim / (dim - 2)
|
|
608
|
+
)
|
|
609
|
+
inv_freq = 1.0 / (
|
|
610
|
+
base
|
|
611
|
+
** (
|
|
612
|
+
torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
|
|
613
|
+
/ dim
|
|
614
|
+
)
|
|
615
|
+
)
|
|
616
|
+
return inv_freq, attention_factor
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def patched_dynamic_rope_update(rope_forward):
|
|
620
|
+
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
|
|
621
|
+
|
|
622
|
+
``rope_type`` is determined in the constructor of class
|
|
623
|
+
:class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
|
|
624
|
+
|
|
625
|
+
.. code-block:: python
|
|
626
|
+
|
|
627
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
628
|
+
self.rope_type = config.rope_scaling.get(
|
|
629
|
+
"rope_type", config.rope_scaling.get("type"))
|
|
630
|
+
else:
|
|
631
|
+
self.rope_type = "default"
|
|
632
|
+
|
|
633
|
+
The original code of the patched function:
|
|
634
|
+
|
|
635
|
+
.. code-block:: python
|
|
636
|
+
|
|
637
|
+
def dynamic_rope_update(rope_forward):
|
|
638
|
+
def longrope_frequency_update(self, position_ids, device):
|
|
639
|
+
seq_len = torch.max(position_ids) + 1
|
|
640
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
641
|
+
original_max_position_embeddings =
|
|
642
|
+
self.config.original_max_position_embeddings
|
|
643
|
+
else:
|
|
644
|
+
original_max_position_embeddings =
|
|
645
|
+
self.config.max_position_embeddings
|
|
646
|
+
if seq_len > original_max_position_embeddings:
|
|
647
|
+
if not hasattr(self, "long_inv_freq"):
|
|
648
|
+
self.long_inv_freq, _ = self.rope_init_fn(
|
|
649
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
650
|
+
)
|
|
651
|
+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
|
652
|
+
else:
|
|
653
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
654
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
655
|
+
|
|
656
|
+
def dynamic_frequency_update(self, position_ids, device):
|
|
657
|
+
seq_len = torch.max(position_ids) + 1
|
|
658
|
+
if seq_len > self.max_seq_len_cached: # growth
|
|
659
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
660
|
+
self.config, device, seq_len=seq_len)
|
|
661
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
662
|
+
self.max_seq_len_cached = seq_len
|
|
663
|
+
|
|
664
|
+
if seq_len < self.original_max_seq_len and
|
|
665
|
+
self.max_seq_len_cached > self.original_max_seq_len:
|
|
666
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
667
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
668
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
|
669
|
+
|
|
670
|
+
@wraps(rope_forward)
|
|
671
|
+
def wrapper(self, x, position_ids):
|
|
672
|
+
if "dynamic" in self.rope_type:
|
|
673
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
674
|
+
elif self.rope_type == "longrope":
|
|
675
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
676
|
+
return rope_forward(self, x, position_ids)
|
|
677
|
+
|
|
678
|
+
return wrapper
|
|
679
|
+
|
|
540
680
|
"""
|
|
541
681
|
|
|
542
682
|
def longrope_frequency_update(self, position_ids, device):
|
|
683
|
+
# It is no use to patch the function after the model is created
|
|
684
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
685
|
+
# is created and when no patch is applied yet.
|
|
686
|
+
# So we select the patched version here.
|
|
687
|
+
rope_init_fn = (
|
|
688
|
+
patched__compute_dynamic_ntk_parameters
|
|
689
|
+
if self.rope_init_fn
|
|
690
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
691
|
+
else self.rope_init_fn
|
|
692
|
+
)
|
|
543
693
|
seq_len = torch.max(position_ids) + 1
|
|
544
694
|
if hasattr(self.config, "original_max_position_embeddings"):
|
|
545
695
|
original_max_position_embeddings = self.config.original_max_position_embeddings
|
|
546
696
|
else:
|
|
547
697
|
original_max_position_embeddings = self.config.max_position_embeddings
|
|
548
698
|
# At export time, seq_len is unknown.
|
|
549
|
-
long_inv_freq, _ =
|
|
699
|
+
long_inv_freq, _ = rope_init_fn(
|
|
550
700
|
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
551
701
|
)
|
|
552
702
|
original_inv_freq = self.original_inv_freq.to(device)
|
|
@@ -565,21 +715,70 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
565
715
|
# self.inv_freq = self.original_inv_freq
|
|
566
716
|
|
|
567
717
|
def dynamic_frequency_update(self, position_ids, device):
|
|
718
|
+
# constructor:
|
|
719
|
+
# - self.max_seq_len_cached = config.max_position_embeddings
|
|
720
|
+
# - self.original_max_seq_len = config.max_position_embeddings
|
|
721
|
+
# - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
722
|
+
|
|
723
|
+
# It is no use to patch the function after the model is created
|
|
724
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
725
|
+
# is created and when no patch is applied yet.
|
|
726
|
+
# So we select the patched version here.
|
|
727
|
+
rope_init_fn = (
|
|
728
|
+
patched__compute_dynamic_ntk_parameters
|
|
729
|
+
if self.rope_init_fn
|
|
730
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
731
|
+
else self.rope_init_fn
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# This behaviour is difficult to translate.
|
|
735
|
+
# The sequence always grows.
|
|
736
|
+
# The test should always True.
|
|
737
|
+
# So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
|
|
738
|
+
#
|
|
739
|
+
# if seq_len > self.max_seq_len_cached: # growth
|
|
740
|
+
# inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
741
|
+
# self.config, device, seq_len=seq_len
|
|
742
|
+
# )
|
|
743
|
+
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
744
|
+
# self.max_seq_len_cached = seq_len
|
|
745
|
+
#
|
|
746
|
+
# So we should not need what follows.
|
|
747
|
+
#
|
|
748
|
+
# cond = (seq_len > self.max_seq_len_cached).item()
|
|
749
|
+
# self.attention_scaling = torch.cond(
|
|
750
|
+
# cond,
|
|
751
|
+
# (lambda x, y: x.clone()),
|
|
752
|
+
# (lambda x, y: y.clone()),
|
|
753
|
+
# [attention_scaling, self.attention_scaling],
|
|
754
|
+
# )
|
|
755
|
+
|
|
568
756
|
seq_len = torch.max(position_ids) + 1
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
)
|
|
573
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
574
|
-
self.max_seq_len_cached = seq_len
|
|
757
|
+
long_inv_freq, self.attention_scaling = rope_init_fn(
|
|
758
|
+
self.config, device, seq_len=seq_len
|
|
759
|
+
)
|
|
575
760
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
761
|
+
# Second test to translate.
|
|
762
|
+
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
|
|
763
|
+
# But in that case the following condition is a way to restore the original cache.
|
|
764
|
+
|
|
765
|
+
# if (
|
|
766
|
+
# seq_len < self.original_max_seq_len
|
|
767
|
+
# and self.max_seq_len_cached > self.original_max_seq_len
|
|
768
|
+
# ):
|
|
769
|
+
# self.original_inv_freq = self.original_inv_freq.to(device)
|
|
770
|
+
# self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
771
|
+
# self.max_seq_len_cached = self.original_max_seq_len
|
|
772
|
+
|
|
773
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
774
|
+
cond = (seq_len >= self.original_max_seq_len).item()
|
|
775
|
+
inv_freq = torch.cond(
|
|
776
|
+
cond,
|
|
777
|
+
(lambda x, y: x.clone()),
|
|
778
|
+
(lambda x, y: y.clone()),
|
|
779
|
+
[long_inv_freq, original_inv_freq],
|
|
780
|
+
)
|
|
781
|
+
self.inv_freq = inv_freq
|
|
583
782
|
|
|
584
783
|
@wraps(rope_forward)
|
|
585
784
|
def wrapper(self, x, position_ids):
|
|
@@ -619,3 +818,152 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
|
|
|
619
818
|
sin = emb.sin() * self.attention_scaling
|
|
620
819
|
|
|
621
820
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
class patched_IdeficsEmbedding(torch.nn.Module):
|
|
824
|
+
_PATCHES_ = ["forward"]
|
|
825
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
|
|
826
|
+
|
|
827
|
+
def forward(self, x, seq_len=None):
|
|
828
|
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
829
|
+
# if seq_len > self.max_seq_len_cached:
|
|
830
|
+
# self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
831
|
+
|
|
832
|
+
def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
|
|
833
|
+
t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
|
|
834
|
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
835
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
836
|
+
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
|
|
837
|
+
|
|
838
|
+
def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
|
|
839
|
+
torch._check(seq_len.item() <= cos_cached.shape[0])
|
|
840
|
+
co = cos_cached[: seq_len.item()].detach().clone()
|
|
841
|
+
torch._check(seq_len.item() <= sin_cached.shape[0])
|
|
842
|
+
si = sin_cached[: seq_len.item()].detach().clone()
|
|
843
|
+
return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
|
|
844
|
+
|
|
845
|
+
cos_cached, sin_cached = torch.cond(
|
|
846
|
+
(seq_len > self.max_seq_len_cached).item(),
|
|
847
|
+
_set_cos_sin_cache_then,
|
|
848
|
+
_set_cos_sin_cache_else,
|
|
849
|
+
[x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
|
|
850
|
+
)
|
|
851
|
+
return cos_cached, sin_cached
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
class patched_IdeficsAttention(torch.nn.Module):
|
|
855
|
+
_PATCHES_ = ["forward"]
|
|
856
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
|
|
857
|
+
|
|
858
|
+
def forward(
|
|
859
|
+
self,
|
|
860
|
+
hidden_states: torch.Tensor,
|
|
861
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
862
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
863
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
864
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
865
|
+
output_attentions: bool = False,
|
|
866
|
+
use_cache: bool = False,
|
|
867
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
868
|
+
**kwargs,
|
|
869
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
870
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
871
|
+
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
|
872
|
+
|
|
873
|
+
bsz, q_len, _ = hidden_states.size()
|
|
874
|
+
|
|
875
|
+
query_states = (
|
|
876
|
+
self.q_proj(hidden_states)
|
|
877
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
878
|
+
.transpose(1, 2)
|
|
879
|
+
)
|
|
880
|
+
if not is_cross_attention:
|
|
881
|
+
key_states = (
|
|
882
|
+
self.k_proj(hidden_states)
|
|
883
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
884
|
+
.transpose(1, 2)
|
|
885
|
+
)
|
|
886
|
+
value_states = (
|
|
887
|
+
self.v_proj(hidden_states)
|
|
888
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
889
|
+
.transpose(1, 2)
|
|
890
|
+
)
|
|
891
|
+
else:
|
|
892
|
+
_, kv_len, _ = (
|
|
893
|
+
key_value_states.size()
|
|
894
|
+
) # Note that, in this case, `kv_len` == `kv_seq_len`
|
|
895
|
+
key_states = (
|
|
896
|
+
self.k_proj(key_value_states)
|
|
897
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
898
|
+
.transpose(1, 2)
|
|
899
|
+
)
|
|
900
|
+
value_states = (
|
|
901
|
+
self.v_proj(key_value_states)
|
|
902
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
903
|
+
.transpose(1, 2)
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
kv_seq_len = key_states.shape[-2]
|
|
907
|
+
if past_key_value is not None:
|
|
908
|
+
kv_seq_len += cache_position[0]
|
|
909
|
+
|
|
910
|
+
if not is_cross_attention:
|
|
911
|
+
rotary_length = torch.maximum(
|
|
912
|
+
torch.tensor(kv_seq_len, dtype=torch.int64),
|
|
913
|
+
torch.tensor(q_len, dtype=torch.int64),
|
|
914
|
+
)
|
|
915
|
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
|
|
916
|
+
query_states, key_states = (
|
|
917
|
+
transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
|
|
918
|
+
query_states, key_states, cos, sin, position_ids
|
|
919
|
+
)
|
|
920
|
+
)
|
|
921
|
+
# [bsz, nh, t, hd]
|
|
922
|
+
|
|
923
|
+
if past_key_value is not None:
|
|
924
|
+
# sin and cos are specific to RoPE models;
|
|
925
|
+
# cache_position needed for the static cache
|
|
926
|
+
cache_kwargs = {"cache_position": cache_position}
|
|
927
|
+
key_states, value_states = past_key_value.update(
|
|
928
|
+
key_states, value_states, self.layer_idx, cache_kwargs
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
if self.qk_layer_norms:
|
|
932
|
+
query_states = self.q_layer_norm(query_states)
|
|
933
|
+
key_states = self.k_layer_norm(key_states)
|
|
934
|
+
|
|
935
|
+
attention_interface: Callable = (
|
|
936
|
+
transformers.models.idefics.modeling_idefics.eager_attention_forward
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
if self.config._attn_implementation != "eager":
|
|
940
|
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
|
941
|
+
transformers.models.idefics.modeling_idefics.logger.warning_once(
|
|
942
|
+
"`torch.nn.functional.scaled_dot_product_attention` does not support "
|
|
943
|
+
"`output_attentions=True`. Falling back to "
|
|
944
|
+
"eager attention. This warning can be removed using the argument "
|
|
945
|
+
'`attn_implementation="eager"` when loading the model.'
|
|
946
|
+
)
|
|
947
|
+
else:
|
|
948
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
949
|
+
self.config._attn_implementation
|
|
950
|
+
]
|
|
951
|
+
|
|
952
|
+
attn_output, attn_weights = attention_interface(
|
|
953
|
+
self,
|
|
954
|
+
query_states,
|
|
955
|
+
key_states,
|
|
956
|
+
value_states,
|
|
957
|
+
attention_mask,
|
|
958
|
+
dropout=0.0 if not self.training else self.dropout,
|
|
959
|
+
scaling=self.scaling,
|
|
960
|
+
**kwargs,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
964
|
+
attn_output = self.o_proj(attn_output)
|
|
965
|
+
|
|
966
|
+
if output_attentions:
|
|
967
|
+
attn_weights = None
|
|
968
|
+
|
|
969
|
+
return attn_output, attn_weights, past_key_value
|
|
@@ -2,6 +2,7 @@ import copy
|
|
|
2
2
|
import functools
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
|
+
import pprint
|
|
5
6
|
from typing import Any, Dict, List, Optional, Union
|
|
6
7
|
import transformers
|
|
7
8
|
from huggingface_hub import HfApi, model_info, hf_hub_download
|
|
@@ -33,10 +34,14 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig
|
|
|
33
34
|
return res
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def get_cached_configuration(
|
|
37
|
+
def get_cached_configuration(
|
|
38
|
+
name: str, exc: bool = False, **kwargs
|
|
39
|
+
) -> Optional[transformers.PretrainedConfig]:
|
|
37
40
|
"""
|
|
38
41
|
Returns cached configuration to avoid having to many accesses to internet.
|
|
39
42
|
It returns None if not Cache. The list of cached models follows.
|
|
43
|
+
If *exc* is True or if environment variable ``NOHTTP`` is defined,
|
|
44
|
+
the function raises an exception if *name* is not found.
|
|
40
45
|
|
|
41
46
|
.. runpython::
|
|
42
47
|
|
|
@@ -54,8 +59,11 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
|
|
|
54
59
|
conf = copy.deepcopy(conf)
|
|
55
60
|
update_config(conf, kwargs)
|
|
56
61
|
return conf
|
|
57
|
-
|
|
58
|
-
|
|
62
|
+
assert not exc and not os.environ.get("NOHTTP", ""), (
|
|
63
|
+
f"Unable to find {name!r} (exc={exc}, "
|
|
64
|
+
f"NOHTTP={os.environ.get('NOHTTP', '')!r}) "
|
|
65
|
+
f"in {pprint.pformat(sorted(cached))}"
|
|
66
|
+
)
|
|
59
67
|
return None
|
|
60
68
|
|
|
61
69
|
|
|
@@ -64,6 +72,7 @@ def get_pretrained_config(
|
|
|
64
72
|
trust_remote_code: bool = True,
|
|
65
73
|
use_preinstalled: bool = True,
|
|
66
74
|
subfolder: Optional[str] = None,
|
|
75
|
+
use_only_preinstalled: bool = False,
|
|
67
76
|
**kwargs,
|
|
68
77
|
) -> Any:
|
|
69
78
|
"""
|
|
@@ -77,13 +86,20 @@ def get_pretrained_config(
|
|
|
77
86
|
:func:`get_cached_configuration`, the cached list is mostly for
|
|
78
87
|
unit tests
|
|
79
88
|
:param subfolder: subfolder for the given model id
|
|
89
|
+
:param use_only_preinstalled: if True, raises an exception if not preinstalled
|
|
80
90
|
:param kwargs: additional kwargs
|
|
81
91
|
:return: a configuration
|
|
82
92
|
"""
|
|
83
93
|
if use_preinstalled:
|
|
84
|
-
conf = get_cached_configuration(
|
|
94
|
+
conf = get_cached_configuration(
|
|
95
|
+
model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
|
|
96
|
+
)
|
|
85
97
|
if conf is not None:
|
|
86
98
|
return conf
|
|
99
|
+
assert not use_only_preinstalled, (
|
|
100
|
+
f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
|
|
101
|
+
f"use_preinstalled={use_preinstalled!r}"
|
|
102
|
+
)
|
|
87
103
|
if subfolder:
|
|
88
104
|
try:
|
|
89
105
|
return transformers.AutoConfig.from_pretrained(
|