onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +22 -5
  3. onnx_diagnostic/ext_test_case.py +31 -0
  4. onnx_diagnostic/helpers/cache_helper.py +23 -12
  5. onnx_diagnostic/helpers/config_helper.py +16 -1
  6. onnx_diagnostic/helpers/log_helper.py +308 -83
  7. onnx_diagnostic/helpers/rt_helper.py +11 -1
  8. onnx_diagnostic/helpers/torch_helper.py +7 -3
  9. onnx_diagnostic/tasks/__init__.py +2 -0
  10. onnx_diagnostic/tasks/text_generation.py +17 -8
  11. onnx_diagnostic/tasks/text_to_image.py +91 -0
  12. onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
  15. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
  16. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  17. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  18. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  19. onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
  20. onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
  23. onnx_diagnostic/torch_models/validate.py +36 -12
  24. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
  25. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
  26. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import pprint
2
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
2
+ from typing import Any, Callable, Dict, Optional, Set
3
3
  import packaging.version as pv
4
4
  import optree
5
5
  import torch
@@ -11,10 +11,9 @@ from transformers.cache_utils import (
11
11
  SlidingWindowCache,
12
12
  StaticCache,
13
13
  )
14
- from transformers.modeling_outputs import BaseModelOutput
15
- from ..helpers import string_type
16
- from ..helpers.cache_helper import make_static_cache
17
14
 
15
+ from ..helpers import string_type
16
+ from .serialization import _lower_name_with_
18
17
 
19
18
  PATCH_OF_PATCHES: Set[Any] = set()
20
19
 
@@ -29,7 +28,8 @@ def register_class_serialization(
29
28
  ) -> bool:
30
29
  """
31
30
  Registers a class.
32
- It can be undone with :func:`unregister`.
31
+ It can be undone with
32
+ :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`.
33
33
 
34
34
  :param cls: class to register
35
35
  :param f_flatten: see ``torch.utils._pytree.register_pytree_node``
@@ -40,10 +40,12 @@ def register_class_serialization(
40
40
  :return: registered or not
41
41
  """
42
42
  if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
43
+ if verbose and cls is not None:
44
+ print(f"[register_class_serialization] already registered {cls.__name__}")
43
45
  return False
44
46
 
45
47
  if verbose:
46
- print(f"[register_cache_serialization] register {cls}")
48
+ print(f"[register_class_serialization] ---------- register {cls.__name__}")
47
49
  torch.utils._pytree.register_pytree_node(
48
50
  cls,
49
51
  f_flatten,
@@ -54,8 +56,8 @@ def register_class_serialization(
54
56
  if pv.Version(torch.__version__) < pv.Version("2.7"):
55
57
  if verbose:
56
58
  print(
57
- f"[register_cache_serialization] "
58
- f"register {cls} for torch=={torch.__version__}"
59
+ f"[register_class_serialization] "
60
+ f"---------- register {cls.__name__} for torch=={torch.__version__}"
59
61
  )
60
62
  torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
61
63
 
@@ -72,11 +74,35 @@ def register_class_serialization(
72
74
  return True
73
75
 
74
76
 
75
- def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
77
+ def register_cache_serialization(
78
+ patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
79
+ ) -> Dict[str, bool]:
76
80
  """
77
- Registers many classes with :func:`register_class_serialization`.
81
+ Registers many classes with
82
+ :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`.
78
83
  Returns information needed to undo the registration.
84
+
85
+ :param patch_transformers: add serialization function for
86
+ :epkg:`transformers` package
87
+ :param patch_diffusers: add serialization function for
88
+ :epkg:`diffusers` package
89
+ :param verbosity: verbosity level
90
+ :return: information to unpatch
79
91
  """
92
+ wrong: Dict[type, Optional[str]] = {}
93
+ if patch_transformers:
94
+ from .serialization.transformers_impl import WRONG_REGISTRATIONS
95
+
96
+ wrong |= WRONG_REGISTRATIONS
97
+ if patch_diffusers:
98
+ from .serialization.diffusers_impl import WRONG_REGISTRATIONS
99
+
100
+ wrong |= WRONG_REGISTRATIONS
101
+
102
+ registration_functions = serialization_functions(
103
+ patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
104
+ )
105
+
80
106
  # DynamicCache serialization is different in transformers and does not
81
107
  # play way with torch.export.export.
82
108
  # see test test_export_dynamic_cache_cat with NOBYPASS=1
@@ -85,109 +111,137 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
85
111
  # torch.fx._pytree.register_pytree_flatten_spec(
86
112
  # DynamicCache, _flatten_dynamic_cache_for_fx)
87
113
  # so we remove it anyway
88
- if (
89
- DynamicCache in torch.utils._pytree.SUPPORTED_NODES
90
- and DynamicCache not in PATCH_OF_PATCHES
91
- # and pv.Version(torch.__version__) < pv.Version("2.7")
92
- and pv.Version(transformers.__version__) >= pv.Version("4.50")
93
- ):
94
- if verbose:
95
- print(
96
- f"[_fix_registration] DynamicCache is unregistered and "
97
- f"registered first for transformers=={transformers.__version__}"
98
- )
99
- unregister(DynamicCache, verbose=verbose)
100
- register_class_serialization(
101
- DynamicCache,
102
- flatten_dynamic_cache,
103
- unflatten_dynamic_cache,
104
- flatten_with_keys_dynamic_cache,
105
- # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
106
- verbose=verbose,
107
- )
108
- if verbose:
109
- print("[_fix_registration] DynamicCache done.")
110
- # To avoid doing it multiple times.
111
- PATCH_OF_PATCHES.add(DynamicCache)
112
-
113
114
  # BaseModelOutput serialization is incomplete.
114
115
  # It does not include dynamic shapes mapping.
115
- if (
116
- BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES
117
- and BaseModelOutput not in PATCH_OF_PATCHES
118
- ):
119
- if verbose:
120
- print(
121
- f"[_fix_registration] BaseModelOutput is unregistered and "
122
- f"registered first for transformers=={transformers.__version__}"
116
+ for cls, version in wrong.items():
117
+ if (
118
+ cls in torch.utils._pytree.SUPPORTED_NODES
119
+ and cls not in PATCH_OF_PATCHES
120
+ # and pv.Version(torch.__version__) < pv.Version("2.7")
121
+ and (
122
+ version is None or pv.Version(transformers.__version__) >= pv.Version(version)
123
123
  )
124
- unregister(BaseModelOutput, verbose=verbose)
125
- register_class_serialization(
126
- BaseModelOutput,
127
- flatten_base_model_output,
128
- unflatten_base_model_output,
129
- flatten_with_keys_base_model_output,
130
- verbose=verbose,
131
- )
132
- if verbose:
133
- print("[_fix_registration] BaseModelOutput done.")
134
-
135
- # To avoid doing it multiple times.
136
- PATCH_OF_PATCHES.add(BaseModelOutput)
137
-
138
- return serialization_functions(verbose=verbose)
124
+ ):
125
+ assert cls in registration_functions, (
126
+ f"{cls} has no registration functions mapped to it, "
127
+ f"available options are {list(registration_functions)}"
128
+ )
129
+ if verbose:
130
+ print(
131
+ f"[_fix_registration] {cls.__name__} is unregistered and "
132
+ f"registered first"
133
+ )
134
+ unregister_class_serialization(cls, verbose=verbose)
135
+ registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg]
136
+ if verbose:
137
+ print(f"[_fix_registration] {cls.__name__} done.")
138
+ # To avoid doing it multiple times.
139
+ PATCH_OF_PATCHES.add(cls)
140
+
141
+ # classes with no registration at all.
142
+ done = {}
143
+ for k, v in registration_functions.items():
144
+ done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg]
145
+ return done
146
+
147
+
148
+ def serialization_functions(
149
+ patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
150
+ ) -> Dict[type, Callable[[int], bool]]:
151
+ """Returns the list of serialization functions."""
139
152
 
153
+ supported_classes: Set[type] = set()
154
+ classes: Dict[type, Callable[[int], bool]] = {}
155
+ all_functions: Dict[type, Optional[str]] = {}
140
156
 
141
- def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]:
142
- """Returns the list of serialization functions."""
143
- return dict(
144
- DynamicCache=register_class_serialization(
145
- DynamicCache,
157
+ if patch_transformers:
158
+ from .serialization.transformers_impl import (
159
+ __dict__ as dtr,
160
+ SUPPORTED_DATACLASSES,
146
161
  flatten_dynamic_cache,
147
162
  unflatten_dynamic_cache,
148
163
  flatten_with_keys_dynamic_cache,
149
- # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
150
- verbose=verbose,
151
- ),
152
- MambaCache=register_class_serialization(
153
- MambaCache,
154
164
  flatten_mamba_cache,
155
165
  unflatten_mamba_cache,
156
166
  flatten_with_keys_mamba_cache,
157
- verbose=verbose,
158
- ),
159
- EncoderDecoderCache=register_class_serialization(
160
- EncoderDecoderCache,
161
167
  flatten_encoder_decoder_cache,
162
168
  unflatten_encoder_decoder_cache,
163
169
  flatten_with_keys_encoder_decoder_cache,
164
- verbose=verbose,
165
- ),
166
- BaseModelOutput=register_class_serialization(
167
- BaseModelOutput,
168
- flatten_base_model_output,
169
- unflatten_base_model_output,
170
- flatten_with_keys_base_model_output,
171
- verbose=verbose,
172
- ),
173
- SlidingWindowCache=register_class_serialization(
174
- SlidingWindowCache,
175
170
  flatten_sliding_window_cache,
176
171
  unflatten_sliding_window_cache,
177
172
  flatten_with_keys_sliding_window_cache,
178
- verbose=verbose,
179
- ),
180
- StaticCache=register_class_serialization(
181
- StaticCache,
182
173
  flatten_static_cache,
183
174
  unflatten_static_cache,
184
175
  flatten_with_keys_static_cache,
185
- verbose=verbose,
186
- ),
187
- )
176
+ )
188
177
 
178
+ all_functions.update(dtr)
179
+ supported_classes |= SUPPORTED_DATACLASSES
180
+
181
+ transformers_classes = {
182
+ DynamicCache: lambda verbose=verbose: register_class_serialization(
183
+ DynamicCache,
184
+ flatten_dynamic_cache,
185
+ unflatten_dynamic_cache,
186
+ flatten_with_keys_dynamic_cache,
187
+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
188
+ verbose=verbose,
189
+ ),
190
+ MambaCache: lambda verbose=verbose: register_class_serialization(
191
+ MambaCache,
192
+ flatten_mamba_cache,
193
+ unflatten_mamba_cache,
194
+ flatten_with_keys_mamba_cache,
195
+ verbose=verbose,
196
+ ),
197
+ EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
198
+ EncoderDecoderCache,
199
+ flatten_encoder_decoder_cache,
200
+ unflatten_encoder_decoder_cache,
201
+ flatten_with_keys_encoder_decoder_cache,
202
+ verbose=verbose,
203
+ ),
204
+ SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
205
+ SlidingWindowCache,
206
+ flatten_sliding_window_cache,
207
+ unflatten_sliding_window_cache,
208
+ flatten_with_keys_sliding_window_cache,
209
+ verbose=verbose,
210
+ ),
211
+ StaticCache: lambda verbose=verbose: register_class_serialization(
212
+ StaticCache,
213
+ flatten_static_cache,
214
+ unflatten_static_cache,
215
+ flatten_with_keys_static_cache,
216
+ verbose=verbose,
217
+ ),
218
+ }
219
+ classes.update(transformers_classes)
220
+
221
+ if patch_diffusers:
222
+ from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu
223
+
224
+ all_functions.update(dfu)
225
+ supported_classes |= SUPPORTED_DATACLASSES
226
+
227
+ for cls in supported_classes:
228
+ lname = _lower_name_with_(cls.__name__)
229
+ assert (
230
+ f"flatten_{lname}" in all_functions
231
+ ), f"Unable to find function 'flatten_{lname}' in {list(all_functions)}"
232
+ classes[cls] = (
233
+ lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501
234
+ cls,
235
+ _al[f"flatten_{_ln}"],
236
+ _al[f"unflatten_{_ln}"],
237
+ _al[f"flatten_with_keys_{_ln}"],
238
+ verbose=verbose,
239
+ )
240
+ )
241
+ return classes
189
242
 
