onnx-diagnostic 0.7.14__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +156 -47
- onnx_diagnostic/export/dynamic_shapes.py +6 -6
- onnx_diagnostic/export/shape_helper.py +124 -6
- onnx_diagnostic/ext_test_case.py +5 -1
- onnx_diagnostic/helpers/cache_helper.py +68 -42
- onnx_diagnostic/helpers/config_helper.py +2 -1
- onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
- onnx_diagnostic/helpers/helper.py +3 -0
- onnx_diagnostic/helpers/rt_helper.py +3 -3
- onnx_diagnostic/tasks/image_text_to_text.py +7 -6
- onnx_diagnostic/tasks/text_generation.py +7 -4
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +69 -11
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +109 -18
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +133 -28
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
- onnx_diagnostic/torch_models/validate.py +73 -29
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +25 -23
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
|
@@ -25,8 +25,8 @@ def retrieve_stacktrace():
|
|
|
25
25
|
|
|
26
26
|
def _catch_produce_guards_and_solve_constraints(
|
|
27
27
|
previous_function: Callable,
|
|
28
|
-
fake_mode:
|
|
29
|
-
gm:
|
|
28
|
+
fake_mode: FakeTensorMode,
|
|
29
|
+
gm: torch.fx.GraphModule,
|
|
30
30
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
|
31
31
|
equalities_inputs: "EqualityConstraint", # noqa: F821
|
|
32
32
|
original_signature: inspect.Signature,
|
|
@@ -88,7 +88,7 @@ def patch__check_input_constraints_for_graph(
|
|
|
88
88
|
|
|
89
89
|
def patched_infer_size(a, b):
|
|
90
90
|
"""Patches ``torch._subclasses.fake_impls.infer_size``."""
|
|
91
|
-
from torch.fx.experimental.symbolic_shapes import
|
|
91
|
+
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
92
92
|
|
|
93
93
|
dimsA = len(a)
|
|
94
94
|
dimsB = len(b)
|
|
@@ -113,19 +113,19 @@ def patched_infer_size(a, b):
|
|
|
113
113
|
# were not the case, we'd need to write this using torch.sym_or() or
|
|
114
114
|
# something like that).
|
|
115
115
|
try:
|
|
116
|
-
b1 =
|
|
116
|
+
b1 = guard_or_false(sizeA == 1)
|
|
117
117
|
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
118
118
|
b1 = False
|
|
119
119
|
try:
|
|
120
|
-
b2 =
|
|
120
|
+
b2 = guard_or_false(sizeB == 1)
|
|
121
121
|
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
122
122
|
b2 = False
|
|
123
123
|
try:
|
|
124
|
-
b3 =
|
|
124
|
+
b3 = guard_or_false(sizeA == sizeB)
|
|
125
125
|
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
126
126
|
b3 = False
|
|
127
127
|
if b1 or b2 or b3:
|
|
128
|
-
expandedSizes[i] = sizeB if
|
|
128
|
+
expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
|
|
129
129
|
else:
|
|
130
130
|
# PATCHED: generic case, the dimension is known, no need to assert
|
|
131
131
|
expandedSizes[i] = torch.sym_max(sizeA, sizeB)
|
|
@@ -137,7 +137,6 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
137
137
|
from functools import reduce
|
|
138
138
|
from torch._prims_common import IntLike
|
|
139
139
|
from torch.fx.experimental.symbolic_shapes import (
|
|
140
|
-
guard_size_oblivious,
|
|
141
140
|
guard_or_false,
|
|
142
141
|
is_nested_int,
|
|
143
142
|
)
|
|
@@ -174,13 +173,15 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
174
173
|
continue
|
|
175
174
|
# PATCHED: two cases, if == for sure, no broadcast,
|
|
176
175
|
# otherwise maybe broadcast with max(dimensions)
|
|
177
|
-
if
|
|
176
|
+
if guard_or_false(common_shape[idx] != 1):
|
|
177
|
+
pass
|
|
178
|
+
elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
|
|
178
179
|
if shape[idx] < 0:
|
|
179
180
|
raise ValueError(
|
|
180
181
|
"Attempting to broadcast a dimension with negative length!"
|
|
181
182
|
)
|
|
182
183
|
common_shape[idx] = shape[idx]
|
|
183
|
-
|
|
184
|
+
else:
|
|
184
185
|
common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
|
|
185
186
|
|
|
186
187
|
return common_shape
|
|
@@ -360,6 +361,10 @@ class patched_ShapeEnv:
|
|
|
360
361
|
},
|
|
361
362
|
)
|
|
362
363
|
|
|
364
|
+
for source in self.var_to_sources.get(a, []):
|
|
365
|
+
if user_tb:
|
|
366
|
+
self.specialization_stacks[source] = user_tb
|
|
367
|
+
|
|
363
368
|
# PATCHED: removed lines
|
|
364
369
|
# if config.print_specializations:
|
|
365
370
|
# self.log.warning(
|
|
@@ -973,15 +978,101 @@ def patched__broadcast_in_dim_meta(
|
|
|
973
978
|
new_strides.append(a.stride()[original_idx])
|
|
974
979
|
else:
|
|
975
980
|
new_strides.append(0)
|
|
981
|
+
# PATCHED: disabled this check
|
|
982
|
+
elif guard_or_false(a.shape[original_idx] != 1):
|
|
983
|
+
new_strides.append(a.stride()[original_idx])
|
|
976
984
|
else:
|
|
977
|
-
#
|
|
978
|
-
#
|
|
979
|
-
#
|
|
980
|
-
#
|
|
981
|
-
#
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
+
# This checks generates the following issue:
|
|
986
|
+
# non-broadcasting semantics require s3 == Max(s10, s3), False,
|
|
987
|
+
# guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
|
|
988
|
+
# idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
|
|
989
|
+
# original_idx=1
|
|
990
|
+
torch._check(
|
|
991
|
+
a.shape[original_idx] == shape[idx],
|
|
992
|
+
lambda idx=idx, original_idx=original_idx: (
|
|
993
|
+
f"non-broadcasting semantics require "
|
|
994
|
+
f"{a.shape[original_idx]} == {shape[idx]}, "
|
|
995
|
+
f"{guard_or_false(a.shape[idx] != 1)}, "
|
|
996
|
+
f"guard_or_false(a.shape[idx]==1)="
|
|
997
|
+
f"{guard_or_false(a.shape[idx] == 1)}, "
|
|
998
|
+
f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
|
|
999
|
+
f"shape={shape}, original_idx={original_idx}"
|
|
1000
|
+
),
|
|
1001
|
+
)
|
|
1002
|
+
new_strides.append(a.stride()[original_idx])
|
|
1003
|
+
original_idx = original_idx + 1
|
|
1004
|
+
else:
|
|
1005
|
+
if guard_or_true(shape[idx] != 1):
|
|
1006
|
+
# consistent with previous use of guard_size_oblivious
|
|
1007
|
+
new_strides.append(0)
|
|
1008
|
+
elif original_idx == a.ndim:
|
|
1009
|
+
new_strides.append(1)
|
|
1010
|
+
else:
|
|
1011
|
+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1012
|
+
|
|
1013
|
+
return a.as_strided(shape, new_strides, a.storage_offset())
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def patched__broadcast_in_dim_meta_level_2(
|
|
1017
|
+
a: torch._prims_common.TensorLikeType,
|
|
1018
|
+
shape: torch._prims_common.ShapeType,
|
|
1019
|
+
broadcast_dimensions: Sequence[int],
|
|
1020
|
+
):
|
|
1021
|
+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
|
|
1022
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
1023
|
+
guard_or_false,
|
|
1024
|
+
guard_or_true,
|
|
1025
|
+
sym_or,
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
# Type checks
|
|
1029
|
+
assert isinstance(a, torch._prims_common.TensorLike)
|
|
1030
|
+
assert isinstance(shape, Sequence)
|
|
1031
|
+
assert isinstance(broadcast_dimensions, Sequence)
|
|
1032
|
+
|
|
1033
|
+
# every dimension must be accounted for
|
|
1034
|
+
assert a.ndim == len(broadcast_dimensions)
|
|
1035
|
+
|
|
1036
|
+
# broadcast shape must have weakly more dimensions
|
|
1037
|
+
assert len(shape) >= a.ndim
|
|
1038
|
+
|
|
1039
|
+
# broadcast_dimensions must be an ascending sequence
|
|
1040
|
+
# (no relative reordering of dims) of integers and
|
|
1041
|
+
# each dimension must be within the new shape
|
|
1042
|
+
def _greater_than_reduce(acc, x):
|
|
1043
|
+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
|
|
1044
|
+
assert x > acc
|
|
1045
|
+
assert x < len(shape)
|
|
1046
|
+
|
|
1047
|
+
return x
|
|
1048
|
+
|
|
1049
|
+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
|
|
1050
|
+
|
|
1051
|
+
# shape must be broadcastable to
|
|
1052
|
+
for idx, new_idx in enumerate(broadcast_dimensions):
|
|
1053
|
+
torch._check(
|
|
1054
|
+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
|
|
1055
|
+
lambda idx=idx, new_idx=new_idx: (
|
|
1056
|
+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
|
|
1057
|
+
),
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
new_strides = []
|
|
1061
|
+
original_idx = 0
|
|
1062
|
+
for idx in range(len(shape)):
|
|
1063
|
+
if idx in broadcast_dimensions:
|
|
1064
|
+
# Assigns a stride of zero to dimensions
|
|
1065
|
+
# which were actually broadcast
|
|
1066
|
+
if guard_or_false(a.shape[original_idx] == 1):
|
|
1067
|
+
if guard_or_false(a.shape[original_idx] == shape[idx]):
|
|
1068
|
+
new_strides.append(a.stride()[original_idx])
|
|
1069
|
+
else:
|
|
1070
|
+
new_strides.append(0)
|
|
1071
|
+
# PATCHED: disabled this check
|
|
1072
|
+
elif guard_or_false(a.shape[original_idx] != 1):
|
|
1073
|
+
new_strides.append(a.stride()[original_idx])
|
|
1074
|
+
else:
|
|
1075
|
+
# PATCHED: torch._check was removed
|
|
985
1076
|
new_strides.append(a.stride()[original_idx])
|
|
986
1077
|
original_idx = original_idx + 1
|
|
987
1078
|
else:
|
|
@@ -1019,6 +1019,26 @@ def patched__compute_dynamic_ntk_parameters(
|
|
|
1019
1019
|
return inv_freq, attention_factor
|
|
1020
1020
|
|
|
1021
1021
|
|
|
1022
|
+
def _get_rope_init_fn(self, layer_type=None) -> Callable:
|
|
1023
|
+
if hasattr(self, "rope_init_fn"):
|
|
1024
|
+
# transformers<=5.0
|
|
1025
|
+
rope_init_fn = (
|
|
1026
|
+
patched__compute_dynamic_ntk_parameters
|
|
1027
|
+
if self.rope_init_fn
|
|
1028
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
1029
|
+
else self.rope_init_fn
|
|
1030
|
+
)
|
|
1031
|
+
return rope_init_fn
|
|
1032
|
+
|
|
1033
|
+
rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
|
|
1034
|
+
rope_init_fn = self.compute_default_rope_parameters
|
|
1035
|
+
if rope_type != "default":
|
|
1036
|
+
rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
1037
|
+
if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
|
|
1038
|
+
return patched__compute_dynamic_ntk_parameters
|
|
1039
|
+
return rope_init_fn
|
|
1040
|
+
|
|
1041
|
+
|
|
1022
1042
|
def patched_dynamic_rope_update(rope_forward):
|
|
1023
1043
|
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
|
|
1024
1044
|
|
|
@@ -1082,22 +1102,27 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
1082
1102
|
|
|
1083
1103
|
"""
|
|
1084
1104
|
|
|
1085
|
-
def longrope_frequency_update(self, position_ids, device):
|
|
1105
|
+
def longrope_frequency_update(self, position_ids, device, layer_type=None):
|
|
1086
1106
|
# It is no use to patch the function after the model is created
|
|
1087
1107
|
# as rope_init_fn is an attribute set to one function when the model
|
|
1088
1108
|
# is created and when no patch is applied yet.
|
|
1089
1109
|
# So we select the patched version here.
|
|
1090
|
-
rope_init_fn = (
|
|
1091
|
-
patched__compute_dynamic_ntk_parameters
|
|
1092
|
-
if self.rope_init_fn
|
|
1093
|
-
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
1094
|
-
else self.rope_init_fn
|
|
1095
|
-
)
|
|
1110
|
+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
|
|
1096
1111
|
seq_len = torch.max(position_ids) + 1
|
|
1097
1112
|
if hasattr(self.config, "original_max_position_embeddings"):
|
|
1098
1113
|
original_max_position_embeddings = self.config.original_max_position_embeddings
|
|
1099
1114
|
else:
|
|
1100
1115
|
original_max_position_embeddings = self.config.max_position_embeddings
|
|
1116
|
+
|
|
1117
|
+
if layer_type is None:
|
|
1118
|
+
# rope_type = self.rope_type
|
|
1119
|
+
original_inv_freq = self.original_inv_freq
|
|
1120
|
+
prefix = ""
|
|
1121
|
+
else:
|
|
1122
|
+
# rope_type = self.rope_type[layer_type]
|
|
1123
|
+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
|
|
1124
|
+
prefix = f"{layer_type}_"
|
|
1125
|
+
|
|
1101
1126
|
# At export time, seq_len is unknown.
|
|
1102
1127
|
long_inv_freq, _ = rope_init_fn(
|
|
1103
1128
|
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
@@ -1112,13 +1137,13 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
1112
1137
|
(lambda x, y: y.clone()),
|
|
1113
1138
|
[long_inv_freq, original_inv_freq],
|
|
1114
1139
|
)
|
|
1115
|
-
self
|
|
1140
|
+
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
1116
1141
|
# if seq_len > original_max_position_embeddings:
|
|
1117
1142
|
# self.inv_freq = self.long_inv_freq
|
|
1118
1143
|
# else:
|
|
1119
1144
|
# self.inv_freq = self.original_inv_freq
|
|
1120
1145
|
|
|
1121
|
-
def dynamic_frequency_update(self, position_ids, device):
|
|
1146
|
+
def dynamic_frequency_update(self, position_ids, device, layer_type=None):
|
|
1122
1147
|
# constructor:
|
|
1123
1148
|
# - self.max_seq_len_cached = config.max_position_embeddings
|
|
1124
1149
|
# - self.original_max_seq_len = config.max_position_embeddings
|
|
@@ -1128,12 +1153,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
1128
1153
|
# as rope_init_fn is an attribute set to one function when the model
|
|
1129
1154
|
# is created and when no patch is applied yet.
|
|
1130
1155
|
# So we select the patched version here.
|
|
1131
|
-
rope_init_fn = (
|
|
1132
|
-
patched__compute_dynamic_ntk_parameters
|
|
1133
|
-
if self.rope_init_fn
|
|
1134
|
-
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
1135
|
-
else self.rope_init_fn
|
|
1136
|
-
)
|
|
1156
|
+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
|
|
1137
1157
|
|
|
1138
1158
|
# This behaviour is difficult to translate.
|
|
1139
1159
|
# The sequence always grows.
|
|
@@ -1162,6 +1182,19 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
1162
1182
|
self.config, device, seq_len=seq_len
|
|
1163
1183
|
)
|
|
1164
1184
|
|
|
1185
|
+
if layer_type is None:
|
|
1186
|
+
# rope_type = self.rope_type
|
|
1187
|
+
# max_seq_len_cached = self.max_seq_len_cached
|
|
1188
|
+
original_inv_freq = self.original_inv_freq
|
|
1189
|
+
prefix = ""
|
|
1190
|
+
else:
|
|
1191
|
+
# rope_type = self.rope_type[layer_type]
|
|
1192
|
+
# max_seq_len_cached = getattr(
|
|
1193
|
+
# self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
|
|
1194
|
+
# )
|
|
1195
|
+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
|
|
1196
|
+
prefix = f"{layer_type}_"
|
|
1197
|
+
|
|
1165
1198
|
# Second test to translate.
|
|
1166
1199
|
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
|
|
1167
1200
|
# But in that case the following condition is a way to restore the original cache.
|
|
@@ -1183,15 +1216,26 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
1183
1216
|
(lambda x, y: y.clone()),
|
|
1184
1217
|
[long_inv_freq, original_inv_freq],
|
|
1185
1218
|
)
|
|
1186
|
-
self
|
|
1219
|
+
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
1187
1220
|
|
|
1188
1221
|
@wraps(rope_forward)
|
|
1189
|
-
def wrapper(self, x, position_ids):
|
|
1222
|
+
def wrapper(self, x, position_ids, layer_type=None):
|
|
1223
|
+
if layer_type is None:
|
|
1224
|
+
if "dynamic" in self.rope_type:
|
|
1225
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
1226
|
+
elif self.rope_type == "longrope":
|
|
1227
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
1228
|
+
return rope_forward(self, x, position_ids)
|
|
1229
|
+
|
|
1190
1230
|
if "dynamic" in self.rope_type:
|
|
1191
|
-
dynamic_frequency_update(
|
|
1231
|
+
dynamic_frequency_update(
|
|
1232
|
+
self, position_ids, device=x.device, layer_type=layer_type
|
|
1233
|
+
)
|
|
1192
1234
|
elif self.rope_type == "longrope":
|
|
1193
|
-
longrope_frequency_update(
|
|
1194
|
-
|
|
1235
|
+
longrope_frequency_update(
|
|
1236
|
+
self, position_ids, device=x.device, layer_type=layer_type
|
|
1237
|
+
)
|
|
1238
|
+
return rope_forward(self, x, position_ids, layer_type=layer_type)
|
|
1195
1239
|
|
|
1196
1240
|
return wrapper
|
|
1197
1241
|
|
|
@@ -1232,6 +1276,60 @@ def common_eager_attention_forward(
|
|
|
1232
1276
|
return attn_output, attn_weights
|
|
1233
1277
|
|
|
1234
1278
|
|
|
1279
|
+
def patched_sdpa_attention_forward(
|
|
1280
|
+
module: torch.nn.Module,
|
|
1281
|
+
query: torch.Tensor,
|
|
1282
|
+
key: torch.Tensor,
|
|
1283
|
+
value: torch.Tensor,
|
|
1284
|
+
attention_mask: Optional[torch.Tensor],
|
|
1285
|
+
dropout: float = 0.0,
|
|
1286
|
+
scaling: Optional[float] = None,
|
|
1287
|
+
is_causal: Optional[bool] = None,
|
|
1288
|
+
**kwargs,
|
|
1289
|
+
) -> tuple[torch.Tensor, None]:
|
|
1290
|
+
"""[patch:transformers.integrations.sdpa_attention.sdpa_attention_forward]"""
|
|
1291
|
+
assert not kwargs.get("output_attentions", False), (
|
|
1292
|
+
"`sdpa` attention does not support `output_attentions=True`."
|
|
1293
|
+
" Please set your attention to `eager` if you want any of these features."
|
|
1294
|
+
)
|
|
1295
|
+
sdpa_kwargs = {}
|
|
1296
|
+
if hasattr(module, "num_key_value_groups"):
|
|
1297
|
+
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
|
|
1298
|
+
key = transformers.integrations.sdpa_attention.repeat_kv(
|
|
1299
|
+
key, module.num_key_value_groups
|
|
1300
|
+
)
|
|
1301
|
+
value = transformers.integrations.sdpa_attention.repeat_kv(
|
|
1302
|
+
value, module.num_key_value_groups
|
|
1303
|
+
)
|
|
1304
|
+
else:
|
|
1305
|
+
sdpa_kwargs = {"enable_gqa": True}
|
|
1306
|
+
|
|
1307
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1308
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
1309
|
+
|
|
1310
|
+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
|
1311
|
+
# PATCHED: remove the test query.shape[2] > 1
|
|
1312
|
+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
|
|
1313
|
+
is_causal = attention_mask is None and is_causal
|
|
1314
|
+
|
|
1315
|
+
torch._check(
|
|
1316
|
+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
|
|
1317
|
+
"Attention mask shape incompatible with key shape.",
|
|
1318
|
+
)
|
|
1319
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
1320
|
+
query,
|
|
1321
|
+
key,
|
|
1322
|
+
value,
|
|
1323
|
+
attn_mask=attention_mask,
|
|
1324
|
+
dropout_p=dropout,
|
|
1325
|
+
scale=scaling,
|
|
1326
|
+
is_causal=is_causal,
|
|
1327
|
+
**sdpa_kwargs,
|
|
1328
|
+
)
|
|
1329
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1330
|
+
return attn_output, None
|
|
1331
|
+
|
|
1332
|
+
|
|
1235
1333
|
def patched_model_bart_eager_attention_forward(
|
|
1236
1334
|
module: torch.nn.Module,
|
|
1237
1335
|
query: torch.Tensor,
|
|
@@ -1287,12 +1385,18 @@ class common_RotaryEmbedding(torch.nn.Module):
|
|
|
1287
1385
|
# @torch.no_grad()
|
|
1288
1386
|
# PATCHED: the decorator
|
|
1289
1387
|
@patched_dynamic_rope_update
|
|
1290
|
-
def forward(self, x, position_ids):
|
|
1388
|
+
def forward(self, x, position_ids, layer_type=None):
|
|
1389
|
+
if layer_type is not None:
|
|
1390
|
+
# transformers>=5.0
|
|
1391
|
+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
|
|
1392
|
+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
|
|
1393
|
+
else:
|
|
1394
|
+
# transformers<5.0
|
|
1395
|
+
inv_freq = self.inv_freq
|
|
1396
|
+
attention_scaling = self.attention_scaling
|
|
1397
|
+
|
|
1291
1398
|
inv_freq_expanded = (
|
|
1292
|
-
|
|
1293
|
-
.float()
|
|
1294
|
-
.expand(position_ids.shape[0], -1, 1)
|
|
1295
|
-
.to(x.device)
|
|
1399
|
+
inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
1296
1400
|
)
|
|
1297
1401
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
1298
1402
|
|
|
@@ -1304,8 +1408,8 @@ class common_RotaryEmbedding(torch.nn.Module):
|
|
|
1304
1408
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
1305
1409
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
1306
1410
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1307
|
-
cos = emb.cos() *
|
|
1308
|
-
sin = emb.sin() *
|
|
1411
|
+
cos = emb.cos() * attention_scaling
|
|
1412
|
+
sin = emb.sin() * attention_scaling
|
|
1309
1413
|
|
|
1310
1414
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
1311
1415
|
|
|
@@ -1380,7 +1484,8 @@ class patched_IdeficsEmbedding(torch.nn.Module):
|
|
|
1380
1484
|
|
|
1381
1485
|
def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
|
|
1382
1486
|
t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
|
|
1383
|
-
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
1487
|
+
# freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
1488
|
+
freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
|
|
1384
1489
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1385
1490
|
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
|
|
1386
1491
|
|