onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,680 @@
1
+ import os
2
+ from typing import Callable, Optional
3
+ import onnx
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from ...export.onnx_plug import EagerDirectReplacementWithOnnx
7
+ from .patch_helper import _is_torchdynamo_exporting
8
+ from ._patch_transformers_attention import patched_sdpa_attention_forward
9
+
10
+ try:
11
+ import transformers.models.qwen2_5_vl
12
+ import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
13
+
14
+ patch_qwen2_5 = True
15
+ except ImportError:
16
+ patch_qwen2_5 = False
17
+
18
+ PLUGS = []
19
+
20
+ if patch_qwen2_5:
21
+ import onnxscript
22
+
23
+ onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
24
+ op = onnxscript.opset22
25
+ msft_op = onnxscript.values.Opset("com.microsoft", 1)
26
+
27
+ @onnxscript.script(opset=onnx_plugs_op)
28
+ def LoopMHAAttention(
29
+ query_states,
30
+ key_states,
31
+ value_states,
32
+ cu_seqlens,
33
+ scaling: float = 0.11180339887498948,
34
+ num_heads: int = 16,
35
+ itype: int = onnx.TensorProto.FLOAT,
36
+ ):
37
+ to_3d_shape = op.Constant(value_ints=[0, 0, -1])
38
+ query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
39
+ output_shape = op.Shape(query_transposed)
40
+ query_3d = op.Reshape(query_transposed, to_3d_shape)
41
+ value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
42
+ key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
43
+ cu_seqlens = op.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
44
+ num_patches = op.Size(cu_seqlens) - 1
45
+ seq_axis = op.Constant(value_ints=[1])
46
+ seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
47
+ # attn_output = op.Slice(value_3d, [0], [0], seq_axis)
48
+ seq_attn = op.SequenceEmpty(dtype=itype)
49
+ for i_patch in range(num_patches):
50
+ i_1d = op.Reshape(i_patch, [1])
51
+ i_plus_1_1d = i_1d + 1
52
+ start = op.Gather(cu_seqlens, i_1d, axis=0)
53
+ end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0)
54
+ query_i = op.Slice(query_3d, start, end, seq_axis_int32)
55
+ key_i = op.Slice(key_3d, start, end, seq_axis_int32)
56
+ value_i = op.Slice(value_3d, start, end, seq_axis_int32)
57
+ mha_output = msft_op.MultiHeadAttention(
58
+ query_i,
59
+ key_i,
60
+ value_i,
61
+ num_heads=num_heads,
62
+ scale=scaling,
63
+ )
64
+ # attn_output = op.Concat(attn_output, mha_output, axis=1)
65
+ seq_attn = op.SequenceInsert(seq_attn, mha_output)
66
+ attn_output = op.ConcatFromSequence(seq_attn, axis=1)
67
+ attn_output_4d = op.Reshape(attn_output, output_shape)
68
+ return attn_output_4d
69
+
70
+ def _add_com_microsoft_opset(function_proto):
71
+ opsets = {d.domain: d.version for d in function_proto.opset_import}
72
+ if "com.microsoft" not in opsets:
73
+ d = function_proto.opset_import.add()
74
+ d.domain = "com.microsoft"
75
+ d.version = 1
76
+ return function_proto
77
+
78
+ @onnxscript.script(opset=onnx_plugs_op)
79
+ def PackedAttention(
80
+ query,
81
+ key,
82
+ value,
83
+ cu_seqlens,
84
+ scaling: float = 0.11180339887498948,
85
+ num_heads: int = 16,
86
+ ):
87
+ num_patches = op.Cast(op.Size(cu_seqlens), to=onnx.TensorProto.INT32) - 1
88
+ starts = op.Slice(cu_seqlens, [0], [-1], [0])
89
+ ends = op.Slice(cu_seqlens, [1], [9223372036854775807], [0])
90
+ lengths = ends - starts
91
+ max_length = op.ReduceMax(lengths, [0], keepdims=0) # max_seqlen
92
+ rows = op.Range(0, num_patches, 1)
93
+ rows_2d = op.Unsqueeze(rows, [1])
94
+ cols = op.Range(0, max_length, 1)
95
+ cols_2d = op.Unsqueeze(cols, [0])
96
+
97
+ position_matrix = op.Cast(rows_2d, to=onnx.TensorProto.INT32) * op.Cast(
98
+ max_length, to=onnx.TensorProto.INT32
99
+ ) + op.Cast(cols_2d, to=onnx.TensorProto.INT32)
100
+ position_matrix_shape = op.Shape(position_matrix)
101
+ token_mask = cols_2d < op.Unsqueeze(lengths, [1])
102
+ token_mask_1d = op.Reshape(token_mask, [-1])
103
+ padded_mask_1d = op.Not(token_mask_1d)
104
+ valid_token_positions = op.Compress(position_matrix, token_mask)
105
+ padded_token_positions = op.Compress(position_matrix, padded_mask_1d)
106
+ token_offset_1d = op.Concat(valid_token_positions, padded_token_positions, axis=0)
107
+ token_offset = op.Reshape(token_offset_1d, position_matrix_shape)
108
+
109
+ query_3d = op.Transpose(op.Squeeze(query, [0]), perm=[1, 0, 2])
110
+ shape_3d = op.Shape(query_3d)
111
+ query_2d = op.Reshape(query_3d, [0, -1])
112
+ key_2d = op.Reshape(op.Transpose(op.Squeeze(key, [0]), perm=[1, 0, 2]), [0, -1])
113
+ value_2d = op.Reshape(op.Transpose(op.Squeeze(value, [0]), perm=[1, 0, 2]), [0, -1])
114
+
115
+ packed_attn_output_2d = msft_op.PackedMultiHeadAttention(
116
+ query_2d,
117
+ key_2d,
118
+ value_2d,
119
+ None,
120
+ op.Cast(token_offset, to=onnx.TensorProto.INT32),
121
+ op.Cast(cu_seqlens, to=onnx.TensorProto.INT32),
122
+ scale=scaling,
123
+ num_heads=num_heads,
124
+ )
125
+ packed_attn_output_3d = op.Reshape(packed_attn_output_2d, shape_3d)
126
+ return op.Unsqueeze(packed_attn_output_3d, [0])
127
+
128
+ def qwen_sdpa_attention(
129
+ query_states: torch.Tensor, # F10s1x16xs47x80
130
+ key_states: torch.Tensor, # F10s1x16xs47x80
131
+ value_states: torch.Tensor, # F10s1x16xs47x80
132
+ cu_seqlens: torch.Tensor, # F7su19
133
+ scaling: float = 0,
134
+ num_heads: int = 16,
135
+ itype: int = onnx.TensorProto.FLOAT,
136
+ ) -> torch.Tensor:
137
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
138
+ splits = [
139
+ torch.split(tensor, lengths.tolist(), dim=2)
140
+ for tensor in (query_states, key_states, value_states)
141
+ ]
142
+
143
+ attn_outputs = [
144
+ patched_sdpa_attention_forward(
145
+ None,
146
+ q,
147
+ k,
148
+ v,
149
+ attention_mask=None,
150
+ scaling=scaling,
151
+ dropout=0.0,
152
+ is_causal=False,
153
+ )[0]
154
+ for q, k, v in zip(*splits)
155
+ ]
156
+ attn_output = torch.cat(attn_outputs, dim=1)
157
+ return attn_output
158
+
159
+ # not ideal
160
+ qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx(
161
+ qwen_sdpa_attention,
162
+ lambda qs, *args, **kwargs: torch.empty(
163
+ (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
164
+ dtype=qs.dtype,
165
+ device=qs.device,
166
+ ),
167
+ _add_com_microsoft_opset(PackedAttention.to_function_proto()),
168
+ n_inputs=4,
169
+ n_outputs=1,
170
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
171
+ name="qwen_sdpa_attention_packed",
172
+ )
173
+ PLUGS.append(qwen_sdpa_attention_packed_versatile)
174
+
175
+ qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx(
176
+ qwen_sdpa_attention,
177
+ lambda qs, *args, **kwargs: torch.empty(
178
+ (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
179
+ dtype=qs.dtype,
180
+ device=qs.device,
181
+ ),
182
+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
183
+ n_inputs=4,
184
+ n_outputs=1,
185
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
186
+ name="qwen_sdpa_attention_loopmha",
187
+ )
188
+ PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
189
+
190
+ class patched_Qwen2_5_VLForConditionalGeneration:
191
+ _PATCHES_ = ["prepare_inputs_for_generation"]
192
+ _PATCHED_CLASS_ = (
193
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
194
+ )
195
+
196
+ def prepare_inputs_for_generation(
197
+ self,
198
+ input_ids,
199
+ past_key_values=None,
200
+ attention_mask=None,
201
+ inputs_embeds=None,
202
+ cache_position=None,
203
+ position_ids=None,
204
+ use_cache=True,
205
+ pixel_values=None,
206
+ pixel_values_videos=None,
207
+ image_grid_thw=None,
208
+ video_grid_thw=None,
209
+ second_per_grid_ts=None,
210
+ **kwargs,
211
+ ):
212
+ # Overwritten -- in specific circumstances we don't want to
213
+ # forward image inputs to the model
214
+ from transformers.generation import GenerationMixin
215
+
216
+ model_inputs = GenerationMixin.prepare_inputs_for_generation(
217
+ self,
218
+ input_ids,
219
+ past_key_values=past_key_values,
220
+ attention_mask=attention_mask,
221
+ inputs_embeds=inputs_embeds,
222
+ cache_position=cache_position,
223
+ position_ids=position_ids,
224
+ pixel_values=pixel_values,
225
+ pixel_values_videos=pixel_values_videos,
226
+ image_grid_thw=image_grid_thw,
227
+ video_grid_thw=video_grid_thw,
228
+ second_per_grid_ts=second_per_grid_ts,
229
+ use_cache=use_cache,
230
+ **kwargs,
231
+ )
232
+
233
+ # Qwen2-5-VL position_ids are prepared with rope_deltas
234
+ if position_ids is None:
235
+ # Calculate RoPE index once per generation in the pre-fill stage only.
236
+ # When compiling, we can't check tensor values thus we check only input length
237
+ # It is safe to assume that `length!=1` means we're in pre-fill
238
+ # because compiled models currently cannot do assisted decoding
239
+ if cache_position[0] == 0 or self.model.rope_deltas is None:
240
+ vision_positions, rope_deltas = self.model.get_rope_index(
241
+ model_inputs.get("input_ids", None),
242
+ image_grid_thw=image_grid_thw,
243
+ video_grid_thw=video_grid_thw,
244
+ second_per_grid_ts=second_per_grid_ts,
245
+ attention_mask=attention_mask,
246
+ )
247
+ self.model.rope_deltas = rope_deltas
248
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
249
+ elif (
250
+ "position_ids" in model_inputs and model_inputs["position_ids"] is not None
251
+ ):
252
+ batch_size, seq_length = model_inputs["position_ids"].shape
253
+ device = model_inputs["position_ids"].device
254
+ position_ids = torch.arange(seq_length, device=device)
255
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
256
+ delta = cache_position[0] + self.model.rope_deltas
257
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
258
+ vision_positions = position_ids + delta.expand_as(position_ids)
259
+
260
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
261
+ if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
262
+ text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
263
+ None, None, :
264
+ ]
265
+ else:
266
+ text_positions = model_inputs["position_ids"][None, ...]
267
+ # text_positions = model_inputs["position_ids"][None, ...]
268
+ assert vision_positions is not None, "vision_positions are missing"
269
+ model_inputs["position_ids"] = torch.cat(
270
+ [text_positions, vision_positions], dim=0
271
+ )
272
+
273
+ if cache_position[0] != 0:
274
+ model_inputs["pixel_values"] = None
275
+ model_inputs["pixel_values_videos"] = None
276
+
277
+ return model_inputs
278
+
279
+ class patched_Qwen2_5_VisionTransformerPretrainedModel:
280
+ _PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
281
+ _PATCHED_CLASS_ = (
282
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
283
+ )
284
+
285
+ def rot_pos_emb(self, grid_thw):
286
+ pos_ids = []
287
+ for thw_ in grid_thw:
288
+ # PATCHED: avoid unbind
289
+ t = thw_[0]
290
+ h = thw_[1]
291
+ w = thw_[2]
292
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
293
+ hpos_ids = hpos_ids.reshape(
294
+ h // self.spatial_merge_size,
295
+ self.spatial_merge_size,
296
+ w // self.spatial_merge_size,
297
+ self.spatial_merge_size,
298
+ )
299
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
300
+ hpos_ids = hpos_ids.flatten()
301
+
302
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
303
+ wpos_ids = wpos_ids.reshape(
304
+ h // self.spatial_merge_size,
305
+ self.spatial_merge_size,
306
+ w // self.spatial_merge_size,
307
+ self.spatial_merge_size,
308
+ )
309
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
310
+ wpos_ids = wpos_ids.flatten()
311
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
312
+ pos_ids = torch.cat(pos_ids, dim=0)
313
+ max_grid_size = grid_thw[:, 1:].max()
314
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
315
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
316
+ return rotary_pos_emb
317
+
318
+ def get_window_index(self, grid_thw):
319
+ window_index: list = [] # type: ignore[annotation-unchecked]
320
+ # PATCHED
321
+ cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
322
+ window_index_id = 0
323
+ vit_merger_window_size = (
324
+ self.window_size // self.spatial_merge_size // self.patch_size
325
+ )
326
+
327
+ for _thw in grid_thw:
328
+ # PATCHED: avoid unbind
329
+ grid_t = _thw[0]
330
+ grid_h = _thw[1]
331
+ grid_w = _thw[2]
332
+ llm_grid_h, llm_grid_w = (
333
+ grid_h // self.spatial_merge_size,
334
+ grid_w // self.spatial_merge_size,
335
+ )
336
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
337
+ grid_t, llm_grid_h, llm_grid_w
338
+ )
339
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
340
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
341
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
342
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
343
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
344
+ index_padded = index_padded.reshape(
345
+ grid_t,
346
+ num_windows_h,
347
+ vit_merger_window_size,
348
+ num_windows_w,
349
+ vit_merger_window_size,
350
+ )
351
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
352
+ grid_t,
353
+ num_windows_h * num_windows_w,
354
+ vit_merger_window_size,
355
+ vit_merger_window_size,
356
+ )
357
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
358
+ index_padded = index_padded.reshape(-1)
359
+ index_new = index_padded[index_padded != -100]
360
+ window_index.append(index_new + window_index_id)
361
+ cu_seqlens_tmp = (
362
+ seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
363
+ )
364
+ # PATCHED
365
+ # cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
366
+ cu_window_seqlens.append(cu_seqlens_tmp)
367
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
368
+ window_index = torch.cat(window_index, dim=0)
369
+
370
+ return window_index, torch.cat(cu_window_seqlens, dim=0)
371
+
372
+ def forward(
373
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
374
+ ) -> torch.Tensor:
375
+ """
376
+ Args:
377
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
378
+ The final hidden states of the model.
379
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
380
+ The temporal, height and width of feature shape of each image in LLM.
381
+
382
+ Returns:
383
+ `torch.Tensor`: hidden_states.
384
+ """
385
+ hidden_states = self.patch_embed(hidden_states)
386
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
387
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
388
+ # PATCHED
389
+ # cu_window_seqlens = torch.tensor(
390
+ # cu_window_seqlens,
391
+ # device=hidden_states.device,
392
+ # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
393
+ # )
394
+ cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
395
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
396
+
397
+ seq_len, _ = hidden_states.size()
398
+ hidden_states = hidden_states.reshape(
399
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
400
+ )
401
+ hidden_states = hidden_states[window_index, :, :]
402
+ hidden_states = hidden_states.reshape(seq_len, -1)
403
+ rotary_pos_emb = rotary_pos_emb.reshape(
404
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
405
+ )
406
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
407
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
408
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
409
+ position_embeddings = (emb.cos(), emb.sin())
410
+
411
+ cu_seqlens = torch.repeat_interleave(
412
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
413
+ ).cumsum(
414
+ dim=0,
415
+ # Select dtype based on the following factors:
416
+ # - FA2 requires that cu_seqlens_q must have dtype int32
417
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype
418
+ # as grid_thw
419
+ # See https://github.com/huggingface/transformers/pull/34852
420
+ # for more information
421
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
422
+ )
423
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
424
+
425
+ for layer_num, blk in enumerate(self.blocks):
426
+ if layer_num in self.fullatt_block_indexes:
427
+ cu_seqlens_now = cu_seqlens
428
+ else:
429
+ cu_seqlens_now = cu_window_seqlens
430
+
431
+ hidden_states = blk(
432
+ hidden_states,
433
+ cu_seqlens=cu_seqlens_now,
434
+ position_embeddings=position_embeddings,
435
+ **kwargs,
436
+ )
437
+
438
+ hidden_states = self.merger(hidden_states)
439
+ reverse_indices = torch.argsort(window_index)
440
+ hidden_states = hidden_states[reverse_indices, :]
441
+ return hidden_states
442
+
443
+ class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
444
+ def forward(
445
+ self,
446
+ start_end,
447
+ query_states,
448
+ key_states,
449
+ value_states,
450
+ scaling: float = 1.0,
451
+ dropout: float = 0.0,
452
+ **kwargs,
453
+ ):
454
+ a = start_end[0].item()
455
+ b = start_end[1].item()
456
+ q = query_states[:, :, a:b, :]
457
+ k = key_states[:, :, a:b, :]
458
+ v = value_states[:, :, a:b, :]
459
+ return patched_sdpa_attention_forward(
460
+ self,
461
+ q,
462
+ k,
463
+ v,
464
+ attention_mask=None,
465
+ scaling=scaling,
466
+ dropout=dropout,
467
+ is_causal=False,
468
+ **kwargs,
469
+ )[0]
470
+
471
+ class patched_Qwen2_5_VLVisionAttention:
472
+ _PATCHES_ = ["forward"]
473
+ _PATCHED_CLASS_ = (
474
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
475
+ )
476
+ STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731
477
+ "QWEN25ATTENTION", "PACKED"
478
+ )
479
+
480
+ def forward(
481
+ self,
482
+ hidden_states: torch.Tensor,
483
+ cu_seqlens: torch.Tensor,
484
+ rotary_pos_emb: Optional[torch.Tensor] = None,
485
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
486
+ **kwargs,
487
+ ) -> torch.Tensor:
488
+ seq_length = hidden_states.shape[0]
489
+ # PATCHED: avoid the use of unbind
490
+ qkv = (
491
+ self.qkv(hidden_states)
492
+ .reshape(seq_length, 3, self.num_heads, -1)
493
+ .permute(1, 0, 2, 3)
494
+ )
495
+
496
+ query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
497
+ cos, sin = position_embeddings
498
+
499
+ # This part should be moved into the loop
500
+ # iteration to enable fusion inside the loop.
501
+ query_states, key_states = (
502
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
503
+ query_states, key_states, cos, sin
504
+ )
505
+ )
506
+
507
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
508
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
509
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
510
+
511
+ attention_interface: Callable = (
512
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
513
+ )
514
+ if self.config._attn_implementation != "eager":
515
+ # PATCHED
516
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[
517
+ # self.config._attn_implementation]
518
+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
519
+ self.config._attn_implementation
520
+ ]
521
+
522
+ is_sdpa = (
523
+ attention_interface
524
+ is transformers.integrations.sdpa_attention.sdpa_attention_forward
525
+ or attention_interface is patched_sdpa_attention_forward
526
+ )
527
+ attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
528
+ if is_sdpa and attention_strategy in "PACKED":
529
+ attn_output = qwen_sdpa_attention_packed_versatile(
530
+ query_states,
531
+ key_states,
532
+ value_states,
533
+ cu_seqlens,
534
+ self.scaling,
535
+ self.num_heads,
536
+ )
537
+ elif _is_torchdynamo_exporting():
538
+ if self.config._attn_implementation == "flash_attention_2":
539
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
540
+ attn_output = torch.onnx.ops.symbolic(
541
+ "custom::qwen25_flash_attention",
542
+ (
543
+ query_states,
544
+ key_states,
545
+ value_states,
546
+ cu_seqlens,
547
+ cu_seqlens,
548
+ max_seqlen,
549
+ max_seqlen,
550
+ torch.tensor(self.scaling, dtype=torch.float32),
551
+ ),
552
+ dtype=query_states.dtype,
553
+ shape=(
554
+ query_states.shape[0], # batch_size
555
+ query_states.shape[2], # sequence_length (total patches)
556
+ query_states.shape[1], # num_heads
557
+ query_states.shape[3], # head_size
558
+ ),
559
+ version=1,
560
+ )
561
+ elif is_sdpa and attention_strategy == "LOOPMHA":
562
+ attn_output = qwen_sdpa_attention_loopmha_versatile(
563
+ query_states,
564
+ key_states,
565
+ value_states,
566
+ cu_seqlens,
567
+ self.scaling,
568
+ self.num_heads,
569
+ (
570
+ onnx.TensorProto.FLOAT
571
+ if query_states.dtype == torch.float32
572
+ else (
573
+ onnx.TensorProto.FLOAT16
574
+ if query_states.dtype == torch.float16
575
+ else onnx.TensorProto.BFLOAT16
576
+ )
577
+ ),
578
+ )
579
+
580
+ # to rewrite later with a for loop
581
+ # def _iteration(start_end, query_states, key_states, value_states):
582
+ # return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
583
+ # self,
584
+ # start_end,
585
+ # query_states,
586
+ # key_states,
587
+ # value_states,
588
+ # scaling=self.scaling,
589
+ # dropout=0.0 if not self.training else self.attention_dropout,
590
+ # )
591
+
592
+ # starts = cu_seqlens[:-1]
593
+ # ends = cu_seqlens[1:]
594
+ # torch._check(starts.shape[0] > 0)
595
+ # torch._check(ends.shape[0] > 0)
596
+ # starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
597
+ # attn_outputs = [
598
+ # _iteration(start_end, query_states, key_states, value_states)
599
+ # for start_end in starts_ends
600
+ # ]
601
+ # attn_output = torch.cat(attn_outputs, dim=1)
602
+ elif is_sdpa and attention_strategy == "BIGMASK":
603
+ # make square mask
604
+ indices = torch.arange(
605
+ cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
606
+ )
607
+ dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
608
+ cu_seqlens.dtype
609
+ )
610
+ dot = dot.sum(dim=0)
611
+ mask = dot.unsqueeze(1) - dot.unsqueeze(0)
612
+ bool_mask = mask == 0
613
+ bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
614
+
615
+ torch._check(bool_mask.shape[2] == key_states.shape[2])
616
+ torch._check(bool_mask.shape[3] == key_states.shape[2])
617
+
618
+ attn_output, _ = attention_interface(
619
+ self,
620
+ query_states,
621
+ key_states,
622
+ value_states,
623
+ attention_mask=bool_mask,
624
+ scaling=self.scaling,
625
+ dropout=0.0 if not self.training else self.attention_dropout,
626
+ is_causal=False,
627
+ **kwargs,
628
+ )
629
+ else:
630
+ raise NotImplementedError(
631
+ f"No corresponding export strategy for "
632
+ f"{attention_strategy!r}, "
633
+ f"(use QWEN25ATTENTION to change it), and attention_interface="
634
+ f"{attention_interface!r} (use sdpa)"
635
+ )
636
+ elif self.config._attn_implementation == "flash_attention_2":
637
+ # Flash Attention 2: Use cu_seqlens for variable length attention
638
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
639
+ attn_output, _ = attention_interface(
640
+ self,
641
+ query_states,
642
+ key_states,
643
+ value_states,
644
+ attention_mask=None,
645
+ scaling=self.scaling,
646
+ dropout=0.0 if not self.training else self.attention_dropout,
647
+ cu_seq_lens_q=cu_seqlens,
648
+ cu_seq_lens_k=cu_seqlens,
649
+ max_length_q=max_seqlen,
650
+ max_length_k=max_seqlen,
651
+ is_causal=False,
652
+ **kwargs,
653
+ )
654
+ else:
655
+ # Other implementations: Process each chunk separately
656
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
657
+ splits = [
658
+ torch.split(tensor, lengths.tolist(), dim=2)
659
+ for tensor in (query_states, key_states, value_states)
660
+ ]
661
+
662
+ attn_outputs = [
663
+ attention_interface(
664
+ self,
665
+ q,
666
+ k,
667
+ v,
668
+ attention_mask=None,
669
+ scaling=self.scaling,
670
+ dropout=0.0 if not self.training else self.attention_dropout,
671
+ is_causal=False,
672
+ **kwargs,
673
+ )[0]
674
+ for q, k, v in zip(*splits)
675
+ ]
676
+ attn_output = torch.cat(attn_outputs, dim=1)
677
+
678
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
679
+ attn_output = self.proj(attn_output)
680
+ return attn_output