onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +47 -10
- onnx_diagnostic/export/api.py +81 -50
- onnx_diagnostic/export/control_flow_research.py +10 -5
- onnx_diagnostic/export/onnx_plug.py +250 -61
- onnx_diagnostic/ext_test_case.py +99 -53
- onnx_diagnostic/helpers/dot_helper.py +37 -25
- onnx_diagnostic/helpers/helper.py +44 -38
- onnx_diagnostic/helpers/onnx_helper.py +441 -18
- onnx_diagnostic/helpers/ort_session.py +8 -8
- onnx_diagnostic/helpers/torch_helper.py +28 -2
- onnx_diagnostic/reference/ort_evaluator.py +6 -29
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +14 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +11 -5
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
|
@@ -18,10 +18,11 @@ from onnx.defs import onnx_opset_version
|
|
|
18
18
|
import onnxruntime
|
|
19
19
|
from ..helpers import string_type
|
|
20
20
|
from ..helpers.onnx_helper import (
|
|
21
|
-
|
|
21
|
+
get_hidden_inputs,
|
|
22
22
|
dtype_to_tensor_dtype,
|
|
23
|
-
to_array_extended,
|
|
24
23
|
np_dtype_to_tensor_dtype,
|
|
24
|
+
to_array_extended,
|
|
25
|
+
pretty_onnx,
|
|
25
26
|
)
|
|
26
27
|
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
|
|
27
28
|
from ..helpers.ort_session import (
|
|
@@ -472,39 +473,15 @@ class OnnxruntimeEvaluator:
|
|
|
472
473
|
yield from self.enumerate_nodes(att.g.node)
|
|
473
474
|
yield node
|
|
474
475
|
|
|
475
|
-
@classmethod
|
|
476
|
-
def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
|
|
477
|
-
"""
|
|
478
|
-
Returns the hidden inputs (inputs coming from an upper context)
|
|
479
|
-
used by a subgraph.
|
|
480
|
-
"""
|
|
481
|
-
hidden = set()
|
|
482
|
-
memo = (
|
|
483
|
-
{i.name for i in graph.initializer}
|
|
484
|
-
| {i.name for i in graph.sparse_initializer}
|
|
485
|
-
| {i.name for i in graph.input}
|
|
486
|
-
)
|
|
487
|
-
for node in graph.node:
|
|
488
|
-
for i in node.input:
|
|
489
|
-
if i not in memo:
|
|
490
|
-
hidden.add(i)
|
|
491
|
-
for att in node.attribute:
|
|
492
|
-
if att.type == AttributeProto.GRAPH and att.g:
|
|
493
|
-
hid = cls._get_hidden_inputs(att.g)
|
|
494
|
-
less = set(h for h in hid if h not in memo)
|
|
495
|
-
hidden |= less
|
|
496
|
-
memo |= set(node.output)
|
|
497
|
-
return hidden
|
|
498
|
-
|
|
499
476
|
@classmethod
|
|
500
477
|
def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
|
|
501
|
-
"""Calls multiple
|
|
478
|
+
"""Calls multiple get_hidden_inputs on every attribute."""
|
|
502
479
|
if node.op_type not in {"Loop", "Scan", "If"}:
|
|
503
480
|
return set()
|
|
504
481
|
hidden = set()
|
|
505
482
|
for att in node.attribute:
|
|
506
483
|
if att.type == AttributeProto.GRAPH:
|
|
507
|
-
hidden |=
|
|
484
|
+
hidden |= get_hidden_inputs(att.g)
|
|
508
485
|
return hidden - (hidden & set(node.input))
|
|
509
486
|
|
|
510
487
|
def _get_sess(
|
|
@@ -624,7 +601,7 @@ class OnnxruntimeEvaluator:
|
|
|
624
601
|
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
|
|
625
602
|
vinputs.append(value)
|
|
626
603
|
|
|
627
|
-
reduced_set =
|
|
604
|
+
reduced_set = get_hidden_inputs(g)
|
|
628
605
|
for i, v in context.items():
|
|
629
606
|
if i in reduced_set and i not in unique_names:
|
|
630
607
|
unique_names.add(i)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
from typing import Callable, List, Optional, Tuple
|
|
2
3
|
import torch
|
|
3
4
|
|
|
@@ -19,6 +20,12 @@ if patch_masking_utils:
|
|
|
19
20
|
prepare_padding_mask,
|
|
20
21
|
)
|
|
21
22
|
|
|
23
|
+
_prepare_padding_mask_kwargs = (
|
|
24
|
+
dict(_slice=False)
|
|
25
|
+
if "_slice" in inspect.signature(prepare_padding_mask).parameters
|
|
26
|
+
else {}
|
|
27
|
+
)
|
|
28
|
+
|
|
22
29
|
try:
|
|
23
30
|
# transformers>=5.0
|
|
24
31
|
from transformers.masking_utils import (
|
|
@@ -132,7 +139,9 @@ if patch_masking_utils:
|
|
|
132
139
|
) -> Optional[torch.Tensor]:
|
|
133
140
|
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
134
141
|
q_length = cache_position.shape[0]
|
|
135
|
-
padding_mask = prepare_padding_mask(
|
|
142
|
+
padding_mask = prepare_padding_mask(
|
|
143
|
+
attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
|
|
144
|
+
)
|
|
136
145
|
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
137
146
|
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
138
147
|
):
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Callable, Optional
|
|
2
|
+
from typing import Callable, Optional, Tuple
|
|
3
3
|
import onnx
|
|
4
|
+
import onnx.helper as oh
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn.functional as F
|
|
6
7
|
from ...export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
8
|
+
from ...helpers.torch_helper import torch_dtype_to_onnx_dtype
|
|
7
9
|
from .patch_helper import _is_torchdynamo_exporting
|
|
8
10
|
from ._patch_transformers_attention import patched_sdpa_attention_forward
|
|
9
11
|
|
|
@@ -22,7 +24,43 @@ if patch_qwen2_5:
|
|
|
22
24
|
|
|
23
25
|
onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
|
|
24
26
|
op = onnxscript.opset22
|
|
27
|
+
op23 = onnxscript.onnx_opset.opset23
|
|
25
28
|
msft_op = onnxscript.values.Opset("com.microsoft", 1)
|
|
29
|
+
STOPAT = (
|
|
30
|
+
int(os.environ.get("STOPAT", None))
|
|
31
|
+
if os.environ.get("STOPAT", None) is not None
|
|
32
|
+
else None
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto:
|
|
36
|
+
opsets = {d.domain: d.version for d in function_proto.opset_import}
|
|
37
|
+
if "com.microsoft" not in opsets:
|
|
38
|
+
d = function_proto.opset_import.add()
|
|
39
|
+
d.domain = "com.microsoft"
|
|
40
|
+
d.version = 1
|
|
41
|
+
return function_proto
|
|
42
|
+
|
|
43
|
+
def _update_sequence_type(
|
|
44
|
+
itype: int, function_proto: onnx.FunctionProto
|
|
45
|
+
) -> onnx.FunctionProto:
|
|
46
|
+
proto = oh.make_function(
|
|
47
|
+
function_proto.domain,
|
|
48
|
+
function_proto.name,
|
|
49
|
+
function_proto.input,
|
|
50
|
+
function_proto.output,
|
|
51
|
+
[
|
|
52
|
+
(
|
|
53
|
+
oh.make_node("SequenceEmpty", node.input, node.output, dtype=itype)
|
|
54
|
+
if node.op_type == "SequenceEmpty"
|
|
55
|
+
else node
|
|
56
|
+
)
|
|
57
|
+
for node in function_proto.node
|
|
58
|
+
],
|
|
59
|
+
attributes=function_proto.attribute,
|
|
60
|
+
attribute_protos=function_proto.attribute_proto,
|
|
61
|
+
opset_imports=function_proto.opset_import,
|
|
62
|
+
)
|
|
63
|
+
return proto
|
|
26
64
|
|
|
27
65
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
28
66
|
def LoopMHAAttention(
|
|
@@ -32,7 +70,6 @@ if patch_qwen2_5:
|
|
|
32
70
|
cu_seqlens,
|
|
33
71
|
scaling: float = 0.11180339887498948,
|
|
34
72
|
num_heads: int = 16,
|
|
35
|
-
itype: int = onnx.TensorProto.FLOAT,
|
|
36
73
|
):
|
|
37
74
|
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
|
|
38
75
|
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
@@ -45,7 +82,7 @@ if patch_qwen2_5:
|
|
|
45
82
|
seq_axis = op.Constant(value_ints=[1])
|
|
46
83
|
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
47
84
|
# attn_output = op.Slice(value_3d, [0], [0], seq_axis)
|
|
48
|
-
seq_attn = op.SequenceEmpty(dtype=
|
|
85
|
+
seq_attn = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
49
86
|
for i_patch in range(num_patches):
|
|
50
87
|
i_1d = op.Reshape(i_patch, [1])
|
|
51
88
|
i_plus_1_1d = i_1d + 1
|
|
@@ -55,11 +92,7 @@ if patch_qwen2_5:
|
|
|
55
92
|
key_i = op.Slice(key_3d, start, end, seq_axis_int32)
|
|
56
93
|
value_i = op.Slice(value_3d, start, end, seq_axis_int32)
|
|
57
94
|
mha_output = msft_op.MultiHeadAttention(
|
|
58
|
-
query_i,
|
|
59
|
-
key_i,
|
|
60
|
-
value_i,
|
|
61
|
-
num_heads=num_heads,
|
|
62
|
-
scale=scaling,
|
|
95
|
+
query_i, key_i, value_i, num_heads=num_heads, scale=scaling
|
|
63
96
|
)
|
|
64
97
|
# attn_output = op.Concat(attn_output, mha_output, axis=1)
|
|
65
98
|
seq_attn = op.SequenceInsert(seq_attn, mha_output)
|
|
@@ -67,13 +100,47 @@ if patch_qwen2_5:
|
|
|
67
100
|
attn_output_4d = op.Reshape(attn_output, output_shape)
|
|
68
101
|
return attn_output_4d
|
|
69
102
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
103
|
+
@onnxscript.script(opset=onnx_plugs_op)
|
|
104
|
+
def LoopAttention23(
|
|
105
|
+
query_states,
|
|
106
|
+
key_states,
|
|
107
|
+
value_states,
|
|
108
|
+
cu_seqlens,
|
|
109
|
+
scaling: float = 0.11180339887498948,
|
|
110
|
+
num_heads: int = 16,
|
|
111
|
+
):
|
|
112
|
+
to_3d_shape = op23.Constant(value_ints=[0, 0, -1])
|
|
113
|
+
query_transposed = op23.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
114
|
+
output_shape = op23.Shape(query_transposed)
|
|
115
|
+
query_3d = op23.Reshape(query_transposed, to_3d_shape)
|
|
116
|
+
value_3d = op23.Reshape(op23.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
117
|
+
key_3d = op23.Reshape(op23.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
118
|
+
cu_seqlens = op23.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
|
|
119
|
+
num_patches = op23.Size(cu_seqlens) - 1
|
|
120
|
+
seq_axis = op23.Constant(value_ints=[1])
|
|
121
|
+
seq_axis_int32 = op23.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
122
|
+
seq_attn = op23.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
123
|
+
for i_patch in range(num_patches):
|
|
124
|
+
i_1d = op23.Reshape(i_patch, [1])
|
|
125
|
+
i_plus_1_1d = i_1d + 1
|
|
126
|
+
start = op23.Gather(cu_seqlens, i_1d, axis=0)
|
|
127
|
+
end = op23.Gather(cu_seqlens, i_plus_1_1d, axis=0)
|
|
128
|
+
query_i = op23.Slice(query_3d, start, end, seq_axis_int32)
|
|
129
|
+
key_i = op23.Slice(key_3d, start, end, seq_axis_int32)
|
|
130
|
+
value_i = op23.Slice(value_3d, start, end, seq_axis_int32)
|
|
131
|
+
mha_output = op23.Attention(
|
|
132
|
+
query_i,
|
|
133
|
+
key_i,
|
|
134
|
+
value_i,
|
|
135
|
+
scale=scaling,
|
|
136
|
+
q_num_heads=num_heads,
|
|
137
|
+
kv_num_heads=num_heads,
|
|
138
|
+
softmax_precision=onnx.TensorProto.FLOAT,
|
|
139
|
+
)
|
|
140
|
+
seq_attn = op23.SequenceInsert(seq_attn, mha_output)
|
|
141
|
+
attn_output = op23.ConcatFromSequence(seq_attn, axis=1)
|
|
142
|
+
attn_output_4d = op23.Reshape(attn_output, output_shape)
|
|
143
|
+
return attn_output_4d
|
|
77
144
|
|
|
78
145
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
79
146
|
def PackedAttention(
|
|
@@ -132,8 +199,40 @@ if patch_qwen2_5:
|
|
|
132
199
|
cu_seqlens: torch.Tensor, # F7su19
|
|
133
200
|
scaling: float = 0,
|
|
134
201
|
num_heads: int = 16,
|
|
135
|
-
itype: int = onnx.TensorProto.FLOAT,
|
|
136
202
|
) -> torch.Tensor:
|
|
203
|
+
"""
|
|
204
|
+
The loop can be removed with the following code
|
|
205
|
+
but it hits memory overflow for big inputs.
|
|
206
|
+
|
|
207
|
+
.. code-block:: python
|
|
208
|
+
|
|
209
|
+
# make square mask
|
|
210
|
+
indices = torch.arange(
|
|
211
|
+
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
|
212
|
+
)
|
|
213
|
+
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
|
|
214
|
+
cu_seqlens.dtype
|
|
215
|
+
)
|
|
216
|
+
dot = dot.sum(dim=0)
|
|
217
|
+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
|
|
218
|
+
bool_mask = mask == 0
|
|
219
|
+
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
|
|
220
|
+
|
|
221
|
+
torch._check(bool_mask.shape[2] == key_states.shape[2])
|
|
222
|
+
torch._check(bool_mask.shape[3] == key_states.shape[2])
|
|
223
|
+
|
|
224
|
+
attn_output, _ = attention_interface(
|
|
225
|
+
self,
|
|
226
|
+
query_states,
|
|
227
|
+
key_states,
|
|
228
|
+
value_states,
|
|
229
|
+
attention_mask=bool_mask,
|
|
230
|
+
scaling=self.scaling,
|
|
231
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
232
|
+
is_causal=False,
|
|
233
|
+
**kwargs,
|
|
234
|
+
)
|
|
235
|
+
"""
|
|
137
236
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
138
237
|
splits = [
|
|
139
238
|
torch.split(tensor, lengths.tolist(), dim=2)
|
|
@@ -156,36 +255,58 @@ if patch_qwen2_5:
|
|
|
156
255
|
attn_output = torch.cat(attn_outputs, dim=1)
|
|
157
256
|
return attn_output
|
|
158
257
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
258
|
+
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
259
|
+
first_tensor = next(a for a in args if a is not None)
|
|
260
|
+
dtype = first_tensor.dtype
|
|
261
|
+
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
|
|
262
|
+
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
263
|
+
if strategy is not None:
|
|
264
|
+
return strategy, itype
|
|
265
|
+
if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
|
|
266
|
+
if opset >= 23:
|
|
267
|
+
return "LOOPA23", itype
|
|
268
|
+
return "LOOPMHA", itype
|
|
269
|
+
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
|
|
270
|
+
# first_tensor may be a SymbolicTensor (onnx).
|
|
271
|
+
# is_cuda is not available.
|
|
272
|
+
if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
|
|
273
|
+
return "PACKED", itype
|
|
274
|
+
return "LOOPMHA", itype
|
|
275
|
+
raise AssertionError(
|
|
276
|
+
f"Unable to handle type {torch.dtype} (itype={itype}) "
|
|
277
|
+
f"on device {torch.device} with opset={opset}"
|
|
278
|
+
)
|
|
174
279
|
|
|
175
|
-
|
|
280
|
+
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
|
|
176
281
|
qwen_sdpa_attention,
|
|
177
282
|
lambda qs, *args, **kwargs: torch.empty(
|
|
178
283
|
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
179
284
|
dtype=qs.dtype,
|
|
180
285
|
device=qs.device,
|
|
181
286
|
),
|
|
182
|
-
|
|
287
|
+
{
|
|
288
|
+
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
289
|
+
PackedAttention.to_function_proto()
|
|
290
|
+
),
|
|
291
|
+
("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
|
|
292
|
+
("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
293
|
+
onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
|
|
294
|
+
),
|
|
295
|
+
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
296
|
+
LoopMHAAttention.to_function_proto()
|
|
297
|
+
),
|
|
298
|
+
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
299
|
+
onnx.TensorProto.FLOAT16,
|
|
300
|
+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
|
|
301
|
+
),
|
|
302
|
+
},
|
|
183
303
|
n_inputs=4,
|
|
184
304
|
n_outputs=1,
|
|
185
|
-
kwargs=dict(scaling=0.11180339887498948, num_heads=16
|
|
186
|
-
name="
|
|
305
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
|
|
306
|
+
name="qwen_sdpa_attention_versatile",
|
|
307
|
+
version_selector=qwen_version_selector,
|
|
187
308
|
)
|
|
188
|
-
PLUGS.append(
|
|
309
|
+
PLUGS.append(qwen_sdpa_attention_versatile)
|
|
189
310
|
|
|
190
311
|
class patched_Qwen2_5_VLForConditionalGeneration:
|
|
191
312
|
_PATCHES_ = ["prepare_inputs_for_generation"]
|
|
@@ -434,6 +555,8 @@ if patch_qwen2_5:
|
|
|
434
555
|
position_embeddings=position_embeddings,
|
|
435
556
|
**kwargs,
|
|
436
557
|
)
|
|
558
|
+
if STOPAT is not None and layer_num > STOPAT:
|
|
559
|
+
break
|
|
437
560
|
|
|
438
561
|
hidden_states = self.merger(hidden_states)
|
|
439
562
|
reverse_indices = torch.argsort(window_index)
|
|
@@ -473,9 +596,7 @@ if patch_qwen2_5:
|
|
|
473
596
|
_PATCHED_CLASS_ = (
|
|
474
597
|
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
|
|
475
598
|
)
|
|
476
|
-
STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731
|
|
477
|
-
"QWEN25ATTENTION", "PACKED"
|
|
478
|
-
)
|
|
599
|
+
STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731
|
|
479
600
|
|
|
480
601
|
def forward(
|
|
481
602
|
self,
|
|
@@ -519,14 +640,15 @@ if patch_qwen2_5:
|
|
|
519
640
|
self.config._attn_implementation
|
|
520
641
|
]
|
|
521
642
|
|
|
522
|
-
|
|
643
|
+
is_sdpa_or_eager = (
|
|
523
644
|
attention_interface
|
|
524
645
|
is transformers.integrations.sdpa_attention.sdpa_attention_forward
|
|
525
646
|
or attention_interface is patched_sdpa_attention_forward
|
|
647
|
+
or attention_interface
|
|
648
|
+
is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
|
|
526
649
|
)
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
attn_output = qwen_sdpa_attention_packed_versatile(
|
|
650
|
+
if is_sdpa_or_eager:
|
|
651
|
+
attn_output = qwen_sdpa_attention_versatile(
|
|
530
652
|
query_states,
|
|
531
653
|
key_states,
|
|
532
654
|
value_states,
|
|
@@ -558,78 +680,10 @@ if patch_qwen2_5:
|
|
|
558
680
|
),
|
|
559
681
|
version=1,
|
|
560
682
|
)
|
|
561
|
-
elif is_sdpa and attention_strategy == "LOOPMHA":
|
|
562
|
-
attn_output = qwen_sdpa_attention_loopmha_versatile(
|
|
563
|
-
query_states,
|
|
564
|
-
key_states,
|
|
565
|
-
value_states,
|
|
566
|
-
cu_seqlens,
|
|
567
|
-
self.scaling,
|
|
568
|
-
self.num_heads,
|
|
569
|
-
(
|
|
570
|
-
onnx.TensorProto.FLOAT
|
|
571
|
-
if query_states.dtype == torch.float32
|
|
572
|
-
else (
|
|
573
|
-
onnx.TensorProto.FLOAT16
|
|
574
|
-
if query_states.dtype == torch.float16
|
|
575
|
-
else onnx.TensorProto.BFLOAT16
|
|
576
|
-
)
|
|
577
|
-
),
|
|
578
|
-
)
|
|
579
|
-
|
|
580
|
-
# to rewrite later with a for loop
|
|
581
|
-
# def _iteration(start_end, query_states, key_states, value_states):
|
|
582
|
-
# return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
|
|
583
|
-
# self,
|
|
584
|
-
# start_end,
|
|
585
|
-
# query_states,
|
|
586
|
-
# key_states,
|
|
587
|
-
# value_states,
|
|
588
|
-
# scaling=self.scaling,
|
|
589
|
-
# dropout=0.0 if not self.training else self.attention_dropout,
|
|
590
|
-
# )
|
|
591
|
-
|
|
592
|
-
# starts = cu_seqlens[:-1]
|
|
593
|
-
# ends = cu_seqlens[1:]
|
|
594
|
-
# torch._check(starts.shape[0] > 0)
|
|
595
|
-
# torch._check(ends.shape[0] > 0)
|
|
596
|
-
# starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
|
|
597
|
-
# attn_outputs = [
|
|
598
|
-
# _iteration(start_end, query_states, key_states, value_states)
|
|
599
|
-
# for start_end in starts_ends
|
|
600
|
-
# ]
|
|
601
|
-
# attn_output = torch.cat(attn_outputs, dim=1)
|
|
602
|
-
elif is_sdpa and attention_strategy == "BIGMASK":
|
|
603
|
-
# make square mask
|
|
604
|
-
indices = torch.arange(
|
|
605
|
-
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
|
606
|
-
)
|
|
607
|
-
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
|
|
608
|
-
cu_seqlens.dtype
|
|
609
|
-
)
|
|
610
|
-
dot = dot.sum(dim=0)
|
|
611
|
-
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
|
|
612
|
-
bool_mask = mask == 0
|
|
613
|
-
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
|
|
614
|
-
|
|
615
|
-
torch._check(bool_mask.shape[2] == key_states.shape[2])
|
|
616
|
-
torch._check(bool_mask.shape[3] == key_states.shape[2])
|
|
617
|
-
|
|
618
|
-
attn_output, _ = attention_interface(
|
|
619
|
-
self,
|
|
620
|
-
query_states,
|
|
621
|
-
key_states,
|
|
622
|
-
value_states,
|
|
623
|
-
attention_mask=bool_mask,
|
|
624
|
-
scaling=self.scaling,
|
|
625
|
-
dropout=0.0 if not self.training else self.attention_dropout,
|
|
626
|
-
is_causal=False,
|
|
627
|
-
**kwargs,
|
|
628
|
-
)
|
|
629
683
|
else:
|
|
630
684
|
raise NotImplementedError(
|
|
631
|
-
f"No corresponding export strategy for "
|
|
632
|
-
f"{
|
|
685
|
+
f"No corresponding export strategy for implementation "
|
|
686
|
+
f"{self.config._attn_implementation!r}, "
|
|
633
687
|
f"(use QWEN25ATTENTION to change it), and attention_interface="
|
|
634
688
|
f"{attention_interface!r} (use sdpa)"
|
|
635
689
|
)
|
|
@@ -653,6 +707,7 @@ if patch_qwen2_5:
|
|
|
653
707
|
)
|
|
654
708
|
else:
|
|
655
709
|
# Other implementations: Process each chunk separately
|
|
710
|
+
# = qwen_sdpa_attention
|
|
656
711
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
657
712
|
splits = [
|
|
658
713
|
torch.split(tensor, lengths.tolist(), dim=2)
|
|
@@ -236,7 +236,7 @@ def code_sample(
|
|
|
236
236
|
)
|
|
237
237
|
)
|
|
238
238
|
"""
|
|
239
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
239
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
240
240
|
model_id,
|
|
241
241
|
subfolder,
|
|
242
242
|
same_as_pretrained=same_as_pretrained,
|
|
@@ -256,6 +256,7 @@ def code_sample(
|
|
|
256
256
|
model_kwargs=mop,
|
|
257
257
|
subfolder=subfolder,
|
|
258
258
|
add_second_input=False,
|
|
259
|
+
submodule=submodule,
|
|
259
260
|
)
|
|
260
261
|
if drop_inputs:
|
|
261
262
|
update = {}
|
|
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _preprocess_model_id(
|
|
29
|
-
model_id: str,
|
|
30
|
-
|
|
29
|
+
model_id: str,
|
|
30
|
+
subfolder: Optional[str],
|
|
31
|
+
same_as_pretrained: bool,
|
|
32
|
+
use_pretrained: bool,
|
|
33
|
+
submodule: Optional[str] = None,
|
|
34
|
+
) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
|
|
35
|
+
if "::" in model_id:
|
|
36
|
+
assert (
|
|
37
|
+
not submodule
|
|
38
|
+
), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
|
|
39
|
+
model_id, submodule = model_id.split("::", maxsplit=1)
|
|
31
40
|
if subfolder or "//" not in model_id:
|
|
32
|
-
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
41
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
|
|
33
42
|
spl = model_id.split("//")
|
|
34
43
|
if spl[-1] == "pretrained":
|
|
35
|
-
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
|
|
44
|
+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
|
|
36
45
|
if spl[-1] in {"transformer", "vae"}:
|
|
37
46
|
# known subfolder
|
|
38
|
-
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
|
|
39
|
-
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
47
|
+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained, submodule
|
|
48
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
|
|
40
49
|
|
|
41
50
|
|
|
42
51
|
def get_untrained_model_with_inputs(
|
|
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
|
|
|
54
63
|
subfolder: Optional[str] = None,
|
|
55
64
|
use_only_preinstalled: bool = False,
|
|
56
65
|
config_reduction: Optional[Callable[[Any, str], Dict]] = None,
|
|
66
|
+
submodule: Optional[str] = None,
|
|
57
67
|
) -> Dict[str, Any]:
|
|
58
68
|
"""
|
|
59
69
|
Gets a non initialized model similar to the original model
|
|
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
|
|
|
82
92
|
<onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
|
|
83
93
|
this function takes a configuration and a task (string)
|
|
84
94
|
as arguments
|
|
95
|
+
:param submodule: use a submodule instead of the main model
|
|
85
96
|
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
|
|
86
97
|
some necessary rewriting as well
|
|
87
98
|
|
|
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
|
|
|
108
119
|
f"model_id={model_id!r}, preinstalled model is only available "
|
|
109
120
|
f"if use_only_preinstalled is False."
|
|
110
121
|
)
|
|
111
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
122
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
112
123
|
model_id,
|
|
113
124
|
subfolder,
|
|
114
125
|
same_as_pretrained=same_as_pretrained,
|
|
115
126
|
use_pretrained=use_pretrained,
|
|
127
|
+
submodule=submodule,
|
|
116
128
|
)
|
|
117
129
|
if verbose:
|
|
118
130
|
print(
|
|
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
|
|
|
147
159
|
if verbose:
|
|
148
160
|
print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
|
|
149
161
|
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
|
|
162
|
+
if submodule:
|
|
163
|
+
print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
|
|
150
164
|
if task is None:
|
|
151
165
|
task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
|
|
152
166
|
if verbose:
|
|
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
|
|
|
357
371
|
if diff_config is not None:
|
|
358
372
|
res["dump_info"] = dict(config_diff=diff_config)
|
|
359
373
|
|
|
374
|
+
if submodule:
|
|
375
|
+
path = submodule.split("::") if "::" in submodule else [submodule]
|
|
376
|
+
for p in path:
|
|
377
|
+
assert hasattr(model, p), (
|
|
378
|
+
f"Unable to find submodule {p!r} in in class {type(model)}, "
|
|
379
|
+
f"submodule={submodule!r}, possible candidates: "
|
|
380
|
+
f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
|
|
381
|
+
)
|
|
382
|
+
model = getattr(model, p)
|
|
383
|
+
|
|
384
|
+
if verbose:
|
|
385
|
+
print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
|
|
386
|
+
|
|
360
387
|
sizes = compute_model_size(model)
|
|
361
388
|
res["model"] = model
|
|
362
389
|
res["configuration"] = config
|