onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +26 -1
- onnx_diagnostic/export/api.py +66 -46
- onnx_diagnostic/export/control_flow_research.py +10 -5
- onnx_diagnostic/export/onnx_plug.py +195 -60
- onnx_diagnostic/ext_test_case.py +99 -53
- onnx_diagnostic/helpers/dot_helper.py +37 -25
- onnx_diagnostic/helpers/helper.py +18 -11
- onnx_diagnostic/helpers/onnx_helper.py +441 -18
- onnx_diagnostic/helpers/ort_session.py +8 -8
- onnx_diagnostic/helpers/torch_helper.py +28 -2
- onnx_diagnostic/reference/ort_evaluator.py +6 -29
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +14 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +11 -5
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +25 -25
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -18,10 +18,11 @@ from onnx.defs import onnx_opset_version
|
|
|
18
18
|
import onnxruntime
|
|
19
19
|
from ..helpers import string_type
|
|
20
20
|
from ..helpers.onnx_helper import (
|
|
21
|
-
|
|
21
|
+
get_hidden_inputs,
|
|
22
22
|
dtype_to_tensor_dtype,
|
|
23
|
-
to_array_extended,
|
|
24
23
|
np_dtype_to_tensor_dtype,
|
|
24
|
+
to_array_extended,
|
|
25
|
+
pretty_onnx,
|
|
25
26
|
)
|
|
26
27
|
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
|
|
27
28
|
from ..helpers.ort_session import (
|
|
@@ -472,39 +473,15 @@ class OnnxruntimeEvaluator:
|
|
|
472
473
|
yield from self.enumerate_nodes(att.g.node)
|
|
473
474
|
yield node
|
|
474
475
|
|
|
475
|
-
@classmethod
|
|
476
|
-
def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
|
|
477
|
-
"""
|
|
478
|
-
Returns the hidden inputs (inputs coming from an upper context)
|
|
479
|
-
used by a subgraph.
|
|
480
|
-
"""
|
|
481
|
-
hidden = set()
|
|
482
|
-
memo = (
|
|
483
|
-
{i.name for i in graph.initializer}
|
|
484
|
-
| {i.name for i in graph.sparse_initializer}
|
|
485
|
-
| {i.name for i in graph.input}
|
|
486
|
-
)
|
|
487
|
-
for node in graph.node:
|
|
488
|
-
for i in node.input:
|
|
489
|
-
if i not in memo:
|
|
490
|
-
hidden.add(i)
|
|
491
|
-
for att in node.attribute:
|
|
492
|
-
if att.type == AttributeProto.GRAPH and att.g:
|
|
493
|
-
hid = cls._get_hidden_inputs(att.g)
|
|
494
|
-
less = set(h for h in hid if h not in memo)
|
|
495
|
-
hidden |= less
|
|
496
|
-
memo |= set(node.output)
|
|
497
|
-
return hidden
|
|
498
|
-
|
|
499
476
|
@classmethod
|
|
500
477
|
def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
|
|
501
|
-
"""Calls multiple
|
|
478
|
+
"""Calls multiple get_hidden_inputs on every attribute."""
|
|
502
479
|
if node.op_type not in {"Loop", "Scan", "If"}:
|
|
503
480
|
return set()
|
|
504
481
|
hidden = set()
|
|
505
482
|
for att in node.attribute:
|
|
506
483
|
if att.type == AttributeProto.GRAPH:
|
|
507
|
-
hidden |=
|
|
484
|
+
hidden |= get_hidden_inputs(att.g)
|
|
508
485
|
return hidden - (hidden & set(node.input))
|
|
509
486
|
|
|
510
487
|
def _get_sess(
|
|
@@ -624,7 +601,7 @@ class OnnxruntimeEvaluator:
|
|
|
624
601
|
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
|
|
625
602
|
vinputs.append(value)
|
|
626
603
|
|
|
627
|
-
reduced_set =
|
|
604
|
+
reduced_set = get_hidden_inputs(g)
|
|
628
605
|
for i, v in context.items():
|
|
629
606
|
if i in reduced_set and i not in unique_names:
|
|
630
607
|
unique_names.add(i)
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Callable, Optional
|
|
2
|
+
from typing import Callable, Optional, Tuple
|
|
3
3
|
import onnx
|
|
4
|
+
import onnx.helper as oh
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn.functional as F
|
|
6
7
|
from ...export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
8
|
+
from ...helpers.torch_helper import torch_dtype_to_onnx_dtype
|
|
7
9
|
from .patch_helper import _is_torchdynamo_exporting
|
|
8
10
|
from ._patch_transformers_attention import patched_sdpa_attention_forward
|
|
9
11
|
|
|
@@ -22,7 +24,43 @@ if patch_qwen2_5:
|
|
|
22
24
|
|
|
23
25
|
onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
|
|
24
26
|
op = onnxscript.opset22
|
|
27
|
+
op24 = onnxscript.onnx_opset.opset24
|
|
25
28
|
msft_op = onnxscript.values.Opset("com.microsoft", 1)
|
|
29
|
+
STOPAT = (
|
|
30
|
+
int(os.environ.get("STOPAT", None))
|
|
31
|
+
if os.environ.get("STOPAT", None) is not None
|
|
32
|
+
else None
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto:
|
|
36
|
+
opsets = {d.domain: d.version for d in function_proto.opset_import}
|
|
37
|
+
if "com.microsoft" not in opsets:
|
|
38
|
+
d = function_proto.opset_import.add()
|
|
39
|
+
d.domain = "com.microsoft"
|
|
40
|
+
d.version = 1
|
|
41
|
+
return function_proto
|
|
42
|
+
|
|
43
|
+
def _update_sequence_type(
|
|
44
|
+
itype: int, function_proto: onnx.FunctionProto
|
|
45
|
+
) -> onnx.FunctionProto:
|
|
46
|
+
proto = oh.make_function(
|
|
47
|
+
function_proto.domain,
|
|
48
|
+
function_proto.name,
|
|
49
|
+
function_proto.input,
|
|
50
|
+
function_proto.output,
|
|
51
|
+
[
|
|
52
|
+
(
|
|
53
|
+
oh.make_node("SequenceEmpty", node.input, node.output, dtype=itype)
|
|
54
|
+
if node.op_type == "SequenceEmpty"
|
|
55
|
+
else node
|
|
56
|
+
)
|
|
57
|
+
for node in function_proto.node
|
|
58
|
+
],
|
|
59
|
+
attributes=function_proto.attribute,
|
|
60
|
+
attribute_protos=function_proto.attribute_proto,
|
|
61
|
+
opset_imports=function_proto.opset_import,
|
|
62
|
+
)
|
|
63
|
+
return proto
|
|
26
64
|
|
|
27
65
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
28
66
|
def LoopMHAAttention(
|
|
@@ -32,7 +70,6 @@ if patch_qwen2_5:
|
|
|
32
70
|
cu_seqlens,
|
|
33
71
|
scaling: float = 0.11180339887498948,
|
|
34
72
|
num_heads: int = 16,
|
|
35
|
-
itype: int = onnx.TensorProto.FLOAT,
|
|
36
73
|
):
|
|
37
74
|
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
|
|
38
75
|
query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
@@ -45,7 +82,7 @@ if patch_qwen2_5:
|
|
|
45
82
|
seq_axis = op.Constant(value_ints=[1])
|
|
46
83
|
seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
47
84
|
# attn_output = op.Slice(value_3d, [0], [0], seq_axis)
|
|
48
|
-
seq_attn = op.SequenceEmpty(dtype=
|
|
85
|
+
seq_attn = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
49
86
|
for i_patch in range(num_patches):
|
|
50
87
|
i_1d = op.Reshape(i_patch, [1])
|
|
51
88
|
i_plus_1_1d = i_1d + 1
|
|
@@ -55,11 +92,7 @@ if patch_qwen2_5:
|
|
|
55
92
|
key_i = op.Slice(key_3d, start, end, seq_axis_int32)
|
|
56
93
|
value_i = op.Slice(value_3d, start, end, seq_axis_int32)
|
|
57
94
|
mha_output = msft_op.MultiHeadAttention(
|
|
58
|
-
query_i,
|
|
59
|
-
key_i,
|
|
60
|
-
value_i,
|
|
61
|
-
num_heads=num_heads,
|
|
62
|
-
scale=scaling,
|
|
95
|
+
query_i, key_i, value_i, num_heads=num_heads, scale=scaling
|
|
63
96
|
)
|
|
64
97
|
# attn_output = op.Concat(attn_output, mha_output, axis=1)
|
|
65
98
|
seq_attn = op.SequenceInsert(seq_attn, mha_output)
|
|
@@ -67,13 +100,47 @@ if patch_qwen2_5:
|
|
|
67
100
|
attn_output_4d = op.Reshape(attn_output, output_shape)
|
|
68
101
|
return attn_output_4d
|
|
69
102
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
103
|
+
@onnxscript.script(opset=onnx_plugs_op)
|
|
104
|
+
def LoopAttention24(
|
|
105
|
+
query_states,
|
|
106
|
+
key_states,
|
|
107
|
+
value_states,
|
|
108
|
+
cu_seqlens,
|
|
109
|
+
scaling: float = 0.11180339887498948,
|
|
110
|
+
num_heads: int = 16,
|
|
111
|
+
):
|
|
112
|
+
to_3d_shape = op24.Constant(value_ints=[0, 0, -1])
|
|
113
|
+
query_transposed = op24.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
114
|
+
output_shape = op24.Shape(query_transposed)
|
|
115
|
+
query_3d = op24.Reshape(query_transposed, to_3d_shape)
|
|
116
|
+
value_3d = op24.Reshape(op24.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
117
|
+
key_3d = op24.Reshape(op24.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
118
|
+
cu_seqlens = op24.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
|
|
119
|
+
num_patches = op24.Size(cu_seqlens) - 1
|
|
120
|
+
seq_axis = op24.Constant(value_ints=[1])
|
|
121
|
+
seq_axis_int32 = op24.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
122
|
+
seq_attn = op24.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
123
|
+
for i_patch in range(num_patches):
|
|
124
|
+
i_1d = op24.Reshape(i_patch, [1])
|
|
125
|
+
i_plus_1_1d = i_1d + 1
|
|
126
|
+
start = op24.Gather(cu_seqlens, i_1d, axis=0)
|
|
127
|
+
end = op24.Gather(cu_seqlens, i_plus_1_1d, axis=0)
|
|
128
|
+
query_i = op24.Slice(query_3d, start, end, seq_axis_int32)
|
|
129
|
+
key_i = op24.Slice(key_3d, start, end, seq_axis_int32)
|
|
130
|
+
value_i = op24.Slice(value_3d, start, end, seq_axis_int32)
|
|
131
|
+
mha_output = op24.Attention(
|
|
132
|
+
query_i,
|
|
133
|
+
key_i,
|
|
134
|
+
value_i,
|
|
135
|
+
scale=scaling,
|
|
136
|
+
q_num_heads=num_heads,
|
|
137
|
+
kv_num_heads=num_heads,
|
|
138
|
+
softmax_precision=onnx.TensorProto.FLOAT,
|
|
139
|
+
)
|
|
140
|
+
seq_attn = op24.SequenceInsert(seq_attn, mha_output)
|
|
141
|
+
attn_output = op24.ConcatFromSequence(seq_attn, axis=1)
|
|
142
|
+
attn_output_4d = op24.Reshape(attn_output, output_shape)
|
|
143
|
+
return attn_output_4d
|
|
77
144
|
|
|
78
145
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
79
146
|
def PackedAttention(
|
|
@@ -132,8 +199,40 @@ if patch_qwen2_5:
|
|
|
132
199
|
cu_seqlens: torch.Tensor, # F7su19
|
|
133
200
|
scaling: float = 0,
|
|
134
201
|
num_heads: int = 16,
|
|
135
|
-
itype: int = onnx.TensorProto.FLOAT,
|
|
136
202
|
) -> torch.Tensor:
|
|
203
|
+
"""
|
|
204
|
+
The loop can be removed with the following code
|
|
205
|
+
but it hits memory overflow for big inputs.
|
|
206
|
+
|
|
207
|
+
.. code-block:: python
|
|
208
|
+
|
|
209
|
+
# make square mask
|
|
210
|
+
indices = torch.arange(
|
|
211
|
+
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
|
212
|
+
)
|
|
213
|
+
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
|
|
214
|
+
cu_seqlens.dtype
|
|
215
|
+
)
|
|
216
|
+
dot = dot.sum(dim=0)
|
|
217
|
+
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
|
|
218
|
+
bool_mask = mask == 0
|
|
219
|
+
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
|
|
220
|
+
|
|
221
|
+
torch._check(bool_mask.shape[2] == key_states.shape[2])
|
|
222
|
+
torch._check(bool_mask.shape[3] == key_states.shape[2])
|
|
223
|
+
|
|
224
|
+
attn_output, _ = attention_interface(
|
|
225
|
+
self,
|
|
226
|
+
query_states,
|
|
227
|
+
key_states,
|
|
228
|
+
value_states,
|
|
229
|
+
attention_mask=bool_mask,
|
|
230
|
+
scaling=self.scaling,
|
|
231
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
232
|
+
is_causal=False,
|
|
233
|
+
**kwargs,
|
|
234
|
+
)
|
|
235
|
+
"""
|
|
137
236
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
138
237
|
splits = [
|
|
139
238
|
torch.split(tensor, lengths.tolist(), dim=2)
|
|
@@ -156,36 +255,58 @@ if patch_qwen2_5:
|
|
|
156
255
|
attn_output = torch.cat(attn_outputs, dim=1)
|
|
157
256
|
return attn_output
|
|
158
257
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
258
|
+
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
259
|
+
first_tensor = next(a for a in args if a is not None)
|
|
260
|
+
dtype = first_tensor.dtype
|
|
261
|
+
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
|
|
262
|
+
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
263
|
+
if strategy is not None:
|
|
264
|
+
return strategy, itype
|
|
265
|
+
if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
|
|
266
|
+
if opset >= 24:
|
|
267
|
+
return "LOOPA24", itype
|
|
268
|
+
return "LOOPMHA", itype
|
|
269
|
+
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
|
|
270
|
+
# first_tensor may be a SymbolicTensor (onnx).
|
|
271
|
+
# is_cuda is not available.
|
|
272
|
+
if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
|
|
273
|
+
return "PACKED", itype
|
|
274
|
+
return "LOOPMHA", itype
|
|
275
|
+
raise AssertionError(
|
|
276
|
+
f"Unable to handle type {torch.dtype} (itype={itype}) "
|
|
277
|
+
f"on device {torch.device} with opset={opset}"
|
|
278
|
+
)
|
|
174
279
|
|
|
175
|
-
|
|
280
|
+
qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
|
|
176
281
|
qwen_sdpa_attention,
|
|
177
282
|
lambda qs, *args, **kwargs: torch.empty(
|
|
178
283
|
(qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
|
|
179
284
|
dtype=qs.dtype,
|
|
180
285
|
device=qs.device,
|
|
181
286
|
),
|
|
182
|
-
|
|
287
|
+
{
|
|
288
|
+
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
289
|
+
PackedAttention.to_function_proto()
|
|
290
|
+
),
|
|
291
|
+
("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
|
|
292
|
+
("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
293
|
+
onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
|
|
294
|
+
),
|
|
295
|
+
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
296
|
+
LoopMHAAttention.to_function_proto()
|
|
297
|
+
),
|
|
298
|
+
("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
299
|
+
onnx.TensorProto.FLOAT16,
|
|
300
|
+
_add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
|
|
301
|
+
),
|
|
302
|
+
},
|
|
183
303
|
n_inputs=4,
|
|
184
304
|
n_outputs=1,
|
|
185
|
-
kwargs=dict(scaling=0.11180339887498948, num_heads=16
|
|
186
|
-
name="
|
|
305
|
+
kwargs=dict(scaling=0.11180339887498948, num_heads=16),
|
|
306
|
+
name="qwen_sdpa_attention_versatile",
|
|
307
|
+
version_selector=qwen_version_selector,
|
|
187
308
|
)
|
|
188
|
-
PLUGS.append(
|
|
309
|
+
PLUGS.append(qwen_sdpa_attention_versatile)
|
|
189
310
|
|
|
190
311
|
class patched_Qwen2_5_VLForConditionalGeneration:
|
|
191
312
|
_PATCHES_ = ["prepare_inputs_for_generation"]
|
|
@@ -434,6 +555,8 @@ if patch_qwen2_5:
|
|
|
434
555
|
position_embeddings=position_embeddings,
|
|
435
556
|
**kwargs,
|
|
436
557
|
)
|
|
558
|
+
if STOPAT is not None and layer_num > STOPAT:
|
|
559
|
+
break
|
|
437
560
|
|
|
438
561
|
hidden_states = self.merger(hidden_states)
|
|
439
562
|
reverse_indices = torch.argsort(window_index)
|
|
@@ -473,9 +596,7 @@ if patch_qwen2_5:
|
|
|
473
596
|
_PATCHED_CLASS_ = (
|
|
474
597
|
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
|
|
475
598
|
)
|
|
476
|
-
STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731
|
|
477
|
-
"QWEN25ATTENTION", "PACKED"
|
|
478
|
-
)
|
|
599
|
+
STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731
|
|
479
600
|
|
|
480
601
|
def forward(
|
|
481
602
|
self,
|
|
@@ -519,14 +640,15 @@ if patch_qwen2_5:
|
|
|
519
640
|
self.config._attn_implementation
|
|
520
641
|
]
|
|
521
642
|
|
|
522
|
-
|
|
643
|
+
is_sdpa_or_eager = (
|
|
523
644
|
attention_interface
|
|
524
645
|
is transformers.integrations.sdpa_attention.sdpa_attention_forward
|
|
525
646
|
or attention_interface is patched_sdpa_attention_forward
|
|
647
|
+
or attention_interface
|
|
648
|
+
is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
|
|
526
649
|
)
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
attn_output = qwen_sdpa_attention_packed_versatile(
|
|
650
|
+
if is_sdpa_or_eager:
|
|
651
|
+
attn_output = qwen_sdpa_attention_versatile(
|
|
530
652
|
query_states,
|
|
531
653
|
key_states,
|
|
532
654
|
value_states,
|
|
@@ -558,78 +680,10 @@ if patch_qwen2_5:
|
|
|
558
680
|
),
|
|
559
681
|
version=1,
|
|
560
682
|
)
|
|
561
|
-
elif is_sdpa and attention_strategy == "LOOPMHA":
|
|
562
|
-
attn_output = qwen_sdpa_attention_loopmha_versatile(
|
|
563
|
-
query_states,
|
|
564
|
-
key_states,
|
|
565
|
-
value_states,
|
|
566
|
-
cu_seqlens,
|
|
567
|
-
self.scaling,
|
|
568
|
-
self.num_heads,
|
|
569
|
-
(
|
|
570
|
-
onnx.TensorProto.FLOAT
|
|
571
|
-
if query_states.dtype == torch.float32
|
|
572
|
-
else (
|
|
573
|
-
onnx.TensorProto.FLOAT16
|
|
574
|
-
if query_states.dtype == torch.float16
|
|
575
|
-
else onnx.TensorProto.BFLOAT16
|
|
576
|
-
)
|
|
577
|
-
),
|
|
578
|
-
)
|
|
579
|
-
|
|
580
|
-
# to rewrite later with a for loop
|
|
581
|
-
# def _iteration(start_end, query_states, key_states, value_states):
|
|
582
|
-
# return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
|
|
583
|
-
# self,
|
|
584
|
-
# start_end,
|
|
585
|
-
# query_states,
|
|
586
|
-
# key_states,
|
|
587
|
-
# value_states,
|
|
588
|
-
# scaling=self.scaling,
|
|
589
|
-
# dropout=0.0 if not self.training else self.attention_dropout,
|
|
590
|
-
# )
|
|
591
|
-
|
|
592
|
-
# starts = cu_seqlens[:-1]
|
|
593
|
-
# ends = cu_seqlens[1:]
|
|
594
|
-
# torch._check(starts.shape[0] > 0)
|
|
595
|
-
# torch._check(ends.shape[0] > 0)
|
|
596
|
-
# starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
|
|
597
|
-
# attn_outputs = [
|
|
598
|
-
# _iteration(start_end, query_states, key_states, value_states)
|
|
599
|
-
# for start_end in starts_ends
|
|
600
|
-
# ]
|
|
601
|
-
# attn_output = torch.cat(attn_outputs, dim=1)
|
|
602
|
-
elif is_sdpa and attention_strategy == "BIGMASK":
|
|
603
|
-
# make square mask
|
|
604
|
-
indices = torch.arange(
|
|
605
|
-
cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
|
606
|
-
)
|
|
607
|
-
dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
|
|
608
|
-
cu_seqlens.dtype
|
|
609
|
-
)
|
|
610
|
-
dot = dot.sum(dim=0)
|
|
611
|
-
mask = dot.unsqueeze(1) - dot.unsqueeze(0)
|
|
612
|
-
bool_mask = mask == 0
|
|
613
|
-
bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
|
|
614
|
-
|
|
615
|
-
torch._check(bool_mask.shape[2] == key_states.shape[2])
|
|
616
|
-
torch._check(bool_mask.shape[3] == key_states.shape[2])
|
|
617
|
-
|
|
618
|
-
attn_output, _ = attention_interface(
|
|
619
|
-
self,
|
|
620
|
-
query_states,
|
|
621
|
-
key_states,
|
|
622
|
-
value_states,
|
|
623
|
-
attention_mask=bool_mask,
|
|
624
|
-
scaling=self.scaling,
|
|
625
|
-
dropout=0.0 if not self.training else self.attention_dropout,
|
|
626
|
-
is_causal=False,
|
|
627
|
-
**kwargs,
|
|
628
|
-
)
|
|
629
683
|
else:
|
|
630
684
|
raise NotImplementedError(
|
|
631
|
-
f"No corresponding export strategy for "
|
|
632
|
-
f"{
|
|
685
|
+
f"No corresponding export strategy for implementation "
|
|
686
|
+
f"{self.config._attn_implementation!r}, "
|
|
633
687
|
f"(use QWEN25ATTENTION to change it), and attention_interface="
|
|
634
688
|
f"{attention_interface!r} (use sdpa)"
|
|
635
689
|
)
|
|
@@ -653,6 +707,7 @@ if patch_qwen2_5:
|
|
|
653
707
|
)
|
|
654
708
|
else:
|
|
655
709
|
# Other implementations: Process each chunk separately
|
|
710
|
+
# = qwen_sdpa_attention
|
|
656
711
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
657
712
|
splits = [
|
|
658
713
|
torch.split(tensor, lengths.tolist(), dim=2)
|
|
@@ -236,7 +236,7 @@ def code_sample(
|
|
|
236
236
|
)
|
|
237
237
|
)
|
|
238
238
|
"""
|
|
239
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
239
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
240
240
|
model_id,
|
|
241
241
|
subfolder,
|
|
242
242
|
same_as_pretrained=same_as_pretrained,
|
|
@@ -256,6 +256,7 @@ def code_sample(
|
|
|
256
256
|
model_kwargs=mop,
|
|
257
257
|
subfolder=subfolder,
|
|
258
258
|
add_second_input=False,
|
|
259
|
+
submodule=submodule,
|
|
259
260
|
)
|
|
260
261
|
if drop_inputs:
|
|
261
262
|
update = {}
|
|
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _preprocess_model_id(
|
|
29
|
-
model_id: str,
|
|
30
|
-
|
|
29
|
+
model_id: str,
|
|
30
|
+
subfolder: Optional[str],
|
|
31
|
+
same_as_pretrained: bool,
|
|
32
|
+
use_pretrained: bool,
|
|
33
|
+
submodule: Optional[str] = None,
|
|
34
|
+
) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
|
|
35
|
+
if "::" in model_id:
|
|
36
|
+
assert (
|
|
37
|
+
not submodule
|
|
38
|
+
), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
|
|
39
|
+
model_id, submodule = model_id.split("::", maxsplit=1)
|
|
31
40
|
if subfolder or "//" not in model_id:
|
|
32
|
-
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
41
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
|
|
33
42
|
spl = model_id.split("//")
|
|
34
43
|
if spl[-1] == "pretrained":
|
|
35
|
-
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
|
|
44
|
+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
|
|
36
45
|
if spl[-1] in {"transformer", "vae"}:
|
|
37
46
|
# known subfolder
|
|
38
|
-
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
|
|
39
|
-
return model_id, subfolder, same_as_pretrained, use_pretrained
|
|
47
|
+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained, submodule
|
|
48
|
+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
|
|
40
49
|
|
|
41
50
|
|
|
42
51
|
def get_untrained_model_with_inputs(
|
|
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
|
|
|
54
63
|
subfolder: Optional[str] = None,
|
|
55
64
|
use_only_preinstalled: bool = False,
|
|
56
65
|
config_reduction: Optional[Callable[[Any, str], Dict]] = None,
|
|
66
|
+
submodule: Optional[str] = None,
|
|
57
67
|
) -> Dict[str, Any]:
|
|
58
68
|
"""
|
|
59
69
|
Gets a non initialized model similar to the original model
|
|
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
|
|
|
82
92
|
<onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
|
|
83
93
|
this function takes a configuration and a task (string)
|
|
84
94
|
as arguments
|
|
95
|
+
:param submodule: use a submodule instead of the main model
|
|
85
96
|
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
|
|
86
97
|
some necessary rewriting as well
|
|
87
98
|
|
|
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
|
|
|
108
119
|
f"model_id={model_id!r}, preinstalled model is only available "
|
|
109
120
|
f"if use_only_preinstalled is False."
|
|
110
121
|
)
|
|
111
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
122
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
112
123
|
model_id,
|
|
113
124
|
subfolder,
|
|
114
125
|
same_as_pretrained=same_as_pretrained,
|
|
115
126
|
use_pretrained=use_pretrained,
|
|
127
|
+
submodule=submodule,
|
|
116
128
|
)
|
|
117
129
|
if verbose:
|
|
118
130
|
print(
|
|
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
|
|
|
147
159
|
if verbose:
|
|
148
160
|
print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
|
|
149
161
|
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
|
|
162
|
+
if submodule:
|
|
163
|
+
print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
|
|
150
164
|
if task is None:
|
|
151
165
|
task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
|
|
152
166
|
if verbose:
|
|
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
|
|
|
357
371
|
if diff_config is not None:
|
|
358
372
|
res["dump_info"] = dict(config_diff=diff_config)
|
|
359
373
|
|
|
374
|
+
if submodule:
|
|
375
|
+
path = submodule.split("::") if "::" in submodule else [submodule]
|
|
376
|
+
for p in path:
|
|
377
|
+
assert hasattr(model, p), (
|
|
378
|
+
f"Unable to find submodule {p!r} in in class {type(model)}, "
|
|
379
|
+
f"submodule={submodule!r}, possible candidates: "
|
|
380
|
+
f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
|
|
381
|
+
)
|
|
382
|
+
model = getattr(model, p)
|
|
383
|
+
|
|
384
|
+
if verbose:
|
|
385
|
+
print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
|
|
386
|
+
|
|
360
387
|
sizes = compute_model_size(model)
|
|
361
388
|
res["model"] = model
|
|
362
389
|
res["configuration"] = config
|
|
@@ -349,13 +349,15 @@ def _prepare_validation(
|
|
|
349
349
|
verbose,
|
|
350
350
|
output_names,
|
|
351
351
|
dump_folder,
|
|
352
|
+
submodule,
|
|
352
353
|
):
|
|
353
354
|
main_validation_begin = time.perf_counter()
|
|
354
|
-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
355
|
+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
|
|
355
356
|
model_id,
|
|
356
357
|
subfolder,
|
|
357
358
|
same_as_pretrained=same_as_pretrained,
|
|
358
359
|
use_pretrained=use_pretrained,
|
|
360
|
+
submodule=submodule,
|
|
359
361
|
)
|
|
360
362
|
time_preprocess_model_id = time.perf_counter() - main_validation_begin
|
|
361
363
|
patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
|
|
@@ -364,6 +366,7 @@ def _prepare_validation(
|
|
|
364
366
|
summary.update(
|
|
365
367
|
dict(
|
|
366
368
|
version_model_id=model_id,
|
|
369
|
+
version_submodule=submodule,
|
|
367
370
|
version_do_run=str(do_run),
|
|
368
371
|
version_dtype=str(dtype or ""),
|
|
369
372
|
version_device=str(device or ""),
|
|
@@ -444,6 +447,7 @@ def _prepare_validation(
|
|
|
444
447
|
dump_folder,
|
|
445
448
|
folder_name,
|
|
446
449
|
patch_kwargs,
|
|
450
|
+
submodule,
|
|
447
451
|
)
|
|
448
452
|
|
|
449
453
|
|
|
@@ -460,6 +464,7 @@ def _get_untrained_model_with_inputs(
|
|
|
460
464
|
inputs2,
|
|
461
465
|
quiet,
|
|
462
466
|
dump_folder,
|
|
467
|
+
submodule,
|
|
463
468
|
):
|
|
464
469
|
iop = input_options or {}
|
|
465
470
|
mop = model_options or {}
|
|
@@ -480,6 +485,7 @@ def _get_untrained_model_with_inputs(
|
|
|
480
485
|
model_kwargs=mop,
|
|
481
486
|
subfolder=sub,
|
|
482
487
|
add_second_input=i2,
|
|
488
|
+
submodule=submodule,
|
|
483
489
|
)
|
|
484
490
|
)
|
|
485
491
|
),
|
|
@@ -842,6 +848,7 @@ def validate_model(
|
|
|
842
848
|
ort_logs: bool = False,
|
|
843
849
|
quiet_input_sets: Optional[Set[str]] = None,
|
|
844
850
|
save_ep: Optional[str] = None,
|
|
851
|
+
submodule: Optional[str] = None,
|
|
845
852
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
846
853
|
"""
|
|
847
854
|
Validates a model.
|
|
@@ -902,6 +909,7 @@ def validate_model(
|
|
|
902
909
|
even if quiet is False
|
|
903
910
|
:param save_ep: if not empty, this can be used to save the input sets and
|
|
904
911
|
the exported program
|
|
912
|
+
:param submodule: to test not the model but a submodule of this model
|
|
905
913
|
:return: two dictionaries, one with some metrics,
|
|
906
914
|
another one with whatever the function produces
|
|
907
915
|
|
|
@@ -966,6 +974,7 @@ def validate_model(
|
|
|
966
974
|
use_pretrained=use_pretrained,
|
|
967
975
|
same_as_pretrained=same_as_pretrained,
|
|
968
976
|
save_ep=save_ep,
|
|
977
|
+
submodule=submodule,
|
|
969
978
|
)
|
|
970
979
|
if dump_folder:
|
|
971
980
|
with open(dump_stats, "w") as f:
|
|
@@ -1053,6 +1062,7 @@ def _validate_model_step1(
|
|
|
1053
1062
|
use_pretrained,
|
|
1054
1063
|
same_as_pretrained,
|
|
1055
1064
|
save_ep,
|
|
1065
|
+
submodule,
|
|
1056
1066
|
):
|
|
1057
1067
|
assert not do_same or do_run, (
|
|
1058
1068
|
f"Discrepancies cannot be measured if the model is not run, "
|
|
@@ -1067,6 +1077,7 @@ def _validate_model_step1(
|
|
|
1067
1077
|
dump_folder,
|
|
1068
1078
|
folder_name,
|
|
1069
1079
|
patch_kwargs,
|
|
1080
|
+
submodule,
|
|
1070
1081
|
) = _prepare_validation(
|
|
1071
1082
|
model_id=model_id,
|
|
1072
1083
|
subfolder=subfolder,
|
|
@@ -1093,6 +1104,7 @@ def _validate_model_step1(
|
|
|
1093
1104
|
verbose=verbose,
|
|
1094
1105
|
output_names=output_names,
|
|
1095
1106
|
dump_folder=dump_folder,
|
|
1107
|
+
submodule=submodule,
|
|
1096
1108
|
)
|
|
1097
1109
|
|
|
1098
1110
|
data, iop, mop = _get_untrained_model_with_inputs(
|
|
@@ -1108,6 +1120,7 @@ def _validate_model_step1(
|
|
|
1108
1120
|
inputs2=inputs2,
|
|
1109
1121
|
quiet=quiet,
|
|
1110
1122
|
dump_folder=dump_folder,
|
|
1123
|
+
submodule=submodule,
|
|
1111
1124
|
)
|
|
1112
1125
|
|
|
1113
1126
|
second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
|