ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,496 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Culprit finder for AI Edge Torch conversion."""
16
+
17
+ import contextlib
18
+ import copy
19
+ import dataclasses
20
+ import functools
21
+ import io
22
+ import operator
23
+ import os
24
+ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
25
+
26
+ import ai_edge_torch
27
+ from ai_edge_torch.debug import utils
28
+ import torch
29
+ from torch._functorch import aot_autograd
30
+ from torch._functorch.fx_minifier import minifier as fx_minifier
31
+ import torch.utils._pytree as pytree
32
+
33
+ _torch_float_dtypes = {
34
+ torch.float32,
35
+ torch.float,
36
+ torch.float64,
37
+ torch.double,
38
+ torch.float16,
39
+ torch.half,
40
+ torch.bfloat16,
41
+ }
42
+ _torch_int_dtypes = {
43
+ torch.uint8,
44
+ torch.int8,
45
+ torch.int16,
46
+ torch.short,
47
+ torch.int32,
48
+ torch.int,
49
+ torch.int64,
50
+ torch.long,
51
+ }
52
+
53
+ _fx_op_runner = {
54
+ "call_function": lambda target, args, kwargs: target(*args, **kwargs),
55
+ "call_method": lambda target, args, kwargs: getattr(args[0], target)(
56
+ *args[1:], **kwargs
57
+ ),
58
+ }
59
+
60
+ _CULPRIT_GRAPH_MODULE_NAME = "CulpritGraphModule"
61
+
62
+
63
+ def _get_shape_str(t: torch.Tensor):
64
+ return f"({', '.join(map(str, t.shape))},)"
65
+
66
+
67
+ def _tensor_to_random_tensor_call(t: torch.Tensor):
68
+ shape_str = _get_shape_str(t)
69
+ if t.dtype in _torch_float_dtypes:
70
+ return f"torch.randn({shape_str}, dtype={t.dtype})"
71
+ elif t.dtype in _torch_int_dtypes:
72
+ return f"torch.randint(0, 10, {shape_str}, dtype={t.dtype})"
73
+ elif t.dtype == torch.bool:
74
+ return f"torch.randint(0, 2, {shape_str}, dtype={t.dtype})"
75
+ else:
76
+ raise ValueError(f"Unsupported dtype: {t.dtype}")
77
+
78
+
79
+ def _tensor_to_buffer(t: torch.Tensor):
80
+ buff = io.BytesIO()
81
+ torch.save(t, buff)
82
+ buff.seek(0)
83
+ return buff.read()
84
+
85
+
86
+ @dataclasses.dataclass
87
+ class SearchResult:
88
+ graph_module: torch.fx.GraphModule
89
+ inputs: Tuple[Any]
90
+
91
+ @property
92
+ def graph(self) -> torch.fx.Graph:
93
+ return self.graph_module.graph
94
+
95
+ @graph.setter
96
+ def graph(self, fx_g: torch.fx.Graph):
97
+ self.graph_module.graph = fx_g
98
+
99
+
100
+ @dataclasses.dataclass
101
+ class Culprit(SearchResult):
102
+ _runtime_errors: bool
103
+
104
+ @property
105
+ def stack_traces(self) -> List[str]:
106
+ stack_traces = set()
107
+ for node in self.graph.nodes:
108
+ if node.op.startswith("call_") and "stack_trace" in node.meta:
109
+ stack_traces.add(node.meta["stack_trace"])
110
+ return list(stack_traces)
111
+
112
+ def print_readable(self, print_output=True):
113
+ """Print the Python code for culprit graph module and sample args.
114
+
115
+ Args:
116
+ print_output: bool - If true, prints the code to stdout. Otherwise returns
117
+ the code in a str.
118
+ """
119
+ # TODO: b/321263453 - Support Python code gen with sample arg tensor values.
120
+ random_inputs = True
121
+
122
+ graph_module_code = self.graph_module.print_readable(
123
+ print_output=False
124
+ ).rstrip()
125
+
126
+ input_strs = []
127
+ for value in self.inputs:
128
+ if torch.is_tensor(value):
129
+ if not random_inputs:
130
+ input_strs.append(
131
+ f"# size={_get_shape_str(value)}, dtype={value.dtype}"
132
+ )
133
+ input_strs.append(
134
+ f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),"
135
+ )
136
+ else:
137
+ input_strs.append(_tensor_to_random_tensor_call(value) + ",")
138
+ else:
139
+ input_strs.append(str(value) + ",")
140
+
141
+ inputs_code = (
142
+ "_args = (\n"
143
+ + "\n".join([" " * 4 + code for code in input_strs])
144
+ + "\n)"
145
+ )
146
+
147
+ code = graph_module_code + "\n\n" + inputs_code
148
+ if print_output:
149
+ print(code)
150
+ else:
151
+ return code
152
+
153
+ def print_code(self, print_output=True):
154
+ """Print the Python code for culprit graph module, sample args, and AI
155
+
156
+ Edge Torch conversion that will fail with the error.
157
+
158
+ Args:
159
+ print_output: bool - If true, prints the code to stdout. Otherwise returns
160
+ the code in a str.
161
+ """
162
+ definitions = self.print_readable(print_output=False)
163
+ code = (
164
+ "import torch\n"
165
+ + "from torch import device\n"
166
+ + "import ai_edge_torch\n\n"
167
+ + definitions
168
+ + "\n\n_edge_model ="
169
+ f" ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(),"
170
+ " _args)\n"
171
+ )
172
+ if self._runtime_errors:
173
+ code += "_edge_model(*_args)\n"
174
+
175
+ if print_output:
176
+ print(code)
177
+ else:
178
+ return code
179
+
180
+ @property
181
+ def code(self):
182
+ return self.print_code(print_output=False)
183
+
184
+ def __repr__(self):
185
+ return self.print_readable(print_output=False)
186
+
187
+ def __str__(self):
188
+ return self.print_readable(print_output=False)
189
+
190
+
191
+ def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
192
+ """This function turns all operator getitem nodes in ExportedProgram FX graph to
193
+
194
+ new nodes composed of "computation + getitem". The normalization duplicates
195
+ some computations in the graph but would make the graph more friendly for
196
+ partitioning in FX minifier.
197
+ """
198
+
199
+ fx_gm = copy.deepcopy(fx_gm)
200
+ graph = fx_gm.graph
201
+ for n in graph.nodes:
202
+ if n.target != operator.getitem:
203
+ continue
204
+
205
+ src_n, key = n.args
206
+ if src_n.op not in _fx_op_runner:
207
+ continue
208
+
209
+ runner = _fx_op_runner.get(src_n.op)
210
+
211
+ with graph.inserting_after(n):
212
+ new_n = graph.call_function(
213
+ lambda src_target, key, args, kwargs: operator.getitem(
214
+ runner(src_target, args, kwargs), key
215
+ ),
216
+ (src_n.target, key, src_n.args, src_n.kwargs),
217
+ )
218
+ n.replace_all_uses_with(new_n)
219
+
220
+ graph.eliminate_dead_code()
221
+ fx_gm.graph = graph
222
+ return fx_gm
223
+
224
+
225
+ def _erase_unused_inputs(
226
+ fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
227
+ ):
228
+ fx_gm = copy.deepcopy(fx_gm)
229
+ inputs = tuple(inputs)
230
+ args = fx_gm.graph.process_inputs(*inputs)
231
+ args_iter = iter(args)
232
+
233
+ graph = fx_gm.graph
234
+ new_inputs = []
235
+ for n in graph.nodes:
236
+ if n.op == "placeholder":
237
+ if n.target.startswith("*"):
238
+ new_inputs += list(args_iter)
239
+ elif len(n.users) > 0:
240
+ new_inputs.append(next(args_iter))
241
+ else:
242
+ graph.erase_node(n)
243
+ next(args_iter)
244
+ new_inputs = tuple(new_inputs)
245
+ fx_gm.graph = graph
246
+ return fx_gm, new_inputs
247
+
248
+
249
+ def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule):
250
+ fx_gm = copy.deepcopy(fx_gm)
251
+
252
+ new_outputs = []
253
+ graph = fx_gm.graph
254
+ nodes = list(graph.nodes)
255
+ assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
256
+ for node in nodes:
257
+ if node.op not in ("placeholder", "output") and len(node.users) == 0:
258
+ new_outputs.append(node)
259
+
260
+ output_node = nodes[-1]
261
+ # FX output node returns the first arg as is.
262
+ # ref: https://github.com/pytorch/pytorch/blob/1a578df57cc0f417f671634e564c62ef5d9a97e2/torch/fx/interpreter.py#L337
263
+ new_outputs, _ = pytree.tree_flatten([new_outputs, output_node.args[0]])
264
+ output_node.update_arg(0, tuple(new_outputs))
265
+
266
+ fx_gm.graph = graph
267
+ return fx_gm
268
+
269
+
270
+ def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule):
271
+ """Remove output nodes directly connected to an input node."""
272
+ fx_gm = copy.deepcopy(fx_gm)
273
+
274
+ graph = fx_gm.graph
275
+ nodes = list(graph.nodes)
276
+ assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
277
+ output_node = nodes[-1]
278
+
279
+ outputs, _ = pytree.tree_flatten(output_node.args[0])
280
+ new_outputs = [output for output in outputs if output.op != "placeholder"]
281
+ output_node.update_arg(0, tuple(new_outputs))
282
+
283
+ fx_gm.recompile()
284
+ return fx_gm
285
+
286
+
287
+ def _erase_sub_gm_from_gm(
288
+ fx_gm: torch.fx.GraphModule,
289
+ fx_inputs: Tuple[torch.Tensor],
290
+ sub_gm: torch.fx.GraphModule,
291
+ sub_inputs: Tuple[torch.Tensor],
292
+ ):
293
+ fx_gm = copy.deepcopy(fx_gm)
294
+ fx_inputs = list(fx_inputs)
295
+
296
+ class EraseNodeInterpreter(torch.fx.Interpreter):
297
+
298
+ def run_node(self, node):
299
+ nonlocal fx_gm, fx_inputs
300
+ res = super().run_node(node)
301
+ if node.op not in ("placeholder", "output"):
302
+ to_erase = next(m for m in fx_gm.graph.nodes if m.name == node.name)
303
+ # Raise the output (tensor) of the erased node to be an input of
304
+ # the new model graph. Some raised inputs may become unused later
305
+ # when all the users are within the erased subgraph, those inputs
306
+ # will be removed by the followed `_erase_unused_inputs` pass.
307
+ with fx_gm.graph.inserting_before(to_erase):
308
+ new_input = fx_gm.graph.placeholder(node.name + "__value")
309
+ to_erase.replace_all_uses_with(new_input)
310
+
311
+ fx_gm.graph.erase_node(to_erase)
312
+ fx_inputs.append(res)
313
+ return res
314
+
315
+ interpreter = EraseNodeInterpreter(sub_gm)
316
+ interpreter.run(*sub_inputs)
317
+
318
+ fx_gm.graph.lint()
319
+ fx_gm.recompile()
320
+
321
+ # Ops prior to the erased subgraph may be dangling. Lift them as outputs.
322
+ fx_gm = _lift_dead_ops_to_outputs(fx_gm)
323
+ fx_gm = _erase_trivial_outputs(fx_gm)
324
+ fx_gm, fx_inputs = _erase_unused_inputs(fx_gm, fx_inputs)
325
+
326
+ fx_gm.graph.lint()
327
+ fx_gm.recompile()
328
+ return fx_gm, fx_inputs
329
+
330
+
331
+ def _normalize_minified_fx_gm(
332
+ fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
333
+ ):
334
+ fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
335
+ fx_gm = _lift_dead_ops_to_outputs(fx_gm)
336
+ fx_gm, _ = aot_autograd.aot_export_module(fx_gm, inputs, trace_joint=False)
337
+ fx_gm.__class__.__name__ = _CULPRIT_GRAPH_MODULE_NAME
338
+ return fx_gm, inputs
339
+
340
+
341
+ def _fx_minifier_checker(fx_gm, inputs, runtime_errors=False):
342
+ fx_gm, inputs = _normalize_minified_fx_gm(fx_gm, inputs)
343
+
344
+ trivial_aten_ops = {
345
+ torch.ops.aten.view,
346
+ torch.ops.aten.view.default,
347
+ }
348
+ if all(
349
+ node.op in ("placeholder", "output") or node.target in trivial_aten_ops
350
+ for node in fx_gm.graph.nodes
351
+ ):
352
+ return False
353
+
354
+ try:
355
+ edge_model = ai_edge_torch.convert(fx_gm.eval(), inputs)
356
+ if runtime_errors:
357
+ edge_model(*inputs)
358
+ except Exception as err:
359
+ return True
360
+ return False
361
+
362
+
363
+ def _search_model(
364
+ predicate_f: Callable[[torch.fx.GraphModule, List[Any]], bool],
365
+ model: Union[torch.export.ExportedProgram, torch.nn.Module],
366
+ export_args: Tuple[Any] = None,
367
+ *,
368
+ max_granularity: Optional[int] = None,
369
+ enable_fx_minifier_logging: bool = False,
370
+ ) -> Generator[SearchResult, None, None]:
371
+ """Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
372
+
373
+ Args:
374
+ predicate_f: a predicate function the users specify. It takes a FX
375
+ (sub)graph and the inputs to this graph, return True if the graph
376
+ satisfies the predicate, return False otherwise.
377
+ model: model in which to search subgraph.
378
+ export_args: A set of args to trace the model with, i.e. model(*args) must
379
+ run. max_granularity - FX minifier arg. The maximum granularity (number of
380
+ nodes) in the returned ATen FX subgraph of the culprit.
381
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to
382
+ log the progress.
383
+ """
384
+
385
+ if isinstance(model, torch.nn.Module):
386
+ try:
387
+ ep = torch.export.export(model, export_args)
388
+ except Exception as err:
389
+ raise ValueError(
390
+ "Your model is not exportable by torch.export.export. Please modify"
391
+ " your model to be torch-exportable first."
392
+ ) from err
393
+ else:
394
+ ep = model
395
+
396
+ fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
397
+ fx_gm = _normalize_getitem_nodes(fx_gm)
398
+
399
+ # HACK: temporarily disable XLA_HLO_DEBUG and create_minified_hlo_graph so that
400
+ # fx_minifier won't dump intermediate stablehlo files to storage.
401
+ # https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
402
+ @contextlib.contextmanager
403
+ def disable_minifier_xla_debug():
404
+ xla_hlo_debug_value = None
405
+ if "XLA_HLO_DEBUG" in os.environ:
406
+ xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
407
+ del os.environ["XLA_HLO_DEBUG"]
408
+
409
+ create_minified_hlo_graph = (
410
+ torch._functorch.fx_minifier.create_minified_hlo_graph
411
+ )
412
+ torch._functorch.fx_minifier.create_minified_hlo_graph = (
413
+ lambda *args, **kwargs: None
414
+ )
415
+
416
+ try:
417
+ yield
418
+ finally:
419
+ if xla_hlo_debug_value is not None:
420
+ os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
421
+
422
+ torch._functorch.fx_minifier.create_minified_hlo_graph = (
423
+ create_minified_hlo_graph
424
+ )
425
+
426
+ found_culprits_num = 0
427
+ while True:
428
+ try:
429
+ with disable_minifier_xla_debug(), open(os.devnull, "w") as devnull:
430
+ with contextlib.nullcontext() if enable_fx_minifier_logging else utils.redirect_stdio(
431
+ stdout=devnull,
432
+ stderr=devnull,
433
+ ):
434
+ raw_min_fx_gm, raw_min_inputs = fx_minifier(
435
+ fx_gm,
436
+ fx_inputs,
437
+ predicate_f,
438
+ max_granularity=max_granularity,
439
+ )
440
+
441
+ min_fx_gm, min_inputs = _normalize_minified_fx_gm(
442
+ raw_min_fx_gm, raw_min_inputs
443
+ )
444
+ found_culprits_num += 1
445
+ yield SearchResult(min_fx_gm, min_inputs)
446
+
447
+ fx_gm, fx_inputs = _erase_sub_gm_from_gm(
448
+ fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
449
+ )
450
+
451
+ except RuntimeError as e:
452
+ if (
453
+ str(e) == "Input graph did not fail the tester"
454
+ and found_culprits_num > 0
455
+ ):
456
+ break
457
+ raise e
458
+
459
+
460
+ def find_culprits(
461
+ torch_model: torch.nn.Module,
462
+ args: Tuple[Any],
463
+ max_granularity: Optional[int] = None,
464
+ runtime_errors: bool = False,
465
+ *,
466
+ enable_fx_minifier_logging: bool = False,
467
+ ) -> Generator[Culprit, None, None]:
468
+ """Finds culprits in the AI Edge Torch model conversion.
469
+
470
+ Args:
471
+ torch_model: model to export and save
472
+ args: A set of args to trace the model with, i.e. torch_model(*args) must
473
+ run max_granularity - FX minifier arg. The maximum granularity (number of
474
+ nodes) in the returned ATen FX subgraph of the culprit.
475
+ runtime_errors: If true, find culprits for Python runtime errors with
476
+ converted model.
477
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to
478
+ log the progress.
479
+ """
480
+
481
+ fx_minifier_checker = functools.partial(
482
+ _fx_minifier_checker, runtime_errors=runtime_errors
483
+ )
484
+
485
+ for search_result in _search_model(
486
+ fx_minifier_checker,
487
+ torch_model,
488
+ args,
489
+ max_granularity=max_granularity,
490
+ enable_fx_minifier_logging=enable_fx_minifier_logging,
491
+ ):
492
+ yield Culprit(
493
+ search_result.graph_module,
494
+ search_result.inputs,
495
+ _runtime_errors=runtime_errors,
496
+ )
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,140 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import ast
18
+
19
+ import ai_edge_torch.debug
20
+ import torch
21
+
22
+ from absl.testing import absltest as googletest
23
+
24
+ find_culprits = ai_edge_torch.debug.find_culprits
25
+
26
+ _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
+
28
+ _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
29
+
30
+
31
+ @torch.library.impl(
32
+ _test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd"
33
+ )
34
+ def non_lowerable_op(x):
35
+ if x.max() > 10.0:
36
+ return x + 1.0
37
+ return x
38
+
39
+
40
+ @torch.library.impl(_test_culprit_lib, "non_lowerable_op", "Meta")
41
+ def non_lowerable_op_meta(x):
42
+ return torch.empty_like(x)
43
+
44
+
45
+ class BadModel(torch.nn.Module):
46
+
47
+ def forward(self, x):
48
+ x = x + 1
49
+ x = torch.ops.test_culprit.non_lowerable_op.default(x)
50
+ return x
51
+
52
+
53
+ class TestCulprit(googletest.TestCase):
54
+
55
+ def setUp(self):
56
+ super().setUp()
57
+ torch.manual_seed(0)
58
+ torch._dynamo.reset()
59
+
60
+ def test_find_culprits(self):
61
+ model = BadModel().eval()
62
+ args = (torch.rand(10),)
63
+
64
+ culprits = list(find_culprits(model, args))
65
+ self.assertEqual(len(culprits), 1)
66
+ self.assertIn(
67
+ torch.ops.test_culprit.non_lowerable_op.default,
68
+ [n.target for n in culprits[0].graph.nodes],
69
+ )
70
+
71
+ def test_valid_culprit_readable(self):
72
+ model = BadModel().eval()
73
+ args = (torch.rand(10),)
74
+
75
+ culprits = list(find_culprits(model, args))
76
+ self.assertEqual(len(culprits), 1)
77
+
78
+ code = culprits[0].print_readable(print_output=False)
79
+
80
+ # The code should be a valid Python code
81
+ ast.parse(code)
82
+
83
+ def test_valid_culprit_code(self):
84
+ model = BadModel().eval()
85
+ args = (torch.rand(10),)
86
+
87
+ culprits = list(find_culprits(model, args))
88
+ self.assertEqual(len(culprits), 1)
89
+
90
+ code = culprits[0].print_code(print_output=False)
91
+
92
+ # The code should be a valid Python code
93
+ ast.parse(code)
94
+
95
+ def test_find_multiple_culprits(self):
96
+ class MultiBadOpsModel(torch.nn.Module):
97
+
98
+ def forward(self, x):
99
+ x = x + 1
100
+ a = torch.ops.test_culprit.non_lowerable_op.default(x)
101
+ b = torch.ops.test_culprit.non_lowerable_op.default(x)
102
+ c = a + b
103
+ d = torch.ops.test_culprit.non_lowerable_op.default(c)
104
+ return d
105
+
106
+ model = MultiBadOpsModel().eval()
107
+ args = (torch.rand(10),)
108
+
109
+ culprits = list(find_culprits(model, args))
110
+ self.assertEqual(len(culprits), 3)
111
+ for culprit in culprits:
112
+ self.assertIn(
113
+ torch.ops.test_culprit.non_lowerable_op.default,
114
+ [n.target for n in culprit.graph.nodes],
115
+ )
116
+
117
+ def test_find_culprits_with_trivial_inputs_outputs(self):
118
+
119
+ class MultiBadOpsModel(torch.nn.Module):
120
+
121
+ def forward(self, x, y, z):
122
+ x = x + 1
123
+ a = torch.ops.test_culprit.non_lowerable_op.default(x)
124
+ b = torch.ops.test_culprit.non_lowerable_op.default(y)
125
+ return a, b, x, y, a, b
126
+
127
+ model = MultiBadOpsModel().eval()
128
+ args = (torch.rand(10), torch.rand(10), torch.rand(10))
129
+
130
+ culprits = list(find_culprits(model, args))
131
+ self.assertEqual(len(culprits), 2)
132
+ for culprit in culprits:
133
+ self.assertIn(
134
+ torch.ops.test_culprit.non_lowerable_op.default,
135
+ [n.target for n in culprit.graph.nodes],
136
+ )
137
+
138
+
139
+ if __name__ == "__main__":
140
+ googletest.main()
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tests for search_model."""
16
+
17
+ from ai_edge_torch.debug import _search_model
18
+ import torch
19
+
20
+ from absl.testing import absltest as googletest
21
+
22
+
23
+ class TestSearchModel(googletest.TestCase):
24
+
25
+ def test_search_model_with_ops(self):
26
+ class MultipleOpsModel(torch.nn.Module):
27
+
28
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
29
+ sub_0 = x - 1
30
+ add_0 = y + 1
31
+ mul_0 = x * y
32
+ add_1 = sub_0 + add_0
33
+ mul_1 = add_0 * mul_0
34
+ sub_1 = add_1 - mul_1
35
+ return sub_1
36
+
37
+ model = MultipleOpsModel().eval()
38
+ args = (torch.rand(10), torch.rand(10))
39
+
40
+ def find_subgraph_with_sub(fx_gm, inputs):
41
+ return torch.ops.aten.sub.Tensor in [n.target for n in fx_gm.graph.nodes]
42
+
43
+ results = list(_search_model(find_subgraph_with_sub, model, args))
44
+ self.assertEqual(len(results), 2)
45
+ self.assertIn(
46
+ torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes]
47
+ )
48
+
49
+
50
+ if __name__ == "__main__":
51
+ googletest.main()