onnx-diagnostic 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +213 -5
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +42 -20
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +1561 -177
- onnx_diagnostic/helpers/torch_helper.py +6 -2
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/image_text_to_text.py +69 -18
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +73 -5
- onnx_diagnostic/torch_models/hghub/hub_data.py +7 -2
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +74 -14
- onnx_diagnostic/torch_models/validate.py +45 -16
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +29 -24
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import pprint
|
|
2
|
-
from typing import Any, Callable, Dict,
|
|
2
|
+
from typing import Any, Callable, Dict, Optional, Set
|
|
3
3
|
import packaging.version as pv
|
|
4
4
|
import optree
|
|
5
5
|
import torch
|
|
@@ -11,10 +11,9 @@ from transformers.cache_utils import (
|
|
|
11
11
|
SlidingWindowCache,
|
|
12
12
|
StaticCache,
|
|
13
13
|
)
|
|
14
|
-
from transformers.modeling_outputs import BaseModelOutput
|
|
15
|
-
from ..helpers import string_type
|
|
16
|
-
from ..helpers.cache_helper import make_static_cache
|
|
17
14
|
|
|
15
|
+
from ..helpers import string_type
|
|
16
|
+
from .serialization import _lower_name_with_
|
|
18
17
|
|
|
19
18
|
PATCH_OF_PATCHES: Set[Any] = set()
|
|
20
19
|
|
|
@@ -40,10 +39,12 @@ def register_class_serialization(
|
|
|
40
39
|
:return: registered or not
|
|
41
40
|
"""
|
|
42
41
|
if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
|
|
42
|
+
if verbose and cls is not None:
|
|
43
|
+
print(f"[register_class_serialization] already registered {cls.__name__}")
|
|
43
44
|
return False
|
|
44
45
|
|
|
45
46
|
if verbose:
|
|
46
|
-
print(f"[
|
|
47
|
+
print(f"[register_class_serialization] ---------- register {cls.__name__}")
|
|
47
48
|
torch.utils._pytree.register_pytree_node(
|
|
48
49
|
cls,
|
|
49
50
|
f_flatten,
|
|
@@ -54,8 +55,8 @@ def register_class_serialization(
|
|
|
54
55
|
if pv.Version(torch.__version__) < pv.Version("2.7"):
|
|
55
56
|
if verbose:
|
|
56
57
|
print(
|
|
57
|
-
f"[
|
|
58
|
-
f"register {cls} for torch=={torch.__version__}"
|
|
58
|
+
f"[register_class_serialization] "
|
|
59
|
+
f"---------- register {cls.__name__} for torch=={torch.__version__}"
|
|
59
60
|
)
|
|
60
61
|
torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
|
|
61
62
|
|
|
@@ -72,11 +73,34 @@ def register_class_serialization(
|
|
|
72
73
|
return True
|
|
73
74
|
|
|
74
75
|
|
|
75
|
-
def register_cache_serialization(
|
|
76
|
+
def register_cache_serialization(
|
|
77
|
+
patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
|
|
78
|
+
) -> Dict[str, bool]:
|
|
76
79
|
"""
|
|
77
80
|
Registers many classes with :func:`register_class_serialization`.
|
|
78
81
|
Returns information needed to undo the registration.
|
|
82
|
+
|
|
83
|
+
:param patch_transformers: add serialization function for
|
|
84
|
+
:epkg:`transformers` package
|
|
85
|
+
:param patch_diffusers: add serialization function for
|
|
86
|
+
:epkg:`diffusers` package
|
|
87
|
+
:param verbosity: verbosity level
|
|
88
|
+
:return: information to unpatch
|
|
79
89
|
"""
|
|
90
|
+
wrong: Dict[type, Optional[str]] = {}
|
|
91
|
+
if patch_transformers:
|
|
92
|
+
from .serialization.transformers_impl import WRONG_REGISTRATIONS
|
|
93
|
+
|
|
94
|
+
wrong |= WRONG_REGISTRATIONS
|
|
95
|
+
if patch_diffusers:
|
|
96
|
+
from .serialization.diffusers_impl import WRONG_REGISTRATIONS
|
|
97
|
+
|
|
98
|
+
wrong |= WRONG_REGISTRATIONS
|
|
99
|
+
|
|
100
|
+
registration_functions = serialization_functions(
|
|
101
|
+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
|
|
102
|
+
)
|
|
103
|
+
|
|
80
104
|
# DynamicCache serialization is different in transformers and does not
|
|
81
105
|
# play way with torch.export.export.
|
|
82
106
|
# see test test_export_dynamic_cache_cat with NOBYPASS=1
|
|
@@ -85,109 +109,137 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
|
|
|
85
109
|
# torch.fx._pytree.register_pytree_flatten_spec(
|
|
86
110
|
# DynamicCache, _flatten_dynamic_cache_for_fx)
|
|
87
111
|
# so we remove it anyway
|
|
88
|
-
if (
|
|
89
|
-
DynamicCache in torch.utils._pytree.SUPPORTED_NODES
|
|
90
|
-
and DynamicCache not in PATCH_OF_PATCHES
|
|
91
|
-
# and pv.Version(torch.__version__) < pv.Version("2.7")
|
|
92
|
-
and pv.Version(transformers.__version__) >= pv.Version("4.50")
|
|
93
|
-
):
|
|
94
|
-
if verbose:
|
|
95
|
-
print(
|
|
96
|
-
f"[_fix_registration] DynamicCache is unregistered and "
|
|
97
|
-
f"registered first for transformers=={transformers.__version__}"
|
|
98
|
-
)
|
|
99
|
-
unregister(DynamicCache, verbose=verbose)
|
|
100
|
-
register_class_serialization(
|
|
101
|
-
DynamicCache,
|
|
102
|
-
flatten_dynamic_cache,
|
|
103
|
-
unflatten_dynamic_cache,
|
|
104
|
-
flatten_with_keys_dynamic_cache,
|
|
105
|
-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
106
|
-
verbose=verbose,
|
|
107
|
-
)
|
|
108
|
-
if verbose:
|
|
109
|
-
print("[_fix_registration] DynamicCache done.")
|
|
110
|
-
# To avoid doing it multiple times.
|
|
111
|
-
PATCH_OF_PATCHES.add(DynamicCache)
|
|
112
|
-
|
|
113
112
|
# BaseModelOutput serialization is incomplete.
|
|
114
113
|
# It does not include dynamic shapes mapping.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
f"registered first for transformers=={transformers.__version__}"
|
|
114
|
+
for cls, version in wrong.items():
|
|
115
|
+
if (
|
|
116
|
+
cls in torch.utils._pytree.SUPPORTED_NODES
|
|
117
|
+
and cls not in PATCH_OF_PATCHES
|
|
118
|
+
# and pv.Version(torch.__version__) < pv.Version("2.7")
|
|
119
|
+
and (
|
|
120
|
+
version is None or pv.Version(transformers.__version__) >= pv.Version(version)
|
|
123
121
|
)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
122
|
+
):
|
|
123
|
+
assert cls in registration_functions, (
|
|
124
|
+
f"{cls} has no registration functions mapped to it, "
|
|
125
|
+
f"available options are {list(registration_functions)}"
|
|
126
|
+
)
|
|
127
|
+
if verbose:
|
|
128
|
+
print(
|
|
129
|
+
f"[_fix_registration] {cls.__name__} is unregistered and "
|
|
130
|
+
f"registered first"
|
|
131
|
+
)
|
|
132
|
+
unregister_class_serialization(cls, verbose=verbose)
|
|
133
|
+
registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg]
|
|
134
|
+
if verbose:
|
|
135
|
+
print(f"[_fix_registration] {cls.__name__} done.")
|
|
136
|
+
# To avoid doing it multiple times.
|
|
137
|
+
PATCH_OF_PATCHES.add(cls)
|
|
138
|
+
|
|
139
|
+
# classes with no registration at all.
|
|
140
|
+
done = {}
|
|
141
|
+
for k, v in registration_functions.items():
|
|
142
|
+
done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg]
|
|
143
|
+
return done
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def serialization_functions(
|
|
147
|
+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
|
|
148
|
+
) -> Dict[type, Callable[[int], bool]]:
|
|
149
|
+
"""Returns the list of serialization functions."""
|
|
139
150
|
|
|
151
|
+
supported_classes: Set[type] = set()
|
|
152
|
+
classes: Dict[type, Callable[[int], bool]] = {}
|
|
153
|
+
all_functions: Dict[type, Optional[str]] = {}
|
|
140
154
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
DynamicCache,
|
|
155
|
+
if patch_transformers:
|
|
156
|
+
from .serialization.transformers_impl import (
|
|
157
|
+
__dict__ as dtr,
|
|
158
|
+
SUPPORTED_DATACLASSES,
|
|
146
159
|
flatten_dynamic_cache,
|
|
147
160
|
unflatten_dynamic_cache,
|
|
148
161
|
flatten_with_keys_dynamic_cache,
|
|
149
|
-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
150
|
-
verbose=verbose,
|
|
151
|
-
),
|
|
152
|
-
MambaCache=register_class_serialization(
|
|
153
|
-
MambaCache,
|
|
154
162
|
flatten_mamba_cache,
|
|
155
163
|
unflatten_mamba_cache,
|
|
156
164
|
flatten_with_keys_mamba_cache,
|
|
157
|
-
verbose=verbose,
|
|
158
|
-
),
|
|
159
|
-
EncoderDecoderCache=register_class_serialization(
|
|
160
|
-
EncoderDecoderCache,
|
|
161
165
|
flatten_encoder_decoder_cache,
|
|
162
166
|
unflatten_encoder_decoder_cache,
|
|
163
167
|
flatten_with_keys_encoder_decoder_cache,
|
|
164
|
-
verbose=verbose,
|
|
165
|
-
),
|
|
166
|
-
BaseModelOutput=register_class_serialization(
|
|
167
|
-
BaseModelOutput,
|
|
168
|
-
flatten_base_model_output,
|
|
169
|
-
unflatten_base_model_output,
|
|
170
|
-
flatten_with_keys_base_model_output,
|
|
171
|
-
verbose=verbose,
|
|
172
|
-
),
|
|
173
|
-
SlidingWindowCache=register_class_serialization(
|
|
174
|
-
SlidingWindowCache,
|
|
175
168
|
flatten_sliding_window_cache,
|
|
176
169
|
unflatten_sliding_window_cache,
|
|
177
170
|
flatten_with_keys_sliding_window_cache,
|
|
178
|
-
verbose=verbose,
|
|
179
|
-
),
|
|
180
|
-
StaticCache=register_class_serialization(
|
|
181
|
-
StaticCache,
|
|
182
171
|
flatten_static_cache,
|
|
183
172
|
unflatten_static_cache,
|
|
184
173
|
flatten_with_keys_static_cache,
|
|
185
|
-
|
|
186
|
-
),
|
|
187
|
-
)
|
|
174
|
+
)
|
|
188
175
|
|
|
176
|
+
all_functions.update(dtr)
|
|
177
|
+
supported_classes |= SUPPORTED_DATACLASSES
|
|
178
|
+
|
|
179
|
+
transformers_classes = {
|
|
180
|
+
DynamicCache: lambda verbose=verbose: register_class_serialization(
|
|
181
|
+
DynamicCache,
|
|
182
|
+
flatten_dynamic_cache,
|
|
183
|
+
unflatten_dynamic_cache,
|
|
184
|
+
flatten_with_keys_dynamic_cache,
|
|
185
|
+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
186
|
+
verbose=verbose,
|
|
187
|
+
),
|
|
188
|
+
MambaCache: lambda verbose=verbose: register_class_serialization(
|
|
189
|
+
MambaCache,
|
|
190
|
+
flatten_mamba_cache,
|
|
191
|
+
unflatten_mamba_cache,
|
|
192
|
+
flatten_with_keys_mamba_cache,
|
|
193
|
+
verbose=verbose,
|
|
194
|
+
),
|
|
195
|
+
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
|
|
196
|
+
EncoderDecoderCache,
|
|
197
|
+
flatten_encoder_decoder_cache,
|
|
198
|
+
unflatten_encoder_decoder_cache,
|
|
199
|
+
flatten_with_keys_encoder_decoder_cache,
|
|
200
|
+
verbose=verbose,
|
|
201
|
+
),
|
|
202
|
+
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
|
|
203
|
+
SlidingWindowCache,
|
|
204
|
+
flatten_sliding_window_cache,
|
|
205
|
+
unflatten_sliding_window_cache,
|
|
206
|
+
flatten_with_keys_sliding_window_cache,
|
|
207
|
+
verbose=verbose,
|
|
208
|
+
),
|
|
209
|
+
StaticCache: lambda verbose=verbose: register_class_serialization(
|
|
210
|
+
StaticCache,
|
|
211
|
+
flatten_static_cache,
|
|
212
|
+
unflatten_static_cache,
|
|
213
|
+
flatten_with_keys_static_cache,
|
|
214
|
+
verbose=verbose,
|
|
215
|
+
),
|
|
216
|
+
}
|
|
217
|
+
classes.update(transformers_classes)
|
|
218
|
+
|
|
219
|
+
if patch_diffusers:
|
|
220
|
+
from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu
|
|
221
|
+
|
|
222
|
+
all_functions.update(dfu)
|
|
223
|
+
supported_classes |= SUPPORTED_DATACLASSES
|
|
224
|
+
|
|
225
|
+
for cls in supported_classes:
|
|
226
|
+
lname = _lower_name_with_(cls.__name__)
|
|
227
|
+
assert (
|
|
228
|
+
f"flatten_{lname}" in all_functions
|
|
229
|
+
), f"Unable to find function 'flatten_{lname}' in {list(all_functions)}"
|
|
230
|
+
classes[cls] = (
|
|
231
|
+
lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501
|
|
232
|
+
cls,
|
|
233
|
+
_al[f"flatten_{_ln}"],
|
|
234
|
+
_al[f"unflatten_{_ln}"],
|
|
235
|
+
_al[f"flatten_with_keys_{_ln}"],
|
|
236
|
+
verbose=verbose,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
return classes
|
|
189
240
|
|
|
190
|
-
|
|
241
|
+
|
|
242
|
+
def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
191
243
|
"""Undo the registration."""
|
|
192
244
|
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
|
|
193
245
|
if cls in torch.fx._pytree.SUPPORTED_NODES:
|
|
@@ -217,264 +269,7 @@ def unregister(cls: type, verbose: int = 0):
|
|
|
217
269
|
|
|
218
270
|
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
219
271
|
"""Undo all registrations."""
|
|
220
|
-
|
|
272
|
+
cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
|
|
273
|
+
for cls in cls_ensemble:
|
|
221
274
|
if undo.get(cls.__name__, False):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
############
|
|
226
|
-
# MambaCache
|
|
227
|
-
############
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
def flatten_mamba_cache(
|
|
231
|
-
mamba_cache: MambaCache,
|
|
232
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
233
|
-
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
234
|
-
flat = [
|
|
235
|
-
("conv_states", mamba_cache.conv_states),
|
|
236
|
-
("ssm_states", mamba_cache.ssm_states),
|
|
237
|
-
]
|
|
238
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def unflatten_mamba_cache(
|
|
242
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
243
|
-
) -> MambaCache:
|
|
244
|
-
"""Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
|
|
245
|
-
conv_states, ssm_states = values
|
|
246
|
-
|
|
247
|
-
class _config:
|
|
248
|
-
def __init__(self):
|
|
249
|
-
if isinstance(conv_states, list):
|
|
250
|
-
self.intermediate_size = conv_states[0].shape[1]
|
|
251
|
-
self.state_size = ssm_states[0].shape[2]
|
|
252
|
-
self.conv_kernel = conv_states[0].shape[2]
|
|
253
|
-
self.num_hidden_layers = len(conv_states)
|
|
254
|
-
else:
|
|
255
|
-
self.intermediate_size = conv_states.shape[2]
|
|
256
|
-
self.state_size = ssm_states.shape[3]
|
|
257
|
-
self.conv_kernel = conv_states.shape[3]
|
|
258
|
-
self.num_hidden_layers = conv_states.shape[0]
|
|
259
|
-
|
|
260
|
-
cache = MambaCache(
|
|
261
|
-
_config(),
|
|
262
|
-
max_batch_size=1,
|
|
263
|
-
dtype=values[-1][0].dtype,
|
|
264
|
-
device="cpu" if values[-1][0].get_device() < 0 else "cuda",
|
|
265
|
-
)
|
|
266
|
-
values = dict(zip(context, values))
|
|
267
|
-
for k, v in values.items():
|
|
268
|
-
setattr(cache, k, v)
|
|
269
|
-
return cache
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
|
|
273
|
-
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
274
|
-
torch.utils._pytree.Context,
|
|
275
|
-
]:
|
|
276
|
-
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
277
|
-
values, context = flatten_mamba_cache(cache)
|
|
278
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
##############
|
|
282
|
-
# DynamicCache
|
|
283
|
-
##############
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
def flatten_dynamic_cache(
|
|
287
|
-
dynamic_cache: DynamicCache,
|
|
288
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
289
|
-
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
290
|
-
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
|
|
291
|
-
return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
|
|
292
|
-
flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
|
|
293
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
def flatten_with_keys_dynamic_cache(
|
|
297
|
-
dynamic_cache: DynamicCache,
|
|
298
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
299
|
-
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
300
|
-
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
|
|
301
|
-
return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
|
|
302
|
-
values, context = flatten_dynamic_cache(dynamic_cache)
|
|
303
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def unflatten_dynamic_cache(
|
|
307
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
308
|
-
) -> DynamicCache:
|
|
309
|
-
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
310
|
-
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
|
|
311
|
-
assert output_type is None, f"output_type={output_type} not supported"
|
|
312
|
-
return transformers.cache_utils._unflatten_dynamic_cache(values, context)
|
|
313
|
-
|
|
314
|
-
cache = transformers.cache_utils.DynamicCache()
|
|
315
|
-
values = dict(zip(context, values))
|
|
316
|
-
for k, v in values.items():
|
|
317
|
-
setattr(cache, k, v)
|
|
318
|
-
return cache
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
##############
|
|
322
|
-
# DynamicCache
|
|
323
|
-
##############
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def flatten_static_cache(
|
|
327
|
-
cache: StaticCache,
|
|
328
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
329
|
-
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
330
|
-
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
331
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
def flatten_with_keys_static_cache(
|
|
335
|
-
cache: StaticCache,
|
|
336
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
337
|
-
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
338
|
-
values, context = flatten_static_cache(cache)
|
|
339
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
def unflatten_static_cache(
|
|
343
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
344
|
-
) -> StaticCache:
|
|
345
|
-
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
346
|
-
return make_static_cache(list(zip(values[0], values[1])))
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
####################
|
|
350
|
-
# SlidingWindowCache
|
|
351
|
-
####################
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
def flatten_sliding_window_cache(
|
|
355
|
-
cache: SlidingWindowCache,
|
|
356
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
357
|
-
"""
|
|
358
|
-
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
359
|
-
with python objects.
|
|
360
|
-
"""
|
|
361
|
-
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
362
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
def flatten_with_keys_sliding_window_cache(
|
|
366
|
-
cache: SlidingWindowCache,
|
|
367
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
368
|
-
"""
|
|
369
|
-
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
370
|
-
with python objects.
|
|
371
|
-
"""
|
|
372
|
-
values, context = flatten_sliding_window_cache(cache)
|
|
373
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
def unflatten_sliding_window_cache(
|
|
377
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
378
|
-
) -> SlidingWindowCache:
|
|
379
|
-
"""Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
|
|
380
|
-
key_cache, value_cache = values
|
|
381
|
-
|
|
382
|
-
class _config:
|
|
383
|
-
def __init__(self):
|
|
384
|
-
self.head_dim = key_cache[0].shape[-1]
|
|
385
|
-
self.num_attention_heads = key_cache[0].shape[1]
|
|
386
|
-
self.num_hidden_layers = len(key_cache)
|
|
387
|
-
self.sliding_window = key_cache[0].shape[2]
|
|
388
|
-
|
|
389
|
-
cache = SlidingWindowCache(
|
|
390
|
-
_config(),
|
|
391
|
-
max_batch_size=key_cache[0].shape[0],
|
|
392
|
-
max_cache_len=key_cache[0].shape[2], # sligding window
|
|
393
|
-
device=key_cache[0].device,
|
|
394
|
-
dtype=key_cache[0].dtype,
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
values = dict(zip(context, values))
|
|
398
|
-
for k, v in values.items():
|
|
399
|
-
setattr(cache, k, v)
|
|
400
|
-
return cache
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
#####################
|
|
404
|
-
# EncoderDecoderCache
|
|
405
|
-
#####################
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
def flatten_encoder_decoder_cache(
|
|
409
|
-
ec_cache: EncoderDecoderCache,
|
|
410
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
411
|
-
"""
|
|
412
|
-
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
413
|
-
with python objects.
|
|
414
|
-
"""
|
|
415
|
-
dictionary = {
|
|
416
|
-
"self_attention_cache": ec_cache.self_attention_cache,
|
|
417
|
-
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
418
|
-
}
|
|
419
|
-
return torch.utils._pytree._dict_flatten(dictionary)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
|
|
423
|
-
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
424
|
-
torch.utils._pytree.Context,
|
|
425
|
-
]:
|
|
426
|
-
"""
|
|
427
|
-
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
428
|
-
with python objects.
|
|
429
|
-
"""
|
|
430
|
-
dictionary = {
|
|
431
|
-
"self_attention_cache": ec_cache.self_attention_cache,
|
|
432
|
-
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
433
|
-
}
|
|
434
|
-
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
def unflatten_encoder_decoder_cache(
|
|
438
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
439
|
-
) -> EncoderDecoderCache:
|
|
440
|
-
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
|
|
441
|
-
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
|
442
|
-
return EncoderDecoderCache(**dictionary)
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
#################
|
|
446
|
-
# BaseModelOutput
|
|
447
|
-
#################
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
def flatten_base_model_output(
|
|
451
|
-
bo: BaseModelOutput,
|
|
452
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
453
|
-
"""
|
|
454
|
-
Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
|
|
455
|
-
with python objects.
|
|
456
|
-
"""
|
|
457
|
-
return list(bo.values()), list(bo.keys())
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
def flatten_with_keys_base_model_output(
|
|
461
|
-
bo: BaseModelOutput,
|
|
462
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
463
|
-
"""
|
|
464
|
-
Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
|
|
465
|
-
with python objects.
|
|
466
|
-
"""
|
|
467
|
-
values, context = flatten_base_model_output(bo)
|
|
468
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
def unflatten_base_model_output(
|
|
472
|
-
values: List[Any],
|
|
473
|
-
context: torch.utils._pytree.Context,
|
|
474
|
-
output_type=None,
|
|
475
|
-
) -> BaseModelOutput:
|
|
476
|
-
"""
|
|
477
|
-
Restores a :class:`transformers.modeling_outputs.BaseModelOutput`
|
|
478
|
-
from python objects.
|
|
479
|
-
"""
|
|
480
|
-
return BaseModelOutput(**dict(zip(context, values)))
|
|
275
|
+
unregister_class_serialization(cls, verbose)
|
|
@@ -2,6 +2,7 @@ import inspect
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from functools import wraps
|
|
4
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
5
|
+
import packaging.version as pv
|
|
5
6
|
import torch
|
|
6
7
|
import transformers
|
|
7
8
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
@@ -20,18 +21,41 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
|
|
|
20
21
|
]
|
|
21
22
|
if bh_indices:
|
|
22
23
|
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
|
24
|
+
# reshape
|
|
23
25
|
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
|
|
24
26
|
dimensions = tuple(reversed(dimensions))
|
|
25
27
|
indices = tuple(shape.index(-1) for shape in dimensions)
|
|
26
28
|
|
|
29
|
+
# unsqueeze
|
|
30
|
+
udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
|
|
31
|
+
|
|
27
32
|
def vector_mask_function(
|
|
28
33
|
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
|
|
29
34
|
):
|
|
30
|
-
assert len(args) == len(
|
|
31
|
-
dimensions
|
|
32
|
-
|
|
35
|
+
assert len(args) == len(dimensions) == len(udimensions), (
|
|
36
|
+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
|
|
37
|
+
f"and udimensions={udimensions}."
|
|
38
|
+
)
|
|
39
|
+
assert len(indices) == len(args), (
|
|
40
|
+
f"Mismatch between args={string_type(args)} and indices={indices}, "
|
|
41
|
+
f"they should have the same length."
|
|
42
|
+
)
|
|
43
|
+
for a in args:
|
|
44
|
+
assert (
|
|
45
|
+
a.ndim == 1
|
|
46
|
+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
|
|
47
|
+
torch._check(a.shape[0] > 0)
|
|
48
|
+
|
|
33
49
|
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
|
|
50
|
+
# new_args = [
|
|
51
|
+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
52
|
+
# for a, dims in zip(args, udimensions)
|
|
53
|
+
# ]
|
|
34
54
|
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
55
|
+
# if is_torchdynamo_exporting():
|
|
56
|
+
# for a in args:
|
|
57
|
+
# # The exporter should export with a dimension > 1 to make sure it is dynamic.
|
|
58
|
+
# torch._check(a.shape[0] > 1)
|
|
35
59
|
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
36
60
|
return mask_function(*expanded_args)
|
|
37
61
|
|
|
@@ -791,10 +815,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
791
815
|
return wrapper
|
|
792
816
|
|
|
793
817
|
|
|
794
|
-
class
|
|
795
|
-
_PATCHES_ = ["forward"]
|
|
796
|
-
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
|
|
797
|
-
|
|
818
|
+
class common_RotaryEmbedding(torch.nn.Module):
|
|
798
819
|
@torch.no_grad()
|
|
799
820
|
@patched_dynamic_rope_update
|
|
800
821
|
def forward(self, x, position_ids):
|
|
@@ -820,6 +841,65 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
|
|
|
820
841
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
821
842
|
|
|
822
843
|
|
|
844
|
+
class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
|
|
845
|
+
_PATCHES_ = ["forward"]
|
|
846
|
+
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
850
|
+
|
|
851
|
+
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
|
|
852
|
+
_PATCHES_ = ["forward"]
|
|
853
|
+
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
|
|
854
|
+
|
|
855
|
+
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
|
|
856
|
+
_PATCHES_ = ["forward"]
|
|
857
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
|
|
861
|
+
_PATCHES_ = ["forward"]
|
|
862
|
+
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
|
|
866
|
+
_PATCHES_ = ["forward"]
|
|
867
|
+
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
|
|
871
|
+
_PATCHES_ = ["forward"]
|
|
872
|
+
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
|
|
876
|
+
_PATCHES_ = ["forward"]
|
|
877
|
+
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.51"):
|
|
881
|
+
|
|
882
|
+
class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
|
|
883
|
+
_PATCHES_ = ["forward"]
|
|
884
|
+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
888
|
+
|
|
889
|
+
class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
|
|
890
|
+
_PATCHES_ = ["forward"]
|
|
891
|
+
_PATCHED_CLASS_ = (
|
|
892
|
+
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.53"):
|
|
897
|
+
|
|
898
|
+
class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
|
|
899
|
+
_PATCHES_ = ["forward"]
|
|
900
|
+
_PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
|
|
901
|
+
|
|
902
|
+
|
|
823
903
|
class patched_IdeficsEmbedding(torch.nn.Module):
|
|
824
904
|
_PATCHES_ = ["forward"]
|
|
825
905
|
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
|