onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -39,22 +39,57 @@ try:
39
39
  except ImportError:
40
40
  patch_DynamicLayer = False
41
41
 
42
- from ...ext_test_case import has_transformers
43
- from ...helpers.torch_helper import is_torchdynamo_exporting
44
42
 
45
- patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
43
+ def _has_transformers(version: str) -> bool:
44
+ return pv.Version(transformers.__version__) >= pv.Version(version)
45
+
46
+
47
+ def _is_torchdynamo_exporting() -> bool:
48
+ """
49
+ Tells if :epkg:`torch` is exporting a model.
50
+ Relies on ``torch.compiler.is_exporting()``.
51
+ """
52
+ import torch
53
+
54
+ if not hasattr(torch.compiler, "is_exporting"):
55
+ # torch.compiler.is_exporting requires torch>=2.7
56
+ return False
57
+
58
+ try:
59
+ return torch.compiler.is_exporting()
60
+ except Exception:
61
+ try:
62
+ import torch._dynamo as dynamo
63
+
64
+ return dynamo.is_exporting() # type: ignore
65
+ except Exception:
66
+ return False
67
+
68
+
69
+ patch_sdpa_is_causal = _has_transformers("4.99")
70
+ patch_is_initialized = _has_transformers("4.56.99")
46
71
 
47
72
 
48
73
  if patch_masking_utils:
49
74
  # Introduced in 4.52
50
75
  from transformers.masking_utils import (
76
+ _ignore_causal_mask_sdpa,
77
+ and_masks,
51
78
  causal_mask_function,
52
79
  padding_mask_function,
53
- and_masks,
54
- _ignore_causal_mask_sdpa,
55
80
  prepare_padding_mask,
56
81
  )
57
82
 
83
+ try:
84
+ # transformers>=5.0
85
+ from transformers.masking_utils import (
86
+ _ignore_bidirectional_mask_sdpa,
87
+ bidirectional_mask_function,
88
+ )
89
+ except ImportError:
90
+ _ignore_bidirectional_mask_sdpa = None
91
+ bidirectional_mask_function = None
92
+
58
93
  def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
59
94
  """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
60
95
  from ...helpers import string_type
@@ -98,7 +133,7 @@ if patch_masking_utils:
98
133
  # for a, dims in zip(args, udimensions)
99
134
  # ]
100
135
  max_shape = tuple(args[i].shape[0] for i in indices)
101
- # if is_torchdynamo_exporting():
136
+ # if _is_torchdynamo_exporting():
102
137
  # for a in args:
103
138
  # # The exporter should export with a dimension > 1
104
139
  # # to make sure it is dynamic.
@@ -121,6 +156,7 @@ if patch_masking_utils:
121
156
  """manual patch for function ``transformers.masking_utils.eager_mask``."""
122
157
  # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
123
158
  _ = kwargs.pop("allow_is_causal_skip", None)
159
+ _ = kwargs.pop("allow_is_bidirectional_skip", None)
124
160
  # PATCHED: this line called the patched version of sdpa_mask
125
161
  mask = patched_sdpa_mask_recent_torch(
126
162
  batch_size=batch_size,
@@ -130,6 +166,7 @@ if patch_masking_utils:
130
166
  mask_function=mask_function,
131
167
  attention_mask=attention_mask,
132
168
  allow_is_causal_skip=False,
169
+ allow_is_bidirectional_skip=False,
133
170
  allow_torch_fix=False,
134
171
  **kwargs,
135
172
  )
@@ -151,6 +188,7 @@ if patch_masking_utils:
151
188
  attention_mask: Optional[torch.Tensor] = None,
152
189
  local_size: Optional[int] = None,
153
190
  allow_is_causal_skip: bool = True,
191
+ allow_is_bidirectional_skip: bool = False,
154
192
  **kwargs,
155
193
  ) -> Optional[torch.Tensor]:
156
194
  """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
@@ -160,6 +198,29 @@ if patch_masking_utils:
160
198
  padding_mask, q_length, kv_length, kv_offset, local_size
161
199
  ):
162
200
  return None
201
+ if (
202
+ allow_is_bidirectional_skip
203
+ and _ignore_bidirectional_mask_sdpa
204
+ and _ignore_bidirectional_mask_sdpa(padding_mask)
205
+ ):
206
+ return None
207
+
208
+ if mask_function is bidirectional_mask_function:
209
+ if padding_mask is not None:
210
+ # used for slicing without data-dependent slicing
211
+ mask_indices = (
212
+ torch.arange(kv_length, device=cache_position.device) + kv_offset
213
+ )
214
+ return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
215
+ return torch.ones(
216
+ batch_size,
217
+ 1,
218
+ q_length,
219
+ kv_length,
220
+ dtype=torch.bool,
221
+ device=cache_position.device,
222
+ )
223
+
163
224
  kv_arange = torch.arange(kv_length, device=cache_position.device)
