onnx-diagnostic 0.7.14__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +156 -47
- onnx_diagnostic/export/dynamic_shapes.py +6 -6
- onnx_diagnostic/export/shape_helper.py +124 -6
- onnx_diagnostic/ext_test_case.py +5 -1
- onnx_diagnostic/helpers/cache_helper.py +68 -42
- onnx_diagnostic/helpers/config_helper.py +2 -1
- onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
- onnx_diagnostic/helpers/helper.py +3 -0
- onnx_diagnostic/helpers/rt_helper.py +3 -3
- onnx_diagnostic/tasks/image_text_to_text.py +7 -6
- onnx_diagnostic/tasks/text_generation.py +7 -4
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +69 -11
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +109 -18
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +133 -28
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
- onnx_diagnostic/torch_models/validate.py +73 -29
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +25 -23
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
|
@@ -108,7 +108,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
108
108
|
|
|
109
109
|
def is_cache_dynamic_registered(fast: bool = False) -> bool:
|
|
110
110
|
"""
|
|
111
|
-
Tells class :class:`transformers.cache_utils.DynamicCache` can be
|
|
111
|
+
Tells if class :class:`transformers.cache_utils.DynamicCache` can be
|
|
112
112
|
serialized and deserialized. Only then, :func:`torch.export.export`
|
|
113
113
|
can export a model.
|
|
114
114
|
|
|
@@ -168,7 +168,33 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
168
168
|
]
|
|
169
169
|
)
|
|
170
170
|
print(string_type(past_key_values, with_shape=True))
|
|
171
|
+
|
|
172
|
+
The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
|
|
173
|
+
``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
|
|
174
|
+
are supported.
|
|
171
175
|
"""
|
|
176
|
+
if (
|
|
177
|
+
key_value_pairs
|
|
178
|
+
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
179
|
+
and pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
180
|
+
):
|
|
181
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
182
|
+
cache.layers.extend(
|
|
183
|
+
[transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
|
|
184
|
+
)
|
|
185
|
+
for i, layer in enumerate(cache.layers):
|
|
186
|
+
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
187
|
+
layer.dtype = k.dtype
|
|
188
|
+
layer.device = k.device
|
|
189
|
+
layer.keys = k
|
|
190
|
+
layer.values = v
|
|
191
|
+
layer.is_initialized = True
|
|
192
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
193
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
194
|
+
f"{len(key_value_pairs)} expected."
|
|
195
|
+
)
|
|
196
|
+
return finalize_cache(cache)
|
|
197
|
+
|
|
172
198
|
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
|
|
173
199
|
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
174
200
|
# The cache constructor contains the two following lines
|
|
@@ -494,51 +520,51 @@ def make_hybrid_cache(
|
|
|
494
520
|
|
|
495
521
|
.. code-block:: python
|
|
496
522
|
|
|
497
|
-
|
|
498
|
-
|
|
523
|
+
self.max_cache_len = (
|
|
524
|
+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
|
|
499
525
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
526
|
+
# Sliding layers can't be larger than the overall max cache len
|
|
527
|
+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
|
528
|
+
self.max_batch_size = max_batch_size
|
|
503
529
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
530
|
+
self.head_dim = (
|
|
531
|
+
config.head_dim if hasattr(config, "head_dim")
|
|
532
|
+
else config.hidden_size // config.num_attention_heads
|
|
533
|
+
)
|
|
508
534
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
535
|
+
self._dtype = dtype
|
|
536
|
+
self.num_key_value_heads = (
|
|
537
|
+
config.num_attention_heads
|
|
538
|
+
if getattr(config, "num_key_value_heads", None) is None
|
|
539
|
+
else config.num_key_value_heads
|
|
540
|
+
)
|
|
515
541
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
+
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
|
543
|
+
if hasattr(config, "layer_types"):
|
|
544
|
+
self.is_sliding = [
|
|
545
|
+
layer_type != "full_attention" for layer_type in config.layer_types]
|
|
546
|
+
else:
|
|
547
|
+
self.is_sliding = [False] * config.num_hidden_layers
|
|
548
|
+
|
|
549
|
+
self.key_cache: list[torch.Tensor] = []
|
|
550
|
+
self.value_cache: list[torch.Tensor] = []
|
|
551
|
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
552
|
+
self.max_cache_len, self.head_dim)
|
|
553
|
+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
554
|
+
self.sliding_window_len, self.head_dim)
|
|
555
|
+
self.sliding_window = min(config.sliding_window, max_cache_len)
|
|
556
|
+
device = torch.device(device) if device is not None else None
|
|
557
|
+
for i in range(config.num_hidden_layers):
|
|
558
|
+
layer_device = layer_device_map[i] if layer_device_map is not None else device
|
|
559
|
+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
|
560
|
+
new_layer_key_cache = torch.zeros(
|
|
561
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
562
|
+
new_layer_value_cache = torch.zeros(
|
|
563
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
564
|
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
565
|
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
566
|
+
self.key_cache.append(new_layer_key_cache)
|
|
567
|
+
self.value_cache.append(new_layer_value_cache)
|
|
542
568
|
"""
|
|
543
569
|
layer_types = None
|
|
544
570
|
if key_value_pairs:
|
|
@@ -95,7 +95,8 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
|
|
|
95
95
|
mod_name = cls.__module__
|
|
96
96
|
mod = importlib.import_module(mod_name)
|
|
97
97
|
source = inspect.getsource(mod)
|
|
98
|
-
|
|
98
|
+
# [^O] avoids capturing Optional[Something]
|
|
99
|
+
reg = re.compile("config: ([^O][A-Za-z0-9]+)")
|
|
99
100
|
fall = reg.findall(source)
|
|
100
101
|
if len(fall) == 0:
|
|
101
102
|
assert not exc, (
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
_UNIQUE = set()
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _unique():
|
|
8
|
+
i = 129 + 1
|
|
9
|
+
while i in _UNIQUE:
|
|
10
|
+
i += 1
|
|
11
|
+
_UNIQUE.add(i)
|
|
12
|
+
return i
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def fake_reshape(
|
|
16
|
+
true_tensor: "torch.Tensor", # noqa: F821
|
|
17
|
+
sh: Dict[int, Any], # noqa: F821
|
|
18
|
+
fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
|
|
19
|
+
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
|
|
20
|
+
) -> "FakeTensor": # noqa: F821
|
|
21
|
+
"""
|
|
22
|
+
Changes the shape of a true tensor to make it dynamic.
|
|
23
|
+
|
|
24
|
+
:param true_tensor: true tensor
|
|
25
|
+
:param sh: dynamic shape
|
|
26
|
+
:param fake_tensor: fake tensor, if None, make a fake one
|
|
27
|
+
:param fake_mode: fake tensor mode
|
|
28
|
+
:return: fake tensor
|
|
29
|
+
"""
|
|
30
|
+
import torch
|
|
31
|
+
|
|
32
|
+
# deal with 0/1
|
|
33
|
+
for i in sh:
|
|
34
|
+
if true_tensor.shape[i] <= 1:
|
|
35
|
+
expanded_shape = list(true_tensor.shape)
|
|
36
|
+
expanded_shape[i] = _unique()
|
|
37
|
+
true_tensor = torch.empty(
|
|
38
|
+
tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# deal with equivalent dimension
|
|
42
|
+
new_shape = list(true_tensor.shape)
|
|
43
|
+
mapping = {}
|
|
44
|
+
for i, s in sh.items():
|
|
45
|
+
d = true_tensor.shape[i]
|
|
46
|
+
if d not in mapping:
|
|
47
|
+
mapping[d] = s
|
|
48
|
+
elif mapping[d] != s:
|
|
49
|
+
d = _unique()
|
|
50
|
+
mapping[d] = s
|
|
51
|
+
new_shape[i] = d
|
|
52
|
+
true_tensor = torch.empty(
|
|
53
|
+
tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# now switch to FakeTensor
|
|
57
|
+
if fake_mode is None:
|
|
58
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
59
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
60
|
+
|
|
61
|
+
shape_env = ShapeEnv()
|
|
62
|
+
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
63
|
+
if fake_tensor is None:
|
|
64
|
+
fake_tensor = fake_mode.from_tensor(true_tensor, static_shapes=False)
|
|
65
|
+
assert fake_mode is not None, "fake_mode must be provided"
|
|
66
|
+
|
|
67
|
+
new_shape = list(true_tensor.shape)
|
|
68
|
+
for i in sh:
|
|
69
|
+
new_shape[i] = fake_tensor.shape[i]
|
|
70
|
+
|
|
71
|
+
reduced_tensor = fake_mode.from_tensor(true_tensor, static_shapes=True).sum(
|
|
72
|
+
axis=tuple(sorted(sh)), keepdim=True
|
|
73
|
+
)
|
|
74
|
+
return reduced_tensor.expand(*new_shape)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def make_fake(
|
|
78
|
+
x: Any, fake_mode: Optional["FakeTensorMode"] = None # noqa: F821
|
|
79
|
+
) -> Tuple[Optional["FakeTensor"], Optional["FakeTensorMode"]]: # noqa: F821
|
|
80
|
+
"""
|
|
81
|
+
Replaces all tensors by fake tensors.
|
|
82
|
+
This modification happens inplace for caches.
|
|
83
|
+
This function is only implemented for cache with
|
|
84
|
+
``transformers>=4.55``.
|
|
85
|
+
|
|
86
|
+
.. runpython::
|
|
87
|
+
:showcode:
|
|
88
|
+
|
|
89
|
+
import pprint
|
|
90
|
+
import torch
|
|
91
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
92
|
+
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
|
|
93
|
+
|
|
94
|
+
inputs, _ = make_fake(
|
|
95
|
+
dict(
|
|
96
|
+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
|
|
97
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
98
|
+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
|
|
99
|
+
past_key_values=make_dynamic_cache(
|
|
100
|
+
[
|
|
101
|
+
(
|
|
102
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
103
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
104
|
+
),
|
|
105
|
+
(
|
|
106
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
107
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
108
|
+
),
|
|
109
|
+
]
|
|
110
|
+
),
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
pprint.pprint(inputs)
|
|
114
|
+
"""
|
|
115
|
+
if x is None:
|
|
116
|
+
return None, None
|
|
117
|
+
if fake_mode is None:
|
|
118
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
119
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
120
|
+
|
|
121
|
+
shape_env = ShapeEnv()
|
|
122
|
+
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
123
|
+
|
|
124
|
+
if isinstance(x, (list, tuple)):
|
|
125
|
+
return x.__class__([make_fake(i, fake_mode=fake_mode)[0] for i in x]), fake_mode
|
|
126
|
+
if isinstance(x, dict):
|
|
127
|
+
return {k: make_fake(v, fake_mode=fake_mode)[0] for k, v in x.items()}, fake_mode
|
|
128
|
+
|
|
129
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
130
|
+
assert hasattr(x, "layers"), (
|
|
131
|
+
f"Une more recent version of transformers (>=4.55), "
|
|
132
|
+
f"'layers' not found in class {type(x)}"
|
|
133
|
+
)
|
|
134
|
+
for layer in x.layers:
|
|
135
|
+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
136
|
+
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
137
|
+
f"not found in class {type(layer)} ({dir(layer)})"
|
|
138
|
+
)
|
|
139
|
+
layer.keys = make_fake(layer.keys, fake_mode=fake_mode)[0]
|
|
140
|
+
layer.values = make_fake(layer.values, fake_mode=fake_mode)[0]
|
|
141
|
+
return x, fake_mode
|
|
142
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
143
|
+
make_fake(x.self_attention_cache, fake_mode=fake_mode)
|
|
144
|
+
make_fake(x.cross_attention_cache, fake_mode=fake_mode)
|
|
145
|
+
return x, fake_mode
|
|
146
|
+
if hasattr(x, "shape"):
|
|
147
|
+
t = fake_mode.from_tensor(x, static_shapes=False)
|
|
148
|
+
return t, fake_mode
|
|
149
|
+
from . import string_type
|
|
150
|
+
|
|
151
|
+
raise TypeError(
|
|
152
|
+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
153
|
+
)
|
|
@@ -463,6 +463,7 @@ def string_type(
|
|
|
463
463
|
if verbose:
|
|
464
464
|
print(f"[string_type] F2:{type(obj)}")
|
|
465
465
|
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
|
|
466
|
+
|
|
466
467
|
if isinstance(obj, torch.Tensor):
|
|
467
468
|
from .torch_helper import torch_dtype_to_onnx_dtype
|
|
468
469
|
|
|
@@ -783,6 +784,8 @@ def string_type(
|
|
|
783
784
|
obj, ultralytics.engine.results.Results
|
|
784
785
|
), f"Unexpected type={type(obj)}"
|
|
785
786
|
return f"ultralytics.{obj.__class__.__name__}(...)"
|
|
787
|
+
if obj.__class__.__name__ == "FakeTensorMode":
|
|
788
|
+
return f"{obj}"
|
|
786
789
|
|
|
787
790
|
if verbose:
|
|
788
791
|
print(f"[string_type] END:{type(obj)}")
|
|
@@ -3,8 +3,6 @@ import numpy as np
|
|
|
3
3
|
import onnx
|
|
4
4
|
import torch
|
|
5
5
|
from .helper import string_type, flatten_object
|
|
6
|
-
from .torch_helper import to_numpy
|
|
7
|
-
from .cache_helper import is_cache_dynamic_registered
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
def name_type_to_onnx_dtype(name: str) -> int:
|
|
@@ -49,7 +47,7 @@ def make_feeds(
|
|
|
49
47
|
assert (
|
|
50
48
|
not check_flatten
|
|
51
49
|
or not all(isinstance(obj, torch.Tensor) for obj in flat)
|
|
52
|
-
or not is_cache_dynamic_registered(fast=True)
|
|
50
|
+
# or not is_cache_dynamic_registered(fast=True)
|
|
53
51
|
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
|
|
54
52
|
), (
|
|
55
53
|
f"Unexpected number of flattened objects, "
|
|
@@ -57,6 +55,8 @@ def make_feeds(
|
|
|
57
55
|
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
|
|
58
56
|
)
|
|
59
57
|
if use_numpy:
|
|
58
|
+
from .torch_helper import to_numpy
|
|
59
|
+
|
|
60
60
|
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
|
|
61
61
|
names = (
|
|
62
62
|
[i.name for i in proto.graph.input]
|
|
@@ -186,12 +186,13 @@ def _get_inputs_gemma3(
|
|
|
186
186
|
f"total_sequence_length={total_sequence_length} != 860 "
|
|
187
187
|
f"for model {model.__class__.__name__}"
|
|
188
188
|
)
|
|
189
|
-
assert (
|
|
190
|
-
|
|
191
|
-
|
|
189
|
+
assert head_dim in (
|
|
190
|
+
256,
|
|
191
|
+
32,
|
|
192
|
+
), f"head_dim={head_dim} not in (32, 256) for model {model.__class__.__name__}"
|
|
192
193
|
assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
|
|
193
|
-
assert num_key_value_heads
|
|
194
|
-
f"num_key_value_heads={num_key_value_heads}
|
|
194
|
+
assert num_key_value_heads in (1, 4), (
|
|
195
|
+
f"num_key_value_heads={num_key_value_heads} not in (1, 4) "
|
|
195
196
|
f"for this model {model.__class__.__name__}"
|
|
196
197
|
)
|
|
197
198
|
|
|
@@ -270,7 +271,7 @@ def get_inputs_default(
|
|
|
270
271
|
"input_ids": {0: batch, 1: seq_length},
|
|
271
272
|
"token_type_ids": {0: batch, 1: seq_length},
|
|
272
273
|
"attention_mask": {0: batch, 1: "cache+seq"},
|
|
273
|
-
"position_ids": {0: batch, 1:
|
|
274
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
274
275
|
"past_key_values": [
|
|
275
276
|
[{0: batch} for _ in range(num_hidden_layers)],
|
|
276
277
|
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
@@ -19,6 +19,9 @@ __TASK__ = "text-generation"
|
|
|
19
19
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
20
20
|
"""Reduces a model size."""
|
|
21
21
|
# FalconMambaConfig: use_mambapy
|
|
22
|
+
if hasattr(config, "text_config"):
|
|
23
|
+
# The model is probably of mixture of models used only for text.
|
|
24
|
+
config = config.text_config
|
|
22
25
|
check_hasattr(
|
|
23
26
|
config,
|
|
24
27
|
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
|
|
@@ -217,10 +220,7 @@ def get_inputs(
|
|
|
217
220
|
0: batch,
|
|
218
221
|
1: "cache+seq", # cache_length + seq_length
|
|
219
222
|
},
|
|
220
|
-
"position_ids": {
|
|
221
|
-
0: batch,
|
|
222
|
-
1: "cache+seq", # cache_length + seq_length
|
|
223
|
-
},
|
|
223
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
224
224
|
"past_key_values": [
|
|
225
225
|
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
226
226
|
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
@@ -308,6 +308,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
308
308
|
|
|
309
309
|
If the configuration is None, the function selects typical dimensions.
|
|
310
310
|
"""
|
|
311
|
+
if hasattr(config, "text_config"):
|
|
312
|
+
# The model is probably of mixture of models used only for text.
|
|
313
|
+
config = config.text_config
|
|
311
314
|
if config is not None:
|
|
312
315
|
check_hasattr(
|
|
313
316
|
config,
|
|
@@ -2,7 +2,7 @@ import functools
|
|
|
2
2
|
import importlib
|
|
3
3
|
import contextlib
|
|
4
4
|
import re
|
|
5
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
6
|
from .onnx_export_serialization import (
|
|
7
7
|
register_cache_serialization,
|
|
8
8
|
unregister_cache_serialization,
|
|
@@ -160,7 +160,7 @@ def register_additional_serialization_functions(
|
|
|
160
160
|
@contextlib.contextmanager
|
|
161
161
|
def torch_export_patches(
|
|
162
162
|
patch_sympy: bool = True,
|
|
163
|
-
patch_torch: bool = True,
|
|
163
|
+
patch_torch: Union[bool, int] = True,
|
|
164
164
|
patch_transformers: bool = False,
|
|
165
165
|
patch_diffusers: bool = False,
|
|
166
166
|
catch_constraints: bool = True,
|
|
@@ -349,6 +349,7 @@ def torch_export_patches(
|
|
|
349
349
|
_catch_produce_guards_and_solve_constraints,
|
|
350
350
|
patch__check_input_constraints_for_graph,
|
|
351
351
|
patched__broadcast_in_dim_meta,
|
|
352
|
+
patched__broadcast_in_dim_meta_level_2,
|
|
352
353
|
patched__maybe_broadcast,
|
|
353
354
|
patched_ShapeEnv,
|
|
354
355
|
)
|
|
@@ -390,8 +391,13 @@ def torch_export_patches(
|
|
|
390
391
|
# torch._prims._broadcast_in_dim_meta
|
|
391
392
|
f_broadcast_in_dim = torch._prims.broadcast_in_dim
|
|
392
393
|
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
|
|
393
|
-
|
|
394
|
-
|
|
394
|
+
_patched_dim_f = (
|
|
395
|
+
patched__broadcast_in_dim_meta_level_2
|
|
396
|
+
if patch_torch == 2
|
|
397
|
+
else patched__broadcast_in_dim_meta
|
|
398
|
+
)
|
|
399
|
+
torch._prims._broadcast_in_dim_meta = _patched_dim_f
|
|
400
|
+
torch._prims.broadcast_in_dim = _patched_dim_f
|
|
395
401
|
|
|
396
402
|
# torch._refs._maybe_broadcast
|
|
397
403
|
f__maybe_broadcast = torch._refs._maybe_broadcast
|
|
@@ -422,7 +428,7 @@ def torch_export_patches(
|
|
|
422
428
|
)
|
|
423
429
|
)
|
|
424
430
|
|
|
425
|
-
if stop_if_static:
|
|
431
|
+
if patch_torch and stop_if_static:
|
|
426
432
|
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
|
|
427
433
|
|
|
428
434
|
if verbose:
|
|
@@ -453,6 +459,16 @@ def torch_export_patches(
|
|
|
453
459
|
except ImportError:
|
|
454
460
|
masking_utils = None
|
|
455
461
|
|
|
462
|
+
try:
|
|
463
|
+
import transformers.integrations.sdpa_attention as sdpa_attention
|
|
464
|
+
except ImportError:
|
|
465
|
+
sdpa_attention = None
|
|
466
|
+
|
|
467
|
+
try:
|
|
468
|
+
import transformers.modeling_utils as modeling_utils
|
|
469
|
+
except ImportError:
|
|
470
|
+
modeling_utils = None
|
|
471
|
+
|
|
456
472
|
if verbose:
|
|
457
473
|
import transformers
|
|
458
474
|
|
|
@@ -464,7 +480,7 @@ def torch_export_patches(
|
|
|
464
480
|
patch_transformers_list, verbose=verbose
|
|
465
481
|
)
|
|
466
482
|
|
|
467
|
-
if (
|
|
483
|
+
if ( # vmap
|
|
468
484
|
masking_utils
|
|
469
485
|
and patch_transformers_list.patch_masking_utils
|
|
470
486
|
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
@@ -499,7 +515,7 @@ def torch_export_patches(
|
|
|
499
515
|
else:
|
|
500
516
|
f_transformers_sdpa_mask = None
|
|
501
517
|
|
|
502
|
-
if (
|
|
518
|
+
if ( # eager_mask
|
|
503
519
|
masking_utils
|
|
504
520
|
and patch_transformers_list.patch_masking_utils
|
|
505
521
|
and hasattr(masking_utils, "eager_mask")
|
|
@@ -526,7 +542,7 @@ def torch_export_patches(
|
|
|
526
542
|
patch_transformers_list.patched_eager_mask
|
|
527
543
|
)
|
|
528
544
|
|
|
529
|
-
if (
|
|
545
|
+
if ( # sdpa_mask
|
|
530
546
|
masking_utils
|
|
531
547
|
and patch_transformers_list.patch_masking_utils
|
|
532
548
|
and hasattr(masking_utils, "sdpa_mask")
|
|
@@ -547,6 +563,29 @@ def torch_export_patches(
|
|
|
547
563
|
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
548
564
|
)
|
|
549
565
|
|
|
566
|
+
if ( # sdpa_attention_forward
|
|
567
|
+
sdpa_attention is not None
|
|
568
|
+
and modeling_utils is not None
|
|
569
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
570
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
571
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
572
|
+
):
|
|
573
|
+
if verbose:
|
|
574
|
+
print(
|
|
575
|
+
"[torch_export_patches] patches "
|
|
576
|
+
"transformers.integrations.sdpa_attention.sdpa_attention_forward"
|
|
577
|
+
)
|
|
578
|
+
f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
|
|
579
|
+
sdpa_attention.sdpa_attention_forward = (
|
|
580
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
581
|
+
)
|
|
582
|
+
modeling_utils.sdpa_attention_forward = (
|
|
583
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
584
|
+
)
|
|
585
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
586
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
587
|
+
)
|
|
588
|
+
|
|
550
589
|
if custom_patches:
|
|
551
590
|
if verbose:
|
|
552
591
|
print("[torch_export_patches] applies custom patches")
|
|
@@ -656,7 +695,7 @@ def torch_export_patches(
|
|
|
656
695
|
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
657
696
|
)
|
|
658
697
|
|
|
659
|
-
if (
|
|
698
|
+
if ( # vmap
|
|
660
699
|
masking_utils
|
|
661
700
|
and patch_transformers_list.patch_masking_utils
|
|
662
701
|
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
@@ -687,7 +726,7 @@ def torch_export_patches(
|
|
|
687
726
|
"transformers.masking_utils.sdpa_mask"
|
|
688
727
|
)
|
|
689
728
|
|
|
690
|
-
if (
|
|
729
|
+
if ( # eager_mask
|
|
691
730
|
masking_utils
|
|
692
731
|
and patch_transformers_list.patch_masking_utils
|
|
693
732
|
and hasattr(masking_utils, "eager_mask")
|
|
@@ -714,7 +753,7 @@ def torch_export_patches(
|
|
|
714
753
|
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
715
754
|
)
|
|
716
755
|
|
|
717
|
-
if (
|
|
756
|
+
if ( # sdpa_mask
|
|
718
757
|
masking_utils
|
|
719
758
|
and patch_transformers_list.patch_masking_utils
|
|
720
759
|
and hasattr(masking_utils, "sdpa_mask")
|
|
@@ -734,6 +773,25 @@ def torch_export_patches(
|
|
|
734
773
|
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
735
774
|
)
|
|
736
775
|
|
|
776
|
+
if ( # sdpa_attention_forward
|
|
777
|
+
sdpa_attention is not None
|
|
778
|
+
and modeling_utils is not None
|
|
779
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
780
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
781
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
782
|
+
):
|
|
783
|
+
sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
|
|
784
|
+
modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
|
|
785
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
786
|
+
f_sdpa_attention_forward
|
|
787
|
+
)
|
|
788
|
+
if verbose:
|
|
789
|
+
print(
|
|
790
|
+
"[torch_export_patches] restored "
|
|
791
|
+
"transformers.integrations.sdpa_attention."
|
|
792
|
+
"sdpa_attention_forward"
|
|
793
|
+
)
|
|
794
|
+
|
|
737
795
|
########
|
|
738
796
|
# caches
|
|
739
797
|
########
|
|
@@ -12,17 +12,26 @@ from transformers.cache_utils import (
|
|
|
12
12
|
StaticCache,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
-
try:
|
|
16
|
-
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
17
|
-
except ImportError:
|
|
18
|
-
from transformers.cache_utils import MambaCache
|
|
19
|
-
|
|
20
15
|
from ..helpers import string_type
|
|
21
16
|
from .serialization import _lower_name_with_
|
|
22
17
|
|
|
23
18
|
PATCH_OF_PATCHES: Set[Any] = set()
|
|
24
19
|
|
|
25
20
|
|
|
21
|
+
def get_mamba_cache_cls() -> type:
|
|
22
|
+
try:
|
|
23
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
24
|
+
|
|
25
|
+
return MambaCache
|
|
26
|
+
except ImportError:
|
|
27
|
+
try:
|
|
28
|
+
from transformers.cache_utils import MambaCache
|
|
29
|
+
|
|
30
|
+
return MambaCache
|
|
31
|
+
except ImportError:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
|
|
26
35
|
def register_class_serialization(
|
|
27
36
|
cls,
|
|
28
37
|
f_flatten: Callable,
|
|
@@ -203,13 +212,6 @@ def serialization_functions(
|
|
|
203
212
|
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
204
213
|
verbose=verbose,
|
|
205
214
|
),
|
|
206
|
-
MambaCache: lambda verbose=verbose: register_class_serialization(
|
|
207
|
-
MambaCache,
|
|
208
|
-
flatten_mamba_cache,
|
|
209
|
-
unflatten_mamba_cache,
|
|
210
|
-
flatten_with_keys_mamba_cache,
|
|
211
|
-
verbose=verbose,
|
|
212
|
-
),
|
|
213
215
|
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
|
|
214
216
|
EncoderDecoderCache,
|
|
215
217
|
flatten_encoder_decoder_cache,
|
|
@@ -232,6 +234,17 @@ def serialization_functions(
|
|
|
232
234
|
verbose=verbose,
|
|
233
235
|
),
|
|
234
236
|
}
|
|
237
|
+
MambaCache = get_mamba_cache_cls()
|
|
238
|
+
if MambaCache:
|
|
239
|
+
transformers_classes[MambaCache] = (
|
|
240
|
+
lambda verbose=verbose: register_class_serialization(
|
|
241
|
+
MambaCache,
|
|
242
|
+
flatten_mamba_cache,
|
|
243
|
+
unflatten_mamba_cache,
|
|
244
|
+
flatten_with_keys_mamba_cache,
|
|
245
|
+
verbose=verbose,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
235
248
|
classes.update(transformers_classes)
|
|
236
249
|
|
|
237
250
|
if patch_diffusers:
|
|
@@ -287,7 +300,12 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
|
287
300
|
|
|
288
301
|
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
289
302
|
"""Undo all registrations."""
|
|
290
|
-
|
|
303
|
+
MambaCache = get_mamba_cache_cls()
|
|
304
|
+
cls_ensemble = (
|
|
305
|
+
{DynamicCache, EncoderDecoderCache}
|
|
306
|
+
| set(undo)
|
|
307
|
+
| ({MambaCache} if MambaCache else set())
|
|
308
|
+
)
|
|
291
309
|
for cls in cls_ensemble:
|
|
292
310
|
if undo.get(cls.__name__, False):
|
|
293
311
|
unregister_class_serialization(cls, verbose)
|