onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.16"
6
+ __version__ = "0.8.1"
7
7
  __author__ = "Xavier Dupré"
@@ -265,7 +265,7 @@ def get_parser_config() -> ArgumentParser:
265
265
  "--mop",
266
266
  metavar="KEY=VALUE",
267
267
  nargs="*",
268
- help="Additional model options, use to change some parameters of the model, "
268
+ help="Additional model options, used to change some parameters of the model, "
269
269
  "example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
270
270
  action=_ParseDict,
271
271
  )
@@ -442,11 +442,17 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
442
442
  default=True,
443
443
  action=_BoolOrParseDictPatch,
444
444
  nargs="*",
445
- help="Applies patches before exporting, it can be a boolean "
446
- "to enable to disable the patches or be more finetuned. It is possible to "
447
- "disable patch for torch by adding "
448
- '--patch "patch_sympy=False" --patch "patch_torch=False", '
449
- "default is True.",
445
+ help=textwrap.dedent(
446
+ """
447
+ Applies patches before exporting, it can be a boolean
448
+ to enable to disable the patches or be more finetuned
449
+ (default is True). It is possible to disable patch for torch
450
+ by adding:
451
+ --patch "patch_sympy=False" --patch "patch_torch=False"
452
+ """.strip(
453
+ "\n"
454
+ )
455
+ ),
450
456
  )
451
457
  parser.add_argument(
452
458
  "--rewrite",
@@ -476,10 +482,16 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
476
482
  "--inputs2",
477
483
  default=1,
478
484
  type=int,
479
- help="Validates or exports the model on a second set of inputs\n"
480
- "to check the exported model supports dynamism. The values is used "
481
- "as an increment to the first set of inputs. A high value may trick "
482
- "a different behavior in the model and missed by the exporter.",
485
+ help=textwrap.dedent(
486
+ """
487
+ Validates or exports the model on a second set of inputs
488
+ to check the exported model supports dynamism. The values is used
489
+ as an increment to the first set of inputs. A high value may trick
490
+ a different behavior in the model and missed by the exporter.
491
+ """.strip(
492
+ "\n"
493
+ )
494
+ ),
483
495
  )
484
496
  parser.add_argument(
485
497
  "--runtime",
@@ -512,9 +524,15 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
512
524
  parser.add_argument(
513
525
  "--ortfusiontype",
514
526
  required=False,
515
- help="Applies onnxruntime fusion, this parameter should contain the\n"
516
- "model type or multiple values separated by `|`. `ALL` can be used\n"
517
- "to run them all.",
527
+ help=textwrap.dedent(
528
+ """
529
+ Applies onnxruntime fusion, this parameter should contain the
530
+ model type or multiple values separated by `|`. `ALL` can be used
531
+ to run them all.
532
+ """.strip(
533
+ "\n"
534
+ )
535
+ ),
518
536
  )
519
537
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
520
538
  parser.add_argument("--dtype", help="Changes dtype if necessary.")
@@ -523,18 +541,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
523
541
  "--iop",
524
542
  metavar="KEY=VALUE",
525
543
  nargs="*",
526
- help="Additional input options, use to change the default"
527
- "inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
528
- "\n --iop cls_cache=StaticCache",
544
+ help=textwrap.dedent(
545
+ """
546
+ Additional input options, used to change the default
547
+ inputs use to export. Examples:
548
+ --iop cls_cache=SlidingWindowCache
549
+ --iop cls_cache=StaticCache
550
+ """.strip(
551
+ "\n"
552
+ )
553
+ ),
529
554
  action=_ParseDict,
530
555
  )
531
556
  parser.add_argument(
532
557
  "--mop",
533
558
  metavar="KEY=VALUE",
534
559
  nargs="*",
535
- help="Additional model options, use to change some parameters of the model, "
536
- "example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
537
- "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
560
+ help=textwrap.dedent(
561
+ """
562
+ Additional model options, used to change some parameters
563
+ of the model. Example:
564
+ --mop attn_implementation=sdpa --mop attn_implementation=eager"
565
+ --mop "rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}"
566
+ """.strip(
567
+ "\n"
568
+ )
569
+ ),
538
570
  action=_ParseDict,
539
571
  )
540
572
  if name == "validate":
