onnx-diagnostic 0.8.9__py3-none-any.whl → 0.8.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
  4. onnx_diagnostic/export/api.py +24 -12
  5. onnx_diagnostic/export/validate.py +2 -0
  6. onnx_diagnostic/ext_test_case.py +32 -15
  7. onnx_diagnostic/helpers/args_helper.py +1 -0
  8. onnx_diagnostic/helpers/bench_run.py +0 -1
  9. onnx_diagnostic/helpers/cache_helper.py +6 -6
  10. onnx_diagnostic/helpers/doc_helper.py +7 -4
  11. onnx_diagnostic/helpers/graph_helper.py +6 -6
  12. onnx_diagnostic/helpers/log_helper.py +37 -14
  13. onnx_diagnostic/helpers/memory_peak.py +5 -1
  14. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  15. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  16. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  17. onnx_diagnostic/helpers/ort_session.py +0 -1
  18. onnx_diagnostic/helpers/torch_helper.py +8 -9
  19. onnx_diagnostic/investigate/__init__.py +0 -0
  20. onnx_diagnostic/investigate/input_observer.py +329 -0
  21. onnx_diagnostic/reference/evaluator.py +0 -1
  22. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  23. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  24. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  25. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  26. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  27. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  28. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  29. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +5 -1
  30. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  32. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +14 -13
  33. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
  34. onnx_diagnostic/torch_models/code_sample.py +5 -10
  35. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +7 -12
  37. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -0
  38. onnx_diagnostic/torch_models/validate.py +1 -1
  39. onnx_diagnostic/torch_onnx/compare.py +0 -1
  40. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  41. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  42. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  43. onnx_diagnostic/typing.py +15 -0
  44. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
  45. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +48 -46
  46. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
  47. onnx_diagnostic/api.py +0 -15
  48. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
  49. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,329 @@
