onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +66 -8
  3. onnx_diagnostic/ext_test_case.py +2 -0
  4. onnx_diagnostic/helpers/_log_helper.py +461 -0
  5. onnx_diagnostic/helpers/cache_helper.py +250 -15
  6. onnx_diagnostic/helpers/helper.py +146 -10
  7. onnx_diagnostic/helpers/log_helper.py +404 -315
  8. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  9. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  10. onnx_diagnostic/helpers/torch_helper.py +33 -11
  11. onnx_diagnostic/tasks/__init__.py +2 -0
  12. onnx_diagnostic/tasks/feature_extraction.py +86 -5
  13. onnx_diagnostic/tasks/image_text_to_text.py +260 -56
  14. onnx_diagnostic/tasks/mask_generation.py +139 -0
  15. onnx_diagnostic/tasks/text2text_generation.py +2 -2
  16. onnx_diagnostic/tasks/text_generation.py +6 -2
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
  18. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  19. onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
  21. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
  24. onnx_diagnostic/torch_models/validate.py +26 -3
  25. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,51 @@ import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
+ try:
8
+ from transformers.models.mamba.modeling_mamba import MambaCache
9
+ except ImportError:
10
+ from transformers.cache_utils import MambaCache
11
+
12
+
13
+ class CacheKeyValue:
14
+ """
15
+ Starting transformers>=4.54, the cache API has deprecated
16
+ ``cache.key_cache`` and ``cache.value_cache``.
17
+ This class wraps a cache independently from transformers version and enables
18
+ attributes ``key_cache`` and ``value_cache``.
19
+
20
+ .. code-block:: python
21
+
22
+ capi = CacheKeyValue(cache)
23
+ capi.key_cache
24
+ capi.value_cache
25
+ """
26
+
27
+ def __init__(self, cache=None):
28
+ if hasattr(cache, "layers"):
29
+ layers = [
30
+ layer
31
+ for layer in cache.layers
32
+ if layer is not None and layer.keys is not None and layer.values is not None
33
+ ]
34
+ self.key_cache = [layer.keys for layer in layers]
35
+ self.value_cache = [layer.values for layer in layers]
36
+ if None in self.key_cache or None in self.value_cache:
37
+ from .helper import string_type
38
+
39
+ raise AssertionError(
40
+ f"issue with key_cache={string_type(self.key_cache)}, "
41
+ f"or value_cache={string_type(self.value_cache)}, "
42
+ f"cache.layers={string_type(cache.layers)}"
43
+ )
44
+ elif cache is not None:
45
+ self.key_cache = cache.key_cache
46
+ self.value_cache = cache.value_cache
47
+
48
+ def make_dynamic_cache(self):
49
+ """Do the reverse operation."""
50
+ return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
51
+
7
52
 
8
53
  def flatten_unflatten_for_dynamic_shapes(
9
54
  obj: Any,
@@ -119,7 +164,19 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
119
164
  )
120
165
  print(string_type(past_key_values, with_shape=True))
