onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +281 -80
- onnx_diagnostic/doc.py +22 -0
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +78 -8
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +1744 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +54 -73
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +21 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +72 -18
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +81 -8
- onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +58 -14
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +166 -106
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +44 -41
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -630,9 +630,12 @@ class ModelInputs:
|
|
|
630
630
|
method_name: str = "forward",
|
|
631
631
|
name: str = "main",
|
|
632
632
|
):
|
|
633
|
-
assert
|
|
634
|
-
model
|
|
635
|
-
),
|
|
633
|
+
assert (
|
|
634
|
+
model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
|
|
635
|
+
), (
|
|
636
|
+
f"unexpected type for model={type(model)}, "
|
|
637
|
+
f"it must be a torch.nn.Module or None"
|
|
638
|
+
)
|
|
636
639
|
assert name, (
|
|
637
640
|
f"name={name!r} cannot be empty this string is used to "
|
|
638
641
|
f"display meaningful error messages"
|
|
@@ -641,26 +644,42 @@ class ModelInputs:
|
|
|
641
644
|
self.model = model
|
|
642
645
|
self.level = level
|
|
643
646
|
self.method_name = method_name
|
|
644
|
-
self.forward = getattr(model, method_name)
|
|
645
|
-
self.signature = inspect.signature(self.forward)
|
|
647
|
+
self.forward = getattr(model, method_name) if model is not None else None
|
|
648
|
+
self.signature = inspect.signature(self.forward) if self.forward else None
|
|
646
649
|
|
|
647
650
|
# information about the signature
|
|
648
|
-
self.forward_parameter_names =
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
651
|
+
self.forward_parameter_names = (
|
|
652
|
+
set(
|
|
653
|
+
p.name
|
|
654
|
+
for p in self.signature.parameters.values()
|
|
655
|
+
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
|
|
656
|
+
)
|
|
657
|
+
if self.signature
|
|
658
|
+
else None
|
|
659
|
+
)
|
|
660
|
+
self.forward_ordered_parameter_names = (
|
|
661
|
+
list(self.signature.parameters) if self.signature else None
|
|
662
|
+
)
|
|
663
|
+
self.forward_positioned_parameter_names = (
|
|
664
|
+
[
|
|
665
|
+
p.name
|
|
666
|
+
for p in self.signature.parameters.values()
|
|
667
|
+
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
668
|
+
]
|
|
669
|
+
if self.signature
|
|
670
|
+
else None
|
|
671
|
+
)
|
|
672
|
+
names = (
|
|
673
|
+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
|
|
674
|
+
if self.signature
|
|
675
|
+
else None
|
|
652
676
|
)
|
|
653
|
-
self.forward_ordered_parameter_names = list(self.signature.parameters)
|
|
654
|
-
self.forward_positioned_parameter_names = [
|
|
655
|
-
p.name
|
|
656
|
-
for p in self.signature.parameters.values()
|
|
657
|
-
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
658
|
-
]
|
|
659
|
-
names = [
|
|
660
|
-
p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
|
|
661
|
-
]
|
|
662
677
|
self.forward_args = names[0] if names else None
|
|
663
|
-
names =
|
|
678
|
+
names = (
|
|
679
|
+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
|
|
680
|
+
if self.signature
|
|
681
|
+
else None
|
|
682
|
+
)
|
|
664
683
|
self.forward_kwargs = names[0] if names else None
|
|
665
684
|
self.forward_custom_op_schema = None
|
|
666
685
|
self.forward_need_serialization = False
|
|
@@ -711,6 +730,7 @@ class ModelInputs:
|
|
|
711
730
|
@property
|
|
712
731
|
def true_model_name(self) -> str:
|
|
713
732
|
"Returns class name or module name."
|
|
733
|
+
assert self.model is not None, "model was None when the class was initialized."
|
|
714
734
|
return (
|
|
715
735
|
self.model.__class__.__name__
|
|
716
736
|
if isinstance(self.model, torch.nn.Module)
|
|
@@ -942,7 +962,7 @@ class ModelInputs:
|
|
|
942
962
|
)
|
|
943
963
|
)
|
|
944
964
|
names = s2.pop()
|
|
945
|
-
for name in names:
|
|
965
|
+
for i, name in enumerate(names):
|
|
946
966
|
assert name not in {"_diag", "verbose"}, (
|
|
947
967
|
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
|
|
948
968
|
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
|
|
@@ -968,6 +988,14 @@ class ModelInputs:
|
|
|
968
988
|
with the corresponding dynamic shapes.
|
|
969
989
|
*kwargs*, *dynamic_shapes* are modified inplace.
|
|
970
990
|
"""
|
|
991
|
+
assert (
|
|
992
|
+
self.signature is not None
|
|
993
|
+
and self.forward_parameter_names is not None
|
|
994
|
+
and self.forward_ordered_parameter_names is not None
|
|
995
|
+
), (
|
|
996
|
+
"model was None when the class was initialized, "
|
|
997
|
+
"cannot move args to kwargs without the signature."
|
|
998
|
+
)
|
|
971
999
|
sig = self.signature
|
|
972
1000
|
arg_dyn, kw_dyn = dynamic_shapes
|
|
973
1001
|
for i, p in enumerate(sig.parameters):
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Set, Tuple, Union
|
|
2
|
+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
3
|
+
from .dynamic_shapes import ModelInputs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
7
|
+
"""
|
|
8
|
+
Returns the dynamic shapes for the given inputs.
|
|
9
|
+
All dimensions are considered as dynamic.
|
|
10
|
+
``dim_prefix`` can be a string (the function uses it as a prefix),
|
|
11
|
+
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
|
|
12
|
+
|
|
13
|
+
.. runpython::
|
|
14
|
+
:showcode:
|
|
15
|
+
|
|
16
|
+
import pprint
|
|
17
|
+
import torch
|
|
18
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
19
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
|
|
20
|
+
|
|
21
|
+
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
22
|
+
inputs = dict(
|
|
23
|
+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
|
|
24
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
25
|
+
position_ids=torch.arange(3, dtype=torch.int64),
|
|
26
|
+
past_key_values=make_dynamic_cache(
|
|
27
|
+
[(torch.randn(bsize, nheads, slen, dim),
|
|
28
|
+
torch.randn(bsize, nheads, slen, dim))]
|
|
29
|
+
),
|
|
30
|
+
)
|
|
31
|
+
ds = all_dynamic_shape_from_inputs(inputs)
|
|
32
|
+
pprint.pprint(ds)
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(dim_prefix, str):
|
|
35
|
+
prefixes: Set[str] = set()
|
|
36
|
+
|
|
37
|
+
def tensor_to_shape(tensor):
|
|
38
|
+
n = len(prefixes)
|
|
39
|
+
p = f"{dim_prefix}_{n}"
|
|
40
|
+
prefixes.add(p)
|
|
41
|
+
return {i: f"{p}_{i}" for i in range(tensor.ndim)}
|
|
42
|
+
|
|
43
|
+
else:
|
|
44
|
+
|
|
45
|
+
def tensor_to_shape(tensor):
|
|
46
|
+
return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
|
|
47
|
+
|
|
48
|
+
return flatten_unflatten_for_dynamic_shapes(
|
|
49
|
+
inputs, change_function=tensor_to_shape, use_dict=True
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def guess_dynamic_shapes_from_inputs(
|
|
54
|
+
inputs: List[Any], auto: Union[bool, str] = False
|
|
55
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
56
|
+
"""
|
|
57
|
+
Guesses which dimension is dimension from a set of inputs.
|
|
58
|
+
Every dimension having different values over multiple sets
|
|
59
|
+
of inputs. Every dimension not changing remains static.
|
|
60
|
+
|
|
61
|
+
:param inputs: a list of input sets
|
|
62
|
+
:param auto: True for ``torch.export.Dim.AUTO``,
|
|
63
|
+
False for ``torch.export.Dim.DYNAMIC``,
|
|
64
|
+
a string to get a unique string for every dynamic dimension
|
|
65
|
+
:return: args and kwargs
|
|
66
|
+
|
|
67
|
+
.. runpython::
|
|
68
|
+
:showcode:
|
|
69
|
+
|
|
70
|
+
import pprint
|
|
71
|
+
import torch
|
|
72
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
73
|
+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
|
|
74
|
+
|
|
75
|
+
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
76
|
+
inputs1 = dict(
|
|
77
|
+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
|
|
78
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
79
|
+
position_ids=torch.arange(3, dtype=torch.int64),
|
|
80
|
+
past_key_values=make_dynamic_cache(
|
|
81
|
+
[
|
|
82
|
+
(
|
|
83
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
84
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
85
|
+
),
|
|
86
|
+
]
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
bsize, nheads, slen, dim = 3, 1, 33, 96
|
|
90
|
+
inputs2 = dict(
|
|
91
|
+
input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
|
|
92
|
+
attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
|
|
93
|
+
position_ids=torch.arange(4, dtype=torch.int64),
|
|
94
|
+
past_key_values=make_dynamic_cache(
|
|
95
|
+
[
|
|
96
|
+
(
|
|
97
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
98
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
99
|
+
),
|
|
100
|
+
]
|
|
101
|
+
),
|
|
102
|
+
)
|
|
103
|
+
ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
|
|
104
|
+
pprint.pprint(ds)
|
|
105
|
+
|
|
106
|
+
This function returns something equivalent to function
|
|
107
|
+
:class:`torch.export.dynamic_shapes.AdditionalInputs` but this
|
|
108
|
+
one needs a model.
|
|
109
|
+
|
|
110
|
+
.. runpython::
|
|
111
|
+
:showcode:
|
|
112
|
+
|
|
113
|
+
import pprint
|
|
114
|
+
import torch
|
|
115
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
116
|
+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
|
|
117
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
118
|
+
|
|
119
|
+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
|
|
120
|
+
ds = torch.export.dynamic_shapes.AdditionalInputs()
|
|
121
|
+
ds.add((), data["inputs"])
|
|
122
|
+
ds.add((), data["inputs2"])
|
|
123
|
+
pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
|
|
124
|
+
"""
|
|
125
|
+
mi = ModelInputs(None, inputs)
|
|
126
|
+
return mi.guess_dynamic_shapes(auto=auto)
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -1014,7 +1014,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1014
1014
|
msg_ = "\n".join(excs)
|
|
1015
1015
|
msg = f"{msg}\n{msg_}" if msg else msg_
|
|
1016
1016
|
raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
|
|
1017
|
-
elif expected.__class__.__name__
|
|
1017
|
+
elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
|
|
1018
1018
|
atts = {"key_cache", "value_cache"}
|
|
1019
1019
|
self.assertEqualArrayAny(
|
|
1020
1020
|
{k: expected.__dict__.get(k, None) for k in atts},
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
from typing import Any, List, Tuple
|
|
1
|
+
from typing import Any, Callable, List, Optional, Tuple
|
|
2
2
|
import packaging.version as pv
|
|
3
3
|
import torch
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def flatten_unflatten_for_dynamic_shapes(
|
|
8
|
+
def flatten_unflatten_for_dynamic_shapes(
|
|
9
|
+
obj: Any,
|
|
10
|
+
use_dict: bool = False,
|
|
11
|
+
change_function: Optional[Callable[[torch.Tensor], Any]] = None,
|
|
12
|
+
) -> Any:
|
|
9
13
|
"""
|
|
10
14
|
Returns the object in a different structure similar to what
|
|
11
15
|
the definition of the dynamic shapes should use.
|
|
@@ -15,11 +19,13 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
|
|
|
15
19
|
:func:`torch.export.export` only considers the values,
|
|
16
20
|
the context gives the dictionary keys but it is not expressed
|
|
17
21
|
in the dynamic shapes, these specifications seems to be different
|
|
18
|
-
for the strict and non strict mode.
|
|
22
|
+
for the strict and non strict mode. It also preserves tuple.
|
|
23
|
+
:param change_function: to modifies the tensor in the structure itself,
|
|
24
|
+
like replace them by a shape
|
|
19
25
|
:return: the serialized object
|
|
20
26
|
"""
|
|
21
27
|
if isinstance(obj, torch.Tensor):
|
|
22
|
-
return obj
|
|
28
|
+
return change_function(obj) if change_function else obj
|
|
23
29
|
flat, spec = torch.utils._pytree.tree_flatten(obj)
|
|
24
30
|
start = 0
|
|
25
31
|
end = 0
|
|
@@ -27,12 +33,17 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
|
|
|
27
33
|
for subspec in spec.children_specs:
|
|
28
34
|
end += subspec.num_leaves
|
|
29
35
|
value = subspec.unflatten(flat[start:end])
|
|
30
|
-
value = flatten_unflatten_for_dynamic_shapes(
|
|
36
|
+
value = flatten_unflatten_for_dynamic_shapes(
|
|
37
|
+
value, use_dict=use_dict, change_function=change_function
|
|
38
|
+
)
|
|
31
39
|
subtrees.append(value)
|
|
32
40
|
start = end
|
|
33
|
-
if use_dict
|
|
34
|
-
|
|
35
|
-
|
|
41
|
+
if use_dict:
|
|
42
|
+
if spec.type is dict or spec.context:
|
|
43
|
+
# This a dictionary.
|
|
44
|
+
return dict(zip(spec.context, subtrees))
|
|
45
|
+
if spec.type is tuple:
|
|
46
|
+
return tuple(subtrees)
|
|
36
47
|
# This is a list.
|
|
37
48
|
return subtrees
|
|
38
49
|
|
|
@@ -141,6 +152,65 @@ else:
|
|
|
141
152
|
return cache
|
|
142
153
|
|
|
143
154
|
|
|
155
|
+
def make_static_cache(
|
|
156
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
157
|
+
) -> transformers.cache_utils.DynamicCache:
|
|
158
|
+
"""
|
|
159
|
+
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
|
|
160
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
161
|
+
:return: :class:`transformers.cache_utils.StaticCache`
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
|
|
165
|
+
.. runpython::
|
|
166
|
+
:showcode:
|
|
167
|
+
|
|
168
|
+
import torch
|
|
169
|
+
from onnx_diagnostic.helpers import string_type
|
|
170
|
+
from onnx_diagnostic.helpers.cache_helper import make_static_cache
|
|
171
|
+
|
|
172
|
+
n_layers = 2
|
|
173
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
174
|
+
|
|
175
|
+
past_key_values = make_static_cache(
|
|
176
|
+
[
|
|
177
|
+
(
|
|
178
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
179
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
180
|
+
)
|
|
181
|
+
for i in range(n_layers)
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
print(string_type(past_key_values, with_shape=True))
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
class _config:
|
|
188
|
+
def __init__(self):
|
|
189
|
+
self.head_dim = key_value_pairs[0][0].shape[-1]
|
|
190
|
+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
191
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
192
|
+
|
|
193
|
+
cache = transformers.cache_utils.StaticCache(
|
|
194
|
+
_config(),
|
|
195
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
196
|
+
device=key_value_pairs[0][0].device,
|
|
197
|
+
dtype=key_value_pairs[0][0].dtype,
|
|
198
|
+
max_cache_len=key_value_pairs[0][0].shape[2],
|
|
199
|
+
)
|
|
200
|
+
for i in range(len(key_value_pairs)):
|
|
201
|
+
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
202
|
+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
203
|
+
f"got {key_value_pairs[i][0].shape}"
|
|
204
|
+
)
|
|
205
|
+
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
206
|
+
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
207
|
+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
208
|
+
f"got {key_value_pairs[i][1].shape}"
|
|
209
|
+
)
|
|
210
|
+
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
211
|
+
return cache
|
|
212
|
+
|
|
213
|
+
|
|
144
214
|
def make_encoder_decoder_cache(
|
|
145
215
|
self_attention_cache: transformers.cache_utils.DynamicCache,
|
|
146
216
|
cross_attention_cache: transformers.cache_utils.DynamicCache,
|
|
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
|
|
|
34
34
|
config._attn_implementation_autoset = False
|
|
35
35
|
continue
|
|
36
36
|
if isinstance(v, dict):
|
|
37
|
-
|
|
38
|
-
config, k
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
if not hasattr(config, k) or getattr(config, k) is None:
|
|
38
|
+
setattr(config, k, v)
|
|
39
|
+
continue
|
|
40
|
+
existing = getattr(config, k)
|
|
41
|
+
if type(existing) is dict:
|
|
42
|
+
existing.update(v)
|
|
43
|
+
else:
|
|
44
|
+
update_config(getattr(config, k), v)
|
|
41
45
|
continue
|
|
42
46
|
setattr(config, k, v)
|
|
43
47
|
|
|
@@ -558,7 +558,7 @@ def string_type(
|
|
|
558
558
|
print(f"[string_type] CACHE1:{type(obj)}")
|
|
559
559
|
return f"MambaCache(conv_states={c}, ssm_states={d})"
|
|
560
560
|
|
|
561
|
-
if obj.__class__.__name__ in
|
|
561
|
+
if obj.__class__.__name__ in {"DynamicCache", "SlidingWindowCache", "StaticCache"}:
|
|
562
562
|
kc = string_type(
|
|
563
563
|
obj.key_cache,
|
|
564
564
|
with_shape=with_shape,
|
|
@@ -857,7 +857,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
|
857
857
|
return flatten_object(list(x.values()), drop_keys=drop_keys)
|
|
858
858
|
return flatten_object(list(x.items()), drop_keys=drop_keys)
|
|
859
859
|
|
|
860
|
-
if x.__class__.__name__
|
|
860
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
|
|
861
861
|
res = flatten_object(x.key_cache) + flatten_object(x.value_cache)
|
|
862
862
|
return tuple(res)
|
|
863
863
|
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
@@ -1424,10 +1424,37 @@ def max_diff(
|
|
|
1424
1424
|
f"level={level}"
|
|
1425
1425
|
)
|
|
1426
1426
|
|
|
1427
|
+
if expected.__class__.__name__ == "StaticCache":
|
|
1428
|
+
if got.__class__.__name__ == "StaticCache":
|
|
1429
|
+
if verbose >= 6:
|
|
1430
|
+
print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
|
|
1431
|
+
return max_diff(
|
|
1432
|
+
[expected.key_cache, expected.value_cache],
|
|
1433
|
+
[got.key_cache, got.value_cache],
|
|
1434
|
+
verbose=verbose,
|
|
1435
|
+
hist=hist,
|
|
1436
|
+
)
|
|
1437
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1438
|
+
return max_diff(
|
|
1439
|
+
[expected.key_cache, expected.value_cache],
|
|
1440
|
+
[got[0], got[1]],
|
|
1441
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1442
|
+
**_dkws,
|
|
1443
|
+
)
|
|
1444
|
+
raise AssertionError(
|
|
1445
|
+
f"StaticCache not fully implemented with classes "
|
|
1446
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1447
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1448
|
+
f"level={level}"
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1427
1451
|
if expected.__class__.__name__ == "SlidingWindowCache":
|
|
1428
1452
|
if got.__class__.__name__ == "SlidingWindowCache":
|
|
1429
1453
|
if verbose >= 6:
|
|
1430
|
-
print(
|
|
1454
|
+
print(
|
|
1455
|
+
f"[max_diff] SlidingWindowCache: "
|
|
1456
|
+
f"{string_type(expected)} ? {string_type(got)}"
|
|
1457
|
+
)
|
|
1431
1458
|
return max_diff(
|
|
1432
1459
|
[expected.key_cache, expected.value_cache],
|
|
1433
1460
|
[got.key_cache, got.value_cache],
|