164
225
  kv_arange += kv_offset
165
226
  if padding_mask is not None:
@@ -275,7 +336,7 @@ class patched_AttentionMaskConverter:
275
336
  """
276
337
 
277
338
  # This method was fixed in 4.51 at least.
278
- _PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
339
+ _PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
279
340
  _PATCHED_CLASS_ = AttentionMaskConverter
280
341
 
281
342
  @staticmethod
@@ -507,7 +568,7 @@ class patched_GenerationMixin:
507
568
  The current implementation does not rely on ``self`` and could be
508
569
  a class method. It is left as a standard method to be easily rewritten.
509
570
  """
510
- if is_torchdynamo_exporting():
571
+ if _is_torchdynamo_exporting():
511
572
  return self._cache_dependant_input_preparation_exporting(
512
573
  input_ids, inputs_embeds, cache_position
513
574
  )
@@ -1287,11 +1348,29 @@ def patched_sdpa_attention_forward(
1287
1348
  is_causal: Optional[bool] = None,
1288
1349
  **kwargs,
1289
1350
  ) -> tuple[torch.Tensor, None]:
1290
- """[patch:transformers.integrations.sdpa_attention.sdpa_attention_forward]"""
1351
+ """
1352
+ manual patch for function
1353
+ ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
1354
+ """
1291
1355
  assert not kwargs.get("output_attentions", False), (
1292
1356
  "`sdpa` attention does not support `output_attentions=True`."
1293
1357
  " Please set your attention to `eager` if you want any of these features."
1294
1358
  )
1359
+ torch._check(
1360
+ query.shape[0] == key.shape[0] or query.shape[0] == 1,
1361
+ lambda: (
1362
+ f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
1363
+ f"value: {value.shape}"
1364
+ ),
1365
+ )
1366
+ torch._check(
1367
+ key.shape[0] == value.shape[0] or key.shape[0] == 1,
1368
+ lambda: (
1369
+ f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
1370
+ f"value: {value.shape}"
1371
+ ),
1372
+ )
1373
+
1295
1374
  sdpa_kwargs = {}
1296
1375
  if hasattr(module, "num_key_value_groups"):
1297
1376
  if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
@@ -1307,24 +1386,83 @@ def patched_sdpa_attention_forward(
1307
1386
  if attention_mask is not None and attention_mask.ndim == 4:
1308
1387
  attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1309
1388
 
1310
- is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1311
- # PATCHED: remove the test query.shape[2] > 1
1312
- # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1313
- is_causal = attention_mask is None and is_causal
1314
-
1315
1389
  torch._check(
1316
1390
  attention_mask is None or attention_mask.shape[3] == key.shape[2],
1317
- "Attention mask shape incompatible with key shape.",
1391
+ lambda: "Attention mask shape incompatible with key shape.",
1318
1392
  )
1319
- attn_output = torch.nn.functional.scaled_dot_product_attention(
1320
- query,
1321
- key,
1322
- value,
1323
- attn_mask=attention_mask,
1324
- dropout_p=dropout,
1325
- scale=scaling,
1326
- is_causal=is_causal,
1327
- **sdpa_kwargs,
1393
+
1394
+ if patch_sdpa_is_causal:
1395
+ # transformers>=4.55
1396
+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1397
+
1398
+ # PATCHED: remove the test query.shape[2] > 1
1399
+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1400
+ # and we split the test to keep the minimum in torch.cond
1401
+ is_causal = attention_mask is None and is_causal
1402
+
1403
+ if not is_causal:
1404
+ return (
1405
+ torch.nn.functional.scaled_dot_product_attention(
1406
+ query,
1407
+ key,
1408
+ value,
1409
+ attn_mask=attention_mask,
1410
+ dropout_p=dropout,
1411
+ scale=scaling,
1412
+ is_causal=is_causal,
1413
+ **sdpa_kwargs,
1414
+ )
1415
+ .transpose(1, 2)
1416
+ .contiguous(),
1417
+ None,
1418
+ )
1419
+ else:
1420
+ # transformers<4.55
1421
+ if is_causal is None and attention_mask is not None:
1422
+ is_causal = False
1423
+ if is_causal is not None:
1424
+ return (
1425
+ torch.nn.functional.scaled_dot_product_attention(
1426
+ query,
1427
+ key,
1428
+ value,
1429
+ attn_mask=attention_mask,
1430
+ dropout_p=dropout,
1431
+ scale=scaling,
1432
+ is_causal=is_causal,
1433
+ **sdpa_kwargs,
1434
+ )
1435
+ .transpose(1, 2)
1436
+ .contiguous(),
1437
+ None,
1438
+ )
1439
+
1440
+ # To avoid the following errors:
1441
+ # is_causal=query.shape[2] > 1
1442
+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
1443
+ # is_causal=torch.tensor(query.shape[2] > 1)
1444
+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
1445
+ attn_output = torch.cond(
1446
+ query.shape[2] > 1, # distinction between prefill and decoding steps
1447
+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1448
+ query,
1449
+ key,
1450
+ value,
1451
+ dropout_p=dropout,
1452
+ scale=scaling,
1453
+ is_causal=True,
1454
+ **sdpa_kwargs,
1455
+ ).contiguous(),
1456
+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
1457
+ query,
1458
+ key,
1459
+ value,
1460
+ dropout_p=dropout,
1461
+ scale=scaling,
1462
+ is_causal=False,
1463
+ **sdpa_kwargs,
1464
+ ).contiguous(),
1465
+ [query, key, value],
1328
1466
  )
