onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
80
80
  start = 0
81
81
  end = 0
82
82
  subtrees = []
83
- for subspec in spec.children_specs:
83
+ for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
84
84
  end += subspec.num_leaves
85
85
  value = subspec.unflatten(flat[start:end])
86
86
  value = flatten_unflatten_for_dynamic_shapes(
@@ -391,17 +391,22 @@ def make_static_cache(
391
391
  return finalize_cache(cache)
392
392
 
393
393
 
394
- def make_encoder_decoder_cache(
395
- self_attention_cache: transformers.cache_utils.DynamicCache,
396
- cross_attention_cache: transformers.cache_utils.DynamicCache,
397
- ) -> transformers.cache_utils.EncoderDecoderCache:
398
- """Creates an EncoderDecoderCache."""
399
- return transformers.cache_utils.EncoderDecoderCache(
400
- # self_attention_cache=self_attention_cache,
401
- # cross_attention_cache=cross_attention_cache
402
- self_attention_cache,
403
- cross_attention_cache,
404
- )
394
+ if hasattr(transformers.cache_utils, "EncoderDecoderCache"):
395
+
396
+ def make_encoder_decoder_cache(
397
+ self_attention_cache: transformers.cache_utils.DynamicCache,
398
+ cross_attention_cache: transformers.cache_utils.DynamicCache,
399
+ ) -> transformers.cache_utils.EncoderDecoderCache:
400
+ """Creates an EncoderDecoderCache."""
401
+ return transformers.cache_utils.EncoderDecoderCache(
402
+ # self_attention_cache=self_attention_cache,
403
+ # cross_attention_cache=cross_attention_cache
404
+ self_attention_cache,
405
+ cross_attention_cache,
406
+ )
407
+
408
+ else:
409
+ make_encoder_decoder_cache = None # type: ignore[assignment]
405
410
 
406
411
 
407
412
  def make_mamba_cache(
@@ -454,220 +459,229 @@ def make_mamba_cache(
454
459
  return finalize_cache(cache)
455
460
 
456
461
 
457
- def make_sliding_window_cache(
458
- key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
459
- ) -> transformers.cache_utils.SlidingWindowCache:
460
- "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
461
- key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
462
+ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
462
463
 
463
- class _config:
464
- def __init__(self):
465
- self.head_dim = key_value_pairs[0][0].shape[-1]
466
- self.num_attention_heads = key_value_pairs[0][0].shape[1]
467
- self.num_hidden_layers = len(key_value_pairs)
468
- self.sliding_window = key_value_pairs[0][0].shape[2]
469
-
470
- def get_text_config(self, *args, **kwargs):
471
- return self
472
-
473
- cache = transformers.cache_utils.SlidingWindowCache(
474
- config=_config(),
475
- max_batch_size=key_value_pairs[0][0].shape[0],
476
- max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
477
- device=key_value_pairs[0][0].device,
478
- dtype=key_value_pairs[0][0].dtype,
479
- )
480
- ca = CacheKeyValue(cache)
481
- if hasattr(cache, "layers") and len(ca.key_cache) == 0:
482
- # transformers>= 4.55.2, layers are empty
483
- cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
484
- for i, (key, value) in enumerate(key_value_pairs):
485
- cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
486
- return cache
464
+ def make_sliding_window_cache(
465
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
466
+ ) -> transformers.cache_utils.SlidingWindowCache:
467
+ "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
468
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
487
469
 
488
- for i in range(len(key_value_pairs)):
489
- assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
490
- f"Shape mismatch, expected {cache.key_cache[i].shape}, "
491
- f"got {key_value_pairs[i][0].shape}"
470
+ class _config:
471
+ def __init__(self):
472
+ self.head_dim = key_value_pairs[0][0].shape[-1]
473
+ self.num_attention_heads = key_value_pairs[0][0].shape[1]
474
+ self.num_hidden_layers = len(key_value_pairs)
475
+ self.sliding_window = key_value_pairs[0][0].shape[2]
476
+
477
+ def get_text_config(self, *args, **kwargs):
478
+ return self
479
+
480
+ cache = transformers.cache_utils.SlidingWindowCache(
481
+ config=_config(),
482
+ max_batch_size=key_value_pairs[0][0].shape[0],
483
+ max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
484
+ device=key_value_pairs[0][0].device,
485
+ dtype=key_value_pairs[0][0].dtype,
492
486
  )
493
- ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
494
- assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
495
- f"Shape mismatch, expected {cache.value_cache[i].shape}, "
496
- f"got {key_value_pairs[i][1].shape}"
487
+ ca = CacheKeyValue(cache)
488
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
489
+ # transformers>= 4.55.2, layers are empty
490
+ cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
491
+ for i, (key, value) in enumerate(key_value_pairs):
492
+ cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
493
+ return cache
494
+
495
+ for i in range(len(key_value_pairs)):
496
+ assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
497
+ f"Shape mismatch, expected {cache.key_cache[i].shape}, "
498
+ f"got {key_value_pairs[i][0].shape}"
499
+ )
500
+ ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
501
+ assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
502
+ f"Shape mismatch, expected {cache.value_cache[i].shape}, "
503
+ f"got {key_value_pairs[i][1].shape}"
504
+ )
505
+ ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
506
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
507
+ # The cache constructor contains the two following lines
508
+ # (in cache_utils.py) which append empty layers when the cache is
509
+ # initialized. We need to remove them.
510
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
511
+ # self.append_new_layers(self.num_hidden_layers - 1)
512
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
513
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
514
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
515
+ f"{len(key_value_pairs)} expected."
497
516
  )
498
- ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
499
- if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
500
- # The cache constructor contains the two following lines
501
- # (in cache_utils.py) which append empty layers when the cache is
502
- # initialized. We need to remove them.
503
- # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
504
- # self.append_new_layers(self.num_hidden_layers - 1)
505
- cache.layers[:] = cache.layers[-len(key_value_pairs) :]
506
- assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
507
- f"Unexpected number of layers in the cache ({len(cache.layers)}), "
508
- f"{len(key_value_pairs)} expected."
509
- )
510
- return finalize_cache(cache)
517
+ return finalize_cache(cache)
511
518
 
