onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +35 -5
  4. onnx_diagnostic/export/control_flow.py +511 -0
  5. onnx_diagnostic/export/control_flow_research.py +135 -0
  6. onnx_diagnostic/ext_test_case.py +33 -9
  7. onnx_diagnostic/helpers/cache_helper.py +217 -203
  8. onnx_diagnostic/helpers/helper.py +6 -2
  9. onnx_diagnostic/helpers/log_helper.py +39 -5
  10. onnx_diagnostic/helpers/memory_peak.py +2 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +55 -3
  12. onnx_diagnostic/helpers/onnx_helper.py +13 -16
  13. onnx_diagnostic/helpers/rt_helper.py +579 -15
  14. onnx_diagnostic/helpers/torch_helper.py +5 -0
  15. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  16. onnx_diagnostic/tasks/text2text_generation.py +1 -0
  17. onnx_diagnostic/tasks/text_generation.py +84 -54
  18. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
  22. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +563 -61
  23. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  25. onnx_diagnostic/torch_models/validate.py +620 -213
  26. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
  27. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
@@ -1188,6 +1188,7 @@ class ExtTestCase(unittest.TestCase):
1188
1188
  copy_inputs: bool = True,
1189
1189
  expected: Optional[Any] = None,
1190
1190
  use_ort: bool = False,
1191
+ ort_optimized_graph: bool = False,
1191
1192
  **kwargs,
1192
1193
  ):
1193
1194
  """
@@ -1206,6 +1207,7 @@ class ExtTestCase(unittest.TestCase):
1206
1207
  :param expected: expected values
1207
1208
  :param copy_inputs: to copy the inputs
1208
1209
  :param use_ort: use :class:`onnxruntime.InferenceSession`
1210
+ :param ort_optimized_graph: dumps the optimized onnxruntime graph
1209
1211
  :param kwargs: arguments sent to
1210
1212
  :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1211
1213
  """
@@ -1214,30 +1216,52 @@ class ExtTestCase(unittest.TestCase):
1214
1216
  from .helpers.ort_session import InferenceSessionForTorch
1215
1217
 
1216
1218
  kws = dict(with_shape=True, with_min_max=verbose > 1)
1217
- if verbose:
1218
- vname = test_name or "assert_onnx_disc"
1219
+ vname = test_name or "assert_onnx_disc"
1219
1220
  if test_name:
1221
+ import onnx
1222
+
1220
1223
  name = f"{test_name}.onnx"
1221
- print(f"[{vname}] save the onnx model into {name!r}")
1222
- name = self.dump_onnx(name, proto)
1223
- print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1224
+ if verbose:
1225
+ print(f"[{vname}] save the onnx model into {name!r}")
1226
+ if isinstance(proto, str):
1227
+ name = proto
1228
+ proto = onnx.load(name)
1229
+ else:
1230
+ assert isinstance(
1231
+ proto, onnx.ModelProto
1232
+ ), f"Unexpected type {type(proto)} for proto"
1233
+ name = self.dump_onnx(name, proto)
1234
+ if verbose:
1235
+ print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1224
1236
  if verbose:
1225
1237
  print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1226
1238
  if use_ort:
1239
+ assert isinstance(
1240
+ proto, onnx.ModelProto
1241
+ ), f"Unexpected type {type(proto)} for proto"
1227
1242
  feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
1228
- if verbose:
1229
- print(f"[{vname}] feeds {string_type(feeds, **kws)}")
1230
1243
  import onnxruntime
1231
1244
 
1245
+ if verbose:
1246
+ print(f"[{vname}] create onnxruntime.InferenceSession")
1247
+ options = onnxruntime.SessionOptions()
1248
+ if ort_optimized_graph:
1249
+ options.optimized_model_filepath = f"{name}.optort.onnx"
1232
1250
  sess = onnxruntime.InferenceSession(
1233
- proto.SerializeToString(), providers=["CPUExecutionProvider"]
1251
+ proto.SerializeToString(),
1252
+ options,
1253
+ providers=kwargs.get("providers", ["CPUExecutionProvider"]),
1234
1254
  )