1329
1467
  attn_output = attn_output.transpose(1, 2).contiguous()
1330
1468
  return attn_output, None
@@ -1,13 +1,20 @@
1
- from typing import Any, List, Set, Tuple
1
+ import itertools
2
+ from typing import Any, Callable, List, Set, Tuple
2
3
  import torch
3
4
  from transformers.cache_utils import (
5
+ Cache,
4
6
  DynamicCache,
5
7
  EncoderDecoderCache,
6
8
  HybridCache,
7
- SlidingWindowCache,
8
9
  StaticCache,
9
10
  )
10
11
 
12
+ try:
13
+ from transformers.cache_utils import SlidingWindowCache
14
+ except ImportError:
15
+ SlidingWindowCache = None
16
+
17
+
11
18
  try:
12
19
  from transformers.models.mamba.modeling_mamba import MambaCache
13
20
  except ImportError:
@@ -30,66 +37,36 @@ WRONG_REGISTRATIONS = {
30
37
  }
31
38
 
32
39
 
33
- ############
34
- # MambaCache
35
- ############
36
-
37
-
38
- def flatten_mamba_cache(
39
- mamba_cache: MambaCache,
40
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
41
- """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
42
- assert isinstance(mamba_cache.conv_states, list) and isinstance(
43
- mamba_cache.ssm_states, list
44
- ), (
45
- f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, "
46
- f"{type(mamba_cache.ssm_states)}"
40
+ def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
41
+ ca = CacheKeyValue(cache)
42
+ flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
43
+ keys = list(
44
+ itertools.chain.from_iterable(
45
+ (f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
46
+ )
47
47
  )
48
- flat = [
49
- ("conv_states", mamba_cache.conv_states),
50
- ("ssm_states", mamba_cache.ssm_states),
51
- ]
52
- return [f[1] for f in flat], [f[0] for f in flat]
48
+ return flat, keys
53
49
 
54
50
 
55
- def unflatten_mamba_cache(
56
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
57
- ) -> MambaCache:
58
- """Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
59
- conv_states, ssm_states = values
60
-
61
- class _config:
62
- def __init__(self):
63
- if isinstance(conv_states, list):
64
- self.intermediate_size = conv_states[0].shape[1]
65
- self.state_size = ssm_states[0].shape[2]
66
- self.conv_kernel = conv_states[0].shape[2]
67
- self.num_hidden_layers = len(conv_states)
68
- else:
69
- self.intermediate_size = conv_states.shape[2]
70
- self.state_size = ssm_states.shape[3]
71
- self.conv_kernel = conv_states.shape[3]
72
- self.num_hidden_layers = conv_states.shape[0]
73
-
74
- cache = MambaCache(
75
- _config(),
76
- max_batch_size=1,
77
- dtype=values[-1][0].dtype,
78
- device="cpu" if values[-1][0].get_device() < 0 else "cuda",
79
- )
80
- values = dict(zip(context, values))
81
- for k, v in values.items():
82
- setattr(cache, k, v)
83
- return cache
51
+ def _flatten_with_keys_cache(
52
+ cache: Cache,
53
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
54
+ values, context = _flatten_key_value_cache(cache)
55
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
84
56
 
85
57
 
86
- def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
87
- List[Tuple[torch.utils._pytree.KeyEntry, Any]],
88
- torch.utils._pytree.Context,
89
- ]:
90
- """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
91
- values, context = flatten_mamba_cache(cache)
92
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
58
+ def _unflatten_cache(
59
+ make_cache: Callable,
60
+ values: List[Any],
61
+ context: torch.utils._pytree.Context,
62
+ output_type=None,
63
+ ) -> DynamicCache:
64
+ """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
65
+ res = make_cache(list(zip(values[::2], values[1::2])))
66
+ assert output_type is None or isinstance(
67
+ res, output_type
68
+ ), f"Type mismatch between {output_type} (expected) and {type(res)}"
69
+ return res
93
70
 
94
71
 
95
72
  ##############
@@ -101,24 +78,21 @@ def flatten_dynamic_cache(
101
78
  dynamic_cache: DynamicCache,
102
79
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
103
80
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
104
- ca = CacheKeyValue(dynamic_cache)
105
- flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
106
- return [f[1] for f in flat], [f[0] for f in flat]
81
+ return _flatten_key_value_cache(dynamic_cache)
107
82
 
108
83
 
109
84
  def flatten_with_keys_dynamic_cache(
110
85
  dynamic_cache: DynamicCache,
111
86
  ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
112
87
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
113
- values, context = flatten_dynamic_cache(dynamic_cache)
114
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
88
+ return _flatten_with_keys_cache(dynamic_cache)
115
89
 
116
90
 
117
91
  def unflatten_dynamic_cache(
118
92
  values: List[Any], context: torch.utils._pytree.Context, output_type=None
119
93
  ) -> DynamicCache:
120
94
  """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
121
- return make_dynamic_cache(list(zip(values[0], values[1])))
95
+ return _unflatten_cache(make_dynamic_cache, values, context, output_type=output_type)
122
96
 
123
97
 
124
98
  #############
@@ -130,24 +104,21 @@ def flatten_hybrid_cache(
130
104
  cache: HybridCache,
131
105
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
132
106
  """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
133
- ca = CacheKeyValue(cache)
134
- flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
135
- return [f[1] for f in flat], [f[0] for f in flat]
107
+ return _flatten_key_value_cache(cache)
136
108
 
137
109
 
138
110
  def flatten_with_keys_hybrid_cache(
139
111
  cache: HybridCache,
140
112
  ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
141
113
  """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
142
- values, context = flatten_hybrid_cache(cache)
143
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
114
+ return _flatten_with_keys_cache(cache)
144
115
 
145
116
 
146
117
  def unflatten_hybrid_cache(
147
118
  values: List[Any], context: torch.utils._pytree.Context, output_type=None
148
119
  ) -> HybridCache:
149
120
  """Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
150
- return make_hybrid_cache(list(zip(values[0], values[1])))
121
+ return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
151
122
 
152
123
 
153
124
  #############
@@ -163,26 +134,27 @@ def flatten_static_cache(
163
134
  assert not ca.key_cache or cache.max_cache_len == ca.key_cache[0].shape[2], (
164
135
  f"Serialization doet not work when "
165
136
  f"cache.max_cache_len={cache.max_cache_len} != "
166
- f"cache.key_cache[0].shape[2]={ca.keu_cache[0].shape[2]}"
137
+ f"cache.key_cache[0].shape[2]={ca.key_cache[0].shape[2]}"
167
138
  )
168
- flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
169
- return [f[1] for f in flat], [f[0] for f in flat]
139
+ return _flatten_key_value_cache(cache)
170
140
 
171
141
 
172
142
  def flatten_with_keys_static_cache(
173
143
  cache: StaticCache,
174
144
  ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
175
145
  """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
176
- values, context = flatten_static_cache(cache)
177
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
146
+ return _flatten_with_keys_cache(cache)
178
147
 
179
148
 
180
149
  def unflatten_static_cache(
181
150
  values: List[Any], context: torch.utils._pytree.Context, output_type=None
182
151
  ) -> StaticCache:
183
152
  """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
184
- return make_static_cache(
185
- list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2]
153
+ return _unflatten_cache(
154
+ lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
155
+ values,
156
+ context,
157
+ output_type=output_type,
186
158
  )
187
159
 
188
160
 
@@ -191,34 +163,36 @@ def unflatten_static_cache(
191
163
  ####################
192
164
 
193
165
 
194
- def flatten_sliding_window_cache(
195
- cache: SlidingWindowCache,
196
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
197
- """
198
- Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
199
- with python objects.
200
- """
201
- ca = CacheKeyValue(cache)
202
- flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
203
- return [f[1] for f in flat], [f[0] for f in flat]
204
-
205
-
206
- def flatten_with_keys_sliding_window_cache(
207
- cache: SlidingWindowCache,
208
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
209
- """
210
- Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
211
- with python objects.
212
- """
213
- values, context = flatten_sliding_window_cache(cache)
214
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
215
-
216
-
217
- def unflatten_sliding_window_cache(
218
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
219
- ) -> SlidingWindowCache:
220
- """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
221
- return make_sliding_window_cache(list(zip(values[0], values[1])))
166
+ if SlidingWindowCache:
167
+
168
+ def flatten_sliding_window_cache(
169
+ cache: SlidingWindowCache,
170
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
171
+ """
172
+ Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
173
+ with python objects.
174
+ """
175
+ return _flatten_key_value_cache(cache)
176
+
177
+ def flatten_with_keys_sliding_window_cache(
178
+ cache: SlidingWindowCache,
179
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
180
+ """
181
+ Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
182
+ with python objects.
183
+ """
184
+ return _flatten_with_keys_cache(cache)
185
+
186
+ def unflatten_sliding_window_cache(
187
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
188
+ ) -> SlidingWindowCache:
189
+ """
190
+ Restores a :class:`transformers.cache_utils.SlidingWindowCache`
191
+ from python objects.
192
+ """
193
+ return _unflatten_cache(
194
+ make_sliding_window_cache, values, context, output_type=output_type
195
+ )
222
196
 
223
197
 
224
198
  #####################
@@ -265,6 +239,68 @@ def unflatten_encoder_decoder_cache(
265
239
  )
266
240
 
267
241
 
242
+ ############
243
+ # MambaCache
244
+ ############
245
+
246
+
247
+ def flatten_mamba_cache(
248
+ mamba_cache: MambaCache,
249
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
250
+ """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
251
+ assert isinstance(mamba_cache.conv_states, list) and isinstance(
252
+ mamba_cache.ssm_states, list
253
+ ), (
254
+ f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, "
255
+ f"{type(mamba_cache.ssm_states)}"
256
+ )
257
+ flat = [
258
+ ("conv_states", mamba_cache.conv_states),
259
+ ("ssm_states", mamba_cache.ssm_states),
260
+ ]
261
+ return [f[1] for f in flat], [f[0] for f in flat]
262
+
263
+
264
+ def unflatten_mamba_cache(
265
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
266
+ ) -> MambaCache:
267
+ """Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
268
+ conv_states, ssm_states = values
269
+
270
+ class _config:
271
+ def __init__(self):
272
+ if isinstance(conv_states, list):
273
+ self.intermediate_size = conv_states[0].shape[1]
274
+ self.state_size = ssm_states[0].shape[2]
275
+ self.conv_kernel = conv_states[0].shape[2]
276
+ self.num_hidden_layers = len(conv_states)
277
+ else:
278
+ self.intermediate_size = conv_states.shape[2]
279
+ self.state_size = ssm_states.shape[3]
280
+ self.conv_kernel = conv_states.shape[3]
281
+ self.num_hidden_layers = conv_states.shape[0]
282
+
283
+ cache = MambaCache(
284
+ _config(),
285
+ max_batch_size=1,
286
+ dtype=values[-1][0].dtype,
287
+ device="cpu" if values[-1][0].get_device() < 0 else "cuda",
288
+ )
289
+ values = dict(zip(context, values))
290
+ for k, v in values.items():
291
+ setattr(cache, k, v)
292
+ return cache
293
+
294
+
295
+ def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
296
+ List[Tuple[torch.utils._pytree.KeyEntry, Any]],
297
+ torch.utils._pytree.Context,
298
+ ]:
299
+ """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
300
+ values, context = flatten_mamba_cache(cache)
301
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
302
+
303
+
268
304
  #############
269
305
  # dataclasses
270
306
  #############
@@ -84,10 +84,7 @@ def get_phi2(
84
84
  0: batch,
85
85
  1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
86
86
  },
87
- "past_key_values": [
88
- [{0: batch, 2: cache_length} for _ in range(n_layers)],
89
- [{0: batch, 2: cache_length} for _ in range(n_layers)],
90
- ],
87
+ "past_key_values": [{0: batch, 2: cache_length} for _ in range(n_layers * 2)],
91
88
  }
92
89
  inputs = dict(
93
90
  input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(