@@ -566,9 +598,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
566
598
  parser.add_argument(
567
599
  "--quiet-input-sets",
568
600
  default="",
569
- help="Avoids raising an exception when an input sets does not work with "
570
- "the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
601
+ help=textwrap.dedent(
602
+ """
603
+ Avoids raising an exception when an input sets does not work with
604
+ the exported model. Example:
605
+ --quiet-input-sets=inputs,inputs22
606
+ """.strip(
607
+ "\n"
608
+ )
609
+ ),
571
610
  )
611
+ parser.add_argument(
612
+ "--expop",
613
+ metavar="KEY=VALUE",
614
+ nargs="*",
615
+ help=textwrap.dedent(
616
+ """
617
+ Additional exporter options, use to change some parameters
618
+ of the model. Examples:
619
+ --expop report=True
620
+ --expop report=True --expop verify=True
621
+ """.strip(
622
+ "\n"
623
+ )
624
+ ),
625
+ action=_ParseDict,
626
+ )
572
627
  return parser
573
628
 
574
629
 
@@ -634,6 +689,7 @@ def _cmd_validate(argv: List[Any]):
634
689
  output_names=(
635
690
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
636
691
  ),
692
+ exporter_options=args.expop,
637
693
  )
638
694
  print("")
639
695
  print("-- summary --")
@@ -940,7 +996,7 @@ def get_parser_agg() -> ArgumentParser:
940
996
  "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
941
997
  "n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
942
998
  "n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
943
- "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
999
+ "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,n_node_random,"
944
1000
  "n_node_constant,n_node_shape,n_node_expand,"
945
1001
  "n_node_function,n_node_initializer,n_node_scatter,"
946
1002
  "time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
@@ -0,0 +1,124 @@
1
+ from typing import Any, Dict, List, Sequence, Optional, Tuple, Union
2
+ import torch
3
+
4
+
5
+ def to_onnx(
6
+ mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
7
+ args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
8
+ kwargs: Optional[Dict[str, "torch.Tensor"]] = None, # noqa: F821
9
+ input_names: Optional[Sequence[str]] = None,
10
+ target_opset: Optional[Union[int, Dict[str, int]]] = None,
11
+ verbose: int = 0,
12
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
13
+ filename: Optional[str] = None,
14
+ output_names: Optional[List[str]] = None,
15
+ output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
16
+ exporter: str = "onnx-dynamo",
17
+ ) -> Any:
18
+ """
19
+ Common API for exporters. By default, the models are optimized to use the
20
+ most efficient kernels implemented in :epkg:`onnxruntime`.
21
+
22
+ :param mod: torch model
23
+ :param args: unnamed arguments
24
+ :param kwargs: named arguments
25
+ :param input_names: input names for the onnx model (optional)
26
+ :param target_opset: opset to target, if not specified, each converter
27
+ keeps its default value
28
+ :param verbose: verbosity level
29
+ :param dynamic_shapes: dynamic shapes, usually a nested structure
30
+ included a dictionary for each tensor
31
+ :param filename: output filename
32
+ :param output_names: to change the output of the onnx model
33
+ :param output_dynamic_shapes: to overwrite the dynamic shapes names
34
+ :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
35
+ :return: the output of the selected exporter, usually a structure including
36
+ an onnx model
37
+
38
+ A simple example:
39
+
40
+ .. code-block:: python
41
+
42
+ to_onnx(
43
+ model,
44
+ kwargs=inputs,
45
+ dynamic_shapes=ds,
46
+ exporter=exporter,
47
+ filename=filename,
48
+ )
49
+ """
50
+ if exporter == "custom":
51
+ from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
52
+ from experimental_experiment.xbuilder import OptimizationOptions
53
+
54
+ return _to_onnx(
55
+ mod,
56
+ args=args,
57
+ kwargs=kwargs,
58
+ input_names=input_names,
59
+ output_names=output_names,
60
+ target_opset=target_opset,
61
+ verbose=verbose,
62
+ filename=filename,
63
+ dynamic_shapes=dynamic_shapes,
64
+ large_model=True,
65
+ output_dynamic_shapes=output_dynamic_shapes,
66
+ options=OptimizationOptions(patterns="default+onnxruntime"),
67
+ )
68
+ if exporter in ("dynamo", "onnx-dynamo"):
69
+ import onnxscript.rewriter.ort_fusions as ort_fusions
70
+
71
+ assert (
72
+ not output_dynamic_shapes
73
+ ), f"output_dynamic_shapes not supported for exporter={exporter!r}"
74
+ epo = torch.onnx.export(
75
+ mod,
76
+ args=args or tuple(),
77
+ kwargs=kwargs,
78
+ input_names=input_names,
79
+ output_names=output_names,
80
+ opset_version=target_opset,
81
+ dynamic_shapes=dynamic_shapes,
82
+ dynamo=True,
83
+ )
84
+ ort_fusions.optimize_for_ort(epo.model)
85
+ epo.save(filename)
86
+ return epo
87
+
88
+ if exporter == "modelbuilder":
89
+ import os
90
+ from ..helpers import flatten_object, string_type
91
+ from ..helpers.model_builder_helper import create_model_builder, save_model_builder
92
+
93
+ assert filename, f"filename must be specified for exporter={exporter!r}"
94
+ assert (
95
+ not output_dynamic_shapes
96
+ ), f"output_dynamic_shapes not supported for exporter={exporter!r}"
97
+ assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}"
98
+ assert not args, f"only kwargs can be defined with exporter={exporter!r}"
99
+ assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( # type: ignore[arg-type]
100
+ f"Only a specified set of inputs is supported for exporter={exporter!r}, "
101
+ f"but it is {list(kwargs)}" # type: ignore[arg-type]
102
+ )
103
+ flat_inputs = flatten_object(kwargs, drop_keys=True)
104
+ first = flat_inputs[0]
105
+ first_float = [
106
+ t
107
+ for t in flat_inputs
108
+ if t.dtype in {torch.float32, torch.double, torch.float16, torch.bfloat16}
109
+ ]
110
+ assert first_float, (
111
+ f"Unable to find a float tensor in the inputs "
112
+ f"{string_type(kwargs, with_shape=True)}"
113
+ )
114
+ onx = create_model_builder(
115
+ mod.config,
116
+ mod,
117
+ precision=str(first_float[0].dtype).split(".")[-1],
118
+ execution_provider="cuda" if first.is_cuda else "cpu",
119
+ cache_dir=os.path.dirname(filename),
120
+ )
121
+ save_model_builder(onx, os.path.dirname(filename))
122
+ return onx
123
+
124
+ raise ValueError(f"Unknown exporter={exporter!r}")
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import itertools
2
3
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
3
4
  import numpy as np
