onnx-diagnostic 0.8.10__py3-none-any.whl → 0.8.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
  4. onnx_diagnostic/export/api.py +2 -4
  5. onnx_diagnostic/export/validate.py +2 -0
  6. onnx_diagnostic/ext_test_case.py +32 -15
  7. onnx_diagnostic/helpers/args_helper.py +1 -0
  8. onnx_diagnostic/helpers/bench_run.py +0 -1
  9. onnx_diagnostic/helpers/cache_helper.py +6 -6
  10. onnx_diagnostic/helpers/doc_helper.py +7 -4
  11. onnx_diagnostic/helpers/graph_helper.py +6 -6
  12. onnx_diagnostic/helpers/log_helper.py +37 -14
  13. onnx_diagnostic/helpers/memory_peak.py +5 -1
  14. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  15. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  16. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  17. onnx_diagnostic/helpers/ort_session.py +0 -1
  18. onnx_diagnostic/helpers/torch_helper.py +8 -9
  19. onnx_diagnostic/investigate/__init__.py +0 -0
  20. onnx_diagnostic/investigate/input_observer.py +329 -0
  21. onnx_diagnostic/reference/evaluator.py +0 -1
  22. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  23. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  24. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  25. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  26. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  27. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  28. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  29. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
  32. onnx_diagnostic/torch_models/code_sample.py +5 -10
  33. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  34. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
  35. onnx_diagnostic/torch_models/validate.py +1 -1
  36. onnx_diagnostic/torch_onnx/compare.py +0 -1
  37. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  38. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  39. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  40. onnx_diagnostic/typing.py +15 -0
  41. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
  42. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +45 -43
  43. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
  44. onnx_diagnostic/api.py +0 -15
  45. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,329 @@
