onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +56 -3
- onnx_diagnostic/export/dynamic_shapes.py +24 -10
- onnx_diagnostic/export/shape_helper.py +6 -2
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +6 -6
- onnx_diagnostic/helpers/cache_helper.py +326 -18
- onnx_diagnostic/helpers/config_helper.py +10 -0
- onnx_diagnostic/helpers/helper.py +152 -11
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
- onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
- onnx_diagnostic/tasks/feature_extraction.py +7 -3
- onnx_diagnostic/tasks/fill_mask.py +6 -2
- onnx_diagnostic/tasks/image_classification.py +6 -2
- onnx_diagnostic/tasks/image_text_to_text.py +289 -62
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
- onnx_diagnostic/tasks/object_detection.py +6 -2
- onnx_diagnostic/tasks/sentence_similarity.py +6 -2
- onnx_diagnostic/tasks/summarization.py +7 -2
- onnx_diagnostic/tasks/text2text_generation.py +7 -2
- onnx_diagnostic/tasks/text_classification.py +6 -2
- onnx_diagnostic/tasks/text_generation.py +14 -16
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +5 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +428 -129
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +60 -41
- onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- onnx_diagnostic/torch_models/validate.py +1 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +43 -42
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -306,7 +306,7 @@ class _ParseDict(argparse.Action):
|
|
|
306
306
|
value = split_items[1]
|
|
307
307
|
|
|
308
308
|
if value in ("True", "true", "False", "false"):
|
|
309
|
-
d[key] =
|
|
309
|
+
d[key] = value in ("True", "true")
|
|
310
310
|
continue
|
|
311
311
|
try:
|
|
312
312
|
d[key] = int(value)
|
|
@@ -323,6 +323,54 @@ class _ParseDict(argparse.Action):
|
|
|
323
323
|
setattr(namespace, self.dest, d)
|
|
324
324
|
|
|
325
325
|
|
|
326
|
+
class _BoolOrParseDictPatch(argparse.Action):
|
|
327
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
328
|
+
|
|
329
|
+
if not values:
|
|
330
|
+
return
|
|
331
|
+
if len(values) == 1 and values[0] in (
|
|
332
|
+
"True",
|
|
333
|
+
"False",
|
|
334
|
+
"true",
|
|
335
|
+
"false",
|
|
336
|
+
"0",
|
|
337
|
+
"1",
|
|
338
|
+
0,
|
|
339
|
+
1,
|
|
340
|
+
):
|
|
341
|
+
setattr(namespace, self.dest, values[0] in ("True", "true", 1, "1"))
|
|
342
|
+
return
|
|
343
|
+
d = getattr(namespace, self.dest) or {}
|
|
344
|
+
if not isinstance(d, dict):
|
|
345
|
+
d = {
|
|
346
|
+
"patch_sympy": d,
|
|
347
|
+
"patch_torch": d,
|
|
348
|
+
"patch_transformers": d,
|
|
349
|
+
"patch_diffusers": d,
|
|
350
|
+
}
|
|
351
|
+
for item in values:
|
|
352
|
+
split_items = item.split("=", 1)
|
|
353
|
+
key = split_items[0].strip() # we remove blanks around keys, as is logical
|
|
354
|
+
value = split_items[1]
|
|
355
|
+
|
|
356
|
+
if value in ("True", "true", "False", "false"):
|
|
357
|
+
d[key] = value in ("True", "true")
|
|
358
|
+
continue
|
|
359
|
+
try:
|
|
360
|
+
d[key] = int(value)
|
|
361
|
+
continue
|
|
362
|
+
except (TypeError, ValueError):
|
|
363
|
+
pass
|
|
364
|
+
try:
|
|
365
|
+
d[key] = float(value)
|
|
366
|
+
continue
|
|
367
|
+
except (TypeError, ValueError):
|
|
368
|
+
pass
|
|
369
|
+
d[key] = _parse_json(value)
|
|
370
|
+
|
|
371
|
+
setattr(namespace, self.dest, d)
|
|
372
|
+
|
|
373
|
+
|
|
326
374
|
def get_parser_validate() -> ArgumentParser:
|
|
327
375
|
parser = ArgumentParser(
|
|
328
376
|
prog="validate",
|
|
@@ -383,8 +431,13 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
383
431
|
parser.add_argument(
|
|
384
432
|
"--patch",
|
|
385
433
|
default=True,
|
|
386
|
-
action=
|
|
387
|
-
|
|
434
|
+
action=_BoolOrParseDictPatch,
|
|
435
|
+
nargs="*",
|
|
436
|
+
help="Applies patches before exporting, it can be a boolean "
|
|
437
|
+
"to enable to disable the patches or be more finetuned. It is possible to "
|
|
438
|
+
"disable patch for torch by adding "
|
|
439
|
+
'--patch "patch_sympy=False" --patch "patch_torch=False", '
|
|
440
|
+
"default is True.",
|
|
388
441
|
)
|
|
389
442
|
parser.add_argument(
|
|
390
443
|
"--rewrite",
|
|
@@ -887,19 +887,30 @@ class ModelInputs:
|
|
|
887
887
|
|
|
888
888
|
# In case DynamicCache is not registered.
|
|
889
889
|
if obj.__class__.__name__ == "DynamicCache":
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
890
|
+
if hasattr(obj, "layers"):
|
|
891
|
+
kc = set(len(o.layers) for o in objs)
|
|
892
|
+
assert (
|
|
893
|
+
len(kc) == 1
|
|
894
|
+
), f"All attribute 'key_cache' should have the same length but found {kc}"
|
|
895
|
+
vc = kc.copy()
|
|
896
|
+
else:
|
|
897
|
+
kc = set(len(o.key_cache) for o in objs)
|
|
898
|
+
assert (
|
|
899
|
+
len(kc) == 1
|
|
900
|
+
), f"All attribute 'key_cache' should have the same length but found {kc}"
|
|
901
|
+
vc = set(len(o.value_cache) for o in objs)
|
|
902
|
+
assert (
|
|
903
|
+
len(vc) == 1
|
|
904
|
+
), f"All attribute 'value_cache' should have the same length but found {vc}"
|
|
905
|
+
|
|
898
906
|
key_cache = []
|
|
899
907
|
for i in range(kc.pop()):
|
|
900
908
|
key_cache.append(
|
|
901
909
|
self.guess_dynamic_dimensions(
|
|
902
|
-
*[
|
|
910
|
+
*[
|
|
911
|
+
o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
|
|
912
|
+
for o in objs
|
|
913
|
+
],
|
|
903
914
|
auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
|
|
904
915
|
)
|
|
905
916
|
)
|
|
@@ -907,7 +918,10 @@ class ModelInputs:
|
|
|
907
918
|
for i in range(vc.pop()):
|
|
908
919
|
value_cache.append(
|
|
909
920
|
self.guess_dynamic_dimensions(
|
|
910
|
-
*[
|
|
921
|
+
*[
|
|
922
|
+
o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
|
|
923
|
+
for o in objs
|
|
924
|
+
],
|
|
911
925
|
auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
|
|
912
926
|
)
|
|
913
927
|
)
|
|
@@ -9,6 +9,8 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
9
9
|
All dimensions are considered as dynamic.
|
|
10
10
|
``dim_prefix`` can be a string (the function uses it as a prefix),
|
|
11
11
|
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
|
|
12
|
+
Depending on the version of transformers, serializations function
|
|
13
|
+
of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
|
|
12
14
|
|
|
13
15
|
.. runpython::
|
|
14
16
|
:showcode:
|
|
@@ -17,6 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
17
19
|
import torch
|
|
18
20
|
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
19
21
|
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
|
|
22
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
20
23
|
|
|
21
24
|
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
22
25
|
inputs = dict(
|
|
@@ -25,10 +28,11 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
25
28
|
position_ids=torch.arange(3, dtype=torch.int64),
|
|
26
29
|
past_key_values=make_dynamic_cache(
|
|
27
30
|
[(torch.randn(bsize, nheads, slen, dim),
|
|
28
|
-
|
|
31
|
+
torch.randn(bsize, nheads, slen, dim))]
|
|
29
32
|
),
|
|
30
33
|
)
|
|
31
|
-
|
|
34
|
+
with torch_export_patches(patch_transformers=True):
|
|
35
|
+
ds = all_dynamic_shape_from_inputs(inputs)
|
|
32
36
|
pprint.pprint(ds)
|
|
33
37
|
|
|
34
38
|
For this function to work, patches must be enabled if :epkg:`transformers`
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -1058,6 +1058,8 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1058
1058
|
elif hasattr(expected, "shape"):
|
|
1059
1059
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1060
1060
|
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
|
|
1061
|
+
elif expected is None:
|
|
1062
|
+
assert value is None, f"Expected is None but value is of type {type(value)}"
|
|
1061
1063
|
else:
|
|
1062
1064
|
raise AssertionError(
|
|
1063
1065
|
f"Comparison not implemented for types {type(expected)} and {type(value)}"
|
|
@@ -33,13 +33,13 @@ def mann_kendall(series: Sequence[float], threshold: float = 0.5):
|
|
|
33
33
|
.. math::
|
|
34
34
|
|
|
35
35
|
sign(x) = \\left\\{ \\begin{array}{l} -1 if x < 0 \\\\ 0 if x = 0 \\\\ +1 otherwise
|
|
36
|
-
\\right.
|
|
36
|
+
\\end{array} \\right.
|
|
37
37
|
|
|
38
38
|
And:
|
|
39
39
|
|
|
40
40
|
.. math::
|
|
41
41
|
|
|
42
|
-
Var(S)= \\frac{n(n-1)(2n+5)
|
|
42
|
+
Var(S)= \\frac{n(n-1)(2n+5) - \\sum_t t(t-1)(2t+5)}{18}
|
|
43
43
|
"""
|
|
44
44
|
aseries = np.asarray(series)
|
|
45
45
|
stat = 0
|
|
@@ -251,7 +251,7 @@ def open_dataframe(
|
|
|
251
251
|
) -> pandas.DataFrame:
|
|
252
252
|
"""
|
|
253
253
|
Opens a filename defined by function
|
|
254
|
-
:func:`onnx_diagnostic.helpers.
|
|
254
|
+
:func:`onnx_diagnostic.helpers._log_helper.enumerate_csv_files`.
|
|
255
255
|
|
|
256
256
|
:param data: a dataframe, a filename, a tuple indicating the file is coming
|
|
257
257
|
from a zip file
|
|
@@ -260,17 +260,17 @@ def open_dataframe(
|
|
|
260
260
|
if isinstance(data, pandas.DataFrame):
|
|
261
261
|
return data
|
|
262
262
|
if isinstance(data, str):
|
|
263
|
-
df = pandas.read_csv(data)
|
|
263
|
+
df = pandas.read_csv(data, low_memory=False)
|
|
264
264
|
df["RAWFILENAME"] = data
|
|
265
265
|
return df
|
|
266
266
|
if isinstance(data, tuple):
|
|
267
267
|
if not data[-1]:
|
|
268
|
-
df = pandas.read_csv(data[2])
|
|
268
|
+
df = pandas.read_csv(data[2], low_memory=False)
|
|
269
269
|
df["RAWFILENAME"] = data[2]
|
|
270
270
|
return df
|
|
271
271
|
zf = zipfile.ZipFile(data[-1])
|
|
272
272
|
with zf.open(data[2]) as f:
|
|
273
|
-
df = pandas.read_csv(f)
|
|
273
|
+
df = pandas.read_csv(f, low_memory=False)
|
|
274
274
|
df["RAWFILENAME"] = f"{data[-1]}/{data[2]}"
|
|
275
275
|
zf.close()
|
|
276
276
|
return df
|
|
@@ -4,6 +4,56 @@ import torch
|
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
|
+
try:
|
|
8
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
9
|
+
except ImportError:
|
|
10
|
+
from transformers.cache_utils import MambaCache
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CacheKeyValue:
|
|
14
|
+
"""
|
|
15
|
+
Starting transformers>=4.54, the cache API has deprecated
|
|
16
|
+
``cache.key_cache`` and ``cache.value_cache``.
|
|
17
|
+
This class wraps a cache independently from transformers version and enables
|
|
18
|
+
attributes ``key_cache`` and ``value_cache``.
|
|
19
|
+
|
|
20
|
+
.. code-block:: python
|
|
21
|
+
|
|
22
|
+
capi = CacheKeyValue(cache)
|
|
23
|
+
capi.key_cache
|
|
24
|
+
capi.value_cache
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, cache=None):
|
|
28
|
+
if hasattr(cache, "layers"):
|
|
29
|
+
layers = [
|
|
30
|
+
layer
|
|
31
|
+
for layer in cache.layers
|
|
32
|
+
if layer is not None and layer.keys is not None and layer.values is not None
|
|
33
|
+
]
|
|
34
|
+
self.key_cache = [layer.keys for layer in layers]
|
|
35
|
+
self.value_cache = [layer.values for layer in layers]
|
|
36
|
+
if None in self.key_cache or None in self.value_cache:
|
|
37
|
+
from .helper import string_type
|
|
38
|
+
|
|
39
|
+
raise AssertionError(
|
|
40
|
+
f"issue with key_cache={string_type(self.key_cache)}, "
|
|
41
|
+
f"or value_cache={string_type(self.value_cache)}, "
|
|
42
|
+
f"cache.layers={string_type(cache.layers)}"
|
|
43
|
+
)
|
|
44
|
+
elif cache is not None and hasattr(cache, "key_cache"):
|
|
45
|
+
self.key_cache = cache.key_cache
|
|
46
|
+
self.value_cache = cache.value_cache
|
|
47
|
+
elif cache is None:
|
|
48
|
+
self.key_cache = None
|
|
49
|
+
self.value_cache = None
|
|
50
|
+
else:
|
|
51
|
+
raise NotImplementedError(f"type(cache)={type(cache)}")
|
|
52
|
+
|
|
53
|
+
def make_dynamic_cache(self):
|
|
54
|
+
"""Do the reverse operation."""
|
|
55
|
+
return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
|
|
56
|
+
|
|
7
57
|
|
|
8
58
|
def flatten_unflatten_for_dynamic_shapes(
|
|
9
59
|
obj: Any,
|
|
@@ -81,6 +131,8 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
|
|
|
81
131
|
)
|
|
82
132
|
values, spec = torch.utils._pytree.tree_flatten(cache)
|
|
83
133
|
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
|
|
134
|
+
if hasattr(cache2, "layers") and hasattr(cache, "layers"):
|
|
135
|
+
return len(cache2.layers) == len(cache.layers)
|
|
84
136
|
return len(cache2.key_cache) == len(cache.value_cache)
|
|
85
137
|
|
|
86
138
|
|
|
@@ -119,7 +171,19 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
119
171
|
)
|
|
120
172
|
print(string_type(past_key_values, with_shape=True))
|
|
121
173
|
"""
|
|
122
|
-
|
|
174
|
+
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
|
|
175
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
176
|
+
# The cache constructor contains the two following lines
|
|
177
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
178
|
+
# initialized. We need to remove them.
|
|
179
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
180
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
181
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
182
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
183
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
184
|
+
f"{len(key_value_pairs)} expected."
|
|
185
|
+
)
|
|
186
|
+
return finalize_cache(cache)
|
|
123
187
|
|
|
124
188
|
else:
|
|
125
189
|
|
|
@@ -203,6 +267,9 @@ def make_static_cache(
|
|
|
203
267
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
204
268
|
self.num_hidden_layers = len(key_value_pairs)
|
|
205
269
|
|
|
270
|
+
def get_text_config(self):
|
|
271
|
+
return self
|
|
272
|
+
|
|
206
273
|
assert max_cache_len is not None, (
|
|
207
274
|
f"max_cache_len={max_cache_len} cannot be setup "
|
|
208
275
|
f"automatically yet from shape {key_value_pairs[0][0].shape}"
|
|
@@ -216,20 +283,59 @@ def make_static_cache(
|
|
|
216
283
|
),
|
|
217
284
|
)
|
|
218
285
|
cache = transformers.cache_utils.StaticCache(
|
|
219
|
-
_config(),
|
|
286
|
+
config=_config(),
|
|
220
287
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
221
288
|
device=key_value_pairs[0][0].device,
|
|
222
289
|
dtype=key_value_pairs[0][0].dtype,
|
|
223
290
|
max_cache_len=max_cache_len,
|
|
224
291
|
)
|
|
292
|
+
ca = CacheKeyValue(cache)
|
|
293
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
294
|
+
# transformers>= 4.55.2, layers are empty
|
|
295
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
296
|
+
cache.update(key, value, i)
|
|
297
|
+
return cache
|
|
298
|
+
|
|
299
|
+
torch._check(
|
|
300
|
+
not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers),
|
|
301
|
+
lambda: (
|
|
302
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
303
|
+
f"len(cache.layers)={len(cache.layers)}"
|
|
304
|
+
),
|
|
305
|
+
)
|
|
306
|
+
torch._check(
|
|
307
|
+
len(key_value_pairs) == len(ca.key_cache),
|
|
308
|
+
lambda: (
|
|
309
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
310
|
+
f"len(ca.key_cache)={len(ca.key_cache)}"
|
|
311
|
+
),
|
|
312
|
+
)
|
|
313
|
+
torch._check(
|
|
314
|
+
len(key_value_pairs) == len(ca.value_cache),
|
|
315
|
+
lambda: (
|
|
316
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
317
|
+
f"len(ca.value_cache)={len(ca.value_cache)}"
|
|
318
|
+
),
|
|
319
|
+
)
|
|
225
320
|
for i in range(len(key_value_pairs)):
|
|
226
321
|
assert (
|
|
227
322
|
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
|
|
228
323
|
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
|
|
229
324
|
d = key_value_pairs[i][1].shape[2]
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
325
|
+
ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
|
|
326
|
+
ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
|
|
327
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
328
|
+
# The cache constructor contains the two following lines
|
|
329
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
330
|
+
# initialized. We need to remove them.
|
|
331
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
332
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
333
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
334
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
335
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
336
|
+
f"{len(key_value_pairs)} expected."
|
|
337
|
+
)
|
|
338
|
+
return finalize_cache(cache)
|
|
233
339
|
|
|
234
340
|
|
|
235
341
|
def make_encoder_decoder_cache(
|
|
@@ -238,14 +344,15 @@ def make_encoder_decoder_cache(
|
|
|
238
344
|
) -> transformers.cache_utils.EncoderDecoderCache:
|
|
239
345
|
"""Creates an EncoderDecoderCache."""
|
|
240
346
|
return transformers.cache_utils.EncoderDecoderCache(
|
|
241
|
-
self_attention_cache=self_attention_cache,
|
|
347
|
+
# self_attention_cache=self_attention_cache,
|
|
348
|
+
# cross_attention_cache=cross_attention_cache
|
|
349
|
+
self_attention_cache,
|
|
350
|
+
cross_attention_cache,
|
|
242
351
|
)
|
|
243
352
|
|
|
244
353
|
|
|
245
|
-
def make_mamba_cache(
|
|
246
|
-
|
|
247
|
-
) -> transformers.cache_utils.MambaCache:
|
|
248
|
-
"Creates a :class:`transformers.cache_utils.MambaCache`."
|
|
354
|
+
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
|
|
355
|
+
"Creates a ``MambaCache``."
|
|
249
356
|
dtype = key_value_pairs[0][0].dtype
|
|
250
357
|
|
|
251
358
|
class _config:
|
|
@@ -256,7 +363,10 @@ def make_mamba_cache(
|
|
|
256
363
|
self.num_hidden_layers = len(key_value_pairs)
|
|
257
364
|
self.dtype = dtype
|
|
258
365
|
|
|
259
|
-
|
|
366
|
+
def get_text_config(self):
|
|
367
|
+
return self
|
|
368
|
+
|
|
369
|
+
cache = MambaCache(
|
|
260
370
|
_config(),
|
|
261
371
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
262
372
|
device=key_value_pairs[0][0].device,
|
|
@@ -281,12 +391,12 @@ def make_mamba_cache(
|
|
|
281
391
|
f"got {key_value_pairs[i][1].shape}"
|
|
282
392
|
)
|
|
283
393
|
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
|
|
284
|
-
return cache
|
|
394
|
+
return finalize_cache(cache)
|
|
285
395
|
|
|
286
396
|
|
|
287
397
|
def make_sliding_window_cache(
|
|
288
398
|
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
289
|
-
) -> transformers.cache_utils.
|
|
399
|
+
) -> transformers.cache_utils.SlidingWindowCache:
|
|
290
400
|
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
291
401
|
|
|
292
402
|
class _config:
|
|
@@ -296,22 +406,220 @@ def make_sliding_window_cache(
|
|
|
296
406
|
self.num_hidden_layers = len(key_value_pairs)
|
|
297
407
|
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
298
408
|
|
|
409
|
+
def get_text_config(self):
|
|
410
|
+
return self
|
|
411
|
+
|
|
299
412
|
cache = transformers.cache_utils.SlidingWindowCache(
|
|
300
|
-
_config(),
|
|
413
|
+
config=_config(),
|
|
301
414
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
302
415
|
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
|
|
303
416
|
device=key_value_pairs[0][0].device,
|
|
304
417
|
dtype=key_value_pairs[0][0].dtype,
|
|
305
418
|
)
|
|
419
|
+
ca = CacheKeyValue(cache)
|
|
420
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
421
|
+
# transformers>= 4.55.2, layers are empty
|
|
422
|
+
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
|
|
423
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
424
|
+
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
|
|
425
|
+
return cache
|
|
426
|
+
|
|
306
427
|
for i in range(len(key_value_pairs)):
|
|
307
|
-
assert
|
|
428
|
+
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
308
429
|
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
309
430
|
f"got {key_value_pairs[i][0].shape}"
|
|
310
431
|
)
|
|
311
|
-
|
|
312
|
-
assert
|
|
432
|
+
ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
433
|
+
assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
313
434
|
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
314
435
|
f"got {key_value_pairs[i][1].shape}"
|
|
315
436
|
)
|
|
316
|
-
|
|
437
|
+
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
438
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
439
|
+
# The cache constructor contains the two following lines
|
|
440
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
441
|
+
# initialized. We need to remove them.
|
|
442
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
443
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
444
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
445
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
446
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
447
|
+
f"{len(key_value_pairs)} expected."
|
|
448
|
+
)
|
|
449
|
+
return finalize_cache(cache)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def make_hybrid_cache(
|
|
453
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
454
|
+
max_cache_len: Optional[int] = None,
|
|
455
|
+
max_batch_size: Optional[int] = None,
|
|
456
|
+
sliding_window: Optional[int] = None,
|
|
457
|
+
) -> transformers.cache_utils.HybridCache:
|
|
458
|
+
"""
|
|
459
|
+
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
|
|
460
|
+
This version is valid for ``transformers < 4.50``.
|
|
461
|
+
|
|
462
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
463
|
+
:return: :class:`transformers.cache_utils.HybridCache`
|
|
464
|
+
|
|
465
|
+
Example:
|
|
466
|
+
|
|
467
|
+
.. runpython::
|
|
468
|
+
:showcode:
|
|
469
|
+
|
|
470
|
+
import torch
|
|
471
|
+
from onnx_diagnostic.helpers import string_type
|
|
472
|
+
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
|
|
473
|
+
|
|
474
|
+
n_layers = 2
|
|
475
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
476
|
+
|
|
477
|
+
past_key_values = make_hybrid_cache(
|
|
478
|
+
[
|
|
479
|
+
(
|
|
480
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
481
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
482
|
+
)
|
|
483
|
+
for i in range(n_layers)
|
|
484
|
+
]
|
|
485
|
+
)
|
|
486
|
+
print(string_type(past_key_values, with_shape=True))
|
|
487
|
+
|
|
488
|
+
This part defines how the shapes are working in one HybridCache.
|
|
489
|
+
|
|
490
|
+
.. code-block:: python
|
|
491
|
+
|
|
492
|
+
self.max_cache_len = (
|
|
493
|
+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
|
|
494
|
+
|
|
495
|
+
# Sliding layers can't be larger than the overall max cache len
|
|
496
|
+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
|
497
|
+
self.max_batch_size = max_batch_size
|
|
498
|
+
|
|
499
|
+
self.head_dim = (
|
|
500
|
+
config.head_dim if hasattr(config, "head_dim")
|
|
501
|
+
else config.hidden_size // config.num_attention_heads
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
self._dtype = dtype
|
|
505
|
+
self.num_key_value_heads = (
|
|
506
|
+
config.num_attention_heads
|
|
507
|
+
if getattr(config, "num_key_value_heads", None) is None
|
|
508
|
+
else config.num_key_value_heads
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
|
512
|
+
if hasattr(config, "layer_types"):
|
|
513
|
+
self.is_sliding = [
|
|
514
|
+
layer_type != "full_attention" for layer_type in config.layer_types]
|
|
515
|
+
else:
|
|
516
|
+
self.is_sliding = [False] * config.num_hidden_layers
|
|
517
|
+
|
|
518
|
+
self.key_cache: list[torch.Tensor] = []
|
|
519
|
+
self.value_cache: list[torch.Tensor] = []
|
|
520
|
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
521
|
+
self.max_cache_len, self.head_dim)
|
|
522
|
+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
523
|
+
self.sliding_window_len, self.head_dim)
|
|
524
|
+
self.sliding_window = min(config.sliding_window, max_cache_len)
|
|
525
|
+
device = torch.device(device) if device is not None else None
|
|
526
|
+
for i in range(config.num_hidden_layers):
|
|
527
|
+
layer_device = layer_device_map[i] if layer_device_map is not None else device
|
|
528
|
+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
|
529
|
+
new_layer_key_cache = torch.zeros(
|
|
530
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
531
|
+
new_layer_value_cache = torch.zeros(
|
|
532
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
533
|
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
534
|
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
535
|
+
self.key_cache.append(new_layer_key_cache)
|
|
536
|
+
self.value_cache.append(new_layer_value_cache)
|
|
537
|
+
"""
|
|
538
|
+
layer_types = None
|
|
539
|
+
if key_value_pairs:
|
|
540
|
+
assert (
|
|
541
|
+
not max_batch_size and not max_cache_len
|
|
542
|
+
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
|
|
543
|
+
max_batch_size = key_value_pairs[0][0].shape[0]
|
|
544
|
+
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
|
|
545
|
+
if len(sets_of_dim) == 1:
|
|
546
|
+
max_cache_len = sets_of_dim.pop()
|
|
547
|
+
sliding_window = max_cache_len
|
|
548
|
+
else:
|
|
549
|
+
assert (
|
|
550
|
+
len(sets_of_dim) == 2
|
|
551
|
+
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
|
|
552
|
+
max_cache_len = max(sets_of_dim)
|
|
553
|
+
sliding_window = min(sets_of_dim)
|
|
554
|
+
layer_types = [
|
|
555
|
+
"full_attention" if i == max_cache_len else "sliding_attention"
|
|
556
|
+
for i in [kv[0].shape[2] for kv in key_value_pairs]
|
|
557
|
+
]
|
|
558
|
+
else:
|
|
559
|
+
assert (
|
|
560
|
+
max_batch_size and max_cache_len
|
|
561
|
+
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
|
|
562
|
+
if sliding_window is None:
|
|
563
|
+
sliding_window = max_cache_len
|
|
564
|
+
_max_cache_len = max_cache_len
|
|
565
|
+
_sliding_window = sliding_window
|
|
566
|
+
|
|
567
|
+
class _config:
|
|
568
|
+
max_cache_len = _max_cache_len
|
|
569
|
+
batch_size = max_batch_size
|
|
570
|
+
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
|
|
571
|
+
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
|
|
572
|
+
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
|
|
573
|
+
num_hidden_layers = len(key_value_pairs)
|
|
574
|
+
sliding_window = _sliding_window
|
|
575
|
+
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
576
|
+
|
|
577
|
+
def get_text_config(self):
|
|
578
|
+
return self
|
|
579
|
+
|
|
580
|
+
if layer_types:
|
|
581
|
+
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
582
|
+
|
|
583
|
+
cache = transformers.cache_utils.HybridCache(
|
|
584
|
+
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
|
|
585
|
+
)
|
|
586
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
587
|
+
cache.update(
|
|
588
|
+
key,
|
|
589
|
+
value,
|
|
590
|
+
i,
|
|
591
|
+
cache_kwargs={
|
|
592
|
+
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
|
|
593
|
+
key.device
|
|
594
|
+
)
|
|
595
|
+
},
|
|
596
|
+
)
|
|
597
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
598
|
+
# The cache constructor contains the two following lines
|
|
599
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
600
|
+
# initialized. We need to remove them.
|
|
601
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
602
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
603
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
604
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
605
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
606
|
+
f"{len(key_value_pairs)} expected."
|
|
607
|
+
)
|
|
608
|
+
return finalize_cache(cache)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
|
|
612
|
+
"""
|
|
613
|
+
Ensures the created cache is consistent.
|
|
614
|
+
Returns the cache modified inplace.
|
|
615
|
+
"""
|
|
616
|
+
if (
|
|
617
|
+
hasattr(cache, "layer_class_to_replicate")
|
|
618
|
+
and hasattr(cache, "layers")
|
|
619
|
+
and cache.layers
|
|
620
|
+
and not cache.layer_class_to_replicate
|
|
621
|
+
):
|
|
622
|
+
# This is used to expand the cache when it does not contains enough layers.
|
|
623
|
+
# This is needed since transformers>4.55.3
|
|
624
|
+
cache.layer_class_to_replicate = cache.layers[0].__class__
|
|
317
625
|
return cache
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import importlib
|
|
3
3
|
import inspect
|
|
4
|
+
import os
|
|
4
5
|
import re
|
|
5
6
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
6
7
|
import transformers
|
|
@@ -110,3 +111,12 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
|
|
|
110
111
|
)
|
|
111
112
|
cls_name = unique.pop()
|
|
112
113
|
return getattr(transformers, cls_name)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def default_num_hidden_layers():
|
|
117
|
+
"""
|
|
118
|
+
Returns the default number of layers.
|
|
119
|
+
It is lower when the unit tests are running
|
|
120
|
+
when ``UNITTEST_GOING=1``.
|
|
121
|
+
"""
|
|
122
|
+
return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4
|