onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +66 -8
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +461 -0
- onnx_diagnostic/helpers/cache_helper.py +250 -15
- onnx_diagnostic/helpers/helper.py +146 -10
- onnx_diagnostic/helpers/log_helper.py +404 -315
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/feature_extraction.py +86 -5
- onnx_diagnostic/tasks/image_text_to_text.py +260 -56
- onnx_diagnostic/tasks/mask_generation.py +139 -0
- onnx_diagnostic/tasks/text2text_generation.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
- onnx_diagnostic/torch_models/validate.py +26 -3
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,51 @@ import torch
|
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
|
+
try:
|
|
8
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
9
|
+
except ImportError:
|
|
10
|
+
from transformers.cache_utils import MambaCache
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CacheKeyValue:
|
|
14
|
+
"""
|
|
15
|
+
Starting transformers>=4.54, the cache API has deprecated
|
|
16
|
+
``cache.key_cache`` and ``cache.value_cache``.
|
|
17
|
+
This class wraps a cache independently from transformers version and enables
|
|
18
|
+
attributes ``key_cache`` and ``value_cache``.
|
|
19
|
+
|
|
20
|
+
.. code-block:: python
|
|
21
|
+
|
|
22
|
+
capi = CacheKeyValue(cache)
|
|
23
|
+
capi.key_cache
|
|
24
|
+
capi.value_cache
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, cache=None):
|
|
28
|
+
if hasattr(cache, "layers"):
|
|
29
|
+
layers = [
|
|
30
|
+
layer
|
|
31
|
+
for layer in cache.layers
|
|
32
|
+
if layer is not None and layer.keys is not None and layer.values is not None
|
|
33
|
+
]
|
|
34
|
+
self.key_cache = [layer.keys for layer in layers]
|
|
35
|
+
self.value_cache = [layer.values for layer in layers]
|
|
36
|
+
if None in self.key_cache or None in self.value_cache:
|
|
37
|
+
from .helper import string_type
|
|
38
|
+
|
|
39
|
+
raise AssertionError(
|
|
40
|
+
f"issue with key_cache={string_type(self.key_cache)}, "
|
|
41
|
+
f"or value_cache={string_type(self.value_cache)}, "
|
|
42
|
+
f"cache.layers={string_type(cache.layers)}"
|
|
43
|
+
)
|
|
44
|
+
elif cache is not None:
|
|
45
|
+
self.key_cache = cache.key_cache
|
|
46
|
+
self.value_cache = cache.value_cache
|
|
47
|
+
|
|
48
|
+
def make_dynamic_cache(self):
|
|
49
|
+
"""Do the reverse operation."""
|
|
50
|
+
return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
|
|
51
|
+
|
|
7
52
|
|
|
8
53
|
def flatten_unflatten_for_dynamic_shapes(
|
|
9
54
|
obj: Any,
|
|
@@ -119,7 +164,19 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
119
164
|
)
|
|
120
165
|
print(string_type(past_key_values, with_shape=True))
|
|
121
166
|
"""
|
|
122
|
-
|
|
167
|
+
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
|
|
168
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
169
|
+
# The cache constructor contains the two following lines
|
|
170
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
171
|
+
# initialized. We need to remove them.
|
|
172
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
173
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
174
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
175
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
176
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
177
|
+
f"{len(key_value_pairs)} expected."
|
|
178
|
+
)
|
|
179
|
+
return cache
|
|
123
180
|
|
|
124
181
|
else:
|
|
125
182
|
|
|
@@ -216,19 +273,31 @@ def make_static_cache(
|
|
|
216
273
|
),
|
|
217
274
|
)
|
|
218
275
|
cache = transformers.cache_utils.StaticCache(
|
|
219
|
-
_config(),
|
|
276
|
+
config=_config(),
|
|
220
277
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
221
278
|
device=key_value_pairs[0][0].device,
|
|
222
279
|
dtype=key_value_pairs[0][0].dtype,
|
|
223
280
|
max_cache_len=max_cache_len,
|
|
224
281
|
)
|
|
282
|
+
ca = CacheKeyValue(cache)
|
|
225
283
|
for i in range(len(key_value_pairs)):
|
|
226
284
|
assert (
|
|
227
285
|
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
|
|
228
286
|
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
|
|
229
287
|
d = key_value_pairs[i][1].shape[2]
|
|
230
|
-
|
|
231
|
-
|
|
288
|
+
ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
|
|
289
|
+
ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
|
|
290
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
291
|
+
# The cache constructor contains the two following lines
|
|
292
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
293
|
+
# initialized. We need to remove them.
|
|
294
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
295
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
296
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
297
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
298
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
299
|
+
f"{len(key_value_pairs)} expected."
|
|
300
|
+
)
|
|
232
301
|
return cache
|
|
233
302
|
|
|
234
303
|
|
|
@@ -242,10 +311,8 @@ def make_encoder_decoder_cache(
|
|
|
242
311
|
)
|
|
243
312
|
|
|
244
313
|
|
|
245
|
-
def make_mamba_cache(
|
|
246
|
-
|
|
247
|
-
) -> transformers.cache_utils.MambaCache:
|
|
248
|
-
"Creates a :class:`transformers.cache_utils.MambaCache`."
|
|
314
|
+
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
|
|
315
|
+
"Creates a ``MambaCache``."
|
|
249
316
|
dtype = key_value_pairs[0][0].dtype
|
|
250
317
|
|
|
251
318
|
class _config:
|
|
@@ -256,7 +323,7 @@ def make_mamba_cache(
|
|
|
256
323
|
self.num_hidden_layers = len(key_value_pairs)
|
|
257
324
|
self.dtype = dtype
|
|
258
325
|
|
|
259
|
-
cache =
|
|
326
|
+
cache = MambaCache(
|
|
260
327
|
_config(),
|
|
261
328
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
262
329
|
device=key_value_pairs[0][0].device,
|
|
@@ -286,7 +353,7 @@ def make_mamba_cache(
|
|
|
286
353
|
|
|
287
354
|
def make_sliding_window_cache(
|
|
288
355
|
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
289
|
-
) -> transformers.cache_utils.
|
|
356
|
+
) -> transformers.cache_utils.SlidingWindowCache:
|
|
290
357
|
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
291
358
|
|
|
292
359
|
class _config:
|
|
@@ -297,21 +364,189 @@ def make_sliding_window_cache(
|
|
|
297
364
|
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
298
365
|
|
|
299
366
|
cache = transformers.cache_utils.SlidingWindowCache(
|
|
300
|
-
_config(),
|
|
367
|
+
config=_config(),
|
|
301
368
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
302
369
|
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
|
|
303
370
|
device=key_value_pairs[0][0].device,
|
|
304
371
|
dtype=key_value_pairs[0][0].dtype,
|
|
305
372
|
)
|
|
373
|
+
ca = CacheKeyValue(cache)
|
|
306
374
|
for i in range(len(key_value_pairs)):
|
|
307
|
-
assert
|
|
375
|
+
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
308
376
|
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
309
377
|
f"got {key_value_pairs[i][0].shape}"
|
|
310
378
|
)
|
|
311
|
-
|
|
312
|
-
assert
|
|
379
|
+
ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
380
|
+
assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
313
381
|
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
314
382
|
f"got {key_value_pairs[i][1].shape}"
|
|
315
383
|
)
|
|
316
|
-
|
|
384
|
+
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
385
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
386
|
+
# The cache constructor contains the two following lines
|
|
387
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
388
|
+
# initialized. We need to remove them.
|
|
389
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
390
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
391
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
392
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
393
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
394
|
+
f"{len(key_value_pairs)} expected."
|
|
395
|
+
)
|
|
396
|
+
return cache
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def make_hybrid_cache(
|
|
400
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
401
|
+
max_cache_len: Optional[int] = None,
|
|
402
|
+
max_batch_size: Optional[int] = None,
|
|
403
|
+
sliding_window: Optional[int] = None,
|
|
404
|
+
) -> transformers.cache_utils.HybridCache:
|
|
405
|
+
"""
|
|
406
|
+
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
|
|
407
|
+
This version is valid for ``transformers < 4.50``.
|
|
408
|
+
|
|
409
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
410
|
+
:return: :class:`transformers.cache_utils.HybridCache`
|
|
411
|
+
|
|
412
|
+
Example:
|
|
413
|
+
|
|
414
|
+
.. runpython::
|
|
415
|
+
:showcode:
|
|
416
|
+
|
|
417
|
+
import torch
|
|
418
|
+
from onnx_diagnostic.helpers import string_type
|
|
419
|
+
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
|
|
420
|
+
|
|
421
|
+
n_layers = 2
|
|
422
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
423
|
+
|
|
424
|
+
past_key_values = make_hybrid_cache(
|
|
425
|
+
[
|
|
426
|
+
(
|
|
427
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
428
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
429
|
+
)
|
|
430
|
+
for i in range(n_layers)
|
|
431
|
+
]
|
|
432
|
+
)
|
|
433
|
+
print(string_type(past_key_values, with_shape=True))
|
|
434
|
+
|
|
435
|
+
This part defines how the shapes are working in one HybridCache.
|
|
436
|
+
|
|
437
|
+
.. code-block:: python
|
|
438
|
+
|
|
439
|
+
self.max_cache_len = (
|
|
440
|
+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
|
|
441
|
+
|
|
442
|
+
# Sliding layers can't be larger than the overall max cache len
|
|
443
|
+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
|
444
|
+
self.max_batch_size = max_batch_size
|
|
445
|
+
|
|
446
|
+
self.head_dim = (
|
|
447
|
+
config.head_dim if hasattr(config, "head_dim")
|
|
448
|
+
else config.hidden_size // config.num_attention_heads
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
self._dtype = dtype
|
|
452
|
+
self.num_key_value_heads = (
|
|
453
|
+
config.num_attention_heads
|
|
454
|
+
if getattr(config, "num_key_value_heads", None) is None
|
|
455
|
+
else config.num_key_value_heads
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
|
459
|
+
if hasattr(config, "layer_types"):
|
|
460
|
+
self.is_sliding = [
|
|
461
|
+
layer_type != "full_attention" for layer_type in config.layer_types]
|
|
462
|
+
else:
|
|
463
|
+
self.is_sliding = [False] * config.num_hidden_layers
|
|
464
|
+
|
|
465
|
+
self.key_cache: list[torch.Tensor] = []
|
|
466
|
+
self.value_cache: list[torch.Tensor] = []
|
|
467
|
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
468
|
+
self.max_cache_len, self.head_dim)
|
|
469
|
+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
470
|
+
self.sliding_window_len, self.head_dim)
|
|
471
|
+
self.sliding_window = min(config.sliding_window, max_cache_len)
|
|
472
|
+
device = torch.device(device) if device is not None else None
|
|
473
|
+
for i in range(config.num_hidden_layers):
|
|
474
|
+
layer_device = layer_device_map[i] if layer_device_map is not None else device
|
|
475
|
+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
|
476
|
+
new_layer_key_cache = torch.zeros(
|
|
477
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
478
|
+
new_layer_value_cache = torch.zeros(
|
|
479
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
480
|
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
481
|
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
482
|
+
self.key_cache.append(new_layer_key_cache)
|
|
483
|
+
self.value_cache.append(new_layer_value_cache)
|
|
484
|
+
"""
|
|
485
|
+
layer_types = None
|
|
486
|
+
if key_value_pairs:
|
|
487
|
+
assert (
|
|
488
|
+
not max_batch_size and not max_cache_len
|
|
489
|
+
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
|
|
490
|
+
max_batch_size = key_value_pairs[0][0].shape[0]
|
|
491
|
+
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
|
|
492
|
+
if len(sets_of_dim) == 1:
|
|
493
|
+
max_cache_len = sets_of_dim.pop()
|
|
494
|
+
sliding_window = max_cache_len
|
|
495
|
+
else:
|
|
496
|
+
assert (
|
|
497
|
+
len(sets_of_dim) == 2
|
|
498
|
+
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
|
|
499
|
+
max_cache_len = max(sets_of_dim)
|
|
500
|
+
sliding_window = min(sets_of_dim)
|
|
501
|
+
layer_types = [
|
|
502
|
+
"full_attention" if i == max_cache_len else "sliding_attention"
|
|
503
|
+
for i in [kv[0].shape[2] for kv in key_value_pairs]
|
|
504
|
+
]
|
|
505
|
+
else:
|
|
506
|
+
assert (
|
|
507
|
+
max_batch_size and max_cache_len
|
|
508
|
+
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
|
|
509
|
+
if sliding_window is None:
|
|
510
|
+
sliding_window = max_cache_len
|
|
511
|
+
_max_cache_len = max_cache_len
|
|
512
|
+
_sliding_window = sliding_window
|
|
513
|
+
|
|
514
|
+
class _config:
|
|
515
|
+
max_cache_len = _max_cache_len
|
|
516
|
+
batch_size = max_batch_size
|
|
517
|
+
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
|
|
518
|
+
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
|
|
519
|
+
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
|
|
520
|
+
num_hidden_layers = len(key_value_pairs)
|
|
521
|
+
sliding_window = _sliding_window
|
|
522
|
+
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
523
|
+
|
|
524
|
+
if layer_types:
|
|
525
|
+
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
526
|
+
|
|
527
|
+
cache = transformers.cache_utils.HybridCache(
|
|
528
|
+
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
|
|
529
|
+
)
|
|
530
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
531
|
+
cache.update(
|
|
532
|
+
key,
|
|
533
|
+
value,
|
|
534
|
+
i,
|
|
535
|
+
cache_kwargs={
|
|
536
|
+
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
|
|
537
|
+
key.device
|
|
538
|
+
)
|
|
539
|
+
},
|
|
540
|
+
)
|
|
541
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
542
|
+
# The cache constructor contains the two following lines
|
|
543
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
544
|
+
# initialized. We need to remove them.
|
|
545
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
546
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
547
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
548
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
549
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
550
|
+
f"{len(key_value_pairs)} expected."
|
|
551
|
+
)
|
|
317
552
|
return cache
|
|
@@ -558,9 +558,17 @@ def string_type(
|
|
|
558
558
|
print(f"[string_type] CACHE1:{type(obj)}")
|
|
559
559
|
return f"MambaCache(conv_states={c}, ssm_states={d})"
|
|
560
560
|
|
|
561
|
-
if obj.__class__.__name__ in {
|
|
561
|
+
if obj.__class__.__name__ in {
|
|
562
|
+
"DynamicCache",
|
|
563
|
+
"SlidingWindowCache",
|
|
564
|
+
"StaticCache",
|
|
565
|
+
"HybridCache",
|
|
566
|
+
}:
|
|
567
|
+
from .cache_helper import CacheKeyValue
|
|
568
|
+
|
|
569
|
+
ca = CacheKeyValue(obj)
|
|
562
570
|
kc = string_type(
|
|
563
|
-
|
|
571
|
+
ca.key_cache,
|
|
564
572
|
with_shape=with_shape,
|
|
565
573
|
with_min_max=with_min_max,
|
|
566
574
|
with_device=with_device,
|
|
@@ -568,7 +576,7 @@ def string_type(
|
|
|
568
576
|
verbose=verbose,
|
|
569
577
|
)
|
|
570
578
|
vc = string_type(
|
|
571
|
-
|
|
579
|
+
ca.value_cache,
|
|
572
580
|
with_shape=with_shape,
|
|
573
581
|
with_min_max=with_min_max,
|
|
574
582
|
with_device=with_device,
|
|
@@ -579,6 +587,27 @@ def string_type(
|
|
|
579
587
|
print(f"[string_type] CACHE2:{type(obj)}")
|
|
580
588
|
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
|
|
581
589
|
|
|
590
|
+
if obj.__class__.__name__ == "StaticLayer":
|
|
591
|
+
kc = string_type(
|
|
592
|
+
list(obj.keys),
|
|
593
|
+
with_shape=with_shape,
|
|
594
|
+
with_min_max=with_min_max,
|
|
595
|
+
with_device=with_device,
|
|
596
|
+
limit=limit,
|
|
597
|
+
verbose=verbose,
|
|
598
|
+
)
|
|
599
|
+
vc = string_type(
|
|
600
|
+
list(obj.values),
|
|
601
|
+
with_shape=with_shape,
|
|
602
|
+
with_min_max=with_min_max,
|
|
603
|
+
with_device=with_device,
|
|
604
|
+
limit=limit,
|
|
605
|
+
verbose=verbose,
|
|
606
|
+
)
|
|
607
|
+
if verbose:
|
|
608
|
+
print(f"[string_type] SL:{type(obj)}")
|
|
609
|
+
return f"{obj.__class__.__name__}(keys={kc}, values={vc})"
|
|
610
|
+
|
|
582
611
|
if obj.__class__.__name__ == "EncoderDecoderCache":
|
|
583
612
|
att = string_type(
|
|
584
613
|
obj.self_attention_cache,
|
|
@@ -663,6 +692,50 @@ def string_type(
|
|
|
663
692
|
f"dtype={obj.dtype}, shape={obj.shape})"
|
|
664
693
|
)
|
|
665
694
|
|
|
695
|
+
if obj.__class__.__name__ == "KeyValuesWrapper":
|
|
696
|
+
import transformers
|
|
697
|
+
|
|
698
|
+
assert isinstance(
|
|
699
|
+
obj, transformers.cache_utils.KeyValuesWrapper
|
|
700
|
+
), f"Unexpected type {type(obj)}"
|
|
701
|
+
if verbose:
|
|
702
|
+
print(f"[string_type] KW0:{type(obj)}")
|
|
703
|
+
s = string_type(
|
|
704
|
+
list(obj),
|
|
705
|
+
with_shape=with_shape,
|
|
706
|
+
with_min_max=with_min_max,
|
|
707
|
+
with_device=with_device,
|
|
708
|
+
limit=limit,
|
|
709
|
+
verbose=verbose,
|
|
710
|
+
)
|
|
711
|
+
return f"{obj.__class__.__name__}[{obj.cache_type}]{s}"
|
|
712
|
+
|
|
713
|
+
if obj.__class__.__name__ == "DynamicLayer":
|
|
714
|
+
import transformers
|
|
715
|
+
|
|
716
|
+
assert isinstance(
|
|
717
|
+
obj, transformers.cache_utils.DynamicLayer
|
|
718
|
+
), f"Unexpected type {type(obj)}"
|
|
719
|
+
if verbose:
|
|
720
|
+
print(f"[string_type] LY0:{type(obj)}")
|
|
721
|
+
s1 = string_type(
|
|
722
|
+
obj.keys,
|
|
723
|
+
with_shape=with_shape,
|
|
724
|
+
with_min_max=with_min_max,
|
|
725
|
+
with_device=with_device,
|
|
726
|
+
limit=limit,
|
|
727
|
+
verbose=verbose,
|
|
728
|
+
)
|
|
729
|
+
s2 = string_type(
|
|
730
|
+
obj.values,
|
|
731
|
+
with_shape=with_shape,
|
|
732
|
+
with_min_max=with_min_max,
|
|
733
|
+
with_device=with_device,
|
|
734
|
+
limit=limit,
|
|
735
|
+
verbose=verbose,
|
|
736
|
+
)
|
|
737
|
+
return f"{obj.__class__.__name__}(keys={s1}, values={s2})"
|
|
738
|
+
|
|
666
739
|
if isinstance(obj, torch.nn.Module):
|
|
667
740
|
if verbose:
|
|
668
741
|
print(f"[string_type] MM:{type(obj)}")
|
|
@@ -858,7 +931,10 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
|
858
931
|
return flatten_object(list(x.items()), drop_keys=drop_keys)
|
|
859
932
|
|
|
860
933
|
if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
|
|
861
|
-
|
|
934
|
+
from .cache_helper import CacheKeyValue
|
|
935
|
+
|
|
936
|
+
kc = CacheKeyValue(x)
|
|
937
|
+
res = flatten_object(kc.key_cache) + flatten_object(kc.value_cache)
|
|
862
938
|
return tuple(res)
|
|
863
939
|
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
864
940
|
res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
|
|
@@ -1424,19 +1500,58 @@ def max_diff(
|
|
|
1424
1500
|
f"level={level}"
|
|
1425
1501
|
)
|
|
1426
1502
|
|
|
1503
|
+
# backup function in case pytorch does not know how to serialize.
|
|
1504
|
+
if expected.__class__.__name__ == "HybridCache":
|
|
1505
|
+
if got.__class__.__name__ == "HybridCache":
|
|
1506
|
+
from .cache_helper import CacheKeyValue
|
|
1507
|
+
|
|
1508
|
+
if verbose >= 6:
|
|
1509
|
+
print(f"[max_diff] HybridCache: {string_type(expected)} ? {string_type(got)}")
|
|
1510
|
+
cae = CacheKeyValue(expected)
|
|
1511
|
+
cag = CacheKeyValue(got)
|
|
1512
|
+
return max_diff(
|
|
1513
|
+
[cae.key_cache, cae.value_cache],
|
|
1514
|
+
[cag.key_cache, cag.value_cache],
|
|
1515
|
+
verbose=verbose,
|
|
1516
|
+
hist=hist,
|
|
1517
|
+
)
|
|
1518
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1519
|
+
from .cache_helper import CacheKeyValue
|
|
1520
|
+
|
|
1521
|
+
cae = CacheKeyValue(expected)
|
|
1522
|
+
return max_diff(
|
|
1523
|
+
[cae.key_cache, cae.value_cache],
|
|
1524
|
+
[got[0], got[1]],
|
|
1525
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1526
|
+
**_dkws,
|
|
1527
|
+
)
|
|
1528
|
+
raise AssertionError(
|
|
1529
|
+
f"HybridCache not fully implemented with classes "
|
|
1530
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1531
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1532
|
+
f"level={level}"
|
|
1533
|
+
)
|
|
1534
|
+
|
|
1427
1535
|
if expected.__class__.__name__ == "StaticCache":
|
|
1428
1536
|
if got.__class__.__name__ == "StaticCache":
|
|
1537
|
+
from .cache_helper import CacheKeyValue
|
|
1538
|
+
|
|
1539
|
+
cae = CacheKeyValue(expected)
|
|
1540
|
+
cag = CacheKeyValue(got)
|
|
1429
1541
|
if verbose >= 6:
|
|
1430
1542
|
print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
|
|
1431
1543
|
return max_diff(
|
|
1432
|
-
[
|
|
1433
|
-
[
|
|
1544
|
+
[cae.key_cache, cae.value_cache],
|
|
1545
|
+
[cag.key_cache, cag.value_cache],
|
|
1434
1546
|
verbose=verbose,
|
|
1435
1547
|
hist=hist,
|
|
1436
1548
|
)
|
|
1437
1549
|
if isinstance(got, tuple) and len(got) == 2:
|
|
1550
|
+
from .cache_helper import CacheKeyValue
|
|
1551
|
+
|
|
1552
|
+
cae = CacheKeyValue(expected)
|
|
1438
1553
|
return max_diff(
|
|
1439
|
-
[
|
|
1554
|
+
[cae.key_cache, cae.value_cache],
|
|
1440
1555
|
[got[0], got[1]],
|
|
1441
1556
|
debug_info=_debug(expected.__class__.__name__),
|
|
1442
1557
|
**_dkws,
|
|
@@ -1455,15 +1570,22 @@ def max_diff(
|
|
|
1455
1570
|
f"[max_diff] SlidingWindowCache: "
|
|
1456
1571
|
f"{string_type(expected)} ? {string_type(got)}"
|
|
1457
1572
|
)
|
|
1573
|
+
from .cache_helper import CacheKeyValue
|
|
1574
|
+
|
|
1575
|
+
cae = CacheKeyValue(expected)
|
|
1576
|
+
cag = CacheKeyValue(got)
|
|
1458
1577
|
return max_diff(
|
|
1459
|
-
[
|
|
1460
|
-
[
|
|
1578
|
+
[cae.key_cache, cae.value_cache],
|
|
1579
|
+
[cag.key_cache, cag.value_cache],
|
|
1461
1580
|
verbose=verbose,
|
|
1462
1581
|
hist=hist,
|
|
1463
1582
|
)
|
|
1464
1583
|
if isinstance(got, tuple) and len(got) == 2:
|
|
1584
|
+
from .cache_helper import CacheKeyValue
|
|
1585
|
+
|
|
1586
|
+
cae = CacheKeyValue(expected)
|
|
1465
1587
|
return max_diff(
|
|
1466
|
-
[
|
|
1588
|
+
[cae.key_cache, cae.value_cache],
|
|
1467
1589
|
[got[0], got[1]],
|
|
1468
1590
|
debug_info=_debug(expected.__class__.__name__),
|
|
1469
1591
|
**_dkws,
|
|
@@ -1521,6 +1643,20 @@ def max_diff(
|
|
|
1521
1643
|
**_dkws,
|
|
1522
1644
|
)
|
|
1523
1645
|
|
|
1646
|
+
if expected.__class__.__name__ == "KeyValuesWrapper":
|
|
1647
|
+
if verbose >= 6:
|
|
1648
|
+
print(f"[max_diff] KeyValuesWrapper: {string_type(expected)} ? {string_type(got)}")
|
|
1649
|
+
if got.__class__.__name__ != expected.__class__.__name__:
|
|
1650
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1651
|
+
if got.cache_type != expected.cache_type:
|
|
1652
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1653
|
+
return max_diff(
|
|
1654
|
+
list(expected),
|
|
1655
|
+
list(got),
|
|
1656
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1657
|
+
**_dkws,
|
|
1658
|
+
)
|
|
1659
|
+
|
|
1524
1660
|
raise AssertionError(
|
|
1525
1661
|
f"Not implemented with implemented with expected="
|
|
1526
1662
|
f"{string_type(expected)}, got={string_type(got)},\n"
|