onnx-diagnostic 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +108 -77
- onnx_diagnostic/doc.py +68 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +59 -0
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/doc_helper.py +27 -7
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +585 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +57 -73
- onnx_diagnostic/helpers/onnx_helper.py +291 -7
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +23 -2
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +3 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +174 -114
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +44 -42
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import importlib
|
|
1
3
|
import contextlib
|
|
2
|
-
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
3
6
|
from .onnx_export_serialization import (
|
|
4
7
|
register_cache_serialization,
|
|
5
8
|
unregister_cache_serialization,
|
|
@@ -7,6 +10,41 @@ from .onnx_export_serialization import (
|
|
|
7
10
|
from .patches import patch_transformers as patch_transformers_list
|
|
8
11
|
|
|
9
12
|
|
|
13
|
+
def get_function(name: str) -> Tuple[type, Callable]:
|
|
14
|
+
"""Returns the module and the function based on its name."""
|
|
15
|
+
spl = name.split(".")
|
|
16
|
+
module_name = ".".join(spl[:-1])
|
|
17
|
+
fname = spl[-1]
|
|
18
|
+
mod = importlib.import_module(module_name)
|
|
19
|
+
return mod, getattr(mod, fname)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@functools.lru_cache
|
|
23
|
+
def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
|
|
24
|
+
"""Returns the list of patches to make for a specific module."""
|
|
25
|
+
to_patch = []
|
|
26
|
+
for k in dir(mod):
|
|
27
|
+
if k.startswith("patched_"):
|
|
28
|
+
v = getattr(mod, k)
|
|
29
|
+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
30
|
+
to_patch.append(v)
|
|
31
|
+
else:
|
|
32
|
+
# a function
|
|
33
|
+
doc = v.__doc__.lstrip()
|
|
34
|
+
if doc.startswith("manual patch"):
|
|
35
|
+
continue
|
|
36
|
+
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
|
|
37
|
+
fall = reg.findall(doc)
|
|
38
|
+
assert (
|
|
39
|
+
len(fall) == 1
|
|
40
|
+
), f"Unable to find patching information for {v} in \n{doc}"
|
|
41
|
+
fmod, f = get_function(fall[0])
|
|
42
|
+
to_patch.append({"module": fmod, "function": f, "patch": v})
|
|
43
|
+
|
|
44
|
+
name = mod.__name__
|
|
45
|
+
return name, to_patch
|
|
46
|
+
|
|
47
|
+
|
|
10
48
|
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
|
|
11
49
|
"""
|
|
12
50
|
Applies all patches defined in classes prefixed by ``patched_``
|
|
@@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
|
|
|
23
61
|
to_patch = mod
|
|
24
62
|
name = "list"
|
|
25
63
|
else:
|
|
26
|
-
to_patch =
|
|
27
|
-
for k in dir(mod):
|
|
28
|
-
if k.startswith("patched_"):
|
|
29
|
-
v = getattr(mod, k)
|
|
30
|
-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
31
|
-
to_patch.append(v)
|
|
32
|
-
name = mod.__name__
|
|
64
|
+
name, to_patch = get_patches(mod, verbose)
|
|
33
65
|
|
|
34
66
|
res = {}
|
|
35
67
|
for cls in to_patch:
|
|
68
|
+
if isinstance(cls, dict):
|
|
69
|
+
# a function
|
|
70
|
+
keep = {}
|
|
71
|
+
original = cls["module"]
|
|
72
|
+
f = cls["function"]
|
|
73
|
+
res[f] = f
|
|
74
|
+
if verbose:
|
|
75
|
+
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
|
|
76
|
+
setattr(original, f.__name__, cls["patch"])
|
|
77
|
+
continue
|
|
78
|
+
|
|
36
79
|
original = cls._PATCHED_CLASS_
|
|
37
80
|
methods = cls._PATCHES_
|
|
38
81
|
if verbose:
|
|
@@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
|
|
|
57
100
|
to_patch = mod
|
|
58
101
|
name = "list"
|
|
59
102
|
else:
|
|
60
|
-
to_patch =
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
65
|
-
to_patch.append(v)
|
|
66
|
-
name = mod.__name__
|
|
67
|
-
set_patch = set(to_patch)
|
|
103
|
+
name, to_patch = get_patches(mod, verbose)
|
|
104
|
+
|
|
105
|
+
set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
|
|
106
|
+
dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
|
|
68
107
|
|
|
69
108
|
for cls, methods in info.items():
|
|
70
|
-
|
|
109
|
+
if cls in set_patch_cls:
|
|
110
|
+
if verbose:
|
|
111
|
+
print(
|
|
112
|
+
f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
|
|
113
|
+
)
|
|
114
|
+
original = cls._PATCHED_CLASS_
|
|
115
|
+
for n, v in methods.items():
|
|
116
|
+
if v is None:
|
|
117
|
+
# The method did not exist. We remove it.
|
|
118
|
+
delattr(original, n)
|
|
119
|
+
else:
|
|
120
|
+
setattr(original, n, v)
|
|
121
|
+
continue
|
|
122
|
+
assert cls in dict_patch_fct, (
|
|
123
|
+
f"No patch registered for {cls} in {mod} "
|
|
124
|
+
f"(found {set_patch_cls} and {set(dict_patch_fct)})"
|
|
125
|
+
)
|
|
126
|
+
patch = dict_patch_fct[cls]
|
|
71
127
|
if verbose:
|
|
72
|
-
print(
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
delattr(original, n)
|
|
78
|
-
else:
|
|
79
|
-
setattr(original, n, v)
|
|
128
|
+
print(
|
|
129
|
+
f"[unpatch_module_or_classes] function "
|
|
130
|
+
f"{patch['module'].__name__}.{cls.__name__}"
|
|
131
|
+
)
|
|
132
|
+
setattr(patch["module"], cls.__name__, patch["function"])
|
|
80
133
|
|
|
81
134
|
|
|
82
135
|
@contextlib.contextmanager
|
|
@@ -9,9 +9,11 @@ from transformers.cache_utils import (
|
|
|
9
9
|
MambaCache,
|
|
10
10
|
EncoderDecoderCache,
|
|
11
11
|
SlidingWindowCache,
|
|
12
|
+
StaticCache,
|
|
12
13
|
)
|
|
13
14
|
from transformers.modeling_outputs import BaseModelOutput
|
|
14
15
|
from ..helpers import string_type
|
|
16
|
+
from ..helpers.cache_helper import make_static_cache
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
PATCH_OF_PATCHES: Set[Any] = set()
|
|
@@ -175,6 +177,13 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]
|
|
|
175
177
|
flatten_with_keys_sliding_window_cache,
|
|
176
178
|
verbose=verbose,
|
|
177
179
|
),
|
|
180
|
+
StaticCache=register_class_serialization(
|
|
181
|
+
StaticCache,
|
|
182
|
+
flatten_static_cache,
|
|
183
|
+
unflatten_static_cache,
|
|
184
|
+
flatten_with_keys_static_cache,
|
|
185
|
+
verbose=verbose,
|
|
186
|
+
),
|
|
178
187
|
)
|
|
179
188
|
|
|
180
189
|
|
|
@@ -309,6 +318,34 @@ def unflatten_dynamic_cache(
|
|
|
309
318
|
return cache
|
|
310
319
|
|
|
311
320
|
|
|
321
|
+
##############
|
|
322
|
+
# DynamicCache
|
|
323
|
+
##############
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def flatten_static_cache(
|
|
327
|
+
cache: StaticCache,
|
|
328
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
329
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
330
|
+
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
331
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def flatten_with_keys_static_cache(
|
|
335
|
+
cache: StaticCache,
|
|
336
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
337
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
338
|
+
values, context = flatten_static_cache(cache)
|
|
339
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def unflatten_static_cache(
|
|
343
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
344
|
+
) -> StaticCache:
|
|
345
|
+
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
346
|
+
return make_static_cache(list(zip(values[0], values[1])))
|
|
347
|
+
|
|
348
|
+
|
|
312
349
|
####################
|
|
313
350
|
# SlidingWindowCache
|
|
314
351
|
####################
|
|
@@ -80,6 +80,7 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
|
|
|
80
80
|
"AutoformerModel": "AutoformerEncoderLayer",
|
|
81
81
|
"BartEncoderLayer": "BartEncoderLayer",
|
|
82
82
|
"BartForConditionalGeneration": "BartEncoderLayer",
|
|
83
|
+
"BartModel": "BartEncoderLayer",
|
|
83
84
|
"BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
|
|
84
85
|
"BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
|
|
85
86
|
"BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",
|
|
@@ -11,7 +11,7 @@ from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
14
|
-
"""
|
|
14
|
+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
15
15
|
from ...helpers import string_type
|
|
16
16
|
|
|
17
17
|
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
@@ -534,19 +534,169 @@ class patched_GenerationMixin:
|
|
|
534
534
|
return model_inputs
|
|
535
535
|
|
|
536
536
|
|
|
537
|
-
def
|
|
537
|
+
def patched__compute_dynamic_ntk_parameters(
|
|
538
|
+
config: Optional[transformers.PretrainedConfig] = None,
|
|
539
|
+
device: Optional["torch.device"] = None,
|
|
540
|
+
seq_len: Optional[int] = None,
|
|
541
|
+
**rope_kwargs,
|
|
542
|
+
) -> Tuple["torch.Tensor", float]:
|
|
543
|
+
"""
|
|
544
|
+
manual patch:
|
|
545
|
+
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
546
|
+
|
|
547
|
+
Computes the inverse frequencies with NTK scaling.
|
|
548
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
552
|
+
The model configuration.
|
|
553
|
+
device (`torch.device`):
|
|
554
|
+
The device to use for initialization of the inverse frequencies.
|
|
555
|
+
seq_len (`int`, *optional*):
|
|
556
|
+
The current sequence length,
|
|
557
|
+
used to update the dynamic RoPE at inference time.
|
|
558
|
+
rope_kwargs (`Dict`, *optional*):
|
|
559
|
+
BC compatibility with the previous
|
|
560
|
+
RoPE class instantiation, will be removed in v4.45.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
Tuple of (`torch.Tensor`, `float`),
|
|
564
|
+
containing the inverse frequencies for the RoPE embeddings and the
|
|
565
|
+
post-processing scaling factor applied to the
|
|
566
|
+
omputed cos/sin (unused in this type of RoPE).
|
|
538
567
|
"""
|
|
539
|
-
|
|
568
|
+
if config is not None and len(rope_kwargs) > 0:
|
|
569
|
+
raise ValueError(
|
|
570
|
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
|
571
|
+
f"`_compute_dynamic_ntk_parameters`, got "
|
|
572
|
+
f"`rope_kwargs`={rope_kwargs} and `config`={config}"
|
|
573
|
+
)
|
|
574
|
+
if len(rope_kwargs) > 0:
|
|
575
|
+
base = rope_kwargs["base"]
|
|
576
|
+
dim = rope_kwargs["dim"]
|
|
577
|
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
|
578
|
+
factor = rope_kwargs["factor"]
|
|
579
|
+
elif config is not None:
|
|
580
|
+
base = config.rope_theta
|
|
581
|
+
partial_rotary_factor = (
|
|
582
|
+
config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
583
|
+
)
|
|
584
|
+
head_dim = getattr(
|
|
585
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
586
|
+
)
|
|
587
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
588
|
+
max_position_embeddings = config.max_position_embeddings
|
|
589
|
+
factor = config.rope_scaling["factor"]
|
|
590
|
+
|
|
591
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
592
|
+
|
|
593
|
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
|
594
|
+
# seq_len = seq_len if seq_len is not None and
|
|
595
|
+
# seq_len > max_position_embeddings else max_position_embeddings
|
|
596
|
+
if seq_len is None:
|
|
597
|
+
seq_len = max_position_embeddings
|
|
598
|
+
else:
|
|
599
|
+
torch._check(isinstance(seq_len, torch.Tensor))
|
|
600
|
+
seq_len = torch.maximum(
|
|
601
|
+
seq_len,
|
|
602
|
+
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Compute the inverse frequencies
|
|
606
|
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
|
|
607
|
+
dim / (dim - 2)
|
|
608
|
+
)
|
|
609
|
+
inv_freq = 1.0 / (
|
|
610
|
+
base
|
|
611
|
+
** (
|
|
612
|
+
torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
|
|
613
|
+
/ dim
|
|
614
|
+
)
|
|
615
|
+
)
|
|
616
|
+
return inv_freq, attention_factor
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def patched_dynamic_rope_update(rope_forward):
|
|
620
|
+
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
|
|
621
|
+
|
|
622
|
+
``rope_type`` is determined in the constructor of class
|
|
623
|
+
:class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
|
|
624
|
+
|
|
625
|
+
.. code-block:: python
|
|
626
|
+
|
|
627
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
628
|
+
self.rope_type = config.rope_scaling.get(
|
|
629
|
+
"rope_type", config.rope_scaling.get("type"))
|
|
630
|
+
else:
|
|
631
|
+
self.rope_type = "default"
|
|
632
|
+
|
|
633
|
+
The original code of the patched function:
|
|
634
|
+
|
|
635
|
+
.. code-block:: python
|
|
636
|
+
|
|
637
|
+
def dynamic_rope_update(rope_forward):
|
|
638
|
+
def longrope_frequency_update(self, position_ids, device):
|
|
639
|
+
seq_len = torch.max(position_ids) + 1
|
|
640
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
641
|
+
original_max_position_embeddings =
|
|
642
|
+
self.config.original_max_position_embeddings
|
|
643
|
+
else:
|
|
644
|
+
original_max_position_embeddings =
|
|
645
|
+
self.config.max_position_embeddings
|
|
646
|
+
if seq_len > original_max_position_embeddings:
|
|
647
|
+
if not hasattr(self, "long_inv_freq"):
|
|
648
|
+
self.long_inv_freq, _ = self.rope_init_fn(
|
|
649
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
650
|
+
)
|
|
651
|
+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
|
652
|
+
else:
|
|
653
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
654
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
655
|
+
|
|
656
|
+
def dynamic_frequency_update(self, position_ids, device):
|
|
657
|
+
seq_len = torch.max(position_ids) + 1
|
|
658
|
+
if seq_len > self.max_seq_len_cached: # growth
|
|
659
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
660
|
+
self.config, device, seq_len=seq_len)
|
|
661
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
662
|
+
self.max_seq_len_cached = seq_len
|
|
663
|
+
|
|
664
|
+
if seq_len < self.original_max_seq_len and
|
|
665
|
+
self.max_seq_len_cached > self.original_max_seq_len:
|
|
666
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
667
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
668
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
|
669
|
+
|
|
670
|
+
@wraps(rope_forward)
|
|
671
|
+
def wrapper(self, x, position_ids):
|
|
672
|
+
if "dynamic" in self.rope_type:
|
|
673
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
674
|
+
elif self.rope_type == "longrope":
|
|
675
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
676
|
+
return rope_forward(self, x, position_ids)
|
|
677
|
+
|
|
678
|
+
return wrapper
|
|
679
|
+
|
|
540
680
|
"""
|
|
541
681
|
|
|
542
682
|
def longrope_frequency_update(self, position_ids, device):
|
|
683
|
+
# It is no use to patch the function after the model is created
|
|
684
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
685
|
+
# is created and when no patch is applied yet.
|
|
686
|
+
# So we select the patched version here.
|
|
687
|
+
rope_init_fn = (
|
|
688
|
+
patched__compute_dynamic_ntk_parameters
|
|
689
|
+
if self.rope_init_fn
|
|
690
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
691
|
+
else self.rope_init_fn
|
|
692
|
+
)
|
|
543
693
|
seq_len = torch.max(position_ids) + 1
|
|
544
694
|
if hasattr(self.config, "original_max_position_embeddings"):
|
|
545
695
|
original_max_position_embeddings = self.config.original_max_position_embeddings
|
|
546
696
|
else:
|
|
547
697
|
original_max_position_embeddings = self.config.max_position_embeddings
|
|
548
698
|
# At export time, seq_len is unknown.
|
|
549
|
-
long_inv_freq, _ =
|
|
699
|
+
long_inv_freq, _ = rope_init_fn(
|
|
550
700
|
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
551
701
|
)
|
|
552
702
|
original_inv_freq = self.original_inv_freq.to(device)
|
|
@@ -565,21 +715,70 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
565
715
|
# self.inv_freq = self.original_inv_freq
|
|
566
716
|
|
|
567
717
|
def dynamic_frequency_update(self, position_ids, device):
|
|
718
|
+
# constructor:
|
|
719
|
+
# - self.max_seq_len_cached = config.max_position_embeddings
|
|
720
|
+
# - self.original_max_seq_len = config.max_position_embeddings
|
|
721
|
+
# - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
722
|
+
|
|
723
|
+
# It is no use to patch the function after the model is created
|
|
724
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
725
|
+
# is created and when no patch is applied yet.
|
|
726
|
+
# So we select the patched version here.
|
|
727
|
+
rope_init_fn = (
|
|
728
|
+
patched__compute_dynamic_ntk_parameters
|
|
729
|
+
if self.rope_init_fn
|
|
730
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
731
|
+
else self.rope_init_fn
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# This behaviour is difficult to translate.
|
|
735
|
+
# The sequence always grows.
|
|
736
|
+
# The test should always True.
|
|
737
|
+
# So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
|
|
738
|
+
#
|
|
739
|
+
# if seq_len > self.max_seq_len_cached: # growth
|
|
740
|
+
# inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
741
|
+
# self.config, device, seq_len=seq_len
|
|
742
|
+
# )
|
|
743
|
+
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
744
|
+
# self.max_seq_len_cached = seq_len
|
|
745
|
+
#
|
|
746
|
+
# So we should not need what follows.
|
|
747
|
+
#
|
|
748
|
+
# cond = (seq_len > self.max_seq_len_cached).item()
|
|
749
|
+
# self.attention_scaling = torch.cond(
|
|
750
|
+
# cond,
|
|
751
|
+
# (lambda x, y: x.clone()),
|
|
752
|
+
# (lambda x, y: y.clone()),
|
|
753
|
+
# [attention_scaling, self.attention_scaling],
|
|
754
|
+
# )
|
|
755
|
+
|
|
568
756
|
seq_len = torch.max(position_ids) + 1
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
)
|
|
573
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
574
|
-
self.max_seq_len_cached = seq_len
|
|
757
|
+
long_inv_freq, self.attention_scaling = rope_init_fn(
|
|
758
|
+
self.config, device, seq_len=seq_len
|
|
759
|
+
)
|
|
575
760
|
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
761
|
+
# Second test to translate.
|
|
762
|
+
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
|
|
763
|
+
# But in that case the following condition is a way to restore the original cache.
|
|
764
|
+
|
|
765
|
+
# if (
|
|
766
|
+
# seq_len < self.original_max_seq_len
|
|
767
|
+
# and self.max_seq_len_cached > self.original_max_seq_len
|
|
768
|
+
# ):
|
|
769
|
+
# self.original_inv_freq = self.original_inv_freq.to(device)
|
|
770
|
+
# self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
771
|
+
# self.max_seq_len_cached = self.original_max_seq_len
|
|
772
|
+
|
|
773
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
774
|
+
cond = (seq_len >= self.original_max_seq_len).item()
|
|
775
|
+
inv_freq = torch.cond(
|
|
776
|
+
cond,
|
|
777
|
+
(lambda x, y: x.clone()),
|
|
778
|
+
(lambda x, y: y.clone()),
|
|
779
|
+
[long_inv_freq, original_inv_freq],
|
|
780
|
+
)
|
|
781
|
+
self.inv_freq = inv_freq
|
|
583
782
|
|
|
584
783
|
@wraps(rope_forward)
|
|
585
784
|
def wrapper(self, x, position_ids):
|
|
@@ -619,3 +818,152 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
|
|
|
619
818
|
sin = emb.sin() * self.attention_scaling
|
|
620
819
|
|
|
621
820
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
class patched_IdeficsEmbedding(torch.nn.Module):
|
|
824
|
+
_PATCHES_ = ["forward"]
|
|
825
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
|
|
826
|
+
|
|
827
|
+
def forward(self, x, seq_len=None):
|
|
828
|
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
829
|
+
# if seq_len > self.max_seq_len_cached:
|
|
830
|
+
# self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
831
|
+
|
|
832
|
+
def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
|
|
833
|
+
t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
|
|
834
|
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
835
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
836
|
+
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
|
|
837
|
+
|
|
838
|
+
def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
|
|
839
|
+
torch._check(seq_len.item() <= cos_cached.shape[0])
|
|
840
|
+
co = cos_cached[: seq_len.item()].detach().clone()
|
|
841
|
+
torch._check(seq_len.item() <= sin_cached.shape[0])
|
|
842
|
+
si = sin_cached[: seq_len.item()].detach().clone()
|
|
843
|
+
return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
|
|
844
|
+
|
|
845
|
+
cos_cached, sin_cached = torch.cond(
|
|
846
|
+
(seq_len > self.max_seq_len_cached).item(),
|
|
847
|
+
_set_cos_sin_cache_then,
|
|
848
|
+
_set_cos_sin_cache_else,
|
|
849
|
+
[x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
|
|
850
|
+
)
|
|
851
|
+
return cos_cached, sin_cached
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
class patched_IdeficsAttention(torch.nn.Module):
|
|
855
|
+
_PATCHES_ = ["forward"]
|
|
856
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
|
|
857
|
+
|
|
858
|
+
def forward(
|
|
859
|
+
self,
|
|
860
|
+
hidden_states: torch.Tensor,
|
|
861
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
862
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
863
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
864
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
865
|
+
output_attentions: bool = False,
|
|
866
|
+
use_cache: bool = False,
|
|
867
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
868
|
+
**kwargs,
|
|
869
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
870
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
871
|
+
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
|
872
|
+
|
|
873
|
+
bsz, q_len, _ = hidden_states.size()
|
|
874
|
+
|
|
875
|
+
query_states = (
|
|
876
|
+
self.q_proj(hidden_states)
|
|
877
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
878
|
+
.transpose(1, 2)
|
|
879
|
+
)
|
|
880
|
+
if not is_cross_attention:
|
|
881
|
+
key_states = (
|
|
882
|
+
self.k_proj(hidden_states)
|
|
883
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
884
|
+
.transpose(1, 2)
|
|
885
|
+
)
|
|
886
|
+
value_states = (
|
|
887
|
+
self.v_proj(hidden_states)
|
|
888
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
889
|
+
.transpose(1, 2)
|
|
890
|
+
)
|
|
891
|
+
else:
|
|
892
|
+
_, kv_len, _ = (
|
|
893
|
+
key_value_states.size()
|
|
894
|
+
) # Note that, in this case, `kv_len` == `kv_seq_len`
|
|
895
|
+
key_states = (
|
|
896
|
+
self.k_proj(key_value_states)
|
|
897
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
898
|
+
.transpose(1, 2)
|
|
899
|
+
)
|
|
900
|
+
value_states = (
|
|
901
|
+
self.v_proj(key_value_states)
|
|
902
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
903
|
+
.transpose(1, 2)
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
kv_seq_len = key_states.shape[-2]
|
|
907
|
+
if past_key_value is not None:
|
|
908
|
+
kv_seq_len += cache_position[0]
|
|
909
|
+
|
|
910
|
+
if not is_cross_attention:
|
|
911
|
+
rotary_length = torch.maximum(
|
|
912
|
+
torch.tensor(kv_seq_len, dtype=torch.int64),
|
|
913
|
+
torch.tensor(q_len, dtype=torch.int64),
|
|
914
|
+
)
|
|
915
|
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
|
|
916
|
+
query_states, key_states = (
|
|
917
|
+
transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
|
|
918
|
+
query_states, key_states, cos, sin, position_ids
|
|
919
|
+
)
|
|
920
|
+
)
|
|
921
|
+
# [bsz, nh, t, hd]
|
|
922
|
+
|
|
923
|
+
if past_key_value is not None:
|
|
924
|
+
# sin and cos are specific to RoPE models;
|
|
925
|
+
# cache_position needed for the static cache
|
|
926
|
+
cache_kwargs = {"cache_position": cache_position}
|
|
927
|
+
key_states, value_states = past_key_value.update(
|
|
928
|
+
key_states, value_states, self.layer_idx, cache_kwargs
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
if self.qk_layer_norms:
|
|
932
|
+
query_states = self.q_layer_norm(query_states)
|
|
933
|
+
key_states = self.k_layer_norm(key_states)
|
|
934
|
+
|
|
935
|
+
attention_interface: Callable = (
|
|
936
|
+
transformers.models.idefics.modeling_idefics.eager_attention_forward
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
if self.config._attn_implementation != "eager":
|
|
940
|
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
|
941
|
+
transformers.models.idefics.modeling_idefics.logger.warning_once(
|
|
942
|
+
"`torch.nn.functional.scaled_dot_product_attention` does not support "
|
|
943
|
+
"`output_attentions=True`. Falling back to "
|
|
944
|
+
"eager attention. This warning can be removed using the argument "
|
|
945
|
+
'`attn_implementation="eager"` when loading the model.'
|
|
946
|
+
)
|
|
947
|
+
else:
|
|
948
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
949
|
+
self.config._attn_implementation
|
|
950
|
+
]
|
|
951
|
+
|
|
952
|
+
attn_output, attn_weights = attention_interface(
|
|
953
|
+
self,
|
|
954
|
+
query_states,
|
|
955
|
+
key_states,
|
|
956
|
+
value_states,
|
|
957
|
+
attention_mask,
|
|
958
|
+
dropout=0.0 if not self.training else self.dropout,
|
|
959
|
+
scaling=self.scaling,
|
|
960
|
+
**kwargs,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
964
|
+
attn_output = self.o_proj(attn_output)
|
|
965
|
+
|
|
966
|
+
if output_attentions:
|
|
967
|
+
attn_weights = None
|
|
968
|
+
|
|
969
|
+
return attn_output, attn_weights, past_key_value
|
|
@@ -2,6 +2,7 @@ import copy
|
|
|
2
2
|
import functools
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
|
+
import pprint
|
|
5
6
|
from typing import Any, Dict, List, Optional, Union
|
|
6
7
|
import transformers
|
|
7
8
|
from huggingface_hub import HfApi, model_info, hf_hub_download
|
|
@@ -33,10 +34,14 @@ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig
|
|
|
33
34
|
return res
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def get_cached_configuration(
|
|
37
|
+
def get_cached_configuration(
|
|
38
|
+
name: str, exc: bool = False, **kwargs
|
|
39
|
+
) -> Optional[transformers.PretrainedConfig]:
|
|
37
40
|
"""
|
|
38
41
|
Returns cached configuration to avoid having to many accesses to internet.
|
|
39
42
|
It returns None if not Cache. The list of cached models follows.
|
|
43
|
+
If *exc* is True or if environment variable ``NOHTTP`` is defined,
|
|
44
|
+
the function raises an exception if *name* is not found.
|
|
40
45
|
|
|
41
46
|
.. runpython::
|
|
42
47
|
|
|
@@ -54,8 +59,11 @@ def get_cached_configuration(name: str, **kwargs) -> Optional[transformers.Pretr
|
|
|
54
59
|
conf = copy.deepcopy(conf)
|
|
55
60
|
update_config(conf, kwargs)
|
|
56
61
|
return conf
|
|
57
|
-
|
|
58
|
-
|
|
62
|
+
assert not exc and not os.environ.get("NOHTTP", ""), (
|
|
63
|
+
f"Unable to find {name!r} (exc={exc}, "
|
|
64
|
+
f"NOHTTP={os.environ.get('NOHTTP', '')!r}) "
|
|
65
|
+
f"in {pprint.pformat(sorted(cached))}"
|
|
66
|
+
)
|
|
59
67
|
return None
|
|
60
68
|
|
|
61
69
|
|
|
@@ -64,6 +72,7 @@ def get_pretrained_config(
|
|
|
64
72
|
trust_remote_code: bool = True,
|
|
65
73
|
use_preinstalled: bool = True,
|
|
66
74
|
subfolder: Optional[str] = None,
|
|
75
|
+
use_only_preinstalled: bool = False,
|
|
67
76
|
**kwargs,
|
|
68
77
|
) -> Any:
|
|
69
78
|
"""
|
|
@@ -77,13 +86,20 @@ def get_pretrained_config(
|
|
|
77
86
|
:func:`get_cached_configuration`, the cached list is mostly for
|
|
78
87
|
unit tests
|
|
79
88
|
:param subfolder: subfolder for the given model id
|
|
89
|
+
:param use_only_preinstalled: if True, raises an exception if not preinstalled
|
|
80
90
|
:param kwargs: additional kwargs
|
|
81
91
|
:return: a configuration
|
|
82
92
|
"""
|
|
83
93
|
if use_preinstalled:
|
|
84
|
-
conf = get_cached_configuration(
|
|
94
|
+
conf = get_cached_configuration(
|
|
95
|
+
model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
|
|
96
|
+
)
|
|
85
97
|
if conf is not None:
|
|
86
98
|
return conf
|
|
99
|
+
assert not use_only_preinstalled, (
|
|
100
|
+
f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
|
|
101
|
+
f"use_preinstalled={use_preinstalled!r}"
|
|
102
|
+
)
|
|
87
103
|
if subfolder:
|
|
88
104
|
try:
|
|
89
105
|
return transformers.AutoConfig.from_pretrained(
|