onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,687 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
2
|
+
import packaging.version as pv
|
|
3
|
+
import torch
|
|
4
|
+
import transformers
|
|
5
|
+
import transformers.cache_utils
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CacheKeyValue:
|
|
9
|
+
"""
|
|
10
|
+
Starting transformers>=4.54, the cache API has deprecated
|
|
11
|
+
``cache.key_cache`` and ``cache.value_cache``.
|
|
12
|
+
This class wraps a cache independently from transformers version and enables
|
|
13
|
+
attributes ``key_cache`` and ``value_cache``.
|
|
14
|
+
|
|
15
|
+
.. code-block:: python
|
|
16
|
+
|
|
17
|
+
capi = CacheKeyValue(cache)
|
|
18
|
+
capi.key_cache
|
|
19
|
+
capi.value_cache
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, cache=None):
|
|
23
|
+
if hasattr(cache, "layers"):
|
|
24
|
+
layers = [
|
|
25
|
+
layer
|
|
26
|
+
for layer in cache.layers
|
|
27
|
+
if layer is not None and layer.keys is not None and layer.values is not None
|
|
28
|
+
]
|
|
29
|
+
self.key_cache = [layer.keys for layer in layers]
|
|
30
|
+
self.value_cache = [layer.values for layer in layers]
|
|
31
|
+
if None in self.key_cache or None in self.value_cache:
|
|
32
|
+
from .helper import string_type
|
|
33
|
+
|
|
34
|
+
raise AssertionError(
|
|
35
|
+
f"issue with key_cache={string_type(self.key_cache)}, "
|
|
36
|
+
f"or value_cache={string_type(self.value_cache)}, "
|
|
37
|
+
f"cache.layers={string_type(cache.layers)}"
|
|
38
|
+
)
|
|
39
|
+
elif cache is not None and hasattr(cache, "key_cache"):
|
|
40
|
+
self.key_cache = cache.key_cache
|
|
41
|
+
self.value_cache = cache.value_cache
|
|
42
|
+
elif cache is None:
|
|
43
|
+
self.key_cache = None
|
|
44
|
+
self.value_cache = None
|
|
45
|
+
else:
|
|
46
|
+
raise NotImplementedError(f"type(cache)={type(cache)}")
|
|
47
|
+
|
|
48
|
+
def make_dynamic_cache(self):
|
|
49
|
+
"""Does the reverse operation."""
|
|
50
|
+
return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def n_layers(self) -> int:
|
|
54
|
+
"""Returns the number of layers."""
|
|
55
|
+
return len(self.key_cache) if self.key_cache else 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def flatten_unflatten_for_dynamic_shapes(
|
|
59
|
+
obj: Any,
|
|
60
|
+
use_dict: bool = False,
|
|
61
|
+
change_function: Optional[Callable[[torch.Tensor], Any]] = None,
|
|
62
|
+
) -> Any:
|
|
63
|
+
"""
|
|
64
|
+
Returns the object in a different structure similar to what
|
|
65
|
+
the definition of the dynamic shapes should use.
|
|
66
|
+
|
|
67
|
+
:param obj: object from a custom class
|
|
68
|
+
:param use_dict: closer to the original result but
|
|
69
|
+
:func:`torch.export.export` only considers the values,
|
|
70
|
+
the context gives the dictionary keys but it is not expressed
|
|
71
|
+
in the dynamic shapes, these specifications seems to be different
|
|
72
|
+
for the strict and non strict mode. It also preserves tuple.
|
|
73
|
+
:param change_function: to modifies the tensor in the structure itself,
|
|
74
|
+
like replace them by a shape
|
|
75
|
+
:return: the serialized object
|
|
76
|
+
"""
|
|
77
|
+
if isinstance(obj, torch.Tensor):
|
|
78
|
+
return change_function(obj) if change_function else obj
|
|
79
|
+
flat, spec = torch.utils._pytree.tree_flatten(obj)
|
|
80
|
+
start = 0
|
|
81
|
+
end = 0
|
|
82
|
+
subtrees = []
|
|
83
|
+
for subspec in spec.children_specs:
|
|
84
|
+
end += subspec.num_leaves
|
|
85
|
+
value = subspec.unflatten(flat[start:end])
|
|
86
|
+
value = flatten_unflatten_for_dynamic_shapes(
|
|
87
|
+
value, use_dict=use_dict, change_function=change_function
|
|
88
|
+
)
|
|
89
|
+
subtrees.append(value)
|
|
90
|
+
start = end
|
|
91
|
+
if use_dict:
|
|
92
|
+
if spec.type is dict:
|
|
93
|
+
# This a dictionary.
|
|
94
|
+
return dict(zip(spec.context, subtrees))
|
|
95
|
+
if spec.type is tuple:
|
|
96
|
+
return tuple(subtrees)
|
|
97
|
+
if spec.type is list:
|
|
98
|
+
return list(subtrees)
|
|
99
|
+
if spec.type is None and not subtrees:
|
|
100
|
+
return None
|
|
101
|
+
if spec.context:
|
|
102
|
+
# This is a custom class with attributes.
|
|
103
|
+
# It is returned as a list.
|
|
104
|
+
return list(subtrees)
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Unable to interpret spec type {spec.type} "
|
|
107
|
+
f"(type is {type(spec.type)}, context is {spec.context}), "
|
|
108
|
+
f"spec={spec}, subtrees={subtrees}"
|
|
109
|
+
)
|
|
110
|
+
# This is a list.
|
|
111
|
+
return subtrees
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def is_cache_dynamic_registered(fast: bool = False) -> bool:
|
|
115
|
+
"""
|
|
116
|
+
Tells if class :class:`transformers.cache_utils.DynamicCache` can be
|
|
117
|
+
serialized and deserialized. Only then, :func:`torch.export.export`
|
|
118
|
+
can export a model.
|
|
119
|
+
|
|
120
|
+
:param fast: if True, do not check the serialization is ok as well
|
|
121
|
+
:return: result
|
|
122
|
+
"""
|
|
123
|
+
if fast:
|
|
124
|
+
return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES
|
|
125
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
126
|
+
cache = make_dynamic_cache(
|
|
127
|
+
[
|
|
128
|
+
(
|
|
129
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
130
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
131
|
+
)
|
|
132
|
+
for i in range(2)
|
|
133
|
+
]
|
|
134
|
+
)
|
|
135
|
+
values, spec = torch.utils._pytree.tree_flatten(cache)
|
|
136
|
+
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
|
|
137
|
+
if hasattr(cache2, "layers") and hasattr(cache, "layers"):
|
|
138
|
+
return len(cache2.layers) == len(cache.layers)
|
|
139
|
+
return len(cache2.key_cache) == len(cache.value_cache)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def make_dynamic_shapes_kv_cache(
|
|
143
|
+
cache: transformers.cache_utils.Cache, shape_of_one: Dict[int, Any]
|
|
144
|
+
) -> List[Dict[int, Any]]:
|
|
145
|
+
"""
|
|
146
|
+
Returns the dynamic shapes for key-value cache
|
|
147
|
+
|
|
148
|
+
:param cache: a cache
|
|
149
|
+
:param shape_of_one: shape of one element
|
|
150
|
+
:return: dynamic shapes
|
|
151
|
+
"""
|
|
152
|
+
return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _preprocess_key_value_pairs(
|
|
156
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
157
|
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
|
158
|
+
if not key_value_pairs or isinstance(key_value_pairs[0], tuple):
|
|
159
|
+
return key_value_pairs
|
|
160
|
+
return list(zip(key_value_pairs[::2], key_value_pairs[1::2]))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
164
|
+
|
|
165
|
+
def make_dynamic_cache(
|
|
166
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
167
|
+
) -> transformers.cache_utils.DynamicCache:
|
|
168
|
+
"""
|
|
169
|
+
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
170
|
+
This version is valid for ``transformers >= 4.50``.
|
|
171
|
+
|
|
172
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
173
|
+
:return: :class:`transformers.cache_utils.DynamicCache`
|
|
174
|
+
|
|
175
|
+
Example:
|
|
176
|
+
|
|
177
|
+
.. runpython::
|
|
178
|
+
:showcode:
|
|
179
|
+
|
|
180
|
+
import torch
|
|
181
|
+
from onnx_diagnostic.helpers import string_type
|
|
182
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
183
|
+
|
|
184
|
+
n_layers = 2
|
|
185
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
186
|
+
|
|
187
|
+
past_key_values = make_dynamic_cache(
|
|
188
|
+
[
|
|
189
|
+
(
|
|
190
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
191
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
192
|
+
)
|
|
193
|
+
for i in range(n_layers)
|
|
194
|
+
]
|
|
195
|
+
)
|
|
196
|
+
print(string_type(past_key_values, with_shape=True))
|
|
197
|
+
|
|
198
|
+
The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
|
|
199
|
+
``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
|
|
200
|
+
are supported.
|
|
201
|
+
"""
|
|
202
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
203
|
+
if (
|
|
204
|
+
key_value_pairs
|
|
205
|
+
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
206
|
+
and pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
207
|
+
):
|
|
208
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
209
|
+
cache.layers.extend(
|
|
210
|
+
[transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
|
|
211
|
+
)
|
|
212
|
+
for i, layer in enumerate(cache.layers):
|
|
213
|
+
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
214
|
+
layer.dtype = k.dtype
|
|
215
|
+
layer.device = k.device
|
|
216
|
+
layer.keys = k
|
|
217
|
+
layer.values = v
|
|
218
|
+
layer.is_initialized = True
|
|
219
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
220
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
221
|
+
f"{len(key_value_pairs)} expected."
|
|
222
|
+
)
|
|
223
|
+
return finalize_cache(cache)
|
|
224
|
+
|
|
225
|
+
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
|
|
226
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
227
|
+
# The cache constructor contains the two following lines
|
|
228
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
229
|
+
# initialized. We need to remove them.
|
|
230
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
231
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
232
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
233
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
234
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
235
|
+
f"{len(key_value_pairs)} expected."
|
|
236
|
+
)
|
|
237
|
+
return finalize_cache(cache)
|
|
238
|
+
|
|
239
|
+
else:
|
|
240
|
+
|
|
241
|
+
def make_dynamic_cache(
|
|
242
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
243
|
+
) -> transformers.cache_utils.DynamicCache:
|
|
244
|
+
"""
|
|
245
|
+
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
246
|
+
This version is valid for ``transformers < 4.50``.
|
|
247
|
+
|
|
248
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
249
|
+
:return: :class:`transformers.cache_utils.DynamicCache`
|
|
250
|
+
|
|
251
|
+
Example:
|
|
252
|
+
|
|
253
|
+
.. runpython::
|
|
254
|
+
:showcode:
|
|
255
|
+
|
|
256
|
+
import torch
|
|
257
|
+
from onnx_diagnostic.helpers import string_type
|
|
258
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
259
|
+
|
|
260
|
+
n_layers = 2
|
|
261
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
262
|
+
|
|
263
|
+
past_key_values = make_dynamic_cache(
|
|
264
|
+
[
|
|
265
|
+
(
|
|
266
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
267
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
268
|
+
)
|
|
269
|
+
for i in range(n_layers)
|
|
270
|
+
]
|
|
271
|
+
)
|
|
272
|
+
print(string_type(past_key_values, with_shape=True))
|
|
273
|
+
"""
|
|
274
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
275
|
+
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
|
|
276
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
277
|
+
cache.update(key, value, i)
|
|
278
|
+
return cache
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def make_static_cache(
|
|
282
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
283
|
+
max_cache_len: Optional[int] = None,
|
|
284
|
+
) -> transformers.cache_utils.DynamicCache:
|
|
285
|
+
"""
|
|
286
|
+
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
|
|
287
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
288
|
+
:param max_cache_len: max_cache_length or something inferred from the vector
|
|
289
|
+
:return: :class:`transformers.cache_utils.StaticCache`
|
|
290
|
+
|
|
291
|
+
Example:
|
|
292
|
+
|
|
293
|
+
.. runpython::
|
|
294
|
+
:showcode:
|
|
295
|
+
|
|
296
|
+
import torch
|
|
297
|
+
from onnx_diagnostic.helpers import string_type
|
|
298
|
+
from onnx_diagnostic.helpers.cache_helper import make_static_cache
|
|
299
|
+
|
|
300
|
+
n_layers = 2
|
|
301
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
302
|
+
|
|
303
|
+
past_key_values = make_static_cache(
|
|
304
|
+
[
|
|
305
|
+
(
|
|
306
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
307
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
308
|
+
)
|
|
309
|
+
for i in range(n_layers)
|
|
310
|
+
],
|
|
311
|
+
max_cache_len=10,
|
|
312
|
+
)
|
|
313
|
+
print(string_type(past_key_values, with_shape=True))
|
|
314
|
+
"""
|
|
315
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
316
|
+
|
|
317
|
+
class _config:
|
|
318
|
+
def __init__(self):
|
|
319
|
+
self.head_dim = key_value_pairs[0][0].shape[-1]
|
|
320
|
+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
321
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
322
|
+
|
|
323
|
+
def get_text_config(self, *args, **kwargs):
|
|
324
|
+
return self
|
|
325
|
+
|
|
326
|
+
assert max_cache_len is not None, (
|
|
327
|
+
f"max_cache_len={max_cache_len} cannot be setup "
|
|
328
|
+
f"automatically yet from shape {key_value_pairs[0][0].shape}"
|
|
329
|
+
)
|
|
330
|
+
torch._check(
|
|
331
|
+
max_cache_len >= key_value_pairs[0][0].shape[2],
|
|
332
|
+
(
|
|
333
|
+
f"max_cache_len={max_cache_len} cannot be smaller "
|
|
334
|
+
f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
|
|
335
|
+
f"{key_value_pairs[0][0].shape}"
|
|
336
|
+
),
|
|
337
|
+
)
|
|
338
|
+
cache = transformers.cache_utils.StaticCache(
|
|
339
|
+
config=_config(),
|
|
340
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
341
|
+
device=key_value_pairs[0][0].device,
|
|
342
|
+
dtype=key_value_pairs[0][0].dtype,
|
|
343
|
+
max_cache_len=max_cache_len,
|
|
344
|
+
)
|
|
345
|
+
ca = CacheKeyValue(cache)
|
|
346
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
347
|
+
# transformers>= 4.55.2, layers are empty
|
|
348
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
349
|
+
cache.update(key, value, i)
|
|
350
|
+
return cache
|
|
351
|
+
|
|
352
|
+
torch._check(
|
|
353
|
+
not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers),
|
|
354
|
+
lambda: (
|
|
355
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
356
|
+
f"len(cache.layers)={len(cache.layers)}"
|
|
357
|
+
),
|
|
358
|
+
)
|
|
359
|
+
torch._check(
|
|
360
|
+
len(key_value_pairs) == len(ca.key_cache),
|
|
361
|
+
lambda: (
|
|
362
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
363
|
+
f"len(ca.key_cache)={len(ca.key_cache)}"
|
|
364
|
+
),
|
|
365
|
+
)
|
|
366
|
+
torch._check(
|
|
367
|
+
len(key_value_pairs) == len(ca.value_cache),
|
|
368
|
+
lambda: (
|
|
369
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
370
|
+
f"len(ca.value_cache)={len(ca.value_cache)}"
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
for i in range(len(key_value_pairs)):
|
|
374
|
+
assert (
|
|
375
|
+
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
|
|
376
|
+
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
|
|
377
|
+
d = key_value_pairs[i][1].shape[2]
|
|
378
|
+
ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
|
|
379
|
+
ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
|
|
380
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
381
|
+
# The cache constructor contains the two following lines
|
|
382
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
383
|
+
# initialized. We need to remove them.
|
|
384
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
385
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
386
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
387
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
388
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
389
|
+
f"{len(key_value_pairs)} expected."
|
|
390
|
+
)
|
|
391
|
+
return finalize_cache(cache)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def make_encoder_decoder_cache(
|
|
395
|
+
self_attention_cache: transformers.cache_utils.DynamicCache,
|
|
396
|
+
cross_attention_cache: transformers.cache_utils.DynamicCache,
|
|
397
|
+
) -> transformers.cache_utils.EncoderDecoderCache:
|
|
398
|
+
"""Creates an EncoderDecoderCache."""
|
|
399
|
+
return transformers.cache_utils.EncoderDecoderCache(
|
|
400
|
+
# self_attention_cache=self_attention_cache,
|
|
401
|
+
# cross_attention_cache=cross_attention_cache
|
|
402
|
+
self_attention_cache,
|
|
403
|
+
cross_attention_cache,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def make_mamba_cache(
|
|
408
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
409
|
+
) -> "MambaCache": # noqa: F821
|
|
410
|
+
"Creates a ``MambaCache``."
|
|
411
|
+
# import is moved here because this part is slow.
|
|
412
|
+
try:
|
|
413
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
414
|
+
except ImportError:
|
|
415
|
+
from transformers.cache_utils import MambaCache
|
|
416
|
+
dtype = key_value_pairs[0][0].dtype
|
|
417
|
+
|
|
418
|
+
class _config:
|
|
419
|
+
def __init__(self):
|
|
420
|
+
self.intermediate_size = key_value_pairs[0][0].shape[1]
|
|
421
|
+
self.conv_kernel = key_value_pairs[0][0].shape[-1]
|
|
422
|
+
self.state_size = key_value_pairs[0][1].shape[-1]
|
|
423
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
424
|
+
self.dtype = dtype
|
|
425
|
+
|
|
426
|
+
def get_text_config(self, *args, **kwargs):
|
|
427
|
+
return self
|
|
428
|
+
|
|
429
|
+
cache = MambaCache(
|
|
430
|
+
_config(),
|
|
431
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
432
|
+
device=key_value_pairs[0][0].device,
|
|
433
|
+
dtype=dtype,
|
|
434
|
+
)
|
|
435
|
+
for i in range(len(key_value_pairs)):
|
|
436
|
+
assert cache.conv_states[i].dtype == dtype, (
|
|
437
|
+
f"Type mismatch for cache.conv_states[{i}].dtype="
|
|
438
|
+
f"{cache.conv_states[i].dtype} != {dtype}"
|
|
439
|
+
)
|
|
440
|
+
assert cache.ssm_states[i].dtype == dtype, (
|
|
441
|
+
f"Type mismatch for cache.ssm_states[{i}].dtype="
|
|
442
|
+
f"{cache.ssm_states[i].dtype} != {dtype}"
|
|
443
|
+
)
|
|
444
|
+
assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
|
|
445
|
+
f"Shape mismatch, expected {cache.conv_states[i].shape}, "
|
|
446
|
+
f"got {key_value_pairs[i][0].shape}"
|
|
447
|
+
)
|
|
448
|
+
cache.conv_states[i][:, :, :] = key_value_pairs[i][0]
|
|
449
|
+
assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, (
|
|
450
|
+
f"Shape mismatch, expected {cache.ssm_states[i].shape}, "
|
|
451
|
+
f"got {key_value_pairs[i][1].shape}"
|
|
452
|
+
)
|
|
453
|
+
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
|
|
454
|
+
return finalize_cache(cache)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def make_sliding_window_cache(
|
|
458
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
459
|
+
) -> transformers.cache_utils.SlidingWindowCache:
|
|
460
|
+
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
461
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
462
|
+
|
|
463
|
+
class _config:
|
|
464
|
+
def __init__(self):
|
|
465
|
+
self.head_dim = key_value_pairs[0][0].shape[-1]
|
|
466
|
+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
467
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
468
|
+
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
469
|
+
|
|
470
|
+
def get_text_config(self, *args, **kwargs):
|
|
471
|
+
return self
|
|
472
|
+
|
|
473
|
+
cache = transformers.cache_utils.SlidingWindowCache(
|
|
474
|
+
config=_config(),
|
|
475
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
476
|
+
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
|
|
477
|
+
device=key_value_pairs[0][0].device,
|
|
478
|
+
dtype=key_value_pairs[0][0].dtype,
|
|
479
|
+
)
|
|
480
|
+
ca = CacheKeyValue(cache)
|
|
481
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
482
|
+
# transformers>= 4.55.2, layers are empty
|
|
483
|
+
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
|
|
484
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
485
|
+
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
|
|
486
|
+
return cache
|
|
487
|
+
|
|
488
|
+
for i in range(len(key_value_pairs)):
|
|
489
|
+
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
490
|
+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
491
|
+
f"got {key_value_pairs[i][0].shape}"
|
|
492
|
+
)
|
|
493
|
+
ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
494
|
+
assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
495
|
+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
496
|
+
f"got {key_value_pairs[i][1].shape}"
|
|
497
|
+
)
|
|
498
|
+
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
499
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
500
|
+
# The cache constructor contains the two following lines
|
|
501
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
502
|
+
# initialized. We need to remove them.
|
|
503
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
504
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
505
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
506
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
507
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
508
|
+
f"{len(key_value_pairs)} expected."
|
|
509
|
+
)
|
|
510
|
+
return finalize_cache(cache)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def make_hybrid_cache(
|
|
514
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
515
|
+
max_cache_len: Optional[int] = None,
|
|
516
|
+
max_batch_size: Optional[int] = None,
|
|
517
|
+
sliding_window: Optional[int] = None,
|
|
518
|
+
) -> transformers.cache_utils.HybridCache:
|
|
519
|
+
"""
|
|
520
|
+
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
|
|
521
|
+
This version is valid for ``transformers < 4.50``.
|
|
522
|
+
|
|
523
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
524
|
+
:return: :class:`transformers.cache_utils.HybridCache`
|
|
525
|
+
|
|
526
|
+
Example:
|
|
527
|
+
|
|
528
|
+
.. runpython::
|
|
529
|
+
:showcode:
|
|
530
|
+
|
|
531
|
+
import torch
|
|
532
|
+
from onnx_diagnostic.helpers import string_type
|
|
533
|
+
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
|
|
534
|
+
|
|
535
|
+
n_layers = 2
|
|
536
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
537
|
+
|
|
538
|
+
past_key_values = make_hybrid_cache(
|
|
539
|
+
[
|
|
540
|
+
(
|
|
541
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
542
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
543
|
+
)
|
|
544
|
+
for i in range(n_layers)
|
|
545
|
+
]
|
|
546
|
+
)
|
|
547
|
+
print(string_type(past_key_values, with_shape=True))
|
|
548
|
+
|
|
549
|
+
This part defines how the shapes are working in one HybridCache.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
self.max_cache_len = (
|
|
554
|
+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
|
|
555
|
+
|
|
556
|
+
# Sliding layers can't be larger than the overall max cache len
|
|
557
|
+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
|
558
|
+
self.max_batch_size = max_batch_size
|
|
559
|
+
|
|
560
|
+
self.head_dim = (
|
|
561
|
+
config.head_dim if hasattr(config, "head_dim")
|
|
562
|
+
else config.hidden_size // config.num_attention_heads
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
self._dtype = dtype
|
|
566
|
+
self.num_key_value_heads = (
|
|
567
|
+
config.num_attention_heads
|
|
568
|
+
if getattr(config, "num_key_value_heads", None) is None
|
|
569
|
+
else config.num_key_value_heads
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
|
573
|
+
if hasattr(config, "layer_types"):
|
|
574
|
+
self.is_sliding = [
|
|
575
|
+
layer_type != "full_attention" for layer_type in config.layer_types]
|
|
576
|
+
else:
|
|
577
|
+
self.is_sliding = [False] * config.num_hidden_layers
|
|
578
|
+
|
|
579
|
+
self.key_cache: list[torch.Tensor] = []
|
|
580
|
+
self.value_cache: list[torch.Tensor] = []
|
|
581
|
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
582
|
+
self.max_cache_len, self.head_dim)
|
|
583
|
+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
584
|
+
self.sliding_window_len, self.head_dim)
|
|
585
|
+
self.sliding_window = min(config.sliding_window, max_cache_len)
|
|
586
|
+
device = torch.device(device) if device is not None else None
|
|
587
|
+
for i in range(config.num_hidden_layers):
|
|
588
|
+
layer_device = layer_device_map[i] if layer_device_map is not None else device
|
|
589
|
+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
|
590
|
+
new_layer_key_cache = torch.zeros(
|
|
591
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
592
|
+
new_layer_value_cache = torch.zeros(
|
|
593
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
594
|
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
595
|
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
596
|
+
self.key_cache.append(new_layer_key_cache)
|
|
597
|
+
self.value_cache.append(new_layer_value_cache)
|
|
598
|
+
"""
|
|
599
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
600
|
+
layer_types = None
|
|
601
|
+
if key_value_pairs:
|
|
602
|
+
assert (
|
|
603
|
+
not max_batch_size and not max_cache_len
|
|
604
|
+
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
|
|
605
|
+
max_batch_size = key_value_pairs[0][0].shape[0]
|
|
606
|
+
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
|
|
607
|
+
if len(sets_of_dim) == 1:
|
|
608
|
+
max_cache_len = sets_of_dim.pop()
|
|
609
|
+
sliding_window = max_cache_len
|
|
610
|
+
else:
|
|
611
|
+
assert (
|
|
612
|
+
len(sets_of_dim) == 2
|
|
613
|
+
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
|
|
614
|
+
max_cache_len = max(sets_of_dim)
|
|
615
|
+
sliding_window = min(sets_of_dim)
|
|
616
|
+
layer_types = [
|
|
617
|
+
"full_attention" if i == max_cache_len else "sliding_attention"
|
|
618
|
+
for i in [kv[0].shape[2] for kv in key_value_pairs]
|
|
619
|
+
]
|
|
620
|
+
else:
|
|
621
|
+
assert (
|
|
622
|
+
max_batch_size and max_cache_len
|
|
623
|
+
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
|
|
624
|
+
if sliding_window is None:
|
|
625
|
+
sliding_window = max_cache_len
|
|
626
|
+
_max_cache_len = max_cache_len
|
|
627
|
+
_sliding_window = sliding_window
|
|
628
|
+
|
|
629
|
+
class _config:
|
|
630
|
+
max_cache_len = _max_cache_len
|
|
631
|
+
batch_size = max_batch_size
|
|
632
|
+
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
|
|
633
|
+
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
|
|
634
|
+
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
|
|
635
|
+
num_hidden_layers = len(key_value_pairs)
|
|
636
|
+
sliding_window = _sliding_window
|
|
637
|
+
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
638
|
+
|
|
639
|
+
def get_text_config(self, *args, **kwargs):
|
|
640
|
+
return self
|
|
641
|
+
|
|
642
|
+
if layer_types:
|
|
643
|
+
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
644
|
+
|
|
645
|
+
cache = transformers.cache_utils.HybridCache(
|
|
646
|
+
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
|
|
647
|
+
)
|
|
648
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
649
|
+
cache.update(
|
|
650
|
+
key,
|
|
651
|
+
value,
|
|
652
|
+
i,
|
|
653
|
+
cache_kwargs={
|
|
654
|
+
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
|
|
655
|
+
key.device
|
|
656
|
+
)
|
|
657
|
+
},
|
|
658
|
+
)
|
|
659
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
660
|
+
# The cache constructor contains the two following lines
|
|
661
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
662
|
+
# initialized. We need to remove them.
|
|
663
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
664
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
665
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
666
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
667
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
668
|
+
f"{len(key_value_pairs)} expected."
|
|
669
|
+
)
|
|
670
|
+
return finalize_cache(cache)
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
|
|
674
|
+
"""
|
|
675
|
+
Ensures the created cache is consistent.
|
|
676
|
+
Returns the cache modified inplace.
|
|
677
|
+
"""
|
|
678
|
+
if (
|
|
679
|
+
hasattr(cache, "layer_class_to_replicate")
|
|
680
|
+
and hasattr(cache, "layers")
|
|
681
|
+
and cache.layers
|
|
682
|
+
and not cache.layer_class_to_replicate
|
|
683
|
+
):
|
|
684
|
+
# This is used to expand the cache when it does not contains enough layers.
|
|
685
|
+
# This is needed since transformers>4.55.3
|
|
686
|
+
cache.layer_class_to_replicate = cache.layers[0].__class__
|
|
687
|
+
return cache
|