onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +118 -5
- onnx_diagnostic/export/control_flow.py +214 -0
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +118 -25
- onnx_diagnostic/helpers/cache_helper.py +218 -204
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +92 -26
- onnx_diagnostic/helpers/log_helper.py +26 -4
- onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +115 -16
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/rt_helper.py +547 -0
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +108 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,680 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
import onnx
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from ...export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
7
|
+
from .patch_helper import _is_torchdynamo_exporting
|
|
8
|
+
from ._patch_transformers_attention import patched_sdpa_attention_forward
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import transformers.models.qwen2_5_vl
|
|
12
|
+
import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
|
|
13
|
+
|
|
14
|
+
patch_qwen2_5 = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
patch_qwen2_5 = False
|
|
17
|
+
|
|
18
|
+
PLUGS = []
|
|
19
|
+
|
|
20
|
+
if patch_qwen2_5:
|
|
21
|
+
import onnxscript
|
|
22
|
+
|
|
23
|
+
onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
|
|
24
|
+
op = onnxscript.opset22
|
|
25
|
+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
|
|
26
|
+
|
|
27
|
+
@onnxscript.script(opset=onnx_plugs_op)
|
|
28
|
+
def LoopMHAAttention(
|
|
29
|
+
query_states,
|
|
30
|
+
key_states,
|
|
31
|
+
value_states,
|
|
32
|
+
cu_seqlens,
|
|
33
|
+
scaling: float = 0.11180339887498948,
|
|
34
|
+
num_heads: int = 16,
|
|
35
|
+
itype: int = onnx.TensorProto.FLOAT,
|
|
36
|
+
):
|
|
37
|
+
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
|
|
38
|
+
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
39
|
+
output_shape = op.Shape(query_transposed)
|
|
40
|
+
query_3d = op.Reshape(query_transposed, to_3d_shape)
|
|
41
|
+
value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
42
|
+
key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
43
|
+
cu_seqlens = op.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
|
|
44
|
+
num_patches = op.Size(cu_seqlens) - 1
|
|
45
|
+
seq_axis = op.Constant(value_ints=[1])
|
|
46
|
+
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
47
|
+
# attn_output = op.Slice(value_3d, [0], [0], seq_axis)
|
|
48
|
+
seq_attn = op.SequenceEmpty(dtype=itype)
|
|
49
|
+
for i_patch in range(num_patches):
|
|
50
|
+
i_1d = op.Reshape(i_patch, [1])
|
|
51
|
+
i_plus_1_1d = i_1d + 1
|
|
52
|
+
start = op.Gather(cu_seqlens, i_1d, axis=0)
|
|
53
|
+
end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0)
|
|
54
|
+
query_i = op.Slice(query_3d, start, end, seq_axis_int32)
|
|
55
|
+
key_i = op.Slice(key_3d, start, end, seq_axis_int32)
|
|
56
|
+
value_i = op.Slice(value_3d, start, end, seq_axis_int32)
|
|
57
|
+
mha_output = msft_op.MultiHeadAttention(
|
|
58
|
+
query_i,
|
|
59
|
+
key_i,
|
|
60
|
+
value_i,
|
|
61
|
+
num_heads=num_heads,
|
|
62
|
+
scale=scaling,
|
|
63
|
+
)
|
|
64
|
+
# attn_output = op.Concat(attn_output, mha_output, axis=1)
|
|
65
|
+
seq_attn = op.SequenceInsert(seq_attn, mha_output)
|
|
66
|
+
attn_output = op.ConcatFromSequence(seq_attn, axis=1)
|
|
67
|
+
attn_output_4d = op.Reshape(attn_output, output_shape)
|
|
68
|
+
return attn_output_4d
|
|
69
|
+
|
|
70
|
+
def _add_com_microsoft_opset(function_proto):
|
|
71
|
+
opsets = {d.domain: d.version for d in function_proto.opset_import}
|
|
72
|
+
if "com.microsoft" not in opsets:
|
|
73
|
+
d = function_proto.opset_import.add()
|
|
74
|
+
d.domain = "com.microsoft"
|
|
75
|
+
d.version = 1
|
|
76
|
+
return function_proto
|
|
77
|
+
|
|
78
|
+
@onnxscript.script(opset=onnx_plugs_op)
|
|
79
|
+
def PackedAttention(
|
|
80
|
+
query,
|
|
81
|
+
key,
|
|
82
|
+
value,
|
|
83
|
+
cu_seqlens,
|
|
84
|
+
scaling: float = 0.11180339887498948,
|
|
85
|
+
num_heads: int = 16,
|
|
86
|
+
):
|
|
87
|
+
num_patches = op.Cast(op.Size(cu_seqlens), to=onnx.TensorProto.INT32) - 1
|
|
88
|
+
starts = op.Slice(cu_seqlens, [0], [-1], [0])
|
|
89
|
+
ends = op.Slice(cu_seqlens, [1], [9223372036854775807], [0])
|
|
90
|
+
lengths = ends - starts
|
|
91
|
+
max_length = op.ReduceMax(lengths, [0], keepdims=0) # max_seqlen
|
|
92
|
+
rows = op.Range(0, num_patches, 1)
|
|
93
|
+
rows_2d = op.Unsqueeze(rows, [1])
|
|
94
|
+
cols = op.Range(0, max_length, 1)
|
|
95
|
+
cols_2d = op.Unsqueeze(cols, [0])
|
|
96
|
+
|
|
97
|
+
position_matrix = op.Cast(rows_2d, to=onnx.TensorProto.INT32) * op.Cast(
|
|
98
|
+
max_length, to=onnx.TensorProto.INT32
|
|
99
|
+
) + op.Cast(cols_2d, to=onnx.TensorProto.INT32)
|
|
100
|
+
position_matrix_shape = op.Shape(position_matrix)
|
|
101
|
+
token_mask = cols_2d < op.Unsqueeze(lengths, [1])
|
|
102
|
+
token_mask_1d = op.Reshape(token_mask, [-1])
|
|
103
|
+
padded_mask_1d = op.Not(token_mask_1d)
|
|
104
|
+
valid_token_positions = op.Compress(position_matrix, token_mask)
|
|
105
|
+
padded_token_positions = op.Compress(position_matrix, padded_mask_1d)
|
|
106
|
+
token_offset_1d = op.Concat(valid_token_positions, padded_token_positions, axis=0)
|
|
107
|
+
token_offset = op.Reshape(token_offset_1d, position_matrix_shape)
|
|
108
|
+
|
|
109
|
+
query_3d = op.Transpose(op.Squeeze(query, [0]), perm=[1, 0, 2])
|
|
110
|
+
shape_3d = op.Shape(query_3d)
|
|
111
|
+
query_2d = op.Reshape(query_3d, [0, -1])
|
|
112
|
+
key_2d = op.Reshape(op.Transpose(op.Squeeze(key, [0]), perm=[1, 0, 2]), [0, -1])
|
|
113
|
+
value_2d = op.Reshape(op.Transpose(op.Squeeze(value, [0]), perm=[1, 0, 2]), [0, -1])
|
|
114
|
+
|
|
115
|
+
packed_attn_output_2d = msft_op.PackedMultiHeadAttention(
|
|
116
|
+
query_2d,
|
|
117
|
+
key_2d,
|
|
118
|
+
value_2d,
|
|
119
|
+
None,
|
|
120
|
+
op.Cast(token_offset, to=onnx.TensorProto.INT32),
|
|
121
|
+
op.Cast(cu_seqlens, to=onnx.TensorProto.INT32),
|
|
122
|
+
scale=scaling,
|
|
123
|
+
num_heads=num_heads,
|
|
124
|
+
)
|
|
125
|
+
packed_attn_output_3d = op.Reshape(packed_attn_output_2d, shape_3d)
|
|
126
|
+
return op.Unsqueeze(packed_attn_output_3d, [0])
|
|
127
|
+
|
|
128
|
+
def qwen_sdpa_attention(
|
|
129
|
+
query_states: torch.Tensor, # F10s1x16xs47x80
|
|
130
|
+
key_states: torch.Tensor, # F10s1x16xs47x80
|
|
131
|
+
value_states: torch.Tensor, # F10s1x16xs47x80
|
|
132
|
+
cu_seqlens: torch.Tensor, # F7su19
|
|
133
|
+
scaling: float = 0,
|
|
134
|
+
num_heads: int = 16,
|
|
135
|
+
itype: int = onnx.TensorProto.FLOAT,
|
|
136
|
+
) -> torch.Tensor:
|
|
137
|
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
138
|
+
splits = [
|
|
139
|
+
torch.split(tensor, lengths.tolist(), dim=2)
|
|
140
|
+
for tensor in (query_states, key_states, value_states)
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
attn_outputs = [
|
|
144
|
+
patched_sdpa_attention_forward(
|
|
145
|
+
None,
|
|
146
|
+
q,
|
|
147
|
+
k,
|
|
148
|
+
v,
|
|
149
|
+
attention_mask=None,
|
|
150
|
+
scaling=scaling,
|
|
151
|
+
dropout=0.0,
|
|
152
|
+
is_causal=False,
|
|
153
|
+
)[0]
|
|
154
|
+
for q, k, v in zip(*splits)
|
|
155
|
+
]
|
|
156
|
+
attn_output = torch.cat(attn_outputs, dim=1)
|
|
157
|
+
return attn_output
|
|
158
|
+
|
|
159
|
+
# not ideal
|
|
160
|
+
qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx(
|
|
161
|
+
qwen_sdpa_attention,
|
|
162
|
+
lambda qs, *args, **kwargs: torch.empty(
|
|
163
|
+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
164
|
+
dtype=qs.dtype,
|
|
165
|
+
device=qs.device,
|
|
166
|
+
),
|
|
167
|
+
_add_com_microsoft_opset(PackedAttention.to_function_proto()),
|
|
168
|
+
n_inputs=4,
|
|
169
|
+
n_outputs=1,
|
|
170
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
|
|
171
|
+
name="qwen_sdpa_attention_packed",
|
|
172
|
+
)
|
|
173
|
+
PLUGS.append(qwen_sdpa_attention_packed_versatile)
|
|
174
|
+
|
|
175
|
+
qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx(
|
|
176
|
+
qwen_sdpa_attention,
|
|
177
|
+
lambda qs, *args, **kwargs: torch.empty(
|
|
178
|
+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
179
|
+
dtype=qs.dtype,
|
|
180
|
+
device=qs.device,
|
|
181
|
+
),
|
|
182
|
+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
|
|
183
|
+
n_inputs=4,
|
|
184
|
+
n_outputs=1,
|
|
185
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
|
|
186
|
+
name="qwen_sdpa_attention_loopmha",
|
|
187
|
+
)
|
|
188
|
+
PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
|
|
189
|
+
|
|
190
|
+
class patched_Qwen2_5_VLForConditionalGeneration:
|
|
191
|
+
_PATCHES_ = ["prepare_inputs_for_generation"]
|
|
192
|
+
_PATCHED_CLASS_ = (
|
|
193
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def prepare_inputs_for_generation(
|
|
197
|
+
self,
|
|
198
|
+
input_ids,
|
|
199
|
+
past_key_values=None,
|
|
200
|
+
attention_mask=None,
|
|
201
|
+
inputs_embeds=None,
|
|
202
|
+
cache_position=None,
|
|
203
|
+
position_ids=None,
|
|
204
|
+
use_cache=True,
|
|
205
|
+
pixel_values=None,
|
|
206
|
+
pixel_values_videos=None,
|
|
207
|
+
image_grid_thw=None,
|
|
208
|
+
video_grid_thw=None,
|
|
209
|
+
second_per_grid_ts=None,
|
|
210
|
+
**kwargs,
|
|
211
|
+
):
|
|
212
|
+
# Overwritten -- in specific circumstances we don't want to
|
|
213
|
+
# forward image inputs to the model
|
|
214
|
+
from transformers.generation import GenerationMixin
|
|
215
|
+
|
|
216
|
+
model_inputs = GenerationMixin.prepare_inputs_for_generation(
|
|
217
|
+
self,
|
|
218
|
+
input_ids,
|
|
219
|
+
past_key_values=past_key_values,
|
|
220
|
+
attention_mask=attention_mask,
|
|
221
|
+
inputs_embeds=inputs_embeds,
|
|
222
|
+
cache_position=cache_position,
|
|
223
|
+
position_ids=position_ids,
|
|
224
|
+
pixel_values=pixel_values,
|
|
225
|
+
pixel_values_videos=pixel_values_videos,
|
|
226
|
+
image_grid_thw=image_grid_thw,
|
|
227
|
+
video_grid_thw=video_grid_thw,
|
|
228
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
229
|
+
use_cache=use_cache,
|
|
230
|
+
**kwargs,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Qwen2-5-VL position_ids are prepared with rope_deltas
|
|
234
|
+
if position_ids is None:
|
|
235
|
+
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
236
|
+
# When compiling, we can't check tensor values thus we check only input length
|
|
237
|
+
# It is safe to assume that `length!=1` means we're in pre-fill
|
|
238
|
+
# because compiled models currently cannot do assisted decoding
|
|
239
|
+
if cache_position[0] == 0 or self.model.rope_deltas is None:
|
|
240
|
+
vision_positions, rope_deltas = self.model.get_rope_index(
|
|
241
|
+
model_inputs.get("input_ids", None),
|
|
242
|
+
image_grid_thw=image_grid_thw,
|
|
243
|
+
video_grid_thw=video_grid_thw,
|
|
244
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
245
|
+
attention_mask=attention_mask,
|
|
246
|
+
)
|
|
247
|
+
self.model.rope_deltas = rope_deltas
|
|
248
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
249
|
+
elif (
|
|
250
|
+
"position_ids" in model_inputs and model_inputs["position_ids"] is not None
|
|
251
|
+
):
|
|
252
|
+
batch_size, seq_length = model_inputs["position_ids"].shape
|
|
253
|
+
device = model_inputs["position_ids"].device
|
|
254
|
+
position_ids = torch.arange(seq_length, device=device)
|
|
255
|
+
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
256
|
+
delta = cache_position[0] + self.model.rope_deltas
|
|
257
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
258
|
+
vision_positions = position_ids + delta.expand_as(position_ids)
|
|
259
|
+
|
|
260
|
+
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
|
261
|
+
if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
|
|
262
|
+
text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
|
|
263
|
+
None, None, :
|
|
264
|
+
]
|
|
265
|
+
else:
|
|
266
|
+
text_positions = model_inputs["position_ids"][None, ...]
|
|
267
|
+
# text_positions = model_inputs["position_ids"][None, ...]
|
|
268
|
+
assert vision_positions is not None, "vision_positions are missing"
|
|
269
|
+
model_inputs["position_ids"] = torch.cat(
|
|
270
|
+
[text_positions, vision_positions], dim=0
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if cache_position[0] != 0:
|
|
274
|
+
model_inputs["pixel_values"] = None
|
|
275
|
+
model_inputs["pixel_values_videos"] = None
|
|
276
|
+
|
|
277
|
+
return model_inputs
|
|
278
|
+
|
|
279
|
+
class patched_Qwen2_5_VisionTransformerPretrainedModel:
|
|
280
|
+
_PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
|
|
281
|
+
_PATCHED_CLASS_ = (
|
|
282
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
def rot_pos_emb(self, grid_thw):
|
|
286
|
+
pos_ids = []
|
|
287
|
+
for thw_ in grid_thw:
|
|
288
|
+
# PATCHED: avoid unbind
|
|
289
|
+
t = thw_[0]
|
|
290
|
+
h = thw_[1]
|
|
291
|
+
w = thw_[2]
|
|
292
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
293
|
+
hpos_ids = hpos_ids.reshape(
|
|
294
|
+
h // self.spatial_merge_size,
|
|
295
|
+
self.spatial_merge_size,
|
|
296
|
+
w // self.spatial_merge_size,
|
|
297
|
+
self.spatial_merge_size,
|
|
298
|
+
)
|
|
299
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
300
|
+
hpos_ids = hpos_ids.flatten()
|
|
301
|
+
|
|
302
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
303
|
+
wpos_ids = wpos_ids.reshape(
|
|
304
|
+
h // self.spatial_merge_size,
|
|
305
|
+
self.spatial_merge_size,
|
|
306
|
+
w // self.spatial_merge_size,
|
|
307
|
+
self.spatial_merge_size,
|
|
308
|
+
)
|
|
309
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
310
|
+
wpos_ids = wpos_ids.flatten()
|
|
311
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
312
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
|
313
|
+
max_grid_size = grid_thw[:, 1:].max()
|
|
314
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
315
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
316
|
+
return rotary_pos_emb
|
|
317
|
+
|
|
318
|
+
def get_window_index(self, grid_thw):
|
|
319
|
+
window_index: list = [] # type: ignore[annotation-unchecked]
|
|
320
|
+
# PATCHED
|
|
321
|
+
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
|
|
322
|
+
window_index_id = 0
|
|
323
|
+
vit_merger_window_size = (
|
|
324
|
+
self.window_size // self.spatial_merge_size // self.patch_size
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
for _thw in grid_thw:
|
|
328
|
+
# PATCHED: avoid unbind
|
|
329
|
+
grid_t = _thw[0]
|
|
330
|
+
grid_h = _thw[1]
|
|
331
|
+
grid_w = _thw[2]
|
|
332
|
+
llm_grid_h, llm_grid_w = (
|
|
333
|
+
grid_h // self.spatial_merge_size,
|
|
334
|
+
grid_w // self.spatial_merge_size,
|
|
335
|
+
)
|
|
336
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
|
337
|
+
grid_t, llm_grid_h, llm_grid_w
|
|
338
|
+
)
|
|
339
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
340
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
341
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
342
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
343
|
+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
|
344
|
+
index_padded = index_padded.reshape(
|
|
345
|
+
grid_t,
|
|
346
|
+
num_windows_h,
|
|
347
|
+
vit_merger_window_size,
|
|
348
|
+
num_windows_w,
|
|
349
|
+
vit_merger_window_size,
|
|
350
|
+
)
|
|
351
|
+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
|
352
|
+
grid_t,
|
|
353
|
+
num_windows_h * num_windows_w,
|
|
354
|
+
vit_merger_window_size,
|
|
355
|
+
vit_merger_window_size,
|
|
356
|
+
)
|
|
357
|
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
358
|
+
index_padded = index_padded.reshape(-1)
|
|
359
|
+
index_new = index_padded[index_padded != -100]
|
|
360
|
+
window_index.append(index_new + window_index_id)
|
|
361
|
+
cu_seqlens_tmp = (
|
|
362
|
+
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
|
|
363
|
+
)
|
|
364
|
+
# PATCHED
|
|
365
|
+
# cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
366
|
+
cu_window_seqlens.append(cu_seqlens_tmp)
|
|
367
|
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
368
|
+
window_index = torch.cat(window_index, dim=0)
|
|
369
|
+
|
|
370
|
+
return window_index, torch.cat(cu_window_seqlens, dim=0)
|
|
371
|
+
|
|
372
|
+
def forward(
|
|
373
|
+
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
|
|
374
|
+
) -> torch.Tensor:
|
|
375
|
+
"""
|
|
376
|
+
Args:
|
|
377
|
+
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
378
|
+
The final hidden states of the model.
|
|
379
|
+
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
|
380
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
`torch.Tensor`: hidden_states.
|
|
384
|
+
"""
|
|
385
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
386
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
387
|
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
388
|
+
# PATCHED
|
|
389
|
+
# cu_window_seqlens = torch.tensor(
|
|
390
|
+
# cu_window_seqlens,
|
|
391
|
+
# device=hidden_states.device,
|
|
392
|
+
# dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
393
|
+
# )
|
|
394
|
+
cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
|
|
395
|
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
|
396
|
+
|
|
397
|
+
seq_len, _ = hidden_states.size()
|
|
398
|
+
hidden_states = hidden_states.reshape(
|
|
399
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
400
|
+
)
|
|
401
|
+
hidden_states = hidden_states[window_index, :, :]
|
|
402
|
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
403
|
+
rotary_pos_emb = rotary_pos_emb.reshape(
|
|
404
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
405
|
+
)
|
|
406
|
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
407
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
408
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
409
|
+
position_embeddings = (emb.cos(), emb.sin())
|
|
410
|
+
|
|
411
|
+
cu_seqlens = torch.repeat_interleave(
|
|
412
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
|
413
|
+
).cumsum(
|
|
414
|
+
dim=0,
|
|
415
|
+
# Select dtype based on the following factors:
|
|
416
|
+
# - FA2 requires that cu_seqlens_q must have dtype int32
|
|
417
|
+
# - torch.onnx.export requires that cu_seqlens_q must have same dtype
|
|
418
|
+
# as grid_thw
|
|
419
|
+
# See https://github.com/huggingface/transformers/pull/34852
|
|
420
|
+
# for more information
|
|
421
|
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
422
|
+
)
|
|
423
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
424
|
+
|
|
425
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
426
|
+
if layer_num in self.fullatt_block_indexes:
|
|
427
|
+
cu_seqlens_now = cu_seqlens
|
|
428
|
+
else:
|
|
429
|
+
cu_seqlens_now = cu_window_seqlens
|
|
430
|
+
|
|
431
|
+
hidden_states = blk(
|
|
432
|
+
hidden_states,
|
|
433
|
+
cu_seqlens=cu_seqlens_now,
|
|
434
|
+
position_embeddings=position_embeddings,
|
|
435
|
+
**kwargs,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
hidden_states = self.merger(hidden_states)
|
|
439
|
+
reverse_indices = torch.argsort(window_index)
|
|
440
|
+
hidden_states = hidden_states[reverse_indices, :]
|
|
441
|
+
return hidden_states
|
|
442
|
+
|
|
443
|
+
class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
|
|
444
|
+
def forward(
|
|
445
|
+
self,
|
|
446
|
+
start_end,
|
|
447
|
+
query_states,
|
|
448
|
+
key_states,
|
|
449
|
+
value_states,
|
|
450
|
+
scaling: float = 1.0,
|
|
451
|
+
dropout: float = 0.0,
|
|
452
|
+
**kwargs,
|
|
453
|
+
):
|
|
454
|
+
a = start_end[0].item()
|
|
455
|
+
b = start_end[1].item()
|
|
456
|
+
q = query_states[:, :, a:b, :]
|
|
457
|
+
k = key_states[:, :, a:b, :]
|
|
458
|
+
v = value_states[:, :, a:b, :]
|
|
459
|
+
return patched_sdpa_attention_forward(
|
|
460
|
+
self,
|
|
461
|
+
q,
|
|
462
|
+
k,
|
|
463
|
+
v,
|
|
464
|
+
attention_mask=None,
|
|
465
|
+
scaling=scaling,
|
|
466
|
+
dropout=dropout,
|
|
467
|
+
is_causal=False,
|
|
468
|
+
**kwargs,
|
|
469
|
+
)[0]
|
|
470
|
+
|
|
471
|
+
class patched_Qwen2_5_VLVisionAttention:
|
|
472
|
+
_PATCHES_ = ["forward"]
|
|
473
|
+
_PATCHED_CLASS_ = (
|
|
474
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
|
|
475
|
+
)
|
|
476
|
+
STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731
|
|
477
|
+
"QWEN25ATTENTION", "PACKED"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
def forward(
|
|
481
|
+
self,
|
|
482
|
+
hidden_states: torch.Tensor,
|
|
483
|
+
cu_seqlens: torch.Tensor,
|
|
484
|
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
485
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
486
|
+
**kwargs,
|
|
487
|
+
) -> torch.Tensor:
|
|
488
|
+
seq_length = hidden_states.shape[0]
|
|
489
|
+
# PATCHED: avoid the use of unbind
|
|
490
|
+
qkv = (
|
|
491
|
+
self.qkv(hidden_states)
|
|
492
|
+
.reshape(seq_length, 3, self.num_heads, -1)
|
|
493
|
+
.permute(1, 0, 2, 3)
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
497
|
+
cos, sin = position_embeddings
|
|
498
|
+
|
|
499
|
+
# This part should be moved into the loop
|
|
500
|
+
# iteration to enable fusion inside the loop.
|
|
501
|
+
query_states, key_states = (
|
|
502
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
|
|
503
|
+
query_states, key_states, cos, sin
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
|
508
|
+
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
|
509
|
+
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
|
510
|
+
|
|
511
|
+
attention_interface: Callable = (
|
|
512
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
|
|
513
|
+
)
|
|
514
|
+
if self.config._attn_implementation != "eager":
|
|
515
|
+
# PATCHED
|
|
516
|
+
# attention_interface = ALL_ATTENTION_FUNCTIONS[
|
|
517
|
+
# self.config._attn_implementation]
|
|
518
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
519
|
+
self.config._attn_implementation
|
|
520
|
+
]
|
|
521
|
+
|
|
522
|
+
is_sdpa = (
|
|
523
|
+
attention_interface
|
|
524
|
+
is transformers.integrations.sdpa_attention.sdpa_attention_forward
|
|
525
|
+
or attention_interface is patched_sdpa_attention_forward
|
|
526
|
+
)
|
|
527
|
+
attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
|
|
528
|
+
if is_sdpa and attention_strategy in "PACKED":
|
|
529
|
+
attn_output = qwen_sdpa_attention_packed_versatile(
|
|
530
|
+
query_states,
|
|
531
|
+
key_states,
|
|
532
|
+
value_states,
|
|
533
|
+
cu_seqlens,
|
|
534
|
+
self.scaling,
|
|
535
|
+
self.num_heads,
|
|
536
|
+
)
|
|
537
|
+
elif _is_torchdynamo_exporting():
|
|
538
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
539
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
540
|
+
attn_output = torch.onnx.ops.symbolic(
|
|
541
|
+
"custom::qwen25_flash_attention",
|
|
542
|
+
(
|
|
543
|
+
query_states,
|
|
544
|
+
key_states,
|
|
545
|
+
value_states,
|
|
546
|
+
cu_seqlens,
|
|
547
|
+
cu_seqlens,
|
|
548
|
+
max_seqlen,
|
|
549
|
+
max_seqlen,
|
|
550
|
+
torch.tensor(self.scaling, dtype=torch.float32),
|
|
551
|
+
),
|
|
552
|
+
dtype=query_states.dtype,
|
|
553
|
+
shape=(
|
|
554
|
+
query_states.shape[0], # batch_size
|
|
555
|
+
query_states.shape[2], # sequence_length (total patches)
|
|
556
|
+
query_states.shape[1], # num_heads
|
|
557
|
+
query_states.shape[3], # head_size
|
|
558
|
+
),
|
|
559
|
+
version=1,
|
|
560
|
+
)
|
|
561
|
+
elif is_sdpa and attention_strategy == "LOOPMHA":
|
|
562
|
+
attn_output = qwen_sdpa_attention_loopmha_versatile(
|
|
563
|
+
query_states,
|
|
564
|
+
key_states,
|
|
565
|
+
value_states,
|
|
566
|
+
cu_seqlens,
|
|
567
|
+
self.scaling,
|
|
568
|
+
self.num_heads,
|
|
569
|
+
(
|
|
570
|
+
onnx.TensorProto.FLOAT
|
|
571
|
+
if query_states.dtype == torch.float32
|
|
572
|
+
else (
|
|
573
|
+
onnx.TensorProto.FLOAT16
|
|
574
|
+
if query_states.dtype == torch.float16
|
|
575
|
+
else onnx.TensorProto.BFLOAT16
|
|
576
|
+
)
|
|
577
|
+
),
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
# to rewrite later with a for loop
|
|
581
|
+
# def _iteration(start_end, query_states, key_states, value_states):
|
|
582
|
+
# return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
|
|
583
|
+
# self,
|
|
584
|
+
# start_end,
|
|
585
|
+
# query_states,
|
|
586
|
+
# key_states,
|
|
587
|
+
# value_states,
|
|
588
|
+
# scaling=self.scaling,
|
|
589
|
+
# dropout=0.0 if not self.training else self.attention_dropout,
|
|
590
|
+
# )
|
|
591
|
+
|
|
592
|
+
# starts = cu_seqlens[:-1]
|
|
593
|
+
# ends = cu_seqlens[1:]
|
|
594
|
+
# torch._check(starts.shape[0] > 0)
|
|
595
|
+
# torch._check(ends.shape[0] > 0)
|
|
596
|
+
# starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
|
|
597
|
+
# attn_outputs = [
|
|
598
|
+
# _iteration(start_end, query_states, key_states, value_states)
|
|
599
|
+
# for start_end in starts_ends
|
|
600
|
+
# ]
|
|
601
|
+
# attn_output = torch.cat(attn_outputs, dim=1)
|
|
602
|
+
elif is_sdpa and attention_strategy == "BIGMASK":
|
|
603
|
+
# make square mask
|
|
604
|
+
indices = torch.arange(
|
|
605
|
+
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
|
606
|
+
)
|
|
607
|
+
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
|
|
608
|
+
cu_seqlens.dtype
|
|
609
|
+
)
|
|
610
|
+
dot = dot.sum(dim=0)
|
|
611
|
+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
|
|
612
|
+
bool_mask = mask == 0
|
|
613
|
+
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
|
|
614
|
+
|
|
615
|
+
torch._check(bool_mask.shape[2] == key_states.shape[2])
|
|
616
|
+
torch._check(bool_mask.shape[3] == key_states.shape[2])
|
|
617
|
+
|
|
618
|
+
attn_output, _ = attention_interface(
|
|
619
|
+
self,
|
|
620
|
+
query_states,
|
|
621
|
+
key_states,
|
|
622
|
+
value_states,
|
|
623
|
+
attention_mask=bool_mask,
|
|
624
|
+
scaling=self.scaling,
|
|
625
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
626
|
+
is_causal=False,
|
|
627
|
+
**kwargs,
|
|
628
|
+
)
|
|
629
|
+
else:
|
|
630
|
+
raise NotImplementedError(
|
|
631
|
+
f"No corresponding export strategy for "
|
|
632
|
+
f"{attention_strategy!r}, "
|
|
633
|
+
f"(use QWEN25ATTENTION to change it), and attention_interface="
|
|
634
|
+
f"{attention_interface!r} (use sdpa)"
|
|
635
|
+
)
|
|
636
|
+
elif self.config._attn_implementation == "flash_attention_2":
|
|
637
|
+
# Flash Attention 2: Use cu_seqlens for variable length attention
|
|
638
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
639
|
+
attn_output, _ = attention_interface(
|
|
640
|
+
self,
|
|
641
|
+
query_states,
|
|
642
|
+
key_states,
|
|
643
|
+
value_states,
|
|
644
|
+
attention_mask=None,
|
|
645
|
+
scaling=self.scaling,
|
|
646
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
647
|
+
cu_seq_lens_q=cu_seqlens,
|
|
648
|
+
cu_seq_lens_k=cu_seqlens,
|
|
649
|
+
max_length_q=max_seqlen,
|
|
650
|
+
max_length_k=max_seqlen,
|
|
651
|
+
is_causal=False,
|
|
652
|
+
**kwargs,
|
|
653
|
+
)
|
|
654
|
+
else:
|
|
655
|
+
# Other implementations: Process each chunk separately
|
|
656
|
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
657
|
+
splits = [
|
|
658
|
+
torch.split(tensor, lengths.tolist(), dim=2)
|
|
659
|
+
for tensor in (query_states, key_states, value_states)
|
|
660
|
+
]
|
|
661
|
+
|
|
662
|
+
attn_outputs = [
|
|
663
|
+
attention_interface(
|
|
664
|
+
self,
|
|
665
|
+
q,
|
|
666
|
+
k,
|
|
667
|
+
v,
|
|
668
|
+
attention_mask=None,
|
|
669
|
+
scaling=self.scaling,
|
|
670
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
671
|
+
is_causal=False,
|
|
672
|
+
**kwargs,
|
|
673
|
+
)[0]
|
|
674
|
+
for q, k, v in zip(*splits)
|
|
675
|
+
]
|
|
676
|
+
attn_output = torch.cat(attn_outputs, dim=1)
|
|
677
|
+
|
|
678
|
+
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
|
679
|
+
attn_output = self.proj(attn_output)
|
|
680
|
+
return attn_output
|