onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2139 @@
1
+ import inspect
2
+ import math
3
+ import os
4
+ from dataclasses import dataclass
5
+ from functools import wraps
6
+ from typing import Callable, List, Optional, Tuple, Union
7
+ import packaging.version as pv
8
+ import torch
9
+ import transformers
10
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
11
+ from transformers.cache_utils import StaticCache, Cache
12
+ from transformers.generation.utils import (
13
+ GenerateNonBeamOutput,
14
+ GenerationConfig,
15
+ StoppingCriteriaList,
16
+ LogitsProcessorList,
17
+ )
18
+
19
+ try:
20
+ from transformers.cache_utils import parse_processor_args # noqa: F401
21
+
22
+ patch_parse_processor_args = True
23
+ except ImportError:
24
+ patch_parse_processor_args = False
25
+
26
+ try:
27
+ import transformers.masking_utils
28
+
29
+ patch_masking_utils = True
30
+ except ImportError:
31
+ patch_masking_utils = False
32
+
33
+
34
+ try:
35
+ # transformers>= 4.55.1
36
+ from transformers.cache_utils import DynamicLayer
37
+
38
+ patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
39
+ except ImportError:
40
+ patch_DynamicLayer = False
41
+
42
+
43
+ def _has_transformers(version: str) -> bool:
44
+ return pv.Version(transformers.__version__) >= pv.Version(version)
45
+
46
+
47
+ def _is_torchdynamo_exporting() -> bool:
48
+ """
49
+ Tells if :epkg:`torch` is exporting a model.
50
+ Relies on ``torch.compiler.is_exporting()``.
51
+ """
52
+ import torch
53
+
54
+ if not hasattr(torch.compiler, "is_exporting"):
55
+ # torch.compiler.is_exporting requires torch>=2.7
56
+ return False
57
+
58
+ try:
59
+ return torch.compiler.is_exporting()
60
+ except Exception:
61
+ try:
62
+ import torch._dynamo as dynamo
63
+
64
+ return dynamo.is_exporting() # type: ignore
65
+ except Exception:
66
+ return False
67
+
68
+
69
+ patch_sdpa_is_causal = _has_transformers("4.99")
70
+ patch_is_initialized = _has_transformers("4.56.99")
71
+
72
+
73
+ if patch_masking_utils:
74
+ # Introduced in 4.52
75
+ from transformers.masking_utils import (
76
+ _ignore_causal_mask_sdpa,
77
+ and_masks,
78
+ causal_mask_function,
79
+ padding_mask_function,
80
+ prepare_padding_mask,
81
+ )
82
+
83
+ try:
84
+ # transformers>=5.0
85
+ from transformers.masking_utils import (
86
+ _ignore_bidirectional_mask_sdpa,
87
+ bidirectional_mask_function,
88
+ )
89
+ except ImportError:
90
+ _ignore_bidirectional_mask_sdpa = None
91
+ bidirectional_mask_function = None
92
+
93
+ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
94
+ """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
95
+ from ...helpers import string_type
96
+
97
+ dimensions: List[Tuple[Optional[int], ...]] = [
98
+ (None, None, None, 0),
99
+ (None, None, 0, None),
100
+ ]
101
+ if bh_indices:
102
+ dimensions.extend([(None, 0, None, None), (0, None, None, None)])
103
+ # reshape
104
+ dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
105
+ dimensions = tuple(reversed(dimensions))
106
+ indices = tuple(shape.index(-1) for shape in dimensions)
107
+
108
+ # unsqueeze
109
+ udimensions = [
110
+ tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
111
+ ]
112
+
113
+ def vector_mask_function(
114
+ *args, mask_function=mask_function, dimensions=dimensions, indices=indices
115
+ ):
116
+ assert len(args) == len(dimensions) == len(udimensions), (
117
+ f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
118
+ f"and udimensions={udimensions}."
119
+ )
120
+ assert len(indices) == len(args), (
121
+ f"Mismatch between args={string_type(args)} and indices={indices}, "
122
+ f"they should have the same length."
123
+ )
124
+ for a in args:
125
+ assert (
126
+ a.ndim == 1
127
+ ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
128
+ torch._check(a.shape[0] > 0)
129
+
130
+ new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
131
+ # new_args = [
132
+ # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
133
+ # for a, dims in zip(args, udimensions)
134
+ # ]
135
+ max_shape = tuple(args[i].shape[0] for i in indices)
136
+ # if _is_torchdynamo_exporting():
137
+ # for a in args:
138
+ # # The exporter should export with a dimension > 1
139
+ # # to make sure it is dynamic.
140
+ # torch._check(a.shape[0] > 1)
141
+ expanded_args = [a.expand(max_shape) for a in new_args]
142
+ return mask_function(*expanded_args)
143
+
144
+ return vector_mask_function
145
+
146
+ def patched_eager_mask(
147
+ batch_size: int,
148
+ cache_position: torch.Tensor,
149
+ kv_length: int,
150
+ kv_offset: int = 0,
151
+ mask_function: Callable = causal_mask_function,
152
+ attention_mask: Optional[torch.Tensor] = None,
153
+ dtype: torch.dtype = torch.float32,
154
+ **kwargs,
155
+ ) -> torch.Tensor:
156
+ """manual patch for function ``transformers.masking_utils.eager_mask``."""
157
+ # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
158
+ _ = kwargs.pop("allow_is_causal_skip", None)
159
+ _ = kwargs.pop("allow_is_bidirectional_skip", None)
160
+ # PATCHED: this line called the patched version of sdpa_mask
161
+ mask = patched_sdpa_mask_recent_torch(
162
+ batch_size=batch_size,
163
+ cache_position=cache_position,
164
+ kv_length=kv_length,
165
+ kv_offset=kv_offset,
166
+ mask_function=mask_function,
167
+ attention_mask=attention_mask,
168
+ allow_is_causal_skip=False,
169
+ allow_is_bidirectional_skip=False,
170
+ allow_torch_fix=False,
171
+ **kwargs,
172
+ )
173
+ min_dtype = torch.finfo(dtype).min
174
+ # PATCHED: the following line
175
+ # we need 0s where the tokens should be taken into account,
176
+ # and -inf otherwise (mask is already of boolean type)
177
+ # mask =
178
+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
179
+ mask = (~mask).to(dtype) * min_dtype
180
+ return mask
181
+
182
+ def patched_sdpa_mask_recent_torch(
183
+ batch_size: int,
184
+ cache_position: torch.Tensor,
185
+ kv_length: int,
186
+ kv_offset: int = 0,
187
+ mask_function: Callable = causal_mask_function,
188
+ attention_mask: Optional[torch.Tensor] = None,
189
+ local_size: Optional[int] = None,
190
+ allow_is_causal_skip: bool = True,
191
+ allow_is_bidirectional_skip: bool = False,
192
+ **kwargs,
193
+ ) -> Optional[torch.Tensor]:
194
+ """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
195
+ q_length = cache_position.shape[0]
196
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
197
+ if allow_is_causal_skip and _ignore_causal_mask_sdpa(
198
+ padding_mask, q_length, kv_length, kv_offset, local_size
199
+ ):
200
+ return None
201
+ if (
202
+ allow_is_bidirectional_skip
203
+ and _ignore_bidirectional_mask_sdpa
204
+ and _ignore_bidirectional_mask_sdpa(padding_mask)
205
+ ):
206
+ return None
207
+
208
+ if mask_function is bidirectional_mask_function:
209
+ if padding_mask is not None:
210
+ # used for slicing without data-dependent slicing
211
+ mask_indices = (
212
+ torch.arange(kv_length, device=cache_position.device) + kv_offset
213
+ )
214
+ return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
215
+ return torch.ones(
216
+ batch_size,
217
+ 1,
218
+ q_length,
219
+ kv_length,
220
+ dtype=torch.bool,
221
+ device=cache_position.device,
222
+ )
223
+
224
+ kv_arange = torch.arange(kv_length, device=cache_position.device)
225
+ kv_arange += kv_offset
226
+ if padding_mask is not None:
227
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
228
+ batch_arange = torch.arange(batch_size, device=cache_position.device)
229
+ head_arange = torch.arange(1, device=cache_position.device)
230
+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
231
+ causal_mask = patched__vmap_for_bhqkv(mask_function)(
232
+ batch_arange, head_arange, cache_position, kv_arange
233
+ )
234
+ return causal_mask
235
+
236
+
237
+ if patch_parse_processor_args:
238
+
239
+ def _init_cache_inspect():
240
+ res = {}
241
+ for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
242
+ try:
243
+ params = list(inspect.signature(processor_class.__init__).parameters)[2:]
244
+ res[processor_class.__init__] = params
245
+ except Exception:
246
+ res[processor_class.__init__] = None
247
+ return res
248
+
249
+ _cache_inspect = _init_cache_inspect()
250
+
251
+ def patched_parse_processor_args(
252
+ processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
253
+ ) -> tuple[dict, dict]:
254
+ """[patch:transformers.cache_utils.parse_processor_args]"""
255
+ # If not patched...
256
+ # Fails with transformers>=4.54 because function ``parse_processor_args``
257
+ # relies in inspect and the exporter is not very fond of that.
258
+ # torch._dynamo.exc.Unsupported: id() with unsupported args
259
+ # Explanation: Dynamo doesn't know how to trace id()
260
+ # call with args
261
+ # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
262
+ # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
263
+ # objects from outside the compiled region.
264
+ # Hint: It may be possible to write Dynamo tracing rules for this code.
265
+ #
266
+ # The patch is caching the signature to avoid any call to inspect.
267
+ if processor_class is None:
268
+ return {}, kwargs
269
+ params = _cache_inspect[processor_class.__init__]
270
+ if params is None:
271
+ return {}, kwargs
272
+ processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
273
+ remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
274
+ return processor_kwargs, remaining_kwargs
275
+
276
+
277
+ if patch_DynamicLayer:
278
+
279
+ class patched_DynamicLayer:
280
+ _PATCHES_ = ["lazy_initialization"]
281
+ _PATCHED_CLASS_ = DynamicLayer
282
+
283
+ def lazy_initialization(self, key_states: torch.Tensor):
284
+ self.dtype, self.device = key_states.dtype, key_states.device
285
+ new_shape = list(key_states.shape)
286
+ new_shape[-2] = 0
287
+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
288
+ self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
289
+ self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
290
+ if patch_is_initialized:
291
+ self.is_initialized = True
292
+
293
+
294
+ def _patch_make_causal_mask(
295
+ input_ids_shape: torch.Size,
296
+ dtype: torch.dtype,
297
+ device: torch.device,
298
+ past_key_values_length: int = 0,
299
+ sliding_window: Optional[int] = None,
300
+ ):
301
+ """Patched method."""
302
+ bsz, tgt_len = input_ids_shape
303
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
304
+ mask_cond = torch.arange(mask.size(-1), device=device)
305
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
306
+
307
+ mask = mask.to(dtype)
308
+
309
+ if past_key_values_length > 0:
310
+ mask = torch.cat(
311
+ [
312
+ torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
313
+ mask,
314
+ ],
315
+ dim=-1,
316
+ )
317
+
318
+ if sliding_window is not None:
319
+ diagonal = past_key_values_length - sliding_window - 1
320
+
321
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
322
+ # PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
323
+ # and used masked_fill instead of masked_fill_
324
+ # In this case, the current implementation of torch fails (17/12/2024).
325
+ # Try model Phi-3.5-Mini-Instruct.
326
+ mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
327
+
328
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
329
+
330
+
331
+ @dataclass
332
+ class patched_AttentionMaskConverter:
333
+ """
334
+ Patches
335
+ ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
336
+ """
337
+
338
+ # This method was fixed in 4.51 at least.
339
+ _PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
340
+ _PATCHED_CLASS_ = AttentionMaskConverter
341
+
342
+ @staticmethod
343
+ def _make_causal_mask(
344
+ *args,
345
+ **kwargs,
346
+ # input_ids_shape: torch.Size,
347
+ # dtype: torch.dtype,
348
+ # device: torch.device,
349
+ # past_key_values_length: int = 0,
350
+ # sliding_window: Optional[int] = None,
351
+ ):
352
+ """
353
+ Patched method.
354
+
355
+ This static method may be called with ``AttentionMaskConverter._make_causal_mask``
356
+ or ``self._make_causal_mask``. That changes this argument is receives.
357
+ That should not matter but...
358
+ The patch should be implemented in another way. static methods do not play well
359
+ with a simple replacement.
360
+ Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
361
+ """
362
+ if args:
363
+ index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
364
+ names = [
365
+ "input_ids_shape",
366
+ "dtype",
367
+ "device",
368
+ "past_key_values_length",
369
+ "sliding_window",
370
+ ]
371
+ for i, a in enumerate(args):
372
+ if i < index:
373
+ continue
374
+ kwargs[names[i - index]] = a
375
+ return _patch_make_causal_mask(**kwargs)
376
+
377
+
378
+ if pv.Version(transformers.__version__) < pv.Version("4.51"):
379
+ from typing import Any, Dict
380
+ from transformers.cache_utils import DynamicCache
381
+
382
+ class patched_DynamicCache:
383
+ """
384
+ Applies modifications implemented in PR
385
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
386
+ """
387
+
388
+ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
389
+ _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
390
+
391
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
392
+ """Returns the sequence length of the cached states.
393
+ A layer index can be optionally passed."""
394
+ # TODO: deprecate this function in favor of `cache_position`
395
+ is_empty_layer = (
396
+ len(self.key_cache) == 0 # no cache in any layer
397
+ or len(self.key_cache)
398
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
399
+ or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
400
+ )
401
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
402
+ return layer_seq_length
403
+
404
+ def reorder_cache(self, beam_idx: torch.LongTensor):
405
+ """Reorders the cache for beam search, given the selected beam indices."""
406
+ for layer_idx in range(len(self.key_cache)):
407
+ if self.key_cache[layer_idx].numel():
408
+ device = self.key_cache[layer_idx].device
409
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
410
+ 0, beam_idx.to(device)
411
+ )
412
+ if self.value_cache[layer_idx].numel():
413
+ device = self.value_cache[layer_idx].device
414
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
415
+ 0, beam_idx.to(device)
416
+ )
417
+
418
+ def update(
419
+ self,
420
+ key_states: torch.Tensor,
421
+ value_states: torch.Tensor,
422
+ layer_idx: int,
423
+ cache_kwargs: Optional[Dict[str, Any]] = None,
424
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
425
+ """
426
+ Updates the cache with the new `key_states`
427
+ and `value_states` for the layer `layer_idx`.
428
+ Parameters:
429
+ key_states (`torch.Tensor`):
430
+ The new key states to cache.
431
+ value_states (`torch.Tensor`):
432
+ The new value states to cache.
433
+ layer_idx (`int`):
434
+ The index of the layer to cache the states for.
435
+ cache_kwargs (`Dict[str, Any]`, `optional`):
436
+ Additional arguments for the cache subclass.
437
+ No additional arguments are used in `DynamicCache`.
438
+ Return:
439
+ A tuple containing the updated key and value states.
440
+ """
441
+ # Update the number of seen tokens
442
+ if layer_idx == 0:
443
+ if hasattr(self, "_seen_tokens"):
444
+ self._seen_tokens += key_states.shape[-2]
445
+
446
+ # Update the cache
447
+ if key_states is not None:
448
+ if len(self.key_cache) <= layer_idx:
449
+ # There may be skipped layers, fill them with empty lists
450
+ for _ in range(len(self.key_cache), layer_idx):
451
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
452
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
453
+ self.key_cache.append(key_states)
454
+ self.value_cache.append(value_states)
455
+ elif not self.key_cache[
456
+ layer_idx
457
+ ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
458
+ # fills previously skipped layers; checking for tensor causes errors
459
+ self.key_cache[layer_idx] = key_states
460
+ self.value_cache[layer_idx] = value_states
461
+ else:
462
+ torch._check(
463
+ len(self.key_cache[layer_idx].shape) == len(key_states.shape),
464
+ lambda: (
465
+ f"Rank mismatch len(self.key_cache[layer_idx].shape)="
466
+ f"{len(self.key_cache[layer_idx].shape)}, "
467
+ f"len(key_states.shape)={len(key_states.shape)}"
468
+ ),
469
+ )
470
+ self.key_cache[layer_idx] = torch.cat(
471
+ [self.key_cache[layer_idx], key_states], dim=-2
472
+ )
473
+ self.value_cache[layer_idx] = torch.cat(
474
+ [self.value_cache[layer_idx], value_states], dim=-2
475
+ )
476
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
477
+
478
+ def crop(self, max_length: int):
479
+ """Crop the past key values up to a new `max_length`
480
+ in terms of tokens. `max_length` can also be
481
+ negative to remove `max_length` tokens.
482
+ This is used in assisted decoding and contrastive search.
483
+ """
484
+ # In case it is negative
485
+ if max_length < 0:
486
+ max_length = self.get_seq_length() - abs(max_length)
487
+
488
+ if self.get_seq_length() <= max_length:
489
+ return
490
+
491
+ if hasattr(self, "_seen_tokens"):
492
+ self._seen_tokens = max_length
493
+ for idx in range(len(self.key_cache)):
494
+ if self.key_cache[idx].numel():
495
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
496
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
497
+
498
+ @classmethod
499
+ def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
500
+ """This is the opposite of the above `batch_split()` method.
501
+ This will be used by `stack_model_outputs` in
502
+ `generation.utils`"""
503
+ cache = cls()
504
+ for idx in range(len(splits[0])):
505
+ key_cache = [
506
+ current.key_cache[idx]
507
+ for current in splits
508
+ if current.key_cache[idx].numel()
509
+ ]
510
+ value_cache = [
511
+ current.value_cache[idx]
512
+ for current in splits
513
+ if current.value_cache[idx].numel()
514
+ ]
515
+ if key_cache != []:
516
+ layer_keys = torch.cat(key_cache, dim=0)
517
+ layer_values = torch.cat(value_cache, dim=0)
518
+ cache.update(layer_keys, layer_values, idx)
519
+ return cache
520
+
521
+
522
+ class patched_GenerationMixin:
523
+ """
524
+ Applies modifications implemented in PR
525
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
526
+ """
527
+
528
+ _PATCHES_ = [
529
+ "_cache_dependant_input_preparation",
530
+ "_cache_dependant_input_preparation_exporting",
531
+ (
532
+ None
533
+ if pv.Version(transformers.__version__) >= pv.Version("4.56")
534
+ else "prepare_inputs_for_generation"
535
+ ),
536
+ (
537
+ "_sample"
538
+ if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
539
+ else None
540
+ ),
541
+ ]
542
+ _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
543
+
544
+ def _cache_dependant_input_preparation(
545
+ self,
546
+ input_ids: torch.LongTensor,
547
+ inputs_embeds: Optional[torch.FloatTensor],
548
+ cache_position: Optional[torch.LongTensor],
549
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
550
+ """
551
+ Generic cache-dependent input preparation
552
+ The code is put in a separate function to allow granular unit testing
553
+ as it needs a different implementation to be exportable.
554
+
555
+ If we have cache: let's slice `input_ids` through `cache_position`,
556
+ to keep only the unprocessed tokens
557
+ - Exception 1: when passing input_embeds,
558
+ input_ids may be missing entries
559
+ - Exception 2: some generation methods do special slicing of input_ids,
560
+ so we don't need to do it here
561
+ - Exception 3: with synced GPUs cache_position may go out of bounds,
562
+ but we only want dummy token in that case.
563
+ - Exception 4: If input_embeds are passed then slice it through
564
+ `cache_position`, to keep only the unprocessed tokens and
565
+ generate the first token for each sequence.
566
+ Later use the generated Input ids for continuation.
567
+
568
+ The current implementation does not rely on ``self`` and could be
569
+ a class method. It is left as a standard method to be easily rewritten.
570
+ """
571
+ if _is_torchdynamo_exporting():
572
+ return self._cache_dependant_input_preparation_exporting(
573
+ input_ids, inputs_embeds, cache_position
574
+ )
575
+ if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
576
+ inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
577
+ elif inputs_embeds is not None or ( # Exception 1
578
+ cache_position[-1] >= input_ids.shape[1]
579
+ ): # Exception 3
580
+ input_ids = input_ids[:, -cache_position.shape[0] :]
581
+ elif (
582
+ input_ids.shape[1] != cache_position.shape[0]
583
+ ): # Default case (the "else", a no op, is Exception 2)
584
+ input_ids = input_ids[:, cache_position]
585
+ return inputs_embeds, input_ids
586
+
587
+ def _cache_dependant_input_preparation_exporting(
588
+ self,
589
+ input_ids: torch.LongTensor,
590
+ inputs_embeds: Optional[torch.FloatTensor],
591
+ cache_position: Optional[torch.LongTensor],
592
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
593
+ """
594
+ This method implements method ``_cache_dependant_input_preparation``
595
+ with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
596
+ The code is put in a separate function to allow granular unit testing.
597
+ """
598
+ if inputs_embeds is None:
599
+ input_ids = input_ids[:, cache_position]
600
+ else:
601
+ # This is the code we need to implemented with torch.cond.
602
+ # if input_ids.shape[1] == 0:
603
+ # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
604
+ # else:
605
+ # if cache_position[-1] >= input_ids.shape[1]:
606
+ # input_ids = input_ids[:, -cache_position.shape[0] :]
607
+ # else:
608
+ # if input_ids.shape[1] != cache_position.shape[0]:
609
+ # input_ids = input_ids[:, cache_position]
610
+ def branch_1(inputs_embeds, cache_position):
611
+ return inputs_embeds[:, -cache_position.shape[0] :]
612
+
613
+ def branch_2(input_ids, cache_position):
614
+ return input_ids[:, -cache_position.shape[0] :]
615
+
616
+ def branch_3(input_ids, cache_position):
617
+ return input_ids[:, cache_position]
618
+
619
+ inputs_embeds, input_ids = torch.cond(
620
+ input_ids.shape[1] == 0,
621
+ (
622
+ lambda input_ids, inputs_embeds, cache_position: (
623
+ branch_1(inputs_embeds, cache_position),
624
+ input_ids,
625
+ )
626
+ ),
627
+ (
628
+ lambda input_ids, inputs_embeds, cache_position: (
629
+ inputs_embeds,
630
+ torch.cond(
631
+ cache_position[-1] >= input_ids.shape[1],
632
+ branch_2,
633
+ lambda input_ids, cache_position: (
634
+ torch.cond(
635
+ input_ids.shape[1] != cache_position.shape[0],
636
+ branch_3,
637
+ (lambda input_ids, cache_position: input_ids),
638
+ [input_ids, cache_position],
639
+ )
640
+ ),
641
+ [input_ids, cache_position],
642
+ ),
643
+ )
644
+ ),
645
+ [input_ids, inputs_embeds, cache_position],
646
+ )
647
+ return inputs_embeds, input_ids
648
+
649
+ def prepare_inputs_for_generation(
650
+ self,
651
+ input_ids: torch.LongTensor,
652
+ past_key_values: Optional[Cache] = None,
653
+ attention_mask: Optional[torch.LongTensor] = None,
654
+ inputs_embeds: Optional[torch.FloatTensor] = None,
655
+ cache_position: Optional[torch.LongTensor] = None,
656
+ **kwargs,
657
+ ):
658
+ """
659
+ Prepare the model inputs for generation.
660
+ In includes operations like computing the 4D attention mask or
661
+ slicing inputs given the existing cache.
662
+
663
+ See the forward pass in the model documentation
664
+ for expected arguments (different models might have different
665
+ requirements for e.g. `past_key_values`).
666
+ This function should work as is for most LLMs.
667
+ """
668
+
669
+ # 1. Handle BC:
670
+ model_inputs = {}
671
+ # - some models don't have `Cache` support
672
+ # (which implies they don't expect `cache_position` in `forward`)
673
+ if getattr(self, "_supports_cache_class", False):
674
+ model_inputs["cache_position"] = cache_position
675
+ # - `cache_position` was not a mandatory input in
676
+ # `prepare_inputs_for_generation` for those models, and this
677
+ # function may be called outside of `generate`.
678
+ # Handle most use cases by creating `cache_position` on the fly
679
+ # (this alternative is not as robust as calling
680
+ # `generate` and letting it create `cache_position`)
681
+ elif cache_position is None:
682
+ past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
683
+ cache_position = torch.arange(
684
+ past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device
685
+ )
686
+
687
+ # 2. Generic cache-dependent input preparation
688
+ if past_key_values is not None:
689
+ model_inputs["past_key_values"] = past_key_values
690
+ inputs_embeds, input_ids = self._cache_dependant_input_preparation(
691
+ input_ids, inputs_embeds, cache_position
692
+ )
693
+
694
+ # 3. Prepare base model inputs
695
+ input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
696
+ # if `inputs_embeds` are passed, we only want
697
+ # to use them in the 1st generation step for every prompt.
698
+ if not self.config.is_encoder_decoder:
699
+ if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
700
+ model_inputs[input_ids_key] = None
701
+ model_inputs["inputs_embeds"] = inputs_embeds
702
+ else:
703
+ # `clone` calls in this function ensure a consistent stride. See #32227
704
+ model_inputs[input_ids_key] = input_ids.clone(
705
+ memory_format=torch.contiguous_format
706
+ )
707
+ model_inputs["inputs_embeds"] = None
708
+ else:
709
+ model_inputs[input_ids_key] = input_ids.clone(
710
+ memory_format=torch.contiguous_format
711
+ )
712
+
713
+ # 4. Create missing `position_ids` on the fly
714
+ encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
715
+ attention_mask = (
716
+ kwargs.pop("decoder_attention_mask", None)
717
+ if self.config.is_encoder_decoder
718
+ else attention_mask
719
+ )
720
+ attention_mask_key = (
721
+ "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
722
+ )
723
+ position_ids_key = (
724
+ "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
725
+ )
726
+ if (
727
+ attention_mask is not None
728
+ and kwargs.get(position_ids_key) is None
729
+ and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
730
+ ):
731
+ position_ids = attention_mask.long().cumsum(-1) - 1
732
+ position_ids.masked_fill_(attention_mask == 0, 1)
733
+ kwargs[position_ids_key] = (
734
+ position_ids # placed in kwargs for further processing (see below)
735
+ )
736
+
737
+ # 5. Slice model inputs if it's an input
738
+ # that should have the same length as `input_ids`
739
+ for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
740
+ model_input = kwargs.get(model_input_name)
741
+ if model_input is not None:
742
+ if past_key_values is not None:
743
+ current_input_length = (
744
+ model_inputs["inputs_embeds"].shape[1]
745
+ if model_inputs.get("inputs_embeds") is not None
746
+ else model_inputs[input_ids_key].shape[1]
747
+ )
748
+ model_input = model_input[:, -current_input_length:]
749
+ model_input = model_input.clone(memory_format=torch.contiguous_format)
750
+ model_inputs[model_input_name] = model_input
751
+
752
+ # 6. Create 4D attention mask is we are using a
753
+ # `StaticCache` (important for performant compiled forward pass)
754
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
755
+ if model_inputs["inputs_embeds"] is not None:
756
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
757
+ device = model_inputs["inputs_embeds"].device
758
+ else:
759
+ batch_size, sequence_length = model_inputs[input_ids_key].shape
760
+ device = model_inputs[input_ids_key].device
761
+
762
+ # Create the causal mask with fixed shape in advance,
763
+ # to reduce recompilations. If the function to create
764
+ # the 4D causal mask exists,
765
+ # it should be present in the base model (XXXModel class).
766
+ base_model = getattr(self, self.base_model_prefix, None)
767
+ if base_model is None:
768
+ causal_mask_creation_function = getattr(
769
+ self, "_prepare_4d_causal_attention_mask_with_cache_position", None
770
+ )
771
+ else:
772
+ causal_mask_creation_function = getattr(
773
+ base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
774
+ )
775
+ if causal_mask_creation_function is None:
776
+ pass
777
+ # logger.warning_once(
778
+ # f"{self.__class__.__name__} has no "
779
+ # "`_prepare_4d_causal_attention_mask_with_cache_position` method "
780
+ # "defined in its base modeling class. "
781
+ # "Compiled forward passes will be sub-optimal. If you're "
782
+ # "writing code, see Llama for an example implementation. "
783
+ # "If you're a user, please report this "
784
+ # "issue on GitHub."
785
+ # )
786
+ else:
787
+ attention_mask = causal_mask_creation_function(
788
+ attention_mask,
789
+ sequence_length=sequence_length,
790
+ target_length=past_key_values.get_max_cache_shape(),
791
+ dtype=self.dtype,
792
+ device=device,
793
+ cache_position=cache_position,
794
+ batch_size=batch_size,
795
+ config=self.config,
796
+ past_key_values=past_key_values,
797
+ )
798
+ if attention_mask is not None:
799
+ model_inputs[attention_mask_key] = attention_mask
800
+
801
+ if encoder_attention_mask is not None:
802
+ model_inputs["attention_mask"] = encoder_attention_mask
803
+
804
+ # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
805
+ for key, value in kwargs.items():
806
+ if key not in model_inputs:
807
+ model_inputs[key] = value
808
+
809
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
810
+ model_inputs.pop("labels", None)
811
+ return model_inputs
812
+
813
+ def _sample(
814
+ self,
815
+ input_ids: torch.LongTensor,
816
+ logits_processor: "LogitsProcessorList", # noqa: F821
817
+ stopping_criteria: "StoppingCriteriaList", # noqa: F821
818
+ generation_config: "GenerationConfig", # noqa: F821
819
+ synced_gpus: bool = False,
820
+ streamer: Optional["BaseStreamer"] = None, # noqa: F821
821
+ **model_kwargs,
822
+ ) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
823
+ """
824
+ 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
825
+ """
826
+ # init values
827
+ pad_token_id = generation_config._pad_token_tensor
828
+ output_attentions = generation_config.output_attentions
829
+ output_hidden_states = generation_config.output_hidden_states
830
+ output_scores = generation_config.output_scores
831
+ output_logits = generation_config.output_logits
832
+ return_dict_in_generate = generation_config.return_dict_in_generate
833
+ has_eos_stopping_criteria = any(
834
+ hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
835
+ )
836
+ do_sample = generation_config.do_sample
837
+
838
+ # init attention / hidden states / scores tuples
839
+ scores = () if (return_dict_in_generate and output_scores) else None
840
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
841
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
842
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
843
+ decoder_hidden_states = (
844
+ () if (return_dict_in_generate and output_hidden_states) else None
845
+ )
846
+
847
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
848
+ if return_dict_in_generate and self.config.is_encoder_decoder:
849
+ encoder_attentions = (
850
+ model_kwargs["encoder_outputs"].get("attentions")
851
+ if output_attentions
852
+ else None
853
+ )
854
+ encoder_hidden_states = (
855
+ model_kwargs["encoder_outputs"].get("hidden_states")
856
+ if output_hidden_states
857
+ else None
858
+ )
859
+
860
+ # keep track of which sequences are already finished
861
+ batch_size, cur_len = input_ids.shape[:2]
862
+ this_peer_finished = False
863
+ unfinished_sequences = torch.ones(
864
+ batch_size, dtype=torch.long, device=input_ids.device
865
+ )
866
+ model_kwargs = self._get_initial_cache_position(
867
+ cur_len, input_ids.device, model_kwargs
868
+ )
869
+
870
+ model_forward = self.__call__
871
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
872
+ if compile_forward:
873
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
874
+ # If we use FA2 and a static cache, we cannot compile with fullgraph
875
+ if self.config._attn_implementation == "flash_attention_2":
876
+ # only raise warning if the user passed an explicit compile-config
877
+ if (
878
+ generation_config.compile_config is not None
879
+ and generation_config.compile_config.fullgraph
880
+ ):
881
+ generation_config.compile_config.fullgraph = False
882
+ model_forward = self.get_compiled_call(generation_config.compile_config)
883
+
884
+ if generation_config.prefill_chunk_size is not None:
885
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
886
+ is_prefill = False
887
+ else:
888
+ is_prefill = True
889
+
890
+ while self._has_unfinished_sequences(
891
+ this_peer_finished, synced_gpus, device=input_ids.device
892
+ ):
893
+ # prepare model inputs
894
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
895
+
896
+ if is_prefill:
897
+ outputs = self(**model_inputs, return_dict=True)
898
+ is_prefill = False
899
+ else:
900
+ outputs = model_forward(**model_inputs, return_dict=True)
901
+
902
+ model_kwargs = self._update_model_kwargs_for_generation(
903
+ outputs,
904
+ model_kwargs,
905
+ is_encoder_decoder=self.config.is_encoder_decoder,
906
+ )
907
+ if synced_gpus and this_peer_finished:
908
+ continue
909
+
910
+ next_token_logits = outputs.logits[:, -1, :].to(
911
+ copy=True, dtype=torch.float32, device=input_ids.device
912
+ )
913
+
914
+ # pre-process distribution
915
+ next_token_scores = logits_processor(input_ids, next_token_logits)
916
+
917
+ # Store scores, attentions and hidden_states when required
918
+ if return_dict_in_generate:
919
+ if output_scores:
920
+ scores += (next_token_scores,)
921
+ if output_logits:
922
+ raw_logits += (next_token_logits,)
923
+ if output_attentions:
924
+ decoder_attentions += (
925
+ (outputs.decoder_attentions,)
926
+ if self.config.is_encoder_decoder
927
+ else (outputs.attentions,)
928
+ )
929
+ if self.config.is_encoder_decoder:
930
+ cross_attentions += (outputs.cross_attentions,)
931
+
932
+ if output_hidden_states:
933
+ decoder_hidden_states += (
934
+ (outputs.decoder_hidden_states,)
935
+ if self.config.is_encoder_decoder
936
+ else (outputs.hidden_states,)
937
+ )
938
+
939
+ # token selection
940
+ if do_sample:
941
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
942
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
943
+ else:
944
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
945
+
946
+ # finished sentences should have their next token be a padding token
947
+ if has_eos_stopping_criteria:
948
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
949
+ 1 - unfinished_sequences
950
+ )
951
+
952
+ # update generated ids, model inputs, and length for next step
953
+ # PATCHED: the two following lines, next_tokens can 2D already for this model
954
+ next_tokens_2d = (
955
+ next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
956
+ )
957
+ input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
958
+ if streamer is not None:
959
+ streamer.put(next_tokens.cpu())
960
+
961
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
962
+ this_peer_finished = unfinished_sequences.max() == 0
963
+ cur_len += 1
964
+
965
+ # This is needed to properly delete outputs.logits which may be very large
966
+ # for first iteration
967
+ # Otherwise a reference to outputs is kept which keeps
968
+ # the logits alive in the next iteration
969
+ del outputs
970
+
971
+ if streamer is not None:
972
+ streamer.end()
973
+
974
+ if return_dict_in_generate:
975
+ if self.config.is_encoder_decoder:
976
+ return transformers.generation.utils.GenerateEncoderDecoderOutput(
977
+ sequences=input_ids,
978
+ scores=scores,
979
+ logits=raw_logits,
980
+ encoder_attentions=encoder_attentions,
981
+ encoder_hidden_states=encoder_hidden_states,
982
+ decoder_attentions=decoder_attentions,
983
+ cross_attentions=cross_attentions,
984
+ decoder_hidden_states=decoder_hidden_states,
985
+ past_key_values=model_kwargs.get("past_key_values"),
986
+ )
987
+ else:
988
+ return transformers.generation.utils.GenerateDecoderOnlyOutput(
989
+ sequences=input_ids,
990
+ scores=scores,
991
+ logits=raw_logits,
992
+ attentions=decoder_attentions,
993
+ hidden_states=decoder_hidden_states,
994
+ past_key_values=model_kwargs.get("past_key_values"),
995
+ )
996
+ else:
997
+ return input_ids
998
+
999
+
1000
+ def patched__compute_dynamic_ntk_parameters(
1001
+ config: Optional[transformers.PretrainedConfig] = None,
1002
+ device: Optional["torch.device"] = None,
1003
+ seq_len: Optional[int] = None,
1004
+ **rope_kwargs,
1005
+ ) -> Tuple["torch.Tensor", float]:
1006
+ """
1007
+ manual patch:
1008
+ ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
1009
+
1010
+ Computes the inverse frequencies with NTK scaling.
1011
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
1012
+
1013
+ Args:
1014
+ config ([`~transformers.PretrainedConfig`]):
1015
+ The model configuration.
1016
+ device (`torch.device`):
1017
+ The device to use for initialization of the inverse frequencies.
1018
+ seq_len (`int`, *optional*):
1019
+ The current sequence length,
1020
+ used to update the dynamic RoPE at inference time.
1021
+ rope_kwargs (`Dict`, *optional*):
1022
+ BC compatibility with the previous
1023
+ RoPE class instantiation, will be removed in v4.45.
1024
+
1025
+ Returns:
1026
+ Tuple of (`torch.Tensor`, `float`),
1027
+ containing the inverse frequencies for the RoPE embeddings and the
1028
+ post-processing scaling factor applied to the
1029
+ omputed cos/sin (unused in this type of RoPE).
1030
+ """
1031
+ if config is not None and len(rope_kwargs) > 0:
1032
+ raise ValueError(
1033
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
1034
+ f"`_compute_dynamic_ntk_parameters`, got "
1035
+ f"`rope_kwargs`={rope_kwargs} and `config`={config}"
1036
+ )
1037
+ if len(rope_kwargs) > 0:
1038
+ base = rope_kwargs["base"]
1039
+ dim = rope_kwargs["dim"]
1040
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
1041
+ factor = rope_kwargs["factor"]
1042
+ elif config is not None:
1043
+ base = config.rope_theta
1044
+ partial_rotary_factor = (
1045
+ config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
1046
+ )
1047
+ head_dim = getattr(
1048
+ config, "head_dim", config.hidden_size // config.num_attention_heads
1049
+ )
1050
+ dim = int(head_dim * partial_rotary_factor)
1051
+ max_position_embeddings = config.max_position_embeddings
1052
+ factor = config.rope_scaling["factor"]
1053
+
1054
+ attention_factor = 1.0 # Unused in this type of RoPE
1055
+
1056
+ # seq_len: default to max_position_embeddings, e.g. at init time
1057
+ # seq_len = seq_len if seq_len is not None and
1058
+ # seq_len > max_position_embeddings else max_position_embeddings
1059
+ if seq_len is None:
1060
+ seq_len = max_position_embeddings
1061
+ else:
1062
+ # PATCHED: remove the line using max
1063
+ torch._check(isinstance(seq_len, torch.Tensor))
1064
+ seq_len = torch.maximum(
1065
+ seq_len,
1066
+ torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
1067
+ )
1068
+
1069
+ # Compute the inverse frequencies
1070
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
1071
+ dim / (dim - 2)
1072
+ )
1073
+ inv_freq = 1.0 / (
1074
+ base
1075
+ ** (
1076
+ torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
1077
+ / dim
1078
+ )
1079
+ )
1080
+ return inv_freq, attention_factor
1081
+
1082
+
1083
+ def _get_rope_init_fn(self, layer_type=None) -> Callable:
1084
+ if hasattr(self, "rope_init_fn"):
1085
+ # transformers<=5.0
1086
+ rope_init_fn = (
1087
+ patched__compute_dynamic_ntk_parameters
1088
+ if self.rope_init_fn
1089
+ is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1090
+ else self.rope_init_fn
1091
+ )
1092
+ return rope_init_fn
1093
+
1094
+ rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
1095
+ rope_init_fn = self.compute_default_rope_parameters
1096
+ if rope_type != "default":
1097
+ rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
1098
+ if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
1099
+ return patched__compute_dynamic_ntk_parameters
1100
+ return rope_init_fn
1101
+
1102
+
1103
+ def patched_dynamic_rope_update(rope_forward):
1104
+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
1105
+
1106
+ ``rope_type`` is determined in the constructor of class
1107
+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
1108
+
1109
+ .. code-block:: python
1110
+
1111
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1112
+ self.rope_type = config.rope_scaling.get(
1113
+ "rope_type", config.rope_scaling.get("type"))
1114
+ else:
1115
+ self.rope_type = "default"
1116
+
1117
+ The original code of the patched function:
1118
+
1119
+ .. code-block:: python
1120
+
1121
+ def dynamic_rope_update(rope_forward):
1122
+ def longrope_frequency_update(self, position_ids, device):
1123
+ seq_len = torch.max(position_ids) + 1
1124
+ if hasattr(self.config, "original_max_position_embeddings"):
1125
+ original_max_position_embeddings =
1126
+ self.config.original_max_position_embeddings
1127
+ else:
1128
+ original_max_position_embeddings =
1129
+ self.config.max_position_embeddings
1130
+ if seq_len > original_max_position_embeddings:
1131
+ if not hasattr(self, "long_inv_freq"):
1132
+ self.long_inv_freq, _ = self.rope_init_fn(
1133
+ self.config, device, seq_len=original_max_position_embeddings + 1
1134
+ )
1135
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
1136
+ else:
1137
+ self.original_inv_freq = self.original_inv_freq.to(device)
1138
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1139
+
1140
+ def dynamic_frequency_update(self, position_ids, device):
1141
+ seq_len = torch.max(position_ids) + 1
1142
+ if seq_len > self.max_seq_len_cached: # growth
1143
+ inv_freq, self.attention_scaling = self.rope_init_fn(
1144
+ self.config, device, seq_len=seq_len)
1145
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1146
+ self.max_seq_len_cached = seq_len
1147
+
1148
+ if seq_len < self.original_max_seq_len and
1149
+ self.max_seq_len_cached > self.original_max_seq_len:
1150
+ self.original_inv_freq = self.original_inv_freq.to(device)
1151
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1152
+ self.max_seq_len_cached = self.original_max_seq_len
1153
+
1154
+ @wraps(rope_forward)
1155
+ def wrapper(self, x, position_ids):
1156
+ if "dynamic" in self.rope_type:
1157
+ dynamic_frequency_update(self, position_ids, device=x.device)
1158
+ elif self.rope_type == "longrope":
1159
+ longrope_frequency_update(self, position_ids, device=x.device)
1160
+ return rope_forward(self, x, position_ids)
1161
+
1162
+ return wrapper
1163
+
1164
+ """
1165
+
1166
+ def longrope_frequency_update(self, position_ids, device, layer_type=None):
1167
+ # It is no use to patch the function after the model is created
1168
+ # as rope_init_fn is an attribute set to one function when the model
1169
+ # is created and when no patch is applied yet.
1170
+ # So we select the patched version here.
1171
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
1172
+ seq_len = torch.max(position_ids) + 1
1173
+ if hasattr(self.config, "original_max_position_embeddings"):
1174
+ original_max_position_embeddings = self.config.original_max_position_embeddings
1175
+ else:
1176
+ original_max_position_embeddings = self.config.max_position_embeddings
1177
+
1178
+ if layer_type is None:
1179
+ # rope_type = self.rope_type
1180
+ original_inv_freq = self.original_inv_freq
1181
+ prefix = ""
1182
+ else:
1183
+ # rope_type = self.rope_type[layer_type]
1184
+ original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1185
+ prefix = f"{layer_type}_"
1186
+
1187
+ # At export time, seq_len is unknown.
1188
+ long_inv_freq, _ = rope_init_fn(
1189
+ self.config, device, seq_len=original_max_position_embeddings + 1
1190
+ )
1191
+ original_inv_freq = self.original_inv_freq.to(device)
1192
+
1193
+ # PATCHED: uses torch.cond instead of a test
1194
+ cond = (seq_len > original_max_position_embeddings).item()
1195
+ inv_freq = torch.cond(
1196
+ cond,
1197
+ (lambda x, y: x.clone()),
1198
+ (lambda x, y: y.clone()),
1199
+ [long_inv_freq, original_inv_freq],
1200
+ )
1201
+ setattr(self, f"{prefix}inv_freq", inv_freq)
1202
+ # if seq_len > original_max_position_embeddings:
1203
+ # self.inv_freq = self.long_inv_freq
1204
+ # else:
1205
+ # self.inv_freq = self.original_inv_freq
1206
+
1207
+ def dynamic_frequency_update(self, position_ids, device, layer_type=None):
1208
+ # constructor:
1209
+ # - self.max_seq_len_cached = config.max_position_embeddings
1210
+ # - self.original_max_seq_len = config.max_position_embeddings
1211
+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
1212
+
1213
+ # It is no use to patch the function after the model is created
1214
+ # as rope_init_fn is an attribute set to one function when the model
1215
+ # is created and when no patch is applied yet.
1216
+ # So we select the patched version here.
1217
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
1218
+
1219
+ # This behaviour is difficult to translate.
1220
+ # The sequence always grows.
1221
+ # The test should always True.
1222
+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
1223
+ #
1224
+ # if seq_len > self.max_seq_len_cached: # growth
1225
+ # inv_freq, self.attention_scaling = self.rope_init_fn(
1226
+ # self.config, device, seq_len=seq_len
1227
+ # )
1228
+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
1229
+ # self.max_seq_len_cached = seq_len
1230
+ #
1231
+ # So we should not need what follows.
1232
+ #
1233
+ # cond = (seq_len > self.max_seq_len_cached).item()
1234
+ # self.attention_scaling = torch.cond(
1235
+ # cond,
1236
+ # (lambda x, y: x.clone()),
1237
+ # (lambda x, y: y.clone()),
1238
+ # [attention_scaling, self.attention_scaling],
1239
+ # )
1240
+
1241
+ seq_len = torch.max(position_ids) + 1
1242
+ long_inv_freq, self.attention_scaling = rope_init_fn(
1243
+ self.config, device, seq_len=seq_len
1244
+ )
1245
+
1246
+ if layer_type is None:
1247
+ # rope_type = self.rope_type
1248
+ # max_seq_len_cached = self.max_seq_len_cached
1249
+ original_inv_freq = self.original_inv_freq
1250
+ prefix = ""
1251
+ else:
1252
+ # rope_type = self.rope_type[layer_type]
1253
+ # max_seq_len_cached = getattr(
1254
+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
1255
+ # )
1256
+ original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1257
+ prefix = f"{layer_type}_"
1258
+
1259
+ # Second test to translate.
1260
+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
1261
+ # But in that case the following condition is a way to restore the original cache.
1262
+
1263
+ # if (
1264
+ # seq_len < self.original_max_seq_len
1265
+ # and self.max_seq_len_cached > self.original_max_seq_len
1266
+ # ):
1267
+ # self.original_inv_freq = self.original_inv_freq.to(device)
1268
+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
1269
+ # self.max_seq_len_cached = self.original_max_seq_len
1270
+
1271
+ original_inv_freq = self.original_inv_freq.to(device)
1272
+ cond = (seq_len >= self.original_max_seq_len).item()
1273
+ # PATCHED: uses torch.cond instead of a test
1274
+ inv_freq = torch.cond(
1275
+ cond,
1276
+ (lambda x, y: x.clone()),
1277
+ (lambda x, y: y.clone()),
1278
+ [long_inv_freq, original_inv_freq],
1279
+ )
1280
+ setattr(self, f"{prefix}inv_freq", inv_freq)
1281
+
1282
+ @wraps(rope_forward)
1283
+ def wrapper(self, x, position_ids, layer_type=None):
1284
+ if layer_type is None:
1285
+ if "dynamic" in self.rope_type:
1286
+ dynamic_frequency_update(self, position_ids, device=x.device)
1287
+ elif self.rope_type == "longrope":
1288
+ longrope_frequency_update(self, position_ids, device=x.device)
1289
+ return rope_forward(self, x, position_ids)
1290
+
1291
+ if "dynamic" in self.rope_type:
1292
+ dynamic_frequency_update(
1293
+ self, position_ids, device=x.device, layer_type=layer_type
1294
+ )
1295
+ elif self.rope_type == "longrope":
1296
+ longrope_frequency_update(
1297
+ self, position_ids, device=x.device, layer_type=layer_type
1298
+ )
1299
+ return rope_forward(self, x, position_ids, layer_type=layer_type)
1300
+
1301
+ return wrapper
1302
+
1303
+
1304
+ def common_eager_attention_forward(
1305
+ module: torch.nn.Module,
1306
+ query: torch.Tensor,
1307
+ key: torch.Tensor,
1308
+ value: torch.Tensor,
1309
+ attention_mask: Optional[torch.Tensor],
1310
+ scaling: Optional[float] = None,
1311
+ dropout: float = 0.0,
1312
+ head_mask: Optional[torch.Tensor] = None,
1313
+ **kwargs,
1314
+ ):
1315
+ if scaling is None:
1316
+ scaling = query.size(-1) ** -0.5
1317
+
1318
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
1319
+ if attention_mask is not None:
1320
+ # PATCHED
1321
+ # The two following lines were added.
1322
+ if attention_mask is not None and attention_mask.ndim == 4:
1323
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1324
+ attn_weights = attn_weights + attention_mask
1325
+
1326
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
1327
+
1328
+ if head_mask is not None:
1329
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
1330
+
1331
+ attn_weights = torch.nn.functional.dropout(
1332
+ attn_weights, p=dropout, training=module.training
1333
+ )
1334
+ attn_output = torch.matmul(attn_weights, value)
1335
+ attn_output = attn_output.transpose(1, 2).contiguous()
1336
+
1337
+ return attn_output, attn_weights
1338
+
1339
+
1340
+ def patched_sdpa_attention_forward(
1341
+ module: torch.nn.Module,
1342
+ query: torch.Tensor,
1343
+ key: torch.Tensor,
1344
+ value: torch.Tensor,
1345
+ attention_mask: Optional[torch.Tensor],
1346
+ dropout: float = 0.0,
1347
+ scaling: Optional[float] = None,
1348
+ is_causal: Optional[bool] = None,
1349
+ **kwargs,
1350
+ ) -> tuple[torch.Tensor, None]:
1351
+ """
1352
+ manual patch for function
1353
+ ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
1354
+ """
1355
+ assert not kwargs.get("output_attentions", False), (
1356
+ "`sdpa` attention does not support `output_attentions=True`."
1357
+ " Please set your attention to `eager` if you want any of these features."
1358
+ )
1359
+ torch._check(
1360
+ query.shape[0] == key.shape[0] or query.shape[0] == 1,
1361
+ lambda: (
1362
+ f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
1363
+ f"value: {value.shape}"
1364
+ ),
1365
+ )
1366
+ torch._check(
1367
+ key.shape[0] == value.shape[0] or key.shape[0] == 1,
1368
+ lambda: (
1369
+ f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
1370
+ f"value: {value.shape}"
1371
+ ),
1372
+ )
1373
+
1374
+ sdpa_kwargs = {}
1375
+ if hasattr(module, "num_key_value_groups"):
1376
+ if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
1377
+ key = transformers.integrations.sdpa_attention.repeat_kv(
1378
+ key, module.num_key_value_groups
1379
+ )
1380
+ value = transformers.integrations.sdpa_attention.repeat_kv(
1381
+ value, module.num_key_value_groups
1382
+ )
1383
+ else:
1384
+ sdpa_kwargs = {"enable_gqa": True}
1385
+
1386
+ if attention_mask is not None and attention_mask.ndim == 4:
1387
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1388
+
1389
+ torch._check(
1390
+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
1391
+ lambda: "Attention mask shape incompatible with key shape.",
1392
+ )
1393
+
1394
+ if patch_sdpa_is_causal:
1395
+ # transformers>=4.55
1396
+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1397
+
1398
+ # PATCHED: remove the test query.shape[2] > 1
1399
+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1400
+ # and we split the test to keep the minimum in torch.cond
1401
+ is_causal = attention_mask is None and is_causal
1402
+
1403
+ if not is_causal:
1404
+ return (
1405
+ torch.nn.functional.scaled_dot_product_attention(
1406
+ query,
1407
+ key,
1408
+ value,
1409
+ attn_mask=attention_mask,
1410
+ dropout_p=dropout,
1411
+ scale=scaling,
1412
+ is_causal=is_causal,
1413
+ **sdpa_kwargs,
1414
+ )
1415
+ .transpose(1, 2)
1416
+ .contiguous(),
1417
+ None,
1418
+ )
1419
+ else:
1420
+ # transformers<4.55
1421
+ if is_causal is None and attention_mask is not None:
1422
+ is_causal = False
1423
+ if is_causal is not None:
1424
+ return (
1425
+ torch.nn.functional.scaled_dot_product_attention(
1426
+ query,
1427
+ key,
1428
+ value,
1429
+ attn_mask=attention_mask,
1430
+ dropout_p=dropout,
1431
+ scale=scaling,
1432
+ is_causal=is_causal,
1433
+ **sdpa_kwargs,
1434
+ )
1435
+ .transpose(1, 2)
1436
+ .contiguous(),
1437
+ None,
1438
+ )
1439
+
1440
+ # To avoid the following errors:
1441
+ # is_causal=query.shape[2] > 1
1442
+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
1443
+ # is_causal=torch.tensor(query.shape[2] > 1)
1444
+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
1445
+ attn_output = torch.cond(
1446
+ query.shape[2] > 1, # distinction between prefill and decoding steps
1447
+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1448
+ query,
1449
+ key,
1450
+ value,
1451
+ dropout_p=dropout,
1452
+ scale=scaling,
1453
+ is_causal=True,
1454
+ **sdpa_kwargs,
1455
+ ),
1456
+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1457
+ query,
1458
+ key,
1459
+ value,
1460
+ dropout_p=dropout,
1461
+ scale=scaling,
1462
+ is_causal=False,
1463
+ **sdpa_kwargs,
1464
+ ),
1465
+ [query, key, value],
1466
+ )
1467
+ attn_output = attn_output.transpose(1, 2).contiguous()
1468
+ return attn_output, None
1469
+
1470
+
1471
+ def patched_model_bart_eager_attention_forward(
1472
+ module: torch.nn.Module,
1473
+ query: torch.Tensor,
1474
+ key: torch.Tensor,
1475
+ value: torch.Tensor,
1476
+ attention_mask: Optional[torch.Tensor],
1477
+ scaling: Optional[float] = None,
1478
+ dropout: float = 0.0,
1479
+ head_mask: Optional[torch.Tensor] = None,
1480
+ **kwargs,
1481
+ ):
1482
+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
1483
+ return common_eager_attention_forward(
1484
+ module,
1485
+ query,
1486
+ key,
1487
+ value,
1488
+ attention_mask=attention_mask,
1489
+ scaling=scaling,
1490
+ dropout=dropout,
1491
+ head_mask=head_mask,
1492
+ **kwargs,
1493
+ )
1494
+
1495
+
1496
+ def patched_modeling_marian_eager_attention_forward(
1497
+ module: torch.nn.Module,
1498
+ query: torch.Tensor,
1499
+ key: torch.Tensor,
1500
+ value: torch.Tensor,
1501
+ attention_mask: Optional[torch.Tensor],
1502
+ scaling: Optional[float] = None,
1503
+ dropout: float = 0.0,
1504
+ head_mask: Optional[torch.Tensor] = None,
1505
+ **kwargs,
1506
+ ):
1507
+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
1508
+ return common_eager_attention_forward(
1509
+ module,
1510
+ query,
1511
+ key,
1512
+ value,
1513
+ attention_mask=attention_mask,
1514
+ scaling=scaling,
1515
+ dropout=dropout,
1516
+ head_mask=head_mask,
1517
+ **kwargs,
1518
+ )
1519
+
1520
+
1521
+ class common_RotaryEmbedding(torch.nn.Module):
1522
+ # This may cause some issues.
1523
+ # @torch.no_grad()
1524
+ # PATCHED: the decorator
1525
+ @patched_dynamic_rope_update
1526
+ def forward(self, x, position_ids, layer_type=None):
1527
+ if layer_type is not None:
1528
+ # transformers>=5.0
1529
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
1530
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
1531
+ else:
1532
+ # transformers<5.0
1533
+ inv_freq = self.inv_freq
1534
+ attention_scaling = self.attention_scaling
1535
+
1536
+ inv_freq_expanded = (
1537
+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1538
+ )
1539
+ position_ids_expanded = position_ids[:, None, :].float()
1540
+
1541
+ device_type = (
1542
+ x.device.type
1543
+ if isinstance(x.device.type, str) and x.device.type != "mps"
1544
+ else "cpu"
1545
+ )
1546
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
1547
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1548
+ emb = torch.cat((freqs, freqs), dim=-1)
1549
+ cos = emb.cos() * attention_scaling
1550
+ sin = emb.sin() * attention_scaling
1551
+
1552
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1553
+
1554
+
1555
+ class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
1556
+ _PATCHES_ = ["forward"]
1557
+ _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
1558
+
1559
+
1560
+ if pv.Version(transformers.__version__) >= pv.Version("4.52"):
1561
+
1562
+ class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
1563
+ _PATCHES_ = ["forward"]
1564
+ _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
1565
+
1566
+ class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
1567
+ _PATCHES_ = ["forward"]
1568
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
1569
+
1570
+
1571
+ class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
1572
+ _PATCHES_ = ["forward"]
1573
+ _PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
1574
+
1575
+
1576
+ class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
1577
+ _PATCHES_ = ["forward"]
1578
+ _PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
1579
+
1580
+
1581
+ class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
1582
+ _PATCHES_ = ["forward"]
1583
+ _PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
1584
+
1585
+
1586
+ class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
1587
+ _PATCHES_ = ["forward"]
1588
+ _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
1589
+
1590
+
1591
+ if pv.Version(transformers.__version__) >= pv.Version("4.51"):
1592
+
1593
+ class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
1594
+ _PATCHES_ = ["forward"]
1595
+ _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
1596
+
1597
+
1598
+ if pv.Version(transformers.__version__) >= pv.Version("4.52"):
1599
+
1600
+ class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
1601
+ _PATCHES_ = ["forward"]
1602
+ _PATCHED_CLASS_ = (
1603
+ transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
1604
+ )
1605
+
1606
+
1607
+ if pv.Version(transformers.__version__) >= pv.Version("4.53"):
1608
+
1609
+ class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
1610
+ _PATCHES_ = ["forward"]
1611
+ _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
1612
+
1613
+
1614
+ class patched_IdeficsEmbedding(torch.nn.Module):
1615
+ _PATCHES_ = ["forward"]
1616
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
1617
+
1618
+ def forward(self, x, seq_len=None):
1619
+ # x: [bs, num_attention_heads, seq_len, head_size]
1620
+ # if seq_len > self.max_seq_len_cached:
1621
+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
1622
+
1623
+ def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
1624
+ t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
1625
+ # freqs = torch.einsum("i,j->ij", t, inv_freq)
1626
+ freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
1627
+ emb = torch.cat((freqs, freqs), dim=-1)
1628
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
1629
+
1630
+ def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
1631
+ torch._check(seq_len.item() <= cos_cached.shape[0])
1632
+ co = cos_cached[: seq_len.item()].detach().clone()
1633
+ torch._check(seq_len.item() <= sin_cached.shape[0])
1634
+ si = sin_cached[: seq_len.item()].detach().clone()
1635
+ return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
1636
+
1637
+ cos_cached, sin_cached = torch.cond(
1638
+ (seq_len > self.max_seq_len_cached).item(),
1639
+ _set_cos_sin_cache_then,
1640
+ _set_cos_sin_cache_else,
1641
+ [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
1642
+ )
1643
+ return cos_cached, sin_cached
1644
+
1645
+
1646
+ class patched_IdeficsAttention(torch.nn.Module):
1647
+ _PATCHES_ = ["forward"]
1648
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
1649
+
1650
+ def forward(
1651
+ self,
1652
+ hidden_states: torch.Tensor,
1653
+ key_value_states: Optional[torch.Tensor] = None,
1654
+ attention_mask: Optional[torch.Tensor] = None,
1655
+ position_ids: Optional[torch.LongTensor] = None,
1656
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1657
+ output_attentions: bool = False,
1658
+ use_cache: bool = False,
1659
+ cache_position: Optional[torch.LongTensor] = None,
1660
+ **kwargs,
1661
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1662
+ # if key_value_states are provided this layer is used as a cross-attention layer
1663
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
1664
+
1665
+ bsz, q_len, _ = hidden_states.size()
1666
+
1667
+ query_states = (
1668
+ self.q_proj(hidden_states)
1669
+ .view(bsz, q_len, self.num_heads, self.head_dim)
1670
+ .transpose(1, 2)
1671
+ )
1672
+ if not is_cross_attention:
1673
+ key_states = (
1674
+ self.k_proj(hidden_states)
1675
+ .view(bsz, q_len, self.num_heads, self.head_dim)
1676
+ .transpose(1, 2)
1677
+ )
1678
+ value_states = (
1679
+ self.v_proj(hidden_states)
1680
+ .view(bsz, q_len, self.num_heads, self.head_dim)
1681
+ .transpose(1, 2)
1682
+ )
1683
+ else:
1684
+ _, kv_len, _ = (
1685
+ key_value_states.size()
1686
+ ) # Note that, in this case, `kv_len` == `kv_seq_len`
1687
+ key_states = (
1688
+ self.k_proj(key_value_states)
1689
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
1690
+ .transpose(1, 2)
1691
+ )
1692
+ value_states = (
1693
+ self.v_proj(key_value_states)
1694
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
1695
+ .transpose(1, 2)
1696
+ )
1697
+
1698
+ kv_seq_len = key_states.shape[-2]
1699
+ if past_key_value is not None:
1700
+ kv_seq_len += cache_position[0]
1701
+
1702
+ if not is_cross_attention:
1703
+ rotary_length = torch.maximum(
1704
+ torch.tensor(kv_seq_len, dtype=torch.int64),
1705
+ torch.tensor(q_len, dtype=torch.int64),
1706
+ )
1707
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
1708
+ query_states, key_states = (
1709
+ transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
1710
+ query_states, key_states, cos, sin, position_ids
1711
+ )
1712
+ )
1713
+ # [bsz, nh, t, hd]
1714
+
1715
+ if past_key_value is not None:
1716
+ # sin and cos are specific to RoPE models;
1717
+ # cache_position needed for the static cache
1718
+ cache_kwargs = {"cache_position": cache_position}
1719
+ key_states, value_states = past_key_value.update(
1720
+ key_states, value_states, self.layer_idx, cache_kwargs
1721
+ )
1722
+
1723
+ if self.qk_layer_norms:
1724
+ query_states = self.q_layer_norm(query_states)
1725
+ key_states = self.k_layer_norm(key_states)
1726
+
1727
+ attention_interface: Callable = (
1728
+ transformers.models.idefics.modeling_idefics.eager_attention_forward
1729
+ )
1730
+
1731
+ if self.config._attn_implementation != "eager":
1732
+ if self.config._attn_implementation == "sdpa" and output_attentions:
1733
+ transformers.models.idefics.modeling_idefics.logger.warning_once(
1734
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
1735
+ "`output_attentions=True`. Falling back to "
1736
+ "eager attention. This warning can be removed using the argument "
1737
+ '`attn_implementation="eager"` when loading the model.'
1738
+ )
1739
+ else:
1740
+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
1741
+ self.config._attn_implementation
1742
+ ]
1743
+
1744
+ attn_output, attn_weights = attention_interface(
1745
+ self,
1746
+ query_states,
1747
+ key_states,
1748
+ value_states,
1749
+ attention_mask,
1750
+ dropout=0.0 if not self.training else self.dropout,
1751
+ scaling=self.scaling,
1752
+ **kwargs,
1753
+ )
1754
+
1755
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
1756
+ attn_output = self.o_proj(attn_output)
1757
+
1758
+ if output_attentions:
1759
+ attn_weights = None
1760
+
1761
+ if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
1762
+ return attn_output, attn_weights, past_key_value
1763
+ return attn_output, attn_weights
1764
+
1765
+
1766
+ class patched_SamMaskDecoder(torch.nn.Module):
1767
+ _PATCHES_ = ["forward"]
1768
+ _PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
1769
+
1770
+ def forward(
1771
+ self,
1772
+ image_embeddings: torch.Tensor,
1773
+ image_positional_embeddings: torch.Tensor,
1774
+ sparse_prompt_embeddings: torch.Tensor,
1775
+ dense_prompt_embeddings: torch.Tensor,
1776
+ multimask_output: bool,
1777
+ output_attentions: Optional[bool] = None,
1778
+ attention_similarity: Optional[torch.Tensor] = None,
1779
+ target_embedding: Optional[torch.Tensor] = None,
1780
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1781
+ """
1782
+ Predict masks given image and prompt embeddings.
1783
+
1784
+ Args:
1785
+ image_embeddings (`torch.Tensor`):
1786
+ the embeddings from the image encoder
1787
+ image_positional_embedding (`torch.Tensor`):
1788
+ positional encoding with the shape of image_embeddings
1789
+ sparse_prompt_embeddings (`torch.Tensor`):
1790
+ The embeddings of the points and boxes
1791
+ dense_prompt_embeddings (`torch.Tensor`):
1792
+ the embeddings of the mask inputs
1793
+ multimask_output (bool):
1794
+ Whether to return multiple masks or a single mask.
1795
+ output_attentions (bool, *optional*):
1796
+ Whether or not to return the attentions tensors of all attention layers.
1797
+ """
1798
+ batch_size, num_channels, height, width = image_embeddings.shape
1799
+ point_batch_size = sparse_prompt_embeddings.shape[1]
1800
+ # Concatenate output tokens
1801
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1802
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
1803
+
1804
+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1805
+ # torch.any is needed to avoid data-dependent control flow
1806
+ # with sparse_prompt_embeddings.sum().item() != 0
1807
+ def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
1808
+ return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
1809
+
1810
+ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
1811
+ return output_tokens.clone()
1812
+
1813
+ tokens = torch.cond(
1814
+ torch.any(sparse_prompt_embeddings != 0),
1815
+ sparse_prompt_embeddings_is_not_empty,
1816
+ sparse_prompt_embeddings_is_empty,
1817
+ [output_tokens, sparse_prompt_embeddings],
1818
+ )
1819
+
1820
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
1821
+
1822
+ # Expand per-image data in batch direction to be per-point
1823
+ image_embeddings = image_embeddings + dense_prompt_embeddings
1824
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
1825
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(
1826
+ point_batch_size, 0
1827
+ )
1828
+
1829
+ # Run the transformer, image_positional_embedding are consumed
1830
+ torch._check(point_embeddings.shape[0] != 0)
1831
+ torch._check(point_embeddings.shape[1] != 0)
1832
+ torch._check(point_embeddings.shape[2] != 0)
1833
+ torch._check(point_embeddings.shape[3] != 0)
1834
+ embeddings_attentions = self.transformer(
1835
+ point_embeddings=point_embeddings,
1836
+ image_embeddings=image_embeddings,
1837
+ image_positional_embeddings=image_positional_embeddings,
1838
+ attention_similarity=attention_similarity,
1839
+ target_embedding=target_embedding,
1840
+ output_attentions=output_attentions,
1841
+ )
1842
+ point_embedding, image_embeddings = embeddings_attentions[:2]
1843
+ iou_token_out = torch.select(point_embedding, dim=2, index=0)
1844
+ mask_tokens_out = torch.narrow(
1845
+ point_embedding, dim=2, start=1, length=self.num_mask_tokens
1846
+ )
1847
+
1848
+ # Upscale mask embeddings and predict masks using the mask tokens
1849
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
1850
+ batch_size * point_batch_size, num_channels, height, width
1851
+ )
1852
+
1853
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
1854
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
1855
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
1856
+
1857
+ hyper_in_list = []
1858
+ for i in range(self.num_mask_tokens):
1859
+ current_mlp = self.output_hypernetworks_mlps[i]
1860
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
1861
+ hyper_in = torch.stack(hyper_in_list, dim=2)
1862
+
1863
+ _, num_channels, height, width = upscaled_embedding.shape
1864
+ upscaled_embedding = upscaled_embedding.reshape(
1865
+ batch_size, point_batch_size, num_channels, height * width
1866
+ )
1867
+ masks = (hyper_in @ upscaled_embedding).reshape(
1868
+ batch_size, point_batch_size, -1, height, width
1869
+ )
1870
+
1871
+ # Generate mask quality predictions
1872
+ iou_pred = self.iou_prediction_head(iou_token_out)
1873
+
1874
+ # Select the correct mask or masks for output
1875
+ if multimask_output:
1876
+ mask_slice = slice(1, None)
1877
+ else:
1878
+ mask_slice = slice(0, 1)
1879
+ masks = masks[:, :, mask_slice, :, :]
1880
+ iou_pred = iou_pred[:, :, mask_slice]
1881
+
1882
+ outputs = (masks, iou_pred)
1883
+
1884
+ if len(embeddings_attentions) == 2:
1885
+ # transformers==4.54
1886
+ return outputs
1887
+
1888
+ if output_attentions and len(embeddings_attentions) > 2:
1889
+ outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
1890
+ else:
1891
+ outputs = outputs + (None,) # noqa: RUF005
1892
+ return outputs
1893
+
1894
+
1895
+ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
1896
+ """
1897
+ Rewrites the loop in:
1898
+
1899
+ .. code-block:: python
1900
+
1901
+ attention_mask = torch.full(
1902
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
1903
+ )
1904
+ for i in range(1, len(seq)):
1905
+ attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
1906
+ """
1907
+ r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
1908
+ less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
1909
+ less = less0.sum(axis=-1, keepdim=True) + 1
1910
+ sq = less * less.T
1911
+ look = (
1912
+ torch.max(seq.min() == 0, less != less.max())
1913
+ * torch.max(seq.max() == mask.shape[-1], less != less.min())
1914
+ * less
1915
+ )
1916
+ filt = (sq != look**2).to(mask.dtype)
1917
+ return mask * filt
1918
+
1919
+
1920
+ class patched_VisionAttention(torch.nn.Module):
1921
+ _PATCHES_ = ["forward"]
1922
+ _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
1923
+
1924
+ def forward(
1925
+ self,
1926
+ hidden_states: torch.Tensor,
1927
+ cu_seqlens: torch.Tensor,
1928
+ rotary_pos_emb: Optional[torch.Tensor] = None,
1929
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1930
+ ) -> torch.Tensor:
1931
+ seq_length = hidden_states.shape[0]
1932
+ q, k, v = (
1933
+ self.qkv(hidden_states)
1934
+ .reshape(seq_length, 3, self.num_heads, -1)
1935
+ .permute(1, 0, 2, 3)
1936
+ .unbind(0)
1937
+ )
1938
+ if position_embeddings is None:
1939
+ transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
1940
+ "The attention layers in this model are transitioning from "
1941
+ " computing the RoPE embeddings internally "
1942
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
1943
+ "to using externally computed "
1944
+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
1945
+ " In v4.54 `rotary_pos_emb` will be "
1946
+ "removed and `position_embeddings` will be mandatory."
1947
+ )
1948
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1949
+ cos = emb.cos()
1950
+ sin = emb.sin()
1951
+ else:
1952
+ cos, sin = position_embeddings
1953
+ q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
1954
+ q, k, cos, sin
1955
+ )
1956
+
1957
+ attention_mask = torch.full(
1958
+ [1, seq_length, seq_length],
1959
+ torch.finfo(q.dtype).min,
1960
+ device=q.device,
1961
+ dtype=q.dtype,
1962
+ )
1963
+ # for i in range(1, len(cu_seqlens)):
1964
+ # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
1965
+ # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1966
+ attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
1967
+
1968
+ q = q.transpose(0, 1)
1969
+ k = k.transpose(0, 1)
1970
+ v = v.transpose(0, 1)
1971
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1972
+ attn_weights = attn_weights + attention_mask
1973
+ attn_weights = torch.nn.functional.softmax(
1974
+ attn_weights, dim=-1, dtype=torch.float32
1975
+ ).to(q.dtype)
1976
+ attn_output = torch.matmul(attn_weights, v)
1977
+ attn_output = attn_output.transpose(0, 1)
1978
+ attn_output = attn_output.reshape(seq_length, -1)
1979
+ attn_output = self.proj(attn_output)
1980
+ return attn_output
1981
+
1982
+
1983
+ try:
1984
+ import transformers.models.qwen3_moe
1985
+
1986
+ patch_qwen3 = True
1987
+ except ImportError:
1988
+ patch_qwen3 = False
1989
+
1990
+ if patch_qwen3:
1991
+
1992
+ class patched_Qwen3MoeSparseMoeBlock(torch.nn.Module):
1993
+ _PATCHES_ = ["forward", "_forward_expert_loop"]
1994
+ _PATCHED_CLASS_ = (
1995
+ transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
1996
+ )
1997
+
1998
+ def _forward_expert_loop(
1999
+ self,
2000
+ final_hidden_states,
2001
+ expert_mask_idx,
2002
+ hidden_states,
2003
+ routing_weights,
2004
+ expert_idx: int,
2005
+ ):
2006
+ # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
2007
+ idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
2008
+ hidden_dim = hidden_states.shape[-1]
2009
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2010
+ expert_current_state = self.experts[expert_idx](current_state)
2011
+ current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
2012
+ return final_hidden_states.index_add(
2013
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
2014
+ )
2015
+
2016
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
2017
+ """ """
2018
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
2019
+ hidden_states = hidden_states.view(-1, hidden_dim)
2020
+ # router_logits: (batch * sequence_length, n_experts)
2021
+ router_logits = self.gate(hidden_states)
2022
+
2023
+ routing_weights = torch.nn.functional.softmax(
2024
+ router_logits, dim=1, dtype=torch.float
2025
+ )
2026
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
2027
+ if self.norm_topk_prob: # only diff with mixtral sparse moe block!
2028
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
2029
+ # we cast back to the input dtype
2030
+ routing_weights = routing_weights.to(hidden_states.dtype)
2031
+
2032
+ final_hidden_states = torch.zeros(
2033
+ (batch_size * sequence_length, hidden_dim),
2034
+ dtype=hidden_states.dtype,
2035
+ device=hidden_states.device,
2036
+ )
2037
+
2038
+ # One hot encode the selected experts to create an expert mask
2039
+ # this will be used to easily index which expert is going to be sollicitated
2040
+ expert_mask = torch.nn.functional.one_hot(
2041
+ selected_experts, num_classes=self.num_experts
2042
+ ).permute(2, 1, 0)
2043
+
2044
+ # Loop over all available experts in the model
2045
+ # and perform the computation on each expert
2046
+ expert_sum = expert_mask.sum(dim=(-1, -2))
2047
+ # expert_hit = torch.greater(expert_sum, 0).nonzero()
2048
+ # for expert_idx in expert_hit:
2049
+ for expert_idx in range(self.num_experts):
2050
+ # initial code has a squeeze but it is not possible to do that.
2051
+ # expert_mask_idx = expert_mask[expert_idx].squeeze(0)
2052
+ expert_mask_idx = expert_mask[expert_idx]
2053
+ final_hidden_states = torch.cond(
2054
+ (expert_sum[expert_idx] > 0).item(),
2055
+ lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
2056
+ final_hidden_states,
2057
+ expert_mask,
2058
+ hidden_states,
2059
+ routing_weights,
2060
+ expert_idx=_i,
2061
+ ),
2062
+ lambda final_hidden_states, *args: final_hidden_states.clone(),
2063
+ [final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
2064
+ )
2065
+
2066
+ # if expert_sum[expert_idx] > 0:
2067
+ # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
2068
+
2069
+ # Index the correct hidden states and compute the expert hidden state for
2070
+ # the current expert. We need to make sure to multiply the output hidden
2071
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
2072
+ # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2073
+ # current_hidden_states = (
2074
+ # expert_layer(current_state) * routing_weights[top_x, idx, None]
2075
+ # )
2076
+
2077
+ # However `index_add_` only support torch tensors for indexing so we'll use
2078
+ # the `top_x` tensor here.
2079
+ # final_hidden_states.index_add_(
2080
+ # 0, top_x, current_hidden_states.to(hidden_states.dtype)
2081
+ # )
2082
+
2083
+ final_hidden_states = final_hidden_states.reshape(
2084
+ batch_size, sequence_length, hidden_dim
2085
+ )
2086
+ return final_hidden_states, router_logits
2087
+
2088
+
2089
+ try:
2090
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
2091
+
2092
+ patch_gemma3 = True
2093
+ except ImportError:
2094
+ patch_gemma3 = False
2095
+
2096
+
2097
+ if patch_gemma3:
2098
+
2099
+ class patched_Gemma3Model(torch.nn.Module):
2100
+ _PATCHES_ = ["get_placeholder_mask"]
2101
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
2102
+ _PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
2103
+
2104
+ def get_placeholder_mask(
2105
+ self,
2106
+ input_ids: torch.LongTensor,
2107
+ inputs_embeds: torch.FloatTensor,
2108
+ image_features: torch.FloatTensor,
2109
+ ):
2110
+ if input_ids is None:
2111
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
2112
+ torch.tensor(
2113
+ self.config.image_token_id,
2114
+ dtype=torch.long,
2115
+ device=inputs_embeds.device,
2116
+ )
2117
+ )
2118
+ special_image_mask = special_image_mask.all(-1)
2119
+ else:
2120
+ special_image_mask = input_ids == self.config.image_token_id
2121
+
2122
+ n_image_tokens = special_image_mask.sum()
2123
+ special_image_mask = (
2124
+ special_image_mask.unsqueeze(-1)
2125
+ .expand_as(inputs_embeds)
2126
+ .to(inputs_embeds.device)
2127
+ )
2128
+ n_image_features = image_features.shape[0] * image_features.shape[1]
2129
+ # PATCHED: torch._check
2130
+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
2131
+ # raise ValueError( ... )
2132
+ torch._check(
2133
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
2134
+ lambda: (
2135
+ f"Image features and image tokens do not match: tokens: "
2136
+ f"{n_image_tokens}, features {n_image_features}"
2137
+ ),
2138
+ )
2139
+ return special_image_mask