onnx-diagnostic 0.8.8__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,7 @@ class CacheKeyValue:
19
19
  capi.value_cache
20
20
  """
21
21
 
22
- def __init__(self, cache=None):
22
+ def __init__(self, cache=None, cls_layers=None):
23
23
  if hasattr(cache, "layers"):
24
24
  layers = [
25
25
  layer
@@ -28,24 +28,52 @@ class CacheKeyValue:
28
28
  ]
29
29
  self.key_cache = [layer.keys for layer in layers]
30
30
  self.value_cache = [layer.values for layer in layers]
31
+ assert (
32
+ cls_layers is None
33
+ ), f"cache is {type(cache)}, cannot specify cls_layers={cls_layers}"
34
+ self.cls_layers = [type(lay) for lay in cache.layers]
31
35
  elif cache is not None and hasattr(cache, "key_cache"):
32
36
  self.key_cache = cache.key_cache
33
37
  self.value_cache = cache.value_cache
38
+ self.cls_layers = cls_layers
39
+ elif (
40
+ cache is not None
41
+ and isinstance(cache, list)
42
+ and all(isinstance(t, torch.Tensor) for t in cache)
43
+ ):
44
+ self.key_cache = cache[::2]
45
+ self.value_cache = cache[1::2]
46
+ self.cls_layers = cls_layers
34
47
  elif cache is None:
35
48
  self.key_cache = None
36
49
  self.value_cache = None
50
+ self.cls_layers = cls_layers
37
51
  else:
38
52
  raise NotImplementedError(f"type(cache)={type(cache)}")
39
53
 
40
54
  def make_dynamic_cache(self):
41
55
  """Does the reverse operation."""
42
- return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
56
+ return make_dynamic_cache(
57
+ list(zip(self.key_cache, self.value_cache)), cls_layers=self.cls_layers
58
+ )
43
59
 
44
60
  @property
45
61
  def n_layers(self) -> int:
46
62
  """Returns the number of layers."""
47
63
  return len(self.key_cache) if self.key_cache else 0
48
64
 
65
+ def __len__(self) -> int:
66
+ "Returns the number of tensors."
67
+ return len(self.key_cache) + len(self.value_cache)
68
+
69
+ def aslist(self) -> List[torch.Tensor]:
70
+ "Returns tensors in a list."
71
+ res = []
72
+ for i in range(self.n_layers):
73
+ res.append(self.key_cache[i])
74
+ res.append(self.value_cache[i])
75
+ return res
76
+
49
77
 
50
78
  def flatten_unflatten_for_dynamic_shapes(
51
79
  obj: Any,
@@ -156,12 +184,16 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
156
184
 
157
185
  def make_dynamic_cache(
158
186
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
187
+ cls_layers: Optional[Union[str, List[type]]] = None,
159
188
  ) -> transformers.cache_utils.DynamicCache:
160
189
  """
161
190
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
162
191
  This version is valid for ``transformers >= 4.50``.
163
192
 
164
193
  :param key_value_pairs: list of pairs of (key, values)
194
+ :param cls_layers: to select the appropriate class to use on each layer,
195
+ if specified, sliding_window is ignored, it can be a string
196
+ if all layers are expected to follow the same class
165
197
  :return: :class:`transformers.cache_utils.DynamicCache`
166
198
 
167
199
  Example:
@@ -192,15 +224,49 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
192
224
  are supported.
