onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +2 -1
- onnx_diagnostic/export/shape_helper.py +47 -70
- onnx_diagnostic/ext_test_case.py +11 -0
- onnx_diagnostic/helpers/cache_helper.py +38 -7
- onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
- onnx_diagnostic/helpers/helper.py +27 -33
- onnx_diagnostic/helpers/log_helper.py +109 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
- onnx_diagnostic/helpers/model_builder_helper.py +132 -2
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +4 -0
- onnx_diagnostic/helpers/rt_helper.py +393 -43
- onnx_diagnostic/helpers/torch_helper.py +20 -1
- onnx_diagnostic/tasks/__init__.py +7 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
- onnx_diagnostic/tasks/feature_extraction.py +2 -8
- onnx_diagnostic/tasks/image_text_to_text.py +10 -8
- onnx_diagnostic/tasks/summarization.py +2 -8
- onnx_diagnostic/tasks/text2text_generation.py +3 -8
- onnx_diagnostic/tasks/text_generation.py +86 -65
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
- onnx_diagnostic/torch_models/validate.py +626 -228
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import inspect
|
|
3
|
+
import pprint
|
|
4
|
+
import re
|
|
5
|
+
import textwrap
|
|
6
|
+
from typing import Any, Dict, Callable, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def clean_code_with_black(code: str) -> str:
|
|
10
|
+
"""Changes the code style with :epkg:`black` if available."""
|
|
11
|
+
code = textwrap.dedent(code)
|
|
12
|
+
try:
|
|
13
|
+
import black
|
|
14
|
+
except ImportError:
|
|
15
|
+
return code
|
|
16
|
+
try:
|
|
17
|
+
return black.format_str(code, mode=black.FileMode(line_length=98))
|
|
18
|
+
except black.parsing.InvalidInput as e:
|
|
19
|
+
raise RuntimeError(f"Unable to parse code\n\n---\n{code}\n---\n") from e
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_diff_code(code1: str, code2: str, output: Optional[str] = None) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Creates a diff between two codes.
|
|
25
|
+
|
|
26
|
+
:param code1: first code
|
|
27
|
+
:param code2: second code
|
|
28
|
+
:param output: if not empty, stores the output in this file
|
|
29
|
+
:return: diff
|
|
30
|
+
"""
|
|
31
|
+
text = "\n".join(
|
|
32
|
+
difflib.unified_diff(
|
|
33
|
+
code1.strip().splitlines(),
|
|
34
|
+
code2.strip().splitlines(),
|
|
35
|
+
fromfile="original",
|
|
36
|
+
tofile="rewritten",
|
|
37
|
+
lineterm="",
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
if output:
|
|
41
|
+
with open(output, "w") as f:
|
|
42
|
+
f.write(text)
|
|
43
|
+
return text
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class PatchInfo:
|
|
47
|
+
"""
|
|
48
|
+
Stores information about patches.
|
|
49
|
+
|
|
50
|
+
:param function_to_patch: function to patch
|
|
51
|
+
:param patch: function patched
|
|
52
|
+
:param family: a category, anything to classify the patch
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
__slots__ = ("depends_on", "family", "function_to_patch", "patch")
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self, function_to_patch: Union[str, Callable], patch: Callable, family: str = ""
|
|
59
|
+
):
|
|
60
|
+
assert callable(function_to_patch) or isinstance(function_to_patch, str), (
|
|
61
|
+
f"function_to_patch is not a function but {type(function_to_patch)} "
|
|
62
|
+
f"- {function_to_patch!r}"
|
|
63
|
+
)
|
|
64
|
+
assert callable(patch), (
|
|
65
|
+
f"function_to_patch is not a function but {type(patch)} - {patch!r}, "
|
|
66
|
+
f"function_to_patch={function_to_patch!r}"
|
|
67
|
+
)
|
|
68
|
+
assert not callable(function_to_patch) or not function_to_patch.__name__.startswith(
|
|
69
|
+
"patched_"
|
|
70
|
+
), (
|
|
71
|
+
f"A patch was probably not removed because function_to_patch="
|
|
72
|
+
f"{function_to_patch!r} and patch={patch!r}"
|
|
73
|
+
)
|
|
74
|
+
self.family = family
|
|
75
|
+
self.function_to_patch = function_to_patch
|
|
76
|
+
self.patch = patch
|
|
77
|
+
self.depends_on: List[PatchInfo] = []
|
|
78
|
+
|
|
79
|
+
def add_dependency(self, patch_info: "PatchInfo"):
|
|
80
|
+
self.depends_on.append(patch_info)
|
|
81
|
+
|
|
82
|
+
def __repr__(self) -> str:
|
|
83
|
+
"usual"
|
|
84
|
+
return (
|
|
85
|
+
(
|
|
86
|
+
f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r}, "
|
|
87
|
+
f"{self.family!r})"
|
|
88
|
+
)
|
|
89
|
+
if self.family
|
|
90
|
+
else f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r})"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def to_tuple(self) -> Tuple[str, Callable, Callable]:
|
|
94
|
+
"usual"
|
|
95
|
+
return (self.family, self.function_to_patch, self.patch)
|
|
96
|
+
|
|
97
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
98
|
+
"usual"
|
|
99
|
+
return {k: getattr(self, k) for k in self.__slots__}
|
|
100
|
+
|
|
101
|
+
def make_diff(self) -> str:
|
|
102
|
+
"""Returns a diff as a string."""
|
|
103
|
+
if isinstance(self.function_to_patch, str):
|
|
104
|
+
return clean_code_with_black(inspect.getsource(self.patch))
|
|
105
|
+
src1 = clean_code_with_black(inspect.getsource(self.function_to_patch))
|
|
106
|
+
src2 = clean_code_with_black(inspect.getsource(self.patch))
|
|
107
|
+
diff = make_diff_code(src1, src2)
|
|
108
|
+
if not self.depends_on:
|
|
109
|
+
return diff
|
|
110
|
+
res = [diff]
|
|
111
|
+
for d in self.depends_on:
|
|
112
|
+
res.append("")
|
|
113
|
+
res.append(d.make_diff())
|
|
114
|
+
return "\n".join(res)
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def function_name(cls, f: Callable) -> str:
|
|
118
|
+
return f.__qualname__
|
|
119
|
+
|
|
120
|
+
def format_diff(self, format: str = "raw") -> str:
|
|
121
|
+
"""
|
|
122
|
+
Format a diff between two function as a string.
|
|
123
|
+
|
|
124
|
+
:param format: ``'raw'`` or ``'rst'``
|
|
125
|
+
:return: diff
|
|
126
|
+
|
|
127
|
+
.. runpython::
|
|
128
|
+
:showcode:
|
|
129
|
+
:rst:
|
|
130
|
+
|
|
131
|
+
import transformers
|
|
132
|
+
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
|
|
133
|
+
from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
|
|
134
|
+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
|
|
135
|
+
patched_eager_mask,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
eager_mask = transformers.masking_utils.eager_mask
|
|
139
|
+
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
|
|
140
|
+
print(diff)
|
|
141
|
+
"""
|
|
142
|
+
diff = self.make_diff()
|
|
143
|
+
kind = self.family or ""
|
|
144
|
+
if kind:
|
|
145
|
+
kind = f"{kind}: "
|
|
146
|
+
function_to_pach_name = (
|
|
147
|
+
f"{self.function_to_patch!r}"
|
|
148
|
+
if isinstance(self.function_to_patch, str)
|
|
149
|
+
else self.function_name(self.function_to_patch)
|
|
150
|
+
)
|
|
151
|
+
patch_name = self.function_name(self.patch)
|
|
152
|
+
title = f"{kind}{function_to_pach_name} -> {patch_name}"
|
|
153
|
+
if format == "raw":
|
|
154
|
+
return f"{title}\n{diff}"
|
|
155
|
+
|
|
156
|
+
rows = [
|
|
157
|
+
title,
|
|
158
|
+
"=" * len(title),
|
|
159
|
+
"",
|
|
160
|
+
".. code-block:: diff",
|
|
161
|
+
" :linenos:",
|
|
162
|
+
"",
|
|
163
|
+
textwrap.indent(diff, prefix=" "),
|
|
164
|
+
]
|
|
165
|
+
return "\n".join(rows)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class PatchDetails:
|
|
169
|
+
"""
|
|
170
|
+
This class is used to store patching information.
|
|
171
|
+
This helps understanding which rewriting was applied to which
|
|
172
|
+
method of functions. Page :ref:`l-patch-diff` contains all the
|
|
173
|
+
diff for all the implemented patches.
|
|
174
|
+
|
|
175
|
+
.. runpython::
|
|
176
|
+
:showcode:
|
|
177
|
+
:rst:
|
|
178
|
+
|
|
179
|
+
import torch
|
|
180
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
181
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
182
|
+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
|
|
183
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
184
|
+
|
|
185
|
+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
|
|
186
|
+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
|
|
187
|
+
details = PatchDetails()
|
|
188
|
+
with torch_export_patches(
|
|
189
|
+
patch_transformers=True, patch_details=details, patch_torch=False
|
|
190
|
+
):
|
|
191
|
+
ep = torch.export.export(
|
|
192
|
+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
|
|
193
|
+
)
|
|
194
|
+
patches = details.patches_involded_in_graph(ep.graph)
|
|
195
|
+
report = details.make_report(patches, format="rst")
|
|
196
|
+
print(report)
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(self):
|
|
200
|
+
self.patched = []
|
|
201
|
+
self.find_cache = {}
|
|
202
|
+
|
|
203
|
+
def find(self, name: str) -> Optional[PatchInfo]:
|
|
204
|
+
"Finds a patch by name."
|
|
205
|
+
if name in self.find_cache:
|
|
206
|
+
return self.find_cache[name]
|
|
207
|
+
for p in self.patched:
|
|
208
|
+
if p.patch.__name__ == name:
|
|
209
|
+
self.find_cache[name] = p
|
|
210
|
+
return p
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def append(
|
|
214
|
+
self, family: str, function_to_patch: Union[str, Callable], patch: Callable
|
|
215
|
+
) -> PatchInfo:
|
|
216
|
+
"""
|
|
217
|
+
Stores a patch.
|
|
218
|
+
|
|
219
|
+
:param family: a category, anything to classify the patch
|
|
220
|
+
:param function_to_patch: function to patch
|
|
221
|
+
:param patch: function patched
|
|
222
|
+
:return: instance of PatchInfo
|
|
223
|
+
"""
|
|
224
|
+
p = PatchInfo(function_to_patch, patch, family=family)
|
|
225
|
+
self.patched.append(p)
|
|
226
|
+
return p
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def n_patches(self) -> int:
|
|
230
|
+
"Returns the number of stored patches."
|
|
231
|
+
# Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
|
|
232
|
+
return len(self.patched)
|
|
233
|
+
|
|
234
|
+
def data(self) -> List[Dict[str, Any]]:
|
|
235
|
+
"""Returns the data for a dataframe."""
|
|
236
|
+
return [p.to_dict() for p in self.patched]
|
|
237
|
+
|
|
238
|
+
def patches_involded_in_graph(
|
|
239
|
+
self, graph: "torch.fx.Graph" # noqa: F821
|
|
240
|
+
) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
|
|
241
|
+
"""
|
|
242
|
+
Enumerates all patches impacting a graph.
|
|
243
|
+
The function goes through the graph node (only the main graph) and
|
|
244
|
+
looks into the metadata to determine if a listed patch was involved.
|
|
245
|
+
|
|
246
|
+
:param graph: fx graph
|
|
247
|
+
:return: list of nodes impacted by a patch
|
|
248
|
+
"""
|
|
249
|
+
patches = []
|
|
250
|
+
for patch in self.patched:
|
|
251
|
+
f = patch.patch
|
|
252
|
+
source = inspect.getsourcefile(f)
|
|
253
|
+
lines, lineno = inspect.getsourcelines(f)
|
|
254
|
+
interval = [lineno, lineno + len(lines)]
|
|
255
|
+
patches.append((patch, f, source, interval))
|
|
256
|
+
|
|
257
|
+
cst = "onnx_diagnostic"
|
|
258
|
+
node_stack = []
|
|
259
|
+
for node in graph.nodes:
|
|
260
|
+
meta = node.meta
|
|
261
|
+
if "stack_trace" not in meta:
|
|
262
|
+
continue
|
|
263
|
+
stack = meta["stack_trace"]
|
|
264
|
+
if cst not in stack:
|
|
265
|
+
# to reduce the cost of the next iteration
|
|
266
|
+
continue
|
|
267
|
+
node_stack.append((node, stack))
|
|
268
|
+
|
|
269
|
+
patch_node = []
|
|
270
|
+
patched_nodes = set()
|
|
271
|
+
for patch, _f, source, interval in patches:
|
|
272
|
+
exp = 'File "([^"]*?%s[^"]+?)", line (\\d+)' % cst
|
|
273
|
+
reg = re.compile(exp)
|
|
274
|
+
for node, stack in node_stack:
|
|
275
|
+
occ = reg.findall(stack)
|
|
276
|
+
if not occ:
|
|
277
|
+
continue
|
|
278
|
+
for filename, line_number in occ:
|
|
279
|
+
if source.replace("\\", "/").strip("/") != filename.replace(
|
|
280
|
+
"\\", "/"
|
|
281
|
+
).strip("/"):
|
|
282
|
+
continue
|
|
283
|
+
line = int(line_number)
|
|
284
|
+
if (
|
|
285
|
+
line >= interval[0]
|
|
286
|
+
and line <= interval[1]
|
|
287
|
+
and self.matching_pair(patch, node)
|
|
288
|
+
):
|
|
289
|
+
patch_node.append((patch, node))
|
|
290
|
+
patched_nodes.add(id(node))
|
|
291
|
+
|
|
292
|
+
# checks all patches were discovered
|
|
293
|
+
for node, _ in node_stack:
|
|
294
|
+
assert id(node) in patched_nodes, (
|
|
295
|
+
f"One node was patched but no patch was found:\n"
|
|
296
|
+
f"node: {node.target}({','.join(map(str, node.args))}) -> {node.name}"
|
|
297
|
+
f"\n--\n{pprint.pformat(node.meta)}"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
res = {} # type: ignore[var-annotated]
|
|
301
|
+
for patch, node in patch_node:
|
|
302
|
+
if patch not in res:
|
|
303
|
+
res[patch] = []
|
|
304
|
+
res[patch].append(node)
|
|
305
|
+
return list(res.items())
|
|
306
|
+
|
|
307
|
+
def matching_pair(cls, patch: PatchInfo, node: "torch.fx.Node") -> bool: # noqa: F821
|
|
308
|
+
"""
|
|
309
|
+
Last validation for a pair. RotaryEmbedding has many rewriting
|
|
310
|
+
and they all end up in the same code line.
|
|
311
|
+
"""
|
|
312
|
+
cls_name = patch.function_to_patch.__qualname__.split(".")[0]
|
|
313
|
+
if not cls_name.endswith("RotaryEmbedding"):
|
|
314
|
+
return True
|
|
315
|
+
return cls_name in str(node.meta)
|
|
316
|
+
|
|
317
|
+
def make_report(
|
|
318
|
+
cls,
|
|
319
|
+
patches: List[Tuple[PatchInfo, List["torch.fx.Node"]]], # noqa: F821
|
|
320
|
+
format: str = "raw",
|
|
321
|
+
) -> str:
|
|
322
|
+
"""
|
|
323
|
+
Creates a report based on the involved patches.
|
|
324
|
+
|
|
325
|
+
:param patches: from method :meth:`patches_involded_in_graph`
|
|
326
|
+
:param format: format of the report
|
|
327
|
+
:return: report
|
|
328
|
+
"""
|
|
329
|
+
rows = []
|
|
330
|
+
for patch, nodes in patches:
|
|
331
|
+
rows.append(patch.format_diff(format=format))
|
|
332
|
+
rows.append("")
|
|
333
|
+
if format == "rst":
|
|
334
|
+
rows.extend(["", "", "**impacted nodes**", "", "", ".. code-block::", ""])
|
|
335
|
+
for node in nodes:
|
|
336
|
+
rows.append(
|
|
337
|
+
f" {node.target}({', '.join(map(str,node.args))}) -> {node.name}"
|
|
338
|
+
)
|
|
339
|
+
rows.append("")
|
|
340
|
+
return "\n".join(rows)
|
|
@@ -38,7 +38,7 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
|
|
|
38
38
|
for v in subset.values():
|
|
39
39
|
axes = v
|
|
40
40
|
break
|
|
41
|
-
new_shape = [
|
|
41
|
+
new_shape = [axes for i in range(cache_length * 2)]
|
|
42
42
|
return new_shape
|
|
43
43
|
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
44
44
|
raise NotImplementedError(
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import copy
|
|
3
3
|
import contextlib
|
|
4
|
-
import difflib
|
|
5
4
|
import inspect
|
|
6
5
|
import os
|
|
7
6
|
import types
|
|
@@ -9,6 +8,7 @@ import textwrap
|
|
|
9
8
|
import sys
|
|
10
9
|
from typing import Callable, Dict, List, Set, Optional, Tuple, Union
|
|
11
10
|
from .patch_module_helper import code_needing_rewriting
|
|
11
|
+
from .patch_details import PatchDetails, make_diff_code, clean_code_with_black
|
|
12
12
|
|
|
13
13
|
NODE_TYPES = tuple(
|
|
14
14
|
getattr(ast, k)
|
|
@@ -881,6 +881,7 @@ def torch_export_rewrite(
|
|
|
881
881
|
] = None,
|
|
882
882
|
dump_rewriting: Optional[str] = None,
|
|
883
883
|
verbose: int = 0,
|
|
884
|
+
patch_details: Optional[PatchDetails] = None,
|
|
884
885
|
):
|
|
885
886
|
"""
|
|
886
887
|
Automatically rewrite the methods given in `rewrite` to export
|
|
@@ -897,6 +898,8 @@ def torch_export_rewrite(
|
|
|
897
898
|
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
|
|
898
899
|
``verbose=1`` shows the rewritten function,
|
|
899
900
|
``verbose=2`` shows the rewritten code as well
|
|
901
|
+
:param patch_details: to store any applied patch and get a better understanding
|
|
902
|
+
of the applied modifications
|
|
900
903
|
|
|
901
904
|
Example:
|
|
902
905
|
|
|
@@ -1019,7 +1022,7 @@ def torch_export_rewrite(
|
|
|
1019
1022
|
if verbose:
|
|
1020
1023
|
print(f"[torch_export_rewrite] dump original code in {filename1!r}")
|
|
1021
1024
|
with open(filename1, "w") as f:
|
|
1022
|
-
code =
|
|
1025
|
+
code = clean_code_with_black(inspect.getsource(to_rewrite))
|
|
1023
1026
|
f.write(code)
|
|
1024
1027
|
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0), **kws)
|
|
1025
1028
|
if dump_rewriting:
|
|
@@ -1027,10 +1030,12 @@ def torch_export_rewrite(
|
|
|
1027
1030
|
if verbose:
|
|
1028
1031
|
print(f"[torch_export_rewrite] dump rewritten code in {filename2!r}")
|
|
1029
1032
|
with open(filename2, "w") as f:
|
|
1030
|
-
rcode =
|
|
1033
|
+
rcode = clean_code_with_black(rewr.code)
|
|
1031
1034
|
f.write(rcode)
|
|
1032
1035
|
diff = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.diff")
|
|
1033
|
-
|
|
1036
|
+
make_diff_code(code, rcode, diff)
|
|
1037
|
+
if patch_details:
|
|
1038
|
+
patch_details.append("rewrite", getattr(cls, name), rewr.func)
|
|
1034
1039
|
setattr(cls, name, rewr.func)
|
|
1035
1040
|
|
|
1036
1041
|
try:
|
|
@@ -1040,35 +1045,3 @@ def torch_export_rewrite(
|
|
|
1040
1045
|
if verbose:
|
|
1041
1046
|
print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}")
|
|
1042
1047
|
setattr(cls, name, me)
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
def _clean_code(code: str) -> str:
|
|
1046
|
-
try:
|
|
1047
|
-
import black
|
|
1048
|
-
except ImportError:
|
|
1049
|
-
return code
|
|
1050
|
-
return black.format_str(code, mode=black.FileMode(line_length=98))
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
def make_diff(code1: str, code2: str, output: Optional[str] = None) -> str:
|
|
1054
|
-
"""
|
|
1055
|
-
Creates a diff between two codes.
|
|
1056
|
-
|
|
1057
|
-
:param code1: first code
|
|
1058
|
-
:param code2: second code
|
|
1059
|
-
:param output: if not empty, stores the output in this file
|
|
1060
|
-
:return: diff
|
|
1061
|
-
"""
|
|
1062
|
-
text = "\n".join(
|
|
1063
|
-
difflib.unified_diff(
|
|
1064
|
-
code1.strip().splitlines(),
|
|
1065
|
-
code2.strip().splitlines(),
|
|
1066
|
-
fromfile="original",
|
|
1067
|
-
tofile="rewritten",
|
|
1068
|
-
lineterm="",
|
|
1069
|
-
)
|
|
1070
|
-
)
|
|
1071
|
-
if output:
|
|
1072
|
-
with open(output, "w") as f:
|
|
1073
|
-
f.write(text)
|
|
1074
|
-
return text
|
|
@@ -195,9 +195,12 @@ class patched_ShapeEnv:
|
|
|
195
195
|
if self.frozen:
|
|
196
196
|
self.counter["ignored_backward_guard"] += 1
|
|
197
197
|
# PATCHED: raised an exception instead of logging.
|
|
198
|
+
import transformers
|
|
199
|
+
|
|
198
200
|
raise AssertionError(
|
|
199
201
|
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
|
|
200
|
-
f"this could result in accuracy problems"
|
|
202
|
+
f"this could result in accuracy problems, transformers.__version__="
|
|
203
|
+
f"{transformers.__version__!r}"
|
|
201
204
|
)
|
|
202
205
|
|
|
203
206
|
def _set_replacement(
|
|
@@ -683,7 +686,7 @@ class patched_ShapeEnv:
|
|
|
683
686
|
return concrete_val
|
|
684
687
|
|
|
685
688
|
|
|
686
|
-
def patched_vmap(func, in_dims=0, out_dims=0):
|
|
689
|
+
def patched_vmap(func, in_dims=0, out_dims=0, use_scan: bool = False):
|
|
687
690
|
"""
|
|
688
691
|
Python implementation of :func:`torch.vmap`.
|
|
689
692
|
The implementation raises an issue when it is being exported with
|
|
@@ -724,8 +727,9 @@ def patched_vmap(func, in_dims=0, out_dims=0):
|
|
|
724
727
|
arg = arg.movedim(in_dim, 0)
|
|
725
728
|
batched_args.append(arg)
|
|
726
729
|
|
|
727
|
-
if
|
|
728
|
-
|
|
730
|
+
if use_scan or (
|
|
731
|
+
all(isinstance(a, torch.Tensor) for a in args)
|
|
732
|
+
and isinstance(batch_size, torch.SymInt)
|
|
729
733
|
):
|
|
730
734
|
batched_tensors = [
|
|
731
735
|
(
|
|
@@ -735,7 +739,9 @@ def patched_vmap(func, in_dims=0, out_dims=0):
|
|
|
735
739
|
)
|
|
736
740
|
for arg, in_dim in zip(batched_args, in_dims_)
|
|
737
741
|
]
|
|
738
|
-
results = torch.ops.higher_order.scan(
|
|
742
|
+
results = torch.ops.higher_order.scan(
|
|
743
|
+
lambda *args, **kwargs: [func(*args, **kwargs)], [], batched_tensors, []
|
|
744
|
+
)
|
|
739
745
|
stacked = results[0]
|
|
740
746
|
if out_dims != 0:
|
|
741
747
|
return stacked.movedim(0, out_dims)
|
|
@@ -745,7 +751,7 @@ def patched_vmap(func, in_dims=0, out_dims=0):
|
|
|
745
751
|
torch._check(
|
|
746
752
|
not isinstance(batch_size, torch.SymInt),
|
|
747
753
|
lambda: (
|
|
748
|
-
f"patched_vmap supports dynamic batch_size only if all
|
|
754
|
+
f"patched_vmap supports dynamic batch_size only if all arguments "
|
|
749
755
|
f"are tensors but types are {[type(a) for a in args]}"
|
|
750
756
|
),
|
|
751
757
|
)
|