onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, Callable, List, Set, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _lower_name_with_(name):
|
|
7
|
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
8
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def make_serialization_function_for_dataclass(
|
|
12
|
+
cls: type, supported_classes: Set[type]
|
|
13
|
+
) -> Tuple[Callable, Callable, Callable]:
|
|
14
|
+
"""
|
|
15
|
+
Automatically creates serialization function for a class decorated with
|
|
16
|
+
``dataclasses.dataclass``.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type]
|
|
20
|
+
"""Serializes a ``%s`` with python objects."""
|
|
21
|
+
return list(obj.values()), list(obj.keys())
|
|
22
|
+
|
|
23
|
+
def flatten_with_keys_cls(
|
|
24
|
+
obj: cls, # type: ignore[valid-type]
|
|
25
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
26
|
+
"""Serializes a ``%s`` with python objects with keys."""
|
|
27
|
+
values, context = list(obj.values()), list(obj.keys())
|
|
28
|
+
return [
|
|
29
|
+
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
|
|
30
|
+
], context
|
|
31
|
+
|
|
32
|
+
def unflatten_cls(
|
|
33
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
34
|
+
) -> cls: # type: ignore[valid-type]
|
|
35
|
+
"""Restores an instance of ``%s`` from python objects."""
|
|
36
|
+
return cls(**dict(zip(context, values)))
|
|
37
|
+
|
|
38
|
+
name = _lower_name_with_(cls.__name__)
|
|
39
|
+
flatten_cls.__name__ = f"flatten_{name}"
|
|
40
|
+
flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}"
|
|
41
|
+
unflatten_cls.__name__ = f"unflatten_{name}"
|
|
42
|
+
flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__
|
|
43
|
+
flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__
|
|
44
|
+
unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__
|
|
45
|
+
supported_classes.add(cls)
|
|
46
|
+
return flatten_cls, flatten_with_keys_cls, unflatten_cls
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Dict, Optional, Set
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
|
5
|
+
except ImportError as e:
|
|
6
|
+
try:
|
|
7
|
+
import diffusers
|
|
8
|
+
except ImportError:
|
|
9
|
+
diffusers = None
|
|
10
|
+
UNet2DConditionOutput = None
|
|
11
|
+
if diffusers:
|
|
12
|
+
raise e
|
|
13
|
+
|
|
14
|
+
from . import make_serialization_function_for_dataclass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _make_wrong_registrations() -> Dict[type, Optional[str]]:
|
|
18
|
+
res: Dict[type, Optional[str]] = {}
|
|
19
|
+
for c in [UNet2DConditionOutput]:
|
|
20
|
+
if c is not None:
|
|
21
|
+
res[c] = None
|
|
22
|
+
return res
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
26
|
+
WRONG_REGISTRATIONS = _make_wrong_registrations()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if UNet2DConditionOutput is not None:
|
|
30
|
+
(
|
|
31
|
+
flatten_u_net2_d_condition_output,
|
|
32
|
+
flatten_with_keys_u_net2_d_condition_output,
|
|
33
|
+
unflatten_u_net2_d_condition_output,
|
|
34
|
+
) = make_serialization_function_for_dataclass(UNet2DConditionOutput, SUPPORTED_DATACLASSES)
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import Any, Callable, List, Set, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
from transformers.cache_utils import (
|
|
5
|
+
Cache,
|
|
6
|
+
DynamicCache,
|
|
7
|
+
EncoderDecoderCache,
|
|
8
|
+
HybridCache,
|
|
9
|
+
StaticCache,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from transformers.cache_utils import SlidingWindowCache
|
|
14
|
+
except ImportError:
|
|
15
|
+
SlidingWindowCache = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
20
|
+
except ImportError:
|
|
21
|
+
from transformers.cache_utils import MambaCache
|
|
22
|
+
from transformers.modeling_outputs import BaseModelOutput
|
|
23
|
+
from ...helpers.cache_helper import (
|
|
24
|
+
make_dynamic_cache,
|
|
25
|
+
make_hybrid_cache,
|
|
26
|
+
make_sliding_window_cache,
|
|
27
|
+
make_static_cache,
|
|
28
|
+
CacheKeyValue,
|
|
29
|
+
)
|
|
30
|
+
from . import make_serialization_function_for_dataclass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
34
|
+
WRONG_REGISTRATIONS = {
|
|
35
|
+
DynamicCache: "4.50",
|
|
36
|
+
BaseModelOutput: None,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
41
|
+
ca = CacheKeyValue(cache)
|
|
42
|
+
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
|
|
43
|
+
keys = list(
|
|
44
|
+
itertools.chain.from_iterable(
|
|
45
|
+
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
return flat, keys
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _flatten_with_keys_cache(
|
|
52
|
+
cache: Cache,
|
|
53
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
54
|
+
values, context = _flatten_key_value_cache(cache)
|
|
55
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _unflatten_cache(
|
|
59
|
+
make_cache: Callable,
|
|
60
|
+
values: List[Any],
|
|
61
|
+
context: torch.utils._pytree.Context,
|
|
62
|
+
output_type=None,
|
|
63
|
+
) -> DynamicCache:
|
|
64
|
+
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
65
|
+
res = make_cache(list(zip(values[::2], values[1::2])))
|
|
66
|
+
assert output_type is None or isinstance(
|
|
67
|
+
res, output_type
|
|
68
|
+
), f"Type mismatch between {output_type} (expected) and {type(res)}"
|
|
69
|
+
return res
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
##############
|
|
73
|
+
# DynamicCache
|
|
74
|
+
##############
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def flatten_dynamic_cache(
|
|
78
|
+
dynamic_cache: DynamicCache,
|
|
79
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
80
|
+
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
81
|
+
return _flatten_key_value_cache(dynamic_cache)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def flatten_with_keys_dynamic_cache(
|
|
85
|
+
dynamic_cache: DynamicCache,
|
|
86
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
87
|
+
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
88
|
+
return _flatten_with_keys_cache(dynamic_cache)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def unflatten_dynamic_cache(
|
|
92
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
93
|
+
) -> DynamicCache:
|
|
94
|
+
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
95
|
+
return _unflatten_cache(make_dynamic_cache, values, context, output_type=output_type)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
#############
|
|
99
|
+
# HybridCache
|
|
100
|
+
#############
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def flatten_hybrid_cache(
|
|
104
|
+
cache: HybridCache,
|
|
105
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
106
|
+
"""Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
|
|
107
|
+
return _flatten_key_value_cache(cache)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def flatten_with_keys_hybrid_cache(
|
|
111
|
+
cache: HybridCache,
|
|
112
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
113
|
+
"""Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
|
|
114
|
+
return _flatten_with_keys_cache(cache)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def unflatten_hybrid_cache(
|
|
118
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
119
|
+
) -> HybridCache:
|
|
120
|
+
"""Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
|
|
121
|
+
return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
#############
|
|
125
|
+
# StaticCache
|
|
126
|
+
#############
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def flatten_static_cache(
|
|
130
|
+
cache: StaticCache,
|
|
131
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
132
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
133
|
+
ca = CacheKeyValue(cache)
|
|
134
|
+
assert not ca.key_cache or cache.max_cache_len == ca.key_cache[0].shape[2], (
|
|
135
|
+
f"Serialization doet not work when "
|
|
136
|
+
f"cache.max_cache_len={cache.max_cache_len} != "
|
|
137
|
+
f"cache.key_cache[0].shape[2]={ca.key_cache[0].shape[2]}"
|
|
138
|
+
)
|
|
139
|
+
return _flatten_key_value_cache(cache)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def flatten_with_keys_static_cache(
|
|
143
|
+
cache: StaticCache,
|
|
144
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
145
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
146
|
+
return _flatten_with_keys_cache(cache)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def unflatten_static_cache(
|
|
150
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
151
|
+
) -> StaticCache:
|
|
152
|
+
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
153
|
+
return _unflatten_cache(
|
|
154
|
+
lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
|
|
155
|
+
values,
|
|
156
|
+
context,
|
|
157
|
+
output_type=output_type,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
####################
|
|
162
|
+
# SlidingWindowCache
|
|
163
|
+
####################
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
if SlidingWindowCache:
|
|
167
|
+
|
|
168
|
+
def flatten_sliding_window_cache(
|
|
169
|
+
cache: SlidingWindowCache,
|
|
170
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
171
|
+
"""
|
|
172
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
173
|
+
with python objects.
|
|
174
|
+
"""
|
|
175
|
+
return _flatten_key_value_cache(cache)
|
|
176
|
+
|
|
177
|
+
def flatten_with_keys_sliding_window_cache(
|
|
178
|
+
cache: SlidingWindowCache,
|
|
179
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
180
|
+
"""
|
|
181
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
182
|
+
with python objects.
|
|
183
|
+
"""
|
|
184
|
+
return _flatten_with_keys_cache(cache)
|
|
185
|
+
|
|
186
|
+
def unflatten_sliding_window_cache(
|
|
187
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
188
|
+
) -> SlidingWindowCache:
|
|
189
|
+
"""
|
|
190
|
+
Restores a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
191
|
+
from python objects.
|
|
192
|
+
"""
|
|
193
|
+
return _unflatten_cache(
|
|
194
|
+
make_sliding_window_cache, values, context, output_type=output_type
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
#####################
|
|
199
|
+
# EncoderDecoderCache
|
|
200
|
+
#####################
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def flatten_encoder_decoder_cache(
|
|
204
|
+
ec_cache: EncoderDecoderCache,
|
|
205
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
206
|
+
"""
|
|
207
|
+
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
208
|
+
with python objects.
|
|
209
|
+
"""
|
|
210
|
+
dictionary = {
|
|
211
|
+
"self_attention_cache": ec_cache.self_attention_cache,
|
|
212
|
+
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
213
|
+
}
|
|
214
|
+
return torch.utils._pytree._dict_flatten(dictionary)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
|
|
218
|
+
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
219
|
+
torch.utils._pytree.Context,
|
|
220
|
+
]:
|
|
221
|
+
"""
|
|
222
|
+
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
223
|
+
with python objects.
|
|
224
|
+
"""
|
|
225
|
+
dictionary = {
|
|
226
|
+
"self_attention_cache": ec_cache.self_attention_cache,
|
|
227
|
+
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
228
|
+
}
|
|
229
|
+
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def unflatten_encoder_decoder_cache(
|
|
233
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
234
|
+
) -> EncoderDecoderCache:
|
|
235
|
+
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
|
|
236
|
+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
|
237
|
+
return EncoderDecoderCache(
|
|
238
|
+
dictionary["self_attention_cache"], dictionary["cross_attention_cache"]
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
############
|
|
243
|
+
# MambaCache
|
|
244
|
+
############
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def flatten_mamba_cache(
|
|
248
|
+
mamba_cache: MambaCache,
|
|
249
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
250
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
251
|
+
assert isinstance(mamba_cache.conv_states, list) and isinstance(
|
|
252
|
+
mamba_cache.ssm_states, list
|
|
253
|
+
), (
|
|
254
|
+
f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, "
|
|
255
|
+
f"{type(mamba_cache.ssm_states)}"
|
|
256
|
+
)
|
|
257
|
+
flat = [
|
|
258
|
+
("conv_states", mamba_cache.conv_states),
|
|
259
|
+
("ssm_states", mamba_cache.ssm_states),
|
|
260
|
+
]
|
|
261
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def unflatten_mamba_cache(
|
|
265
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
266
|
+
) -> MambaCache:
|
|
267
|
+
"""Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
|
|
268
|
+
conv_states, ssm_states = values
|
|
269
|
+
|
|
270
|
+
class _config:
|
|
271
|
+
def __init__(self):
|
|
272
|
+
if isinstance(conv_states, list):
|
|
273
|
+
self.intermediate_size = conv_states[0].shape[1]
|
|
274
|
+
self.state_size = ssm_states[0].shape[2]
|
|
275
|
+
self.conv_kernel = conv_states[0].shape[2]
|
|
276
|
+
self.num_hidden_layers = len(conv_states)
|
|
277
|
+
else:
|
|
278
|
+
self.intermediate_size = conv_states.shape[2]
|
|
279
|
+
self.state_size = ssm_states.shape[3]
|
|
280
|
+
self.conv_kernel = conv_states.shape[3]
|
|
281
|
+
self.num_hidden_layers = conv_states.shape[0]
|
|
282
|
+
|
|
283
|
+
cache = MambaCache(
|
|
284
|
+
_config(),
|
|
285
|
+
max_batch_size=1,
|
|
286
|
+
dtype=values[-1][0].dtype,
|
|
287
|
+
device="cpu" if values[-1][0].get_device() < 0 else "cuda",
|
|
288
|
+
)
|
|
289
|
+
values = dict(zip(context, values))
|
|
290
|
+
for k, v in values.items():
|
|
291
|
+
setattr(cache, k, v)
|
|
292
|
+
return cache
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
|
|
296
|
+
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
297
|
+
torch.utils._pytree.Context,
|
|
298
|
+
]:
|
|
299
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
300
|
+
values, context = flatten_mamba_cache(cache)
|
|
301
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
#############
|
|
305
|
+
# dataclasses
|
|
306
|
+
#############
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
(
|
|
310
|
+
flatten_base_model_output,
|
|
311
|
+
flatten_with_keys_base_model_output,
|
|
312
|
+
unflatten_base_model_output,
|
|
313
|
+
) = make_serialization_function_for_dataclass(BaseModelOutput, SUPPORTED_DATACLASSES)
|
|
File without changes
|