onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,340 @@
1
+ import difflib
2
+ import inspect
3
+ import pprint
4
+ import re
5
+ import textwrap
6
+ from typing import Any, Dict, Callable, List, Optional, Tuple, Union
7
+
8
+
9
+ def clean_code_with_black(code: str) -> str:
10
+ """Changes the code style with :epkg:`black` if available."""
11
+ code = textwrap.dedent(code)
12
+ try:
13
+ import black
14
+ except ImportError:
15
+ return code
16
+ try:
17
+ return black.format_str(code, mode=black.FileMode(line_length=98))
18
+ except black.parsing.InvalidInput as e:
19
+ raise RuntimeError(f"Unable to parse code\n\n---\n{code}\n---\n") from e
20
+
21
+
22
+ def make_diff_code(code1: str, code2: str, output: Optional[str] = None) -> str:
23
+ """
24
+ Creates a diff between two codes.
25
+
26
+ :param code1: first code
27
+ :param code2: second code
28
+ :param output: if not empty, stores the output in this file
29
+ :return: diff
30
+ """
31
+ text = "\n".join(
32
+ difflib.unified_diff(
33
+ code1.strip().splitlines(),
34
+ code2.strip().splitlines(),
35
+ fromfile="original",
36
+ tofile="rewritten",
37
+ lineterm="",
38
+ )
39
+ )
40
+ if output:
41
+ with open(output, "w") as f:
42
+ f.write(text)
43
+ return text
44
+
45
+
46
+ class PatchInfo:
47
+ """
48
+ Stores information about patches.
49
+
50
+ :param function_to_patch: function to patch
51
+ :param patch: function patched
52
+ :param family: a category, anything to classify the patch
53
+ """
54
+
55
+ __slots__ = ("depends_on", "family", "function_to_patch", "patch")
56
+
57
+ def __init__(
58
+ self, function_to_patch: Union[str, Callable], patch: Callable, family: str = ""
59
+ ):
60
+ assert callable(function_to_patch) or isinstance(function_to_patch, str), (
61
+ f"function_to_patch is not a function but {type(function_to_patch)} "
62
+ f"- {function_to_patch!r}"
63
+ )
64
+ assert callable(patch), (
65
+ f"function_to_patch is not a function but {type(patch)} - {patch!r}, "
66
+ f"function_to_patch={function_to_patch!r}"
67
+ )
68
+ assert not callable(function_to_patch) or not function_to_patch.__name__.startswith(
69
+ "patched_"
70
+ ), (
71
+ f"A patch was probably not removed because function_to_patch="
72
+ f"{function_to_patch!r} and patch={patch!r}"
73
+ )
74
+ self.family = family
75
+ self.function_to_patch = function_to_patch
76
+ self.patch = patch
77
+ self.depends_on: List[PatchInfo] = []
78
+
79
+ def add_dependency(self, patch_info: "PatchInfo"):
80
+ self.depends_on.append(patch_info)
81
+
82
+ def __repr__(self) -> str:
83
+ "usual"
84
+ return (
85
+ (
86
+ f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r}, "
87
+ f"{self.family!r})"
88
+ )
89
+ if self.family
90
+ else f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r})"
91
+ )
92
+
93
+ def to_tuple(self) -> Tuple[str, Callable, Callable]:
94
+ "usual"
95
+ return (self.family, self.function_to_patch, self.patch)
96
+
97
+ def to_dict(self) -> Dict[str, Any]:
98
+ "usual"
99
+ return {k: getattr(self, k) for k in self.__slots__}
100
+
101
+ def make_diff(self) -> str:
102
+ """Returns a diff as a string."""
103
+ if isinstance(self.function_to_patch, str):
104
+ return clean_code_with_black(inspect.getsource(self.patch))
105
+ src1 = clean_code_with_black(inspect.getsource(self.function_to_patch))
106
+ src2 = clean_code_with_black(inspect.getsource(self.patch))
107
+ diff = make_diff_code(src1, src2)
108
+ if not self.depends_on:
109
+ return diff
110
+ res = [diff]
111
+ for d in self.depends_on:
112
+ res.append("")
113
+ res.append(d.make_diff())
114
+ return "\n".join(res)
115
+
116
+ @classmethod
117
+ def function_name(cls, f: Callable) -> str:
118
+ return f.__qualname__
119
+
120
+ def format_diff(self, format: str = "raw") -> str:
121
+ """
122
+ Format a diff between two function as a string.
123
+
124
+ :param format: ``'raw'`` or ``'rst'``
125
+ :return: diff
126
+
127
+ .. runpython::
128
+ :showcode:
129
+ :rst:
130
+
131
+ import transformers
132
+ import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
133
+ from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
134
+ from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
135
+ patched_eager_mask,
136
+ )
137
+
138
+ eager_mask = transformers.masking_utils.eager_mask
139
+ diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
140
+ print(diff)
141
+ """
142
+ diff = self.make_diff()
143
+ kind = self.family or ""
144
+ if kind:
145
+ kind = f"{kind}: "
146
+ function_to_pach_name = (
147
+ f"{self.function_to_patch!r}"
148
+ if isinstance(self.function_to_patch, str)
149
+ else self.function_name(self.function_to_patch)
150
+ )
151
+ patch_name = self.function_name(self.patch)
152
+ title = f"{kind}{function_to_pach_name} -> {patch_name}"
153
+ if format == "raw":
154
+ return f"{title}\n{diff}"
155
+
156
+ rows = [
157
+ title,
158
+ "=" * len(title),
159
+ "",
160
+ ".. code-block:: diff",
161
+ " :linenos:",
162
+ "",
163
+ textwrap.indent(diff, prefix=" "),
164
+ ]
165
+ return "\n".join(rows)
166
+
167
+
168
+ class PatchDetails:
169
+ """
170
+ This class is used to store patching information.
171
+ This helps understanding which rewriting was applied to which
172
+ method of functions. Page :ref:`l-patch-diff` contains all the
173
+ diff for all the implemented patches.
174
+
175
+ .. runpython::
176
+ :showcode:
177
+ :rst:
178
+
179
+ import torch
180
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
181
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
182
+ from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
183
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
184
+
185
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
186
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
187
+ details = PatchDetails()
188
+ with torch_export_patches(
189
+ patch_transformers=True, patch_details=details, patch_torch=False
190
+ ):
191
+ ep = torch.export.export(
192
+ model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
193
+ )
194
+ patches = details.patches_involded_in_graph(ep.graph)
195
+ report = details.make_report(patches, format="rst")
196
+ print(report)
197
+ """
198
+
199
+ def __init__(self):
200
+ self.patched = []
201
+ self.find_cache = {}
202
+
203
+ def find(self, name: str) -> Optional[PatchInfo]:
204
+ "Finds a patch by name."
205
+ if name in self.find_cache:
206
+ return self.find_cache[name]
207
+ for p in self.patched:
208
+ if p.patch.__name__ == name:
209
+ self.find_cache[name] = p
210
+ return p
211
+ return None
212
+
213
+ def append(
214
+ self, family: str, function_to_patch: Union[str, Callable], patch: Callable
215
+ ) -> PatchInfo:
216
+ """
217
+ Stores a patch.
218
+
219
+ :param family: a category, anything to classify the patch
220
+ :param function_to_patch: function to patch
221
+ :param patch: function patched
222
+ :return: instance of PatchInfo
223
+ """
224
+ p = PatchInfo(function_to_patch, patch, family=family)
225
+ self.patched.append(p)
226
+ return p
227
+
228
+ @property
229
+ def n_patches(self) -> int:
230
+ "Returns the number of stored patches."
231
+ # Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
232
+ return len(self.patched)
233
+
234
+ def data(self) -> List[Dict[str, Any]]:
235
+ """Returns the data for a dataframe."""
236
+ return [p.to_dict() for p in self.patched]
237
+
238
+ def patches_involded_in_graph(
239
+ self, graph: "torch.fx.Graph" # noqa: F821
240
+ ) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
241
+ """
242
+ Enumerates all patches impacting a graph.
243
+ The function goes through the graph node (only the main graph) and
244
+ looks into the metadata to determine if a listed patch was involved.
245
+
246
+ :param graph: fx graph
247
+ :return: list of nodes impacted by a patch
248
+ """
249
+ patches = []
250
+ for patch in self.patched:
251
+ f = patch.patch
252
+ source = inspect.getsourcefile(f)
253
+ lines, lineno = inspect.getsourcelines(f)
254
+ interval = [lineno, lineno + len(lines)]
255
+ patches.append((patch, f, source, interval))
256
+
257
+ cst = "onnx_diagnostic"
258
+ node_stack = []
259
+ for node in graph.nodes:
260
+ meta = node.meta
261
+ if "stack_trace" not in meta:
262
+ continue
263
+ stack = meta["stack_trace"]
264
+ if cst not in stack:
265
+ # to reduce the cost of the next iteration
266
+ continue
267
+ node_stack.append((node, stack))
268
+
269
+ patch_node = []
270
+ patched_nodes = set()
271
+ for patch, _f, source, interval in patches:
272
+ exp = 'File "([^"]*?%s[^"]+?)", line (\\d+)' % cst
273
+ reg = re.compile(exp)
274
+ for node, stack in node_stack:
275
+ occ = reg.findall(stack)
276
+ if not occ:
277
+ continue
278
+ for filename, line_number in occ:
279
+ if source.replace("\\", "/").strip("/") != filename.replace(
280
+ "\\", "/"
281
+ ).strip("/"):
282
+ continue
283
+ line = int(line_number)
284
+ if (
285
+ line >= interval[0]
286
+ and line <= interval[1]
287
+ and self.matching_pair(patch, node)
288
+ ):
289
+ patch_node.append((patch, node))
290
+ patched_nodes.add(id(node))
291
+
292
+ # checks all patches were discovered
293
+ for node, _ in node_stack:
294
+ assert id(node) in patched_nodes, (
295
+ f"One node was patched but no patch was found:\n"
296
+ f"node: {node.target}({','.join(map(str, node.args))}) -> {node.name}"
297
+ f"\n--\n{pprint.pformat(node.meta)}"
298
+ )
299
+
300
+ res = {} # type: ignore[var-annotated]
301
+ for patch, node in patch_node:
302
+ if patch not in res:
303
+ res[patch] = []
304
+ res[patch].append(node)
305
+ return list(res.items())
306
+
307
+ def matching_pair(cls, patch: PatchInfo, node: "torch.fx.Node") -> bool: # noqa: F821
308
+ """
309
+ Last validation for a pair. RotaryEmbedding has many rewriting
310
+ and they all end up in the same code line.
311
+ """
312
+ cls_name = patch.function_to_patch.__qualname__.split(".")[0]
313
+ if not cls_name.endswith("RotaryEmbedding"):
314
+ return True
315
+ return cls_name in str(node.meta)
316
+
317
+ def make_report(
318
+ cls,
319
+ patches: List[Tuple[PatchInfo, List["torch.fx.Node"]]], # noqa: F821
320
+ format: str = "raw",
321
+ ) -> str:
322
+ """
323
+ Creates a report based on the involved patches.
324
+
325
+ :param patches: from method :meth:`patches_involded_in_graph`
326
+ :param format: format of the report
327
+ :return: report
328
+ """
329
+ rows = []
330
+ for patch, nodes in patches:
331
+ rows.append(patch.format_diff(format=format))
332
+ rows.append("")
333
+ if format == "rst":
334
+ rows.extend(["", "", "**impacted nodes**", "", "", ".. code-block::", ""])
335
+ for node in nodes:
336
+ rows.append(
337
+ f" {node.target}({', '.join(map(str,node.args))}) -> {node.name}"
338
+ )
339
+ rows.append("")
340
+ return "\n".join(rows)
@@ -38,7 +38,7 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
38
38
  for v in subset.values():
