onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,735 @@
1
+ import os
2
+ from typing import Callable, Optional, Tuple
3
+ import onnx
4
+ import onnx.helper as oh
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from ...export.onnx_plug import EagerDirectReplacementWithOnnx
8
+ from ...helpers.torch_helper import torch_dtype_to_onnx_dtype
9
+ from .patch_helper import _is_torchdynamo_exporting
10
+ from ._patch_transformers_attention import patched_sdpa_attention_forward
11
+
12
+ try:
13
+ import transformers.models.qwen2_5_vl
14
+ import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
15
+
16
+ patch_qwen2_5 = True
17
+ except ImportError:
18
+ patch_qwen2_5 = False
19
+
20
+ PLUGS = []
21
+
22
+ if patch_qwen2_5:
23
+ import onnxscript
24
+
25
+ onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
26
+ op = onnxscript.opset22
27
+ op24 = onnxscript.onnx_opset.opset24
28
+ msft_op = onnxscript.values.Opset("com.microsoft", 1)
29
+ STOPAT = (
30
+ int(os.environ.get("STOPAT", None))
31
+ if os.environ.get("STOPAT", None) is not None
32
+ else None
33
+ )
34
+
35
+ def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto:
36
+ opsets = {d.domain: d.version for d in function_proto.opset_import}
37
+ if "com.microsoft" not in opsets:
38
+ d = function_proto.opset_import.add()
39
+ d.domain = "com.microsoft"
40
+ d.version = 1
41
+ return function_proto
42
+
43
+ def _update_sequence_type(
44
+ itype: int, function_proto: onnx.FunctionProto
45
+ ) -> onnx.FunctionProto:
46
+ proto = oh.make_function(
47
+ function_proto.domain,
48
+ function_proto.name,
49
+ function_proto.input,
50
+ function_proto.output,
51
+ [
52
+ (
53
+ oh.make_node("SequenceEmpty", node.input, node.output, dtype=itype)
54
+ if node.op_type == "SequenceEmpty"
55
+ else node
56
+ )
57
+ for node in function_proto.node
58
+ ],
59
+ attributes=function_proto.attribute,
60
+ attribute_protos=function_proto.attribute_proto,
61
+ opset_imports=function_proto.opset_import,
62
+ )
63
+ return proto
64
+
65
+ @onnxscript.script(opset=onnx_plugs_op)
66
+ def LoopMHAAttention(
67
+ query_states,
68
+ key_states,
69
+ value_states,
70
+ cu_seqlens,
71
+ scaling: float = 0.11180339887498948,
72
+ num_heads: int = 16,
73
+ ):
74
+ to_3d_shape = op.Constant(value_ints=[0, 0, -1])
75
+ query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
76
+ output_shape = op.Shape(query_transposed)
77
+ query_3d = op.Reshape(query_transposed, to_3d_shape)
78
+ value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
79
+ key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
80
+ cu_seqlens = op.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
81
+ num_patches = op.Size(cu_seqlens) - 1
82
+ seq_axis = op.Constant(value_ints=[1])
83
+ seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
84
+ # attn_output = op.Slice(value_3d, [0], [0], seq_axis)
85
+ seq_attn = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
86
+ for i_patch in range(num_patches):
87
+ i_1d = op.Reshape(i_patch, [1])
88
+ i_plus_1_1d = i_1d + 1
89
+ start = op.Gather(cu_seqlens, i_1d, axis=0)
90
+ end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0)
91
+ query_i = op.Slice(query_3d, start, end, seq_axis_int32)
92
+ key_i = op.Slice(key_3d, start, end, seq_axis_int32)
93
+ value_i = op.Slice(value_3d, start, end, seq_axis_int32)
94
+ mha_output = msft_op.MultiHeadAttention(
95
+ query_i, key_i, value_i, num_heads=num_heads, scale=scaling
96
+ )
97
+ # attn_output = op.Concat(attn_output, mha_output, axis=1)
98
+ seq_attn = op.SequenceInsert(seq_attn, mha_output)
99
+ attn_output = op.ConcatFromSequence(seq_attn, axis=1)
100
+ attn_output_4d = op.Reshape(attn_output, output_shape)
101
+ return attn_output_4d
102
+
103
+ @onnxscript.script(opset=onnx_plugs_op)
104
+ def LoopAttention24(
105
+ query_states,
106
+ key_states,
107
+ value_states,
108
+ cu_seqlens,
109
+ scaling: float = 0.11180339887498948,
110
+ num_heads: int = 16,
111
+ ):
112
+ to_3d_shape = op24.Constant(value_ints=[0, 0, -1])
113
+ query_transposed = op24.Transpose(query_states, perm=[0, 2, 1, 3])
114
+ output_shape = op24.Shape(query_transposed)
115
+ query_3d = op24.Reshape(query_transposed, to_3d_shape)
116
+ value_3d = op24.Reshape(op24.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
117
+ key_3d = op24.Reshape(op24.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
118
+ cu_seqlens = op24.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
119
+ num_patches = op24.Size(cu_seqlens) - 1
120
+ seq_axis = op24.Constant(value_ints=[1])
121
+ seq_axis_int32 = op24.Cast(seq_axis, to=onnx.TensorProto.INT32)
122
+ seq_attn = op24.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
123
+ for i_patch in range(num_patches):
124
+ i_1d = op24.Reshape(i_patch, [1])
125
+ i_plus_1_1d = i_1d + 1
126
+ start = op24.Gather(cu_seqlens, i_1d, axis=0)
127
+ end = op24.Gather(cu_seqlens, i_plus_1_1d, axis=0)
128
+ query_i = op24.Slice(query_3d, start, end, seq_axis_int32)
129
+ key_i = op24.Slice(key_3d, start, end, seq_axis_int32)
130
+ value_i = op24.Slice(value_3d, start, end, seq_axis_int32)
131
+ mha_output = op24.Attention(
132
+ query_i,
133
+ key_i,
134
+ value_i,
135
+ scale=scaling,
136
+ q_num_heads=num_heads,
137
+ kv_num_heads=num_heads,
138
+ softmax_precision=onnx.TensorProto.FLOAT,
139
+ )
140
+ seq_attn = op24.SequenceInsert(seq_attn, mha_output)
141
+ attn_output = op24.ConcatFromSequence(seq_attn, axis=1)
142
+ attn_output_4d = op24.Reshape(attn_output, output_shape)
143
+ return attn_output_4d
144
+
145
+ @onnxscript.script(opset=onnx_plugs_op)
146
+ def PackedAttention(
147
+ query,
148
+ key,
149
+ value,
150
+ cu_seqlens,
151
+ scaling: float = 0.11180339887498948,
152
+ num_heads: int = 16,
153
+ ):
154
+ num_patches = op.Cast(op.Size(cu_seqlens), to=onnx.TensorProto.INT32) - 1
155
+ starts = op.Slice(cu_seqlens, [0], [-1], [0])
156
+ ends = op.Slice(cu_seqlens, [1], [9223372036854775807], [0])
157
+ lengths = ends - starts
158
+ max_length = op.ReduceMax(lengths, [0], keepdims=0) # max_seqlen
159
+ rows = op.Range(0, num_patches, 1)
160
+ rows_2d = op.Unsqueeze(rows, [1])
161
+ cols = op.Range(0, max_length, 1)
162
+ cols_2d = op.Unsqueeze(cols, [0])
163
+
164
+ position_matrix = op.Cast(rows_2d, to=onnx.TensorProto.INT32) * op.Cast(
165
+ max_length, to=onnx.TensorProto.INT32
166
+ ) + op.Cast(cols_2d, to=onnx.TensorProto.INT32)
167
+ position_matrix_shape = op.Shape(position_matrix)
168
+ token_mask = cols_2d < op.Unsqueeze(lengths, [1])
169
+ token_mask_1d = op.Reshape(token_mask, [-1])
170
+ padded_mask_1d = op.Not(token_mask_1d)
171
+ valid_token_positions = op.Compress(position_matrix, token_mask)
172
+ padded_token_positions = op.Compress(position_matrix, padded_mask_1d)
173
+ token_offset_1d = op.Concat(valid_token_positions, padded_token_positions, axis=0)
174
+ token_offset = op.Reshape(token_offset_1d, position_matrix_shape)
175
+
176
+ query_3d = op.Transpose(op.Squeeze(query, [0]), perm=[1, 0, 2])
177
+ shape_3d = op.Shape(query_3d)
178
+ query_2d = op.Reshape(query_3d, [0, -1])
179
+ key_2d = op.Reshape(op.Transpose(op.Squeeze(key, [0]), perm=[1, 0, 2]), [0, -1])
180
+ value_2d = op.Reshape(op.Transpose(op.Squeeze(value, [0]), perm=[1, 0, 2]), [0, -1])
181
+
182
+ packed_attn_output_2d = msft_op.PackedMultiHeadAttention(
183
+ query_2d,
184
+ key_2d,
185
+ value_2d,
186
+ None,
187
+ op.Cast(token_offset, to=onnx.TensorProto.INT32),
188
+ op.Cast(cu_seqlens, to=onnx.TensorProto.INT32),
189
+ scale=scaling,
190
+ num_heads=num_heads,
191
+ )
192
+ packed_attn_output_3d = op.Reshape(packed_attn_output_2d, shape_3d)
193
+ return op.Unsqueeze(packed_attn_output_3d, [0])
194
+
195
+ def qwen_sdpa_attention(
196
+ query_states: torch.Tensor, # F10s1x16xs47x80
197
+ key_states: torch.Tensor, # F10s1x16xs47x80
198
+ value_states: torch.Tensor, # F10s1x16xs47x80
199
+ cu_seqlens: torch.Tensor, # F7su19
200
+ scaling: float = 0,
201
+ num_heads: int = 16,
202
+ ) -> torch.Tensor:
203
+ """
204
+ The loop can be removed with the following code
205
+ but it hits memory overflow for big inputs.
206
+
207
+ .. code-block:: python
208
+
209
+ # make square mask
210
+ indices = torch.arange(
211
+ cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
212
+ )
213
+ dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
214
+ cu_seqlens.dtype
215
+ )
216
+ dot = dot.sum(dim=0)
217
+ mask = dot.unsqueeze(1) - dot.unsqueeze(0)
218
+ bool_mask = mask == 0
219
+ bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
220
+
221
+ torch._check(bool_mask.shape[2] == key_states.shape[2])
222
+ torch._check(bool_mask.shape[3] == key_states.shape[2])
223
+
224
+ attn_output, _ = attention_interface(
225
+ self,
226
+ query_states,
227
+ key_states,
228
+ value_states,
229
+ attention_mask=bool_mask,
230
+ scaling=self.scaling,
231
+ dropout=0.0 if not self.training else self.attention_dropout,
232
+ is_causal=False,
233
+ **kwargs,
234
+ )
235
+ """
236
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
237
+ splits = [
238
+ torch.split(tensor, lengths.tolist(), dim=2)
239
+ for tensor in (query_states, key_states, value_states)
240
+ ]
241
+
242
+ attn_outputs = [
243
+ patched_sdpa_attention_forward(
244
+ None,
245
+ q,
246
+ k,
247
+ v,
248
+ attention_mask=None,
249
+ scaling=scaling,
250
+ dropout=0.0,
251
+ is_causal=False,
252
+ )[0]
253
+ for q, k, v in zip(*splits)
254
+ ]
255
+ attn_output = torch.cat(attn_outputs, dim=1)
256
+ return attn_output
257
+
258
+ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259
+ first_tensor = next(a for a in args if a is not None)
260
+ dtype = first_tensor.dtype
261
+ strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
262
+ itype = torch_dtype_to_onnx_dtype(dtype)
263
+ if strategy is not None:
264
+ return strategy, itype
265
+ if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
266
+ if opset >= 24:
267
+ return "LOOPA24", itype
268
+ return "LOOPMHA", itype
269
+ if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270
+ # first_tensor may be a SymbolicTensor (onnx).
271
+ # is_cuda is not available.
272
+ if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
273
+ return "PACKED", itype
274
+ return "LOOPMHA", itype
275
+ raise AssertionError(
276
+ f"Unable to handle type {torch.dtype} (itype={itype}) "
277
+ f"on device {torch.device} with opset={opset}"
278
+ )
279
+
280
+ qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
281
+ qwen_sdpa_attention,
282
+ lambda qs, *args, **kwargs: torch.empty(
283
+ (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
284
+ dtype=qs.dtype,
285
+ device=qs.device,
286
+ ),
287
+ {
288
+ ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
289
+ PackedAttention.to_function_proto()
290
+ ),
291
+ ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
292
+ ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
293
+ onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
294
+ ),
295
+ ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
296
+ LoopMHAAttention.to_function_proto()
297
+ ),
298
+ ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
299
+ onnx.TensorProto.FLOAT16,
300
+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
301
+ ),
302
+ },
303
+ n_inputs=4,
304
+ n_outputs=1,
305
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16),
306
+ name="qwen_sdpa_attention_versatile",
307
+ version_selector=qwen_version_selector,
308
+ )
309
+ PLUGS.append(qwen_sdpa_attention_versatile)
310
+
311
+ class patched_Qwen2_5_VLForConditionalGeneration:
312
+ _PATCHES_ = ["prepare_inputs_for_generation"]
313
+ _PATCHED_CLASS_ = (
314
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
315
+ )
316
+
317
+ def prepare_inputs_for_generation(
318
+ self,
319
+ input_ids,
320
+ past_key_values=None,
321
+ attention_mask=None,
322
+ inputs_embeds=None,
323
+ cache_position=None,
324
+ position_ids=None,
325
+ use_cache=True,
326
+ pixel_values=None,
327
+ pixel_values_videos=None,
328
+ image_grid_thw=None,
329
+ video_grid_thw=None,
330
+ second_per_grid_ts=None,
331
+ **kwargs,
332
+ ):
333
+ # Overwritten -- in specific circumstances we don't want to
334
+ # forward image inputs to the model
335
+ from transformers.generation import GenerationMixin
336
+
337
+ model_inputs = GenerationMixin.prepare_inputs_for_generation(
338
+ self,
339
+ input_ids,
340
+ past_key_values=past_key_values,
341
+ attention_mask=attention_mask,
342
+ inputs_embeds=inputs_embeds,
343
+ cache_position=cache_position,
344
+ position_ids=position_ids,
345
+ pixel_values=pixel_values,
346
+ pixel_values_videos=pixel_values_videos,
347
+ image_grid_thw=image_grid_thw,
348
+ video_grid_thw=video_grid_thw,
349
+ second_per_grid_ts=second_per_grid_ts,
350
+ use_cache=use_cache,
351
+ **kwargs,
352
+ )
353
+
354
+ # Qwen2-5-VL position_ids are prepared with rope_deltas
355
+ if position_ids is None:
356
+ # Calculate RoPE index once per generation in the pre-fill stage only.
357
+ # When compiling, we can't check tensor values thus we check only input length
358
+ # It is safe to assume that `length!=1` means we're in pre-fill
359
+ # because compiled models currently cannot do assisted decoding
360
+ if cache_position[0] == 0 or self.model.rope_deltas is None:
361
+ vision_positions, rope_deltas = self.model.get_rope_index(
362
+ model_inputs.get("input_ids", None),
363
+ image_grid_thw=image_grid_thw,
364
+ video_grid_thw=video_grid_thw,
365
+ second_per_grid_ts=second_per_grid_ts,
366
+ attention_mask=attention_mask,
367
+ )
368
+ self.model.rope_deltas = rope_deltas
369
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
370
+ elif (
371
+ "position_ids" in model_inputs and model_inputs["position_ids"] is not None
372
+ ):
373
+ batch_size, seq_length = model_inputs["position_ids"].shape
374
+ device = model_inputs["position_ids"].device
375
+ position_ids = torch.arange(seq_length, device=device)
376
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
377
+ delta = cache_position[0] + self.model.rope_deltas
378
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
379
+ vision_positions = position_ids + delta.expand_as(position_ids)
380
+
381
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
382
+ if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
383
+ text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
384
+ None, None, :
385
+ ]
386
+ else:
387
+ text_positions = model_inputs["position_ids"][None, ...]
388
+ # text_positions = model_inputs["position_ids"][None, ...]
389
+ assert vision_positions is not None, "vision_positions are missing"
390
+ model_inputs["position_ids"] = torch.cat(
391
+ [text_positions, vision_positions], dim=0
392
+ )
393
+
394
+ if cache_position[0] != 0:
395
+ model_inputs["pixel_values"] = None
396
+ model_inputs["pixel_values_videos"] = None
397
+
398
+ return model_inputs
399
+
400
+ class patched_Qwen2_5_VisionTransformerPretrainedModel:
401
+ _PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
402
+ _PATCHED_CLASS_ = (
403
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
404
+ )
405
+
406
+ def rot_pos_emb(self, grid_thw):
407
+ pos_ids = []
408
+ for thw_ in grid_thw:
409
+ # PATCHED: avoid unbind
410
+ t = thw_[0]
411
+ h = thw_[1]
412
+ w = thw_[2]
413
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
414
+ hpos_ids = hpos_ids.reshape(
415
+ h // self.spatial_merge_size,
416
+ self.spatial_merge_size,
417
+ w // self.spatial_merge_size,
418
+ self.spatial_merge_size,
419
+ )
420
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
421
+ hpos_ids = hpos_ids.flatten()
422
+
423
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
424
+ wpos_ids = wpos_ids.reshape(
425
+ h // self.spatial_merge_size,
426
+ self.spatial_merge_size,
427
+ w // self.spatial_merge_size,
428
+ self.spatial_merge_size,
429
+ )
430
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
431
+ wpos_ids = wpos_ids.flatten()
432
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
433
+ pos_ids = torch.cat(pos_ids, dim=0)
434
+ max_grid_size = grid_thw[:, 1:].max()
435
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
436
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
437
+ return rotary_pos_emb
438
+
439
+ def get_window_index(self, grid_thw):
440
+ window_index: list = [] # type: ignore[annotation-unchecked]
441
+ # PATCHED
442
+ cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
443
+ window_index_id = 0
444
+ vit_merger_window_size = (
445
+ self.window_size // self.spatial_merge_size // self.patch_size
446
+ )
447
+
448
+ for _thw in grid_thw:
449
+ # PATCHED: avoid unbind
450
+ grid_t = _thw[0]
451
+ grid_h = _thw[1]
452
+ grid_w = _thw[2]
453
+ llm_grid_h, llm_grid_w = (
454
+ grid_h // self.spatial_merge_size,
455
+ grid_w // self.spatial_merge_size,
456
+ )
457
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
458
+ grid_t, llm_grid_h, llm_grid_w
459
+ )
460
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
461
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
462
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
463
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
464
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
465
+ index_padded = index_padded.reshape(
466
+ grid_t,
467
+ num_windows_h,
468
+ vit_merger_window_size,
469
+ num_windows_w,
470
+ vit_merger_window_size,
471
+ )
472
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
473
+ grid_t,
474
+ num_windows_h * num_windows_w,
475
+ vit_merger_window_size,
476
+ vit_merger_window_size,
477
+ )
478
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
479
+ index_padded = index_padded.reshape(-1)
480
+ index_new = index_padded[index_padded != -100]
481
+ window_index.append(index_new + window_index_id)
482
+ cu_seqlens_tmp = (
483
+ seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
484
+ )
485
+ # PATCHED
486
+ # cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
487
+ cu_window_seqlens.append(cu_seqlens_tmp)
488
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
489
+ window_index = torch.cat(window_index, dim=0)
490
+
491
+ return window_index, torch.cat(cu_window_seqlens, dim=0)
492
+
493
+ def forward(
494
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
495
+ ) -> torch.Tensor:
496
+ """
497
+ Args:
498
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
499
+ The final hidden states of the model.
500
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
501
+ The temporal, height and width of feature shape of each image in LLM.
502
+
503
+ Returns:
504
+ `torch.Tensor`: hidden_states.
505
+ """
506
+ hidden_states = self.patch_embed(hidden_states)
507
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
508
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
509
+ # PATCHED
510
+ # cu_window_seqlens = torch.tensor(
511
+ # cu_window_seqlens,
512
+ # device=hidden_states.device,
513
+ # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
514
+ # )
515
+ cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
516
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
517
+
518
+ seq_len, _ = hidden_states.size()
519
+ hidden_states = hidden_states.reshape(
520
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
521
+ )
522
+ hidden_states = hidden_states[window_index, :, :]
523
+ hidden_states = hidden_states.reshape(seq_len, -1)
524
+ rotary_pos_emb = rotary_pos_emb.reshape(
525
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
526
+ )
527
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
528
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
529
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
530
+ position_embeddings = (emb.cos(), emb.sin())
531
+
532
+ cu_seqlens = torch.repeat_interleave(
533
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
534
+ ).cumsum(
535
+ dim=0,
536
+ # Select dtype based on the following factors:
537
+ # - FA2 requires that cu_seqlens_q must have dtype int32
538
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype
539
+ # as grid_thw
540
+ # See https://github.com/huggingface/transformers/pull/34852
541
+ # for more information
542
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
543
+ )
544
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
545
+
546
+ for layer_num, blk in enumerate(self.blocks):
547
+ if layer_num in self.fullatt_block_indexes:
548
+ cu_seqlens_now = cu_seqlens
549
+ else:
550
+ cu_seqlens_now = cu_window_seqlens
551
+
552
+ hidden_states = blk(
553
+ hidden_states,
554
+ cu_seqlens=cu_seqlens_now,
555
+ position_embeddings=position_embeddings,
556
+ **kwargs,
557
+ )
558
+ if STOPAT is not None and layer_num > STOPAT:
559
+ break
560
+
561
+ hidden_states = self.merger(hidden_states)
562
+ reverse_indices = torch.argsort(window_index)
563
+ hidden_states = hidden_states[reverse_indices, :]
564
+ return hidden_states
565
+
566
+ class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
567
+ def forward(
568
+ self,
569
+ start_end,
570
+ query_states,
571
+ key_states,
572
+ value_states,
573
+ scaling: float = 1.0,
574
+ dropout: float = 0.0,
575
+ **kwargs,
576
+ ):
577
+ a = start_end[0].item()
578
+ b = start_end[1].item()
579
+ q = query_states[:, :, a:b, :]
580
+ k = key_states[:, :, a:b, :]
581
+ v = value_states[:, :, a:b, :]
582
+ return patched_sdpa_attention_forward(
583
+ self,
584
+ q,
585
+ k,
586
+ v,
587
+ attention_mask=None,
588
+ scaling=scaling,
589
+ dropout=dropout,
590
+ is_causal=False,
591
+ **kwargs,
592
+ )[0]
593
+
594
+ class patched_Qwen2_5_VLVisionAttention:
595
+ _PATCHES_ = ["forward"]
596
+ _PATCHED_CLASS_ = (
597
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
598
+ )
599
+ STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731
600
+
601
+ def forward(
602
+ self,
603
+ hidden_states: torch.Tensor,
604
+ cu_seqlens: torch.Tensor,
605
+ rotary_pos_emb: Optional[torch.Tensor] = None,
606
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
607
+ **kwargs,
608
+ ) -> torch.Tensor:
609
+ seq_length = hidden_states.shape[0]
610
+ # PATCHED: avoid the use of unbind
611
+ qkv = (
612
+ self.qkv(hidden_states)
613
+ .reshape(seq_length, 3, self.num_heads, -1)
614
+ .permute(1, 0, 2, 3)
615
+ )
616
+
617
+ query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
618
+ cos, sin = position_embeddings
619
+
620
+ # This part should be moved into the loop
621
+ # iteration to enable fusion inside the loop.
622
+ query_states, key_states = (
623
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
624
+ query_states, key_states, cos, sin
625
+ )
626
+ )
627
+
628
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
629
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
630
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
631
+
632
+ attention_interface: Callable = (
633
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
634
+ )
635
+ if self.config._attn_implementation != "eager":
636
+ # PATCHED
637
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[
638
+ # self.config._attn_implementation]
639
+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
640
+ self.config._attn_implementation
641
+ ]
642
+
643
+ is_sdpa_or_eager = (
644
+ attention_interface
645
+ is transformers.integrations.sdpa_attention.sdpa_attention_forward
646
+ or attention_interface is patched_sdpa_attention_forward
647
+ or attention_interface
648
+ is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
649
+ )
650
+ if is_sdpa_or_eager:
651
+ attn_output = qwen_sdpa_attention_versatile(
652
+ query_states,
653
+ key_states,
654
+ value_states,
655
+ cu_seqlens,
656
+ self.scaling,
657
+ self.num_heads,
658
+ )
659
+ elif _is_torchdynamo_exporting():
660
+ if self.config._attn_implementation == "flash_attention_2":
661
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
662
+ attn_output = torch.onnx.ops.symbolic(
663
+ "custom::qwen25_flash_attention",
664
+ (
665
+ query_states,
666
+ key_states,
667
+ value_states,
668
+ cu_seqlens,
669
+ cu_seqlens,
670
+ max_seqlen,
671
+ max_seqlen,
672
+ torch.tensor(self.scaling, dtype=torch.float32),
673
+ ),
674
+ dtype=query_states.dtype,
675
+ shape=(
676
+ query_states.shape[0], # batch_size
677
+ query_states.shape[2], # sequence_length (total patches)
678
+ query_states.shape[1], # num_heads
679
+ query_states.shape[3], # head_size
680
+ ),
681
+ version=1,
682
+ )
683
+ else:
684
+ raise NotImplementedError(
685
+ f"No corresponding export strategy for implementation "
686
+ f"{self.config._attn_implementation!r}, "
687
+ f"(use QWEN25ATTENTION to change it), and attention_interface="
688
+ f"{attention_interface!r} (use sdpa)"
689
+ )
690
+ elif self.config._attn_implementation == "flash_attention_2":
691
+ # Flash Attention 2: Use cu_seqlens for variable length attention
692
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
693
+ attn_output, _ = attention_interface(
694
+ self,
695
+ query_states,
696
+ key_states,
697
+ value_states,
698
+ attention_mask=None,
699
+ scaling=self.scaling,
700
+ dropout=0.0 if not self.training else self.attention_dropout,
701
+ cu_seq_lens_q=cu_seqlens,
702
+ cu_seq_lens_k=cu_seqlens,
703
+ max_length_q=max_seqlen,
704
+ max_length_k=max_seqlen,
705
+ is_causal=False,
706
+ **kwargs,
707
+ )
708
+ else:
709
+ # Other implementations: Process each chunk separately
710
+ # = qwen_sdpa_attention
711
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
712
+ splits = [
713
+ torch.split(tensor, lengths.tolist(), dim=2)
714
+ for tensor in (query_states, key_states, value_states)
715
+ ]
716
+
717
+ attn_outputs = [
718
+ attention_interface(
719
+ self,
720
+ q,
721
+ k,
722
+ v,
723
+ attention_mask=None,
724
+ scaling=self.scaling,
725
+ dropout=0.0 if not self.training else self.attention_dropout,
726
+ is_causal=False,
727
+ **kwargs,
728
+ )[0]
729
+ for q, k, v in zip(*splits)
730
+ ]
731
+ attn_output = torch.cat(attn_outputs, dim=1)
732
+
733
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
734
+ attn_output = self.proj(attn_output)
735
+ return attn_output