onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1047 @@
1
+ import ast
2
+ import copy
3
+ import contextlib
4
+ import inspect
5
+ import os
6
+ import types
7
+ import textwrap
8
+ import sys
9
+ from typing import Callable, Dict, List, Set, Optional, Tuple, Union
10
+ from .patch_module_helper import code_needing_rewriting
11
+ from .patch_details import PatchDetails, make_diff_code, clean_code_with_black
12
+
13
+ NODE_TYPES = tuple(
14
+ getattr(ast, k)
15
+ for k in dir(ast)
16
+ if "A" <= k[0] <= "Z" and isinstance(getattr(ast, k), type)
17
+ )
18
+
19
+
20
+ def _settl(node, lineno, level=0):
21
+ if isinstance(node, (str, int, float)):
22
+ return node
23
+ if isinstance(node, list):
24
+ for n in node:
25
+ _settl(n, lineno, level=level + 1)
26
+ return node
27
+ if isinstance(node, NODE_TYPES):
28
+ if not hasattr(node, "lineno") or node.lineno is None:
29
+ node.lineno = lineno
30
+ for k in dir(node):
31
+ if k in {"s", "n", "parent"}:
32
+ continue
33
+ if k[0] == "_":
34
+ continue
35
+ v = getattr(node, k)
36
+ _settl(v, max(lineno, node.lineno), level=level + 1)
37
+ return node
38
+
39
+
40
+ class UsedVarsFinder(ast.NodeVisitor):
41
+ """Finds used and defined local variables with a section."""
42
+
43
+ def __init__(self):
44
+ self.used = set()
45
+ self.defined = set()
46
+
47
+ def visit_Name(self, node):
48
+ if isinstance(node.ctx, ast.Load):
49
+ self.used.add(node.id)
50
+ elif isinstance(node.ctx, ast.Store):
51
+ self.defined.add(node.id)
52
+ self.generic_visit(node)
53
+
54
+ def visit_Global(self, node):
55
+ pass
56
+
57
+ def visit_Nonlocal(self, node):
58
+ pass
59
+
60
+
61
+ class ShapeFinder(ast.NodeVisitor):
62
+ """Finds <x> in the expression ``x.shape[0]``."""
63
+
64
+ def __init__(self):
65
+ self.found_shape = set()
66
+ super().__init__()
67
+
68
+ def visit_Call(self, node):
69
+ if isinstance(node.func, ast.Name) and node.func.id == "range" and len(node.args) == 1:
70
+ n = node.args[0]
71
+ if (
72
+ isinstance(n, ast.Subscript)
73
+ and isinstance(n.slice, ast.Constant)
74
+ and isinstance(n.slice.value, int)
75
+ and n.slice.value == 0
76
+ and isinstance(n.value, ast.Attribute)
77
+ and isinstance(n.value.value, ast.Name)
78
+ and n.value.attr == "shape"
79
+ ):
80
+ self.found_shape.add(n.value.value.id)
81
+
82
+ self.generic_visit(node)
83
+
84
+
85
+ class RewriteControlFlow(ast.NodeTransformer):
86
+ """
87
+ The class rewrites tests with function :func:`torch.cond`.
88
+ ``empty_tensor`` is a function returning an empty tensor,
89
+ when a branch returns something the other branch does not.
90
+
91
+ :param prefix: prefix used for nested tests
92
+ :param skip_objects: to skip variable names if included in that list
93
+ such as modules
94
+ :param args_names: defines the local variables
95
+ :param filter_nodes: a function which is used to decide which node
96
+ to rewrite, True by default
97
+ :param pre_rewriter: a rewriter applied before the automated rewriting
98
+ :param post_rewriter: a rewriter applied after the automated rewriting
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ prefix: str = "branch_cond",
104
+ skip_objects: Optional[Dict[str, object]] = None,
105
+ args_names: Optional[Set[str]] = None,
106
+ filter_node: Optional[Callable[["ast.Node"], bool]] = None,
107
+ pre_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
108
+ post_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
109
+ ):
110
+ self.counter_test = 0
111
+ self.counter_loop = 0
112
+ self.current_func_args = None
113
+ self.prefix = prefix
114
+ self.skip_objects = skip_objects or {}
115
+ self.args_names = args_names or set()
116
+ self.local_variables = self.args_names.copy()
117
+ self.filter_node = filter_node or (lambda _node: True)
118
+ self.pre_rewriter = pre_rewriter or (lambda node: node)
119
+ self.post_rewriter = post_rewriter or (lambda node: node)
120
+
121
+ def generic_visit(self, node):
122
+ return super().generic_visit(node)
123
+
124
+ def _check(
125
+ self, cond: bool, node: "ast.Node", msg: str, cls: Optional[type[Exception]] = None
126
+ ):
127
+ """
128
+ Checks the condition is True, otherwise raises an exception with an error message
129
+ including the parsed code.
130
+ """
131
+ if cls is not None:
132
+ if not cond:
133
+ smsg = msg if isinstance(msg, str) else msg()
134
+ raise cls(f"{smsg}\n\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}")
135
+ return
136
+ assert cond, (
137
+ f"{msg if isinstance(msg, str) else msg()}\n\n--\n"
138
+ f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
139
+ )
140
+
141
+ def visit_Name(self, node):
142
+ node = self.generic_visit(node)
143
+ if isinstance(node.ctx, ast.Store):
144
+ self.local_variables.add(node.id)
145
+ return node
146
+
147
+ def visit_FunctionDef(self, node):
148
+ # Capture argument names for branch functions
149
+ old_args = self.current_func_args
150
+ self.current_func_args = [arg.arg for arg in node.args.args]
151
+ node.body = [self.visit(n) for n in node.body]
152
+ self.current_func_args = old_args
153
+ return node
154
+
155
+ def _find_id(self, exprs: List["ast.Node"]) -> List[str]:
156
+ vars = []
157
+ for expr in exprs:
158
+ for n in ast.walk(expr):
159
+ if (
160
+ isinstance(n, ast.Name)
161
+ # and isinstance(n.ctx, ast.Load)
162
+ and n.id not in self.skip_objects
163
+ ):
164
+ vars.append(n.id)
165
+ return sorted(set(vars))
166
+
167
+ def _clone(self, name):
168
+ assert isinstance(name, ast.Name), f"Unexpected type {type(name)} for name"
169
+ return ast.Call(
170
+ func=ast.Attribute(value=name, attr="clone", ctx=ast.Load()), args=[], keywords=[]
171
+ )
172
+
173
+ def _rewrite_if(
174
+ self, node, then_exprs, else_exprs, tgt_mapping=None, known_local_variables=None
175
+ ):
176
+ assert known_local_variables is not None, "known_local_variables cannot be None"
177
+ test_node = node.test
178
+ drop = set()
179
+
180
+ # extract free variables
181
+ then_name = f"{self.prefix}_then_{self.counter_test}"
182
+ else_name = f"{self.prefix}_else_{self.counter_test}"
183
+ then_vars = self._find_id(then_exprs)
184
+ else_vars = self._find_id(else_exprs)
185
+ then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ in known_local_variables)
186
+ then_ret, else_ret = None, None
187
+ if tgt_mapping is None and len(then_exprs) == 1 and len(else_exprs) == 1:
188
+ # return
189
+ then_ret = then_exprs[0]
190
+ else_ret = else_exprs[0]
191
+ then_exprs = [n for n in node.body if not isinstance(n, ast.Return)]
192
+ else_exprs = [n for n in node.orelse if not isinstance(n, ast.Return)]
193
+ is_tuple_or_list = (
194
+ isinstance(then_ret, (ast.Tuple, ast.List)),
195
+ isinstance(else_ret, (ast.Tuple, ast.List)),
196
+ )
197
+ assert len(set(is_tuple_or_list)) == 1, (
198
+ f"is_tuple_or_list={is_tuple_or_list}, inconsistencies return "
199
+ f"then value={then_ret}, "
200
+ f"else value={else_ret}"
201
+ )
202
+ if is_tuple_or_list[0]:
203
+ assert len(then_ret.elts) == len(else_ret.elts), (
204
+ f"Unexpected number of elements on both branches, "
205
+ f"then:{then_ret.elts}, else:{else_ret.elts}"
206
+ )
207
+ n_returned_values = len(then_ret.elts)
208
+ else:
209
+ n_returned_values = 0
210
+ else:
211
+ self._check(
212
+ tgt_mapping,
213
+ node,
214
+ "then and else branches do not have the same number "
215
+ "of assignments, we need more information to understand "
216
+ "which ones to return",
217
+ )
218
+ drop = set()
219
+ then_exprs, else_exprs = node.body, node.orelse
220
+ then_rets, else_rets = [], []
221
+ for t, then_else in sorted(tgt_mapping.items()):
222
+ then_e, else_e = then_else
223
+ if (then_e is None or else_e is None) and t not in then_else_vars:
224
+ # The variable is not used by one branch and it is not an input.
225
+ # Let's drop it.
226
+ drop.add(t)
227
+ continue
228
+ then_rets.append(then_e or ast.Name(else_e.id, ctx=ast.Load()))
229
+ else_rets.append(else_e or ast.Name(then_e.id, ctx=ast.Load()))
230
+ then_ret = (
231
+ self._clone(then_rets[0])
232
+ if len(then_rets) == 1
233
+ else ast.Tuple([self._clone(r) for r in then_rets], ctx=ast.Load())
234
+ )
235
+ else_ret = (
236
+ self._clone(else_rets[0])
237
+ if len(else_rets) == 1
238
+ else ast.Tuple([self._clone(r) for r in else_rets], ctx=ast.Load())
239
+ )
240
+ n_returned_values = len(then_rets) if len(then_rets) > 1 else 0
241
+
242
+ # build local funcs
243
+ then_def = ast.FunctionDef(
244
+ name=then_name,
245
+ args=ast.arguments(
246
+ posonlyargs=[],
247
+ args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
248
+ kwonlyargs=[],
249
+ kw_defaults=[],
250
+ defaults=[],
251
+ ),
252
+ body=[*then_exprs, ast.Return(then_ret)],
253
+ decorator_list=[],
254
+ returns=None,
255
+ )
256
+ else_def = ast.FunctionDef(
257
+ name=else_name,
258
+ args=ast.arguments(
259
+ posonlyargs=[],
260
+ args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
261
+ kwonlyargs=[],
262
+ kw_defaults=[],
263
+ defaults=[],
264
+ ),
265
+ body=[*else_exprs, ast.Return(else_ret)],
266
+ decorator_list=[],
267
+ returns=None,
268
+ )
269
+ # fix locations
270
+ for n in (then_def, else_def):
271
+ ast.copy_location(n, node)
272
+ ast.fix_missing_locations(n)
273
+ assert hasattr(n, "lineno")
274
+ # wrapper call and assignment
275
+ then_else_args_list = ast.List(
276
+ [ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
277
+ ctx=ast.Load(),
278
+ )
279
+
280
+ call = ast.Call(
281
+ func=ast.Attribute(
282
+ value=ast.Name(id="torch", ctx=ast.Load()), attr="cond", ctx=ast.Load()
283
+ ),
284
+ args=[
285
+ test_node,
286
+ ast.Name(id=then_name, ctx=ast.Load()),
287
+ ast.Name(id=else_name, ctx=ast.Load()),
288
+ then_else_args_list,
289
+ ],
290
+ keywords=[],
291
+ )
292
+ return then_def, else_def, call, drop, n_returned_values
293
+
294
+ def _filter_target(self, node, tgt_mapping):
295
+ """
296
+ This function should reduce the number of elements to return
297
+ by looking at the one used after the If statement.
298
+ """
299
+ return tgt_mapping
300
+
301
+ def _make_targets(self, node, then_assigns, else_assigns):
302
+ tgt_mapping = {}
303
+ for a, then_or_else in [
304
+ *[(a, True) for a in then_assigns],
305
+ *[(a, False) for a in else_assigns],
306
+ ]:
307
+ for t in a.targets:
308
+ if isinstance(t, ast.Name) and isinstance(t.ctx, ast.Store):
309
+ if t.id not in tgt_mapping:
310
+ tgt_mapping[t.id] = (t, None) if then_or_else else (None, t)
311
+ else:
312
+ v = tgt_mapping[t.id]
313
+ tgt_mapping[t.id] = (t, v[1]) if then_or_else else (v[0], t)
314
+ continue
315
+
316
+ self._check(
317
+ isinstance(t, ast.Tuple) and all(isinstance(_, ast.Name) for _ in t.elts),
318
+ node,
319
+ "Unexpected assignment. Not Supported.",
320
+ )
321
+ for _t in t.elts:
322
+ if not isinstance(_t, ast.Name) or not isinstance(_t.ctx, ast.Store):
323
+ continue
324
+ if _t.id not in tgt_mapping:
325
+ tgt_mapping[_t.id] = (_t, None) if then_or_else else (None, _t)
326
+ else:
327
+ v = tgt_mapping[_t.id]
328
+ tgt_mapping[_t.id] = (_t, v[1]) if then_or_else else (v[0], _t)
329
+
330
+ tgt_mapping = self._filter_target(node, tgt_mapping)
331
+ d = [(v[0] or v[1]) for k, v in sorted(dict(tgt_mapping).items())]
332
+ tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load())
333
+ return tgt, tgt_mapping
334
+
335
+ def visit_If(self, node):
336
+ if not self.filter_node(node):
337
+ return [node]
338
+
339
+ node = self.pre_rewriter(node)
340
+
341
+ # First recurse into subnodes
342
+ known_local_variables = self.local_variables.copy()
343
+ node = self.generic_visit(node)
344
+
345
+ has_then_return = any(isinstance(n, ast.Return) for n in node.body)
346
+ has_else_return = any(isinstance(n, ast.Return) for n in node.orelse)
347
+ ok = (has_then_return and has_else_return) or (
348
+ not has_then_return and not has_else_return
349
+ )
350
+ self._check(
351
+ ok,
352
+ node,
353
+ "Cannot mix return and assignment in a test or a "
354
+ "unique then branch with a return",
355
+ NotImplementedError,
356
+ )
357
+ self._check(self.current_func_args is not None, node, "current_func_args is None")
358
+ self.counter_test += 1
359
+
360
+ if not has_then_return:
361
+ # Case 1: simple assignment in both branches
362
+ then_assigns = [n for n in node.body if isinstance(n, ast.Assign)]
363
+ else_assigns = [n for n in node.orelse if isinstance(n, ast.Assign)]
364
+ self._check(then_assigns or else_assigns, node, "Missing assignment")
365
+
366
+ # the targets we need to export
367
+ tgt, tgt_mapping = self._make_targets(node, then_assigns, else_assigns)
368
+
369
+ then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
370
+ node,
371
+ then_assigns,
372
+ else_assigns,
373
+ tgt_mapping=tgt_mapping,
374
+ known_local_variables=known_local_variables,
375
+ )
376
+ if dropped and isinstance(tgt, ast.Tuple):
377
+ tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
378
+ elif isinstance(tgt, ast.Tuple):
379
+ tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
380
+ else:
381
+ tgt_elts = [tgt]
382
+
383
+ if n_returned_values == 0:
384
+ assert len(tgt_elts) == 1, (
385
+ f"Inconsistencies between n_returned_values={n_returned_values}, "
386
+ f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
387
+ )
388
+ tgt = tgt_elts[0]
389
+ else:
390
+ assert n_returned_values == len(tgt_elts), (
391
+ f"Inconsistencies between n_returned_values={n_returned_values}, "
392
+ f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
393
+ )
394
+ tgt = ast.Tuple(tgt_elts, ctx=ast.Store())
395
+
396
+ added = {tgt.id} if isinstance(tgt, ast.Name) else set(t.id for t in tgt.elts)
397
+ assign = ast.Assign(targets=[tgt], value=call)
398
+ ast.copy_location(assign, node)
399
+ ast.fix_missing_locations(assign)
400
+ self.local_variables = known_local_variables | added
401
+ return [self.post_rewriter(n) for n in [then_def, else_def, assign]]
402
+
403
+ # Case 2: return in both branches, we assume both branches return the same results.
404
+ then_ret = node.body[-1]
405
+ else_ret = node.orelse[-1]
406
+ self._check(
407
+ isinstance(then_ret, ast.Return),
408
+ node,
409
+ "return is not the last instruction of then branch",
410
+ )
411
+ self._check(
412
+ isinstance(else_ret, ast.Return),
413
+ node,
414
+ "return is not the last instruction of else branch",
415
+ )
416
+ then_expr = then_ret.value
417
+ else_expr = else_ret.value
418
+ then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
419
+ node, [then_expr], [else_expr], known_local_variables=known_local_variables
420
+ )
421
+ ret = ast.Return(call)
422
+ ast.copy_location(ret, node)
423
+ ast.fix_missing_locations(ret)
424
+ return [self.post_rewriter(n) for n in [then_def, else_def, ret]]
425
+
426
+ def _find_loop_vars(self, node):
427
+ assert isinstance(node, ast.For), f"Unexpected type {type(node)} for node"
428
+ finder = ShapeFinder()
429
+ finder.visit(node.iter)
430
+ scan_shape_vars = finder.found_shape
431
+ scan_vars = set()
432
+
433
+ finder = UsedVarsFinder()
434
+ for stmt in node.body:
435
+ finder.visit(stmt)
436
+
437
+ assigned_in_body = set()
438
+ for stmt in node.body:
439
+ if isinstance(stmt, ast.Assign):
440
+ for tgt in stmt.targets:
441
+ if isinstance(tgt, ast.Name) and isinstance(tgt.value.ctx, ast.Store):
442
+ assigned_in_body |= {tgt.value.id}
443
+
444
+ extra_defined = set()
445
+ for stmt in node.body:
446
+ if isinstance(stmt, ast.Assign):
447
+ for tgt in stmt.targets:
448
+ if isinstance(tgt, ast.Subscript):
449
+ # It means the target existed before.
450
+ if (
451
+ isinstance(tgt.value, ast.Name)
452
+ and tgt.value.id not in assigned_in_body
453
+ ):
454
+ extra_defined.add(tgt.value.id)
455
+
456
+ loop_vars = set()
457
+ if isinstance(node.target, ast.Name):
458
+ loop_vars.add(node.target.id)
459
+ elif isinstance(node.target, (ast.Tuple, ast.List)):
460
+ loop_vars |= {elt.id for elt in node.target.elts if isinstance(elt, ast.Name)}
461
+
462
+ output_vars = finder.defined | assigned_in_body
463
+ input_vars = (
464
+ finder.used
465
+ - finder.defined
466
+ - loop_vars
467
+ - scan_shape_vars
468
+ - scan_vars
469
+ - output_vars
470
+ - assigned_in_body
471
+ - extra_defined
472
+ )
473
+ return dict(
474
+ init=sorted(extra_defined),
475
+ loop=sorted(loop_vars),
476
+ scan_shape=sorted(scan_shape_vars),
477
+ scan=sorted(scan_vars),
478
+ input=sorted(input_vars),
479
+ output=sorted(output_vars),
480
+ )
481
+
482
+ def visit_For(self, node):
483
+ if not self.filter_node(node):
484
+ return [node]
485
+
486
+ node = self.pre_rewriter(node)
487
+
488
+ # For nested loops.
489
+ self.generic_visit(node)
490
+ # look for variables, loop, inputs and outputs of the body
491
+ vars = self._find_loop_vars(node)
492
+ init_vars, loop_vars, scan_shape_vars, scan_vars, input_vars, output_vars = [
493
+ vars[k] for k in ["init", "loop", "scan_shape", "scan", "input", "output"]
494
+ ]
495
+ self._check(
496
+ len(scan_shape_vars) == len(loop_vars),
497
+ node,
498
+ lambda: (
499
+ f"Inconsistencies between loop_vars={loop_vars} "
500
+ f"and scan_shape_vars={scan_shape_vars}"
501
+ ),
502
+ )
503
+ self._check(
504
+ len(scan_shape_vars) in {0, 1},
505
+ node,
506
+ lambda: f"Inconsistencies with scan_shape_vars={scan_shape_vars}",
507
+ )
508
+ self._check(
509
+ (len(scan_shape_vars) == 0 or len(scan_vars) == 0)
510
+ and (scan_shape_vars or scan_vars),
511
+ node,
512
+ lambda: (
513
+ f"Inconsistencies between scan_vars={scan_vars} "
514
+ f"and scan_shape_vars={scan_shape_vars}"
515
+ ),
516
+ )
517
+
518
+ # creates the function
519
+ func_name = f"loop_body_{self.counter_loop}"
520
+ self.counter_loop += 1
521
+ func_def = ast.FunctionDef(
522
+ name=func_name,
523
+ args=ast.arguments(
524
+ posonlyargs=[],
525
+ args=[
526
+ ast.arg(arg=v)
527
+ for v in [
528
+ *init_vars,
529
+ *loop_vars,
530
+ *scan_vars,
531
+ *scan_shape_vars,
532
+ *input_vars,
533
+ ]
534
+ ],
535
+ kwonlyargs=[],
536
+ kw_defaults=[],
537
+ defaults=[],
538
+ ),
539
+ body=[
540
+ *[
541
+ ast.Assign(
542
+ targets=[ast.Name(id=i, ctx=ast.Load())],
543
+ value=[
544
+ ast.Call(
545
+ func=ast.Attribute(
546
+ value=ast.Name(id=i, ctx=ast.Load()),
547
+ attr="clone",
548
+ ctx=ast.Load(),
549
+ ),
550
+ args=[],
551
+ keywords=[],
552
+ ctx=ast.Load(),
553
+ )
554
+ ],
555
+ )
556
+ for i in init_vars
557
+ ],
558
+ *node.body,
559
+ ast.Return(
560
+ value=ast.List(
561
+ [
562
+ ast.Name(id=v, ctx=ast.Load())
563
+ for v in [*init_vars, *loop_vars, *output_vars]
564
+ ],
565
+ ctx=ast.Load(),
566
+ )
567
+ ),
568
+ ],
569
+ decorator_list=[],
570
+ ctx=ast.Store(),
571
+ )
572
+
573
+ # final rewriting
574
+ call = ast.Call(
575
+ func=(
576
+ ast.Attribute(
577
+ value=ast.Attribute(
578
+ value=ast.Attribute(
579
+ value=ast.Name(id="torch", ctx=ast.Load()),
580
+ attr="ops",
581
+ ctx=ast.Load(),
582
+ ),
583
+ attr="higher_order",
584
+ ctx=ast.Load(),
585
+ ),
586
+ attr="scan",
587
+ ctx=ast.Load(),
588
+ )
589
+ ),
590
+ args=[
591
+ ast.Name(id=func_name, ctx=ast.Load()),
592
+ ast.List(
593
+ elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Store()
594
+ ),
595
+ ast.List(
596
+ elts=[
597
+ *[
598
+ ast.Call(
599
+ ast.Attribute(
600
+ value=ast.Name(id="torch", ctx=ast.Load()),
601
+ attr="arange",
602
+ ctx=ast.Load(),
603
+ ),
604
+ args=[
605
+ ast.Subscript(
606
+ value=ast.Attribute(
607
+ value=ast.Name(id=v, ctx=ast.Load()),
608
+ attr="shape",
609
+ ctx=ast.Load(),
610
+ ),
611
+ slice=ast.Constant(value=0, ctx=ast.Load()),
612
+ ctx=ast.Load(),
613
+ ),
614
+ ],
615
+ keywords=[
616
+ ast.keyword(
617
+ arg="dtype",
618
+ value=ast.Attribute(
619
+ value=ast.Name(id="torch", ctx=ast.Load()),
620
+ attr="int64",
621
+ ctx=ast.Load(),
622
+ ),
623
+ )
624
+ ],
625
+ ctx=ast.Load(),
626
+ )
627
+ for v in scan_shape_vars
628
+ ],
629
+ *[ast.Name(id=v, ctx=ast.Load()) for v in scan_vars],
630
+ ],
631
+ ctx=ast.Store(),
632
+ ),
633
+ ast.List(
634
+ elts=[
635
+ ast.Name(id=v, ctx=ast.Load()) for v in [*scan_shape_vars, *input_vars]
636
+ ],
637
+ ctx=ast.Store(),
638
+ ),
639
+ ],
640
+ keywords=[],
641
+ ctx=ast.Load(),
642
+ )
643
+ target = ast.Tuple(
644
+ [ast.Name(id=v, ctx=ast.Store()) for v in [*init_vars, *loop_vars, *output_vars]],
645
+ ctx=ast.Store(),
646
+ )
647
+ assign = ast.Assign(targets=[target], value=call)
648
+ return [self.post_rewriter(func_def), self.post_rewriter(assign)]
649
+
650
+
651
+ class RewrittenMethod:
652
+ """
653
+ Stores a rewritten method using
654
+ :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
655
+
656
+ :param tree: ast tree
657
+ :param func: callable compiled from the tree
658
+ """
659
+
660
+ def __init__(self, tree, func):
661
+ self.tree = tree
662
+ self.func = func
663
+
664
+ @property
665
+ def code(self) -> str:
666
+ """Returns the source."""
667
+ return ast.unparse(self.tree)
668
+
669
+ @property
670
+ def dump(self) -> str:
671
+ """Returns the tree dumped as a string."""
672
+ return ast.dump(self.tree, indent=2)
673
+
674
+ def __repr__(self):
675
+ "usual"
676
+ return f"{self.__class__.__name__}({self.func})"
677
+
678
+
679
+ class _AddParentTransformer(ast.NodeTransformer):
680
+ parent = None
681
+
682
+ def visit(self, node):
683
+ node.parent = self.parent
684
+ self.parent = node
685
+ node = super().visit(node)
686
+ if isinstance(node, ast.AST):
687
+ self.parent = node.parent
688
+ return node
689
+
690
+
691
+ class _SelectiveAssignNormalizer(ast.NodeTransformer):
692
+ def visit_If(self, node):
693
+ self.generic_visit(node)
694
+ node.body = [self._transform_if_needed(stmt) for stmt in node.body]
695
+ node.orelse = [self._transform_if_needed(stmt) for stmt in node.orelse]
696
+ return node
697
+
698
+ def _transform_if_needed(self, stmt):
699
+ if isinstance(stmt, ast.AugAssign):
700
+ return ast.Assign(
701
+ targets=[stmt.target],
702
+ value=ast.BinOp(left=copy.deepcopy(stmt.target), op=stmt.op, right=stmt.value),
703
+ )
704
+ if isinstance(stmt, ast.AnnAssign) and stmt.value is not None:
705
+ return ast.Assign(targets=[stmt.target], value=stmt.value)
706
+ return self.visit(stmt)
707
+
708
+
709
+ def inplace_add_parent(tree: "ast.Node"):
710
+ """Adds parents to an AST tree."""
711
+ _AddParentTransformer().visit(tree)
712
+
713
+
714
+ def normalize_assignment_in_test(tree: "ast.Node"):
715
+ """Split AugAssign into BinOp and Assign to simplify whatever comes after."""
716
+ _SelectiveAssignNormalizer().visit(tree)
717
+
718
+
719
+ def transform_method(
720
+ func: Callable,
721
+ prefix: str = "branch_cond",
722
+ verbose: int = 0,
723
+ filter_node: Optional[Callable[["ast.Node"], bool]] = None,
724
+ pre_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
725
+ post_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
726
+ ) -> RewrittenMethod:
727
+ """
728
+ Returns a new function based on `func` where every test (if)
729
+ is replaced by a call to :func:`torch.cond`.
730
+ Some known rewriting are part of the default patches
731
+ (see :ref:`l-control-flow-rewriting`).
732
+
733
+ A test must return the same things if it returns something
734
+ or assign something. It cannot return in one branch and assign
735
+ in the other branch.
736
+
737
+ .. warning:: room for improvement
738
+
739
+ When it assigns a value to a constant,
740
+ the current implementation does check which ones is really used
741
+ after the test. The rewritten local functions returns every
742
+ assigned variable. This could be reduced.
743
+ See method ``_filter_target``.
744
+
745
+ :param func: method or function to rewrite
746
+ :param prefix: prefix used to create the functions for the branches
747
+ :param verbose: verbosity
748
+ :param filter_node: a function which tells which node to rewrite
749
+ :param pre_rewriter: a rewriter applied before the automated rewriting
750
+ :param post_rewriter: a rewriter applied after the automated rewriting
751
+ :return: rewritten method
752
+
753
+ An example with **return**:
754
+
755
+ .. runpython::
756
+ :showcode:
757
+ :process:
758
+ :store_in_file: test_example_transform_method_1.py
759
+
760
+ import torch
761
+ from onnx_diagnostic.torch_export_patches.patch_module import transform_method
762
+
763
+ class Model(torch.nn.Module):
764
+ def forward(self, x, y):
765
+ if x.sum() > 0:
766
+ return x + y, x - y
767
+ else:
768
+ return torch.abs(x) + y, torch.abs(x) - y
769
+
770
+ x, y = torch.rand((3, 4)), torch.rand((3, 4))
771
+ expected = Model()(x, y)
772
+
773
+ rewritten = transform_method(Model.forward)
774
+ print("-- code --")
775
+ print(rewritten.code)
776
+
777
+ print(" -- export --")
778
+ Model.forward = rewritten.func
779
+
780
+ DYN = torch.export.Dim.DYNAMIC
781
+ ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
782
+ ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
783
+ print(ep)
784
+
785
+ An example with **assignments**:
786
+
787
+ .. runpython::
788
+ :showcode:
789
+ :process:
790
+ :store_in_file: test_example_transform_method_2.py
791
+
792
+ import torch
793
+ from onnx_diagnostic.torch_export_patches.patch_module import transform_method
794
+
795
+ class Model(torch.nn.Module):
796
+ def forward(self, x, y):
797
+ if x.sum() > 0:
798
+ w = x + y
799
+ z = x - y
800
+ else:
801
+ w = torch.abs(x) + y
802
+ z = torch.abs(x) - y
803
+ return w, z
804
+
805
+ x, y = torch.rand((3, 4)), torch.rand((3, 4))
806
+ expected = Model()(x, y)
807
+
808
+ rewritten = transform_method(Model.forward)
809
+ print("-- code --")
810
+ print(rewritten.code)
811
+
812
+ print(" -- export --")
813
+ Model.forward = rewritten.func
814
+
815
+ DYN = torch.export.Dim.DYNAMIC
816
+ ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
817
+ ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
818
+ print(ep)
819
+ """
820
+ # Retrieve source of the function
821
+ modules = {k: v for k, v in func.__globals__.items() if inspect.ismodule(v)}
822
+ src = inspect.getsource(func)
823
+ sig = inspect.signature(func)
824
+ if verbose:
825
+ print(f"[transform_method] -- source -- {func}\n\n{src}\n\n[transform_method] --")
826
+ # Parse into AST
827
+ tree = ast.parse(textwrap.dedent(src))
828
+ if verbose > 1:
829
+ print(f"[transform_method] -- tree --\n\n{ast.dump(tree, indent=2)}")
830
+ # Apply transformation
831
+ transformer = RewriteControlFlow(
832
+ prefix=prefix,
833
+ skip_objects=modules,
834
+ args_names=set(sig.parameters),
835
+ filter_node=filter_node,
836
+ pre_rewriter=pre_rewriter,
837
+ post_rewriter=post_rewriter,
838
+ )
839
+ normalize_assignment_in_test(tree)
840
+ inplace_add_parent(tree)
841
+ new_tree = transformer.visit(tree)
842
+ if verbose > 1:
843
+ print(f"[transform_method] -- new tree --\n\n{ast.dump(tree, indent=2)}")
844
+ ast.fix_missing_locations(new_tree)
845
+ _settl(new_tree, 0)
846
+
847
+ if verbose > 0:
848
+ print(
849
+ f"[transform_method] -- new code --\n\n"
850
+ f"{ast.unparse(new_tree)}\n\n[transform_method] --"
851
+ )
852
+ try:
853
+ mod = compile(new_tree, filename="<ast>", mode="exec")
854
+ except TypeError as e:
855
+ if 'required field "lineno" missing from stmt' in str(e):
856
+ # Could not find a way to avoid compilng a string.
857
+ # The error message still pops up without indicating which node is not
858
+ # properly set.
859
+ code = ast.unparse(new_tree)
860
+ mod = compile(code, filename="<source>", mode="exec")
861
+ else:
862
+ kws = dict(include_attributes=True, annotate_fields=True, indent=4)
863
+ raise RuntimeError(
864
+ f"Unable to compile code\n--CODE--\n"
865
+ f"{ast.unparse(new_tree)}\n--TREE--\n"
866
+ f"{ast.dump(new_tree, **kws)}"
867
+ ) from e
868
+ namespace: Dict[str, type] = {}
869
+ globs = func.__globals__.copy()
870
+ exec(mod, globs, namespace)
871
+ new_func = namespace.get(func.__name__)
872
+ if not isinstance(new_func, types.FunctionType):
873
+ raise RuntimeError("Transformed function not found")
874
+ return RewrittenMethod(new_tree, new_func)
875
+
876
+
877
+ @contextlib.contextmanager
878
+ def torch_export_rewrite(
879
+ rewrite: Optional[
880
+ Union["torch.nn.Module", List[Union[Tuple[type, str], Callable]]] # noqa: F821
881
+ ] = None,
882
+ dump_rewriting: Optional[str] = None,
883
+ verbose: int = 0,
884
+ patch_details: Optional[PatchDetails] = None,
885
+ ):
886
+ """
887
+ Automatically rewrite the methods given in `rewrite` to export
888
+ control flows (test and loops).
889
+
890
+ :param rewrite: methods of functions to rewrite, if not empty, the function may try
891
+ to discover them, a method is defined by its class (a type) and its name
892
+ if the class is local, by itself otherwise, it can also be a model,
893
+ in that case, the function calls :func:`code_needing_rewriting
894
+ <onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting>`
895
+ to retrieve the necessary rewriting
896
+ :param dump_rewriting: dumps rewriting into that folder, if it does not exists,
897
+ it creates it.
898
+ :param verbose: verbosity, up to 10, 10 shows the rewritten code,
899
+ ``verbose=1`` shows the rewritten function,
900
+ ``verbose=2`` shows the rewritten code as well
901
+ :param patch_details: to store any applied patch and get a better understanding
902
+ of the applied modifications
903
+
904
+ Example:
905
+
906
+ .. code-block:: python
907
+
908
+ class Model(torch.nn.Module):
909
+ def forward(self, x, y):
910
+ if x.sum() > 0:
911
+ return x + y
912
+ else:
913
+ return torch.abs(x) + y + 1
914
+
915
+ model = Model()
916
+ x, y = torch.rand((4, 5)), torch.rand((4, 5))
917
+ DYN = torch.export.Dim.DYNAMIC
918
+ ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
919
+ with torch_export_rewrite(rewrite=[(Model, "forward")]):
920
+ ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
921
+
922
+ If the method to rewrite is not local, then the following can be used:
923
+
924
+ .. code-block:: python
925
+
926
+ with torch_export_rewrite(rewrite=[Model.forward]):
927
+ ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
928
+
929
+ Functions (if not local) can also be rewritten:
930
+
931
+ .. code-block:: python
932
+
933
+ def outside(x, y):
934
+ if x.sum() > 0:
935
+ return x + y
936
+ else:
937
+ return torch.abs(x) + y + 1
938
+
939
+ class Model(torch.nn.Module):
940
+ def forward(self, x, y):
941
+ return outside(x, y)
942
+
943
+ model = Model()
944
+ x, y = torch.rand((4, 5)), torch.rand((4, 5))
945
+ DYN = torch.export.Dim.DYNAMIC
946
+ ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
947
+ with torch_export_rewrite(rewrite=[outside]):
948
+ ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
949
+ """
950
+ if hasattr(rewrite, "forward"):
951
+ # It is a torch.nn.Module.
952
+ # Let's retrieve the known rewriting for this model class.
953
+ rewrite = code_needing_rewriting(rewrite.__class__.__name__)
954
+ assert rewrite, "rewrite is empty, automated discovery is not implemented yet"
955
+ keep = {}
956
+ for me in rewrite:
957
+ if isinstance(me, tuple):
958
+ assert len(me) == 2, f"Unexpected value for a rewritten method or function {me}"
959
+ cls, name = me
960
+ to_rewrite = getattr(cls, name)
961
+ kind = "method"
962
+ kws = {} # type: ignore[var-annotated]
963
+ else:
964
+ if isinstance(me, dict):
965
+ assert "function" in me and (
966
+ "filter_node" in me or "pre_rewriter" in me or "post_rewriter" in me
967
+ ), (
968
+ f"If the rewriting code is defined as a dictionary, key "
969
+ f"'function' must be defined, other arguments must be understood by "
970
+ f"{transform_method.__name__}, "
971
+ f"the given value is {me!r}."
972
+ )
973
+ kws = me
974
+ me = me["function"]
975
+ del kws["function"]
976
+ else:
977
+ kws = {}
978
+ name = me.__qualname__
979
+ spl = name.split(".")
980
+ if len(spl) == 1:
981
+ # This a function
982
+ module = me.__module__
983
+ if module in me.__globals__:
984
+ mod = me.__globals__[module]
985
+ else:
986
+ assert module in sys.modules, (
987
+ f"Cannot find module name {module!r} in sys.modules or "
988
+ f"__globals__={sorted(me.__globals__)}"
989
+ )
990
+ mod = sys.modules[module]
991
+ cls_name = module
992
+ cls = mod
993
+ name = name
994
+ to_rewrite = me
995
+ kind = "function"
996
+ else:
997
+ kind = "method"
998
+ # This is a method
999
+ assert len(spl) >= 2, (
1000
+ f"{me} is not method, its name {name!r} does not contain a class name, "
1001
+ f"dir(me)={dir(me)}"
1002
+ )
1003
+ cls_name = spl[-2]
1004
+ assert cls_name in me.__globals__, (
1005
+ f"Class name {cls_name!r} from method {name!r} "
1006
+ f"could not be found in set(me.__globals__)={sorted(me.__globals__)}"
1007
+ )
1008
+ cls = me.__globals__[cls_name]
1009
+ name = me.__name__
1010
+ to_rewrite = me
1011
+ assert hasattr(
1012
+ cls, name
1013
+ ), f"Method {name!r} inferred form {me} was not found in class {cls}."
1014
+ assert (cls, name) not in keep, f"{kind} {me} cannot be rewritten twice."
1015
+ if verbose:
1016
+ print(f"[torch_export_rewrite] rewrites {kind} {cls.__name__}.{name}")
1017
+ keep[cls, name] = to_rewrite
1018
+ if dump_rewriting:
1019
+ if not os.path.exists(dump_rewriting):
1020
+ os.makedirs(dump_rewriting)
1021
+ filename1 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.original.py")
1022
+ if verbose:
1023
+ print(f"[torch_export_rewrite] dump original code in {filename1!r}")
1024
+ with open(filename1, "w") as f:
1025
+ code = clean_code_with_black(inspect.getsource(to_rewrite))
1026
+ f.write(code)
1027
+ rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0), **kws)
1028
+ if dump_rewriting:
1029
+ filename2 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.rewritten.py")
1030
+ if verbose:
1031
+ print(f"[torch_export_rewrite] dump rewritten code in {filename2!r}")
1032
+ with open(filename2, "w") as f:
1033
+ rcode = clean_code_with_black(rewr.code)
1034
+ f.write(rcode)
1035
+ diff = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.diff")
1036
+ make_diff_code(code, rcode, diff)
1037
+ if patch_details:
1038
+ patch_details.append("rewrite", getattr(cls, name), rewr.func)
1039
+ setattr(cls, name, rewr.func)
1040
+
1041
+ try:
1042
+ yield
1043
+ finally:
1044
+ for (cls, name), me in keep.items():
1045
+ if verbose:
1046
+ print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}")
1047
+ setattr(cls, name, me)