1
+ import contextlib
2
+ import inspect
3
+ from typing import Any, Callable, Sequence
4
+ import torch
5
+
6
+
7
+ def flatten_unflatten_for_dynamic_shapes(
8
+ obj: Any,
9
+ use_dict: bool = True,
10
+ change_function: Callable[[torch.Tensor], Any] | None = None,
11
+ ) -> Any:
12
+ """
13
+ Returns the object in a different structure similar to what
14
+ the definition of the dynamic shapes should use.
15
+
16
+ Args:
17
+ obj:
18
+ object from a custom class
19
+ use_dict:
20
+ closer to the original result but
21
+ :func:`torch.export.export` only considers the values,
22
+ the context gives the dictionary keys but it is not expressed
23
+ in the dynamic shapes, these specifications seems to be different
24
+ for the strict and non strict mode. It also preserves tuple.
25
+ change_function:
26
+ to modify the tensor in the structure itself,
27
+ like replace them by a shape
28
+
29
+ Returns:
30
+ the serialized object
31
+ """
32
+ if isinstance(obj, torch.Tensor):
33
+ return change_function(obj) if change_function else obj
34
+ flat, spec = torch.utils._pytree.tree_flatten(obj)
35
+ start = 0
36
+ end = 0
37
+ subtrees = []
38
+ for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
39
+ end += subspec.num_leaves
40
+ value = subspec.unflatten(flat[start:end])
41
+ value = flatten_unflatten_for_dynamic_shapes(
42
+ value, use_dict=use_dict, change_function=change_function
43
+ )
44
+ subtrees.append(value)
45
+ start = end
46
+ if use_dict:
47
+ if spec.type is dict:
48
+ # This is a dictionary.
49
+ return dict(zip(spec.context, subtrees))
50
+ if spec.type is tuple:
51
+ return tuple(subtrees)
52
+ if spec.type is list:
53
+ return list(subtrees)
54
+ if spec.type is None and not subtrees:
55
+ return None
56
+ if spec.context:
57
+ # This is a custom class with attributes.
58
+ # It is returned as a list.
59
+ return list(subtrees)
60
+ raise ValueError(
61
+ f"Unable to interpret spec type {spec.type} "
62
+ f"(type is {type(spec.type)}, context is {spec.context}), "
63
+ f"spec={spec}, subtrees={subtrees}"
64
+ )
65
+ # This is a list.
66
+ return subtrees
67
+
68
+
69
+ def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int]:
70
+ """
71
+ Returns the list of dynamic dimensions given a list of shapes
72
+ corresponding to the same tensor.
73
+
74
+ Args:
75
+ shape_list:
76
+ list of shapes, they must all have the same length
77
+
78
+ Returns:
79
+ list of dynamic dimensions
80
+ """
81
+ unique_ranks = {len(shape) for shape in shape_list}
82
+ torch._check(
83
+ len(unique_ranks) == 1, lambda: "all shapes in shape_list must have the same rank"
84
+ )
85
+ rank = unique_ranks.pop()
86
+ dynamic = []
87
+ for i in range(rank):
88
+ dims = [shape[i] for shape in shape_list]
89
+ if len(set(dims)) > 1:
90
+ dynamic.append(i)
91
+ return dynamic
92
+
93
+
94
+ class InputObserverInfo:
95
+ def __init__(self, signature: inspect.Signature):
96
+ # pyrefly: ignore
97
+ self.inputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
98
+ self.flat_inputs: list[list[torch.Tensor | None]] = []
99
+ # pyrefly: ignore
100
+ self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
101
+ self.flat_outputs: list[torch.Tensor | list[torch.Tensor]] = []
102
+ self.signature = signature
103
+
104
+ self._max_args: tuple[Any, torch.Tensor] | None = None
105
+ self._max_kwargs: dict[str, torch.Tensor] | None = None
106
+
107
+ def __len__(self) -> int:
108
+ return len(self.flat_inputs)
109
+
110
+ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
111
+ kwargs = {
112
+ k: v
113
+ for k, v in kwargs.items()
114
+ if v is not None and not isinstance(v, (int, float, bool))
115
+ }
116
+ flat_args, spec = torch.utils._pytree.tree_flatten((args, kwargs))
117
+ self.inputs_specs.append(spec)
118
+ cloned = [
119
+ (None if not isinstance(t, torch.Tensor) else t.clone().detach())
120
+ for t in flat_args
121
+ ]
122
+ self.flat_inputs.append(cloned)
123
+
124
+ cloned_args, cloned_kwargs = torch.utils._pytree.tree_unflatten(cloned, spec)
125
+ if self._max_args is None or len(cloned_args) > len(self._max_args):
126
+ self._max_args = cloned_args
127
+ if self._max_kwargs is None or len(cloned_kwargs) > len(self._max_kwargs):
128
+ self._max_kwargs = cloned_kwargs
129
+
130
+ def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
131
+ flat_res, spec = torch.utils._pytree.tree_flatten(res)
132
+ self.outputs_specs.append(spec)
133
+ self.flat_outputs.append([t.clone().detach() for t in flat_res])
134
+
135
+ def build_inputs_completed_with_none_values(self) -> list[list[torch.Tensor]]:
136
+ # Let's compute the sizes of each independently.
137
+ if not self.flat_inputs or self._max_args is None or self._max_kwargs is None:
138
+ raise RuntimeError("No inputs were captured.")
139
+ arg_sizes = [len(torch.utils._pytree.tree_flatten(a)[0]) for a in self._max_args]
140
+ kwarg_sizes = {
141
+ k: len(torch.utils._pytree.tree_flatten(v)[0]) for k, v in self._max_kwargs.items()
142
+ }
143
+
144
+ # Let's reprocess everything.
145
+ captured_inputs: dict[int | str, int] = {}
146
+ new_flat_inputs = []
147
+ for args_kwargs, spec in zip(self.flat_inputs, self.inputs_specs):
148
+ args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
149
+ if len(set(kwargs) | set(self._max_kwargs)) > len(self._max_kwargs):
150
+ raise RuntimeError(
151
+ "At least one call to the observed model "
152
+ "must contain all the named arguments."
153
+ )
154
+ flat = []
155
+ for i in range(len(self._max_args)):
156
+ if i < len(args):
157
+ ts = torch.utils._pytree.tree_flatten(args[i])[0]
158
+ if i in captured_inputs and captured_inputs[i] != len(ts):
159
+ raise RuntimeError(
160
+ f"Positional argument {i} has {len(ts)} tensors "
161
+ f"but previously got {captured_inputs[i]} tensors. "
162
+ f"Inference is impossible in that case."
163
+ )
164
+ captured_inputs[i] = len(ts)
165
+ flat.extend(ts)
166
+ else:
167
+ flat.extend([None for _ in range(arg_sizes[i])])
168
+ for k in self._max_kwargs:
169
+ if k in kwargs:
170
+ ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
171
+ if k in captured_inputs and captured_inputs[k] != len(ts):
172
+ raise RuntimeError(
173
+ f"Named argument {k!r} has {len(ts)} tensors "
174
+ f"but previously got {captured_inputs[k]} tensors. "
175
+ f"Inference is impossible in that case."
176
+ )
177
+ captured_inputs[k] = len(ts)
178
+ flat.extend(ts)
179
+ else:
180
+ flat.extend([None for _ in range(kwarg_sizes[k])])
181
+ new_flat_inputs.append(flat)
182
+ return new_flat_inputs
183
+
184
+ def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
185
+ flat_inputs = self.build_inputs_completed_with_none_values()
186
+ # This is already checked by build_inputs_completed_with_none_values
187
+ # but this is not always well captured by tools checking types.
188
+ assert self._max_args is not None and self._max_kwargs is not None
189
+ if len({len(flat) for flat in flat_inputs}) != 1:
190
+ raise NotImplementedError(
191
+ "infer_dynamic_shapes is not implemented "
192
+ "when the number of input tensors are not the same."
193
+ )
194
+ shape_lists = [
195
+ [(None if t is None else t.shape) for t in tensors] for tensors in flat_inputs
196
+ ]
197
+ n_tensors = len(shape_lists[0])
198
+ dynamic_shapes = [
199
+ infer_dynamic_dimensions(
200
+ [s for s in [shapes[index] for shapes in shape_lists] if s is not None]
201
+ )
202
+ for index in range(n_tensors)
203
+ ]
204
+ cst = torch.export.Dim.DYNAMIC
205
+ flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
206
+ if len(flat_dynamic_shapes) == len(self._max_args) + len(self._max_kwargs):
207
+ # It means forward method is called with tensors only.
208
+ if not self._max_kwargs:
209
+ # only positional arguments
210
+ return tuple(flat_dynamic_shapes)
211
+ if not self._max_args:
212
+ # only named arguments
213
+ return dict(zip(list(self._max_kwargs), flat_dynamic_shapes))
214
+ # positional arguments needs to be moved to the named arguments
215
+ n_args = len(self._max_args)
216
+ pos_names = list(self.signature.parameters)[:n_args]
217
+ return {
218
+ **dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
219
+ **dict(zip(list(self._max_kwargs), flat_dynamic_shapes[n_args:])),
220
+ }
221
+
222
+ # nested types, here comes the fun part because the shapes cannot be unflattened,
223
+ # custom classes must appear in their flattened shape.
224
+ # This does not work in all cases but every time every available argument is flattened
225
+ # with the same number of tensors. The function does not check
226
+ # if that assumption is true.
227
+ flat_inputs, _max_spec = torch.utils._pytree.tree_flatten(
228
+ (self._max_args, self._max_kwargs)
229
+ )
230
+ torch._check(
231
+ len(flat_inputs) == len(flat_dynamic_shapes),
232
+ (
233
+ f"Length mismatch len(flat_inputs)={len(flat_inputs)}, "
234
+ f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}"
235
+ ),
236
+ )
237
+ mapping = {id(t): shape for t, shape in zip(flat_inputs, flat_dynamic_shapes)}
238
+ ds_args, ds_kwargs = flatten_unflatten_for_dynamic_shapes(
239
+ (self._max_args, self._max_kwargs), change_function=lambda t: mapping[id(t)]
240
+ )
241
+ if not ds_kwargs:
242
+ return tuple(ds_args)
243
+ if not ds_args:
244
+ return tuple(ds_kwargs)
245
+ pos_names = list(self.signature.parameters)[: len(ds_args)]
246
+ return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
247
+
248
+ def infer_arguments(
249
+ self, index: int | None = None
250
+ ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
251
+ # This is already checked by build_inputs_completed_with_none_values
252
+ # but this is not always well captured by tools checking types.
253
+ assert self._max_args is not None and self._max_kwargs is not None
254
+ candidate = None
255
+ if index is None:
256
+ for i, (args_kwargs, spec) in enumerate(zip(self.flat_inputs, self.inputs_specs)):
257
+ args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
258
+ if len(args) == len(self._max_args) and len(kwargs) == len(self._max_kwargs):
259
+ index = i
260
+ candidate = args, kwargs
261
+ break
262
+ if index is not None:
263
+ # found one available set.
264
+ args, kwargs = candidate or torch.utils._pytree.tree_unflatten(
265
+ self.flat_inputs[index], self.inputs_specs[index]
266
+ )
267
+ if not kwargs:
268
+ return args
269
+ if not args:
270
+ return kwargs
271
+ # We need to move args to kwargs
272
+ pos_names = list(self.signature.parameters)[: len(args)]
273
+ return {**dict(zip(pos_names, args)), **kwargs}
274
+
275
+ raise NotImplementedError(
276
+ "We could not find a good set of inputs/outputs. "
277
+ "We need to replace none by empty tensors."
278
+ )
279
+
280
+
281
+ class InputObserver:
282
+ def __init__(self, store_n_calls: int = 3):
283
+ self.store_n_calls = store_n_calls
284
+ self.info: InputObserverInfo | None = None
285
+
286
+ def _forward_captured(self, *args, _captured_forward=None, **kwargs):
287
+ assert _captured_forward is not None, "_captured_forward cannot be None"
288
+ assert self.info is not None, "info cannot be None"
289
+ n_stored = len(self.info)
290
+ if n_stored < self.store_n_calls:
291
+ self.info.add_inputs(args, kwargs)
292
+ res = _captured_forward(*args, **kwargs)
293
+ if n_stored < self.store_n_calls:
294
+ self.info.add_outputs(res)
295
+ return res
296
+
297
+ @contextlib.contextmanager
298
+ def __call__(self, model: torch.nn.Module):
299
+ if self.info is not None:
300
+ raise RuntimeError(
301
+ "This class was already used to capture a model. Please create a new one."
302
+ )
303
+ self.info = InputObserverInfo(signature=inspect.signature(model.forward))
304
+ forward_method = model.forward
305
+ model.forward = (
306
+ lambda *args, _captured_forward=forward_method, **kwargs: self._forward_captured(
307
+ *args, _captured_forward=_captured_forward, **kwargs
308
+ )
309
+ )
310
+ try:
311
+ yield self
312
+ finally:
313
+ model.forward = forward_method
314
+
315
+ def _check_captured(self):
316
+ if self.info is None:
317
+ raise RuntimeError("No inputs were captured.")
318
+
319
+ def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
320
+ self._check_captured()
321
+ assert self.info is not None # missed by type checking
322
+ return self.info.infer_dynamic_shapes()
323
+
324
+ def infer_arguments(
325
+ self, index: int | None = None
326
+ ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
327
+ self._check_captured()
328
+ assert self.info is not None # missed by type checking
329
+ return self.info.infer_arguments(index=index)
@@ -42,7 +42,6 @@ from .ops.op_slice import Slice_1, Slice_10
42
42
  from .ops.op_transpose_cast import Transpose2DCastFP16, Transpose2DCastFP32
43
43
  from .ops.op_tri_matrix import TriMatrix
44
44
 
45
-
46
45
  logger = getLogger("onnx-diagnostic-eval")
47
46
 
48
47
 
@@ -34,7 +34,6 @@ from ..helpers.torch_helper import to_tensor
34
34
  from .report_results_comparison import ReportResultComparison
35
35
  from .evaluator import ExtendedReferenceEvaluator
36
36
 
37
-
38
37
  PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
39
38
  Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
40
39
 
@@ -1,5 +1,4 @@
1
- from typing import Any, Dict, List, Tuple, Union
2
-
1
+ from typing import Any, Dict, List, Set, Tuple, Union
3
2
 
4
3
  ReportKeyNameType = Union[str, Tuple[str, int, str]]
5
4
  ReportKeyValueType = Tuple[int, Tuple[int, ...]]
@@ -14,6 +13,7 @@ class ReportResultComparison:
14
13
  :param tensors: tensor
15
14
  """
16
15
 
16
+ # pyrefly: ignore[unknown-name]
17
17
  def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
18
18
  from ..helpers.onnx_helper import dtype_to_tensor_dtype
19
19
  from ..helpers import max_diff, string_type
@@ -25,7 +25,9 @@ class ReportResultComparison:
25
25
  self.max_diff = max_diff
26
26
  self.tensors = tensors
27
27
  self._build_mapping()
28
+ self.unique_run_names: Set[str] = set()
28
29
 
30
+ # pyrefly: ignore[unknown-name]
29
31
  def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
30
32
  "Returns a key for a tensor, (onnx dtype, shape)."
31
33
  return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
@@ -59,12 +61,15 @@ class ReportResultComparison:
59
61
  for k, v in self.value.items():
60
62
  (i_run, run_name), ref_name = k
61
63
  d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
64
+ # pyrefly: ignore[no-matching-overload]
62
65
  d.update(v)
63
66
  rows.append(d)
64
67
  return rows
65
68
 
66
69
  def report(
67
- self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
70
+ self,
71
+ # pyrefly: ignore[unknown-name]
72
+ outputs: Dict[str, "torch.Tensor"], # noqa: F821
68
73
  ) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
69
74
  """
70
75
  For every tensor in outputs, compares it to every tensor held by
@@ -79,6 +84,7 @@ class ReportResultComparison:
79
84
  key = self.key(tensor)
80
85
  if key not in self.mapping:
81
86
  continue
87
+ # pyrefly: ignore[unknown-name]
82
88
  cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
83
89
  for held_key in self.mapping[key]:
84
90
  t2 = self.tensors[held_key]
@@ -63,7 +63,7 @@ class TorchOnnxEvaluator:
63
63
  * `functions`: local functions
64
64
 
65
65
  The class is not multithreaded. `runtime_info` gets updated
66
- by the the class. The list of available kernels is returned by function
66
+ by the class. The list of available kernels is returned by function
67
67
  :func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
68
68
  Example:
69
69
 
@@ -494,8 +494,10 @@ class TorchOnnxEvaluator:
494
494
  r = self.runtime_info[k]
495
495
  r.set_value(
496
496
  torch_ops.OpRunTensor(
497
+ # pyrefly: ignore[missing-attribute]
497
498
  v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
498
499
  is_constant=False,
500
+ # pyrefly: ignore[missing-attribute]
499
501
  may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
500
502
  )
501
503
  )
@@ -524,6 +526,7 @@ class TorchOnnxEvaluator:
524
526
  f"for kernel {type(kernel)}."
525
527
  )
526
528
  for name, t in zip(kernel.output, res):
529
+ # pyrefly: ignore[bad-argument-type]
527
530
  self.runtime_info[name].set_value(t)
528
531
  if self.verbose:
529
532
  for name in kernel.output:
@@ -644,6 +647,7 @@ class TorchOnnxEvaluator:
644
647
  f"for kernel {type(kernel)}."
645
648
  )