1
+ import contextlib
2
+ import inspect
3
+ from typing import Any, Callable, Sequence
4
+ import torch
5
+
6
+
7
+ def flatten_unflatten_for_dynamic_shapes(
8
+ obj: Any,
9
+ use_dict: bool = True,
10
+ change_function: Callable[[torch.Tensor], Any] | None = None,
11
+ ) -> Any:
12
+ """
13
+ Returns the object in a different structure similar to what
14
+ the definition of the dynamic shapes should use.
15
+
16
+ Args:
17
+ obj:
18
+ object from a custom class
19
+ use_dict:
20
+ closer to the original result but
21
+ :func:`torch.export.export` only considers the values,
22
+ the context gives the dictionary keys but it is not expressed
23
+ in the dynamic shapes, these specifications seems to be different
24
+ for the strict and non strict mode. It also preserves tuple.
25
+ change_function:
26
+ to modify the tensor in the structure itself,
27
+ like replace them by a shape
28
+
29
+ Returns:
30
+ the serialized object
31
+ """
32
+ if isinstance(obj, torch.Tensor):
33
+ return change_function(obj) if change_function else obj
34
+ flat, spec = torch.utils._pytree.tree_flatten(obj)
35
+ start = 0
36
+ end = 0
37
+ subtrees = []
38
+ for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
39
+ end += subspec.num_leaves
40
+ value = subspec.unflatten(flat[start:end])
41
+ value = flatten_unflatten_for_dynamic_shapes(
42
+ value, use_dict=use_dict, change_function=change_function
43
+ )
44
+ subtrees.append(value)
45
+ start = end
46
+ if use_dict:
47
+ if spec.type is dict:
48
+ # This is a dictionary.
49
+ return dict(zip(spec.context, subtrees))
50
+ if spec.type is tuple:
51
+ return tuple(subtrees)
52
+ if spec.type is list:
53
+ return list(subtrees)
54
+ if spec.type is None and not subtrees:
55
+ return None
56
+ if spec.context:
57
+ # This is a custom class with attributes.
58
+ # It is returned as a list.
59
+ return list(subtrees)
60
+ raise ValueError(
61
+ f"Unable to interpret spec type {spec.type} "
62
+ f"(type is {type(spec.type)}, context is {spec.context}), "
63
+ f"spec={spec}, subtrees={subtrees}"
64
+ )
65
+ # This is a list.
66
+ return subtrees
67
+
68
+
69
+ def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int]:
70
+ """
71
+ Returns the list of dynamic dimensions given a list of shapes
72
+ corresponding to the same tensor.
73
+
74
+ Args:
75
+ shape_list:
76
+ list of shapes, they must all have the same length
77
+
78
+ Returns:
79
+ list of dynamic dimensions
80
+ """
81
+ unique_ranks = {len(shape) for shape in shape_list}
82
+ torch._check(
83
+ len(unique_ranks) == 1, lambda: "all shapes in shape_list must have the same rank"
84
+ )
85
+ rank = unique_ranks.pop()
86
+ dynamic = []
87
+ for i in range(rank):
88
+ dims = [shape[i] for shape in shape_list]
89
+ if len(set(dims)) > 1:
90
+ dynamic.append(i)
91
+ return dynamic
92
+
93
+
94
+ class InputObserverInfo:
95
+ def __init__(self, signature: inspect.Signature):
96
+ # pyrefly: ignore
97
+ self.inputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
98
+ self.flat_inputs: list[list[torch.Tensor | None]] = []
99
+ # pyrefly: ignore
100
+ self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
101
+ self.flat_outputs: list[torch.Tensor | list[torch.Tensor]] = []
102
+ self.signature = signature
103
+
104
+ self._max_args: tuple[Any, torch.Tensor] | None = None
105
+ self._max_kwargs: dict[str, torch.Tensor] | None = None
106
+
107
+ def __len__(self) -> int:
108
+ return len(self.flat_inputs)
109
+
110
+ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
111
+ kwargs = {
112
+ k: v
113
+ for k, v in kwargs.items()
114
+ if v is not None and not isinstance(v, (int, float, bool))
115
+ }
116
+ flat_args, spec = torch.utils._pytree.tree_flatten((args, kwargs))
117
+ self.inputs_specs.append(spec)
118
+ cloned = [
119
+ (None if not isinstance(t, torch.Tensor) else t.clone().detach())
120
+ for t in flat_args
121
+ ]
122
+ self.flat_inputs.append(cloned)
123
+
124
+ cloned_args, cloned_kwargs = torch.utils._pytree.tree_unflatten(cloned, spec)
125
+ if self._max_args is None or len(cloned_args) > len(self._max_args):
126
+ self._max_args = cloned_args
127
+ if self._max_kwargs is None or len(cloned_kwargs) > len(self._max_kwargs):
128
+ self._max_kwargs = cloned_kwargs
129
+
130
+ def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
131
+ flat_res, spec = torch.utils._pytree.tree_flatten(res)
132
+ self.outputs_specs.append(spec)
133
+ self.flat_outputs.append([t.clone().detach() for t in flat_res])
134
+
135
+ def build_inputs_completed_with_none_values(self) -> list[list[torch.Tensor]]:
136
+ # Let's compute the sizes of each independently.
137
+ if not self.flat_inputs or self._max_args is None or self._max_kwargs is None:
138
+ raise RuntimeError("No inputs were captured.")
139
+ arg_sizes = [len(torch.utils._pytree.tree_flatten(a)[0]) for a in self._max_args]
140
+ kwarg_sizes = {
141
+ k: len(torch.utils._pytree.tree_flatten(v)[0]) for k, v in self._max_kwargs.items()
142
+ }
143
+
144
+ # Let's reprocess everything.
145
+ captured_inputs: dict[int | str, int] = {}
146
+ new_flat_inputs = []
147
+ for args_kwargs, spec in zip(self.flat_inputs, self.inputs_specs):
148
+ args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
149
+ if len(set(kwargs) | set(self._max_kwargs)) > len(self._max_kwargs):
150
+ raise RuntimeError(
151
+ "At least one call to the observed model "
152
+ "must contain all the named arguments."
153
+ )
154
+ flat = []
155
+ for i in range(len(self._max_args)):
156
+ if i < len(args):
157
+ ts = torch.utils._pytree.tree_flatten(args[i])[0]
158
+ if i in captured_inputs and captured_inputs[i] != len(ts):
159
+ raise RuntimeError(
160
+ f"Positional argument {i} has {len(ts)} tensors "
161
+ f"but previously got {captured_inputs[i]} tensors. "
162
+ f"Inference is impossible in that case."
163
+ )
164
+ captured_inputs[i] = len(ts)
165
+ flat.extend(ts)
166
+ else:
167
+ flat.extend([None for _ in range(arg_sizes[i])])
168
+ for k in self._max_kwargs:
169
+ if k in kwargs:
170
+ ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
171
+ if k in captured_inputs and captured_inputs[k] != len(ts):
172
+ raise RuntimeError(
173
+ f"Named argument {k!r} has {len(ts)} tensors "
174
+ f"but previously got {captured_inputs[k]} tensors. "
175
+ f"Inference is impossible in that case."
176
+ )
177
+ captured_inputs[k] = len(ts)
178
+ flat.extend(ts)
179
+ else:
180
+ flat.extend([None for _ in range(kwarg_sizes[k])])
181
+ new_flat_inputs.append(flat)
182
+ return new_flat_inputs
183
+
184
+ def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
185
+ flat_inputs = self.build_inputs_completed_with_none_values()
186
+ # This is already checked by build_inputs_completed_with_none_values
187
+ # but this is not always well captured by tools checking types.
188
+ assert self._max_args is not None and self._max_kwargs is not None
189
+ if len({len(flat) for flat in flat_inputs}) != 1:
190
+ raise NotImplementedError(
191
+ "infer_dynamic_shapes is not implemented "
192
+ "when the number of input tensors are not the same."
193
+ )
194
+ shape_lists = [
195
+ [(None if t is None else t.shape) for t in tensors] for tensors in flat_inputs
196
+ ]
197
+ n_tensors = len(shape_lists[0])
198
+ dynamic_shapes = [
199
+ infer_dynamic_dimensions(
200
+ [s for s in [shapes[index] for shapes in shape_lists] if s is not None]
201
+ )
202
+ for index in range(n_tensors)
203
+ ]
204
+ cst = torch.export.Dim.DYNAMIC
205
+ flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
206
+ if len(flat_dynamic_shapes) == len(self._max_args) + len(self._max_kwargs):
207
+ # It means forward method is called with tensors only.
208
+ if not self._max_kwargs:
209
+ # only positional arguments
210
+ return tuple(flat_dynamic_shapes)
211
+ if not self._max_args:
212
+ # only named arguments
213
+ return dict(zip(list(self._max_kwargs), flat_dynamic_shapes))
214
+ # positional arguments needs to be moved to the named arguments
215
+ n_args = len(self._max_args)
216
+ pos_names = list(self.signature.parameters)[:n_args]
217
+ return {
218
+ **dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
219
+ **dict(zip(list(self._max_kwargs), flat_dynamic_shapes[n_args:])),
220
+ }
221
+
222
+ # nested types, here comes the fun part because the shapes cannot be unflattened,
223
+ # custom classes must appear in their flattened shape.
224
+ # This does not work in all cases but every time every available argument is flattened
225
+ # with the same number of tensors. The function does not check
226
+ # if that assumption is true.
227
+ flat_inputs, _max_spec = torch.utils._pytree.tree_flatten(
228
+ (self._max_args, self._max_kwargs)
229
+ )
230
+ torch._check(
231
+ len(flat_inputs) == len(flat_dynamic_shapes),
232
+ (
233
+ f"Length mismatch len(flat_inputs)={len(flat_inputs)}, "
234
+ f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}"
235
+ ),
236
+ )
237
+ mapping = {id(t): shape for t, shape in zip(flat_inputs, flat_dynamic_shapes)}
238
+ ds_args, ds_kwargs = flatten_unflatten_for_dynamic_shapes(
239
+ (self._max_args, self._max_kwargs), change_function=lambda t: mapping[id(t)]
240
+ )
241
+ if not ds_kwargs:
242
+ return tuple(ds_args)
243
+ if not ds_args:
244
+ return tuple(ds_kwargs)
245
+ pos_names = list(self.signature.parameters)[: len(ds_args)]
246
+ return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
247
+
248
+ def infer_arguments(
249
+ self, index: int | None = None
250
+ ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
251
+ # This is already checked by build_inputs_completed_with_none_values
252
+ # but this is not always well captured by tools checking types.
253
+ assert self._max_args is not None and self._max_kwargs is not None
254
+ candidate = None
255
+ if index is None:
256
+ for i, (args_kwargs, spec) in enumerate(zip(self.flat_inputs, self.inputs_specs)):
257
+ args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
258
+ if len(args) == len(self._max_args) and len(kwargs) == len(self._max_kwargs):
259
+ index = i
260
+ candidate = args, kwargs
261
+ break
262
+ if index is not None:
263
+ # found one available set.
264
+ args, kwargs = candidate or torch.utils._pytree.tree_unflatten(
265
+ self.flat_inputs[index], self.inputs_specs[index]
266
+ )
267
+ if not kwargs:
268
+ return args
269
+ if not args:
270
+ return kwargs
271
+ # We need to move args to kwargs
272
+ pos_names = list(self.signature.parameters)[: len(args)]
273
+ return {**dict(zip(pos_names, args)), **kwargs}
274
+
275
+ raise NotImplementedError(
276
+ "We could not find a good set of inputs/outputs. "
277
+ "We need to replace none by empty tensors."
278
+ )
279
+
280
+
281
+ class InputObserver:
282
+ def __init__(self, store_n_calls: int = 3):
283
+ self.store_n_calls = store_n_calls
284
+ self.info: InputObserverInfo | None = None
285
+
286
+ def _forward_captured(self, *args, _captured_forward=None, **kwargs):
287
+ assert _captured_forward is not None, "_captured_forward cannot be None"
288
+ assert self.info is not None, "info cannot be None"
289
+ n_stored = len(self.info)
290
+ if n_stored < self.store_n_calls:
291
+ self.info.add_inputs(args, kwargs)
292
+ res = _captured_forward(*args, **kwargs)
293
+ if n_stored < self.store_n_calls:
294
+ self.info.add_outputs(res)
295
+ return res
296
+
297
+ @contextlib.contextmanager
298
+ def __call__(self, model: torch.nn.Module):
299
+ if self.info is not None:
300
+ raise RuntimeError(
301
+ "This class was already used to capture a model. Please create a new one."
302
+ )
303
+ self.info = InputObserverInfo(signature=inspect.signature(model.forward))
304
+ forward_method = model.forward
305
+ model.forward = (
306
+ lambda *args, _captured_forward=forward_method, **kwargs: self._forward_captured(
307
+ *args, _captured_forward=_captured_forward, **kwargs
308
+ )
309
+ )
310
+ try:
311
+ yield self
312
+ finally:
313
+ model.forward = forward_method
314
+
315
+ def _check_captured(self):
316
+ if self.info is None:
317
+ raise RuntimeError("No inputs were captured.")
318
+
319
+ def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
320
+ self._check_captured()
321
+ assert self.info is not None # missed by type checking
322
+ return self.info.infer_dynamic_shapes()
323
+
324
+ def infer_arguments(
325
+ self, index: int | None = None
326
+ ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
327
+ self._check_captured()
328
+ assert self.info is not None # missed by type checking
329
+ return self.info.infer_arguments(index=index)
@@ -42,7 +42,6 @@ from .ops.op_slice import Slice_1, Slice_10
42
42
  from .ops.op_transpose_cast import Transpose2DCastFP16, Transpose2DCastFP32
43
43
  from .ops.op_tri_matrix import TriMatrix
44
44
 
45
-
46
45
  logger = getLogger("onnx-diagnostic-eval")
47
46
 
48
47
 
@@ -34,7 +34,6 @@ from ..helpers.torch_helper import to_tensor
34
34
  from .report_results_comparison import ReportResultComparison
35
35
  from .evaluator import ExtendedReferenceEvaluator
36
36
 
37
-
38
37
  PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
39
38
  Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
40
39
 
@@ -1,5 +1,4 @@
1
- from typing import Any, Dict, List, Tuple, Union
2
-
1
+ from typing import Any, Dict, List, Set, Tuple, Union
3
2
 
4
3
  ReportKeyNameType = Union[str, Tuple[str, int, str]]
5
4
  ReportKeyValueType = Tuple[int, Tuple[int, ...]]
@@ -14,6 +13,7 @@ class ReportResultComparison:
14
13
  :param tensors: tensor
15
14
  """
16
15
 
16
+ # pyrefly: ignore[unknown-name]
17
17
  def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
18
18
  from ..helpers.onnx_helper import dtype_to_tensor_dtype
19
19
  from ..helpers import max_diff, string_type
@@ -25,7 +25,9 @@ class ReportResultComparison:
25
25
  self.max_diff = max_diff
26
26
  self.tensors = tensors
27
27
  self._build_mapping()
28
+ self.unique_run_names: Set[str] = set()
28
29
 
30
+ # pyrefly: ignore[unknown-name]
29
31
  def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
30
32
  "Returns a key for a tensor, (onnx dtype, shape)."
31
33
  return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
@@ -59,12 +61,15 @@ class ReportResultComparison:
59
61
  for k, v in self.value.items():
60
62
  (i_run, run_name), ref_name = k
61
63
  d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
64
+ # pyrefly: ignore[no-matching-overload]
62
65
  d.update(v)
63
66
  rows.append(d)
64
67
  return rows
65
68
 
66
69
  def report(
67
- self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
70
+ self,
71
+ # pyrefly: ignore[unknown-name]
72
+ outputs: Dict[str, "torch.Tensor"], # noqa: F821
68
73
  ) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
69
74
  """
70
75
  For every tensor in outputs, compares it to every tensor held by
@@ -79,6 +84,7 @@ class ReportResultComparison:
79
84
  key = self.key(tensor)
80
85
  if key not in self.mapping:
81
86
  continue
87
+ # pyrefly: ignore[unknown-name]
82
88
  cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
83
89
  for held_key in self.mapping[key]:
84
90
  t2 = self.tensors[held_key]
@@ -63,7 +63,7 @@ class TorchOnnxEvaluator:
63
63
  * `functions`: local functions
64
64
 
65
65
  The class is not multithreaded. `runtime_info` gets updated
66
- by the the class. The list of available kernels is returned by function
66
+ by the class. The list of available kernels is returned by function
67
67
  :func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
68
68
  Example:
69
69
 
@@ -494,8 +494,10 @@ class TorchOnnxEvaluator:
494
494
  r = self.runtime_info[k]
495
495
  r.set_value(
496
496
  torch_ops.OpRunTensor(
497
+ # pyrefly: ignore[missing-attribute]
497
498
  v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
498
499
  is_constant=False,
500
+ # pyrefly: ignore[missing-attribute]
499
501
  may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
500
502
  )
501
503
  )
@@ -524,6 +526,7 @@ class TorchOnnxEvaluator:
524
526
  f"for kernel {type(kernel)}."
525
527
  )
526
528
  for name, t in zip(kernel.output, res):
529
+ # pyrefly: ignore[bad-argument-type]
527
530
  self.runtime_info[name].set_value(t)
528
531
  if self.verbose:
529
532
  for name in kernel.output:
@@ -644,6 +647,7 @@ class TorchOnnxEvaluator:
644
647
  f"for kernel {type(kernel)}."
645
648
  )
