onnx-diagnostic 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
  4. onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
  5. onnx_diagnostic/ci_models/data/__init__.py +0 -0
  6. onnx_diagnostic/ci_models/export_phi4_mm.py +10 -7
  7. onnx_diagnostic/export/api.py +13 -4
  8. onnx_diagnostic/export/dynamic_shapes.py +1 -1
  9. onnx_diagnostic/export/validate.py +2 -0
  10. onnx_diagnostic/ext_test_case.py +32 -15
  11. onnx_diagnostic/helpers/args_helper.py +1 -0
  12. onnx_diagnostic/helpers/bench_run.py +0 -1
  13. onnx_diagnostic/helpers/cache_helper.py +102 -36
  14. onnx_diagnostic/helpers/doc_helper.py +7 -4
  15. onnx_diagnostic/helpers/graph_helper.py +6 -6
  16. onnx_diagnostic/helpers/helper.py +39 -0
  17. onnx_diagnostic/helpers/log_helper.py +37 -14
  18. onnx_diagnostic/helpers/memory_peak.py +5 -1
  19. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  20. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  21. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  22. onnx_diagnostic/helpers/ort_session.py +5 -2
  23. onnx_diagnostic/helpers/rt_helper.py +53 -9
  24. onnx_diagnostic/helpers/torch_helper.py +15 -11
  25. onnx_diagnostic/investigate/__init__.py +0 -0
  26. onnx_diagnostic/investigate/input_observer.py +970 -0
  27. onnx_diagnostic/reference/evaluator.py +0 -1
  28. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  29. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  30. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  31. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  32. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  33. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  34. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  35. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
  36. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  39. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
  40. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
  41. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +70 -23
  42. onnx_diagnostic/torch_models/code_sample.py +5 -10
  43. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  44. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
  45. onnx_diagnostic/torch_models/validate.py +1 -1
  46. onnx_diagnostic/torch_onnx/compare.py +0 -1
  47. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  48. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  49. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  50. onnx_diagnostic/typing.py +15 -0
  51. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
  52. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +55 -50
  53. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
  54. onnx_diagnostic/api.py +0 -15
  55. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  56. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,6 @@ from .ops.op_slice import Slice_1, Slice_10
42
42
  from .ops.op_transpose_cast import Transpose2DCastFP16, Transpose2DCastFP32
43
43
  from .ops.op_tri_matrix import TriMatrix
44
44
 
45
-
46
45
  logger = getLogger("onnx-diagnostic-eval")
47
46
 
48
47
 
@@ -34,7 +34,6 @@ from ..helpers.torch_helper import to_tensor
34
34
  from .report_results_comparison import ReportResultComparison
35
35
  from .evaluator import ExtendedReferenceEvaluator
36
36
 
37
-
38
37
  PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
39
38
  Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
40
39
 
@@ -1,5 +1,4 @@
1
- from typing import Any, Dict, List, Tuple, Union
2
-
1
+ from typing import Any, Dict, List, Set, Tuple, Union
3
2
 
4
3
  ReportKeyNameType = Union[str, Tuple[str, int, str]]
5
4
  ReportKeyValueType = Tuple[int, Tuple[int, ...]]
@@ -14,6 +13,7 @@ class ReportResultComparison:
14
13
  :param tensors: tensor
15
14
  """
16
15
 
16
+ # pyrefly: ignore[unknown-name]
17
17
  def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
18
18
  from ..helpers.onnx_helper import dtype_to_tensor_dtype
19
19
  from ..helpers import max_diff, string_type
@@ -25,7 +25,9 @@ class ReportResultComparison:
25
25
  self.max_diff = max_diff
26
26
  self.tensors = tensors
27
27
  self._build_mapping()
28
+ self.unique_run_names: Set[str] = set()
28
29
 
30
+ # pyrefly: ignore[unknown-name]
29
31
  def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
30
32
  "Returns a key for a tensor, (onnx dtype, shape)."
31
33
  return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
@@ -59,12 +61,15 @@ class ReportResultComparison:
59
61
  for k, v in self.value.items():
60
62
  (i_run, run_name), ref_name = k
61
63
  d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
64
+ # pyrefly: ignore[no-matching-overload]
62
65
  d.update(v)
63
66
  rows.append(d)
64
67
  return rows
65
68
 
66
69
  def report(
67
- self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
70
+ self,
71
+ # pyrefly: ignore[unknown-name]
72
+ outputs: Dict[str, "torch.Tensor"], # noqa: F821
68
73
  ) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
69
74
  """
