onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +281 -80
- onnx_diagnostic/doc.py +22 -0
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +78 -8
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +1744 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +54 -73
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +21 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +72 -18
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +81 -8
- onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +58 -14
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +166 -106
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +44 -41
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -11,7 +11,7 @@ from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
14
|
-
"""
|
|
14
|
+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
15
15
|
from ...helpers import string_type
|
|
16
16
|
|
|
17
17
|
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
@@ -534,19 +534,169 @@ class patched_GenerationMixin:
|
|
|
534
534
|
return model_inputs
|
|
535
535
|
|
|
536
536
|
|
|
537
|
-
def
|
|
537
|
+
def patched__compute_dynamic_ntk_parameters(
|
|
538
|
+
config: Optional[transformers.PretrainedConfig] = None,
|
|
539
|
+
device: Optional["torch.device"] = None,
|
|
540
|
+
seq_len: Optional[int] = None,
|
|
541
|
+
**rope_kwargs,
|
|
542
|
+
) -> Tuple["torch.Tensor", float]:
|
|
543
|
+
"""
|
|
544
|
+
manual patch:
|
|
545
|
+
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
546
|
+
|
|
547
|
+
Computes the inverse frequencies with NTK scaling.
|
|
548
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
552
|
+
The model configuration.
|
|
553
|
+
device (`torch.device`):
|
|
554
|
+
The device to use for initialization of the inverse frequencies.
|
|
555
|
+
seq_len (`int`, *optional*):
|
|
556
|
+
The current sequence length,
|
|
557
|
+
used to update the dynamic RoPE at inference time.
|
|
558
|
+
rope_kwargs (`Dict`, *optional*):
|
|
559
|
+
BC compatibility with the previous
|
|
560
|
+
RoPE class instantiation, will be removed in v4.45.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
Tuple of (`torch.Tensor`, `float`),
|
|
564
|
+
containing the inverse frequencies for the RoPE embeddings and the
|
|
565
|
+
post-processing scaling factor applied to the
|
|
566
|
+
omputed cos/sin (unused in this type of RoPE).
|
|
538
567
|
"""
|
|
539
|
-
|
|
568
|
+
if config is not None and len(rope_kwargs) > 0:
|
|
569
|
+
raise ValueError(
|
|
570
|
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
|
571
|
+
f"`_compute_dynamic_ntk_parameters`, got "
|
|
572
|
+
f"`rope_kwargs`={rope_kwargs} and `config`={config}"
|
|
573
|
+
)
|
|
574
|
+
if len(rope_kwargs) > 0:
|
|
575
|
+
base = rope_kwargs["base"]
|
|
576
|
+
dim = rope_kwargs["dim"]
|
|
577
|
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
|
578
|
+
factor = rope_kwargs["factor"]
|
|
579
|
+
elif config is not None:
|
|
580
|
+
base = config.rope_theta
|
|
581
|
+
partial_rotary_factor = (
|
|
582
|
+
config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
583
|
+
)
|
|
584
|
+
head_dim = getattr(
|
|
585
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
586
|
+
)
|
|
587
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
588
|
+
max_position_embeddings = config.max_position_embeddings
|
|
589
|
+
factor = config.rope_scaling["factor"]
|
|
590
|
+
|
|
591
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
592
|
+
|
|
593
|
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
|
594
|
+
# seq_len = seq_len if seq_len is not None and
|
|
595
|
+
# seq_len > max_position_embeddings else max_position_embeddings
|
|
596
|
+
if seq_len is None:
|
|
597
|
+
seq_len = max_position_embeddings
|
|
598
|
+
else:
|
|
599
|
+
torch._check(isinstance(seq_len, torch.Tensor))
|
|
600
|
+
seq_len = torch.maximum(
|
|
601
|
+
seq_len,
|
|
602
|
+
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Compute the inverse frequencies
|
|
606
|
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
|
|
607
|
+
dim / (dim - 2)
|
|
608
|
+
)
|
|
609
|
+
inv_freq = 1.0 / (
|
|
610
|
+
base
|
|
611
|
+
** (
|
|
612
|
+
torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
|
|
613
|
+
/ dim
|
|
614
|
+
)
|
|
615
|
+
)
|
|
616
|
+
return inv_freq, attention_factor
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def patched_dynamic_rope_update(rope_forward):
|
|
620
|
+
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
|
|
621
|
+
|
|
622
|
+
``rope_type`` is determined in the constructor of class
|
|
623
|
+
:class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
|
|
624
|
+
|
|
625
|
+
.. code-block:: python
|
|
626
|
+
|
|
627
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
628
|
+
self.rope_type = config.rope_scaling.get(
|
|
629
|
+
"rope_type", config.rope_scaling.get("type"))
|
|
630
|
+
else:
|
|
631
|
+
self.rope_type = "default"
|
|
632
|
+
|
|
633
|
+
The original code of the patched function:
|
|
634
|
+
|
|
635
|
+
.. code-block:: python
|
|
636
|
+
|
|
637
|
+
def dynamic_rope_update(rope_forward):
|
|
638
|
+
def longrope_frequency_update(self, position_ids, device):
|
|
639
|
+
seq_len = torch.max(position_ids) + 1
|
|
640
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
641
|
+
original_max_position_embeddings =
|
|
642
|
+
self.config.original_max_position_embeddings
|
|
643
|
+
else:
|
|
644
|
+
original_max_position_embeddings =
|
|
645
|
+
self.config.max_position_embeddings
|
|
646
|
+
if seq_len > original_max_position_embeddings:
|
|
647
|
+
if not hasattr(self, "long_inv_freq"):
|
|
648
|
+
self.long_inv_freq, _ = self.rope_init_fn(
|
|
649
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
650
|
+
)
|
|
651
|
+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
|
652
|
+
else:
|
|
653
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
654
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
655
|
+
|
|
656
|
+
def dynamic_frequency_update(self, position_ids, device):
|
|
657
|
+
seq_len = torch.max(position_ids) + 1
|
|
658
|
+
if seq_len > self.max_seq_len_cached: # growth
|
|
659
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
660
|
+
self.config, device, seq_len=seq_len)
|
|
661
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
662
|
+
self.max_seq_len_cached = seq_len
|
|
663
|
+
|
|
664
|
+
if seq_len < self.original_max_seq_len and
|
|
665
|
+
self.max_seq_len_cached > self.original_max_seq_len:
|
|
666
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
667
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
668
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
|
669
|
+
|
|
670
|
+
@wraps(rope_forward)
|
|
671
|
+
def wrapper(self, x, position_ids):
|
|
672
|
+
if "dynamic" in self.rope_type:
|
|
673
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
674
|
+
elif self.rope_type == "longrope":
|
|
675
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
676
|
+
return rope_forward(self, x, position_ids)
|
|
677
|
+
|
|
678
|
+
return wrapper
|
|
679
|
+
|
|
540
680
|
"""
|
|
541
681
|
|
|
542
682
|
def longrope_frequency_update(self, position_ids, device):
|
|
683
|
+
# It is no use to patch the function after the model is created
|
|
684
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
685
|
+
# is created and when no patch is applied yet.
|
|
686
|
+
# So we select the patched version here.
|
|
687
|
+
rope_init_fn = (
|
|
688
|
+
patched__compute_dynamic_ntk_parameters
|
|
689
|
+
if self.rope_init_fn
|
|
690
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
691
|
+
else self.rope_init_fn
|
|
692
|
+
)
|
|
543
693
|
seq_len = torch.max(position_ids) + 1
|
|
544
694
|
if hasattr(self.config, "original_max_position_embeddings"):
|
|
545
695
|
original_max_position_embeddings = self.config.original_max_position_embeddings
|
|
546
696
|
else:
|
|
547
697
|
original_max_position_embeddings = self.config.max_position_embeddings
|
|
548
698
|
# At export time, seq_len is unknown.
|
|
549
|
-
long_inv_freq, _ =
|
|
699
|
+
long_inv_freq, _ = rope_init_fn(
|
|
550
700
|
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
551
701
|
)
|
|
552
702
|
original_inv_freq = self.original_inv_freq.to(device)
|
|
@@ -565,21 +715,70 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
565
715
|
# self.inv_freq = self.original_inv_freq
|
|
566
716
|
|
|
567
717
|
def dynamic_frequency_update(self, position_ids, device):
|
|
718
|
+
# constructor:
|
|
719
|
+
# - self.max_seq_len_cached = config.max_position_embeddings
|
|
720
|
+
# - self.original_max_seq_len = config.max_position_embeddings
|
|
721
|
+
# - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
722
|
+
|
|
723
|
+
# It is no use to patch the function after the model is created
|
|
724
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
725
|
+
# is created and when no patch is applied yet.
|
|
726
|
+
# So we select the patched version here.
|
|
727
|
+
rope_init_fn = (
|
|
728
|
+
patched__compute_dynamic_ntk_parameters
|
|
729
|
+
if self.rope_init_fn
|
|
730
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
731
|
+
else self.rope_init_fn
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# This behaviour is difficult to translate.
|
|
735
|
+
# The sequence always grows.
|
|
736
|
+
# The test should always True.
|
|
737
|
+
# So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
|
|
738
|
+
#
|
|
739
|
+
# if seq_len > self.max_seq_len_cached: # growth
|
|
740
|
+
# inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
741
|
+
# self.config, device, seq_len=seq_len
|
|
742
|
+
# )
|
|
743
|
+
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
744
|
+
# self.max_seq_len_cached = seq_len
|
|
745
|
+
#
|
|
746
|
+
# So we should not need what follows.
|
|
747
|
+
#
|
|
748
|
+
# cond = (seq_len > self.max_seq_len_cached).item()
|
|
749
|
+
# self.attention_scaling = torch.cond(
|
|
750
|
+
# cond,
|
|
751
|
+
# (lambda x, y: x.clone()),
|
|
752
|
+
# (lambda x, y: y.clone()),
|
|
753
|
+
# [attention_scaling, self.attention_scaling],
|
|
754
|
+
# )
|
|
755
|
+
|
|
568
756
|
seq_len = torch.max(position_ids) + 1
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
)
|
|
573
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
574
|
-
self.max_seq_len_cached = seq_len
|
|
757
|
+
long_inv_freq, self.attention_scaling = rope_init_fn(
|
|
758
|
+
self.config, device, seq_len=seq_len
|
|
759
|
+
)
|
|
575
760
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
761
|
+
# Second test to translate.
|
|
762
|
+
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
|
|
763
|
+
# But in that case the following condition is a way to restore the original cache.
|
|
764
|
+
|
|
765
|
+
# if (
|
|
766
|
+
# seq_len < self.original_max_seq_len
|
|
767
|
+
# and self.max_seq_len_cached > self.original_max_seq_len
|
|
768
|
+
# ):
|
|
769
|
+
# self.original_inv_freq = self.original_inv_freq.to(device)
|
|
770
|
+
# self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
771
|
+
# self.max_seq_len_cached = self.original_max_seq_len
|
|
772
|
+
|
|
773
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
774
|
+
cond = (seq_len >= self.original_max_seq_len).item()
|
|
775
|
+
inv_freq = torch.cond(
|
|
776
|
+
cond,
|
|
777
|
+
(lambda x, y: x.clone()),
|
|
778
|
+
(lambda x, y: y.clone()),
|
|
779
|
+
[long_inv_freq, original_inv_freq],
|
|
780
|
+
)
|
|
781
|
+
self.inv_freq = inv_freq
|
|
583
782
|
|
|
584
783
|
@wraps(rope_forward)
|
|
585
784
|
def wrapper(self, x, position_ids):
|
|
@@ -619,3 +818,152 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
|
|
|
619
818
|
sin = emb.sin() * self.attention_scaling
|
|
620
819
|
|
|
621
820
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
class patched_IdeficsEmbedding(torch.nn.Module):
|
|
824
|
+
_PATCHES_ = ["forward"]
|
|
825
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
|
|
826
|
+
|
|
827
|
+
def forward(self, x, seq_len=None):
|
|
828
|
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
829
|
+
# if seq_len > self.max_seq_len_cached:
|
|
830
|
+
# self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
831
|
+
|
|
832
|
+
def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
|
|
833
|
+
t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
|
|
834
|
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
835
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
836
|
+
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
|
|
837
|
+
|
|
838
|
+
def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
|
|
839
|
+
torch._check(seq_len.item() <= cos_cached.shape[0])
|
|
840
|
+
co = cos_cached[: seq_len.item()].detach().clone()
|
|
841
|
+
torch._check(seq_len.item() <= sin_cached.shape[0])
|
|
842
|
+
si = sin_cached[: seq_len.item()].detach().clone()
|
|
843
|
+
return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
|
|
844
|
+
|
|
845
|
+
cos_cached, sin_cached = torch.cond(
|
|
846
|
+
(seq_len > self.max_seq_len_cached).item(),
|
|
847
|
+
_set_cos_sin_cache_then,
|
|
848
|
+
_set_cos_sin_cache_else,
|
|
849
|
+
[x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
|
|
850
|
+
)
|
|
851
|
+
return cos_cached, sin_cached
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
class patched_IdeficsAttention(torch.nn.Module):
|
|
855
|
+
_PATCHES_ = ["forward"]
|
|
856
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
|
|
857
|
+
|
|
858
|
+
def forward(
|
|
859
|
+
self,
|
|
860
|
+
hidden_states: torch.Tensor,
|
|
861
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
862
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
863
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
864
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
865
|
+
output_attentions: bool = False,
|
|
866
|
+
use_cache: bool = False,
|
|
867
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
868
|
+
**kwargs,
|
|
869
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
870
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
871
|
+
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
|
872
|
+
|
|
873
|
+
bsz, q_len, _ = hidden_states.size()
|
|
874
|
+
|
|
875
|
+
query_states = (
|
|
876
|
+
self.q_proj(hidden_states)
|
|
877
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
878
|
+
.transpose(1, 2)
|
|
879
|
+
)
|
|
880
|
+
if not is_cross_attention:
|
|
881
|
+
key_states = (
|
|
882
|
+
self.k_proj(hidden_states)
|
|
883
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
884
|
+
.transpose(1, 2)
|
|
885
|
+
)
|
|
886
|
+
value_states = (
|
|
887
|
+
self.v_proj(hidden_states)
|
|
888
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
889
|
+
.transpose(1, 2)
|
|
890
|
+
)
|
|
891
|
+
else:
|
|
892
|
+
_, kv_len, _ = (
|
|
893
|
+
key_value_states.size()
|
|
894
|
+
) # Note that, in this case, `kv_len` == `kv_seq_len`
|
|
895
|
+
key_states = (
|
|
896
|
+
self.k_proj(key_value_states)
|
|
897
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
898
|
+
.transpose(1, 2)
|
|
899
|
+
)
|
|
900
|
+
value_states = (
|
|
901
|
+
self.v_proj(key_value_states)
|
|
902
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
903
|
+
.transpose(1, 2)
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
kv_seq_len = key_states.shape[-2]
|
|
907
|
+
if past_key_value is not None:
|
|
908
|
+
kv_seq_len += cache_position[0]
|
|
909
|
+
|
|
910
|
+
if not is_cross_attention:
|
|
911
|
+
rotary_length = torch.maximum(
|
|
912
|
+
torch.tensor(kv_seq_len, dtype=torch.int64),
|
|
913
|
+
torch.tensor(q_len, dtype=torch.int64),
|
|
914
|
+
)
|
|
915
|
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
|
|
916
|
+
query_states, key_states = (
|
|
917
|
+
transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
|
|
918
|
+
query_states, key_states, cos, sin, position_ids
|
|
919
|
+
)
|
|
920
|
+
)
|
|
921
|
+
# [bsz, nh, t, hd]
|
|
922
|
+
|
|
923
|
+
if past_key_value is not None:
|
|
924
|
+
# sin and cos are specific to RoPE models;
|
|
925
|
+
# cache_position needed for the static cache
|
|
926
|
+
cache_kwargs = {"cache_position": cache_position}
|
|
927
|
+
key_states, value_states = past_key_value.update(
|
|
928
|
+
key_states, value_states, self.layer_idx, cache_kwargs
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
if self.qk_layer_norms:
|
|
932
|
+
query_states = self.q_layer_norm(query_states)
|
|
933
|
+
key_states = self.k_layer_norm(key_states)
|
|
934
|
+
|
|
935
|
+
attention_interface: Callable = (
|
|
936
|
+
transformers.models.idefics.modeling_idefics.eager_attention_forward
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
if self.config._attn_implementation != "eager":
|
|
940
|
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
|
941
|
+
transformers.models.idefics.modeling_idefics.logger.warning_once(
|
|
942
|
+
"`torch.nn.functional.scaled_dot_product_attention` does not support "
|
|
943
|
+
"`output_attentions=True`. Falling back to "
|
|
944
|
+
"eager attention. This warning can be removed using the argument "
|
|
945
|
+
'`attn_implementation="eager"` when loading the model.'
|
|
946
|
+
)
|
|
947
|
+
else:
|
|
948
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
949
|
+
self.config._attn_implementation
|
|
950
|
+
]
|
|
951
|
+
|
|
952
|
+
attn_output, attn_weights = attention_interface(
|
|
953
|
+
self,
|
|
954
|
+
query_states,
|
|
955
|
+
key_states,
|
|
956
|
+
value_states,
|
|
957
|
+
attention_mask,
|
|
958
|
+
dropout=0.0 if not self.training else self.dropout,
|
|
959
|
+
scaling=self.scaling,
|
|
960
|
+
**kwargs,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
964
|
+
attn_output = self.o_proj(attn_output)
|
|
965
|
+
|
|
966
|
+
if output_attentions:
|
|
967
|
+
attn_weights = None
|
|
968
|
+
|
|
969
|
+
return attn_output, attn_weights, past_key_value
|
|
@@ -2,9 +2,11 @@ import copy
|
|
|
2
2
|
import functools
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
|
+
import pprint
|
|
6
|
+
import sys
|
|
5
7
|
from typing import Any, Dict, List, Optional, Union
|
|
6
8
|
import transformers
|
|
7
|
-
from huggingface_hub import HfApi, model_info, hf_hub_download
|
|
9
|
+
from huggingface_hub import HfApi, model_info, hf_hub_download, list_repo_files
|
|
8
10
|
from ...helpers.config_helper import update_config
|
|
9
11
|
from . import hub_data_cached_configs
|
|
10
12
|
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
|
|
@@ -33,10 +35,14 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig
|
|
|
33
35
|
return res
|
|
34
36
|
|
|
35
37
|
|
|
36
|
-
def get_cached_configuration(
|
|
38
|
+
def get_cached_configuration(
|
|
39
|
+
name: str, exc: bool = False, **kwargs
|
|
40
|
+
) -> Optional[transformers.PretrainedConfig]:
|
|
37
41
|
"""
|
|
38
42
|
Returns cached configuration to avoid having to many accesses to internet.
|
|
39
43
|
It returns None if not Cache. The list of cached models follows.
|
|
44
|
+
If *exc* is True or if environment variable ``NOHTTP`` is defined,
|
|
45
|
+
the function raises an exception if *name* is not found.
|
|
40
46
|
|
|
41
47
|
.. runpython::
|
|
42
48
|
|
|
@@ -54,8 +60,11 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
|
|
|
54
60
|
conf = copy.deepcopy(conf)
|
|
55
61
|
update_config(conf, kwargs)
|
|
56
62
|
return conf
|
|
57
|
-
|
|
58
|
-
|
|
63
|
+
assert not exc and not os.environ.get("NOHTTP", ""), (
|
|
64
|
+
f"Unable to find {name!r} (exc={exc}, "
|
|
65
|
+
f"NOHTTP={os.environ.get('NOHTTP', '')!r}) "
|
|
66
|
+
f"in {pprint.pformat(sorted(cached))}"
|
|
67
|
+
)
|
|
59
68
|
return None
|
|
60
69
|
|
|
61
70
|
|
|
@@ -64,6 +73,7 @@ def get_pretrained_config(
|
|
|
64
73
|
trust_remote_code: bool = True,
|
|
65
74
|
use_preinstalled: bool = True,
|
|
66
75
|
subfolder: Optional[str] = None,
|
|
76
|
+
use_only_preinstalled: bool = False,
|
|
67
77
|
**kwargs,
|
|
68
78
|
) -> Any:
|
|
69
79
|
"""
|
|
@@ -77,13 +87,20 @@ def get_pretrained_config(
|
|
|
77
87
|
:func:`get_cached_configuration`, the cached list is mostly for
|
|
78
88
|
unit tests
|
|
79
89
|
:param subfolder: subfolder for the given model id
|
|
90
|
+
:param use_only_preinstalled: if True, raises an exception if not preinstalled
|
|
80
91
|
:param kwargs: additional kwargs
|
|
81
92
|
:return: a configuration
|
|
82
93
|
"""
|
|
83
94
|
if use_preinstalled:
|
|
84
|
-
conf = get_cached_configuration(
|
|
95
|
+
conf = get_cached_configuration(
|
|
96
|
+
model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
|
|
97
|
+
)
|
|
85
98
|
if conf is not None:
|
|
86
99
|
return conf
|
|
100
|
+
assert not use_only_preinstalled, (
|
|
101
|
+
f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
|
|
102
|
+
f"use_preinstalled={use_preinstalled!r}"
|
|
103
|
+
)
|
|
87
104
|
if subfolder:
|
|
88
105
|
try:
|
|
89
106
|
return transformers.AutoConfig.from_pretrained(
|
|
@@ -122,12 +139,15 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
|
|
|
122
139
|
|
|
123
140
|
|
|
124
141
|
@functools.cache
|
|
125
|
-
def task_from_arch(
|
|
142
|
+
def task_from_arch(
|
|
143
|
+
arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
|
|
144
|
+
) -> str:
|
|
126
145
|
"""
|
|
127
146
|
This function relies on stored information. That information needs to be refresh.
|
|
128
147
|
|
|
129
148
|
:param arch: architecture name
|
|
130
149
|
:param default_value: default value in case the task cannot be determined
|
|
150
|
+
:param model_id: unused unless the architecture does not help.
|
|
131
151
|
:return: task
|
|
132
152
|
|
|
133
153
|
.. runpython::
|
|
@@ -140,9 +160,16 @@ def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
|
|
|
140
160
|
<onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
|
|
141
161
|
"""
|
|
142
162
|
data = load_architecture_task()
|
|
163
|
+
if arch not in data and model_id:
|
|
164
|
+
# Let's try with the model id.
|
|
165
|
+
return task_from_id(model_id)
|
|
143
166
|
if default_value is not None:
|
|
144
167
|
return data.get(arch, default_value)
|
|
145
|
-
assert arch in data,
|
|
168
|
+
assert arch in data, (
|
|
169
|
+
f"Architecture {arch!r} is unknown, last refresh in {__date__}. "
|
|
170
|
+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__`` "
|
|
171
|
+
f"needs to be updated (model_id={(model_id or '?')!r})."
|
|
172
|
+
)
|
|
146
173
|
return data[arch]
|
|
147
174
|
|
|
148
175
|
|
|
@@ -160,6 +187,7 @@ def task_from_id(
|
|
|
160
187
|
if the task cannot be determined
|
|
161
188
|
:param pretrained: uses the config
|
|
162
189
|
:param fall_back_to_pretrained: falls back to pretrained config
|
|
190
|
+
:param exc: raises an exception if True
|
|
163
191
|
:return: task
|
|
164
192
|
"""
|
|
165
193
|
if not pretrained:
|
|
@@ -175,9 +203,14 @@ def task_from_id(
|
|
|
175
203
|
guess = _guess_task_from_config(config)
|
|
176
204
|
if guess is not None:
|
|
177
205
|
return guess
|
|
206
|
+
data = load_architecture_task()
|
|
207
|
+
if model_id in data:
|
|
208
|
+
return data[model_id]
|
|
178
209
|
assert config.architectures is not None and len(config.architectures) == 1, (
|
|
179
210
|
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
|
|
180
|
-
f"architectures={config.architectures} in config={config}"
|
|
211
|
+
f"architectures={config.architectures} in config={config}. "
|
|
212
|
+
f"The task can be added in "
|
|
213
|
+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
|
|
181
214
|
)
|
|
182
215
|
return task_from_arch(config.architectures[0], default_value=default_value)
|
|
183
216
|
|
|
@@ -295,3 +328,43 @@ def enumerate_model_list(
|
|
|
295
328
|
n -= 1
|
|
296
329
|
if n == 0:
|
|
297
330
|
break
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def download_code_modelid(
|
|
334
|
+
model_id: str, verbose: int = 0, add_path_to_sys_path: bool = True
|
|
335
|
+
) -> List[str]:
|
|
336
|
+
"""
|
|
337
|
+
Downloads the code for a given model id.
|
|
338
|
+
|
|
339
|
+
:param model_id: model id
|
|
340
|
+
:param verbose: verbosity
|
|
341
|
+
:param add_path_to_sys_path: add folder where the files are downloaded to sys.path
|
|
342
|
+
:return: list of downloaded files
|
|
343
|
+
"""
|
|
344
|
+
if verbose:
|
|
345
|
+
print(f"[download_code_modelid] retrieve file list for {model_id!r}")
|
|
346
|
+
files = list_repo_files(model_id)
|
|
347
|
+
pyfiles = [name for name in files if os.path.splitext(name)[-1] == ".py"]
|
|
348
|
+
if verbose:
|
|
349
|
+
print(f"[download_code_modelid] python files {pyfiles}")
|
|
350
|
+
absfiles = []
|
|
351
|
+
paths = set()
|
|
352
|
+
for i, name in enumerate(pyfiles):
|
|
353
|
+
if verbose:
|
|
354
|
+
print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
|
|
355
|
+
r = hf_hub_download(repo_id=model_id, filename=name)
|
|
356
|
+
p = os.path.split(r)[0]
|
|
357
|
+
paths.add(p)
|
|
358
|
+
absfiles.append(r)
|
|
359
|
+
if add_path_to_sys_path:
|
|
360
|
+
for p in paths:
|
|
361
|
+
init = os.path.join(p, "__init__.py")
|
|
362
|
+
if not os.path.exists(init):
|
|
363
|
+
with open(init, "w"):
|
|
364
|
+
pass
|
|
365
|
+
if p in sys.path:
|
|
366
|
+
continue
|
|
367
|
+
if verbose:
|
|
368
|
+
print(f"[download_code_modelid] add {p!r} to 'sys.path'")
|
|
369
|
+
sys.path.insert(0, p)
|
|
370
|
+
return absfiles
|
|
@@ -3,7 +3,7 @@ import functools
|
|
|
3
3
|
import textwrap
|
|
4
4
|
from typing import Dict, List
|
|
5
5
|
|
|
6
|
-
__date__ = "2025-
|
|
6
|
+
__date__ = "2025-06-21"
|
|
7
7
|
|
|
8
8
|
__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
|
|
9
9
|
|
|
@@ -52,6 +52,8 @@ __data_arch__ = textwrap.dedent(
|
|
|
52
52
|
GPTNeoModel,feature-extraction
|
|
53
53
|
GPTNeoXForCausalLM,text-generation
|
|
54
54
|
GemmaForCausalLM,text-generation
|
|
55
|
+
Gemma2ForCausalLM,text-generation
|
|
56
|
+
Gemma3ForConditionalGeneration,image-text-to-text
|
|
55
57
|
GraniteForCausalLM,text-generation
|
|
56
58
|
GroupViTModel,feature-extraction
|
|
57
59
|
HieraForImageClassification,image-classification
|
|
@@ -97,6 +99,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
97
99
|
PegasusModel,feature-extraction
|
|
98
100
|
Phi3ForCausalLM,text-generation
|
|
99
101
|
PhiForCausalLM,text-generation
|
|
102
|
+
PhiMoEForCausalLM,text-generation
|
|
100
103
|
Pix2StructForConditionalGeneration,image-to-text
|
|
101
104
|
PLBartForConditionalGeneration,text2text-generation
|
|
102
105
|
PoolFormerModel,image-feature-extraction
|
|
@@ -144,7 +147,8 @@ __data_arch__ = textwrap.dedent(
|
|
|
144
147
|
XLMRobertaModel,sentence-similarity
|
|
145
148
|
Wav2Vec2ForCTC,automatic-speech-recognition
|
|
146
149
|
YolosForObjectDetection,object-detection
|
|
147
|
-
YolosModel,image-feature-extraction
|
|
150
|
+
YolosModel,image-feature-extraction
|
|
151
|
+
emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
|
|
148
152
|
)
|
|
149
153
|
|
|
150
154
|
__data_tasks__ = [
|