onnx-diagnostic 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +213 -5
  3. onnx_diagnostic/export/dynamic_shapes.py +48 -20
  4. onnx_diagnostic/export/shape_helper.py +126 -0
  5. onnx_diagnostic/ext_test_case.py +31 -0
  6. onnx_diagnostic/helpers/cache_helper.py +42 -20
  7. onnx_diagnostic/helpers/config_helper.py +16 -1
  8. onnx_diagnostic/helpers/log_helper.py +1561 -177
  9. onnx_diagnostic/helpers/torch_helper.py +6 -2
  10. onnx_diagnostic/tasks/__init__.py +2 -0
  11. onnx_diagnostic/tasks/image_text_to_text.py +69 -18
  12. onnx_diagnostic/tasks/text_generation.py +17 -8
  13. onnx_diagnostic/tasks/text_to_image.py +91 -0
  14. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  15. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
  16. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
  17. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  18. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  19. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  20. onnx_diagnostic/torch_models/hghub/hub_api.py +73 -5
  21. onnx_diagnostic/torch_models/hghub/hub_data.py +7 -2
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +74 -14
  24. onnx_diagnostic/torch_models/validate.py +45 -16
  25. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +29 -24
  27. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import pprint
2
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
2
+ from typing import Any, Callable, Dict, Optional, Set
3
3
  import packaging.version as pv
4
4
  import optree
5
5
  import torch
@@ -11,10 +11,9 @@ from transformers.cache_utils import (
11
11
  SlidingWindowCache,
12
12
  StaticCache,
13
13
  )
14
- from transformers.modeling_outputs import BaseModelOutput
15
- from ..helpers import string_type
16
- from ..helpers.cache_helper import make_static_cache
17
14
 
15
+ from ..helpers import string_type
16
+ from .serialization import _lower_name_with_
18
17
 
19
18
  PATCH_OF_PATCHES: Set[Any] = set()
20
19
 
