onnx-diagnostic 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/doc.py +258 -8
- onnx_diagnostic/export/api.py +492 -17
- onnx_diagnostic/export/dynamic_shapes.py +21 -6
- onnx_diagnostic/export/shape_helper.py +0 -8
- onnx_diagnostic/helpers/cache_helper.py +98 -13
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/onnx_helper.py +7 -0
- onnx_diagnostic/helpers/rt_helper.py +14 -1
- onnx_diagnostic/helpers/torch_helper.py +22 -9
- onnx_diagnostic/tasks/image_text_to_text.py +4 -1
- onnx_diagnostic/tasks/text_generation.py +17 -17
- onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +67 -39
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -9
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +1 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -0
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.10.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.10.dist-info}/RECORD +23 -23
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.10.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.10.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.10.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,7 @@ class CacheKeyValue:
|
|
|
19
19
|
capi.value_cache
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
def __init__(self, cache=None):
|
|
22
|
+
def __init__(self, cache=None, cls_layers=None):
|
|
23
23
|
if hasattr(cache, "layers"):
|
|
24
24
|
layers = [
|
|
25
25
|
layer
|
|
@@ -28,24 +28,52 @@ class CacheKeyValue:
|
|
|
28
28
|
]
|
|
29
29
|
self.key_cache = [layer.keys for layer in layers]
|
|
30
30
|
self.value_cache = [layer.values for layer in layers]
|
|
31
|
+
assert (
|
|
32
|
+
cls_layers is None
|
|
33
|
+
), f"cache is {type(cache)}, cannot specify cls_layers={cls_layers}"
|
|
34
|
+
self.cls_layers = [type(lay) for lay in cache.layers]
|
|
31
35
|
elif cache is not None and hasattr(cache, "key_cache"):
|
|
32
36
|
self.key_cache = cache.key_cache
|
|
33
37
|
self.value_cache = cache.value_cache
|
|
38
|
+
self.cls_layers = cls_layers
|
|
39
|
+
elif (
|
|
40
|
+
cache is not None
|
|
41
|
+
and isinstance(cache, list)
|
|
42
|
+
and all(isinstance(t, torch.Tensor) for t in cache)
|
|
43
|
+
):
|
|
44
|
+
self.key_cache = cache[::2]
|
|
45
|
+
self.value_cache = cache[1::2]
|
|
46
|
+
self.cls_layers = cls_layers
|
|
34
47
|
elif cache is None:
|
|
35
48
|
self.key_cache = None
|
|
36
49
|
self.value_cache = None
|
|
50
|
+
self.cls_layers = cls_layers
|
|
37
51
|
else:
|
|
38
52
|
raise NotImplementedError(f"type(cache)={type(cache)}")
|
|
39
53
|
|
|
40
54
|
def make_dynamic_cache(self):
|
|
41
55
|
"""Does the reverse operation."""
|
|
42
|
-
return make_dynamic_cache(
|
|
56
|
+
return make_dynamic_cache(
|
|
57
|
+
list(zip(self.key_cache, self.value_cache)), cls_layers=self.cls_layers
|
|
58
|
+
)
|
|
43
59
|
|
|
44
60
|
@property
|
|
45
61
|
def n_layers(self) -> int:
|
|
46
62
|
"""Returns the number of layers."""
|
|
47
63
|
return len(self.key_cache) if self.key_cache else 0
|
|
48
64
|
|
|
65
|
+
def __len__(self) -> int:
|
|
66
|
+
"Returns the number of tensors."
|
|
67
|
+
return len(self.key_cache) + len(self.value_cache)
|
|
68
|
+
|
|
69
|
+
def aslist(self) -> List[torch.Tensor]:
|
|
70
|
+
"Returns tensors in a list."
|
|
71
|
+
res = []
|
|
72
|
+
for i in range(self.n_layers):
|
|
73
|
+
res.append(self.key_cache[i])
|
|
74
|
+
res.append(self.value_cache[i])
|
|
75
|
+
return res
|
|
76
|
+
|
|
49
77
|
|
|
50
78
|
def flatten_unflatten_for_dynamic_shapes(
|
|
51
79
|
obj: Any,
|
|
@@ -156,12 +184,16 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
156
184
|
|
|
157
185
|
def make_dynamic_cache(
|
|
158
186
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
187
|
+
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
159
188
|
) -> transformers.cache_utils.DynamicCache:
|
|
160
189
|
"""
|
|
161
190
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
162
191
|
This version is valid for ``transformers >= 4.50``.
|
|
163
192
|
|
|
164
193
|
:param key_value_pairs: list of pairs of (key, values)
|
|
194
|
+
:param cls_layers: to select the appropriate class to use on each layer,
|
|
195
|
+
if specified, sliding_window is ignored, it can be a string
|
|
196
|
+
if all layers are expected to follow the same class
|
|
165
197
|
:return: :class:`transformers.cache_utils.DynamicCache`
|
|
166
198
|
|
|
167
199
|
Example:
|
|
@@ -192,15 +224,49 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
192
224
|
are supported.
|
|
193
225
|
"""
|
|
194
226
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
227
|
+
cls_kwargs = {}
|
|
228
|
+
if isinstance(cls_layers, str):
|
|
229
|
+
assert hasattr(
|
|
230
|
+
transformers.cache_utils, cls_layers
|
|
231
|
+
), f"Unable to find class {cls_layers!r} in transformers.cache_utils"
|
|
232
|
+
cls_layer = getattr(transformers.cache_utils, cls_layers)
|
|
233
|
+
if cls_layers == "DynamicSlidingWindowLayer":
|
|
234
|
+
cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
|
|
235
|
+
assert isinstance(
|
|
236
|
+
cls_kwargs["sliding_window"], int
|
|
237
|
+
), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
|
|
238
|
+
elif cls_layers is not None:
|
|
239
|
+
unique = set(cls_layers)
|
|
240
|
+
assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}"
|
|
241
|
+
cls_layer = unique.pop()
|
|
242
|
+
if (
|
|
243
|
+
hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer")
|
|
244
|
+
and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer
|
|
245
|
+
):
|
|
246
|
+
from .helper import string_type
|
|
247
|
+
|
|
248
|
+
assert key_value_pairs and key_value_pairs[0], (
|
|
249
|
+
f"not implemented for key_value_pairs="
|
|
250
|
+
f"{string_type(key_value_pairs, with_shape=True)}"
|
|
251
|
+
)
|
|
252
|
+
cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
|
|
253
|
+
assert isinstance(
|
|
254
|
+
cls_kwargs["sliding_window"], int
|
|
255
|
+
), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
|
|
256
|
+
else:
|
|
257
|
+
cls_layer = (
|
|
258
|
+
transformers.cache_utils.DynamicLayer
|
|
259
|
+
if hasattr(transformers.cache_utils, "DynamicLayer")
|
|
260
|
+
else None
|
|
261
|
+
)
|
|
262
|
+
|
|
195
263
|
if (
|
|
196
264
|
key_value_pairs
|
|
197
265
|
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
198
266
|
and pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
199
267
|
):
|
|
200
268
|
cache = transformers.cache_utils.DynamicCache()
|
|
201
|
-
cache.layers.extend(
|
|
202
|
-
[transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
|
|
203
|
-
)
|
|
269
|
+
cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
|
|
204
270
|
for i, layer in enumerate(cache.layers):
|
|
205
271
|
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
206
272
|
layer.dtype = k.dtype
|
|
@@ -214,14 +280,21 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
214
280
|
)
|
|
215
281
|
return finalize_cache(cache)
|
|
216
282
|
|
|
217
|
-
cache = transformers.cache_utils.DynamicCache(
|
|
218
|
-
if hasattr(cache, "layers") and
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
cache
|
|
283
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
284
|
+
if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer:
|
|
285
|
+
cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
|
|
286
|
+
for i, layer in enumerate(cache.layers):
|
|
287
|
+
layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
288
|
+
layer.is_initialized = True
|
|
289
|
+
else:
|
|
290
|
+
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
|
|
291
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
292
|
+
# The cache constructor contains the two following lines
|
|
293
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
294
|
+
# initialized. We need to remove them.
|
|
295
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
296
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
297
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
225
298
|
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
226
299
|
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
227
300
|
f"{len(key_value_pairs)} expected."
|
|
@@ -232,6 +305,7 @@ else:
|
|
|
232
305
|
|
|
233
306
|
def make_dynamic_cache(
|
|
234
307
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
308
|
+
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
235
309
|
) -> transformers.cache_utils.DynamicCache:
|
|
236
310
|
"""
|
|
237
311
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -263,6 +337,7 @@ else:
|
|
|
263
337
|
)
|
|
264
338
|
print(string_type(past_key_values, with_shape=True))
|
|
265
339
|
"""
|
|
340
|
+
assert not cls_layers, "cls_layers cannot be used for transformers<5."
|
|
266
341
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
267
342
|
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
|
|
268
343
|
for i, (key, value) in enumerate(key_value_pairs):
|
|
@@ -508,9 +583,13 @@ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
|
|
|
508
583
|
)
|
|
509
584
|
return finalize_cache(cache)
|
|
510
585
|
|
|
586
|
+
def get_make_hybrid_cache():
|
|
587
|
+
return make_sliding_window_cache
|
|
588
|
+
|
|
511
589
|
else:
|
|
512
590
|
make_sliding_window_cache = None # type: ignore[assignment]
|
|
513
591
|
|
|
592
|
+
|
|
514
593
|
if hasattr(transformers.cache_utils, "HybridCache"):
|
|
515
594
|
|
|
516
595
|
def make_hybrid_cache(
|
|
@@ -672,9 +751,15 @@ if hasattr(transformers.cache_utils, "HybridCache"):
|
|
|
672
751
|
)
|
|
673
752
|
return finalize_cache(cache)
|
|
674
753
|
|
|
754
|
+
def get_make_hybrid_cache():
|
|
755
|
+
return make_hybrid_cache
|
|
756
|
+
|
|
675
757
|
else:
|
|
676
758
|
make_hybrid_cache = None # type: ignore[assignment]
|
|
677
759
|
|
|
760
|
+
def get_make_hybrid_cache():
|
|
761
|
+
return None
|
|
762
|
+
|
|
678
763
|
|
|
679
764
|
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
|
|
680
765
|
"""
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import enum
|
|
3
3
|
import inspect
|
|
4
|
-
import itertools
|
|
5
4
|
import json
|
|
6
5
|
from dataclasses import is_dataclass, fields
|
|
7
6
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
@@ -991,15 +990,17 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
|
991
990
|
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
992
991
|
from .cache_helper import CacheKeyValue
|
|
993
992
|
|
|
994
|
-
|
|
995
|
-
return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
|
|
993
|
+
return CacheKeyValue(x).aslist()
|
|
996
994
|
|
|
997
995
|
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
998
|
-
res =
|
|
996
|
+
res = [
|
|
997
|
+
*flatten_object(x.self_attention_cache),
|
|
998
|
+
*flatten_object(x.cross_attention_cache),
|
|
999
|
+
]
|
|
999
1000
|
return tuple(res)
|
|
1000
1001
|
if x.__class__.__name__ == "MambaCache":
|
|
1001
1002
|
if isinstance(x.conv_states, list):
|
|
1002
|
-
res = flatten_object(x.conv_states)
|
|
1003
|
+
res = [*flatten_object(x.conv_states), *flatten_object(x.ssm_states)]
|
|
1003
1004
|
return tuple(res)
|
|
1004
1005
|
return (x.conv_states, x.ssm_states)
|
|
1005
1006
|
if hasattr(x, "to_tuple"):
|
|
@@ -28,6 +28,7 @@ from onnx import (
|
|
|
28
28
|
NodeProto,
|
|
29
29
|
OperatorSetIdProto,
|
|
30
30
|
TensorProto,
|
|
31
|
+
TypeProto,
|
|
31
32
|
ValueInfoProto,
|
|
32
33
|
load as onnx_load,
|
|
33
34
|
)
|
|
@@ -385,6 +386,12 @@ def pretty_onnx(
|
|
|
385
386
|
shape_str = ",".join(map(str, shape))
|
|
386
387
|
return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
|
|
387
388
|
|
|
389
|
+
if isinstance(onx, TypeProto):
|
|
390
|
+
itype = onx.tensor_type.elem_type
|
|
391
|
+
shape = tuple((d.dim_param or d.dim_value) for d in onx.tensor_type.shape.dim)
|
|
392
|
+
shape_str = ",".join(map(str, shape))
|
|
393
|
+
return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}]"
|
|
394
|
+
|
|
388
395
|
if isinstance(onx, AttributeProto):
|
|
389
396
|
att = onx
|
|
390
397
|
if att.type == AttributeProto.INT:
|
|
@@ -41,7 +41,20 @@ def make_feeds(
|
|
|
41
41
|
"""
|
|
42
42
|
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
|
|
43
43
|
# because it's fued into rotary embedding in GQA.
|
|
44
|
-
if is_modelbuilder and isinstance(inputs, dict):
|
|
44
|
+
if is_modelbuilder and isinstance(inputs, dict) and "position_ids" in inputs:
|
|
45
|
+
position_ids = inputs["position_ids"] # type: ignore[valid-type]
|
|
46
|
+
# We just check position_ids are contiguous.
|
|
47
|
+
assert isinstance(position_ids, torch.Tensor) and (
|
|
48
|
+
(
|
|
49
|
+
(position_ids - position_ids.min())
|
|
50
|
+
== torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)
|
|
51
|
+
)
|
|
52
|
+
.max()
|
|
53
|
+
.item()
|
|
54
|
+
), (
|
|
55
|
+
f"ModelBuilder does not support position_ids={position_ids}, "
|
|
56
|
+
f"inputs={string_type(inputs, with_shape=True)}"
|
|
57
|
+
)
|
|
45
58
|
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
|
|
46
59
|
|
|
47
60
|
flat = flatten_object(inputs, drop_keys=True)
|
|
@@ -15,9 +15,6 @@ from .helper import string_type, size_type
|
|
|
15
15
|
from .cache_helper import (
|
|
16
16
|
make_dynamic_cache,
|
|
17
17
|
make_encoder_decoder_cache,
|
|
18
|
-
make_hybrid_cache,
|
|
19
|
-
make_sliding_window_cache,
|
|
20
|
-
make_mamba_cache,
|
|
21
18
|
make_static_cache,
|
|
22
19
|
CacheKeyValue,
|
|
23
20
|
)
|
|
@@ -769,10 +766,22 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
|
769
766
|
return {to_any(t, to_value) for t in value}
|
|
770
767
|
if type(value) is dict:
|
|
771
768
|
return {k: to_any(t, to_value) for k, t in value.items()}
|
|
772
|
-
if value.__class__.__name__
|
|
773
|
-
make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
|
|
769
|
+
if value.__class__.__name__ == "DynamicCache":
|
|
774
770
|
cc = CacheKeyValue(value)
|
|
775
|
-
return
|
|
771
|
+
return make_dynamic_cache(
|
|
772
|
+
list(
|
|
773
|
+
zip(
|
|
774
|
+
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
775
|
+
[t.to(to_value) if t is not None else t for t in cc.value_cache],
|
|
776
|
+
)
|
|
777
|
+
),
|
|
778
|
+
cls_layers=cc.cls_layers,
|
|
779
|
+
)
|
|
780
|
+
if value.__class__.__name__ == "HybridCache":
|
|
781
|
+
from .cache_helper import make_hybrid_cache
|
|
782
|
+
|
|
783
|
+
cc = CacheKeyValue(value)
|
|
784
|
+
return make_hybrid_cache(
|
|
776
785
|
list(
|
|
777
786
|
zip(
|
|
778
787
|
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
@@ -843,7 +852,9 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
843
852
|
from .cache_helper import CacheKeyValue
|
|
844
853
|
|
|
845
854
|
ca = CacheKeyValue(value)
|
|
846
|
-
return make_dynamic_cache(
|
|
855
|
+
return make_dynamic_cache(
|
|
856
|
+
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
|
|
857
|
+
)
|
|
847
858
|
if value.__class__.__name__ == "StaticCache":
|
|
848
859
|
from .cache_helper import CacheKeyValue
|
|
849
860
|
|
|
@@ -858,12 +869,12 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
858
869
|
max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
|
|
859
870
|
)
|
|
860
871
|
if value.__class__.__name__ == "HybridCache":
|
|
861
|
-
from .cache_helper import CacheKeyValue
|
|
872
|
+
from .cache_helper import CacheKeyValue, make_hybrid_cache
|
|
862
873
|
|
|
863
874
|
ca = CacheKeyValue(value)
|
|
864
875
|
return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
|
|
865
876
|
if value.__class__.__name__ == "SlidingWindowCache":
|
|
866
|
-
from .cache_helper import CacheKeyValue
|
|
877
|
+
from .cache_helper import CacheKeyValue, make_sliding_window_cache
|
|
867
878
|
|
|
868
879
|
ca = CacheKeyValue(value)
|
|
869
880
|
return make_sliding_window_cache(
|
|
@@ -875,6 +886,8 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
875
886
|
torch_deepcopy(value.cross_attention_cache),
|
|
876
887
|
)
|
|
877
888
|
if value.__class__.__name__ == "MambaCache":
|
|
889
|
+
from .cache_helper import make_mamba_cache
|
|
890
|
+
|
|
878
891
|
return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
|
|
879
892
|
|
|
880
893
|
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
3
3
|
import torch
|
|
4
|
-
from ..helpers.cache_helper import make_dynamic_cache,
|
|
4
|
+
from ..helpers.cache_helper import make_dynamic_cache, get_make_hybrid_cache
|
|
5
5
|
from ..helpers.config_helper import (
|
|
6
6
|
update_config,
|
|
7
7
|
check_hasattr,
|
|
@@ -200,6 +200,9 @@ def _get_inputs_gemma3(
|
|
|
200
200
|
|
|
201
201
|
_check_()
|
|
202
202
|
|
|
203
|
+
make_hybrid_cache = get_make_hybrid_cache()
|
|
204
|
+
assert make_hybrid_cache is not None, "not implemented when make_hybrid_cache is missing"
|
|
205
|
+
|
|
203
206
|
inputs = dict(
|
|
204
207
|
input_ids=dummies["input_ids"],
|
|
205
208
|
token_type_ids=dummies["token_type_ids"],
|
|
@@ -1,11 +1,6 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.cache_helper import
|
|
4
|
-
make_dynamic_cache,
|
|
5
|
-
make_mamba_cache,
|
|
6
|
-
make_sliding_window_cache,
|
|
7
|
-
make_static_cache,
|
|
8
|
-
)
|
|
3
|
+
from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache, make_static_cache
|
|
9
4
|
from ..helpers.config_helper import (
|
|
10
5
|
update_config,
|
|
11
6
|
check_hasattr,
|
|
@@ -187,17 +182,22 @@ def get_inputs(
|
|
|
187
182
|
if cls_cache is None or isinstance(cls_cache, str)
|
|
188
183
|
else cls_cache.__name__
|
|
189
184
|
)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
185
|
+
if cache_name == "DynamicSlidingWindowCache":
|
|
186
|
+
from ..helpers.cache_helper import make_sliding_window_cache
|
|
187
|
+
|
|
188
|
+
make_cache = make_sliding_window_cache
|
|
189
|
+
is_static = False
|
|
190
|
+
else:
|
|
191
|
+
make_caches = {
|
|
192
|
+
"DynamicCache": make_dynamic_cache,
|
|
193
|
+
"StaticCache": make_static_cache,
|
|
194
|
+
}
|
|
195
|
+
assert cache_name is None or cache_name in make_caches, (
|
|
196
|
+
f"Unable to handle cls_cache={cache_name!r}, it should be in "
|
|
197
|
+
f"{sorted(make_caches)}"
|
|
198
|
+
)
|
|
199
|
+
make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] # type: ignore[assignment]
|
|
200
|
+
is_static = cache_name == "StaticCache"
|
|
201
201
|
|
|
202
202
|
if is_static:
|
|
203
203
|
# static
|
|
@@ -521,7 +521,7 @@ def run_exporter(
|
|
|
521
521
|
:param exporter: exporter
|
|
522
522
|
:param cls_model: model class to create
|
|
523
523
|
:param inputs: list of inputs to try
|
|
524
|
-
:param dynamic: use dynamic
|
|
524
|
+
:param dynamic: use dynamic shapes or not
|
|
525
525
|
:param quiet: raise exception or not
|
|
526
526
|
:param verbose: verbosity
|
|
527
527
|
:return: results
|
|
@@ -7,15 +7,9 @@ import transformers
|
|
|
7
7
|
from transformers.cache_utils import DynamicCache, StaticCache
|
|
8
8
|
|
|
9
9
|
try:
|
|
10
|
-
from transformers.cache_utils import
|
|
11
|
-
EncoderDecoderCache,
|
|
12
|
-
HybridCache,
|
|
13
|
-
SlidingWindowCache,
|
|
14
|
-
)
|
|
10
|
+
from transformers.cache_utils import EncoderDecoderCache
|
|
15
11
|
except ImportError:
|
|
16
12
|
EncoderDecoderCache = None
|
|
17
|
-
HybridCache = None
|
|
18
|
-
SlidingWindowCache = None
|
|
19
13
|
from ..helpers import string_type
|
|
20
14
|
from .serialization import _lower_name_with_
|
|
21
15
|
|
|
@@ -36,6 +30,24 @@ def get_mamba_cache_cls() -> type:
|
|
|
36
30
|
return None
|
|
37
31
|
|
|
38
32
|
|
|
33
|
+
def get_hybrid_cache_cls() -> type:
|
|
34
|
+
try:
|
|
35
|
+
from transformers.cache_utils import HybridCache
|
|
36
|
+
|
|
37
|
+
return HybridCache
|
|
38
|
+
except ImportError:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_sliding_window_cache_cls() -> type:
|
|
43
|
+
try:
|
|
44
|
+
from transformers.cache_utils import SlidingWindowCache
|
|
45
|
+
|
|
46
|
+
return SlidingWindowCache
|
|
47
|
+
except ImportError:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
39
51
|
def register_class_serialization(
|
|
40
52
|
cls,
|
|
41
53
|
f_flatten: Callable,
|
|
@@ -179,18 +191,9 @@ def serialization_functions(
|
|
|
179
191
|
flatten_dynamic_cache,
|
|
180
192
|
unflatten_dynamic_cache,
|
|
181
193
|
flatten_with_keys_dynamic_cache,
|
|
182
|
-
flatten_hybrid_cache,
|
|
183
|
-
unflatten_hybrid_cache,
|
|
184
|
-
flatten_with_keys_hybrid_cache,
|
|
185
|
-
flatten_mamba_cache,
|
|
186
|
-
unflatten_mamba_cache,
|
|
187
|
-
flatten_with_keys_mamba_cache,
|
|
188
194
|
flatten_encoder_decoder_cache,
|
|
189
195
|
unflatten_encoder_decoder_cache,
|
|
190
196
|
flatten_with_keys_encoder_decoder_cache,
|
|
191
|
-
flatten_sliding_window_cache,
|
|
192
|
-
unflatten_sliding_window_cache,
|
|
193
|
-
flatten_with_keys_sliding_window_cache,
|
|
194
197
|
flatten_static_cache,
|
|
195
198
|
unflatten_static_cache,
|
|
196
199
|
flatten_with_keys_static_cache,
|
|
@@ -208,14 +211,6 @@ def serialization_functions(
|
|
|
208
211
|
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
209
212
|
verbose=verbose,
|
|
210
213
|
),
|
|
211
|
-
HybridCache: lambda verbose=verbose: register_class_serialization(
|
|
212
|
-
HybridCache,
|
|
213
|
-
flatten_hybrid_cache,
|
|
214
|
-
unflatten_hybrid_cache,
|
|
215
|
-
flatten_with_keys_hybrid_cache,
|
|
216
|
-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
217
|
-
verbose=verbose,
|
|
218
|
-
),
|
|
219
214
|
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
|
|
220
215
|
EncoderDecoderCache,
|
|
221
216
|
flatten_encoder_decoder_cache,
|
|
@@ -223,13 +218,6 @@ def serialization_functions(
|
|
|
223
218
|
flatten_with_keys_encoder_decoder_cache,
|
|
224
219
|
verbose=verbose,
|
|
225
220
|
),
|
|
226
|
-
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
|
|
227
|
-
SlidingWindowCache,
|
|
228
|
-
flatten_sliding_window_cache,
|
|
229
|
-
unflatten_sliding_window_cache,
|
|
230
|
-
flatten_with_keys_sliding_window_cache,
|
|
231
|
-
verbose=verbose,
|
|
232
|
-
),
|
|
233
221
|
StaticCache: lambda verbose=verbose: register_class_serialization(
|
|
234
222
|
StaticCache,
|
|
235
223
|
flatten_static_cache,
|
|
@@ -240,6 +228,12 @@ def serialization_functions(
|
|
|
240
228
|
}
|
|
241
229
|
MambaCache = get_mamba_cache_cls()
|
|
242
230
|
if MambaCache:
|
|
231
|
+
from .serialization.transformers_impl import (
|
|
232
|
+
flatten_mamba_cache,
|
|
233
|
+
unflatten_mamba_cache,
|
|
234
|
+
flatten_with_keys_mamba_cache,
|
|
235
|
+
)
|
|
236
|
+
|
|
243
237
|
transformers_classes[MambaCache] = (
|
|
244
238
|
lambda verbose=verbose: register_class_serialization(
|
|
245
239
|
MambaCache,
|
|
@@ -249,6 +243,42 @@ def serialization_functions(
|
|
|
249
243
|
verbose=verbose,
|
|
250
244
|
)
|
|
251
245
|
)
|
|
246
|
+
HybridCache = get_hybrid_cache_cls()
|
|
247
|
+
if HybridCache:
|
|
248
|
+
from .serialization.transformers_impl import (
|
|
249
|
+
flatten_hybrid_cache,
|
|
250
|
+
unflatten_hybrid_cache,
|
|
251
|
+
flatten_with_keys_hybrid_cache,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
transformers_classes[HybridCache] = (
|
|
255
|
+
lambda verbose=verbose: register_class_serialization(
|
|
256
|
+
HybridCache,
|
|
257
|
+
flatten_hybrid_cache,
|
|
258
|
+
unflatten_hybrid_cache,
|
|
259
|
+
flatten_with_keys_hybrid_cache,
|
|
260
|
+
verbose=verbose,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
SlidingWindowCache = get_sliding_window_cache_cls()
|
|
265
|
+
if SlidingWindowCache:
|
|
266
|
+
from .serialization.transformers_impl import (
|
|
267
|
+
flatten_sliding_window_cache,
|
|
268
|
+
unflatten_sliding_window_cache,
|
|
269
|
+
flatten_with_keys_sliding_window_cache,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
transformers_classes[SlidingWindowCache] = (
|
|
273
|
+
lambda verbose=verbose: register_class_serialization(
|
|
274
|
+
SlidingWindowCache,
|
|
275
|
+
flatten_sliding_window_cache,
|
|
276
|
+
unflatten_sliding_window_cache,
|
|
277
|
+
flatten_with_keys_sliding_window_cache,
|
|
278
|
+
verbose=verbose,
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
|
|
252
282
|
classes.update(transformers_classes)
|
|
253
283
|
|
|
254
284
|
if patch_diffusers:
|
|
@@ -275,7 +305,7 @@ def serialization_functions(
|
|
|
275
305
|
|
|
276
306
|
|
|
277
307
|
def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
278
|
-
"""Undo the registration."""
|
|
308
|
+
"""Undo the registration for a class."""
|
|
279
309
|
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
|
|
280
310
|
if cls in torch.fx._pytree.SUPPORTED_NODES:
|
|
281
311
|
del torch.fx._pytree.SUPPORTED_NODES[cls]
|
|
@@ -303,13 +333,11 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
|
303
333
|
|
|
304
334
|
|
|
305
335
|
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
306
|
-
"""
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
| ({MambaCache} if MambaCache else set())
|
|
312
|
-
)
|
|
336
|
+
"""
|
|
337
|
+
Undo the registration made by
|
|
338
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`.
|
|
339
|
+
"""
|
|
340
|
+
cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
|
|
313
341
|
for cls in cls_ensemble:
|
|
314
342
|
if undo.get(cls.__name__, False):
|
|
315
343
|
unregister_class_serialization(cls, verbose)
|
|
@@ -524,13 +524,16 @@ class patched_ShapeEnv:
|
|
|
524
524
|
|
|
525
525
|
transmute_into_runtime_assert = False
|
|
526
526
|
|
|
527
|
+
backed_var_to_val = getattr(
|
|
528
|
+
self, "backed_var_to_val", getattr(self, "var_to_val", {})
|
|
529
|
+
)
|
|
527
530
|
concrete_val = None
|
|
528
|
-
if not (expr.free_symbols <=
|
|
531
|
+
if not (expr.free_symbols <= backed_var_to_val.keys()):
|
|
529
532
|
# TODO: dedupe this with _maybe_evaluate_static
|
|
530
533
|
# Attempt to eliminate the unbacked SymInt
|
|
531
534
|
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
|
532
535
|
assert new_expr is not None
|
|
533
|
-
if not (new_expr.free_symbols <=
|
|
536
|
+
if not (new_expr.free_symbols <= backed_var_to_val.keys()):
|
|
534
537
|
ok = False
|
|
535
538
|
|
|
536
539
|
# fallback_value is set when guard_or_true or guard_or_false are used.
|
|
@@ -542,13 +545,14 @@ class patched_ShapeEnv:
|
|
|
542
545
|
# with DimDynamic.OBLIVIOUS_SIZE type.
|
|
543
546
|
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
|
|
544
547
|
if (
|
|
545
|
-
|
|
548
|
+
backed_var_to_val
|
|
549
|
+
and getattr(self, "real_tensor_prop_unbacked_vals", True)
|
|
546
550
|
and not (
|
|
547
|
-
correct_hint := orig_expr.xreplace(
|
|
551
|
+
correct_hint := orig_expr.xreplace(backed_var_to_val)
|
|
548
552
|
).free_symbols
|
|
549
553
|
and not (
|
|
550
554
|
counterfactual_hint := orig_expr.xreplace(
|
|
551
|
-
{k: max(2, v) for k, v in
|
|
555
|
+
{k: max(2, v) for k, v in backed_var_to_val.items()}
|
|
552
556
|
)
|
|
553
557
|
).free_symbols
|
|
554
558
|
and correct_hint == counterfactual_hint
|
|
@@ -571,11 +575,11 @@ class patched_ShapeEnv:
|
|
|
571
575
|
# and if they pass we add a runtime assertions and continue.
|
|
572
576
|
if (
|
|
573
577
|
not ok
|
|
574
|
-
and
|
|
578
|
+
and backed_var_to_val
|
|
575
579
|
and not (
|
|
576
|
-
unsound_result := orig_expr.xreplace(
|
|
577
|
-
|
|
578
|
-
)
|
|
580
|
+
unsound_result := orig_expr.xreplace(backed_var_to_val).xreplace(
|
|
581
|
+
backed_var_to_val
|
|
582
|
+
)
|
|
579
583
|
).free_symbols
|
|
580
584
|
):
|
|
581
585
|
# pyrefly: ignore # unbound-name
|