onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +118 -5
- onnx_diagnostic/export/control_flow.py +214 -0
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +118 -25
- onnx_diagnostic/helpers/cache_helper.py +218 -204
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +92 -26
- onnx_diagnostic/helpers/log_helper.py +26 -4
- onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +115 -16
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/rt_helper.py +547 -0
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +108 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
80
80
|
start = 0
|
|
81
81
|
end = 0
|
|
82
82
|
subtrees = []
|
|
83
|
-
for subspec in spec.children_specs:
|
|
83
|
+
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
|
|
84
84
|
end += subspec.num_leaves
|
|
85
85
|
value = subspec.unflatten(flat[start:end])
|
|
86
86
|
value = flatten_unflatten_for_dynamic_shapes(
|
|
@@ -391,17 +391,22 @@ def make_static_cache(
|
|
|
391
391
|
return finalize_cache(cache)
|
|
392
392
|
|
|
393
393
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
394
|
+
if hasattr(transformers.cache_utils, "EncoderDecoderCache"):
|
|
395
|
+
|
|
396
|
+
def make_encoder_decoder_cache(
|
|
397
|
+
self_attention_cache: transformers.cache_utils.DynamicCache,
|
|
398
|
+
cross_attention_cache: transformers.cache_utils.DynamicCache,
|
|
399
|
+
) -> transformers.cache_utils.EncoderDecoderCache:
|
|
400
|
+
"""Creates an EncoderDecoderCache."""
|
|
401
|
+
return transformers.cache_utils.EncoderDecoderCache(
|
|
402
|
+
# self_attention_cache=self_attention_cache,
|
|
403
|
+
# cross_attention_cache=cross_attention_cache
|
|
404
|
+
self_attention_cache,
|
|
405
|
+
cross_attention_cache,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
else:
|
|
409
|
+
make_encoder_decoder_cache = None # type: ignore[assignment]
|
|
405
410
|
|
|
406
411
|
|
|
407
412
|
def make_mamba_cache(
|
|
@@ -454,220 +459,229 @@ def make_mamba_cache(
|
|
|
454
459
|
return finalize_cache(cache)
|
|
455
460
|
|
|
456
461
|
|
|
457
|
-
|
|
458
|
-
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
459
|
-
) -> transformers.cache_utils.SlidingWindowCache:
|
|
460
|
-
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
461
|
-
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
462
|
+
if hasattr(transformers.cache_utils, "SlidingWindowCache"):
|
|
462
463
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
469
|
-
|
|
470
|
-
def get_text_config(self, *args, **kwargs):
|
|
471
|
-
return self
|
|
472
|
-
|
|
473
|
-
cache = transformers.cache_utils.SlidingWindowCache(
|
|
474
|
-
config=_config(),
|
|
475
|
-
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
476
|
-
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
|
|
477
|
-
device=key_value_pairs[0][0].device,
|
|
478
|
-
dtype=key_value_pairs[0][0].dtype,
|
|
479
|
-
)
|
|
480
|
-
ca = CacheKeyValue(cache)
|
|
481
|
-
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
482
|
-
# transformers>= 4.55.2, layers are empty
|
|
483
|
-
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
|
|
484
|
-
for i, (key, value) in enumerate(key_value_pairs):
|
|
485
|
-
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
|
|
486
|
-
return cache
|
|
464
|
+
def make_sliding_window_cache(
|
|
465
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
466
|
+
) -> transformers.cache_utils.SlidingWindowCache:
|
|
467
|
+
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
468
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
487
469
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
470
|
+
class _config:
|
|
471
|
+
def __init__(self):
|
|
472
|
+
self.head_dim = key_value_pairs[0][0].shape[-1]
|
|
473
|
+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
474
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
475
|
+
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
476
|
+
|
|
477
|
+
def get_text_config(self, *args, **kwargs):
|
|
478
|
+
return self
|
|
479
|
+
|
|
480
|
+
cache = transformers.cache_utils.SlidingWindowCache(
|
|
481
|
+
config=_config(),
|
|
482
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
483
|
+
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
|
|
484
|
+
device=key_value_pairs[0][0].device,
|
|
485
|
+
dtype=key_value_pairs[0][0].dtype,
|
|
492
486
|
)
|
|
493
|
-
ca
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
487
|
+
ca = CacheKeyValue(cache)
|
|
488
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
489
|
+
# transformers>= 4.55.2, layers are empty
|
|
490
|
+
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
|
|
491
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
492
|
+
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
|
|
493
|
+
return cache
|
|
494
|
+
|
|
495
|
+
for i in range(len(key_value_pairs)):
|
|
496
|
+
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
497
|
+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
498
|
+
f"got {key_value_pairs[i][0].shape}"
|
|
499
|
+
)
|
|
500
|
+
ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
501
|
+
assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
502
|
+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
503
|
+
f"got {key_value_pairs[i][1].shape}"
|
|
504
|
+
)
|
|
505
|
+
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
506
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
507
|
+
# The cache constructor contains the two following lines
|
|
508
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
509
|
+
# initialized. We need to remove them.
|
|
510
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
511
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
512
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
513
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
514
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
515
|
+
f"{len(key_value_pairs)} expected."
|
|
497
516
|
)
|
|
498
|
-
|
|
499
|
-
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
500
|
-
# The cache constructor contains the two following lines
|
|
501
|
-
# (in cache_utils.py) which append empty layers when the cache is
|
|
502
|
-
# initialized. We need to remove them.
|
|
503
|
-
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
504
|
-
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
505
|
-
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
506
|
-
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
507
|
-
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
508
|
-
f"{len(key_value_pairs)} expected."
|
|
509
|
-
)
|
|
510
|
-
return finalize_cache(cache)
|
|
517
|
+
return finalize_cache(cache)
|
|
511
518
|
|
|
519
|
+
else:
|
|
520
|
+
make_sliding_window_cache = None # type: ignore[assignment]
|
|
512
521
|
|
|
513
|
-
|
|
514
|
-
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
515
|
-
max_cache_len: Optional[int] = None,
|
|
516
|
-
max_batch_size: Optional[int] = None,
|
|
517
|
-
sliding_window: Optional[int] = None,
|
|
518
|
-
) -> transformers.cache_utils.HybridCache:
|
|
519
|
-
"""
|
|
520
|
-
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
|
|
521
|
-
This version is valid for ``transformers < 4.50``.
|
|
522
|
+
if hasattr(transformers.cache_utils, "HybridCache"):
|
|
522
523
|
|
|
523
|
-
|
|
524
|
-
|
|
524
|
+
def make_hybrid_cache(
|
|
525
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
526
|
+
max_cache_len: Optional[int] = None,
|
|
527
|
+
max_batch_size: Optional[int] = None,
|
|
528
|
+
sliding_window: Optional[int] = None,
|
|
529
|
+
) -> transformers.cache_utils.HybridCache:
|
|
530
|
+
"""
|
|
531
|
+
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
|
|
532
|
+
This version is valid for ``transformers < 4.50``.
|
|
525
533
|
|
|
526
|
-
|
|
534
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
535
|
+
:return: :class:`transformers.cache_utils.HybridCache`
|
|
527
536
|
|
|
528
|
-
|
|
529
|
-
:showcode:
|
|
537
|
+
Example:
|
|
530
538
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
|
|
539
|
+
.. runpython::
|
|
540
|
+
:showcode:
|
|
534
541
|
|
|
535
|
-
|
|
536
|
-
|
|
542
|
+
import torch
|
|
543
|
+
from onnx_diagnostic.helpers import string_type
|
|
544
|
+
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
|
|
537
545
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
546
|
+
n_layers = 2
|
|
547
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
548
|
+
|
|
549
|
+
past_key_values = make_hybrid_cache(
|
|
550
|
+
[
|
|
551
|
+
(
|
|
552
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
553
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
554
|
+
)
|
|
555
|
+
for i in range(n_layers)
|
|
556
|
+
]
|
|
557
|
+
)
|
|
558
|
+
print(string_type(past_key_values, with_shape=True))
|
|
548
559
|
|
|
549
|
-
|
|
560
|
+
This part defines how the shapes are working in one HybridCache.
|
|
550
561
|
|
|
551
|
-
|
|
562
|
+
.. code-block:: python
|
|
552
563
|
|
|
553
|
-
|
|
554
|
-
|
|
564
|
+
self.max_cache_len = (
|
|
565
|
+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
|
|
555
566
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
567
|
+
# Sliding layers can't be larger than the overall max cache len
|
|
568
|
+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
|
569
|
+
self.max_batch_size = max_batch_size
|
|
559
570
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
571
|
+
self.head_dim = (
|
|
572
|
+
config.head_dim if hasattr(config, "head_dim")
|
|
573
|
+
else config.hidden_size // config.num_attention_heads
|
|
574
|
+
)
|
|
564
575
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
576
|
+
self._dtype = dtype
|
|
577
|
+
self.num_key_value_heads = (
|
|
578
|
+
config.num_attention_heads
|
|
579
|
+
if getattr(config, "num_key_value_heads", None) is None
|
|
580
|
+
else config.num_key_value_heads
|
|
581
|
+
)
|
|
571
582
|
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
583
|
+
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
|
584
|
+
if hasattr(config, "layer_types"):
|
|
585
|
+
self.is_sliding = [
|
|
586
|
+
layer_type != "full_attention" for layer_type in config.layer_types]
|
|
587
|
+
else:
|
|
588
|
+
self.is_sliding = [False] * config.num_hidden_layers
|
|
589
|
+
|
|
590
|
+
self.key_cache: list[torch.Tensor] = []
|
|
591
|
+
self.value_cache: list[torch.Tensor] = []
|
|
592
|
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
593
|
+
self.max_cache_len, self.head_dim)
|
|
594
|
+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
595
|
+
self.sliding_window_len, self.head_dim)
|
|
596
|
+
self.sliding_window = min(config.sliding_window, max_cache_len)
|
|
597
|
+
device = torch.device(device) if device is not None else None
|
|
598
|
+
for i in range(config.num_hidden_layers):
|
|
599
|
+
layer_device = layer_device_map[i] if layer_device_map is not None else device
|
|
600
|
+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
|
601
|
+
new_layer_key_cache = torch.zeros(
|
|
602
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
603
|
+
new_layer_value_cache = torch.zeros(
|
|
604
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
605
|
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
606
|
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
607
|
+
self.key_cache.append(new_layer_key_cache)
|
|
608
|
+
self.value_cache.append(new_layer_value_cache)
|
|
609
|
+
"""
|
|
610
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
611
|
+
layer_types = None
|
|
612
|
+
if key_value_pairs:
|
|
613
|
+
assert (
|
|
614
|
+
not max_batch_size and not max_cache_len
|
|
615
|
+
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
|
|
616
|
+
max_batch_size = key_value_pairs[0][0].shape[0]
|
|
617
|
+
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
|
|
618
|
+
if len(sets_of_dim) == 1:
|
|
619
|
+
max_cache_len = sets_of_dim.pop()
|
|
620
|
+
sliding_window = max_cache_len
|
|
621
|
+
else:
|
|
622
|
+
assert (
|
|
623
|
+
len(sets_of_dim) == 2
|
|
624
|
+
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
|
|
625
|
+
max_cache_len = max(sets_of_dim)
|
|
626
|
+
sliding_window = min(sets_of_dim)
|
|
627
|
+
layer_types = [
|
|
628
|
+
"full_attention" if i == max_cache_len else "sliding_attention"
|
|
629
|
+
for i in [kv[0].shape[2] for kv in key_value_pairs]
|
|
630
|
+
]
|
|
610
631
|
else:
|
|
611
632
|
assert (
|
|
612
|
-
|
|
613
|
-
),
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
638
|
-
|
|
639
|
-
def get_text_config(self, *args, **kwargs):
|
|
640
|
-
return self
|
|
641
|
-
|
|
642
|
-
if layer_types:
|
|
643
|
-
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
644
|
-
|
|
645
|
-
cache = transformers.cache_utils.HybridCache(
|
|
646
|
-
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
|
|
647
|
-
)
|
|
648
|
-
for i, (key, value) in enumerate(key_value_pairs):
|
|
649
|
-
cache.update(
|
|
650
|
-
key,
|
|
651
|
-
value,
|
|
652
|
-
i,
|
|
653
|
-
cache_kwargs={
|
|
654
|
-
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
|
|
655
|
-
key.device
|
|
656
|
-
)
|
|
657
|
-
},
|
|
633
|
+
max_batch_size and max_cache_len
|
|
634
|
+
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
|
|
635
|
+
if sliding_window is None:
|
|
636
|
+
sliding_window = max_cache_len
|
|
637
|
+
_max_cache_len = max_cache_len
|
|
638
|
+
_sliding_window = sliding_window
|
|
639
|
+
|
|
640
|
+
class _config:
|
|
641
|
+
max_cache_len = _max_cache_len
|
|
642
|
+
batch_size = max_batch_size
|
|
643
|
+
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
|
|
644
|
+
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
|
|
645
|
+
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
|
|
646
|
+
num_hidden_layers = len(key_value_pairs)
|
|
647
|
+
sliding_window = _sliding_window
|
|
648
|
+
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
649
|
+
|
|
650
|
+
def get_text_config(self, *args, **kwargs):
|
|
651
|
+
return self
|
|
652
|
+
|
|
653
|
+
if layer_types:
|
|
654
|
+
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
655
|
+
|
|
656
|
+
cache = transformers.cache_utils.HybridCache(
|
|
657
|
+
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
|
|
658
658
|
)
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
659
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
660
|
+
cache.update(
|
|
661
|
+
key,
|
|
662
|
+
value,
|
|
663
|
+
i,
|
|
664
|
+
cache_kwargs={
|
|
665
|
+
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
|
|
666
|
+
key.device
|
|
667
|
+
)
|
|
668
|
+
},
|
|
669
|
+
)
|
|
670
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
671
|
+
# The cache constructor contains the two following lines
|
|
672
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
673
|
+
# initialized. We need to remove them.
|
|
674
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
675
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
676
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
677
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
678
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
679
|
+
f"{len(key_value_pairs)} expected."
|
|
680
|
+
)
|
|
681
|
+
return finalize_cache(cache)
|
|
682
|
+
|
|
683
|
+
else:
|
|
684
|
+
make_hybrid_cache = None # type: ignore[assignment]
|
|
671
685
|
|
|
672
686
|
|
|
673
687
|
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
|