39
39
  axes = v
40
40
  break
41
- new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]]
41
+ new_shape = [axes for i in range(cache_length * 2)]
42
42
  return new_shape
43
43
  if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
44
44
  raise NotImplementedError(
@@ -1,7 +1,6 @@
1
1
  import ast
2
2
  import copy
3
3
  import contextlib
4
- import difflib
5
4
  import inspect
6
5
  import os
7
6
  import types
@@ -9,6 +8,7 @@ import textwrap
9
8
  import sys
10
9
  from typing import Callable, Dict, List, Set, Optional, Tuple, Union
11
10
  from .patch_module_helper import code_needing_rewriting
11
+ from .patch_details import PatchDetails, make_diff_code, clean_code_with_black
12
12
 
13
13
  NODE_TYPES = tuple(
14
14
  getattr(ast, k)
@@ -881,6 +881,7 @@ def torch_export_rewrite(
881
881
  ] = None,
882
882
  dump_rewriting: Optional[str] = None,
883
883
  verbose: int = 0,
884
+ patch_details: Optional[PatchDetails] = None,
884
885
  ):
885
886
  """
886
887
  Automatically rewrite the methods given in `rewrite` to export
@@ -897,6 +898,8 @@ def torch_export_rewrite(
897
898
  :param verbose: verbosity, up to 10, 10 shows the rewritten code,
898
899
  ``verbose=1`` shows the rewritten function,
899
900
  ``verbose=2`` shows the rewritten code as well
901
+ :param patch_details: to store any applied patch and get a better understanding
902
+ of the applied modifications
900
903
 
901
904
  Example:
902
905
 
@@ -1019,7 +1022,7 @@ def torch_export_rewrite(
1019
1022
  if verbose:
1020
1023
  print(f"[torch_export_rewrite] dump original code in {filename1!r}")
1021
1024
  with open(filename1, "w") as f:
1022
- code = _clean_code(inspect.getsource(to_rewrite))
1025
+ code = clean_code_with_black(inspect.getsource(to_rewrite))
1023
1026
  f.write(code)
1024
1027
  rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0), **kws)
1025
1028
  if dump_rewriting:
@@ -1027,10 +1030,12 @@ def torch_export_rewrite(
1027
1030
  if verbose:
1028
1031
  print(f"[torch_export_rewrite] dump rewritten code in {filename2!r}")
1029
1032
  with open(filename2, "w") as f:
1030
- rcode = _clean_code(rewr.code)
1033
+ rcode = clean_code_with_black(rewr.code)
1031
1034
  f.write(rcode)
1032
1035
  diff = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.diff")
1033
- make_diff(code, rcode, diff)
1036
+ make_diff_code(code, rcode, diff)
1037
+ if patch_details:
1038
+ patch_details.append("rewrite", getattr(cls, name), rewr.func)
1034
1039
  setattr(cls, name, rewr.func)
1035
1040
 
1036
1041
  try:
@@ -1040,35 +1045,3 @@ def torch_export_rewrite(
1040
1045
  if verbose:
1041
1046
  print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}")
1042
1047
  setattr(cls, name, me)
1043
-
1044
-
1045
- def _clean_code(code: str) -> str:
1046
- try:
1047
- import black
1048
- except ImportError:
1049
- return code
1050
- return black.format_str(code, mode=black.FileMode(line_length=98))
1051
-
1052
-
1053
- def make_diff(code1: str, code2: str, output: Optional[str] = None) -> str:
1054
- """
1055
- Creates a diff between two codes.
1056
-
1057
- :param code1: first code
1058
- :param code2: second code
1059
- :param output: if not empty, stores the output in this file
1060
- :return: diff
1061
- """
1062
- text = "\n".join(
1063
- difflib.unified_diff(
1064
- code1.strip().splitlines(),
1065
- code2.strip().splitlines(),
1066
- fromfile="original",
1067
- tofile="rewritten",
1068
- lineterm="",
1069
- )
1070
- )
1071
- if output:
1072
- with open(output, "w") as f:
1073
- f.write(text)
1074
- return text
@@ -195,9 +195,12 @@ class patched_ShapeEnv:
195
195
  if self.frozen:
196
196
  self.counter["ignored_backward_guard"] += 1
197
197
  # PATCHED: raised an exception instead of logging.
198
+ import transformers
199
+
198
200
  raise AssertionError(
199
201
  f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
200
- f"this could result in accuracy problems"
202
+ f"this could result in accuracy problems, transformers.__version__="
203
+ f"{transformers.__version__!r}"
201
204
  )
202
205
 
203
206
  def _set_replacement(
@@ -683,7 +686,7 @@ class patched_ShapeEnv:
683
686
  return concrete_val
684
687
 
685
688
 
686
- def patched_vmap(func, in_dims=0, out_dims=0):
689
+ def patched_vmap(func, in_dims=0, out_dims=0, use_scan: bool = False):
687
690
  """
688
691
  Python implementation of :func:`torch.vmap`.
689
692
  The implementation raises an issue when it is being exported with
@@ -724,8 +727,9 @@ def patched_vmap(func, in_dims=0, out_dims=0):
724
727
  arg = arg.movedim(in_dim, 0)
725
728
  batched_args.append(arg)
726
729
 
727
- if all(isinstance(a, torch.Tensor) for a in args) and isinstance(
728
- batch_size, torch.SymInt
730
+ if use_scan or (
731
+ all(isinstance(a, torch.Tensor) for a in args)
732
+ and isinstance(batch_size, torch.SymInt)
729
733
  ):
730
734
  batched_tensors = [
731
735
  (
@@ -735,7 +739,9 @@ def patched_vmap(func, in_dims=0, out_dims=0):
735
739
  )
736
740
  for arg, in_dim in zip(batched_args, in_dims_)
737
741
  ]
738
- results = torch.ops.higher_order.scan(func, [], batched_tensors, [])
742
+ results = torch.ops.higher_order.scan(
743
+ lambda *args, **kwargs: [func(*args, **kwargs)], [], batched_tensors, []
744
+ )
739
745
  stacked = results[0]
740
746
  if out_dims != 0:
741
747
  return stacked.movedim(0, out_dims)
@@ -745,7 +751,7 @@ def patched_vmap(func, in_dims=0, out_dims=0):
745
751
  torch._check(
746
752
  not isinstance(batch_size, torch.SymInt),
747
753
  lambda: (
748
- f"patched_vmap supports dynamic batch_size only if all argument "
754
+ f"patched_vmap supports dynamic batch_size only if all arguments "
749
755
  f"are tensors but types are {[type(a) for a in args]}"
750
756
  ),
751
757
  )