onnx-diagnostic 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -77
  3. onnx_diagnostic/doc.py +68 -0
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +59 -0
  6. onnx_diagnostic/helpers/config_helper.py +8 -4
  7. onnx_diagnostic/helpers/doc_helper.py +27 -7
  8. onnx_diagnostic/helpers/helper.py +30 -3
  9. onnx_diagnostic/helpers/log_helper.py +585 -0
  10. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  11. onnx_diagnostic/helpers/model_builder_helper.py +57 -73
  12. onnx_diagnostic/helpers/onnx_helper.py +291 -7
  13. onnx_diagnostic/helpers/torch_helper.py +18 -2
  14. onnx_diagnostic/reference/__init__.py +1 -0
  15. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  16. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  17. onnx_diagnostic/reference/torch_evaluator.py +23 -2
  18. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  19. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  20. onnx_diagnostic/tasks/fill_mask.py +3 -0
  21. onnx_diagnostic/tasks/image_classification.py +7 -1
  22. onnx_diagnostic/tasks/image_text_to_text.py +3 -0
  23. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  24. onnx_diagnostic/tasks/object_detection.py +3 -0
  25. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  26. onnx_diagnostic/tasks/summarization.py +3 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  28. onnx_diagnostic/tasks/text_classification.py +3 -0
  29. onnx_diagnostic/tasks/text_generation.py +90 -43
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  33. onnx_diagnostic/torch_export_patches/patch_module_helper.py +1 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  35. onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
  38. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  39. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +174 -114
  40. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
  41. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +44 -42
  42. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
  43. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  44. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,8 @@
1
+ import functools
2
+ import importlib
1
3
  import contextlib