1255
+ if verbose:
1256
+ print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1235
1257
  got = sess.run(None, feeds)
1236
1258
  else:
1237
1259
  feeds = make_feeds(proto, inputs, copy=True)
1238
1260
  if verbose:
1239
- print(f"[{vname}] feeds {string_type(feeds, **kws)}")
1261
+ print(f"[{vname}] create InferenceSessionForTorch")
1240
1262
  sess = InferenceSessionForTorch(proto, **kwargs)
1263
+ if verbose:
1264
+ print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1241
1265
  got = sess.run(None, feeds)
1242
1266
  if verbose:
1243
1267
  print(f"[{vname}] compute expected values")
@@ -391,17 +391,22 @@ def make_static_cache(
391
391
  return finalize_cache(cache)
392
392
 
393
393
 
394
- def make_encoder_decoder_cache(
395
- self_attention_cache: transformers.cache_utils.DynamicCache,
396
- cross_attention_cache: transformers.cache_utils.DynamicCache,
397
- ) -> transformers.cache_utils.EncoderDecoderCache:
398
- """Creates an EncoderDecoderCache."""
399
- return transformers.cache_utils.EncoderDecoderCache(
400
- # self_attention_cache=self_attention_cache,
401
- # cross_attention_cache=cross_attention_cache
402
- self_attention_cache,
403
- cross_attention_cache,
404
- )
394
+ if hasattr(transformers.cache_utils, "EncoderDecoderCache"):
395
+
396
+ def make_encoder_decoder_cache(
397
+ self_attention_cache: transformers.cache_utils.DynamicCache,
398
+ cross_attention_cache: transformers.cache_utils.DynamicCache,
399
+ ) -> transformers.cache_utils.EncoderDecoderCache:
400
+ """Creates an EncoderDecoderCache."""
401
+ return transformers.cache_utils.EncoderDecoderCache(
402
+ # self_attention_cache=self_attention_cache,
403
+ # cross_attention_cache=cross_attention_cache
404
+ self_attention_cache,
405
+ cross_attention_cache,
406
+ )
407
+
408
+ else:
409
+ make_encoder_decoder_cache = None # type: ignore[assignment]
405
410
 
406
411
 
407
412
  def make_mamba_cache(
@@ -454,220 +459,229 @@ def make_mamba_cache(
454
459
  return finalize_cache(cache)
455
460
 
456
461
 
457
- def make_sliding_window_cache(
458
- key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
459
- ) -> transformers.cache_utils.SlidingWindowCache:
460
- "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
461
- key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
462
+ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
462
463
 
463
- class _config:
464
- def __init__(self):
465
- self.head_dim = key_value_pairs[0][0].shape[-1]
466
- self.num_attention_heads = key_value_pairs[0][0].shape[1]
467
- self.num_hidden_layers = len(key_value_pairs)
468
- self.sliding_window = key_value_pairs[0][0].shape[2]
469
-
470
- def get_text_config(self, *args, **kwargs):
471
- return self
472
-
473
- cache = transformers.cache_utils.SlidingWindowCache(
474
- config=_config(),
475
- max_batch_size=key_value_pairs[0][0].shape[0],
476
- max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
477
- device=key_value_pairs[0][0].device,
478
- dtype=key_value_pairs[0][0].dtype,
479
- )
480
- ca = CacheKeyValue(cache)
481
- if hasattr(cache, "layers") and len(ca.key_cache) == 0:
482
- # transformers>= 4.55.2, layers are empty
483
- cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
484
- for i, (key, value) in enumerate(key_value_pairs):
485
- cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
486
- return cache
464
+ def make_sliding_window_cache(
465
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
466
+ ) -> transformers.cache_utils.SlidingWindowCache:
467
+ "Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
468
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
487
469
 
488
- for i in range(len(key_value_pairs)):
489
- assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
490
- f"Shape mismatch, expected {cache.key_cache[i].shape}, "
491
- f"got {key_value_pairs[i][0].shape}"
470
+ class _config:
471
+ def __init__(self):
472
+ self.head_dim = key_value_pairs[0][0].shape[-1]
473
+ self.num_attention_heads = key_value_pairs[0][0].shape[1]
474
+ self.num_hidden_layers = len(key_value_pairs)
475
+ self.sliding_window = key_value_pairs[0][0].shape[2]
476
+
477
+ def get_text_config(self, *args, **kwargs):
478
+ return self
479
+
480
+ cache = transformers.cache_utils.SlidingWindowCache(
481
+ config=_config(),
482
+ max_batch_size=key_value_pairs[0][0].shape[0],
483
+ max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
484
+ device=key_value_pairs[0][0].device,
485
+ dtype=key_value_pairs[0][0].dtype,
492
486
  )
493
- ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
494
- assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
495
- f"Shape mismatch, expected {cache.value_cache[i].shape}, "
496
- f"got {key_value_pairs[i][1].shape}"
487
+ ca = CacheKeyValue(cache)
488
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
489
+ # transformers>= 4.55.2, layers are empty
490
+ cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
491
+ for i, (key, value) in enumerate(key_value_pairs):
492
+ cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
493
+ return cache
494
+
495
+ for i in range(len(key_value_pairs)):
496
+ assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
497
+ f"Shape mismatch, expected {cache.key_cache[i].shape}, "
498
+ f"got {key_value_pairs[i][0].shape}"
499
+ )
500
+ ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
501
+ assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
502
+ f"Shape mismatch, expected {cache.value_cache[i].shape}, "
503
+ f"got {key_value_pairs[i][1].shape}"
504
+ )
505
+ ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
506
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
507
+ # The cache constructor contains the two following lines
508
+ # (in cache_utils.py) which append empty layers when the cache is
509
+ # initialized. We need to remove them.
510
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
511
+ # self.append_new_layers(self.num_hidden_layers - 1)
512
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
513
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
514
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
515
+ f"{len(key_value_pairs)} expected."
497
516
  )