@@ -40,10 +39,12 @@ def register_class_serialization(
40
39
  :return: registered or not
41
40
  """
42
41
  if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
42
+ if verbose and cls is not None:
43
+ print(f"[register_class_serialization] already registered {cls.__name__}")
43
44
  return False
44
45
 
45
46
  if verbose:
46
- print(f"[register_cache_serialization] register {cls}")
47
+ print(f"[register_class_serialization] ---------- register {cls.__name__}")
47
48
  torch.utils._pytree.register_pytree_node(
48
49
  cls,
49
50
  f_flatten,
@@ -54,8 +55,8 @@ def register_class_serialization(
54
55
  if pv.Version(torch.__version__) < pv.Version("2.7"):
55
56
  if verbose:
56
57
  print(
57
- f"[register_cache_serialization] "
58
- f"register {cls} for torch=={torch.__version__}"
58
+ f"[register_class_serialization] "
59
+ f"---------- register {cls.__name__} for torch=={torch.__version__}"
59
60
  )
60
61
  torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
61
62
 
@@ -72,11 +73,34 @@ def register_class_serialization(
72
73
  return True
73
74
 
74
75
 
75
- def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
76
+ def register_cache_serialization(
77
+ patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
78
+ ) -> Dict[str, bool]:
76
79
  """
77
80
  Registers many classes with :func:`register_class_serialization`.
78
81
  Returns information needed to undo the registration.
82
+
83
+ :param patch_transformers: add serialization function for
84
+ :epkg:`transformers` package
85
+ :param patch_diffusers: add serialization function for
86
+ :epkg:`diffusers` package
87
+ :param verbosity: verbosity level
88
+ :return: information to unpatch
79
89
  """
90
+ wrong: Dict[type, Optional[str]] = {}
91
+ if patch_transformers:
92
+ from .serialization.transformers_impl import WRONG_REGISTRATIONS
93
+
94
+ wrong |= WRONG_REGISTRATIONS
95
+ if patch_diffusers:
96
+ from .serialization.diffusers_impl import WRONG_REGISTRATIONS
97
+
98
+ wrong |= WRONG_REGISTRATIONS
99
+
100
+ registration_functions = serialization_functions(
101
+ patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
102
+ )
103
+
80
104
  # DynamicCache serialization is different in transformers and does not
81
105
  # play way with torch.export.export.
82
106
  # see test test_export_dynamic_cache_cat with NOBYPASS=1
@@ -85,109 +109,137 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
85
109
  # torch.fx._pytree.register_pytree_flatten_spec(
86
110
  # DynamicCache, _flatten_dynamic_cache_for_fx)
87
111
  # so we remove it anyway
88
- if (
89
- DynamicCache in torch.utils._pytree.SUPPORTED_NODES
90
- and DynamicCache not in PATCH_OF_PATCHES
91
- # and pv.Version(torch.__version__) < pv.Version("2.7")
92
- and pv.Version(transformers.__version__) >= pv.Version("4.50")
93
- ):
94
- if verbose:
95
- print(
96
- f"[_fix_registration] DynamicCache is unregistered and "
97
- f"registered first for transformers=={transformers.__version__}"
98
- )
99
- unregister(DynamicCache, verbose=verbose)
100
- register_class_serialization(
101
- DynamicCache,
102
- flatten_dynamic_cache,
103
- unflatten_dynamic_cache,
104
- flatten_with_keys_dynamic_cache,
105
- # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
106
- verbose=verbose,
107
- )
108
- if verbose:
109
- print("[_fix_registration] DynamicCache done.")
110
- # To avoid doing it multiple times.
111
- PATCH_OF_PATCHES.add(DynamicCache)
112
-
113
112
  # BaseModelOutput serialization is incomplete.
114
113
  # It does not include dynamic shapes mapping.
115
- if (
116
- BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES
117
- and BaseModelOutput not in PATCH_OF_PATCHES
118
- ):
119
- if verbose:
120
- print(
121
- f"[_fix_registration] BaseModelOutput is unregistered and "
122
- f"registered first for transformers=={transformers.__version__}"
114
+ for cls, version in wrong.items():
115
+ if (
116
+ cls in torch.utils._pytree.SUPPORTED_NODES
117
+ and cls not in PATCH_OF_PATCHES
118
+ # and pv.Version(torch.__version__) < pv.Version("2.7")
119
+ and (
120
+ version is None or pv.Version(transformers.__version__) >= pv.Version(version)
123
121
  )
124
- unregister(BaseModelOutput, verbose=verbose)
125
- register_class_serialization(
126
- BaseModelOutput,
127
- flatten_base_model_output,
128
- unflatten_base_model_output,
129
- flatten_with_keys_base_model_output,
130
- verbose=verbose,
131
- )
132
- if verbose:
133
- print("[_fix_registration] BaseModelOutput done.")
134
-
135
- # To avoid doing it multiple times.
136
- PATCH_OF_PATCHES.add(BaseModelOutput)
137
-
138
- return serialization_functions(verbose=verbose)
122
+ ):
123
+ assert cls in registration_functions, (
124
+ f"{cls} has no registration functions mapped to it, "
125
+ f"available options are {list(registration_functions)}"
126
+ )
127
+ if verbose:
128
+ print(
129
+ f"[_fix_registration] {cls.__name__} is unregistered and "
130
+ f"registered first"
131
+ )
132
+ unregister_class_serialization(cls, verbose=verbose)
133
+ registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg]
134
+ if verbose:
135
+ print(f"[_fix_registration] {cls.__name__} done.")
136
+ # To avoid doing it multiple times.
137
+ PATCH_OF_PATCHES.add(cls)
138
+
139
+ # classes with no registration at all.
140
+ done = {}
141
+ for k, v in registration_functions.items():
142
+ done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg]
143
+ return done
144
+
145
+
146
+ def serialization_functions(
147
+ patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
148
+ ) -> Dict[type, Callable[[int], bool]]:
149
+ """Returns the list of serialization functions."""
139
150
 
151
+ supported_classes: Set[type] = set()
152
+ classes: Dict[type, Callable[[int], bool]] = {}
153
+ all_functions: Dict[type, Optional[str]] = {}
140
154
 
141
- def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]:
142
- """Returns the list of serialization functions."""
143
- return dict(
144
- DynamicCache=register_class_serialization(
145
- DynamicCache,
155
+ if patch_transformers:
156
+ from .serialization.transformers_impl import (
157
+ __dict__ as dtr,
158
+ SUPPORTED_DATACLASSES,
146
159
  flatten_dynamic_cache,
147
160
  unflatten_dynamic_cache,
148
161
  flatten_with_keys_dynamic_cache,
149
- # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
150
- verbose=verbose,
151
- ),
152
- MambaCache=register_class_serialization(
153
- MambaCache,
154
162
  flatten_mamba_cache,
155
163
  unflatten_mamba_cache,
156
164
  flatten_with_keys_mamba_cache,
157
- verbose=verbose,
158
- ),
159
- EncoderDecoderCache=register_class_serialization(
160
- EncoderDecoderCache,
161
165
  flatten_encoder_decoder_cache,
162
166
  unflatten_encoder_decoder_cache,
163
167
  flatten_with_keys_encoder_decoder_cache,
164
- verbose=verbose,
165
- ),
166
- BaseModelOutput=register_class_serialization(
167
- BaseModelOutput,
168
- flatten_base_model_output,
169
- unflatten_base_model_output,
170
- flatten_with_keys_base_model_output,
171
- verbose=verbose,
172
- ),
173
- SlidingWindowCache=register_class_serialization(
174
- SlidingWindowCache,
175
168
  flatten_sliding_window_cache,
176
169
  unflatten_sliding_window_cache,
177
170
  flatten_with_keys_sliding_window_cache,
178
- verbose=verbose,
179
- ),
180
- StaticCache=register_class_serialization(
181
- StaticCache,
182
171
  flatten_static_cache,
183
172
  unflatten_static_cache,
184
173
  flatten_with_keys_static_cache,
185
- verbose=verbose,
186
- ),
187
- )
174
+ )
188
175
 
176
+ all_functions.update(dtr)
177
+ supported_classes |= SUPPORTED_DATACLASSES
178
+
179
+ transformers_classes = {
180
+ DynamicCache: lambda verbose=verbose: register_class_serialization(
181
+ DynamicCache,
182
+ flatten_dynamic_cache,
183
+ unflatten_dynamic_cache,
184
+ flatten_with_keys_dynamic_cache,
185
+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
186
+ verbose=verbose,
187
+ ),
188
+ MambaCache: lambda verbose=verbose: register_class_serialization(
189
+ MambaCache,
190
+ flatten_mamba_cache,
191
+ unflatten_mamba_cache,
192
+ flatten_with_keys_mamba_cache,
193
+ verbose=verbose,
194
+ ),
195
+ EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
196
+ EncoderDecoderCache,
197
+ flatten_encoder_decoder_cache,
198
+ unflatten_encoder_decoder_cache,
199
+ flatten_with_keys_encoder_decoder_cache,
200
+ verbose=verbose,
201
+ ),
202
+ SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
203
+ SlidingWindowCache,
204
+ flatten_sliding_window_cache,
205
+ unflatten_sliding_window_cache,
206
+ flatten_with_keys_sliding_window_cache,
207
+ verbose=verbose,
208
+ ),
209
+ StaticCache: lambda verbose=verbose: register_class_serialization(
210
+ StaticCache,
211
+ flatten_static_cache,
212
+ unflatten_static_cache,
213
+ flatten_with_keys_static_cache,
214
+ verbose=verbose,
215
+ ),
216
+ }
217
+ classes.update(transformers_classes)
218
+
219
+ if patch_diffusers:
220
+ from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu
221
+
222
+ all_functions.update(dfu)
223
+ supported_classes |= SUPPORTED_DATACLASSES
224
+
225
+ for cls in supported_classes:
226
+ lname = _lower_name_with_(cls.__name__)
227
+ assert (
228
+ f"flatten_{lname}" in all_functions
229
+ ), f"Unable to find function 'flatten_{lname}' in {list(all_functions)}"
230
+ classes[cls] = (
231
+ lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501
232
+ cls,
233
+ _al[f"flatten_{_ln}"],
234
+ _al[f"unflatten_{_ln}"],
235
+ _al[f"flatten_with_keys_{_ln}"],
236
+ verbose=verbose,
237
+ )
238
+ )
239
+ return classes
189
240
 
190
- def unregister(cls: type, verbose: int = 0):
241
+
242
+ def unregister_class_serialization(cls: type, verbose: int = 0):
191
243
  """Undo the registration."""
192
244
  # torch.utils._pytree._deregister_pytree_flatten_spec(cls)
193
245
  if cls in torch.fx._pytree.SUPPORTED_NODES:
@@ -217,264 +269,7 @@ def unregister(cls: type, verbose: int = 0):
217
269
 
218
270
  def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
219
271
  """Undo all registrations."""
220
- for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]:
272
+ cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
273
+ for cls in cls_ensemble:
221
274
  if undo.get(cls.__name__, False):
222
- unregister(cls, verbose)
223
-
224
-
225
- ############
226
- # MambaCache
227
- ############
228
-
229
-
230
- def flatten_mamba_cache(
231
- mamba_cache: MambaCache,
232
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
233
- """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
234
- flat = [
235
- ("conv_states", mamba_cache.conv_states),
236
- ("ssm_states", mamba_cache.ssm_states),
237
- ]
238
- return [f[1] for f in flat], [f[0] for f in flat]
239
-
240
-
241
- def unflatten_mamba_cache(
242
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
243
- ) -> MambaCache:
244
- """Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
245
- conv_states, ssm_states = values
246
-
247
- class _config:
248
- def __init__(self):
249
- if isinstance(conv_states, list):
250
- self.intermediate_size = conv_states[0].shape[1]
251
- self.state_size = ssm_states[0].shape[2]
252
- self.conv_kernel = conv_states[0].shape[2]
253
- self.num_hidden_layers = len(conv_states)
254
- else:
255
- self.intermediate_size = conv_states.shape[2]
256
- self.state_size = ssm_states.shape[3]
257
- self.conv_kernel = conv_states.shape[3]
258
- self.num_hidden_layers = conv_states.shape[0]
259
-
260
- cache = MambaCache(
261
- _config(),
262
- max_batch_size=1,
263
- dtype=values[-1][0].dtype,
264
- device="cpu" if values[-1][0].get_device() < 0 else "cuda",
265
- )
266
- values = dict(zip(context, values))
267
- for k, v in values.items():
268
- setattr(cache, k, v)
269
- return cache
270
-
271
-
272
- def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
273
- List[Tuple[torch.utils._pytree.KeyEntry, Any]],
274
- torch.utils._pytree.Context,
275
- ]:
276
- """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
277
- values, context = flatten_mamba_cache(cache)
278
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
279
-
280
-
281
- ##############
282
- # DynamicCache
283
- ##############
284
-
285
-
286
- def flatten_dynamic_cache(
287
- dynamic_cache: DynamicCache,
288
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
289
- """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
290
- if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
291
- return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
292
- flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
293
- return [f[1] for f in flat], [f[0] for f in flat]
294
-
295
-
296
- def flatten_with_keys_dynamic_cache(
297
- dynamic_cache: DynamicCache,
298
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
299
- """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
300
- if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
301
- return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
302
- values, context = flatten_dynamic_cache(dynamic_cache)
303
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
304
-
305
-
306
- def unflatten_dynamic_cache(
307
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
308
- ) -> DynamicCache:
309
- """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
310
- if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
311
- assert output_type is None, f"output_type={output_type} not supported"
312
- return transformers.cache_utils._unflatten_dynamic_cache(values, context)
313
-
314
- cache = transformers.cache_utils.DynamicCache()
315
- values = dict(zip(context, values))
316
- for k, v in values.items():
317
- setattr(cache, k, v)
318
- return cache
319
-
320
-
321
- ##############
322
- # DynamicCache
323
- ##############
324
-
325
-
326
- def flatten_static_cache(
327
- cache: StaticCache,
328
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
329
- """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
330
- flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
331
- return [f[1] for f in flat], [f[0] for f in flat]
332
-
333
-
334
- def flatten_with_keys_static_cache(
335
- cache: StaticCache,
336
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
337
- """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
338
- values, context = flatten_static_cache(cache)
339
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
340
-
341
-
342
- def unflatten_static_cache(
343
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
344
- ) -> StaticCache:
345
- """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
346
- return make_static_cache(list(zip(values[0], values[1])))
347
-
348
-
349
- ####################
350
- # SlidingWindowCache
351
- ####################
352
-
353
-
354
- def flatten_sliding_window_cache(
355
- cache: SlidingWindowCache,
356
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
357
- """
358
- Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
359
- with python objects.
360
- """
361
- flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
362
- return [f[1] for f in flat], [f[0] for f in flat]
363
-
364
-
365
- def flatten_with_keys_sliding_window_cache(
366
- cache: SlidingWindowCache,
367
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
368
- """
369
- Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
370
- with python objects.
371
- """
372
- values, context = flatten_sliding_window_cache(cache)
373
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
374
-
375
-
376
- def unflatten_sliding_window_cache(
377
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
378
- ) -> SlidingWindowCache:
379
- """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
380
- key_cache, value_cache = values
381
-
382
- class _config:
383
- def __init__(self):
384
- self.head_dim = key_cache[0].shape[-1]
385
- self.num_attention_heads = key_cache[0].shape[1]
386
- self.num_hidden_layers = len(key_cache)
387
- self.sliding_window = key_cache[0].shape[2]
388
-
389
- cache = SlidingWindowCache(
390
- _config(),
391
- max_batch_size=key_cache[0].shape[0],
392
- max_cache_len=key_cache[0].shape[2], # sligding window
393
- device=key_cache[0].device,
394
- dtype=key_cache[0].dtype,
395
- )
396
-
397
- values = dict(zip(context, values))
398
- for k, v in values.items():
399
- setattr(cache, k, v)
400
- return cache
401
-
402
-
403
- #####################
404
- # EncoderDecoderCache
405
- #####################
406
-
407
-
408
- def flatten_encoder_decoder_cache(
409
- ec_cache: EncoderDecoderCache,
410
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
411
- """
412
- Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
413
- with python objects.
414
- """
415
- dictionary = {
416
- "self_attention_cache": ec_cache.self_attention_cache,
417
- "cross_attention_cache": ec_cache.cross_attention_cache,
418
- }
419
- return torch.utils._pytree._dict_flatten(dictionary)
420
-
421
-
422
- def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
423
- List[Tuple[torch.utils._pytree.KeyEntry, Any]],
424
- torch.utils._pytree.Context,
425
- ]:
426
- """
427
- Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
428
- with python objects.
429
- """
430
- dictionary = {
431
- "self_attention_cache": ec_cache.self_attention_cache,
432
- "cross_attention_cache": ec_cache.cross_attention_cache,
433
- }
434
- return torch.utils._pytree._dict_flatten_with_keys(dictionary)
435
-
436
-
437
- def unflatten_encoder_decoder_cache(
438
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
439
- ) -> EncoderDecoderCache:
440
- """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
441
- dictionary = torch.utils._pytree._dict_unflatten(values, context)
442
- return EncoderDecoderCache(**dictionary)
443
-
444
-
445
- #################
446
- # BaseModelOutput
447
- #################
448
-
449
-
450
- def flatten_base_model_output(
451
- bo: BaseModelOutput,
452
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
453
- """
454
- Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
455
- with python objects.
456
- """
457
- return list(bo.values()), list(bo.keys())
458
-
459
-
460
- def flatten_with_keys_base_model_output(
461
- bo: BaseModelOutput,
462
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
463
- """
464
- Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
465
- with python objects.
466
- """
467
- values, context = flatten_base_model_output(bo)
468
- return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
469
-
470
-
471
- def unflatten_base_model_output(
472
- values: List[Any],
473
- context: torch.utils._pytree.Context,
474
- output_type=None,
475
- ) -> BaseModelOutput:
476
- """
477
- Restores a :class:`transformers.modeling_outputs.BaseModelOutput`
478
- from python objects.
479
- """
480
- return BaseModelOutput(**dict(zip(context, values)))
275
+ unregister_class_serialization(cls, verbose)
@@ -2,6 +2,7 @@ import inspect
2
2
  from dataclasses import dataclass
3
3
  from functools import wraps
4
4
  from typing import Any, Callable, Dict, List, Optional, Tuple
5
+ import packaging.version as pv
5
6
  import torch
6
7
  import transformers
7
8
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -20,18 +21,41 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
20
21
  ]
21
22
  if bh_indices:
22
23
  dimensions.extend([(None, 0, None, None), (0, None, None, None)])
24
+ # reshape
23
25
  dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
24
26
  dimensions = tuple(reversed(dimensions))
25
27
  indices = tuple(shape.index(-1) for shape in dimensions)
26
28
 
29
+ # unsqueeze
30
+ udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
31
+
27
32
  def vector_mask_function(
28
33
  *args, mask_function=mask_function, dimensions=dimensions, indices=indices
29
34
  ):
30
- assert len(args) == len(
31
- dimensions
32
- ), f"Mismatch between args={string_type(args)} and dimensions={dimensions}"
35
+ assert len(args) == len(dimensions) == len(udimensions), (
36
+ f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
37
+ f"and udimensions={udimensions}."
38
+ )
39
+ assert len(indices) == len(args), (
40
+ f"Mismatch between args={string_type(args)} and indices={indices}, "
41
+ f"they should have the same length."
42
+ )
43
+ for a in args:
44
+ assert (
45
+ a.ndim == 1
46
+ ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
47
+ torch._check(a.shape[0] > 0)
48
+
33
49
  new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
50
+ # new_args = [
51
+ # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
52
+ # for a, dims in zip(args, udimensions)
53
+ # ]
34
54
  max_shape = tuple(args[i].shape[0] for i in indices)
55
+ # if is_torchdynamo_exporting():
56
+ # for a in args:
57
+ # # The exporter should export with a dimension > 1 to make sure it is dynamic.
58
+ # torch._check(a.shape[0] > 1)
35
59
  expanded_args = [a.expand(max_shape) for a in new_args]
36
60
  return mask_function(*expanded_args)
37
61
 
@@ -791,10 +815,7 @@ def patched_dynamic_rope_update(rope_forward):
791
815
  return wrapper
792
816
 
793
817
 
794
- class patched_Phi3RotaryEmbedding(torch.nn.Module):
795
- _PATCHES_ = ["forward"]
796
- _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
797
-
818
+ class common_RotaryEmbedding(torch.nn.Module):
798
819
  @torch.no_grad()
799
820
  @patched_dynamic_rope_update
800
821
  def forward(self, x, position_ids):
@@ -820,6 +841,65 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
820
841
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
821
842
 
822
843
 
844
+ class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
845
+ _PATCHES_ = ["forward"]
846
+ _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
847
+
848
+
849
+ if pv.Version(transformers.__version__) >= pv.Version("4.52"):
850
+
851
+ class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
852
+ _PATCHES_ = ["forward"]
853
+ _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
854
+
855
+ class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
856
+ _PATCHES_ = ["forward"]
857
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
858
+
859
+
860
+ class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
861
+ _PATCHES_ = ["forward"]
862
+ _PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
863
+
864
+
865
+ class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
866
+ _PATCHES_ = ["forward"]
867
+ _PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
868
+
869
+
870
+ class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
871
+ _PATCHES_ = ["forward"]
872
+ _PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
873
+
874
+
875
+ class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
876
+ _PATCHES_ = ["forward"]
877
+ _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
878
+
879
+
880
+ if pv.Version(transformers.__version__) >= pv.Version("4.51"):
881
+
882
+ class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
883
+ _PATCHES_ = ["forward"]
884
+ _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
885
+
886
+
887
+ if pv.Version(transformers.__version__) >= pv.Version("4.52"):
888
+
889
+ class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
890
+ _PATCHES_ = ["forward"]
891
+ _PATCHED_CLASS_ = (
892
+ transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
893
+ )
894
+
895
+
896
+ if pv.Version(transformers.__version__) >= pv.Version("4.53"):
897
+
898
+ class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
899
+ _PATCHES_ = ["forward"]
900
+ _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
901
+
902
+
823
903
  class patched_IdeficsEmbedding(torch.nn.Module):
824
904
  _PATCHES_ = ["forward"]
825
905
  _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding