onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1047 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import copy
|
|
3
|
+
import contextlib
|
|
4
|
+
import inspect
|
|
5
|
+
import os
|
|
6
|
+
import types
|
|
7
|
+
import textwrap
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Callable, Dict, List, Set, Optional, Tuple, Union
|
|
10
|
+
from .patch_module_helper import code_needing_rewriting
|
|
11
|
+
from .patch_details import PatchDetails, make_diff_code, clean_code_with_black
|
|
12
|
+
|
|
13
|
+
NODE_TYPES = tuple(
|
|
14
|
+
getattr(ast, k)
|
|
15
|
+
for k in dir(ast)
|
|
16
|
+
if "A" <= k[0] <= "Z" and isinstance(getattr(ast, k), type)
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _settl(node, lineno, level=0):
|
|
21
|
+
if isinstance(node, (str, int, float)):
|
|
22
|
+
return node
|
|
23
|
+
if isinstance(node, list):
|
|
24
|
+
for n in node:
|
|
25
|
+
_settl(n, lineno, level=level + 1)
|
|
26
|
+
return node
|
|
27
|
+
if isinstance(node, NODE_TYPES):
|
|
28
|
+
if not hasattr(node, "lineno") or node.lineno is None:
|
|
29
|
+
node.lineno = lineno
|
|
30
|
+
for k in dir(node):
|
|
31
|
+
if k in {"s", "n", "parent"}:
|
|
32
|
+
continue
|
|
33
|
+
if k[0] == "_":
|
|
34
|
+
continue
|
|
35
|
+
v = getattr(node, k)
|
|
36
|
+
_settl(v, max(lineno, node.lineno), level=level + 1)
|
|
37
|
+
return node
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class UsedVarsFinder(ast.NodeVisitor):
|
|
41
|
+
"""Finds used and defined local variables with a section."""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
self.used = set()
|
|
45
|
+
self.defined = set()
|
|
46
|
+
|
|
47
|
+
def visit_Name(self, node):
|
|
48
|
+
if isinstance(node.ctx, ast.Load):
|
|
49
|
+
self.used.add(node.id)
|
|
50
|
+
elif isinstance(node.ctx, ast.Store):
|
|
51
|
+
self.defined.add(node.id)
|
|
52
|
+
self.generic_visit(node)
|
|
53
|
+
|
|
54
|
+
def visit_Global(self, node):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
def visit_Nonlocal(self, node):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ShapeFinder(ast.NodeVisitor):
|
|
62
|
+
"""Finds <x> in the expression ``x.shape[0]``."""
|
|
63
|
+
|
|
64
|
+
def __init__(self):
|
|
65
|
+
self.found_shape = set()
|
|
66
|
+
super().__init__()
|
|
67
|
+
|
|
68
|
+
def visit_Call(self, node):
|
|
69
|
+
if isinstance(node.func, ast.Name) and node.func.id == "range" and len(node.args) == 1:
|
|
70
|
+
n = node.args[0]
|
|
71
|
+
if (
|
|
72
|
+
isinstance(n, ast.Subscript)
|
|
73
|
+
and isinstance(n.slice, ast.Constant)
|
|
74
|
+
and isinstance(n.slice.value, int)
|
|
75
|
+
and n.slice.value == 0
|
|
76
|
+
and isinstance(n.value, ast.Attribute)
|
|
77
|
+
and isinstance(n.value.value, ast.Name)
|
|
78
|
+
and n.value.attr == "shape"
|
|
79
|
+
):
|
|
80
|
+
self.found_shape.add(n.value.value.id)
|
|
81
|
+
|
|
82
|
+
self.generic_visit(node)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RewriteControlFlow(ast.NodeTransformer):
|
|
86
|
+
"""
|
|
87
|
+
The class rewrites tests with function :func:`torch.cond`.
|
|
88
|
+
``empty_tensor`` is a function returning an empty tensor,
|
|
89
|
+
when a branch returns something the other branch does not.
|
|
90
|
+
|
|
91
|
+
:param prefix: prefix used for nested tests
|
|
92
|
+
:param skip_objects: to skip variable names if included in that list
|
|
93
|
+
such as modules
|
|
94
|
+
:param args_names: defines the local variables
|
|
95
|
+
:param filter_nodes: a function which is used to decide which node
|
|
96
|
+
to rewrite, True by default
|
|
97
|
+
:param pre_rewriter: a rewriter applied before the automated rewriting
|
|
98
|
+
:param post_rewriter: a rewriter applied after the automated rewriting
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
prefix: str = "branch_cond",
|
|
104
|
+
skip_objects: Optional[Dict[str, object]] = None,
|
|
105
|
+
args_names: Optional[Set[str]] = None,
|
|
106
|
+
filter_node: Optional[Callable[["ast.Node"], bool]] = None,
|
|
107
|
+
pre_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
|
|
108
|
+
post_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
|
|
109
|
+
):
|
|
110
|
+
self.counter_test = 0
|
|
111
|
+
self.counter_loop = 0
|
|
112
|
+
self.current_func_args = None
|
|
113
|
+
self.prefix = prefix
|
|
114
|
+
self.skip_objects = skip_objects or {}
|
|
115
|
+
self.args_names = args_names or set()
|
|
116
|
+
self.local_variables = self.args_names.copy()
|
|
117
|
+
self.filter_node = filter_node or (lambda _node: True)
|
|
118
|
+
self.pre_rewriter = pre_rewriter or (lambda node: node)
|
|
119
|
+
self.post_rewriter = post_rewriter or (lambda node: node)
|
|
120
|
+
|
|
121
|
+
def generic_visit(self, node):
|
|
122
|
+
return super().generic_visit(node)
|
|
123
|
+
|
|
124
|
+
def _check(
|
|
125
|
+
self, cond: bool, node: "ast.Node", msg: str, cls: Optional[type[Exception]] = None
|
|
126
|
+
):
|
|
127
|
+
"""
|
|
128
|
+
Checks the condition is True, otherwise raises an exception with an error message
|
|
129
|
+
including the parsed code.
|
|
130
|
+
"""
|
|
131
|
+
if cls is not None:
|
|
132
|
+
if not cond:
|
|
133
|
+
smsg = msg if isinstance(msg, str) else msg()
|
|
134
|
+
raise cls(f"{smsg}\n\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}")
|
|
135
|
+
return
|
|
136
|
+
assert cond, (
|
|
137
|
+
f"{msg if isinstance(msg, str) else msg()}\n\n--\n"
|
|
138
|
+
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def visit_Name(self, node):
|
|
142
|
+
node = self.generic_visit(node)
|
|
143
|
+
if isinstance(node.ctx, ast.Store):
|
|
144
|
+
self.local_variables.add(node.id)
|
|
145
|
+
return node
|
|
146
|
+
|
|
147
|
+
def visit_FunctionDef(self, node):
|
|
148
|
+
# Capture argument names for branch functions
|
|
149
|
+
old_args = self.current_func_args
|
|
150
|
+
self.current_func_args = [arg.arg for arg in node.args.args]
|
|
151
|
+
node.body = [self.visit(n) for n in node.body]
|
|
152
|
+
self.current_func_args = old_args
|
|
153
|
+
return node
|
|
154
|
+
|
|
155
|
+
def _find_id(self, exprs: List["ast.Node"]) -> List[str]:
|
|
156
|
+
vars = []
|
|
157
|
+
for expr in exprs:
|
|
158
|
+
for n in ast.walk(expr):
|
|
159
|
+
if (
|
|
160
|
+
isinstance(n, ast.Name)
|
|
161
|
+
# and isinstance(n.ctx, ast.Load)
|
|
162
|
+
and n.id not in self.skip_objects
|
|
163
|
+
):
|
|
164
|
+
vars.append(n.id)
|
|
165
|
+
return sorted(set(vars))
|
|
166
|
+
|
|
167
|
+
def _clone(self, name):
|
|
168
|
+
assert isinstance(name, ast.Name), f"Unexpected type {type(name)} for name"
|
|
169
|
+
return ast.Call(
|
|
170
|
+
func=ast.Attribute(value=name, attr="clone", ctx=ast.Load()), args=[], keywords=[]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def _rewrite_if(
|
|
174
|
+
self, node, then_exprs, else_exprs, tgt_mapping=None, known_local_variables=None
|
|
175
|
+
):
|
|
176
|
+
assert known_local_variables is not None, "known_local_variables cannot be None"
|
|
177
|
+
test_node = node.test
|
|
178
|
+
drop = set()
|
|
179
|
+
|
|
180
|
+
# extract free variables
|
|
181
|
+
then_name = f"{self.prefix}_then_{self.counter_test}"
|
|
182
|
+
else_name = f"{self.prefix}_else_{self.counter_test}"
|
|
183
|
+
then_vars = self._find_id(then_exprs)
|
|
184
|
+
else_vars = self._find_id(else_exprs)
|
|
185
|
+
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ in known_local_variables)
|
|
186
|
+
then_ret, else_ret = None, None
|
|
187
|
+
if tgt_mapping is None and len(then_exprs) == 1 and len(else_exprs) == 1:
|
|
188
|
+
# return
|
|
189
|
+
then_ret = then_exprs[0]
|
|
190
|
+
else_ret = else_exprs[0]
|
|
191
|
+
then_exprs = [n for n in node.body if not isinstance(n, ast.Return)]
|
|
192
|
+
else_exprs = [n for n in node.orelse if not isinstance(n, ast.Return)]
|
|
193
|
+
is_tuple_or_list = (
|
|
194
|
+
isinstance(then_ret, (ast.Tuple, ast.List)),
|
|
195
|
+
isinstance(else_ret, (ast.Tuple, ast.List)),
|
|
196
|
+
)
|
|
197
|
+
assert len(set(is_tuple_or_list)) == 1, (
|
|
198
|
+
f"is_tuple_or_list={is_tuple_or_list}, inconsistencies return "
|
|
199
|
+
f"then value={then_ret}, "
|
|
200
|
+
f"else value={else_ret}"
|
|
201
|
+
)
|
|
202
|
+
if is_tuple_or_list[0]:
|
|
203
|
+
assert len(then_ret.elts) == len(else_ret.elts), (
|
|
204
|
+
f"Unexpected number of elements on both branches, "
|
|
205
|
+
f"then:{then_ret.elts}, else:{else_ret.elts}"
|
|
206
|
+
)
|
|
207
|
+
n_returned_values = len(then_ret.elts)
|
|
208
|
+
else:
|
|
209
|
+
n_returned_values = 0
|
|
210
|
+
else:
|
|
211
|
+
self._check(
|
|
212
|
+
tgt_mapping,
|
|
213
|
+
node,
|
|
214
|
+
"then and else branches do not have the same number "
|
|
215
|
+
"of assignments, we need more information to understand "
|
|
216
|
+
"which ones to return",
|
|
217
|
+
)
|
|
218
|
+
drop = set()
|
|
219
|
+
then_exprs, else_exprs = node.body, node.orelse
|
|
220
|
+
then_rets, else_rets = [], []
|
|
221
|
+
for t, then_else in sorted(tgt_mapping.items()):
|
|
222
|
+
then_e, else_e = then_else
|
|
223
|
+
if (then_e is None or else_e is None) and t not in then_else_vars:
|
|
224
|
+
# The variable is not used by one branch and it is not an input.
|
|
225
|
+
# Let's drop it.
|
|
226
|
+
drop.add(t)
|
|
227
|
+
continue
|
|
228
|
+
then_rets.append(then_e or ast.Name(else_e.id, ctx=ast.Load()))
|
|
229
|
+
else_rets.append(else_e or ast.Name(then_e.id, ctx=ast.Load()))
|
|
230
|
+
then_ret = (
|
|
231
|
+
self._clone(then_rets[0])
|
|
232
|
+
if len(then_rets) == 1
|
|
233
|
+
else ast.Tuple([self._clone(r) for r in then_rets], ctx=ast.Load())
|
|
234
|
+
)
|
|
235
|
+
else_ret = (
|
|
236
|
+
self._clone(else_rets[0])
|
|
237
|
+
if len(else_rets) == 1
|
|
238
|
+
else ast.Tuple([self._clone(r) for r in else_rets], ctx=ast.Load())
|
|
239
|
+
)
|
|
240
|
+
n_returned_values = len(then_rets) if len(then_rets) > 1 else 0
|
|
241
|
+
|
|
242
|
+
# build local funcs
|
|
243
|
+
then_def = ast.FunctionDef(
|
|
244
|
+
name=then_name,
|
|
245
|
+
args=ast.arguments(
|
|
246
|
+
posonlyargs=[],
|
|
247
|
+
args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
|
|
248
|
+
kwonlyargs=[],
|
|
249
|
+
kw_defaults=[],
|
|
250
|
+
defaults=[],
|
|
251
|
+
),
|
|
252
|
+
body=[*then_exprs, ast.Return(then_ret)],
|
|
253
|
+
decorator_list=[],
|
|
254
|
+
returns=None,
|
|
255
|
+
)
|
|
256
|
+
else_def = ast.FunctionDef(
|
|
257
|
+
name=else_name,
|
|
258
|
+
args=ast.arguments(
|
|
259
|
+
posonlyargs=[],
|
|
260
|
+
args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
|
|
261
|
+
kwonlyargs=[],
|
|
262
|
+
kw_defaults=[],
|
|
263
|
+
defaults=[],
|
|
264
|
+
),
|
|
265
|
+
body=[*else_exprs, ast.Return(else_ret)],
|
|
266
|
+
decorator_list=[],
|
|
267
|
+
returns=None,
|
|
268
|
+
)
|
|
269
|
+
# fix locations
|
|
270
|
+
for n in (then_def, else_def):
|
|
271
|
+
ast.copy_location(n, node)
|
|
272
|
+
ast.fix_missing_locations(n)
|
|
273
|
+
assert hasattr(n, "lineno")
|
|
274
|
+
# wrapper call and assignment
|
|
275
|
+
then_else_args_list = ast.List(
|
|
276
|
+
[ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
|
|
277
|
+
ctx=ast.Load(),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
call = ast.Call(
|
|
281
|
+
func=ast.Attribute(
|
|
282
|
+
value=ast.Name(id="torch", ctx=ast.Load()), attr="cond", ctx=ast.Load()
|
|
283
|
+
),
|
|
284
|
+
args=[
|
|
285
|
+
test_node,
|
|
286
|
+
ast.Name(id=then_name, ctx=ast.Load()),
|
|
287
|
+
ast.Name(id=else_name, ctx=ast.Load()),
|
|
288
|
+
then_else_args_list,
|
|
289
|
+
],
|
|
290
|
+
keywords=[],
|
|
291
|
+
)
|
|
292
|
+
return then_def, else_def, call, drop, n_returned_values
|
|
293
|
+
|
|
294
|
+
def _filter_target(self, node, tgt_mapping):
|
|
295
|
+
"""
|
|
296
|
+
This function should reduce the number of elements to return
|
|
297
|
+
by looking at the one used after the If statement.
|
|
298
|
+
"""
|
|
299
|
+
return tgt_mapping
|
|
300
|
+
|
|
301
|
+
def _make_targets(self, node, then_assigns, else_assigns):
|
|
302
|
+
tgt_mapping = {}
|
|
303
|
+
for a, then_or_else in [
|
|
304
|
+
*[(a, True) for a in then_assigns],
|
|
305
|
+
*[(a, False) for a in else_assigns],
|
|
306
|
+
]:
|
|
307
|
+
for t in a.targets:
|
|
308
|
+
if isinstance(t, ast.Name) and isinstance(t.ctx, ast.Store):
|
|
309
|
+
if t.id not in tgt_mapping:
|
|
310
|
+
tgt_mapping[t.id] = (t, None) if then_or_else else (None, t)
|
|
311
|
+
else:
|
|
312
|
+
v = tgt_mapping[t.id]
|
|
313
|
+
tgt_mapping[t.id] = (t, v[1]) if then_or_else else (v[0], t)
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
self._check(
|
|
317
|
+
isinstance(t, ast.Tuple) and all(isinstance(_, ast.Name) for _ in t.elts),
|
|
318
|
+
node,
|
|
319
|
+
"Unexpected assignment. Not Supported.",
|
|
320
|
+
)
|
|
321
|
+
for _t in t.elts:
|
|
322
|
+
if not isinstance(_t, ast.Name) or not isinstance(_t.ctx, ast.Store):
|
|
323
|
+
continue
|
|
324
|
+
if _t.id not in tgt_mapping:
|
|
325
|
+
tgt_mapping[_t.id] = (_t, None) if then_or_else else (None, _t)
|
|
326
|
+
else:
|
|
327
|
+
v = tgt_mapping[_t.id]
|
|
328
|
+
tgt_mapping[_t.id] = (_t, v[1]) if then_or_else else (v[0], _t)
|
|
329
|
+
|
|
330
|
+
tgt_mapping = self._filter_target(node, tgt_mapping)
|
|
331
|
+
d = [(v[0] or v[1]) for k, v in sorted(dict(tgt_mapping).items())]
|
|
332
|
+
tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load())
|
|
333
|
+
return tgt, tgt_mapping
|
|
334
|
+
|
|
335
|
+
def visit_If(self, node):
|
|
336
|
+
if not self.filter_node(node):
|
|
337
|
+
return [node]
|
|
338
|
+
|
|
339
|
+
node = self.pre_rewriter(node)
|
|
340
|
+
|
|
341
|
+
# First recurse into subnodes
|
|
342
|
+
known_local_variables = self.local_variables.copy()
|
|
343
|
+
node = self.generic_visit(node)
|
|
344
|
+
|
|
345
|
+
has_then_return = any(isinstance(n, ast.Return) for n in node.body)
|
|
346
|
+
has_else_return = any(isinstance(n, ast.Return) for n in node.orelse)
|
|
347
|
+
ok = (has_then_return and has_else_return) or (
|
|
348
|
+
not has_then_return and not has_else_return
|
|
349
|
+
)
|
|
350
|
+
self._check(
|
|
351
|
+
ok,
|
|
352
|
+
node,
|
|
353
|
+
"Cannot mix return and assignment in a test or a "
|
|
354
|
+
"unique then branch with a return",
|
|
355
|
+
NotImplementedError,
|
|
356
|
+
)
|
|
357
|
+
self._check(self.current_func_args is not None, node, "current_func_args is None")
|
|
358
|
+
self.counter_test += 1
|
|
359
|
+
|
|
360
|
+
if not has_then_return:
|
|
361
|
+
# Case 1: simple assignment in both branches
|
|
362
|
+
then_assigns = [n for n in node.body if isinstance(n, ast.Assign)]
|
|
363
|
+
else_assigns = [n for n in node.orelse if isinstance(n, ast.Assign)]
|
|
364
|
+
self._check(then_assigns or else_assigns, node, "Missing assignment")
|
|
365
|
+
|
|
366
|
+
# the targets we need to export
|
|
367
|
+
tgt, tgt_mapping = self._make_targets(node, then_assigns, else_assigns)
|
|
368
|
+
|
|
369
|
+
then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
|
|
370
|
+
node,
|
|
371
|
+
then_assigns,
|
|
372
|
+
else_assigns,
|
|
373
|
+
tgt_mapping=tgt_mapping,
|
|
374
|
+
known_local_variables=known_local_variables,
|
|
375
|
+
)
|
|
376
|
+
if dropped and isinstance(tgt, ast.Tuple):
|
|
377
|
+
tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
|
|
378
|
+
elif isinstance(tgt, ast.Tuple):
|
|
379
|
+
tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
|
|
380
|
+
else:
|
|
381
|
+
tgt_elts = [tgt]
|
|
382
|
+
|
|
383
|
+
if n_returned_values == 0:
|
|
384
|
+
assert len(tgt_elts) == 1, (
|
|
385
|
+
f"Inconsistencies between n_returned_values={n_returned_values}, "
|
|
386
|
+
f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
|
|
387
|
+
)
|
|
388
|
+
tgt = tgt_elts[0]
|
|
389
|
+
else:
|
|
390
|
+
assert n_returned_values == len(tgt_elts), (
|
|
391
|
+
f"Inconsistencies between n_returned_values={n_returned_values}, "
|
|
392
|
+
f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
|
|
393
|
+
)
|
|
394
|
+
tgt = ast.Tuple(tgt_elts, ctx=ast.Store())
|
|
395
|
+
|
|
396
|
+
added = {tgt.id} if isinstance(tgt, ast.Name) else set(t.id for t in tgt.elts)
|
|
397
|
+
assign = ast.Assign(targets=[tgt], value=call)
|
|
398
|
+
ast.copy_location(assign, node)
|
|
399
|
+
ast.fix_missing_locations(assign)
|
|
400
|
+
self.local_variables = known_local_variables | added
|
|
401
|
+
return [self.post_rewriter(n) for n in [then_def, else_def, assign]]
|
|
402
|
+
|
|
403
|
+
# Case 2: return in both branches, we assume both branches return the same results.
|
|
404
|
+
then_ret = node.body[-1]
|
|
405
|
+
else_ret = node.orelse[-1]
|
|
406
|
+
self._check(
|
|
407
|
+
isinstance(then_ret, ast.Return),
|
|
408
|
+
node,
|
|
409
|
+
"return is not the last instruction of then branch",
|
|
410
|
+
)
|
|
411
|
+
self._check(
|
|
412
|
+
isinstance(else_ret, ast.Return),
|
|
413
|
+
node,
|
|
414
|
+
"return is not the last instruction of else branch",
|
|
415
|
+
)
|
|
416
|
+
then_expr = then_ret.value
|
|
417
|
+
else_expr = else_ret.value
|
|
418
|
+
then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
|
|
419
|
+
node, [then_expr], [else_expr], known_local_variables=known_local_variables
|
|
420
|
+
)
|
|
421
|
+
ret = ast.Return(call)
|
|
422
|
+
ast.copy_location(ret, node)
|
|
423
|
+
ast.fix_missing_locations(ret)
|
|
424
|
+
return [self.post_rewriter(n) for n in [then_def, else_def, ret]]
|
|
425
|
+
|
|
426
|
+
def _find_loop_vars(self, node):
|
|
427
|
+
assert isinstance(node, ast.For), f"Unexpected type {type(node)} for node"
|
|
428
|
+
finder = ShapeFinder()
|
|
429
|
+
finder.visit(node.iter)
|
|
430
|
+
scan_shape_vars = finder.found_shape
|
|
431
|
+
scan_vars = set()
|
|
432
|
+
|
|
433
|
+
finder = UsedVarsFinder()
|
|
434
|
+
for stmt in node.body:
|
|
435
|
+
finder.visit(stmt)
|
|
436
|
+
|
|
437
|
+
assigned_in_body = set()
|
|
438
|
+
for stmt in node.body:
|
|
439
|
+
if isinstance(stmt, ast.Assign):
|
|
440
|
+
for tgt in stmt.targets:
|
|
441
|
+
if isinstance(tgt, ast.Name) and isinstance(tgt.value.ctx, ast.Store):
|
|
442
|
+
assigned_in_body |= {tgt.value.id}
|
|
443
|
+
|
|
444
|
+
extra_defined = set()
|
|
445
|
+
for stmt in node.body:
|
|
446
|
+
if isinstance(stmt, ast.Assign):
|
|
447
|
+
for tgt in stmt.targets:
|
|
448
|
+
if isinstance(tgt, ast.Subscript):
|
|
449
|
+
# It means the target existed before.
|
|
450
|
+
if (
|
|
451
|
+
isinstance(tgt.value, ast.Name)
|
|
452
|
+
and tgt.value.id not in assigned_in_body
|
|
453
|
+
):
|
|
454
|
+
extra_defined.add(tgt.value.id)
|
|
455
|
+
|
|
456
|
+
loop_vars = set()
|
|
457
|
+
if isinstance(node.target, ast.Name):
|
|
458
|
+
loop_vars.add(node.target.id)
|
|
459
|
+
elif isinstance(node.target, (ast.Tuple, ast.List)):
|
|
460
|
+
loop_vars |= {elt.id for elt in node.target.elts if isinstance(elt, ast.Name)}
|
|
461
|
+
|
|
462
|
+
output_vars = finder.defined | assigned_in_body
|
|
463
|
+
input_vars = (
|
|
464
|
+
finder.used
|
|
465
|
+
- finder.defined
|
|
466
|
+
- loop_vars
|
|
467
|
+
- scan_shape_vars
|
|
468
|
+
- scan_vars
|
|
469
|
+
- output_vars
|
|
470
|
+
- assigned_in_body
|
|
471
|
+
- extra_defined
|
|
472
|
+
)
|
|
473
|
+
return dict(
|
|
474
|
+
init=sorted(extra_defined),
|
|
475
|
+
loop=sorted(loop_vars),
|
|
476
|
+
scan_shape=sorted(scan_shape_vars),
|
|
477
|
+
scan=sorted(scan_vars),
|
|
478
|
+
input=sorted(input_vars),
|
|
479
|
+
output=sorted(output_vars),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def visit_For(self, node):
|
|
483
|
+
if not self.filter_node(node):
|
|
484
|
+
return [node]
|
|
485
|
+
|
|
486
|
+
node = self.pre_rewriter(node)
|
|
487
|
+
|
|
488
|
+
# For nested loops.
|
|
489
|
+
self.generic_visit(node)
|
|
490
|
+
# look for variables, loop, inputs and outputs of the body
|
|
491
|
+
vars = self._find_loop_vars(node)
|
|
492
|
+
init_vars, loop_vars, scan_shape_vars, scan_vars, input_vars, output_vars = [
|
|
493
|
+
vars[k] for k in ["init", "loop", "scan_shape", "scan", "input", "output"]
|
|
494
|
+
]
|
|
495
|
+
self._check(
|
|
496
|
+
len(scan_shape_vars) == len(loop_vars),
|
|
497
|
+
node,
|
|
498
|
+
lambda: (
|
|
499
|
+
f"Inconsistencies between loop_vars={loop_vars} "
|
|
500
|
+
f"and scan_shape_vars={scan_shape_vars}"
|
|
501
|
+
),
|
|
502
|
+
)
|
|
503
|
+
self._check(
|
|
504
|
+
len(scan_shape_vars) in {0, 1},
|
|
505
|
+
node,
|
|
506
|
+
lambda: f"Inconsistencies with scan_shape_vars={scan_shape_vars}",
|
|
507
|
+
)
|
|
508
|
+
self._check(
|
|
509
|
+
(len(scan_shape_vars) == 0 or len(scan_vars) == 0)
|
|
510
|
+
and (scan_shape_vars or scan_vars),
|
|
511
|
+
node,
|
|
512
|
+
lambda: (
|
|
513
|
+
f"Inconsistencies between scan_vars={scan_vars} "
|
|
514
|
+
f"and scan_shape_vars={scan_shape_vars}"
|
|
515
|
+
),
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# creates the function
|
|
519
|
+
func_name = f"loop_body_{self.counter_loop}"
|
|
520
|
+
self.counter_loop += 1
|
|
521
|
+
func_def = ast.FunctionDef(
|
|
522
|
+
name=func_name,
|
|
523
|
+
args=ast.arguments(
|
|
524
|
+
posonlyargs=[],
|
|
525
|
+
args=[
|
|
526
|
+
ast.arg(arg=v)
|
|
527
|
+
for v in [
|
|
528
|
+
*init_vars,
|
|
529
|
+
*loop_vars,
|
|
530
|
+
*scan_vars,
|
|
531
|
+
*scan_shape_vars,
|
|
532
|
+
*input_vars,
|
|
533
|
+
]
|
|
534
|
+
],
|
|
535
|
+
kwonlyargs=[],
|
|
536
|
+
kw_defaults=[],
|
|
537
|
+
defaults=[],
|
|
538
|
+
),
|
|
539
|
+
body=[
|
|
540
|
+
*[
|
|
541
|
+
ast.Assign(
|
|
542
|
+
targets=[ast.Name(id=i, ctx=ast.Load())],
|
|
543
|
+
value=[
|
|
544
|
+
ast.Call(
|
|
545
|
+
func=ast.Attribute(
|
|
546
|
+
value=ast.Name(id=i, ctx=ast.Load()),
|
|
547
|
+
attr="clone",
|
|
548
|
+
ctx=ast.Load(),
|
|
549
|
+
),
|
|
550
|
+
args=[],
|
|
551
|
+
keywords=[],
|
|
552
|
+
ctx=ast.Load(),
|
|
553
|
+
)
|
|
554
|
+
],
|
|
555
|
+
)
|
|
556
|
+
for i in init_vars
|
|
557
|
+
],
|
|
558
|
+
*node.body,
|
|
559
|
+
ast.Return(
|
|
560
|
+
value=ast.List(
|
|
561
|
+
[
|
|
562
|
+
ast.Name(id=v, ctx=ast.Load())
|
|
563
|
+
for v in [*init_vars, *loop_vars, *output_vars]
|
|
564
|
+
],
|
|
565
|
+
ctx=ast.Load(),
|
|
566
|
+
)
|
|
567
|
+
),
|
|
568
|
+
],
|
|
569
|
+
decorator_list=[],
|
|
570
|
+
ctx=ast.Store(),
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# final rewriting
|
|
574
|
+
call = ast.Call(
|
|
575
|
+
func=(
|
|
576
|
+
ast.Attribute(
|
|
577
|
+
value=ast.Attribute(
|
|
578
|
+
value=ast.Attribute(
|
|
579
|
+
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
580
|
+
attr="ops",
|
|
581
|
+
ctx=ast.Load(),
|
|
582
|
+
),
|
|
583
|
+
attr="higher_order",
|
|
584
|
+
ctx=ast.Load(),
|
|
585
|
+
),
|
|
586
|
+
attr="scan",
|
|
587
|
+
ctx=ast.Load(),
|
|
588
|
+
)
|
|
589
|
+
),
|
|
590
|
+
args=[
|
|
591
|
+
ast.Name(id=func_name, ctx=ast.Load()),
|
|
592
|
+
ast.List(
|
|
593
|
+
elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Store()
|
|
594
|
+
),
|
|
595
|
+
ast.List(
|
|
596
|
+
elts=[
|
|
597
|
+
*[
|
|
598
|
+
ast.Call(
|
|
599
|
+
ast.Attribute(
|
|
600
|
+
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
601
|
+
attr="arange",
|
|
602
|
+
ctx=ast.Load(),
|
|
603
|
+
),
|
|
604
|
+
args=[
|
|
605
|
+
ast.Subscript(
|
|
606
|
+
value=ast.Attribute(
|
|
607
|
+
value=ast.Name(id=v, ctx=ast.Load()),
|
|
608
|
+
attr="shape",
|
|
609
|
+
ctx=ast.Load(),
|
|
610
|
+
),
|
|
611
|
+
slice=ast.Constant(value=0, ctx=ast.Load()),
|
|
612
|
+
ctx=ast.Load(),
|
|
613
|
+
),
|
|
614
|
+
],
|
|
615
|
+
keywords=[
|
|
616
|
+
ast.keyword(
|
|
617
|
+
arg="dtype",
|
|
618
|
+
value=ast.Attribute(
|
|
619
|
+
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
620
|
+
attr="int64",
|
|
621
|
+
ctx=ast.Load(),
|
|
622
|
+
),
|
|
623
|
+
)
|
|
624
|
+
],
|
|
625
|
+
ctx=ast.Load(),
|
|
626
|
+
)
|
|
627
|
+
for v in scan_shape_vars
|
|
628
|
+
],
|
|
629
|
+
*[ast.Name(id=v, ctx=ast.Load()) for v in scan_vars],
|
|
630
|
+
],
|
|
631
|
+
ctx=ast.Store(),
|
|
632
|
+
),
|
|
633
|
+
ast.List(
|
|
634
|
+
elts=[
|
|
635
|
+
ast.Name(id=v, ctx=ast.Load()) for v in [*scan_shape_vars, *input_vars]
|
|
636
|
+
],
|
|
637
|
+
ctx=ast.Store(),
|
|
638
|
+
),
|
|
639
|
+
],
|
|
640
|
+
keywords=[],
|
|
641
|
+
ctx=ast.Load(),
|
|
642
|
+
)
|
|
643
|
+
target = ast.Tuple(
|
|
644
|
+
[ast.Name(id=v, ctx=ast.Store()) for v in [*init_vars, *loop_vars, *output_vars]],
|
|
645
|
+
ctx=ast.Store(),
|
|
646
|
+
)
|
|
647
|
+
assign = ast.Assign(targets=[target], value=call)
|
|
648
|
+
return [self.post_rewriter(func_def), self.post_rewriter(assign)]
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
class RewrittenMethod:
|
|
652
|
+
"""
|
|
653
|
+
Stores a rewritten method using
|
|
654
|
+
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
|
|
655
|
+
|
|
656
|
+
:param tree: ast tree
|
|
657
|
+
:param func: callable compiled from the tree
|
|
658
|
+
"""
|
|
659
|
+
|
|
660
|
+
def __init__(self, tree, func):
|
|
661
|
+
self.tree = tree
|
|
662
|
+
self.func = func
|
|
663
|
+
|
|
664
|
+
@property
|
|
665
|
+
def code(self) -> str:
|
|
666
|
+
"""Returns the source."""
|
|
667
|
+
return ast.unparse(self.tree)
|
|
668
|
+
|
|
669
|
+
@property
|
|
670
|
+
def dump(self) -> str:
|
|
671
|
+
"""Returns the tree dumped as a string."""
|
|
672
|
+
return ast.dump(self.tree, indent=2)
|
|
673
|
+
|
|
674
|
+
def __repr__(self):
|
|
675
|
+
"usual"
|
|
676
|
+
return f"{self.__class__.__name__}({self.func})"
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
class _AddParentTransformer(ast.NodeTransformer):
|
|
680
|
+
parent = None
|
|
681
|
+
|
|
682
|
+
def visit(self, node):
|
|
683
|
+
node.parent = self.parent
|
|
684
|
+
self.parent = node
|
|
685
|
+
node = super().visit(node)
|
|
686
|
+
if isinstance(node, ast.AST):
|
|
687
|
+
self.parent = node.parent
|
|
688
|
+
return node
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
class _SelectiveAssignNormalizer(ast.NodeTransformer):
|
|
692
|
+
def visit_If(self, node):
|
|
693
|
+
self.generic_visit(node)
|
|
694
|
+
node.body = [self._transform_if_needed(stmt) for stmt in node.body]
|
|
695
|
+
node.orelse = [self._transform_if_needed(stmt) for stmt in node.orelse]
|
|
696
|
+
return node
|
|
697
|
+
|
|
698
|
+
def _transform_if_needed(self, stmt):
|
|
699
|
+
if isinstance(stmt, ast.AugAssign):
|
|
700
|
+
return ast.Assign(
|
|
701
|
+
targets=[stmt.target],
|
|
702
|
+
value=ast.BinOp(left=copy.deepcopy(stmt.target), op=stmt.op, right=stmt.value),
|
|
703
|
+
)
|
|
704
|
+
if isinstance(stmt, ast.AnnAssign) and stmt.value is not None:
|
|
705
|
+
return ast.Assign(targets=[stmt.target], value=stmt.value)
|
|
706
|
+
return self.visit(stmt)
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
def inplace_add_parent(tree: "ast.Node"):
|
|
710
|
+
"""Adds parents to an AST tree."""
|
|
711
|
+
_AddParentTransformer().visit(tree)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def normalize_assignment_in_test(tree: "ast.Node"):
|
|
715
|
+
"""Split AugAssign into BinOp and Assign to simplify whatever comes after."""
|
|
716
|
+
_SelectiveAssignNormalizer().visit(tree)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
def transform_method(
|
|
720
|
+
func: Callable,
|
|
721
|
+
prefix: str = "branch_cond",
|
|
722
|
+
verbose: int = 0,
|
|
723
|
+
filter_node: Optional[Callable[["ast.Node"], bool]] = None,
|
|
724
|
+
pre_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
|
|
725
|
+
post_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
|
|
726
|
+
) -> RewrittenMethod:
|
|
727
|
+
"""
|
|
728
|
+
Returns a new function based on `func` where every test (if)
|
|
729
|
+
is replaced by a call to :func:`torch.cond`.
|
|
730
|
+
Some known rewriting are part of the default patches
|
|
731
|
+
(see :ref:`l-control-flow-rewriting`).
|
|
732
|
+
|
|
733
|
+
A test must return the same things if it returns something
|
|
734
|
+
or assign something. It cannot return in one branch and assign
|
|
735
|
+
in the other branch.
|
|
736
|
+
|
|
737
|
+
.. warning:: room for improvement
|
|
738
|
+
|
|
739
|
+
When it assigns a value to a constant,
|
|
740
|
+
the current implementation does check which ones is really used
|
|
741
|
+
after the test. The rewritten local functions returns every
|
|
742
|
+
assigned variable. This could be reduced.
|
|
743
|
+
See method ``_filter_target``.
|
|
744
|
+
|
|
745
|
+
:param func: method or function to rewrite
|
|
746
|
+
:param prefix: prefix used to create the functions for the branches
|
|
747
|
+
:param verbose: verbosity
|
|
748
|
+
:param filter_node: a function which tells which node to rewrite
|
|
749
|
+
:param pre_rewriter: a rewriter applied before the automated rewriting
|
|
750
|
+
:param post_rewriter: a rewriter applied after the automated rewriting
|
|
751
|
+
:return: rewritten method
|
|
752
|
+
|
|
753
|
+
An example with **return**:
|
|
754
|
+
|
|
755
|
+
.. runpython::
|
|
756
|
+
:showcode:
|
|
757
|
+
:process:
|
|
758
|
+
:store_in_file: test_example_transform_method_1.py
|
|
759
|
+
|
|
760
|
+
import torch
|
|
761
|
+
from onnx_diagnostic.torch_export_patches.patch_module import transform_method
|
|
762
|
+
|
|
763
|
+
class Model(torch.nn.Module):
|
|
764
|
+
def forward(self, x, y):
|
|
765
|
+
if x.sum() > 0:
|
|
766
|
+
return x + y, x - y
|
|
767
|
+
else:
|
|
768
|
+
return torch.abs(x) + y, torch.abs(x) - y
|
|
769
|
+
|
|
770
|
+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
|
|
771
|
+
expected = Model()(x, y)
|
|
772
|
+
|
|
773
|
+
rewritten = transform_method(Model.forward)
|
|
774
|
+
print("-- code --")
|
|
775
|
+
print(rewritten.code)
|
|
776
|
+
|
|
777
|
+
print(" -- export --")
|
|
778
|
+
Model.forward = rewritten.func
|
|
779
|
+
|
|
780
|
+
DYN = torch.export.Dim.DYNAMIC
|
|
781
|
+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
|
|
782
|
+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
|
|
783
|
+
print(ep)
|
|
784
|
+
|
|
785
|
+
An example with **assignments**:
|
|
786
|
+
|
|
787
|
+
.. runpython::
|
|
788
|
+
:showcode:
|
|
789
|
+
:process:
|
|
790
|
+
:store_in_file: test_example_transform_method_2.py
|
|
791
|
+
|
|
792
|
+
import torch
|
|
793
|
+
from onnx_diagnostic.torch_export_patches.patch_module import transform_method
|
|
794
|
+
|
|
795
|
+
class Model(torch.nn.Module):
|
|
796
|
+
def forward(self, x, y):
|
|
797
|
+
if x.sum() > 0:
|
|
798
|
+
w = x + y
|
|
799
|
+
z = x - y
|
|
800
|
+
else:
|
|
801
|
+
w = torch.abs(x) + y
|
|
802
|
+
z = torch.abs(x) - y
|
|
803
|
+
return w, z
|
|
804
|
+
|
|
805
|
+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
|
|
806
|
+
expected = Model()(x, y)
|
|
807
|
+
|
|
808
|
+
rewritten = transform_method(Model.forward)
|
|
809
|
+
print("-- code --")
|
|
810
|
+
print(rewritten.code)
|
|
811
|
+
|
|
812
|
+
print(" -- export --")
|
|
813
|
+
Model.forward = rewritten.func
|
|
814
|
+
|
|
815
|
+
DYN = torch.export.Dim.DYNAMIC
|
|
816
|
+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
|
|
817
|
+
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
|
|
818
|
+
print(ep)
|
|
819
|
+
"""
|
|
820
|
+
# Retrieve source of the function
|
|
821
|
+
modules = {k: v for k, v in func.__globals__.items() if inspect.ismodule(v)}
|
|
822
|
+
src = inspect.getsource(func)
|
|
823
|
+
sig = inspect.signature(func)
|
|
824
|
+
if verbose:
|
|
825
|
+
print(f"[transform_method] -- source -- {func}\n\n{src}\n\n[transform_method] --")
|
|
826
|
+
# Parse into AST
|
|
827
|
+
tree = ast.parse(textwrap.dedent(src))
|
|
828
|
+
if verbose > 1:
|
|
829
|
+
print(f"[transform_method] -- tree --\n\n{ast.dump(tree, indent=2)}")
|
|
830
|
+
# Apply transformation
|
|
831
|
+
transformer = RewriteControlFlow(
|
|
832
|
+
prefix=prefix,
|
|
833
|
+
skip_objects=modules,
|
|
834
|
+
args_names=set(sig.parameters),
|
|
835
|
+
filter_node=filter_node,
|
|
836
|
+
pre_rewriter=pre_rewriter,
|
|
837
|
+
post_rewriter=post_rewriter,
|
|
838
|
+
)
|
|
839
|
+
normalize_assignment_in_test(tree)
|
|
840
|
+
inplace_add_parent(tree)
|
|
841
|
+
new_tree = transformer.visit(tree)
|
|
842
|
+
if verbose > 1:
|
|
843
|
+
print(f"[transform_method] -- new tree --\n\n{ast.dump(tree, indent=2)}")
|
|
844
|
+
ast.fix_missing_locations(new_tree)
|
|
845
|
+
_settl(new_tree, 0)
|
|
846
|
+
|
|
847
|
+
if verbose > 0:
|
|
848
|
+
print(
|
|
849
|
+
f"[transform_method] -- new code --\n\n"
|
|
850
|
+
f"{ast.unparse(new_tree)}\n\n[transform_method] --"
|
|
851
|
+
)
|
|
852
|
+
try:
|
|
853
|
+
mod = compile(new_tree, filename="<ast>", mode="exec")
|
|
854
|
+
except TypeError as e:
|
|
855
|
+
if 'required field "lineno" missing from stmt' in str(e):
|
|
856
|
+
# Could not find a way to avoid compilng a string.
|
|
857
|
+
# The error message still pops up without indicating which node is not
|
|
858
|
+
# properly set.
|
|
859
|
+
code = ast.unparse(new_tree)
|
|
860
|
+
mod = compile(code, filename="<source>", mode="exec")
|
|
861
|
+
else:
|
|
862
|
+
kws = dict(include_attributes=True, annotate_fields=True, indent=4)
|
|
863
|
+
raise RuntimeError(
|
|
864
|
+
f"Unable to compile code\n--CODE--\n"
|
|
865
|
+
f"{ast.unparse(new_tree)}\n--TREE--\n"
|
|
866
|
+
f"{ast.dump(new_tree, **kws)}"
|
|
867
|
+
) from e
|
|
868
|
+
namespace: Dict[str, type] = {}
|
|
869
|
+
globs = func.__globals__.copy()
|
|
870
|
+
exec(mod, globs, namespace)
|
|
871
|
+
new_func = namespace.get(func.__name__)
|
|
872
|
+
if not isinstance(new_func, types.FunctionType):
|
|
873
|
+
raise RuntimeError("Transformed function not found")
|
|
874
|
+
return RewrittenMethod(new_tree, new_func)
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
@contextlib.contextmanager
|
|
878
|
+
def torch_export_rewrite(
|
|
879
|
+
rewrite: Optional[
|
|
880
|
+
Union["torch.nn.Module", List[Union[Tuple[type, str], Callable]]] # noqa: F821
|
|
881
|
+
] = None,
|
|
882
|
+
dump_rewriting: Optional[str] = None,
|
|
883
|
+
verbose: int = 0,
|
|
884
|
+
patch_details: Optional[PatchDetails] = None,
|
|
885
|
+
):
|
|
886
|
+
"""
|
|
887
|
+
Automatically rewrite the methods given in `rewrite` to export
|
|
888
|
+
control flows (test and loops).
|
|
889
|
+
|
|
890
|
+
:param rewrite: methods of functions to rewrite, if not empty, the function may try
|
|
891
|
+
to discover them, a method is defined by its class (a type) and its name
|
|
892
|
+
if the class is local, by itself otherwise, it can also be a model,
|
|
893
|
+
in that case, the function calls :func:`code_needing_rewriting
|
|
894
|
+
<onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting>`
|
|
895
|
+
to retrieve the necessary rewriting
|
|
896
|
+
:param dump_rewriting: dumps rewriting into that folder, if it does not exists,
|
|
897
|
+
it creates it.
|
|
898
|
+
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
|
|
899
|
+
``verbose=1`` shows the rewritten function,
|
|
900
|
+
``verbose=2`` shows the rewritten code as well
|
|
901
|
+
:param patch_details: to store any applied patch and get a better understanding
|
|
902
|
+
of the applied modifications
|
|
903
|
+
|
|
904
|
+
Example:
|
|
905
|
+
|
|
906
|
+
.. code-block:: python
|
|
907
|
+
|
|
908
|
+
class Model(torch.nn.Module):
|
|
909
|
+
def forward(self, x, y):
|
|
910
|
+
if x.sum() > 0:
|
|
911
|
+
return x + y
|
|
912
|
+
else:
|
|
913
|
+
return torch.abs(x) + y + 1
|
|
914
|
+
|
|
915
|
+
model = Model()
|
|
916
|
+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
|
|
917
|
+
DYN = torch.export.Dim.DYNAMIC
|
|
918
|
+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
|
|
919
|
+
with torch_export_rewrite(rewrite=[(Model, "forward")]):
|
|
920
|
+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
|
|
921
|
+
|
|
922
|
+
If the method to rewrite is not local, then the following can be used:
|
|
923
|
+
|
|
924
|
+
.. code-block:: python
|
|
925
|
+
|
|
926
|
+
with torch_export_rewrite(rewrite=[Model.forward]):
|
|
927
|
+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
|
|
928
|
+
|
|
929
|
+
Functions (if not local) can also be rewritten:
|
|
930
|
+
|
|
931
|
+
.. code-block:: python
|
|
932
|
+
|
|
933
|
+
def outside(x, y):
|
|
934
|
+
if x.sum() > 0:
|
|
935
|
+
return x + y
|
|
936
|
+
else:
|
|
937
|
+
return torch.abs(x) + y + 1
|
|
938
|
+
|
|
939
|
+
class Model(torch.nn.Module):
|
|
940
|
+
def forward(self, x, y):
|
|
941
|
+
return outside(x, y)
|
|
942
|
+
|
|
943
|
+
model = Model()
|
|
944
|
+
x, y = torch.rand((4, 5)), torch.rand((4, 5))
|
|
945
|
+
DYN = torch.export.Dim.DYNAMIC
|
|
946
|
+
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
|
|
947
|
+
with torch_export_rewrite(rewrite=[outside]):
|
|
948
|
+
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
|
|
949
|
+
"""
|
|
950
|
+
if hasattr(rewrite, "forward"):
|
|
951
|
+
# It is a torch.nn.Module.
|
|
952
|
+
# Let's retrieve the known rewriting for this model class.
|
|
953
|
+
rewrite = code_needing_rewriting(rewrite.__class__.__name__)
|
|
954
|
+
assert rewrite, "rewrite is empty, automated discovery is not implemented yet"
|
|
955
|
+
keep = {}
|
|
956
|
+
for me in rewrite:
|
|
957
|
+
if isinstance(me, tuple):
|
|
958
|
+
assert len(me) == 2, f"Unexpected value for a rewritten method or function {me}"
|
|
959
|
+
cls, name = me
|
|
960
|
+
to_rewrite = getattr(cls, name)
|
|
961
|
+
kind = "method"
|
|
962
|
+
kws = {} # type: ignore[var-annotated]
|
|
963
|
+
else:
|
|
964
|
+
if isinstance(me, dict):
|
|
965
|
+
assert "function" in me and (
|
|
966
|
+
"filter_node" in me or "pre_rewriter" in me or "post_rewriter" in me
|
|
967
|
+
), (
|
|
968
|
+
f"If the rewriting code is defined as a dictionary, key "
|
|
969
|
+
f"'function' must be defined, other arguments must be understood by "
|
|
970
|
+
f"{transform_method.__name__}, "
|
|
971
|
+
f"the given value is {me!r}."
|
|
972
|
+
)
|
|
973
|
+
kws = me
|
|
974
|
+
me = me["function"]
|
|
975
|
+
del kws["function"]
|
|
976
|
+
else:
|
|
977
|
+
kws = {}
|
|
978
|
+
name = me.__qualname__
|
|
979
|
+
spl = name.split(".")
|
|
980
|
+
if len(spl) == 1:
|
|
981
|
+
# This a function
|
|
982
|
+
module = me.__module__
|
|
983
|
+
if module in me.__globals__:
|
|
984
|
+
mod = me.__globals__[module]
|
|
985
|
+
else:
|
|
986
|
+
assert module in sys.modules, (
|
|
987
|
+
f"Cannot find module name {module!r} in sys.modules or "
|
|
988
|
+
f"__globals__={sorted(me.__globals__)}"
|
|
989
|
+
)
|
|
990
|
+
mod = sys.modules[module]
|
|
991
|
+
cls_name = module
|
|
992
|
+
cls = mod
|
|
993
|
+
name = name
|
|
994
|
+
to_rewrite = me
|
|
995
|
+
kind = "function"
|
|
996
|
+
else:
|
|
997
|
+
kind = "method"
|
|
998
|
+
# This is a method
|
|
999
|
+
assert len(spl) >= 2, (
|
|
1000
|
+
f"{me} is not method, its name {name!r} does not contain a class name, "
|
|
1001
|
+
f"dir(me)={dir(me)}"
|
|
1002
|
+
)
|
|
1003
|
+
cls_name = spl[-2]
|
|
1004
|
+
assert cls_name in me.__globals__, (
|
|
1005
|
+
f"Class name {cls_name!r} from method {name!r} "
|
|
1006
|
+
f"could not be found in set(me.__globals__)={sorted(me.__globals__)}"
|
|
1007
|
+
)
|
|
1008
|
+
cls = me.__globals__[cls_name]
|
|
1009
|
+
name = me.__name__
|
|
1010
|
+
to_rewrite = me
|
|
1011
|
+
assert hasattr(
|
|
1012
|
+
cls, name
|
|
1013
|
+
), f"Method {name!r} inferred form {me} was not found in class {cls}."
|
|
1014
|
+
assert (cls, name) not in keep, f"{kind} {me} cannot be rewritten twice."
|
|
1015
|
+
if verbose:
|
|
1016
|
+
print(f"[torch_export_rewrite] rewrites {kind} {cls.__name__}.{name}")
|
|
1017
|
+
keep[cls, name] = to_rewrite
|
|
1018
|
+
if dump_rewriting:
|
|
1019
|
+
if not os.path.exists(dump_rewriting):
|
|
1020
|
+
os.makedirs(dump_rewriting)
|
|
1021
|
+
filename1 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.original.py")
|
|
1022
|
+
if verbose:
|
|
1023
|
+
print(f"[torch_export_rewrite] dump original code in {filename1!r}")
|
|
1024
|
+
with open(filename1, "w") as f:
|
|
1025
|
+
code = clean_code_with_black(inspect.getsource(to_rewrite))
|
|
1026
|
+
f.write(code)
|
|
1027
|
+
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0), **kws)
|
|
1028
|
+
if dump_rewriting:
|
|
1029
|
+
filename2 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.rewritten.py")
|
|
1030
|
+
if verbose:
|
|
1031
|
+
print(f"[torch_export_rewrite] dump rewritten code in {filename2!r}")
|
|
1032
|
+
with open(filename2, "w") as f:
|
|
1033
|
+
rcode = clean_code_with_black(rewr.code)
|
|
1034
|
+
f.write(rcode)
|
|
1035
|
+
diff = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.diff")
|
|
1036
|
+
make_diff_code(code, rcode, diff)
|
|
1037
|
+
if patch_details:
|
|
1038
|
+
patch_details.append("rewrite", getattr(cls, name), rewr.func)
|
|
1039
|
+
setattr(cls, name, rewr.func)
|
|
1040
|
+
|
|
1041
|
+
try:
|
|
1042
|
+
yield
|
|
1043
|
+
finally:
|
|
1044
|
+
for (cls, name), me in keep.items():
|
|
1045
|
+
if verbose:
|
|
1046
|
+
print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}")
|
|
1047
|
+
setattr(cls, name, me)
|