onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +35 -5
  4. onnx_diagnostic/export/control_flow.py +511 -0
  5. onnx_diagnostic/export/control_flow_research.py +135 -0
  6. onnx_diagnostic/ext_test_case.py +33 -9
  7. onnx_diagnostic/helpers/cache_helper.py +217 -203
  8. onnx_diagnostic/helpers/helper.py +6 -2
  9. onnx_diagnostic/helpers/log_helper.py +39 -5
  10. onnx_diagnostic/helpers/memory_peak.py +2 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +55 -3
  12. onnx_diagnostic/helpers/onnx_helper.py +13 -16
  13. onnx_diagnostic/helpers/rt_helper.py +579 -15
  14. onnx_diagnostic/helpers/torch_helper.py +5 -0
  15. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  16. onnx_diagnostic/tasks/text2text_generation.py +1 -0
  17. onnx_diagnostic/tasks/text_generation.py +84 -54
  18. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
  22. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +563 -61
  23. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  25. onnx_diagnostic/torch_models/validate.py +620 -213
  26. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
  27. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
@@ -608,20 +608,20 @@ class patched_GenerationMixin:
608
608
  # if input_ids.shape[1] != cache_position.shape[0]:
609
609
  # input_ids = input_ids[:, cache_position]
610
610
  def branch_1(inputs_embeds, cache_position):
611
- return inputs_embeds[:, -cache_position.shape[0] :]
611
+ return inputs_embeds[:, -cache_position.shape[0] :].clone()
612
612
 
613
613
  def branch_2(input_ids, cache_position):
614
- return input_ids[:, -cache_position.shape[0] :]
614
+ return input_ids[:, -cache_position.shape[0] :].clone()
615
615
 
616
616
  def branch_3(input_ids, cache_position):
617
- return input_ids[:, cache_position]
617
+ return input_ids[:, cache_position].clone()
618
618
 
