onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,46 @@
1
+ import re
2
+ from typing import Any, Callable, List, Set, Tuple
3
+ import torch
4
+
5
+
6
+ def _lower_name_with_(name):
7
+ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
8
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
9
+
10
+
11
+ def make_serialization_function_for_dataclass(
12
+ cls: type, supported_classes: Set[type]
13
+ ) -> Tuple[Callable, Callable, Callable]:
14
+ """
15
+ Automatically creates serialization function for a class decorated with
16
+ ``dataclasses.dataclass``.
17
+ """
18
+
19
+ def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type]
20
+ """Serializes a ``%s`` with python objects."""
21
+ return list(obj.values()), list(obj.keys())
22
+
23
+ def flatten_with_keys_cls(
24
+ obj: cls, # type: ignore[valid-type]
25
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
26
+ """Serializes a ``%s`` with python objects with keys."""
27
+ values, context = list(obj.values()), list(obj.keys())
28
+ return [
29
+ (torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
30
+ ], context
31
+
32
+ def unflatten_cls(
33
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
34
+ ) -> cls: # type: ignore[valid-type]
35
+ """Restores an instance of ``%s`` from python objects."""
36
+ return cls(**dict(zip(context, values)))
37
+
38
+ name = _lower_name_with_(cls.__name__)
39
+ flatten_cls.__name__ = f"flatten_{name}"
40
+ flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}"
41
+ unflatten_cls.__name__ = f"unflatten_{name}"
42
+ flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__
43
+ flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__
44
+ unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__
45
+ supported_classes.add(cls)
46
+ return flatten_cls, flatten_with_keys_cls, unflatten_cls
@@ -0,0 +1,34 @@
1
+ from typing import Dict, Optional, Set
2
+
3
+ try:
4
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
5
+ except ImportError as e:
6
+ try:
7
+ import diffusers
8
+ except ImportError:
9
+ diffusers = None
10
+ UNet2DConditionOutput = None
11
+ if diffusers:
12
+ raise e
13
+
14
+ from . import make_serialization_function_for_dataclass
15
+
16
+
17
+ def _make_wrong_registrations() -> Dict[type, Optional[str]]:
18
+ res: Dict[type, Optional[str]] = {}
19
+ for c in [UNet2DConditionOutput]:
20
+ if c is not None:
21
+ res[c] = None
22
+ return res
23
+
24
+
25
+ SUPPORTED_DATACLASSES: Set[type] = set()
26
+ WRONG_REGISTRATIONS = _make_wrong_registrations()
27
+
28
+
29
+ if UNet2DConditionOutput is not None:
30
+ (
31
+ flatten_u_net2_d_condition_output,
32
+ flatten_with_keys_u_net2_d_condition_output,
33
+ unflatten_u_net2_d_condition_output,
34
+ ) = make_serialization_function_for_dataclass(UNet2DConditionOutput, SUPPORTED_DATACLASSES)
@@ -0,0 +1,313 @@
1
+ import itertools
2
+ from typing import Any, Callable, List, Set, Tuple
3
+ import torch
4
+ from transformers.cache_utils import (
5
+ Cache,
6
+ DynamicCache,
7
+ EncoderDecoderCache,
8
+ HybridCache,
9
+ StaticCache,
10
+ )
11
+
12
+ try:
13
+ from transformers.cache_utils import SlidingWindowCache
14
+ except ImportError:
15
+ SlidingWindowCache = None
16
+
17
+
18
+ try:
19
+ from transformers.models.mamba.modeling_mamba import MambaCache
20
+ except ImportError:
21
+ from transformers.cache_utils import MambaCache
22
+ from transformers.modeling_outputs import BaseModelOutput
23
+ from ...helpers.cache_helper import (
24
+ make_dynamic_cache,
25
+ make_hybrid_cache,
26
+ make_sliding_window_cache,
27
+ make_static_cache,
28
+ CacheKeyValue,
29
+ )
30
+ from . import make_serialization_function_for_dataclass
31
+
32
+
33
+ SUPPORTED_DATACLASSES: Set[type] = set()
34
+ WRONG_REGISTRATIONS = {
35
+ DynamicCache: "4.50",
36
+ BaseModelOutput: None,
37
+ }
38
+
39
+
40
+ def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
41
+ ca = CacheKeyValue(cache)
42
+ flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
43
+ keys = list(
44
+ itertools.chain.from_iterable(
45
+ (f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
46
+ )
47
+ )
48
+ return flat, keys
49
+
50
+
51
+ def _flatten_with_keys_cache(
52
+ cache: Cache,
53
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
54
+ values, context = _flatten_key_value_cache(cache)
55
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
56
+
57
+
58
+ def _unflatten_cache(
59
+ make_cache: Callable,
60
+ values: List[Any],
61
+ context: torch.utils._pytree.Context,
62
+ output_type=None,
63
+ ) -> DynamicCache:
64
+ """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
65
+ res = make_cache(list(zip(values[::2], values[1::2])))
66
+ assert output_type is None or isinstance(
67
+ res, output_type
68
+ ), f"Type mismatch between {output_type} (expected) and {type(res)}"
69
+ return res
70
+
71
+
72
+ ##############
73
+ # DynamicCache
74
+ ##############
75
+
76
+
77
+ def flatten_dynamic_cache(
78
+ dynamic_cache: DynamicCache,
79
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
80
+ """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
81
+ return _flatten_key_value_cache(dynamic_cache)
82
+
83
+
84
+ def flatten_with_keys_dynamic_cache(
85
+ dynamic_cache: DynamicCache,
86
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
87
+ """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
88
+ return _flatten_with_keys_cache(dynamic_cache)
89
+
90
+
91
+ def unflatten_dynamic_cache(
92
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
93
+ ) -> DynamicCache:
94
+ """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
95
+ return _unflatten_cache(make_dynamic_cache, values, context, output_type=output_type)
96
+
97
+
98
+ #############
99
+ # HybridCache
100
+ #############
101
+
102
+
103
+ def flatten_hybrid_cache(
104
+ cache: HybridCache,
105
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
106
+ """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
107
+ return _flatten_key_value_cache(cache)
108
+
109
+
110
+ def flatten_with_keys_hybrid_cache(
111
+ cache: HybridCache,
112
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
113
+ """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
114
+ return _flatten_with_keys_cache(cache)
115
+
116
+
117
+ def unflatten_hybrid_cache(
118
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
119
+ ) -> HybridCache:
120
+ """Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
121
+ return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
122
+
123
+
124
+ #############
125
+ # StaticCache
126
+ #############
127
+
128
+
129
+ def flatten_static_cache(
130
+ cache: StaticCache,
131
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
132
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
133
+ ca = CacheKeyValue(cache)
134
+ assert not ca.key_cache or cache.max_cache_len == ca.key_cache[0].shape[2], (
135
+ f"Serialization doet not work when "
136
+ f"cache.max_cache_len={cache.max_cache_len} != "
137
+ f"cache.key_cache[0].shape[2]={ca.key_cache[0].shape[2]}"
138
+ )
139
+ return _flatten_key_value_cache(cache)
140
+
141
+
142
+ def flatten_with_keys_static_cache(
143
+ cache: StaticCache,
144
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
145
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
146
+ return _flatten_with_keys_cache(cache)
147
+
148
+
149
+ def unflatten_static_cache(
150
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
151
+ ) -> StaticCache:
152
+ """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
153
+ return _unflatten_cache(
154
+ lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
155
+ values,
156
+ context,
157
+ output_type=output_type,
158
+ )
159
+
160
+
161
+ ####################
162
+ # SlidingWindowCache
163
+ ####################
164
+
165
+
166
+ if SlidingWindowCache:
167
+
168
+ def flatten_sliding_window_cache(
169
+ cache: SlidingWindowCache,
170
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
171
+ """
172
+ Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
173
+ with python objects.
174
+ """
175
+ return _flatten_key_value_cache(cache)
176
+
177
+ def flatten_with_keys_sliding_window_cache(
178
+ cache: SlidingWindowCache,
179
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
180
+ """
181
+ Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
182
+ with python objects.
183
+ """
184
+ return _flatten_with_keys_cache(cache)
185
+
186
+ def unflatten_sliding_window_cache(
187
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
188
+ ) -> SlidingWindowCache:
189
+ """
190
+ Restores a :class:`transformers.cache_utils.SlidingWindowCache`
191
+ from python objects.
192
+ """
193
+ return _unflatten_cache(
194
+ make_sliding_window_cache, values, context, output_type=output_type
195
+ )
196
+
197
+
198
+ #####################
199
+ # EncoderDecoderCache
200
+ #####################
201
+
202
+
203
+ def flatten_encoder_decoder_cache(
204
+ ec_cache: EncoderDecoderCache,
205
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
206
+ """
207
+ Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
208
+ with python objects.
209
+ """
210
+ dictionary = {
211
+ "self_attention_cache": ec_cache.self_attention_cache,
212
+ "cross_attention_cache": ec_cache.cross_attention_cache,
213
+ }
214
+ return torch.utils._pytree._dict_flatten(dictionary)
215
+
216
+
217
+ def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
218
+ List[Tuple[torch.utils._pytree.KeyEntry, Any]],
219
+ torch.utils._pytree.Context,
220
+ ]:
221
+ """
222
+ Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
223
+ with python objects.
224
+ """
225
+ dictionary = {
226
+ "self_attention_cache": ec_cache.self_attention_cache,
227
+ "cross_attention_cache": ec_cache.cross_attention_cache,
228
+ }
229
+ return torch.utils._pytree._dict_flatten_with_keys(dictionary)
230
+
231
+
232
+ def unflatten_encoder_decoder_cache(
233
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
234
+ ) -> EncoderDecoderCache:
235
+ """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
236
+ dictionary = torch.utils._pytree._dict_unflatten(values, context)
237
+ return EncoderDecoderCache(
238
+ dictionary["self_attention_cache"], dictionary["cross_attention_cache"]
239
+ )
240
+
241
+
242
+ ############
243
+ # MambaCache
244
+ ############
245
+
246
+
247
+ def flatten_mamba_cache(
248
+ mamba_cache: MambaCache,
249
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
250
+ """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
251
+ assert isinstance(mamba_cache.conv_states, list) and isinstance(
252
+ mamba_cache.ssm_states, list
253
+ ), (
254
+ f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, "
255
+ f"{type(mamba_cache.ssm_states)}"
256
+ )
257
+ flat = [
258
+ ("conv_states", mamba_cache.conv_states),
259
+ ("ssm_states", mamba_cache.ssm_states),
260
+ ]
261
+ return [f[1] for f in flat], [f[0] for f in flat]
262
+
263
+
264
+ def unflatten_mamba_cache(
265
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
266
+ ) -> MambaCache:
267
+ """Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
268
+ conv_states, ssm_states = values
269
+
270
+ class _config:
271
+ def __init__(self):
272
+ if isinstance(conv_states, list):
273
+ self.intermediate_size = conv_states[0].shape[1]
274
+ self.state_size = ssm_states[0].shape[2]
275
+ self.conv_kernel = conv_states[0].shape[2]
276
+ self.num_hidden_layers = len(conv_states)
277
+ else:
278
+ self.intermediate_size = conv_states.shape[2]
279
+ self.state_size = ssm_states.shape[3]
280
+ self.conv_kernel = conv_states.shape[3]
281
+ self.num_hidden_layers = conv_states.shape[0]
282
+
283
+ cache = MambaCache(
284
+ _config(),
285
+ max_batch_size=1,
286
+ dtype=values[-1][0].dtype,
287
+ device="cpu" if values[-1][0].get_device() < 0 else "cuda",
288
+ )
289
+ values = dict(zip(context, values))
290
+ for k, v in values.items():
291
+ setattr(cache, k, v)
292
+ return cache
293
+
294
+
295
+ def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
296
+ List[Tuple[torch.utils._pytree.KeyEntry, Any]],
297
+ torch.utils._pytree.Context,
298
+ ]:
299
+ """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
300
+ values, context = flatten_mamba_cache(cache)
301
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
302
+
303
+
304
+ #############
305
+ # dataclasses
306
+ #############
307
+
308
+
309
+ (
310
+ flatten_base_model_output,
311
+ flatten_with_keys_base_model_output,
312
+ unflatten_base_model_output,
313
+ ) = make_serialization_function_for_dataclass(BaseModelOutput, SUPPORTED_DATACLASSES)
File without changes