70
75
  For every tensor in outputs, compares it to every tensor held by
@@ -79,6 +84,7 @@ class ReportResultComparison:
79
84
  key = self.key(tensor)
80
85
  if key not in self.mapping:
81
86
  continue
87
+ # pyrefly: ignore[unknown-name]
82
88
  cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
83
89
  for held_key in self.mapping[key]:
84
90
  t2 = self.tensors[held_key]
@@ -63,7 +63,7 @@ class TorchOnnxEvaluator:
63
63
  * `functions`: local functions
64
64
 
65
65
  The class is not multithreaded. `runtime_info` gets updated
66
- by the the class. The list of available kernels is returned by function
66
+ by the class. The list of available kernels is returned by function
67
67
  :func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
68
68
  Example:
69
69
 
@@ -494,8 +494,10 @@ class TorchOnnxEvaluator:
494
494
  r = self.runtime_info[k]
495
495
  r.set_value(
496
496
  torch_ops.OpRunTensor(
497
+ # pyrefly: ignore[missing-attribute]
497
498
  v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
498
499
  is_constant=False,
500
+ # pyrefly: ignore[missing-attribute]
499
501
  may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
500
502
  )
501
503
  )
@@ -524,6 +526,7 @@ class TorchOnnxEvaluator:
524
526
  f"for kernel {type(kernel)}."
525
527
  )
526
528
  for name, t in zip(kernel.output, res):
529
+ # pyrefly: ignore[bad-argument-type]
527
530
  self.runtime_info[name].set_value(t)
528
531
  if self.verbose:
529
532
  for name in kernel.output:
@@ -644,6 +647,7 @@ class TorchOnnxEvaluator:
644
647
  f"for kernel {type(kernel)}."
645
648
  )
646
649
  for name, t in zip(kernel.output, res):
650
+ # pyrefly: ignore[bad-argument-type]
647
651
  self.runtime_info[name].set_value(t)
648
652
  else:
649
653
  assert isinstance(
@@ -1,7 +1,7 @@
1
1
  from typing import Any, Dict, List, Optional, Union, Tuple
2
2
  import onnx
3
3
  import torch
4
- from ...api import TensorLike
4
+ from ...typing import TensorLike
5
5
  from ...helpers import string_type
6
6
  from ...helpers.torch_helper import to_tensor
7
7
 
@@ -149,7 +149,7 @@ class OpRunSequence(OpRunValue):
149
149
  ) -> "OpRunSequence":
150
150
  "Inserts a value at a given position."
151
151
  assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
152
- new_seq = OpRunSequence()
152
+ new_seq = OpRunSequence() # type: ignore[abstract]
153
153
  seq = self.sequence.copy()
154
154
  new_seq.sequence = seq
155
155
  if position is None:
@@ -314,9 +314,7 @@ class OpRunKernel:
314
314
 
315
315
 
316
316
  class OpRunFunction(OpRunKernel):
317
- """
318
- Defines a kernel based on a local functions.
319
- """
317
+ """Defines a kernel based on a local functions."""
320
318
 
321
319
  def __init__(
322
320
  self,
@@ -46,7 +46,7 @@ class SequenceEmpty_11(OpRunOpSequence):
46
46
  )
47
47
 
48
48
  def run(self) -> OpRunSequence:
49
- return OpRunSequence(dtype=self.dtype)
49
+ return OpRunSequence(dtype=self.dtype) # type: ignore[abstract]
50
50
 
51
51
 
52
52
  class SequenceInsert_11(OpRunOpSequence):
@@ -3,7 +3,6 @@ import torch
3
3
  from ..helpers.config_helper import update_config, check_hasattr
