onnx-diagnostic 0.7.8__py3-none-any.whl → 0.7.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +2 -2
  3. onnx_diagnostic/helpers/_log_helper.py +4 -2
  4. onnx_diagnostic/helpers/cache_helper.py +4 -4
  5. onnx_diagnostic/helpers/helper.py +8 -0
  6. onnx_diagnostic/helpers/log_helper.py +7 -1
  7. onnx_diagnostic/helpers/model_builder_helper.py +5 -0
  8. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  9. onnx_diagnostic/helpers/torch_helper.py +14 -4
  10. onnx_diagnostic/reference/ops/op_scan.py +5 -5
  11. onnx_diagnostic/reference/ort_evaluator.py +2 -2
  12. onnx_diagnostic/tasks/automatic_speech_recognition.py +1 -1
  13. onnx_diagnostic/tasks/feature_extraction.py +1 -1
  14. onnx_diagnostic/tasks/fill_mask.py +1 -1
  15. onnx_diagnostic/tasks/image_text_to_text.py +2 -2
  16. onnx_diagnostic/tasks/sentence_similarity.py +1 -1
  17. onnx_diagnostic/tasks/summarization.py +1 -1
  18. onnx_diagnostic/tasks/text2text_generation.py +1 -1
  19. onnx_diagnostic/tasks/text_classification.py +1 -1
  20. onnx_diagnostic/tasks/text_generation.py +1 -1
  21. onnx_diagnostic/tasks/zero_shot_image_classification.py +1 -1
  22. onnx_diagnostic/torch_export_patches/eval/model_cases.py +3 -3
  23. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +98 -4
  24. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
  25. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +37 -2
  26. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +0 -1
  27. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -0
  28. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  29. onnx_diagnostic/torch_models/hghub/model_inputs.py +139 -126
  30. onnx_diagnostic/torch_models/hghub/model_specific.py +49 -0
  31. onnx_diagnostic/torch_models/untrained/llm_phi2.py +11 -3
  32. onnx_diagnostic/torch_models/validate.py +44 -4
  33. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  34. {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/METADATA +2 -2
  35. {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/RECORD +38 -37
  36. {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.8"
6
+ __version__ = "0.7.10"
7
7
  __author__ = "Xavier Dupré"
@@ -850,13 +850,13 @@ def get_parser_agg() -> ArgumentParser:
850
850
  "--filter-in",
851
851
  default="",
852
852
  help="adds a filter to filter in data, syntax is\n"
853
- '``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
853
+ '``"<column1>:<value1>;<value2>//<column2>:<value3>"`` ...',
854
854
  )
855
855
  parser.add_argument(
856
856
  "--filter-out",
857
857
  default="",
858
858
  help="adds a filter to filter out data, syntax is\n"
859
- '``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
859
+ '``"<column1>:<value1>;<value2>//<column2>:<value3>"`` ...',
860
860
  )
861
861
  parser.add_argument(
862
862
  "--sbs",
@@ -118,9 +118,11 @@ def filter_data(
118
118
  if isinstance(fmt, str):
119
119
  cols = fmt.split("//")
120
120
  for c in cols:
121
- assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}"
121
+ assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}, cols={cols!r}"
122
122
  spl = c.split(":")
123
- assert len(spl) == 2, f"Unexpected value {c!r} in fmt={fmt!r}"
123
+ assert (
124
+ len(spl) == 2
125
+ ), f"Unexpected value {c!r} in fmt={fmt!r}, spl={spl}, cols={cols}"
124
126
  name, fil = spl
125
127
  cond[name] = set(fil.split(";"))
126
128
  return cond
@@ -270,7 +270,7 @@ def make_static_cache(
270
270
  self.num_attention_heads = key_value_pairs[0][0].shape[1]
271
271
  self.num_hidden_layers = len(key_value_pairs)
272
272
 
273
- def get_text_config(self):
273
+ def get_text_config(self, *args, **kwargs):
274
274
  return self
275
275
 
276
276
  assert max_cache_len is not None, (
@@ -366,7 +366,7 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
366
366
  self.num_hidden_layers = len(key_value_pairs)
367
367
  self.dtype = dtype
368
368
 
369
- def get_text_config(self):
369
+ def get_text_config(self, *args, **kwargs):
370
370
  return self
371
371
 
372
372
  cache = MambaCache(
@@ -409,7 +409,7 @@ def make_sliding_window_cache(
409
409
  self.num_hidden_layers = len(key_value_pairs)
410
410
  self.sliding_window = key_value_pairs[0][0].shape[2]
411
411
 
412
- def get_text_config(self):
412
+ def get_text_config(self, *args, **kwargs):
413
413
  return self
414
414
 
415
415
  cache = transformers.cache_utils.SlidingWindowCache(
@@ -577,7 +577,7 @@ def make_hybrid_cache(
577
577
  sliding_window = _sliding_window
578
578
  num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
579
579
 
580
- def get_text_config(self):
580
+ def get_text_config(self, *args, **kwargs):
581
581
  return self
582
582
 
583
583
  if layer_types:
@@ -774,6 +774,14 @@ def string_type(
774
774
  return f"{obj.__class__.__name__}(**{s})"
775
775
  if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
776
776
  return f"{obj.__class__.__name__}(...)"
777
+ if obj.__class__.__name__ == "Results":
778
+ import ultralytics
779
+
780
+ assert isinstance(
781
+ obj, ultralytics.engine.results.Results
782
+ ), f"Unexpected type={type(obj)}"
783
+ return f"ultralytics.{obj.__class__.__name__}(...)"
784
+
777
785
  if verbose:
778
786
  print(f"[string_type] END:{type(obj)}")
779
787
  raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
@@ -1,5 +1,6 @@
1
1
  import enum
2
2
  import io
3
+ import os
3
4
  import pprint
4
5
  import re
5
6
  import warnings
@@ -270,6 +271,10 @@ class CubePlot:
270
271
  def _to_images_bar(
271
272
  self, verbose: int = 0, merge: bool = True, title_suffix: Optional[str] = None
272
273
  ) -> List[bytes]:
274
+ """
275
+ Environment variable ``FIGSIZEH`` can be set to increase the
276
+ graph height. Default is 1.0.
277
+ """
273
278
  assert merge, f"merge={merge} not implemented yet"
274
279
  import matplotlib.pyplot as plt
275
280
 
@@ -279,7 +284,8 @@ class CubePlot:
279
284
  n_cols = 3
280
285
  nn = df.shape[1] // n_cols
281
286
  nn += int(df.shape[1] % n_cols != 0)
282
- fig, axs = plt.subplots(nn, n_cols, figsize=(6 * n_cols, nn * df.shape[0] / 5))
287
+ ratio = float(os.environ.get("FIGSIZEH", "1"))
288
+ fig, axs = plt.subplots(nn, n_cols, figsize=(6 * n_cols, nn * df.shape[0] / 3 * ratio))
283
289
  pos = 0
284
290
  imgs = []
285
291
  for c in self._make_loop(df.columns, verbose):
@@ -201,10 +201,12 @@ def create_model_builder(
201
201
  arch_map = {
202
202
  "ChatGLMForConditionalGeneration": builder.ChatGLMModel,
203
203
  "ChatGLMModel": builder.ChatGLMModel,
204
+ "Ernie4_5_ForCausalLM": builder.ErnieModel,
204
205
  "GemmaForCausalLM": builder.Gemma2Model,
205
206
  "Gemma3ForCausalLM": builder.Gemma3Model,
206
207
  "Gemma3ForConditionalGeneration": builder.Gemma3Model,
207
208
  "GraniteForCausalLM": builder.GraniteModel,
209
+ "GptOssForCausalLM": builder.GPTOSSModel,
208
210
  "LlamaForCausalLM": builder.LlamaModel,
209
211
  "MistralForCausalLM": builder.MistralModel,
210
212
  "NemotronForCausalLM": builder.NemotronModel,
@@ -235,6 +237,7 @@ def create_model_builder(
235
237
  "Phi4MMForCausalLM": builder.Phi4MMModel,
236
238
  "Qwen2ForCausalLM": builder.QwenModel,
237
239
  "Qwen3ForCausalLM": builder.Qwen3Model,
240
+ "SmolLM3ForCausalLM": builder.SmolLM3Model,
238
241
  }
239
242
 
240
243
  assert config.architectures[0] in arch_map, (
@@ -276,6 +279,8 @@ def create_model_builder(
276
279
  for key in text_config:
277
280
  if not hasattr(config, key):
278
281
  setattr(config, key, getattr(text_config, key))
282
+ elif config.architectures[0] == "GptOssForCausalLM":
283
+ delattr(config, "quantization_config")
279
284
  elif (
280
285
  config.architectures[0] == "PhiMoEForCausalLM"
281
286
  and config.max_position_embeddings != config.original_max_position_embeddings
@@ -1186,7 +1186,7 @@ def shadowing_names(
1186
1186
  shadow |= set(i.name for i in g.input) & shadow_context
1187
1187
  shadow |= set(i.name for i in g.initializer) & shadow_context
1188
1188
  shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1189
- s, ps, c = shadowing_names(
1189
+ s, _ps, c = shadowing_names(
1190
1190
  g.node, verbose=verbose, existing=existing, shadow_context=existing
1191
1191
  )
1192
1192
  shadow |= s
@@ -543,7 +543,7 @@ def dummy_llm(
543
543
  )
544
544
 
545
545
  def forward(self, x):
546
- B, T, C = x.shape
546
+ _B, T, C = x.shape
547
547
 
548
548
  query = self.query(x)
549
549
  key = self.key(x)
@@ -721,9 +721,10 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
721
721
  return {to_any(t, to_value) for t in value}
722
722
  if type(value) is dict:
723
723
  return {k: to_any(t, to_value) for k, t in value.items()}
724
- if value.__class__.__name__ == "DynamicCache":
724
+ if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
725
+ make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
725
726
  cc = CacheKeyValue(value)
726
- return make_dynamic_cache(
727
+ return make[value.__class__.__name__]( # type: ignore[operator]
727
728
  list(
728
729
  zip(
729
730
  [t.to(to_value) if t is not None else t for t in cc.key_cache],
@@ -822,6 +823,15 @@ def torch_deepcopy(value: Any) -> Any:
822
823
  new_args = torch_deepcopy(args)
823
824
  return torch.utils._pytree.tree_unflatten(new_args, spec)
824
825
 
826
+ if value.__class__.__name__ == "Results":
827
+ import copy
828
+ import ultralytics
829
+
830
+ assert isinstance(
831
+ value, ultralytics.engine.results.Results
832
+ ), f"Unexpected type={type(value)}"
833
+ return copy.deepcopy(value)
834
+
825
835
  # We should have a code using serialization, deserialization assuming a model
826
836
  # cannot be exported without them.
827
837
  raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
@@ -856,7 +866,7 @@ def torch_tensor_size(value: Any) -> Any:
856
866
  if value.__class__.__name__ == "MambaCache":
857
867
  return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
858
868
  if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
859
- args, spec = torch.utils._pytree.tree_flatten(value)
869
+ args, _spec = torch.utils._pytree.tree_flatten(value)
860
870
  return sum(torch_tensor_size(a) for a in args)
861
871
 
862
872
  # We should have a code using serialization, deserialization assuming a model
@@ -26,11 +26,11 @@ class Scan(_Scan):
26
26
  ):
27
27
  (
28
28
  num_loop_state_vars,
29
- num_scan_outputs,
30
- output_directions,
31
- max_dir_out,
32
- output_axes,
33
- max_axe_out,
29
+ _num_scan_outputs,
30
+ _output_directions,
31
+ _max_dir_out,
32
+ _output_axes,
33
+ _max_axe_out,
34
34
  state_names_in,
35
35
  state_names_out,
36
36
  scan_names_in,
@@ -562,7 +562,7 @@ class OnnxruntimeEvaluator:
562
562
  if key in self._cache:
563
563
  sess = self._cache[key][1]
564
564
  else:
565
- self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results)
565
+ self._cache[key] = _onx, sess = self._get_sess_if(node, name, inputs, results)
566
566
 
567
567
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
568
568
  feeds = {name: results[name] for name in sess.input_names}
@@ -616,7 +616,7 @@ class OnnxruntimeEvaluator:
616
616
  if key in self._cache:
617
617
  sess = self._cache[key][1]
618
618
  else:
619
- self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results)
619
+ self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
620
620
 
621
621
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
622
622
  feeds = {name: results[name] for name in sess.input_names}
@@ -76,7 +76,7 @@ def get_inputs(
76
76
  assert (
77
77
  "cls_cache" not in kwargs
78
78
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
79
- batch = torch.export.Dim("batch", min=1, max=1024)
79
+ batch = "batch"
80
80
  seq_length = "seq_length"
81
81
 
82
82
  shapes = {
@@ -47,7 +47,7 @@ def get_inputs(
47
47
  assert (
48
48
  "cls_cache" not in kwargs
49
49
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
50
- batch = torch.export.Dim("batch", min=1, max=1024)
50
+ batch = "batch"
51
51
  seq_length = "sequence_length"
52
52
  shapes = {
53
53
  "input_ids": {0: batch, 1: seq_length},
@@ -42,7 +42,7 @@ def get_inputs(
42
42
  assert (
43
43
  "cls_cache" not in kwargs
44
44
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45
- batch = torch.export.Dim("batch", min=1, max=1024)
45
+ batch = "batch"
46
46
  seq_length = "sequence_length"
47
47
  shapes = {
48
48
  "input_ids": {0: batch, 1: seq_length},
@@ -107,7 +107,7 @@ def _get_inputs_gemma3(
107
107
  assert (
108
108
  "cls_cache" not in kwargs
109
109
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
110
- batch = torch.export.Dim("batch", min=1, max=1024)
110
+ batch = "batch"
111
111
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
112
112
  # cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
113
113
 
@@ -230,7 +230,7 @@ def get_inputs(
230
230
  assert (
231
231
  "cls_cache" not in kwargs
232
232
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
233
- batch = torch.export.Dim("batch", min=1, max=1024)
233
+ batch = "batch"
234
234
  batch_img = torch.export.Dim("batch_img", min=1, max=1024)
235
235
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
236
236
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -42,7 +42,7 @@ def get_inputs(
42
42
  assert (
43
43
  "cls_cache" not in kwargs
44
44
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45
- batch = torch.export.Dim("batch", min=1, max=1024)
45
+ batch = "batch"
46
46
  seq_length = "seq_length"
47
47
  shapes = {
48
48
  "input_ids": {0: batch, 1: seq_length},
@@ -70,7 +70,7 @@ def get_inputs(
70
70
  assert (
71
71
  "cls_cache" not in kwargs
72
72
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
73
- batch = torch.export.Dim("batch", min=1, max=1024)
73
+ batch = "batch"
74
74
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
75
75
  cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
76
76
  cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
@@ -72,7 +72,7 @@ def get_inputs(
72
72
  assert (
73
73
  "cls_cache" not in kwargs
74
74
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
75
- batch = torch.export.Dim("batch", min=1, max=1024)
75
+ batch = "batch"
76
76
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
77
77
  cache_length = "cache_length_key"
78
78
  cache_length2 = "cache_length_val"
@@ -42,7 +42,7 @@ def get_inputs(
42
42
  assert (
43
43
  "cls_cache" not in kwargs
44
44
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45
- batch = torch.export.Dim("batch", min=1, max=1024)
45
+ batch = "batch"
46
46
  seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
47
47
  shapes = {
48
48
  "input_ids": {0: batch, 1: seq_length},
@@ -83,7 +83,7 @@ def get_inputs(
83
83
  :class:`transformers.cache_utils.DynamicCache`
84
84
  :return: dictionary
85
85
  """
86
- batch = torch.export.Dim("batch", min=1, max=1024)
86
+ batch = "batch"
87
87
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
88
88
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
89
89
 
@@ -65,7 +65,7 @@ def get_inputs(
65
65
  input_width, int
66
66
  ), f"Unexpected type for input_height {type(input_height)}{config}"
67
67
 
68
- batch = torch.export.Dim("batch", min=1, max=1024)
68
+ batch = "batch"
69
69
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
70
70
  shapes = {
71
71
  "input_ids": {0: batch, 1: seq_length},
@@ -384,7 +384,7 @@ class ControlFlowScan(torch.nn.Module):
384
384
 
385
385
  def forward(self, x):
386
386
  init = torch.zeros_like(x[0])
387
- carry, out = torch.ops.higher_order.scan(
387
+ carry, _out = torch.ops.higher_order.scan(
388
388
  ControlFlowScan.add, [init], [x], additional_inputs=[]
389
389
  )
390
390
  return carry
@@ -429,7 +429,7 @@ class ControlFlowScanCDist(torch.nn.Module):
429
429
  return [carry.clone(), rd]
430
430
 
431
431
  def forward(self, x):
432
- carry, out = torch.ops.higher_order.scan(
432
+ _carry, out = torch.ops.higher_order.scan(
433
433
  ControlFlowScanCDist.dist,
434
434
  [x],
435
435
  [x],
@@ -483,7 +483,7 @@ class ControlFlowScanCDistXY(torch.nn.Module):
483
483
  return [y.clone(), rd]
484
484
 
485
485
  def forward(self, x, y):
486
- carry, out = torch.ops.higher_order.scan(
486
+ _carry, out = torch.ops.higher_order.scan(
487
487
  ControlFlowScanCDistXY.dist,
488
488
  [y],
489
489
  [x],
@@ -439,6 +439,28 @@ def torch_export_patches(
439
439
  f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
440
440
  masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
441
441
 
442
+ if verbose:
443
+ print(
444
+ "[torch_export_patches] patches "
445
+ "transformers.masking_utils.sdpa_mask_recent_torch"
446
+ )
447
+ f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
448
+ masking_utils.sdpa_mask_recent_torch = (
449
+ patch_transformers_list.patched_sdpa_mask_recent_torch
450
+ )
451
+ if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
452
+ if verbose:
453
+ print(
454
+ "[torch_export_patches] patches "
455
+ "transformers.masking_utils.sdpa_mask"
456
+ )
457
+ f_transformers_sdpa_mask = masking_utils.sdpa_mask
458
+ masking_utils.sdpa_mask = (
459
+ patch_transformers_list.patched_sdpa_mask_recent_torch
460
+ )
461
+ else:
462
+ f_transformers_sdpa_mask = None
463
+
442
464
  if (
443
465
  masking_utils
444
466
  and patch_transformers_list.patch_masking_utils
@@ -456,10 +478,37 @@ def torch_export_patches(
456
478
  and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
457
479
  == f_transformers_eager_mask
458
480
  ):
481
+ if verbose:
482
+ print(
483
+ "[torch_export_patches] patches "
484
+ "transformers.masking_utils.eager_mask "
485
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
486
+ )
459
487
  masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
460
488
  patch_transformers_list.patched_eager_mask
461
489
  )
462
490
 
491
+ if (
492
+ masking_utils
493
+ and patch_transformers_list.patch_masking_utils
494
+ and hasattr(masking_utils, "sdpa_mask")
495
+ and f_transformers_sdpa_mask is not None
496
+ ):
497
+ if verbose:
498
+ print(
499
+ "[torch_export_patches] patches "
500
+ "transformers.masking_utils.sdpa_mask "
501
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
502
+ )
503
+ if (
504
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
505
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
506
+ == f_transformers_sdpa_mask
507
+ ):
508
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
509
+ patch_transformers_list.patched_sdpa_mask_recent_torch
510
+ )
511
+
463
512
  if custom_patches:
464
513
  if verbose:
465
514
  print("[torch_export_patches] applies custom patches")
@@ -568,12 +617,31 @@ def torch_export_patches(
568
617
  and hasattr(masking_utils, "_vmap_for_bhqkv")
569
618
  ):
570
619
  masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
620
+
571
621
  if verbose:
572
622
  print(
573
623
  "[torch_export_patches] restored "
574
624
  "transformers.masking_utils._vmap_for_bhqkv"
575
625
  )
576
626
 
627
+ masking_utils.sdpa_mask_recent_torch = (
628
+ f_transformers_sdpa_mask_recent_torch
629
+ )
630
+
631
+ if verbose:
632
+ print(
633
+ "[torch_export_patches] restored "
634
+ "transformers.masking_utils.sdpa_mask_recent_torch"
635
+ )
636
+
637
+ if f_transformers_sdpa_mask is not None:
638
+ masking_utils.sdpa_mask = f_transformers_sdpa_mask
639
+ if verbose:
640
+ print(
641
+ "[torch_export_patches] restored "
642
+ "transformers.masking_utils.sdpa_mask"
643
+ )
644
+
577
645
  if (
578
646
  masking_utils
579
647
  and patch_transformers_list.patch_masking_utils
@@ -581,6 +649,11 @@ def torch_export_patches(
581
649
  ):
582
650
  f_transformers_eager_mask = masking_utils.eager_mask
583
651
  masking_utils.eager_mask = f_transformers_eager_mask
652
+ if verbose:
653
+ print(
654
+ "[torch_export_patches] restored "
655
+ "transformers.masking_utils.eager_mask"
656
+ )
584
657
  if (
585
658
  "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
586
659
  and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
@@ -589,11 +662,32 @@ def torch_export_patches(
589
662
  masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
590
663
  f_transformers_eager_mask
591
664
  )
592
- if verbose:
593
- print(
594
- "[torch_export_patches] restored "
595
- "transformers.masking_utils.eager_mask"
665
+ if verbose:
666
+ print(
667
+ "[torch_export_patches] restored "
668
+ "transformers.masking_utils.eager_mask "
669
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
670
+ )
671
+
672
+ if (
673
+ masking_utils
674
+ and patch_transformers_list.patch_masking_utils
675
+ and hasattr(masking_utils, "sdpa_mask")
676
+ ):
677
+ if (
678
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
679
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
680
+ == patch_transformers_list.patched_sdpa_mask_recent_torch
681
+ ):
682
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
683
+ f_transformers_sdpa_mask
596
684
  )
685
+ if verbose:
686
+ print(
687
+ "[torch_export_patches] restored "
688
+ "transformers.masking_utils.sdpa_mask "
689
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
690
+ )
597
691
 
598
692
  ########
599
693
  # caches
@@ -205,7 +205,10 @@ class patched_ShapeEnv:
205
205
  # Precondition: a == tgt
206
206
  assert isinstance(a, sympy.Symbol)
207
207
 
208
- if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt):
208
+ if (
209
+ getattr(self, "allow_complex_guards_as_runtime_asserts", False)
210
+ or getattr(self, "prefer_deferred_runtime_asserts_over_guards", False)
211
+ ) and not _is_supported_equivalence(tgt):
209
212
  # continuing leads to placeholder shapes
210
213
  # having complex expressions that we can't resolve
211
214
  return
@@ -37,7 +37,13 @@ from ...helpers.torch_helper import is_torchdynamo_exporting
37
37
 
38
38
  if patch_masking_utils:
39
39
  # Introduced in 4.52
40
- from transformers.masking_utils import causal_mask_function, sdpa_mask
40
+ from transformers.masking_utils import (
41
+ causal_mask_function,
42
+ padding_mask_function,
43
+ and_masks,
44
+ _ignore_causal_mask_sdpa,
45
+ prepare_padding_mask,
46
+ )
41
47
 
42
48
  def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
43
49
  """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
@@ -105,7 +111,7 @@ if patch_masking_utils:
105
111
  """manual patch for function ``transformers.masking_utils.eager_mask``."""
106
112
  # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
107
113
  _ = kwargs.pop("allow_is_causal_skip", None)
108
- mask = sdpa_mask(
114
+ mask = patched_sdpa_mask_recent_torch(
109
115
  batch_size=batch_size,
110
116
  cache_position=cache_position,
111
117
  kv_length=kv_length,
@@ -125,6 +131,35 @@ if patch_masking_utils:
125
131
  mask = (~mask).to(dtype) * min_dtype
126
132
  return mask
127
133
 
134
+ def patched_sdpa_mask_recent_torch(
135
+ batch_size: int,
136
+ cache_position: torch.Tensor,
137
+ kv_length: int,
138
+ kv_offset: int = 0,
139
+ mask_function: Callable = causal_mask_function,
140
+ attention_mask: Optional[torch.Tensor] = None,
141
+ local_size: Optional[int] = None,
142
+ allow_is_causal_skip: bool = True,
143
+ **kwargs,
144
+ ) -> Optional[torch.Tensor]:
145
+ """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
146
+ q_length = cache_position.shape[0]
147
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
148
+ if allow_is_causal_skip and _ignore_causal_mask_sdpa(
149
+ padding_mask, q_length, kv_length, kv_offset, local_size
150
+ ):
151
+ return None
152
+ kv_arange = torch.arange(kv_length, device=cache_position.device)
153
+ kv_arange += kv_offset
154
+ if padding_mask is not None:
155
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
156
+ batch_arange = torch.arange(batch_size, device=cache_position.device)
157
+ head_arange = torch.arange(1, device=cache_position.device)
158
+ causal_mask = patched__vmap_for_bhqkv(mask_function)(
159
+ batch_arange, head_arange, cache_position, kv_arange
160
+ )
161
+ return causal_mask
162
+
128
163
 
129
164
  if patch_parse_processor_args:
130
165
 
@@ -218,7 +218,6 @@ def unflatten_sliding_window_cache(
218
218
  values: List[Any], context: torch.utils._pytree.Context, output_type=None
219
219
  ) -> SlidingWindowCache:
220
220
  """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
221
- key_cache, value_cache = values
222
221
  return make_sliding_window_cache(list(zip(values[0], values[1])))
223
222
 
224
223
 
@@ -11,6 +11,7 @@ __data_arch__ = textwrap.dedent(
11
11
  """
12
12
  architecture,task
13
13
  ASTModel,feature-extraction
14
+ AutoencoderKL,image-to-image
14
15
  AlbertModel,feature-extraction
15
16
  BeitForImageClassification,image-classification
16
17
  BartForConditionalGeneration,summarization
@@ -154,6 +155,7 @@ __data_arch__ = textwrap.dedent(
154
155
  Wav2Vec2ForCTC,automatic-speech-recognition
155
156
  YolosForObjectDetection,object-detection
156
157
  YolosModel,image-feature-extraction
158
+ Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
157
159
  emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
158
160
  )
159
161