646
649
  for name, t in zip(kernel.output, res):
650
+ # pyrefly: ignore[bad-argument-type]
647
651
  self.runtime_info[name].set_value(t)
648
652
  else:
649
653
  assert isinstance(
@@ -1,7 +1,7 @@
1
1
  from typing import Any, Dict, List, Optional, Union, Tuple
2
2
  import onnx
3
3
  import torch
4
- from ...api import TensorLike
4
+ from ...typing import TensorLike
5
5
  from ...helpers import string_type
6
6
  from ...helpers.torch_helper import to_tensor
7
7
 
@@ -149,7 +149,7 @@ class OpRunSequence(OpRunValue):
149
149
  ) -> "OpRunSequence":
150
150
  "Inserts a value at a given position."
151
151
  assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
152
- new_seq = OpRunSequence()
152
+ new_seq = OpRunSequence() # type: ignore[abstract]
153
153
  seq = self.sequence.copy()
154
154
  new_seq.sequence = seq
155
155
  if position is None:
@@ -314,9 +314,7 @@ class OpRunKernel:
314
314
 
315
315
 
316
316
  class OpRunFunction(OpRunKernel):
317
- """
318
- Defines a kernel based on a local functions.
319
- """
317
+ """Defines a kernel based on a local functions."""
320
318
 