190
- def unregister(cls: type, verbose: int = 0):
243
+
244
+ def unregister_class_serialization(cls: type, verbose: int = 0):
191
245
  """Undo the registration."""
192
246
  # torch.utils._pytree._deregister_pytree_flatten_spec(cls)
193
247
  if cls in torch.fx._pytree.SUPPORTED_NODES:
@@ -217,264 +271,7 @@ def unregister(cls: type, verbose: int = 0):
217
271
 
218
272
  def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
219
273
  """Undo all registrations."""
220
- for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]:
274
+ cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
275
+ for cls in cls_ensemble:
221
276
  if undo.get(cls.__name__, False):
222
- unregister(cls, verbose)
223
-
224
-
225
- ############
226
- # MambaCache
227
- ############
228
-
229
-
230
- def flatten_mamba_cache(
231
- mamba_cache: MambaCache,
232
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
233
- """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
234
- flat = [
235
- ("conv_states", mamba_cache.conv_states),
236
- ("ssm_states", mamba_cache.ssm_states),
237
- ]
238
- return [f[1] for f in flat], [f[0] for f in flat]
239
-
240
-
241
- def unflatten_mamba_cache(
242
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
243
- ) -> MambaCache:
244
- """Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
245
- conv_states, ssm_states = values
246
-
247
- class _config:
248
- def __init__(self):
249
- if isinstance(conv_states, list):
250
- self.intermediate_size = conv_states[0].shape[1]
251
- self.state_size = ssm_states[0].shape[2]
252
- self.conv_kernel = conv_states[0].shape[2]
253
- self.num_hidden_layers = len(conv_states)
254
- else:
255
- self.intermediate_size = conv_states.shape[2]
256
- self.state_size = ssm_states.shape[3]
257
- self.conv_kernel = conv_states.shape[3]
258
- self.num_hidden_layers = conv_states.shape[0]
259
-
260
- cache = MambaCache(
261
- _config(),
262
- max_batch_size=1,
263
- dtype=values[-1][0].dtype,
264
- device="cpu" if values[-1][0].get_device() < 0 else "cuda",
265
- )
266
- values = dict(zip(context, values))
267
- for k, v in values.items():
268
- setattr(cache, k, v)
269
- return cache
270
-
271
-
272
- def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
273
- List[Tuple[torch.utils._pytree.KeyEntry, Any]],
274
- torch.utils._pytree.Context,
275
- ]:
276
- """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
277
- values, context = flatten_mamba_cache(cache)
278
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
279
-
280
-
281
- ##############
282
- # DynamicCache
283
- ##############
284
-
285
-
286
- def flatten_dynamic_cache(
287
- dynamic_cache: DynamicCache,
288
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
289
- """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
290
- if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
291
- return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
292
- flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
293
- return [f[1] for f in flat], [f[0] for f in flat]
294
-
295
-
296
- def flatten_with_keys_dynamic_cache(
297
- dynamic_cache: DynamicCache,
298
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
299
- """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
300
- if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
301
- return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
302
- values, context = flatten_dynamic_cache(dynamic_cache)
303
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
304
-
305
-
306
- def unflatten_dynamic_cache(
307
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
308
- ) -> DynamicCache:
309
- """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
310
- if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
311
- assert output_type is None, f"output_type={output_type} not supported"
312
- return transformers.cache_utils._unflatten_dynamic_cache(values, context)
313
-
314
- cache = transformers.cache_utils.DynamicCache()
315
- values = dict(zip(context, values))
316
- for k, v in values.items():
317
- setattr(cache, k, v)
318
- return cache
319
-
320
-
321
- #############
322
- # StaticCache
323
- #############
324
-
325
-
326
- def flatten_static_cache(
327
- cache: StaticCache,
328
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
329
- """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
330
- flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
331
- return [f[1] for f in flat], [f[0] for f in flat]
332
-
333
-
334
- def flatten_with_keys_static_cache(
335
- cache: StaticCache,
336
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
337
- """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
338
- values, context = flatten_static_cache(cache)
339
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
340
-
341
-
342
- def unflatten_static_cache(
343
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
344
- ) -> StaticCache:
345
- """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
346
- return make_static_cache(list(zip(values[0], values[1])))
347
-
348
-
349
- ####################
350
- # SlidingWindowCache
351
- ####################
352
-
353
-
354
- def flatten_sliding_window_cache(
355
- cache: SlidingWindowCache,
356
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
357
- """
358
- Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
359
- with python objects.
360
- """
361
- flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
362
- return [f[1] for f in flat], [f[0] for f in flat]
363
-
364
-
365
- def flatten_with_keys_sliding_window_cache(
366
- cache: SlidingWindowCache,
367
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
368
- """
369
- Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
370
- with python objects.
371
- """
372
- values, context = flatten_sliding_window_cache(cache)
373
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
374
-
375
-
376
- def unflatten_sliding_window_cache(
377
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
378
- ) -> SlidingWindowCache:
379
- """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
380
- key_cache, value_cache = values
381
-
382
- class _config:
383
- def __init__(self):
384
- self.head_dim = key_cache[0].shape[-1]
385
- self.num_attention_heads = key_cache[0].shape[1]
386
- self.num_hidden_layers = len(key_cache)
387
- self.sliding_window = key_cache[0].shape[2]
388
-
389
- cache = SlidingWindowCache(
390
- _config(),
391
- max_batch_size=key_cache[0].shape[0],
392
- max_cache_len=key_cache[0].shape[2], # sligding window
393
- device=key_cache[0].device,
394
- dtype=key_cache[0].dtype,
395
- )
396
-
397
- values = dict(zip(context, values))
398
- for k, v in values.items():
399
- setattr(cache, k, v)
400
- return cache
401
-
402
-
403
- #####################
404
- # EncoderDecoderCache
405
- #####################
406
-
407
-
408
- def flatten_encoder_decoder_cache(
409
- ec_cache: EncoderDecoderCache,
410
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
411
- """
412
- Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
413
- with python objects.
414
- """
415
- dictionary = {
416
- "self_attention_cache": ec_cache.self_attention_cache,
417
- "cross_attention_cache": ec_cache.cross_attention_cache,
418
- }
419
- return torch.utils._pytree._dict_flatten(dictionary)
420
-
421
-
422
- def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
423
- List[Tuple[torch.utils._pytree.KeyEntry, Any]],
424
- torch.utils._pytree.Context,
425
- ]:
426
- """
427
- Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
428
- with python objects.
429
- """
430
- dictionary = {
431
- "self_attention_cache": ec_cache.self_attention_cache,
432
- "cross_attention_cache": ec_cache.cross_attention_cache,
433
- }
434
- return torch.utils._pytree._dict_flatten_with_keys(dictionary)
435
-
436
-
437
- def unflatten_encoder_decoder_cache(
438
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
439
- ) -> EncoderDecoderCache:
440
- """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
441
- dictionary = torch.utils._pytree._dict_unflatten(values, context)
442
- return EncoderDecoderCache(**dictionary)
443
-
444
-
445
- #################
446
- # BaseModelOutput
447
- #################
448
-
449
-
450
- def flatten_base_model_output(
451
- bo: BaseModelOutput,
452
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
453
- """
454
- Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
455
- with python objects.
456
- """
457
- return list(bo.values()), list(bo.keys())
458
-
459
-
460
- def flatten_with_keys_base_model_output(
461
- bo: BaseModelOutput,
462
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
463
- """
464
- Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
465
- with python objects.
466
- """
467
- values, context = flatten_base_model_output(bo)
468
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
469
-
470
-
471
- def unflatten_base_model_output(
472
- values: List[Any],
473
- context: torch.utils._pytree.Context,
474
- output_type=None,
475
- ) -> BaseModelOutput:
476
- """
477
- Restores a :class:`transformers.modeling_outputs.BaseModelOutput`
478
- from python objects.
479
- """
480
- return BaseModelOutput(**dict(zip(context, values)))
277
+ unregister_class_serialization(cls, verbose)