519
+ else:
520
+ make_sliding_window_cache = None # type: ignore[assignment]
512
521
 
513
- def make_hybrid_cache(
514
- key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
515
- max_cache_len: Optional[int] = None,
516
- max_batch_size: Optional[int] = None,
517
- sliding_window: Optional[int] = None,
518
- ) -> transformers.cache_utils.HybridCache:
519
- """
520
- Creates an instance of :class:`transformers.cache_utils.HybridCache`.
521
- This version is valid for ``transformers < 4.50``.
522
+ if hasattr(transformers.cache_utils, "HybridCache"):
522
523
 
523
- :param key_value_pairs: list of pairs of (key, values)
524
- :return: :class:`transformers.cache_utils.HybridCache`
524
+ def make_hybrid_cache(
525
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
526
+ max_cache_len: Optional[int] = None,
527
+ max_batch_size: Optional[int] = None,
528
+ sliding_window: Optional[int] = None,
529
+ ) -> transformers.cache_utils.HybridCache:
530
+ """
531
+ Creates an instance of :class:`transformers.cache_utils.HybridCache`.
532
+ This version is valid for ``transformers < 4.50``.
525
533
 
526
- Example:
534
+ :param key_value_pairs: list of pairs of (key, values)
535
+ :return: :class:`transformers.cache_utils.HybridCache`
527
536
 
528
- .. runpython::
529
- :showcode:
537
+ Example:
530
538
 
531
- import torch
532
- from onnx_diagnostic.helpers import string_type
533
- from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
539
+ .. runpython::
540
+ :showcode:
534
541
 
535
- n_layers = 2
536
- bsize, nheads, slen, dim = 2, 4, 3, 7
542
+ import torch
543
+ from onnx_diagnostic.helpers import string_type
544
+ from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
537
545
 
538
- past_key_values = make_hybrid_cache(
539
- [
540
- (
541
- torch.randn(bsize, nheads, slen, dim),
542
- torch.randn(bsize, nheads, slen, dim),
543
- )
544
- for i in range(n_layers)
545
- ]
546
- )
547
- print(string_type(past_key_values, with_shape=True))
546
+ n_layers = 2
547
+ bsize, nheads, slen, dim = 2, 4, 3, 7
548
+
549
+ past_key_values = make_hybrid_cache(
550
+ [
551
+ (
552
+ torch.randn(bsize, nheads, slen, dim),
553
+ torch.randn(bsize, nheads, slen, dim),
554
+ )
555
+ for i in range(n_layers)
556
+ ]
557
+ )
558
+ print(string_type(past_key_values, with_shape=True))
548
559
 
549
- This part defines how the shapes are working in one HybridCache.
560
+ This part defines how the shapes are working in one HybridCache.
550
561
 
