onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -1,2641 +1,97 @@
1
- import inspect
2
- import math
3
- import os
4
- from dataclasses import dataclass
5
- from functools import wraps
6
- from typing import Callable, List, Optional, Tuple, Union
7
- import packaging.version as pv
8
- import torch
9
- import transformers
10
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
11
- from transformers.cache_utils import StaticCache, Cache
12
- from transformers.generation.utils import (
13
- GenerateNonBeamOutput,
14
- GenerationConfig,
15
- StoppingCriteriaList,
16
- LogitsProcessorList,
1
+ # transformers
2
+ from typing import List
3
+ from .patch_helper import _has_transformers
4
+
5
+ from ._patch_transformers_attention import (
6
+ patched_sdpa_attention_forward,
7
+ patched_model_bart_eager_attention_forward,
8
+ patched_modeling_marian_eager_attention_forward,
17
9
  )
18
10
 
19
- try:
20
- from transformers.cache_utils import parse_processor_args # noqa: F401
21
-
22
- patch_parse_processor_args = True
23
- except ImportError:
24
- patch_parse_processor_args = False
25
-
26
- try:
27
- import transformers.masking_utils
28
-
29
- patch_masking_utils = True
30
- except ImportError:
31
- patch_masking_utils = False
32
-
33
-
34
- try:
35
- # transformers>= 4.55.1
36
- from transformers.cache_utils import DynamicLayer
37
-
38
- patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
39
- except ImportError:
40
- patch_DynamicLayer = False
41
-
42
-
43
- def _has_transformers(version: str) -> bool:
44
- return pv.Version(transformers.__version__) >= pv.Version(version)
45
-
46
-
47
- def _is_torchdynamo_exporting() -> bool:
48
- """
49
- Tells if :epkg:`torch` is exporting a model.
50
- Relies on ``torch.compiler.is_exporting()``.
51
- """
52
- import torch
53
-
54
- if not hasattr(torch.compiler, "is_exporting"):
55
- # torch.compiler.is_exporting requires torch>=2.7
56
- return False
57
-
58
- try:
59
- return torch.compiler.is_exporting()
60
- except Exception:
61
- try:
62
- import torch._dynamo as dynamo
63
-
64
- return dynamo.is_exporting() # type: ignore
65
- except Exception:
66
- return False
67
-
68
-
69
- patch_sdpa_is_causal = _has_transformers("4.99")
70
- patch_is_initialized = _has_transformers("4.56.99")
71
-
72
-
73
- if patch_masking_utils:
74
- # Introduced in 4.52
75
- from transformers.masking_utils import (
76
- _ignore_causal_mask_sdpa,
77
- and_masks,
78
- causal_mask_function,
79
- padding_mask_function,
80
- prepare_padding_mask,
81
- )
82
-
83
- try:
84
- # transformers>=5.0
85
- from transformers.masking_utils import (
86
- _ignore_bidirectional_mask_sdpa,
87
- bidirectional_mask_function,
88
- )
89
- except ImportError:
90
- _ignore_bidirectional_mask_sdpa = None
91
- bidirectional_mask_function = None
92
-
93
- def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
94
- """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
95
- from ...helpers import string_type
96
-
97
- dimensions: List[Tuple[Optional[int], ...]] = [
98
- (None, None, None, 0),
99
- (None, None, 0, None),
100
- ]
101
- if bh_indices:
102
- dimensions.extend([(None, 0, None, None), (0, None, None, None)])
103
- # reshape
104
- dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
105
- dimensions = tuple(reversed(dimensions))
106
- indices = tuple(shape.index(-1) for shape in dimensions)
107
-
108
- # unsqueeze
109
- udimensions = [
110
- tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
111
- ]
112
-
113
- def vector_mask_function(
114
- *args, mask_function=mask_function, dimensions=dimensions, indices=indices
115
- ):
116
- assert len(args) == len(dimensions) == len(udimensions), (
117
- f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
118
- f"and udimensions={udimensions}."
119
- )
120
- assert len(indices) == len(args), (
121
- f"Mismatch between args={string_type(args)} and indices={indices}, "
122
- f"they should have the same length."
123
- )
124
- for a in args:
125
- assert (
126
- a.ndim == 1
127
- ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
128
- torch._check(a.shape[0] > 0)
129
-
130
- new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
131
- # new_args = [
132
- # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
133
- # for a, dims in zip(args, udimensions)
134
- # ]
135
- max_shape = tuple(args[i].shape[0] for i in indices)
136
- # if _is_torchdynamo_exporting():
137
- # for a in args:
138
- # # The exporter should export with a dimension > 1
139
- # # to make sure it is dynamic.
140
- # torch._check(a.shape[0] > 1)
141
- expanded_args = [a.expand(max_shape) for a in new_args]
142
- return mask_function(*expanded_args)
143
-
144
- return vector_mask_function
145
-
146
- def patched_eager_mask(
147
- batch_size: int,
148
- cache_position: torch.Tensor,
149
- kv_length: int,
150
- kv_offset: int = 0,
151
- mask_function: Callable = causal_mask_function,
152
- attention_mask: Optional[torch.Tensor] = None,
153
- dtype: torch.dtype = torch.float32,
154
- **kwargs,
155
- ) -> torch.Tensor:
156
- """manual patch for function ``transformers.masking_utils.eager_mask``."""
157
- # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
158
- _ = kwargs.pop("allow_is_causal_skip", None)
159
- _ = kwargs.pop("allow_is_bidirectional_skip", None)
160
- # PATCHED: this line called the patched version of sdpa_mask
161
- mask = patched_sdpa_mask_recent_torch(
162
- batch_size=batch_size,
163
- cache_position=cache_position,
164
- kv_length=kv_length,
165
- kv_offset=kv_offset,
166
- mask_function=mask_function,
167
- attention_mask=attention_mask,
168
- allow_is_causal_skip=False,
169
- allow_is_bidirectional_skip=False,
170
- allow_torch_fix=False,
171
- **kwargs,
172
- )
173
- min_dtype = torch.finfo(dtype).min
174
- # PATCHED: the following line
175
- # we need 0s where the tokens should be taken into account,
176
- # and -inf otherwise (mask is already of boolean type)
177
- # mask =
178
- # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
179
- mask = (~mask).to(dtype) * min_dtype
180
- return mask
181
-
182
- def patched_sdpa_mask_recent_torch(
183
- batch_size: int,
184
- cache_position: torch.Tensor,
185
- kv_length: int,
186
- kv_offset: int = 0,
187
- mask_function: Callable = causal_mask_function,
188
- attention_mask: Optional[torch.Tensor] = None,
189
- local_size: Optional[int] = None,
190
- allow_is_causal_skip: bool = True,
191
- allow_is_bidirectional_skip: bool = False,
192
- **kwargs,
193
- ) -> Optional[torch.Tensor]:
194
- """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
195
- q_length = cache_position.shape[0]
196
- padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
197
- if allow_is_causal_skip and _ignore_causal_mask_sdpa(
198
- padding_mask, q_length, kv_length, kv_offset, local_size
199
- ):
200
- return None
201
- if (
202
- allow_is_bidirectional_skip
203
- and _ignore_bidirectional_mask_sdpa
204
- and _ignore_bidirectional_mask_sdpa(padding_mask)
205
- ):
206
- return None
207
-
208
- if mask_function is bidirectional_mask_function:
209
- if padding_mask is not None:
210
- # used for slicing without data-dependent slicing
211
- mask_indices = (
212
- torch.arange(kv_length, device=cache_position.device) + kv_offset
213
- )
214
- return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
215
- return torch.ones(
216
- batch_size,
217
- 1,
218
- q_length,
219
- kv_length,
220
- dtype=torch.bool,
221
- device=cache_position.device,
222
- )
223
-
224
- kv_arange = torch.arange(kv_length, device=cache_position.device)
225
- kv_arange += kv_offset
226
- if padding_mask is not None:
227
- mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
228
- batch_arange = torch.arange(batch_size, device=cache_position.device)
229
- head_arange = torch.arange(1, device=cache_position.device)
230
- # PATCHED: this line calls the patched version of vmap_for_bhqkv
231
- causal_mask = patched__vmap_for_bhqkv(mask_function)(
232
- batch_arange, head_arange, cache_position, kv_arange
233
- )
234
- return causal_mask
235
-
11
+ from ._patch_transformers_cache_utils import patch_parse_processor_args
236
12
 
237
13
  if patch_parse_processor_args:
14
+ from ._patch_transformers_cache_utils import patched_parse_processor_args
238
15
 
239
- def _init_cache_inspect():
240
- res = {}
241
- for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
242
- try:
243
- params = list(inspect.signature(processor_class.__init__).parameters)[2:]
244
- res[processor_class.__init__] = params
245
- except Exception:
246
- res[processor_class.__init__] = None
247
- return res
248
-
249
- _cache_inspect = _init_cache_inspect()
250
-
251
- def patched_parse_processor_args(
252
- processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
253
- ) -> tuple[dict, dict]:
254
- """[patch:transformers.cache_utils.parse_processor_args]"""
255
- # If not patched...
256
- # Fails with transformers>=4.54 because function ``parse_processor_args``
257
- # relies in inspect and the exporter is not very fond of that.
258
- # torch._dynamo.exc.Unsupported: id() with unsupported args
259
- # Explanation: Dynamo doesn't know how to trace id()
260
- # call with args
261
- # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
262
- # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
263
- # objects from outside the compiled region.
264
- # Hint: It may be possible to write Dynamo tracing rules for this code.
265
- #
266
- # The patch is caching the signature to avoid any call to inspect.
267
- if processor_class is None:
268
- return {}, kwargs
269
- params = _cache_inspect[processor_class.__init__]
270
- if params is None:
271
- return {}, kwargs
272
- processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
273
- remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
274
- return processor_kwargs, remaining_kwargs
16
+ from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
275
17
 
18
+ from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache
276
19
 
277
20
  if patch_DynamicLayer:
21
+ from ._patch_transformers_dynamic_cache import patched_DynamicLayer
22
+ if patch_DynamicCache:
23
+ from ._patch_transformers_dynamic_cache import patched_DynamicCache
278
24
 
279
- class patched_DynamicLayer:
280
- _PATCHES_ = ["lazy_initialization"]
281
- _PATCHED_CLASS_ = DynamicLayer
282
-
283
- def lazy_initialization(self, key_states: torch.Tensor):
284
- self.dtype, self.device = key_states.dtype, key_states.device
285
- new_shape = list(key_states.shape)
286
- new_shape[-2] = 0
287
- # PATCHED: used a tensor with an empty shape and not en empty list to initialize
288
- self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
289
- self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
290
- if patch_is_initialized:
291
- self.is_initialized = True
292
-
293
-
294
- def _patch_make_causal_mask(
295
- input_ids_shape: torch.Size,
296
- dtype: torch.dtype,
297
- device: torch.device,
298
- past_key_values_length: int = 0,
299
- sliding_window: Optional[int] = None,
300
- ):
301
- """Patched method."""
302
- bsz, tgt_len = input_ids_shape
303
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
304
- mask_cond = torch.arange(mask.size(-1), device=device)
305
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
306
-
307
- mask = mask.to(dtype)
308
-
309
- if past_key_values_length > 0:
310
- mask = torch.cat(
311
- [
312
- torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
313
- mask,
314
- ],
315
- dim=-1,
316
- )
317
-
318
- if sliding_window is not None:
319
- diagonal = past_key_values_length - sliding_window - 1
320
-
321
- context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
322
- # PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
323
- # and used masked_fill instead of masked_fill_
324
- # In this case, the current implementation of torch fails (17/12/2024).
325
- # Try model Phi-3.5-Mini-Instruct.
326
- mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
327
-
328
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
329
-
330
-
331
- @dataclass
332
- class patched_AttentionMaskConverter:
333
- """
334
- Patches
335
- ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
336
- """
337
-
338
- # This method was fixed in 4.51 at least.
339
- _PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
340
- _PATCHED_CLASS_ = AttentionMaskConverter
341
-
342
- @staticmethod
343
- def _make_causal_mask(
344
- *args,
345
- **kwargs,
346
- # input_ids_shape: torch.Size,
347
- # dtype: torch.dtype,
348
- # device: torch.device,
349
- # past_key_values_length: int = 0,
350
- # sliding_window: Optional[int] = None,
351
- ):
352
- """
353
- Patched method.
354
-
355
- This static method may be called with ``AttentionMaskConverter._make_causal_mask``
356
- or ``self._make_causal_mask``. That changes this argument is receives.
357
- That should not matter but...
358
- The patch should be implemented in another way. static methods do not play well
359
- with a simple replacement.
360
- Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
361
- """
362
- if args:
363
- index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
364
- names = [
365
- "input_ids_shape",
366
- "dtype",
367
- "device",
368
- "past_key_values_length",
369
- "sliding_window",
370
- ]
371
- for i, a in enumerate(args):
372
- if i < index:
373
- continue
374
- kwargs[names[i - index]] = a
375
- return _patch_make_causal_mask(**kwargs)
376
-
377
-
378
- if pv.Version(transformers.__version__) < pv.Version("4.51"):
379
- from typing import Any, Dict
380
- from transformers.cache_utils import DynamicCache
381
-
382
- class patched_DynamicCache:
383
- """
384
- Applies modifications implemented in PR
385
- `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
386
- """
387
-
388
- _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
389
- _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
390
-
391
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
392
- """Returns the sequence length of the cached states.
393
- A layer index can be optionally passed."""
394
- # TODO: deprecate this function in favor of `cache_position`
395
- is_empty_layer = (
396
- len(self.key_cache) == 0 # no cache in any layer
397
- or len(self.key_cache)
398
- <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
399
- or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
400
- )
401
- layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
402
- return layer_seq_length
403
-
404
- def reorder_cache(self, beam_idx: torch.LongTensor):
405
- """Reorders the cache for beam search, given the selected beam indices."""
406
- for layer_idx in range(len(self.key_cache)):
407
- if self.key_cache[layer_idx].numel():
408
- device = self.key_cache[layer_idx].device
409
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
410
- 0, beam_idx.to(device)
411
- )
412
- if self.value_cache[layer_idx].numel():
413
- device = self.value_cache[layer_idx].device
414
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
415
- 0, beam_idx.to(device)
416
- )
417
-
418
- def update(
419
- self,
420
- key_states: torch.Tensor,
421
- value_states: torch.Tensor,
422
- layer_idx: int,
423
- cache_kwargs: Optional[Dict[str, Any]] = None,
424
- ) -> Tuple[torch.Tensor, torch.Tensor]:
425
- """
426
- Updates the cache with the new `key_states`
427
- and `value_states` for the layer `layer_idx`.
428
- Parameters:
429
- key_states (`torch.Tensor`):
430
- The new key states to cache.
431
- value_states (`torch.Tensor`):
432
- The new value states to cache.
433
- layer_idx (`int`):
434
- The index of the layer to cache the states for.
435
- cache_kwargs (`Dict[str, Any]`, `optional`):
436
- Additional arguments for the cache subclass.
437
- No additional arguments are used in `DynamicCache`.
438
- Return:
439
- A tuple containing the updated key and value states.
440
- """
441
- # Update the number of seen tokens
442
- if layer_idx == 0:
443
- if hasattr(self, "_seen_tokens"):
444
- self._seen_tokens += key_states.shape[-2]
445
-
446
- # Update the cache
447
- if key_states is not None:
448
- if len(self.key_cache) <= layer_idx:
449
- # There may be skipped layers, fill them with empty lists
450
- for _ in range(len(self.key_cache), layer_idx):
451
- self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
452
- self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
453
- self.key_cache.append(key_states)
454
- self.value_cache.append(value_states)
455
- elif not self.key_cache[
456
- layer_idx
457
- ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
458
- # fills previously skipped layers; checking for tensor causes errors
459
- self.key_cache[layer_idx] = key_states
460
- self.value_cache[layer_idx] = value_states
461
- else:
462
- torch._check(
463
- len(self.key_cache[layer_idx].shape) == len(key_states.shape),
464
- lambda: (
465
- f"Rank mismatch len(self.key_cache[layer_idx].shape)="
466
- f"{len(self.key_cache[layer_idx].shape)}, "
467
- f"len(key_states.shape)={len(key_states.shape)}"
468
- ),
469
- )
470
- self.key_cache[layer_idx] = torch.cat(
471
- [self.key_cache[layer_idx], key_states], dim=-2
472
- )
473
- self.value_cache[layer_idx] = torch.cat(
474
- [self.value_cache[layer_idx], value_states], dim=-2
475
- )
476
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
477
-
478
- def crop(self, max_length: int):
479
- """Crop the past key values up to a new `max_length`
480
- in terms of tokens. `max_length` can also be
481
- negative to remove `max_length` tokens.
482
- This is used in assisted decoding and contrastive search.
483
- """
484
- # In case it is negative
485
- if max_length < 0:
486
- max_length = self.get_seq_length() - abs(max_length)
487
-
488
- if self.get_seq_length() <= max_length:
489
- return
490
-
491
- if hasattr(self, "_seen_tokens"):
492
- self._seen_tokens = max_length
493
- for idx in range(len(self.key_cache)):
494
- if self.key_cache[idx].numel():
495
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
496
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
497
-
498
- @classmethod
499
- def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
500
- """This is the opposite of the above `batch_split()` method.
501
- This will be used by `stack_model_outputs` in
502
- `generation.utils`"""
503
- cache = cls()
504
- for idx in range(len(splits[0])):
505
- key_cache = [
506
- current.key_cache[idx]
507
- for current in splits
508
- if current.key_cache[idx].numel()
509
- ]
510
- value_cache = [
511
- current.value_cache[idx]
512
- for current in splits
513
- if current.value_cache[idx].numel()
514
- ]
515
- if key_cache != []:
516
- layer_keys = torch.cat(key_cache, dim=0)
517
- layer_values = torch.cat(value_cache, dim=0)
518
- cache.update(layer_keys, layer_values, idx)
519
- return cache
520
-
521
-
522
- class patched_GenerationMixin:
523
- """
524
- Applies modifications implemented in PR
525
- `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
526
- """
527
-
528
- _PATCHES_ = [
529
- "_cache_dependant_input_preparation",
530
- "_cache_dependant_input_preparation_exporting",
531
- (
532
- None
533
- if pv.Version(transformers.__version__) >= pv.Version("4.56")
534
- else "prepare_inputs_for_generation"
535
- ),
536
- (
537
- "_sample"
538
- if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
539
- else None
540
- ),
541
- ]
542
- _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
543
-
544
- def _cache_dependant_input_preparation(
545
- self,
546
- input_ids: torch.LongTensor,
547
- inputs_embeds: Optional[torch.FloatTensor],
548
- cache_position: Optional[torch.LongTensor],
549
- ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
550
- """
551
- Generic cache-dependent input preparation
552
- The code is put in a separate function to allow granular unit testing
553
- as it needs a different implementation to be exportable.
554
-
555
- If we have cache: let's slice `input_ids` through `cache_position`,
556
- to keep only the unprocessed tokens
557
- - Exception 1: when passing input_embeds,
558
- input_ids may be missing entries
559
- - Exception 2: some generation methods do special slicing of input_ids,
560
- so we don't need to do it here
561
- - Exception 3: with synced GPUs cache_position may go out of bounds,
562
- but we only want dummy token in that case.
563
- - Exception 4: If input_embeds are passed then slice it through
564
- `cache_position`, to keep only the unprocessed tokens and
565
- generate the first token for each sequence.
566
- Later use the generated Input ids for continuation.
567
-
568
- The current implementation does not rely on ``self`` and could be
569
- a class method. It is left as a standard method to be easily rewritten.
570
- """
571
- if _is_torchdynamo_exporting():
572
- return self._cache_dependant_input_preparation_exporting(
573
- input_ids, inputs_embeds, cache_position
574
- )
575
- if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
576
- inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
577
- elif inputs_embeds is not None or ( # Exception 1
578
- cache_position[-1] >= input_ids.shape[1]
579
- ): # Exception 3
580
- input_ids = input_ids[:, -cache_position.shape[0] :]
581
- elif (
582
- input_ids.shape[1] != cache_position.shape[0]
583
- ): # Default case (the "else", a no op, is Exception 2)
584
- input_ids = input_ids[:, cache_position]
585
- return inputs_embeds, input_ids
586
-
587
- def _cache_dependant_input_preparation_exporting(
588
- self,
589
- input_ids: torch.LongTensor,
590
- inputs_embeds: Optional[torch.FloatTensor],
591
- cache_position: Optional[torch.LongTensor],
592
- ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
593
- """
594
- This method implements method ``_cache_dependant_input_preparation``
595
- with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
596
- The code is put in a separate function to allow granular unit testing.
597
- """
598
- if inputs_embeds is None:
599
- input_ids = input_ids[:, cache_position]
600
- else:
601
- # This is the code we need to implemented with torch.cond.
602
- # if input_ids.shape[1] == 0:
603
- # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
604
- # else:
605
- # if cache_position[-1] >= input_ids.shape[1]:
606
- # input_ids = input_ids[:, -cache_position.shape[0] :]
607
- # else:
608
- # if input_ids.shape[1] != cache_position.shape[0]:
609
- # input_ids = input_ids[:, cache_position]
610
- def branch_1(inputs_embeds, cache_position):
611
- return inputs_embeds[:, -cache_position.shape[0] :].clone()
612
-
613
- def branch_2(input_ids, cache_position):
614
- return input_ids[:, -cache_position.shape[0] :].clone()
615
-
616
- def branch_3(input_ids, cache_position):
617
- return input_ids[:, cache_position].clone()
618
-
619
- inputs_embeds, input_ids = torch.cond(
620
- input_ids.shape[1] == 0,
621
- (
622
- lambda input_ids, inputs_embeds, cache_position: (
623
- branch_1(inputs_embeds, cache_position),
624
- input_ids.clone(),
625
- )
626
- ),
627
- (
628
- lambda input_ids, inputs_embeds, cache_position: (
629
- inputs_embeds,
630
- torch.cond(
631
- cache_position[-1] >= input_ids.shape[1],
632
- branch_2,
633
- lambda input_ids, cache_position: (
634
- torch.cond(
635
- input_ids.shape[1] != cache_position.shape[0],
636
- branch_3,
637
- (lambda input_ids, cache_position: input_ids),
638
- [input_ids, cache_position],
639
- )
640
- ),
641
- [input_ids, cache_position],
642
- ),
643
- )
644
- ),
645
- [input_ids, inputs_embeds, cache_position],
646
- )
647
- return inputs_embeds, input_ids
648
-
649
- def prepare_inputs_for_generation(
650
- self,
651
- input_ids: torch.LongTensor,
652
- past_key_values: Optional[Cache] = None,
653
- attention_mask: Optional[torch.LongTensor] = None,
654
- inputs_embeds: Optional[torch.FloatTensor] = None,
655
- cache_position: Optional[torch.LongTensor] = None,
656
- **kwargs,
657
- ):
658
- """
659
- Prepare the model inputs for generation.
660
- In includes operations like computing the 4D attention mask or
661
- slicing inputs given the existing cache.
662
-
663
- See the forward pass in the model documentation
664
- for expected arguments (different models might have different
665
- requirements for e.g. `past_key_values`).
666
- This function should work as is for most LLMs.
667
- """
668
-
669
- # 1. Handle BC:
670
- model_inputs = {}
671
- # - some models don't have `Cache` support
672
- # (which implies they don't expect `cache_position` in `forward`)
673
- if getattr(self, "_supports_cache_class", False):
674
- model_inputs["cache_position"] = cache_position
675
- # - `cache_position` was not a mandatory input in
676
- # `prepare_inputs_for_generation` for those models, and this
677
- # function may be called outside of `generate`.
678
- # Handle most use cases by creating `cache_position` on the fly
679
- # (this alternative is not as robust as calling
680
- # `generate` and letting it create `cache_position`)
681
- elif cache_position is None:
682
- past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
683
- cache_position = torch.arange(
684
- past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device
685
- )
686
-
687
- # 2. Generic cache-dependent input preparation
688
- if past_key_values is not None:
689
- model_inputs["past_key_values"] = past_key_values
690
- inputs_embeds, input_ids = self._cache_dependant_input_preparation(
691
- input_ids, inputs_embeds, cache_position
692
- )
693
-
694
- # 3. Prepare base model inputs
695
- input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
696
- # if `inputs_embeds` are passed, we only want
697
- # to use them in the 1st generation step for every prompt.
698
- if not self.config.is_encoder_decoder:
699
- if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
700
- model_inputs[input_ids_key] = None
701
- model_inputs["inputs_embeds"] = inputs_embeds
702
- else:
703
- # `clone` calls in this function ensure a consistent stride. See #32227
704
- model_inputs[input_ids_key] = input_ids.clone(
705
- memory_format=torch.contiguous_format
706
- )
707
- model_inputs["inputs_embeds"] = None
708
- else:
709
- model_inputs[input_ids_key] = input_ids.clone(
710
- memory_format=torch.contiguous_format
711
- )
712
-
713
- # 4. Create missing `position_ids` on the fly
714
- encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
715
- attention_mask = (
716
- kwargs.pop("decoder_attention_mask", None)
717
- if self.config.is_encoder_decoder
718
- else attention_mask
719
- )
720
- attention_mask_key = (
721
- "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
722
- )
723
- position_ids_key = (
724
- "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
725
- )
726
- if (
727
- attention_mask is not None
728
- and kwargs.get(position_ids_key) is None
729
- and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
730
- ):
731
- position_ids = attention_mask.long().cumsum(-1) - 1
732
- position_ids.masked_fill_(attention_mask == 0, 1)
733
- kwargs[position_ids_key] = (
734
- position_ids # placed in kwargs for further processing (see below)
735
- )
736
-
737
- # 5. Slice model inputs if it's an input
738
- # that should have the same length as `input_ids`
739
- for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
740
- model_input = kwargs.get(model_input_name)
741
- if model_input is not None:
742
- if past_key_values is not None:
743
- current_input_length = (
744
- model_inputs["inputs_embeds"].shape[1]
745
- if model_inputs.get("inputs_embeds") is not None
746
- else model_inputs[input_ids_key].shape[1]
747
- )
748
- model_input = model_input[:, -current_input_length:]
749
- model_input = model_input.clone(memory_format=torch.contiguous_format)
750
- model_inputs[model_input_name] = model_input
751
-
752
- # 6. Create 4D attention mask is we are using a
753
- # `StaticCache` (important for performant compiled forward pass)
754
- if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
755
- if model_inputs["inputs_embeds"] is not None:
756
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
757
- device = model_inputs["inputs_embeds"].device
758
- else:
759
- batch_size, sequence_length = model_inputs[input_ids_key].shape
760
- device = model_inputs[input_ids_key].device
761
-
762
- # Create the causal mask with fixed shape in advance,
763
- # to reduce recompilations. If the function to create
764
- # the 4D causal mask exists,
765
- # it should be present in the base model (XXXModel class).
766
- base_model = getattr(self, self.base_model_prefix, None)
767
- if base_model is None:
768
- causal_mask_creation_function = getattr(
769
- self, "_prepare_4d_causal_attention_mask_with_cache_position", None
770
- )
771
- else:
772
- causal_mask_creation_function = getattr(
773
- base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
774
- )
775
- if causal_mask_creation_function is None:
776
- pass
777
- # logger.warning_once(
778
- # f"{self.__class__.__name__} has no "
779
- # "`_prepare_4d_causal_attention_mask_with_cache_position` method "
780
- # "defined in its base modeling class. "
781
- # "Compiled forward passes will be sub-optimal. If you're "
782
- # "writing code, see Llama for an example implementation. "
783
- # "If you're a user, please report this "
784
- # "issue on GitHub."
785
- # )
786
- else:
787
- attention_mask = causal_mask_creation_function(
788
- attention_mask,
789
- sequence_length=sequence_length,
790
- target_length=past_key_values.get_max_cache_shape(),
791
- dtype=self.dtype,
792
- device=device,
793
- cache_position=cache_position,
794
- batch_size=batch_size,
795
- config=self.config,
796
- past_key_values=past_key_values,
797
- )
798
- if attention_mask is not None:
799
- model_inputs[attention_mask_key] = attention_mask
800
-
801
- if encoder_attention_mask is not None:
802
- model_inputs["attention_mask"] = encoder_attention_mask
803
-
804
- # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
805
- for key, value in kwargs.items():
806
- if key not in model_inputs:
807
- model_inputs[key] = value
808
-
809
- # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
810
- model_inputs.pop("labels", None)
811
- return model_inputs
812
-
813
- def _sample(
814
- self,
815
- input_ids: torch.LongTensor,
816
- logits_processor: "LogitsProcessorList", # noqa: F821
817
- stopping_criteria: "StoppingCriteriaList", # noqa: F821
818
- generation_config: "GenerationConfig", # noqa: F821
819
- synced_gpus: bool = False,
820
- streamer: Optional["BaseStreamer"] = None, # noqa: F821
821
- **model_kwargs,
822
- ) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
823
- """
824
- 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
825
- """
826
- # init values
827
- pad_token_id = generation_config._pad_token_tensor
828
- output_attentions = generation_config.output_attentions
829
- output_hidden_states = generation_config.output_hidden_states
830
- output_scores = generation_config.output_scores
831
- output_logits = generation_config.output_logits
832
- return_dict_in_generate = generation_config.return_dict_in_generate
833
- has_eos_stopping_criteria = any(
834
- hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
835
- )
836
- do_sample = generation_config.do_sample
837
-
838
- # init attention / hidden states / scores tuples
839
- scores = () if (return_dict_in_generate and output_scores) else None
840
- raw_logits = () if (return_dict_in_generate and output_logits) else None
841
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
842
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
843
- decoder_hidden_states = (
844
- () if (return_dict_in_generate and output_hidden_states) else None
845
- )
846
-
847
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
848
- if return_dict_in_generate and self.config.is_encoder_decoder:
849
- encoder_attentions = (
850
- model_kwargs["encoder_outputs"].get("attentions")
851
- if output_attentions
852
- else None
853
- )
854
- encoder_hidden_states = (
855
- model_kwargs["encoder_outputs"].get("hidden_states")
856
- if output_hidden_states
857
- else None
858
- )
859
-
860
- # keep track of which sequences are already finished
861
- batch_size, cur_len = input_ids.shape[:2]
862
- this_peer_finished = False
863
- unfinished_sequences = torch.ones(
864
- batch_size, dtype=torch.long, device=input_ids.device
865
- )
866
- model_kwargs = self._get_initial_cache_position(
867
- cur_len, input_ids.device, model_kwargs
868
- )
869
-
870
- model_forward = self.__call__
871
- compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
872
- if compile_forward:
873
- os.environ["TOKENIZERS_PARALLELISM"] = "0"
874
- # If we use FA2 and a static cache, we cannot compile with fullgraph
875
- if self.config._attn_implementation == "flash_attention_2":
876
- # only raise warning if the user passed an explicit compile-config
877
- if (
878
- generation_config.compile_config is not None
879
- and generation_config.compile_config.fullgraph
880
- ):
881
- generation_config.compile_config.fullgraph = False
882
- model_forward = self.get_compiled_call(generation_config.compile_config)
883
-
884
- if generation_config.prefill_chunk_size is not None:
885
- model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
886
- is_prefill = False
887
- else:
888
- is_prefill = True
889
-
890
- while self._has_unfinished_sequences(
891
- this_peer_finished, synced_gpus, device=input_ids.device
892
- ):
893
- # prepare model inputs
894
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
895
-
896
- if is_prefill:
897
- outputs = self(**model_inputs, return_dict=True)
898
- is_prefill = False
899
- else:
900
- outputs = model_forward(**model_inputs, return_dict=True)
901
-
902
- model_kwargs = self._update_model_kwargs_for_generation(
903
- outputs,
904
- model_kwargs,
905
- is_encoder_decoder=self.config.is_encoder_decoder,
906
- )
907
- if synced_gpus and this_peer_finished:
908
- continue
909
-
910
- next_token_logits = outputs.logits[:, -1, :].to(
911
- copy=True, dtype=torch.float32, device=input_ids.device
912
- )
913
-
914
- # pre-process distribution
915
- next_token_scores = logits_processor(input_ids, next_token_logits)
916
-
917
- # Store scores, attentions and hidden_states when required
918
- if return_dict_in_generate:
919
- if output_scores:
920
- scores += (next_token_scores,)
921
- if output_logits:
922
- raw_logits += (next_token_logits,)
923
- if output_attentions:
924
- decoder_attentions += (
925
- (outputs.decoder_attentions,)
926
- if self.config.is_encoder_decoder
927
- else (outputs.attentions,)
928
- )
929
- if self.config.is_encoder_decoder:
930
- cross_attentions += (outputs.cross_attentions,)
931
-
932
- if output_hidden_states:
933
- decoder_hidden_states += (
934
- (outputs.decoder_hidden_states,)
935
- if self.config.is_encoder_decoder
936
- else (outputs.hidden_states,)
937
- )
938
-
939
- # token selection
940
- if do_sample:
941
- probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
942
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
943
- else:
944
- next_tokens = torch.argmax(next_token_scores, dim=-1)
945
-
946
- # finished sentences should have their next token be a padding token
947
- if has_eos_stopping_criteria:
948
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
949
- 1 - unfinished_sequences
950
- )
951
-
952
- # update generated ids, model inputs, and length for next step
953
- # PATCHED: the two following lines, next_tokens can 2D already for this model
954
- next_tokens_2d = (
955
- next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
956
- )
957
- input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
958
- if streamer is not None:
959
- streamer.put(next_tokens.cpu())
960
-
961
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
962
- this_peer_finished = unfinished_sequences.max() == 0
963
- cur_len += 1
964
-
965
- # This is needed to properly delete outputs.logits which may be very large
966
- # for first iteration
967
- # Otherwise a reference to outputs is kept which keeps
968
- # the logits alive in the next iteration
969
- del outputs
970
-
971
- if streamer is not None:
972
- streamer.end()
973
-
974
- if return_dict_in_generate:
975
- if self.config.is_encoder_decoder:
976
- return transformers.generation.utils.GenerateEncoderDecoderOutput(
977
- sequences=input_ids,
978
- scores=scores,
979
- logits=raw_logits,
980
- encoder_attentions=encoder_attentions,
981
- encoder_hidden_states=encoder_hidden_states,
982
- decoder_attentions=decoder_attentions,
983
- cross_attentions=cross_attentions,
984
- decoder_hidden_states=decoder_hidden_states,
985
- past_key_values=model_kwargs.get("past_key_values"),
986
- )
987
- else:
988
- return transformers.generation.utils.GenerateDecoderOnlyOutput(
989
- sequences=input_ids,
990
- scores=scores,
991
- logits=raw_logits,
992
- attentions=decoder_attentions,
993
- hidden_states=decoder_hidden_states,
994
- past_key_values=model_kwargs.get("past_key_values"),
995
- )
996
- else:
997
- return input_ids
998
-
999
-
1000
- def patched__compute_dynamic_ntk_parameters(
1001
- config: Optional[transformers.PretrainedConfig] = None,
1002
- device: Optional["torch.device"] = None,
1003
- seq_len: Optional[int] = None,
1004
- **rope_kwargs,
1005
- ) -> Tuple["torch.Tensor", float]:
1006
- """
1007
- manual patch:
1008
- ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
1009
-
1010
- Computes the inverse frequencies with NTK scaling.
1011
- Credits to the Reddit users /u/bloc97 and /u/emozilla
1012
-
1013
- Args:
1014
- config ([`~transformers.PretrainedConfig`]):
1015
- The model configuration.
1016
- device (`torch.device`):
1017
- The device to use for initialization of the inverse frequencies.
1018
- seq_len (`int`, *optional*):
1019
- The current sequence length,
1020
- used to update the dynamic RoPE at inference time.
1021
- rope_kwargs (`Dict`, *optional*):
1022
- BC compatibility with the previous
1023
- RoPE class instantiation, will be removed in v4.45.
1024
-
1025
- Returns:
1026
- Tuple of (`torch.Tensor`, `float`),
1027
- containing the inverse frequencies for the RoPE embeddings and the
1028
- post-processing scaling factor applied to the
1029
- omputed cos/sin (unused in this type of RoPE).
1030
- """
1031
- if config is not None and len(rope_kwargs) > 0:
1032
- raise ValueError(
1033
- "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
1034
- f"`_compute_dynamic_ntk_parameters`, got "
1035
- f"`rope_kwargs`={rope_kwargs} and `config`={config}"
1036
- )
1037
- if len(rope_kwargs) > 0:
1038
- base = rope_kwargs["base"]
1039
- dim = rope_kwargs["dim"]
1040
- max_position_embeddings = rope_kwargs["max_position_embeddings"]
1041
- factor = rope_kwargs["factor"]
1042
- elif config is not None:
1043
- base = config.rope_theta
1044
- partial_rotary_factor = (
1045
- config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
1046
- )
1047
- head_dim = getattr(
1048
- config, "head_dim", config.hidden_size // config.num_attention_heads
1049
- )
1050
- dim = int(head_dim * partial_rotary_factor)
1051
- max_position_embeddings = config.max_position_embeddings
1052
- factor = config.rope_scaling["factor"]
1053
-
1054
- attention_factor = 1.0 # Unused in this type of RoPE
1055
-
1056
- # seq_len: default to max_position_embeddings, e.g. at init time
1057
- # seq_len = seq_len if seq_len is not None and
1058
- # seq_len > max_position_embeddings else max_position_embeddings
1059
- if seq_len is None:
1060
- seq_len = max_position_embeddings
1061
- else:
1062
- # PATCHED: remove the line using max
1063
- torch._check(isinstance(seq_len, torch.Tensor))
1064
- seq_len = torch.maximum(
1065
- seq_len,
1066
- torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
1067
- )
1068
-
1069
- # Compute the inverse frequencies
1070
- base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
1071
- dim / (dim - 2)
1072
- )
1073
- inv_freq = 1.0 / (
1074
- base
1075
- ** (
1076
- torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
1077
- / dim
1078
- )
1079
- )
1080
- return inv_freq, attention_factor
1081
-
1082
-
1083
- def _get_rope_init_fn(self, layer_type=None) -> Callable:
1084
- if hasattr(self, "rope_init_fn"):
1085
- # transformers<=5.0
1086
- rope_init_fn = (
1087
- patched__compute_dynamic_ntk_parameters
1088
- if self.rope_init_fn
1089
- is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1090
- else self.rope_init_fn
1091
- )
1092
- return rope_init_fn
1093
-
1094
- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
1095
- rope_init_fn = self.compute_default_rope_parameters
1096
- if rope_type != "default":
1097
- rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
1098
- if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
1099
- return patched__compute_dynamic_ntk_parameters
1100
- return rope_init_fn
1101
-
1102
-
1103
- def patched_dynamic_rope_update(rope_forward):
1104
- """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
1105
-
1106
- ``rope_type`` is determined in the constructor of class
1107
- :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
1108
-
1109
- .. code-block:: python
1110
-
1111
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1112
- self.rope_type = config.rope_scaling.get(
1113
- "rope_type", config.rope_scaling.get("type"))
1114
- else:
1115
- self.rope_type = "default"
1116
-
1117
- The original code of the patched function:
1118
-
1119
- .. code-block:: python
1120
-
1121
- def dynamic_rope_update(rope_forward):
1122
- def longrope_frequency_update(self, position_ids, device):
1123
- seq_len = torch.max(position_ids) + 1
1124
- if hasattr(self.config, "original_max_position_embeddings"):
1125
- original_max_position_embeddings =
1126
- self.config.original_max_position_embeddings
1127
- else:
1128
- original_max_position_embeddings =
1129
- self.config.max_position_embeddings
1130
- if seq_len > original_max_position_embeddings:
1131
- if not hasattr(self, "long_inv_freq"):
1132
- self.long_inv_freq, _ = self.rope_init_fn(
1133
- self.config, device, seq_len=original_max_position_embeddings + 1
1134
- )
1135
- self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
1136
- else:
1137
- self.original_inv_freq = self.original_inv_freq.to(device)
1138
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1139
-
1140
- def dynamic_frequency_update(self, position_ids, device):
1141
- seq_len = torch.max(position_ids) + 1
1142
- if seq_len > self.max_seq_len_cached: # growth
1143
- inv_freq, self.attention_scaling = self.rope_init_fn(
1144
- self.config, device, seq_len=seq_len)
1145
- self.register_buffer("inv_freq", inv_freq, persistent=False)
1146
- self.max_seq_len_cached = seq_len
1147
-
1148
- if seq_len < self.original_max_seq_len and
1149
- self.max_seq_len_cached > self.original_max_seq_len:
1150
- self.original_inv_freq = self.original_inv_freq.to(device)
1151
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1152
- self.max_seq_len_cached = self.original_max_seq_len
1153
-
1154
- @wraps(rope_forward)
1155
- def wrapper(self, x, position_ids):
1156
- if "dynamic" in self.rope_type:
1157
- dynamic_frequency_update(self, position_ids, device=x.device)
1158
- elif self.rope_type == "longrope":
1159
- longrope_frequency_update(self, position_ids, device=x.device)
1160
- return rope_forward(self, x, position_ids)
1161
-
1162
- return wrapper
1163
-
1164
- """
1165
-
1166
- def longrope_frequency_update(self, position_ids, device, layer_type=None):
1167
- # It is no use to patch the function after the model is created
1168
- # as rope_init_fn is an attribute set to one function when the model
1169
- # is created and when no patch is applied yet.
1170
- # So we select the patched version here.
1171
- rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
1172
- seq_len = torch.max(position_ids) + 1
1173
- if hasattr(self.config, "original_max_position_embeddings"):
1174
- original_max_position_embeddings = self.config.original_max_position_embeddings
1175
- else:
1176
- original_max_position_embeddings = self.config.max_position_embeddings
1177
-
1178
- if layer_type is None:
1179
- # rope_type = self.rope_type
1180
- original_inv_freq = self.original_inv_freq
1181
- prefix = ""
1182
- else:
1183
- # rope_type = self.rope_type[layer_type]
1184
- original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1185
- prefix = f"{layer_type}_"
1186
-
1187
- # At export time, seq_len is unknown.
1188
- long_inv_freq, _ = rope_init_fn(
1189
- self.config, device, seq_len=original_max_position_embeddings + 1
1190
- )
1191
- original_inv_freq = self.original_inv_freq.to(device)
1192
-
1193
- # PATCHED: uses torch.cond instead of a test
1194
- cond = (seq_len > original_max_position_embeddings).item()
1195
- inv_freq = torch.cond(
1196
- cond,
1197
- (lambda x, y: x.clone()),
1198
- (lambda x, y: y.clone()),
1199
- [long_inv_freq, original_inv_freq],
1200
- )
1201
- setattr(self, f"{prefix}inv_freq", inv_freq)
1202
- # if seq_len > original_max_position_embeddings:
1203
- # self.inv_freq = self.long_inv_freq
1204
- # else:
1205
- # self.inv_freq = self.original_inv_freq
1206
-
1207
- def dynamic_frequency_update(self, position_ids, device, layer_type=None):
1208
- # constructor:
1209
- # - self.max_seq_len_cached = config.max_position_embeddings
1210
- # - self.original_max_seq_len = config.max_position_embeddings
1211
- # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
1212
-
1213
- # It is no use to patch the function after the model is created
1214
- # as rope_init_fn is an attribute set to one function when the model
1215
- # is created and when no patch is applied yet.
1216
- # So we select the patched version here.
1217
- rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
1218
-
1219
- # This behaviour is difficult to translate.
1220
- # The sequence always grows.
1221
- # The test should always True.
1222
- # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
1223
- #
1224
- # if seq_len > self.max_seq_len_cached: # growth
1225
- # inv_freq, self.attention_scaling = self.rope_init_fn(
1226
- # self.config, device, seq_len=seq_len
1227
- # )
1228
- # self.register_buffer("inv_freq", inv_freq, persistent=False)
1229
- # self.max_seq_len_cached = seq_len
1230
- #
1231
- # So we should not need what follows.
1232
- #
1233
- # cond = (seq_len > self.max_seq_len_cached).item()
1234
- # self.attention_scaling = torch.cond(
1235
- # cond,
1236
- # (lambda x, y: x.clone()),
1237
- # (lambda x, y: y.clone()),
1238
- # [attention_scaling, self.attention_scaling],
1239
- # )
1240
-
1241
- seq_len = torch.max(position_ids) + 1
1242
- long_inv_freq, self.attention_scaling = rope_init_fn(
1243
- self.config, device, seq_len=seq_len
1244
- )
1245
-
1246
- if layer_type is None:
1247
- # rope_type = self.rope_type
1248
- # max_seq_len_cached = self.max_seq_len_cached
1249
- original_inv_freq = self.original_inv_freq
1250
- prefix = ""
1251
- else:
1252
- # rope_type = self.rope_type[layer_type]
1253
- # max_seq_len_cached = getattr(
1254
- # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
1255
- # )
1256
- original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1257
- prefix = f"{layer_type}_"
1258
-
1259
- # Second test to translate.
1260
- # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
1261
- # But in that case the following condition is a way to restore the original cache.
1262
-
1263
- # if (
1264
- # seq_len < self.original_max_seq_len
1265
- # and self.max_seq_len_cached > self.original_max_seq_len
1266
- # ):
1267
- # self.original_inv_freq = self.original_inv_freq.to(device)
1268
- # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1269
- # self.max_seq_len_cached = self.original_max_seq_len
1270
-
1271
- original_inv_freq = self.original_inv_freq.to(device)
1272
- cond = (seq_len >= self.original_max_seq_len).item()
1273
- # PATCHED: uses torch.cond instead of a test
1274
- inv_freq = torch.cond(
1275
- cond,
1276
- (lambda x, y: x.clone()),
1277
- (lambda x, y: y.clone()),
1278
- [long_inv_freq, original_inv_freq],
1279
- )
1280
- setattr(self, f"{prefix}inv_freq", inv_freq)
1281
-
1282
- @wraps(rope_forward)
1283
- def wrapper(self, x, position_ids, layer_type=None):
1284
- if layer_type is None:
1285
- if "dynamic" in self.rope_type:
1286
- dynamic_frequency_update(self, position_ids, device=x.device)
1287
- elif self.rope_type == "longrope":
1288
- longrope_frequency_update(self, position_ids, device=x.device)
1289
- return rope_forward(self, x, position_ids)
1290
-
1291
- if "dynamic" in self.rope_type:
1292
- dynamic_frequency_update(
1293
- self, position_ids, device=x.device, layer_type=layer_type
1294
- )
1295
- elif self.rope_type == "longrope":
1296
- longrope_frequency_update(
1297
- self, position_ids, device=x.device, layer_type=layer_type
1298
- )
1299
- return rope_forward(self, x, position_ids, layer_type=layer_type)
1300
-
1301
- return wrapper
1302
-
1303
-
1304
- def common_eager_attention_forward(
1305
- module: torch.nn.Module,
1306
- query: torch.Tensor,
1307
- key: torch.Tensor,
1308
- value: torch.Tensor,
1309
- attention_mask: Optional[torch.Tensor],
1310
- scaling: Optional[float] = None,
1311
- dropout: float = 0.0,
1312
- head_mask: Optional[torch.Tensor] = None,
1313
- **kwargs,
1314
- ):
1315
- if scaling is None:
1316
- scaling = query.size(-1) ** -0.5
1317
-
1318
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
1319
- if attention_mask is not None:
1320
- # PATCHED
1321
- # The two following lines were added.
1322
- if attention_mask is not None and attention_mask.ndim == 4:
1323
- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1324
- attn_weights = attn_weights + attention_mask
1325
-
1326
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
1327
-
1328
- if head_mask is not None:
1329
- attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
1330
-
1331
- attn_weights = torch.nn.functional.dropout(
1332
- attn_weights, p=dropout, training=module.training
1333
- )
1334
- attn_output = torch.matmul(attn_weights, value)
1335
- attn_output = attn_output.transpose(1, 2).contiguous()
1336
-
1337
- return attn_output, attn_weights
25
+ from ._patch_transformers_generation_mixin import patched_GenerationMixin
1338
26
 
27
+ from ._patch_transformers_masking_utils import patch_masking_utils
1339
28
 
1340
- def patched_sdpa_attention_forward(
1341
- module: torch.nn.Module,
1342
- query: torch.Tensor,
1343
- key: torch.Tensor,
1344
- value: torch.Tensor,
1345
- attention_mask: Optional[torch.Tensor],
1346
- dropout: float = 0.0,
1347
- scaling: Optional[float] = None,
1348
- is_causal: Optional[bool] = None,
1349
- **kwargs,
1350
- ) -> tuple[torch.Tensor, None]:
1351
- """
1352
- manual patch for function
1353
- ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
1354
- """
1355
- assert not kwargs.get("output_attentions", False), (
1356
- "`sdpa` attention does not support `output_attentions=True`."
1357
- " Please set your attention to `eager` if you want any of these features."
1358
- )
1359
- torch._check(
1360
- query.shape[0] == key.shape[0] or query.shape[0] == 1,
1361
- lambda: (
1362
- f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
1363
- f"value: {value.shape}"
1364
- ),
1365
- )
1366
- torch._check(
1367
- key.shape[0] == value.shape[0] or key.shape[0] == 1,
1368
- lambda: (
1369
- f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
1370
- f"value: {value.shape}"
1371
- ),
1372
- )
1373
-
1374
- sdpa_kwargs = {}
1375
- if hasattr(module, "num_key_value_groups"):
1376
- if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
1377
- key = transformers.integrations.sdpa_attention.repeat_kv(
1378
- key, module.num_key_value_groups
1379
- )
1380
- value = transformers.integrations.sdpa_attention.repeat_kv(
1381
- value, module.num_key_value_groups
1382
- )
1383
- else:
1384
- sdpa_kwargs = {"enable_gqa": True}
1385
-
1386
- if attention_mask is not None and attention_mask.ndim == 4:
1387
- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1388
-
1389
- torch._check(
1390
- attention_mask is None or attention_mask.shape[3] == key.shape[2],
1391
- lambda: "Attention mask shape incompatible with key shape.",
1392
- )
1393
-
1394
- if patch_sdpa_is_causal:
1395
- # transformers>=4.55
1396
- is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1397
-
1398
- # PATCHED: remove the test query.shape[2] > 1
1399
- # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1400
- # and we split the test to keep the minimum in torch.cond
1401
- is_causal = attention_mask is None and is_causal
1402
-
1403
- if not is_causal:
1404
- torch._check(query.shape[0] > 0)
1405
- torch._check(query.shape[1] > 0)
1406
- torch._check(query.shape[2] > 0)
1407
- torch._check(query.shape[3] > 0)
1408
- torch._check(key.shape[0] > 0)
1409
- torch._check(key.shape[1] > 0)
1410
- torch._check(key.shape[2] > 0)
1411
- torch._check(key.shape[3] > 0)
1412
- torch._check(value.shape[0] > 0)
1413
- torch._check(value.shape[1] > 0)
1414
- torch._check(value.shape[2] > 0)
1415
- torch._check(value.shape[3] > 0)
1416
- return (
1417
- torch.nn.functional.scaled_dot_product_attention(
1418
- query,
1419
- key,
1420
- value,
1421
- attn_mask=attention_mask,
1422
- dropout_p=dropout,
1423
- scale=scaling,
1424
- is_causal=is_causal,
1425
- **sdpa_kwargs,
1426
- )
1427
- .transpose(1, 2)
1428
- .contiguous(),
1429
- None,
1430
- )
1431
- else:
1432
- # transformers<4.55
1433
- if is_causal is None and attention_mask is not None:
1434
- is_causal = False
1435
- if is_causal is not None:
1436
- return (
1437
- torch.nn.functional.scaled_dot_product_attention(
1438
- query,
1439
- key,
1440
- value,
1441
- attn_mask=attention_mask,
1442
- dropout_p=dropout,
1443
- scale=scaling,
1444
- is_causal=is_causal,
1445
- **sdpa_kwargs,
1446
- )
1447
- .transpose(1, 2)
1448
- .contiguous(),
1449
- None,
1450
- )
1451
-
1452
- # To avoid the following errors:
1453
- # is_causal=query.shape[2] > 1
1454
- # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
1455
- # is_causal=torch.tensor(query.shape[2] > 1)
1456
- # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
1457
- attn_output = torch.cond(
1458
- query.shape[2] > 1, # distinction between prefill and decoding steps
1459
- lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1460
- query,
1461
- key,
1462
- value,
1463
- dropout_p=dropout,
1464
- scale=scaling,
1465
- is_causal=True,
1466
- **sdpa_kwargs,
1467
- ).contiguous(),
1468
- lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1469
- query,
1470
- key,
1471
- value,
1472
- dropout_p=dropout,
1473
- scale=scaling,
1474
- is_causal=False,
1475
- **sdpa_kwargs,
1476
- ).contiguous(),
1477
- [query, key, value],
1478
- )
1479
- attn_output = attn_output.transpose(1, 2).contiguous()
1480
- return attn_output, None
1481
-
1482
-
1483
- def patched_model_bart_eager_attention_forward(
1484
- module: torch.nn.Module,
1485
- query: torch.Tensor,
1486
- key: torch.Tensor,
1487
- value: torch.Tensor,
1488
- attention_mask: Optional[torch.Tensor],
1489
- scaling: Optional[float] = None,
1490
- dropout: float = 0.0,
1491
- head_mask: Optional[torch.Tensor] = None,
1492
- **kwargs,
1493
- ):
1494
- """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
1495
- return common_eager_attention_forward(
1496
- module,
1497
- query,
1498
- key,
1499
- value,
1500
- attention_mask=attention_mask,
1501
- scaling=scaling,
1502
- dropout=dropout,
1503
- head_mask=head_mask,
1504
- **kwargs,
29
+ if patch_masking_utils:
30
+ from ._patch_transformers_masking_utils import (
31
+ patched__vmap_for_bhqkv,
32
+ patched_eager_mask,
33
+ patched_sdpa_mask_recent_torch,
1505
34
  )
1506
35
 
36
+ from ._patch_transformers_rotary_embedding import (
37
+ patched__compute_dynamic_ntk_parameters,
38
+ patched_dynamic_rope_update,
39
+ patched_GemmaRotaryEmbedding,
40
+ patched_LlamaRotaryEmbedding,
41
+ patched_MistralRotaryEmbedding,
42
+ patched_MixtralRotaryEmbedding,
43
+ patched_PhiRotaryEmbedding,
44
+ )
1507
45
 
1508
- def patched_modeling_marian_eager_attention_forward(
1509
- module: torch.nn.Module,
1510
- query: torch.Tensor,
1511
- key: torch.Tensor,
1512
- value: torch.Tensor,
1513
- attention_mask: Optional[torch.Tensor],
1514
- scaling: Optional[float] = None,
1515
- dropout: float = 0.0,
1516
- head_mask: Optional[torch.Tensor] = None,
1517
- **kwargs,
1518
- ):
1519
- """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
1520
- return common_eager_attention_forward(
1521
- module,
1522
- query,
1523
- key,
1524
- value,
1525
- attention_mask=attention_mask,
1526
- scaling=scaling,
1527
- dropout=dropout,
1528
- head_mask=head_mask,
1529
- **kwargs,
46
+ if _has_transformers("4.51"):
47
+ from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding
48
+ if _has_transformers("4.52"):
49
+ from ._patch_transformers_rotary_embedding import (
50
+ patched_Gemma2RotaryEmbedding,
51
+ patched_Gemma3RotaryEmbedding,
52
+ patched_Phi4MultimodalRotaryEmbedding,
1530
53
  )
54
+ if _has_transformers("4.53"):
55
+ from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding
1531
56
 
57
+ # Models
1532
58
 
1533
- class common_RotaryEmbedding(torch.nn.Module):
1534
- # This may cause some issues.
1535
- # @torch.no_grad()
1536
- # PATCHED: the decorator
1537
- @patched_dynamic_rope_update
1538
- def forward(self, x, position_ids, layer_type=None):
1539
- if layer_type is not None:
1540
- # transformers>=5.0
1541
- inv_freq = getattr(self, f"{layer_type}_inv_freq")
1542
- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
1543
- else:
1544
- # transformers<5.0
1545
- inv_freq = self.inv_freq
1546
- attention_scaling = self.attention_scaling
1547
-
1548
- inv_freq_expanded = (
1549
- inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1550
- )
1551
- position_ids_expanded = position_ids[:, None, :].float()
1552
-
1553
- device_type = (
1554
- x.device.type
1555
- if isinstance(x.device.type, str) and x.device.type != "mps"
1556
- else "cpu"
1557
- )
1558
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
1559
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1560
- emb = torch.cat((freqs, freqs), dim=-1)
1561
- cos = emb.cos() * attention_scaling
1562
- sin = emb.sin() * attention_scaling
1563
-
1564
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1565
-
1566
-
1567
- class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
1568
- _PATCHES_ = ["forward"]
1569
- _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
1570
-
1571
-
1572
- if pv.Version(transformers.__version__) >= pv.Version("4.52"):
1573
-
1574
- class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
1575
- _PATCHES_ = ["forward"]
1576
- _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
1577
-
1578
- class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
1579
- _PATCHES_ = ["forward"]
1580
- _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
1581
-
1582
-
1583
- class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
1584
- _PATCHES_ = ["forward"]
1585
- _PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
1586
-
1587
-
1588
- class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
1589
- _PATCHES_ = ["forward"]
1590
- _PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
1591
-
1592
-
1593
- class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
1594
- _PATCHES_ = ["forward"]
1595
- _PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
1596
-
1597
-
1598
- class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
1599
- _PATCHES_ = ["forward"]
1600
- _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
1601
-
1602
-
1603
- if pv.Version(transformers.__version__) >= pv.Version("4.51"):
1604
-
1605
- class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
1606
- _PATCHES_ = ["forward"]
1607
- _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
1608
-
1609
-
1610
- if pv.Version(transformers.__version__) >= pv.Version("4.52"):
1611
-
1612
- class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
1613
- _PATCHES_ = ["forward"]
1614
- _PATCHED_CLASS_ = (
1615
- transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
1616
- )
1617
-
1618
-
1619
- if pv.Version(transformers.__version__) >= pv.Version("4.53"):
1620
-
1621
- class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
1622
- _PATCHES_ = ["forward"]
1623
- _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
1624
-
1625
-
1626
- class patched_IdeficsEmbedding(torch.nn.Module):
1627
- _PATCHES_ = ["forward"]
1628
- _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
1629
-
1630
- def forward(self, x, seq_len=None):
1631
- # x: [bs, num_attention_heads, seq_len, head_size]
1632
- # if seq_len > self.max_seq_len_cached:
1633
- # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
1634
-
1635
- def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
1636
- t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
1637
- # freqs = torch.einsum("i,j->ij", t, inv_freq)
1638
- freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
1639
- emb = torch.cat((freqs, freqs), dim=-1)
1640
- return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
1641
-
1642
- def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
1643
- torch._check(seq_len.item() <= cos_cached.shape[0])
1644
- co = cos_cached[: seq_len.item()].detach().clone()
1645
- torch._check(seq_len.item() <= sin_cached.shape[0])
1646
- si = sin_cached[: seq_len.item()].detach().clone()
1647
- return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
1648
-
1649
- cos_cached, sin_cached = torch.cond(
1650
- (seq_len > self.max_seq_len_cached).item(),
1651
- _set_cos_sin_cache_then,
1652
- _set_cos_sin_cache_else,
1653
- [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
1654
- )
1655
- return cos_cached, sin_cached
1656
-
1657
-
1658
- class patched_IdeficsAttention(torch.nn.Module):
1659
- _PATCHES_ = ["forward"]
1660
- _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
1661
-
1662
- def forward(
1663
- self,
1664
- hidden_states: torch.Tensor,
1665
- key_value_states: Optional[torch.Tensor] = None,
1666
- attention_mask: Optional[torch.Tensor] = None,
1667
- position_ids: Optional[torch.LongTensor] = None,
1668
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
1669
- output_attentions: bool = False,
1670
- use_cache: bool = False,
1671
- cache_position: Optional[torch.LongTensor] = None,
1672
- **kwargs,
1673
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1674
- # if key_value_states are provided this layer is used as a cross-attention layer
1675
- is_cross_attention = self.is_cross_attention or key_value_states is not None
1676
-
1677
- bsz, q_len, _ = hidden_states.size()
1678
-
1679
- query_states = (
1680
- self.q_proj(hidden_states)
1681
- .view(bsz, q_len, self.num_heads, self.head_dim)
1682
- .transpose(1, 2)
1683
- )
1684
- if not is_cross_attention:
1685
- key_states = (
1686
- self.k_proj(hidden_states)
1687
- .view(bsz, q_len, self.num_heads, self.head_dim)
1688
- .transpose(1, 2)
1689
- )
1690
- value_states = (
1691
- self.v_proj(hidden_states)
1692
- .view(bsz, q_len, self.num_heads, self.head_dim)
1693
- .transpose(1, 2)
1694
- )
1695
- else:
1696
- _, kv_len, _ = (
1697
- key_value_states.size()
1698
- ) # Note that, in this case, `kv_len` == `kv_seq_len`
1699
- key_states = (
1700
- self.k_proj(key_value_states)
1701
- .view(bsz, kv_len, self.num_heads, self.head_dim)
1702
- .transpose(1, 2)
1703
- )
1704
- value_states = (
1705
- self.v_proj(key_value_states)
1706
- .view(bsz, kv_len, self.num_heads, self.head_dim)
1707
- .transpose(1, 2)
1708
- )
1709
-
1710
- kv_seq_len = key_states.shape[-2]
1711
- if past_key_value is not None:
1712
- kv_seq_len += cache_position[0]
1713
-
1714
- if not is_cross_attention:
1715
- rotary_length = torch.maximum(
1716
- torch.tensor(kv_seq_len, dtype=torch.int64),
1717
- torch.tensor(q_len, dtype=torch.int64),
1718
- )
1719
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
1720
- query_states, key_states = (
1721
- transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
1722
- query_states, key_states, cos, sin, position_ids
1723
- )
1724
- )
1725
- # [bsz, nh, t, hd]
1726
-
1727
- if past_key_value is not None:
1728
- # sin and cos are specific to RoPE models;
1729
- # cache_position needed for the static cache
1730
- cache_kwargs = {"cache_position": cache_position}
1731
- key_states, value_states = past_key_value.update(
1732
- key_states, value_states, self.layer_idx, cache_kwargs
1733
- )
1734
-
1735
- if self.qk_layer_norms:
1736
- query_states = self.q_layer_norm(query_states)
1737
- key_states = self.k_layer_norm(key_states)
1738
-
1739
- attention_interface: Callable = (
1740
- transformers.models.idefics.modeling_idefics.eager_attention_forward
1741
- )
1742
-
1743
- if self.config._attn_implementation != "eager":
1744
- if self.config._attn_implementation == "sdpa" and output_attentions:
1745
- transformers.models.idefics.modeling_idefics.logger.warning_once(
1746
- "`torch.nn.functional.scaled_dot_product_attention` does not support "
1747
- "`output_attentions=True`. Falling back to "
1748
- "eager attention. This warning can be removed using the argument "
1749
- '`attn_implementation="eager"` when loading the model.'
1750
- )
1751
- else:
1752
- attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
1753
- self.config._attn_implementation
1754
- ]
1755
-
1756
- attn_output, attn_weights = attention_interface(
1757
- self,
1758
- query_states,
1759
- key_states,
1760
- value_states,
1761
- attention_mask,
1762
- dropout=0.0 if not self.training else self.dropout,
1763
- scaling=self.scaling,
1764
- **kwargs,
1765
- )
1766
-
1767
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
1768
- attn_output = self.o_proj(attn_output)
1769
-
1770
- if output_attentions:
1771
- attn_weights = None
1772
-
1773
- if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
1774
- return attn_output, attn_weights, past_key_value
1775
- return attn_output, attn_weights
1776
-
1777
-
1778
- class patched_SamMaskDecoder(torch.nn.Module):
1779
- _PATCHES_ = ["forward"]
1780
- _PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
1781
-
1782
- def forward(
1783
- self,
1784
- image_embeddings: torch.Tensor,
1785
- image_positional_embeddings: torch.Tensor,
1786
- sparse_prompt_embeddings: torch.Tensor,
1787
- dense_prompt_embeddings: torch.Tensor,
1788
- multimask_output: bool,
1789
- output_attentions: Optional[bool] = None,
1790
- attention_similarity: Optional[torch.Tensor] = None,
1791
- target_embedding: Optional[torch.Tensor] = None,
1792
- ) -> tuple[torch.Tensor, torch.Tensor]:
1793
- """
1794
- Predict masks given image and prompt embeddings.
1795
-
1796
- Args:
1797
- image_embeddings (`torch.Tensor`):
1798
- the embeddings from the image encoder
1799
- image_positional_embedding (`torch.Tensor`):
1800
- positional encoding with the shape of image_embeddings
1801
- sparse_prompt_embeddings (`torch.Tensor`):
1802
- The embeddings of the points and boxes
1803
- dense_prompt_embeddings (`torch.Tensor`):
1804
- the embeddings of the mask inputs
1805
- multimask_output (bool):
1806
- Whether to return multiple masks or a single mask.
1807
- output_attentions (bool, *optional*):
1808
- Whether or not to return the attentions tensors of all attention layers.
1809
- """
1810
- batch_size, num_channels, height, width = image_embeddings.shape
1811
- point_batch_size = sparse_prompt_embeddings.shape[1]
1812
- # Concatenate output tokens
1813
- output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1814
- output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
1815
-
1816
- # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1817
- # torch.any is needed to avoid data-dependent control flow
1818
- # with sparse_prompt_embeddings.sum().item() != 0
1819
- def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
1820
- return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
1821
-
1822
- def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
1823
- return output_tokens.clone()
1824
-
1825
- tokens = torch.cond(
1826
- torch.any(sparse_prompt_embeddings != 0),
1827
- sparse_prompt_embeddings_is_not_empty,
1828
- sparse_prompt_embeddings_is_empty,
1829
- [output_tokens, sparse_prompt_embeddings],
1830
- )
1831
-
1832
- point_embeddings = tokens.to(self.iou_token.weight.dtype)
1833
-
1834
- # Expand per-image data in batch direction to be per-point
1835
- image_embeddings = image_embeddings + dense_prompt_embeddings
1836
- image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
1837
- image_positional_embeddings = image_positional_embeddings.repeat_interleave(
1838
- point_batch_size, 0
1839
- )
1840
-
1841
- # Run the transformer, image_positional_embedding are consumed
1842
- torch._check(point_embeddings.shape[0] != 0)
1843
- torch._check(point_embeddings.shape[1] != 0)
1844
- torch._check(point_embeddings.shape[2] != 0)
1845
- torch._check(point_embeddings.shape[3] != 0)
1846
- embeddings_attentions = self.transformer(
1847
- point_embeddings=point_embeddings,
1848
- image_embeddings=image_embeddings,
1849
- image_positional_embeddings=image_positional_embeddings,
1850
- attention_similarity=attention_similarity,
1851
- target_embedding=target_embedding,
1852
- output_attentions=output_attentions,
1853
- )
1854
- point_embedding, image_embeddings = embeddings_attentions[:2]
1855
- iou_token_out = torch.select(point_embedding, dim=2, index=0)
1856
- mask_tokens_out = torch.narrow(
1857
- point_embedding, dim=2, start=1, length=self.num_mask_tokens
1858
- )
1859
-
1860
- # Upscale mask embeddings and predict masks using the mask tokens
1861
- image_embeddings = image_embeddings.transpose(2, 3).reshape(
1862
- batch_size * point_batch_size, num_channels, height, width
1863
- )
59
+ from ._patch_transformers_gemma3 import patch_gemma3
1864
60
 
1865
- upscaled_embedding = self.upscale_conv1(image_embeddings)
1866
- upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
1867
- upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
1868
-
1869
- hyper_in_list = []
1870
- for i in range(self.num_mask_tokens):
1871
- current_mlp = self.output_hypernetworks_mlps[i]
1872
- hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
1873
- hyper_in = torch.stack(hyper_in_list, dim=2)
1874
-
1875
- _, num_channels, height, width = upscaled_embedding.shape
1876
- upscaled_embedding = upscaled_embedding.reshape(
1877
- batch_size, point_batch_size, num_channels, height * width
1878
- )
1879
- masks = (hyper_in @ upscaled_embedding).reshape(
1880
- batch_size, point_batch_size, -1, height, width
1881
- )
1882
-
1883
- # Generate mask quality predictions
1884
- iou_pred = self.iou_prediction_head(iou_token_out)
1885
-
1886
- # Select the correct mask or masks for output
1887
- if multimask_output:
1888
- mask_slice = slice(1, None)
1889
- else:
1890
- mask_slice = slice(0, 1)
1891
- masks = masks[:, :, mask_slice, :, :]
1892
- iou_pred = iou_pred[:, :, mask_slice]
1893
-
1894
- outputs = (masks, iou_pred)
1895
-
1896
- if len(embeddings_attentions) == 2:
1897
- # transformers==4.54
1898
- return outputs
1899
-
1900
- if output_attentions and len(embeddings_attentions) > 2:
1901
- outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
1902
- else:
1903
- outputs = outputs + (None,) # noqa: RUF005
1904
- return outputs
1905
-
1906
-
1907
- def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
1908
- """
1909
- Rewrites the loop in:
1910
-
1911
- .. code-block:: python
1912
-
1913
- attention_mask = torch.full(
1914
- [1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
1915
- )
1916
- for i in range(1, len(seq)):
1917
- attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
1918
- """
1919
- r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
1920
- less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
1921
- less = less0.sum(axis=-1, keepdim=True) + 1
1922
- sq = less * less.T
1923
- look = (
1924
- torch.max(seq.min() == 0, less != less.max())
1925
- * torch.max(seq.max() == mask.shape[-1], less != less.min())
1926
- * less
1927
- )
1928
- filt = (sq != look**2).to(mask.dtype)
1929
- return mask * filt
61
+ if patch_gemma3:
62
+ from ._patch_transformers_gemma3 import patched_Gemma3Model
1930
63
 
64
+ from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
1931
65
 
1932
- try:
1933
- import transformers.models.qwen2_vl
1934
66
 
1935
- patch_qwen2 = True
1936
- except ImportError:
1937
- patch_qwen2 = False
67
+ from ._patch_transformers_qwen2 import patch_qwen2
1938
68
 
1939
69
  if patch_qwen2:
70
+ from ._patch_transformers_qwen2 import patched_VisionAttention
1940
71
 
1941
- class patched_VisionAttention(torch.nn.Module):
1942
- _PATCHES_ = ["forward"]
1943
- _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
1944
-
1945
- def forward(
1946
- self,
1947
- hidden_states: torch.Tensor,
1948
- cu_seqlens: torch.Tensor,
1949
- rotary_pos_emb: Optional[torch.Tensor] = None,
1950
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1951
- ) -> torch.Tensor:
1952
- seq_length = hidden_states.shape[0]
1953
- q, k, v = (
1954
- self.qkv(hidden_states)
1955
- .reshape(seq_length, 3, self.num_heads, -1)
1956
- .permute(1, 0, 2, 3)
1957
- .unbind(0)
1958
- )
1959
- if position_embeddings is None:
1960
- transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
1961
- "The attention layers in this model are transitioning from "
1962
- " computing the RoPE embeddings internally "
1963
- "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
1964
- "to using externally computed "
1965
- "`position_embeddings` (Tuple of tensors, containing cos and sin)."
1966
- " In v4.54 `rotary_pos_emb` will be "
1967
- "removed and `position_embeddings` will be mandatory."
1968
- )
1969
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1970
- cos = emb.cos()
1971
- sin = emb.sin()
1972
- else:
1973
- cos, sin = position_embeddings
1974
- q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
1975
- q, k, cos, sin
1976
- )
1977
-
1978
- attention_mask = torch.full(
1979
- [1, seq_length, seq_length],
1980
- torch.finfo(q.dtype).min,
1981
- device=q.device,
1982
- dtype=q.dtype,
1983
- )
1984
- # for i in range(1, len(cu_seqlens)):
1985
- # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
1986
- # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1987
- attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
1988
-
1989
- q = q.transpose(0, 1)
1990
- k = k.transpose(0, 1)
1991
- v = v.transpose(0, 1)
1992
- attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1993
- attn_weights = attn_weights + attention_mask
1994
- attn_weights = torch.nn.functional.softmax(
1995
- attn_weights, dim=-1, dtype=torch.float32
1996
- ).to(q.dtype)
1997
- attn_output = torch.matmul(attn_weights, v)
1998
- attn_output = attn_output.transpose(0, 1)
1999
- attn_output = attn_output.reshape(seq_length, -1)
2000
- attn_output = self.proj(attn_output)
2001
- return attn_output
2002
-
2003
-
2004
- try:
2005
- import transformers.models.qwen2_5_vl
2006
- import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
2007
-
2008
- patch_qwen2_5 = True
2009
- except ImportError:
2010
- patch_qwen2_5 = False
72
+ from ._patch_transformers_qwen2_5 import patch_qwen2_5
2011
73
 
