onnx-diagnostic 0.7.15__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +154 -52
- onnx_diagnostic/export/dynamic_shapes.py +6 -6
- onnx_diagnostic/export/shape_helper.py +124 -6
- onnx_diagnostic/ext_test_case.py +5 -1
- onnx_diagnostic/helpers/cache_helper.py +67 -41
- onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
- onnx_diagnostic/helpers/helper.py +3 -0
- onnx_diagnostic/tasks/image_text_to_text.py +1 -1
- onnx_diagnostic/tasks/text_generation.py +1 -4
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +68 -10
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +84 -5
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +54 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
- onnx_diagnostic/torch_models/validate.py +39 -20
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +21 -19
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
|
@@ -168,7 +168,33 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
168
168
|
]
|
|
169
169
|
)
|
|
170
170
|
print(string_type(past_key_values, with_shape=True))
|
|
171
|
+
|
|
172
|
+
The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
|
|
173
|
+
``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
|
|
174
|
+
are supported.
|
|
171
175
|
"""
|
|
176
|
+
if (
|
|
177
|
+
key_value_pairs
|
|
178
|
+
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
179
|
+
and pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
180
|
+
):
|
|
181
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
182
|
+
cache.layers.extend(
|
|
183
|
+
[transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
|
|
184
|
+
)
|
|
185
|
+
for i, layer in enumerate(cache.layers):
|
|
186
|
+
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
187
|
+
layer.dtype = k.dtype
|
|
188
|
+
layer.device = k.device
|
|
189
|
+
layer.keys = k
|
|
190
|
+
layer.values = v
|
|
191
|
+
layer.is_initialized = True
|
|
192
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
193
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
194
|
+
f"{len(key_value_pairs)} expected."
|
|
195
|
+
)
|
|
196
|
+
return finalize_cache(cache)
|
|
197
|
+
|
|
172
198
|
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
|
|
173
199
|
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
174
200
|
# The cache constructor contains the two following lines
|
|
@@ -494,51 +520,51 @@ def make_hybrid_cache(
|
|
|
494
520
|
|
|
495
521
|
.. code-block:: python
|
|
496
522
|
|
|
497
|
-
|
|
498
|
-
|
|
523
|
+
self.max_cache_len = (
|
|
524
|
+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
|
|
499
525
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
526
|
+
# Sliding layers can't be larger than the overall max cache len
|
|
527
|
+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
|
528
|
+
self.max_batch_size = max_batch_size
|
|
503
529
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
530
|
+
self.head_dim = (
|
|
531
|
+
config.head_dim if hasattr(config, "head_dim")
|
|
532
|
+
else config.hidden_size // config.num_attention_heads
|
|
533
|
+
)
|
|
508
534
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
535
|
+
self._dtype = dtype
|
|
536
|
+
self.num_key_value_heads = (
|
|
537
|
+
config.num_attention_heads
|
|
538
|
+
if getattr(config, "num_key_value_heads", None) is None
|
|
539
|
+
else config.num_key_value_heads
|
|
540
|
+
)
|
|
515
541
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
+
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
|
543
|
+
if hasattr(config, "layer_types"):
|
|
544
|
+
self.is_sliding = [
|
|
545
|
+
layer_type != "full_attention" for layer_type in config.layer_types]
|
|
546
|
+
else:
|
|
547
|
+
self.is_sliding = [False] * config.num_hidden_layers
|
|
548
|
+
|
|
549
|
+
self.key_cache: list[torch.Tensor] = []
|
|
550
|
+
self.value_cache: list[torch.Tensor] = []
|
|
551
|
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
552
|
+
self.max_cache_len, self.head_dim)
|
|
553
|
+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
554
|
+
self.sliding_window_len, self.head_dim)
|
|
555
|
+
self.sliding_window = min(config.sliding_window, max_cache_len)
|
|
556
|
+
device = torch.device(device) if device is not None else None
|
|
557
|
+
for i in range(config.num_hidden_layers):
|
|
558
|
+
layer_device = layer_device_map[i] if layer_device_map is not None else device
|
|
559
|
+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
|
560
|
+
new_layer_key_cache = torch.zeros(
|
|
561
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
562
|
+
new_layer_value_cache = torch.zeros(
|
|
563
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
564
|
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
565
|
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
566
|
+
self.key_cache.append(new_layer_key_cache)
|
|
567
|
+
self.value_cache.append(new_layer_value_cache)
|
|
542
568
|
"""
|
|
543
569
|
layer_types = None
|
|
544
570
|
if key_value_pairs:
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
_UNIQUE = set()
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _unique():
|
|
8
|
+
i = 129 + 1
|
|
9
|
+
while i in _UNIQUE:
|
|
10
|
+
i += 1
|
|
11
|
+
_UNIQUE.add(i)
|
|
12
|
+
return i
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def fake_reshape(
|
|
16
|
+
true_tensor: "torch.Tensor", # noqa: F821
|
|
17
|
+
sh: Dict[int, Any], # noqa: F821
|
|
18
|
+
fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
|
|
19
|
+
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
|
|
20
|
+
) -> "FakeTensor": # noqa: F821
|
|
21
|
+
"""
|
|
22
|
+
Changes the shape of a true tensor to make it dynamic.
|
|
23
|
+
|
|
24
|
+
:param true_tensor: true tensor
|
|
25
|
+
:param sh: dynamic shape
|
|
26
|
+
:param fake_tensor: fake tensor, if None, make a fake one
|
|
27
|
+
:param fake_mode: fake tensor mode
|
|
28
|
+
:return: fake tensor
|
|
29
|
+
"""
|
|
30
|
+
import torch
|
|
31
|
+
|
|
32
|
+
# deal with 0/1
|
|
33
|
+
for i in sh:
|
|
34
|
+
if true_tensor.shape[i] <= 1:
|
|
35
|
+
expanded_shape = list(true_tensor.shape)
|
|
36
|
+
expanded_shape[i] = _unique()
|
|
37
|
+
true_tensor = torch.empty(
|
|
38
|
+
tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# deal with equivalent dimension
|
|
42
|
+
new_shape = list(true_tensor.shape)
|
|
43
|
+
mapping = {}
|
|
44
|
+
for i, s in sh.items():
|
|
45
|
+
d = true_tensor.shape[i]
|
|
46
|
+
if d not in mapping:
|
|
47
|
+
mapping[d] = s
|
|
48
|
+
elif mapping[d] != s:
|
|
49
|
+
d = _unique()
|
|
50
|
+
mapping[d] = s
|
|
51
|
+
new_shape[i] = d
|
|
52
|
+
true_tensor = torch.empty(
|
|
53
|
+
tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# now switch to FakeTensor
|
|
57
|
+
if fake_mode is None:
|
|
58
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
59
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
60
|
+
|
|
61
|
+
shape_env = ShapeEnv()
|
|
62
|
+
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
63
|
+
if fake_tensor is None:
|
|
64
|
+
fake_tensor = fake_mode.from_tensor(true_tensor, static_shapes=False)
|
|
65
|
+
assert fake_mode is not None, "fake_mode must be provided"
|
|
66
|
+
|
|
67
|
+
new_shape = list(true_tensor.shape)
|
|
68
|
+
for i in sh:
|
|
69
|
+
new_shape[i] = fake_tensor.shape[i]
|
|
70
|
+
|
|
71
|
+
reduced_tensor = fake_mode.from_tensor(true_tensor, static_shapes=True).sum(
|
|
72
|
+
axis=tuple(sorted(sh)), keepdim=True
|
|
73
|
+
)
|
|
74
|
+
return reduced_tensor.expand(*new_shape)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def make_fake(
|
|
78
|
+
x: Any, fake_mode: Optional["FakeTensorMode"] = None # noqa: F821
|
|
79
|
+
) -> Tuple[Optional["FakeTensor"], Optional["FakeTensorMode"]]: # noqa: F821
|
|
80
|
+
"""
|
|
81
|
+
Replaces all tensors by fake tensors.
|
|
82
|
+
This modification happens inplace for caches.
|
|
83
|
+
This function is only implemented for cache with
|
|
84
|
+
``transformers>=4.55``.
|
|
85
|
+
|
|
86
|
+
.. runpython::
|
|
87
|
+
:showcode:
|
|
88
|
+
|
|
89
|
+
import pprint
|
|
90
|
+
import torch
|
|
91
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
92
|
+
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
|
|
93
|
+
|
|
94
|
+
inputs, _ = make_fake(
|
|
95
|
+
dict(
|
|
96
|
+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
|
|
97
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
98
|
+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
|
|
99
|
+
past_key_values=make_dynamic_cache(
|
|
100
|
+
[
|
|
101
|
+
(
|
|
102
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
103
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
104
|
+
),
|
|
105
|
+
(
|
|
106
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
107
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
108
|
+
),
|
|
109
|
+
]
|
|
110
|
+
),
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
pprint.pprint(inputs)
|
|
114
|
+
"""
|
|
115
|
+
if x is None:
|
|
116
|
+
return None, None
|
|
117
|
+
if fake_mode is None:
|
|
118
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
119
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
120
|
+
|
|
121
|
+
shape_env = ShapeEnv()
|
|
122
|
+
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
123
|
+
|
|
124
|
+
if isinstance(x, (list, tuple)):
|
|
125
|
+
return x.__class__([make_fake(i, fake_mode=fake_mode)[0] for i in x]), fake_mode
|
|
126
|
+
if isinstance(x, dict):
|
|
127
|
+
return {k: make_fake(v, fake_mode=fake_mode)[0] for k, v in x.items()}, fake_mode
|
|
128
|
+
|
|
129
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
130
|
+
assert hasattr(x, "layers"), (
|
|
131
|
+
f"Une more recent version of transformers (>=4.55), "
|
|
132
|
+
f"'layers' not found in class {type(x)}"
|
|
133
|
+
)
|
|
134
|
+
for layer in x.layers:
|
|
135
|
+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
136
|
+
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
137
|
+
f"not found in class {type(layer)} ({dir(layer)})"
|
|
138
|
+
)
|
|
139
|
+
layer.keys = make_fake(layer.keys, fake_mode=fake_mode)[0]
|
|
140
|
+
layer.values = make_fake(layer.values, fake_mode=fake_mode)[0]
|
|
141
|
+
return x, fake_mode
|
|
142
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
143
|
+
make_fake(x.self_attention_cache, fake_mode=fake_mode)
|
|
144
|
+
make_fake(x.cross_attention_cache, fake_mode=fake_mode)
|
|
145
|
+
return x, fake_mode
|
|
146
|
+
if hasattr(x, "shape"):
|
|
147
|
+
t = fake_mode.from_tensor(x, static_shapes=False)
|
|
148
|
+
return t, fake_mode
|
|
149
|
+
from . import string_type
|
|
150
|
+
|
|
151
|
+
raise TypeError(
|
|
152
|
+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
153
|
+
)
|
|
@@ -463,6 +463,7 @@ def string_type(
|
|
|
463
463
|
if verbose:
|
|
464
464
|
print(f"[string_type] F2:{type(obj)}")
|
|
465
465
|
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
|
|
466
|
+
|
|
466
467
|
if isinstance(obj, torch.Tensor):
|
|
467
468
|
from .torch_helper import torch_dtype_to_onnx_dtype
|
|
468
469
|
|
|
@@ -783,6 +784,8 @@ def string_type(
|
|
|
783
784
|
obj, ultralytics.engine.results.Results
|
|
784
785
|
), f"Unexpected type={type(obj)}"
|
|
785
786
|
return f"ultralytics.{obj.__class__.__name__}(...)"
|
|
787
|
+
if obj.__class__.__name__ == "FakeTensorMode":
|
|
788
|
+
return f"{obj}"
|
|
786
789
|
|
|
787
790
|
if verbose:
|
|
788
791
|
print(f"[string_type] END:{type(obj)}")
|
|
@@ -271,7 +271,7 @@ def get_inputs_default(
|
|
|
271
271
|
"input_ids": {0: batch, 1: seq_length},
|
|
272
272
|
"token_type_ids": {0: batch, 1: seq_length},
|
|
273
273
|
"attention_mask": {0: batch, 1: "cache+seq"},
|
|
274
|
-
"position_ids": {0: batch, 1:
|
|
274
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
275
275
|
"past_key_values": [
|
|
276
276
|
[{0: batch} for _ in range(num_hidden_layers)],
|
|
277
277
|
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
@@ -220,10 +220,7 @@ def get_inputs(
|
|
|
220
220
|
0: batch,
|
|
221
221
|
1: "cache+seq", # cache_length + seq_length
|
|
222
222
|
},
|
|
223
|
-
"position_ids": {
|
|
224
|
-
0: batch,
|
|
225
|
-
1: "cache+seq", # cache_length + seq_length
|
|
226
|
-
},
|
|
223
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
227
224
|
"past_key_values": [
|
|
228
225
|
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
229
226
|
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
@@ -2,7 +2,7 @@ import functools
|
|
|
2
2
|
import importlib
|
|
3
3
|
import contextlib
|
|
4
4
|
import re
|
|
5
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
6
|
from .onnx_export_serialization import (
|
|
7
7
|
register_cache_serialization,
|
|
8
8
|
unregister_cache_serialization,
|
|
@@ -160,7 +160,7 @@ def register_additional_serialization_functions(
|
|
|
160
160
|
@contextlib.contextmanager
|
|
161
161
|
def torch_export_patches(
|
|
162
162
|
patch_sympy: bool = True,
|
|
163
|
-
patch_torch: bool = True,
|
|
163
|
+
patch_torch: Union[bool, int] = True,
|
|
164
164
|
patch_transformers: bool = False,
|
|
165
165
|
patch_diffusers: bool = False,
|
|
166
166
|
catch_constraints: bool = True,
|
|
@@ -349,6 +349,7 @@ def torch_export_patches(
|
|
|
349
349
|
_catch_produce_guards_and_solve_constraints,
|
|
350
350
|
patch__check_input_constraints_for_graph,
|
|
351
351
|
patched__broadcast_in_dim_meta,
|
|
352
|
+
patched__broadcast_in_dim_meta_level_2,
|
|
352
353
|
patched__maybe_broadcast,
|
|
353
354
|
patched_ShapeEnv,
|
|
354
355
|
)
|
|
@@ -390,8 +391,13 @@ def torch_export_patches(
|
|
|
390
391
|
# torch._prims._broadcast_in_dim_meta
|
|
391
392
|
f_broadcast_in_dim = torch._prims.broadcast_in_dim
|
|
392
393
|
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
|
|
393
|
-
|
|
394
|
-
|
|
394
|
+
_patched_dim_f = (
|
|
395
|
+
patched__broadcast_in_dim_meta_level_2
|
|
396
|
+
if patch_torch == 2
|
|
397
|
+
else patched__broadcast_in_dim_meta
|
|
398
|
+
)
|
|
399
|
+
torch._prims._broadcast_in_dim_meta = _patched_dim_f
|
|
400
|
+
torch._prims.broadcast_in_dim = _patched_dim_f
|
|
395
401
|
|
|
396
402
|
# torch._refs._maybe_broadcast
|
|
397
403
|
f__maybe_broadcast = torch._refs._maybe_broadcast
|
|
@@ -453,6 +459,16 @@ def torch_export_patches(
|
|
|
453
459
|
except ImportError:
|
|
454
460
|
masking_utils = None
|
|
455
461
|
|
|
462
|
+
try:
|
|
463
|
+
import transformers.integrations.sdpa_attention as sdpa_attention
|
|
464
|
+
except ImportError:
|
|
465
|
+
sdpa_attention = None
|
|
466
|
+
|
|
467
|
+
try:
|
|
468
|
+
import transformers.modeling_utils as modeling_utils
|
|
469
|
+
except ImportError:
|
|
470
|
+
modeling_utils = None
|
|
471
|
+
|
|
456
472
|
if verbose:
|
|
457
473
|
import transformers
|
|
458
474
|
|
|
@@ -464,7 +480,7 @@ def torch_export_patches(
|
|
|
464
480
|
patch_transformers_list, verbose=verbose
|
|
465
481
|
)
|
|
466
482
|
|
|
467
|
-
if (
|
|
483
|
+
if ( # vmap
|
|
468
484
|
masking_utils
|
|
469
485
|
and patch_transformers_list.patch_masking_utils
|
|
470
486
|
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
@@ -499,7 +515,7 @@ def torch_export_patches(
|
|
|
499
515
|
else:
|
|
500
516
|
f_transformers_sdpa_mask = None
|
|
501
517
|
|
|
502
|
-
if (
|
|
518
|
+
if ( # eager_mask
|
|
503
519
|
masking_utils
|
|
504
520
|
and patch_transformers_list.patch_masking_utils
|
|
505
521
|
and hasattr(masking_utils, "eager_mask")
|
|
@@ -526,7 +542,7 @@ def torch_export_patches(
|
|
|
526
542
|
patch_transformers_list.patched_eager_mask
|
|
527
543
|
)
|
|
528
544
|
|
|
529
|
-
if (
|
|
545
|
+
if ( # sdpa_mask
|
|
530
546
|
masking_utils
|
|
531
547
|
and patch_transformers_list.patch_masking_utils
|
|
532
548
|
and hasattr(masking_utils, "sdpa_mask")
|
|
@@ -547,6 +563,29 @@ def torch_export_patches(
|
|
|
547
563
|
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
548
564
|
)
|
|
549
565
|
|
|
566
|
+
if ( # sdpa_attention_forward
|
|
567
|
+
sdpa_attention is not None
|
|
568
|
+
and modeling_utils is not None
|
|
569
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
570
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
571
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
572
|
+
):
|
|
573
|
+
if verbose:
|
|
574
|
+
print(
|
|
575
|
+
"[torch_export_patches] patches "
|
|
576
|
+
"transformers.integrations.sdpa_attention.sdpa_attention_forward"
|
|
577
|
+
)
|
|
578
|
+
f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
|
|
579
|
+
sdpa_attention.sdpa_attention_forward = (
|
|
580
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
581
|
+
)
|
|
582
|
+
modeling_utils.sdpa_attention_forward = (
|
|
583
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
584
|
+
)
|
|
585
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
586
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
587
|
+
)
|
|
588
|
+
|
|
550
589
|
if custom_patches:
|
|
551
590
|
if verbose:
|
|
552
591
|
print("[torch_export_patches] applies custom patches")
|
|
@@ -656,7 +695,7 @@ def torch_export_patches(
|
|
|
656
695
|
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
657
696
|
)
|
|
658
697
|
|
|
659
|
-
if (
|
|
698
|
+
if ( # vmap
|
|
660
699
|
masking_utils
|
|
661
700
|
and patch_transformers_list.patch_masking_utils
|
|
662
701
|
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
@@ -687,7 +726,7 @@ def torch_export_patches(
|
|
|
687
726
|
"transformers.masking_utils.sdpa_mask"
|
|
688
727
|
)
|
|
689
728
|
|
|
690
|
-
if (
|
|
729
|
+
if ( # eager_mask
|
|
691
730
|
masking_utils
|
|
692
731
|
and patch_transformers_list.patch_masking_utils
|
|
693
732
|
and hasattr(masking_utils, "eager_mask")
|
|
@@ -714,7 +753,7 @@ def torch_export_patches(
|
|
|
714
753
|
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
715
754
|
)
|
|
716
755
|
|
|
717
|
-
if (
|
|
756
|
+
if ( # sdpa_mask
|
|
718
757
|
masking_utils
|
|
719
758
|
and patch_transformers_list.patch_masking_utils
|
|
720
759
|
and hasattr(masking_utils, "sdpa_mask")
|
|
@@ -734,6 +773,25 @@ def torch_export_patches(
|
|
|
734
773
|
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
735
774
|
)
|
|
736
775
|
|
|
776
|
+
if ( # sdpa_attention_forward
|
|
777
|
+
sdpa_attention is not None
|
|
778
|
+
and modeling_utils is not None
|
|
779
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
780
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
781
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
782
|
+
):
|
|
783
|
+
sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
|
|
784
|
+
modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
|
|
785
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
786
|
+
f_sdpa_attention_forward
|
|
787
|
+
)
|
|
788
|
+
if verbose:
|
|
789
|
+
print(
|
|
790
|
+
"[torch_export_patches] restored "
|
|
791
|
+
"transformers.integrations.sdpa_attention."
|
|
792
|
+
"sdpa_attention_forward"
|
|
793
|
+
)
|
|
794
|
+
|
|
737
795
|
########
|
|
738
796
|
# caches
|
|
739
797
|
########
|
|
@@ -25,8 +25,8 @@ def retrieve_stacktrace():
|
|
|
25
25
|
|
|
26
26
|
def _catch_produce_guards_and_solve_constraints(
|
|
27
27
|
previous_function: Callable,
|
|
28
|
-
fake_mode:
|
|
29
|
-
gm:
|
|
28
|
+
fake_mode: FakeTensorMode,
|
|
29
|
+
gm: torch.fx.GraphModule,
|
|
30
30
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
|
31
31
|
equalities_inputs: "EqualityConstraint", # noqa: F821
|
|
32
32
|
original_signature: inspect.Signature,
|
|
@@ -982,16 +982,21 @@ def patched__broadcast_in_dim_meta(
|
|
|
982
982
|
elif guard_or_false(a.shape[original_idx] != 1):
|
|
983
983
|
new_strides.append(a.stride()[original_idx])
|
|
984
984
|
else:
|
|
985
|
+
# This checks generates the following issue:
|
|
986
|
+
# non-broadcasting semantics require s3 == Max(s10, s3), False,
|
|
987
|
+
# guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
|
|
988
|
+
# idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
|
|
989
|
+
# original_idx=1
|
|
985
990
|
torch._check(
|
|
986
991
|
a.shape[original_idx] == shape[idx],
|
|
987
992
|
lambda idx=idx, original_idx=original_idx: (
|
|
988
993
|
f"non-broadcasting semantics require "
|
|
989
994
|
f"{a.shape[original_idx]} == {shape[idx]}, "
|
|
990
995
|
f"{guard_or_false(a.shape[idx] != 1)}, "
|
|
991
|
-
f"guard_or_false(a.shape[idx]
|
|
996
|
+
f"guard_or_false(a.shape[idx]==1)="
|
|
992
997
|
f"{guard_or_false(a.shape[idx] == 1)}, "
|
|
993
|
-
f"a.stride()={a.stride()}, idx={idx}, "
|
|
994
|
-
f"original_idx={original_idx}"
|
|
998
|
+
f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
|
|
999
|
+
f"shape={shape}, original_idx={original_idx}"
|
|
995
1000
|
),
|
|
996
1001
|
)
|
|
997
1002
|
new_strides.append(a.stride()[original_idx])
|
|
@@ -1006,3 +1011,77 @@ def patched__broadcast_in_dim_meta(
|
|
|
1006
1011
|
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1007
1012
|
|
|
1008
1013
|
return a.as_strided(shape, new_strides, a.storage_offset())
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def patched__broadcast_in_dim_meta_level_2(
|
|
1017
|
+
a: torch._prims_common.TensorLikeType,
|
|
1018
|
+
shape: torch._prims_common.ShapeType,
|
|
1019
|
+
broadcast_dimensions: Sequence[int],
|
|
1020
|
+
):
|
|
1021
|
+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
|
|
1022
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
1023
|
+
guard_or_false,
|
|
1024
|
+
guard_or_true,
|
|
1025
|
+
sym_or,
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
# Type checks
|
|
1029
|
+
assert isinstance(a, torch._prims_common.TensorLike)
|
|
1030
|
+
assert isinstance(shape, Sequence)
|
|
1031
|
+
assert isinstance(broadcast_dimensions, Sequence)
|
|
1032
|
+
|
|
1033
|
+
# every dimension must be accounted for
|
|
1034
|
+
assert a.ndim == len(broadcast_dimensions)
|
|
1035
|
+
|
|
1036
|
+
# broadcast shape must have weakly more dimensions
|
|
1037
|
+
assert len(shape) >= a.ndim
|
|
1038
|
+
|
|
1039
|
+
# broadcast_dimensions must be an ascending sequence
|
|
1040
|
+
# (no relative reordering of dims) of integers and
|
|
1041
|
+
# each dimension must be within the new shape
|
|
1042
|
+
def _greater_than_reduce(acc, x):
|
|
1043
|
+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
|
|
1044
|
+
assert x > acc
|
|
1045
|
+
assert x < len(shape)
|
|
1046
|
+
|
|
1047
|
+
return x
|
|
1048
|
+
|
|
1049
|
+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
|
|
1050
|
+
|
|
1051
|
+
# shape must be broadcastable to
|
|
1052
|
+
for idx, new_idx in enumerate(broadcast_dimensions):
|
|
1053
|
+
torch._check(
|
|
1054
|
+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
|
|
1055
|
+
lambda idx=idx, new_idx=new_idx: (
|
|
1056
|
+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
|
|
1057
|
+
),
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
new_strides = []
|
|
1061
|
+
original_idx = 0
|
|
1062
|
+
for idx in range(len(shape)):
|
|
1063
|
+
if idx in broadcast_dimensions:
|
|
1064
|
+
# Assigns a stride of zero to dimensions
|
|
1065
|
+
# which were actually broadcast
|
|
1066
|
+
if guard_or_false(a.shape[original_idx] == 1):
|
|
1067
|
+
if guard_or_false(a.shape[original_idx] == shape[idx]):
|
|
1068
|
+
new_strides.append(a.stride()[original_idx])
|
|
1069
|
+
else:
|
|
1070
|
+
new_strides.append(0)
|
|
1071
|
+
# PATCHED: disabled this check
|
|
1072
|
+
elif guard_or_false(a.shape[original_idx] != 1):
|
|
1073
|
+
new_strides.append(a.stride()[original_idx])
|
|
1074
|
+
else:
|
|
1075
|
+
# PATCHED: torch._check was removed
|
|
1076
|
+
new_strides.append(a.stride()[original_idx])
|
|
1077
|
+
original_idx = original_idx + 1
|
|
1078
|
+
else:
|
|
1079
|
+
if guard_or_true(shape[idx] != 1):
|
|
1080
|
+
# consistent with previous use of guard_size_oblivious
|
|
1081
|
+
new_strides.append(0)
|
|
1082
|
+
elif original_idx == a.ndim:
|
|
1083
|
+
new_strides.append(1)
|
|
1084
|
+
else:
|
|
1085
|
+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1086
|
+
|
|
1087
|
+
return a.as_strided(shape, new_strides, a.storage_offset())
|
|
@@ -1276,6 +1276,60 @@ def common_eager_attention_forward(
|
|
|
1276
1276
|
return attn_output, attn_weights
|
|
1277
1277
|
|
|
1278
1278
|
|
|
1279
|
+
def patched_sdpa_attention_forward(
|
|
1280
|
+
module: torch.nn.Module,
|
|
1281
|
+
query: torch.Tensor,
|
|
1282
|
+
key: torch.Tensor,
|
|
1283
|
+
value: torch.Tensor,
|
|
1284
|
+
attention_mask: Optional[torch.Tensor],
|
|
1285
|
+
dropout: float = 0.0,
|
|
1286
|
+
scaling: Optional[float] = None,
|
|
1287
|
+
is_causal: Optional[bool] = None,
|
|
1288
|
+
**kwargs,
|
|
1289
|
+
) -> tuple[torch.Tensor, None]:
|
|
1290
|
+
"""[patch:transformers.integrations.sdpa_attention.sdpa_attention_forward]"""
|
|
1291
|
+
assert not kwargs.get("output_attentions", False), (
|
|
1292
|
+
"`sdpa` attention does not support `output_attentions=True`."
|
|
1293
|
+
" Please set your attention to `eager` if you want any of these features."
|
|
1294
|
+
)
|
|
1295
|
+
sdpa_kwargs = {}
|
|
1296
|
+
if hasattr(module, "num_key_value_groups"):
|
|
1297
|
+
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
|
|
1298
|
+
key = transformers.integrations.sdpa_attention.repeat_kv(
|
|
1299
|
+
key, module.num_key_value_groups
|
|
1300
|
+
)
|
|
1301
|
+
value = transformers.integrations.sdpa_attention.repeat_kv(
|
|
1302
|
+
value, module.num_key_value_groups
|
|
1303
|
+
)
|
|
1304
|
+
else:
|
|
1305
|
+
sdpa_kwargs = {"enable_gqa": True}
|
|
1306
|
+
|
|
1307
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1308
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
1309
|
+
|
|
1310
|
+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
|
1311
|
+
# PATCHED: remove the test query.shape[2] > 1
|
|
1312
|
+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
|
|
1313
|
+
is_causal = attention_mask is None and is_causal
|
|
1314
|
+
|
|
1315
|
+
torch._check(
|
|
1316
|
+
attention_mask is None or attention_mask.shape[3] == key.shape[2],
|
|
1317
|
+
"Attention mask shape incompatible with key shape.",
|
|
1318
|
+
)
|
|
1319
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
1320
|
+
query,
|
|
1321
|
+
key,
|
|
1322
|
+
value,
|
|
1323
|
+
attn_mask=attention_mask,
|
|
1324
|
+
dropout_p=dropout,
|
|
1325
|
+
scale=scaling,
|
|
1326
|
+
is_causal=is_causal,
|
|
1327
|
+
**sdpa_kwargs,
|
|
1328
|
+
)
|
|
1329
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1330
|
+
return attn_output, None
|
|
1331
|
+
|
|
1332
|
+
|
|
1279
1333
|
def patched_model_bart_eager_attention_forward(
|
|
1280
1334
|
module: torch.nn.Module,
|
|
1281
1335
|
query: torch.Tensor,
|