498
- ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
499
- if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
500
- # The cache constructor contains the two following lines
501
- # (in cache_utils.py) which append empty layers when the cache is
502
- # initialized. We need to remove them.
503
- # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
504
- # self.append_new_layers(self.num_hidden_layers - 1)
505
- cache.layers[:] = cache.layers[-len(key_value_pairs) :]
506
- assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
507
- f"Unexpected number of layers in the cache ({len(cache.layers)}), "
508
- f"{len(key_value_pairs)} expected."
509
- )
510
- return finalize_cache(cache)
517
+ return finalize_cache(cache)
511
518
 
519
+ else:
520
+ make_sliding_window_cache = None # type: ignore[assignment]
512
521
 
513
- def make_hybrid_cache(
514
- key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
515
- max_cache_len: Optional[int] = None,
516
- max_batch_size: Optional[int] = None,
517
- sliding_window: Optional[int] = None,
518
- ) -> transformers.cache_utils.HybridCache:
519
- """
520
- Creates an instance of :class:`transformers.cache_utils.HybridCache`.
521
- This version is valid for ``transformers < 4.50``.
522
+ if hasattr(transformers.cache_utils, "HybridCache"):
522
523
 
523
- :param key_value_pairs: list of pairs of (key, values)
524
- :return: :class:`transformers.cache_utils.HybridCache`
524
+ def make_hybrid_cache(
525
+ key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
526
+ max_cache_len: Optional[int] = None,
527
+ max_batch_size: Optional[int] = None,
528
+ sliding_window: Optional[int] = None,
529
+ ) -> transformers.cache_utils.HybridCache:
530
+ """
531
+ Creates an instance of :class:`transformers.cache_utils.HybridCache`.
532
+ This version is valid for ``transformers < 4.50``.
525
533
 