551
- .. code-block:: python
562
+ .. code-block:: python
552
563
 
553
- self.max_cache_len = (
554
- max_cache_len if max_cache_len is not None else config.max_position_embeddings)
564
+ self.max_cache_len = (
565
+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
555
566
 
556
- # Sliding layers can't be larger than the overall max cache len
557
- self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
558
- self.max_batch_size = max_batch_size
567
+ # Sliding layers can't be larger than the overall max cache len
568
+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
569
+ self.max_batch_size = max_batch_size
559
570
 
560
- self.head_dim = (
561
- config.head_dim if hasattr(config, "head_dim")
562
- else config.hidden_size // config.num_attention_heads
563
- )
571
+ self.head_dim = (
572
+ config.head_dim if hasattr(config, "head_dim")
573
+ else config.hidden_size // config.num_attention_heads
574
+ )
564
575
 
565
- self._dtype = dtype
566
- self.num_key_value_heads = (
567
- config.num_attention_heads
568
- if getattr(config, "num_key_value_heads", None) is None
569
- else config.num_key_value_heads
570
- )
576
+ self._dtype = dtype
577
+ self.num_key_value_heads = (
578
+ config.num_attention_heads
579
+ if getattr(config, "num_key_value_heads", None) is None
580
+ else config.num_key_value_heads
581
+ )
571
582
 
572
- # If the attribute does not exist in the config, fallback to a simple StaticCache
573
- if hasattr(config, "layer_types"):
574
- self.is_sliding = [
575
- layer_type != "full_attention" for layer_type in config.layer_types]
576
- else:
577
- self.is_sliding = [False] * config.num_hidden_layers
578
-
579
- self.key_cache: list[torch.Tensor] = []
580
- self.value_cache: list[torch.Tensor] = []
581
- global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
582
- self.max_cache_len, self.head_dim)
583
- sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
584
- self.sliding_window_len, self.head_dim)
585
- self.sliding_window = min(config.sliding_window, max_cache_len)
586
- device = torch.device(device) if device is not None else None
587
- for i in range(config.num_hidden_layers):
588
- layer_device = layer_device_map[i] if layer_device_map is not None else device
589
- cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
590
- new_layer_key_cache = torch.zeros(
591
- cache_shape, dtype=self._dtype, device=layer_device)
592
- new_layer_value_cache = torch.zeros(
593
- cache_shape, dtype=self._dtype, device=layer_device)
594
- torch._dynamo.mark_static_address(new_layer_key_cache)
595
- torch._dynamo.mark_static_address(new_layer_value_cache)
596
- self.key_cache.append(new_layer_key_cache)
597
- self.value_cache.append(new_layer_value_cache)
598
- """
599
- key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
600
- layer_types = None
601
- if key_value_pairs:
602
- assert (
603
- not max_batch_size and not max_cache_len
604
- ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
605
- max_batch_size = key_value_pairs[0][0].shape[0]
606
- sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
607
- if len(sets_of_dim) == 1:
608
- max_cache_len = sets_of_dim.pop()
609
- sliding_window = max_cache_len
583
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
584
+ if hasattr(config, "layer_types"):
585
+ self.is_sliding = [
586
+ layer_type != "full_attention" for layer_type in config.layer_types]
587
+ else:
588
+ self.is_sliding = [False] * config.num_hidden_layers
589
+
590
+ self.key_cache: list[torch.Tensor] = []
591
+ self.value_cache: list[torch.Tensor] = []
592
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
593
+ self.max_cache_len, self.head_dim)
594
+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
595
+ self.sliding_window_len, self.head_dim)
596
+ self.sliding_window = min(config.sliding_window, max_cache_len)
597
+ device = torch.device(device) if device is not None else None
598
+ for i in range(config.num_hidden_layers):
599
+ layer_device = layer_device_map[i] if layer_device_map is not None else device
600
+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
601
+ new_layer_key_cache = torch.zeros(
602
+ cache_shape, dtype=self._dtype, device=layer_device)
603
+ new_layer_value_cache = torch.zeros(
604
+ cache_shape, dtype=self._dtype, device=layer_device)
605
+ torch._dynamo.mark_static_address(new_layer_key_cache)
606
+ torch._dynamo.mark_static_address(new_layer_value_cache)
607
+ self.key_cache.append(new_layer_key_cache)
608
+ self.value_cache.append(new_layer_value_cache)
609
+ """
610
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
611
+ layer_types = None
612
+ if key_value_pairs:
613
+ assert (
614
+ not max_batch_size and not max_cache_len
615
+ ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
616
+ max_batch_size = key_value_pairs[0][0].shape[0]
617
+ sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
618
+ if len(sets_of_dim) == 1:
619
+ max_cache_len = sets_of_dim.pop()
620
+ sliding_window = max_cache_len
621
+ else:
622
+ assert (
623
+ len(sets_of_dim) == 2
624
+ ), f"Not implemented for more than 2 dimensions {sets_of_dim}"
625
+ max_cache_len = max(sets_of_dim)
626
+ sliding_window = min(sets_of_dim)
627
+ layer_types = [
628
+ "full_attention" if i == max_cache_len else "sliding_attention"
629
+ for i in [kv[0].shape[2] for kv in key_value_pairs]
630
+ ]
610
631
  else:
611
632
  assert (
612
- len(sets_of_dim) == 2
613
- ), f"Not implemented for more than 2 dimensions {sets_of_dim}"
614
- max_cache_len = max(sets_of_dim)
615
- sliding_window = min(sets_of_dim)
616
- layer_types = [
617
- "full_attention" if i == max_cache_len else "sliding_attention"
618
- for i in [kv[0].shape[2] for kv in key_value_pairs]
619
- ]
620
- else:
621
- assert (
622
- max_batch_size and max_cache_len
623
- ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
624
- if sliding_window is None:
625
- sliding_window = max_cache_len
626
- _max_cache_len = max_cache_len
627
- _sliding_window = sliding_window
628
-
629
- class _config:
630
- max_cache_len = _max_cache_len
631
- batch_size = max_batch_size
632
- num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
633
- head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
634
- num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
635
- num_hidden_layers = len(key_value_pairs)
636
- sliding_window = _sliding_window
637
- num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
638
-
639
- def get_text_config(self, *args, **kwargs):
640
- return self
641
-
642
- if layer_types:
643
- _config.layer_types = layer_types # type: ignore[attr-defined]
644
-
645
- cache = transformers.cache_utils.HybridCache(
646
- config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
647
- )
648
- for i, (key, value) in enumerate(key_value_pairs):
649
- cache.update(
650
- key,
651
- value,
652
- i,
653
- cache_kwargs={
654
- "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
655
- key.device
656
- )
657
- },
633
+ max_batch_size and max_cache_len
634
+ ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
635
+ if sliding_window is None:
636
+ sliding_window = max_cache_len
637
+ _max_cache_len = max_cache_len
638
+ _sliding_window = sliding_window
639
+
640
+ class _config:
641
+ max_cache_len = _max_cache_len
642
+ batch_size = max_batch_size
643
+ num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
644
+ head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
645
+ num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
646
+ num_hidden_layers = len(key_value_pairs)
647
+ sliding_window = _sliding_window
648
+ num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
649
+
650
+ def get_text_config(self, *args, **kwargs):
651
+ return self
652
+
653
+ if layer_types:
654
+ _config.layer_types = layer_types # type: ignore[attr-defined]
655
+
656
+ cache = transformers.cache_utils.HybridCache(
657
+ config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
658
658
  )
659
- if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
660
- # The cache constructor contains the two following lines
661
- # (in cache_utils.py) which append empty layers when the cache is
662
- # initialized. We need to remove them.
663
- # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
664
- # self.append_new_layers(self.num_hidden_layers - 1)
665
- cache.layers[:] = cache.layers[-len(key_value_pairs) :]
666
- assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
667
- f"Unexpected number of layers in the cache ({len(cache.layers)}), "
668
- f"{len(key_value_pairs)} expected."
669
- )
670
- return finalize_cache(cache)
659
+ for i, (key, value) in enumerate(key_value_pairs):
660
+ cache.update(
661
+ key,
662
+ value,
663
+ i,
664
+ cache_kwargs={
665
+ "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
666
+ key.device
667
+ )
668
+ },
669
+ )
670
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
671
+ # The cache constructor contains the two following lines
672
+ # (in cache_utils.py) which append empty layers when the cache is
673
+ # initialized. We need to remove them.
674
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
675
+ # self.append_new_layers(self.num_hidden_layers - 1)
676
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
677
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
678
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
679
+ f"{len(key_value_pairs)} expected."
680
+ )
681
+ return finalize_cache(cache)
682
+
683
+ else:
684
+ make_hybrid_cache = None # type: ignore[assignment]
671
685
 
672
686
 
673
687
  def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache: