onnx-diagnostic 0.7.14__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +156 -47
  3. onnx_diagnostic/export/dynamic_shapes.py +6 -6
  4. onnx_diagnostic/export/shape_helper.py +124 -6
  5. onnx_diagnostic/ext_test_case.py +5 -1
  6. onnx_diagnostic/helpers/cache_helper.py +68 -42
  7. onnx_diagnostic/helpers/config_helper.py +2 -1
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
  9. onnx_diagnostic/helpers/helper.py +3 -0
  10. onnx_diagnostic/helpers/rt_helper.py +3 -3
  11. onnx_diagnostic/tasks/image_text_to_text.py +7 -6
  12. onnx_diagnostic/tasks/text_generation.py +7 -4
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +69 -11
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
  15. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +109 -18
  16. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +133 -28
  17. onnx_diagnostic/torch_models/code_sample.py +343 -0
  18. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
  19. onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
  20. onnx_diagnostic/torch_models/validate.py +73 -29
  21. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
  22. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +25 -23
  23. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
  24. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
  25. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
@@ -108,7 +108,7 @@ def flatten_unflatten_for_dynamic_shapes(
108
108
 
109
109
  def is_cache_dynamic_registered(fast: bool = False) -> bool:
110
110
  """
111
- Tells class :class:`transformers.cache_utils.DynamicCache` can be
111
+ Tells if class :class:`transformers.cache_utils.DynamicCache` can be
112
112
  serialized and deserialized. Only then, :func:`torch.export.export`
113
113
  can export a model.
114
114
 
@@ -168,7 +168,33 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
168
168
  ]
169
169
  )
170
170
  print(string_type(past_key_values, with_shape=True))
171
+
172
+ The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
173
+ ``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
174
+ are supported.
171
175
  """
176
+ if (
177
+ key_value_pairs
178
+ and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
179
+ and pv.Version(transformers.__version__) >= pv.Version("4.56")
180
+ ):
181
+ cache = transformers.cache_utils.DynamicCache()
182
+ cache.layers.extend(
183
+ [transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
184
+ )
185
+ for i, layer in enumerate(cache.layers):
186
+ k, v = key_value_pairs[i][0], key_value_pairs[i][1]
187
+ layer.dtype = k.dtype
188
+ layer.device = k.device
189
+ layer.keys = k
190
+ layer.values = v
191
+ layer.is_initialized = True
192
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
193
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
194
+ f"{len(key_value_pairs)} expected."
195
+ )
196
+ return finalize_cache(cache)
197
+
172
198
  cache = transformers.cache_utils.DynamicCache(key_value_pairs)
173
199
  if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
174
200
  # The cache constructor contains the two following lines
@@ -494,51 +520,51 @@ def make_hybrid_cache(
494
520
 
495
521
  .. code-block:: python
496
522
 
497
- self.max_cache_len = (
498
- max_cache_len if max_cache_len is not None else config.max_position_embeddings)
523
+ self.max_cache_len = (
524
+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
499
525
 
500
- # Sliding layers can't be larger than the overall max cache len
501
- self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
502
- self.max_batch_size = max_batch_size
526
+ # Sliding layers can't be larger than the overall max cache len
527
+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
528
+ self.max_batch_size = max_batch_size
503
529
 
504
- self.head_dim = (
505
- config.head_dim if hasattr(config, "head_dim")
506
- else config.hidden_size // config.num_attention_heads
507
- )
530
+ self.head_dim = (
531
+ config.head_dim if hasattr(config, "head_dim")
532
+ else config.hidden_size // config.num_attention_heads
533
+ )
508
534
 
509
- self._dtype = dtype
510
- self.num_key_value_heads = (
511
- config.num_attention_heads
512
- if getattr(config, "num_key_value_heads", None) is None
513
- else config.num_key_value_heads
514
- )
535
+ self._dtype = dtype
536
+ self.num_key_value_heads = (
537
+ config.num_attention_heads
538
+ if getattr(config, "num_key_value_heads", None) is None
539
+ else config.num_key_value_heads
540
+ )
515
541
 
516
- # If the attribute does not exist in the config, fallback to a simple StaticCache
517
- if hasattr(config, "layer_types"):
518
- self.is_sliding = [
519
- layer_type != "full_attention" for layer_type in config.layer_types]
520
- else:
521
- self.is_sliding = [False] * config.num_hidden_layers
522
-
523
- self.key_cache: list[torch.Tensor] = []
524
- self.value_cache: list[torch.Tensor] = []
525
- global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
526
- self.max_cache_len, self.head_dim)
527
- sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
528
- self.sliding_window_len, self.head_dim)
529
- self.sliding_window = min(config.sliding_window, max_cache_len)
530
- device = torch.device(device) if device is not None else None
531
- for i in range(config.num_hidden_layers):
532
- layer_device = layer_device_map[i] if layer_device_map is not None else device
533
- cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
534
- new_layer_key_cache = torch.zeros(
535
- cache_shape, dtype=self._dtype, device=layer_device)
536
- new_layer_value_cache = torch.zeros(
537
- cache_shape, dtype=self._dtype, device=layer_device)
538
- torch._dynamo.mark_static_address(new_layer_key_cache)
539
- torch._dynamo.mark_static_address(new_layer_value_cache)
540
- self.key_cache.append(new_layer_key_cache)
541
- self.value_cache.append(new_layer_value_cache)
542
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
543
+ if hasattr(config, "layer_types"):
544
+ self.is_sliding = [
545
+ layer_type != "full_attention" for layer_type in config.layer_types]
546
+ else:
547
+ self.is_sliding = [False] * config.num_hidden_layers
548
+
549
+ self.key_cache: list[torch.Tensor] = []
550
+ self.value_cache: list[torch.Tensor] = []
551
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
552
+ self.max_cache_len, self.head_dim)
553
+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
554
+ self.sliding_window_len, self.head_dim)
555
+ self.sliding_window = min(config.sliding_window, max_cache_len)
556
+ device = torch.device(device) if device is not None else None
557
+ for i in range(config.num_hidden_layers):
558
+ layer_device = layer_device_map[i] if layer_device_map is not None else device
559
+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
560
+ new_layer_key_cache = torch.zeros(
561
+ cache_shape, dtype=self._dtype, device=layer_device)
562
+ new_layer_value_cache = torch.zeros(
563
+ cache_shape, dtype=self._dtype, device=layer_device)
564
+ torch._dynamo.mark_static_address(new_layer_key_cache)
565
+ torch._dynamo.mark_static_address(new_layer_value_cache)
566
+ self.key_cache.append(new_layer_key_cache)
567
+ self.value_cache.append(new_layer_value_cache)
542
568
  """
543
569
  layer_types = None
544
570
  if key_value_pairs:
@@ -95,7 +95,8 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
95
95
  mod_name = cls.__module__
96
96
  mod = importlib.import_module(mod_name)
97
97
  source = inspect.getsource(mod)
98
- reg = re.compile("config: ([A-Za-z0-9]+)")
98
+ # [^O] avoids capturing Optional[Something]
99
+ reg = re.compile("config: ([^O][A-Za-z0-9]+)")
99
100
  fall = reg.findall(source)
100
101
  if len(fall) == 0:
101
102
  assert not exc, (
@@ -0,0 +1,153 @@
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+
4
+ _UNIQUE = set()
5
+
6
+
7
+ def _unique():
8
+ i = 129 + 1
9
+ while i in _UNIQUE:
10
+ i += 1
11
+ _UNIQUE.add(i)
12
+ return i
13
+
14
+
15
+ def fake_reshape(
16
+ true_tensor: "torch.Tensor", # noqa: F821
17
+ sh: Dict[int, Any], # noqa: F821
18
+ fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
19
+ fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
20
+ ) -> "FakeTensor": # noqa: F821
21
+ """
22
+ Changes the shape of a true tensor to make it dynamic.
23
+
24
+ :param true_tensor: true tensor
25
+ :param sh: dynamic shape
26
+ :param fake_tensor: fake tensor, if None, make a fake one
27
+ :param fake_mode: fake tensor mode
28
+ :return: fake tensor
29
+ """
30
+ import torch
31
+
32
+ # deal with 0/1
33
+ for i in sh:
34
+ if true_tensor.shape[i] <= 1:
35
+ expanded_shape = list(true_tensor.shape)
36
+ expanded_shape[i] = _unique()
37
+ true_tensor = torch.empty(
38
+ tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
39
+ )
40
+
41
+ # deal with equivalent dimension
42
+ new_shape = list(true_tensor.shape)
43
+ mapping = {}
44
+ for i, s in sh.items():
45
+ d = true_tensor.shape[i]
46
+ if d not in mapping:
47
+ mapping[d] = s
48
+ elif mapping[d] != s:
49
+ d = _unique()
50
+ mapping[d] = s
51
+ new_shape[i] = d
52
+ true_tensor = torch.empty(
53
+ tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
54
+ )
55
+
56
+ # now switch to FakeTensor
57
+ if fake_mode is None:
58
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
59
+ from torch._subclasses.fake_tensor import FakeTensorMode
60
+
61
+ shape_env = ShapeEnv()
62
+ fake_mode = FakeTensorMode(shape_env=shape_env)
63
+ if fake_tensor is None:
64
+ fake_tensor = fake_mode.from_tensor(true_tensor, static_shapes=False)
65
+ assert fake_mode is not None, "fake_mode must be provided"
66
+
67
+ new_shape = list(true_tensor.shape)
68
+ for i in sh:
69
+ new_shape[i] = fake_tensor.shape[i]
70
+
71
+ reduced_tensor = fake_mode.from_tensor(true_tensor, static_shapes=True).sum(
72
+ axis=tuple(sorted(sh)), keepdim=True
73
+ )
74
+ return reduced_tensor.expand(*new_shape)
75
+
76
+
77
+ def make_fake(
78
+ x: Any, fake_mode: Optional["FakeTensorMode"] = None # noqa: F821
79
+ ) -> Tuple[Optional["FakeTensor"], Optional["FakeTensorMode"]]: # noqa: F821
80
+ """
81
+ Replaces all tensors by fake tensors.
82
+ This modification happens inplace for caches.
83
+ This function is only implemented for cache with
84
+ ``transformers>=4.55``.
85
+
86
+ .. runpython::
87
+ :showcode:
88
+
89
+ import pprint
90
+ import torch
91
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
92
+ from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
93
+
94
+ inputs, _ = make_fake(
95
+ dict(
96
+ input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
97
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
98
+ position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
99
+ past_key_values=make_dynamic_cache(
100
+ [
101
+ (
102
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
103
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
104
+ ),
105
+ (
106
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
107
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
108
+ ),
109
+ ]
110
+ ),
111
+ )
112
+ )
113
+ pprint.pprint(inputs)
114
+ """
115
+ if x is None:
116
+ return None, None
117
+ if fake_mode is None:
118
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
119
+ from torch._subclasses.fake_tensor import FakeTensorMode
120
+
121
+ shape_env = ShapeEnv()
122
+ fake_mode = FakeTensorMode(shape_env=shape_env)
123
+
124
+ if isinstance(x, (list, tuple)):
125
+ return x.__class__([make_fake(i, fake_mode=fake_mode)[0] for i in x]), fake_mode
126
+ if isinstance(x, dict):
127
+ return {k: make_fake(v, fake_mode=fake_mode)[0] for k, v in x.items()}, fake_mode
128
+
129
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
130
+ assert hasattr(x, "layers"), (
131
+ f"Une more recent version of transformers (>=4.55), "
132
+ f"'layers' not found in class {type(x)}"
133
+ )
134
+ for layer in x.layers:
135
+ assert hasattr(layer, "keys") and hasattr(layer, "values"), (
136
+ f"Une more recent version of transformers (>=4.55), 'layers' "
137
+ f"not found in class {type(layer)} ({dir(layer)})"
138
+ )
139
+ layer.keys = make_fake(layer.keys, fake_mode=fake_mode)[0]
140
+ layer.values = make_fake(layer.values, fake_mode=fake_mode)[0]
141
+ return x, fake_mode
142
+ if x.__class__.__name__ == "EncoderDecoderCache":
143
+ make_fake(x.self_attention_cache, fake_mode=fake_mode)
144
+ make_fake(x.cross_attention_cache, fake_mode=fake_mode)
145
+ return x, fake_mode
146
+ if hasattr(x, "shape"):
147
+ t = fake_mode.from_tensor(x, static_shapes=False)
148
+ return t, fake_mode
149
+ from . import string_type
150
+
151
+ raise TypeError(
152
+ f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
153
+ )
@@ -463,6 +463,7 @@ def string_type(
463
463
  if verbose:
464
464
  print(f"[string_type] F2:{type(obj)}")
465
465
  return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
466
+
466
467
  if isinstance(obj, torch.Tensor):
467
468
  from .torch_helper import torch_dtype_to_onnx_dtype
468
469
 
@@ -783,6 +784,8 @@ def string_type(
783
784
  obj, ultralytics.engine.results.Results
784
785
  ), f"Unexpected type={type(obj)}"
785
786
  return f"ultralytics.{obj.__class__.__name__}(...)"
787
+ if obj.__class__.__name__ == "FakeTensorMode":
788
+ return f"{obj}"
786
789
 
787
790
  if verbose:
788
791
  print(f"[string_type] END:{type(obj)}")
@@ -3,8 +3,6 @@ import numpy as np
3
3
  import onnx
4
4
  import torch
5
5
  from .helper import string_type, flatten_object
6
- from .torch_helper import to_numpy
7
- from .cache_helper import is_cache_dynamic_registered
8
6
 
9
7
 
10
8
  def name_type_to_onnx_dtype(name: str) -> int:
@@ -49,7 +47,7 @@ def make_feeds(
49
47
  assert (
50
48
  not check_flatten
51
49
  or not all(isinstance(obj, torch.Tensor) for obj in flat)
52
- or not is_cache_dynamic_registered(fast=True)
50
+ # or not is_cache_dynamic_registered(fast=True)
53
51
  or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
54
52
  ), (
55
53
  f"Unexpected number of flattened objects, "
@@ -57,6 +55,8 @@ def make_feeds(
57
55
  f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
58
56
  )
59
57
  if use_numpy:
58
+ from .torch_helper import to_numpy
59
+
60
60
  flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
61
61
  names = (
62
62
  [i.name for i in proto.graph.input]
@@ -186,12 +186,13 @@ def _get_inputs_gemma3(
186
186
  f"total_sequence_length={total_sequence_length} != 860 "
187
187
  f"for model {model.__class__.__name__}"
188
188
  )
189
- assert (
190
- head_dim == 256
191
- ), f"head_dim={head_dim} != 256 for model {model.__class__.__name__}"
189
+ assert head_dim in (
190
+ 256,
191
+ 32,
192
+ ), f"head_dim={head_dim} not in (32, 256) for model {model.__class__.__name__}"
192
193
  assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
193
- assert num_key_value_heads == 4, (
194
- f"num_key_value_heads={num_key_value_heads} != 256 "
194
+ assert num_key_value_heads in (1, 4), (
195
+ f"num_key_value_heads={num_key_value_heads} not in (1, 4) "
195
196
  f"for this model {model.__class__.__name__}"
196
197
  )
197
198
 
@@ -270,7 +271,7 @@ def get_inputs_default(
270
271
  "input_ids": {0: batch, 1: seq_length},
271
272
  "token_type_ids": {0: batch, 1: seq_length},
272
273
  "attention_mask": {0: batch, 1: "cache+seq"},
273
- "position_ids": {0: batch, 1: "cache+seq"},
274
+ "position_ids": {0: batch, 1: seq_length},
274
275
  "past_key_values": [
275
276
  [{0: batch} for _ in range(num_hidden_layers)],
276
277
  [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
@@ -19,6 +19,9 @@ __TASK__ = "text-generation"
19
19
  def reduce_model_config(config: Any) -> Dict[str, Any]:
20
20
  """Reduces a model size."""
21
21
  # FalconMambaConfig: use_mambapy
22
+ if hasattr(config, "text_config"):
23
+ # The model is probably of mixture of models used only for text.
24
+ config = config.text_config
22
25
  check_hasattr(
23
26
  config,
24
27
  ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
@@ -217,10 +220,7 @@ def get_inputs(
217
220
  0: batch,
218
221
  1: "cache+seq", # cache_length + seq_length
219
222
  },
220
- "position_ids": {
221
- 0: batch,
222
- 1: "cache+seq", # cache_length + seq_length
223
- },
223
+ "position_ids": {0: batch, 1: seq_length},
224
224
  "past_key_values": [
225
225
  [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
226
226
  [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
@@ -308,6 +308,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
308
308
 
309
309
  If the configuration is None, the function selects typical dimensions.
310
310
  """
