ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,423 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import contextlib
17
+ import copy
18
+ import dataclasses
19
+ import functools
20
+ import io
21
+ import operator
22
+ import os
23
+ import sys
24
+ from typing import Any, Generator, List, Optional, Tuple
25
+
26
+ from functorch.compile import minifier as fx_minifier
27
+ import torch
28
+ from torch._functorch import aot_autograd
29
+ import torch.utils._pytree as pytree
30
+
31
+ import ai_edge_torch
32
+ from ai_edge_torch.debug import utils
33
+
34
+ _torch_float_dtypes = {
35
+ torch.float32,
36
+ torch.float,
37
+ torch.float64,
38
+ torch.double,
39
+ torch.float16,
40
+ torch.half,
41
+ torch.bfloat16,
42
+ }
43
+ _torch_int_dtypes = {
44
+ torch.uint8,
45
+ torch.int8,
46
+ torch.int16,
47
+ torch.short,
48
+ torch.int32,
49
+ torch.int,
50
+ torch.int64,
51
+ torch.long,
52
+ }
53
+
54
+ _fx_op_runner = {
55
+ "call_function": lambda target, args, kwargs: target(*args, **kwargs),
56
+ "call_method": lambda target, args, kwargs: getattr(args[0], target)(
57
+ *args[1:], **kwargs
58
+ ),
59
+ }
60
+
61
+ _CULPRIT_GRAPH_MODULE_NAME = "CulpritGraphModule"
62
+
63
+
64
+ def _get_shape_str(t: torch.Tensor):
65
+ return f"({', '.join(map(str, t.shape))},)"
66
+
67
+
68
+ def _tensor_to_random_tensor_call(t: torch.Tensor):
69
+ shape_str = _get_shape_str(t)
70
+ if t.dtype in _torch_float_dtypes:
71
+ return f"torch.randn({shape_str}, dtype={t.dtype})"
72
+ elif t.dtype in _torch_int_dtypes:
73
+ return f"torch.randint(0, 10, {shape_str}, dtype={t.dtype})"
74
+ elif t.dtype == torch.bool:
75
+ return f"torch.randint(0, 2, {shape_str}, dtype={t.dtype})"
76
+ else:
77
+ raise ValueError(f"Unsupported dtype: {t.dtype}")
78
+
79
+
80
+ def _tensor_to_buffer(t: torch.Tensor):
81
+ buff = io.BytesIO()
82
+ torch.save(t, buff)
83
+ buff.seek(0)
84
+ return buff.read()
85
+
86
+
87
+ @dataclasses.dataclass
88
+ class Culprit:
89
+ graph_module: torch.fx.GraphModule
90
+ inputs: Tuple[Any]
91
+ _runtime_errors: bool
92
+
93
+ @property
94
+ def graph(self) -> torch.fx.Graph:
95
+ return self.graph_module.graph
96
+
97
+ @graph.setter
98
+ def graph(self, fx_g: torch.fx.Graph):
99
+ self.graph_module.graph = fx_g
100
+
101
+ @property
102
+ def stack_traces(self) -> List[str]:
103
+ stack_traces = set()
104
+ for node in self.graph.nodes:
105
+ if node.op.startswith("call_") and "stack_trace" in node.meta:
106
+ stack_traces.add(node.meta["stack_trace"])
107
+ return list(stack_traces)
108
+
109
+ def print_readable(self, print_output=True):
110
+ """Print the Python code for culprit graph module and sample args.
111
+
112
+ Args:
113
+ print_output: bool - If true, prints the code to stdout. Otherwise returns
114
+ the code in a str.
115
+ """
116
+ # TODO (b/321263453): Support Python code gen with sample arg tensor values.
117
+ random_inputs = True
118
+
119
+ graph_module_code = self.graph_module.print_readable(print_output=False).rstrip()
120
+
121
+ input_strs = []
122
+ for value in self.inputs:
123
+ if torch.is_tensor(value):
124
+ if not random_inputs:
125
+ input_strs.append(f"# size={_get_shape_str(value)}, dtype={value.dtype}")
126
+ input_strs.append(f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),")
127
+ else:
128
+ input_strs.append(_tensor_to_random_tensor_call(value) + ",")
129
+ else:
130
+ input_strs.append(str(value) + ",")
131
+
132
+ inputs_code = (
133
+ "_args = (\n" + "\n".join([" " * 4 + code for code in input_strs]) + "\n)"
134
+ )
135
+
136
+ code = graph_module_code + "\n\n" + inputs_code
137
+ if print_output:
138
+ print(code)
139
+ else:
140
+ return code
141
+
142
+ def print_code(self, print_output=True):
143
+ """Print the Python code for culprit graph module, sample args, and AI
144
+ Edge Torch conversion that will fail with the error.
145
+
146
+ Args:
147
+ print_output: bool - If true, prints the code to stdout. Otherwise returns
148
+ the code in a str.
149
+ """
150
+ definitions = self.print_readable(print_output=False)
151
+ code = (
152
+ "import torch\n"
153
+ + "from torch import device\n"
154
+ + "import ai_edge_torch\n\n"
155
+ + definitions
156
+ + f"\n\n_edge_model = ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(), _args)\n"
157
+ )
158
+ if self._runtime_errors:
159
+ code += "_edge_model(*_args)\n"
160
+
161
+ if print_output:
162
+ print(code)
163
+ else:
164
+ return code
165
+
166
+ @property
167
+ def code(self):
168
+ return self.print_code(print_output=False)
169
+
170
+ def __repr__(self):
171
+ return self.print_readable(print_output=False)
172
+
173
+ def __str__(self):
174
+ return self.print_readable(print_output=False)
175
+
176
+
177
+ def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
178
+ """
179
+ This function turns all operator getitem nodes in ExportedProgram FX graph to
180
+ new nodes composed of "computation + getitem". The normalization duplicates
181
+ some computations in the graph but would make the graph more friendly for
182
+ partitioning in FX minifier.
183
+ """
184
+
185
+ fx_gm = copy.deepcopy(fx_gm)
186
+ graph = fx_gm.graph
187
+ for n in graph.nodes:
188
+ if n.target != operator.getitem:
189
+ continue
190
+
191
+ src_n, key = n.args
192
+ if src_n.op not in _fx_op_runner:
193
+ continue
194
+
195
+ runner = _fx_op_runner.get(src_n.op)
196
+
197
+ with graph.inserting_after(n):
198
+ new_n = graph.call_function(
199
+ lambda src_target, key, args, kwargs: operator.getitem(
200
+ runner(src_target, args, kwargs), key
201
+ ),
202
+ (src_n.target, key, src_n.args, src_n.kwargs),
203
+ )
204
+ n.replace_all_uses_with(new_n)
205
+
206
+ graph.eliminate_dead_code()
207
+ fx_gm.graph = graph
208
+ return fx_gm
209
+
210
+
211
+ def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
212
+ fx_gm = copy.deepcopy(fx_gm)
213
+ inputs = tuple(inputs)
214
+ args = fx_gm.graph.process_inputs(*inputs)
215
+ args_iter = iter(args)
216
+
217
+ graph = fx_gm.graph
218
+ new_inputs = []
219
+ for n in graph.nodes:
220
+ if n.op == "placeholder":
221
+ if n.target.startswith("*"):
222
+ new_inputs += list(args_iter)
223
+ elif len(n.users) > 0:
224
+ new_inputs.append(next(args_iter))
225
+ else:
226
+ graph.erase_node(n)
227
+ next(args_iter)
228
+ new_inputs = tuple(new_inputs)
229
+ fx_gm.graph = graph
230
+ return fx_gm, new_inputs
231
+
232
+
233
+ def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule):
234
+ fx_gm = copy.deepcopy(fx_gm)
235
+
236
+ new_outputs = []
237
+ graph = fx_gm.graph
238
+ nodes = list(graph.nodes)
239
+ assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
240
+ for node in nodes:
241
+ if node.op not in ("placeholder", "output") and len(node.users) == 0:
242
+ new_outputs.append(node)
243
+
244
+ output_node = nodes[-1]
245
+ # FX output node returns the first arg as is.
246
+ # ref: https://github.com/pytorch/pytorch/blob/1a578df57cc0f417f671634e564c62ef5d9a97e2/torch/fx/interpreter.py#L337
247
+ new_outputs, _ = pytree.tree_flatten([new_outputs, output_node.args[0]])
248
+ output_node.update_arg(0, tuple(new_outputs))
249
+
250
+ fx_gm.graph = graph
251
+ return fx_gm
252
+
253
+
254
+ def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule):
255
+ """Remove output nodes directly connected to an input node."""
256
+ fx_gm = copy.deepcopy(fx_gm)
257
+
258
+ graph = fx_gm.graph
259
+ nodes = list(graph.nodes)
260
+ assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
261
+ output_node = nodes[-1]
262
+
263
+ outputs, _ = pytree.tree_flatten(output_node.args[0])
264
+ new_outputs = [output for output in outputs if output.op != "placeholder"]
265
+ output_node.update_arg(0, tuple(new_outputs))
266
+
267
+ fx_gm.recompile()
268
+ return fx_gm
269
+
270
+
271
+ def _erase_sub_gm_from_gm(
272
+ fx_gm: torch.fx.GraphModule,
273
+ fx_inputs: Tuple[torch.Tensor],
274
+ sub_gm: torch.fx.GraphModule,
275
+ sub_inputs: Tuple[torch.Tensor],
276
+ ):
277
+ fx_gm = copy.deepcopy(fx_gm)
278
+ fx_inputs = list(fx_inputs)
279
+
280
+ class EraseNodeInterpreter(torch.fx.Interpreter):
281
+
282
+ def run_node(self, node):
283
+ nonlocal fx_gm, fx_inputs
284
+ res = super().run_node(node)
285
+ if node.op not in ("placeholder", "output"):
286
+ to_erase = next(m for m in fx_gm.graph.nodes if m.name == node.name)
287
+ # Raise the output (tensor) of the erased node to be an input of
288
+ # the new model graph. Some raised inputs may become unused later
289
+ # when all the users are within the erased subgraph, those inputs
290
+ # will be removed by the followed `_erase_unused_inputs` pass.
291
+ with fx_gm.graph.inserting_before(to_erase):
292
+ new_input = fx_gm.graph.placeholder(node.name + "__value")
293
+ to_erase.replace_all_uses_with(new_input)
294
+
295
+ fx_gm.graph.erase_node(to_erase)
296
+ fx_inputs.append(res)
297
+ return res
298
+
299
+ interpreter = EraseNodeInterpreter(sub_gm)
300
+ interpreter.run(*sub_inputs)
301
+
302
+ fx_gm.graph.lint()
303
+ fx_gm.recompile()
304
+
305
+ # Ops prior to the erased subgraph may be dangling. Lift them as outputs.
306
+ fx_gm = _lift_dead_ops_to_outputs(fx_gm)
307
+ fx_gm = _erase_trivial_outputs(fx_gm)
308
+ fx_gm, fx_inputs = _erase_unused_inputs(fx_gm, fx_inputs)
309
+
310
+ fx_gm.graph.lint()
311
+ fx_gm.recompile()
312
+ return fx_gm, fx_inputs
313
+
314
+
315
+ def _normalize_minified_fx_gm(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
316
+ fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
317
+ fx_gm = _lift_dead_ops_to_outputs(fx_gm)
318
+ fx_gm, _ = aot_autograd.aot_export_module(fx_gm, inputs, trace_joint=False)
319
+ fx_gm.__class__.__name__ = _CULPRIT_GRAPH_MODULE_NAME
320
+ return fx_gm, inputs
321
+
322
+
323
+ def _fx_minifier_checker(fx_gm, inputs, runtime_errors=False):
324
+ fx_gm, inputs = _normalize_minified_fx_gm(fx_gm, inputs)
325
+
326
+ trivial_aten_ops = {
327
+ torch.ops.aten.view,
328
+ torch.ops.aten.view.default,
329
+ }
330
+ if all(
331
+ node.op in ("placeholder", "output") or node.target in trivial_aten_ops
332
+ for node in fx_gm.graph.nodes
333
+ ):
334
+ return False
335
+
336
+ try:
337
+ edge_model = ai_edge_torch.convert(fx_gm.eval(), inputs)
338
+ if runtime_errors:
339
+ edge_model(*inputs)
340
+ except Exception as err:
341
+ return True
342
+ return False
343
+
344
+
345
+ def find_culprits(
346
+ torch_model: torch.nn.Module,
347
+ args: Tuple[Any],
348
+ max_granularity: Optional[int] = None,
349
+ runtime_errors: bool = False,
350
+ *,
351
+ enable_fx_minifier_logging: bool = False,
352
+ ) -> Generator[Culprit, None, None]:
353
+ """Finds culprits in the AI Edge Torch model conversion.
354
+
355
+ Args:
356
+ torch_model: model to export and save
357
+ args: A set of args to trace the model with, i.e.
358
+ torch_model(*args) must run
359
+ max_granularity - FX minifier arg. The maximum granularity (number of nodes)
360
+ in the returned ATen FX subgraph of the culprit.
361
+ runtime_errors: If true, find culprits for Python runtime errors
362
+ with converted model.
363
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to log
364
+ the progress.
365
+ """
366
+
367
+ try:
368
+ ep = torch.export.export(torch_model, args)
369
+ except Exception as err:
370
+ raise ValueError(
371
+ "Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first."
372
+ ) from err
373
+
374
+ fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
375
+ fx_gm = _normalize_getitem_nodes(fx_gm)
376
+
377
+ fx_minifier_checker = functools.partial(
378
+ _fx_minifier_checker, runtime_errors=runtime_errors
379
+ )
380
+
381
+ # HACK: temporarily disable XLA_HLO_DEBUG so that fx_minifier won't dump
382
+ # intermediate stablehlo files to storage.
383
+ # https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
384
+ @contextlib.contextmanager
385
+ def disable_xla_hlo_debug():
386
+ xla_hlo_debug_value = None
387
+ if "XLA_HLO_DEBUG" in os.environ:
388
+ xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
389
+ del os.environ["XLA_HLO_DEBUG"]
390
+
391
+ try:
392
+ yield None
393
+ finally:
394
+ if xla_hlo_debug_value is not None:
395
+ os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
396
+
397
+ found_culprits_num = 0
398
+ while True:
399
+ try:
400
+ with disable_xla_hlo_debug(), open(os.devnull, "w") as devnull:
401
+ with contextlib.nullcontext() if enable_fx_minifier_logging else utils.redirect_stdio(
402
+ stdout=devnull,
403
+ stderr=devnull,
404
+ ):
405
+ raw_min_fx_gm, raw_min_inputs = fx_minifier(
406
+ fx_gm,
407
+ fx_inputs,
408
+ fx_minifier_checker,
409
+ max_granularity=max_granularity,
410
+ )
411
+
412
+ min_fx_gm, min_inputs = _normalize_minified_fx_gm(raw_min_fx_gm, raw_min_inputs)
413
+ found_culprits_num += 1
414
+ yield Culprit(min_fx_gm, min_inputs, _runtime_errors=runtime_errors)
415
+
416
+ fx_gm, fx_inputs = _erase_sub_gm_from_gm(
417
+ fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
418
+ )
419
+
420
+ except RuntimeError as e:
421
+ if str(e) == "Input graph did not fail the tester" and found_culprits_num > 0:
422
+ break
423
+ raise e
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import ast
18
+ import io
19
+ import sys
20
+ import unittest
21
+
22
+ import torch
23
+
24
+ from ai_edge_torch.debug import find_culprits
25
+
26
+ _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
+
28
+ _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
29
+
30
+
31
+ @torch.library.impl(_test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd")
32
+ def non_lowerable_op(x):
33
+ if x.max() > 10.0:
34
+ return x + 1.0
35
+ return x
36
+
37
+
38
+ @torch.library.impl(_test_culprit_lib, "non_lowerable_op", "Meta")
39
+ def non_lowerable_op_meta(x):
40
+ return torch.empty_like(x)
41
+
42
+
43
+ class BadModel(torch.nn.Module):
44
+
45
+ def forward(self, x):
46
+ x = x + 1
47
+ x = torch.ops.test_culprit.non_lowerable_op.default(x)
48
+ return x
49
+
50
+
51
+ class TestCulprit(unittest.TestCase):
52
+
53
+ def test_find_culprits(self):
54
+ model = BadModel().eval()
55
+ args = (torch.rand(10),)
56
+
57
+ culprits = list(find_culprits(model, args))
58
+ self.assertEqual(len(culprits), 1)
59
+ self.assertIn(
60
+ torch.ops.test_culprit.non_lowerable_op.default,
61
+ [n.target for n in culprits[0].graph.nodes],
62
+ )
63
+
64
+ def test_valid_culprit_readable(self):
65
+ model = BadModel().eval()
66
+ args = (torch.rand(10),)
67
+
68
+ culprits = list(find_culprits(model, args))
69
+ self.assertEqual(len(culprits), 1)
70
+
71
+ code = culprits[0].print_readable(print_output=False)
72
+
73
+ # The code should be a valid Python code
74
+ ast.parse(code)
75
+
76
+ def test_valid_culprit_code(self):
77
+ model = BadModel().eval()
78
+ args = (torch.rand(10),)
79
+
80
+ culprits = list(find_culprits(model, args))
81
+ self.assertEqual(len(culprits), 1)
82
+
83
+ code = culprits[0].print_code(print_output=False)
84
+
85
+ # The code should be a valid Python code
86
+ ast.parse(code)
87
+
88
+ def test_find_multiple_culprits(self):
89
+ class MultiBadOpsModel(torch.nn.Module):
90
+
91
+ def forward(self, x):
92
+ x = x + 1
93
+ a = torch.ops.test_culprit.non_lowerable_op.default(x)
94
+ b = torch.ops.test_culprit.non_lowerable_op.default(x)
95
+ c = a + b
96
+ d = torch.ops.test_culprit.non_lowerable_op.default(c)
97
+ return d
98
+
99
+ model = MultiBadOpsModel().eval()
100
+ args = (torch.rand(10),)
101
+
102
+ culprits = list(find_culprits(model, args))
103
+ self.assertEqual(len(culprits), 3)
104
+ for culprit in culprits:
105
+ self.assertIn(
106
+ torch.ops.test_culprit.non_lowerable_op.default,
107
+ [n.target for n in culprit.graph.nodes],
108
+ )
109
+
110
+ def test_find_culprits_with_trivial_inputs_outputs(self):
111
+
112
+ class MultiBadOpsModel(torch.nn.Module):
113
+
114
+ def forward(self, x, y, z):
115
+ x = x + 1
116
+ a = torch.ops.test_culprit.non_lowerable_op.default(x)
117
+ b = torch.ops.test_culprit.non_lowerable_op.default(y)
118
+ return a, b, x, y, a, b
119
+
120
+ model = MultiBadOpsModel().eval()
121
+ args = (torch.rand(10), torch.rand(10), torch.rand(10))
122
+
123
+ culprits = list(find_culprits(model, args))
124
+ self.assertEqual(len(culprits), 2)
125
+ for culprit in culprits:
126
+ self.assertIn(
127
+ torch.ops.test_culprit.non_lowerable_op.default,
128
+ [n.target for n in culprit.graph.nodes],
129
+ )
130
+
131
+
132
+ if __name__ == "__main__":
133
+ unittest.main()
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import contextlib
16
+ import sys
17
+
18
+ import torch
19
+ from torch.export.graph_signature import InputKind
20
+ import torch.fx._pytree as fx_pytree
21
+ from torch.utils import _pytree as pytree
22
+
23
+
24
+ def exported_program_to_fx_graph_module_and_inputs(ep: torch.export.ExportedProgram):
25
+ fx_gm = ep.graph_module
26
+ fx_inputs = pytree.tree_map(
27
+ torch.tensor, ep._graph_module_flat_inputs(*ep.example_inputs)
28
+ )
29
+ return fx_gm, fx_inputs
30
+
31
+
32
+ @contextlib.contextmanager
33
+ def redirect_stdio(stdout, stderr):
34
+ old_stdout = sys.stdout
35
+ old_stderr = sys.stderr
36
+
37
+ old_stdout.flush()
38
+ old_stderr.flush()
39
+
40
+ sys.stdout = stdout
41
+ sys.stderr = stderr
42
+ try:
43
+ yield stdout, stderr
44
+ finally:
45
+ stdout.flush()
46
+ stderr.flush()
47
+ sys.stdout = old_stdout
48
+ sys.stderr = old_stderr
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================