onnx-diagnostic 0.7.9__py3-none-any.whl → 0.7.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +8 -1
  3. onnx_diagnostic/helpers/cache_helper.py +12 -10
  4. onnx_diagnostic/helpers/helper.py +8 -0
  5. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  6. onnx_diagnostic/helpers/torch_helper.py +14 -4
  7. onnx_diagnostic/reference/ops/op_scan.py +5 -5
  8. onnx_diagnostic/reference/ort_evaluator.py +2 -2
  9. onnx_diagnostic/tasks/__init__.py +4 -2
  10. onnx_diagnostic/tasks/image_to_video.py +127 -0
  11. onnx_diagnostic/torch_export_patches/eval/model_cases.py +3 -3
  12. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +98 -4
  13. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +42 -2
  14. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +0 -1
  15. onnx_diagnostic/torch_models/hghub/hub_api.py +69 -22
  16. onnx_diagnostic/torch_models/hghub/hub_data.py +5 -1
  17. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  18. onnx_diagnostic/torch_models/hghub/model_inputs.py +173 -128
  19. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  20. onnx_diagnostic/torch_models/untrained/llm_phi2.py +11 -3
  21. onnx_diagnostic/torch_models/validate.py +146 -17
  22. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  23. {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/METADATA +2 -2
  24. {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/RECORD +27 -25
  25. {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/WHEEL +0 -0
  26. {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/licenses/LICENSE.txt +0 -0
  27. {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.9"
6
+ __version__ = "0.7.11"
7
7
  __author__ = "Xavier Dupré"
@@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
474
474
  )
475
475
  parser.add_argument(
476
476
  "--runtime",
477
- choices=["onnxruntime", "torch", "ref"],
477
+ choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
478
478
  default="onnxruntime",
479
479
  help="onnx runtime to use, `onnxruntime` by default",
480
480
  )
@@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser:
542
542
  "the onnx exporter should use.",
543
543
  default="",
544
544
  )
545
+ parser.add_argument(
546
+ "--ort-logs",
547
+ default=False,
548
+ action=BooleanOptionalAction,
549
+ help="Enables onnxruntime logging when the session is created",
550
+ )
545
551
  return parser
546
552
 
547
553
 
@@ -601,6 +607,7 @@ def _cmd_validate(argv: List[Any]):
601
607
  repeat=args.repeat,
602
608
  warmup=args.warmup,
603
609
  inputs2=args.inputs2,
610
+ ort_logs=args.ort_logs,
604
611
  output_names=(
605
612
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
606
613
  ),
@@ -4,11 +4,6 @@ import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
- try:
8
- from transformers.models.mamba.modeling_mamba import MambaCache
9
- except ImportError:
10
- from transformers.cache_utils import MambaCache
11
-
12
7
 
13
8
  class CacheKeyValue:
14
9
  """
@@ -270,7 +265,7 @@ def make_static_cache(
270
265
  self.num_attention_heads = key_value_pairs[0][0].shape[1]
271
266
  self.num_hidden_layers = len(key_value_pairs)
272
267
 
273
- def get_text_config(self):
268
+ def get_text_config(self, *args, **kwargs):
274
269
  return self
275
270
 
276
271
  assert max_cache_len is not None, (
@@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
354
349
  )
355
350
 
356
351
 
357
- def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
352
+ def make_mamba_cache(
353
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
354
+ ) -> "MambaCache": # noqa: F821
358
355
  "Creates a ``MambaCache``."
356
+ # import is moved here because this part is slow.
357
+ try:
358
+ from transformers.models.mamba.modeling_mamba import MambaCache
359
+ except ImportError:
360
+ from transformers.cache_utils import MambaCache
359
361
  dtype = key_value_pairs[0][0].dtype
360
362
 
361
363
  class _config:
@@ -366,7 +368,7 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
366
368
  self.num_hidden_layers = len(key_value_pairs)
367
369
  self.dtype = dtype
368
370
 
369
- def get_text_config(self):
371
+ def get_text_config(self, *args, **kwargs):
370
372
  return self
371
373
 
372
374
  cache = MambaCache(
@@ -409,7 +411,7 @@ def make_sliding_window_cache(
409
411
  self.num_hidden_layers = len(key_value_pairs)
410
412
  self.sliding_window = key_value_pairs[0][0].shape[2]
411
413
 
412
- def get_text_config(self):
414
+ def get_text_config(self, *args, **kwargs):
413
415
  return self
414
416
 
415
417
  cache = transformers.cache_utils.SlidingWindowCache(
@@ -577,7 +579,7 @@ def make_hybrid_cache(
577
579
  sliding_window = _sliding_window
578
580
  num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
579
581
 
580
- def get_text_config(self):
582
+ def get_text_config(self, *args, **kwargs):
581
583
  return self
582
584
 
583
585
  if layer_types:
@@ -774,6 +774,14 @@ def string_type(
774
774
  return f"{obj.__class__.__name__}(**{s})"
775
775
  if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
776
776
  return f"{obj.__class__.__name__}(...)"
777
+ if obj.__class__.__name__ == "Results":
778
+ import ultralytics
779
+
780
+ assert isinstance(
781
+ obj, ultralytics.engine.results.Results
782
+ ), f"Unexpected type={type(obj)}"
783
+ return f"ultralytics.{obj.__class__.__name__}(...)"
784
+
777
785
  if verbose:
778
786
  print(f"[string_type] END:{type(obj)}")
779
787
  raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
@@ -1186,7 +1186,7 @@ def shadowing_names(
1186
1186
  shadow |= set(i.name for i in g.input) & shadow_context
1187
1187
  shadow |= set(i.name for i in g.initializer) & shadow_context
1188
1188
  shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1189
- s, ps, c = shadowing_names(
1189
+ s, _ps, c = shadowing_names(
1190
1190
  g.node, verbose=verbose, existing=existing, shadow_context=existing
1191
1191
  )
1192
1192
  shadow |= s
@@ -543,7 +543,7 @@ def dummy_llm(
543
543
  )
544
544
 
545
545
  def forward(self, x):
546
- B, T, C = x.shape
546
+ _B, T, C = x.shape
547
547
 
548
548
  query = self.query(x)
549
549
  key = self.key(x)
@@ -721,9 +721,10 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
721
721
  return {to_any(t, to_value) for t in value}
722
722
  if type(value) is dict:
723
723
  return {k: to_any(t, to_value) for k, t in value.items()}
724
- if value.__class__.__name__ == "DynamicCache":
724
+ if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
725
+ make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
725
726
  cc = CacheKeyValue(value)
726
- return make_dynamic_cache(
727
+ return make[value.__class__.__name__]( # type: ignore[operator]
727
728
  list(
728
729
  zip(
729
730
  [t.to(to_value) if t is not None else t for t in cc.key_cache],
@@ -822,6 +823,15 @@ def torch_deepcopy(value: Any) -> Any:
822
823
  new_args = torch_deepcopy(args)
823
824
  return torch.utils._pytree.tree_unflatten(new_args, spec)
824
825
 
826
+ if value.__class__.__name__ == "Results":
827
+ import copy
828
+ import ultralytics
829
+
830
+ assert isinstance(
831
+ value, ultralytics.engine.results.Results
832
+ ), f"Unexpected type={type(value)}"
833
+ return copy.deepcopy(value)
834
+
825
835
  # We should have a code using serialization, deserialization assuming a model
826
836
  # cannot be exported without them.
827
837
  raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
@@ -856,7 +866,7 @@ def torch_tensor_size(value: Any) -> Any:
856
866
  if value.__class__.__name__ == "MambaCache":
857
867
  return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
858
868
  if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
859
- args, spec = torch.utils._pytree.tree_flatten(value)
869
+ args, _spec = torch.utils._pytree.tree_flatten(value)
860
870
  return sum(torch_tensor_size(a) for a in args)
861
871
 
862
872
  # We should have a code using serialization, deserialization assuming a model
@@ -26,11 +26,11 @@ class Scan(_Scan):
26
26
  ):
27
27
  (
28
28
  num_loop_state_vars,
29
- num_scan_outputs,
30
- output_directions,
31
- max_dir_out,
32
- output_axes,
33
- max_axe_out,
29
+ _num_scan_outputs,
30
+ _output_directions,
31
+ _max_dir_out,
32
+ _output_axes,
33
+ _max_axe_out,
34
34
  state_names_in,
35
35
  state_names_out,
36
36
  scan_names_in,
@@ -562,7 +562,7 @@ class OnnxruntimeEvaluator:
562
562
  if key in self._cache:
563
563
  sess = self._cache[key][1]
564
564
  else:
565
- self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results)
565
+ self._cache[key] = _onx, sess = self._get_sess_if(node, name, inputs, results)
566
566
 
567
567
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
568
568
  feeds = {name: results[name] for name in sess.input_names}
@@ -616,7 +616,7 @@ class OnnxruntimeEvaluator:
616
616
  if key in self._cache:
617
617
  sess = self._cache[key][1]
618
618
  else:
619
- self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results)
619
+ self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
620
620
 
621
621
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
622
622
  feeds = {name: results[name] for name in sess.input_names}
@@ -5,6 +5,8 @@ from . import (
5
5
  fill_mask,
6
6
  image_classification,
7
7
  image_text_to_text,
8
+ image_to_video,
9
+ mask_generation,
8
10
  mixture_of_expert,
9
11
  object_detection,
10
12
  sentence_similarity,
@@ -14,7 +16,6 @@ from . import (
14
16
  text_to_image,
15
17
  text2text_generation,
16
18
  zero_shot_image_classification,
17
- mask_generation,
18
19
  )
19
20
 
20
21
  __TASKS__ = [
@@ -23,6 +24,8 @@ __TASKS__ = [
23
24
  fill_mask,
24
25
  image_classification,
25
26
  image_text_to_text,
27
+ image_to_video,
28
+ mask_generation,
26
29
  mixture_of_expert,
27
30
  object_detection,
28
31
  sentence_similarity,
@@ -32,7 +35,6 @@ __TASKS__ = [
32
35
  text_to_image,
33
36
  text2text_generation,
34
37
  zero_shot_image_classification,
35
- mask_generation,
36
38
  ]
37
39
 
38
40
 
@@ -0,0 +1,127 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "image-to-video"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
15
+ # We cannot reduce.
16
+ return {}
17
+ check_hasattr(config, ("num_hidden_layers", "num_layers"))
18
+ kwargs = {}
19
+ if hasattr(config, "num_layers"):
20
+ kwargs["num_layers"] = min(config.num_layers, nhl())
21
+ if hasattr(config, "num_hidden_layers"):
22
+ kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
23
+
24
+ update_config(config, kwargs)
25
+ return kwargs
26
+
27
+
28
+ def get_inputs(
29
+ model: torch.nn.Module,
30
+ config: Optional[Any],
31
+ text_embed_dim: int,
32
+ latent_channels: int,
33
+ batch_size: int = 2,
34
+ image_height: int = 704,
35
+ image_width: int = 1280,
36
+ latent_frames: int = 1,
37
+ text_maxlen: int = 512,
38
+ add_second_input: int = 1,
39
+ **kwargs, # unused
40
+ ):
41
+ """
42
+ Generates inputs for task ``image-to-video``.
43
+ """
44
+ assert (
45
+ "cls_cache" not in kwargs
46
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
47
+ latent_height = image_height // 8
48
+ latent_width = image_width // 8
49
+ dtype = torch.float32
50
+
51
+ inputs = dict(
52
+ hidden_states=torch.randn(
53
+ batch_size,
54
+ latent_channels,
55
+ latent_frames,
56
+ latent_height,
57
+ latent_width,
58
+ dtype=dtype,
59
+ ),
60
+ timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
61
+ encoder_hidden_states=torch.randn(
62
+ batch_size, text_maxlen, text_embed_dim, dtype=dtype
63
+ ),
64
+ padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
65
+ fps=torch.tensor([16] * batch_size, dtype=dtype),
66
+ condition_mask=torch.randn(
67
+ batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
68
+ ),
69
+ )
70
+ shapes = dict(
71
+ hidden_states={
72
+ 0: "batch_size",
73
+ 2: "latent_frames",
74
+ 3: "latent_height",
75
+ 4: "latent_width",
76
+ },
77
+ timestep={0: "batch_size"},
78
+ encoder_hidden_states={0: "batch_size"},
79
+ padding_mask={0: "batch_size", 2: "height", 3: "width"},
80
+ fps={0: "batch_size"},
81
+ condition_mask={
82
+ 0: "batch_size",
83
+ 2: "latent_frames",
84
+ 3: "latent_height",
85
+ 4: "latent_width",
86
+ },
87
+ )
88
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
89
+
90
+ if add_second_input:
91
+ assert (
92
+ add_second_input > 0
93
+ ), f"Not implemented for add_second_input={add_second_input}."
94
+ res["inputs2"] = get_inputs(
95
+ model=model,
96
+ config=config,
97
+ text_embed_dim=text_embed_dim,
98
+ latent_channels=latent_channels,
99
+ batch_size=batch_size,
100
+ image_height=image_height,
101
+ image_width=image_width,
102
+ latent_frames=latent_frames,
103
+ text_maxlen=text_maxlen,
104
+ add_second_input=0,
105
+ **kwargs,
106
+ )["inputs"]
107
+ return res
108
+
109
+
110
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
111
+ """
112
+ Inputs kwargs.
113
+
114
+ If the configuration is None, the function selects typical dimensions.
115
+ """
116
+ if config is not None:
117
+ check_hasattr(config, "in_channels", "text_embed_dim"),
118
+ kwargs = dict(
119
+ text_embed_dim=1024 if config is None else config.text_embed_dim,
120
+ latent_channels=16 if config is None else config.in_channels - 1,
121
+ batch_size=1,
122
+ image_height=8 * 50,
123
+ image_width=8 * 80,
124
+ latent_frames=1,
125
+ text_maxlen=512,
126
+ )
127
+ return kwargs, get_inputs
@@ -384,7 +384,7 @@ class ControlFlowScan(torch.nn.Module):
384
384
 
385
385
  def forward(self, x):
386
386
  init = torch.zeros_like(x[0])
387
- carry, out = torch.ops.higher_order.scan(
387
+ carry, _out = torch.ops.higher_order.scan(
388
388
  ControlFlowScan.add, [init], [x], additional_inputs=[]
389
389
  )
390
390
  return carry
@@ -429,7 +429,7 @@ class ControlFlowScanCDist(torch.nn.Module):
429
429
  return [carry.clone(), rd]
430
430
 
431
431
  def forward(self, x):
432
- carry, out = torch.ops.higher_order.scan(
432
+ _carry, out = torch.ops.higher_order.scan(
433
433
  ControlFlowScanCDist.dist,
434
434
  [x],
435
435
  [x],
@@ -483,7 +483,7 @@ class ControlFlowScanCDistXY(torch.nn.Module):
483
483
  return [y.clone(), rd]
484
484
 
485
485
  def forward(self, x, y):
486
- carry, out = torch.ops.higher_order.scan(
486
+ _carry, out = torch.ops.higher_order.scan(
487
487
  ControlFlowScanCDistXY.dist,
488
488
  [y],
489
489
  [x],
@@ -439,6 +439,28 @@ def torch_export_patches(
439
439
  f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
440
440
  masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
441
441
 
442
+ if verbose:
443
+ print(
444
+ "[torch_export_patches] patches "
445
+ "transformers.masking_utils.sdpa_mask_recent_torch"
446
+ )
447
+ f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
448
+ masking_utils.sdpa_mask_recent_torch = (
449
+ patch_transformers_list.patched_sdpa_mask_recent_torch
450
+ )
451
+ if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
452
+ if verbose:
453
+ print(
454
+ "[torch_export_patches] patches "
455
+ "transformers.masking_utils.sdpa_mask"
456
+ )
457
+ f_transformers_sdpa_mask = masking_utils.sdpa_mask
458
+ masking_utils.sdpa_mask = (
459
+ patch_transformers_list.patched_sdpa_mask_recent_torch
460
+ )
461
+ else:
462
+ f_transformers_sdpa_mask = None
463
+
442
464
  if (
443
465
  masking_utils
444
466
  and patch_transformers_list.patch_masking_utils
@@ -456,10 +478,37 @@ def torch_export_patches(
456
478
  and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
457
479
  == f_transformers_eager_mask
458
480
  ):
481
+ if verbose:
482
+ print(
483
+ "[torch_export_patches] patches "
484
+ "transformers.masking_utils.eager_mask "
485
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
486
+ )
459
487
  masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
460
488
  patch_transformers_list.patched_eager_mask
461
489
  )
462
490
 
491
+ if (
492
+ masking_utils
493
+ and patch_transformers_list.patch_masking_utils
494
+ and hasattr(masking_utils, "sdpa_mask")
495
+ and f_transformers_sdpa_mask is not None
496
+ ):
497
+ if verbose:
498
+ print(
499
+ "[torch_export_patches] patches "
500
+ "transformers.masking_utils.sdpa_mask "
501
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
502
+ )
503
+ if (
504
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
505
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
506
+ == f_transformers_sdpa_mask
507
+ ):
508
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
509
+ patch_transformers_list.patched_sdpa_mask_recent_torch
510
+ )
511
+
463
512
  if custom_patches:
464
513
  if verbose:
465
514
  print("[torch_export_patches] applies custom patches")
@@ -568,12 +617,31 @@ def torch_export_patches(
568
617
  and hasattr(masking_utils, "_vmap_for_bhqkv")
569
618
  ):
570
619
  masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
620
+
571
621
  if verbose:
572
622
  print(
573
623
  "[torch_export_patches] restored "
574
624
  "transformers.masking_utils._vmap_for_bhqkv"
575
625
  )
576
626
 
627
+ masking_utils.sdpa_mask_recent_torch = (
628
+ f_transformers_sdpa_mask_recent_torch
629
+ )
630
+
631
+ if verbose:
632
+ print(
633
+ "[torch_export_patches] restored "
634
+ "transformers.masking_utils.sdpa_mask_recent_torch"
635
+ )
636
+
637
+ if f_transformers_sdpa_mask is not None:
638
+ masking_utils.sdpa_mask = f_transformers_sdpa_mask
639
+ if verbose:
640
+ print(
641
+ "[torch_export_patches] restored "
642
+ "transformers.masking_utils.sdpa_mask"
643
+ )
644
+
577
645
  if (
578
646
  masking_utils
579
647
  and patch_transformers_list.patch_masking_utils
@@ -581,6 +649,11 @@ def torch_export_patches(
581
649
  ):
582
650
  f_transformers_eager_mask = masking_utils.eager_mask
583
651
  masking_utils.eager_mask = f_transformers_eager_mask
652
+ if verbose:
653
+ print(
654
+ "[torch_export_patches] restored "
655
+ "transformers.masking_utils.eager_mask"
656
+ )
584
657
  if (
585
658
  "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
586
659
  and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
@@ -589,11 +662,32 @@ def torch_export_patches(
589
662
  masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
590
663
  f_transformers_eager_mask
591
664
  )
592
- if verbose:
593
- print(
594
- "[torch_export_patches] restored "
595
- "transformers.masking_utils.eager_mask"
665
+ if verbose:
666
+ print(
667
+ "[torch_export_patches] restored "
668
+ "transformers.masking_utils.eager_mask "
669
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
670
+ )
671
+
672
+ if (
673
+ masking_utils
674
+ and patch_transformers_list.patch_masking_utils
675
+ and hasattr(masking_utils, "sdpa_mask")
676
+ ):
677
+ if (
678
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
679
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
680
+ == patch_transformers_list.patched_sdpa_mask_recent_torch
681
+ ):
682
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
683
+ f_transformers_sdpa_mask
596
684
  )
685
+ if verbose:
686
+ print(
687
+ "[torch_export_patches] restored "
688
+ "transformers.masking_utils.sdpa_mask "
689
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
690
+ )
597
691
 
598
692
  ########
599
693
  # caches
@@ -35,9 +35,18 @@ except ImportError:
35
35
  from ...ext_test_case import has_transformers
36
36
  from ...helpers.torch_helper import is_torchdynamo_exporting
37
37
 
38
+ patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
39
+
40
+
38
41
  if patch_masking_utils:
39
42
  # Introduced in 4.52
40
- from transformers.masking_utils import causal_mask_function, sdpa_mask
43
+ from transformers.masking_utils import (
44
+ causal_mask_function,
45
+ padding_mask_function,
46
+ and_masks,
47
+ _ignore_causal_mask_sdpa,
48
+ prepare_padding_mask,
49
+ )
41
50
 
42
51
  def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
43
52
  """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
@@ -105,7 +114,7 @@ if patch_masking_utils:
105
114
  """manual patch for function ``transformers.masking_utils.eager_mask``."""
106
115
  # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
107
116
  _ = kwargs.pop("allow_is_causal_skip", None)
108
- mask = sdpa_mask(
117
+ mask = patched_sdpa_mask_recent_torch(
109
118
  batch_size=batch_size,
110
119
  cache_position=cache_position,
111
120
  kv_length=kv_length,
@@ -125,6 +134,35 @@ if patch_masking_utils:
125
134
  mask = (~mask).to(dtype) * min_dtype
126
135
  return mask
127
136
 
137
+ def patched_sdpa_mask_recent_torch(
138
+ batch_size: int,
139
+ cache_position: torch.Tensor,
140
+ kv_length: int,
141
+ kv_offset: int = 0,
142
+ mask_function: Callable = causal_mask_function,
143
+ attention_mask: Optional[torch.Tensor] = None,
144
+ local_size: Optional[int] = None,
145
+ allow_is_causal_skip: bool = True,
146
+ **kwargs,
147
+ ) -> Optional[torch.Tensor]:
148
+ """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
149
+ q_length = cache_position.shape[0]
150
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
151
+ if allow_is_causal_skip and _ignore_causal_mask_sdpa(
152
+ padding_mask, q_length, kv_length, kv_offset, local_size
153
+ ):
154
+ return None
155
+ kv_arange = torch.arange(kv_length, device=cache_position.device)
156
+ kv_arange += kv_offset
157
+ if padding_mask is not None:
158
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
159
+ batch_arange = torch.arange(batch_size, device=cache_position.device)
160
+ head_arange = torch.arange(1, device=cache_position.device)
161
+ causal_mask = patched__vmap_for_bhqkv(mask_function)(
162
+ batch_arange, head_arange, cache_position, kv_arange
163
+ )
164
+ return causal_mask
165
+
128
166
 
129
167
  if patch_parse_processor_args:
130
168
 
@@ -178,6 +216,8 @@ if patch_DynamicLayer:
178
216
  new_shape[-2] = 0
179
217
  self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
180
218
  self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
219
+ if patch_is_initialized:
220
+ self.is_initialized = True
181
221
 
182
222
 
183
223
  def _patch_make_causal_mask(
@@ -218,7 +218,6 @@ def unflatten_sliding_window_cache(
218
218
  values: List[Any], context: torch.utils._pytree.Context, output_type=None
219
219
  ) -> SlidingWindowCache:
220
220
  """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
221
- key_cache, value_cache = values
222
221
  return make_sliding_window_cache(list(zip(values[0], values[1])))
223
222
 
224
223