321
319
  def __init__(
322
320
  self,
@@ -46,7 +46,7 @@ class SequenceEmpty_11(OpRunOpSequence):
46
46
  )
47
47
 
48
48
  def run(self) -> OpRunSequence:
49
- return OpRunSequence(dtype=self.dtype)
49
+ return OpRunSequence(dtype=self.dtype) # type: ignore[abstract]
50
50
 
51
51
 
52
52
  class SequenceInsert_11(OpRunOpSequence):
@@ -3,7 +3,6 @@ import torch
3
3
  from ..helpers.config_helper import update_config, check_hasattr
4
4
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
5
 
6
-
7
6
  __TASK__ = "feature-extraction"
8
7
 
9
8
 
@@ -4,7 +4,6 @@ from .onnx_export_errors import (
4
4
  )
5
5
  from .patch_module import torch_export_rewrite
6
6
 
7
-
8
7
  # bypass_export_some_errors is the first name given to the patches.
9
8
  bypass_export_some_errors = torch_export_patches # type: ignore
10
9
 
@@ -305,7 +305,7 @@ def serialization_functions(
305
305
 
306
306
 
307
307
  def unregister_class_serialization(cls: type, verbose: int = 0):
308
- """Undo the registration."""
308
+ """Undo the registration for a class."""
309
309
  # torch.utils._pytree._deregister_pytree_flatten_spec(cls)
310
310
  if cls in torch.fx._pytree.SUPPORTED_NODES:
311
311
  del torch.fx._pytree.SUPPORTED_NODES[cls]
@@ -333,6 +333,10 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
333
333
 
334
334
 
335
335
  def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
336
+ """
337
+ Undo the registration made by
338
+ :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`.
339
+ """
336
340
  cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
337
341
  for cls in cls_ensemble:
338
342
  if undo.get(cls.__name__, False):
@@ -986,7 +986,7 @@ def torch_export_rewrite(
986
986
  name = me.__qualname__
987
987
  spl = name.split(".")
988
988
  if len(spl) == 1:
989
- # This a function
989
+ # This is a function
990
990
  module = me.__module__
991
991
  if module in me.__globals__:
992
992
  mod = me.__globals__[module]
@@ -7,10 +7,10 @@ import transformers
7
7
 
8
8
  def patched__compute_dynamic_ntk_parameters(
9
9
  config: Optional[transformers.PretrainedConfig] = None,
10
- device: Optional["torch.device"] = None,
10
+ device: Optional[torch.device] = None,
11
11
  seq_len: Optional[int] = None,
12
12
  **rope_kwargs,
13
- ) -> Tuple["torch.Tensor", float]:
13
+ ) -> Tuple[torch.Tensor, float]:
14
14
  """
15
15
  manual patch:
16
16
  ``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
@@ -524,13 +524,16 @@ class patched_ShapeEnv:
524
524
 
525
525
  transmute_into_runtime_assert = False
526
526
 
527
+ backed_var_to_val = getattr(
528
+ self, "backed_var_to_val", getattr(self, "var_to_val", {})
529
+ )
527
530
  concrete_val = None
528
- if not (expr.free_symbols <= self.var_to_val.keys()):
531
+ if not (expr.free_symbols <= backed_var_to_val.keys()):
529
532
  # TODO: dedupe this with _maybe_evaluate_static
530
533
  # Attempt to eliminate the unbacked SymInt
531
534
  new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
532
535
  assert new_expr is not None
533
- if not (new_expr.free_symbols <= self.var_to_val.keys()):
536
+ if not (new_expr.free_symbols <= backed_var_to_val.keys()):
534
537
  ok = False
535
538
 
536
539
  # fallback_value is set when guard_or_true or guard_or_false are used.
@@ -541,17 +544,15 @@ class patched_ShapeEnv:
541
544
  # oblivious_var_to_val will be defined iff we have sizes
542
545
  # with DimDynamic.OBLIVIOUS_SIZE type.
543
546
  # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
544
- var_to_val = getattr(
545
- self,
546
- "unbacked_var_to_val",
547
- getattr(self, "oblivious_var_to_val", False),
548
- )
549
547
  if (
550
- var_to_val
551
- and not (correct_hint := orig_expr.xreplace(var_to_val)).free_symbols
548
+ backed_var_to_val
549
+ and getattr(self, "real_tensor_prop_unbacked_vals", True)
550
+ and not (
551
+ correct_hint := orig_expr.xreplace(backed_var_to_val)
552
+ ).free_symbols
552
553
  and not (
553
554
  counterfactual_hint := orig_expr.xreplace(
554
- {k: max(2, v) for k, v in var_to_val.items()}
555
+ {k: max(2, v) for k, v in backed_var_to_val.items()}
555
556
  )
556
557
  ).free_symbols
557
558
  and correct_hint == counterfactual_hint
@@ -574,10 +575,10 @@ class patched_ShapeEnv:
574
575
  # and if they pass we add a runtime assertions and continue.
575
576
  if (
576
577
  not ok
577
- and var_to_val
578
+ and backed_var_to_val
578
579
  and not (
579
- unsound_result := orig_expr.xreplace(var_to_val).xreplace(
580
- var_to_val
580
+ unsound_result := orig_expr.xreplace(backed_var_to_val).xreplace(
581
+ backed_var_to_val
581
582
  )
582
583
  ).free_symbols
583
584
  ):