121
166
  """
122
- return transformers.cache_utils.DynamicCache(key_value_pairs)
167
+ cache = transformers.cache_utils.DynamicCache(key_value_pairs)
168
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
169
+ # The cache constructor contains the two following lines
170
+ # (in cache_utils.py) which append empty layers when the cache is
171
+ # initialized. We need to remove them.
172
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
173
+ # self.append_new_layers(self.num_hidden_layers - 1)
174
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
175
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
176
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
177
+ f"{len(key_value_pairs)} expected."
178
+ )
179
+ return cache
123
180
 
124
181
  else:
125
182
 
@@ -216,19 +273,31 @@ def make_static_cache(
216
273
  ),
217
274
  )
218
275
  cache = transformers.cache_utils.StaticCache(
219
- _config(),
276
+ config=_config(),
220
277
  max_batch_size=key_value_pairs[0][0].shape[0],
221
278
  device=key_value_pairs[0][0].device,
222
279
  dtype=key_value_pairs[0][0].dtype,
223
280
  max_cache_len=max_cache_len,
224
281
  )
282
+ ca = CacheKeyValue(cache)
225
283
  for i in range(len(key_value_pairs)):
226
284
  assert (
227
285
  key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
228
286
  ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
229
287
  d = key_value_pairs[i][1].shape[2]
230
- cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
231
- cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
288
+ ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
289
+ ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
290
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
291
+ # The cache constructor contains the two following lines
292
+ # (in cache_utils.py) which append empty layers when the cache is
293
+ # initialized. We need to remove them.
294
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
295
+ # self.append_new_layers(self.num_hidden_layers - 1)
296
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
297
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
298
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
299
+ f"{len(key_value_pairs)} expected."
300
+ )
232
301
  return cache
233
302
 
234
303
 
@@ -242,10 +311,8 @@ def make_encoder_decoder_cache(
242
311
  )
243
312
 
244
313
 
245
- def make_mamba_cache(
246
- key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
247
- ) -> transformers.cache_utils.MambaCache:
248
- "Creates a :class:`transformers.cache_utils.MambaCache`."
314
+ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
315
+ "Creates a ``MambaCache``."
249
316
  dtype = key_value_pairs[0][0].dtype
250
317
 
251
318
  class _config:
@@ -256,7 +323,7 @@ def make_mamba_cache(
256
323
  self.num_hidden_layers = len(key_value_pairs)
257
324
  self.dtype = dtype
258
325
 
259
- cache = transformers.cache_utils.MambaCache(
326
+ cache = MambaCache(
260
327
  _config(),
261
328
  max_batch_size=key_value_pairs[0][0].shape[0],
262
329
  device=key_value_pairs[0][0].device,
@@ -286,7 +353,7 @@ def make_mamba_cache(
286
353
 
287
354
  def make_sliding_window_cache(
288
355
  key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
289
- ) -> transformers.cache_utils.MambaCache:
356
+ ) -> transformers.cache_utils.SlidingWindowCache:
290
357
  "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
291
358
 
292
359
  class _config:
@@ -297,21 +364,189 @@ def make_sliding_window_cache(
297
364
  self.sliding_window = key_value_pairs[0][0].shape[2]
298
365
 
299
366
  cache = transformers.cache_utils.SlidingWindowCache(
300
- _config(),
367
+ config=_config(),
301
368
  max_batch_size=key_value_pairs[0][0].shape[0],
302
369
  max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
303
370
  device=key_value_pairs[0][0].device,
304
371
  dtype=key_value_pairs[0][0].dtype,
305
372
  )
373
+ ca = CacheKeyValue(cache)
306
374
  for i in range(len(key_value_pairs)):
307
- assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
375
+ assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
308
376
  f"Shape mismatch, expected {cache.key_cache[i].shape}, "
309
377
  f"got {key_value_pairs[i][0].shape}"
310
378
  )
311
- cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
312
- assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
379
+ ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
380
+ assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
313
381
  f"Shape mismatch, expected {cache.value_cache[i].shape}, "
314
382
  f"got {key_value_pairs[i][1].shape}"
315
383
  )
316
- cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
384
+ ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
385
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
386
+ # The cache constructor contains the two following lines
387
+ # (in cache_utils.py) which append empty layers when the cache is
388
+ # initialized. We need to remove them.
389
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
390
+ # self.append_new_layers(self.num_hidden_layers - 1)
391
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
392
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
393
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
394
+ f"{len(key_value_pairs)} expected."
395
+ )
396
+ return cache
397
+
398
+
399
+ def make_hybrid_cache(
400
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
401
+ max_cache_len: Optional[int] = None,
402
+ max_batch_size: Optional[int] = None,
403
+ sliding_window: Optional[int] = None,
404
+ ) -> transformers.cache_utils.HybridCache:
405
+ """
406
+ Creates an instance of :class:`transformers.cache_utils.HybridCache`.
407
+ This version is valid for ``transformers < 4.50``.
408
+
409
+ :param key_value_pairs: list of pairs of (key, values)
410
+ :return: :class:`transformers.cache_utils.HybridCache`
411
+
412
+ Example:
413
+
414
+ .. runpython::
415
+ :showcode:
416
+
417
+ import torch
418
+ from onnx_diagnostic.helpers import string_type
419
+ from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
420
+
421
+ n_layers = 2
422
+ bsize, nheads, slen, dim = 2, 4, 3, 7
423
+
424
+ past_key_values = make_hybrid_cache(
425
+ [
426
+ (
427
+ torch.randn(bsize, nheads, slen, dim),
428
+ torch.randn(bsize, nheads, slen, dim),
429
+ )
430
+ for i in range(n_layers)
431
+ ]
432
+ )
433
+ print(string_type(past_key_values, with_shape=True))
434
+
435
+ This part defines how the shapes are working in one HybridCache.
436
+
437
+ .. code-block:: python
438
+
439
+ self.max_cache_len = (
440
+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
441
+
442
+ # Sliding layers can't be larger than the overall max cache len
443
+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
444
+ self.max_batch_size = max_batch_size
445
+
446
+ self.head_dim = (
447
+ config.head_dim if hasattr(config, "head_dim")
448
+ else config.hidden_size // config.num_attention_heads
449
+ )
450
+
451
+ self._dtype = dtype
452
+ self.num_key_value_heads = (
453
+ config.num_attention_heads
454
+ if getattr(config, "num_key_value_heads", None) is None
455
+ else config.num_key_value_heads
456
+ )
457
+
458
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
459
+ if hasattr(config, "layer_types"):
460
+ self.is_sliding = [
461
+ layer_type != "full_attention" for layer_type in config.layer_types]
462
+ else:
463
+ self.is_sliding = [False] * config.num_hidden_layers
464
+
465
+ self.key_cache: list[torch.Tensor] = []
466
+ self.value_cache: list[torch.Tensor] = []
467
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
468
+ self.max_cache_len, self.head_dim)
469
+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
470
+ self.sliding_window_len, self.head_dim)
471
+ self.sliding_window = min(config.sliding_window, max_cache_len)
472
+ device = torch.device(device) if device is not None else None
473
+ for i in range(config.num_hidden_layers):
474
+ layer_device = layer_device_map[i] if layer_device_map is not None else device
475
+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
476
+ new_layer_key_cache = torch.zeros(
477
+ cache_shape, dtype=self._dtype, device=layer_device)
478
+ new_layer_value_cache = torch.zeros(
479
+ cache_shape, dtype=self._dtype, device=layer_device)
480
+ torch._dynamo.mark_static_address(new_layer_key_cache)
481
+ torch._dynamo.mark_static_address(new_layer_value_cache)
482
+ self.key_cache.append(new_layer_key_cache)
483
+ self.value_cache.append(new_layer_value_cache)
484
+ """
485
+ layer_types = None
486
+ if key_value_pairs:
487
+ assert (
488
+ not max_batch_size and not max_cache_len
489
+ ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
490
+ max_batch_size = key_value_pairs[0][0].shape[0]
491
+ sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
492
+ if len(sets_of_dim) == 1:
493
+ max_cache_len = sets_of_dim.pop()
494
+ sliding_window = max_cache_len
495
+ else:
496
+ assert (
497
+ len(sets_of_dim) == 2
498
+ ), f"Not implemented for more than 2 dimensions {sets_of_dim}"
499
+ max_cache_len = max(sets_of_dim)
500
+ sliding_window = min(sets_of_dim)
501
+ layer_types = [
502
+ "full_attention" if i == max_cache_len else "sliding_attention"
503
+ for i in [kv[0].shape[2] for kv in key_value_pairs]
504
+ ]
505
+ else:
506
+ assert (
507
+ max_batch_size and max_cache_len
508
+ ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
509
+ if sliding_window is None:
510
+ sliding_window = max_cache_len
511
+ _max_cache_len = max_cache_len
512
+ _sliding_window = sliding_window
513
+
514
+ class _config:
515
+ max_cache_len = _max_cache_len
516
+ batch_size = max_batch_size
517
+ num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
518
+ head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
519
+ num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
520
+ num_hidden_layers = len(key_value_pairs)
521
+ sliding_window = _sliding_window
522
+ num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
523
+
524
+ if layer_types:
525
+ _config.layer_types = layer_types # type: ignore[attr-defined]
526
+
527
+ cache = transformers.cache_utils.HybridCache(
528
+ config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
529
+ )
530
+ for i, (key, value) in enumerate(key_value_pairs):
531
+ cache.update(
532
+ key,
533
+ value,
534
+ i,
535
+ cache_kwargs={
536
+ "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
537
+ key.device
538
+ )
539
+ },
540
+ )
541
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
542
+ # The cache constructor contains the two following lines
543
+ # (in cache_utils.py) which append empty layers when the cache is
544
+ # initialized. We need to remove them.
545
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
546
+ # self.append_new_layers(self.num_hidden_layers - 1)
547
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
548
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
549
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
550
+ f"{len(key_value_pairs)} expected."
551
+ )
317
552
  return cache
@@ -558,9 +558,17 @@ def string_type(
558
558
  print(f"[string_type] CACHE1:{type(obj)}")
559
559
  return f"MambaCache(conv_states={c}, ssm_states={d})"
560
560
 
561
- if obj.__class__.__name__ in {"DynamicCache", "SlidingWindowCache", "StaticCache"}:
561
+ if obj.__class__.__name__ in {
562
+ "DynamicCache",
563
+ "SlidingWindowCache",
564
+ "StaticCache",
565
+ "HybridCache",
566
+ }:
567
+ from .cache_helper import CacheKeyValue
568
+
569
+ ca = CacheKeyValue(obj)
562
570
  kc = string_type(
563
- obj.key_cache,
571
+ ca.key_cache,
564
572
  with_shape=with_shape,
565
573
  with_min_max=with_min_max,
566
574
  with_device=with_device,
@@ -568,7 +576,7 @@ def string_type(
568
576
  verbose=verbose,
569
577
  )
570
578
  vc = string_type(
571
- obj.value_cache,
579
+ ca.value_cache,
572
580
  with_shape=with_shape,
573
581
  with_min_max=with_min_max,
574
582
  with_device=with_device,
@@ -579,6 +587,27 @@ def string_type(
579
587
  print(f"[string_type] CACHE2:{type(obj)}")
580
588
  return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
581
589
 
590
+ if obj.__class__.__name__ == "StaticLayer":
591
+ kc = string_type(
592
+ list(obj.keys),
593
+ with_shape=with_shape,
594
+ with_min_max=with_min_max,
595
+ with_device=with_device,
596
+ limit=limit,
597
+ verbose=verbose,
598
+ )
599
+ vc = string_type(
600
+ list(obj.values),
601
+ with_shape=with_shape,
602
+ with_min_max=with_min_max,
603
+ with_device=with_device,
604
+ limit=limit,
605
+ verbose=verbose,
606
+ )
607
+ if verbose:
608
+ print(f"[string_type] SL:{type(obj)}")
609
+ return f"{obj.__class__.__name__}(keys={kc}, values={vc})"
610
+
582
611
  if obj.__class__.__name__ == "EncoderDecoderCache":
583
612
  att = string_type(
584
613
  obj.self_attention_cache,
@@ -663,6 +692,50 @@ def string_type(
663
692
  f"dtype={obj.dtype}, shape={obj.shape})"
664
693
  )
665
694
 
695
+ if obj.__class__.__name__ == "KeyValuesWrapper":
696
+ import transformers
697
+
698
+ assert isinstance(
699
+ obj, transformers.cache_utils.KeyValuesWrapper
700
+ ), f"Unexpected type {type(obj)}"
701
+ if verbose:
702
+ print(f"[string_type] KW0:{type(obj)}")
703
+ s = string_type(
704
+ list(obj),
705
+ with_shape=with_shape,
706
+ with_min_max=with_min_max,
707
+ with_device=with_device,
708
+ limit=limit,
709
+ verbose=verbose,
710
+ )
711
+ return f"{obj.__class__.__name__}[{obj.cache_type}]{s}"
712
+
713
+ if obj.__class__.__name__ == "DynamicLayer":
714
+ import transformers
715
+
716
+ assert isinstance(
717
+ obj, transformers.cache_utils.DynamicLayer
718
+ ), f"Unexpected type {type(obj)}"
719
+ if verbose:
720
+ print(f"[string_type] LY0:{type(obj)}")
721
+ s1 = string_type(
722
+ obj.keys,
723
+ with_shape=with_shape,
724
+ with_min_max=with_min_max,
725
+ with_device=with_device,
726
+ limit=limit,
727
+ verbose=verbose,
728
+ )
729
+ s2 = string_type(
730
+ obj.values,
731
+ with_shape=with_shape,
732
+ with_min_max=with_min_max,
733
+ with_device=with_device,
734
+ limit=limit,
735
+ verbose=verbose,
736
+ )
737
+ return f"{obj.__class__.__name__}(keys={s1}, values={s2})"
738
+
666
739
  if isinstance(obj, torch.nn.Module):
667
740
  if verbose:
668
741
  print(f"[string_type] MM:{type(obj)}")
@@ -858,7 +931,10 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
858
931
  return flatten_object(list(x.items()), drop_keys=drop_keys)
859
932
 
860
933
  if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
861
- res = flatten_object(x.key_cache) + flatten_object(x.value_cache)
934
+ from .cache_helper import CacheKeyValue
935
+
936
+ kc = CacheKeyValue(x)
937
+ res = flatten_object(kc.key_cache) + flatten_object(kc.value_cache)
862
938
  return tuple(res)
863
939
  if x.__class__.__name__ == "EncoderDecoderCache":
864
940
  res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
@@ -1424,19 +1500,58 @@ def max_diff(
1424
1500
  f"level={level}"
1425
1501
  )
1426
1502
 
1503
+ # backup function in case pytorch does not know how to serialize.
1504
+ if expected.__class__.__name__ == "HybridCache":
1505
+ if got.__class__.__name__ == "HybridCache":
1506
+ from .cache_helper import CacheKeyValue
1507
+
1508
+ if verbose >= 6:
1509
+ print(f"[max_diff] HybridCache: {string_type(expected)} ? {string_type(got)}")
1510
+ cae = CacheKeyValue(expected)
1511
+ cag = CacheKeyValue(got)
1512
+ return max_diff(
1513
+ [cae.key_cache, cae.value_cache],
1514
+ [cag.key_cache, cag.value_cache],
1515
+ verbose=verbose,
1516
+ hist=hist,
1517
+ )
1518
+ if isinstance(got, tuple) and len(got) == 2:
1519
+ from .cache_helper import CacheKeyValue
1520
+
1521
+ cae = CacheKeyValue(expected)
1522
+ return max_diff(
1523
+ [cae.key_cache, cae.value_cache],
1524
+ [got[0], got[1]],
1525
+ debug_info=_debug(expected.__class__.__name__),
1526
+ **_dkws,
1527
+ )
1528
+ raise AssertionError(
1529
+ f"HybridCache not fully implemented with classes "
1530
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1531
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1532
+ f"level={level}"
1533
+ )
1534
+
1427
1535
  if expected.__class__.__name__ == "StaticCache":
1428
1536
  if got.__class__.__name__ == "StaticCache":
1537
+ from .cache_helper import CacheKeyValue
1538
+
1539
+ cae = CacheKeyValue(expected)
1540
+ cag = CacheKeyValue(got)
1429
1541
  if verbose >= 6:
1430
1542
  print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
1431
1543
  return max_diff(
1432
- [expected.key_cache, expected.value_cache],
1433
- [got.key_cache, got.value_cache],
1544
+ [cae.key_cache, cae.value_cache],
1545
+ [cag.key_cache, cag.value_cache],
1434
1546
  verbose=verbose,
1435
1547
  hist=hist,
1436
1548
  )
1437
1549
  if isinstance(got, tuple) and len(got) == 2:
1550
+ from .cache_helper import CacheKeyValue
1551
+
1552
+ cae = CacheKeyValue(expected)
1438
1553
  return max_diff(
1439
- [expected.key_cache, expected.value_cache],
1554
+ [cae.key_cache, cae.value_cache],
1440
1555
  [got[0], got[1]],
1441
1556
  debug_info=_debug(expected.__class__.__name__),
1442
1557
  **_dkws,
@@ -1455,15 +1570,22 @@ def max_diff(
1455
1570
  f"[max_diff] SlidingWindowCache: "
1456
1571
  f"{string_type(expected)} ? {string_type(got)}"
1457
1572
  )
1573
+ from .cache_helper import CacheKeyValue
1574
+
1575
+ cae = CacheKeyValue(expected)
1576
+ cag = CacheKeyValue(got)
1458
1577
  return max_diff(
1459
- [expected.key_cache, expected.value_cache],
1460
- [got.key_cache, got.value_cache],
1578
+ [cae.key_cache, cae.value_cache],
1579
+ [cag.key_cache, cag.value_cache],
1461
1580
  verbose=verbose,
1462
1581
  hist=hist,
1463
1582
  )
1464
1583
  if isinstance(got, tuple) and len(got) == 2:
1584
+ from .cache_helper import CacheKeyValue
1585
+
1586
+ cae = CacheKeyValue(expected)
1465
1587
  return max_diff(
1466
- [expected.key_cache, expected.value_cache],
1588
+ [cae.key_cache, cae.value_cache],
1467
1589
  [got[0], got[1]],
1468
1590
  debug_info=_debug(expected.__class__.__name__),
1469
1591
  **_dkws,
@@ -1521,6 +1643,20 @@ def max_diff(
1521
1643
  **_dkws,
1522
1644
  )
1523
1645
 
1646
+ if expected.__class__.__name__ == "KeyValuesWrapper":
1647
+ if verbose >= 6:
1648
+ print(f"[max_diff] KeyValuesWrapper: {string_type(expected)} ? {string_type(got)}")
1649
+ if got.__class__.__name__ != expected.__class__.__name__:
1650
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1651
+ if got.cache_type != expected.cache_type:
1652
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1653
+ return max_diff(
1654
+ list(expected),
1655
+ list(got),
1656
+ debug_info=_debug(expected.__class__.__name__),
1657
+ **_dkws,
1658
+ )
1659
+
1524
1660
  raise AssertionError(
1525
1661
  f"Not implemented with implemented with expected="
1526
1662
  f"{string_type(expected)}, got={string_type(got)},\n"