onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +2 -1
- onnx_diagnostic/export/shape_helper.py +47 -70
- onnx_diagnostic/ext_test_case.py +11 -0
- onnx_diagnostic/helpers/cache_helper.py +38 -7
- onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
- onnx_diagnostic/helpers/helper.py +27 -33
- onnx_diagnostic/helpers/log_helper.py +109 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
- onnx_diagnostic/helpers/model_builder_helper.py +132 -2
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +4 -0
- onnx_diagnostic/helpers/rt_helper.py +393 -43
- onnx_diagnostic/helpers/torch_helper.py +20 -1
- onnx_diagnostic/tasks/__init__.py +7 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
- onnx_diagnostic/tasks/feature_extraction.py +2 -8
- onnx_diagnostic/tasks/image_text_to_text.py +10 -8
- onnx_diagnostic/tasks/summarization.py +2 -8
- onnx_diagnostic/tasks/text2text_generation.py +3 -8
- onnx_diagnostic/tasks/text_generation.py +86 -65
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
- onnx_diagnostic/torch_models/validate.py +626 -228
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -39,22 +39,57 @@ try:
|
|
|
39
39
|
except ImportError:
|
|
40
40
|
patch_DynamicLayer = False
|
|
41
41
|
|
|
42
|
-
from ...ext_test_case import has_transformers
|
|
43
|
-
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
44
42
|
|
|
45
|
-
|
|
43
|
+
def _has_transformers(version: str) -> bool:
|
|
44
|
+
return pv.Version(transformers.__version__) >= pv.Version(version)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _is_torchdynamo_exporting() -> bool:
|
|
48
|
+
"""
|
|
49
|
+
Tells if :epkg:`torch` is exporting a model.
|
|
50
|
+
Relies on ``torch.compiler.is_exporting()``.
|
|
51
|
+
"""
|
|
52
|
+
import torch
|
|
53
|
+
|
|
54
|
+
if not hasattr(torch.compiler, "is_exporting"):
|
|
55
|
+
# torch.compiler.is_exporting requires torch>=2.7
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
return torch.compiler.is_exporting()
|
|
60
|
+
except Exception:
|
|
61
|
+
try:
|
|
62
|
+
import torch._dynamo as dynamo
|
|
63
|
+
|
|
64
|
+
return dynamo.is_exporting() # type: ignore
|
|
65
|
+
except Exception:
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
patch_sdpa_is_causal = _has_transformers("4.99")
|
|
70
|
+
patch_is_initialized = _has_transformers("4.56.99")
|
|
46
71
|
|
|
47
72
|
|
|
48
73
|
if patch_masking_utils:
|
|
49
74
|
# Introduced in 4.52
|
|
50
75
|
from transformers.masking_utils import (
|
|
76
|
+
_ignore_causal_mask_sdpa,
|
|
77
|
+
and_masks,
|
|
51
78
|
causal_mask_function,
|
|
52
79
|
padding_mask_function,
|
|
53
|
-
and_masks,
|
|
54
|
-
_ignore_causal_mask_sdpa,
|
|
55
80
|
prepare_padding_mask,
|
|
56
81
|
)
|
|
57
82
|
|
|
83
|
+
try:
|
|
84
|
+
# transformers>=5.0
|
|
85
|
+
from transformers.masking_utils import (
|
|
86
|
+
_ignore_bidirectional_mask_sdpa,
|
|
87
|
+
bidirectional_mask_function,
|
|
88
|
+
)
|
|
89
|
+
except ImportError:
|
|
90
|
+
_ignore_bidirectional_mask_sdpa = None
|
|
91
|
+
bidirectional_mask_function = None
|
|
92
|
+
|
|
58
93
|
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
59
94
|
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
60
95
|
from ...helpers import string_type
|
|
@@ -98,7 +133,7 @@ if patch_masking_utils:
|
|
|
98
133
|
# for a, dims in zip(args, udimensions)
|
|
99
134
|
# ]
|
|
100
135
|
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
101
|
-
# if
|
|
136
|
+
# if _is_torchdynamo_exporting():
|
|
102
137
|
# for a in args:
|
|
103
138
|
# # The exporter should export with a dimension > 1
|
|
104
139
|
# # to make sure it is dynamic.
|
|
@@ -121,6 +156,7 @@ if patch_masking_utils:
|
|
|
121
156
|
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
122
157
|
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
123
158
|
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
159
|
+
_ = kwargs.pop("allow_is_bidirectional_skip", None)
|
|
124
160
|
# PATCHED: this line called the patched version of sdpa_mask
|
|
125
161
|
mask = patched_sdpa_mask_recent_torch(
|
|
126
162
|
batch_size=batch_size,
|
|
@@ -130,6 +166,7 @@ if patch_masking_utils:
|
|
|
130
166
|
mask_function=mask_function,
|
|
131
167
|
attention_mask=attention_mask,
|
|
132
168
|
allow_is_causal_skip=False,
|
|
169
|
+
allow_is_bidirectional_skip=False,
|
|
133
170
|
allow_torch_fix=False,
|
|
134
171
|
**kwargs,
|
|
135
172
|
)
|
|
@@ -151,6 +188,7 @@ if patch_masking_utils:
|
|
|
151
188
|
attention_mask: Optional[torch.Tensor] = None,
|
|
152
189
|
local_size: Optional[int] = None,
|
|
153
190
|
allow_is_causal_skip: bool = True,
|
|
191
|
+
allow_is_bidirectional_skip: bool = False,
|
|
154
192
|
**kwargs,
|
|
155
193
|
) -> Optional[torch.Tensor]:
|
|
156
194
|
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
@@ -160,6 +198,29 @@ if patch_masking_utils:
|
|
|
160
198
|
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
161
199
|
):
|
|
162
200
|
return None
|
|
201
|
+
if (
|
|
202
|
+
allow_is_bidirectional_skip
|
|
203
|
+
and _ignore_bidirectional_mask_sdpa
|
|
204
|
+
and _ignore_bidirectional_mask_sdpa(padding_mask)
|
|
205
|
+
):
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
if mask_function is bidirectional_mask_function:
|
|
209
|
+
if padding_mask is not None:
|
|
210
|
+
# used for slicing without data-dependent slicing
|
|
211
|
+
mask_indices = (
|
|
212
|
+
torch.arange(kv_length, device=cache_position.device) + kv_offset
|
|
213
|
+
)
|
|
214
|
+
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
|
|
215
|
+
return torch.ones(
|
|
216
|
+
batch_size,
|
|
217
|
+
1,
|
|
218
|
+
q_length,
|
|
219
|
+
kv_length,
|
|
220
|
+
dtype=torch.bool,
|
|
221
|
+
device=cache_position.device,
|
|
222
|
+
)
|
|
223
|
+
|
|
163
224
|
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
|
164
225
|
kv_arange += kv_offset
|
|
165
226
|
if padding_mask is not None:
|
|
@@ -275,7 +336,7 @@ class patched_AttentionMaskConverter:
|
|
|
275
336
|
"""
|
|
276
337
|
|
|
277
338
|
# This method was fixed in 4.51 at least.
|
|
278
|
-
_PATCHES_ = ["_make_causal_mask"] if not
|
|
339
|
+
_PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
|
|
279
340
|
_PATCHED_CLASS_ = AttentionMaskConverter
|
|
280
341
|
|
|
281
342
|
@staticmethod
|
|
@@ -507,7 +568,7 @@ class patched_GenerationMixin:
|
|
|
507
568
|
The current implementation does not rely on ``self`` and could be
|
|
508
569
|
a class method. It is left as a standard method to be easily rewritten.
|
|
509
570
|
"""
|
|
510
|
-
if
|
|
571
|
+
if _is_torchdynamo_exporting():
|
|
511
572
|
return self._cache_dependant_input_preparation_exporting(
|
|
512
573
|
input_ids, inputs_embeds, cache_position
|
|
513
574
|
)
|
|
@@ -1287,11 +1348,29 @@ def patched_sdpa_attention_forward(
|
|
|
1287
1348
|
is_causal: Optional[bool] = None,
|
|
1288
1349
|
**kwargs,
|
|
1289
1350
|
) -> tuple[torch.Tensor, None]:
|
|
1290
|
-
"""
|
|
1351
|
+
"""
|
|
1352
|
+
manual patch for function
|
|
1353
|
+
``transformers.integrations.sdpa_attention.sdpa_attention_forward``
|
|
1354
|
+
"""
|
|
1291
1355
|
assert not kwargs.get("output_attentions", False), (
|
|
1292
1356
|
"`sdpa` attention does not support `output_attentions=True`."
|
|
1293
1357
|
" Please set your attention to `eager` if you want any of these features."
|
|
1294
1358
|
)
|
|
1359
|
+
torch._check(
|
|
1360
|
+
query.shape[0] == key.shape[0] or query.shape[0] == 1,
|
|
1361
|
+
lambda: (
|
|
1362
|
+
f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
|
|
1363
|
+
f"value: {value.shape}"
|
|
1364
|
+
),
|
|
1365
|
+
)
|
|
1366
|
+
torch._check(
|
|
1367
|
+
key.shape[0] == value.shape[0] or key.shape[0] == 1,
|
|
1368
|
+
lambda: (
|
|
1369
|
+
f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
|
|
1370
|
+
f"value: {value.shape}"
|
|
1371
|
+
),
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1295
1374
|
sdpa_kwargs = {}
|
|
1296
1375
|
if hasattr(module, "num_key_value_groups"):
|
|
1297
1376
|
if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
|
|
@@ -1307,24 +1386,83 @@ def patched_sdpa_attention_forward(
|
|
|
1307
1386
|
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1308
1387
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
1309
1388
|
|
|
1310
|
-
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
|
1311
|
-
# PATCHED: remove the test query.shape[2] > 1
|
|
1312
|
-
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
|
|
1313
|
-
is_causal = attention_mask is None and is_causal
|
|
1314
|
-
|
|
1315
1389
|
torch._check(
|
|
1316
1390
|
attention_mask is None or attention_mask.shape[3] == key.shape[2],
|
|
1317
|
-
"Attention mask shape incompatible with key shape.",
|
|
1391
|
+
lambda: "Attention mask shape incompatible with key shape.",
|
|
1318
1392
|
)
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1393
|
+
|
|
1394
|
+
if patch_sdpa_is_causal:
|
|
1395
|
+
# transformers>=4.55
|
|
1396
|
+
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
|
1397
|
+
|
|
1398
|
+
# PATCHED: remove the test query.shape[2] > 1
|
|
1399
|
+
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
|
|
1400
|
+
# and we split the test to keep the minimum in torch.cond
|
|
1401
|
+
is_causal = attention_mask is None and is_causal
|
|
1402
|
+
|
|
1403
|
+
if not is_causal:
|
|
1404
|
+
return (
|
|
1405
|
+
torch.nn.functional.scaled_dot_product_attention(
|
|
1406
|
+
query,
|
|
1407
|
+
key,
|
|
1408
|
+
value,
|
|
1409
|
+
attn_mask=attention_mask,
|
|
1410
|
+
dropout_p=dropout,
|
|
1411
|
+
scale=scaling,
|
|
1412
|
+
is_causal=is_causal,
|
|
1413
|
+
**sdpa_kwargs,
|
|
1414
|
+
)
|
|
1415
|
+
.transpose(1, 2)
|
|
1416
|
+
.contiguous(),
|
|
1417
|
+
None,
|
|
1418
|
+
)
|
|
1419
|
+
else:
|
|
1420
|
+
# transformers<4.55
|
|
1421
|
+
if is_causal is None and attention_mask is not None:
|
|
1422
|
+
is_causal = False
|
|
1423
|
+
if is_causal is not None:
|
|
1424
|
+
return (
|
|
1425
|
+
torch.nn.functional.scaled_dot_product_attention(
|
|
1426
|
+
query,
|
|
1427
|
+
key,
|
|
1428
|
+
value,
|
|
1429
|
+
attn_mask=attention_mask,
|
|
1430
|
+
dropout_p=dropout,
|
|
1431
|
+
scale=scaling,
|
|
1432
|
+
is_causal=is_causal,
|
|
1433
|
+
**sdpa_kwargs,
|
|
1434
|
+
)
|
|
1435
|
+
.transpose(1, 2)
|
|
1436
|
+
.contiguous(),
|
|
1437
|
+
None,
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
# To avoid the following errors:
|
|
1441
|
+
# is_causal=query.shape[2] > 1
|
|
1442
|
+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
|
|
1443
|
+
# is_causal=torch.tensor(query.shape[2] > 1)
|
|
1444
|
+
# TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
|
|
1445
|
+
attn_output = torch.cond(
|
|
1446
|
+
query.shape[2] > 1, # distinction between prefill and decoding steps
|
|
1447
|
+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
1448
|
+
query,
|
|
1449
|
+
key,
|
|
1450
|
+
value,
|
|
1451
|
+
dropout_p=dropout,
|
|
1452
|
+
scale=scaling,
|
|
1453
|
+
is_causal=True,
|
|
1454
|
+
**sdpa_kwargs,
|
|
1455
|
+
).contiguous(),
|
|
1456
|
+
lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
|
|
1457
|
+
query,
|
|
1458
|
+
key,
|
|
1459
|
+
value,
|
|
1460
|
+
dropout_p=dropout,
|
|
1461
|
+
scale=scaling,
|
|
1462
|
+
is_causal=False,
|
|
1463
|
+
**sdpa_kwargs,
|
|
1464
|
+
).contiguous(),
|
|
1465
|
+
[query, key, value],
|
|
1328
1466
|
)
|
|
1329
1467
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1330
1468
|
return attn_output, None
|
|
@@ -1,13 +1,20 @@
|
|
|
1
|
-
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import Any, Callable, List, Set, Tuple
|
|
2
3
|
import torch
|
|
3
4
|
from transformers.cache_utils import (
|
|
5
|
+
Cache,
|
|
4
6
|
DynamicCache,
|
|
5
7
|
EncoderDecoderCache,
|
|
6
8
|
HybridCache,
|
|
7
|
-
SlidingWindowCache,
|
|
8
9
|
StaticCache,
|
|
9
10
|
)
|
|
10
11
|
|
|
12
|
+
try:
|
|
13
|
+
from transformers.cache_utils import SlidingWindowCache
|
|
14
|
+
except ImportError:
|
|
15
|
+
SlidingWindowCache = None
|
|
16
|
+
|
|
17
|
+
|
|
11
18
|
try:
|
|
12
19
|
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
13
20
|
except ImportError:
|
|
@@ -30,66 +37,36 @@ WRONG_REGISTRATIONS = {
|
|
|
30
37
|
}
|
|
31
38
|
|
|
32
39
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
41
|
-
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
42
|
-
assert isinstance(mamba_cache.conv_states, list) and isinstance(
|
|
43
|
-
mamba_cache.ssm_states, list
|
|
44
|
-
), (
|
|
45
|
-
f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, "
|
|
46
|
-
f"{type(mamba_cache.ssm_states)}"
|
|
40
|
+
def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
41
|
+
ca = CacheKeyValue(cache)
|
|
42
|
+
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
|
|
43
|
+
keys = list(
|
|
44
|
+
itertools.chain.from_iterable(
|
|
45
|
+
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
|
|
46
|
+
)
|
|
47
47
|
)
|
|
48
|
-
flat
|
|
49
|
-
("conv_states", mamba_cache.conv_states),
|
|
50
|
-
("ssm_states", mamba_cache.ssm_states),
|
|
51
|
-
]
|
|
52
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
48
|
+
return flat, keys
|
|
53
49
|
|
|
54
50
|
|
|
55
|
-
def
|
|
56
|
-
|
|
57
|
-
) ->
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class _config:
|
|
62
|
-
def __init__(self):
|
|
63
|
-
if isinstance(conv_states, list):
|
|
64
|
-
self.intermediate_size = conv_states[0].shape[1]
|
|
65
|
-
self.state_size = ssm_states[0].shape[2]
|
|
66
|
-
self.conv_kernel = conv_states[0].shape[2]
|
|
67
|
-
self.num_hidden_layers = len(conv_states)
|
|
68
|
-
else:
|
|
69
|
-
self.intermediate_size = conv_states.shape[2]
|
|
70
|
-
self.state_size = ssm_states.shape[3]
|
|
71
|
-
self.conv_kernel = conv_states.shape[3]
|
|
72
|
-
self.num_hidden_layers = conv_states.shape[0]
|
|
73
|
-
|
|
74
|
-
cache = MambaCache(
|
|
75
|
-
_config(),
|
|
76
|
-
max_batch_size=1,
|
|
77
|
-
dtype=values[-1][0].dtype,
|
|
78
|
-
device="cpu" if values[-1][0].get_device() < 0 else "cuda",
|
|
79
|
-
)
|
|
80
|
-
values = dict(zip(context, values))
|
|
81
|
-
for k, v in values.items():
|
|
82
|
-
setattr(cache, k, v)
|
|
83
|
-
return cache
|
|
51
|
+
def _flatten_with_keys_cache(
|
|
52
|
+
cache: Cache,
|
|
53
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
54
|
+
values, context = _flatten_key_value_cache(cache)
|
|
55
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
84
56
|
|
|
85
57
|
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
58
|
+
def _unflatten_cache(
|
|
59
|
+
make_cache: Callable,
|
|
60
|
+
values: List[Any],
|
|
61
|
+
context: torch.utils._pytree.Context,
|
|
62
|
+
output_type=None,
|
|
63
|
+
) -> DynamicCache:
|
|
64
|
+
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
65
|
+
res = make_cache(list(zip(values[::2], values[1::2])))
|
|
66
|
+
assert output_type is None or isinstance(
|
|
67
|
+
res, output_type
|
|
68
|
+
), f"Type mismatch between {output_type} (expected) and {type(res)}"
|
|
69
|
+
return res
|
|
93
70
|
|
|
94
71
|
|
|
95
72
|
##############
|
|
@@ -101,24 +78,21 @@ def flatten_dynamic_cache(
|
|
|
101
78
|
dynamic_cache: DynamicCache,
|
|
102
79
|
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
103
80
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
104
|
-
|
|
105
|
-
flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
|
|
106
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
81
|
+
return _flatten_key_value_cache(dynamic_cache)
|
|
107
82
|
|
|
108
83
|
|
|
109
84
|
def flatten_with_keys_dynamic_cache(
|
|
110
85
|
dynamic_cache: DynamicCache,
|
|
111
86
|
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
112
87
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
113
|
-
|
|
114
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
88
|
+
return _flatten_with_keys_cache(dynamic_cache)
|
|
115
89
|
|
|
116
90
|
|
|
117
91
|
def unflatten_dynamic_cache(
|
|
118
92
|
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
119
93
|
) -> DynamicCache:
|
|
120
94
|
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
121
|
-
return make_dynamic_cache
|
|
95
|
+
return _unflatten_cache(make_dynamic_cache, values, context, output_type=output_type)
|
|
122
96
|
|
|
123
97
|
|
|
124
98
|
#############
|
|
@@ -130,24 +104,21 @@ def flatten_hybrid_cache(
|
|
|
130
104
|
cache: HybridCache,
|
|
131
105
|
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
132
106
|
"""Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
|
|
133
|
-
|
|
134
|
-
flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
|
|
135
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
107
|
+
return _flatten_key_value_cache(cache)
|
|
136
108
|
|
|
137
109
|
|
|
138
110
|
def flatten_with_keys_hybrid_cache(
|
|
139
111
|
cache: HybridCache,
|
|
140
112
|
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
141
113
|
"""Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
|
|
142
|
-
|
|
143
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
114
|
+
return _flatten_with_keys_cache(cache)
|
|
144
115
|
|
|
145
116
|
|
|
146
117
|
def unflatten_hybrid_cache(
|
|
147
118
|
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
148
119
|
) -> HybridCache:
|
|
149
120
|
"""Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
|
|
150
|
-
return make_hybrid_cache
|
|
121
|
+
return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
|
|
151
122
|
|
|
152
123
|
|
|
153
124
|
#############
|
|
@@ -163,26 +134,27 @@ def flatten_static_cache(
|
|
|
163
134
|
assert not ca.key_cache or cache.max_cache_len == ca.key_cache[0].shape[2], (
|
|
164
135
|
f"Serialization doet not work when "
|
|
165
136
|
f"cache.max_cache_len={cache.max_cache_len} != "
|
|
166
|
-
f"cache.key_cache[0].shape[2]={ca.
|
|
137
|
+
f"cache.key_cache[0].shape[2]={ca.key_cache[0].shape[2]}"
|
|
167
138
|
)
|
|
168
|
-
|
|
169
|
-
return [f[1] for f in flat], [f[0] for f in flat]
|
|
139
|
+
return _flatten_key_value_cache(cache)
|
|
170
140
|
|
|
171
141
|
|
|
172
142
|
def flatten_with_keys_static_cache(
|
|
173
143
|
cache: StaticCache,
|
|
174
144
|
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
175
145
|
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
176
|
-
|
|
177
|
-
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
146
|
+
return _flatten_with_keys_cache(cache)
|
|
178
147
|
|
|
179
148
|
|
|
180
149
|
def unflatten_static_cache(
|
|
181
150
|
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
182
151
|
) -> StaticCache:
|
|
183
152
|
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
184
|
-
return
|
|
185
|
-
|
|
153
|
+
return _unflatten_cache(
|
|
154
|
+
lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
|
|
155
|
+
values,
|
|
156
|
+
context,
|
|
157
|
+
output_type=output_type,
|
|
186
158
|
)
|
|
187
159
|
|
|
188
160
|
|
|
@@ -191,34 +163,36 @@ def unflatten_static_cache(
|
|
|
191
163
|
####################
|
|
192
164
|
|
|
193
165
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
166
|
+
if SlidingWindowCache:
|
|
167
|
+
|
|
168
|
+
def flatten_sliding_window_cache(
|
|
169
|
+
cache: SlidingWindowCache,
|
|
170
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
171
|
+
"""
|
|
172
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
173
|
+
with python objects.
|
|
174
|
+
"""
|
|
175
|
+
return _flatten_key_value_cache(cache)
|
|
176
|
+
|
|
177
|
+
def flatten_with_keys_sliding_window_cache(
|
|
178
|
+
cache: SlidingWindowCache,
|
|
179
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
180
|
+
"""
|
|
181
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
182
|
+
with python objects.
|
|
183
|
+
"""
|
|
184
|
+
return _flatten_with_keys_cache(cache)
|
|
185
|
+
|
|
186
|
+
def unflatten_sliding_window_cache(
|
|
187
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
188
|
+
) -> SlidingWindowCache:
|
|
189
|
+
"""
|
|
190
|
+
Restores a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
191
|
+
from python objects.
|
|
192
|
+
"""
|
|
193
|
+
return _unflatten_cache(
|
|
194
|
+
make_sliding_window_cache, values, context, output_type=output_type
|
|
195
|
+
)
|
|
222
196
|
|
|
223
197
|
|
|
224
198
|
#####################
|
|
@@ -265,6 +239,68 @@ def unflatten_encoder_decoder_cache(
|
|
|
265
239
|
)
|
|
266
240
|
|
|
267
241
|
|
|
242
|
+
############
|
|
243
|
+
# MambaCache
|
|
244
|
+
############
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def flatten_mamba_cache(
|
|
248
|
+
mamba_cache: MambaCache,
|
|
249
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
250
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
251
|
+
assert isinstance(mamba_cache.conv_states, list) and isinstance(
|
|
252
|
+
mamba_cache.ssm_states, list
|
|
253
|
+
), (
|
|
254
|
+
f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, "
|
|
255
|
+
f"{type(mamba_cache.ssm_states)}"
|
|
256
|
+
)
|
|
257
|
+
flat = [
|
|
258
|
+
("conv_states", mamba_cache.conv_states),
|
|
259
|
+
("ssm_states", mamba_cache.ssm_states),
|
|
260
|
+
]
|
|
261
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def unflatten_mamba_cache(
|
|
265
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
266
|
+
) -> MambaCache:
|
|
267
|
+
"""Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
|
|
268
|
+
conv_states, ssm_states = values
|
|
269
|
+
|
|
270
|
+
class _config:
|
|
271
|
+
def __init__(self):
|
|
272
|
+
if isinstance(conv_states, list):
|
|
273
|
+
self.intermediate_size = conv_states[0].shape[1]
|
|
274
|
+
self.state_size = ssm_states[0].shape[2]
|
|
275
|
+
self.conv_kernel = conv_states[0].shape[2]
|
|
276
|
+
self.num_hidden_layers = len(conv_states)
|
|
277
|
+
else:
|
|
278
|
+
self.intermediate_size = conv_states.shape[2]
|
|
279
|
+
self.state_size = ssm_states.shape[3]
|
|
280
|
+
self.conv_kernel = conv_states.shape[3]
|
|
281
|
+
self.num_hidden_layers = conv_states.shape[0]
|
|
282
|
+
|
|
283
|
+
cache = MambaCache(
|
|
284
|
+
_config(),
|
|
285
|
+
max_batch_size=1,
|
|
286
|
+
dtype=values[-1][0].dtype,
|
|
287
|
+
device="cpu" if values[-1][0].get_device() < 0 else "cuda",
|
|
288
|
+
)
|
|
289
|
+
values = dict(zip(context, values))
|
|
290
|
+
for k, v in values.items():
|
|
291
|
+
setattr(cache, k, v)
|
|
292
|
+
return cache
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
|
|
296
|
+
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
297
|
+
torch.utils._pytree.Context,
|
|
298
|
+
]:
|
|
299
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
300
|
+
values, context = flatten_mamba_cache(cache)
|
|
301
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
302
|
+
|
|
303
|
+
|
|
268
304
|
#############
|
|
269
305
|
# dataclasses
|
|
270
306
|
#############
|
|
@@ -84,10 +84,7 @@ def get_phi2(
|
|
|
84
84
|
0: batch,
|
|
85
85
|
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
|
|
86
86
|
},
|
|
87
|
-
"past_key_values": [
|
|
88
|
-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
|
|
89
|
-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
|
|
90
|
-
],
|
|
87
|
+
"past_key_values": [{0: batch, 2: cache_length} for _ in range(n_layers * 2)],
|
|
91
88
|
}
|
|
92
89
|
inputs = dict(
|
|
93
90
|
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
|