646
649
  for name, t in zip(kernel.output, res):
650
+ # pyrefly: ignore[bad-argument-type]
647
651
  self.runtime_info[name].set_value(t)
648
652
  else:
649
653
  assert isinstance(
@@ -1,7 +1,7 @@
1
1
  from typing import Any, Dict, List, Optional, Union, Tuple
2
2
  import onnx
3
3
  import torch
4
- from ...api import TensorLike
4
+ from ...typing import TensorLike
5
5
  from ...helpers import string_type
6
6
  from ...helpers.torch_helper import to_tensor
7
7
 
@@ -149,7 +149,7 @@ class OpRunSequence(OpRunValue):
149
149
  ) -> "OpRunSequence":
150
150
  "Inserts a value at a given position."
151
151
  assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
152
- new_seq = OpRunSequence()
152
+ new_seq = OpRunSequence() # type: ignore[abstract]
153
153
  seq = self.sequence.copy()
154
154
  new_seq.sequence = seq
155
155
  if position is None:
@@ -314,9 +314,7 @@ class OpRunKernel:
314
314
 
315
315
 
316
316
  class OpRunFunction(OpRunKernel):
317
- """
318
- Defines a kernel based on a local functions.
319
- """
317
+ """Defines a kernel based on a local functions."""
320
318
 
321
319
  def __init__(
322
320
  self,
@@ -46,7 +46,7 @@ class SequenceEmpty_11(OpRunOpSequence):
46
46
  )
47
47
 
48
48
  def run(self) -> OpRunSequence:
49
- return OpRunSequence(dtype=self.dtype)
49
+ return OpRunSequence(dtype=self.dtype) # type: ignore[abstract]
50
50
 
51
51
 
52
52
  class SequenceInsert_11(OpRunOpSequence):
@@ -3,7 +3,6 @@ import torch
3
3
  from ..helpers.config_helper import update_config, check_hasattr
4
4
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
5
 
6
-
7
6
  __TASK__ = "feature-extraction"
8
7
 
9
8
 
@@ -4,7 +4,6 @@ from .onnx_export_errors import (
4
4
  )
5
5
  from .patch_module import torch_export_rewrite
6
6
 
7
-
8
7
  # bypass_export_some_errors is the first name given to the patches.
9
8
  bypass_export_some_errors = torch_export_patches # type: ignore
10
9
 
@@ -986,7 +986,7 @@ def torch_export_rewrite(
986
986
  name = me.__qualname__
987
987
  spl = name.split(".")
988
988
  if len(spl) == 1:
989
- # This a function
989
+ # This is a function
990
990
  module = me.__module__
991
991
  if module in me.__globals__:
992
992
  mod = me.__globals__[module]
@@ -7,10 +7,10 @@ import transformers
7
7
 
8
8
  def patched__compute_dynamic_ntk_parameters(
9
9
  config: Optional[transformers.PretrainedConfig] = None,
10
- device: Optional["torch.device"] = None,
10
+ device: Optional[torch.device] = None,
11
11
  seq_len: Optional[int] = None,
12
12
  **rope_kwargs,
13
- ) -> Tuple["torch.Tensor", float]:
13
+ ) -> Tuple[torch.Tensor, float]:
14
14
  """
15
15
  manual patch:
16
16
  ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
@@ -1,6 +1,7 @@
1
1
  import itertools
2
2
  from typing import Any, Callable, List, Set, Tuple
3
3
  import torch
4
+ import transformers.cache_utils
4
5
  from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
5
6
 
6
7
  try:
@@ -22,22 +23,43 @@ from transformers.modeling_outputs import BaseModelOutput
22
23
  from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue
23
24
  from . import make_serialization_function_for_dataclass
24
25
 
25
-
26
26
  SUPPORTED_DATACLASSES: Set[type] = set()
27
27
  WRONG_REGISTRATIONS = {
28
28
  DynamicCache: "4.50",
29
29
  BaseModelOutput: None,
30
30
  }
31
+ SHORTEN_LAYER_NAMES = {
32
+ "DynamicLayer": "D",
33
+ "DynamicSlidingWindowLayer": "W",
34
+ "StaticLayer": "S",
35
+ "StaticSlidingWindowLayer": "X",
36
+ "D": "DynamicLayer",
37
+ "W": "DynamicSlidingWindowLayer",
38
+ "S": "StaticLayer",
39
+ "X": "StaticSlidingWindowLayer",
40
+ }
31
41
 
32
42
 
33
43
  def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
34
44
  ca = CacheKeyValue(cache)
35
45
  flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
36
- keys = list(
37
- itertools.chain.from_iterable(
38
- (f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
46
+ unique = set(ca.cls_layers) if ca.cls_layers else None
47
+ if (
48
+ cache.__class__.__name__ != "DynamicCache"
49
+ or unique is None
50
+ or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
51
+ ):
52
+ keys = list(
53
+ itertools.chain.from_iterable(
54
+ (f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
55
+ )
39
56
  )
40
- )
57
+ return flat, keys
58
+
59
+ keys = []
60
+ for i in range(len(ca.key_cache)):
61
+ letter = SHORTEN_LAYER_NAMES[ca.cls_layers[i].__name__]
62
+ keys.extend([f"key_{letter}{i}", f"value_{letter}{i}"])
41
63
  return flat, keys
42
64
 
43
65
 
@@ -55,7 +77,20 @@ def _unflatten_cache(
55
77
  output_type=None,
56
78
  ) -> DynamicCache:
57
79
  """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
58
- res = make_cache(list(zip(values[::2], values[1::2])))
80
+ expected = list(
81
+ itertools.chain.from_iterable(
82
+ (f"key_{i}", f"value_{i}") for i in range(len(values) // 2)
83
+ )
84
+ )
85
+ if expected == context:
86
+ res = make_cache(list(zip(values[::2], values[1::2])))
87
+ else:
88
+ cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
89
+ cls_layers = [
90
+ getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names
91
+ ]
92
+ res = make_cache(list(zip(values[::2], values[1::2])), cls_layers=cls_layers)
93
+
59
94
  assert output_type is None or isinstance(
60
95
  res, output_type
61
96
  ), f"Type mismatch between {output_type} (expected) and {type(res)}"
@@ -71,14 +106,6 @@ def flatten_dynamic_cache(
71
106
  dynamic_cache: DynamicCache,
72
107
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
73
108
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
74
- assert (
75
- not hasattr(dynamic_cache, "layers")
76
- or not dynamic_cache.layers
77
- or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
78
- ), (
79
- f"The serialization does not work yet on other layers "
80
- f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
81
- )
82
109
  return _flatten_key_value_cache(dynamic_cache)
83
110
 
84
111
 
@@ -86,14 +113,6 @@ def flatten_with_keys_dynamic_cache(
86
113
  dynamic_cache: DynamicCache,
87
114
  ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
88
115
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
89
- assert (
90
- not hasattr(dynamic_cache, "layers")
91
- or not dynamic_cache.layers
92
- or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
93
- ), (
94
- f"The serialization does not work yet on other layers "
95
- f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
96
- )
97
116
  return _flatten_with_keys_cache(dynamic_cache)
98
117
 
99
118
 
@@ -161,7 +180,9 @@ def unflatten_static_cache(
161
180
  ) -> StaticCache:
162
181
  """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
163
182
  return _unflatten_cache(
164
- lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
183
+ lambda *args, **kwargs: make_static_cache(
184
+ *args, max_cache_len=values[0].shape[2], **kwargs
185
+ ),
165
186
  values,
166
187
  context,
167
188
  output_type=output_type,
@@ -8,11 +8,9 @@ from .hghub.model_inputs import _preprocess_model_id
8
8
  from .hghub import get_untrained_model_with_inputs
9
9
  from .validate import filter_inputs, make_patch_kwargs
10
10
 
11
-
12
11
  CODE_SAMPLES = {
13
12
  "imports": "from typing import Any\nimport torch",
14
- "get_model_with_inputs": textwrap.dedent(
15
- """
13
+ "get_model_with_inputs": textwrap.dedent("""
16
14
  def get_model_with_inputs(
17
15
  model_id:str,
18
16
  subfolder: str | None = None,
@@ -57,8 +55,7 @@ CODE_SAMPLES = {
57
55
  if device:
58
56
  data["model"] = data["model"].to(device)
59
57
  return data["model"]
60
- """
61
- ),
58
+ """),
62
59
  }
63
60
 
64
61
 
@@ -198,7 +195,7 @@ def code_sample(
198
195
  this is not always possible
199
196
  :param use_pretrained: use the trained model, not the untrained one
200
197
  :param optimization: optimization to apply to the exported model,
201
- depend on the the exporter
198
+ depend on the exporter
202
199
  :param quiet: if quiet, catches exception if any issue
203
200
  :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
204
201
  if True before exporting
@@ -326,11 +323,9 @@ def code_sample(
326
323
  imports,
327
324
  cache_import,
328
325
  CODE_SAMPLES["get_model_with_inputs"],
329
- textwrap.dedent(
330
- f"""
326
+ textwrap.dedent(f"""
331
327
  model = get_model_with_inputs({model_args})
332
- """
333
- ),
328
+ """),
334
329
  f"inputs = {input_code}",
335
330
  exporter_code,
336
331
  ]