193
225
  """
194
226
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
227
+ cls_kwargs = {}
228
+ if isinstance(cls_layers, str):
229
+ assert hasattr(
230
+ transformers.cache_utils, cls_layers
231
+ ), f"Unable to find class {cls_layers!r} in transformers.cache_utils"
232
+ cls_layer = getattr(transformers.cache_utils, cls_layers)
233
+ if cls_layers == "DynamicSlidingWindowLayer":
234
+ cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
235
+ assert isinstance(
236
+ cls_kwargs["sliding_window"], int
237
+ ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
238
+ elif cls_layers is not None:
239
+ unique = set(cls_layers)
240
+ assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}"
241
+ cls_layer = unique.pop()
242
+ if (
243
+ hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer")
244
+ and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer
245
+ ):
246
+ from .helper import string_type
247
+
248
+ assert key_value_pairs and key_value_pairs[0], (
249
+ f"not implemented for key_value_pairs="
250
+ f"{string_type(key_value_pairs, with_shape=True)}"
251
+ )
252
+ cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
253
+ assert isinstance(
254
+ cls_kwargs["sliding_window"], int
255
+ ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
256
+ else:
257
+ cls_layer = (
258
+ transformers.cache_utils.DynamicLayer
259
+ if hasattr(transformers.cache_utils, "DynamicLayer")
260
+ else None
261
+ )
262
+
195
263
  if (
196
264
  key_value_pairs
197
265
  and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
198
266
  and pv.Version(transformers.__version__) >= pv.Version("4.56")
199
267
  ):
200
268
  cache = transformers.cache_utils.DynamicCache()
201
- cache.layers.extend(
202
- [transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
203
- )
269
+ cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
204
270
  for i, layer in enumerate(cache.layers):
205
271
  k, v = key_value_pairs[i][0], key_value_pairs[i][1]
206
272
  layer.dtype = k.dtype
@@ -214,14 +280,21 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
214
280
  )
215
281
  return finalize_cache(cache)
216
282
 
217
- cache = transformers.cache_utils.DynamicCache(key_value_pairs)
218
- if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
219
- # The cache constructor contains the two following lines
220
- # (in cache_utils.py) which append empty layers when the cache is
221
- # initialized. We need to remove them.
222
- # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
223
- # self.append_new_layers(self.num_hidden_layers - 1)
224
- cache.layers[:] = cache.layers[-len(key_value_pairs) :]
283
+ cache = transformers.cache_utils.DynamicCache()
284
+ if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer:
285
+ cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
286
+ for i, layer in enumerate(cache.layers):
287
+ layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
288
+ layer.is_initialized = True
289
+ else:
290
+ cache = transformers.cache_utils.DynamicCache(key_value_pairs)
291
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
292
+ # The cache constructor contains the two following lines
293
+ # (in cache_utils.py) which append empty layers when the cache is
294
+ # initialized. We need to remove them.
295
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
296
+ # self.append_new_layers(self.num_hidden_layers - 1)
297
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
225
298
  assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
226
299
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
227
300
  f"{len(key_value_pairs)} expected."
@@ -232,6 +305,7 @@ else:
232
305
 
233
306
  def make_dynamic_cache(
234
307
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
308
+ cls_layers: Optional[Union[str, List[type]]] = None,
235
309
  ) -> transformers.cache_utils.DynamicCache:
236
310
  """
237
311
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -263,6 +337,7 @@ else:
263
337
  )
264
338
  print(string_type(past_key_values, with_shape=True))
265
339
  """
340
+ assert not cls_layers, "cls_layers cannot be used for transformers<5."
266
341
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
267
342
  cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
268
343
  for i, (key, value) in enumerate(key_value_pairs):
@@ -508,9 +583,13 @@ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
508
583
  )
509
584
  return finalize_cache(cache)
510
585
 
586
+ def get_make_hybrid_cache():
587
+ return make_sliding_window_cache
588
+
511
589
  else:
512
590
  make_sliding_window_cache = None # type: ignore[assignment]
513
591
 
592
+
514
593
  if hasattr(transformers.cache_utils, "HybridCache"):
515
594
 
516
595
  def make_hybrid_cache(
@@ -672,9 +751,15 @@ if hasattr(transformers.cache_utils, "HybridCache"):
672
751
  )
673
752
  return finalize_cache(cache)
674
753
 
754
+ def get_make_hybrid_cache():
755
+ return make_hybrid_cache
756
+
675
757
  else:
676
758
  make_hybrid_cache = None # type: ignore[assignment]
677
759
 
760
+ def get_make_hybrid_cache():
761
+ return None
762
+
678
763
 
679
764
  def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
680
765
  """
@@ -1,7 +1,6 @@
1
1
  import ast
2
2
  import enum
3
3
  import inspect
4
- import itertools
5
4
  import json
6
5
  from dataclasses import is_dataclass, fields
7
6
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -991,15 +990,17 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
991
990
  if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