526
- Example:
534
+ :param key_value_pairs: list of pairs of (key, values)
535
+ :return: :class:`transformers.cache_utils.HybridCache`
527
536
 
528
- .. runpython::
529
- :showcode:
537
+ Example:
530
538
 
531
- import torch
532
- from onnx_diagnostic.helpers import string_type
533
- from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
539
+ .. runpython::
540
+ :showcode:
534
541
 
535
- n_layers = 2
536
- bsize, nheads, slen, dim = 2, 4, 3, 7
542
+ import torch
543
+ from onnx_diagnostic.helpers import string_type
544
+ from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
537
545
 
538
- past_key_values = make_hybrid_cache(
539
- [
540
- (
541
- torch.randn(bsize, nheads, slen, dim),
542
- torch.randn(bsize, nheads, slen, dim),
543
- )
544
- for i in range(n_layers)
545
- ]
546
- )
547
- print(string_type(past_key_values, with_shape=True))
546
+ n_layers = 2
547
+ bsize, nheads, slen, dim = 2, 4, 3, 7
548
+
549
+ past_key_values = make_hybrid_cache(
550
+ [
551
+ (
552
+ torch.randn(bsize, nheads, slen, dim),
553
+ torch.randn(bsize, nheads, slen, dim),
554
+ )
555
+ for i in range(n_layers)
556
+ ]
557
+ )
558
+ print(string_type(past_key_values, with_shape=True))
548
559
 
549
- This part defines how the shapes are working in one HybridCache.
560
+ This part defines how the shapes are working in one HybridCache.
550
561
 
551
- .. code-block:: python
562
+ .. code-block:: python
552
563
 
553
- self.max_cache_len = (
554
- max_cache_len if max_cache_len is not None else config.max_position_embeddings)
564
+ self.max_cache_len = (
565
+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
555
566
 
556
- # Sliding layers can't be larger than the overall max cache len
557
- self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
558
- self.max_batch_size = max_batch_size
567
+ # Sliding layers can't be larger than the overall max cache len
568
+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
569
+ self.max_batch_size = max_batch_size
559
570
 
560
- self.head_dim = (
561
- config.head_dim if hasattr(config, "head_dim")
562
- else config.hidden_size // config.num_attention_heads
563
- )
571
+ self.head_dim = (
572
+ config.head_dim if hasattr(config, "head_dim")
573
+ else config.hidden_size // config.num_attention_heads
574
+ )
564
575
 
565
- self._dtype = dtype
566
- self.num_key_value_heads = (
567
- config.num_attention_heads
568
- if getattr(config, "num_key_value_heads", None) is None
569
- else config.num_key_value_heads
570
- )
576
+ self._dtype = dtype
577
+ self.num_key_value_heads = (
578
+ config.num_attention_heads
579
+ if getattr(config, "num_key_value_heads", None) is None
580
+ else config.num_key_value_heads
581
+ )
571
582
 
572
- # If the attribute does not exist in the config, fallback to a simple StaticCache
573
- if hasattr(config, "layer_types"):
574
- self.is_sliding = [
575
- layer_type != "full_attention" for layer_type in config.layer_types]
576
- else:
577
- self.is_sliding = [False] * config.num_hidden_layers
578
-
579
- self.key_cache: list[torch.Tensor] = []
580
- self.value_cache: list[torch.Tensor] = []
581
- global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
582
- self.max_cache_len, self.head_dim)
583
- sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
584
- self.sliding_window_len, self.head_dim)
585
- self.sliding_window = min(config.sliding_window, max_cache_len)
586
- device = torch.device(device) if device is not None else None
587
- for i in range(config.num_hidden_layers):
588
- layer_device = layer_device_map[i] if layer_device_map is not None else device
589
- cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
590
- new_layer_key_cache = torch.zeros(
591
- cache_shape, dtype=self._dtype, device=layer_device)
592
- new_layer_value_cache = torch.zeros(
593
- cache_shape, dtype=self._dtype, device=layer_device)
594
- torch._dynamo.mark_static_address(new_layer_key_cache)
595
- torch._dynamo.mark_static_address(new_layer_value_cache)
596
- self.key_cache.append(new_layer_key_cache)
597
- self.value_cache.append(new_layer_value_cache)
598
- """
599
- key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
600
- layer_types = None
601
- if key_value_pairs:
602
- assert (
603
- not max_batch_size and not max_cache_len
604
- ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
605
- max_batch_size = key_value_pairs[0][0].shape[0]
606
- sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
607
- if len(sets_of_dim) == 1:
608
- max_cache_len = sets_of_dim.pop()
609
- sliding_window = max_cache_len
583
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
584
+ if hasattr(config, "layer_types"):
585
+ self.is_sliding = [
586
+ layer_type != "full_attention" for layer_type in config.layer_types]
587
+ else:
588
+ self.is_sliding = [False] * config.num_hidden_layers
589
+
590
+ self.key_cache: list[torch.Tensor] = []
591
+ self.value_cache: list[torch.Tensor] = []
592
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
593
+ self.max_cache_len, self.head_dim)
594
+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
595
+ self.sliding_window_len, self.head_dim)
596
+ self.sliding_window = min(config.sliding_window, max_cache_len)
597
+ device = torch.device(device) if device is not None else None
598
+ for i in range(config.num_hidden_layers):
599
+ layer_device = layer_device_map[i] if layer_device_map is not None else device
600
+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
601
+ new_layer_key_cache = torch.zeros(
602
+ cache_shape, dtype=self._dtype, device=layer_device)
603
+ new_layer_value_cache = torch.zeros(
604
+ cache_shape, dtype=self._dtype, device=layer_device)
605
+ torch._dynamo.mark_static_address(new_layer_key_cache)
606
+ torch._dynamo.mark_static_address(new_layer_value_cache)
607
+ self.key_cache.append(new_layer_key_cache)
608
+ self.value_cache.append(new_layer_value_cache)
609
+ """
610
+ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
611
+ layer_types = None
612
+ if key_value_pairs:
613
+ assert (
614
+ not max_batch_size and not max_cache_len
615
+ ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
616
+ max_batch_size = key_value_pairs[0][0].shape[0]
617
+ sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
618
+ if len(sets_of_dim) == 1:
619
+ max_cache_len = sets_of_dim.pop()
620
+ sliding_window = max_cache_len
621
+ else:
622
+ assert (
623
+ len(sets_of_dim) == 2
624
+ ), f"Not implemented for more than 2 dimensions {sets_of_dim}"
625
+ max_cache_len = max(sets_of_dim)
626
+ sliding_window = min(sets_of_dim)
627
+ layer_types = [
628
+ "full_attention" if i == max_cache_len else "sliding_attention"
629
+ for i in [kv[0].shape[2] for kv in key_value_pairs]
630
+ ]
610
631
  else:
611
632
  assert (
612
- len(sets_of_dim) == 2
613
- ), f"Not implemented for more than 2 dimensions {sets_of_dim}"
614
- max_cache_len = max(sets_of_dim)
615
- sliding_window = min(sets_of_dim)
616
- layer_types = [
617
- "full_attention" if i == max_cache_len else "sliding_attention"
618
- for i in [kv[0].shape[2] for kv in key_value_pairs]
619
- ]
620
- else:
621
- assert (
622
- max_batch_size and max_cache_len
623
- ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
624
- if sliding_window is None:
625
- sliding_window = max_cache_len
626
- _max_cache_len = max_cache_len
627
- _sliding_window = sliding_window
628
-
629
- class _config:
630
- max_cache_len = _max_cache_len
631
- batch_size = max_batch_size
632
- num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
633
- head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
634
- num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
635
- num_hidden_layers = len(key_value_pairs)
636
- sliding_window = _sliding_window
637
- num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
638
-
639
- def get_text_config(self, *args, **kwargs):
640
- return self
641
-
642
- if layer_types:
643
- _config.layer_types = layer_types # type: ignore[attr-defined]
644
-
645
- cache = transformers.cache_utils.HybridCache(
646
- config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
647
- )
648
- for i, (key, value) in enumerate(key_value_pairs):
649
- cache.update(
650
- key,
651
- value,
652
- i,
653
- cache_kwargs={
654
- "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
655
- key.device
656
- )
657
- },
633
+ max_batch_size and max_cache_len
634
+ ), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
635
+ if sliding_window is None:
636
+ sliding_window = max_cache_len
637
+ _max_cache_len = max_cache_len
638
+ _sliding_window = sliding_window
639
+
640
+ class _config:
641
+ max_cache_len = _max_cache_len
642
+ batch_size = max_batch_size
643
+ num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
644
+ head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
645
+ num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
646
+ num_hidden_layers = len(key_value_pairs)
647
+ sliding_window = _sliding_window
648
+ num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
649
+
650
+ def get_text_config(self, *args, **kwargs):
651
+ return self
652
+
653
+ if layer_types:
654
+ _config.layer_types = layer_types # type: ignore[attr-defined]
655
+
656
+ cache = transformers.cache_utils.HybridCache(
657
+ config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
658
658
  )
659
- if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
660
- # The cache constructor contains the two following lines
661
- # (in cache_utils.py) which append empty layers when the cache is
662
- # initialized. We need to remove them.
663
- # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
664
- # self.append_new_layers(self.num_hidden_layers - 1)
665
- cache.layers[:] = cache.layers[-len(key_value_pairs) :]
666
- assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
667
- f"Unexpected number of layers in the cache ({len(cache.layers)}), "
668
- f"{len(key_value_pairs)} expected."
669
- )
670
- return finalize_cache(cache)
659
+ for i, (key, value) in enumerate(key_value_pairs):
660
+ cache.update(
661
+ key,
662
+ value,
663
+ i,
664
+ cache_kwargs={
665
+ "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
666
+ key.device
667
+ )
668
+ },
669
+ )
670
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
671
+ # The cache constructor contains the two following lines
672
+ # (in cache_utils.py) which append empty layers when the cache is
673
+ # initialized. We need to remove them.
674
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
675
+ # self.append_new_layers(self.num_hidden_layers - 1)
676
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
677
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
678
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
679
+ f"{len(key_value_pairs)} expected."
680
+ )
681
+ return finalize_cache(cache)
682
+
683
+ else:
684
+ make_hybrid_cache = None # type: ignore[assignment]
671
685
 