2012
74
  if patch_qwen2_5:
2013
- import torch.nn.functional as F
2014
-
2015
- use_loop_for_attention_in_qwen_2_5 = False
2016
-
2017
- class patched_Qwen2_5_VLForConditionalGeneration:
2018
- _PATCHES_ = ["prepare_inputs_for_generation"]
2019
- _PATCHED_CLASS_ = (
2020
- transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
2021
- )
2022
-
2023
- def prepare_inputs_for_generation(
2024
- self,
2025
- input_ids,
2026
- past_key_values=None,
2027
- attention_mask=None,
2028
- inputs_embeds=None,
2029
- cache_position=None,
2030
- position_ids=None,
2031
- use_cache=True,
2032
- pixel_values=None,
2033
- pixel_values_videos=None,
2034
- image_grid_thw=None,
2035
- video_grid_thw=None,
2036
- second_per_grid_ts=None,
2037
- **kwargs,
2038
- ):
2039
- # Overwritten -- in specific circumstances we don't want to f
2040
- # forward image inputs to the model
2041
- from transformers.generation import GenerationMixin
2042
-
2043
- model_inputs = GenerationMixin.prepare_inputs_for_generation(
2044
- self,
2045
- input_ids,
2046
- past_key_values=past_key_values,
2047
- attention_mask=attention_mask,
2048
- inputs_embeds=inputs_embeds,
2049
- cache_position=cache_position,
2050
- position_ids=position_ids,
2051
- pixel_values=pixel_values,
2052
- pixel_values_videos=pixel_values_videos,
2053
- image_grid_thw=image_grid_thw,
2054
- video_grid_thw=video_grid_thw,
2055
- second_per_grid_ts=second_per_grid_ts,
2056
- use_cache=use_cache,
2057
- **kwargs,
2058
- )
2059
-
2060
- # Qwen2-5-VL position_ids are prepared with rope_deltas
2061
- if position_ids is None:
2062
- # Calculate RoPE index once per generation in the pre-fill stage only.
2063
- # When compiling, we can't check tensor values thus we check only input length
2064
- # It is safe to assume that `length!=1` means we're in pre-fill
2065
- # because compiled models currently cannot do assisted decoding
2066
- if cache_position[0] == 0 or self.model.rope_deltas is None:
2067
- vision_positions, rope_deltas = self.model.get_rope_index(
2068
- model_inputs.get("input_ids", None),
2069
- image_grid_thw=image_grid_thw,
2070
- video_grid_thw=video_grid_thw,
2071
- second_per_grid_ts=second_per_grid_ts,
2072
- attention_mask=attention_mask,
2073
- )
2074
- self.model.rope_deltas = rope_deltas
2075
- # then use the prev pre-calculated rope-deltas to get the correct position ids
2076
- elif (
2077
- "position_ids" in model_inputs and model_inputs["position_ids"] is not None
2078
- ):
2079
- batch_size, seq_length = model_inputs["position_ids"].shape
2080
- device = model_inputs["position_ids"].device
2081
- position_ids = torch.arange(seq_length, device=device)
2082
- position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
2083
- delta = cache_position[0] + self.model.rope_deltas
2084
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
2085
- vision_positions = position_ids + delta.expand_as(position_ids)
2086
-
2087
- # Concatenate "text + vision" positions into [4, bs, seq-len]
2088
- if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
2089
- text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
2090
- None, None, :
2091
- ]
2092
- else:
2093
- text_positions = model_inputs["position_ids"][None, ...]
2094
- # text_positions = model_inputs["position_ids"][None, ...]
2095
- assert vision_positions is not None, "vision_positions are missing"
2096
- model_inputs["position_ids"] = torch.cat(
2097
- [text_positions, vision_positions], dim=0
2098
- )
2099
-
2100
- if cache_position[0] != 0:
2101
- model_inputs["pixel_values"] = None
2102
- model_inputs["pixel_values_videos"] = None
2103
-
2104
- return model_inputs
2105
-
2106
- class patched_Qwen2_5_VisionTransformerPretrainedModel:
2107
- _PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
2108
- _PATCHED_CLASS_ = (
2109
- transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
2110
- )
2111
-
2112
- def rot_pos_emb(self, grid_thw):
2113
- pos_ids = []
2114
- for thw_ in grid_thw:
2115
- # PATCHED: avoid unbind
2116
- t = thw_[0]
2117
- h = thw_[1]
2118
- w = thw_[2]
2119
- hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
2120
- hpos_ids = hpos_ids.reshape(
2121
- h // self.spatial_merge_size,
2122
- self.spatial_merge_size,
2123
- w // self.spatial_merge_size,
2124
- self.spatial_merge_size,
2125
- )
2126
- hpos_ids = hpos_ids.permute(0, 2, 1, 3)
2127
- hpos_ids = hpos_ids.flatten()
2128
-
2129
- wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
2130
- wpos_ids = wpos_ids.reshape(
2131
- h // self.spatial_merge_size,
2132
- self.spatial_merge_size,
2133
- w // self.spatial_merge_size,
2134
- self.spatial_merge_size,
2135
- )
2136
- wpos_ids = wpos_ids.permute(0, 2, 1, 3)
2137
- wpos_ids = wpos_ids.flatten()
2138
- pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
2139
- pos_ids = torch.cat(pos_ids, dim=0)
2140
- max_grid_size = grid_thw[:, 1:].max()
2141
- rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
2142
- rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
2143
- return rotary_pos_emb
2144
-
2145
- def get_window_index(self, grid_thw):
2146
- window_index: list = [] # type: ignore[annotation-unchecked]
2147
- # PATCHED
2148
- cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
2149
- window_index_id = 0
2150
- vit_merger_window_size = (
2151
- self.window_size // self.spatial_merge_size // self.patch_size
2152
- )
2153
-
2154
- for _thw in grid_thw:
2155
- # PATCHED: avoid unbind
2156
- grid_t = _thw[0]
2157
- grid_h = _thw[1]
2158
- grid_w = _thw[2]
2159
- llm_grid_h, llm_grid_w = (
2160
- grid_h // self.spatial_merge_size,
2161
- grid_w // self.spatial_merge_size,
2162
- )
2163
- index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
2164
- grid_t, llm_grid_h, llm_grid_w
2165
- )
2166
- pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
2167
- pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
2168
- num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
2169
- num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
2170
- index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
2171
- index_padded = index_padded.reshape(
2172
- grid_t,
2173
- num_windows_h,
2174
- vit_merger_window_size,
2175
- num_windows_w,
2176
- vit_merger_window_size,
2177
- )
2178
- index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
2179
- grid_t,
2180
- num_windows_h * num_windows_w,
2181
- vit_merger_window_size,
2182
- vit_merger_window_size,
2183
- )
2184
- seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
2185
- index_padded = index_padded.reshape(-1)
2186
- index_new = index_padded[index_padded != -100]
2187
- window_index.append(index_new + window_index_id)
2188
- cu_seqlens_tmp = (
2189
- seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
2190
- )
2191
- # PATCHED
2192
- # cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
2193
- cu_window_seqlens.append(cu_seqlens_tmp)
2194
- window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
2195
- window_index = torch.cat(window_index, dim=0)
2196
-
2197
- return window_index, torch.cat(cu_window_seqlens, dim=0)
2198
-
2199
- def forward(
2200
- self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
2201
- ) -> torch.Tensor:
2202
- """
2203
- Args:
2204
- hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
2205
- The final hidden states of the model.
2206
- grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
2207
- The temporal, height and width of feature shape of each image in LLM.
2208
-
2209
- Returns:
2210
- `torch.Tensor`: hidden_states.
2211
- """
2212
- hidden_states = self.patch_embed(hidden_states)
2213
- rotary_pos_emb = self.rot_pos_emb(grid_thw)
2214
- window_index, cu_window_seqlens = self.get_window_index(grid_thw)
2215
- # PATCHED
2216
- # cu_window_seqlens = torch.tensor(
2217
- # cu_window_seqlens,
2218
- # device=hidden_states.device,
2219
- # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
2220
- # )
2221
- cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
2222
- cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
2223
-
2224
- seq_len, _ = hidden_states.size()
2225
- hidden_states = hidden_states.reshape(
2226
- seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
2227
- )
2228
- hidden_states = hidden_states[window_index, :, :]
2229
- hidden_states = hidden_states.reshape(seq_len, -1)
2230
- rotary_pos_emb = rotary_pos_emb.reshape(
2231
- seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
2232
- )
2233
- rotary_pos_emb = rotary_pos_emb[window_index, :, :]
2234
- rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
2235
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
2236
- position_embeddings = (emb.cos(), emb.sin())
2237
-
2238
- cu_seqlens = torch.repeat_interleave(
2239
- grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
2240
- ).cumsum(
2241
- dim=0,
2242
- # Select dtype based on the following factors:
2243
- # - FA2 requires that cu_seqlens_q must have dtype int32
2244
- # - torch.onnx.export requires that cu_seqlens_q must have same dtype
2245
- # as grid_thw
2246
- # See https://github.com/huggingface/transformers/pull/34852
2247
- # for more information
2248
- dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
2249
- )
2250
- cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
2251
-
2252
- for layer_num, blk in enumerate(self.blocks):
2253
- if layer_num in self.fullatt_block_indexes:
2254
- cu_seqlens_now = cu_seqlens
2255
- else:
2256
- cu_seqlens_now = cu_window_seqlens
2257
-
2258
- hidden_states = blk(
2259
- hidden_states,
2260
- cu_seqlens=cu_seqlens_now,
2261
- position_embeddings=position_embeddings,
2262
- **kwargs,
2263
- )
2264
-
2265
- hidden_states = self.merger(hidden_states)
2266
- reverse_indices = torch.argsort(window_index)
2267
- hidden_states = hidden_states[reverse_indices, :]
2268
- return hidden_states
2269
-
2270
- class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
2271
- def forward(
2272
- self,
2273
- start_end,
2274
- query_states,
2275
- key_states,
2276
- value_states,
2277
- scaling: float = 1.0,
2278
- dropout: float = 0.0,
2279
- **kwargs,
2280
- ):
2281
- a = start_end[0].item()
2282
- b = start_end[1].item()
2283
- q = query_states[:, :, a:b, :]
2284
- k = key_states[:, :, a:b, :]
2285
- v = value_states[:, :, a:b, :]
2286
- return patched_sdpa_attention_forward(
2287
- self,
2288
- q,
2289
- k,
2290
- v,
2291
- attention_mask=None,
2292
- scaling=scaling,
2293
- dropout=dropout,
2294
- is_causal=False,
2295
- **kwargs,
2296
- )[0]
2297
-
2298
- class patched_Qwen2_5_VLVisionAttention:
2299
- _PATCHES_ = ["forward"]
2300
- _PATCHED_CLASS_ = (
2301
- transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
2302
- )
2303
-
2304
- def forward(
2305
- self,
2306
- hidden_states: torch.Tensor,
2307
- cu_seqlens: torch.Tensor,
2308
- rotary_pos_emb: Optional[torch.Tensor] = None,
2309
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
2310
- **kwargs,
2311
- ) -> torch.Tensor:
2312
- seq_length = hidden_states.shape[0]
2313
- # PATCHED: avoid the use of unbind
2314
- qkv = (
2315
- self.qkv(hidden_states)
2316
- .reshape(seq_length, 3, self.num_heads, -1)
2317
- .permute(1, 0, 2, 3)
2318
- )
2319
-
2320
- query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
2321
- cos, sin = position_embeddings
2322
-
2323
- # This part should be moved into the loop
2324
- # iteration to enable fusion inside the loop.
2325
- query_states, key_states = (
2326
- transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
2327
- query_states, key_states, cos, sin
2328
- )
2329
- )
2330
-
2331
- query_states = query_states.transpose(0, 1).unsqueeze(0)
2332
- key_states = key_states.transpose(0, 1).unsqueeze(0)
2333
- value_states = value_states.transpose(0, 1).unsqueeze(0)
2334
-
2335
- attention_interface: Callable = (
2336
- transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
2337
- )
2338
- if self.config._attn_implementation != "eager":
2339
- # PATCHED
2340
- # attention_interface = ALL_ATTENTION_FUNCTIONS[
2341
- # self.config._attn_implementation]
2342
- attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
2343
- self.config._attn_implementation
2344
- ]
2345
-
2346
- if (
2347
- self.config._attn_implementation == "flash_attention_2"
2348
- and _is_torchdynamo_exporting()
2349
- ):
2350
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
2351
- attn_output = torch.onnx.ops.symbolic(
2352
- "custom::qwen25_attention",
2353
- (
2354
- query_states,
2355
- key_states,
2356
- value_states,
2357
- cu_seqlens,
2358
- cu_seqlens,
2359
- max_seqlen,
2360
- max_seqlen,
2361
- torch.tensor(self.scaling, dtype=torch.float32),
2362
- ),
2363
- dtype=query_states.dtype,
2364
- shape=(
2365
- key_states.shape[0],
2366
- value_states.shape[1],
2367
- max_seqlen,
2368
- value_states.shape[-1],
2369
- ),
2370
- version=1,
2371
- )
2372
- elif self.config._attn_implementation == "flash_attention_2":
2373
- # Flash Attention 2: Use cu_seqlens for variable length attention
2374
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
2375
- attn_output, _ = attention_interface(
2376
- self,
2377
- query_states,
2378
- key_states,
2379
- value_states,
2380
- attention_mask=None,
2381
- scaling=self.scaling,
2382
- dropout=0.0 if not self.training else self.attention_dropout,
2383
- cu_seq_lens_q=cu_seqlens,
2384
- cu_seq_lens_k=cu_seqlens,
2385
- max_length_q=max_seqlen,
2386
- max_length_k=max_seqlen,
2387
- is_causal=False,
2388
- **kwargs,
2389
- )
2390
- elif _is_torchdynamo_exporting():
2391
- if (
2392
- attention_interface
2393
- is transformers.integrations.sdpa_attention.sdpa_attention_forward
2394
- ):
2395
- attention_interface = patched_sdpa_attention_forward
2396
-
2397
- if use_loop_for_attention_in_qwen_2_5:
2398
-
2399
- def _iteration(start_end, query_states, key_states, value_states):
2400
- return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
2401
- self,
2402
- start_end,
2403
- query_states,
2404
- key_states,
2405
- value_states,
2406
- scaling=self.scaling,
2407
- dropout=0.0 if not self.training else self.attention_dropout,
2408
- )
2409
-
2410
- starts = cu_seqlens[:-1]
2411
- ends = cu_seqlens[1:]
2412
- # cu_seqlens = [0, 10, 14, 27]
2413
- # starts: [0, 10, 14]
2414
- # ends: [10, 14, 17]
2415
- # starts_ends: [[0, 10], [10, 14], [14, 27]]
2416
- starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
2417
- attn_outputs = [
2418
- _iteration(start_end, query_states, key_states, value_states)
2419
- for start_end in starts_ends
2420
- ]
2421
- # attn_outputs = torch._higher_order_ops.while_loop(
2422
- # attn_outputs = torch.ops.higher_order.while_loop(
2423
- # (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
2424
- # _iteration,
2425
- # (torch.tensor(0),
2426
- # starts_ends, query_states, key_states, value_states), tuple(),
2427
- # )
2428
- attn_output = torch.cat(attn_outputs, dim=1)
2429
- else:
2430
- # make square mask
2431
- indices = torch.arange(
2432
- cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
2433
- )
2434
- dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
2435
- cu_seqlens.dtype
2436
- )
2437
- dot = dot.sum(dim=0)
2438
- mask = dot.unsqueeze(1) - dot.unsqueeze(0)
2439
- bool_mask = mask == 0
2440
- bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
2441
-
2442
- torch._check(bool_mask.shape[2] == key_states.shape[2])
2443
- torch._check(bool_mask.shape[3] == key_states.shape[2])
2444
-
2445
- attn_output, _ = attention_interface(
2446
- self,
2447
- query_states,
2448
- key_states,
2449
- value_states,
2450
- attention_mask=bool_mask,
2451
- scaling=self.scaling,
2452
- dropout=0.0 if not self.training else self.attention_dropout,
2453
- is_causal=False,
2454
- **kwargs,
2455
- )
2456
- else:
2457
- # Other implementations: Process each chunk separately
2458
- lengths = cu_seqlens[1:] - cu_seqlens[:-1]
2459
- splits = [
2460
- torch.split(tensor, lengths.tolist(), dim=2)
2461
- for tensor in (query_states, key_states, value_states)
2462
- ]
2463
-
2464
- attn_outputs = [
2465
- attention_interface(
2466
- self,
2467
- q,
2468
- k,
2469
- v,
2470
- attention_mask=None,
2471
- scaling=self.scaling,
2472
- dropout=0.0 if not self.training else self.attention_dropout,
2473
- is_causal=False,
2474
- **kwargs,
2475
- )[0]
2476
- for q, k, v in zip(*splits)
2477
- ]
2478
- attn_output = torch.cat(attn_outputs, dim=1)
2479
-
2480
- attn_output = attn_output.reshape(seq_length, -1).contiguous()
2481
- attn_output = self.proj(attn_output)
2482
- return attn_output
2483
-
2484
-
2485
- try:
2486
- import transformers.models.qwen3_moe
75
+ from ._patch_transformers_qwen2_5 import (
76
+ patched_Qwen2_5_VLForConditionalGeneration,
77
+ patched_Qwen2_5_VisionTransformerPretrainedModel,
78
+ patched_Qwen2_5_VLVisionAttentionOneIteration,
79
+ patched_Qwen2_5_VLVisionAttention,
80
+ PLUGS as PLUGS_Qwen25,
81
+ )
2487
82
 