4
5
  import torch
@@ -934,7 +935,7 @@ class ModelInputs:
934
935
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
935
936
  )
936
937
  )
937
- return [key_cache, value_cache]
938
+ return list(itertools.chain.from_iterable(zip(key_cache, value_cache)))
938
939
 
939
940
  raise NotImplementedError(
940
941
  f"Unable to build dynamic shapes for type {set_types.pop()}: "
@@ -1,6 +1,5 @@
1
1
  from typing import Any, Dict, List, Set, Optional, Tuple, Union
2
2
  from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3
- from ..helpers.fake_tensor_helper import fake_reshape
4
3
  from .dynamic_shapes import ModelInputs
5
4
 
6
5
 
@@ -203,14 +202,49 @@ def guess_dynamic_shapes_from_inputs(
203
202
 
204
203
 
205
204
  def make_fake_with_dynamic_dimensions(
206
- x: Any,
207
- dynamic_shapes: Any,
208
- fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
209
- ) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821
205
+ x: Any, dynamic_shapes: Any, context: Optional["FakeTensorContext"] = None # noqa: F821
206
+ ) -> Tuple[Any, "FakeTensorContext"]: # noqa: F821
210
207
  """
211
208
  Replaces all tensors by fake tensor respecting the same
212
209
  constraints as the following dynamic shapes.
213
210
  This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
211
+ Parameter ``existing`` is used to reused the same object when the dynamic
212
+ dimension is given the same name as another one.
213
+
214
+ A simple tensor:
215
+
216
+ .. runpython::
217
+ :showcode:
218
+
219
+ import torch
220
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
221
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
222
+
223
+ inputs, _ = make_fake_with_dynamic_dimensions(
224
+ torch.rand((2, 3, 4, 5), dtype=torch.float32),
225
+ {0: "batch", 2: "cache_length"},
226
+ )
227
+ print(inputs)
228
+
229
+ Two tensors:
230
+
231
+ .. runpython::
232
+ :showcode:
233
+
234
+ import torch
235
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
236
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
237
+
238
+ inputs, _ = make_fake_with_dynamic_dimensions(
239
+ (
240
+ torch.rand((2, 3, 4, 5), dtype=torch.float32),
241
+ torch.rand((2, 3, 4, 5), dtype=torch.float32),
242
+ ),
243
+ ({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
244
+ )
245
+ print(inputs)
246
+
247
+ With a cache:
214
248
 
215
249
  .. runpython::
216
250
  :showcode:
@@ -243,8 +277,10 @@ def make_fake_with_dynamic_dimensions(
243
277
  "attention_mask": {0: "batch", 1: "cache+seq"},
244
278
  "position_ids": {0: "batch", 1: "seq_length"},
245
279
  "past_key_values": [
246
- [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
247
- [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
280
+ {0: "batch", 2: "cache_length"},
281
+ {0: "batch", 2: "cache_length"},
282
+ {0: "batch", 2: "cache_length"},
283
+ {0: "batch", 2: "cache_length"},
248
284
  ],
249
285
  },
250
286
  )
@@ -252,68 +288,9 @@ def make_fake_with_dynamic_dimensions(
252
288
  """
253
289
  if x is None:
254
290
  return None, None
255
- if fake_mode is None:
256
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
257
- from torch._subclasses.fake_tensor import FakeTensorMode
291
+ if context is None:
292
+ from ..helpers.fake_tensor_helper import FakeTensorContext
258
293
 
259
- shape_env = ShapeEnv()
260
- fake_mode = FakeTensorMode(shape_env=shape_env)
294
+ context = FakeTensorContext()
261
295
 
262
- if isinstance(x, (list, tuple)):
263
- return (
264
- x.__class__(
265
- [
266
- make_fake_with_dynamic_dimensions(
267
- i, fake_mode=fake_mode, dynamic_shapes=ds
268
- )[0]
269
- for i, ds in zip(x, dynamic_shapes)
270
- ]
271
- ),
272
- fake_mode,
273
- )
274
- if isinstance(x, dict):
275
- return {
276
- k: make_fake_with_dynamic_dimensions(
277
- v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k]
278
- )[0]
279
- for k, v in x.items()
280
- }, fake_mode
281
-
282
- if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
283
- assert hasattr(x, "layers"), (
284
- f"Une more recent version of transformers (>=4.55), "
285
- f"'layers' not found in class {type(x)}"
286
- )
287
- assert (
288
- isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2
289
- ), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
290
- for il, layer in enumerate(x.layers):
291
- assert hasattr(layer, "keys") and hasattr(layer, "values"), (
292
- f"Une more recent version of transformers (>=4.55), 'layers' "
293
- f"not found in class {type(layer)} ({dir(layer)})"
294
- )
295
- layer.keys = make_fake_with_dynamic_dimensions(
296
- layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il]
297
- )[0]
298
- layer.values = make_fake_with_dynamic_dimensions(
299
- layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il]
300
- )[0]
301
- return x, fake_mode
302
- if x.__class__.__name__ == "EncoderDecoderCache":
303
- make_fake_with_dynamic_dimensions(
304
- x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0]
305
- )
306
- make_fake_with_dynamic_dimensions(
307
- x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1]
308
- )
309
- return x, fake_mode
310
- if hasattr(x, "shape"):
311
- t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode)
312
- assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
313
- assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
314
- return t, fake_mode
315
- from ..helpers import string_type
316
-
317
- raise TypeError(
318
- f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
319
- )
296
+ return context.make_fake_with_dynamic_dimensions(x, dynamic_shapes), context
@@ -630,6 +630,17 @@ def has_onnxruntime_training(push_back_batch: bool = False):
630
630
  return True