619
619
  inputs_embeds, input_ids = torch.cond(
620
620
  input_ids.shape[1] == 0,
621
621
  (
622
622
  lambda input_ids, inputs_embeds, cache_position: (
623
623
  branch_1(inputs_embeds, cache_position),
624
- input_ids,
624
+ input_ids.clone(),
625
625
  )
626
626
  ),
627
627
  (
@@ -1401,6 +1401,18 @@ def patched_sdpa_attention_forward(
1401
1401
  is_causal = attention_mask is None and is_causal
1402
1402
 
1403
1403
  if not is_causal:
1404
+ torch._check(query.shape[0] > 0)
1405
+ torch._check(query.shape[1] > 0)
1406
+ torch._check(query.shape[2] > 0)
1407
+ torch._check(query.shape[3] > 0)
1408
+ torch._check(key.shape[0] > 0)
1409
+ torch._check(key.shape[1] > 0)
1410
+ torch._check(key.shape[2] > 0)
1411
+ torch._check(key.shape[3] > 0)
1412
+ torch._check(value.shape[0] > 0)
1413
+ torch._check(value.shape[1] > 0)
1414
+ torch._check(value.shape[2] > 0)
1415
+ torch._check(value.shape[3] > 0)
1404
1416
  return (
1405
1417
  torch.nn.functional.scaled_dot_product_attention(
1406
1418
  query,
@@ -1452,7 +1464,7 @@ def patched_sdpa_attention_forward(
1452
1464
  scale=scaling,
1453
1465
  is_causal=True,
1454
1466
  **sdpa_kwargs,
1455
- ),
1467
+ ).contiguous(),
1456
1468
  lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1457
1469
  query,
1458
1470
  key,
@@ -1461,7 +1473,7 @@ def patched_sdpa_attention_forward(
1461
1473
  scale=scaling,
1462
1474
  is_causal=False,
1463
1475
  **sdpa_kwargs,
1464
- ),
1476
+ ).contiguous(),
1465
1477
  [query, key, value],
1466
1478
  )
1467
1479
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -1917,67 +1929,557 @@ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
1917
1929
  return mask * filt
1918
1930
 
1919
1931
 
1920
- class patched_VisionAttention(torch.nn.Module):
1921
- _PATCHES_ = ["forward"]
1922
- _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
1932
+ try:
1933
+ import transformers.models.qwen2_vl
1923
1934
 
1924
- def forward(
1925
- self,
1926
- hidden_states: torch.Tensor,
1927
- cu_seqlens: torch.Tensor,
1928
- rotary_pos_emb: Optional[torch.Tensor] = None,
1929
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1930
- ) -> torch.Tensor:
1931
- seq_length = hidden_states.shape[0]
1932
- q, k, v = (
1933
- self.qkv(hidden_states)
1934
- .reshape(seq_length, 3, self.num_heads, -1)
1935
- .permute(1, 0, 2, 3)
1936
- .unbind(0)
1935
+ patch_qwen2 = True
1936
+ except ImportError:
1937
+ patch_qwen2 = False
1938
+
1939
+ if patch_qwen2:
1940
+
1941
+ class patched_VisionAttention(torch.nn.Module):
1942
+ _PATCHES_ = ["forward"]
1943
+ _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
1944
+
1945
+ def forward(
1946
+ self,
1947
+ hidden_states: torch.Tensor,
1948
+ cu_seqlens: torch.Tensor,
1949
+ rotary_pos_emb: Optional[torch.Tensor] = None,
1950
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1951
+ ) -> torch.Tensor:
1952
+ seq_length = hidden_states.shape[0]
1953
+ q, k, v = (
1954
+ self.qkv(hidden_states)
1955
+ .reshape(seq_length, 3, self.num_heads, -1)
1956
+ .permute(1, 0, 2, 3)
1957
+ .unbind(0)
1958
+ )
1959
+ if position_embeddings is None:
1960
+ transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
1961
+ "The attention layers in this model are transitioning from "
1962
+ " computing the RoPE embeddings internally "
1963
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
1964
+ "to using externally computed "
1965
+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
1966
+ " In v4.54 `rotary_pos_emb` will be "
1967
+ "removed and `position_embeddings` will be mandatory."
1968
+ )
1969
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1970
+ cos = emb.cos()
1971
+ sin = emb.sin()
1972
+ else:
1973
+ cos, sin = position_embeddings
1974
+ q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
1975
+ q, k, cos, sin
1976
+ )
1977
+
1978
+ attention_mask = torch.full(
1979
+ [1, seq_length, seq_length],
1980
+ torch.finfo(q.dtype).min,
1981
+ device=q.device,
1982
+ dtype=q.dtype,
1983
+ )
1984
+ # for i in range(1, len(cu_seqlens)):
1985
+ # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
1986
+ # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1987
+ attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
1988
+
1989
+ q = q.transpose(0, 1)
1990
+ k = k.transpose(0, 1)
1991
+ v = v.transpose(0, 1)
1992
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1993
+ attn_weights = attn_weights + attention_mask
1994
+ attn_weights = torch.nn.functional.softmax(
1995
+ attn_weights, dim=-1, dtype=torch.float32
1996
+ ).to(q.dtype)
1997
+ attn_output = torch.matmul(attn_weights, v)
1998
+ attn_output = attn_output.transpose(0, 1)
1999
+ attn_output = attn_output.reshape(seq_length, -1)
2000
+ attn_output = self.proj(attn_output)
2001
+ return attn_output
2002
+
2003
+
2004
+ try:
2005
+ import transformers.models.qwen2_5_vl
2006
+ import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
2007
+
2008
+ patch_qwen2_5 = True
2009
+ except ImportError:
2010
+ patch_qwen2_5 = False
2011
+
2012
+ if patch_qwen2_5:
2013
+ import torch.nn.functional as F
2014
+
2015
+ use_loop_for_attention_in_qwen_2_5 = False
2016
+
2017
+ class patched_Qwen2_5_VLForConditionalGeneration:
2018
+ _PATCHES_ = ["prepare_inputs_for_generation"]
2019
+ _PATCHED_CLASS_ = (
2020
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
1937
2021
  )
1938
- if position_embeddings is None:
1939
- transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
1940
- "The attention layers in this model are transitioning from "
1941
- " computing the RoPE embeddings internally "
1942
- "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
1943
- "to using externally computed "
1944
- "`position_embeddings` (Tuple of tensors, containing cos and sin)."
1945
- " In v4.54 `rotary_pos_emb` will be "
1946
- "removed and `position_embeddings` will be mandatory."
2022
+
2023
+ def prepare_inputs_for_generation(
2024
+ self,
2025
+ input_ids,
2026
+ past_key_values=None,
2027
+ attention_mask=None,
2028
+ inputs_embeds=None,
2029
+ cache_position=None,
2030
+ position_ids=None,
2031
+ use_cache=True,
2032
+ pixel_values=None,
2033
+ pixel_values_videos=None,
2034
+ image_grid_thw=None,
2035
+ video_grid_thw=None,
2036
+ second_per_grid_ts=None,
2037
+ **kwargs,
2038
+ ):
2039
+ # Overwritten -- in specific circumstances we don't want to f
2040
+ # forward image inputs to the model
2041
+ from transformers.generation import GenerationMixin
2042
+
2043
+ model_inputs = GenerationMixin.prepare_inputs_for_generation(
2044
+ self,
2045
+ input_ids,
2046
+ past_key_values=past_key_values,
2047
+ attention_mask=attention_mask,
2048
+ inputs_embeds=inputs_embeds,
2049
+ cache_position=cache_position,
2050
+ position_ids=position_ids,
2051
+ pixel_values=pixel_values,
2052
+ pixel_values_videos=pixel_values_videos,
2053
+ image_grid_thw=image_grid_thw,
2054
+ video_grid_thw=video_grid_thw,
2055
+ second_per_grid_ts=second_per_grid_ts,
2056
+ use_cache=use_cache,
2057
+ **kwargs,
1947
2058
  )
1948
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1949
- cos = emb.cos()
1950
- sin = emb.sin()
1951
- else:
1952
- cos, sin = position_embeddings
1953
- q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
1954
- q, k, cos, sin
2059
+
2060
+ # Qwen2-5-VL position_ids are prepared with rope_deltas
2061
+ if position_ids is None:
2062
+ # Calculate RoPE index once per generation in the pre-fill stage only.
2063
+ # When compiling, we can't check tensor values thus we check only input length
2064
+ # It is safe to assume that `length!=1` means we're in pre-fill
2065
+ # because compiled models currently cannot do assisted decoding
2066
+ if cache_position[0] == 0 or self.model.rope_deltas is None:
2067
+ vision_positions, rope_deltas = self.model.get_rope_index(
2068
+ model_inputs.get("input_ids", None),
2069
+ image_grid_thw=image_grid_thw,
2070
+ video_grid_thw=video_grid_thw,
2071
+ second_per_grid_ts=second_per_grid_ts,
2072
+ attention_mask=attention_mask,
2073
+ )
2074
+ self.model.rope_deltas = rope_deltas
2075
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
2076
+ elif (
2077
+ "position_ids" in model_inputs and model_inputs["position_ids"] is not None
2078
+ ):
2079
+ batch_size, seq_length = model_inputs["position_ids"].shape
2080
+ device = model_inputs["position_ids"].device
2081
+ position_ids = torch.arange(seq_length, device=device)
2082
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
2083
+ delta = cache_position[0] + self.model.rope_deltas
2084
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
2085
+ vision_positions = position_ids + delta.expand_as(position_ids)
2086
+
2087
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
2088
+ if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
2089
+ text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
2090
+ None, None, :
2091
+ ]
2092
+ else:
2093
+ text_positions = model_inputs["position_ids"][None, ...]
2094
+ # text_positions = model_inputs["position_ids"][None, ...]
2095
+ assert vision_positions is not None, "vision_positions are missing"
2096
+ model_inputs["position_ids"] = torch.cat(
2097
+ [text_positions, vision_positions], dim=0
2098
+ )
2099
+
2100
+ if cache_position[0] != 0:
2101
+ model_inputs["pixel_values"] = None
2102
+ model_inputs["pixel_values_videos"] = None
2103
+
2104
+ return model_inputs
2105
+
2106
+ class patched_Qwen2_5_VisionTransformerPretrainedModel:
2107
+ _PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
2108
+ _PATCHED_CLASS_ = (
2109
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
1955
2110
  )