2
- from typing import Any, Callable, Dict, List, Optional
4
+ import re
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
3
6
  from .onnx_export_serialization import (
4
7
  register_cache_serialization,
5
8
  unregister_cache_serialization,
@@ -7,6 +10,41 @@ from .onnx_export_serialization import (
7
10
  from .patches import patch_transformers as patch_transformers_list
8
11
 
9
12
 
13
+ def get_function(name: str) -> Tuple[type, Callable]:
14
+ """Returns the module and the function based on its name."""
15
+ spl = name.split(".")
16
+ module_name = ".".join(spl[:-1])
17
+ fname = spl[-1]
18
+ mod = importlib.import_module(module_name)
19
+ return mod, getattr(mod, fname)
20
+
21
+
22
+ @functools.lru_cache
23
+ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
24
+ """Returns the list of patches to make for a specific module."""
25
+ to_patch = []
26
+ for k in dir(mod):
27
+ if k.startswith("patched_"):
28
+ v = getattr(mod, k)
29
+ if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
30
+ to_patch.append(v)
31
+ else:
32
+ # a function
33
+ doc = v.__doc__.lstrip()
34
+ if doc.startswith("manual patch"):
35
+ continue
36
+ reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
37
+ fall = reg.findall(doc)
38
+ assert (
39
+ len(fall) == 1
40
+ ), f"Unable to find patching information for {v} in \n{doc}"
41
+ fmod, f = get_function(fall[0])
42
+ to_patch.append({"module": fmod, "function": f, "patch": v})
43
+
44
+ name = mod.__name__
45
+ return name, to_patch
46
+
47
+
10
48
  def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
11
49
  """
12
50
  Applies all patches defined in classes prefixed by ``patched_``
@@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
23
61
  to_patch = mod
24
62
  name = "list"
25
63
  else:
26
- to_patch = []
27
- for k in dir(mod):
28
- if k.startswith("patched_"):
29
- v = getattr(mod, k)
30
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
31
- to_patch.append(v)
32
- name = mod.__name__
64
+ name, to_patch = get_patches(mod, verbose)
33
65
 
34
66
  res = {}
35
67
  for cls in to_patch:
68
+ if isinstance(cls, dict):
69
+ # a function
70
+ keep = {}
71
+ original = cls["module"]
72
+ f = cls["function"]
73
+ res[f] = f
74
+ if verbose:
75
+ print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
76
+ setattr(original, f.__name__, cls["patch"])
77
+ continue
78
+
36
79
  original = cls._PATCHED_CLASS_
37
80
  methods = cls._PATCHES_
38
81
  if verbose:
@@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
57
100
  to_patch = mod
58
101
  name = "list"
59
102
  else:
60
- to_patch = []
61
- for k in dir(mod):
62
- if k.startswith("patched_"):
63
- v = getattr(mod, k)
64
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
65
- to_patch.append(v)
66
- name = mod.__name__
67
- set_patch = set(to_patch)
103
+ name, to_patch = get_patches(mod, verbose)
104
+
105
+ set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
106
+ dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
68
107
 
69
108
  for cls, methods in info.items():
70
- assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
109
+ if cls in set_patch_cls:
110
+ if verbose:
111
+ print(
112
+ f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
113
+ )
114
+ original = cls._PATCHED_CLASS_
115
+ for n, v in methods.items():
116
+ if v is None:
117
+ # The method did not exist. We remove it.
118
+ delattr(original, n)
119
+ else:
120
+ setattr(original, n, v)
121
+ continue
122
+ assert cls in dict_patch_fct, (
123
+ f"No patch registered for {cls} in {mod} "
124
+ f"(found {set_patch_cls} and {set(dict_patch_fct)})"
125
+ )
126
+ patch = dict_patch_fct[cls]
71
127
  if verbose:
72
- print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
73
- original = cls._PATCHED_CLASS_
74
- for n, v in methods.items():
75
- if v is None:
76
- # The method did not exist. We remove it.
77
- delattr(original, n)
78
- else:
79
- setattr(original, n, v)
128
+ print(
129
+ f"[unpatch_module_or_classes] function "
130
+ f"{patch['module'].__name__}.{cls.__name__}"
131
+ )
132
+ setattr(patch["module"], cls.__name__, patch["function"])
80
133
 
81
134
 
82
135
  @contextlib.contextmanager
@@ -9,9 +9,11 @@ from transformers.cache_utils import (
9
9
  MambaCache,
10
10
  EncoderDecoderCache,
11
11
  SlidingWindowCache,
12
+ StaticCache,
12
13
  )
13
14
  from transformers.modeling_outputs import BaseModelOutput
14
15
  from ..helpers import string_type
16
+ from ..helpers.cache_helper import make_static_cache
15
17
 
16
18
 
17
19
  PATCH_OF_PATCHES: Set[Any] = set()
@@ -175,6 +177,13 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]
175
177
  flatten_with_keys_sliding_window_cache,
176
178
  verbose=verbose,
177
179
  ),
180
+ StaticCache=register_class_serialization(
181
+ StaticCache,
182
+ flatten_static_cache,
183
+ unflatten_static_cache,
184
+ flatten_with_keys_static_cache,
185
+ verbose=verbose,
186
+ ),
178
187
  )
179
188
 
180
189
 
@@ -309,6 +318,34 @@ def unflatten_dynamic_cache(
309
318
  return cache
310
319
 
311
320
 
321
+ ##############
322
+ # DynamicCache
323
+ ##############
324
+
325
+
326
+ def flatten_static_cache(
327
+ cache: StaticCache,
328
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
329
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
330
+ flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
331
+ return [f[1] for f in flat], [f[0] for f in flat]
332
+
333
+
334
+ def flatten_with_keys_static_cache(
335
+ cache: StaticCache,
336
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
337
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
338
+ values, context = flatten_static_cache(cache)
339
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
340
+
341
+
342
+ def unflatten_static_cache(
343
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
344
+ ) -> StaticCache:
345
+ """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
346
+ return make_static_cache(list(zip(values[0], values[1])))
347
+
348
+
312
349
  ####################
313
350
  # SlidingWindowCache
314
351
  ####################
@@ -80,6 +80,7 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
80
80
  "AutoformerModel": "AutoformerEncoderLayer",
81
81
  "BartEncoderLayer": "BartEncoderLayer",
82
82
  "BartForConditionalGeneration": "BartEncoderLayer",
83
+ "BartModel": "BartEncoderLayer",
83
84
  "BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
84
85
  "BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
85
86
  "BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",
@@ -11,7 +11,7 @@ from ...helpers.torch_helper import is_torchdynamo_exporting
11
11
 
12
12
 
13
13
  def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
14
- """Patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
14
+ """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
15
15
  from ...helpers import string_type
16
16
 
17
17
  dimensions: List[Tuple[Optional[int], ...]] = [
@@ -534,19 +534,169 @@ class patched_GenerationMixin:
534
534
  return model_inputs
535
535
 
536
536
 
537
- def patched_dynamic_rope_update(rope_forward):
537
+ def patched__compute_dynamic_ntk_parameters(
538
+ config: Optional[transformers.PretrainedConfig] = None,
539
+ device: Optional["torch.device"] = None,
540
+ seq_len: Optional[int] = None,
541
+ **rope_kwargs,
542
+ ) -> Tuple["torch.Tensor", float]:
543
+ """
544
+ manual patch:
545
+ ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
546
+
547
+ Computes the inverse frequencies with NTK scaling.
548
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
549
+
550
+ Args:
551
+ config ([`~transformers.PretrainedConfig`]):
552
+ The model configuration.
553
+ device (`torch.device`):
554
+ The device to use for initialization of the inverse frequencies.
555
+ seq_len (`int`, *optional*):
556
+ The current sequence length,
557
+ used to update the dynamic RoPE at inference time.
558
+ rope_kwargs (`Dict`, *optional*):
559
+ BC compatibility with the previous
560
+ RoPE class instantiation, will be removed in v4.45.
561
+
562
+ Returns:
563
+ Tuple of (`torch.Tensor`, `float`),
564
+ containing the inverse frequencies for the RoPE embeddings and the
565
+ post-processing scaling factor applied to the
566
+ omputed cos/sin (unused in this type of RoPE).
538
567
  """
539
- patch:transformers.modeling_rope_utils.dynamic_rope_update
568
+ if config is not None and len(rope_kwargs) > 0:
569
+ raise ValueError(
570
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
571
+ f"`_compute_dynamic_ntk_parameters`, got "
572
+ f"`rope_kwargs`={rope_kwargs} and `config`={config}"
573
+ )
574
+ if len(rope_kwargs) > 0:
575
+ base = rope_kwargs["base"]
576
+ dim = rope_kwargs["dim"]
577
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
578
+ factor = rope_kwargs["factor"]
579
+ elif config is not None:
580
+ base = config.rope_theta
581
+ partial_rotary_factor = (
582
+ config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
583
+ )
584
+ head_dim = getattr(
585
+ config, "head_dim", config.hidden_size // config.num_attention_heads
586
+ )
587
+ dim = int(head_dim * partial_rotary_factor)
588
+ max_position_embeddings = config.max_position_embeddings
589
+ factor = config.rope_scaling["factor"]
590
+
591
+ attention_factor = 1.0 # Unused in this type of RoPE
592
+
593
+ # seq_len: default to max_position_embeddings, e.g. at init time
594
+ # seq_len = seq_len if seq_len is not None and
595
+ # seq_len > max_position_embeddings else max_position_embeddings
596
+ if seq_len is None:
597
+ seq_len = max_position_embeddings
598
+ else:
599
+ torch._check(isinstance(seq_len, torch.Tensor))
600
+ seq_len = torch.maximum(
601
+ seq_len,
602
+ torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
603
+ )
604
+
605
+ # Compute the inverse frequencies
606
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
607
+ dim / (dim - 2)
608
+ )
609
+ inv_freq = 1.0 / (
610
+ base
611
+ ** (
612
+ torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
613
+ / dim
614
+ )
615
+ )
616
+ return inv_freq, attention_factor
617
+
618
+
619
+ def patched_dynamic_rope_update(rope_forward):
620
+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
621
+
622
+ ``rope_type`` is determined in the constructor of class
623
+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
624
+
625
+ .. code-block:: python
626
+
627
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
628
+ self.rope_type = config.rope_scaling.get(
629
+ "rope_type", config.rope_scaling.get("type"))
630
+ else:
631
+ self.rope_type = "default"
632
+
633
+ The original code of the patched function:
634
+
635
+ .. code-block:: python
636
+
637
+ def dynamic_rope_update(rope_forward):
638
+ def longrope_frequency_update(self, position_ids, device):
639
+ seq_len = torch.max(position_ids) + 1
640
+ if hasattr(self.config, "original_max_position_embeddings"):
641
+ original_max_position_embeddings =
642
+ self.config.original_max_position_embeddings
643
+ else:
644
+ original_max_position_embeddings =
645
+ self.config.max_position_embeddings
646
+ if seq_len > original_max_position_embeddings:
647
+ if not hasattr(self, "long_inv_freq"):
648
+ self.long_inv_freq, _ = self.rope_init_fn(
649
+ self.config, device, seq_len=original_max_position_embeddings + 1
650
+ )
651
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
652
+ else:
653
+ self.original_inv_freq = self.original_inv_freq.to(device)
654
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
655
+
656
+ def dynamic_frequency_update(self, position_ids, device):
657
+ seq_len = torch.max(position_ids) + 1
658
+ if seq_len > self.max_seq_len_cached: # growth
659
+ inv_freq, self.attention_scaling = self.rope_init_fn(
660
+ self.config, device, seq_len=seq_len)
661
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
662
+ self.max_seq_len_cached = seq_len
663
+
664
+ if seq_len < self.original_max_seq_len and
665
+ self.max_seq_len_cached > self.original_max_seq_len:
666
+ self.original_inv_freq = self.original_inv_freq.to(device)
667
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
668
+ self.max_seq_len_cached = self.original_max_seq_len
669
+
670
+ @wraps(rope_forward)
671
+ def wrapper(self, x, position_ids):
672
+ if "dynamic" in self.rope_type:
673
+ dynamic_frequency_update(self, position_ids, device=x.device)
674
+ elif self.rope_type == "longrope":
675
+ longrope_frequency_update(self, position_ids, device=x.device)
676
+ return rope_forward(self, x, position_ids)
677
+
678
+ return wrapper
679
+
540
680
  """
541
681
 
542
682
  def longrope_frequency_update(self, position_ids, device):
683
+ # It is no use to patch the function after the model is created
684
+ # as rope_init_fn is an attribute set to one function when the model
685
+ # is created and when no patch is applied yet.
686
+ # So we select the patched version here.
687
+ rope_init_fn = (
688
+ patched__compute_dynamic_ntk_parameters
689
+ if self.rope_init_fn
690
+ is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
691
+ else self.rope_init_fn
692
+ )
543
693
  seq_len = torch.max(position_ids) + 1
544
694
  if hasattr(self.config, "original_max_position_embeddings"):
545
695
  original_max_position_embeddings = self.config.original_max_position_embeddings
546
696
  else:
547
697
  original_max_position_embeddings = self.config.max_position_embeddings
548
698
  # At export time, seq_len is unknown.
549
- long_inv_freq, _ = self.rope_init_fn(
699
+ long_inv_freq, _ = rope_init_fn(
550
700
  self.config, device, seq_len=original_max_position_embeddings + 1
551
701
  )
552
702
  original_inv_freq = self.original_inv_freq.to(device)
@@ -565,21 +715,70 @@ def patched_dynamic_rope_update(rope_forward):
565
715
  # self.inv_freq = self.original_inv_freq
566
716
 
567
717
  def dynamic_frequency_update(self, position_ids, device):
718
+ # constructor:
719
+ # - self.max_seq_len_cached = config.max_position_embeddings
720
+ # - self.original_max_seq_len = config.max_position_embeddings
721
+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
722
+
723
+ # It is no use to patch the function after the model is created
724
+ # as rope_init_fn is an attribute set to one function when the model
725
+ # is created and when no patch is applied yet.
726
+ # So we select the patched version here.
727
+ rope_init_fn = (
728
+ patched__compute_dynamic_ntk_parameters
729
+ if self.rope_init_fn
730
+ is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
731
+ else self.rope_init_fn
732
+ )
733
+
734
+ # This behaviour is difficult to translate.
735
+ # The sequence always grows.
736
+ # The test should always True.
737
+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
738
+ #
739
+ # if seq_len > self.max_seq_len_cached: # growth
740
+ # inv_freq, self.attention_scaling = self.rope_init_fn(
741
+ # self.config, device, seq_len=seq_len
742
+ # )
743
+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
744
+ # self.max_seq_len_cached = seq_len
745
+ #
746
+ # So we should not need what follows.
747
+ #
748
+ # cond = (seq_len > self.max_seq_len_cached).item()
749
+ # self.attention_scaling = torch.cond(
750
+ # cond,
751
+ # (lambda x, y: x.clone()),
752
+ # (lambda x, y: y.clone()),
753
+ # [attention_scaling, self.attention_scaling],
754
+ # )
755
+
568
756
  seq_len = torch.max(position_ids) + 1
569
- if seq_len > self.max_seq_len_cached: # growth
570
- inv_freq, self.attention_scaling = self.rope_init_fn(
571
- self.config, device, seq_len=seq_len
572
- )
573
- self.register_buffer("inv_freq", inv_freq, persistent=False)
574
- self.max_seq_len_cached = seq_len
757
+ long_inv_freq, self.attention_scaling = rope_init_fn(
758
+ self.config, device, seq_len=seq_len
759
+ )
575
760
 
576
- if (
577
- seq_len < self.original_max_seq_len
578
- and self.max_seq_len_cached > self.original_max_seq_len
579
- ):
580
- self.original_inv_freq = self.original_inv_freq.to(device)
581
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
582
- self.max_seq_len_cached = self.original_max_seq_len
761
+ # Second test to translate.
762
+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
763
+ # But in that case the following condition is a way to restore the original cache.
764
+
765
+ # if (
766
+ # seq_len < self.original_max_seq_len
767
+ # and self.max_seq_len_cached > self.original_max_seq_len
768
+ # ):
769
+ # self.original_inv_freq = self.original_inv_freq.to(device)
770
+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
771
+ # self.max_seq_len_cached = self.original_max_seq_len
772
+
773
+ original_inv_freq = self.original_inv_freq.to(device)
774
+ cond = (seq_len >= self.original_max_seq_len).item()
775
+ inv_freq = torch.cond(
776
+ cond,
777
+ (lambda x, y: x.clone()),
778
+ (lambda x, y: y.clone()),
779
+ [long_inv_freq, original_inv_freq],
780
+ )
781
+ self.inv_freq = inv_freq
583
782
 
584
783
  @wraps(rope_forward)
585
784
  def wrapper(self, x, position_ids):
@@ -619,3 +818,152 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
619
818
  sin = emb.sin() * self.attention_scaling
620
819
 
621
820
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
821
+
822
+
823
+ class patched_IdeficsEmbedding(torch.nn.Module):
824
+ _PATCHES_ = ["forward"]
825
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
826
+
827
+ def forward(self, x, seq_len=None):
828
+ # x: [bs, num_attention_heads, seq_len, head_size]
829
+ # if seq_len > self.max_seq_len_cached:
830
+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
831
+
832
+ def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
833
+ t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
834
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
835
+ emb = torch.cat((freqs, freqs), dim=-1)
836
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
837
+
838
+ def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
839
+ torch._check(seq_len.item() <= cos_cached.shape[0])
840
+ co = cos_cached[: seq_len.item()].detach().clone()
841
+ torch._check(seq_len.item() <= sin_cached.shape[0])
842
+ si = sin_cached[: seq_len.item()].detach().clone()
843
+ return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
844
+
845
+ cos_cached, sin_cached = torch.cond(
846
+ (seq_len > self.max_seq_len_cached).item(),
847
+ _set_cos_sin_cache_then,
848
+ _set_cos_sin_cache_else,
849
+ [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
850
+ )
851
+ return cos_cached, sin_cached
852
+
853
+
854
+ class patched_IdeficsAttention(torch.nn.Module):
855
+ _PATCHES_ = ["forward"]
856
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
857
+
858
+ def forward(
859
+ self,
860
+ hidden_states: torch.Tensor,
861
+ key_value_states: Optional[torch.Tensor] = None,
862
+ attention_mask: Optional[torch.Tensor] = None,
863
+ position_ids: Optional[torch.LongTensor] = None,
864
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
865
+ output_attentions: bool = False,
866
+ use_cache: bool = False,
867
+ cache_position: Optional[torch.LongTensor] = None,
868
+ **kwargs,
869
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
870
+ # if key_value_states are provided this layer is used as a cross-attention layer
871
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
872
+
873
+ bsz, q_len, _ = hidden_states.size()
874
+
875
+ query_states = (
876
+ self.q_proj(hidden_states)
877
+ .view(bsz, q_len, self.num_heads, self.head_dim)
878
+ .transpose(1, 2)
879
+ )
880
+ if not is_cross_attention:
881
+ key_states = (
882
+ self.k_proj(hidden_states)
883
+ .view(bsz, q_len, self.num_heads, self.head_dim)
884
+ .transpose(1, 2)
885
+ )
886
+ value_states = (
887
+ self.v_proj(hidden_states)
888
+ .view(bsz, q_len, self.num_heads, self.head_dim)
889
+ .transpose(1, 2)
890
+ )
891
+ else:
892
+ _, kv_len, _ = (
893
+ key_value_states.size()
894
+ ) # Note that, in this case, `kv_len` == `kv_seq_len`
895
+ key_states = (
896
+ self.k_proj(key_value_states)
897
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
898
+ .transpose(1, 2)
899
+ )
900
+ value_states = (
901
+ self.v_proj(key_value_states)
902
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
903
+ .transpose(1, 2)
904
+ )
905
+
906
+ kv_seq_len = key_states.shape[-2]
907
+ if past_key_value is not None:
908
+ kv_seq_len += cache_position[0]
909
+
910
+ if not is_cross_attention:
911
+ rotary_length = torch.maximum(
912
+ torch.tensor(kv_seq_len, dtype=torch.int64),
913
+ torch.tensor(q_len, dtype=torch.int64),
914
+ )
915
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
916
+ query_states, key_states = (
917
+ transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
918
+ query_states, key_states, cos, sin, position_ids
919
+ )
920
+ )
921
+ # [bsz, nh, t, hd]
922
+
923
+ if past_key_value is not None:
924
+ # sin and cos are specific to RoPE models;
925
+ # cache_position needed for the static cache
926
+ cache_kwargs = {"cache_position": cache_position}
927
+ key_states, value_states = past_key_value.update(
928
+ key_states, value_states, self.layer_idx, cache_kwargs
929
+ )
930
+
931
+ if self.qk_layer_norms:
932
+ query_states = self.q_layer_norm(query_states)
933
+ key_states = self.k_layer_norm(key_states)
934
+
935
+ attention_interface: Callable = (
936
+ transformers.models.idefics.modeling_idefics.eager_attention_forward
937
+ )
938
+
939
+ if self.config._attn_implementation != "eager":
940
+ if self.config._attn_implementation == "sdpa" and output_attentions:
941
+ transformers.models.idefics.modeling_idefics.logger.warning_once(
942
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
943
+ "`output_attentions=True`. Falling back to "
944
+ "eager attention. This warning can be removed using the argument "
945
+ '`attn_implementation="eager"` when loading the model.'
946
+ )
947
+ else:
948
+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
949
+ self.config._attn_implementation
950
+ ]
951
+
952
+ attn_output, attn_weights = attention_interface(
953
+ self,
954
+ query_states,
955
+ key_states,
956
+ value_states,
957
+ attention_mask,
958
+ dropout=0.0 if not self.training else self.dropout,
959
+ scaling=self.scaling,
960
+ **kwargs,
961
+ )
962
+
963
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
964
+ attn_output = self.o_proj(attn_output)
965
+
966
+ if output_attentions:
967
+ attn_weights = None
968
+
969
+ return attn_output, attn_weights, past_key_value
@@ -2,6 +2,7 @@ import copy
2
2
  import functools
3
3
  import json
4
4
  import os
5
+ import pprint
5
6
  from typing import Any, Dict, List, Optional, Union
6
7
  import transformers
7
8
  from huggingface_hub import HfApi, model_info, hf_hub_download
@@ -33,10 +34,14 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig
33
34
  return res
34
35
 
35
36
 
36
- def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.PretrainedConfig]:
37
+ def get_cached_configuration(
38
+ name: str, exc: bool = False, **kwargs
39
+ ) -> Optional[transformers.PretrainedConfig]:
37
40
  """
38
41
  Returns cached configuration to avoid having to many accesses to internet.
39
42
  It returns None if not Cache. The list of cached models follows.
43
+ If *exc* is True or if environment variable ``NOHTTP`` is defined,
44
+ the function raises an exception if *name* is not found.
40
45
 
41
46
  .. runpython::
42
47
 
@@ -54,8 +59,11 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
54
59
  conf = copy.deepcopy(conf)
55
60
  update_config(conf, kwargs)
56
61
  return conf
57
- if os.environ.get("NOHTTP", ""):
58
- raise AssertionError(f"Unable to find {name!r} in {sorted(cached)}")
62
+ assert not exc and not os.environ.get("NOHTTP", ""), (
63
+ f"Unable to find {name!r} (exc={exc}, "
64
+ f"NOHTTP={os.environ.get('NOHTTP', '')!r}) "
65
+ f"in {pprint.pformat(sorted(cached))}"
66
+ )
59
67
  return None
60
68
 
61
69
 
@@ -64,6 +72,7 @@ def get_pretrained_config(
64
72
  trust_remote_code: bool = True,
65
73
  use_preinstalled: bool = True,
66
74
  subfolder: Optional[str] = None,
75
+ use_only_preinstalled: bool = False,
67
76
  **kwargs,
68
77
  ) -> Any:
69
78
  """
@@ -77,13 +86,20 @@ def get_pretrained_config(
77
86
  :func:`get_cached_configuration`, the cached list is mostly for
78
87
  unit tests
79
88
  :param subfolder: subfolder for the given model id
89
+ :param use_only_preinstalled: if True, raises an exception if not preinstalled
80
90
  :param kwargs: additional kwargs
81
91
  :return: a configuration
82
92
  """
83
93
  if use_preinstalled:
84
- conf = get_cached_configuration(model_id, subfolder=subfolder, **kwargs)
94
+ conf = get_cached_configuration(
95
+ model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
96
+ )
85
97
  if conf is not None:
86
98
  return conf
99
+ assert not use_only_preinstalled, (
100
+ f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
101
+ f"use_preinstalled={use_preinstalled!r}"
102
+ )
87
103
  if subfolder:
88
104
  try:
89
105
  return transformers.AutoConfig.from_pretrained(