311
+ if hasattr(config, "text_config"):
312
+ # The model is probably of mixture of models used only for text.
313
+ config = config.text_config
311
314
  if config is not None:
312
315
  check_hasattr(
313
316
  config,
@@ -2,7 +2,7 @@ import functools
2
2
  import importlib
3
3
  import contextlib
4
4
  import re
5
- from typing import Any, Callable, Dict, List, Optional, Tuple
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
6
  from .onnx_export_serialization import (
7
7
  register_cache_serialization,
8
8
  unregister_cache_serialization,
@@ -160,7 +160,7 @@ def register_additional_serialization_functions(
160
160
  @contextlib.contextmanager
161
161
  def torch_export_patches(
162
162
  patch_sympy: bool = True,
163
- patch_torch: bool = True,
163
+ patch_torch: Union[bool, int] = True,
164
164
  patch_transformers: bool = False,
165
165
  patch_diffusers: bool = False,
166
166
  catch_constraints: bool = True,
@@ -349,6 +349,7 @@ def torch_export_patches(
349
349
  _catch_produce_guards_and_solve_constraints,
350
350
  patch__check_input_constraints_for_graph,
351
351
  patched__broadcast_in_dim_meta,
352
+ patched__broadcast_in_dim_meta_level_2,
352
353
  patched__maybe_broadcast,
353
354
  patched_ShapeEnv,
354
355
  )
@@ -390,8 +391,13 @@ def torch_export_patches(
390
391
  # torch._prims._broadcast_in_dim_meta
391
392
  f_broadcast_in_dim = torch._prims.broadcast_in_dim
392
393
  f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
393
- torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
394
- torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
394
+ _patched_dim_f = (
395
+ patched__broadcast_in_dim_meta_level_2
396
+ if patch_torch == 2
397
+ else patched__broadcast_in_dim_meta
398
+ )
399
+ torch._prims._broadcast_in_dim_meta = _patched_dim_f
400
+ torch._prims.broadcast_in_dim = _patched_dim_f
395
401
 
396
402
  # torch._refs._maybe_broadcast
397
403
  f__maybe_broadcast = torch._refs._maybe_broadcast
@@ -422,7 +428,7 @@ def torch_export_patches(
422
428
  )
423
429
  )
424
430
 
425
- if stop_if_static:
431
+ if patch_torch and stop_if_static:
426
432
  ShapeEnv._log_guard_remember = ShapeEnv._log_guard
427
433
 
428
434
  if verbose:
@@ -453,6 +459,16 @@ def torch_export_patches(
453
459
  except ImportError:
454
460
  masking_utils = None
455
461
 
462
+ try:
463
+ import transformers.integrations.sdpa_attention as sdpa_attention
464
+ except ImportError:
465
+ sdpa_attention = None
466
+
467
+ try:
468
+ import transformers.modeling_utils as modeling_utils
469
+ except ImportError:
470
+ modeling_utils = None
471
+
456
472
  if verbose:
457
473
  import transformers
458
474
 
@@ -464,7 +480,7 @@ def torch_export_patches(
464
480
  patch_transformers_list, verbose=verbose
465
481
  )
466
482
 
467
- if (
483
+ if ( # vmap
468
484
  masking_utils
469
485
  and patch_transformers_list.patch_masking_utils
470
486
  and hasattr(masking_utils, "_vmap_for_bhqkv")
@@ -499,7 +515,7 @@ def torch_export_patches(
499
515
  else:
500
516
  f_transformers_sdpa_mask = None
501
517
 
502
- if (
518
+ if ( # eager_mask
503
519
  masking_utils
504
520
  and patch_transformers_list.patch_masking_utils
505
521
  and hasattr(masking_utils, "eager_mask")
@@ -526,7 +542,7 @@ def torch_export_patches(
526
542
  patch_transformers_list.patched_eager_mask
527
543
  )
528
544
 
529
- if (
545
+ if ( # sdpa_mask
530
546
  masking_utils
531
547
  and patch_transformers_list.patch_masking_utils
532
548
  and hasattr(masking_utils, "sdpa_mask")
@@ -547,6 +563,29 @@ def torch_export_patches(
547
563
  patch_transformers_list.patched_sdpa_mask_recent_torch
548
564
  )
549
565
 
566
+ if ( # sdpa_attention_forward
567
+ sdpa_attention is not None
568
+ and modeling_utils is not None
569
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
570
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
571
+ and hasattr(modeling_utils, "AttentionInterface")
572
+ ):
573
+ if verbose:
574
+ print(
575
+ "[torch_export_patches] patches "
576
+ "transformers.integrations.sdpa_attention.sdpa_attention_forward"
577
+ )
578
+ f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
579
+ sdpa_attention.sdpa_attention_forward = (
580
+ patch_transformers_list.patched_sdpa_attention_forward
581
+ )
582
+ modeling_utils.sdpa_attention_forward = (
583
+ patch_transformers_list.patched_sdpa_attention_forward
584
+ )
585
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
586
+ patch_transformers_list.patched_sdpa_attention_forward
587
+ )
588
+
550
589
  if custom_patches:
551
590
  if verbose:
552
591
  print("[torch_export_patches] applies custom patches")
@@ -656,7 +695,7 @@ def torch_export_patches(
656
695
  patch_transformers_list, revert_patches_info, verbose=verbose
657
696
  )
658
697
 
659
- if (
698
+ if ( # vmap
660
699
  masking_utils
661
700
  and patch_transformers_list.patch_masking_utils
662
701
  and hasattr(masking_utils, "_vmap_for_bhqkv")
@@ -687,7 +726,7 @@ def torch_export_patches(
687
726
  "transformers.masking_utils.sdpa_mask"
688
727
  )
689
728
 
690
- if (
729
+ if ( # eager_mask
691
730
  masking_utils
692
731
  and patch_transformers_list.patch_masking_utils
693
732
  and hasattr(masking_utils, "eager_mask")
@@ -714,7 +753,7 @@ def torch_export_patches(
714
753
  "in ALL_MASK_ATTENTION_FUNCTIONS"
715
754
  )
716
755
 
717
- if (
756
+ if ( # sdpa_mask
718
757
  masking_utils
719
758
  and patch_transformers_list.patch_masking_utils
720
759
  and hasattr(masking_utils, "sdpa_mask")
@@ -734,6 +773,25 @@ def torch_export_patches(
734
773
  "in ALL_MASK_ATTENTION_FUNCTIONS"
735
774
  )
736
775
 
776
+ if ( # sdpa_attention_forward
777
+ sdpa_attention is not None
778
+ and modeling_utils is not None
779
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
780
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
781
+ and hasattr(modeling_utils, "AttentionInterface")
782
+ ):
783
+ sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
784
+ modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
785
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
786
+ f_sdpa_attention_forward
787
+ )
788
+ if verbose:
789
+ print(
790
+ "[torch_export_patches] restored "
791
+ "transformers.integrations.sdpa_attention."
792
+ "sdpa_attention_forward"
793
+ )
794
+
737
795
  ########
738
796
  # caches
739
797
  ########
@@ -12,17 +12,26 @@ from transformers.cache_utils import (
12
12
  StaticCache,
13
13
  )
14
14
 
15
- try:
16
- from transformers.models.mamba.modeling_mamba import MambaCache
17
- except ImportError:
18
- from transformers.cache_utils import MambaCache
19
-
20
15
  from ..helpers import string_type
21
16
  from .serialization import _lower_name_with_
22
17
 
23
18
  PATCH_OF_PATCHES: Set[Any] = set()
24
19
 
25
20
 
21
+ def get_mamba_cache_cls() -> type:
22
+ try:
23
+ from transformers.models.mamba.modeling_mamba import MambaCache
24
+
25
+ return MambaCache
26
+ except ImportError:
27
+ try:
28
+ from transformers.cache_utils import MambaCache
29
+
30
+ return MambaCache
31
+ except ImportError:
32
+ return None
33
+
34
+
26
35
  def register_class_serialization(
27
36
  cls,
28
37
  f_flatten: Callable,
@@ -203,13 +212,6 @@ def serialization_functions(
203
212
  # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
204
213
  verbose=verbose,
205
214
  ),
206
- MambaCache: lambda verbose=verbose: register_class_serialization(
207
- MambaCache,
208
- flatten_mamba_cache,
209
- unflatten_mamba_cache,
210
- flatten_with_keys_mamba_cache,
211
- verbose=verbose,
212
- ),
213
215
  EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
214
216
  EncoderDecoderCache,
215
217
  flatten_encoder_decoder_cache,
@@ -232,6 +234,17 @@ def serialization_functions(
232
234
  verbose=verbose,
233
235
  ),
234
236
  }
237
+ MambaCache = get_mamba_cache_cls()
238
+ if MambaCache:
239
+ transformers_classes[MambaCache] = (
240
+ lambda verbose=verbose: register_class_serialization(
241
+ MambaCache,
242
+ flatten_mamba_cache,
243
+ unflatten_mamba_cache,
244
+ flatten_with_keys_mamba_cache,
245
+ verbose=verbose,
246
+ )
247
+ )
235
248
  classes.update(transformers_classes)
236
249
 
237
250
  if patch_diffusers:
@@ -287,7 +300,12 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
287
300
 
288
301
  def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
289
302
  """Undo all registrations."""
290
- cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
303
+ MambaCache = get_mamba_cache_cls()
304
+ cls_ensemble = (
305
+ {DynamicCache, EncoderDecoderCache}
306
+ | set(undo)
307
+ | ({MambaCache} if MambaCache else set())
308
+ )
291
309
  for cls in cls_ensemble:
292
310
  if undo.get(cls.__name__, False):
293
311
  unregister_class_serialization(cls, verbose)