onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +35 -5
- onnx_diagnostic/export/control_flow.py +511 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/ext_test_case.py +33 -9
- onnx_diagnostic/helpers/cache_helper.py +217 -203
- onnx_diagnostic/helpers/helper.py +6 -2
- onnx_diagnostic/helpers/log_helper.py +39 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +55 -3
- onnx_diagnostic/helpers/onnx_helper.py +13 -16
- onnx_diagnostic/helpers/rt_helper.py +579 -15
- onnx_diagnostic/helpers/torch_helper.py +5 -0
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/text2text_generation.py +1 -0
- onnx_diagnostic/tasks/text_generation.py +84 -54
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +563 -61
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +620 -213
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +30 -28
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -1188,6 +1188,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1188
1188
|
copy_inputs: bool = True,
|
|
1189
1189
|
expected: Optional[Any] = None,
|
|
1190
1190
|
use_ort: bool = False,
|
|
1191
|
+
ort_optimized_graph: bool = False,
|
|
1191
1192
|
**kwargs,
|
|
1192
1193
|
):
|
|
1193
1194
|
"""
|
|
@@ -1206,6 +1207,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1206
1207
|
:param expected: expected values
|
|
1207
1208
|
:param copy_inputs: to copy the inputs
|
|
1208
1209
|
:param use_ort: use :class:`onnxruntime.InferenceSession`
|
|
1210
|
+
:param ort_optimized_graph: dumps the optimized onnxruntime graph
|
|
1209
1211
|
:param kwargs: arguments sent to
|
|
1210
1212
|
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
|
|
1211
1213
|
"""
|
|
@@ -1214,30 +1216,52 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1214
1216
|
from .helpers.ort_session import InferenceSessionForTorch
|
|
1215
1217
|
|
|
1216
1218
|
kws = dict(with_shape=True, with_min_max=verbose > 1)
|
|
1217
|
-
|
|
1218
|
-
vname = test_name or "assert_onnx_disc"
|
|
1219
|
+
vname = test_name or "assert_onnx_disc"
|
|
1219
1220
|
if test_name:
|
|
1221
|
+
import onnx
|
|
1222
|
+
|
|
1220
1223
|
name = f"{test_name}.onnx"
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
+
if verbose:
|
|
1225
|
+
print(f"[{vname}] save the onnx model into {name!r}")
|
|
1226
|
+
if isinstance(proto, str):
|
|
1227
|
+
name = proto
|
|
1228
|
+
proto = onnx.load(name)
|
|
1229
|
+
else:
|
|
1230
|
+
assert isinstance(
|
|
1231
|
+
proto, onnx.ModelProto
|
|
1232
|
+
), f"Unexpected type {type(proto)} for proto"
|
|
1233
|
+
name = self.dump_onnx(name, proto)
|
|
1234
|
+
if verbose:
|
|
1235
|
+
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
|
|
1224
1236
|
if verbose:
|
|
1225
1237
|
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
|
|
1226
1238
|
if use_ort:
|
|
1239
|
+
assert isinstance(
|
|
1240
|
+
proto, onnx.ModelProto
|
|
1241
|
+
), f"Unexpected type {type(proto)} for proto"
|
|
1227
1242
|
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
|
|
1228
|
-
if verbose:
|
|
1229
|
-
print(f"[{vname}] feeds {string_type(feeds, **kws)}")
|
|
1230
1243
|
import onnxruntime
|
|
1231
1244
|
|
|
1245
|
+
if verbose:
|
|
1246
|
+
print(f"[{vname}] create onnxruntime.InferenceSession")
|
|
1247
|
+
options = onnxruntime.SessionOptions()
|
|
1248
|
+
if ort_optimized_graph:
|
|
1249
|
+
options.optimized_model_filepath = f"{name}.optort.onnx"
|
|
1232
1250
|
sess = onnxruntime.InferenceSession(
|
|
1233
|
-
proto.SerializeToString(),
|
|
1251
|
+
proto.SerializeToString(),
|
|
1252
|
+
options,
|
|
1253
|
+
providers=kwargs.get("providers", ["CPUExecutionProvider"]),
|
|
1234
1254
|
)
|
|
1255
|
+
if verbose:
|
|
1256
|
+
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
|
|
1235
1257
|
got = sess.run(None, feeds)
|
|
1236
1258
|
else:
|
|
1237
1259
|
feeds = make_feeds(proto, inputs, copy=True)
|
|
1238
1260
|
if verbose:
|
|
1239
|
-
print(f"[{vname}]
|
|
1261
|
+
print(f"[{vname}] create InferenceSessionForTorch")
|
|
1240
1262
|
sess = InferenceSessionForTorch(proto, **kwargs)
|
|
1263
|
+
if verbose:
|
|
1264
|
+
print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
|
|
1241
1265
|
got = sess.run(None, feeds)
|
|
1242
1266
|
if verbose:
|
|
1243
1267
|
print(f"[{vname}] compute expected values")
|
|
@@ -391,17 +391,22 @@ def make_static_cache(
|
|
|
391
391
|
return finalize_cache(cache)
|
|
392
392
|
|
|
393
393
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
394
|
+
if hasattr(transformers.cache_utils, "EncoderDecoderCache"):
|
|
395
|
+
|
|
396
|
+
def make_encoder_decoder_cache(
|
|
397
|
+
self_attention_cache: transformers.cache_utils.DynamicCache,
|
|
398
|
+
cross_attention_cache: transformers.cache_utils.DynamicCache,
|
|
399
|
+
) -> transformers.cache_utils.EncoderDecoderCache:
|
|
400
|
+
"""Creates an EncoderDecoderCache."""
|
|
401
|
+
return transformers.cache_utils.EncoderDecoderCache(
|
|
402
|
+
# self_attention_cache=self_attention_cache,
|
|
403
|
+
# cross_attention_cache=cross_attention_cache
|
|
404
|
+
self_attention_cache,
|
|
405
|
+
cross_attention_cache,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
else:
|
|
409
|
+
make_encoder_decoder_cache = None # type: ignore[assignment]
|
|
405
410
|
|
|
406
411
|
|
|
407
412
|
def make_mamba_cache(
|
|
@@ -454,220 +459,229 @@ def make_mamba_cache(
|
|
|
454
459
|
return finalize_cache(cache)
|
|
455
460
|
|
|
456
461
|
|
|
457
|
-
|
|
458
|
-
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
459
|
-
) -> transformers.cache_utils.SlidingWindowCache:
|
|
460
|
-
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
461
|
-
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
462
|
+
if hasattr(transformers.cache_utils, "SlidingWindowCache"):
|
|
462
463
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
469
|
-
|
|
470
|
-
def get_text_config(self, *args, **kwargs):
|
|
471
|
-
return self
|
|
472
|
-
|
|
473
|
-
cache = transformers.cache_utils.SlidingWindowCache(
|
|
474
|
-
config=_config(),
|
|
475
|
-
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
476
|
-
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
|
|
477
|
-
device=key_value_pairs[0][0].device,
|
|
478
|
-
dtype=key_value_pairs[0][0].dtype,
|
|
479
|
-
)
|
|
480
|
-
ca = CacheKeyValue(cache)
|
|
481
|
-
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
482
|
-
# transformers>= 4.55.2, layers are empty
|
|
483
|
-
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
|
|
484
|
-
for i, (key, value) in enumerate(key_value_pairs):
|
|
485
|
-
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
|
|
486
|
-
return cache
|
|
464
|
+
def make_sliding_window_cache(
|
|
465
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
466
|
+
) -> transformers.cache_utils.SlidingWindowCache:
|
|
467
|
+
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
468
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
487
469
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
470
|
+
class _config:
|
|
471
|
+
def __init__(self):
|
|
472
|
+
self.head_dim = key_value_pairs[0][0].shape[-1]
|
|
473
|
+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
474
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
475
|
+
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
476
|
+
|
|
477
|
+
def get_text_config(self, *args, **kwargs):
|
|
478
|
+
return self
|
|
479
|
+
|
|
480
|
+
cache = transformers.cache_utils.SlidingWindowCache(
|
|
481
|
+
config=_config(),
|
|
482
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
483
|
+
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
|
|
484
|
+
device=key_value_pairs[0][0].device,
|
|
485
|
+
dtype=key_value_pairs[0][0].dtype,
|
|
492
486
|
)
|
|
493
|
-
ca
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
487
|
+
ca = CacheKeyValue(cache)
|
|
488
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
489
|
+
# transformers>= 4.55.2, layers are empty
|
|
490
|
+
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
|
|
491
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
492
|
+
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
|
|
493
|
+
return cache
|
|
494
|
+
|
|
495
|
+
for i in range(len(key_value_pairs)):
|
|
496
|
+
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
497
|
+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
498
|
+
f"got {key_value_pairs[i][0].shape}"
|
|
499
|
+
)
|
|
500
|
+
ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
501
|
+
assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
502
|
+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
503
|
+
f"got {key_value_pairs[i][1].shape}"
|
|
504
|
+
)
|
|
505
|
+
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
506
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
507
|
+
# The cache constructor contains the two following lines
|
|
508
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
509
|
+
# initialized. We need to remove them.
|
|
510
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
511
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
512
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
513
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
514
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
515
|
+
f"{len(key_value_pairs)} expected."
|
|
497
516
|
)
|
|
498
|
-
|
|
499
|
-
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
500
|
-
# The cache constructor contains the two following lines
|
|
501
|
-
# (in cache_utils.py) which append empty layers when the cache is
|
|
502
|
-
# initialized. We need to remove them.
|
|
503
|
-
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
504
|
-
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
505
|
-
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
506
|
-
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
507
|
-
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
508
|
-
f"{len(key_value_pairs)} expected."
|
|
509
|
-
)
|
|
510
|
-
return finalize_cache(cache)
|
|
517
|
+
return finalize_cache(cache)
|
|
511
518
|
|
|
519
|
+
else:
|
|
520
|
+
make_sliding_window_cache = None # type: ignore[assignment]
|
|
512
521
|
|
|
513
|
-
|
|
514
|
-
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
515
|
-
max_cache_len: Optional[int] = None,
|
|
516
|
-
max_batch_size: Optional[int] = None,
|
|
517
|
-
sliding_window: Optional[int] = None,
|
|
518
|
-
) -> transformers.cache_utils.HybridCache:
|
|
519
|
-
"""
|
|
520
|
-
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
|
|
521
|
-
This version is valid for ``transformers < 4.50``.
|
|
522
|
+
if hasattr(transformers.cache_utils, "HybridCache"):
|
|
522
523
|
|
|
523
|
-
|
|
524
|
-
|
|
524
|
+
def make_hybrid_cache(
|
|
525
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
526
|
+
max_cache_len: Optional[int] = None,
|
|
527
|
+
max_batch_size: Optional[int] = None,
|
|
528
|
+
sliding_window: Optional[int] = None,
|
|
529
|
+
) -> transformers.cache_utils.HybridCache:
|
|
530
|
+
"""
|
|
531
|
+
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
|
|
532
|
+
This version is valid for ``transformers < 4.50``.
|
|
525
533
|
|
|
526
|
-
|
|
534
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
535
|
+
:return: :class:`transformers.cache_utils.HybridCache`
|
|
527
536
|
|
|
528
|
-
|
|
529
|
-
:showcode:
|
|
537
|
+
Example:
|
|
530
538
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
|
|
539
|
+
.. runpython::
|
|
540
|
+
:showcode:
|
|
534
541
|
|
|
535
|
-
|
|
536
|
-
|
|
542
|
+
import torch
|
|
543
|
+
from onnx_diagnostic.helpers import string_type
|
|
544
|
+
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
|
|
537
545
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
546
|
+
n_layers = 2
|
|
547
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
548
|
+
|
|
549
|
+
past_key_values = make_hybrid_cache(
|
|
550
|
+
[
|
|
551
|
+
(
|
|
552
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
553
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
554
|
+
)
|
|
555
|
+
for i in range(n_layers)
|
|
556
|
+
]
|
|
557
|
+
)
|
|
558
|
+
print(string_type(past_key_values, with_shape=True))
|
|
548
559
|
|
|
549
|
-
|
|
560
|
+
This part defines how the shapes are working in one HybridCache.
|
|
550
561
|
|
|
551
|
-
|
|
562
|
+
.. code-block:: python
|
|
552
563
|
|
|
553
|
-
|
|
554
|
-
|
|
564
|
+
self.max_cache_len = (
|
|
565
|
+
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
|
|
555
566
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
567
|
+
# Sliding layers can't be larger than the overall max cache len
|
|
568
|
+
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
|
569
|
+
self.max_batch_size = max_batch_size
|
|
559
570
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
571
|
+
self.head_dim = (
|
|
572
|
+
config.head_dim if hasattr(config, "head_dim")
|
|
573
|
+
else config.hidden_size // config.num_attention_heads
|
|
574
|
+
)
|
|
564
575
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
576
|
+
self._dtype = dtype
|
|
577
|
+
self.num_key_value_heads = (
|
|
578
|
+
config.num_attention_heads
|
|
579
|
+
if getattr(config, "num_key_value_heads", None) is None
|
|
580
|
+
else config.num_key_value_heads
|
|
581
|
+
)
|
|
571
582
|
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
583
|
+
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
|
584
|
+
if hasattr(config, "layer_types"):
|
|
585
|
+
self.is_sliding = [
|
|
586
|
+
layer_type != "full_attention" for layer_type in config.layer_types]
|
|
587
|
+
else:
|
|
588
|
+
self.is_sliding = [False] * config.num_hidden_layers
|
|
589
|
+
|
|
590
|
+
self.key_cache: list[torch.Tensor] = []
|
|
591
|
+
self.value_cache: list[torch.Tensor] = []
|
|
592
|
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
593
|
+
self.max_cache_len, self.head_dim)
|
|
594
|
+
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
|
|
595
|
+
self.sliding_window_len, self.head_dim)
|
|
596
|
+
self.sliding_window = min(config.sliding_window, max_cache_len)
|
|
597
|
+
device = torch.device(device) if device is not None else None
|
|
598
|
+
for i in range(config.num_hidden_layers):
|
|
599
|
+
layer_device = layer_device_map[i] if layer_device_map is not None else device
|
|
600
|
+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
|
601
|
+
new_layer_key_cache = torch.zeros(
|
|
602
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
603
|
+
new_layer_value_cache = torch.zeros(
|
|
604
|
+
cache_shape, dtype=self._dtype, device=layer_device)
|
|
605
|
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
606
|
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
607
|
+
self.key_cache.append(new_layer_key_cache)
|
|
608
|
+
self.value_cache.append(new_layer_value_cache)
|
|
609
|
+
"""
|
|
610
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
611
|
+
layer_types = None
|
|
612
|
+
if key_value_pairs:
|
|
613
|
+
assert (
|
|
614
|
+
not max_batch_size and not max_cache_len
|
|
615
|
+
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
|
|
616
|
+
max_batch_size = key_value_pairs[0][0].shape[0]
|
|
617
|
+
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
|
|
618
|
+
if len(sets_of_dim) == 1:
|
|
619
|
+
max_cache_len = sets_of_dim.pop()
|
|
620
|
+
sliding_window = max_cache_len
|
|
621
|
+
else:
|
|
622
|
+
assert (
|
|
623
|
+
len(sets_of_dim) == 2
|
|
624
|
+
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
|
|
625
|
+
max_cache_len = max(sets_of_dim)
|
|
626
|
+
sliding_window = min(sets_of_dim)
|
|
627
|
+
layer_types = [
|
|
628
|
+
"full_attention" if i == max_cache_len else "sliding_attention"
|
|
629
|
+
for i in [kv[0].shape[2] for kv in key_value_pairs]
|
|
630
|
+
]
|
|
610
631
|
else:
|
|
611
632
|
assert (
|
|
612
|
-
|
|
613
|
-
),
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
638
|
-
|
|
639
|
-
def get_text_config(self, *args, **kwargs):
|
|
640
|
-
return self
|
|
641
|
-
|
|
642
|
-
if layer_types:
|
|
643
|
-
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
644
|
-
|
|
645
|
-
cache = transformers.cache_utils.HybridCache(
|
|
646
|
-
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
|
|
647
|
-
)
|
|
648
|
-
for i, (key, value) in enumerate(key_value_pairs):
|
|
649
|
-
cache.update(
|
|
650
|
-
key,
|
|
651
|
-
value,
|
|
652
|
-
i,
|
|
653
|
-
cache_kwargs={
|
|
654
|
-
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
|
|
655
|
-
key.device
|
|
656
|
-
)
|
|
657
|
-
},
|
|
633
|
+
max_batch_size and max_cache_len
|
|
634
|
+
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
|
|
635
|
+
if sliding_window is None:
|
|
636
|
+
sliding_window = max_cache_len
|
|
637
|
+
_max_cache_len = max_cache_len
|
|
638
|
+
_sliding_window = sliding_window
|
|
639
|
+
|
|
640
|
+
class _config:
|
|
641
|
+
max_cache_len = _max_cache_len
|
|
642
|
+
batch_size = max_batch_size
|
|
643
|
+
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
|
|
644
|
+
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
|
|
645
|
+
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
|
|
646
|
+
num_hidden_layers = len(key_value_pairs)
|
|
647
|
+
sliding_window = _sliding_window
|
|
648
|
+
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
649
|
+
|
|
650
|
+
def get_text_config(self, *args, **kwargs):
|
|
651
|
+
return self
|
|
652
|
+
|
|
653
|
+
if layer_types:
|
|
654
|
+
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
655
|
+
|
|
656
|
+
cache = transformers.cache_utils.HybridCache(
|
|
657
|
+
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
|
|
658
658
|
)
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
659
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
660
|
+
cache.update(
|
|
661
|
+
key,
|
|
662
|
+
value,
|
|
663
|
+
i,
|
|
664
|
+
cache_kwargs={
|
|
665
|
+
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
|
|
666
|
+
key.device
|
|
667
|
+
)
|
|
668
|
+
},
|
|
669
|
+
)
|
|
670
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
671
|
+
# The cache constructor contains the two following lines
|
|
672
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
673
|
+
# initialized. We need to remove them.
|
|
674
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
675
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
676
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
677
|
+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
678
|
+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
679
|
+
f"{len(key_value_pairs)} expected."
|
|
680
|
+
)
|
|
681
|
+
return finalize_cache(cache)
|
|
682
|
+
|
|
683
|
+
else:
|
|
684
|
+
make_hybrid_cache = None # type: ignore[assignment]
|
|
671
685
|
|
|
672
686
|
|
|
673
687
|
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
|
|
@@ -787,6 +787,8 @@ def string_type(
|
|
|
787
787
|
return f"ultralytics.{obj.__class__.__name__}(...)"
|
|
788
788
|
if obj.__class__.__name__ == "FakeTensorMode":
|
|
789
789
|
return f"{obj}"
|
|
790
|
+
if obj.__class__.__name__ == "FakeTensorContext":
|
|
791
|
+
return "FakeTensorContext(...)"
|
|
790
792
|
|
|
791
793
|
if verbose:
|
|
792
794
|
print(f"[string_type] END:{type(obj)}")
|
|
@@ -1016,6 +1018,8 @@ def max_diff(
|
|
|
1016
1018
|
|
|
1017
1019
|
You may use :func:`string_diff` to display the discrepancies in one string.
|
|
1018
1020
|
"""
|
|
1021
|
+
if verbose >= 10:
|
|
1022
|
+
print(f"[max_diff] {type(expected)} ? {type(got)}")
|
|
1019
1023
|
if expected is None and got is None:
|
|
1020
1024
|
return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
|
|
1021
1025
|
|
|
@@ -1061,8 +1065,8 @@ def max_diff(
|
|
|
1061
1065
|
if expected.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1062
1066
|
if verbose >= 6:
|
|
1063
1067
|
print(
|
|
1064
|
-
f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
|
|
1065
|
-
f"? {string_type(got)}"
|
|
1068
|
+
f"[max_diff] CausalLMOutputWithPast: {string_type(expected, with_shape=True)} "
|
|
1069
|
+
f"? {string_type(got, with_shape=True)}"
|
|
1066
1070
|
)
|
|
1067
1071
|
if got.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1068
1072
|
return max_diff(
|