2488
- patch_qwen3 = True
2489
- except ImportError:
2490
- patch_qwen3 = False
83
+ from ._patch_transformers_qwen3 import patch_qwen3
2491
84
 
2492
85
  if patch_qwen3:
86
+ from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock
2493
87
 
2494
- class patched_Qwen3MoeSparseMoeBlock(torch.nn.Module):
2495
- _PATCHES_ = ["forward", "_forward_expert_loop"]
2496
- _PATCHED_CLASS_ = (
2497
- transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
2498
- )
2499
-
2500
- def _forward_expert_loop(
2501
- self,
2502
- final_hidden_states,
2503
- expert_mask_idx,
2504
- hidden_states,
2505
- routing_weights,
2506
- expert_idx: int,
2507
- ):
2508
- # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
2509
- idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
2510
- hidden_dim = hidden_states.shape[-1]
2511
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2512
- expert_current_state = self.experts[expert_idx](current_state)
2513
- current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
2514
- return final_hidden_states.index_add(
2515
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
2516
- )
2517
-
2518
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
2519
- """ """
2520
- batch_size, sequence_length, hidden_dim = hidden_states.shape
2521
- hidden_states = hidden_states.view(-1, hidden_dim)
2522
- # router_logits: (batch * sequence_length, n_experts)
2523
- router_logits = self.gate(hidden_states)
2524
-
2525
- routing_weights = torch.nn.functional.softmax(
2526
- router_logits, dim=1, dtype=torch.float
2527
- )
2528
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
2529
- if self.norm_topk_prob: # only diff with mixtral sparse moe block!
2530
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
2531
- # we cast back to the input dtype
2532
- routing_weights = routing_weights.to(hidden_states.dtype)
2533
-
2534
- final_hidden_states = torch.zeros(
2535
- (batch_size * sequence_length, hidden_dim),
2536
- dtype=hidden_states.dtype,
2537
- device=hidden_states.device,
2538
- )
2539
-
2540
- # One hot encode the selected experts to create an expert mask
2541
- # this will be used to easily index which expert is going to be sollicitated
2542
- expert_mask = torch.nn.functional.one_hot(
2543
- selected_experts, num_classes=self.num_experts
2544
- ).permute(2, 1, 0)
2545
-
2546
- # Loop over all available experts in the model
2547
- # and perform the computation on each expert
2548
- expert_sum = expert_mask.sum(dim=(-1, -2))
2549
- # expert_hit = torch.greater(expert_sum, 0).nonzero()
2550
- # for expert_idx in expert_hit:
2551
- for expert_idx in range(self.num_experts):
2552
- # initial code has a squeeze but it is not possible to do that.
2553
- # expert_mask_idx = expert_mask[expert_idx].squeeze(0)
2554
- expert_mask_idx = expert_mask[expert_idx]
2555
- final_hidden_states = torch.cond(
2556
- (expert_sum[expert_idx] > 0).item(),
2557
- lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
2558
- final_hidden_states,
2559
- expert_mask,
2560
- hidden_states,
2561
- routing_weights,
2562
- expert_idx=_i,
2563
- ),
2564
- lambda final_hidden_states, *args: final_hidden_states.clone(),
2565
- [final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
2566
- )
2567
-
2568
- # if expert_sum[expert_idx] > 0:
2569
- # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
2570
-
2571
- # Index the correct hidden states and compute the expert hidden state for
2572
- # the current expert. We need to make sure to multiply the output hidden
2573
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
2574
- # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2575
- # current_hidden_states = (
2576
- # expert_layer(current_state) * routing_weights[top_x, idx, None]
2577
- # )
2578
-
2579
- # However `index_add_` only support torch tensors for indexing so we'll use
2580
- # the `top_x` tensor here.
2581
- # final_hidden_states.index_add_(
2582
- # 0, top_x, current_hidden_states.to(hidden_states.dtype)
2583
- # )
2584
-
2585
- final_hidden_states = final_hidden_states.reshape(
2586
- batch_size, sequence_length, hidden_dim
2587
- )
2588
- return final_hidden_states, router_logits
2589
-
2590
-
2591
- try:
2592
- from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
2593
-
2594
- patch_gemma3 = True
2595
- except ImportError:
2596
- patch_gemma3 = False
2597
-
2598
-
2599
- if patch_gemma3:
2600
88
 
