ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,464 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import contextlib
17
+ import copy
18
+ import dataclasses
19
+ import functools
20
+ import io
21
+ import operator
22
+ import os
23
+ import sys
24
+ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
25
+
26
+ from functorch.compile import minifier as fx_minifier
27
+ import torch
28
+ from torch._functorch import aot_autograd
29
+ import torch.utils._pytree as pytree
30
+
31
+ import ai_edge_torch
32
+ from ai_edge_torch.debug import utils
33
+
34
+ _torch_float_dtypes = {
35
+ torch.float32,
36
+ torch.float,
37
+ torch.float64,
38
+ torch.double,
39
+ torch.float16,
40
+ torch.half,
41
+ torch.bfloat16,
42
+ }
43
+ _torch_int_dtypes = {
44
+ torch.uint8,
45
+ torch.int8,
46
+ torch.int16,
47
+ torch.short,
48
+ torch.int32,
49
+ torch.int,
50
+ torch.int64,
51
+ torch.long,
52
+ }
53
+
54
+ _fx_op_runner = {
55
+ "call_function": lambda target, args, kwargs: target(*args, **kwargs),
56
+ "call_method": lambda target, args, kwargs: getattr(args[0], target)(
57
+ *args[1:], **kwargs
58
+ ),
59
+ }
60
+
61
+ _CULPRIT_GRAPH_MODULE_NAME = "CulpritGraphModule"
62
+
63
+
64
+ def _get_shape_str(t: torch.Tensor):
65
+ return f"({', '.join(map(str, t.shape))},)"
66
+
67
+
68
+ def _tensor_to_random_tensor_call(t: torch.Tensor):
69
+ shape_str = _get_shape_str(t)
70
+ if t.dtype in _torch_float_dtypes:
71
+ return f"torch.randn({shape_str}, dtype={t.dtype})"
72
+ elif t.dtype in _torch_int_dtypes:
73
+ return f"torch.randint(0, 10, {shape_str}, dtype={t.dtype})"
74
+ elif t.dtype == torch.bool:
75
+ return f"torch.randint(0, 2, {shape_str}, dtype={t.dtype})"
76
+ else:
77
+ raise ValueError(f"Unsupported dtype: {t.dtype}")
78
+
79
+
80
+ def _tensor_to_buffer(t: torch.Tensor):
81
+ buff = io.BytesIO()
82
+ torch.save(t, buff)
83
+ buff.seek(0)
84
+ return buff.read()
85
+
86
+
87
+ @dataclasses.dataclass
88
+ class SearchResult:
89
+ graph_module: torch.fx.GraphModule
90
+ inputs: Tuple[Any]
91
+
92
+ @property
93
+ def graph(self) -> torch.fx.Graph:
94
+ return self.graph_module.graph
95
+
96
+ @graph.setter
97
+ def graph(self, fx_g: torch.fx.Graph):
98
+ self.graph_module.graph = fx_g
99
+
100
+
101
+ @dataclasses.dataclass
102
+ class Culprit(SearchResult):
103
+ _runtime_errors: bool
104
+
105
+ @property
106
+ def stack_traces(self) -> List[str]:
107
+ stack_traces = set()
108
+ for node in self.graph.nodes:
109
+ if node.op.startswith("call_") and "stack_trace" in node.meta:
110
+ stack_traces.add(node.meta["stack_trace"])
111
+ return list(stack_traces)
112
+
113
+ def print_readable(self, print_output=True):
114
+ """Print the Python code for culprit graph module and sample args.
115
+
116
+ Args:
117
+ print_output: bool - If true, prints the code to stdout. Otherwise returns
118
+ the code in a str.
119
+ """
120
+ # TODO (b/321263453): Support Python code gen with sample arg tensor values.
121
+ random_inputs = True
122
+
123
+ graph_module_code = self.graph_module.print_readable(print_output=False).rstrip()
124
+
125
+ input_strs = []
126
+ for value in self.inputs:
127
+ if torch.is_tensor(value):
128
+ if not random_inputs:
129
+ input_strs.append(f"# size={_get_shape_str(value)}, dtype={value.dtype}")
130
+ input_strs.append(f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),")
131
+ else:
132
+ input_strs.append(_tensor_to_random_tensor_call(value) + ",")
133
+ else:
134
+ input_strs.append(str(value) + ",")
135
+
136
+ inputs_code = (
137
+ "_args = (\n" + "\n".join([" " * 4 + code for code in input_strs]) + "\n)"
138
+ )
139
+
140
+ code = graph_module_code + "\n\n" + inputs_code
141
+ if print_output:
142
+ print(code)
143
+ else:
144
+ return code
145
+
146
+ def print_code(self, print_output=True):
147
+ """Print the Python code for culprit graph module, sample args, and AI
148
+ Edge Torch conversion that will fail with the error.
149
+
150
+ Args:
151
+ print_output: bool - If true, prints the code to stdout. Otherwise returns
152
+ the code in a str.
153
+ """
154
+ definitions = self.print_readable(print_output=False)
155
+ code = (
156
+ "import torch\n"
157
+ + "from torch import device\n"
158
+ + "import ai_edge_torch\n\n"
159
+ + definitions
160
+ + f"\n\n_edge_model = ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(), _args)\n"
161
+ )
162
+ if self._runtime_errors:
163
+ code += "_edge_model(*_args)\n"
164
+
165
+ if print_output:
166
+ print(code)
167
+ else:
168
+ return code
169
+
170
+ @property
171
+ def code(self):
172
+ return self.print_code(print_output=False)
173
+
174
+ def __repr__(self):
175
+ return self.print_readable(print_output=False)
176
+
177
+ def __str__(self):
178
+ return self.print_readable(print_output=False)
179
+
180
+
181
+ def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
182
+ """
183
+ This function turns all operator getitem nodes in ExportedProgram FX graph to
184
+ new nodes composed of "computation + getitem". The normalization duplicates
185
+ some computations in the graph but would make the graph more friendly for
186
+ partitioning in FX minifier.
187
+ """
188
+
189
+ fx_gm = copy.deepcopy(fx_gm)
190
+ graph = fx_gm.graph
191
+ for n in graph.nodes:
192
+ if n.target != operator.getitem:
193
+ continue
194
+
195
+ src_n, key = n.args
196
+ if src_n.op not in _fx_op_runner:
197
+ continue
198
+
199
+ runner = _fx_op_runner.get(src_n.op)
200
+
201
+ with graph.inserting_after(n):
202
+ new_n = graph.call_function(
203
+ lambda src_target, key, args, kwargs: operator.getitem(
204
+ runner(src_target, args, kwargs), key
205
+ ),
206
+ (src_n.target, key, src_n.args, src_n.kwargs),
207
+ )
208
+ n.replace_all_uses_with(new_n)
209
+
210
+ graph.eliminate_dead_code()
211
+ fx_gm.graph = graph
212
+ return fx_gm
213
+
214
+
215
+ def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
216
+ fx_gm = copy.deepcopy(fx_gm)
217
+ inputs = tuple(inputs)
218
+ args = fx_gm.graph.process_inputs(*inputs)
219
+ args_iter = iter(args)
220
+
221
+ graph = fx_gm.graph
222
+ new_inputs = []
223
+ for n in graph.nodes:
224
+ if n.op == "placeholder":
225
+ if n.target.startswith("*"):
226
+ new_inputs += list(args_iter)
227
+ elif len(n.users) > 0:
228
+ new_inputs.append(next(args_iter))
229
+ else:
230
+ graph.erase_node(n)
231
+ next(args_iter)
232
+ new_inputs = tuple(new_inputs)
233
+ fx_gm.graph = graph
234
+ return fx_gm, new_inputs
235
+
236
+
237
+ def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule):
238
+ fx_gm = copy.deepcopy(fx_gm)
239
+
240
+ new_outputs = []
241
+ graph = fx_gm.graph
242
+ nodes = list(graph.nodes)
243
+ assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
244
+ for node in nodes:
245
+ if node.op not in ("placeholder", "output") and len(node.users) == 0:
246
+ new_outputs.append(node)
247
+
248
+ output_node = nodes[-1]
249
+ # FX output node returns the first arg as is.
250
+ # ref: https://github.com/pytorch/pytorch/blob/1a578df57cc0f417f671634e564c62ef5d9a97e2/torch/fx/interpreter.py#L337
251
+ new_outputs, _ = pytree.tree_flatten([new_outputs, output_node.args[0]])
252
+ output_node.update_arg(0, tuple(new_outputs))
253
+
254
+ fx_gm.graph = graph
255
+ return fx_gm
256
+
257
+
258
+ def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule):
259
+ """Remove output nodes directly connected to an input node."""
260
+ fx_gm = copy.deepcopy(fx_gm)
261
+
262
+ graph = fx_gm.graph
263
+ nodes = list(graph.nodes)
264
+ assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
265
+ output_node = nodes[-1]
266
+
267
+ outputs, _ = pytree.tree_flatten(output_node.args[0])
268
+ new_outputs = [output for output in outputs if output.op != "placeholder"]
269
+ output_node.update_arg(0, tuple(new_outputs))
270
+
271
+ fx_gm.recompile()
272
+ return fx_gm
273
+
274
+
275
+ def _erase_sub_gm_from_gm(
276
+ fx_gm: torch.fx.GraphModule,
277
+ fx_inputs: Tuple[torch.Tensor],
278
+ sub_gm: torch.fx.GraphModule,
279
+ sub_inputs: Tuple[torch.Tensor],
280
+ ):
281
+ fx_gm = copy.deepcopy(fx_gm)
282
+ fx_inputs = list(fx_inputs)
283
+
284
+ class EraseNodeInterpreter(torch.fx.Interpreter):
285
+
286
+ def run_node(self, node):
287
+ nonlocal fx_gm, fx_inputs
288
+ res = super().run_node(node)
289
+ if node.op not in ("placeholder", "output"):
290
+ to_erase = next(m for m in fx_gm.graph.nodes if m.name == node.name)
291
+ # Raise the output (tensor) of the erased node to be an input of
292
+ # the new model graph. Some raised inputs may become unused later
293
+ # when all the users are within the erased subgraph, those inputs
294
+ # will be removed by the followed `_erase_unused_inputs` pass.
295
+ with fx_gm.graph.inserting_before(to_erase):
296
+ new_input = fx_gm.graph.placeholder(node.name + "__value")
297
+ to_erase.replace_all_uses_with(new_input)
298
+
299
+ fx_gm.graph.erase_node(to_erase)
300
+ fx_inputs.append(res)
301
+ return res
302
+
303
+ interpreter = EraseNodeInterpreter(sub_gm)
304
+ interpreter.run(*sub_inputs)
305
+
306
+ fx_gm.graph.lint()
307
+ fx_gm.recompile()
308
+
309
+ # Ops prior to the erased subgraph may be dangling. Lift them as outputs.
310
+ fx_gm = _lift_dead_ops_to_outputs(fx_gm)
311
+ fx_gm = _erase_trivial_outputs(fx_gm)
312
+ fx_gm, fx_inputs = _erase_unused_inputs(fx_gm, fx_inputs)
313
+
314
+ fx_gm.graph.lint()
315
+ fx_gm.recompile()
316
+ return fx_gm, fx_inputs
317
+
318
+
319
+ def _normalize_minified_fx_gm(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
320
+ fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
321
+ fx_gm = _lift_dead_ops_to_outputs(fx_gm)
322
+ fx_gm, _ = aot_autograd.aot_export_module(fx_gm, inputs, trace_joint=False)
323
+ fx_gm.__class__.__name__ = _CULPRIT_GRAPH_MODULE_NAME
324
+ return fx_gm, inputs
325
+
326
+
327
+ def _fx_minifier_checker(fx_gm, inputs, runtime_errors=False):
328
+ fx_gm, inputs = _normalize_minified_fx_gm(fx_gm, inputs)
329
+
330
+ trivial_aten_ops = {
331
+ torch.ops.aten.view,
332
+ torch.ops.aten.view.default,
333
+ }
334
+ if all(
335
+ node.op in ("placeholder", "output") or node.target in trivial_aten_ops
336
+ for node in fx_gm.graph.nodes
337
+ ):
338
+ return False
339
+
340
+ try:
341
+ edge_model = ai_edge_torch.convert(fx_gm.eval(), inputs)
342
+ if runtime_errors:
343
+ edge_model(*inputs)
344
+ except Exception as err:
345
+ return True
346
+ return False
347
+
348
+
349
+ def _search_model(
350
+ predicate_f: Callable[[torch.fx.GraphModule, List[Any]], bool],
351
+ model: Union[torch.export.ExportedProgram, torch.nn.Module],
352
+ export_args: Tuple[Any] = None,
353
+ *,
354
+ max_granularity: Optional[int] = None,
355
+ enable_fx_minifier_logging: bool = False,
356
+ ) -> Generator[SearchResult, None, None]:
357
+ """Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
358
+
359
+ Args:
360
+ predicate_f: a predicate function the users specify.
361
+ It takes a FX (sub)graph and the inputs to this graph,
362
+ return True if the graph satisfies the predicate,
363
+ return False otherwise.
364
+ model: model in which to search subgraph.
365
+ export_args: A set of args to trace the model with,
366
+ i.e. model(*args) must run.
367
+ max_granularity - FX minifier arg. The maximum granularity (number of nodes)
368
+ in the returned ATen FX subgraph of the culprit.
369
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
370
+ """
371
+
372
+ if isinstance(model, torch.nn.Module):
373
+ try:
374
+ ep = torch.export.export(model, export_args)
375
+ except Exception as err:
376
+ raise ValueError(
377
+ "Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first."
378
+ ) from err
379
+ else:
380
+ ep = model
381
+
382
+ fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
383
+ fx_gm = _normalize_getitem_nodes(fx_gm)
384
+
385
+ # HACK: temporarily disable XLA_HLO_DEBUG so that fx_minifier won't dump
386
+ # intermediate stablehlo files to storage.
387
+ # https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
388
+ @contextlib.contextmanager
389
+ def disable_xla_hlo_debug():
390
+ xla_hlo_debug_value = None
391
+ if "XLA_HLO_DEBUG" in os.environ:
392
+ xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
393
+ del os.environ["XLA_HLO_DEBUG"]
394
+
395
+ try:
396
+ yield None
397
+ finally:
398
+ if xla_hlo_debug_value is not None:
399
+ os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
400
+
401
+ found_culprits_num = 0
402
+ while True:
403
+ try:
404
+ with disable_xla_hlo_debug(), open(os.devnull, "w") as devnull:
405
+ with contextlib.nullcontext() if enable_fx_minifier_logging else utils.redirect_stdio(
406
+ stdout=devnull,
407
+ stderr=devnull,
408
+ ):
409
+ raw_min_fx_gm, raw_min_inputs = fx_minifier(
410
+ fx_gm,
411
+ fx_inputs,
412
+ predicate_f,
413
+ max_granularity=max_granularity,
414
+ )
415
+
416
+ min_fx_gm, min_inputs = _normalize_minified_fx_gm(raw_min_fx_gm, raw_min_inputs)
417
+ found_culprits_num += 1
418
+ yield SearchResult(min_fx_gm, min_inputs)
419
+
420
+ fx_gm, fx_inputs = _erase_sub_gm_from_gm(
421
+ fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
422
+ )
423
+
424
+ except RuntimeError as e:
425
+ if str(e) == "Input graph did not fail the tester" and found_culprits_num > 0:
426
+ break
427
+ raise e
428
+
429
+
430
+ def find_culprits(
431
+ torch_model: torch.nn.Module,
432
+ args: Tuple[Any],
433
+ max_granularity: Optional[int] = None,
434
+ runtime_errors: bool = False,
435
+ *,
436
+ enable_fx_minifier_logging: bool = False,
437
+ ) -> Generator[Culprit, None, None]:
438
+ """Finds culprits in the AI Edge Torch model conversion.
439
+
440
+ Args:
441
+ torch_model: model to export and save
442
+ args: A set of args to trace the model with, i.e.
443
+ torch_model(*args) must run
444
+ max_granularity - FX minifier arg. The maximum granularity (number of nodes)
445
+ in the returned ATen FX subgraph of the culprit.
446
+ runtime_errors: If true, find culprits for Python runtime errors
447
+ with converted model.
448
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
449
+ """
450
+
451
+ fx_minifier_checker = functools.partial(
452
+ _fx_minifier_checker, runtime_errors=runtime_errors
453
+ )
454
+
455
+ for search_result in _search_model(
456
+ fx_minifier_checker,
457
+ torch_model,
458
+ args,
459
+ max_granularity=max_granularity,
460
+ enable_fx_minifier_logging=enable_fx_minifier_logging,
461
+ ):
462
+ yield Culprit(
463
+ search_result.graph_module, search_result.inputs, _runtime_errors=runtime_errors
464
+ )
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import ast
18
+ import io
19
+ import sys
20
+ import unittest
21
+
22
+ import torch
23
+
24
+ from ai_edge_torch.debug import find_culprits
25
+
26
+ _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
+
28
+ _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
29
+
30
+
31
+ @torch.library.impl(_test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd")
32
+ def non_lowerable_op(x):
33
+ if x.max() > 10.0:
34
+ return x + 1.0
35
+ return x
36
+
37
+
38
+ @torch.library.impl(_test_culprit_lib, "non_lowerable_op", "Meta")
39
+ def non_lowerable_op_meta(x):
40
+ return torch.empty_like(x)
41
+
42
+
43
+ class BadModel(torch.nn.Module):
44
+
45
+ def forward(self, x):
46
+ x = x + 1
47
+ x = torch.ops.test_culprit.non_lowerable_op.default(x)
48
+ return x
49
+
50
+
51
+ class TestCulprit(unittest.TestCase):
52
+
53
+ def test_find_culprits(self):
54
+ model = BadModel().eval()
55
+ args = (torch.rand(10),)
56
+
57
+ culprits = list(find_culprits(model, args))
58
+ self.assertEqual(len(culprits), 1)
59
+ self.assertIn(
60
+ torch.ops.test_culprit.non_lowerable_op.default,
61
+ [n.target for n in culprits[0].graph.nodes],
62
+ )
63
+
64
+ def test_valid_culprit_readable(self):
65
+ model = BadModel().eval()
66
+ args = (torch.rand(10),)
67
+
68
+ culprits = list(find_culprits(model, args))
69
+ self.assertEqual(len(culprits), 1)
70
+
71
+ code = culprits[0].print_readable(print_output=False)
72
+
73
+ # The code should be a valid Python code
74
+ ast.parse(code)
75
+
76
+ def test_valid_culprit_code(self):
77
+ model = BadModel().eval()
78
+ args = (torch.rand(10),)
79
+
80
+ culprits = list(find_culprits(model, args))
81
+ self.assertEqual(len(culprits), 1)
82
+
83
+ code = culprits[0].print_code(print_output=False)
84
+
85
+ # The code should be a valid Python code
86
+ ast.parse(code)
87
+
88
+ def test_find_multiple_culprits(self):
89
+ class MultiBadOpsModel(torch.nn.Module):
90
+
91
+ def forward(self, x):
92
+ x = x + 1
93
+ a = torch.ops.test_culprit.non_lowerable_op.default(x)
94
+ b = torch.ops.test_culprit.non_lowerable_op.default(x)
95
+ c = a + b
96
+ d = torch.ops.test_culprit.non_lowerable_op.default(c)
97
+ return d
98
+
99
+ model = MultiBadOpsModel().eval()
100
+ args = (torch.rand(10),)
101
+
102
+ culprits = list(find_culprits(model, args))
103
+ self.assertEqual(len(culprits), 3)
104
+ for culprit in culprits:
105
+ self.assertIn(
106
+ torch.ops.test_culprit.non_lowerable_op.default,
107
+ [n.target for n in culprit.graph.nodes],
108
+ )
109
+
110
+ def test_find_culprits_with_trivial_inputs_outputs(self):
111
+
112
+ class MultiBadOpsModel(torch.nn.Module):
113
+
114
+ def forward(self, x, y, z):
115
+ x = x + 1
116
+ a = torch.ops.test_culprit.non_lowerable_op.default(x)
117
+ b = torch.ops.test_culprit.non_lowerable_op.default(y)
118
+ return a, b, x, y, a, b
119
+
120
+ model = MultiBadOpsModel().eval()
121
+ args = (torch.rand(10), torch.rand(10), torch.rand(10))
122
+
123
+ culprits = list(find_culprits(model, args))
124
+ self.assertEqual(len(culprits), 2)
125
+ for culprit in culprits:
126
+ self.assertIn(
127
+ torch.ops.test_culprit.non_lowerable_op.default,
128
+ [n.target for n in culprit.graph.nodes],
129
+ )
130
+
131
+
132
+ if __name__ == "__main__":
133
+ unittest.main()
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import torch
20
+
21
+ from ai_edge_torch.debug import _search_model
22
+
23
+
24
+ class TestSearchModel(unittest.TestCase):
25
+
26
+ def test_search_model_with_ops(self):
27
+ class MultipleOpsModel(torch.nn.Module):
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
30
+ sub_0 = x - 1
31
+ add_0 = y + 1
32
+ mul_0 = x * y
33
+ add_1 = sub_0 + add_0
34
+ mul_1 = add_0 * mul_0
35
+ sub_1 = add_1 - mul_1
36
+ return sub_1
37
+
38
+ model = MultipleOpsModel().eval()
39
+ args = (torch.rand(10), torch.rand(10))
40
+
41
+ def find_subgraph_with_sub(fx_gm, inputs):
42
+ return torch.ops.aten.sub.Tensor in [n.target for n in fx_gm.graph.nodes]
43
+
44
+ results = list(_search_model(find_subgraph_with_sub, model, args))
45
+ self.assertEqual(len(results), 2)
46
+ self.assertIn(torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes])
47
+
48
+
49
+ if __name__ == "__main__":
50
+ unittest.main()
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import contextlib
16
+ import sys
17
+
18
+ import torch
19
+ from torch.export.graph_signature import InputKind
20
+ import torch.fx._pytree as fx_pytree
21
+ from torch.utils import _pytree as pytree
22
+
23
+
24
+ def exported_program_to_fx_graph_module_and_inputs(ep: torch.export.ExportedProgram):
25
+ fx_gm = ep.graph_module
26
+ fx_inputs = pytree.tree_map(
27
+ torch.tensor, ep._graph_module_flat_inputs(*ep.example_inputs)
28
+ )
29
+ return fx_gm, fx_inputs
30
+
31
+
32
+ @contextlib.contextmanager
33
+ def redirect_stdio(stdout, stderr):
34
+ old_stdout = sys.stdout
35
+ old_stderr = sys.stderr
36
+
37
+ old_stdout.flush()
38
+ old_stderr.flush()
39
+
40
+ sys.stdout = stdout
41
+ sys.stderr = stderr
42
+ try:
43
+ yield stdout, stderr
44
+ finally:
45
+ stdout.flush()
46
+ stderr.flush()
47
+ sys.stdout = old_stdout
48
+ sys.stderr = old_stderr