ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import contextlib
|
|
17
|
+
import copy
|
|
18
|
+
import dataclasses
|
|
19
|
+
import functools
|
|
20
|
+
import io
|
|
21
|
+
import operator
|
|
22
|
+
import os
|
|
23
|
+
import sys
|
|
24
|
+
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
25
|
+
|
|
26
|
+
from functorch.compile import minifier as fx_minifier
|
|
27
|
+
import torch
|
|
28
|
+
from torch._functorch import aot_autograd
|
|
29
|
+
import torch.utils._pytree as pytree
|
|
30
|
+
|
|
31
|
+
import ai_edge_torch
|
|
32
|
+
from ai_edge_torch.debug import utils
|
|
33
|
+
|
|
34
|
+
_torch_float_dtypes = {
|
|
35
|
+
torch.float32,
|
|
36
|
+
torch.float,
|
|
37
|
+
torch.float64,
|
|
38
|
+
torch.double,
|
|
39
|
+
torch.float16,
|
|
40
|
+
torch.half,
|
|
41
|
+
torch.bfloat16,
|
|
42
|
+
}
|
|
43
|
+
_torch_int_dtypes = {
|
|
44
|
+
torch.uint8,
|
|
45
|
+
torch.int8,
|
|
46
|
+
torch.int16,
|
|
47
|
+
torch.short,
|
|
48
|
+
torch.int32,
|
|
49
|
+
torch.int,
|
|
50
|
+
torch.int64,
|
|
51
|
+
torch.long,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
_fx_op_runner = {
|
|
55
|
+
"call_function": lambda target, args, kwargs: target(*args, **kwargs),
|
|
56
|
+
"call_method": lambda target, args, kwargs: getattr(args[0], target)(
|
|
57
|
+
*args[1:], **kwargs
|
|
58
|
+
),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
_CULPRIT_GRAPH_MODULE_NAME = "CulpritGraphModule"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_shape_str(t: torch.Tensor):
|
|
65
|
+
return f"({', '.join(map(str, t.shape))},)"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _tensor_to_random_tensor_call(t: torch.Tensor):
|
|
69
|
+
shape_str = _get_shape_str(t)
|
|
70
|
+
if t.dtype in _torch_float_dtypes:
|
|
71
|
+
return f"torch.randn({shape_str}, dtype={t.dtype})"
|
|
72
|
+
elif t.dtype in _torch_int_dtypes:
|
|
73
|
+
return f"torch.randint(0, 10, {shape_str}, dtype={t.dtype})"
|
|
74
|
+
elif t.dtype == torch.bool:
|
|
75
|
+
return f"torch.randint(0, 2, {shape_str}, dtype={t.dtype})"
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Unsupported dtype: {t.dtype}")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _tensor_to_buffer(t: torch.Tensor):
|
|
81
|
+
buff = io.BytesIO()
|
|
82
|
+
torch.save(t, buff)
|
|
83
|
+
buff.seek(0)
|
|
84
|
+
return buff.read()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclasses.dataclass
|
|
88
|
+
class SearchResult:
|
|
89
|
+
graph_module: torch.fx.GraphModule
|
|
90
|
+
inputs: Tuple[Any]
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def graph(self) -> torch.fx.Graph:
|
|
94
|
+
return self.graph_module.graph
|
|
95
|
+
|
|
96
|
+
@graph.setter
|
|
97
|
+
def graph(self, fx_g: torch.fx.Graph):
|
|
98
|
+
self.graph_module.graph = fx_g
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclasses.dataclass
|
|
102
|
+
class Culprit(SearchResult):
|
|
103
|
+
_runtime_errors: bool
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def stack_traces(self) -> List[str]:
|
|
107
|
+
stack_traces = set()
|
|
108
|
+
for node in self.graph.nodes:
|
|
109
|
+
if node.op.startswith("call_") and "stack_trace" in node.meta:
|
|
110
|
+
stack_traces.add(node.meta["stack_trace"])
|
|
111
|
+
return list(stack_traces)
|
|
112
|
+
|
|
113
|
+
def print_readable(self, print_output=True):
|
|
114
|
+
"""Print the Python code for culprit graph module and sample args.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
print_output: bool - If true, prints the code to stdout. Otherwise returns
|
|
118
|
+
the code in a str.
|
|
119
|
+
"""
|
|
120
|
+
# TODO (b/321263453): Support Python code gen with sample arg tensor values.
|
|
121
|
+
random_inputs = True
|
|
122
|
+
|
|
123
|
+
graph_module_code = self.graph_module.print_readable(print_output=False).rstrip()
|
|
124
|
+
|
|
125
|
+
input_strs = []
|
|
126
|
+
for value in self.inputs:
|
|
127
|
+
if torch.is_tensor(value):
|
|
128
|
+
if not random_inputs:
|
|
129
|
+
input_strs.append(f"# size={_get_shape_str(value)}, dtype={value.dtype}")
|
|
130
|
+
input_strs.append(f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),")
|
|
131
|
+
else:
|
|
132
|
+
input_strs.append(_tensor_to_random_tensor_call(value) + ",")
|
|
133
|
+
else:
|
|
134
|
+
input_strs.append(str(value) + ",")
|
|
135
|
+
|
|
136
|
+
inputs_code = (
|
|
137
|
+
"_args = (\n" + "\n".join([" " * 4 + code for code in input_strs]) + "\n)"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
code = graph_module_code + "\n\n" + inputs_code
|
|
141
|
+
if print_output:
|
|
142
|
+
print(code)
|
|
143
|
+
else:
|
|
144
|
+
return code
|
|
145
|
+
|
|
146
|
+
def print_code(self, print_output=True):
|
|
147
|
+
"""Print the Python code for culprit graph module, sample args, and AI
|
|
148
|
+
Edge Torch conversion that will fail with the error.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
print_output: bool - If true, prints the code to stdout. Otherwise returns
|
|
152
|
+
the code in a str.
|
|
153
|
+
"""
|
|
154
|
+
definitions = self.print_readable(print_output=False)
|
|
155
|
+
code = (
|
|
156
|
+
"import torch\n"
|
|
157
|
+
+ "from torch import device\n"
|
|
158
|
+
+ "import ai_edge_torch\n\n"
|
|
159
|
+
+ definitions
|
|
160
|
+
+ f"\n\n_edge_model = ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(), _args)\n"
|
|
161
|
+
)
|
|
162
|
+
if self._runtime_errors:
|
|
163
|
+
code += "_edge_model(*_args)\n"
|
|
164
|
+
|
|
165
|
+
if print_output:
|
|
166
|
+
print(code)
|
|
167
|
+
else:
|
|
168
|
+
return code
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def code(self):
|
|
172
|
+
return self.print_code(print_output=False)
|
|
173
|
+
|
|
174
|
+
def __repr__(self):
|
|
175
|
+
return self.print_readable(print_output=False)
|
|
176
|
+
|
|
177
|
+
def __str__(self):
|
|
178
|
+
return self.print_readable(print_output=False)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
|
|
182
|
+
"""
|
|
183
|
+
This function turns all operator getitem nodes in ExportedProgram FX graph to
|
|
184
|
+
new nodes composed of "computation + getitem". The normalization duplicates
|
|
185
|
+
some computations in the graph but would make the graph more friendly for
|
|
186
|
+
partitioning in FX minifier.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
fx_gm = copy.deepcopy(fx_gm)
|
|
190
|
+
graph = fx_gm.graph
|
|
191
|
+
for n in graph.nodes:
|
|
192
|
+
if n.target != operator.getitem:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
src_n, key = n.args
|
|
196
|
+
if src_n.op not in _fx_op_runner:
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
runner = _fx_op_runner.get(src_n.op)
|
|
200
|
+
|
|
201
|
+
with graph.inserting_after(n):
|
|
202
|
+
new_n = graph.call_function(
|
|
203
|
+
lambda src_target, key, args, kwargs: operator.getitem(
|
|
204
|
+
runner(src_target, args, kwargs), key
|
|
205
|
+
),
|
|
206
|
+
(src_n.target, key, src_n.args, src_n.kwargs),
|
|
207
|
+
)
|
|
208
|
+
n.replace_all_uses_with(new_n)
|
|
209
|
+
|
|
210
|
+
graph.eliminate_dead_code()
|
|
211
|
+
fx_gm.graph = graph
|
|
212
|
+
return fx_gm
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
|
|
216
|
+
fx_gm = copy.deepcopy(fx_gm)
|
|
217
|
+
inputs = tuple(inputs)
|
|
218
|
+
args = fx_gm.graph.process_inputs(*inputs)
|
|
219
|
+
args_iter = iter(args)
|
|
220
|
+
|
|
221
|
+
graph = fx_gm.graph
|
|
222
|
+
new_inputs = []
|
|
223
|
+
for n in graph.nodes:
|
|
224
|
+
if n.op == "placeholder":
|
|
225
|
+
if n.target.startswith("*"):
|
|
226
|
+
new_inputs += list(args_iter)
|
|
227
|
+
elif len(n.users) > 0:
|
|
228
|
+
new_inputs.append(next(args_iter))
|
|
229
|
+
else:
|
|
230
|
+
graph.erase_node(n)
|
|
231
|
+
next(args_iter)
|
|
232
|
+
new_inputs = tuple(new_inputs)
|
|
233
|
+
fx_gm.graph = graph
|
|
234
|
+
return fx_gm, new_inputs
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule):
|
|
238
|
+
fx_gm = copy.deepcopy(fx_gm)
|
|
239
|
+
|
|
240
|
+
new_outputs = []
|
|
241
|
+
graph = fx_gm.graph
|
|
242
|
+
nodes = list(graph.nodes)
|
|
243
|
+
assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
|
|
244
|
+
for node in nodes:
|
|
245
|
+
if node.op not in ("placeholder", "output") and len(node.users) == 0:
|
|
246
|
+
new_outputs.append(node)
|
|
247
|
+
|
|
248
|
+
output_node = nodes[-1]
|
|
249
|
+
# FX output node returns the first arg as is.
|
|
250
|
+
# ref: https://github.com/pytorch/pytorch/blob/1a578df57cc0f417f671634e564c62ef5d9a97e2/torch/fx/interpreter.py#L337
|
|
251
|
+
new_outputs, _ = pytree.tree_flatten([new_outputs, output_node.args[0]])
|
|
252
|
+
output_node.update_arg(0, tuple(new_outputs))
|
|
253
|
+
|
|
254
|
+
fx_gm.graph = graph
|
|
255
|
+
return fx_gm
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule):
|
|
259
|
+
"""Remove output nodes directly connected to an input node."""
|
|
260
|
+
fx_gm = copy.deepcopy(fx_gm)
|
|
261
|
+
|
|
262
|
+
graph = fx_gm.graph
|
|
263
|
+
nodes = list(graph.nodes)
|
|
264
|
+
assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
|
|
265
|
+
output_node = nodes[-1]
|
|
266
|
+
|
|
267
|
+
outputs, _ = pytree.tree_flatten(output_node.args[0])
|
|
268
|
+
new_outputs = [output for output in outputs if output.op != "placeholder"]
|
|
269
|
+
output_node.update_arg(0, tuple(new_outputs))
|
|
270
|
+
|
|
271
|
+
fx_gm.recompile()
|
|
272
|
+
return fx_gm
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _erase_sub_gm_from_gm(
|
|
276
|
+
fx_gm: torch.fx.GraphModule,
|
|
277
|
+
fx_inputs: Tuple[torch.Tensor],
|
|
278
|
+
sub_gm: torch.fx.GraphModule,
|
|
279
|
+
sub_inputs: Tuple[torch.Tensor],
|
|
280
|
+
):
|
|
281
|
+
fx_gm = copy.deepcopy(fx_gm)
|
|
282
|
+
fx_inputs = list(fx_inputs)
|
|
283
|
+
|
|
284
|
+
class EraseNodeInterpreter(torch.fx.Interpreter):
|
|
285
|
+
|
|
286
|
+
def run_node(self, node):
|
|
287
|
+
nonlocal fx_gm, fx_inputs
|
|
288
|
+
res = super().run_node(node)
|
|
289
|
+
if node.op not in ("placeholder", "output"):
|
|
290
|
+
to_erase = next(m for m in fx_gm.graph.nodes if m.name == node.name)
|
|
291
|
+
# Raise the output (tensor) of the erased node to be an input of
|
|
292
|
+
# the new model graph. Some raised inputs may become unused later
|
|
293
|
+
# when all the users are within the erased subgraph, those inputs
|
|
294
|
+
# will be removed by the followed `_erase_unused_inputs` pass.
|
|
295
|
+
with fx_gm.graph.inserting_before(to_erase):
|
|
296
|
+
new_input = fx_gm.graph.placeholder(node.name + "__value")
|
|
297
|
+
to_erase.replace_all_uses_with(new_input)
|
|
298
|
+
|
|
299
|
+
fx_gm.graph.erase_node(to_erase)
|
|
300
|
+
fx_inputs.append(res)
|
|
301
|
+
return res
|
|
302
|
+
|
|
303
|
+
interpreter = EraseNodeInterpreter(sub_gm)
|
|
304
|
+
interpreter.run(*sub_inputs)
|
|
305
|
+
|
|
306
|
+
fx_gm.graph.lint()
|
|
307
|
+
fx_gm.recompile()
|
|
308
|
+
|
|
309
|
+
# Ops prior to the erased subgraph may be dangling. Lift them as outputs.
|
|
310
|
+
fx_gm = _lift_dead_ops_to_outputs(fx_gm)
|
|
311
|
+
fx_gm = _erase_trivial_outputs(fx_gm)
|
|
312
|
+
fx_gm, fx_inputs = _erase_unused_inputs(fx_gm, fx_inputs)
|
|
313
|
+
|
|
314
|
+
fx_gm.graph.lint()
|
|
315
|
+
fx_gm.recompile()
|
|
316
|
+
return fx_gm, fx_inputs
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _normalize_minified_fx_gm(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
|
|
320
|
+
fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
|
|
321
|
+
fx_gm = _lift_dead_ops_to_outputs(fx_gm)
|
|
322
|
+
fx_gm, _ = aot_autograd.aot_export_module(fx_gm, inputs, trace_joint=False)
|
|
323
|
+
fx_gm.__class__.__name__ = _CULPRIT_GRAPH_MODULE_NAME
|
|
324
|
+
return fx_gm, inputs
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _fx_minifier_checker(fx_gm, inputs, runtime_errors=False):
|
|
328
|
+
fx_gm, inputs = _normalize_minified_fx_gm(fx_gm, inputs)
|
|
329
|
+
|
|
330
|
+
trivial_aten_ops = {
|
|
331
|
+
torch.ops.aten.view,
|
|
332
|
+
torch.ops.aten.view.default,
|
|
333
|
+
}
|
|
334
|
+
if all(
|
|
335
|
+
node.op in ("placeholder", "output") or node.target in trivial_aten_ops
|
|
336
|
+
for node in fx_gm.graph.nodes
|
|
337
|
+
):
|
|
338
|
+
return False
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
edge_model = ai_edge_torch.convert(fx_gm.eval(), inputs)
|
|
342
|
+
if runtime_errors:
|
|
343
|
+
edge_model(*inputs)
|
|
344
|
+
except Exception as err:
|
|
345
|
+
return True
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _search_model(
|
|
350
|
+
predicate_f: Callable[[torch.fx.GraphModule, List[Any]], bool],
|
|
351
|
+
model: Union[torch.export.ExportedProgram, torch.nn.Module],
|
|
352
|
+
export_args: Tuple[Any] = None,
|
|
353
|
+
*,
|
|
354
|
+
max_granularity: Optional[int] = None,
|
|
355
|
+
enable_fx_minifier_logging: bool = False,
|
|
356
|
+
) -> Generator[SearchResult, None, None]:
|
|
357
|
+
"""Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
predicate_f: a predicate function the users specify.
|
|
361
|
+
It takes a FX (sub)graph and the inputs to this graph,
|
|
362
|
+
return True if the graph satisfies the predicate,
|
|
363
|
+
return False otherwise.
|
|
364
|
+
model: model in which to search subgraph.
|
|
365
|
+
export_args: A set of args to trace the model with,
|
|
366
|
+
i.e. model(*args) must run.
|
|
367
|
+
max_granularity - FX minifier arg. The maximum granularity (number of nodes)
|
|
368
|
+
in the returned ATen FX subgraph of the culprit.
|
|
369
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
if isinstance(model, torch.nn.Module):
|
|
373
|
+
try:
|
|
374
|
+
ep = torch.export.export(model, export_args)
|
|
375
|
+
except Exception as err:
|
|
376
|
+
raise ValueError(
|
|
377
|
+
"Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first."
|
|
378
|
+
) from err
|
|
379
|
+
else:
|
|
380
|
+
ep = model
|
|
381
|
+
|
|
382
|
+
fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
|
|
383
|
+
fx_gm = _normalize_getitem_nodes(fx_gm)
|
|
384
|
+
|
|
385
|
+
# HACK: temporarily disable XLA_HLO_DEBUG so that fx_minifier won't dump
|
|
386
|
+
# intermediate stablehlo files to storage.
|
|
387
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
|
|
388
|
+
@contextlib.contextmanager
|
|
389
|
+
def disable_xla_hlo_debug():
|
|
390
|
+
xla_hlo_debug_value = None
|
|
391
|
+
if "XLA_HLO_DEBUG" in os.environ:
|
|
392
|
+
xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
|
|
393
|
+
del os.environ["XLA_HLO_DEBUG"]
|
|
394
|
+
|
|
395
|
+
try:
|
|
396
|
+
yield None
|
|
397
|
+
finally:
|
|
398
|
+
if xla_hlo_debug_value is not None:
|
|
399
|
+
os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
|
|
400
|
+
|
|
401
|
+
found_culprits_num = 0
|
|
402
|
+
while True:
|
|
403
|
+
try:
|
|
404
|
+
with disable_xla_hlo_debug(), open(os.devnull, "w") as devnull:
|
|
405
|
+
with contextlib.nullcontext() if enable_fx_minifier_logging else utils.redirect_stdio(
|
|
406
|
+
stdout=devnull,
|
|
407
|
+
stderr=devnull,
|
|
408
|
+
):
|
|
409
|
+
raw_min_fx_gm, raw_min_inputs = fx_minifier(
|
|
410
|
+
fx_gm,
|
|
411
|
+
fx_inputs,
|
|
412
|
+
predicate_f,
|
|
413
|
+
max_granularity=max_granularity,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
min_fx_gm, min_inputs = _normalize_minified_fx_gm(raw_min_fx_gm, raw_min_inputs)
|
|
417
|
+
found_culprits_num += 1
|
|
418
|
+
yield SearchResult(min_fx_gm, min_inputs)
|
|
419
|
+
|
|
420
|
+
fx_gm, fx_inputs = _erase_sub_gm_from_gm(
|
|
421
|
+
fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
except RuntimeError as e:
|
|
425
|
+
if str(e) == "Input graph did not fail the tester" and found_culprits_num > 0:
|
|
426
|
+
break
|
|
427
|
+
raise e
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def find_culprits(
|
|
431
|
+
torch_model: torch.nn.Module,
|
|
432
|
+
args: Tuple[Any],
|
|
433
|
+
max_granularity: Optional[int] = None,
|
|
434
|
+
runtime_errors: bool = False,
|
|
435
|
+
*,
|
|
436
|
+
enable_fx_minifier_logging: bool = False,
|
|
437
|
+
) -> Generator[Culprit, None, None]:
|
|
438
|
+
"""Finds culprits in the AI Edge Torch model conversion.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
torch_model: model to export and save
|
|
442
|
+
args: A set of args to trace the model with, i.e.
|
|
443
|
+
torch_model(*args) must run
|
|
444
|
+
max_granularity - FX minifier arg. The maximum granularity (number of nodes)
|
|
445
|
+
in the returned ATen FX subgraph of the culprit.
|
|
446
|
+
runtime_errors: If true, find culprits for Python runtime errors
|
|
447
|
+
with converted model.
|
|
448
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
fx_minifier_checker = functools.partial(
|
|
452
|
+
_fx_minifier_checker, runtime_errors=runtime_errors
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
for search_result in _search_model(
|
|
456
|
+
fx_minifier_checker,
|
|
457
|
+
torch_model,
|
|
458
|
+
args,
|
|
459
|
+
max_granularity=max_granularity,
|
|
460
|
+
enable_fx_minifier_logging=enable_fx_minifier_logging,
|
|
461
|
+
):
|
|
462
|
+
yield Culprit(
|
|
463
|
+
search_result.graph_module, search_result.inputs, _runtime_errors=runtime_errors
|
|
464
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import ast
|
|
18
|
+
import io
|
|
19
|
+
import sys
|
|
20
|
+
import unittest
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from ai_edge_torch.debug import find_culprits
|
|
25
|
+
|
|
26
|
+
_test_culprit_lib = torch.library.Library("test_culprit", "DEF")
|
|
27
|
+
|
|
28
|
+
_test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@torch.library.impl(_test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd")
|
|
32
|
+
def non_lowerable_op(x):
|
|
33
|
+
if x.max() > 10.0:
|
|
34
|
+
return x + 1.0
|
|
35
|
+
return x
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@torch.library.impl(_test_culprit_lib, "non_lowerable_op", "Meta")
|
|
39
|
+
def non_lowerable_op_meta(x):
|
|
40
|
+
return torch.empty_like(x)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class BadModel(torch.nn.Module):
|
|
44
|
+
|
|
45
|
+
def forward(self, x):
|
|
46
|
+
x = x + 1
|
|
47
|
+
x = torch.ops.test_culprit.non_lowerable_op.default(x)
|
|
48
|
+
return x
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TestCulprit(unittest.TestCase):
|
|
52
|
+
|
|
53
|
+
def test_find_culprits(self):
|
|
54
|
+
model = BadModel().eval()
|
|
55
|
+
args = (torch.rand(10),)
|
|
56
|
+
|
|
57
|
+
culprits = list(find_culprits(model, args))
|
|
58
|
+
self.assertEqual(len(culprits), 1)
|
|
59
|
+
self.assertIn(
|
|
60
|
+
torch.ops.test_culprit.non_lowerable_op.default,
|
|
61
|
+
[n.target for n in culprits[0].graph.nodes],
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def test_valid_culprit_readable(self):
|
|
65
|
+
model = BadModel().eval()
|
|
66
|
+
args = (torch.rand(10),)
|
|
67
|
+
|
|
68
|
+
culprits = list(find_culprits(model, args))
|
|
69
|
+
self.assertEqual(len(culprits), 1)
|
|
70
|
+
|
|
71
|
+
code = culprits[0].print_readable(print_output=False)
|
|
72
|
+
|
|
73
|
+
# The code should be a valid Python code
|
|
74
|
+
ast.parse(code)
|
|
75
|
+
|
|
76
|
+
def test_valid_culprit_code(self):
|
|
77
|
+
model = BadModel().eval()
|
|
78
|
+
args = (torch.rand(10),)
|
|
79
|
+
|
|
80
|
+
culprits = list(find_culprits(model, args))
|
|
81
|
+
self.assertEqual(len(culprits), 1)
|
|
82
|
+
|
|
83
|
+
code = culprits[0].print_code(print_output=False)
|
|
84
|
+
|
|
85
|
+
# The code should be a valid Python code
|
|
86
|
+
ast.parse(code)
|
|
87
|
+
|
|
88
|
+
def test_find_multiple_culprits(self):
|
|
89
|
+
class MultiBadOpsModel(torch.nn.Module):
|
|
90
|
+
|
|
91
|
+
def forward(self, x):
|
|
92
|
+
x = x + 1
|
|
93
|
+
a = torch.ops.test_culprit.non_lowerable_op.default(x)
|
|
94
|
+
b = torch.ops.test_culprit.non_lowerable_op.default(x)
|
|
95
|
+
c = a + b
|
|
96
|
+
d = torch.ops.test_culprit.non_lowerable_op.default(c)
|
|
97
|
+
return d
|
|
98
|
+
|
|
99
|
+
model = MultiBadOpsModel().eval()
|
|
100
|
+
args = (torch.rand(10),)
|
|
101
|
+
|
|
102
|
+
culprits = list(find_culprits(model, args))
|
|
103
|
+
self.assertEqual(len(culprits), 3)
|
|
104
|
+
for culprit in culprits:
|
|
105
|
+
self.assertIn(
|
|
106
|
+
torch.ops.test_culprit.non_lowerable_op.default,
|
|
107
|
+
[n.target for n in culprit.graph.nodes],
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def test_find_culprits_with_trivial_inputs_outputs(self):
|
|
111
|
+
|
|
112
|
+
class MultiBadOpsModel(torch.nn.Module):
|
|
113
|
+
|
|
114
|
+
def forward(self, x, y, z):
|
|
115
|
+
x = x + 1
|
|
116
|
+
a = torch.ops.test_culprit.non_lowerable_op.default(x)
|
|
117
|
+
b = torch.ops.test_culprit.non_lowerable_op.default(y)
|
|
118
|
+
return a, b, x, y, a, b
|
|
119
|
+
|
|
120
|
+
model = MultiBadOpsModel().eval()
|
|
121
|
+
args = (torch.rand(10), torch.rand(10), torch.rand(10))
|
|
122
|
+
|
|
123
|
+
culprits = list(find_culprits(model, args))
|
|
124
|
+
self.assertEqual(len(culprits), 2)
|
|
125
|
+
for culprit in culprits:
|
|
126
|
+
self.assertIn(
|
|
127
|
+
torch.ops.test_culprit.non_lowerable_op.default,
|
|
128
|
+
[n.target for n in culprit.graph.nodes],
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
if __name__ == "__main__":
|
|
133
|
+
unittest.main()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import unittest
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.debug import _search_model
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestSearchModel(unittest.TestCase):
|
|
25
|
+
|
|
26
|
+
def test_search_model_with_ops(self):
|
|
27
|
+
class MultipleOpsModel(torch.nn.Module):
|
|
28
|
+
|
|
29
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
30
|
+
sub_0 = x - 1
|
|
31
|
+
add_0 = y + 1
|
|
32
|
+
mul_0 = x * y
|
|
33
|
+
add_1 = sub_0 + add_0
|
|
34
|
+
mul_1 = add_0 * mul_0
|
|
35
|
+
sub_1 = add_1 - mul_1
|
|
36
|
+
return sub_1
|
|
37
|
+
|
|
38
|
+
model = MultipleOpsModel().eval()
|
|
39
|
+
args = (torch.rand(10), torch.rand(10))
|
|
40
|
+
|
|
41
|
+
def find_subgraph_with_sub(fx_gm, inputs):
|
|
42
|
+
return torch.ops.aten.sub.Tensor in [n.target for n in fx_gm.graph.nodes]
|
|
43
|
+
|
|
44
|
+
results = list(_search_model(find_subgraph_with_sub, model, args))
|
|
45
|
+
self.assertEqual(len(results), 2)
|
|
46
|
+
self.assertIn(torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
if __name__ == "__main__":
|
|
50
|
+
unittest.main()
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import contextlib
|
|
16
|
+
import sys
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch.export.graph_signature import InputKind
|
|
20
|
+
import torch.fx._pytree as fx_pytree
|
|
21
|
+
from torch.utils import _pytree as pytree
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def exported_program_to_fx_graph_module_and_inputs(ep: torch.export.ExportedProgram):
|
|
25
|
+
fx_gm = ep.graph_module
|
|
26
|
+
fx_inputs = pytree.tree_map(
|
|
27
|
+
torch.tensor, ep._graph_module_flat_inputs(*ep.example_inputs)
|
|
28
|
+
)
|
|
29
|
+
return fx_gm, fx_inputs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@contextlib.contextmanager
|
|
33
|
+
def redirect_stdio(stdout, stderr):
|
|
34
|
+
old_stdout = sys.stdout
|
|
35
|
+
old_stderr = sys.stderr
|
|
36
|
+
|
|
37
|
+
old_stdout.flush()
|
|
38
|
+
old_stderr.flush()
|
|
39
|
+
|
|
40
|
+
sys.stdout = stdout
|
|
41
|
+
sys.stderr = stderr
|
|
42
|
+
try:
|
|
43
|
+
yield stdout, stderr
|
|
44
|
+
finally:
|
|
45
|
+
stdout.flush()
|
|
46
|
+
stderr.flush()
|
|
47
|
+
sys.stdout = old_stdout
|
|
48
|
+
sys.stderr = old_stderr
|