onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +22 -5
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +23 -12
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +308 -83
- onnx_diagnostic/helpers/rt_helper.py +11 -1
- onnx_diagnostic/helpers/torch_helper.py +7 -3
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
- onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
- onnx_diagnostic/torch_models/validate.py +36 -12
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import pprint
|
|
2
|
-
from typing import Any, Callable, Dict,
|
|
2
|
+
from typing import Any, Callable, Dict, Optional, Set
|
|
3
3
|
import packaging.version as pv
|
|
4
4
|
import optree
|
|
5
5
|
import torch
|
|
@@ -11,10 +11,9 @@ from transformers.cache_utils import (
|
|
|
11
11
|
SlidingWindowCache,
|
|
12
12
|
StaticCache,
|
|
13
13
|
)
|
|
14
|
-
from transformers.modeling_outputs import BaseModelOutput
|
|
15
|
-
from ..helpers import string_type
|
|
16
|
-
from ..helpers.cache_helper import make_static_cache
|
|
17
14
|
|
|
15
|
+
from ..helpers import string_type
|
|
16
|
+
from .serialization import _lower_name_with_
|
|
18
17
|
|
|
19
18
|
PATCH_OF_PATCHES: Set[Any] = set()
|
|
20
19
|
|
|
@@ -29,7 +28,8 @@ def register_class_serialization(
|
|
|
29
28
|
) -> bool:
|
|
30
29
|
"""
|
|
31
30
|
Registers a class.
|
|
32
|
-
It can be undone with
|
|
31
|
+
It can be undone with
|
|
32
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`.
|
|
33
33
|
|
|
34
34
|
:param cls: class to register
|
|
35
35
|
:param f_flatten: see ``torch.utils._pytree.register_pytree_node``
|
|
@@ -40,10 +40,12 @@ def register_class_serialization(
|
|
|
40
40
|
:return: registered or not
|
|
41
41
|
"""
|
|
42
42
|
if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
|
|
43
|
+
if verbose and cls is not None:
|
|
44
|
+
print(f"[register_class_serialization] already registered {cls.__name__}")
|
|
43
45
|
return False
|
|
44
46
|
|
|
45
47
|
if verbose:
|
|
46
|
-
print(f"[
|
|
48
|
+
print(f"[register_class_serialization] ---------- register {cls.__name__}")
|
|
47
49
|
torch.utils._pytree.register_pytree_node(
|
|
48
50
|
cls,
|
|
49
51
|
f_flatten,
|
|
@@ -54,8 +56,8 @@ def register_class_serialization(
|
|
|
54
56
|
if pv.Version(torch.__version__) < pv.Version("2.7"):
|
|
55
57
|
if verbose:
|
|
56
58
|
print(
|
|
57
|
-
f"[
|
|
58
|
-
f"register {cls} for torch=={torch.__version__}"
|
|
59
|
+
f"[register_class_serialization] "
|
|
60
|
+
f"---------- register {cls.__name__} for torch=={torch.__version__}"
|
|
59
61
|
)
|
|
60
62
|
torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
|
|
61
63
|
|
|
@@ -72,11 +74,35 @@ def register_class_serialization(
|
|
|
72
74
|
return True
|
|
73
75
|
|
|
74
76
|
|
|
75
|
-
def register_cache_serialization(
|
|
77
|
+
def register_cache_serialization(
|
|
78
|
+
patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
|
|
79
|
+
) -> Dict[str, bool]:
|
|
76
80
|
"""
|
|
77
|
-
Registers many classes with
|
|
81
|
+
Registers many classes with
|
|
82
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`.
|
|
78
83
|
Returns information needed to undo the registration.
|
|
84
|
+
|
|
85
|
+
:param patch_transformers: add serialization function for
|
|
86
|
+
:epkg:`transformers` package
|
|
87
|
+
:param patch_diffusers: add serialization function for
|
|
88
|
+
:epkg:`diffusers` package
|
|
89
|
+
:param verbosity: verbosity level
|
|
90
|
+
:return: information to unpatch
|
|
79
91
|
"""
|
|
92
|
+
wrong: Dict[type, Optional[str]] = {}
|
|
93
|
+
if patch_transformers:
|
|
94
|
+
from .serialization.transformers_impl import WRONG_REGISTRATIONS
|
|
95
|
+
|
|
96
|
+
wrong |= WRONG_REGISTRATIONS
|
|
97
|
+
if patch_diffusers:
|
|
98
|
+
from .serialization.diffusers_impl import WRONG_REGISTRATIONS
|
|
99
|
+
|
|
100
|
+
wrong |= WRONG_REGISTRATIONS
|
|
101
|
+
|
|
102
|
+
registration_functions = serialization_functions(
|
|
103
|
+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
|
|
104
|
+
)
|
|
105
|
+
|
|
80
106
|
# DynamicCache serialization is different in transformers and does not
|
|
81
107
|
# play way with torch.export.export.
|
|
82
108
|
# see test test_export_dynamic_cache_cat with NOBYPASS=1
|
|
@@ -85,109 +111,137 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
|
|
|
85
111
|
# torch.fx._pytree.register_pytree_flatten_spec(
|
|
86
112
|
# DynamicCache, _flatten_dynamic_cache_for_fx)
|
|
87
113
|
# so we remove it anyway
|
|
88
|
-
if (
|
|
89
|
-
DynamicCache in torch.utils._pytree.SUPPORTED_NODES
|
|
90
|
-
and DynamicCache not in PATCH_OF_PATCHES
|
|
91
|
-
# and pv.Version(torch.__version__) < pv.Version("2.7")
|
|
92
|
-
and pv.Version(transformers.__version__) >= pv.Version("4.50")
|
|
93
|
-
):
|
|
94
|
-
if verbose:
|
|
95
|
-
print(
|
|
96
|
-
f"[_fix_registration] DynamicCache is unregistered and "
|
|
97
|
-
f"registered first for transformers=={transformers.__version__}"
|
|
98
|
-
)
|
|
99
|
-
unregister(DynamicCache, verbose=verbose)
|
|
100
|
-
register_class_serialization(
|
|
101
|
-
DynamicCache,
|
|
102
|
-
flatten_dynamic_cache,
|
|
103
|
-
unflatten_dynamic_cache,
|
|
104
|
-
flatten_with_keys_dynamic_cache,
|
|
105
|
-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
106
|
-
verbose=verbose,
|
|
107
|
-
)
|
|
108
|
-
if verbose:
|
|
109
|
-
print("[_fix_registration] DynamicCache done.")
|
|
110
|
-
# To avoid doing it multiple times.
|
|
111
|
-
PATCH_OF_PATCHES.add(DynamicCache)
|
|
112
|
-
|
|
113
114
|
# BaseModelOutput serialization is incomplete.
|
|
114
115
|
# It does not include dynamic shapes mapping.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
f"registered first for transformers=={transformers.__version__}"
|
|
116
|
+
for cls, version in wrong.items():
|
|
117
|
+
if (
|
|
118
|
+
cls in torch.utils._pytree.SUPPORTED_NODES
|
|
119
|
+
and cls not in PATCH_OF_PATCHES
|
|
120
|
+
# and pv.Version(torch.__version__) < pv.Version("2.7")
|
|
121
|
+
and (
|
|
122
|
+
version is None or pv.Version(transformers.__version__) >= pv.Version(version)
|
|
123
123
|
)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
124
|
+
):
|
|
125
|
+
assert cls in registration_functions, (
|
|
126
|
+
f"{cls} has no registration functions mapped to it, "
|
|
127
|
+
f"available options are {list(registration_functions)}"
|
|
128
|
+
)
|
|
129
|
+
if verbose:
|
|
130
|
+
print(
|
|
131
|
+
f"[_fix_registration] {cls.__name__} is unregistered and "
|
|
132
|
+
f"registered first"
|
|
133
|
+
)
|
|
134
|
+
unregister_class_serialization(cls, verbose=verbose)
|
|
135
|
+
registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg]
|
|
136
|
+
if verbose:
|
|
137
|
+
print(f"[_fix_registration] {cls.__name__} done.")
|
|
138
|
+
# To avoid doing it multiple times.
|
|
139
|
+
PATCH_OF_PATCHES.add(cls)
|
|
140
|
+
|
|
141
|
+
# classes with no registration at all.
|
|
142
|
+
done = {}
|
|
143
|
+
for k, v in registration_functions.items():
|
|
144
|
+
done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg]
|
|
145
|
+
return done
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def serialization_functions(
|
|
149
|
+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
|
|
150
|
+
) -> Dict[type, Callable[[int], bool]]:
|
|
151
|
+
"""Returns the list of serialization functions."""
|
|
139
152
|
|
|
153
|
+
supported_classes: Set[type] = set()
|
|
154
|
+
classes: Dict[type, Callable[[int], bool]] = {}
|
|
155
|
+
all_functions: Dict[type, Optional[str]] = {}
|
|
140
156
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
DynamicCache,
|
|
157
|
+
if patch_transformers:
|
|
158
|
+
from .serialization.transformers_impl import (
|
|
159
|
+
__dict__ as dtr,
|
|
160
|
+
SUPPORTED_DATACLASSES,
|
|
146
161
|
flatten_dynamic_cache,
|
|
147
162
|
unflatten_dynamic_cache,
|
|
148
163
|
flatten_with_keys_dynamic_cache,
|
|
149
|
-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
150
|
-
verbose=verbose,
|
|
151
|
-
),
|
|
152
|
-
MambaCache=register_class_serialization(
|
|
153
|
-
MambaCache,
|
|
154
164
|
flatten_mamba_cache,
|
|
155
165
|
unflatten_mamba_cache,
|
|
156
166
|
flatten_with_keys_mamba_cache,
|
|
157
|
-
verbose=verbose,
|
|
158
|
-
),
|
|
159
|
-
EncoderDecoderCache=register_class_serialization(
|
|
160
|
-
EncoderDecoderCache,
|
|
161
167
|
flatten_encoder_decoder_cache,
|
|
162
168
|
unflatten_encoder_decoder_cache,
|
|
163
169
|
flatten_with_keys_encoder_decoder_cache,
|
|
164
|
-
verbose=verbose,
|
|
165
|
-
),
|
|
166
|
-
BaseModelOutput=register_class_serialization(
|
|
167
|
-
BaseModelOutput,
|
|
168
|
-
flatten_base_model_output,
|
|
169
|
-
unflatten_base_model_output,
|
|
170
|
-
flatten_with_keys_base_model_output,
|
|
171
|
-
verbose=verbose,
|
|
172
|
-
),
|
|
173
|
-
SlidingWindowCache=register_class_serialization(
|
|
174
|
-
SlidingWindowCache,
|
|
175
170
|
flatten_sliding_window_cache,
|
|
176
171
|
unflatten_sliding_window_cache,
|
|
177
172
|
flatten_with_keys_sliding_window_cache,
|
|
178
|
-
verbose=verbose,
|
|
179
|
-
),
|
|
180
|
-
StaticCache=register_class_serialization(
|
|
181
|
-
StaticCache,
|
|
182
173
|
flatten_static_cache,
|
|
183
174
|
unflatten_static_cache,
|
|
184
175
|
flatten_with_keys_static_cache,
|
|
185
|
-
|
|
186
|
-
),
|
|
187
|
-
)
|
|
176
|
+
)
|
|
188
177
|
|
|
178
|
+
all_functions.update(dtr)
|
|
179
|
+
supported_classes |= SUPPORTED_DATACLASSES
|
|
180
|
+
|
|
181
|
+
transformers_classes = {
|
|
182
|
+
DynamicCache: lambda verbose=verbose: register_class_serialization(
|
|
183
|
+
DynamicCache,
|
|
184
|
+
flatten_dynamic_cache,
|
|
185
|
+
unflatten_dynamic_cache,
|
|
186
|
+
flatten_with_keys_dynamic_cache,
|
|
187
|
+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
188
|
+
verbose=verbose,
|
|
189
|
+
),
|
|
190
|
+
MambaCache: lambda verbose=verbose: register_class_serialization(
|
|
191
|
+
MambaCache,
|
|
192
|
+
flatten_mamba_cache,
|
|
193
|
+
unflatten_mamba_cache,
|
|
194
|
+
flatten_with_keys_mamba_cache,
|
|
195
|
+
verbose=verbose,
|
|
196
|
+
),
|
|
197
|
+
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
|
|
198
|
+
EncoderDecoderCache,
|
|
199
|
+
flatten_encoder_decoder_cache,
|
|
200
|
+
unflatten_encoder_decoder_cache,
|
|
201
|
+
flatten_with_keys_encoder_decoder_cache,
|
|
202
|
+
verbose=verbose,
|
|
203
|
+
),
|
|
204
|
+
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
|
|
205
|
+
SlidingWindowCache,
|
|
206
|
+
flatten_sliding_window_cache,
|
|
207
|
+
unflatten_sliding_window_cache,
|
|
208
|
+
flatten_with_keys_sliding_window_cache,
|
|
209
|
+
verbose=verbose,
|
|
210
|
+
),
|
|
211
|
+
StaticCache: lambda verbose=verbose: register_class_serialization(
|
|
212
|
+
StaticCache,
|
|
213
|
+
flatten_static_cache,
|
|
214
|
+
unflatten_static_cache,
|
|
215
|
+
flatten_with_keys_static_cache,
|
|
216
|
+
verbose=verbose,
|
|
217
|
+
),
|
|
218
|
+
}
|
|
219
|
+
classes.update(transformers_classes)
|
|
220
|
+
|
|
221
|
+
if patch_diffusers:
|
|
222
|
+
from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu
|
|
223
|
+
|
|
224
|
+
all_functions.update(dfu)
|
|
225
|
+
supported_classes |= SUPPORTED_DATACLASSES
|
|
226
|
+
|
|
227
|
+
for cls in supported_classes:
|
|
228
|
+
lname = _lower_name_with_(cls.__name__)
|
|
229
|
+
assert (
|
|
230
|
+
f"flatten_{lname}" in all_functions
|
|
231
|
+
), f"Unable to find function 'flatten_{lname}' in {list(all_functions)}"
|
|
232
|
+
classes[cls] = (
|
|
233
|
+
lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501
|
|
234
|
+
cls,
|
|
235
|
+
_al[f"flatten_{_ln}"],
|
|
236
|
+
_al[f"unflatten_{_ln}"],
|
|
237
|
+
_al[f"flatten_with_keys_{_ln}"],
|
|
238
|
+
verbose=verbose,
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
return classes
|
|
189
242
|
|
|
190
|
-
|
|
243
|
+
|
|
244
|
+
def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
191
245
|
"""Undo the registration."""
|
|
192
246
|
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
|
|
193
247
|
if cls in torch.fx._pytree.SUPPORTED_NODES:
|
|
@@ -217,264 +271,7 @@ def unregister(cls: type, verbose: int = 0):
|
|
|
217
271
|
|
|
218
272
|
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
219
273
|
"""Undo all registrations."""
|
|
220
|
-
|
|
274
|
+
cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
|
|
275
|
+
for cls in cls_ensemble:
|
|
221
276
|
if undo.get(cls.__name__, False):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
############
|
|
226
|
-
# MambaCache
|
|
227
|
-
############
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
def flatten_mamba_cache(
|
|
231
|
-
mamba_cache: MambaCache,
|
|
232
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
233
|
-
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
234
|
-
flat = [
|
|
235
|
-
("conv_states", mamba_cache.conv_states),
|
|
236
|
-
("ssm_states", mamba_cache.ssm_states),
|
|
237
|
-
]
|
|
238
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def unflatten_mamba_cache(
|
|
242
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
243
|
-
) -> MambaCache:
|
|
244
|
-
"""Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
|
|
245
|
-
conv_states, ssm_states = values
|
|
246
|
-
|
|
247
|
-
class _config:
|
|
248
|
-
def __init__(self):
|
|
249
|
-
if isinstance(conv_states, list):
|
|
250
|
-
self.intermediate_size = conv_states[0].shape[1]
|
|
251
|
-
self.state_size = ssm_states[0].shape[2]
|
|
252
|
-
self.conv_kernel = conv_states[0].shape[2]
|
|
253
|
-
self.num_hidden_layers = len(conv_states)
|
|
254
|
-
else:
|
|
255
|
-
self.intermediate_size = conv_states.shape[2]
|
|
256
|
-
self.state_size = ssm_states.shape[3]
|
|
257
|
-
self.conv_kernel = conv_states.shape[3]
|
|
258
|
-
self.num_hidden_layers = conv_states.shape[0]
|
|
259
|
-
|
|
260
|
-
cache = MambaCache(
|
|
261
|
-
_config(),
|
|
262
|
-
max_batch_size=1,
|
|
263
|
-
dtype=values[-1][0].dtype,
|
|
264
|
-
device="cpu" if values[-1][0].get_device() < 0 else "cuda",
|
|
265
|
-
)
|
|
266
|
-
values = dict(zip(context, values))
|
|
267
|
-
for k, v in values.items():
|
|
268
|
-
setattr(cache, k, v)
|
|
269
|
-
return cache
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
|
|
273
|
-
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
274
|
-
torch.utils._pytree.Context,
|
|
275
|
-
]:
|
|
276
|
-
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
277
|
-
values, context = flatten_mamba_cache(cache)
|
|
278
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
##############
|
|
282
|
-
# DynamicCache
|
|
283
|
-
##############
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
def flatten_dynamic_cache(
|
|
287
|
-
dynamic_cache: DynamicCache,
|
|
288
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
289
|
-
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
290
|
-
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
|
|
291
|
-
return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
|
|
292
|
-
flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
|
|
293
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
def flatten_with_keys_dynamic_cache(
|
|
297
|
-
dynamic_cache: DynamicCache,
|
|
298
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
299
|
-
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
300
|
-
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
|
|
301
|
-
return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
|
|
302
|
-
values, context = flatten_dynamic_cache(dynamic_cache)
|
|
303
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def unflatten_dynamic_cache(
|
|
307
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
308
|
-
) -> DynamicCache:
|
|
309
|
-
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
310
|
-
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
|
|
311
|
-
assert output_type is None, f"output_type={output_type} not supported"
|
|
312
|
-
return transformers.cache_utils._unflatten_dynamic_cache(values, context)
|
|
313
|
-
|
|
314
|
-
cache = transformers.cache_utils.DynamicCache()
|
|
315
|
-
values = dict(zip(context, values))
|
|
316
|
-
for k, v in values.items():
|
|
317
|
-
setattr(cache, k, v)
|
|
318
|
-
return cache
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
#############
|
|
322
|
-
# StaticCache
|
|
323
|
-
#############
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def flatten_static_cache(
|
|
327
|
-
cache: StaticCache,
|
|
328
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
329
|
-
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
330
|
-
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
331
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
def flatten_with_keys_static_cache(
|
|
335
|
-
cache: StaticCache,
|
|
336
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
337
|
-
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
338
|
-
values, context = flatten_static_cache(cache)
|
|
339
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
def unflatten_static_cache(
|
|
343
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
344
|
-
) -> StaticCache:
|
|
345
|
-
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
346
|
-
return make_static_cache(list(zip(values[0], values[1])))
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
####################
|
|
350
|
-
# SlidingWindowCache
|
|
351
|
-
####################
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
def flatten_sliding_window_cache(
|
|
355
|
-
cache: SlidingWindowCache,
|
|
356
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
357
|
-
"""
|
|
358
|
-
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
359
|
-
with python objects.
|
|
360
|
-
"""
|
|
361
|
-
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
362
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
def flatten_with_keys_sliding_window_cache(
|
|
366
|
-
cache: SlidingWindowCache,
|
|
367
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
368
|
-
"""
|
|
369
|
-
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
370
|
-
with python objects.
|
|
371
|
-
"""
|
|
372
|
-
values, context = flatten_sliding_window_cache(cache)
|
|
373
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
def unflatten_sliding_window_cache(
|
|
377
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
378
|
-
) -> SlidingWindowCache:
|
|
379
|
-
"""Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
|
|
380
|
-
key_cache, value_cache = values
|
|
381
|
-
|
|
382
|
-
class _config:
|
|
383
|
-
def __init__(self):
|
|
384
|
-
self.head_dim = key_cache[0].shape[-1]
|
|
385
|
-
self.num_attention_heads = key_cache[0].shape[1]
|
|
386
|
-
self.num_hidden_layers = len(key_cache)
|
|
387
|
-
self.sliding_window = key_cache[0].shape[2]
|
|
388
|
-
|
|
389
|
-
cache = SlidingWindowCache(
|
|
390
|
-
_config(),
|
|
391
|
-
max_batch_size=key_cache[0].shape[0],
|
|
392
|
-
max_cache_len=key_cache[0].shape[2], # sligding window
|
|
393
|
-
device=key_cache[0].device,
|
|
394
|
-
dtype=key_cache[0].dtype,
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
values = dict(zip(context, values))
|
|
398
|
-
for k, v in values.items():
|
|
399
|
-
setattr(cache, k, v)
|
|
400
|
-
return cache
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
#####################
|
|
404
|
-
# EncoderDecoderCache
|
|
405
|
-
#####################
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
def flatten_encoder_decoder_cache(
|
|
409
|
-
ec_cache: EncoderDecoderCache,
|
|
410
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
411
|
-
"""
|
|
412
|
-
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
413
|
-
with python objects.
|
|
414
|
-
"""
|
|
415
|
-
dictionary = {
|
|
416
|
-
"self_attention_cache": ec_cache.self_attention_cache,
|
|
417
|
-
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
418
|
-
}
|
|
419
|
-
return torch.utils._pytree._dict_flatten(dictionary)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
|
|
423
|
-
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
424
|
-
torch.utils._pytree.Context,
|
|
425
|
-
]:
|
|
426
|
-
"""
|
|
427
|
-
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
428
|
-
with python objects.
|
|
429
|
-
"""
|
|
430
|
-
dictionary = {
|
|
431
|
-
"self_attention_cache": ec_cache.self_attention_cache,
|
|
432
|
-
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
433
|
-
}
|
|
434
|
-
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
def unflatten_encoder_decoder_cache(
|
|
438
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
439
|
-
) -> EncoderDecoderCache:
|
|
440
|
-
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
|
|
441
|
-
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
|
442
|
-
return EncoderDecoderCache(**dictionary)
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
#################
|
|
446
|
-
# BaseModelOutput
|
|
447
|
-
#################
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
def flatten_base_model_output(
|
|
451
|
-
bo: BaseModelOutput,
|
|
452
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
453
|
-
"""
|
|
454
|
-
Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
|
|
455
|
-
with python objects.
|
|
456
|
-
"""
|
|
457
|
-
return list(bo.values()), list(bo.keys())
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
def flatten_with_keys_base_model_output(
|
|
461
|
-
bo: BaseModelOutput,
|
|
462
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
463
|
-
"""
|
|
464
|
-
Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
|
|
465
|
-
with python objects.
|
|
466
|
-
"""
|
|
467
|
-
values, context = flatten_base_model_output(bo)
|
|
468
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
def unflatten_base_model_output(
|
|
472
|
-
values: List[Any],
|
|
473
|
-
context: torch.utils._pytree.Context,
|
|
474
|
-
output_type=None,
|
|
475
|
-
) -> BaseModelOutput:
|
|
476
|
-
"""
|
|
477
|
-
Restores a :class:`transformers.modeling_outputs.BaseModelOutput`
|
|
478
|
-
from python objects.
|
|
479
|
-
"""
|
|
480
|
-
return BaseModelOutput(**dict(zip(context, values)))
|
|
277
|
+
unregister_class_serialization(cls, verbose)
|