onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,687 @@
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+ import packaging.version as pv
3
+ import torch
4
+ import transformers
5
+ import transformers.cache_utils
6
+
7
+
8
+ class CacheKeyValue:
9
+ """
10
+ Starting transformers>=4.54, the cache API has deprecated
11
+ ``cache.key_cache`` and ``cache.value_cache``.
12
+ This class wraps a cache independently from transformers version and enables
13
+ attributes ``key_cache`` and ``value_cache``.
14
+
15
+ .. code-block:: python
16
+
17
+ capi = CacheKeyValue(cache)
18
+ capi.key_cache
19
+ capi.value_cache
20
+ """
21
+
22
+ def __init__(self, cache=None):
23
+ if hasattr(cache, "layers"):
24
+ layers = [
25
+ layer
26
+ for layer in cache.layers
27
+ if layer is not None and layer.keys is not None and layer.values is not None
28
+ ]
29
+ self.key_cache = [layer.keys for layer in layers]
30
+ self.value_cache = [layer.values for layer in layers]
31
+ if None in self.key_cache or None in self.value_cache:
32
+ from .helper import string_type
33
+
34
+ raise AssertionError(
35
+ f"issue with key_cache={string_type(self.key_cache)}, "
36
+ f"or value_cache={string_type(self.value_cache)}, "
37
+ f"cache.layers={string_type(cache.layers)}"
38
+ )
39
+ elif cache is not None and hasattr(cache, "key_cache"):
40
+ self.key_cache = cache.key_cache
41
+ self.value_cache = cache.value_cache
42
+ elif cache is None:
43
+ self.key_cache = None
44
+ self.value_cache = None
45
+ else:
46
+ raise NotImplementedError(f"type(cache)={type(cache)}")
47
+
48
+ def make_dynamic_cache(self):
49
+ """Does the reverse operation."""
50
+ return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
51
+
52
+ @property
53
+ def n_layers(self) -> int:
54
+ """Returns the number of layers."""
55
+ return len(self.key_cache) if self.key_cache else 0
56
+
57
+
58
+ def flatten_unflatten_for_dynamic_shapes(
59
+ obj: Any,
60
+ use_dict: bool = False,
61
+ change_function: Optional[Callable[[torch.Tensor], Any]] = None,
62
+ ) -> Any:
63
+ """
64
+ Returns the object in a different structure similar to what
65
+ the definition of the dynamic shapes should use.
66
+
67
+ :param obj: object from a custom class
68
+ :param use_dict: closer to the original result but
69
+ :func:`torch.export.export` only considers the values,
70
+ the context gives the dictionary keys but it is not expressed
71
+ in the dynamic shapes, these specifications seems to be different
72
+ for the strict and non strict mode. It also preserves tuple.
73
+ :param change_function: to modifies the tensor in the structure itself,
74
+ like replace them by a shape
75
+ :return: the serialized object
76
+ """
77
+ if isinstance(obj, torch.Tensor):
78
+ return change_function(obj) if change_function else obj
79
+ flat, spec = torch.utils._pytree.tree_flatten(obj)
80
+ start = 0
81
+ end = 0
82
+ subtrees = []
83
+ for subspec in spec.children_specs:
84
+ end += subspec.num_leaves
85
+ value = subspec.unflatten(flat[start:end])
86
+ value = flatten_unflatten_for_dynamic_shapes(
87
+ value, use_dict=use_dict, change_function=change_function
88
+ )
89
+ subtrees.append(value)
90
+ start = end
91
+ if use_dict:
92
+ if spec.type is dict:
93
+ # This a dictionary.
94
+ return dict(zip(spec.context, subtrees))
95
+ if spec.type is tuple:
96
+ return tuple(subtrees)
97
+ if spec.type is list:
98
+ return list(subtrees)
99
+ if spec.type is None and not subtrees:
100
+ return None
101
+ if spec.context:
102
+ # This is a custom class with attributes.
103
+ # It is returned as a list.
104
+ return list(subtrees)
105
+ raise ValueError(
106
+ f"Unable to interpret spec type {spec.type} "
107
+ f"(type is {type(spec.type)}, context is {spec.context}), "
108
+ f"spec={spec}, subtrees={subtrees}"
109
+ )
110
+ # This is a list.
111
+ return subtrees
112
+
113
+
114
+ def is_cache_dynamic_registered(fast: bool = False) -> bool:
115
+ """
116
+ Tells if class :class:`transformers.cache_utils.DynamicCache` can be
117
+ serialized and deserialized. Only then, :func:`torch.export.export`
118
+ can export a model.
119
+
120
+ :param fast: if True, do not check the serialization is ok as well
121
+ :return: result
122
+ """
123
+ if fast:
124
+ return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES
125
+ bsize, nheads, slen, dim = 2, 4, 3, 7
126
+ cache = make_dynamic_cache(
127
+ [
128
+ (
129
+ torch.randn(bsize, nheads, slen, dim),
130
+ torch.randn(bsize, nheads, slen, dim),
131
+ )
132
+ for i in range(2)
133
+ ]
134
+ )
135
+ values, spec = torch.utils._pytree.tree_flatten(cache)
136
+ cache2 = torch.utils._pytree.tree_unflatten(values, spec)
137
+ if hasattr(cache2, "layers") and hasattr(cache, "layers"):
138
+ return len(cache2.layers) == len(cache.layers)
139
+ return len(cache2.key_cache) == len(cache.value_cache)
140
+
141
+
142
+ def make_dynamic_shapes_kv_cache(
143
+ cache: transformers.cache_utils.Cache, shape_of_one: Dict[int, Any]
144
+ ) -> List[Dict[int, Any]]:
145
+ """
146
+ Returns the dynamic shapes for key-value cache
147
+
148
+ :param cache: a cache
149
+ :param shape_of_one: shape of one element
150
+ :return: dynamic shapes
151
+ """
152
+ return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)]
153
+
154
+
155
+ def _preprocess_key_value_pairs(
156
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
157
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
158
+ if not key_value_pairs or isinstance(key_value_pairs[0], tuple):
159
+ return key_value_pairs
160
+ return list(zip(key_value_pairs[::2], key_value_pairs[1::2]))
161
+
162
+
163
+ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
164
+
165
+ def make_dynamic_cache(
166
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
167
+ ) -> transformers.cache_utils.DynamicCache:
168
+ """
169
+ Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
170
+ This version is valid for ``transformers >= 4.50``.
171
+
172
+ :param key_value_pairs: list of pairs of (key, values)
173
+ :return: :class:`transformers.cache_utils.DynamicCache`
174
+
175
+ Example:
176
+
177
+ .. runpython::
178
+ :showcode:
179
+
180
+ import torch
181
+ from onnx_diagnostic.helpers import string_type
182
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
183
+
184
+ n_layers = 2
185
+ bsize, nheads, slen, dim = 2, 4, 3, 7
186
+
187
+ past_key_values = make_dynamic_cache(
188
+ [
189
+ (
190
+ torch.randn(bsize, nheads, slen, dim),
191
+ torch.randn(bsize, nheads, slen, dim),
192
+ )
193
+ for i in range(n_layers)
194
+ ]
195
+ )
196
+ print(string_type(past_key_values, with_shape=True))
197
+
198
+ The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
199
+ ``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
200
+ are supported.
201
+ """
202
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
203
+ if (
204
+ key_value_pairs
205
+ and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
206
+ and pv.Version(transformers.__version__) >= pv.Version("4.56")
207
+ ):
208
+ cache = transformers.cache_utils.DynamicCache()
209
+ cache.layers.extend(
210
+ [transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
211
+ )
212
+ for i, layer in enumerate(cache.layers):
213
+ k, v = key_value_pairs[i][0], key_value_pairs[i][1]
214
+ layer.dtype = k.dtype
215
+ layer.device = k.device
216
+ layer.keys = k
217
+ layer.values = v
218
+ layer.is_initialized = True
219
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
220
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
221
+ f"{len(key_value_pairs)} expected."
222
+ )
223
+ return finalize_cache(cache)
224
+
225
+ cache = transformers.cache_utils.DynamicCache(key_value_pairs)
226
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
227
+ # The cache constructor contains the two following lines
228
+ # (in cache_utils.py) which append empty layers when the cache is
229
+ # initialized. We need to remove them.
230
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
231
+ # self.append_new_layers(self.num_hidden_layers - 1)
232
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
233
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
234
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
235
+ f"{len(key_value_pairs)} expected."
236
+ )
237
+ return finalize_cache(cache)
238
+
239
+ else:
240
+
241
+ def make_dynamic_cache(
242
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
243
+ ) -> transformers.cache_utils.DynamicCache:
244
+ """
245
+ Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
246
+ This version is valid for ``transformers < 4.50``.
247
+
248
+ :param key_value_pairs: list of pairs of (key, values)
249
+ :return: :class:`transformers.cache_utils.DynamicCache`
250
+
251
+ Example:
252
+
253
+ .. runpython::
254
+ :showcode:
255
+
256
+ import torch
257
+ from onnx_diagnostic.helpers import string_type
258
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
259
+
260
+ n_layers = 2
261
+ bsize, nheads, slen, dim = 2, 4, 3, 7
262
+
263
+ past_key_values = make_dynamic_cache(
264
+ [
265
+ (
266
+ torch.randn(bsize, nheads, slen, dim),
267
+ torch.randn(bsize, nheads, slen, dim),
268
+ )
269
+ for i in range(n_layers)
270
+ ]
271
+ )
272
+ print(string_type(past_key_values, with_shape=True))
273
+ """
274
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
275
+ cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
276
+ for i, (key, value) in enumerate(key_value_pairs):
277
+ cache.update(key, value, i)
278
+ return cache
279
+
280
+
281
+ def make_static_cache(
282
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
283
+ max_cache_len: Optional[int] = None,
284
+ ) -> transformers.cache_utils.DynamicCache:
285
+ """
286
+ Creates an instance of :class:`transformers.cache_utils.StaticCache`.
287
+ :param key_value_pairs: list of pairs of (key, values)
288
+ :param max_cache_len: max_cache_length or something inferred from the vector
289
+ :return: :class:`transformers.cache_utils.StaticCache`
290
+
291
+ Example:
292
+
293
+ .. runpython::
294
+ :showcode:
295
+
296
+ import torch
297
+ from onnx_diagnostic.helpers import string_type
298
+ from onnx_diagnostic.helpers.cache_helper import make_static_cache
299
+
300
+ n_layers = 2
301
+ bsize, nheads, slen, dim = 2, 4, 3, 7
302
+
303
+ past_key_values = make_static_cache(
304
+ [
305
+ (
306
+ torch.randn(bsize, nheads, slen, dim),
307
+ torch.randn(bsize, nheads, slen, dim),
308
+ )
309
+ for i in range(n_layers)
310
+ ],
311
+ max_cache_len=10,
312
+ )
313
+ print(string_type(past_key_values, with_shape=True))
314
+ """
315
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
316
+
317
+ class _config:
318
+ def __init__(self):
319
+ self.head_dim = key_value_pairs[0][0].shape[-1]
320
+ self.num_attention_heads = key_value_pairs[0][0].shape[1]
321
+ self.num_hidden_layers = len(key_value_pairs)
322
+
323
+ def get_text_config(self, *args, **kwargs):
324
+ return self
325
+
326
+ assert max_cache_len is not None, (
327
+ f"max_cache_len={max_cache_len} cannot be setup "
328
+ f"automatically yet from shape {key_value_pairs[0][0].shape}"
329
+ )
330
+ torch._check(
331
+ max_cache_len >= key_value_pairs[0][0].shape[2],
332
+ (
333
+ f"max_cache_len={max_cache_len} cannot be smaller "
334
+ f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
335
+ f"{key_value_pairs[0][0].shape}"
336
+ ),
337
+ )
338
+ cache = transformers.cache_utils.StaticCache(
339
+ config=_config(),
340
+ max_batch_size=key_value_pairs[0][0].shape[0],
341
+ device=key_value_pairs[0][0].device,
342
+ dtype=key_value_pairs[0][0].dtype,
343
+ max_cache_len=max_cache_len,
344
+ )
345
+ ca = CacheKeyValue(cache)
346
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
347
+ # transformers>= 4.55.2, layers are empty
348
+ for i, (key, value) in enumerate(key_value_pairs):
349
+ cache.update(key, value, i)
350
+ return cache
351
+
352
+ torch._check(
353
+ not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers),
354
+ lambda: (
355
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
356
+ f"len(cache.layers)={len(cache.layers)}"
357
+ ),
358
+ )
359
+ torch._check(
360
+ len(key_value_pairs) == len(ca.key_cache),
361
+ lambda: (
362
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
363
+ f"len(ca.key_cache)={len(ca.key_cache)}"
364
+ ),
365
+ )
366
+ torch._check(
367
+ len(key_value_pairs) == len(ca.value_cache),
368
+ lambda: (
369
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
370
+ f"len(ca.value_cache)={len(ca.value_cache)}"
371
+ ),
372
+ )
373
+ for i in range(len(key_value_pairs)):
374
+ assert (
375
+ key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
376
+ ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
377
+ d = key_value_pairs[i][1].shape[2]
378
+ ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
379
+ ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
380
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
381
+ # The cache constructor contains the two following lines
382
+ # (in cache_utils.py) which append empty layers when the cache is
383
+ # initialized. We need to remove them.
384
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
385
+ # self.append_new_layers(self.num_hidden_layers - 1)
386
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
387
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
388
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
389
+ f"{len(key_value_pairs)} expected."
390
+ )
391
+ return finalize_cache(cache)
392
+
393
+
394
+ def make_encoder_decoder_cache(
395
+ self_attention_cache: transformers.cache_utils.DynamicCache,
396
+ cross_attention_cache: transformers.cache_utils.DynamicCache,
397
+ ) -> transformers.cache_utils.EncoderDecoderCache:
398
+ """Creates an EncoderDecoderCache."""
399
+ return transformers.cache_utils.EncoderDecoderCache(
400
+ # self_attention_cache=self_attention_cache,
401
+ # cross_attention_cache=cross_attention_cache
402
+ self_attention_cache,
403
+ cross_attention_cache,
404
+ )
405
+
406
+
407
+ def make_mamba_cache(
408
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
409
+ ) -> "MambaCache": # noqa: F821
410
+ "Creates a ``MambaCache``."
411
+ # import is moved here because this part is slow.
412
+ try:
413
+ from transformers.models.mamba.modeling_mamba import MambaCache
414
+ except ImportError:
415
+ from transformers.cache_utils import MambaCache
416
+ dtype = key_value_pairs[0][0].dtype
417
+
418
+ class _config:
419
+ def __init__(self):
420
+ self.intermediate_size = key_value_pairs[0][0].shape[1]
421
+ self.conv_kernel = key_value_pairs[0][0].shape[-1]
422
+ self.state_size = key_value_pairs[0][1].shape[-1]
423
+ self.num_hidden_layers = len(key_value_pairs)
424
+ self.dtype = dtype
425
+
426
+ def get_text_config(self, *args, **kwargs):
427
+ return self
428
+
429
+ cache = MambaCache(
430
+ _config(),
431
+ max_batch_size=key_value_pairs[0][0].shape[0],
432
+ device=key_value_pairs[0][0].device,
433
+ dtype=dtype,
434
+ )
435
+ for i in range(len(key_value_pairs)):
436
+ assert cache.conv_states[i].dtype == dtype, (
437
+ f"Type mismatch for cache.conv_states[{i}].dtype="
438
+ f"{cache.conv_states[i].dtype} != {dtype}"
439
+ )
440
+ assert cache.ssm_states[i].dtype == dtype, (
441
+ f"Type mismatch for cache.ssm_states[{i}].dtype="
442
+ f"{cache.ssm_states[i].dtype} != {dtype}"
443
+ )
444
+ assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
445
+ f"Shape mismatch, expected {cache.conv_states[i].shape}, "
446
+ f"got {key_value_pairs[i][0].shape}"
447
+ )
448
+ cache.conv_states[i][:, :, :] = key_value_pairs[i][0]
449
+ assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, (
450
+ f"Shape mismatch, expected {cache.ssm_states[i].shape}, "
451
+ f"got {key_value_pairs[i][1].shape}"
452
+ )
453
+ cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
454
+ return finalize_cache(cache)
455
+
456
+
457
+ def make_sliding_window_cache(
458
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
459
+ ) -> transformers.cache_utils.SlidingWindowCache:
460
+ "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
461
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
462
+
463
+ class _config:
464
+ def __init__(self):
465
+ self.head_dim = key_value_pairs[0][0].shape[-1]
466
+ self.num_attention_heads = key_value_pairs[0][0].shape[1]
467
+ self.num_hidden_layers = len(key_value_pairs)
468
+ self.sliding_window = key_value_pairs[0][0].shape[2]
469
+
470
+ def get_text_config(self, *args, **kwargs):
471
+ return self
472
+
473
+ cache = transformers.cache_utils.SlidingWindowCache(
474
+ config=_config(),
475
+ max_batch_size=key_value_pairs[0][0].shape[0],
476
+ max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
477
+ device=key_value_pairs[0][0].device,
478
+ dtype=key_value_pairs[0][0].dtype,
479
+ )
480
+ ca = CacheKeyValue(cache)
481
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
482
+ # transformers>= 4.55.2, layers are empty
483
+ cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
484
+ for i, (key, value) in enumerate(key_value_pairs):
485
+ cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
486
+ return cache
487
+
488
+ for i in range(len(key_value_pairs)):
489
+ assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
490
+ f"Shape mismatch, expected {cache.key_cache[i].shape}, "
491
+ f"got {key_value_pairs[i][0].shape}"
492
+ )
493
+ ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
494
+ assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
495
+ f"Shape mismatch, expected {cache.value_cache[i].shape}, "
496
+ f"got {key_value_pairs[i][1].shape}"
497
+ )
498
+ ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
499
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
500
+ # The cache constructor contains the two following lines
501
+ # (in cache_utils.py) which append empty layers when the cache is
502
+ # initialized. We need to remove them.
503
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
504
+ # self.append_new_layers(self.num_hidden_layers - 1)
505
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
506
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
507
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
508
+ f"{len(key_value_pairs)} expected."
509
+ )
510
+ return finalize_cache(cache)
511
+
512
+
513
+ def make_hybrid_cache(
514
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
515
+ max_cache_len: Optional[int] = None,
516
+ max_batch_size: Optional[int] = None,
517
+ sliding_window: Optional[int] = None,
518
+ ) -> transformers.cache_utils.HybridCache:
519
+ """
520
+ Creates an instance of :class:`transformers.cache_utils.HybridCache`.
521
+ This version is valid for ``transformers < 4.50``.
522
+
523
+ :param key_value_pairs: list of pairs of (key, values)
524
+ :return: :class:`transformers.cache_utils.HybridCache`
525
+
526
+ Example:
527
+
528
+ .. runpython::
529
+ :showcode:
530
+
531
+ import torch
532
+ from onnx_diagnostic.helpers import string_type
533
+ from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
534
+
535
+ n_layers = 2
536
+ bsize, nheads, slen, dim = 2, 4, 3, 7
537
+
538
+ past_key_values = make_hybrid_cache(
539
+ [
540
+ (
541
+ torch.randn(bsize, nheads, slen, dim),
542
+ torch.randn(bsize, nheads, slen, dim),
543
+ )
544
+ for i in range(n_layers)
545
+ ]
546
+ )
547
+ print(string_type(past_key_values, with_shape=True))
548
+
549
+ This part defines how the shapes are working in one HybridCache.
550
+
551
+ .. code-block:: python
552
+
553
+ self.max_cache_len = (
554
+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
555
+
556
+ # Sliding layers can't be larger than the overall max cache len
557
+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
558
+ self.max_batch_size = max_batch_size
559
+
560
+ self.head_dim = (
561
+ config.head_dim if hasattr(config, "head_dim")
562
+ else config.hidden_size // config.num_attention_heads
563
+ )
564
+
565
+ self._dtype = dtype
566
+ self.num_key_value_heads = (
567
+ config.num_attention_heads
568
+ if getattr(config, "num_key_value_heads", None) is None
569
+ else config.num_key_value_heads
570
+ )
571
+
572
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
573
+ if hasattr(config, "layer_types"):
574
+ self.is_sliding = [
575
+ layer_type != "full_attention" for layer_type in config.layer_types]
576
+ else:
577
+ self.is_sliding = [False] * config.num_hidden_layers
578
+
579
+ self.key_cache: list[torch.Tensor] = []
580
+ self.value_cache: list[torch.Tensor] = []
581
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
582
+ self.max_cache_len, self.head_dim)
583
+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
584
+ self.sliding_window_len, self.head_dim)
585
+ self.sliding_window = min(config.sliding_window, max_cache_len)
586
+ device = torch.device(device) if device is not None else None
587
+ for i in range(config.num_hidden_layers):
588
+ layer_device = layer_device_map[i] if layer_device_map is not None else device
589
+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
590
+ new_layer_key_cache = torch.zeros(
591
+ cache_shape, dtype=self._dtype, device=layer_device)
592
+ new_layer_value_cache = torch.zeros(
593
+ cache_shape, dtype=self._dtype, device=layer_device)
594
+ torch._dynamo.mark_static_address(new_layer_key_cache)
595
+ torch._dynamo.mark_static_address(new_layer_value_cache)
596
+ self.key_cache.append(new_layer_key_cache)
597
+ self.value_cache.append(new_layer_value_cache)
598
+ """
599
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
600
+ layer_types = None
601
+ if key_value_pairs:
602
+ assert (
603
+ not max_batch_size and not max_cache_len
604
+ ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
605
+ max_batch_size = key_value_pairs[0][0].shape[0]
606
+ sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
607
+ if len(sets_of_dim) == 1:
608
+ max_cache_len = sets_of_dim.pop()
609
+ sliding_window = max_cache_len
610
+ else:
611
+ assert (
612
+ len(sets_of_dim) == 2
613
+ ), f"Not implemented for more than 2 dimensions {sets_of_dim}"
614
+ max_cache_len = max(sets_of_dim)
615
+ sliding_window = min(sets_of_dim)
616
+ layer_types = [
617
+ "full_attention" if i == max_cache_len else "sliding_attention"
618
+ for i in [kv[0].shape[2] for kv in key_value_pairs]
619
+ ]
620
+ else:
621
+ assert (
622
+ max_batch_size and max_cache_len
623
+ ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
624
+ if sliding_window is None:
625
+ sliding_window = max_cache_len
626
+ _max_cache_len = max_cache_len
627
+ _sliding_window = sliding_window
628
+
629
+ class _config:
630
+ max_cache_len = _max_cache_len
631
+ batch_size = max_batch_size
632
+ num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
633
+ head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
634
+ num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
635
+ num_hidden_layers = len(key_value_pairs)
636
+ sliding_window = _sliding_window
637
+ num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
638
+
639
+ def get_text_config(self, *args, **kwargs):
640
+ return self
641
+
642
+ if layer_types:
643
+ _config.layer_types = layer_types # type: ignore[attr-defined]
644
+
645
+ cache = transformers.cache_utils.HybridCache(
646
+ config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
647
+ )
648
+ for i, (key, value) in enumerate(key_value_pairs):
649
+ cache.update(
650
+ key,
651
+ value,
652
+ i,
653
+ cache_kwargs={
654
+ "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
655
+ key.device
656
+ )
657
+ },
658
+ )
659
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
660
+ # The cache constructor contains the two following lines
661
+ # (in cache_utils.py) which append empty layers when the cache is
662
+ # initialized. We need to remove them.
663
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
664
+ # self.append_new_layers(self.num_hidden_layers - 1)
665
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
666
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
667
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
668
+ f"{len(key_value_pairs)} expected."
669
+ )
670
+ return finalize_cache(cache)
671
+
672
+
673
+ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
674
+ """
675
+ Ensures the created cache is consistent.
676
+ Returns the cache modified inplace.
677
+ """
678
+ if (
679
+ hasattr(cache, "layer_class_to_replicate")
680
+ and hasattr(cache, "layers")
681
+ and cache.layers
682
+ and not cache.layer_class_to_replicate
683
+ ):
684
+ # This is used to expand the cache when it does not contains enough layers.
685
+ # This is needed since transformers>4.55.3
686
+ cache.layer_class_to_replicate = cache.layers[0].__class__
687
+ return cache