onnx-diagnostic 0.7.15__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -168,7 +168,33 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
168
168
  ]
169
169
  )
170
170
  print(string_type(past_key_values, with_shape=True))
171
+
172
+ The function is fully able to handle ``FakeTensor`` with dynamic dimensions if
173
+ ``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
174
+ are supported.
171
175
  """
176
+ if (
177
+ key_value_pairs
178
+ and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
179
+ and pv.Version(transformers.__version__) >= pv.Version("4.56")
180
+ ):
181
+ cache = transformers.cache_utils.DynamicCache()
182
+ cache.layers.extend(
183
+ [transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
184
+ )
185
+ for i, layer in enumerate(cache.layers):
186
+ k, v = key_value_pairs[i][0], key_value_pairs[i][1]
187
+ layer.dtype = k.dtype
188
+ layer.device = k.device
189
+ layer.keys = k
190
+ layer.values = v
191
+ layer.is_initialized = True
192
+ assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
193
+ f"Unexpected number of layers in the cache ({len(cache.layers)}), "
194
+ f"{len(key_value_pairs)} expected."
195
+ )
196
+ return finalize_cache(cache)
197
+
172
198
  cache = transformers.cache_utils.DynamicCache(key_value_pairs)
173
199
  if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
174
200
  # The cache constructor contains the two following lines
@@ -494,51 +520,51 @@ def make_hybrid_cache(
494
520
 
495
521
  .. code-block:: python
496
522
 
497
- self.max_cache_len = (
498
- max_cache_len if max_cache_len is not None else config.max_position_embeddings)
523
+ self.max_cache_len = (
524
+ max_cache_len if max_cache_len is not None else config.max_position_embeddings)
499
525
 
500
- # Sliding layers can't be larger than the overall max cache len
501
- self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
502
- self.max_batch_size = max_batch_size
526
+ # Sliding layers can't be larger than the overall max cache len
527
+ self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
528
+ self.max_batch_size = max_batch_size
503
529
 
504
- self.head_dim = (
505
- config.head_dim if hasattr(config, "head_dim")
506
- else config.hidden_size // config.num_attention_heads
507
- )
530
+ self.head_dim = (
531
+ config.head_dim if hasattr(config, "head_dim")
532
+ else config.hidden_size // config.num_attention_heads
533
+ )
508
534
 
509
- self._dtype = dtype
510
- self.num_key_value_heads = (
511
- config.num_attention_heads
512
- if getattr(config, "num_key_value_heads", None) is None
513
- else config.num_key_value_heads
514
- )
535
+ self._dtype = dtype
536
+ self.num_key_value_heads = (
537
+ config.num_attention_heads
538
+ if getattr(config, "num_key_value_heads", None) is None
539
+ else config.num_key_value_heads
540
+ )
515
541
 
516
- # If the attribute does not exist in the config, fallback to a simple StaticCache
517
- if hasattr(config, "layer_types"):
518
- self.is_sliding = [
519
- layer_type != "full_attention" for layer_type in config.layer_types]
520
- else:
521
- self.is_sliding = [False] * config.num_hidden_layers
522
-
523
- self.key_cache: list[torch.Tensor] = []
524
- self.value_cache: list[torch.Tensor] = []
525
- global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
526
- self.max_cache_len, self.head_dim)
527
- sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
528
- self.sliding_window_len, self.head_dim)
529
- self.sliding_window = min(config.sliding_window, max_cache_len)
530
- device = torch.device(device) if device is not None else None
531
- for i in range(config.num_hidden_layers):
532
- layer_device = layer_device_map[i] if layer_device_map is not None else device
533
- cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
534
- new_layer_key_cache = torch.zeros(
535
- cache_shape, dtype=self._dtype, device=layer_device)
536
- new_layer_value_cache = torch.zeros(
537
- cache_shape, dtype=self._dtype, device=layer_device)
538
- torch._dynamo.mark_static_address(new_layer_key_cache)
539
- torch._dynamo.mark_static_address(new_layer_value_cache)
540
- self.key_cache.append(new_layer_key_cache)
541
- self.value_cache.append(new_layer_value_cache)
542
+ # If the attribute does not exist in the config, fallback to a simple StaticCache
543
+ if hasattr(config, "layer_types"):
544
+ self.is_sliding = [
545
+ layer_type != "full_attention" for layer_type in config.layer_types]
546
+ else:
547
+ self.is_sliding = [False] * config.num_hidden_layers
548
+
549
+ self.key_cache: list[torch.Tensor] = []
550
+ self.value_cache: list[torch.Tensor] = []
551
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
552
+ self.max_cache_len, self.head_dim)
553
+ sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
554
+ self.sliding_window_len, self.head_dim)
555
+ self.sliding_window = min(config.sliding_window, max_cache_len)
556
+ device = torch.device(device) if device is not None else None
557
+ for i in range(config.num_hidden_layers):
558
+ layer_device = layer_device_map[i] if layer_device_map is not None else device
559
+ cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
560
+ new_layer_key_cache = torch.zeros(
561
+ cache_shape, dtype=self._dtype, device=layer_device)
562
+ new_layer_value_cache = torch.zeros(
563
+ cache_shape, dtype=self._dtype, device=layer_device)
564
+ torch._dynamo.mark_static_address(new_layer_key_cache)
565
+ torch._dynamo.mark_static_address(new_layer_value_cache)
566
+ self.key_cache.append(new_layer_key_cache)
567
+ self.value_cache.append(new_layer_value_cache)
542
568
  """