1956
2111
 
1957
- attention_mask = torch.full(
1958
- [1, seq_length, seq_length],
1959
- torch.finfo(q.dtype).min,
1960
- device=q.device,
1961
- dtype=q.dtype,
2112
+ def rot_pos_emb(self, grid_thw):
2113
+ pos_ids = []
2114
+ for thw_ in grid_thw:
2115
+ # PATCHED: avoid unbind
2116
+ t = thw_[0]
2117
+ h = thw_[1]
2118
+ w = thw_[2]
2119
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
2120
+ hpos_ids = hpos_ids.reshape(
2121
+ h // self.spatial_merge_size,
2122
+ self.spatial_merge_size,
2123
+ w // self.spatial_merge_size,
2124
+ self.spatial_merge_size,
2125
+ )
2126
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
2127
+ hpos_ids = hpos_ids.flatten()
2128
+
2129
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
2130
+ wpos_ids = wpos_ids.reshape(
2131
+ h // self.spatial_merge_size,
2132
+ self.spatial_merge_size,
2133
+ w // self.spatial_merge_size,
2134
+ self.spatial_merge_size,
2135
+ )
2136
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
2137
+ wpos_ids = wpos_ids.flatten()
2138
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
2139
+ pos_ids = torch.cat(pos_ids, dim=0)
2140
+ max_grid_size = grid_thw[:, 1:].max()
2141
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
2142
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
2143
+ return rotary_pos_emb
2144
+
2145
+ def get_window_index(self, grid_thw):
2146
+ window_index: list = [] # type: ignore[annotation-unchecked]
2147
+ # PATCHED
2148
+ cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
2149
+ window_index_id = 0
2150
+ vit_merger_window_size = (
2151
+ self.window_size // self.spatial_merge_size // self.patch_size
2152
+ )
2153
+
2154
+ for _thw in grid_thw:
2155
+ # PATCHED: avoid unbind
2156
+ grid_t = _thw[0]
2157
+ grid_h = _thw[1]
2158
+ grid_w = _thw[2]
2159
+ llm_grid_h, llm_grid_w = (
2160
+ grid_h // self.spatial_merge_size,
2161
+ grid_w // self.spatial_merge_size,
2162
+ )
2163
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
2164
+ grid_t, llm_grid_h, llm_grid_w
2165
+ )
2166
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
2167
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
2168
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
2169
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
2170
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
2171
+ index_padded = index_padded.reshape(
2172
+ grid_t,
2173
+ num_windows_h,
2174
+ vit_merger_window_size,
2175
+ num_windows_w,
2176
+ vit_merger_window_size,
2177
+ )
2178
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
2179
+ grid_t,
2180
+ num_windows_h * num_windows_w,
2181
+ vit_merger_window_size,
2182
+ vit_merger_window_size,
2183
+ )
2184
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
2185
+ index_padded = index_padded.reshape(-1)
2186
+ index_new = index_padded[index_padded != -100]
2187
+ window_index.append(index_new + window_index_id)
2188
+ cu_seqlens_tmp = (
2189
+ seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
2190
+ )
2191
+ # PATCHED
2192
+ # cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
2193
+ cu_window_seqlens.append(cu_seqlens_tmp)
2194
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
2195
+ window_index = torch.cat(window_index, dim=0)
2196
+
2197
+ return window_index, torch.cat(cu_window_seqlens, dim=0)
2198
+
2199
+ def forward(
2200
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
2201
+ ) -> torch.Tensor:
2202
+ """
2203
+ Args:
2204
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
2205
+ The final hidden states of the model.
2206
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
2207
+ The temporal, height and width of feature shape of each image in LLM.
2208
+
2209
+ Returns:
2210
+ `torch.Tensor`: hidden_states.
2211
+ """
2212
+ hidden_states = self.patch_embed(hidden_states)
2213
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
2214
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
2215
+ # PATCHED
2216
+ # cu_window_seqlens = torch.tensor(
2217
+ # cu_window_seqlens,
2218
+ # device=hidden_states.device,
2219
+ # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
2220
+ # )
2221
+ cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
2222
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
2223
+
2224
+ seq_len, _ = hidden_states.size()
2225
+ hidden_states = hidden_states.reshape(
2226
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
2227
+ )
2228
+ hidden_states = hidden_states[window_index, :, :]
2229
+ hidden_states = hidden_states.reshape(seq_len, -1)
2230
+ rotary_pos_emb = rotary_pos_emb.reshape(
2231
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
2232
+ )
2233
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
2234
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
2235
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
2236
+ position_embeddings = (emb.cos(), emb.sin())
2237
+
2238
+ cu_seqlens = torch.repeat_interleave(
2239
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
2240
+ ).cumsum(
2241
+ dim=0,
2242
+ # Select dtype based on the following factors:
2243
+ # - FA2 requires that cu_seqlens_q must have dtype int32
2244
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype
2245
+ # as grid_thw
2246
+ # See https://github.com/huggingface/transformers/pull/34852
2247
+ # for more information
2248
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
2249
+ )
2250
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
2251
+
2252
+ for layer_num, blk in enumerate(self.blocks):
2253
+ if layer_num in self.fullatt_block_indexes:
2254
+ cu_seqlens_now = cu_seqlens
2255
+ else:
2256
+ cu_seqlens_now = cu_window_seqlens
2257
+
2258
+ hidden_states = blk(
2259
+ hidden_states,
2260
+ cu_seqlens=cu_seqlens_now,
2261
+ position_embeddings=position_embeddings,
2262
+ **kwargs,
2263
+ )
2264
+
2265
+ hidden_states = self.merger(hidden_states)
2266
+ reverse_indices = torch.argsort(window_index)
2267
+ hidden_states = hidden_states[reverse_indices, :]
2268
+ return hidden_states
2269
+
2270
+ class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
2271
+ def forward(
2272
+ self,
2273
+ start_end,
2274
+ query_states,
2275
+ key_states,
2276
+ value_states,
2277
+ scaling: float = 1.0,
2278
+ dropout: float = 0.0,
2279
+ **kwargs,
2280
+ ):
2281
+ a = start_end[0].item()
2282
+ b = start_end[1].item()
2283
+ q = query_states[:, :, a:b, :]
2284
+ k = key_states[:, :, a:b, :]
2285
+ v = value_states[:, :, a:b, :]
2286
+ return patched_sdpa_attention_forward(
2287
+ self,
2288
+ q,
2289
+ k,
2290
+ v,
2291
+ attention_mask=None,
2292
+ scaling=scaling,
2293
+ dropout=dropout,
2294
+ is_causal=False,
2295
+ **kwargs,
2296
+ )[0]
2297
+
2298
+ class patched_Qwen2_5_VLVisionAttention:
2299
+ _PATCHES_ = ["forward"]
2300
+ _PATCHED_CLASS_ = (
2301
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
1962
2302
  )