4
4
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
5
 
6
-
7
6
  __TASK__ = "feature-extraction"
8
7
 
9
8
 
@@ -4,7 +4,6 @@ from .onnx_export_errors import (
4
4
  )
5
5
  from .patch_module import torch_export_rewrite
6
6
 
7
-
8
7
  # bypass_export_some_errors is the first name given to the patches.
9
8
  bypass_export_some_errors = torch_export_patches # type: ignore
10
9
 
@@ -562,6 +562,7 @@ def _patch_transformers(
562
562
  "[torch_export_patches] patches "
563
563
  "transformers.masking_utils.sdpa_mask_recent_torch"
564
564
  )
565
+
565
566
  f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
566
567
  masking_utils.sdpa_mask_recent_torch = (
567
568
  patch_transformers_list.patched_sdpa_mask_recent_torch
@@ -574,7 +575,9 @@ def _patch_transformers(
574
575
  )
575
576
  if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
576
577
  if verbose:
577
- print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask")
578
+ print(
579
+ "[torch_export_patches] patches transformers.masking_utils.sdpa_mask (1)"
580
+ )
578
581
  f_transformers_sdpa_mask = masking_utils.sdpa_mask
579
582
  masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask_recent_torch
580
583
  if patch_details:
@@ -583,8 +586,23 @@ def _patch_transformers(
583
586
  f_transformers_sdpa_mask,
584
587
  patch_transformers_list.patched_sdpa_mask_recent_torch,
585
588
  )
586
- else:
587
- f_transformers_sdpa_mask = None
589
+
590
+ if ( # vmap
591
+ masking_utils
592
+ and patch_transformers_list.patch_masking_utils
593
+ and hasattr(masking_utils, "sdpa_mask")
594
+ and f_transformers_sdpa_mask is None
595
+ ):
596
+ if verbose:
597
+ print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask (3)")
598
+ f_transformers_sdpa_mask = masking_utils.sdpa_mask
599
+ masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask
600
+ if patch_details:
601
+ patch_details.append(
602
+ "transformers",
603
+ f_transformers_sdpa_mask,
604
+ patch_transformers_list.patched_sdpa_mask,
605
+ )
588
606
 
589
607
  if ( # eager_mask
590
608
  masking_utils
@@ -742,17 +760,17 @@ def _unpatch_transformers(
742
760
  "transformers.masking_utils.sdpa_mask_recent_torch"
743
761
  )
744
762
 
745
- if f_transformers_sdpa_mask is not None:
746
- assert f_transformers_sdpa_mask.__name__ in (
747
- "sdpa_mask",
748
- "sdpa_mask_recent_torch",
749
- ), (
750
- f"corrupted function 'sdpa_mask', its name is "
751
- f"{f_transformers_sdpa_mask.__name__!r}"
752
- )
753
- masking_utils.sdpa_mask = f_transformers_sdpa_mask
754
- if verbose:
755
- print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask")
763
+ if f_transformers_sdpa_mask is not None:
764
+ assert f_transformers_sdpa_mask.__name__ in (
765
+ "sdpa_mask",
766
+ "sdpa_mask_recent_torch",
767
+ ), (
768
+ f"corrupted function 'sdpa_mask', its name is "
769
+ f"{f_transformers_sdpa_mask.__name__!r}"
770
+ )
771
+ masking_utils.sdpa_mask = f_transformers_sdpa_mask
772
+ if verbose:
773
+ print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask")
756
774
 
757
775
  if ( # eager_mask
758
776
  masking_utils
@@ -986,7 +986,7 @@ def torch_export_rewrite(
986
986
  name = me.__qualname__
987
987
  spl = name.split(".")
988
988
  if len(spl) == 1:
989
- # This a function
989
+ # This is a function
990
990
  module = me.__module__
991
991
  if module in me.__globals__:
992
992
  mod = me.__globals__[module]
@@ -36,6 +36,26 @@ if patch_masking_utils:
36
36
  _ignore_bidirectional_mask_sdpa = None
37
37
  bidirectional_mask_function = None
38
38
 
39
+ try:
40
+ from transformers.masking_utils import _non_vmap_expansion_sdpa
41
+ except ImportError:
42
+
43
+ def _non_vmap_expansion_sdpa(
44
+ batch_indices: torch.Tensor,
45
+ head_indices: torch.Tensor,
46
+ q_indices: torch.Tensor,
47
+ kv_indices: torch.Tensor,
48
+ ):
49
+ """
50
+ https://github.com/huggingface/optimum-onnx/blob/
51
+ c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
52
+ """
53
+ batch_indices = batch_indices[:, None, None, None]
54
+ head_indices = head_indices[None, :, None, None]
55
+ q_indices = q_indices[None, None, :, None]
56
+ kv_indices = kv_indices[None, None, None, :]
57
+ return batch_indices, head_indices, q_indices, kv_indices
58
+
39
59
  def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
40
60
  """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
41
61
  from ...helpers import string_type
@@ -146,12 +166,13 @@ if patch_masking_utils:
146
166
  padding_mask, q_length, kv_length, kv_offset, local_size
147
167
  ):
148
168
  return None
149
- if (
150
- allow_is_bidirectional_skip
151
- and _ignore_bidirectional_mask_sdpa
152
- and _ignore_bidirectional_mask_sdpa(padding_mask)
153
- ):
154
- return None
169
+ if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa:
170
+ # transformers<=5.0: 1 parameter, 3 for transformers>5.0
171
+ n_parameters = len(inspect.signature(_ignore_bidirectional_mask_sdpa).parameters)
172
+ if _ignore_bidirectional_mask_sdpa(
173
+ *[padding_mask, kv_length, kv_offset][:n_parameters]
174
+ ):
175
+ return None
155
176
 
156
177
  if mask_function is bidirectional_mask_function:
157
178
  if padding_mask is not None:
@@ -180,3 +201,83 @@ if patch_masking_utils:
180
201
  batch_arange, head_arange, cache_position, kv_arange
181
202
  )
182
203
  return causal_mask
204
+
205
+ def patched_sdpa_mask(
206
+ batch_size: int,
207
+ cache_position: torch.Tensor,
208
+ kv_length: int,
209
+ kv_offset: int = 0,
210
+ mask_function: Callable = causal_mask_function,
211
+ attention_mask: torch.Tensor | None = None,
212
+ local_size: int | None = None,
213
+ allow_is_causal_skip: bool = True,
214
+ allow_is_bidirectional_skip: bool = False,
215
+ allow_torch_fix: bool = True,
216
+ use_vmap: bool = False,
217
+ **kwargs,
218
+ ) -> torch.Tensor | None:
219
+ """manual patch for function ``transformers.masking_utils.sdpa_mask``."""
220
+ q_length = cache_position.shape[0]
221
+
222
+ # Potentially pad the 2D mask
223
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
224
+
225
+ # Under specific conditions, we can avoid materializing the mask
226
+ # 1. Causal masks can rely on the `is_causal` argument
227
+ # 2. Bidirectional do not need any further processing (no bias)
228
+ if allow_is_causal_skip and _ignore_causal_mask_sdpa(
229
+ padding_mask, q_length, kv_length, kv_offset, local_size
230
+ ):
231
+ return None
232
+ if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(
233
+ padding_mask, kv_length, local_size
234
+ ):
235
+ return None
236
+
237
+ # Potentially add the padding 2D mask
238
+ if padding_mask is not None:
239
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
240
+
241
+ batch_arange = torch.arange(batch_size, device=cache_position.device)
242
+ head_arange = torch.arange(1, device=cache_position.device)
243
+ # Similar to `kv_arange = torch.arange(start=kv_offset,
244
+ # end=kv_offset + kv_length, device=cache_position.device)`
245
+ # but without data-dependent slicing (i.e. torch.compile friendly)
246
+ kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset
247
+
248
+ # Actual mask creation
249
+ # Option 1: Fast non-vmap mask creation (default)
250
+ # PATCHED
251
+ use_vmap = False
252
+ if not use_vmap:
253
+ # Apply mask function element-wise through broadcasting
254
+ attention_mask = mask_function(
255
+ *_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange)
256
+ )
257
+ # Expand the mask to match batch size
258
+ # and query length if they weren't used in the mask function
259
+ attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
260
+
261
+ # Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
262
+ # elif _is_torch_greater_or_equal_than_2_6:
263
+ # This creates the 4D mask easily.
264
+ # Note that we need this context manager as vmap cannot handle slicing a tensor from
265
+ # scalar tensor (it internally calls `.item()` which vmap does not allow,
266
+ # but this context works around it
267
+ # We don't need to add an offset to the mask_function either,
268
+ # as we vmap directly the correct indices for k and kv indices
269
+ # with TransformGetItemToIndex():
270
+ # attention_mask = _vmap_expansion_sdpa(mask_function)(
271
+ # batch_arange, head_arange, cache_position, kv_arange
272
+ # )
273
+
274
+ # Option 3: Error out since it indicates that the user did something custom,
275
+ # which they shouldn't have (torch<2.6)
276
+ else:
277
+ raise ValueError(
278
+ "The vmap functionality for mask creation "
279
+ "is only supported from torch>=2.6. "
280
+ "Please update your torch version or use "
281
+ "`use_vmap=False` with index-based masks."
282
+ )
283
+ return attention_mask
@@ -7,10 +7,10 @@ import transformers
7
7
 
8
8
  def patched__compute_dynamic_ntk_parameters(
9
9
  config: Optional[transformers.PretrainedConfig] = None,
10
- device: Optional["torch.device"] = None,
10
+ device: Optional[torch.device] = None,
11
11
  seq_len: Optional[int] = None,
12
12
  **rope_kwargs,
13
- ) -> Tuple["torch.Tensor", float]:
13
+ ) -> Tuple[torch.Tensor, float]:
14
14
  """
15
15
  manual patch:
16
16
  ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
@@ -188,6 +188,11 @@ def patched__broadcast_shapes(*_shapes):
188
188
  return common_shape
189
189
 
190
190
 
191
+ def value_ranges_is_positive(value_ranges: torch.utils._sympy.value_ranges.ValueRanges):
192
+ """Tells if an interval is equivalent to a positive or null integer."""
193
+ return value_ranges.lower == 0 and value_ranges.upper > 4623372036854775806
194
+
195
+
191
196
  class patched_ShapeEnv:
192
197
 
193
198
  def _check_frozen(
@@ -281,7 +286,10 @@ class patched_ShapeEnv:
281
286
  )
282
287
  self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
283
288
  tgt_bound = self.bound_sympy(tgt)
284
- assert tgt_bound.issubset(
289
+ assert (
290
+ value_ranges_is_positive(tgt_bound)
291
+ and value_ranges_is_positive(src_bound)
292
+ ) or tgt_bound.issubset(
285
293
  src_bound
286
294
  ), f"{tgt_bound=} not a subset of {src_bound=}"
287
295
 
@@ -524,8 +532,10 @@ class patched_ShapeEnv:
524
532
 
525
533
  transmute_into_runtime_assert = False
526
534
 
527
- backed_var_to_val = getattr(
528
- self, "backed_var_to_val", getattr(self, "var_to_val", {})
535
+ backed_var_to_val = (
536
+ self.backed_var_to_val
537
+ if hasattr(self, "backed_var_to_val")
538
+ else self.var_to_val
529
539
  )
530
540
  concrete_val = None
531
541
  if not (expr.free_symbols <= backed_var_to_val.keys()):
@@ -38,6 +38,7 @@ if patch_masking_utils:
38
38
  from ._patch_transformers_masking_utils import (
39
39
  patched__vmap_for_bhqkv,
40
40
  patched_eager_mask,
41
+ patched_sdpa_mask,
41
42
  patched_sdpa_mask_recent_torch,
42
43
  )
43
44
 
@@ -1,6 +1,7 @@
1
1
  import itertools
2
2
  from typing import Any, Callable, List, Set, Tuple
3
3
  import torch
4
+ import transformers.cache_utils
4
5
  from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
5
6
 
6
7
  try:
@@ -22,22 +23,63 @@ from transformers.modeling_outputs import BaseModelOutput
22
23
  from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue
23
24
  from . import make_serialization_function_for_dataclass
24
25
 
25
-
26
26
  SUPPORTED_DATACLASSES: Set[type] = set()
27
+
27
28
  WRONG_REGISTRATIONS = {
28
29
  DynamicCache: "4.50",
29
30
  BaseModelOutput: None,
30
31
  }
31
32
 
33
+ SHORTEN_LAYER_NAMES = {
34
+ "DynamicLayer": "D",
35
+ "DynamicSlidingWindowLayer": "W",
36
+ "StaticLayer": "S",
37
+ "StaticSlidingWindowLayer": "X",
38
+ "D": "DynamicLayer",
39
+ "W": "DynamicSlidingWindowLayer",
40
+ "S": "StaticLayer",
41
+ "X": "StaticSlidingWindowLayer",
42
+ }
43
+
44
+ KWARGS_LAYER_NAMES = {
45
+ "DynamicLayer": lambda layer: "",
46
+ "DynamicSlidingWindowLayer": lambda layer: str(layer.sliding_window),
47
+ "StaticLayer": lambda layer: "",
48
+ "StaticSlidingWindowLayer": lambda layer: str(layer.sliding_window),
49
+ }
50
+
51
+ PARSE_LAYER_NAMES = {
52
+ "DynamicLayer": lambda skw: {},
53
+ "DynamicSlidingWindowLayer": lambda skw: dict(sliding_window=int(skw[1:])),
54
+ "StaticLayer": lambda skw: {},
55
+ "StaticSlidingWindowLayer": lambda skw: dict(sliding_window=int(skw[1:])),
56
+ }
57
+
32
58
 
33
59
  def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
34
60
  ca = CacheKeyValue(cache)
35
61
  flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
36
- keys = list(
37
- itertools.chain.from_iterable(
38
- (f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
62
+ unique = set(ca.cls_layers) if ca.cls_layers else None
63
+ if (
64
+ cache.__class__.__name__ != "DynamicCache"
65
+ or unique is None
66
+ or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
67
+ ):
68
+ keys = list(
69
+ itertools.chain.from_iterable(
70
+ (f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
71
+ )
39
72
  )
40
- )
73
+ return flat, keys
74
+
75
+ keys = []
76
+ for i in range(len(ca.key_cache)):
77
+ letter = SHORTEN_LAYER_NAMES[ca.cls_layers[i].__name__]
78
+ if hasattr(cache, "layers"):
79
+ kwargs = KWARGS_LAYER_NAMES[ca.cls_layers[i].__name__](cache.layers[i])
80
+ else:
81
+ kwargs = ""
82
+ keys.extend([f"key_{letter}{kwargs}_{i}", f"value_{letter}{kwargs}_{i}"])
41
83
  return flat, keys
42
84
 
43
85
 
@@ -55,7 +97,26 @@ def _unflatten_cache(
55
97
  output_type=None,
56
98
  ) -> DynamicCache:
57
99
  """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
58
- res = make_cache(list(zip(values[::2], values[1::2])))
100
+ expected = list(
101
+ itertools.chain.from_iterable(
102
+ (f"key_{i}", f"value_{i}") for i in range(len(values) // 2)
103
+ )
104
+ )
105
+ if expected == context:
106
+ res = make_cache(list(zip(values[::2], values[1::2])))
107
+ else:
108
+ cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
109
+ cls_kwargs = [
110
+ PARSE_LAYER_NAMES[SHORTEN_LAYER_NAMES[name.split("_")[1][0]]](name.split("_")[1])
111
+ for name in context
112
+ ][::2]
113
+ cls_layers = [
114
+ getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names
115
+ ]
116
+ res = make_cache(
117
+ list(zip(values[::2], values[1::2])), cls_layers=cls_layers, cls_kwargs=cls_kwargs
118
+ )
119
+
59
120
  assert output_type is None or isinstance(
60
121
  res, output_type
61
122
  ), f"Type mismatch between {output_type} (expected) and {type(res)}"
@@ -71,14 +132,6 @@ def flatten_dynamic_cache(
71
132
  dynamic_cache: DynamicCache,
72
133
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
73
134
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
74
- assert (
75
- not hasattr(dynamic_cache, "layers")
76
- or not dynamic_cache.layers
77
- or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
78
- ), (
79
- f"The serialization does not work yet on other layers "
80
- f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
81
- )
82
135
  return _flatten_key_value_cache(dynamic_cache)
83
136
 
84
137
 
@@ -86,14 +139,6 @@ def flatten_with_keys_dynamic_cache(
86
139
  dynamic_cache: DynamicCache,
87
140
  ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
88
141
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
89
- assert (
90
- not hasattr(dynamic_cache, "layers")
91
- or not dynamic_cache.layers
92
- or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
93
- ), (
94
- f"The serialization does not work yet on other layers "
95
- f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
96
- )
97
142
  return _flatten_with_keys_cache(dynamic_cache)
98
143
 
99
144
 
@@ -161,7 +206,9 @@ def unflatten_static_cache(
161
206
  ) -> StaticCache:
162
207
  """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
163
208
  return _unflatten_cache(
164
- lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
209
+ lambda *args, **kwargs: make_static_cache(
210
+ *args, max_cache_len=values[0].shape[2], **kwargs
211
+ ),
165
212
  values,
166
213
  context,
167
214
  output_type=output_type,
@@ -8,11 +8,9 @@ from .hghub.model_inputs import _preprocess_model_id
8
8
  from .hghub import get_untrained_model_with_inputs
9
9
  from .validate import filter_inputs, make_patch_kwargs
10
10
 
11
-
12
11
  CODE_SAMPLES = {
13
12
  "imports": "from typing import Any\nimport torch",
14
- "get_model_with_inputs": textwrap.dedent(
15
- """
13
+ "get_model_with_inputs": textwrap.dedent("""
16
14
  def get_model_with_inputs(
17
15
  model_id:str,
18
16
  subfolder: str | None = None,
@@ -57,8 +55,7 @@ CODE_SAMPLES = {
57
55
  if device:
58
56
  data["model"] = data["model"].to(device)
59
57
  return data["model"]
60
- """
61
- ),
58
+ """),
62
59
  }
63
60
 
64
61
 
@@ -198,7 +195,7 @@ def code_sample(
198
195
  this is not always possible
199
196
  :param use_pretrained: use the trained model, not the untrained one
200
197
  :param optimization: optimization to apply to the exported model,
201
- depend on the the exporter
198
+ depend on the exporter
202
199
  :param quiet: if quiet, catches exception if any issue
203
200
  :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
204
201
  if True before exporting
@@ -326,11 +323,9 @@ def code_sample(
326
323
  imports,
327
324
  cache_import,
328
325
  CODE_SAMPLES["get_model_with_inputs"],
329
- textwrap.dedent(
330
- f"""
326
+ textwrap.dedent(f"""
331
327
  model = get_model_with_inputs({model_args})
332
- """
333
- ),
328
+ """),
334
329
  f"inputs = {input_code}",
335
330
  exporter_code,
336
331
  ]
@@ -10,8 +10,7 @@ __data_arch_values__ = {
10
10
  "ResNetForImageClassification": dict(image_size=224),
11
11
  }
12
12
 
13
- __data_arch__ = textwrap.dedent(
14
- """
13
+ __data_arch__ = textwrap.dedent("""
15
14
  architecture,task
16
15
  ASTModel,feature-extraction
17
16
  AutoencoderKL,image-to-image
@@ -166,8 +165,7 @@ __data_arch__ = textwrap.dedent(
166
165
  YolosModel,image-feature-extraction
167
166
  Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
168
167
  emilyalsentzer/Bio_ClinicalBERT,fill-mask
169
- nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video"""
170
- )
168
+ nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video""")
171
169
 
172
170
  __data_tasks__ = [
173
171
  "audio-classification",