onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +91 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +3 -3
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +92 -23
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +90 -26
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +103 -1
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +103 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -1,23 +1,16 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
3
|
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
4
|
-
from ..helpers.config_helper import
|
|
5
|
-
update_config,
|
|
6
|
-
check_hasattr,
|
|
7
|
-
_pick,
|
|
8
|
-
default_num_hidden_layers as nhl,
|
|
9
|
-
)
|
|
4
|
+
from ..helpers.config_helper import update_config, check_hasattr
|
|
10
5
|
|
|
11
6
|
__TASK__ = "summarization"
|
|
12
7
|
|
|
13
8
|
|
|
14
9
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
15
10
|
"""Reduces a model size."""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
if hasattr(config, "num_hidden_layers"):
|
|
20
|
-
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
|
|
11
|
+
check_hasattr(config, "vocab_size")
|
|
12
|
+
# Bart architecture does not like too much that the number of layers is changed.
|
|
13
|
+
kwargs = dict(vocab_size=2056)
|
|
21
14
|
update_config(config, kwargs)
|
|
22
15
|
return kwargs
|
|
23
16
|
|
|
@@ -25,96 +18,66 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
25
18
|
def get_inputs(
|
|
26
19
|
model: torch.nn.Module,
|
|
27
20
|
config: Optional[Any],
|
|
21
|
+
batch_size: int,
|
|
22
|
+
sequence_length: int,
|
|
28
23
|
dummy_max_token_id: int,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
sequence_length2: int = 3,
|
|
24
|
+
past_length: int = 30,
|
|
25
|
+
past_length2: int = 4,
|
|
26
|
+
decoder_attention_heads: Optional[int] = None,
|
|
27
|
+
encoder_attention_heads: Optional[int] = None,
|
|
28
|
+
encoder_ffn_dim: Optional[int] = None,
|
|
29
|
+
decoder_ffn_dim: Optional[int] = None,
|
|
30
|
+
num_hidden_layers: Optional[int] = None,
|
|
37
31
|
add_second_input: int = 1,
|
|
38
32
|
**kwargs, # unused
|
|
39
33
|
):
|
|
40
34
|
"""
|
|
41
|
-
Generates
|
|
42
|
-
|
|
43
|
-
:param model: model to get the missing information
|
|
44
|
-
:param config: configuration used to generate the model
|
|
45
|
-
:param head_dim_encoder: last dimension of the cache for the encoder
|
|
46
|
-
:param head_dim_decoder: last dimension of the cache for the decoder
|
|
47
|
-
:param num_key_value_heads_encoder: number of heads for the encoder
|
|
48
|
-
:param num_key_value_heads_decoder: number of heads for the decoder
|
|
49
|
-
:param dummy_max_token_id: dummy max token id
|
|
50
|
-
:param batch_size: batch size
|
|
51
|
-
:param sequence_length: sequence length
|
|
52
|
-
:param sequence_length2: new sequence length
|
|
53
|
-
:return: dictionary
|
|
54
|
-
|
|
55
|
-
Stolen inputs for one model.
|
|
35
|
+
Generates inputs for task ``feature-extraction``.
|
|
36
|
+
Example:
|
|
56
37
|
|
|
57
38
|
::
|
|
58
39
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
key_cache=#6[T1s1x8x1x64,...],
|
|
63
|
-
value_cache=#6[T1s1x8x1x64,...]),
|
|
64
|
-
cross_attention_cache=DynamicCache(
|
|
65
|
-
key_cache=#6[T1s1x8x16x64,...],
|
|
66
|
-
value_cache=#6[T1s1x8x16x64,...])),
|
|
67
|
-
decoder_input_ids:T7s1x1,
|
|
68
|
-
encoder_outputs:dict(last_hidden_state:T1s1x16x512)
|
|
40
|
+
input_ids:T7s1x13[101,72654:A16789.23076923077],
|
|
41
|
+
token_type_ids:T7s1x13[0,0:A0.0],
|
|
42
|
+
attention_mask:T7s1x13[1,1:A1.0])
|
|
69
43
|
"""
|
|
70
44
|
assert (
|
|
71
45
|
"cls_cache" not in kwargs
|
|
72
46
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
73
47
|
batch = "batch"
|
|
74
|
-
seq_length = "
|
|
75
|
-
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
76
|
-
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
|
|
77
|
-
|
|
48
|
+
seq_length = "sequence_length"
|
|
78
49
|
shapes = {
|
|
79
50
|
"input_ids": {0: batch, 1: seq_length},
|
|
80
|
-
"
|
|
81
|
-
"attention_mask": {0: batch, 1: "seq_mask"},
|
|
82
|
-
# "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
|
|
83
|
-
"past_key_values": [
|
|
84
|
-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
|
|
85
|
-
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
|
|
86
|
-
],
|
|
87
|
-
# one these is selected based on the forward method signature
|
|
88
|
-
# "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
|
|
89
|
-
# "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
|
|
51
|
+
"attention_mask": {0: batch, 1: seq_length},
|
|
90
52
|
}
|
|
91
|
-
|
|
92
53
|
inputs = dict(
|
|
93
54
|
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
|
|
94
55
|
torch.int64
|
|
95
56
|
),
|
|
96
|
-
decoder_input_ids=torch.randint(
|
|
97
|
-
0, dummy_max_token_id, (batch_size, sequence_length2)
|
|
98
|
-
).to(torch.int64),
|
|
99
57
|
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
58
|
+
)
|
|
59
|
+
if (
|
|
60
|
+
encoder_attention_heads
|
|
61
|
+
and decoder_attention_heads
|
|
62
|
+
and encoder_ffn_dim
|
|
63
|
+
and decoder_ffn_dim
|
|
64
|
+
and num_hidden_layers
|
|
65
|
+
):
|
|
66
|
+
inputs["past_key_values"] = make_encoder_decoder_cache(
|
|
104
67
|
make_dynamic_cache(
|
|
105
68
|
[
|
|
106
69
|
(
|
|
107
70
|
torch.randn(
|
|
108
71
|
batch_size,
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
72
|
+
encoder_attention_heads,
|
|
73
|
+
past_length,
|
|
74
|
+
encoder_ffn_dim,
|
|
112
75
|
),
|
|
113
76
|
torch.randn(
|
|
114
77
|
batch_size,
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
78
|
+
encoder_attention_heads,
|
|
79
|
+
past_length,
|
|
80
|
+
encoder_ffn_dim,
|
|
118
81
|
),
|
|
119
82
|
)
|
|
120
83
|
for i in range(num_hidden_layers)
|
|
@@ -125,22 +88,28 @@ def get_inputs(
|
|
|
125
88
|
(
|
|
126
89
|
torch.randn(
|
|
127
90
|
batch_size,
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
91
|
+
decoder_attention_heads,
|
|
92
|
+
past_length2,
|
|
93
|
+
decoder_ffn_dim,
|
|
131
94
|
),
|
|
132
95
|
torch.randn(
|
|
133
96
|
batch_size,
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
97
|
+
decoder_attention_heads,
|
|
98
|
+
past_length2,
|
|
99
|
+
decoder_ffn_dim,
|
|
137
100
|
),
|
|
138
101
|
)
|
|
139
102
|
for i in range(num_hidden_layers)
|
|
140
103
|
]
|
|
141
104
|
),
|
|
142
|
-
)
|
|
143
|
-
|
|
105
|
+
)
|
|
106
|
+
cache_length = "cache_length_key"
|
|
107
|
+
cache_length2 = "cache_length_val"
|
|
108
|
+
shapes["past_key_values"] = [ # type: ignore[assignment]
|
|
109
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
|
|
110
|
+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
|
|
111
|
+
]
|
|
112
|
+
|
|
144
113
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
145
114
|
if add_second_input:
|
|
146
115
|
assert (
|
|
@@ -149,15 +118,16 @@ def get_inputs(
|
|
|
149
118
|
res["inputs2"] = get_inputs(
|
|
150
119
|
model=model,
|
|
151
120
|
config=config,
|
|
152
|
-
dummy_max_token_id=dummy_max_token_id,
|
|
153
|
-
num_key_value_heads_encoder=num_key_value_heads_encoder,
|
|
154
|
-
num_key_value_heads_decoder=num_key_value_heads_decoder,
|
|
155
|
-
num_hidden_layers=num_hidden_layers,
|
|
156
|
-
head_dim_encoder=head_dim_encoder,
|
|
157
|
-
head_dim_decoder=head_dim_decoder,
|
|
158
121
|
batch_size=batch_size + 1,
|
|
159
122
|
sequence_length=sequence_length + add_second_input,
|
|
160
|
-
|
|
123
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
124
|
+
past_length=past_length,
|
|
125
|
+
past_length2=past_length2,
|
|
126
|
+
decoder_attention_heads=decoder_attention_heads,
|
|
127
|
+
encoder_attention_heads=encoder_attention_heads,
|
|
128
|
+
encoder_ffn_dim=encoder_ffn_dim,
|
|
129
|
+
decoder_ffn_dim=decoder_ffn_dim,
|
|
130
|
+
num_hidden_layers=num_hidden_layers,
|
|
161
131
|
add_second_input=0,
|
|
162
132
|
**kwargs,
|
|
163
133
|
)["inputs"]
|
|
@@ -171,57 +141,22 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
171
141
|
If the configuration is None, the function selects typical dimensions.
|
|
172
142
|
"""
|
|
173
143
|
if config is not None:
|
|
174
|
-
check_hasattr(
|
|
175
|
-
config,
|
|
176
|
-
"vocab_size",
|
|
177
|
-
"hidden_size",
|
|
178
|
-
"num_attention_heads",
|
|
179
|
-
("num_hidden_layers", "num_layers"),
|
|
180
|
-
("n_positions", "d_model"),
|
|
181
|
-
(
|
|
182
|
-
"num_key_value_heads",
|
|
183
|
-
"num_heads",
|
|
184
|
-
("decoder_attention_heads", "encoder_attention_heads"),
|
|
185
|
-
),
|
|
186
|
-
)
|
|
187
|
-
# exceptions = {
|
|
188
|
-
# "PLBartForConditionalGeneration": (
|
|
189
|
-
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
|
|
190
|
-
# )
|
|
191
|
-
# }
|
|
144
|
+
check_hasattr(config, "vocab_size")
|
|
192
145
|
kwargs = dict(
|
|
193
146
|
batch_size=2,
|
|
194
|
-
sequence_length=
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
),
|
|
199
|
-
head_dim_decoder=(
|
|
200
|
-
16 if config is None else int(_pick(config, "decoder_ffn_dim") ** 0.5)
|
|
201
|
-
),
|
|
202
|
-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
|
|
203
|
-
num_hidden_layers=(
|
|
204
|
-
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
|
|
205
|
-
),
|
|
206
|
-
num_key_value_heads_encoder=(
|
|
207
|
-
16
|
|
208
|
-
if config is None
|
|
209
|
-
else _pick(
|
|
210
|
-
config,
|
|
211
|
-
"encoder_attention_heads",
|
|
212
|
-
"num_key_value_heads",
|
|
213
|
-
"num_heads",
|
|
214
|
-
)
|
|
215
|
-
),
|
|
216
|
-
num_key_value_heads_decoder=(
|
|
217
|
-
16
|
|
218
|
-
if config is None
|
|
219
|
-
else _pick(
|
|
220
|
-
config,
|
|
221
|
-
"decoder_attention_heads",
|
|
222
|
-
"num_key_value_heads",
|
|
223
|
-
"num_heads",
|
|
224
|
-
)
|
|
225
|
-
),
|
|
147
|
+
sequence_length=12,
|
|
148
|
+
past_length=30,
|
|
149
|
+
past_length2=4,
|
|
150
|
+
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
|
|
226
151
|
)
|
|
152
|
+
for att in [
|
|
153
|
+
"decoder_attention_heads",
|
|
154
|
+
"encoder_attention_heads",
|
|
155
|
+
"encoder_ffn_dim",
|
|
156
|
+
"decoder_ffn_dim",
|
|
157
|
+
"num_hidden_layers",
|
|
158
|
+
]:
|
|
159
|
+
if hasattr(config, att):
|
|
160
|
+
kwargs[att] = getattr(config, att)
|
|
161
|
+
kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
|
|
227
162
|
return kwargs, get_inputs
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import torch
|
|
3
|
+
import transformers
|
|
4
|
+
from .patch_helper import _has_transformers
|
|
5
|
+
|
|
6
|
+
patch_sdpa_is_causal = _has_transformers("4.99")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def common_eager_attention_forward(
|
|
10
|
+
module: torch.nn.Module,
|
|
11
|
+
query: torch.Tensor,
|
|
12
|
+
key: torch.Tensor,
|
|
13
|
+
value: torch.Tensor,
|
|
14
|
+
attention_mask: Optional[torch.Tensor],
|
|
15
|
+
scaling: Optional[float] = None,
|
|
16
|
+
dropout: float = 0.0,
|
|
17
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
18
|
+
**kwargs,
|
|
19
|
+
):
|
|
20
|
+
if scaling is None:
|
|
21
|
+
scaling = query.size(-1) ** -0.5
|
|
22
|
+
|
|
23
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
24
|
+
if attention_mask is not None:
|
|
25
|
+
# PATCHED
|
|
26
|
+
# The two following lines were added.
|
|
27
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
28
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
29
|
+
attn_weights = attn_weights + attention_mask
|
|
30
|
+
|
|
31
|
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
|
32
|
+
|
|
33
|
+
if head_mask is not None:
|
|
34
|
+
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
|
35
|
+
|
|
36
|
+
attn_weights = torch.nn.functional.dropout(
|
|
37
|
+
attn_weights, p=dropout, training=module.training
|
|
38
|
+
)
|
|
39
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
40
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
41
|
+
|
|
42
|
+
return attn_output, attn_weights
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def patched_sdpa_attention_forward(
|
|
46
|
+
module: torch.nn.Module,
|
|
47
|
+
query: torch.Tensor,
|
|
48
|
+
key: torch.Tensor,
|
|
49
|
+
value: torch.Tensor,
|
|
50
|
+
attention_mask: Optional[torch.Tensor],
|
|
51
|
+
dropout: float = 0.0,
|
|
52
|
+
scaling: Optional[float] = None,
|
|
53
|
+
is_causal: Optional[bool] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> tuple[torch.Tensor, None]:
|
|
56
|
+
"""
|
|
57
|
+
manual patch for function
|
|
58
|
+
``transformers.integrations.sdpa_attention.sdpa_attention_forward``
|
|
59
|
+
"""
|
|
60
|
+
assert not kwargs.get("output_attentions", False), (
|
|
61
|
+
"`sdpa` attention does not support `output_attentions=True`."
|
|
62
|
+
" Please set your attention to `eager` if you want any of these features."
|
|
63
|
+
)
|
|
64
|
+
torch._check(
|
|
65
|
+
query.shape[0] == key.shape[0] or query.shape[0] == 1,
|
|
66
|
+
lambda: (
|
|
67
|
+
f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
|
|
68
|
+
f"value: {value.shape}"
|
|
69
|
+
),
|
|
70
|
+
)
|
|
71
|
+
torch._check(
|
|
72
|
+
key.shape[0] == value.shape[0] or key.shape[0] == 1,
|
|
73
|
+
lambda: (
|
|
74
|
+
f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
|
|
75
|
+
f"value: {value.shape}"
|
|
76
|
+
),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
sdpa_kwargs = {}
|
|
80
|
+
if hasattr(module, "num_key_value_groups"):
|
|
81
|
+
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
|
|
82
|
+
key = transformers.integrations.sdpa_attention.repeat_kv(
|
|
83
|
+
key, module.num_key_value_groups
|
|
84
|
+
)
|
|
85
|
+
value = transformers.integrations.sdpa_attention.repeat_kv(
|
|
86
|
+
value, module.num_key_value_groups
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
sdpa_kwargs = {"enable_gqa": True}
|
|
90
|
+
|
|
91
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
92
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
93
|
+
|
|
94
|
+
torch._check(
|
|
95
|
+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
|
|
96
|
+
lambda: "Attention mask shape incompatible with key shape.",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if patch_sdpa_is_causal:
|
|
100
|
+
# transformers>=4.55
|
|
101
|
+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
|
102
|
+
|
|
103
|
+
# PATCHED: remove the test query.shape[2] > 1
|
|
104
|
+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
|
|
105
|
+
# and we split the test to keep the minimum in torch.cond
|
|
106
|
+
is_causal = attention_mask is None and is_causal
|
|
107
|
+
|
|
108
|
+
if not is_causal:
|
|
109
|
+
torch._check(query.shape[0] > 0)
|
|
110
|
+
torch._check(query.shape[1] > 0)
|
|
111
|
+
torch._check(query.shape[2] > 0)
|
|
112
|
+
torch._check(query.shape[3] > 0)
|
|
113
|
+
torch._check(key.shape[0] > 0)
|
|
114
|
+
torch._check(key.shape[1] > 0)
|
|
115
|
+
torch._check(key.shape[2] > 0)
|
|
116
|
+
torch._check(key.shape[3] > 0)
|
|
117
|
+
torch._check(value.shape[0] > 0)
|
|
118
|
+
torch._check(value.shape[1] > 0)
|
|
119
|
+
torch._check(value.shape[2] > 0)
|
|
120
|
+
torch._check(value.shape[3] > 0)
|
|
121
|
+
return (
|
|
122
|
+
torch.nn.functional.scaled_dot_product_attention(
|
|
123
|
+
query,
|
|
124
|
+
key,
|
|
125
|
+
value,
|
|
126
|
+
attn_mask=attention_mask,
|
|
127
|
+
dropout_p=dropout,
|
|
128
|
+
scale=scaling,
|
|
129
|
+
is_causal=is_causal,
|
|
130
|
+
**sdpa_kwargs,
|
|
131
|
+
)
|
|
132
|
+
.transpose(1, 2)
|
|
133
|
+
.contiguous(),
|
|
134
|
+
None,
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
# transformers<4.55
|
|
138
|
+
if is_causal is None and attention_mask is not None:
|
|
139
|
+
is_causal = False
|
|
140
|
+
if is_causal is not None:
|
|
141
|
+
return (
|
|
142
|
+
torch.nn.functional.scaled_dot_product_attention(
|
|
143
|
+
query,
|
|
144
|
+
key,
|
|
145
|
+
value,
|
|
146
|
+
attn_mask=attention_mask,
|
|
147
|
+
dropout_p=dropout,
|
|
148
|
+
scale=scaling,
|
|
149
|
+
is_causal=is_causal,
|
|
150
|
+
**sdpa_kwargs,
|
|
151
|
+
)
|
|
152
|
+
.transpose(1, 2)
|
|
153
|
+
.contiguous(),
|
|
154
|
+
None,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# To avoid the following errors:
|
|
158
|
+
# is_causal=query.shape[2] > 1
|
|
159
|
+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
|
|
160
|
+
# is_causal=torch.tensor(query.shape[2] > 1)
|
|
161
|
+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
|
|
162
|
+
attn_output = torch.cond(
|
|
163
|
+
query.shape[2] > 1, # distinction between prefill and decoding steps
|
|
164
|
+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
165
|
+
query,
|
|
166
|
+
key,
|
|
167
|
+
value,
|
|
168
|
+
dropout_p=dropout,
|
|
169
|
+
scale=scaling,
|
|
170
|
+
is_causal=True,
|
|
171
|
+
**sdpa_kwargs,
|
|
172
|
+
).contiguous(),
|
|
173
|
+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
174
|
+
query,
|
|
175
|
+
key,
|
|
176
|
+
value,
|
|
177
|
+
dropout_p=dropout,
|
|
178
|
+
scale=scaling,
|
|
179
|
+
is_causal=False,
|
|
180
|
+
**sdpa_kwargs,
|
|
181
|
+
).contiguous(),
|
|
182
|
+
[query, key, value],
|
|
183
|
+
)
|
|
184
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
185
|
+
return attn_output, None
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def patched_model_bart_eager_attention_forward(
|
|
189
|
+
module: torch.nn.Module,
|
|
190
|
+
query: torch.Tensor,
|
|
191
|
+
key: torch.Tensor,
|
|
192
|
+
value: torch.Tensor,
|
|
193
|
+
attention_mask: Optional[torch.Tensor],
|
|
194
|
+
scaling: Optional[float] = None,
|
|
195
|
+
dropout: float = 0.0,
|
|
196
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
197
|
+
**kwargs,
|
|
198
|
+
):
|
|
199
|
+
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
|
|
200
|
+
return common_eager_attention_forward(
|
|
201
|
+
module,
|
|
202
|
+
query,
|
|
203
|
+
key,
|
|
204
|
+
value,
|
|
205
|
+
attention_mask=attention_mask,
|
|
206
|
+
scaling=scaling,
|
|
207
|
+
dropout=dropout,
|
|
208
|
+
head_mask=head_mask,
|
|
209
|
+
**kwargs,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def patched_modeling_marian_eager_attention_forward(
|
|
214
|
+
module: torch.nn.Module,
|
|
215
|
+
query: torch.Tensor,
|
|
216
|
+
key: torch.Tensor,
|
|
217
|
+
value: torch.Tensor,
|
|
218
|
+
attention_mask: Optional[torch.Tensor],
|
|
219
|
+
scaling: Optional[float] = None,
|
|
220
|
+
dropout: float = 0.0,
|
|
221
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
222
|
+
**kwargs,
|
|
223
|
+
):
|
|
224
|
+
"""[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
|
|
225
|
+
return common_eager_attention_forward(
|
|
226
|
+
module,
|
|
227
|
+
query,
|
|
228
|
+
key,
|
|
229
|
+
value,
|
|
230
|
+
attention_mask=attention_mask,
|
|
231
|
+
scaling=scaling,
|
|
232
|
+
dropout=dropout,
|
|
233
|
+
head_mask=head_mask,
|
|
234
|
+
**kwargs,
|
|
235
|
+
)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import inspect
|
|
3
|
+
import transformers
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from transformers.cache_utils import parse_processor_args # noqa: F401
|
|
7
|
+
|
|
8
|
+
patch_parse_processor_args = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
patch_parse_processor_args = False
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
if patch_parse_processor_args:
|
|
14
|
+
|
|
15
|
+
def _init_cache_inspect():
|
|
16
|
+
res = {}
|
|
17
|
+
for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
|
|
18
|
+
try:
|
|
19
|
+
params = list(inspect.signature(processor_class.__init__).parameters)[2:]
|
|
20
|
+
res[processor_class.__init__] = params
|
|
21
|
+
except Exception:
|
|
22
|
+
res[processor_class.__init__] = None
|
|
23
|
+
return res
|
|
24
|
+
|
|
25
|
+
_cache_inspect = _init_cache_inspect()
|
|
26
|
+
|
|
27
|
+
def patched_parse_processor_args(
|
|
28
|
+
processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
|
|
29
|
+
) -> tuple[dict, dict]:
|
|
30
|
+
"""[patch:transformers.cache_utils.parse_processor_args]"""
|
|
31
|
+
# If not patched...
|
|
32
|
+
# Fails with transformers>=4.54 because function ``parse_processor_args``
|
|
33
|
+
# relies in inspect and the exporter is not very fond of that.
|
|
34
|
+
# torch._dynamo.exc.Unsupported: id() with unsupported args
|
|
35
|
+
# Explanation: Dynamo doesn't know how to trace id()
|
|
36
|
+
# call with args
|
|
37
|
+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
|
|
38
|
+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
|
|
39
|
+
# objects from outside the compiled region.
|
|
40
|
+
# Hint: It may be possible to write Dynamo tracing rules for this code.
|
|
41
|
+
#
|
|
42
|
+
# The patch is caching the signature to avoid any call to inspect.
|
|
43
|
+
if processor_class is None:
|
|
44
|
+
return {}, kwargs
|
|
45
|
+
params = _cache_inspect[processor_class.__init__]
|
|
46
|
+
if params is None:
|
|
47
|
+
return {}, kwargs
|
|
48
|
+
processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
|
|
49
|
+
remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
|
|
50
|
+
return processor_kwargs, remaining_kwargs
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import torch
|
|
4
|
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
5
|
+
from .patch_helper import _has_transformers
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _patch_make_causal_mask(
|
|
9
|
+
input_ids_shape: torch.Size,
|
|
10
|
+
dtype: torch.dtype,
|
|
11
|
+
device: torch.device,
|
|
12
|
+
past_key_values_length: int = 0,
|
|
13
|
+
sliding_window: Optional[int] = None,
|
|
14
|
+
):
|
|
15
|
+
"""Patched method."""
|
|
16
|
+
bsz, tgt_len = input_ids_shape
|
|
17
|
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
|
18
|
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
19
|
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
20
|
+
|
|
21
|
+
mask = mask.to(dtype)
|
|
22
|
+
|
|
23
|
+
if past_key_values_length > 0:
|
|
24
|
+
mask = torch.cat(
|
|
25
|
+
[
|
|
26
|
+
torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
|
|
27
|
+
mask,
|
|
28
|
+
],
|
|
29
|
+
dim=-1,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if sliding_window is not None:
|
|
33
|
+
diagonal = past_key_values_length - sliding_window - 1
|
|
34
|
+
|
|
35
|
+
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
|
36
|
+
# PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
|
|
37
|
+
# and used masked_fill instead of masked_fill_
|
|
38
|
+
# In this case, the current implementation of torch fails (17/12/2024).
|
|
39
|
+
# Try model Phi-3.5-Mini-Instruct.
|
|
40
|
+
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
|
|
41
|
+
|
|
42
|
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class patched_AttentionMaskConverter:
|
|
47
|
+
"""
|
|
48
|
+
Patches
|
|
49
|
+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# This method was fixed in 4.51 at least.
|
|
53
|
+
_PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
|
|
54
|
+
_PATCHED_CLASS_ = AttentionMaskConverter
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _make_causal_mask(
|
|
58
|
+
*args,
|
|
59
|
+
**kwargs,
|
|
60
|
+
# input_ids_shape: torch.Size,
|
|
61
|
+
# dtype: torch.dtype,
|
|
62
|
+
# device: torch.device,
|
|
63
|
+
# past_key_values_length: int = 0,
|
|
64
|
+
# sliding_window: Optional[int] = None,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Patched method.
|
|
68
|
+
|
|
69
|
+
This static method may be called with ``AttentionMaskConverter._make_causal_mask``
|
|
70
|
+
or ``self._make_causal_mask``. That changes this argument is receives.
|
|
71
|
+
That should not matter but...
|
|
72
|
+
The patch should be implemented in another way. static methods do not play well
|
|
73
|
+
with a simple replacement.
|
|
74
|
+
Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
|
|
75
|
+
"""
|
|
76
|
+
if args:
|
|
77
|
+
index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
|
|
78
|
+
names = [
|
|
79
|
+
"input_ids_shape",
|
|
80
|
+
"dtype",
|
|
81
|
+
"device",
|
|
82
|
+
"past_key_values_length",
|
|
83
|
+
"sliding_window",
|
|
84
|
+
]
|
|
85
|
+
for i, a in enumerate(args):
|
|
86
|
+
if i < index:
|
|
87
|
+
continue
|
|
88
|
+
kwargs[names[i - index]] = a
|
|
89
|
+
return _patch_make_causal_mask(**kwargs)
|