1963
- # for i in range(1, len(cu_seqlens)):
1964
- # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
1965
- # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1966
- attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
1967
-
1968
- q = q.transpose(0, 1)
1969
- k = k.transpose(0, 1)
1970
- v = v.transpose(0, 1)
1971
- attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1972
- attn_weights = attn_weights + attention_mask
1973
- attn_weights = torch.nn.functional.softmax(
1974
- attn_weights, dim=-1, dtype=torch.float32
1975
- ).to(q.dtype)
1976
- attn_output = torch.matmul(attn_weights, v)
1977
- attn_output = attn_output.transpose(0, 1)
1978
- attn_output = attn_output.reshape(seq_length, -1)
1979
- attn_output = self.proj(attn_output)
1980
- return attn_output
2303
+
2304
+ def forward(
2305
+ self,
2306
+ hidden_states: torch.Tensor,
2307
+ cu_seqlens: torch.Tensor,
2308
+ rotary_pos_emb: Optional[torch.Tensor] = None,
2309
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
2310
+ **kwargs,
2311
+ ) -> torch.Tensor:
2312
+ seq_length = hidden_states.shape[0]
2313
+ # PATCHED: avoid the use of unbind
2314
+ qkv = (
2315
+ self.qkv(hidden_states)
2316
+ .reshape(seq_length, 3, self.num_heads, -1)
2317
+ .permute(1, 0, 2, 3)
2318
+ )
2319
+
2320
+ query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
2321
+ cos, sin = position_embeddings
2322
+
2323
+ # This part should be moved into the loop
2324
+ # iteration to enable fusion inside the loop.
2325
+ query_states, key_states = (
2326
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
2327
+ query_states, key_states, cos, sin
2328
+ )
2329
+ )
2330
+
2331
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
2332
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
2333
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
2334
+
2335
+ attention_interface: Callable = (
2336
+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
2337
+ )
2338
+ if self.config._attn_implementation != "eager":
2339
+ # PATCHED
2340
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[
2341
+ # self.config._attn_implementation]
2342
+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
2343
+ self.config._attn_implementation
2344
+ ]
2345
+
2346
+ if (
2347
+ self.config._attn_implementation == "flash_attention_2"
2348
+ and _is_torchdynamo_exporting()
2349
+ ):
2350
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
2351
+ attn_output = torch.onnx.ops.symbolic(
2352
+ "custom::qwen25_attention",
2353
+ (
2354
+ query_states,
2355
+ key_states,
2356
+ value_states,
2357
+ cu_seqlens,
2358
+ cu_seqlens,
2359
+ max_seqlen,
2360
+ max_seqlen,
2361
+ torch.tensor(self.scaling, dtype=torch.float32),
2362
+ ),
2363
+ dtype=query_states.dtype,
2364
+ shape=(
2365
+ key_states.shape[0],
2366
+ value_states.shape[1],
2367
+ max_seqlen,
2368
+ value_states.shape[-1],
2369
+ ),
2370
+ version=1,
2371
+ )
2372
+ elif self.config._attn_implementation == "flash_attention_2":
2373
+ # Flash Attention 2: Use cu_seqlens for variable length attention
2374
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
2375
+ attn_output, _ = attention_interface(
2376
+ self,
2377
+ query_states,
2378
+ key_states,
2379
+ value_states,
2380
+ attention_mask=None,
2381
+ scaling=self.scaling,
2382
+ dropout=0.0 if not self.training else self.attention_dropout,
2383
+ cu_seq_lens_q=cu_seqlens,
2384
+ cu_seq_lens_k=cu_seqlens,
2385
+ max_length_q=max_seqlen,
2386
+ max_length_k=max_seqlen,
2387
+ is_causal=False,
2388
+ **kwargs,
2389
+ )
2390
+ elif _is_torchdynamo_exporting():
2391
+ if (
2392
+ attention_interface
2393
+ is transformers.integrations.sdpa_attention.sdpa_attention_forward
2394
+ ):
2395
+ attention_interface = patched_sdpa_attention_forward
2396
+
2397
+ if use_loop_for_attention_in_qwen_2_5:
2398
+
2399
+ def _iteration(start_end, query_states, key_states, value_states):
2400
+ return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
2401
+ self,
2402
+ start_end,
2403
+ query_states,
2404
+ key_states,
2405
+ value_states,
2406
+ scaling=self.scaling,
2407
+ dropout=0.0 if not self.training else self.attention_dropout,
2408
+ )
2409
+
2410
+ starts = cu_seqlens[:-1]
2411
+ ends = cu_seqlens[1:]
2412
+ # cu_seqlens = [0, 10, 14, 27]
2413
+ # starts: [0, 10, 14]
2414
+ # ends: [10, 14, 17]
2415
+ # starts_ends: [[0, 10], [10, 14], [14, 27]]
2416
+ starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
2417
+ attn_outputs = [
2418
+ _iteration(start_end, query_states, key_states, value_states)
2419
+ for start_end in starts_ends
2420
+ ]
2421
+ # attn_outputs = torch._higher_order_ops.while_loop(
2422
+ # attn_outputs = torch.ops.higher_order.while_loop(
2423
+ # (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
2424
+ # _iteration,
2425
+ # (torch.tensor(0),
2426
+ # starts_ends, query_states, key_states, value_states), tuple(),
2427
+ # )
2428
+ attn_output = torch.cat(attn_outputs, dim=1)
2429
+ else:
2430
+ # make square mask
2431
+ indices = torch.arange(
2432
+ cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
2433
+ )
2434
+ dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
2435
+ cu_seqlens.dtype
2436
+ )
2437
+ dot = dot.sum(dim=0)
2438
+ mask = dot.unsqueeze(1) - dot.unsqueeze(0)
2439
+ bool_mask = mask == 0
2440
+ bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
2441
+
2442
+ torch._check(bool_mask.shape[2] == key_states.shape[2])
2443
+ torch._check(bool_mask.shape[3] == key_states.shape[2])
2444
+
2445
+ attn_output, _ = attention_interface(
2446
+ self,
2447
+ query_states,
2448
+ key_states,
2449
+ value_states,
2450
+ attention_mask=bool_mask,
2451
+ scaling=self.scaling,
2452
+ dropout=0.0 if not self.training else self.attention_dropout,
2453
+ is_causal=False,
2454
+ **kwargs,
2455
+ )
2456
+ else:
2457
+ # Other implementations: Process each chunk separately
2458
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
2459
+ splits = [
2460
+ torch.split(tensor, lengths.tolist(), dim=2)
2461
+ for tensor in (query_states, key_states, value_states)
2462
+ ]
2463
+
2464
+ attn_outputs = [
2465
+ attention_interface(
2466
+ self,
2467
+ q,
2468
+ k,
2469
+ v,
2470
+ attention_mask=None,
2471
+ scaling=self.scaling,
2472
+ dropout=0.0 if not self.training else self.attention_dropout,
2473
+ is_causal=False,
2474
+ **kwargs,
2475
+ )[0]
2476
+ for q, k, v in zip(*splits)
2477
+ ]
2478
+ attn_output = torch.cat(attn_outputs, dim=1)
2479
+
2480
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
2481
+ attn_output = self.proj(attn_output)
2482
+ return attn_output
1981
2483
 
1982
2484
 
1983
2485
  try: