onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Callable, Optional, Tuple
|
|
3
|
+
import onnx
|
|
4
|
+
import onnx.helper as oh
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from ...export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
8
|
+
from ...helpers.torch_helper import torch_dtype_to_onnx_dtype
|
|
9
|
+
from .patch_helper import _is_torchdynamo_exporting
|
|
10
|
+
from ._patch_transformers_attention import patched_sdpa_attention_forward
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import transformers.models.qwen2_5_vl
|
|
14
|
+
import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
|
|
15
|
+
|
|
16
|
+
patch_qwen2_5 = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
patch_qwen2_5 = False
|
|
19
|
+
|
|
20
|
+
PLUGS = []
|
|
21
|
+
|
|
22
|
+
if patch_qwen2_5:
|
|
23
|
+
import onnxscript
|
|
24
|
+
|
|
25
|
+
onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
|
|
26
|
+
op = onnxscript.opset22
|
|
27
|
+
op24 = onnxscript.onnx_opset.opset24
|
|
28
|
+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
|
|
29
|
+
STOPAT = (
|
|
30
|
+
int(os.environ.get("STOPAT", None))
|
|
31
|
+
if os.environ.get("STOPAT", None) is not None
|
|
32
|
+
else None
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto:
|
|
36
|
+
opsets = {d.domain: d.version for d in function_proto.opset_import}
|
|
37
|
+
if "com.microsoft" not in opsets:
|
|
38
|
+
d = function_proto.opset_import.add()
|
|
39
|
+
d.domain = "com.microsoft"
|
|
40
|
+
d.version = 1
|
|
41
|
+
return function_proto
|
|
42
|
+
|
|
43
|
+
def _update_sequence_type(
|
|
44
|
+
itype: int, function_proto: onnx.FunctionProto
|
|
45
|
+
) -> onnx.FunctionProto:
|
|
46
|
+
proto = oh.make_function(
|
|
47
|
+
function_proto.domain,
|
|
48
|
+
function_proto.name,
|
|
49
|
+
function_proto.input,
|
|
50
|
+
function_proto.output,
|
|
51
|
+
[
|
|
52
|
+
(
|
|
53
|
+
oh.make_node("SequenceEmpty", node.input, node.output, dtype=itype)
|
|
54
|
+
if node.op_type == "SequenceEmpty"
|
|
55
|
+
else node
|
|
56
|
+
)
|
|
57
|
+
for node in function_proto.node
|
|
58
|
+
],
|
|
59
|
+
attributes=function_proto.attribute,
|
|
60
|
+
attribute_protos=function_proto.attribute_proto,
|
|
61
|
+
opset_imports=function_proto.opset_import,
|
|
62
|
+
)
|
|
63
|
+
return proto
|
|
64
|
+
|
|
65
|
+
@onnxscript.script(opset=onnx_plugs_op)
|
|
66
|
+
def LoopMHAAttention(
|
|
67
|
+
query_states,
|
|
68
|
+
key_states,
|
|
69
|
+
value_states,
|
|
70
|
+
cu_seqlens,
|
|
71
|
+
scaling: float = 0.11180339887498948,
|
|
72
|
+
num_heads: int = 16,
|
|
73
|
+
):
|
|
74
|
+
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
|
|
75
|
+
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
76
|
+
output_shape = op.Shape(query_transposed)
|
|
77
|
+
query_3d = op.Reshape(query_transposed, to_3d_shape)
|
|
78
|
+
value_3d = op.Reshape(op.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
79
|
+
key_3d = op.Reshape(op.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
80
|
+
cu_seqlens = op.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
|
|
81
|
+
num_patches = op.Size(cu_seqlens) - 1
|
|
82
|
+
seq_axis = op.Constant(value_ints=[1])
|
|
83
|
+
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
84
|
+
# attn_output = op.Slice(value_3d, [0], [0], seq_axis)
|
|
85
|
+
seq_attn = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
86
|
+
for i_patch in range(num_patches):
|
|
87
|
+
i_1d = op.Reshape(i_patch, [1])
|
|
88
|
+
i_plus_1_1d = i_1d + 1
|
|
89
|
+
start = op.Gather(cu_seqlens, i_1d, axis=0)
|
|
90
|
+
end = op.Gather(cu_seqlens, i_plus_1_1d, axis=0)
|
|
91
|
+
query_i = op.Slice(query_3d, start, end, seq_axis_int32)
|
|
92
|
+
key_i = op.Slice(key_3d, start, end, seq_axis_int32)
|
|
93
|
+
value_i = op.Slice(value_3d, start, end, seq_axis_int32)
|
|
94
|
+
mha_output = msft_op.MultiHeadAttention(
|
|
95
|
+
query_i, key_i, value_i, num_heads=num_heads, scale=scaling
|
|
96
|
+
)
|
|
97
|
+
# attn_output = op.Concat(attn_output, mha_output, axis=1)
|
|
98
|
+
seq_attn = op.SequenceInsert(seq_attn, mha_output)
|
|
99
|
+
attn_output = op.ConcatFromSequence(seq_attn, axis=1)
|
|
100
|
+
attn_output_4d = op.Reshape(attn_output, output_shape)
|
|
101
|
+
return attn_output_4d
|
|
102
|
+
|
|
103
|
+
@onnxscript.script(opset=onnx_plugs_op)
|
|
104
|
+
def LoopAttention24(
|
|
105
|
+
query_states,
|
|
106
|
+
key_states,
|
|
107
|
+
value_states,
|
|
108
|
+
cu_seqlens,
|
|
109
|
+
scaling: float = 0.11180339887498948,
|
|
110
|
+
num_heads: int = 16,
|
|
111
|
+
):
|
|
112
|
+
to_3d_shape = op24.Constant(value_ints=[0, 0, -1])
|
|
113
|
+
query_transposed = op24.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
114
|
+
output_shape = op24.Shape(query_transposed)
|
|
115
|
+
query_3d = op24.Reshape(query_transposed, to_3d_shape)
|
|
116
|
+
value_3d = op24.Reshape(op24.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
117
|
+
key_3d = op24.Reshape(op24.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
118
|
+
cu_seqlens = op24.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
|
|
119
|
+
num_patches = op24.Size(cu_seqlens) - 1
|
|
120
|
+
seq_axis = op24.Constant(value_ints=[1])
|
|
121
|
+
seq_axis_int32 = op24.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
122
|
+
seq_attn = op24.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
123
|
+
for i_patch in range(num_patches):
|
|
124
|
+
i_1d = op24.Reshape(i_patch, [1])
|
|
125
|
+
i_plus_1_1d = i_1d + 1
|
|
126
|
+
start = op24.Gather(cu_seqlens, i_1d, axis=0)
|
|
127
|
+
end = op24.Gather(cu_seqlens, i_plus_1_1d, axis=0)
|
|
128
|
+
query_i = op24.Slice(query_3d, start, end, seq_axis_int32)
|
|
129
|
+
key_i = op24.Slice(key_3d, start, end, seq_axis_int32)
|
|
130
|
+
value_i = op24.Slice(value_3d, start, end, seq_axis_int32)
|
|
131
|
+
mha_output = op24.Attention(
|
|
132
|
+
query_i,
|
|
133
|
+
key_i,
|
|
134
|
+
value_i,
|
|
135
|
+
scale=scaling,
|
|
136
|
+
q_num_heads=num_heads,
|
|
137
|
+
kv_num_heads=num_heads,
|
|
138
|
+
softmax_precision=onnx.TensorProto.FLOAT,
|
|
139
|
+
)
|
|
140
|
+
seq_attn = op24.SequenceInsert(seq_attn, mha_output)
|
|
141
|
+
attn_output = op24.ConcatFromSequence(seq_attn, axis=1)
|
|
142
|
+
attn_output_4d = op24.Reshape(attn_output, output_shape)
|
|
143
|
+
return attn_output_4d
|
|
144
|
+
|
|
145
|
+
@onnxscript.script(opset=onnx_plugs_op)
|
|
146
|
+
def PackedAttention(
|
|
147
|
+
query,
|
|
148
|
+
key,
|
|
149
|
+
value,
|
|
150
|
+
cu_seqlens,
|
|
151
|
+
scaling: float = 0.11180339887498948,
|
|
152
|
+
num_heads: int = 16,
|
|
153
|
+
):
|
|
154
|
+
num_patches = op.Cast(op.Size(cu_seqlens), to=onnx.TensorProto.INT32) - 1
|
|
155
|
+
starts = op.Slice(cu_seqlens, [0], [-1], [0])
|
|
156
|
+
ends = op.Slice(cu_seqlens, [1], [9223372036854775807], [0])
|
|
157
|
+
lengths = ends - starts
|
|
158
|
+
max_length = op.ReduceMax(lengths, [0], keepdims=0) # max_seqlen
|
|
159
|
+
rows = op.Range(0, num_patches, 1)
|
|
160
|
+
rows_2d = op.Unsqueeze(rows, [1])
|
|
161
|
+
cols = op.Range(0, max_length, 1)
|
|
162
|
+
cols_2d = op.Unsqueeze(cols, [0])
|
|
163
|
+
|
|
164
|
+
position_matrix = op.Cast(rows_2d, to=onnx.TensorProto.INT32) * op.Cast(
|
|
165
|
+
max_length, to=onnx.TensorProto.INT32
|
|
166
|
+
) + op.Cast(cols_2d, to=onnx.TensorProto.INT32)
|
|
167
|
+
position_matrix_shape = op.Shape(position_matrix)
|
|
168
|
+
token_mask = cols_2d < op.Unsqueeze(lengths, [1])
|
|
169
|
+
token_mask_1d = op.Reshape(token_mask, [-1])
|
|
170
|
+
padded_mask_1d = op.Not(token_mask_1d)
|
|
171
|
+
valid_token_positions = op.Compress(position_matrix, token_mask)
|
|
172
|
+
padded_token_positions = op.Compress(position_matrix, padded_mask_1d)
|
|
173
|
+
token_offset_1d = op.Concat(valid_token_positions, padded_token_positions, axis=0)
|
|
174
|
+
token_offset = op.Reshape(token_offset_1d, position_matrix_shape)
|
|
175
|
+
|
|
176
|
+
query_3d = op.Transpose(op.Squeeze(query, [0]), perm=[1, 0, 2])
|
|
177
|
+
shape_3d = op.Shape(query_3d)
|
|
178
|
+
query_2d = op.Reshape(query_3d, [0, -1])
|
|
179
|
+
key_2d = op.Reshape(op.Transpose(op.Squeeze(key, [0]), perm=[1, 0, 2]), [0, -1])
|
|
180
|
+
value_2d = op.Reshape(op.Transpose(op.Squeeze(value, [0]), perm=[1, 0, 2]), [0, -1])
|
|
181
|
+
|
|
182
|
+
packed_attn_output_2d = msft_op.PackedMultiHeadAttention(
|
|
183
|
+
query_2d,
|
|
184
|
+
key_2d,
|
|
185
|
+
value_2d,
|
|
186
|
+
None,
|
|
187
|
+
op.Cast(token_offset, to=onnx.TensorProto.INT32),
|
|
188
|
+
op.Cast(cu_seqlens, to=onnx.TensorProto.INT32),
|
|
189
|
+
scale=scaling,
|
|
190
|
+
num_heads=num_heads,
|
|
191
|
+
)
|
|
192
|
+
packed_attn_output_3d = op.Reshape(packed_attn_output_2d, shape_3d)
|
|
193
|
+
return op.Unsqueeze(packed_attn_output_3d, [0])
|
|
194
|
+
|
|
195
|
+
def qwen_sdpa_attention(
|
|
196
|
+
query_states: torch.Tensor, # F10s1x16xs47x80
|
|
197
|
+
key_states: torch.Tensor, # F10s1x16xs47x80
|
|
198
|
+
value_states: torch.Tensor, # F10s1x16xs47x80
|
|
199
|
+
cu_seqlens: torch.Tensor, # F7su19
|
|
200
|
+
scaling: float = 0,
|
|
201
|
+
num_heads: int = 16,
|
|
202
|
+
) -> torch.Tensor:
|
|
203
|
+
"""
|
|
204
|
+
The loop can be removed with the following code
|
|
205
|
+
but it hits memory overflow for big inputs.
|
|
206
|
+
|
|
207
|
+
.. code-block:: python
|
|
208
|
+
|
|
209
|
+
# make square mask
|
|
210
|
+
indices = torch.arange(
|
|
211
|
+
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
|
212
|
+
)
|
|
213
|
+
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
|
|
214
|
+
cu_seqlens.dtype
|
|
215
|
+
)
|
|
216
|
+
dot = dot.sum(dim=0)
|
|
217
|
+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
|
|
218
|
+
bool_mask = mask == 0
|
|
219
|
+
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
|
|
220
|
+
|
|
221
|
+
torch._check(bool_mask.shape[2] == key_states.shape[2])
|
|
222
|
+
torch._check(bool_mask.shape[3] == key_states.shape[2])
|
|
223
|
+
|
|
224
|
+
attn_output, _ = attention_interface(
|
|
225
|
+
self,
|
|
226
|
+
query_states,
|
|
227
|
+
key_states,
|
|
228
|
+
value_states,
|
|
229
|
+
attention_mask=bool_mask,
|
|
230
|
+
scaling=self.scaling,
|
|
231
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
232
|
+
is_causal=False,
|
|
233
|
+
**kwargs,
|
|
234
|
+
)
|
|
235
|
+
"""
|
|
236
|
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
237
|
+
splits = [
|
|
238
|
+
torch.split(tensor, lengths.tolist(), dim=2)
|
|
239
|
+
for tensor in (query_states, key_states, value_states)
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
attn_outputs = [
|
|
243
|
+
patched_sdpa_attention_forward(
|
|
244
|
+
None,
|
|
245
|
+
q,
|
|
246
|
+
k,
|
|
247
|
+
v,
|
|
248
|
+
attention_mask=None,
|
|
249
|
+
scaling=scaling,
|
|
250
|
+
dropout=0.0,
|
|
251
|
+
is_causal=False,
|
|
252
|
+
)[0]
|
|
253
|
+
for q, k, v in zip(*splits)
|
|
254
|
+
]
|
|
255
|
+
attn_output = torch.cat(attn_outputs, dim=1)
|
|
256
|
+
return attn_output
|
|
257
|
+
|
|
258
|
+
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
259
|
+
first_tensor = next(a for a in args if a is not None)
|
|
260
|
+
dtype = first_tensor.dtype
|
|
261
|
+
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
|
|
262
|
+
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
263
|
+
if strategy is not None:
|
|
264
|
+
return strategy, itype
|
|
265
|
+
if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
|
|
266
|
+
if opset >= 24:
|
|
267
|
+
return "LOOPA24", itype
|
|
268
|
+
return "LOOPMHA", itype
|
|
269
|
+
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
|
|
270
|
+
# first_tensor may be a SymbolicTensor (onnx).
|
|
271
|
+
# is_cuda is not available.
|
|
272
|
+
if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
|
|
273
|
+
return "PACKED", itype
|
|
274
|
+
return "LOOPMHA", itype
|
|
275
|
+
raise AssertionError(
|
|
276
|
+
f"Unable to handle type {torch.dtype} (itype={itype}) "
|
|
277
|
+
f"on device {torch.device} with opset={opset}"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
|
|
281
|
+
qwen_sdpa_attention,
|
|
282
|
+
lambda qs, *args, **kwargs: torch.empty(
|
|
283
|
+
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
284
|
+
dtype=qs.dtype,
|
|
285
|
+
device=qs.device,
|
|
286
|
+
),
|
|
287
|
+
{
|
|
288
|
+
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
289
|
+
PackedAttention.to_function_proto()
|
|
290
|
+
),
|
|
291
|
+
("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
|
|
292
|
+
("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
293
|
+
onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
|
|
294
|
+
),
|
|
295
|
+
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
296
|
+
LoopMHAAttention.to_function_proto()
|
|
297
|
+
),
|
|
298
|
+
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
299
|
+
onnx.TensorProto.FLOAT16,
|
|
300
|
+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
|
|
301
|
+
),
|
|
302
|
+
},
|
|
303
|
+
n_inputs=4,
|
|
304
|
+
n_outputs=1,
|
|
305
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
|
|
306
|
+
name="qwen_sdpa_attention_versatile",
|
|
307
|
+
version_selector=qwen_version_selector,
|
|
308
|
+
)
|
|
309
|
+
PLUGS.append(qwen_sdpa_attention_versatile)
|
|
310
|
+
|
|
311
|
+
class patched_Qwen2_5_VLForConditionalGeneration:
|
|
312
|
+
_PATCHES_ = ["prepare_inputs_for_generation"]
|
|
313
|
+
_PATCHED_CLASS_ = (
|
|
314
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def prepare_inputs_for_generation(
|
|
318
|
+
self,
|
|
319
|
+
input_ids,
|
|
320
|
+
past_key_values=None,
|
|
321
|
+
attention_mask=None,
|
|
322
|
+
inputs_embeds=None,
|
|
323
|
+
cache_position=None,
|
|
324
|
+
position_ids=None,
|
|
325
|
+
use_cache=True,
|
|
326
|
+
pixel_values=None,
|
|
327
|
+
pixel_values_videos=None,
|
|
328
|
+
image_grid_thw=None,
|
|
329
|
+
video_grid_thw=None,
|
|
330
|
+
second_per_grid_ts=None,
|
|
331
|
+
**kwargs,
|
|
332
|
+
):
|
|
333
|
+
# Overwritten -- in specific circumstances we don't want to
|
|
334
|
+
# forward image inputs to the model
|
|
335
|
+
from transformers.generation import GenerationMixin
|
|
336
|
+
|
|
337
|
+
model_inputs = GenerationMixin.prepare_inputs_for_generation(
|
|
338
|
+
self,
|
|
339
|
+
input_ids,
|
|
340
|
+
past_key_values=past_key_values,
|
|
341
|
+
attention_mask=attention_mask,
|
|
342
|
+
inputs_embeds=inputs_embeds,
|
|
343
|
+
cache_position=cache_position,
|
|
344
|
+
position_ids=position_ids,
|
|
345
|
+
pixel_values=pixel_values,
|
|
346
|
+
pixel_values_videos=pixel_values_videos,
|
|
347
|
+
image_grid_thw=image_grid_thw,
|
|
348
|
+
video_grid_thw=video_grid_thw,
|
|
349
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
350
|
+
use_cache=use_cache,
|
|
351
|
+
**kwargs,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Qwen2-5-VL position_ids are prepared with rope_deltas
|
|
355
|
+
if position_ids is None:
|
|
356
|
+
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
357
|
+
# When compiling, we can't check tensor values thus we check only input length
|
|
358
|
+
# It is safe to assume that `length!=1` means we're in pre-fill
|
|
359
|
+
# because compiled models currently cannot do assisted decoding
|
|
360
|
+
if cache_position[0] == 0 or self.model.rope_deltas is None:
|
|
361
|
+
vision_positions, rope_deltas = self.model.get_rope_index(
|
|
362
|
+
model_inputs.get("input_ids", None),
|
|
363
|
+
image_grid_thw=image_grid_thw,
|
|
364
|
+
video_grid_thw=video_grid_thw,
|
|
365
|
+
second_per_grid_ts=second_per_grid_ts,
|
|
366
|
+
attention_mask=attention_mask,
|
|
367
|
+
)
|
|
368
|
+
self.model.rope_deltas = rope_deltas
|
|
369
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
370
|
+
elif (
|
|
371
|
+
"position_ids" in model_inputs and model_inputs["position_ids"] is not None
|
|
372
|
+
):
|
|
373
|
+
batch_size, seq_length = model_inputs["position_ids"].shape
|
|
374
|
+
device = model_inputs["position_ids"].device
|
|
375
|
+
position_ids = torch.arange(seq_length, device=device)
|
|
376
|
+
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
377
|
+
delta = cache_position[0] + self.model.rope_deltas
|
|
378
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
379
|
+
vision_positions = position_ids + delta.expand_as(position_ids)
|
|
380
|
+
|
|
381
|
+
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
|
382
|
+
if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
|
|
383
|
+
text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
|
|
384
|
+
None, None, :
|
|
385
|
+
]
|
|
386
|
+
else:
|
|
387
|
+
text_positions = model_inputs["position_ids"][None, ...]
|
|
388
|
+
# text_positions = model_inputs["position_ids"][None, ...]
|
|
389
|
+
assert vision_positions is not None, "vision_positions are missing"
|
|
390
|
+
model_inputs["position_ids"] = torch.cat(
|
|
391
|
+
[text_positions, vision_positions], dim=0
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if cache_position[0] != 0:
|
|
395
|
+
model_inputs["pixel_values"] = None
|
|
396
|
+
model_inputs["pixel_values_videos"] = None
|
|
397
|
+
|
|
398
|
+
return model_inputs
|
|
399
|
+
|
|
400
|
+
class patched_Qwen2_5_VisionTransformerPretrainedModel:
|
|
401
|
+
_PATCHES_ = ["get_window_index", "forward", "rot_pos_emb"]
|
|
402
|
+
_PATCHED_CLASS_ = (
|
|
403
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def rot_pos_emb(self, grid_thw):
|
|
407
|
+
pos_ids = []
|
|
408
|
+
for thw_ in grid_thw:
|
|
409
|
+
# PATCHED: avoid unbind
|
|
410
|
+
t = thw_[0]
|
|
411
|
+
h = thw_[1]
|
|
412
|
+
w = thw_[2]
|
|
413
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
414
|
+
hpos_ids = hpos_ids.reshape(
|
|
415
|
+
h // self.spatial_merge_size,
|
|
416
|
+
self.spatial_merge_size,
|
|
417
|
+
w // self.spatial_merge_size,
|
|
418
|
+
self.spatial_merge_size,
|
|
419
|
+
)
|
|
420
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
421
|
+
hpos_ids = hpos_ids.flatten()
|
|
422
|
+
|
|
423
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
424
|
+
wpos_ids = wpos_ids.reshape(
|
|
425
|
+
h // self.spatial_merge_size,
|
|
426
|
+
self.spatial_merge_size,
|
|
427
|
+
w // self.spatial_merge_size,
|
|
428
|
+
self.spatial_merge_size,
|
|
429
|
+
)
|
|
430
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
431
|
+
wpos_ids = wpos_ids.flatten()
|
|
432
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
433
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
|
434
|
+
max_grid_size = grid_thw[:, 1:].max()
|
|
435
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
436
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
437
|
+
return rotary_pos_emb
|
|
438
|
+
|
|
439
|
+
def get_window_index(self, grid_thw):
|
|
440
|
+
window_index: list = [] # type: ignore[annotation-unchecked]
|
|
441
|
+
# PATCHED
|
|
442
|
+
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
|
|
443
|
+
window_index_id = 0
|
|
444
|
+
vit_merger_window_size = (
|
|
445
|
+
self.window_size // self.spatial_merge_size // self.patch_size
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
for _thw in grid_thw:
|
|
449
|
+
# PATCHED: avoid unbind
|
|
450
|
+
grid_t = _thw[0]
|
|
451
|
+
grid_h = _thw[1]
|
|
452
|
+
grid_w = _thw[2]
|
|
453
|
+
llm_grid_h, llm_grid_w = (
|
|
454
|
+
grid_h // self.spatial_merge_size,
|
|
455
|
+
grid_w // self.spatial_merge_size,
|
|
456
|
+
)
|
|
457
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
|
458
|
+
grid_t, llm_grid_h, llm_grid_w
|
|
459
|
+
)
|
|
460
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
461
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
462
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
463
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
464
|
+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
|
465
|
+
index_padded = index_padded.reshape(
|
|
466
|
+
grid_t,
|
|
467
|
+
num_windows_h,
|
|
468
|
+
vit_merger_window_size,
|
|
469
|
+
num_windows_w,
|
|
470
|
+
vit_merger_window_size,
|
|
471
|
+
)
|
|
472
|
+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
|
473
|
+
grid_t,
|
|
474
|
+
num_windows_h * num_windows_w,
|
|
475
|
+
vit_merger_window_size,
|
|
476
|
+
vit_merger_window_size,
|
|
477
|
+
)
|
|
478
|
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
479
|
+
index_padded = index_padded.reshape(-1)
|
|
480
|
+
index_new = index_padded[index_padded != -100]
|
|
481
|
+
window_index.append(index_new + window_index_id)
|
|
482
|
+
cu_seqlens_tmp = (
|
|
483
|
+
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
|
|
484
|
+
)
|
|
485
|
+
# PATCHED
|
|
486
|
+
# cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
487
|
+
cu_window_seqlens.append(cu_seqlens_tmp)
|
|
488
|
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
489
|
+
window_index = torch.cat(window_index, dim=0)
|
|
490
|
+
|
|
491
|
+
return window_index, torch.cat(cu_window_seqlens, dim=0)
|
|
492
|
+
|
|
493
|
+
def forward(
|
|
494
|
+
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
|
|
495
|
+
) -> torch.Tensor:
|
|
496
|
+
"""
|
|
497
|
+
Args:
|
|
498
|
+
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
|
499
|
+
The final hidden states of the model.
|
|
500
|
+
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
|
501
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
`torch.Tensor`: hidden_states.
|
|
505
|
+
"""
|
|
506
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
507
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
508
|
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
509
|
+
# PATCHED
|
|
510
|
+
# cu_window_seqlens = torch.tensor(
|
|
511
|
+
# cu_window_seqlens,
|
|
512
|
+
# device=hidden_states.device,
|
|
513
|
+
# dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
514
|
+
# )
|
|
515
|
+
cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
|
|
516
|
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
|
517
|
+
|
|
518
|
+
seq_len, _ = hidden_states.size()
|
|
519
|
+
hidden_states = hidden_states.reshape(
|
|
520
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
521
|
+
)
|
|
522
|
+
hidden_states = hidden_states[window_index, :, :]
|
|
523
|
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
524
|
+
rotary_pos_emb = rotary_pos_emb.reshape(
|
|
525
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
526
|
+
)
|
|
527
|
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
528
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
529
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
530
|
+
position_embeddings = (emb.cos(), emb.sin())
|
|
531
|
+
|
|
532
|
+
cu_seqlens = torch.repeat_interleave(
|
|
533
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
|
534
|
+
).cumsum(
|
|
535
|
+
dim=0,
|
|
536
|
+
# Select dtype based on the following factors:
|
|
537
|
+
# - FA2 requires that cu_seqlens_q must have dtype int32
|
|
538
|
+
# - torch.onnx.export requires that cu_seqlens_q must have same dtype
|
|
539
|
+
# as grid_thw
|
|
540
|
+
# See https://github.com/huggingface/transformers/pull/34852
|
|
541
|
+
# for more information
|
|
542
|
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
|
543
|
+
)
|
|
544
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
545
|
+
|
|
546
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
547
|
+
if layer_num in self.fullatt_block_indexes:
|
|
548
|
+
cu_seqlens_now = cu_seqlens
|
|
549
|
+
else:
|
|
550
|
+
cu_seqlens_now = cu_window_seqlens
|
|
551
|
+
|
|
552
|
+
hidden_states = blk(
|
|
553
|
+
hidden_states,
|
|
554
|
+
cu_seqlens=cu_seqlens_now,
|
|
555
|
+
position_embeddings=position_embeddings,
|
|
556
|
+
**kwargs,
|
|
557
|
+
)
|
|
558
|
+
if STOPAT is not None and layer_num > STOPAT:
|
|
559
|
+
break
|
|
560
|
+
|
|
561
|
+
hidden_states = self.merger(hidden_states)
|
|
562
|
+
reverse_indices = torch.argsort(window_index)
|
|
563
|
+
hidden_states = hidden_states[reverse_indices, :]
|
|
564
|
+
return hidden_states
|
|
565
|
+
|
|
566
|
+
class patched_Qwen2_5_VLVisionAttentionOneIteration(torch.nn.Module):
|
|
567
|
+
def forward(
|
|
568
|
+
self,
|
|
569
|
+
start_end,
|
|
570
|
+
query_states,
|
|
571
|
+
key_states,
|
|
572
|
+
value_states,
|
|
573
|
+
scaling: float = 1.0,
|
|
574
|
+
dropout: float = 0.0,
|
|
575
|
+
**kwargs,
|
|
576
|
+
):
|
|
577
|
+
a = start_end[0].item()
|
|
578
|
+
b = start_end[1].item()
|
|
579
|
+
q = query_states[:, :, a:b, :]
|
|
580
|
+
k = key_states[:, :, a:b, :]
|
|
581
|
+
v = value_states[:, :, a:b, :]
|
|
582
|
+
return patched_sdpa_attention_forward(
|
|
583
|
+
self,
|
|
584
|
+
q,
|
|
585
|
+
k,
|
|
586
|
+
v,
|
|
587
|
+
attention_mask=None,
|
|
588
|
+
scaling=scaling,
|
|
589
|
+
dropout=dropout,
|
|
590
|
+
is_causal=False,
|
|
591
|
+
**kwargs,
|
|
592
|
+
)[0]
|
|
593
|
+
|
|
594
|
+
class patched_Qwen2_5_VLVisionAttention:
|
|
595
|
+
_PATCHES_ = ["forward"]
|
|
596
|
+
_PATCHED_CLASS_ = (
|
|
597
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
|
|
598
|
+
)
|
|
599
|
+
STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731
|
|
600
|
+
|
|
601
|
+
def forward(
|
|
602
|
+
self,
|
|
603
|
+
hidden_states: torch.Tensor,
|
|
604
|
+
cu_seqlens: torch.Tensor,
|
|
605
|
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
606
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
607
|
+
**kwargs,
|
|
608
|
+
) -> torch.Tensor:
|
|
609
|
+
seq_length = hidden_states.shape[0]
|
|
610
|
+
# PATCHED: avoid the use of unbind
|
|
611
|
+
qkv = (
|
|
612
|
+
self.qkv(hidden_states)
|
|
613
|
+
.reshape(seq_length, 3, self.num_heads, -1)
|
|
614
|
+
.permute(1, 0, 2, 3)
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
618
|
+
cos, sin = position_embeddings
|
|
619
|
+
|
|
620
|
+
# This part should be moved into the loop
|
|
621
|
+
# iteration to enable fusion inside the loop.
|
|
622
|
+
query_states, key_states = (
|
|
623
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
|
|
624
|
+
query_states, key_states, cos, sin
|
|
625
|
+
)
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
|
629
|
+
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
|
630
|
+
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
|
631
|
+
|
|
632
|
+
attention_interface: Callable = (
|
|
633
|
+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
|
|
634
|
+
)
|
|
635
|
+
if self.config._attn_implementation != "eager":
|
|
636
|
+
# PATCHED
|
|
637
|
+
# attention_interface = ALL_ATTENTION_FUNCTIONS[
|
|
638
|
+
# self.config._attn_implementation]
|
|
639
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
640
|
+
self.config._attn_implementation
|
|
641
|
+
]
|
|
642
|
+
|
|
643
|
+
is_sdpa_or_eager = (
|
|
644
|
+
attention_interface
|
|
645
|
+
is transformers.integrations.sdpa_attention.sdpa_attention_forward
|
|
646
|
+
or attention_interface is patched_sdpa_attention_forward
|
|
647
|
+
or attention_interface
|
|
648
|
+
is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
|
|
649
|
+
)
|
|
650
|
+
if is_sdpa_or_eager:
|
|
651
|
+
attn_output = qwen_sdpa_attention_versatile(
|
|
652
|
+
query_states,
|
|
653
|
+
key_states,
|
|
654
|
+
value_states,
|
|
655
|
+
cu_seqlens,
|
|
656
|
+
self.scaling,
|
|
657
|
+
self.num_heads,
|
|
658
|
+
)
|
|
659
|
+
elif _is_torchdynamo_exporting():
|
|
660
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
661
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
662
|
+
attn_output = torch.onnx.ops.symbolic(
|
|
663
|
+
"custom::qwen25_flash_attention",
|
|
664
|
+
(
|
|
665
|
+
query_states,
|
|
666
|
+
key_states,
|
|
667
|
+
value_states,
|
|
668
|
+
cu_seqlens,
|
|
669
|
+
cu_seqlens,
|
|
670
|
+
max_seqlen,
|
|
671
|
+
max_seqlen,
|
|
672
|
+
torch.tensor(self.scaling, dtype=torch.float32),
|
|
673
|
+
),
|
|
674
|
+
dtype=query_states.dtype,
|
|
675
|
+
shape=(
|
|
676
|
+
query_states.shape[0], # batch_size
|
|
677
|
+
query_states.shape[2], # sequence_length (total patches)
|
|
678
|
+
query_states.shape[1], # num_heads
|
|
679
|
+
query_states.shape[3], # head_size
|
|
680
|
+
),
|
|
681
|
+
version=1,
|
|
682
|
+
)
|
|
683
|
+
else:
|
|
684
|
+
raise NotImplementedError(
|
|
685
|
+
f"No corresponding export strategy for implementation "
|
|
686
|
+
f"{self.config._attn_implementation!r}, "
|
|
687
|
+
f"(use QWEN25ATTENTION to change it), and attention_interface="
|
|
688
|
+
f"{attention_interface!r} (use sdpa)"
|
|
689
|
+
)
|
|
690
|
+
elif self.config._attn_implementation == "flash_attention_2":
|
|
691
|
+
# Flash Attention 2: Use cu_seqlens for variable length attention
|
|
692
|
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
693
|
+
attn_output, _ = attention_interface(
|
|
694
|
+
self,
|
|
695
|
+
query_states,
|
|
696
|
+
key_states,
|
|
697
|
+
value_states,
|
|
698
|
+
attention_mask=None,
|
|
699
|
+
scaling=self.scaling,
|
|
700
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
701
|
+
cu_seq_lens_q=cu_seqlens,
|
|
702
|
+
cu_seq_lens_k=cu_seqlens,
|
|
703
|
+
max_length_q=max_seqlen,
|
|
704
|
+
max_length_k=max_seqlen,
|
|
705
|
+
is_causal=False,
|
|
706
|
+
**kwargs,
|
|
707
|
+
)
|
|
708
|
+
else:
|
|
709
|
+
# Other implementations: Process each chunk separately
|
|
710
|
+
# = qwen_sdpa_attention
|
|
711
|
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
712
|
+
splits = [
|
|
713
|
+
torch.split(tensor, lengths.tolist(), dim=2)
|
|
714
|
+
for tensor in (query_states, key_states, value_states)
|
|
715
|
+
]
|
|
716
|
+
|
|
717
|
+
attn_outputs = [
|
|
718
|
+
attention_interface(
|
|
719
|
+
self,
|
|
720
|
+
q,
|
|
721
|
+
k,
|
|
722
|
+
v,
|
|
723
|
+
attention_mask=None,
|
|
724
|
+
scaling=self.scaling,
|
|
725
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
726
|
+
is_causal=False,
|
|
727
|
+
**kwargs,
|
|
728
|
+
)[0]
|
|
729
|
+
for q, k, v in zip(*splits)
|
|
730
|
+
]
|
|
731
|
+
attn_output = torch.cat(attn_outputs, dim=1)
|
|
732
|
+
|
|
733
|
+
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
|
734
|
+
attn_output = self.proj(attn_output)
|
|
735
|
+
return attn_output
|