672
686
 
673
687
  def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
@@ -787,6 +787,8 @@ def string_type(
787
787
  return f"ultralytics.{obj.__class__.__name__}(...)"
788
788
  if obj.__class__.__name__ == "FakeTensorMode":
789
789
  return f"{obj}"
790
+ if obj.__class__.__name__ == "FakeTensorContext":
791
+ return "FakeTensorContext(...)"
790
792
 
791
793
  if verbose:
792
794
  print(f"[string_type] END:{type(obj)}")
@@ -1016,6 +1018,8 @@ def max_diff(
1016
1018
 
1017
1019
  You may use :func:`string_diff` to display the discrepancies in one string.
1018
1020
  """
1021
+ if verbose >= 10:
1022
+ print(f"[max_diff] {type(expected)} ? {type(got)}")
1019
1023
  if expected is None and got is None:
1020
1024
  return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
1021
1025
 
@@ -1061,8 +1065,8 @@ def max_diff(
1061
1065
  if expected.__class__.__name__ == "CausalLMOutputWithPast":
1062
1066
  if verbose >= 6:
1063
1067
  print(
1064
- f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
1065
- f"? {string_type(got)}"
1068
+ f"[max_diff] CausalLMOutputWithPast: {string_type(expected, with_shape=True)} "
1069
+ f"? {string_type(got, with_shape=True)}"
1066
1070
  )
1067
1071
  if got.__class__.__name__ == "CausalLMOutputWithPast":
1068
1072
  return max_diff(