631
631
 
632
632
 
633
+ def has_onnxruntime_genai():
634
+ """Tells if onnxruntime_genai is installed."""
635
+ try:
636
+ import onnxruntime_genai # noqa: F401
637
+
638
+ return True
639
+ except ImportError:
640
+ # onnxruntime not training
641
+ return False
642
+
643
+
633
644
  def requires_onnxruntime_training(
634
645
  push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
635
646
  ) -> Callable:
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, List, Optional, Tuple
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
2
  import packaging.version as pv
3
3
  import torch
4
4
  import transformers
@@ -46,9 +46,14 @@ class CacheKeyValue:
46
46
  raise NotImplementedError(f"type(cache)={type(cache)}")
47
47
 
48
48
  def make_dynamic_cache(self):
49
- """Do the reverse operation."""
49
+ """Does the reverse operation."""
50
50
  return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
51
51
 
52
+ @property
53
+ def n_layers(self) -> int:
54
+ """Returns the number of layers."""
55
+ return len(self.key_cache) if self.key_cache else 0
56
+
52
57
 
53
58
  def flatten_unflatten_for_dynamic_shapes(
54
59
  obj: Any,
@@ -134,10 +139,31 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
134
139
  return len(cache2.key_cache) == len(cache.value_cache)
135
140
 
136
141
 
142
+ def make_dynamic_shapes_kv_cache(
143
+ cache: transformers.cache_utils.Cache, shape_of_one: Dict[int, Any]
144
+ ) -> List[Dict[int, Any]]:
145
+ """
146
+ Returns the dynamic shapes for key-value cache
147
+
148
+ :param cache: a cache
149
+ :param shape_of_one: shape of one element
150
+ :return: dynamic shapes
151
+ """
152
+ return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)]
153
+
154
+
155
+ def _preprocess_key_value_pairs(
156
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
157
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
158
+ if not key_value_pairs or isinstance(key_value_pairs[0], tuple):
159
+ return key_value_pairs
160
+ return list(zip(key_value_pairs[::2], key_value_pairs[1::2]))
161
+
162
+
137
163
  if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
138
164
 
139
165
  def make_dynamic_cache(
140
- key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
166
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
141
167
  ) -> transformers.cache_utils.DynamicCache:
142
168
  """
143
169
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -173,6 +199,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
173
199
  ``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
174
200
  are supported.
175
201
  """
202
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
176
203
  if (
177
204
  key_value_pairs
178
205
  and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
@@ -212,7 +239,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
212
239
  else:
213
240
 
214
241
  def make_dynamic_cache(
215
- key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
242
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
216
243
  ) -> transformers.cache_utils.DynamicCache:
217
244
  """
218
245
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -244,6 +271,7 @@ else:
244
271
  )
245
272
  print(string_type(past_key_values, with_shape=True))
246
273
  """
274
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
247
275
  cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
248
276
  for i, (key, value) in enumerate(key_value_pairs):
249
277
  cache.update(key, value, i)
@@ -251,7 +279,7 @@ else:
251
279
 
252
280
 
253
281
  def make_static_cache(
254
- key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
282
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
255
283
  max_cache_len: Optional[int] = None,
256
284
  ) -> transformers.cache_utils.DynamicCache:
257
285
  """
@@ -284,6 +312,7 @@ def make_static_cache(
284
312
  )
285
313
  print(string_type(past_key_values, with_shape=True))
286
314
  """
315
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
287
316
 
288
317
  class _config:
289
318
  def __init__(self):
@@ -426,9 +455,10 @@ def make_mamba_cache(
426
455
 
427
456
 
428
457
  def make_sliding_window_cache(
429
- key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
458
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
430
459
  ) -> transformers.cache_utils.SlidingWindowCache:
431
460
  "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
461
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
432
462
 
433
463
  class _config:
434
464
  def __init__(self):
@@ -481,7 +511,7 @@ def make_sliding_window_cache(
481
511
 
482
512
 
483
513
  def make_hybrid_cache(
484
- key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
514
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
485
515
  max_cache_len: Optional[int] = None,
486
516
  max_batch_size: Optional[int] = None,
487
517
  sliding_window: Optional[int] = None,
@@ -566,6 +596,7 @@ def make_hybrid_cache(
566
596
  self.key_cache.append(new_layer_key_cache)
567
597
  self.value_cache.append(new_layer_value_cache)
568
598
  """
599
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
569
600
  layer_types = None
570
601
  if key_value_pairs:
571
602
  assert (