onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2139 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import Callable, List, Optional, Tuple, Union
|
|
7
|
+
import packaging.version as pv
|
|
8
|
+
import torch
|
|
9
|
+
import transformers
|
|
10
|
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
11
|
+
from transformers.cache_utils import StaticCache, Cache
|
|
12
|
+
from transformers.generation.utils import (
|
|
13
|
+
GenerateNonBeamOutput,
|
|
14
|
+
GenerationConfig,
|
|
15
|
+
StoppingCriteriaList,
|
|
16
|
+
LogitsProcessorList,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from transformers.cache_utils import parse_processor_args # noqa: F401
|
|
21
|
+
|
|
22
|
+
patch_parse_processor_args = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
patch_parse_processor_args = False
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import transformers.masking_utils
|
|
28
|
+
|
|
29
|
+
patch_masking_utils = True
|
|
30
|
+
except ImportError:
|
|
31
|
+
patch_masking_utils = False
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
# transformers>= 4.55.1
|
|
36
|
+
from transformers.cache_utils import DynamicLayer
|
|
37
|
+
|
|
38
|
+
patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
|
|
39
|
+
except ImportError:
|
|
40
|
+
patch_DynamicLayer = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _has_transformers(version: str) -> bool:
|
|
44
|
+
return pv.Version(transformers.__version__) >= pv.Version(version)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _is_torchdynamo_exporting() -> bool:
|
|
48
|
+
"""
|
|
49
|
+
Tells if :epkg:`torch` is exporting a model.
|
|
50
|
+
Relies on ``torch.compiler.is_exporting()``.
|
|
51
|
+
"""
|
|
52
|
+
import torch
|
|
53
|
+
|
|
54
|
+
if not hasattr(torch.compiler, "is_exporting"):
|
|
55
|
+
# torch.compiler.is_exporting requires torch>=2.7
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
return torch.compiler.is_exporting()
|
|
60
|
+
except Exception:
|
|
61
|
+
try:
|
|
62
|
+
import torch._dynamo as dynamo
|
|
63
|
+
|
|
64
|
+
return dynamo.is_exporting() # type: ignore
|
|
65
|
+
except Exception:
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
patch_sdpa_is_causal = _has_transformers("4.99")
|
|
70
|
+
patch_is_initialized = _has_transformers("4.56.99")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
if patch_masking_utils:
|
|
74
|
+
# Introduced in 4.52
|
|
75
|
+
from transformers.masking_utils import (
|
|
76
|
+
_ignore_causal_mask_sdpa,
|
|
77
|
+
and_masks,
|
|
78
|
+
causal_mask_function,
|
|
79
|
+
padding_mask_function,
|
|
80
|
+
prepare_padding_mask,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
# transformers>=5.0
|
|
85
|
+
from transformers.masking_utils import (
|
|
86
|
+
_ignore_bidirectional_mask_sdpa,
|
|
87
|
+
bidirectional_mask_function,
|
|
88
|
+
)
|
|
89
|
+
except ImportError:
|
|
90
|
+
_ignore_bidirectional_mask_sdpa = None
|
|
91
|
+
bidirectional_mask_function = None
|
|
92
|
+
|
|
93
|
+
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
94
|
+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
95
|
+
from ...helpers import string_type
|
|
96
|
+
|
|
97
|
+
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
98
|
+
(None, None, None, 0),
|
|
99
|
+
(None, None, 0, None),
|
|
100
|
+
]
|
|
101
|
+
if bh_indices:
|
|
102
|
+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
|
103
|
+
# reshape
|
|
104
|
+
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
|
|
105
|
+
dimensions = tuple(reversed(dimensions))
|
|
106
|
+
indices = tuple(shape.index(-1) for shape in dimensions)
|
|
107
|
+
|
|
108
|
+
# unsqueeze
|
|
109
|
+
udimensions = [
|
|
110
|
+
tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
def vector_mask_function(
|
|
114
|
+
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
|
|
115
|
+
):
|
|
116
|
+
assert len(args) == len(dimensions) == len(udimensions), (
|
|
117
|
+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
|
|
118
|
+
f"and udimensions={udimensions}."
|
|
119
|
+
)
|
|
120
|
+
assert len(indices) == len(args), (
|
|
121
|
+
f"Mismatch between args={string_type(args)} and indices={indices}, "
|
|
122
|
+
f"they should have the same length."
|
|
123
|
+
)
|
|
124
|
+
for a in args:
|
|
125
|
+
assert (
|
|
126
|
+
a.ndim == 1
|
|
127
|
+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
|
|
128
|
+
torch._check(a.shape[0] > 0)
|
|
129
|
+
|
|
130
|
+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
|
|
131
|
+
# new_args = [
|
|
132
|
+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
133
|
+
# for a, dims in zip(args, udimensions)
|
|
134
|
+
# ]
|
|
135
|
+
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
136
|
+
# if _is_torchdynamo_exporting():
|
|
137
|
+
# for a in args:
|
|
138
|
+
# # The exporter should export with a dimension > 1
|
|
139
|
+
# # to make sure it is dynamic.
|
|
140
|
+
# torch._check(a.shape[0] > 1)
|
|
141
|
+
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
142
|
+
return mask_function(*expanded_args)
|
|
143
|
+
|
|
144
|
+
return vector_mask_function
|
|
145
|
+
|
|
146
|
+
def patched_eager_mask(
|
|
147
|
+
batch_size: int,
|
|
148
|
+
cache_position: torch.Tensor,
|
|
149
|
+
kv_length: int,
|
|
150
|
+
kv_offset: int = 0,
|
|
151
|
+
mask_function: Callable = causal_mask_function,
|
|
152
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
153
|
+
dtype: torch.dtype = torch.float32,
|
|
154
|
+
**kwargs,
|
|
155
|
+
) -> torch.Tensor:
|
|
156
|
+
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
157
|
+
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
158
|
+
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
159
|
+
_ = kwargs.pop("allow_is_bidirectional_skip", None)
|
|
160
|
+
# PATCHED: this line called the patched version of sdpa_mask
|
|
161
|
+
mask = patched_sdpa_mask_recent_torch(
|
|
162
|
+
batch_size=batch_size,
|
|
163
|
+
cache_position=cache_position,
|
|
164
|
+
kv_length=kv_length,
|
|
165
|
+
kv_offset=kv_offset,
|
|
166
|
+
mask_function=mask_function,
|
|
167
|
+
attention_mask=attention_mask,
|
|
168
|
+
allow_is_causal_skip=False,
|
|
169
|
+
allow_is_bidirectional_skip=False,
|
|
170
|
+
allow_torch_fix=False,
|
|
171
|
+
**kwargs,
|
|
172
|
+
)
|
|
173
|
+
min_dtype = torch.finfo(dtype).min
|
|
174
|
+
# PATCHED: the following line
|
|
175
|
+
# we need 0s where the tokens should be taken into account,
|
|
176
|
+
# and -inf otherwise (mask is already of boolean type)
|
|
177
|
+
# mask =
|
|
178
|
+
# torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
|
|
179
|
+
mask = (~mask).to(dtype) * min_dtype
|
|
180
|
+
return mask
|
|
181
|
+
|
|
182
|
+
def patched_sdpa_mask_recent_torch(
|
|
183
|
+
batch_size: int,
|
|
184
|
+
cache_position: torch.Tensor,
|
|
185
|
+
kv_length: int,
|
|
186
|
+
kv_offset: int = 0,
|
|
187
|
+
mask_function: Callable = causal_mask_function,
|
|
188
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
189
|
+
local_size: Optional[int] = None,
|
|
190
|
+
allow_is_causal_skip: bool = True,
|
|
191
|
+
allow_is_bidirectional_skip: bool = False,
|
|
192
|
+
**kwargs,
|
|
193
|
+
) -> Optional[torch.Tensor]:
|
|
194
|
+
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
195
|
+
q_length = cache_position.shape[0]
|
|
196
|
+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
|
197
|
+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
198
|
+
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
199
|
+
):
|
|
200
|
+
return None
|
|
201
|
+
if (
|
|
202
|
+
allow_is_bidirectional_skip
|
|
203
|
+
and _ignore_bidirectional_mask_sdpa
|
|
204
|
+
and _ignore_bidirectional_mask_sdpa(padding_mask)
|
|
205
|
+
):
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
if mask_function is bidirectional_mask_function:
|
|
209
|
+
if padding_mask is not None:
|
|
210
|
+
# used for slicing without data-dependent slicing
|
|
211
|
+
mask_indices = (
|
|
212
|
+
torch.arange(kv_length, device=cache_position.device) + kv_offset
|
|
213
|
+
)
|
|
214
|
+
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
|
|
215
|
+
return torch.ones(
|
|
216
|
+
batch_size,
|
|
217
|
+
1,
|
|
218
|
+
q_length,
|
|
219
|
+
kv_length,
|
|
220
|
+
dtype=torch.bool,
|
|
221
|
+
device=cache_position.device,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
|
225
|
+
kv_arange += kv_offset
|
|
226
|
+
if padding_mask is not None:
|
|
227
|
+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
|
228
|
+
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
|
229
|
+
head_arange = torch.arange(1, device=cache_position.device)
|
|
230
|
+
# PATCHED: this line calls the patched version of vmap_for_bhqkv
|
|
231
|
+
causal_mask = patched__vmap_for_bhqkv(mask_function)(
|
|
232
|
+
batch_arange, head_arange, cache_position, kv_arange
|
|
233
|
+
)
|
|
234
|
+
return causal_mask
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
if patch_parse_processor_args:
|
|
238
|
+
|
|
239
|
+
def _init_cache_inspect():
|
|
240
|
+
res = {}
|
|
241
|
+
for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
|
|
242
|
+
try:
|
|
243
|
+
params = list(inspect.signature(processor_class.__init__).parameters)[2:]
|
|
244
|
+
res[processor_class.__init__] = params
|
|
245
|
+
except Exception:
|
|
246
|
+
res[processor_class.__init__] = None
|
|
247
|
+
return res
|
|
248
|
+
|
|
249
|
+
_cache_inspect = _init_cache_inspect()
|
|
250
|
+
|
|
251
|
+
def patched_parse_processor_args(
|
|
252
|
+
processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
|
|
253
|
+
) -> tuple[dict, dict]:
|
|
254
|
+
"""[patch:transformers.cache_utils.parse_processor_args]"""
|
|
255
|
+
# If not patched...
|
|
256
|
+
# Fails with transformers>=4.54 because function ``parse_processor_args``
|
|
257
|
+
# relies in inspect and the exporter is not very fond of that.
|
|
258
|
+
# torch._dynamo.exc.Unsupported: id() with unsupported args
|
|
259
|
+
# Explanation: Dynamo doesn't know how to trace id()
|
|
260
|
+
# call with args
|
|
261
|
+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
|
|
262
|
+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
|
|
263
|
+
# objects from outside the compiled region.
|
|
264
|
+
# Hint: It may be possible to write Dynamo tracing rules for this code.
|
|
265
|
+
#
|
|
266
|
+
# The patch is caching the signature to avoid any call to inspect.
|
|
267
|
+
if processor_class is None:
|
|
268
|
+
return {}, kwargs
|
|
269
|
+
params = _cache_inspect[processor_class.__init__]
|
|
270
|
+
if params is None:
|
|
271
|
+
return {}, kwargs
|
|
272
|
+
processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
|
|
273
|
+
remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
|
|
274
|
+
return processor_kwargs, remaining_kwargs
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
if patch_DynamicLayer:
|
|
278
|
+
|
|
279
|
+
class patched_DynamicLayer:
|
|
280
|
+
_PATCHES_ = ["lazy_initialization"]
|
|
281
|
+
_PATCHED_CLASS_ = DynamicLayer
|
|
282
|
+
|
|
283
|
+
def lazy_initialization(self, key_states: torch.Tensor):
|
|
284
|
+
self.dtype, self.device = key_states.dtype, key_states.device
|
|
285
|
+
new_shape = list(key_states.shape)
|
|
286
|
+
new_shape[-2] = 0
|
|
287
|
+
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
|
|
288
|
+
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
289
|
+
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
290
|
+
if patch_is_initialized:
|
|
291
|
+
self.is_initialized = True
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _patch_make_causal_mask(
|
|
295
|
+
input_ids_shape: torch.Size,
|
|
296
|
+
dtype: torch.dtype,
|
|
297
|
+
device: torch.device,
|
|
298
|
+
past_key_values_length: int = 0,
|
|
299
|
+
sliding_window: Optional[int] = None,
|
|
300
|
+
):
|
|
301
|
+
"""Patched method."""
|
|
302
|
+
bsz, tgt_len = input_ids_shape
|
|
303
|
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
|
304
|
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
305
|
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
306
|
+
|
|
307
|
+
mask = mask.to(dtype)
|
|
308
|
+
|
|
309
|
+
if past_key_values_length > 0:
|
|
310
|
+
mask = torch.cat(
|
|
311
|
+
[
|
|
312
|
+
torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
|
|
313
|
+
mask,
|
|
314
|
+
],
|
|
315
|
+
dim=-1,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if sliding_window is not None:
|
|
319
|
+
diagonal = past_key_values_length - sliding_window - 1
|
|
320
|
+
|
|
321
|
+
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
|
322
|
+
# PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
|
|
323
|
+
# and used masked_fill instead of masked_fill_
|
|
324
|
+
# In this case, the current implementation of torch fails (17/12/2024).
|
|
325
|
+
# Try model Phi-3.5-Mini-Instruct.
|
|
326
|
+
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
|
|
327
|
+
|
|
328
|
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@dataclass
|
|
332
|
+
class patched_AttentionMaskConverter:
|
|
333
|
+
"""
|
|
334
|
+
Patches
|
|
335
|
+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
# This method was fixed in 4.51 at least.
|
|
339
|
+
_PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
|
|
340
|
+
_PATCHED_CLASS_ = AttentionMaskConverter
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def _make_causal_mask(
|
|
344
|
+
*args,
|
|
345
|
+
**kwargs,
|
|
346
|
+
# input_ids_shape: torch.Size,
|
|
347
|
+
# dtype: torch.dtype,
|
|
348
|
+
# device: torch.device,
|
|
349
|
+
# past_key_values_length: int = 0,
|
|
350
|
+
# sliding_window: Optional[int] = None,
|
|
351
|
+
):
|
|
352
|
+
"""
|
|
353
|
+
Patched method.
|
|
354
|
+
|
|
355
|
+
This static method may be called with ``AttentionMaskConverter._make_causal_mask``
|
|
356
|
+
or ``self._make_causal_mask``. That changes this argument is receives.
|
|
357
|
+
That should not matter but...
|
|
358
|
+
The patch should be implemented in another way. static methods do not play well
|
|
359
|
+
with a simple replacement.
|
|
360
|
+
Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
|
|
361
|
+
"""
|
|
362
|
+
if args:
|
|
363
|
+
index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
|
|
364
|
+
names = [
|
|
365
|
+
"input_ids_shape",
|
|
366
|
+
"dtype",
|
|
367
|
+
"device",
|
|
368
|
+
"past_key_values_length",
|
|
369
|
+
"sliding_window",
|
|
370
|
+
]
|
|
371
|
+
for i, a in enumerate(args):
|
|
372
|
+
if i < index:
|
|
373
|
+
continue
|
|
374
|
+
kwargs[names[i - index]] = a
|
|
375
|
+
return _patch_make_causal_mask(**kwargs)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
if pv.Version(transformers.__version__) < pv.Version("4.51"):
|
|
379
|
+
from typing import Any, Dict
|
|
380
|
+
from transformers.cache_utils import DynamicCache
|
|
381
|
+
|
|
382
|
+
class patched_DynamicCache:
|
|
383
|
+
"""
|
|
384
|
+
Applies modifications implemented in PR
|
|
385
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
389
|
+
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
390
|
+
|
|
391
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
392
|
+
"""Returns the sequence length of the cached states.
|
|
393
|
+
A layer index can be optionally passed."""
|
|
394
|
+
# TODO: deprecate this function in favor of `cache_position`
|
|
395
|
+
is_empty_layer = (
|
|
396
|
+
len(self.key_cache) == 0 # no cache in any layer
|
|
397
|
+
or len(self.key_cache)
|
|
398
|
+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
399
|
+
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
400
|
+
)
|
|
401
|
+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
402
|
+
return layer_seq_length
|
|
403
|
+
|
|
404
|
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
405
|
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
406
|
+
for layer_idx in range(len(self.key_cache)):
|
|
407
|
+
if self.key_cache[layer_idx].numel():
|
|
408
|
+
device = self.key_cache[layer_idx].device
|
|
409
|
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
410
|
+
0, beam_idx.to(device)
|
|
411
|
+
)
|
|
412
|
+
if self.value_cache[layer_idx].numel():
|
|
413
|
+
device = self.value_cache[layer_idx].device
|
|
414
|
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
415
|
+
0, beam_idx.to(device)
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
def update(
|
|
419
|
+
self,
|
|
420
|
+
key_states: torch.Tensor,
|
|
421
|
+
value_states: torch.Tensor,
|
|
422
|
+
layer_idx: int,
|
|
423
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
424
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
425
|
+
"""
|
|
426
|
+
Updates the cache with the new `key_states`
|
|
427
|
+
and `value_states` for the layer `layer_idx`.
|
|
428
|
+
Parameters:
|
|
429
|
+
key_states (`torch.Tensor`):
|
|
430
|
+
The new key states to cache.
|
|
431
|
+
value_states (`torch.Tensor`):
|
|
432
|
+
The new value states to cache.
|
|
433
|
+
layer_idx (`int`):
|
|
434
|
+
The index of the layer to cache the states for.
|
|
435
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
436
|
+
Additional arguments for the cache subclass.
|
|
437
|
+
No additional arguments are used in `DynamicCache`.
|
|
438
|
+
Return:
|
|
439
|
+
A tuple containing the updated key and value states.
|
|
440
|
+
"""
|
|
441
|
+
# Update the number of seen tokens
|
|
442
|
+
if layer_idx == 0:
|
|
443
|
+
if hasattr(self, "_seen_tokens"):
|
|
444
|
+
self._seen_tokens += key_states.shape[-2]
|
|
445
|
+
|
|
446
|
+
# Update the cache
|
|
447
|
+
if key_states is not None:
|
|
448
|
+
if len(self.key_cache) <= layer_idx:
|
|
449
|
+
# There may be skipped layers, fill them with empty lists
|
|
450
|
+
for _ in range(len(self.key_cache), layer_idx):
|
|
451
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
452
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
453
|
+
self.key_cache.append(key_states)
|
|
454
|
+
self.value_cache.append(value_states)
|
|
455
|
+
elif not self.key_cache[
|
|
456
|
+
layer_idx
|
|
457
|
+
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
458
|
+
# fills previously skipped layers; checking for tensor causes errors
|
|
459
|
+
self.key_cache[layer_idx] = key_states
|
|
460
|
+
self.value_cache[layer_idx] = value_states
|
|
461
|
+
else:
|
|
462
|
+
torch._check(
|
|
463
|
+
len(self.key_cache[layer_idx].shape) == len(key_states.shape),
|
|
464
|
+
lambda: (
|
|
465
|
+
f"Rank mismatch len(self.key_cache[layer_idx].shape)="
|
|
466
|
+
f"{len(self.key_cache[layer_idx].shape)}, "
|
|
467
|
+
f"len(key_states.shape)={len(key_states.shape)}"
|
|
468
|
+
),
|
|
469
|
+
)
|
|
470
|
+
self.key_cache[layer_idx] = torch.cat(
|
|
471
|
+
[self.key_cache[layer_idx], key_states], dim=-2
|
|
472
|
+
)
|
|
473
|
+
self.value_cache[layer_idx] = torch.cat(
|
|
474
|
+
[self.value_cache[layer_idx], value_states], dim=-2
|
|
475
|
+
)
|
|
476
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
477
|
+
|
|
478
|
+
def crop(self, max_length: int):
|
|
479
|
+
"""Crop the past key values up to a new `max_length`
|
|
480
|
+
in terms of tokens. `max_length` can also be
|
|
481
|
+
negative to remove `max_length` tokens.
|
|
482
|
+
This is used in assisted decoding and contrastive search.
|
|
483
|
+
"""
|
|
484
|
+
# In case it is negative
|
|
485
|
+
if max_length < 0:
|
|
486
|
+
max_length = self.get_seq_length() - abs(max_length)
|
|
487
|
+
|
|
488
|
+
if self.get_seq_length() <= max_length:
|
|
489
|
+
return
|
|
490
|
+
|
|
491
|
+
if hasattr(self, "_seen_tokens"):
|
|
492
|
+
self._seen_tokens = max_length
|
|
493
|
+
for idx in range(len(self.key_cache)):
|
|
494
|
+
if self.key_cache[idx].numel():
|
|
495
|
+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
|
496
|
+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
|
497
|
+
|
|
498
|
+
@classmethod
|
|
499
|
+
def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
|
|
500
|
+
"""This is the opposite of the above `batch_split()` method.
|
|
501
|
+
This will be used by `stack_model_outputs` in
|
|
502
|
+
`generation.utils`"""
|
|
503
|
+
cache = cls()
|
|
504
|
+
for idx in range(len(splits[0])):
|
|
505
|
+
key_cache = [
|
|
506
|
+
current.key_cache[idx]
|
|
507
|
+
for current in splits
|
|
508
|
+
if current.key_cache[idx].numel()
|
|
509
|
+
]
|
|
510
|
+
value_cache = [
|
|
511
|
+
current.value_cache[idx]
|
|
512
|
+
for current in splits
|
|
513
|
+
if current.value_cache[idx].numel()
|
|
514
|
+
]
|
|
515
|
+
if key_cache != []:
|
|
516
|
+
layer_keys = torch.cat(key_cache, dim=0)
|
|
517
|
+
layer_values = torch.cat(value_cache, dim=0)
|
|
518
|
+
cache.update(layer_keys, layer_values, idx)
|
|
519
|
+
return cache
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class patched_GenerationMixin:
|
|
523
|
+
"""
|
|
524
|
+
Applies modifications implemented in PR
|
|
525
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
526
|
+
"""
|
|
527
|
+
|
|
528
|
+
_PATCHES_ = [
|
|
529
|
+
"_cache_dependant_input_preparation",
|
|
530
|
+
"_cache_dependant_input_preparation_exporting",
|
|
531
|
+
(
|
|
532
|
+
None
|
|
533
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
534
|
+
else "prepare_inputs_for_generation"
|
|
535
|
+
),
|
|
536
|
+
(
|
|
537
|
+
"_sample"
|
|
538
|
+
if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
|
|
539
|
+
else None
|
|
540
|
+
),
|
|
541
|
+
]
|
|
542
|
+
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
|
|
543
|
+
|
|
544
|
+
def _cache_dependant_input_preparation(
|
|
545
|
+
self,
|
|
546
|
+
input_ids: torch.LongTensor,
|
|
547
|
+
inputs_embeds: Optional[torch.FloatTensor],
|
|
548
|
+
cache_position: Optional[torch.LongTensor],
|
|
549
|
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
|
550
|
+
"""
|
|
551
|
+
Generic cache-dependent input preparation
|
|
552
|
+
The code is put in a separate function to allow granular unit testing
|
|
553
|
+
as it needs a different implementation to be exportable.
|
|
554
|
+
|
|
555
|
+
If we have cache: let's slice `input_ids` through `cache_position`,
|
|
556
|
+
to keep only the unprocessed tokens
|
|
557
|
+
- Exception 1: when passing input_embeds,
|
|
558
|
+
input_ids may be missing entries
|
|
559
|
+
- Exception 2: some generation methods do special slicing of input_ids,
|
|
560
|
+
so we don't need to do it here
|
|
561
|
+
- Exception 3: with synced GPUs cache_position may go out of bounds,
|
|
562
|
+
but we only want dummy token in that case.
|
|
563
|
+
- Exception 4: If input_embeds are passed then slice it through
|
|
564
|
+
`cache_position`, to keep only the unprocessed tokens and
|
|
565
|
+
generate the first token for each sequence.
|
|
566
|
+
Later use the generated Input ids for continuation.
|
|
567
|
+
|
|
568
|
+
The current implementation does not rely on ``self`` and could be
|
|
569
|
+
a class method. It is left as a standard method to be easily rewritten.
|
|
570
|
+
"""
|
|
571
|
+
if _is_torchdynamo_exporting():
|
|
572
|
+
return self._cache_dependant_input_preparation_exporting(
|
|
573
|
+
input_ids, inputs_embeds, cache_position
|
|
574
|
+
)
|
|
575
|
+
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
|
576
|
+
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
|
577
|
+
elif inputs_embeds is not None or ( # Exception 1
|
|
578
|
+
cache_position[-1] >= input_ids.shape[1]
|
|
579
|
+
): # Exception 3
|
|
580
|
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
581
|
+
elif (
|
|
582
|
+
input_ids.shape[1] != cache_position.shape[0]
|
|
583
|
+
): # Default case (the "else", a no op, is Exception 2)
|
|
584
|
+
input_ids = input_ids[:, cache_position]
|
|
585
|
+
return inputs_embeds, input_ids
|
|
586
|
+
|
|
587
|
+
def _cache_dependant_input_preparation_exporting(
|
|
588
|
+
self,
|
|
589
|
+
input_ids: torch.LongTensor,
|
|
590
|
+
inputs_embeds: Optional[torch.FloatTensor],
|
|
591
|
+
cache_position: Optional[torch.LongTensor],
|
|
592
|
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
|
593
|
+
"""
|
|
594
|
+
This method implements method ``_cache_dependant_input_preparation``
|
|
595
|
+
with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
|
|
596
|
+
The code is put in a separate function to allow granular unit testing.
|
|
597
|
+
"""
|
|
598
|
+
if inputs_embeds is None:
|
|
599
|
+
input_ids = input_ids[:, cache_position]
|
|
600
|
+
else:
|
|
601
|
+
# This is the code we need to implemented with torch.cond.
|
|
602
|
+
# if input_ids.shape[1] == 0:
|
|
603
|
+
# inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
|
604
|
+
# else:
|
|
605
|
+
# if cache_position[-1] >= input_ids.shape[1]:
|
|
606
|
+
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
607
|
+
# else:
|
|
608
|
+
# if input_ids.shape[1] != cache_position.shape[0]:
|
|
609
|
+
# input_ids = input_ids[:, cache_position]
|
|
610
|
+
def branch_1(inputs_embeds, cache_position):
|
|
611
|
+
return inputs_embeds[:, -cache_position.shape[0] :]
|
|
612
|
+
|
|
613
|
+
def branch_2(input_ids, cache_position):
|
|
614
|
+
return input_ids[:, -cache_position.shape[0] :]
|
|
615
|
+
|
|
616
|
+
def branch_3(input_ids, cache_position):
|
|
617
|
+
return input_ids[:, cache_position]
|
|
618
|
+
|
|
619
|
+
inputs_embeds, input_ids = torch.cond(
|
|
620
|
+
input_ids.shape[1] == 0,
|
|
621
|
+
(
|
|
622
|
+
lambda input_ids, inputs_embeds, cache_position: (
|
|
623
|
+
branch_1(inputs_embeds, cache_position),
|
|
624
|
+
input_ids,
|
|
625
|
+
)
|
|
626
|
+
),
|
|
627
|
+
(
|
|
628
|
+
lambda input_ids, inputs_embeds, cache_position: (
|
|
629
|
+
inputs_embeds,
|
|
630
|
+
torch.cond(
|
|
631
|
+
cache_position[-1] >= input_ids.shape[1],
|
|
632
|
+
branch_2,
|
|
633
|
+
lambda input_ids, cache_position: (
|
|
634
|
+
torch.cond(
|
|
635
|
+
input_ids.shape[1] != cache_position.shape[0],
|
|
636
|
+
branch_3,
|
|
637
|
+
(lambda input_ids, cache_position: input_ids),
|
|
638
|
+
[input_ids, cache_position],
|
|
639
|
+
)
|
|
640
|
+
),
|
|
641
|
+
[input_ids, cache_position],
|
|
642
|
+
),
|
|
643
|
+
)
|
|
644
|
+
),
|
|
645
|
+
[input_ids, inputs_embeds, cache_position],
|
|
646
|
+
)
|
|
647
|
+
return inputs_embeds, input_ids
|
|
648
|
+
|
|
649
|
+
def prepare_inputs_for_generation(
|
|
650
|
+
self,
|
|
651
|
+
input_ids: torch.LongTensor,
|
|
652
|
+
past_key_values: Optional[Cache] = None,
|
|
653
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
654
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
655
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
656
|
+
**kwargs,
|
|
657
|
+
):
|
|
658
|
+
"""
|
|
659
|
+
Prepare the model inputs for generation.
|
|
660
|
+
In includes operations like computing the 4D attention mask or
|
|
661
|
+
slicing inputs given the existing cache.
|
|
662
|
+
|
|
663
|
+
See the forward pass in the model documentation
|
|
664
|
+
for expected arguments (different models might have different
|
|
665
|
+
requirements for e.g. `past_key_values`).
|
|
666
|
+
This function should work as is for most LLMs.
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
# 1. Handle BC:
|
|
670
|
+
model_inputs = {}
|
|
671
|
+
# - some models don't have `Cache` support
|
|
672
|
+
# (which implies they don't expect `cache_position` in `forward`)
|
|
673
|
+
if getattr(self, "_supports_cache_class", False):
|
|
674
|
+
model_inputs["cache_position"] = cache_position
|
|
675
|
+
# - `cache_position` was not a mandatory input in
|
|
676
|
+
# `prepare_inputs_for_generation` for those models, and this
|
|
677
|
+
# function may be called outside of `generate`.
|
|
678
|
+
# Handle most use cases by creating `cache_position` on the fly
|
|
679
|
+
# (this alternative is not as robust as calling
|
|
680
|
+
# `generate` and letting it create `cache_position`)
|
|
681
|
+
elif cache_position is None:
|
|
682
|
+
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
683
|
+
cache_position = torch.arange(
|
|
684
|
+
past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# 2. Generic cache-dependent input preparation
|
|
688
|
+
if past_key_values is not None:
|
|
689
|
+
model_inputs["past_key_values"] = past_key_values
|
|
690
|
+
inputs_embeds, input_ids = self._cache_dependant_input_preparation(
|
|
691
|
+
input_ids, inputs_embeds, cache_position
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
# 3. Prepare base model inputs
|
|
695
|
+
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
|
696
|
+
# if `inputs_embeds` are passed, we only want
|
|
697
|
+
# to use them in the 1st generation step for every prompt.
|
|
698
|
+
if not self.config.is_encoder_decoder:
|
|
699
|
+
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
|
700
|
+
model_inputs[input_ids_key] = None
|
|
701
|
+
model_inputs["inputs_embeds"] = inputs_embeds
|
|
702
|
+
else:
|
|
703
|
+
# `clone` calls in this function ensure a consistent stride. See #32227
|
|
704
|
+
model_inputs[input_ids_key] = input_ids.clone(
|
|
705
|
+
memory_format=torch.contiguous_format
|
|
706
|
+
)
|
|
707
|
+
model_inputs["inputs_embeds"] = None
|
|
708
|
+
else:
|
|
709
|
+
model_inputs[input_ids_key] = input_ids.clone(
|
|
710
|
+
memory_format=torch.contiguous_format
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
# 4. Create missing `position_ids` on the fly
|
|
714
|
+
encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
|
|
715
|
+
attention_mask = (
|
|
716
|
+
kwargs.pop("decoder_attention_mask", None)
|
|
717
|
+
if self.config.is_encoder_decoder
|
|
718
|
+
else attention_mask
|
|
719
|
+
)
|
|
720
|
+
attention_mask_key = (
|
|
721
|
+
"decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
|
|
722
|
+
)
|
|
723
|
+
position_ids_key = (
|
|
724
|
+
"decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
|
|
725
|
+
)
|
|
726
|
+
if (
|
|
727
|
+
attention_mask is not None
|
|
728
|
+
and kwargs.get(position_ids_key) is None
|
|
729
|
+
and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
|
|
730
|
+
):
|
|
731
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
732
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
733
|
+
kwargs[position_ids_key] = (
|
|
734
|
+
position_ids # placed in kwargs for further processing (see below)
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# 5. Slice model inputs if it's an input
|
|
738
|
+
# that should have the same length as `input_ids`
|
|
739
|
+
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
|
|
740
|
+
model_input = kwargs.get(model_input_name)
|
|
741
|
+
if model_input is not None:
|
|
742
|
+
if past_key_values is not None:
|
|
743
|
+
current_input_length = (
|
|
744
|
+
model_inputs["inputs_embeds"].shape[1]
|
|
745
|
+
if model_inputs.get("inputs_embeds") is not None
|
|
746
|
+
else model_inputs[input_ids_key].shape[1]
|
|
747
|
+
)
|
|
748
|
+
model_input = model_input[:, -current_input_length:]
|
|
749
|
+
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
|
750
|
+
model_inputs[model_input_name] = model_input
|
|
751
|
+
|
|
752
|
+
# 6. Create 4D attention mask is we are using a
|
|
753
|
+
# `StaticCache` (important for performant compiled forward pass)
|
|
754
|
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
|
755
|
+
if model_inputs["inputs_embeds"] is not None:
|
|
756
|
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
|
757
|
+
device = model_inputs["inputs_embeds"].device
|
|
758
|
+
else:
|
|
759
|
+
batch_size, sequence_length = model_inputs[input_ids_key].shape
|
|
760
|
+
device = model_inputs[input_ids_key].device
|
|
761
|
+
|
|
762
|
+
# Create the causal mask with fixed shape in advance,
|
|
763
|
+
# to reduce recompilations. If the function to create
|
|
764
|
+
# the 4D causal mask exists,
|
|
765
|
+
# it should be present in the base model (XXXModel class).
|
|
766
|
+
base_model = getattr(self, self.base_model_prefix, None)
|
|
767
|
+
if base_model is None:
|
|
768
|
+
causal_mask_creation_function = getattr(
|
|
769
|
+
self, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
|
770
|
+
)
|
|
771
|
+
else:
|
|
772
|
+
causal_mask_creation_function = getattr(
|
|
773
|
+
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
|
774
|
+
)
|
|
775
|
+
if causal_mask_creation_function is None:
|
|
776
|
+
pass
|
|
777
|
+
# logger.warning_once(
|
|
778
|
+
# f"{self.__class__.__name__} has no "
|
|
779
|
+
# "`_prepare_4d_causal_attention_mask_with_cache_position` method "
|
|
780
|
+
# "defined in its base modeling class. "
|
|
781
|
+
# "Compiled forward passes will be sub-optimal. If you're "
|
|
782
|
+
# "writing code, see Llama for an example implementation. "
|
|
783
|
+
# "If you're a user, please report this "
|
|
784
|
+
# "issue on GitHub."
|
|
785
|
+
# )
|
|
786
|
+
else:
|
|
787
|
+
attention_mask = causal_mask_creation_function(
|
|
788
|
+
attention_mask,
|
|
789
|
+
sequence_length=sequence_length,
|
|
790
|
+
target_length=past_key_values.get_max_cache_shape(),
|
|
791
|
+
dtype=self.dtype,
|
|
792
|
+
device=device,
|
|
793
|
+
cache_position=cache_position,
|
|
794
|
+
batch_size=batch_size,
|
|
795
|
+
config=self.config,
|
|
796
|
+
past_key_values=past_key_values,
|
|
797
|
+
)
|
|
798
|
+
if attention_mask is not None:
|
|
799
|
+
model_inputs[attention_mask_key] = attention_mask
|
|
800
|
+
|
|
801
|
+
if encoder_attention_mask is not None:
|
|
802
|
+
model_inputs["attention_mask"] = encoder_attention_mask
|
|
803
|
+
|
|
804
|
+
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
|
805
|
+
for key, value in kwargs.items():
|
|
806
|
+
if key not in model_inputs:
|
|
807
|
+
model_inputs[key] = value
|
|
808
|
+
|
|
809
|
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
|
810
|
+
model_inputs.pop("labels", None)
|
|
811
|
+
return model_inputs
|
|
812
|
+
|
|
813
|
+
def _sample(
|
|
814
|
+
self,
|
|
815
|
+
input_ids: torch.LongTensor,
|
|
816
|
+
logits_processor: "LogitsProcessorList", # noqa: F821
|
|
817
|
+
stopping_criteria: "StoppingCriteriaList", # noqa: F821
|
|
818
|
+
generation_config: "GenerationConfig", # noqa: F821
|
|
819
|
+
synced_gpus: bool = False,
|
|
820
|
+
streamer: Optional["BaseStreamer"] = None, # noqa: F821
|
|
821
|
+
**model_kwargs,
|
|
822
|
+
) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
|
|
823
|
+
"""
|
|
824
|
+
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
|
|
825
|
+
"""
|
|
826
|
+
# init values
|
|
827
|
+
pad_token_id = generation_config._pad_token_tensor
|
|
828
|
+
output_attentions = generation_config.output_attentions
|
|
829
|
+
output_hidden_states = generation_config.output_hidden_states
|
|
830
|
+
output_scores = generation_config.output_scores
|
|
831
|
+
output_logits = generation_config.output_logits
|
|
832
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
833
|
+
has_eos_stopping_criteria = any(
|
|
834
|
+
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
|
|
835
|
+
)
|
|
836
|
+
do_sample = generation_config.do_sample
|
|
837
|
+
|
|
838
|
+
# init attention / hidden states / scores tuples
|
|
839
|
+
scores = () if (return_dict_in_generate and output_scores) else None
|
|
840
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
841
|
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
842
|
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
843
|
+
decoder_hidden_states = (
|
|
844
|
+
() if (return_dict_in_generate and output_hidden_states) else None
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
848
|
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
849
|
+
encoder_attentions = (
|
|
850
|
+
model_kwargs["encoder_outputs"].get("attentions")
|
|
851
|
+
if output_attentions
|
|
852
|
+
else None
|
|
853
|
+
)
|
|
854
|
+
encoder_hidden_states = (
|
|
855
|
+
model_kwargs["encoder_outputs"].get("hidden_states")
|
|
856
|
+
if output_hidden_states
|
|
857
|
+
else None
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
# keep track of which sequences are already finished
|
|
861
|
+
batch_size, cur_len = input_ids.shape[:2]
|
|
862
|
+
this_peer_finished = False
|
|
863
|
+
unfinished_sequences = torch.ones(
|
|
864
|
+
batch_size, dtype=torch.long, device=input_ids.device
|
|
865
|
+
)
|
|
866
|
+
model_kwargs = self._get_initial_cache_position(
|
|
867
|
+
cur_len, input_ids.device, model_kwargs
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
model_forward = self.__call__
|
|
871
|
+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
|
872
|
+
if compile_forward:
|
|
873
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
|
874
|
+
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
|
875
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
876
|
+
# only raise warning if the user passed an explicit compile-config
|
|
877
|
+
if (
|
|
878
|
+
generation_config.compile_config is not None
|
|
879
|
+
and generation_config.compile_config.fullgraph
|
|
880
|
+
):
|
|
881
|
+
generation_config.compile_config.fullgraph = False
|
|
882
|
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
|
883
|
+
|
|
884
|
+
if generation_config.prefill_chunk_size is not None:
|
|
885
|
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
|
886
|
+
is_prefill = False
|
|
887
|
+
else:
|
|
888
|
+
is_prefill = True
|
|
889
|
+
|
|
890
|
+
while self._has_unfinished_sequences(
|
|
891
|
+
this_peer_finished, synced_gpus, device=input_ids.device
|
|
892
|
+
):
|
|
893
|
+
# prepare model inputs
|
|
894
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
895
|
+
|
|
896
|
+
if is_prefill:
|
|
897
|
+
outputs = self(**model_inputs, return_dict=True)
|
|
898
|
+
is_prefill = False
|
|
899
|
+
else:
|
|
900
|
+
outputs = model_forward(**model_inputs, return_dict=True)
|
|
901
|
+
|
|
902
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
903
|
+
outputs,
|
|
904
|
+
model_kwargs,
|
|
905
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
906
|
+
)
|
|
907
|
+
if synced_gpus and this_peer_finished:
|
|
908
|
+
continue
|
|
909
|
+
|
|
910
|
+
next_token_logits = outputs.logits[:, -1, :].to(
|
|
911
|
+
copy=True, dtype=torch.float32, device=input_ids.device
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
# pre-process distribution
|
|
915
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
916
|
+
|
|
917
|
+
# Store scores, attentions and hidden_states when required
|
|
918
|
+
if return_dict_in_generate:
|
|
919
|
+
if output_scores:
|
|
920
|
+
scores += (next_token_scores,)
|
|
921
|
+
if output_logits:
|
|
922
|
+
raw_logits += (next_token_logits,)
|
|
923
|
+
if output_attentions:
|
|
924
|
+
decoder_attentions += (
|
|
925
|
+
(outputs.decoder_attentions,)
|
|
926
|
+
if self.config.is_encoder_decoder
|
|
927
|
+
else (outputs.attentions,)
|
|
928
|
+
)
|
|
929
|
+
if self.config.is_encoder_decoder:
|
|
930
|
+
cross_attentions += (outputs.cross_attentions,)
|
|
931
|
+
|
|
932
|
+
if output_hidden_states:
|
|
933
|
+
decoder_hidden_states += (
|
|
934
|
+
(outputs.decoder_hidden_states,)
|
|
935
|
+
if self.config.is_encoder_decoder
|
|
936
|
+
else (outputs.hidden_states,)
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# token selection
|
|
940
|
+
if do_sample:
|
|
941
|
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
|
942
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
943
|
+
else:
|
|
944
|
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
945
|
+
|
|
946
|
+
# finished sentences should have their next token be a padding token
|
|
947
|
+
if has_eos_stopping_criteria:
|
|
948
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
949
|
+
1 - unfinished_sequences
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
# update generated ids, model inputs, and length for next step
|
|
953
|
+
# PATCHED: the two following lines, next_tokens can 2D already for this model
|
|
954
|
+
next_tokens_2d = (
|
|
955
|
+
next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
|
|
956
|
+
)
|
|
957
|
+
input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
|
|
958
|
+
if streamer is not None:
|
|
959
|
+
streamer.put(next_tokens.cpu())
|
|
960
|
+
|
|
961
|
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
962
|
+
this_peer_finished = unfinished_sequences.max() == 0
|
|
963
|
+
cur_len += 1
|
|
964
|
+
|
|
965
|
+
# This is needed to properly delete outputs.logits which may be very large
|
|
966
|
+
# for first iteration
|
|
967
|
+
# Otherwise a reference to outputs is kept which keeps
|
|
968
|
+
# the logits alive in the next iteration
|
|
969
|
+
del outputs
|
|
970
|
+
|
|
971
|
+
if streamer is not None:
|
|
972
|
+
streamer.end()
|
|
973
|
+
|
|
974
|
+
if return_dict_in_generate:
|
|
975
|
+
if self.config.is_encoder_decoder:
|
|
976
|
+
return transformers.generation.utils.GenerateEncoderDecoderOutput(
|
|
977
|
+
sequences=input_ids,
|
|
978
|
+
scores=scores,
|
|
979
|
+
logits=raw_logits,
|
|
980
|
+
encoder_attentions=encoder_attentions,
|
|
981
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
982
|
+
decoder_attentions=decoder_attentions,
|
|
983
|
+
cross_attentions=cross_attentions,
|
|
984
|
+
decoder_hidden_states=decoder_hidden_states,
|
|
985
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
986
|
+
)
|
|
987
|
+
else:
|
|
988
|
+
return transformers.generation.utils.GenerateDecoderOnlyOutput(
|
|
989
|
+
sequences=input_ids,
|
|
990
|
+
scores=scores,
|
|
991
|
+
logits=raw_logits,
|
|
992
|
+
attentions=decoder_attentions,
|
|
993
|
+
hidden_states=decoder_hidden_states,
|
|
994
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
995
|
+
)
|
|
996
|
+
else:
|
|
997
|
+
return input_ids
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
def patched__compute_dynamic_ntk_parameters(
|
|
1001
|
+
config: Optional[transformers.PretrainedConfig] = None,
|
|
1002
|
+
device: Optional["torch.device"] = None,
|
|
1003
|
+
seq_len: Optional[int] = None,
|
|
1004
|
+
**rope_kwargs,
|
|
1005
|
+
) -> Tuple["torch.Tensor", float]:
|
|
1006
|
+
"""
|
|
1007
|
+
manual patch:
|
|
1008
|
+
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
1009
|
+
|
|
1010
|
+
Computes the inverse frequencies with NTK scaling.
|
|
1011
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
|
1012
|
+
|
|
1013
|
+
Args:
|
|
1014
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
1015
|
+
The model configuration.
|
|
1016
|
+
device (`torch.device`):
|
|
1017
|
+
The device to use for initialization of the inverse frequencies.
|
|
1018
|
+
seq_len (`int`, *optional*):
|
|
1019
|
+
The current sequence length,
|
|
1020
|
+
used to update the dynamic RoPE at inference time.
|
|
1021
|
+
rope_kwargs (`Dict`, *optional*):
|
|
1022
|
+
BC compatibility with the previous
|
|
1023
|
+
RoPE class instantiation, will be removed in v4.45.
|
|
1024
|
+
|
|
1025
|
+
Returns:
|
|
1026
|
+
Tuple of (`torch.Tensor`, `float`),
|
|
1027
|
+
containing the inverse frequencies for the RoPE embeddings and the
|
|
1028
|
+
post-processing scaling factor applied to the
|
|
1029
|
+
omputed cos/sin (unused in this type of RoPE).
|
|
1030
|
+
"""
|
|
1031
|
+
if config is not None and len(rope_kwargs) > 0:
|
|
1032
|
+
raise ValueError(
|
|
1033
|
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
|
1034
|
+
f"`_compute_dynamic_ntk_parameters`, got "
|
|
1035
|
+
f"`rope_kwargs`={rope_kwargs} and `config`={config}"
|
|
1036
|
+
)
|
|
1037
|
+
if len(rope_kwargs) > 0:
|
|
1038
|
+
base = rope_kwargs["base"]
|
|
1039
|
+
dim = rope_kwargs["dim"]
|
|
1040
|
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
|
1041
|
+
factor = rope_kwargs["factor"]
|
|
1042
|
+
elif config is not None:
|
|
1043
|
+
base = config.rope_theta
|
|
1044
|
+
partial_rotary_factor = (
|
|
1045
|
+
config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
1046
|
+
)
|
|
1047
|
+
head_dim = getattr(
|
|
1048
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
1049
|
+
)
|
|
1050
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
1051
|
+
max_position_embeddings = config.max_position_embeddings
|
|
1052
|
+
factor = config.rope_scaling["factor"]
|
|
1053
|
+
|
|
1054
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
1055
|
+
|
|
1056
|
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
|
1057
|
+
# seq_len = seq_len if seq_len is not None and
|
|
1058
|
+
# seq_len > max_position_embeddings else max_position_embeddings
|
|
1059
|
+
if seq_len is None:
|
|
1060
|
+
seq_len = max_position_embeddings
|
|
1061
|
+
else:
|
|
1062
|
+
# PATCHED: remove the line using max
|
|
1063
|
+
torch._check(isinstance(seq_len, torch.Tensor))
|
|
1064
|
+
seq_len = torch.maximum(
|
|
1065
|
+
seq_len,
|
|
1066
|
+
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
# Compute the inverse frequencies
|
|
1070
|
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
|
|
1071
|
+
dim / (dim - 2)
|
|
1072
|
+
)
|
|
1073
|
+
inv_freq = 1.0 / (
|
|
1074
|
+
base
|
|
1075
|
+
** (
|
|
1076
|
+
torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
|
|
1077
|
+
/ dim
|
|
1078
|
+
)
|
|
1079
|
+
)
|
|
1080
|
+
return inv_freq, attention_factor
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
def _get_rope_init_fn(self, layer_type=None) -> Callable:
|
|
1084
|
+
if hasattr(self, "rope_init_fn"):
|
|
1085
|
+
# transformers<=5.0
|
|
1086
|
+
rope_init_fn = (
|
|
1087
|
+
patched__compute_dynamic_ntk_parameters
|
|
1088
|
+
if self.rope_init_fn
|
|
1089
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
1090
|
+
else self.rope_init_fn
|
|
1091
|
+
)
|
|
1092
|
+
return rope_init_fn
|
|
1093
|
+
|
|
1094
|
+
rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
|
|
1095
|
+
rope_init_fn = self.compute_default_rope_parameters
|
|
1096
|
+
if rope_type != "default":
|
|
1097
|
+
rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
1098
|
+
if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
|
|
1099
|
+
return patched__compute_dynamic_ntk_parameters
|
|
1100
|
+
return rope_init_fn
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
def patched_dynamic_rope_update(rope_forward):
|
|
1104
|
+
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
|
|
1105
|
+
|
|
1106
|
+
``rope_type`` is determined in the constructor of class
|
|
1107
|
+
:class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
|
|
1108
|
+
|
|
1109
|
+
.. code-block:: python
|
|
1110
|
+
|
|
1111
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1112
|
+
self.rope_type = config.rope_scaling.get(
|
|
1113
|
+
"rope_type", config.rope_scaling.get("type"))
|
|
1114
|
+
else:
|
|
1115
|
+
self.rope_type = "default"
|
|
1116
|
+
|
|
1117
|
+
The original code of the patched function:
|
|
1118
|
+
|
|
1119
|
+
.. code-block:: python
|
|
1120
|
+
|
|
1121
|
+
def dynamic_rope_update(rope_forward):
|
|
1122
|
+
def longrope_frequency_update(self, position_ids, device):
|
|
1123
|
+
seq_len = torch.max(position_ids) + 1
|
|
1124
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
1125
|
+
original_max_position_embeddings =
|
|
1126
|
+
self.config.original_max_position_embeddings
|
|
1127
|
+
else:
|
|
1128
|
+
original_max_position_embeddings =
|
|
1129
|
+
self.config.max_position_embeddings
|
|
1130
|
+
if seq_len > original_max_position_embeddings:
|
|
1131
|
+
if not hasattr(self, "long_inv_freq"):
|
|
1132
|
+
self.long_inv_freq, _ = self.rope_init_fn(
|
|
1133
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
1134
|
+
)
|
|
1135
|
+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
|
1136
|
+
else:
|
|
1137
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
1138
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
1139
|
+
|
|
1140
|
+
def dynamic_frequency_update(self, position_ids, device):
|
|
1141
|
+
seq_len = torch.max(position_ids) + 1
|
|
1142
|
+
if seq_len > self.max_seq_len_cached: # growth
|
|
1143
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
1144
|
+
self.config, device, seq_len=seq_len)
|
|
1145
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1146
|
+
self.max_seq_len_cached = seq_len
|
|
1147
|
+
|
|
1148
|
+
if seq_len < self.original_max_seq_len and
|
|
1149
|
+
self.max_seq_len_cached > self.original_max_seq_len:
|
|
1150
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
1151
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
1152
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
|
1153
|
+
|
|
1154
|
+
@wraps(rope_forward)
|
|
1155
|
+
def wrapper(self, x, position_ids):
|
|
1156
|
+
if "dynamic" in self.rope_type:
|
|
1157
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
1158
|
+
elif self.rope_type == "longrope":
|
|
1159
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
1160
|
+
return rope_forward(self, x, position_ids)
|
|
1161
|
+
|
|
1162
|
+
return wrapper
|
|
1163
|
+
|
|
1164
|
+
"""
|
|
1165
|
+
|
|
1166
|
+
def longrope_frequency_update(self, position_ids, device, layer_type=None):
|
|
1167
|
+
# It is no use to patch the function after the model is created
|
|
1168
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
1169
|
+
# is created and when no patch is applied yet.
|
|
1170
|
+
# So we select the patched version here.
|
|
1171
|
+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
|
|
1172
|
+
seq_len = torch.max(position_ids) + 1
|
|
1173
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
1174
|
+
original_max_position_embeddings = self.config.original_max_position_embeddings
|
|
1175
|
+
else:
|
|
1176
|
+
original_max_position_embeddings = self.config.max_position_embeddings
|
|
1177
|
+
|
|
1178
|
+
if layer_type is None:
|
|
1179
|
+
# rope_type = self.rope_type
|
|
1180
|
+
original_inv_freq = self.original_inv_freq
|
|
1181
|
+
prefix = ""
|
|
1182
|
+
else:
|
|
1183
|
+
# rope_type = self.rope_type[layer_type]
|
|
1184
|
+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
|
|
1185
|
+
prefix = f"{layer_type}_"
|
|
1186
|
+
|
|
1187
|
+
# At export time, seq_len is unknown.
|
|
1188
|
+
long_inv_freq, _ = rope_init_fn(
|
|
1189
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
1190
|
+
)
|
|
1191
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
1192
|
+
|
|
1193
|
+
# PATCHED: uses torch.cond instead of a test
|
|
1194
|
+
cond = (seq_len > original_max_position_embeddings).item()
|
|
1195
|
+
inv_freq = torch.cond(
|
|
1196
|
+
cond,
|
|
1197
|
+
(lambda x, y: x.clone()),
|
|
1198
|
+
(lambda x, y: y.clone()),
|
|
1199
|
+
[long_inv_freq, original_inv_freq],
|
|
1200
|
+
)
|
|
1201
|
+
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
1202
|
+
# if seq_len > original_max_position_embeddings:
|
|
1203
|
+
# self.inv_freq = self.long_inv_freq
|
|
1204
|
+
# else:
|
|
1205
|
+
# self.inv_freq = self.original_inv_freq
|
|
1206
|
+
|
|
1207
|
+
def dynamic_frequency_update(self, position_ids, device, layer_type=None):
|
|
1208
|
+
# constructor:
|
|
1209
|
+
# - self.max_seq_len_cached = config.max_position_embeddings
|
|
1210
|
+
# - self.original_max_seq_len = config.max_position_embeddings
|
|
1211
|
+
# - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
1212
|
+
|
|
1213
|
+
# It is no use to patch the function after the model is created
|
|
1214
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
1215
|
+
# is created and when no patch is applied yet.
|
|
1216
|
+
# So we select the patched version here.
|
|
1217
|
+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
|
|
1218
|
+
|
|
1219
|
+
# This behaviour is difficult to translate.
|
|
1220
|
+
# The sequence always grows.
|
|
1221
|
+
# The test should always True.
|
|
1222
|
+
# So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
|
|
1223
|
+
#
|
|
1224
|
+
# if seq_len > self.max_seq_len_cached: # growth
|
|
1225
|
+
# inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
1226
|
+
# self.config, device, seq_len=seq_len
|
|
1227
|
+
# )
|
|
1228
|
+
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1229
|
+
# self.max_seq_len_cached = seq_len
|
|
1230
|
+
#
|
|
1231
|
+
# So we should not need what follows.
|
|
1232
|
+
#
|
|
1233
|
+
# cond = (seq_len > self.max_seq_len_cached).item()
|
|
1234
|
+
# self.attention_scaling = torch.cond(
|
|
1235
|
+
# cond,
|
|
1236
|
+
# (lambda x, y: x.clone()),
|
|
1237
|
+
# (lambda x, y: y.clone()),
|
|
1238
|
+
# [attention_scaling, self.attention_scaling],
|
|
1239
|
+
# )
|
|
1240
|
+
|
|
1241
|
+
seq_len = torch.max(position_ids) + 1
|
|
1242
|
+
long_inv_freq, self.attention_scaling = rope_init_fn(
|
|
1243
|
+
self.config, device, seq_len=seq_len
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
if layer_type is None:
|
|
1247
|
+
# rope_type = self.rope_type
|
|
1248
|
+
# max_seq_len_cached = self.max_seq_len_cached
|
|
1249
|
+
original_inv_freq = self.original_inv_freq
|
|
1250
|
+
prefix = ""
|
|
1251
|
+
else:
|
|
1252
|
+
# rope_type = self.rope_type[layer_type]
|
|
1253
|
+
# max_seq_len_cached = getattr(
|
|
1254
|
+
# self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
|
|
1255
|
+
# )
|
|
1256
|
+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
|
|
1257
|
+
prefix = f"{layer_type}_"
|
|
1258
|
+
|
|
1259
|
+
# Second test to translate.
|
|
1260
|
+
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
|
|
1261
|
+
# But in that case the following condition is a way to restore the original cache.
|
|
1262
|
+
|
|
1263
|
+
# if (
|
|
1264
|
+
# seq_len < self.original_max_seq_len
|
|
1265
|
+
# and self.max_seq_len_cached > self.original_max_seq_len
|
|
1266
|
+
# ):
|
|
1267
|
+
# self.original_inv_freq = self.original_inv_freq.to(device)
|
|
1268
|
+
# self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
1269
|
+
# self.max_seq_len_cached = self.original_max_seq_len
|
|
1270
|
+
|
|
1271
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
1272
|
+
cond = (seq_len >= self.original_max_seq_len).item()
|
|
1273
|
+
# PATCHED: uses torch.cond instead of a test
|
|
1274
|
+
inv_freq = torch.cond(
|
|
1275
|
+
cond,
|
|
1276
|
+
(lambda x, y: x.clone()),
|
|
1277
|
+
(lambda x, y: y.clone()),
|
|
1278
|
+
[long_inv_freq, original_inv_freq],
|
|
1279
|
+
)
|
|
1280
|
+
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
1281
|
+
|
|
1282
|
+
@wraps(rope_forward)
|
|
1283
|
+
def wrapper(self, x, position_ids, layer_type=None):
|
|
1284
|
+
if layer_type is None:
|
|
1285
|
+
if "dynamic" in self.rope_type:
|
|
1286
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
1287
|
+
elif self.rope_type == "longrope":
|
|
1288
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
1289
|
+
return rope_forward(self, x, position_ids)
|
|
1290
|
+
|
|
1291
|
+
if "dynamic" in self.rope_type:
|
|
1292
|
+
dynamic_frequency_update(
|
|
1293
|
+
self, position_ids, device=x.device, layer_type=layer_type
|
|
1294
|
+
)
|
|
1295
|
+
elif self.rope_type == "longrope":
|
|
1296
|
+
longrope_frequency_update(
|
|
1297
|
+
self, position_ids, device=x.device, layer_type=layer_type
|
|
1298
|
+
)
|
|
1299
|
+
return rope_forward(self, x, position_ids, layer_type=layer_type)
|
|
1300
|
+
|
|
1301
|
+
return wrapper
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
def common_eager_attention_forward(
|
|
1305
|
+
module: torch.nn.Module,
|
|
1306
|
+
query: torch.Tensor,
|
|
1307
|
+
key: torch.Tensor,
|
|
1308
|
+
value: torch.Tensor,
|
|
1309
|
+
attention_mask: Optional[torch.Tensor],
|
|
1310
|
+
scaling: Optional[float] = None,
|
|
1311
|
+
dropout: float = 0.0,
|
|
1312
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
1313
|
+
**kwargs,
|
|
1314
|
+
):
|
|
1315
|
+
if scaling is None:
|
|
1316
|
+
scaling = query.size(-1) ** -0.5
|
|
1317
|
+
|
|
1318
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
1319
|
+
if attention_mask is not None:
|
|
1320
|
+
# PATCHED
|
|
1321
|
+
# The two following lines were added.
|
|
1322
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1323
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
1324
|
+
attn_weights = attn_weights + attention_mask
|
|
1325
|
+
|
|
1326
|
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
|
1327
|
+
|
|
1328
|
+
if head_mask is not None:
|
|
1329
|
+
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
|
1330
|
+
|
|
1331
|
+
attn_weights = torch.nn.functional.dropout(
|
|
1332
|
+
attn_weights, p=dropout, training=module.training
|
|
1333
|
+
)
|
|
1334
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
1335
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1336
|
+
|
|
1337
|
+
return attn_output, attn_weights
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
def patched_sdpa_attention_forward(
|
|
1341
|
+
module: torch.nn.Module,
|
|
1342
|
+
query: torch.Tensor,
|
|
1343
|
+
key: torch.Tensor,
|
|
1344
|
+
value: torch.Tensor,
|
|
1345
|
+
attention_mask: Optional[torch.Tensor],
|
|
1346
|
+
dropout: float = 0.0,
|
|
1347
|
+
scaling: Optional[float] = None,
|
|
1348
|
+
is_causal: Optional[bool] = None,
|
|
1349
|
+
**kwargs,
|
|
1350
|
+
) -> tuple[torch.Tensor, None]:
|
|
1351
|
+
"""
|
|
1352
|
+
manual patch for function
|
|
1353
|
+
``transformers.integrations.sdpa_attention.sdpa_attention_forward``
|
|
1354
|
+
"""
|
|
1355
|
+
assert not kwargs.get("output_attentions", False), (
|
|
1356
|
+
"`sdpa` attention does not support `output_attentions=True`."
|
|
1357
|
+
" Please set your attention to `eager` if you want any of these features."
|
|
1358
|
+
)
|
|
1359
|
+
torch._check(
|
|
1360
|
+
query.shape[0] == key.shape[0] or query.shape[0] == 1,
|
|
1361
|
+
lambda: (
|
|
1362
|
+
f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
|
|
1363
|
+
f"value: {value.shape}"
|
|
1364
|
+
),
|
|
1365
|
+
)
|
|
1366
|
+
torch._check(
|
|
1367
|
+
key.shape[0] == value.shape[0] or key.shape[0] == 1,
|
|
1368
|
+
lambda: (
|
|
1369
|
+
f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
|
|
1370
|
+
f"value: {value.shape}"
|
|
1371
|
+
),
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
sdpa_kwargs = {}
|
|
1375
|
+
if hasattr(module, "num_key_value_groups"):
|
|
1376
|
+
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
|
|
1377
|
+
key = transformers.integrations.sdpa_attention.repeat_kv(
|
|
1378
|
+
key, module.num_key_value_groups
|
|
1379
|
+
)
|
|
1380
|
+
value = transformers.integrations.sdpa_attention.repeat_kv(
|
|
1381
|
+
value, module.num_key_value_groups
|
|
1382
|
+
)
|
|
1383
|
+
else:
|
|
1384
|
+
sdpa_kwargs = {"enable_gqa": True}
|
|
1385
|
+
|
|
1386
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1387
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
1388
|
+
|
|
1389
|
+
torch._check(
|
|
1390
|
+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
|
|
1391
|
+
lambda: "Attention mask shape incompatible with key shape.",
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1394
|
+
if patch_sdpa_is_causal:
|
|
1395
|
+
# transformers>=4.55
|
|
1396
|
+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
|
1397
|
+
|
|
1398
|
+
# PATCHED: remove the test query.shape[2] > 1
|
|
1399
|
+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
|
|
1400
|
+
# and we split the test to keep the minimum in torch.cond
|
|
1401
|
+
is_causal = attention_mask is None and is_causal
|
|
1402
|
+
|
|
1403
|
+
if not is_causal:
|
|
1404
|
+
return (
|
|
1405
|
+
torch.nn.functional.scaled_dot_product_attention(
|
|
1406
|
+
query,
|
|
1407
|
+
key,
|
|
1408
|
+
value,
|
|
1409
|
+
attn_mask=attention_mask,
|
|
1410
|
+
dropout_p=dropout,
|
|
1411
|
+
scale=scaling,
|
|
1412
|
+
is_causal=is_causal,
|
|
1413
|
+
**sdpa_kwargs,
|
|
1414
|
+
)
|
|
1415
|
+
.transpose(1, 2)
|
|
1416
|
+
.contiguous(),
|
|
1417
|
+
None,
|
|
1418
|
+
)
|
|
1419
|
+
else:
|
|
1420
|
+
# transformers<4.55
|
|
1421
|
+
if is_causal is None and attention_mask is not None:
|
|
1422
|
+
is_causal = False
|
|
1423
|
+
if is_causal is not None:
|
|
1424
|
+
return (
|
|
1425
|
+
torch.nn.functional.scaled_dot_product_attention(
|
|
1426
|
+
query,
|
|
1427
|
+
key,
|
|
1428
|
+
value,
|
|
1429
|
+
attn_mask=attention_mask,
|
|
1430
|
+
dropout_p=dropout,
|
|
1431
|
+
scale=scaling,
|
|
1432
|
+
is_causal=is_causal,
|
|
1433
|
+
**sdpa_kwargs,
|
|
1434
|
+
)
|
|
1435
|
+
.transpose(1, 2)
|
|
1436
|
+
.contiguous(),
|
|
1437
|
+
None,
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
# To avoid the following errors:
|
|
1441
|
+
# is_causal=query.shape[2] > 1
|
|
1442
|
+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
|
|
1443
|
+
# is_causal=torch.tensor(query.shape[2] > 1)
|
|
1444
|
+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
|
|
1445
|
+
attn_output = torch.cond(
|
|
1446
|
+
query.shape[2] > 1, # distinction between prefill and decoding steps
|
|
1447
|
+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
1448
|
+
query,
|
|
1449
|
+
key,
|
|
1450
|
+
value,
|
|
1451
|
+
dropout_p=dropout,
|
|
1452
|
+
scale=scaling,
|
|
1453
|
+
is_causal=True,
|
|
1454
|
+
**sdpa_kwargs,
|
|
1455
|
+
),
|
|
1456
|
+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
1457
|
+
query,
|
|
1458
|
+
key,
|
|
1459
|
+
value,
|
|
1460
|
+
dropout_p=dropout,
|
|
1461
|
+
scale=scaling,
|
|
1462
|
+
is_causal=False,
|
|
1463
|
+
**sdpa_kwargs,
|
|
1464
|
+
),
|
|
1465
|
+
[query, key, value],
|
|
1466
|
+
)
|
|
1467
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1468
|
+
return attn_output, None
|
|
1469
|
+
|
|
1470
|
+
|
|
1471
|
+
def patched_model_bart_eager_attention_forward(
|
|
1472
|
+
module: torch.nn.Module,
|
|
1473
|
+
query: torch.Tensor,
|
|
1474
|
+
key: torch.Tensor,
|
|
1475
|
+
value: torch.Tensor,
|
|
1476
|
+
attention_mask: Optional[torch.Tensor],
|
|
1477
|
+
scaling: Optional[float] = None,
|
|
1478
|
+
dropout: float = 0.0,
|
|
1479
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
1480
|
+
**kwargs,
|
|
1481
|
+
):
|
|
1482
|
+
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
|
|
1483
|
+
return common_eager_attention_forward(
|
|
1484
|
+
module,
|
|
1485
|
+
query,
|
|
1486
|
+
key,
|
|
1487
|
+
value,
|
|
1488
|
+
attention_mask=attention_mask,
|
|
1489
|
+
scaling=scaling,
|
|
1490
|
+
dropout=dropout,
|
|
1491
|
+
head_mask=head_mask,
|
|
1492
|
+
**kwargs,
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
def patched_modeling_marian_eager_attention_forward(
|
|
1497
|
+
module: torch.nn.Module,
|
|
1498
|
+
query: torch.Tensor,
|
|
1499
|
+
key: torch.Tensor,
|
|
1500
|
+
value: torch.Tensor,
|
|
1501
|
+
attention_mask: Optional[torch.Tensor],
|
|
1502
|
+
scaling: Optional[float] = None,
|
|
1503
|
+
dropout: float = 0.0,
|
|
1504
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
1505
|
+
**kwargs,
|
|
1506
|
+
):
|
|
1507
|
+
"""[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
|
|
1508
|
+
return common_eager_attention_forward(
|
|
1509
|
+
module,
|
|
1510
|
+
query,
|
|
1511
|
+
key,
|
|
1512
|
+
value,
|
|
1513
|
+
attention_mask=attention_mask,
|
|
1514
|
+
scaling=scaling,
|
|
1515
|
+
dropout=dropout,
|
|
1516
|
+
head_mask=head_mask,
|
|
1517
|
+
**kwargs,
|
|
1518
|
+
)
|
|
1519
|
+
|
|
1520
|
+
|
|
1521
|
+
class common_RotaryEmbedding(torch.nn.Module):
|
|
1522
|
+
# This may cause some issues.
|
|
1523
|
+
# @torch.no_grad()
|
|
1524
|
+
# PATCHED: the decorator
|
|
1525
|
+
@patched_dynamic_rope_update
|
|
1526
|
+
def forward(self, x, position_ids, layer_type=None):
|
|
1527
|
+
if layer_type is not None:
|
|
1528
|
+
# transformers>=5.0
|
|
1529
|
+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
|
|
1530
|
+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
|
|
1531
|
+
else:
|
|
1532
|
+
# transformers<5.0
|
|
1533
|
+
inv_freq = self.inv_freq
|
|
1534
|
+
attention_scaling = self.attention_scaling
|
|
1535
|
+
|
|
1536
|
+
inv_freq_expanded = (
|
|
1537
|
+
inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
1538
|
+
)
|
|
1539
|
+
position_ids_expanded = position_ids[:, None, :].float()
|
|
1540
|
+
|
|
1541
|
+
device_type = (
|
|
1542
|
+
x.device.type
|
|
1543
|
+
if isinstance(x.device.type, str) and x.device.type != "mps"
|
|
1544
|
+
else "cpu"
|
|
1545
|
+
)
|
|
1546
|
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
1547
|
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
1548
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1549
|
+
cos = emb.cos() * attention_scaling
|
|
1550
|
+
sin = emb.sin() * attention_scaling
|
|
1551
|
+
|
|
1552
|
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
1553
|
+
|
|
1554
|
+
|
|
1555
|
+
class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
|
|
1556
|
+
_PATCHES_ = ["forward"]
|
|
1557
|
+
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
1561
|
+
|
|
1562
|
+
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
|
|
1563
|
+
_PATCHES_ = ["forward"]
|
|
1564
|
+
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
|
|
1565
|
+
|
|
1566
|
+
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
|
|
1567
|
+
_PATCHES_ = ["forward"]
|
|
1568
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
|
|
1569
|
+
|
|
1570
|
+
|
|
1571
|
+
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
|
|
1572
|
+
_PATCHES_ = ["forward"]
|
|
1573
|
+
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
|
|
1574
|
+
|
|
1575
|
+
|
|
1576
|
+
class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
|
|
1577
|
+
_PATCHES_ = ["forward"]
|
|
1578
|
+
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
|
|
1579
|
+
|
|
1580
|
+
|
|
1581
|
+
class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
|
|
1582
|
+
_PATCHES_ = ["forward"]
|
|
1583
|
+
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
|
|
1584
|
+
|
|
1585
|
+
|
|
1586
|
+
class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
|
|
1587
|
+
_PATCHES_ = ["forward"]
|
|
1588
|
+
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
|
|
1589
|
+
|
|
1590
|
+
|
|
1591
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.51"):
|
|
1592
|
+
|
|
1593
|
+
class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
|
|
1594
|
+
_PATCHES_ = ["forward"]
|
|
1595
|
+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
|
|
1596
|
+
|
|
1597
|
+
|
|
1598
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
1599
|
+
|
|
1600
|
+
class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
|
|
1601
|
+
_PATCHES_ = ["forward"]
|
|
1602
|
+
_PATCHED_CLASS_ = (
|
|
1603
|
+
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
|
|
1604
|
+
)
|
|
1605
|
+
|
|
1606
|
+
|
|
1607
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.53"):
|
|
1608
|
+
|
|
1609
|
+
class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
|
|
1610
|
+
_PATCHES_ = ["forward"]
|
|
1611
|
+
_PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
|
|
1612
|
+
|
|
1613
|
+
|
|
1614
|
+
class patched_IdeficsEmbedding(torch.nn.Module):
|
|
1615
|
+
_PATCHES_ = ["forward"]
|
|
1616
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
|
|
1617
|
+
|
|
1618
|
+
def forward(self, x, seq_len=None):
|
|
1619
|
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
1620
|
+
# if seq_len > self.max_seq_len_cached:
|
|
1621
|
+
# self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
1622
|
+
|
|
1623
|
+
def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
|
|
1624
|
+
t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
|
|
1625
|
+
# freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
1626
|
+
freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
|
|
1627
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1628
|
+
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
|
|
1629
|
+
|
|
1630
|
+
def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
|
|
1631
|
+
torch._check(seq_len.item() <= cos_cached.shape[0])
|
|
1632
|
+
co = cos_cached[: seq_len.item()].detach().clone()
|
|
1633
|
+
torch._check(seq_len.item() <= sin_cached.shape[0])
|
|
1634
|
+
si = sin_cached[: seq_len.item()].detach().clone()
|
|
1635
|
+
return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
|
|
1636
|
+
|
|
1637
|
+
cos_cached, sin_cached = torch.cond(
|
|
1638
|
+
(seq_len > self.max_seq_len_cached).item(),
|
|
1639
|
+
_set_cos_sin_cache_then,
|
|
1640
|
+
_set_cos_sin_cache_else,
|
|
1641
|
+
[x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
|
|
1642
|
+
)
|
|
1643
|
+
return cos_cached, sin_cached
|
|
1644
|
+
|
|
1645
|
+
|
|
1646
|
+
class patched_IdeficsAttention(torch.nn.Module):
|
|
1647
|
+
_PATCHES_ = ["forward"]
|
|
1648
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
|
|
1649
|
+
|
|
1650
|
+
def forward(
|
|
1651
|
+
self,
|
|
1652
|
+
hidden_states: torch.Tensor,
|
|
1653
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
1654
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
1655
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
1656
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
1657
|
+
output_attentions: bool = False,
|
|
1658
|
+
use_cache: bool = False,
|
|
1659
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
1660
|
+
**kwargs,
|
|
1661
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
1662
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
1663
|
+
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
|
1664
|
+
|
|
1665
|
+
bsz, q_len, _ = hidden_states.size()
|
|
1666
|
+
|
|
1667
|
+
query_states = (
|
|
1668
|
+
self.q_proj(hidden_states)
|
|
1669
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
1670
|
+
.transpose(1, 2)
|
|
1671
|
+
)
|
|
1672
|
+
if not is_cross_attention:
|
|
1673
|
+
key_states = (
|
|
1674
|
+
self.k_proj(hidden_states)
|
|
1675
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
1676
|
+
.transpose(1, 2)
|
|
1677
|
+
)
|
|
1678
|
+
value_states = (
|
|
1679
|
+
self.v_proj(hidden_states)
|
|
1680
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
1681
|
+
.transpose(1, 2)
|
|
1682
|
+
)
|
|
1683
|
+
else:
|
|
1684
|
+
_, kv_len, _ = (
|
|
1685
|
+
key_value_states.size()
|
|
1686
|
+
) # Note that, in this case, `kv_len` == `kv_seq_len`
|
|
1687
|
+
key_states = (
|
|
1688
|
+
self.k_proj(key_value_states)
|
|
1689
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
1690
|
+
.transpose(1, 2)
|
|
1691
|
+
)
|
|
1692
|
+
value_states = (
|
|
1693
|
+
self.v_proj(key_value_states)
|
|
1694
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
1695
|
+
.transpose(1, 2)
|
|
1696
|
+
)
|
|
1697
|
+
|
|
1698
|
+
kv_seq_len = key_states.shape[-2]
|
|
1699
|
+
if past_key_value is not None:
|
|
1700
|
+
kv_seq_len += cache_position[0]
|
|
1701
|
+
|
|
1702
|
+
if not is_cross_attention:
|
|
1703
|
+
rotary_length = torch.maximum(
|
|
1704
|
+
torch.tensor(kv_seq_len, dtype=torch.int64),
|
|
1705
|
+
torch.tensor(q_len, dtype=torch.int64),
|
|
1706
|
+
)
|
|
1707
|
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
|
|
1708
|
+
query_states, key_states = (
|
|
1709
|
+
transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
|
|
1710
|
+
query_states, key_states, cos, sin, position_ids
|
|
1711
|
+
)
|
|
1712
|
+
)
|
|
1713
|
+
# [bsz, nh, t, hd]
|
|
1714
|
+
|
|
1715
|
+
if past_key_value is not None:
|
|
1716
|
+
# sin and cos are specific to RoPE models;
|
|
1717
|
+
# cache_position needed for the static cache
|
|
1718
|
+
cache_kwargs = {"cache_position": cache_position}
|
|
1719
|
+
key_states, value_states = past_key_value.update(
|
|
1720
|
+
key_states, value_states, self.layer_idx, cache_kwargs
|
|
1721
|
+
)
|
|
1722
|
+
|
|
1723
|
+
if self.qk_layer_norms:
|
|
1724
|
+
query_states = self.q_layer_norm(query_states)
|
|
1725
|
+
key_states = self.k_layer_norm(key_states)
|
|
1726
|
+
|
|
1727
|
+
attention_interface: Callable = (
|
|
1728
|
+
transformers.models.idefics.modeling_idefics.eager_attention_forward
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
if self.config._attn_implementation != "eager":
|
|
1732
|
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
|
1733
|
+
transformers.models.idefics.modeling_idefics.logger.warning_once(
|
|
1734
|
+
"`torch.nn.functional.scaled_dot_product_attention` does not support "
|
|
1735
|
+
"`output_attentions=True`. Falling back to "
|
|
1736
|
+
"eager attention. This warning can be removed using the argument "
|
|
1737
|
+
'`attn_implementation="eager"` when loading the model.'
|
|
1738
|
+
)
|
|
1739
|
+
else:
|
|
1740
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
1741
|
+
self.config._attn_implementation
|
|
1742
|
+
]
|
|
1743
|
+
|
|
1744
|
+
attn_output, attn_weights = attention_interface(
|
|
1745
|
+
self,
|
|
1746
|
+
query_states,
|
|
1747
|
+
key_states,
|
|
1748
|
+
value_states,
|
|
1749
|
+
attention_mask,
|
|
1750
|
+
dropout=0.0 if not self.training else self.dropout,
|
|
1751
|
+
scaling=self.scaling,
|
|
1752
|
+
**kwargs,
|
|
1753
|
+
)
|
|
1754
|
+
|
|
1755
|
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
1756
|
+
attn_output = self.o_proj(attn_output)
|
|
1757
|
+
|
|
1758
|
+
if output_attentions:
|
|
1759
|
+
attn_weights = None
|
|
1760
|
+
|
|
1761
|
+
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
|
|
1762
|
+
return attn_output, attn_weights, past_key_value
|
|
1763
|
+
return attn_output, attn_weights
|
|
1764
|
+
|
|
1765
|
+
|
|
1766
|
+
class patched_SamMaskDecoder(torch.nn.Module):
|
|
1767
|
+
_PATCHES_ = ["forward"]
|
|
1768
|
+
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
|
|
1769
|
+
|
|
1770
|
+
def forward(
|
|
1771
|
+
self,
|
|
1772
|
+
image_embeddings: torch.Tensor,
|
|
1773
|
+
image_positional_embeddings: torch.Tensor,
|
|
1774
|
+
sparse_prompt_embeddings: torch.Tensor,
|
|
1775
|
+
dense_prompt_embeddings: torch.Tensor,
|
|
1776
|
+
multimask_output: bool,
|
|
1777
|
+
output_attentions: Optional[bool] = None,
|
|
1778
|
+
attention_similarity: Optional[torch.Tensor] = None,
|
|
1779
|
+
target_embedding: Optional[torch.Tensor] = None,
|
|
1780
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1781
|
+
"""
|
|
1782
|
+
Predict masks given image and prompt embeddings.
|
|
1783
|
+
|
|
1784
|
+
Args:
|
|
1785
|
+
image_embeddings (`torch.Tensor`):
|
|
1786
|
+
the embeddings from the image encoder
|
|
1787
|
+
image_positional_embedding (`torch.Tensor`):
|
|
1788
|
+
positional encoding with the shape of image_embeddings
|
|
1789
|
+
sparse_prompt_embeddings (`torch.Tensor`):
|
|
1790
|
+
The embeddings of the points and boxes
|
|
1791
|
+
dense_prompt_embeddings (`torch.Tensor`):
|
|
1792
|
+
the embeddings of the mask inputs
|
|
1793
|
+
multimask_output (bool):
|
|
1794
|
+
Whether to return multiple masks or a single mask.
|
|
1795
|
+
output_attentions (bool, *optional*):
|
|
1796
|
+
Whether or not to return the attentions tensors of all attention layers.
|
|
1797
|
+
"""
|
|
1798
|
+
batch_size, num_channels, height, width = image_embeddings.shape
|
|
1799
|
+
point_batch_size = sparse_prompt_embeddings.shape[1]
|
|
1800
|
+
# Concatenate output tokens
|
|
1801
|
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
1802
|
+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
|
1803
|
+
|
|
1804
|
+
# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
|
|
1805
|
+
# torch.any is needed to avoid data-dependent control flow
|
|
1806
|
+
# with sparse_prompt_embeddings.sum().item() != 0
|
|
1807
|
+
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
|
|
1808
|
+
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
|
1809
|
+
|
|
1810
|
+
def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
|
|
1811
|
+
return output_tokens.clone()
|
|
1812
|
+
|
|
1813
|
+
tokens = torch.cond(
|
|
1814
|
+
torch.any(sparse_prompt_embeddings != 0),
|
|
1815
|
+
sparse_prompt_embeddings_is_not_empty,
|
|
1816
|
+
sparse_prompt_embeddings_is_empty,
|
|
1817
|
+
[output_tokens, sparse_prompt_embeddings],
|
|
1818
|
+
)
|
|
1819
|
+
|
|
1820
|
+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
|
1821
|
+
|
|
1822
|
+
# Expand per-image data in batch direction to be per-point
|
|
1823
|
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
|
1824
|
+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
|
1825
|
+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(
|
|
1826
|
+
point_batch_size, 0
|
|
1827
|
+
)
|
|
1828
|
+
|
|
1829
|
+
# Run the transformer, image_positional_embedding are consumed
|
|
1830
|
+
torch._check(point_embeddings.shape[0] != 0)
|
|
1831
|
+
torch._check(point_embeddings.shape[1] != 0)
|
|
1832
|
+
torch._check(point_embeddings.shape[2] != 0)
|
|
1833
|
+
torch._check(point_embeddings.shape[3] != 0)
|
|
1834
|
+
embeddings_attentions = self.transformer(
|
|
1835
|
+
point_embeddings=point_embeddings,
|
|
1836
|
+
image_embeddings=image_embeddings,
|
|
1837
|
+
image_positional_embeddings=image_positional_embeddings,
|
|
1838
|
+
attention_similarity=attention_similarity,
|
|
1839
|
+
target_embedding=target_embedding,
|
|
1840
|
+
output_attentions=output_attentions,
|
|
1841
|
+
)
|
|
1842
|
+
point_embedding, image_embeddings = embeddings_attentions[:2]
|
|
1843
|
+
iou_token_out = torch.select(point_embedding, dim=2, index=0)
|
|
1844
|
+
mask_tokens_out = torch.narrow(
|
|
1845
|
+
point_embedding, dim=2, start=1, length=self.num_mask_tokens
|
|
1846
|
+
)
|
|
1847
|
+
|
|
1848
|
+
# Upscale mask embeddings and predict masks using the mask tokens
|
|
1849
|
+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
|
|
1850
|
+
batch_size * point_batch_size, num_channels, height, width
|
|
1851
|
+
)
|
|
1852
|
+
|
|
1853
|
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
|
1854
|
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
|
1855
|
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
|
1856
|
+
|
|
1857
|
+
hyper_in_list = []
|
|
1858
|
+
for i in range(self.num_mask_tokens):
|
|
1859
|
+
current_mlp = self.output_hypernetworks_mlps[i]
|
|
1860
|
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
|
1861
|
+
hyper_in = torch.stack(hyper_in_list, dim=2)
|
|
1862
|
+
|
|
1863
|
+
_, num_channels, height, width = upscaled_embedding.shape
|
|
1864
|
+
upscaled_embedding = upscaled_embedding.reshape(
|
|
1865
|
+
batch_size, point_batch_size, num_channels, height * width
|
|
1866
|
+
)
|
|
1867
|
+
masks = (hyper_in @ upscaled_embedding).reshape(
|
|
1868
|
+
batch_size, point_batch_size, -1, height, width
|
|
1869
|
+
)
|
|
1870
|
+
|
|
1871
|
+
# Generate mask quality predictions
|
|
1872
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
1873
|
+
|
|
1874
|
+
# Select the correct mask or masks for output
|
|
1875
|
+
if multimask_output:
|
|
1876
|
+
mask_slice = slice(1, None)
|
|
1877
|
+
else:
|
|
1878
|
+
mask_slice = slice(0, 1)
|
|
1879
|
+
masks = masks[:, :, mask_slice, :, :]
|
|
1880
|
+
iou_pred = iou_pred[:, :, mask_slice]
|
|
1881
|
+
|
|
1882
|
+
outputs = (masks, iou_pred)
|
|
1883
|
+
|
|
1884
|
+
if len(embeddings_attentions) == 2:
|
|
1885
|
+
# transformers==4.54
|
|
1886
|
+
return outputs
|
|
1887
|
+
|
|
1888
|
+
if output_attentions and len(embeddings_attentions) > 2:
|
|
1889
|
+
outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
|
|
1890
|
+
else:
|
|
1891
|
+
outputs = outputs + (None,) # noqa: RUF005
|
|
1892
|
+
return outputs
|
|
1893
|
+
|
|
1894
|
+
|
|
1895
|
+
def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
|
|
1896
|
+
"""
|
|
1897
|
+
Rewrites the loop in:
|
|
1898
|
+
|
|
1899
|
+
.. code-block:: python
|
|
1900
|
+
|
|
1901
|
+
attention_mask = torch.full(
|
|
1902
|
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
|
|
1903
|
+
)
|
|
1904
|
+
for i in range(1, len(seq)):
|
|
1905
|
+
attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
|
|
1906
|
+
"""
|
|
1907
|
+
r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
|
|
1908
|
+
less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
|
|
1909
|
+
less = less0.sum(axis=-1, keepdim=True) + 1
|
|
1910
|
+
sq = less * less.T
|
|
1911
|
+
look = (
|
|
1912
|
+
torch.max(seq.min() == 0, less != less.max())
|
|
1913
|
+
* torch.max(seq.max() == mask.shape[-1], less != less.min())
|
|
1914
|
+
* less
|
|
1915
|
+
)
|
|
1916
|
+
filt = (sq != look**2).to(mask.dtype)
|
|
1917
|
+
return mask * filt
|
|
1918
|
+
|
|
1919
|
+
|
|
1920
|
+
class patched_VisionAttention(torch.nn.Module):
|
|
1921
|
+
_PATCHES_ = ["forward"]
|
|
1922
|
+
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
|
|
1923
|
+
|
|
1924
|
+
def forward(
|
|
1925
|
+
self,
|
|
1926
|
+
hidden_states: torch.Tensor,
|
|
1927
|
+
cu_seqlens: torch.Tensor,
|
|
1928
|
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
1929
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
1930
|
+
) -> torch.Tensor:
|
|
1931
|
+
seq_length = hidden_states.shape[0]
|
|
1932
|
+
q, k, v = (
|
|
1933
|
+
self.qkv(hidden_states)
|
|
1934
|
+
.reshape(seq_length, 3, self.num_heads, -1)
|
|
1935
|
+
.permute(1, 0, 2, 3)
|
|
1936
|
+
.unbind(0)
|
|
1937
|
+
)
|
|
1938
|
+
if position_embeddings is None:
|
|
1939
|
+
transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
|
|
1940
|
+
"The attention layers in this model are transitioning from "
|
|
1941
|
+
" computing the RoPE embeddings internally "
|
|
1942
|
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), "
|
|
1943
|
+
"to using externally computed "
|
|
1944
|
+
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
|
|
1945
|
+
" In v4.54 `rotary_pos_emb` will be "
|
|
1946
|
+
"removed and `position_embeddings` will be mandatory."
|
|
1947
|
+
)
|
|
1948
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
1949
|
+
cos = emb.cos()
|
|
1950
|
+
sin = emb.sin()
|
|
1951
|
+
else:
|
|
1952
|
+
cos, sin = position_embeddings
|
|
1953
|
+
q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
|
|
1954
|
+
q, k, cos, sin
|
|
1955
|
+
)
|
|
1956
|
+
|
|
1957
|
+
attention_mask = torch.full(
|
|
1958
|
+
[1, seq_length, seq_length],
|
|
1959
|
+
torch.finfo(q.dtype).min,
|
|
1960
|
+
device=q.device,
|
|
1961
|
+
dtype=q.dtype,
|
|
1962
|
+
)
|
|
1963
|
+
# for i in range(1, len(cu_seqlens)):
|
|
1964
|
+
# attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
|
|
1965
|
+
# cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
|
1966
|
+
attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
|
|
1967
|
+
|
|
1968
|
+
q = q.transpose(0, 1)
|
|
1969
|
+
k = k.transpose(0, 1)
|
|
1970
|
+
v = v.transpose(0, 1)
|
|
1971
|
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
|
1972
|
+
attn_weights = attn_weights + attention_mask
|
|
1973
|
+
attn_weights = torch.nn.functional.softmax(
|
|
1974
|
+
attn_weights, dim=-1, dtype=torch.float32
|
|
1975
|
+
).to(q.dtype)
|
|
1976
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
1977
|
+
attn_output = attn_output.transpose(0, 1)
|
|
1978
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
1979
|
+
attn_output = self.proj(attn_output)
|
|
1980
|
+
return attn_output
|
|
1981
|
+
|
|
1982
|
+
|
|
1983
|
+
try:
|
|
1984
|
+
import transformers.models.qwen3_moe
|
|
1985
|
+
|
|
1986
|
+
patch_qwen3 = True
|
|
1987
|
+
except ImportError:
|
|
1988
|
+
patch_qwen3 = False
|
|
1989
|
+
|
|
1990
|
+
if patch_qwen3:
|
|
1991
|
+
|
|
1992
|
+
class patched_Qwen3MoeSparseMoeBlock(torch.nn.Module):
|
|
1993
|
+
_PATCHES_ = ["forward", "_forward_expert_loop"]
|
|
1994
|
+
_PATCHED_CLASS_ = (
|
|
1995
|
+
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
|
|
1996
|
+
)
|
|
1997
|
+
|
|
1998
|
+
def _forward_expert_loop(
|
|
1999
|
+
self,
|
|
2000
|
+
final_hidden_states,
|
|
2001
|
+
expert_mask_idx,
|
|
2002
|
+
hidden_states,
|
|
2003
|
+
routing_weights,
|
|
2004
|
+
expert_idx: int,
|
|
2005
|
+
):
|
|
2006
|
+
# idx, top_x = torch.where(expert_mask_idx.squeeze(0))
|
|
2007
|
+
idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
|
|
2008
|
+
hidden_dim = hidden_states.shape[-1]
|
|
2009
|
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
|
2010
|
+
expert_current_state = self.experts[expert_idx](current_state)
|
|
2011
|
+
current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
|
|
2012
|
+
return final_hidden_states.index_add(
|
|
2013
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
2014
|
+
)
|
|
2015
|
+
|
|
2016
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
2017
|
+
""" """
|
|
2018
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
2019
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
2020
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
2021
|
+
router_logits = self.gate(hidden_states)
|
|
2022
|
+
|
|
2023
|
+
routing_weights = torch.nn.functional.softmax(
|
|
2024
|
+
router_logits, dim=1, dtype=torch.float
|
|
2025
|
+
)
|
|
2026
|
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
2027
|
+
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
|
|
2028
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
2029
|
+
# we cast back to the input dtype
|
|
2030
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
2031
|
+
|
|
2032
|
+
final_hidden_states = torch.zeros(
|
|
2033
|
+
(batch_size * sequence_length, hidden_dim),
|
|
2034
|
+
dtype=hidden_states.dtype,
|
|
2035
|
+
device=hidden_states.device,
|
|
2036
|
+
)
|
|
2037
|
+
|
|
2038
|
+
# One hot encode the selected experts to create an expert mask
|
|
2039
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
2040
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
2041
|
+
selected_experts, num_classes=self.num_experts
|
|
2042
|
+
).permute(2, 1, 0)
|
|
2043
|
+
|
|
2044
|
+
# Loop over all available experts in the model
|
|
2045
|
+
# and perform the computation on each expert
|
|
2046
|
+
expert_sum = expert_mask.sum(dim=(-1, -2))
|
|
2047
|
+
# expert_hit = torch.greater(expert_sum, 0).nonzero()
|
|
2048
|
+
# for expert_idx in expert_hit:
|
|
2049
|
+
for expert_idx in range(self.num_experts):
|
|
2050
|
+
# initial code has a squeeze but it is not possible to do that.
|
|
2051
|
+
# expert_mask_idx = expert_mask[expert_idx].squeeze(0)
|
|
2052
|
+
expert_mask_idx = expert_mask[expert_idx]
|
|
2053
|
+
final_hidden_states = torch.cond(
|
|
2054
|
+
(expert_sum[expert_idx] > 0).item(),
|
|
2055
|
+
lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
|
|
2056
|
+
final_hidden_states,
|
|
2057
|
+
expert_mask,
|
|
2058
|
+
hidden_states,
|
|
2059
|
+
routing_weights,
|
|
2060
|
+
expert_idx=_i,
|
|
2061
|
+
),
|
|
2062
|
+
lambda final_hidden_states, *args: final_hidden_states.clone(),
|
|
2063
|
+
[final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
|
|
2064
|
+
)
|
|
2065
|
+
|
|
2066
|
+
# if expert_sum[expert_idx] > 0:
|
|
2067
|
+
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
2068
|
+
|
|
2069
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
2070
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
2071
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
2072
|
+
# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
|
2073
|
+
# current_hidden_states = (
|
|
2074
|
+
# expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
2075
|
+
# )
|
|
2076
|
+
|
|
2077
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
2078
|
+
# the `top_x` tensor here.
|
|
2079
|
+
# final_hidden_states.index_add_(
|
|
2080
|
+
# 0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
2081
|
+
# )
|
|
2082
|
+
|
|
2083
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
2084
|
+
batch_size, sequence_length, hidden_dim
|
|
2085
|
+
)
|
|
2086
|
+
return final_hidden_states, router_logits
|
|
2087
|
+
|
|
2088
|
+
|
|
2089
|
+
try:
|
|
2090
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
|
|
2091
|
+
|
|
2092
|
+
patch_gemma3 = True
|
|
2093
|
+
except ImportError:
|
|
2094
|
+
patch_gemma3 = False
|
|
2095
|
+
|
|
2096
|
+
|
|
2097
|
+
if patch_gemma3:
|
|
2098
|
+
|
|
2099
|
+
class patched_Gemma3Model(torch.nn.Module):
|
|
2100
|
+
_PATCHES_ = ["get_placeholder_mask"]
|
|
2101
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
|
|
2102
|
+
_PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
|
|
2103
|
+
|
|
2104
|
+
def get_placeholder_mask(
|
|
2105
|
+
self,
|
|
2106
|
+
input_ids: torch.LongTensor,
|
|
2107
|
+
inputs_embeds: torch.FloatTensor,
|
|
2108
|
+
image_features: torch.FloatTensor,
|
|
2109
|
+
):
|
|
2110
|
+
if input_ids is None:
|
|
2111
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
2112
|
+
torch.tensor(
|
|
2113
|
+
self.config.image_token_id,
|
|
2114
|
+
dtype=torch.long,
|
|
2115
|
+
device=inputs_embeds.device,
|
|
2116
|
+
)
|
|
2117
|
+
)
|
|
2118
|
+
special_image_mask = special_image_mask.all(-1)
|
|
2119
|
+
else:
|
|
2120
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
2121
|
+
|
|
2122
|
+
n_image_tokens = special_image_mask.sum()
|
|
2123
|
+
special_image_mask = (
|
|
2124
|
+
special_image_mask.unsqueeze(-1)
|
|
2125
|
+
.expand_as(inputs_embeds)
|
|
2126
|
+
.to(inputs_embeds.device)
|
|
2127
|
+
)
|
|
2128
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
2129
|
+
# PATCHED: torch._check
|
|
2130
|
+
# if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
2131
|
+
# raise ValueError( ... )
|
|
2132
|
+
torch._check(
|
|
2133
|
+
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
2134
|
+
lambda: (
|
|
2135
|
+
f"Image features and image tokens do not match: tokens: "
|
|
2136
|
+
f"{n_image_tokens}, features {n_image_features}"
|
|
2137
|
+
),
|
|
2138
|
+
)
|
|
2139
|
+
return special_image_mask
|