543
569
  layer_types = None
544
570
  if key_value_pairs:
@@ -0,0 +1,153 @@
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+
4
+ _UNIQUE = set()
5
+
6
+
7
+ def _unique():
8
+ i = 129 + 1
9
+ while i in _UNIQUE:
10
+ i += 1
11
+ _UNIQUE.add(i)
12
+ return i
13
+
14
+
15
+ def fake_reshape(
16
+ true_tensor: "torch.Tensor", # noqa: F821
17
+ sh: Dict[int, Any], # noqa: F821
18
+ fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
19
+ fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
20
+ ) -> "FakeTensor": # noqa: F821
21
+ """
22
+ Changes the shape of a true tensor to make it dynamic.
23
+
24
+ :param true_tensor: true tensor
25
+ :param sh: dynamic shape
26
+ :param fake_tensor: fake tensor, if None, make a fake one
27
+ :param fake_mode: fake tensor mode
28
+ :return: fake tensor
29
+ """
30
+ import torch
31
+
32
+ # deal with 0/1
33
+ for i in sh:
34
+ if true_tensor.shape[i] <= 1:
35
+ expanded_shape = list(true_tensor.shape)
36
+ expanded_shape[i] = _unique()
37
+ true_tensor = torch.empty(
38
+ tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
39
+ )
40
+
41
+ # deal with equivalent dimension
42
+ new_shape = list(true_tensor.shape)
43
+ mapping = {}
44
+ for i, s in sh.items():
45
+ d = true_tensor.shape[i]
46
+ if d not in mapping:
47
+ mapping[d] = s
48
+ elif mapping[d] != s:
49
+ d = _unique()
50
+ mapping[d] = s
51
+ new_shape[i] = d
52
+ true_tensor = torch.empty(
53
+ tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
54
+ )
55
+
56
+ # now switch to FakeTensor
57
+ if fake_mode is None:
58
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
59
+ from torch._subclasses.fake_tensor import FakeTensorMode
60
+
61
+ shape_env = ShapeEnv()
62
+ fake_mode = FakeTensorMode(shape_env=shape_env)
63
+ if fake_tensor is None:
64
+ fake_tensor = fake_mode.from_tensor(true_tensor, static_shapes=False)
65
+ assert fake_mode is not None, "fake_mode must be provided"
66
+
67
+ new_shape = list(true_tensor.shape)
68
+ for i in sh:
69
+ new_shape[i] = fake_tensor.shape[i]
70
+
71
+ reduced_tensor = fake_mode.from_tensor(true_tensor, static_shapes=True).sum(
72
+ axis=tuple(sorted(sh)), keepdim=True
73
+ )
74
+ return reduced_tensor.expand(*new_shape)
75
+
76
+
77
+ def make_fake(
78
+ x: Any, fake_mode: Optional["FakeTensorMode"] = None # noqa: F821
79
+ ) -> Tuple[Optional["FakeTensor"], Optional["FakeTensorMode"]]: # noqa: F821
80
+ """
81
+ Replaces all tensors by fake tensors.
82
+ This modification happens inplace for caches.
83
+ This function is only implemented for cache with
84
+ ``transformers>=4.55``.
85
+
86
+ .. runpython::
87
+ :showcode:
88
+
89
+ import pprint
90
+ import torch
91
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
92
+ from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
93
+
94
+ inputs, _ = make_fake(
95
+ dict(
96
+ input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
97
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
98
+ position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
99
+ past_key_values=make_dynamic_cache(
100
+ [
101
+ (
102
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
103
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
104
+ ),
105
+ (
106
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
107
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
108
+ ),
109
+ ]
110
+ ),
111
+ )
112
+ )
113
+ pprint.pprint(inputs)
114
+ """
115
+ if x is None:
116
+ return None, None
117
+ if fake_mode is None:
118
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
119
+ from torch._subclasses.fake_tensor import FakeTensorMode
120
+
121
+ shape_env = ShapeEnv()
122
+ fake_mode = FakeTensorMode(shape_env=shape_env)
123
+
124
+ if isinstance(x, (list, tuple)):
125
+ return x.__class__([make_fake(i, fake_mode=fake_mode)[0] for i in x]), fake_mode
126
+ if isinstance(x, dict):
127
+ return {k: make_fake(v, fake_mode=fake_mode)[0] for k, v in x.items()}, fake_mode
128
+
129
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
130
+ assert hasattr(x, "layers"), (
131
+ f"Une more recent version of transformers (>=4.55), "
132
+ f"'layers' not found in class {type(x)}"
133
+ )
134
+ for layer in x.layers:
135
+ assert hasattr(layer, "keys") and hasattr(layer, "values"), (
136
+ f"Une more recent version of transformers (>=4.55), 'layers' "
137
+ f"not found in class {type(layer)} ({dir(layer)})"
138
+ )
139
+ layer.keys = make_fake(layer.keys, fake_mode=fake_mode)[0]
140
+ layer.values = make_fake(layer.values, fake_mode=fake_mode)[0]
141
+ return x, fake_mode
142
+ if x.__class__.__name__ == "EncoderDecoderCache":
143
+ make_fake(x.self_attention_cache, fake_mode=fake_mode)
144
+ make_fake(x.cross_attention_cache, fake_mode=fake_mode)
145
+ return x, fake_mode
146
+ if hasattr(x, "shape"):
147
+ t = fake_mode.from_tensor(x, static_shapes=False)
148
+ return t, fake_mode
149
+ from . import string_type
150
+
151
+ raise TypeError(
152
+ f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
153
+ )
@@ -463,6 +463,7 @@ def string_type(
463
463
  if verbose:
464
464
  print(f"[string_type] F2:{type(obj)}")
465
465
  return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
466
+
466
467
  if isinstance(obj, torch.Tensor):
467
468
  from .torch_helper import torch_dtype_to_onnx_dtype
468
469
 
@@ -783,6 +784,8 @@ def string_type(
783
784
  obj, ultralytics.engine.results.Results
784
785
  ), f"Unexpected type={type(obj)}"
785
786
  return f"ultralytics.{obj.__class__.__name__}(...)"
787
+ if obj.__class__.__name__ == "FakeTensorMode":
788
+ return f"{obj}"
786
789
 
787
790
  if verbose:
788
791
  print(f"[string_type] END:{type(obj)}")
@@ -271,7 +271,7 @@ def get_inputs_default(
271
271
  "input_ids": {0: batch, 1: seq_length},
272
272
  "token_type_ids": {0: batch, 1: seq_length},
273
273
  "attention_mask": {0: batch, 1: "cache+seq"},
274
- "position_ids": {0: batch, 1: "cache+seq"},
274
+ "position_ids": {0: batch, 1: seq_length},
275
275
  "past_key_values": [
276
276
  [{0: batch} for _ in range(num_hidden_layers)],
277
277
  [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
@@ -220,10 +220,7 @@ def get_inputs(
220
220
  0: batch,
221
221
  1: "cache+seq", # cache_length + seq_length
222
222
  },
223
- "position_ids": {
224
- 0: batch,
225
- 1: "cache+seq", # cache_length + seq_length
226
- },
223
+ "position_ids": {0: batch, 1: seq_length},
227
224
  "past_key_values": [
228
225
  [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
229
226
  [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
@@ -2,7 +2,7 @@ import functools
2
2
  import importlib
3
3
  import contextlib
4
4
  import re
5
- from typing import Any, Callable, Dict, List, Optional, Tuple
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
6
  from .onnx_export_serialization import (
7
7
  register_cache_serialization,
8
8
  unregister_cache_serialization,
@@ -160,7 +160,7 @@ def register_additional_serialization_functions(
160
160
  @contextlib.contextmanager
161
161
  def torch_export_patches(
162
162
  patch_sympy: bool = True,
163
- patch_torch: bool = True,
163
+ patch_torch: Union[bool, int] = True,
164
164
  patch_transformers: bool = False,
165
165
  patch_diffusers: bool = False,
166
166
  catch_constraints: bool = True,
@@ -349,6 +349,7 @@ def torch_export_patches(
349
349
  _catch_produce_guards_and_solve_constraints,
350
350
  patch__check_input_constraints_for_graph,
351
351
  patched__broadcast_in_dim_meta,
352
+ patched__broadcast_in_dim_meta_level_2,
352
353
  patched__maybe_broadcast,
353
354
  patched_ShapeEnv,
354
355
  )
@@ -390,8 +391,13 @@ def torch_export_patches(
390
391
  # torch._prims._broadcast_in_dim_meta
391
392
  f_broadcast_in_dim = torch._prims.broadcast_in_dim
392
393
  f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
393
- torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
394
- torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
394
+ _patched_dim_f = (
395
+ patched__broadcast_in_dim_meta_level_2
396
+ if patch_torch == 2
397
+ else patched__broadcast_in_dim_meta
398
+ )
399
+ torch._prims._broadcast_in_dim_meta = _patched_dim_f
400
+ torch._prims.broadcast_in_dim = _patched_dim_f
395
401
 
396
402
  # torch._refs._maybe_broadcast
397
403
  f__maybe_broadcast = torch._refs._maybe_broadcast
@@ -453,6 +459,16 @@ def torch_export_patches(
453
459
  except ImportError:
454
460
  masking_utils = None
455
461
 
462
+ try:
463
+ import transformers.integrations.sdpa_attention as sdpa_attention
464
+ except ImportError:
465
+ sdpa_attention = None
466
+
467
+ try:
468
+ import transformers.modeling_utils as modeling_utils
469
+ except ImportError:
470
+ modeling_utils = None
471
+
456
472
  if verbose:
457
473
  import transformers
458
474
 
@@ -464,7 +480,7 @@ def torch_export_patches(
464
480
  patch_transformers_list, verbose=verbose
465
481
  )
466
482
 
467
- if (
483
+ if ( # vmap
468
484
  masking_utils
469
485
  and patch_transformers_list.patch_masking_utils
470
486
  and hasattr(masking_utils, "_vmap_for_bhqkv")
@@ -499,7 +515,7 @@ def torch_export_patches(
499
515
  else:
500
516
  f_transformers_sdpa_mask = None
501
517
 
502
- if (
518
+ if ( # eager_mask
503
519
  masking_utils
504
520
  and patch_transformers_list.patch_masking_utils
505
521
  and hasattr(masking_utils, "eager_mask")
@@ -526,7 +542,7 @@ def torch_export_patches(
526
542
  patch_transformers_list.patched_eager_mask
527
543
  )
528
544
 
529
- if (
545
+ if ( # sdpa_mask
530
546
  masking_utils
531
547
  and patch_transformers_list.patch_masking_utils
532
548
  and hasattr(masking_utils, "sdpa_mask")
@@ -547,6 +563,29 @@ def torch_export_patches(
547
563
  patch_transformers_list.patched_sdpa_mask_recent_torch
548
564
  )
549
565
 
566
+ if ( # sdpa_attention_forward
567
+ sdpa_attention is not None
568
+ and modeling_utils is not None
569
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
570
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
571
+ and hasattr(modeling_utils, "AttentionInterface")
572
+ ):
573
+ if verbose:
574
+ print(
575
+ "[torch_export_patches] patches "
576
+ "transformers.integrations.sdpa_attention.sdpa_attention_forward"
577
+ )
578
+ f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
579
+ sdpa_attention.sdpa_attention_forward = (
580
+ patch_transformers_list.patched_sdpa_attention_forward
581
+ )
582
+ modeling_utils.sdpa_attention_forward = (
583
+ patch_transformers_list.patched_sdpa_attention_forward
584
+ )
585
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
586
+ patch_transformers_list.patched_sdpa_attention_forward
587
+ )
588
+
550
589
  if custom_patches:
551
590
  if verbose:
552
591
  print("[torch_export_patches] applies custom patches")
@@ -656,7 +695,7 @@ def torch_export_patches(
656
695
  patch_transformers_list, revert_patches_info, verbose=verbose
657
696
  )
658
697
 
659
- if (
698
+ if ( # vmap
660
699
  masking_utils
661
700
  and patch_transformers_list.patch_masking_utils
662
701
  and hasattr(masking_utils, "_vmap_for_bhqkv")
@@ -687,7 +726,7 @@ def torch_export_patches(
687
726
  "transformers.masking_utils.sdpa_mask"
688
727
  )
689
728
 
690
- if (
729
+ if ( # eager_mask
691
730
  masking_utils
692
731
  and patch_transformers_list.patch_masking_utils
693
732
  and hasattr(masking_utils, "eager_mask")
@@ -714,7 +753,7 @@ def torch_export_patches(
714
753
  "in ALL_MASK_ATTENTION_FUNCTIONS"
715
754
  )
716
755
 
717
- if (
756
+ if ( # sdpa_mask
718
757
  masking_utils
719
758
  and patch_transformers_list.patch_masking_utils
720
759
  and hasattr(masking_utils, "sdpa_mask")
@@ -734,6 +773,25 @@ def torch_export_patches(
734
773
  "in ALL_MASK_ATTENTION_FUNCTIONS"
735
774
  )
736
775
 
776
+ if ( # sdpa_attention_forward
777
+ sdpa_attention is not None
778
+ and modeling_utils is not None
779
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
780
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
781
+ and hasattr(modeling_utils, "AttentionInterface")
782
+ ):
783
+ sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
784
+ modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
785
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
786
+ f_sdpa_attention_forward
787
+ )
788
+ if verbose:
789
+ print(
790
+ "[torch_export_patches] restored "
791
+ "transformers.integrations.sdpa_attention."
792
+ "sdpa_attention_forward"
793
+ )
794
+
737
795
  ########
738
796
  # caches
739
797
  ########
@@ -25,8 +25,8 @@ def retrieve_stacktrace():
25
25
 
26
26
  def _catch_produce_guards_and_solve_constraints(
27
27
  previous_function: Callable,
28
- fake_mode: "FakeTensorMode", # noqa: F821
29
- gm: "torch.fx.GraphModule", # noqa: F821
28
+ fake_mode: FakeTensorMode,
29
+ gm: torch.fx.GraphModule,
30
30
  dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
31
31
  equalities_inputs: "EqualityConstraint", # noqa: F821
32
32
  original_signature: inspect.Signature,
@@ -982,16 +982,21 @@ def patched__broadcast_in_dim_meta(
982
982
  elif guard_or_false(a.shape[original_idx] != 1):
983
983
  new_strides.append(a.stride()[original_idx])
984
984
  else:
985
+ # This checks generates the following issue:
986
+ # non-broadcasting semantics require s3 == Max(s10, s3), False,
987
+ # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
988
+ # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
989
+ # original_idx=1
985
990
  torch._check(
986
991
  a.shape[original_idx] == shape[idx],
987
992
  lambda idx=idx, original_idx=original_idx: (
988
993
  f"non-broadcasting semantics require "
989
994
  f"{a.shape[original_idx]} == {shape[idx]}, "
990
995
  f"{guard_or_false(a.shape[idx] != 1)}, "
991
- f"guard_or_false(a.shape[idx] == 1)="
996
+ f"guard_or_false(a.shape[idx]==1)="
992
997
  f"{guard_or_false(a.shape[idx] == 1)}, "
993
- f"a.stride()={a.stride()}, idx={idx}, "
994
- f"original_idx={original_idx}"
998
+ f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
999
+ f"shape={shape}, original_idx={original_idx}"
995
1000
  ),
996
1001
  )
997
1002
  new_strides.append(a.stride()[original_idx])
@@ -1006,3 +1011,77 @@ def patched__broadcast_in_dim_meta(
1006
1011
  new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1007
1012
 
1008
1013
  return a.as_strided(shape, new_strides, a.storage_offset())
1014
+
1015
+
1016
+ def patched__broadcast_in_dim_meta_level_2(
1017
+ a: torch._prims_common.TensorLikeType,
1018
+ shape: torch._prims_common.ShapeType,
1019
+ broadcast_dimensions: Sequence[int],
1020
+ ):
1021
+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
1022
+ from torch.fx.experimental.symbolic_shapes import (
1023
+ guard_or_false,
1024
+ guard_or_true,
1025
+ sym_or,
1026
+ )
1027
+
1028
+ # Type checks
1029
+ assert isinstance(a, torch._prims_common.TensorLike)
1030
+ assert isinstance(shape, Sequence)
1031
+ assert isinstance(broadcast_dimensions, Sequence)
1032
+
1033
+ # every dimension must be accounted for
1034
+ assert a.ndim == len(broadcast_dimensions)
1035
+
1036
+ # broadcast shape must have weakly more dimensions
1037
+ assert len(shape) >= a.ndim
1038
+
1039
+ # broadcast_dimensions must be an ascending sequence
1040
+ # (no relative reordering of dims) of integers and
1041
+ # each dimension must be within the new shape
1042
+ def _greater_than_reduce(acc, x):
1043
+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
1044
+ assert x > acc
1045
+ assert x < len(shape)
1046
+
1047
+ return x
1048
+
1049
+ reduce(_greater_than_reduce, broadcast_dimensions, -1)
1050
+
1051
+ # shape must be broadcastable to
1052
+ for idx, new_idx in enumerate(broadcast_dimensions):
1053
+ torch._check(
1054
+ sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
1055
+ lambda idx=idx, new_idx=new_idx: (
1056
+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
1057
+ ),
1058
+ )
1059
+
1060
+ new_strides = []
1061
+ original_idx = 0
1062
+ for idx in range(len(shape)):
1063
+ if idx in broadcast_dimensions:
1064
+ # Assigns a stride of zero to dimensions
1065
+ # which were actually broadcast
1066
+ if guard_or_false(a.shape[original_idx] == 1):
1067
+ if guard_or_false(a.shape[original_idx] == shape[idx]):
1068
+ new_strides.append(a.stride()[original_idx])
1069
+ else:
1070
+ new_strides.append(0)
1071
+ # PATCHED: disabled this check
1072
+ elif guard_or_false(a.shape[original_idx] != 1):
1073
+ new_strides.append(a.stride()[original_idx])
1074
+ else:
1075
+ # PATCHED: torch._check was removed
1076
+ new_strides.append(a.stride()[original_idx])
1077
+ original_idx = original_idx + 1
1078
+ else:
1079
+ if guard_or_true(shape[idx] != 1):
1080
+ # consistent with previous use of guard_size_oblivious
1081
+ new_strides.append(0)
1082
+ elif original_idx == a.ndim:
1083
+ new_strides.append(1)
1084
+ else:
1085
+ new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1086
+
1087
+ return a.as_strided(shape, new_strides, a.storage_offset())
@@ -1276,6 +1276,60 @@ def common_eager_attention_forward(
1276
1276
  return attn_output, attn_weights
1277
1277
 
1278
1278
 
1279
+ def patched_sdpa_attention_forward(
1280
+ module: torch.nn.Module,
1281
+ query: torch.Tensor,
1282
+ key: torch.Tensor,
1283
+ value: torch.Tensor,
1284
+ attention_mask: Optional[torch.Tensor],
1285
+ dropout: float = 0.0,
1286
+ scaling: Optional[float] = None,
1287
+ is_causal: Optional[bool] = None,
1288
+ **kwargs,
1289
+ ) -> tuple[torch.Tensor, None]:
1290
+ """[patch:transformers.integrations.sdpa_attention.sdpa_attention_forward]"""
1291
+ assert not kwargs.get("output_attentions", False), (
1292
+ "`sdpa` attention does not support `output_attentions=True`."
1293
+ " Please set your attention to `eager` if you want any of these features."
1294
+ )
1295
+ sdpa_kwargs = {}
1296
+ if hasattr(module, "num_key_value_groups"):
1297
+ if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
1298
+ key = transformers.integrations.sdpa_attention.repeat_kv(
1299
+ key, module.num_key_value_groups
1300
+ )
1301
+ value = transformers.integrations.sdpa_attention.repeat_kv(
1302
+ value, module.num_key_value_groups
1303
+ )
1304
+ else:
1305
+ sdpa_kwargs = {"enable_gqa": True}
1306
+
1307
+ if attention_mask is not None and attention_mask.ndim == 4:
1308
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1309
+
1310
+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1311
+ # PATCHED: remove the test query.shape[2] > 1
1312
+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1313
+ is_causal = attention_mask is None and is_causal
1314
+
1315
+ torch._check(
1316
+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
1317
+ "Attention mask shape incompatible with key shape.",
1318
+ )
1319
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1320
+ query,
1321
+ key,
1322
+ value,
1323
+ attn_mask=attention_mask,
1324
+ dropout_p=dropout,
1325
+ scale=scaling,
1326
+ is_causal=is_causal,
1327
+ **sdpa_kwargs,
1328
+ )
1329
+ attn_output = attn_output.transpose(1, 2).contiguous()
1330
+ return attn_output, None
1331
+
1332
+
1279
1333
  def patched_model_bart_eager_attention_forward(
1280
1334
  module: torch.nn.Module,
1281
1335
  query: torch.Tensor,