onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +35 -5
- onnx_diagnostic/export/control_flow.py +511 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/ext_test_case.py +33 -9
- onnx_diagnostic/helpers/cache_helper.py +217 -203
- onnx_diagnostic/helpers/helper.py +6 -2
- onnx_diagnostic/helpers/log_helper.py +39 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +55 -3
- onnx_diagnostic/helpers/onnx_helper.py +13 -16
- onnx_diagnostic/helpers/rt_helper.py +579 -15
- onnx_diagnostic/helpers/torch_helper.py +5 -0
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/text2text_generation.py +1 -0
- onnx_diagnostic/tasks/text_generation.py +84 -54
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +563 -61
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +620 -213
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +30 -28
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
|
@@ -608,20 +608,20 @@ class patched_GenerationMixin:
|
|
|
608
608
|
# if input_ids.shape[1] != cache_position.shape[0]:
|
|
609
609
|
# input_ids = input_ids[:, cache_position]
|
|
610
610
|
def branch_1(inputs_embeds, cache_position):
|
|
611
|
-
return inputs_embeds[:, -cache_position.shape[0] :]
|
|
611
|
+
return inputs_embeds[:, -cache_position.shape[0] :].clone()
|
|
612
612
|
|
|
613
613
|
def branch_2(input_ids, cache_position):
|
|
614
|
-
return input_ids[:, -cache_position.shape[0] :]
|
|
614
|
+
return input_ids[:, -cache_position.shape[0] :].clone()
|
|
615
615
|
|
|
616
616
|
def branch_3(input_ids, cache_position):
|
|
617
|
-
return input_ids[:, cache_position]
|
|
617
|
+
return input_ids[:, cache_position].clone()
|
|
618
618
|
|
|
619
619
|
inputs_embeds, input_ids = torch.cond(
|
|
620
620
|
input_ids.shape[1] == 0,
|
|
621
621
|
(
|
|
622
622
|
lambda input_ids, inputs_embeds, cache_position: (
|
|
623
623
|
branch_1(inputs_embeds, cache_position),
|
|
624
|
-
input_ids,
|
|
624
|
+
input_ids.clone(),
|
|
625
625
|
)
|
|
626
626
|
),
|
|
627
627
|
(
|
|
@@ -1401,6 +1401,18 @@ def patched_sdpa_attention_forward(
|
|
|
1401
1401
|
is_causal = attention_mask is None and is_causal
|
|
1402
1402
|
|
|
1403
1403
|
if not is_causal:
|
|
1404
|
+
torch._check(query.shape[0] > 0)
|
|
1405
|
+
torch._check(query.shape[1] > 0)
|
|
1406
|
+
torch._check(query.shape[2] > 0)
|
|
1407
|
+
torch._check(query.shape[3] > 0)
|
|
1408
|
+
torch._check(key.shape[0] > 0)
|
|
1409
|
+
torch._check(key.shape[1] > 0)
|
|
1410
|
+
torch._check(key.shape[2] > 0)
|
|
1411
|
+
torch._check(key.shape[3] > 0)
|
|
1412
|
+
torch._check(value.shape[0] > 0)
|
|
1413
|
+
torch._check(value.shape[1] > 0)
|
|
1414
|
+
torch._check(value.shape[2] > 0)
|
|
1415
|
+
torch._check(value.shape[3] > 0)
|
|
1404
1416
|
return (
|
|
1405
1417
|
torch.nn.functional.scaled_dot_product_attention(
|
|
1406
1418
|
query,
|
|
@@ -1452,7 +1464,7 @@ def patched_sdpa_attention_forward(
|
|
|
1452
1464
|
scale=scaling,
|
|
1453
1465
|
is_causal=True,
|
|
1454
1466
|
**sdpa_kwargs,
|
|
1455
|
-
),
|
|
1467
|
+
).contiguous(),
|
|
1456
1468
|
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
1457
1469
|
query,
|
|
1458
1470
|
key,
|
|
@@ -1461,7 +1473,7 @@ def patched_sdpa_attention_forward(
|
|
|
1461
1473
|
scale=scaling,
|
|
1462
1474
|
is_causal=False,
|
|
1463
1475
|
**sdpa_kwargs,
|
|
1464
|
-
),
|
|
1476
|
+
).contiguous(),
|
|
1465
1477
|
[query, key, value],
|
|
1466
1478
|
)
|
|
1467
1479
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
@@ -1917,67 +1929,557 @@ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
|
|
|
1917
1929
|
return mask * filt
|
|
1918
1930
|
|
|
1919
1931
|
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
|
|
1932
|
+
try:
|
|
1933
|
+
import transformers.models.qwen2_vl
|
|
1923
1934
|
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
.
|
|
1935
|
+
patch_qwen2 = True
|
|
1936
|
+
except ImportError:
|
|
1937
|
+
patch_qwen2 = False
|
|
1938
|
+
|
|
1939
|
+
if patch_qwen2:
|
|
1940
|
+
|
|
1941
|
+
class patched_VisionAttention(torch.nn.Module):
|
|
1942
|
+
_PATCHES_ = ["forward"]
|
|
1943
|
+
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
|
|
1944
|
+
|
|
1945
|
+
def forward(
|
|
1946
|
+
self,
|
|
1947
|
+
hidden_states: torch.Tensor,
|
|
1948
|
+
cu_seqlens: torch.Tensor,
|
|
1949
|
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
1950
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
1951
|
+
) -> torch.Tensor:
|
|
1952
|
+
seq_length = hidden_states.shape[0]
|
|
1953
|
+
q, k, v = (
|
|
1954
|
+
self.qkv(hidden_states)
|
|
1955
|
+
.reshape(seq_length, 3, self.num_heads, -1)
|
|
1956
|
+
.permute(1, 0, 2, 3)
|
|
1957
|
+
.unbind(0)
|
|
1958
|
+
)
|
|
1959
|
+
if position_embeddings is None:
|
|
1960
|
+
transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
|
|
1961
|
+
"The attention layers in this model are transitioning from "
|
|
1962
|
+
" computing the RoPE embeddings internally "
|
|
1963
|
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), "
|
|
1964
|
+
"to using externally computed "
|
|
1965
|
+
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
|
|
1966
|
+
" In v4.54 `rotary_pos_emb` will be "
|
|
1967
|
+
"removed and `position_embeddings` will be mandatory."
|
|
1968
|
+
)
|
|
1969
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
1970
|
+
cos = emb.cos()
|
|
1971
|
+
sin = emb.sin()
|
|
1972
|
+
else:
|
|
1973
|
+
cos, sin = position_embeddings
|
|
1974
|
+
q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
|
|
1975
|
+
q, k, cos, sin
|
|
1976
|
+
)
|
|
1977
|
+
|
|
1978
|
+
attention_mask = torch.full(
|
|
1979
|
+
[1, seq_length, seq_length],
|
|
1980
|
+
torch.finfo(q.dtype).min,
|
|
1981
|
+
device=q.device,
|
|
1982
|
+
dtype=q.dtype,
|
|
1983
|
+
)
|
|
1984
|
+
# for i in range(1, len(cu_seqlens)):
|
|
1985
|
+
# attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
|
|
1986
|
+
# cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
|
1987
|
+
attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
|
|
1988
|
+
|
|
1989
|
+
q = q.transpose(0, 1)
|
|
1990
|
+
k = k.transpose(0, 1)
|
|
1991
|
+
v = v.transpose(0, 1)
|
|
1992
|
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
|
1993
|
+
attn_weights = attn_weights + attention_mask
|
|
1994
|
+
attn_weights = torch.nn.functional.softmax(
|
|
1995
|
+
attn_weights, dim=-1, dtype=torch.float32
|
|
1996
|
+
).to(q.dtype)
|
|
1997
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
1998
|
+
attn_output = attn_output.transpose(0, 1)
|
|
1999
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
2000
|
+
attn_output = self.proj(attn_output)
|
|
2001
|
+
return attn_output
|
|
2002
|
+
|
|
2003
|
+
|
|
2004
|
+
try:
|
|
2005
|
+
import transformers.models.qwen2_5_vl
|
|
2006
|
+
import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
|
|
2007
|
+
|
|
2008
|
+
patch_qwen2_5 = True
|
|
2009
|
+
except ImportError:
|
|
2010
|
+
patch_qwen2_5 = False
|
|
2011
|
+
|
|
2012
|
+
if patch_qwen2_5:
|
|
2013
|
+
import torch.nn.functional as F
|
|
2014
|
+
|
|
2015
|
+
use_loop_for_attention_in_qwen_2_5 = False
|
|
2016
|
+
|
|
2017
|
+
class patched_Qwen2_5_VLForConditionalGeneration:
|
|
2018
|
+
_PATCHES_ = ["prepare_inputs_for_generation"]
|
|
2019
|
+
_PATCHED_CLASS_ = (
|
|
2020
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
|
1937
2021
|
)
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
2022
|
+
|
|
2023
|
+
def prepare_inputs_for_generation(
|
|
2024
|
+
self,
|
|
2025
|
+
input_ids,
|
|
2026
|
+
past_key_values=None,
|
|
2027
|
+
attention_mask=None,
|
|
2028
|
+
inputs_embeds=None,
|
|
2029
|
+
cache_position=None,
|
|
2030
|
+
position_ids=None,
|
|
2031
|
+
use_cache=True,
|
|
2032
|
+
pixel_values=None,
|
|
2033
|
+
pixel_values_videos=None,
|
|
2034
|
+
image_grid_thw=None,
|
|
2035
|
+
video_grid_thw=None,
|
|
2036
|
+
second_per_grid_ts=None,
|
|
2037
|
+
**kwargs,
|
|
2038
|
+
):
|
|
2039
|
+
# Overwritten -- in specific circumstances we don't want to f
|
|
2040
|
+
# forward image inputs to the model
|
|
2041
|
+
from transformers.generation import GenerationMixin
|
|
2042
|
+
|
|
2043
|
+
model_inputs = GenerationMixin.prepare_inputs_for_generation(
|
|
2044
|
+
self,
|
|
2045
|
+
input_ids,
|
|
2046
|
+
past_key_values=past_key_values,
|
|
2047
|
+
attention_mask=attention_mask,
|
|
2048
|
+
inputs_embeds=inputs_embeds,
|
|
2049
|
+
cache_position=cache_position,
|
|
2050
|
+
position_ids=position_ids,
|
|
2051
|
+
pixel_values=pixel_values,
|
|
2052
|
+
pixel_values_videos=pixel_values_videos,
|
|
2053
|
+
image_grid_thw=image_grid_thw,
|
|
2054
|
+
video_grid_thw=video_grid_thw,
|
|
2055
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
2056
|
+
use_cache=use_cache,
|
|
2057
|
+
**kwargs,
|
|
1947
2058
|
)
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
2059
|
+
|
|
2060
|
+
# Qwen2-5-VL position_ids are prepared with rope_deltas
|
|
2061
|
+
if position_ids is None:
|
|
2062
|
+
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
2063
|
+
# When compiling, we can't check tensor values thus we check only input length
|
|
2064
|
+
# It is safe to assume that `length!=1` means we're in pre-fill
|
|
2065
|
+
# because compiled models currently cannot do assisted decoding
|
|
2066
|
+
if cache_position[0] == 0 or self.model.rope_deltas is None:
|
|
2067
|
+
vision_positions, rope_deltas = self.model.get_rope_index(
|
|
2068
|
+
model_inputs.get("input_ids", None),
|
|
2069
|
+
image_grid_thw=image_grid_thw,
|
|
2070
|
+
video_grid_thw=video_grid_thw,
|
|
2071
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
2072
|
+
attention_mask=attention_mask,
|
|
2073
|
+
)
|
|
2074
|
+
self.model.rope_deltas = rope_deltas
|
|
2075
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
2076
|
+
elif (
|
|
2077
|
+
"position_ids" in model_inputs and model_inputs["position_ids"] is not None
|
|
2078
|
+
):
|
|
2079
|
+
batch_size, seq_length = model_inputs["position_ids"].shape
|
|
2080
|
+
device = model_inputs["position_ids"].device
|
|
2081
|
+
position_ids = torch.arange(seq_length, device=device)
|
|
2082
|
+
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
2083
|
+
delta = cache_position[0] + self.model.rope_deltas
|
|
2084
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
2085
|
+
vision_positions = position_ids + delta.expand_as(position_ids)
|
|
2086
|
+
|
|
2087
|
+
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
|
2088
|
+
if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
|
|
2089
|
+
text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
|
|
2090
|
+
None, None, :
|
|
2091
|
+
]
|
|
2092
|
+
else:
|
|
2093
|
+
text_positions = model_inputs["position_ids"][None, ...]
|
|
2094
|
+
# text_positions = model_inputs["position_ids"][None, ...]
|
|
2095
|
+
assert vision_positions is not None, "vision_positions are missing"
|
|
2096
|
+
model_inputs["position_ids"] = torch.cat(
|
|
2097
|
+
[text_positions, vision_positions], dim=0
|
|
2098
|
+
)
|
|
2099
|
+
|
|
2100
|
+
if cache_position[0] != 0:
|
|
2101
|
+
model_inputs["pixel_values"] = None
|
|
2102
|
+
model_inputs["pixel_values_videos"] = None
|
|
2103
|
+
|
|
2104
|
+
return model_inputs
|
|
2105
|
+
|
|
2106
|
+
class patched_Qwen2_5_VisionTransformerPretrainedModel:
|
|
2107
|
+
_PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
|
|
2108
|
+
_PATCHED_CLASS_ = (
|
|
2109
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
|
|
1955
2110
|
)
|
|
1956
2111
|
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
2112
|
+
def rot_pos_emb(self, grid_thw):
|
|
2113
|
+
pos_ids = []
|
|
2114
|
+
for thw_ in grid_thw:
|
|
2115
|
+
# PATCHED: avoid unbind
|
|
2116
|
+
t = thw_[0]
|
|
2117
|
+
h = thw_[1]
|
|
2118
|
+
w = thw_[2]
|
|
2119
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
2120
|
+
hpos_ids = hpos_ids.reshape(
|
|
2121
|
+
h // self.spatial_merge_size,
|
|
2122
|
+
self.spatial_merge_size,
|
|
2123
|
+
w // self.spatial_merge_size,
|
|
2124
|
+
self.spatial_merge_size,
|
|
2125
|
+
)
|
|
2126
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
2127
|
+
hpos_ids = hpos_ids.flatten()
|
|
2128
|
+
|
|
2129
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
2130
|
+
wpos_ids = wpos_ids.reshape(
|
|
2131
|
+
h // self.spatial_merge_size,
|
|
2132
|
+
self.spatial_merge_size,
|
|
2133
|
+
w // self.spatial_merge_size,
|
|
2134
|
+
self.spatial_merge_size,
|
|
2135
|
+
)
|
|
2136
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
2137
|
+
wpos_ids = wpos_ids.flatten()
|
|
2138
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
2139
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
|
2140
|
+
max_grid_size = grid_thw[:, 1:].max()
|
|
2141
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
2142
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
2143
|
+
return rotary_pos_emb
|
|
2144
|
+
|
|
2145
|
+
def get_window_index(self, grid_thw):
|
|
2146
|
+
window_index: list = [] # type: ignore[annotation-unchecked]
|
|
2147
|
+
# PATCHED
|
|
2148
|
+
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
|
|
2149
|
+
window_index_id = 0
|
|
2150
|
+
vit_merger_window_size = (
|
|
2151
|
+
self.window_size // self.spatial_merge_size // self.patch_size
|
|
2152
|
+
)
|
|
2153
|
+
|
|
2154
|
+
for _thw in grid_thw:
|
|
2155
|
+
# PATCHED: avoid unbind
|
|
2156
|
+
grid_t = _thw[0]
|
|
2157
|
+
grid_h = _thw[1]
|
|
2158
|
+
grid_w = _thw[2]
|
|
2159
|
+
llm_grid_h, llm_grid_w = (
|
|
2160
|
+
grid_h // self.spatial_merge_size,
|
|
2161
|
+
grid_w // self.spatial_merge_size,
|
|
2162
|
+
)
|
|
2163
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
|
2164
|
+
grid_t, llm_grid_h, llm_grid_w
|
|
2165
|
+
)
|
|
2166
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
2167
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
2168
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
2169
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
2170
|
+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
|
2171
|
+
index_padded = index_padded.reshape(
|
|
2172
|
+
grid_t,
|
|
2173
|
+
num_windows_h,
|
|
2174
|
+
vit_merger_window_size,
|
|
2175
|
+
num_windows_w,
|
|
2176
|
+
vit_merger_window_size,
|
|
2177
|
+
)
|
|
2178
|
+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
|
2179
|
+
grid_t,
|
|
2180
|
+
num_windows_h * num_windows_w,
|
|
2181
|
+
vit_merger_window_size,
|
|
2182
|
+
vit_merger_window_size,
|
|
2183
|
+
)
|
|
2184
|
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
2185
|
+
index_padded = index_padded.reshape(-1)
|
|
2186
|
+
index_new = index_padded[index_padded != -100]
|
|
2187
|
+
window_index.append(index_new + window_index_id)
|
|
2188
|
+
cu_seqlens_tmp = (
|
|
2189
|
+
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
|
|
2190
|
+
)
|
|
2191
|
+
# PATCHED
|
|
2192
|
+
# cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
2193
|
+
cu_window_seqlens.append(cu_seqlens_tmp)
|
|
2194
|
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
2195
|
+
window_index = torch.cat(window_index, dim=0)
|
|
2196
|
+
|
|
2197
|
+
return window_index, torch.cat(cu_window_seqlens, dim=0)
|
|
2198
|
+
|
|
2199
|
+
def forward(
|
|
2200
|
+
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
|
|
2201
|
+
) -> torch.Tensor:
|
|
2202
|
+
"""
|
|
2203
|
+
Args:
|
|
2204
|
+
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
2205
|
+
The final hidden states of the model.
|
|
2206
|
+
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
|
2207
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
2208
|
+
|
|
2209
|
+
Returns:
|
|
2210
|
+
`torch.Tensor`: hidden_states.
|
|
2211
|
+
"""
|
|
2212
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
2213
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
2214
|
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
2215
|
+
# PATCHED
|
|
2216
|
+
# cu_window_seqlens = torch.tensor(
|
|
2217
|
+
# cu_window_seqlens,
|
|
2218
|
+
# device=hidden_states.device,
|
|
2219
|
+
# dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
2220
|
+
# )
|
|
2221
|
+
cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
|
|
2222
|
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
|
2223
|
+
|
|
2224
|
+
seq_len, _ = hidden_states.size()
|
|
2225
|
+
hidden_states = hidden_states.reshape(
|
|
2226
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
2227
|
+
)
|
|
2228
|
+
hidden_states = hidden_states[window_index, :, :]
|
|
2229
|
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
2230
|
+
rotary_pos_emb = rotary_pos_emb.reshape(
|
|
2231
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
2232
|
+
)
|
|
2233
|
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
2234
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
2235
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
2236
|
+
position_embeddings = (emb.cos(), emb.sin())
|
|
2237
|
+
|
|
2238
|
+
cu_seqlens = torch.repeat_interleave(
|
|
2239
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
|
2240
|
+
).cumsum(
|
|
2241
|
+
dim=0,
|
|
2242
|
+
# Select dtype based on the following factors:
|
|
2243
|
+
# - FA2 requires that cu_seqlens_q must have dtype int32
|
|
2244
|
+
# - torch.onnx.export requires that cu_seqlens_q must have same dtype
|
|
2245
|
+
# as grid_thw
|
|
2246
|
+
# See https://github.com/huggingface/transformers/pull/34852
|
|
2247
|
+
# for more information
|
|
2248
|
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
2249
|
+
)
|
|
2250
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
2251
|
+
|
|
2252
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
2253
|
+
if layer_num in self.fullatt_block_indexes:
|
|
2254
|
+
cu_seqlens_now = cu_seqlens
|
|
2255
|
+
else:
|
|
2256
|
+
cu_seqlens_now = cu_window_seqlens
|
|
2257
|
+
|
|
2258
|
+
hidden_states = blk(
|
|
2259
|
+
hidden_states,
|
|
2260
|
+
cu_seqlens=cu_seqlens_now,
|
|
2261
|
+
position_embeddings=position_embeddings,
|
|
2262
|
+
**kwargs,
|
|
2263
|
+
)
|
|
2264
|
+
|
|
2265
|
+
hidden_states = self.merger(hidden_states)
|
|
2266
|
+
reverse_indices = torch.argsort(window_index)
|
|
2267
|
+
hidden_states = hidden_states[reverse_indices, :]
|
|
2268
|
+
return hidden_states
|
|
2269
|
+
|
|
2270
|
+
class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
|
|
2271
|
+
def forward(
|
|
2272
|
+
self,
|
|
2273
|
+
start_end,
|
|
2274
|
+
query_states,
|
|
2275
|
+
key_states,
|
|
2276
|
+
value_states,
|
|
2277
|
+
scaling: float = 1.0,
|
|
2278
|
+
dropout: float = 0.0,
|
|
2279
|
+
**kwargs,
|
|
2280
|
+
):
|
|
2281
|
+
a = start_end[0].item()
|
|
2282
|
+
b = start_end[1].item()
|
|
2283
|
+
q = query_states[:, :, a:b, :]
|
|
2284
|
+
k = key_states[:, :, a:b, :]
|
|
2285
|
+
v = value_states[:, :, a:b, :]
|
|
2286
|
+
return patched_sdpa_attention_forward(
|
|
2287
|
+
self,
|
|
2288
|
+
q,
|
|
2289
|
+
k,
|
|
2290
|
+
v,
|
|
2291
|
+
attention_mask=None,
|
|
2292
|
+
scaling=scaling,
|
|
2293
|
+
dropout=dropout,
|
|
2294
|
+
is_causal=False,
|
|
2295
|
+
**kwargs,
|
|
2296
|
+
)[0]
|
|
2297
|
+
|
|
2298
|
+
class patched_Qwen2_5_VLVisionAttention:
|
|
2299
|
+
_PATCHES_ = ["forward"]
|
|
2300
|
+
_PATCHED_CLASS_ = (
|
|
2301
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
|
|
1962
2302
|
)
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
2303
|
+
|
|
2304
|
+
def forward(
|
|
2305
|
+
self,
|
|
2306
|
+
hidden_states: torch.Tensor,
|
|
2307
|
+
cu_seqlens: torch.Tensor,
|
|
2308
|
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
2309
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
2310
|
+
**kwargs,
|
|
2311
|
+
) -> torch.Tensor:
|
|
2312
|
+
seq_length = hidden_states.shape[0]
|
|
2313
|
+
# PATCHED: avoid the use of unbind
|
|
2314
|
+
qkv = (
|
|
2315
|
+
self.qkv(hidden_states)
|
|
2316
|
+
.reshape(seq_length, 3, self.num_heads, -1)
|
|
2317
|
+
.permute(1, 0, 2, 3)
|
|
2318
|
+
)
|
|
2319
|
+
|
|
2320
|
+
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
2321
|
+
cos, sin = position_embeddings
|
|
2322
|
+
|
|
2323
|
+
# This part should be moved into the loop
|
|
2324
|
+
# iteration to enable fusion inside the loop.
|
|
2325
|
+
query_states, key_states = (
|
|
2326
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
|
|
2327
|
+
query_states, key_states, cos, sin
|
|
2328
|
+
)
|
|
2329
|
+
)
|
|
2330
|
+
|
|
2331
|
+
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
|
2332
|
+
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
|
2333
|
+
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
|
2334
|
+
|
|
2335
|
+
attention_interface: Callable = (
|
|
2336
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
|
|
2337
|
+
)
|
|
2338
|
+
if self.config._attn_implementation != "eager":
|
|
2339
|
+
# PATCHED
|
|
2340
|
+
# attention_interface = ALL_ATTENTION_FUNCTIONS[
|
|
2341
|
+
# self.config._attn_implementation]
|
|
2342
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
2343
|
+
self.config._attn_implementation
|
|
2344
|
+
]
|
|
2345
|
+
|
|
2346
|
+
if (
|
|
2347
|
+
self.config._attn_implementation == "flash_attention_2"
|
|
2348
|
+
and _is_torchdynamo_exporting()
|
|
2349
|
+
):
|
|
2350
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
2351
|
+
attn_output = torch.onnx.ops.symbolic(
|
|
2352
|
+
"custom::qwen25_attention",
|
|
2353
|
+
(
|
|
2354
|
+
query_states,
|
|
2355
|
+
key_states,
|
|
2356
|
+
value_states,
|
|
2357
|
+
cu_seqlens,
|
|
2358
|
+
cu_seqlens,
|
|
2359
|
+
max_seqlen,
|
|
2360
|
+
max_seqlen,
|
|
2361
|
+
torch.tensor(self.scaling, dtype=torch.float32),
|
|
2362
|
+
),
|
|
2363
|
+
dtype=query_states.dtype,
|
|
2364
|
+
shape=(
|
|
2365
|
+
key_states.shape[0],
|
|
2366
|
+
value_states.shape[1],
|
|
2367
|
+
max_seqlen,
|
|
2368
|
+
value_states.shape[-1],
|
|
2369
|
+
),
|
|
2370
|
+
version=1,
|
|
2371
|
+
)
|
|
2372
|
+
elif self.config._attn_implementation == "flash_attention_2":
|
|
2373
|
+
# Flash Attention 2: Use cu_seqlens for variable length attention
|
|
2374
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
2375
|
+
attn_output, _ = attention_interface(
|
|
2376
|
+
self,
|
|
2377
|
+
query_states,
|
|
2378
|
+
key_states,
|
|
2379
|
+
value_states,
|
|
2380
|
+
attention_mask=None,
|
|
2381
|
+
scaling=self.scaling,
|
|
2382
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
2383
|
+
cu_seq_lens_q=cu_seqlens,
|
|
2384
|
+
cu_seq_lens_k=cu_seqlens,
|
|
2385
|
+
max_length_q=max_seqlen,
|
|
2386
|
+
max_length_k=max_seqlen,
|
|
2387
|
+
is_causal=False,
|
|
2388
|
+
**kwargs,
|
|
2389
|
+
)
|
|
2390
|
+
elif _is_torchdynamo_exporting():
|
|
2391
|
+
if (
|
|
2392
|
+
attention_interface
|
|
2393
|
+
is transformers.integrations.sdpa_attention.sdpa_attention_forward
|
|
2394
|
+
):
|
|
2395
|
+
attention_interface = patched_sdpa_attention_forward
|
|
2396
|
+
|
|
2397
|
+
if use_loop_for_attention_in_qwen_2_5:
|
|
2398
|
+
|
|
2399
|
+
def _iteration(start_end, query_states, key_states, value_states):
|
|
2400
|
+
return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
|
|
2401
|
+
self,
|
|
2402
|
+
start_end,
|
|
2403
|
+
query_states,
|
|
2404
|
+
key_states,
|
|
2405
|
+
value_states,
|
|
2406
|
+
scaling=self.scaling,
|
|
2407
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
2408
|
+
)
|
|
2409
|
+
|
|
2410
|
+
starts = cu_seqlens[:-1]
|
|
2411
|
+
ends = cu_seqlens[1:]
|
|
2412
|
+
# cu_seqlens = [0, 10, 14, 27]
|
|
2413
|
+
# starts: [0, 10, 14]
|
|
2414
|
+
# ends: [10, 14, 17]
|
|
2415
|
+
# starts_ends: [[0, 10], [10, 14], [14, 27]]
|
|
2416
|
+
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
|
|
2417
|
+
attn_outputs = [
|
|
2418
|
+
_iteration(start_end, query_states, key_states, value_states)
|
|
2419
|
+
for start_end in starts_ends
|
|
2420
|
+
]
|
|
2421
|
+
# attn_outputs = torch._higher_order_ops.while_loop(
|
|
2422
|
+
# attn_outputs = torch.ops.higher_order.while_loop(
|
|
2423
|
+
# (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
|
|
2424
|
+
# _iteration,
|
|
2425
|
+
# (torch.tensor(0),
|
|
2426
|
+
# starts_ends, query_states, key_states, value_states), tuple(),
|
|
2427
|
+
# )
|
|
2428
|
+
attn_output = torch.cat(attn_outputs, dim=1)
|
|
2429
|
+
else:
|
|
2430
|
+
# make square mask
|
|
2431
|
+
indices = torch.arange(
|
|
2432
|
+
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
|
2433
|
+
)
|
|
2434
|
+
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
|
|
2435
|
+
cu_seqlens.dtype
|
|
2436
|
+
)
|
|
2437
|
+
dot = dot.sum(dim=0)
|
|
2438
|
+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
|
|
2439
|
+
bool_mask = mask == 0
|
|
2440
|
+
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
|
|
2441
|
+
|
|
2442
|
+
torch._check(bool_mask.shape[2] == key_states.shape[2])
|
|
2443
|
+
torch._check(bool_mask.shape[3] == key_states.shape[2])
|
|
2444
|
+
|
|
2445
|
+
attn_output, _ = attention_interface(
|
|
2446
|
+
self,
|
|
2447
|
+
query_states,
|
|
2448
|
+
key_states,
|
|
2449
|
+
value_states,
|
|
2450
|
+
attention_mask=bool_mask,
|
|
2451
|
+
scaling=self.scaling,
|
|
2452
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
2453
|
+
is_causal=False,
|
|
2454
|
+
**kwargs,
|
|
2455
|
+
)
|
|
2456
|
+
else:
|
|
2457
|
+
# Other implementations: Process each chunk separately
|
|
2458
|
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
2459
|
+
splits = [
|
|
2460
|
+
torch.split(tensor, lengths.tolist(), dim=2)
|
|
2461
|
+
for tensor in (query_states, key_states, value_states)
|
|
2462
|
+
]
|
|
2463
|
+
|
|
2464
|
+
attn_outputs = [
|
|
2465
|
+
attention_interface(
|
|
2466
|
+
self,
|
|
2467
|
+
q,
|
|
2468
|
+
k,
|
|
2469
|
+
v,
|
|
2470
|
+
attention_mask=None,
|
|
2471
|
+
scaling=self.scaling,
|
|
2472
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
2473
|
+
is_causal=False,
|
|
2474
|
+
**kwargs,
|
|
2475
|
+
)[0]
|
|
2476
|
+
for q, k, v in zip(*splits)
|
|
2477
|
+
]
|
|
2478
|
+
attn_output = torch.cat(attn_outputs, dim=1)
|
|
2479
|
+
|
|
2480
|
+
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
|
2481
|
+
attn_output = self.proj(attn_output)
|
|
2482
|
+
return attn_output
|
|
1981
2483
|
|
|
1982
2484
|
|
|
1983
2485
|
try:
|