2601
- class patched_Gemma3Model(torch.nn.Module):
2602
- _PATCHES_ = ["get_placeholder_mask"]
2603
- _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
2604
- _PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
89
+ from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
2605
90
 
2606
- def get_placeholder_mask(
2607
- self,
2608
- input_ids: torch.LongTensor,
2609
- inputs_embeds: torch.FloatTensor,
2610
- image_features: torch.FloatTensor,
2611
- ):
2612
- if input_ids is None:
2613
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
2614
- torch.tensor(
2615
- self.config.image_token_id,
2616
- dtype=torch.long,
2617
- device=inputs_embeds.device,
2618
- )
2619
- )
2620
- special_image_mask = special_image_mask.all(-1)
2621
- else:
2622
- special_image_mask = input_ids == self.config.image_token_id
2623
91
 
2624
- n_image_tokens = special_image_mask.sum()
2625
- special_image_mask = (
2626
- special_image_mask.unsqueeze(-1)
2627
- .expand_as(inputs_embeds)
2628
- .to(inputs_embeds.device)
2629
- )
2630
- n_image_features = image_features.shape[0] * image_features.shape[1]
2631
- # PATCHED: torch._check
2632
- # if inputs_embeds[special_image_mask].numel() != image_features.numel():
2633
- # raise ValueError( ... )
2634
- torch._check(
2635
- inputs_embeds[special_image_mask].numel() == image_features.numel(),
2636
- lambda: (
2637
- f"Image features and image tokens do not match: tokens: "
2638
- f"{n_image_tokens}, features {n_image_features}"
2639
- ),
2640
- )
2641
- return special_image_mask
92
+ def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821
93
+ """Returns the necessary plugs to rewrite models."""
94
+ plugs = []
95
+ if patch_qwen2_5:
96
+ plugs.extend(PLUGS_Qwen25)
97
+ return plugs