992
991
  from .cache_helper import CacheKeyValue
993
992
 
994
- kc = CacheKeyValue(x)
995
- return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
993
+ return CacheKeyValue(x).aslist()
996
994
 
997
995
  if x.__class__.__name__ == "EncoderDecoderCache":
998
- res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
996
+ res = [
997
+ *flatten_object(x.self_attention_cache),
998
+ *flatten_object(x.cross_attention_cache),
999
+ ]
999
1000
  return tuple(res)
1000
1001
  if x.__class__.__name__ == "MambaCache":
1001
1002
  if isinstance(x.conv_states, list):
1002
- res = flatten_object(x.conv_states) + flatten_object(x.ssm_states)
1003
+ res = [*flatten_object(x.conv_states), *flatten_object(x.ssm_states)]
1003
1004
  return tuple(res)
1004
1005
  return (x.conv_states, x.ssm_states)
1005
1006
  if hasattr(x, "to_tuple"):
@@ -28,6 +28,7 @@ from onnx import (
28
28
  NodeProto,
29
29
  OperatorSetIdProto,
30
30
  TensorProto,
31
+ TypeProto,
31
32
  ValueInfoProto,
32
33
  load as onnx_load,
33
34
  )
@@ -385,6 +386,12 @@ def pretty_onnx(
385
386
  shape_str = ",".join(map(str, shape))
386
387
  return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
387
388
 
389
+ if isinstance(onx, TypeProto):
390
+ itype = onx.tensor_type.elem_type
391
+ shape = tuple((d.dim_param or d.dim_value) for d in onx.tensor_type.shape.dim)
392
+ shape_str = ",".join(map(str, shape))
393
+ return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}]"
394
+
388
395
  if isinstance(onx, AttributeProto):
389
396
  att = onx
390
397
  if att.type == AttributeProto.INT:
@@ -41,7 +41,20 @@ def make_feeds(
41
41
  """
42
42
  # NOTE: position_ids is a special case because ModelBuilder does not usually use it,
43
43
  # because it's fued into rotary embedding in GQA.
44
- if is_modelbuilder and isinstance(inputs, dict):
44
+ if is_modelbuilder and isinstance(inputs, dict) and "position_ids" in inputs:
45
+ position_ids = inputs["position_ids"] # type: ignore[valid-type]
46
+ # We just check position_ids are contiguous.
47
+ assert isinstance(position_ids, torch.Tensor) and (
48
+ (
49
+ (position_ids - position_ids.min())
50
+ == torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)
51
+ )
52
+ .max()
53
+ .item()
54
+ ), (
55
+ f"ModelBuilder does not support position_ids={position_ids}, "
56
+ f"inputs={string_type(inputs, with_shape=True)}"
57
+ )
45
58
  inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
46
59
 
47
60
  flat = flatten_object(inputs, drop_keys=True)
@@ -15,9 +15,6 @@ from .helper import string_type, size_type
15
15
  from .cache_helper import (
16
16
  make_dynamic_cache,
17
17
  make_encoder_decoder_cache,
18
- make_hybrid_cache,
19
- make_sliding_window_cache,
20
- make_mamba_cache,
21
18
  make_static_cache,
22
19
  CacheKeyValue,
23
20
  )
@@ -769,10 +766,22 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
769
766
  return {to_any(t, to_value) for t in value}
770
767
  if type(value) is dict:
771
768
  return {k: to_any(t, to_value) for k, t in value.items()}
772
- if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
773
- make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
769
+ if value.__class__.__name__ == "DynamicCache":
774
770
  cc = CacheKeyValue(value)
775
- return make[value.__class__.__name__]( # type: ignore[operator]
771
+ return make_dynamic_cache(
772
+ list(
773
+ zip(
774
+ [t.to(to_value) if t is not None else t for t in cc.key_cache],
775
+ [t.to(to_value) if t is not None else t for t in cc.value_cache],
776
+ )
777
+ ),
778
+ cls_layers=cc.cls_layers,
779
+ )
780
+ if value.__class__.__name__ == "HybridCache":
781
+ from .cache_helper import make_hybrid_cache
782
+
783
+ cc = CacheKeyValue(value)
784
+ return make_hybrid_cache(
776
785
  list(
777
786
  zip(
778
787
  [t.to(to_value) if t is not None else t for t in cc.key_cache],
@@ -843,7 +852,9 @@ def torch_deepcopy(value: Any) -> Any:
843
852
  from .cache_helper import CacheKeyValue
844
853
 
845
854
  ca = CacheKeyValue(value)
846
- return make_dynamic_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
855
+ return make_dynamic_cache(
856
+ torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
857
+ )
847
858
  if value.__class__.__name__ == "StaticCache":
848
859
  from .cache_helper import CacheKeyValue
849
860
 
@@ -858,12 +869,12 @@ def torch_deepcopy(value: Any) -> Any:
858
869
  max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
859
870
  )
860
871
  if value.__class__.__name__ == "HybridCache":
861
- from .cache_helper import CacheKeyValue
872
+ from .cache_helper import CacheKeyValue, make_hybrid_cache
862
873
 
863
874
  ca = CacheKeyValue(value)
864
875
  return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
865
876
  if value.__class__.__name__ == "SlidingWindowCache":
866
- from .cache_helper import CacheKeyValue
877
+ from .cache_helper import CacheKeyValue, make_sliding_window_cache
867
878
 
868
879
  ca = CacheKeyValue(value)
869
880
  return make_sliding_window_cache(
@@ -875,6 +886,8 @@ def torch_deepcopy(value: Any) -> Any:
875
886
  torch_deepcopy(value.cross_attention_cache),
876
887
  )
877
888
  if value.__class__.__name__ == "MambaCache":
889
+ from .cache_helper import make_mamba_cache
890
+
878
891
  return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
879
892
 
880
893
  if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
@@ -1,7 +1,7 @@
1
1
  import itertools
2
2
  from typing import Any, Callable, Dict, Optional, Tuple
3
3
  import torch
4
- from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
4
+ from ..helpers.cache_helper import make_dynamic_cache, get_make_hybrid_cache
5
5
  from ..helpers.config_helper import (
6
6
  update_config,
7
7
  check_hasattr,
@@ -200,6 +200,9 @@ def _get_inputs_gemma3(
200
200
 
201
201
  _check_()
202
202
 
203
+ make_hybrid_cache = get_make_hybrid_cache()
204
+ assert make_hybrid_cache is not None, "not implemented when make_hybrid_cache is missing"
205
+
203
206
  inputs = dict(
204
207
  input_ids=dummies["input_ids"],
205
208
  token_type_ids=dummies["token_type_ids"],
@@ -1,11 +1,6 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple, Union
2
2
  import torch
3
- from ..helpers.cache_helper import (
4
- make_dynamic_cache,
5
- make_mamba_cache,
6
- make_sliding_window_cache,
7
- make_static_cache,
8
- )
3
+ from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache, make_static_cache
9
4
  from ..helpers.config_helper import (
10
5
  update_config,
11
6
  check_hasattr,
@@ -187,17 +182,22 @@ def get_inputs(
187
182
  if cls_cache is None or isinstance(cls_cache, str)
188
183
  else cls_cache.__name__
189
184
  )
190
- make_caches = {
191
- "DynamicCache": make_dynamic_cache,
192
- "SlidingWindowCache": make_sliding_window_cache,
193
- "StaticCache": make_static_cache,
194
- }
195
- assert cache_name is None or cache_name in make_caches, (
196
- f"Unable to handle cls_cache={cache_name!r}, it should be in "
197
- f"{sorted(make_caches)}"
198
- )
199
- make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
200
- is_static = cache_name == "StaticCache"
185
+ if cache_name == "DynamicSlidingWindowCache":
186
+ from ..helpers.cache_helper import make_sliding_window_cache
187
+
188
+ make_cache = make_sliding_window_cache
189
+ is_static = False
190
+ else:
191
+ make_caches = {
192
+ "DynamicCache": make_dynamic_cache,
193
+ "StaticCache": make_static_cache,
194
+ }
195
+ assert cache_name is None or cache_name in make_caches, (
196
+ f"Unable to handle cls_cache={cache_name!r}, it should be in "
197
+ f"{sorted(make_caches)}"
198
+ )
199
+ make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] # type: ignore[assignment]
200
+ is_static = cache_name == "StaticCache"
201
201
 
202
202
  if is_static:
203
203
  # static
@@ -521,7 +521,7 @@ def run_exporter(
521
521
  :param exporter: exporter
522
522
  :param cls_model: model class to create
523
523
  :param inputs: list of inputs to try
524
- :param dynamic: use dynamic shape or not
524
+ :param dynamic: use dynamic shapes or not
525
525
  :param quiet: raise exception or not
526
526
  :param verbose: verbosity
527
527
  :return: results
@@ -7,15 +7,9 @@ import transformers
7
7
  from transformers.cache_utils import DynamicCache, StaticCache
8
8
 
9
9
  try:
10
- from transformers.cache_utils import (
11
- EncoderDecoderCache,
12
- HybridCache,
13
- SlidingWindowCache,
14
- )
10
+ from transformers.cache_utils import EncoderDecoderCache
15
11
  except ImportError:
16
12
  EncoderDecoderCache = None
17
- HybridCache = None
18
- SlidingWindowCache = None
19
13
  from ..helpers import string_type
20
14
  from .serialization import _lower_name_with_
21
15
 
@@ -36,6 +30,24 @@ def get_mamba_cache_cls() -> type:
36
30
  return None
37
31
 
38
32
 
33
+ def get_hybrid_cache_cls() -> type:
34
+ try:
35
+ from transformers.cache_utils import HybridCache
36
+
37
+ return HybridCache
38
+ except ImportError:
39
+ return None
40
+
41
+
42
+ def get_sliding_window_cache_cls() -> type:
43
+ try:
44
+ from transformers.cache_utils import SlidingWindowCache
45
+
46
+ return SlidingWindowCache
47
+ except ImportError:
48
+ return None
49
+
50
+
39
51
  def register_class_serialization(
40
52
  cls,
41
53
  f_flatten: Callable,
@@ -179,18 +191,9 @@ def serialization_functions(
179
191
  flatten_dynamic_cache,
180
192
  unflatten_dynamic_cache,
181
193
  flatten_with_keys_dynamic_cache,
182
- flatten_hybrid_cache,
183
- unflatten_hybrid_cache,
184
- flatten_with_keys_hybrid_cache,
185
- flatten_mamba_cache,
186
- unflatten_mamba_cache,
187
- flatten_with_keys_mamba_cache,
188
194
  flatten_encoder_decoder_cache,
189
195
  unflatten_encoder_decoder_cache,
190
196
  flatten_with_keys_encoder_decoder_cache,
191
- flatten_sliding_window_cache,
192
- unflatten_sliding_window_cache,
193
- flatten_with_keys_sliding_window_cache,
194
197
  flatten_static_cache,
195
198
  unflatten_static_cache,
196
199
  flatten_with_keys_static_cache,
@@ -208,14 +211,6 @@ def serialization_functions(
208
211
  # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
209
212
  verbose=verbose,
210
213
  ),
211
- HybridCache: lambda verbose=verbose: register_class_serialization(
212
- HybridCache,
213
- flatten_hybrid_cache,
214
- unflatten_hybrid_cache,
215
- flatten_with_keys_hybrid_cache,
216
- # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
217
- verbose=verbose,
218
- ),
219
214
  EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
220
215
  EncoderDecoderCache,
221
216
  flatten_encoder_decoder_cache,
@@ -223,13 +218,6 @@ def serialization_functions(
223
218
  flatten_with_keys_encoder_decoder_cache,
224
219
  verbose=verbose,
225
220
  ),
226
- SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
227
- SlidingWindowCache,
228
- flatten_sliding_window_cache,
229
- unflatten_sliding_window_cache,
230
- flatten_with_keys_sliding_window_cache,
231
- verbose=verbose,
232
- ),
233
221
  StaticCache: lambda verbose=verbose: register_class_serialization(
234
222
  StaticCache,
235
223
  flatten_static_cache,
@@ -240,6 +228,12 @@ def serialization_functions(
240
228
  }
241
229
  MambaCache = get_mamba_cache_cls()
242
230
  if MambaCache:
231
+ from .serialization.transformers_impl import (
232
+ flatten_mamba_cache,
233
+ unflatten_mamba_cache,
234
+ flatten_with_keys_mamba_cache,
235
+ )
236
+
243
237
  transformers_classes[MambaCache] = (
244
238
  lambda verbose=verbose: register_class_serialization(
245
239
  MambaCache,
@@ -249,6 +243,42 @@ def serialization_functions(
249
243
  verbose=verbose,
250
244
  )
251
245
  )
246
+ HybridCache = get_hybrid_cache_cls()
247
+ if HybridCache:
248
+ from .serialization.transformers_impl import (
249
+ flatten_hybrid_cache,
250
+ unflatten_hybrid_cache,
251
+ flatten_with_keys_hybrid_cache,
252
+ )
253
+
254
+ transformers_classes[HybridCache] = (
255
+ lambda verbose=verbose: register_class_serialization(
256
+ HybridCache,
257
+ flatten_hybrid_cache,
258
+ unflatten_hybrid_cache,
259
+ flatten_with_keys_hybrid_cache,
260
+ verbose=verbose,
261
+ )
262
+ )
263
+
264
+ SlidingWindowCache = get_sliding_window_cache_cls()
265
+ if SlidingWindowCache:
266
+ from .serialization.transformers_impl import (
267
+ flatten_sliding_window_cache,
268
+ unflatten_sliding_window_cache,
269
+ flatten_with_keys_sliding_window_cache,
270
+ )
271
+
272
+ transformers_classes[SlidingWindowCache] = (
273
+ lambda verbose=verbose: register_class_serialization(
274
+ SlidingWindowCache,
275
+ flatten_sliding_window_cache,
276
+ unflatten_sliding_window_cache,
277
+ flatten_with_keys_sliding_window_cache,
278
+ verbose=verbose,
279
+ )
280
+ )
281
+
252
282
  classes.update(transformers_classes)
253
283
 
254
284
  if patch_diffusers:
@@ -303,13 +333,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
303
333
 
304
334
 
305
335
  def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
306
- """Undo all registrations."""
307
- MambaCache = get_mamba_cache_cls()
308
- cls_ensemble = (
309
- {DynamicCache, EncoderDecoderCache}
310
- | set(undo)
311
- | ({MambaCache} if MambaCache else set())
312
- )
336
+ cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
313
337
  for cls in cls_ensemble:
314
338
  if undo.get(cls.__name__, False):
315
339
  unregister_class_serialization(cls, verbose)
@@ -541,14 +541,17 @@ class patched_ShapeEnv:
541
541
  # oblivious_var_to_val will be defined iff we have sizes
542
542
  # with DimDynamic.OBLIVIOUS_SIZE type.
543
543
  # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
544
+ var_to_val = getattr(
545
+ self,
546
+ "unbacked_var_to_val",
547
+ getattr(self, "oblivious_var_to_val", False),
548
+ )
544
549
  if (
545
- self.oblivious_var_to_val
546
- and not (
547
- correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
548
- ).free_symbols
550
+ var_to_val
551
+ and not (correct_hint := orig_expr.xreplace(var_to_val)).free_symbols
549
552
  and not (
550
553
  counterfactual_hint := orig_expr.xreplace(
551
- {k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
554
+ {k: max(2, v) for k, v in var_to_val.items()}
552
555
  )
553
556
  ).free_symbols
554
557
  and correct_hint == counterfactual_hint
@@ -571,11 +574,11 @@ class patched_ShapeEnv:
571
574
  # and if they pass we add a runtime assertions and continue.
572
575
  if (
573
576
  not ok
574
- and self.unbacked_var_to_val
577
+ and var_to_val
575
578
  and not (
576
- unsound_result := orig_expr.xreplace(
577
- self.unbacked_var_to_val
578
- ).xreplace(self.var_to_val)
579
+ unsound_result := orig_expr.xreplace(var_to_val).xreplace(
580
+ var_to_val
581
+ )
579
582
  ).free_symbols
580
583
  ):
581
584
  # pyrefly: ignore # unbound-name