onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +56 -3
  3. onnx_diagnostic/export/dynamic_shapes.py +24 -10
  4. onnx_diagnostic/export/shape_helper.py +6 -2
  5. onnx_diagnostic/ext_test_case.py +2 -0
  6. onnx_diagnostic/helpers/_log_helper.py +6 -6
  7. onnx_diagnostic/helpers/cache_helper.py +326 -18
  8. onnx_diagnostic/helpers/config_helper.py +10 -0
  9. onnx_diagnostic/helpers/helper.py +152 -11
  10. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  11. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  12. onnx_diagnostic/helpers/torch_helper.py +33 -11
  13. onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
  14. onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
  15. onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
  16. onnx_diagnostic/tasks/__init__.py +2 -0
  17. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  18. onnx_diagnostic/tasks/feature_extraction.py +7 -3
  19. onnx_diagnostic/tasks/fill_mask.py +6 -2
  20. onnx_diagnostic/tasks/image_classification.py +6 -2
  21. onnx_diagnostic/tasks/image_text_to_text.py +289 -62
  22. onnx_diagnostic/tasks/mask_generation.py +143 -0
  23. onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
  24. onnx_diagnostic/tasks/object_detection.py +6 -2
  25. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  26. onnx_diagnostic/tasks/summarization.py +7 -2
  27. onnx_diagnostic/tasks/text2text_generation.py +7 -2
  28. onnx_diagnostic/tasks/text_classification.py +6 -2
  29. onnx_diagnostic/tasks/text_generation.py +14 -16
  30. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
  31. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  32. onnx_diagnostic/torch_export_patches/patch_inputs.py +5 -2
  33. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
  34. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +428 -129
  35. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +60 -41
  36. onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
  37. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  38. onnx_diagnostic/torch_models/validate.py +1 -0
  39. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
  40. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +43 -42
  41. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.5"
6
+ __version__ = "0.7.7"
7
7
  __author__ = "Xavier Dupré"
@@ -306,7 +306,7 @@ class _ParseDict(argparse.Action):
306
306
  value = split_items[1]
307
307
 
308
308
  if value in ("True", "true", "False", "false"):
309
- d[key] = bool(value)
309
+ d[key] = value in ("True", "true")
310
310
  continue
311
311
  try:
312
312
  d[key] = int(value)
@@ -323,6 +323,54 @@ class _ParseDict(argparse.Action):
323
323
  setattr(namespace, self.dest, d)
324
324
 
325
325
 
326
+ class _BoolOrParseDictPatch(argparse.Action):
327
+ def __call__(self, parser, namespace, values, option_string=None):
328
+
329
+ if not values:
330
+ return
331
+ if len(values) == 1 and values[0] in (
332
+ "True",
333
+ "False",
334
+ "true",
335
+ "false",
336
+ "0",
337
+ "1",
338
+ 0,
339
+ 1,
340
+ ):
341
+ setattr(namespace, self.dest, values[0] in ("True", "true", 1, "1"))
342
+ return
343
+ d = getattr(namespace, self.dest) or {}
344
+ if not isinstance(d, dict):
345
+ d = {
346
+ "patch_sympy": d,
347
+ "patch_torch": d,
348
+ "patch_transformers": d,
349
+ "patch_diffusers": d,
350
+ }
351
+ for item in values:
352
+ split_items = item.split("=", 1)
353
+ key = split_items[0].strip() # we remove blanks around keys, as is logical
354
+ value = split_items[1]
355
+
356
+ if value in ("True", "true", "False", "false"):
357
+ d[key] = value in ("True", "true")
358
+ continue
359
+ try:
360
+ d[key] = int(value)
361
+ continue
362
+ except (TypeError, ValueError):
363
+ pass
364
+ try:
365
+ d[key] = float(value)
366
+ continue
367
+ except (TypeError, ValueError):
368
+ pass
369
+ d[key] = _parse_json(value)
370
+
371
+ setattr(namespace, self.dest, d)
372
+
373
+
326
374
  def get_parser_validate() -> ArgumentParser:
327
375
  parser = ArgumentParser(
328
376
  prog="validate",
@@ -383,8 +431,13 @@ def get_parser_validate() -> ArgumentParser:
383
431
  parser.add_argument(
384
432
  "--patch",
385
433
  default=True,
386
- action=BooleanOptionalAction,
387
- help="Applies patches before exporting.",
434
+ action=_BoolOrParseDictPatch,
435
+ nargs="*",
436
+ help="Applies patches before exporting, it can be a boolean "
437
+ "to enable to disable the patches or be more finetuned. It is possible to "
438
+ "disable patch for torch by adding "
439
+ '--patch "patch_sympy=False" --patch "patch_torch=False", '
440
+ "default is True.",
388
441
  )
389
442
  parser.add_argument(
390
443
  "--rewrite",
@@ -887,19 +887,30 @@ class ModelInputs:
887
887
 
888
888
  # In case DynamicCache is not registered.
889
889
  if obj.__class__.__name__ == "DynamicCache":
890
- kc = set(len(o.key_cache) for o in objs)
891
- assert (
892
- len(kc) == 1
893
- ), f"All attribute 'key_cache' should have the same length but found {kc}"
894
- vc = set(len(o.value_cache) for o in objs)
895
- assert (
896
- len(vc) == 1
897
- ), f"All attribute 'value_cache' should have the same length but found {vc}"
890
+ if hasattr(obj, "layers"):
891
+ kc = set(len(o.layers) for o in objs)
892
+ assert (
893
+ len(kc) == 1
894
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
895
+ vc = kc.copy()
896
+ else:
897
+ kc = set(len(o.key_cache) for o in objs)
898
+ assert (
899
+ len(kc) == 1
900
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
901
+ vc = set(len(o.value_cache) for o in objs)
902
+ assert (
903
+ len(vc) == 1
904
+ ), f"All attribute 'value_cache' should have the same length but found {vc}"
905
+
898
906
  key_cache = []
899
907
  for i in range(kc.pop()):
900
908
  key_cache.append(
901
909
  self.guess_dynamic_dimensions(
902
- *[o.key_cache[i] for o in objs],
910
+ *[
911
+ o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
912
+ for o in objs
913
+ ],
903
914
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
904
915
  )
905
916
  )
@@ -907,7 +918,10 @@ class ModelInputs:
907
918
  for i in range(vc.pop()):
908
919
  value_cache.append(
909
920
  self.guess_dynamic_dimensions(
910
- *[o.value_cache[i] for o in objs],
921
+ *[
922
+ o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
923
+ for o in objs
924
+ ],
911
925
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
912
926
  )
913
927
  )
@@ -9,6 +9,8 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
9
9
  All dimensions are considered as dynamic.
10
10
  ``dim_prefix`` can be a string (the function uses it as a prefix),
11
11
  or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12
+ Depending on the version of transformers, serializations function
13
+ of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
12
14
 
13
15
  .. runpython::
14
16
  :showcode:
@@ -17,6 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
17
19
  import torch
18
20
  from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19
21
  from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
22
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
20
23
 
21
24
  bsize, nheads, slen, dim = 2, 1, 30, 96
22
25
  inputs = dict(
@@ -25,10 +28,11 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
25
28
  position_ids=torch.arange(3, dtype=torch.int64),
26
29
  past_key_values=make_dynamic_cache(
27
30
  [(torch.randn(bsize, nheads, slen, dim),
28
- torch.randn(bsize, nheads, slen, dim))]
31
+ torch.randn(bsize, nheads, slen, dim))]
29
32
  ),
30
33
  )
31
- ds = all_dynamic_shape_from_inputs(inputs)
34
+ with torch_export_patches(patch_transformers=True):
35
+ ds = all_dynamic_shape_from_inputs(inputs)
32
36
  pprint.pprint(ds)
33
37
 
34
38
  For this function to work, patches must be enabled if :epkg:`transformers`
@@ -1058,6 +1058,8 @@ class ExtTestCase(unittest.TestCase):
1058
1058
  elif hasattr(expected, "shape"):
1059
1059
  self.assertEqual(type(expected), type(value), msg=msg)
1060
1060
  self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
1061
+ elif expected is None:
1062
+ assert value is None, f"Expected is None but value is of type {type(value)}"
1061
1063
  else:
1062
1064
  raise AssertionError(
1063
1065
  f"Comparison not implemented for types {type(expected)} and {type(value)}"
@@ -33,13 +33,13 @@ def mann_kendall(series: Sequence[float], threshold: float = 0.5):
33
33
  .. math::
34
34
 
35
35
  sign(x) = \\left\\{ \\begin{array}{l} -1 if x < 0 \\\\ 0 if x = 0 \\\\ +1 otherwise
36
- \\right.
36
+ \\end{array} \\right.
37
37
 
38
38
  And:
39
39
 
40
40
  .. math::
41
41
 
42
- Var(S)= \\frac{n(n-1)(2n+5)} - \\sum_t t(t-1)(2t+5)}{18}
42
+ Var(S)= \\frac{n(n-1)(2n+5) - \\sum_t t(t-1)(2t+5)}{18}
43
43
  """
44
44
  aseries = np.asarray(series)
45
45
  stat = 0
@@ -251,7 +251,7 @@ def open_dataframe(
251
251
  ) -> pandas.DataFrame:
252
252
  """
253
253
  Opens a filename defined by function
254
- :func:`onnx_diagnostic.helpers.log_helper.enumerate_csv_files`.
254
+ :func:`onnx_diagnostic.helpers._log_helper.enumerate_csv_files`.
255
255
 
256
256
  :param data: a dataframe, a filename, a tuple indicating the file is coming
257
257
  from a zip file
@@ -260,17 +260,17 @@ def open_dataframe(
260
260
  if isinstance(data, pandas.DataFrame):
261
261
  return data
262
262
  if isinstance(data, str):
263
- df = pandas.read_csv(data)
263
+ df = pandas.read_csv(data, low_memory=False)
264
264
  df["RAWFILENAME"] = data
265
265
  return df
266
266
  if isinstance(data, tuple):
267
267
  if not data[-1]:
268
- df = pandas.read_csv(data[2])
268
+ df = pandas.read_csv(data[2], low_memory=False)
269
269
  df["RAWFILENAME"] = data[2]
270
270
  return df
271
271
  zf = zipfile.ZipFile(data[-1])
272
272
  with zf.open(data[2]) as f:
273
- df = pandas.read_csv(f)
273
+ df = pandas.read_csv(f, low_memory=False)
274
274
  df["RAWFILENAME"] = f"{data[-1]}/{data[2]}"
275
275
  zf.close()
276
276
  return df
@@ -4,6 +4,56 @@ import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
+ try:
8
+ from transformers.models.mamba.modeling_mamba import MambaCache
9
+ except ImportError:
10
+ from transformers.cache_utils import MambaCache
11
+
12
+
13
+ class CacheKeyValue:
14
+ """
15
+ Starting transformers>=4.54, the cache API has deprecated
16
+ ``cache.key_cache`` and ``cache.value_cache``.
17
+ This class wraps a cache independently from transformers version and enables
18
+ attributes ``key_cache`` and ``value_cache``.
19
+
20
+ .. code-block:: python
21
+
22
+ capi = CacheKeyValue(cache)
23
+ capi.key_cache
24
+ capi.value_cache
25
+ """
26
+
27
+ def __init__(self, cache=None):
28
+ if hasattr(cache, "layers"):
29
+ layers = [
30
+ layer
31
+ for layer in cache.layers
32
+ if layer is not None and layer.keys is not None and layer.values is not None
33
+ ]
34
+ self.key_cache = [layer.keys for layer in layers]
35
+ self.value_cache = [layer.values for layer in layers]
36
+ if None in self.key_cache or None in self.value_cache:
37
+ from .helper import string_type
38
+
39
+ raise AssertionError(
40
+ f"issue with key_cache={string_type(self.key_cache)}, "
41
+ f"or value_cache={string_type(self.value_cache)}, "
42
+ f"cache.layers={string_type(cache.layers)}"
43
+ )
44
+ elif cache is not None and hasattr(cache, "key_cache"):
45
+ self.key_cache = cache.key_cache
46
+ self.value_cache = cache.value_cache
47
+ elif cache is None:
48
+ self.key_cache = None
49
+ self.value_cache = None
50
+ else:
51
+ raise NotImplementedError(f"type(cache)={type(cache)}")
52
+
53
+ def make_dynamic_cache(self):
54
+ """Do the reverse operation."""
55
+ return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
56
+
7
57
 
8
58
  def flatten_unflatten_for_dynamic_shapes(
9
59
  obj: Any,
@@ -81,6 +131,8 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
81
131
  )
82
132
  values, spec = torch.utils._pytree.tree_flatten(cache)
83
133
  cache2 = torch.utils._pytree.tree_unflatten(values, spec)
134
+ if hasattr(cache2, "layers") and hasattr(cache, "layers"):
135
+ return len(cache2.layers) == len(cache.layers)
84
136
  return len(cache2.key_cache) == len(cache.value_cache)
85
137
 
86
138
 
@@ -119,7 +171,19 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
119
171
  )
120
172
  print(string_type(past_key_values, with_shape=True))
121
173
  """
122
- return transformers.cache_utils.DynamicCache(key_value_pairs)
174
+ cache = transformers.cache_utils.DynamicCache(key_value_pairs)
175
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
176
+ # The cache constructor contains the two following lines
177
+ # (in cache_utils.py) which append empty layers when the cache is
178
+ # initialized. We need to remove them.
179
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
180
+ # self.append_new_layers(self.num_hidden_layers - 1)
181
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
182
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
183
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
184
+ f"{len(key_value_pairs)} expected."
185
+ )
186
+ return finalize_cache(cache)
123
187
 
124
188
  else:
125
189
 
@@ -203,6 +267,9 @@ def make_static_cache(
203
267
  self.num_attention_heads = key_value_pairs[0][0].shape[1]
204
268
  self.num_hidden_layers = len(key_value_pairs)
205
269
 
270
+ def get_text_config(self):
271
+ return self
272
+
206
273
  assert max_cache_len is not None, (
207
274
  f"max_cache_len={max_cache_len} cannot be setup "
208
275
  f"automatically yet from shape {key_value_pairs[0][0].shape}"
@@ -216,20 +283,59 @@ def make_static_cache(
216
283
  ),
217
284
  )
218
285
  cache = transformers.cache_utils.StaticCache(
219
- _config(),
286
+ config=_config(),
220
287
  max_batch_size=key_value_pairs[0][0].shape[0],
221
288
  device=key_value_pairs[0][0].device,
222
289
  dtype=key_value_pairs[0][0].dtype,
223
290
  max_cache_len=max_cache_len,
224
291
  )
292
+ ca = CacheKeyValue(cache)
293
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
294
+ # transformers>= 4.55.2, layers are empty
295
+ for i, (key, value) in enumerate(key_value_pairs):
296
+ cache.update(key, value, i)
297
+ return cache
298
+
299
+ torch._check(
300
+ not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers),
301
+ lambda: (
302
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
303
+ f"len(cache.layers)={len(cache.layers)}"
304
+ ),
305
+ )
306
+ torch._check(
307
+ len(key_value_pairs) == len(ca.key_cache),
308
+ lambda: (
309
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
310
+ f"len(ca.key_cache)={len(ca.key_cache)}"
311
+ ),
312
+ )
313
+ torch._check(
314
+ len(key_value_pairs) == len(ca.value_cache),
315
+ lambda: (
316
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
317
+ f"len(ca.value_cache)={len(ca.value_cache)}"
318
+ ),
319
+ )
225
320
  for i in range(len(key_value_pairs)):
226
321
  assert (
227
322
  key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
228
323
  ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
229
324
  d = key_value_pairs[i][1].shape[2]
230
- cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
231
- cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
232
- return cache
325
+ ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
326
+ ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
327
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
328
+ # The cache constructor contains the two following lines
329
+ # (in cache_utils.py) which append empty layers when the cache is
330
+ # initialized. We need to remove them.
331
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
332
+ # self.append_new_layers(self.num_hidden_layers - 1)
333
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
334
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
335
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
336
+ f"{len(key_value_pairs)} expected."
337
+ )
338
+ return finalize_cache(cache)
233
339
 
234
340
 
235
341
  def make_encoder_decoder_cache(
@@ -238,14 +344,15 @@ def make_encoder_decoder_cache(
238
344
  ) -> transformers.cache_utils.EncoderDecoderCache:
239
345
  """Creates an EncoderDecoderCache."""
240
346
  return transformers.cache_utils.EncoderDecoderCache(
241
- self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
347
+ # self_attention_cache=self_attention_cache,
348
+ # cross_attention_cache=cross_attention_cache
349
+ self_attention_cache,
350
+ cross_attention_cache,
242
351
  )
243
352
 
244
353
 
245
- def make_mamba_cache(
246
- key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
247
- ) -> transformers.cache_utils.MambaCache:
248
- "Creates a :class:`transformers.cache_utils.MambaCache`."
354
+ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
355
+ "Creates a ``MambaCache``."
249
356
  dtype = key_value_pairs[0][0].dtype
250
357
 
251
358
  class _config:
@@ -256,7 +363,10 @@ def make_mamba_cache(
256
363
  self.num_hidden_layers = len(key_value_pairs)
257
364
  self.dtype = dtype
258
365
 
259
- cache = transformers.cache_utils.MambaCache(
366
+ def get_text_config(self):
367
+ return self
368
+
369
+ cache = MambaCache(
260
370
  _config(),
261
371
  max_batch_size=key_value_pairs[0][0].shape[0],
262
372
  device=key_value_pairs[0][0].device,
@@ -281,12 +391,12 @@ def make_mamba_cache(
281
391
  f"got {key_value_pairs[i][1].shape}"
282
392
  )
283
393
  cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
284
- return cache
394
+ return finalize_cache(cache)
285
395
 
286
396
 
287
397
  def make_sliding_window_cache(
288
398
  key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
289
- ) -> transformers.cache_utils.MambaCache:
399
+ ) -> transformers.cache_utils.SlidingWindowCache:
290
400
  "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
291
401
 
292
402
  class _config:
@@ -296,22 +406,220 @@ def make_sliding_window_cache(
296
406
  self.num_hidden_layers = len(key_value_pairs)
297
407
  self.sliding_window = key_value_pairs[0][0].shape[2]
298
408
 
409
+ def get_text_config(self):
410
+ return self
411
+
299
412
  cache = transformers.cache_utils.SlidingWindowCache(
300
- _config(),
413
+ config=_config(),
301
414
  max_batch_size=key_value_pairs[0][0].shape[0],
302
415
  max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
303
416
  device=key_value_pairs[0][0].device,
304
417
  dtype=key_value_pairs[0][0].dtype,
305
418
  )
419
+ ca = CacheKeyValue(cache)
420
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
421
+ # transformers>= 4.55.2, layers are empty
422
+ cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
423
+ for i, (key, value) in enumerate(key_value_pairs):
424
+ cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
425
+ return cache
426
+
306
427
  for i in range(len(key_value_pairs)):
307
- assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
428
+ assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
308
429
  f"Shape mismatch, expected {cache.key_cache[i].shape}, "
309
430
  f"got {key_value_pairs[i][0].shape}"
310
431
  )
311
- cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
312
- assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
432
+ ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
433
+ assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
313
434
  f"Shape mismatch, expected {cache.value_cache[i].shape}, "
314
435
  f"got {key_value_pairs[i][1].shape}"
315
436
  )
316
- cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
437
+ ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
438
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
439
+ # The cache constructor contains the two following lines
440
+ # (in cache_utils.py) which append empty layers when the cache is
441
+ # initialized. We need to remove them.
442
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
443
+ # self.append_new_layers(self.num_hidden_layers - 1)
444
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
445
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
446
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
447
+ f"{len(key_value_pairs)} expected."
448
+ )
449
+ return finalize_cache(cache)
450
+
451
+
452
+ def make_hybrid_cache(
453
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
454
+ max_cache_len: Optional[int] = None,
455
+ max_batch_size: Optional[int] = None,
456
+ sliding_window: Optional[int] = None,
457
+ ) -> transformers.cache_utils.HybridCache:
458
+ """
459
+ Creates an instance of :class:`transformers.cache_utils.HybridCache`.
460
+ This version is valid for ``transformers < 4.50``.
461
+
462
+ :param key_value_pairs: list of pairs of (key, values)
463
+ :return: :class:`transformers.cache_utils.HybridCache`
464
+
465
+ Example:
466
+
467
+ .. runpython::
468
+ :showcode:
469
+
470
+ import torch
471
+ from onnx_diagnostic.helpers import string_type
472
+ from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
473
+
474
+ n_layers = 2
475
+ bsize, nheads, slen, dim = 2, 4, 3, 7
476
+
477
+ past_key_values = make_hybrid_cache(
478
+ [
479
+ (
480
+ torch.randn(bsize, nheads, slen, dim),
481
+ torch.randn(bsize, nheads, slen, dim),
482
+ )
483
+ for i in range(n_layers)
484
+ ]
485
+ )
486
+ print(string_type(past_key_values, with_shape=True))
487
+
488
+ This part defines how the shapes are working in one HybridCache.
489
+
490
+ .. code-block:: python
491
+
492
+ self.max_cache_len = (
493
+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
494
+
495
+ # Sliding layers can't be larger than the overall max cache len
496
+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
497
+ self.max_batch_size = max_batch_size
498
+
499
+ self.head_dim = (
500
+ config.head_dim if hasattr(config, "head_dim")
501
+ else config.hidden_size // config.num_attention_heads
502
+ )
503
+
504
+ self._dtype = dtype
505
+ self.num_key_value_heads = (
506
+ config.num_attention_heads
507
+ if getattr(config, "num_key_value_heads", None) is None
508
+ else config.num_key_value_heads
509
+ )
510
+
511
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
512
+ if hasattr(config, "layer_types"):
513
+ self.is_sliding = [
514
+ layer_type != "full_attention" for layer_type in config.layer_types]
515
+ else:
516
+ self.is_sliding = [False] * config.num_hidden_layers
517
+
518
+ self.key_cache: list[torch.Tensor] = []
519
+ self.value_cache: list[torch.Tensor] = []
520
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
521
+ self.max_cache_len, self.head_dim)
522
+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
523
+ self.sliding_window_len, self.head_dim)
524
+ self.sliding_window = min(config.sliding_window, max_cache_len)
525
+ device = torch.device(device) if device is not None else None
526
+ for i in range(config.num_hidden_layers):
527
+ layer_device = layer_device_map[i] if layer_device_map is not None else device
528
+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
529
+ new_layer_key_cache = torch.zeros(
530
+ cache_shape, dtype=self._dtype, device=layer_device)
531
+ new_layer_value_cache = torch.zeros(
532
+ cache_shape, dtype=self._dtype, device=layer_device)
533
+ torch._dynamo.mark_static_address(new_layer_key_cache)
534
+ torch._dynamo.mark_static_address(new_layer_value_cache)
535
+ self.key_cache.append(new_layer_key_cache)
536
+ self.value_cache.append(new_layer_value_cache)
537
+ """
538
+ layer_types = None
539
+ if key_value_pairs:
540
+ assert (
541
+ not max_batch_size and not max_cache_len
542
+ ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
543
+ max_batch_size = key_value_pairs[0][0].shape[0]
544
+ sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
545
+ if len(sets_of_dim) == 1:
546
+ max_cache_len = sets_of_dim.pop()
547
+ sliding_window = max_cache_len
548
+ else:
549
+ assert (
550
+ len(sets_of_dim) == 2
551
+ ), f"Not implemented for more than 2 dimensions {sets_of_dim}"
552
+ max_cache_len = max(sets_of_dim)
553
+ sliding_window = min(sets_of_dim)
554
+ layer_types = [
555
+ "full_attention" if i == max_cache_len else "sliding_attention"
556
+ for i in [kv[0].shape[2] for kv in key_value_pairs]
557
+ ]
558
+ else:
559
+ assert (
560
+ max_batch_size and max_cache_len
561
+ ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
562
+ if sliding_window is None:
563
+ sliding_window = max_cache_len
564
+ _max_cache_len = max_cache_len
565
+ _sliding_window = sliding_window
566
+
567
+ class _config:
568
+ max_cache_len = _max_cache_len
569
+ batch_size = max_batch_size
570
+ num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
571
+ head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
572
+ num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
573
+ num_hidden_layers = len(key_value_pairs)
574
+ sliding_window = _sliding_window
575
+ num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
576
+
577
+ def get_text_config(self):
578
+ return self
579
+
580
+ if layer_types:
581
+ _config.layer_types = layer_types # type: ignore[attr-defined]
582
+
583
+ cache = transformers.cache_utils.HybridCache(
584
+ config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
585
+ )
586
+ for i, (key, value) in enumerate(key_value_pairs):
587
+ cache.update(
588
+ key,
589
+ value,
590
+ i,
591
+ cache_kwargs={
592
+ "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
593
+ key.device
594
+ )
595
+ },
596
+ )
597
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
598
+ # The cache constructor contains the two following lines
599
+ # (in cache_utils.py) which append empty layers when the cache is
600
+ # initialized. We need to remove them.
601
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
602
+ # self.append_new_layers(self.num_hidden_layers - 1)
603
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
604
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
605
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
606
+ f"{len(key_value_pairs)} expected."
607
+ )
608
+ return finalize_cache(cache)
609
+
610
+
611
+ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
612
+ """
613
+ Ensures the created cache is consistent.
614
+ Returns the cache modified inplace.
615
+ """
616
+ if (
617
+ hasattr(cache, "layer_class_to_replicate")
618
+ and hasattr(cache, "layers")
619
+ and cache.layers
620
+ and not cache.layer_class_to_replicate
621
+ ):
622
+ # This is used to expand the cache when it does not contains enough layers.
623
+ # This is needed since transformers>4.55.3
624
+ cache.layer_class_to_replicate = cache.layers[0].__class__
317
625
  return cache
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import importlib
3
3
  import inspect
4
+ import os
4
5
  import re
5
6
  from typing import Any, Callable, Dict, Optional, Tuple, Union
6
7
  import transformers
@@ -110,3 +111,12 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
110
111
  )
111
112
  cls_name = unique.pop()
112
113
  return getattr(transformers, cls_name)
114
+
115
+
116
+ def default_num_hidden_layers():
117
+ """
118
+ Returns the default number of layers.
119
+ It is lower when the unit tests are running
120
+ when ``UNITTEST_GOING=1``.
121
+ """
122
+ return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4