ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +32 -0
- ai_edge_torch/_config.py +69 -0
- ai_edge_torch/_convert/__init__.py +14 -0
- ai_edge_torch/_convert/conversion.py +153 -0
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/_convert/converter.py +270 -0
- ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
- ai_edge_torch/_convert/signature.py +66 -0
- ai_edge_torch/_convert/test/__init__.py +14 -0
- ai_edge_torch/_convert/test/test_convert.py +558 -0
- ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
- ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/_convert/to_channel_last_io.py +92 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +496 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +140 -0
- ai_edge_torch/debug/test/test_search_model.py +51 -0
- ai_edge_torch/debug/utils.py +59 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/fx_pass_base.py +110 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
- ai_edge_torch/generative/examples/llama/llama.py +196 -0
- ai_edge_torch/generative/examples/llama/verify.py +88 -0
- ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
- ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
- ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
- ai_edge_torch/generative/examples/openelm/verify.py +71 -0
- ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
- ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
- ai_edge_torch/generative/examples/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/phi/phi2.py +107 -0
- ai_edge_torch/generative/examples/phi/phi3.py +219 -0
- ai_edge_torch/generative/examples/phi/verify.py +64 -0
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
- ai_edge_torch/generative/examples/qwen/verify.py +88 -0
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
- ai_edge_torch/generative/examples/smollm/verify.py +86 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
- ai_edge_torch/generative/examples/t5/t5.py +655 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
- ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
- ai_edge_torch/generative/fx_passes/__init__.py +30 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +399 -0
- ai_edge_torch/generative/layers/attention_utils.py +210 -0
- ai_edge_torch/generative/layers/builder.py +160 -0
- ai_edge_torch/generative/layers/feed_forward.py +120 -0
- ai_edge_torch/generative/layers/kv_cache.py +204 -0
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +238 -0
- ai_edge_torch/generative/layers/normalization.py +222 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
- ai_edge_torch/generative/layers/unet/builder.py +50 -0
- ai_edge_torch/generative/layers/unet/model_config.py +282 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +47 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_custom_dus.py +107 -0
- ai_edge_torch/generative/test/test_kv_cache.py +120 -0
- ai_edge_torch/generative/test/test_loader.py +83 -0
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion.py +191 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
- ai_edge_torch/generative/test/test_quantize.py +183 -0
- ai_edge_torch/generative/test/utils.py +82 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/converter.py +215 -0
- ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
- ai_edge_torch/generative/utilities/loader.py +398 -0
- ai_edge_torch/generative/utilities/model_builder.py +180 -0
- ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
- ai_edge_torch/generative/utilities/t5_loader.py +512 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +335 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
- ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
- ai_edge_torch/lowertools/__init__.py +18 -0
- ai_edge_torch/lowertools/_shim.py +86 -0
- ai_edge_torch/lowertools/common_utils.py +142 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
- ai_edge_torch/lowertools/test_utils.py +62 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
- ai_edge_torch/lowertools/translate_recipe.py +163 -0
- ai_edge_torch/model.py +177 -0
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +88 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +403 -0
- ai_edge_torch/odml_torch/export_utils.py +157 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
- ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +156 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
- ai_edge_torch/version.py +16 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,496 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Culprit finder for AI Edge Torch conversion."""
|
16
|
+
|
17
|
+
import contextlib
|
18
|
+
import copy
|
19
|
+
import dataclasses
|
20
|
+
import functools
|
21
|
+
import io
|
22
|
+
import operator
|
23
|
+
import os
|
24
|
+
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
25
|
+
|
26
|
+
import ai_edge_torch
|
27
|
+
from ai_edge_torch.debug import utils
|
28
|
+
import torch
|
29
|
+
from torch._functorch import aot_autograd
|
30
|
+
from torch._functorch.fx_minifier import minifier as fx_minifier
|
31
|
+
import torch.utils._pytree as pytree
|
32
|
+
|
33
|
+
_torch_float_dtypes = {
|
34
|
+
torch.float32,
|
35
|
+
torch.float,
|
36
|
+
torch.float64,
|
37
|
+
torch.double,
|
38
|
+
torch.float16,
|
39
|
+
torch.half,
|
40
|
+
torch.bfloat16,
|
41
|
+
}
|
42
|
+
_torch_int_dtypes = {
|
43
|
+
torch.uint8,
|
44
|
+
torch.int8,
|
45
|
+
torch.int16,
|
46
|
+
torch.short,
|
47
|
+
torch.int32,
|
48
|
+
torch.int,
|
49
|
+
torch.int64,
|
50
|
+
torch.long,
|
51
|
+
}
|
52
|
+
|
53
|
+
_fx_op_runner = {
|
54
|
+
"call_function": lambda target, args, kwargs: target(*args, **kwargs),
|
55
|
+
"call_method": lambda target, args, kwargs: getattr(args[0], target)(
|
56
|
+
*args[1:], **kwargs
|
57
|
+
),
|
58
|
+
}
|
59
|
+
|
60
|
+
_CULPRIT_GRAPH_MODULE_NAME = "CulpritGraphModule"
|
61
|
+
|
62
|
+
|
63
|
+
def _get_shape_str(t: torch.Tensor):
|
64
|
+
return f"({', '.join(map(str, t.shape))},)"
|
65
|
+
|
66
|
+
|
67
|
+
def _tensor_to_random_tensor_call(t: torch.Tensor):
|
68
|
+
shape_str = _get_shape_str(t)
|
69
|
+
if t.dtype in _torch_float_dtypes:
|
70
|
+
return f"torch.randn({shape_str}, dtype={t.dtype})"
|
71
|
+
elif t.dtype in _torch_int_dtypes:
|
72
|
+
return f"torch.randint(0, 10, {shape_str}, dtype={t.dtype})"
|
73
|
+
elif t.dtype == torch.bool:
|
74
|
+
return f"torch.randint(0, 2, {shape_str}, dtype={t.dtype})"
|
75
|
+
else:
|
76
|
+
raise ValueError(f"Unsupported dtype: {t.dtype}")
|
77
|
+
|
78
|
+
|
79
|
+
def _tensor_to_buffer(t: torch.Tensor):
|
80
|
+
buff = io.BytesIO()
|
81
|
+
torch.save(t, buff)
|
82
|
+
buff.seek(0)
|
83
|
+
return buff.read()
|
84
|
+
|
85
|
+
|
86
|
+
@dataclasses.dataclass
|
87
|
+
class SearchResult:
|
88
|
+
graph_module: torch.fx.GraphModule
|
89
|
+
inputs: Tuple[Any]
|
90
|
+
|
91
|
+
@property
|
92
|
+
def graph(self) -> torch.fx.Graph:
|
93
|
+
return self.graph_module.graph
|
94
|
+
|
95
|
+
@graph.setter
|
96
|
+
def graph(self, fx_g: torch.fx.Graph):
|
97
|
+
self.graph_module.graph = fx_g
|
98
|
+
|
99
|
+
|
100
|
+
@dataclasses.dataclass
|
101
|
+
class Culprit(SearchResult):
|
102
|
+
_runtime_errors: bool
|
103
|
+
|
104
|
+
@property
|
105
|
+
def stack_traces(self) -> List[str]:
|
106
|
+
stack_traces = set()
|
107
|
+
for node in self.graph.nodes:
|
108
|
+
if node.op.startswith("call_") and "stack_trace" in node.meta:
|
109
|
+
stack_traces.add(node.meta["stack_trace"])
|
110
|
+
return list(stack_traces)
|
111
|
+
|
112
|
+
def print_readable(self, print_output=True):
|
113
|
+
"""Print the Python code for culprit graph module and sample args.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
print_output: bool - If true, prints the code to stdout. Otherwise returns
|
117
|
+
the code in a str.
|
118
|
+
"""
|
119
|
+
# TODO: b/321263453 - Support Python code gen with sample arg tensor values.
|
120
|
+
random_inputs = True
|
121
|
+
|
122
|
+
graph_module_code = self.graph_module.print_readable(
|
123
|
+
print_output=False
|
124
|
+
).rstrip()
|
125
|
+
|
126
|
+
input_strs = []
|
127
|
+
for value in self.inputs:
|
128
|
+
if torch.is_tensor(value):
|
129
|
+
if not random_inputs:
|
130
|
+
input_strs.append(
|
131
|
+
f"# size={_get_shape_str(value)}, dtype={value.dtype}"
|
132
|
+
)
|
133
|
+
input_strs.append(
|
134
|
+
f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),"
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
input_strs.append(_tensor_to_random_tensor_call(value) + ",")
|
138
|
+
else:
|
139
|
+
input_strs.append(str(value) + ",")
|
140
|
+
|
141
|
+
inputs_code = (
|
142
|
+
"_args = (\n"
|
143
|
+
+ "\n".join([" " * 4 + code for code in input_strs])
|
144
|
+
+ "\n)"
|
145
|
+
)
|
146
|
+
|
147
|
+
code = graph_module_code + "\n\n" + inputs_code
|
148
|
+
if print_output:
|
149
|
+
print(code)
|
150
|
+
else:
|
151
|
+
return code
|
152
|
+
|
153
|
+
def print_code(self, print_output=True):
|
154
|
+
"""Print the Python code for culprit graph module, sample args, and AI
|
155
|
+
|
156
|
+
Edge Torch conversion that will fail with the error.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
print_output: bool - If true, prints the code to stdout. Otherwise returns
|
160
|
+
the code in a str.
|
161
|
+
"""
|
162
|
+
definitions = self.print_readable(print_output=False)
|
163
|
+
code = (
|
164
|
+
"import torch\n"
|
165
|
+
+ "from torch import device\n"
|
166
|
+
+ "import ai_edge_torch\n\n"
|
167
|
+
+ definitions
|
168
|
+
+ "\n\n_edge_model ="
|
169
|
+
f" ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(),"
|
170
|
+
" _args)\n"
|
171
|
+
)
|
172
|
+
if self._runtime_errors:
|
173
|
+
code += "_edge_model(*_args)\n"
|
174
|
+
|
175
|
+
if print_output:
|
176
|
+
print(code)
|
177
|
+
else:
|
178
|
+
return code
|
179
|
+
|
180
|
+
@property
|
181
|
+
def code(self):
|
182
|
+
return self.print_code(print_output=False)
|
183
|
+
|
184
|
+
def __repr__(self):
|
185
|
+
return self.print_readable(print_output=False)
|
186
|
+
|
187
|
+
def __str__(self):
|
188
|
+
return self.print_readable(print_output=False)
|
189
|
+
|
190
|
+
|
191
|
+
def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
|
192
|
+
"""This function turns all operator getitem nodes in ExportedProgram FX graph to
|
193
|
+
|
194
|
+
new nodes composed of "computation + getitem". The normalization duplicates
|
195
|
+
some computations in the graph but would make the graph more friendly for
|
196
|
+
partitioning in FX minifier.
|
197
|
+
"""
|
198
|
+
|
199
|
+
fx_gm = copy.deepcopy(fx_gm)
|
200
|
+
graph = fx_gm.graph
|
201
|
+
for n in graph.nodes:
|
202
|
+
if n.target != operator.getitem:
|
203
|
+
continue
|
204
|
+
|
205
|
+
src_n, key = n.args
|
206
|
+
if src_n.op not in _fx_op_runner:
|
207
|
+
continue
|
208
|
+
|
209
|
+
runner = _fx_op_runner.get(src_n.op)
|
210
|
+
|
211
|
+
with graph.inserting_after(n):
|
212
|
+
new_n = graph.call_function(
|
213
|
+
lambda src_target, key, args, kwargs: operator.getitem(
|
214
|
+
runner(src_target, args, kwargs), key
|
215
|
+
),
|
216
|
+
(src_n.target, key, src_n.args, src_n.kwargs),
|
217
|
+
)
|
218
|
+
n.replace_all_uses_with(new_n)
|
219
|
+
|
220
|
+
graph.eliminate_dead_code()
|
221
|
+
fx_gm.graph = graph
|
222
|
+
return fx_gm
|
223
|
+
|
224
|
+
|
225
|
+
def _erase_unused_inputs(
|
226
|
+
fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
|
227
|
+
):
|
228
|
+
fx_gm = copy.deepcopy(fx_gm)
|
229
|
+
inputs = tuple(inputs)
|
230
|
+
args = fx_gm.graph.process_inputs(*inputs)
|
231
|
+
args_iter = iter(args)
|
232
|
+
|
233
|
+
graph = fx_gm.graph
|
234
|
+
new_inputs = []
|
235
|
+
for n in graph.nodes:
|
236
|
+
if n.op == "placeholder":
|
237
|
+
if n.target.startswith("*"):
|
238
|
+
new_inputs += list(args_iter)
|
239
|
+
elif len(n.users) > 0:
|
240
|
+
new_inputs.append(next(args_iter))
|
241
|
+
else:
|
242
|
+
graph.erase_node(n)
|
243
|
+
next(args_iter)
|
244
|
+
new_inputs = tuple(new_inputs)
|
245
|
+
fx_gm.graph = graph
|
246
|
+
return fx_gm, new_inputs
|
247
|
+
|
248
|
+
|
249
|
+
def _lift_dead_ops_to_outputs(fx_gm: torch.fx.GraphModule):
|
250
|
+
fx_gm = copy.deepcopy(fx_gm)
|
251
|
+
|
252
|
+
new_outputs = []
|
253
|
+
graph = fx_gm.graph
|
254
|
+
nodes = list(graph.nodes)
|
255
|
+
assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
|
256
|
+
for node in nodes:
|
257
|
+
if node.op not in ("placeholder", "output") and len(node.users) == 0:
|
258
|
+
new_outputs.append(node)
|
259
|
+
|
260
|
+
output_node = nodes[-1]
|
261
|
+
# FX output node returns the first arg as is.
|
262
|
+
# ref: https://github.com/pytorch/pytorch/blob/1a578df57cc0f417f671634e564c62ef5d9a97e2/torch/fx/interpreter.py#L337
|
263
|
+
new_outputs, _ = pytree.tree_flatten([new_outputs, output_node.args[0]])
|
264
|
+
output_node.update_arg(0, tuple(new_outputs))
|
265
|
+
|
266
|
+
fx_gm.graph = graph
|
267
|
+
return fx_gm
|
268
|
+
|
269
|
+
|
270
|
+
def _erase_trivial_outputs(fx_gm: torch.fx.GraphModule):
|
271
|
+
"""Remove output nodes directly connected to an input node."""
|
272
|
+
fx_gm = copy.deepcopy(fx_gm)
|
273
|
+
|
274
|
+
graph = fx_gm.graph
|
275
|
+
nodes = list(graph.nodes)
|
276
|
+
assert nodes[-1].op == "output" and sum(n.op == "output" for n in nodes) == 1
|
277
|
+
output_node = nodes[-1]
|
278
|
+
|
279
|
+
outputs, _ = pytree.tree_flatten(output_node.args[0])
|
280
|
+
new_outputs = [output for output in outputs if output.op != "placeholder"]
|
281
|
+
output_node.update_arg(0, tuple(new_outputs))
|
282
|
+
|
283
|
+
fx_gm.recompile()
|
284
|
+
return fx_gm
|
285
|
+
|
286
|
+
|
287
|
+
def _erase_sub_gm_from_gm(
|
288
|
+
fx_gm: torch.fx.GraphModule,
|
289
|
+
fx_inputs: Tuple[torch.Tensor],
|
290
|
+
sub_gm: torch.fx.GraphModule,
|
291
|
+
sub_inputs: Tuple[torch.Tensor],
|
292
|
+
):
|
293
|
+
fx_gm = copy.deepcopy(fx_gm)
|
294
|
+
fx_inputs = list(fx_inputs)
|
295
|
+
|
296
|
+
class EraseNodeInterpreter(torch.fx.Interpreter):
|
297
|
+
|
298
|
+
def run_node(self, node):
|
299
|
+
nonlocal fx_gm, fx_inputs
|
300
|
+
res = super().run_node(node)
|
301
|
+
if node.op not in ("placeholder", "output"):
|
302
|
+
to_erase = next(m for m in fx_gm.graph.nodes if m.name == node.name)
|
303
|
+
# Raise the output (tensor) of the erased node to be an input of
|
304
|
+
# the new model graph. Some raised inputs may become unused later
|
305
|
+
# when all the users are within the erased subgraph, those inputs
|
306
|
+
# will be removed by the followed `_erase_unused_inputs` pass.
|
307
|
+
with fx_gm.graph.inserting_before(to_erase):
|
308
|
+
new_input = fx_gm.graph.placeholder(node.name + "__value")
|
309
|
+
to_erase.replace_all_uses_with(new_input)
|
310
|
+
|
311
|
+
fx_gm.graph.erase_node(to_erase)
|
312
|
+
fx_inputs.append(res)
|
313
|
+
return res
|
314
|
+
|
315
|
+
interpreter = EraseNodeInterpreter(sub_gm)
|
316
|
+
interpreter.run(*sub_inputs)
|
317
|
+
|
318
|
+
fx_gm.graph.lint()
|
319
|
+
fx_gm.recompile()
|
320
|
+
|
321
|
+
# Ops prior to the erased subgraph may be dangling. Lift them as outputs.
|
322
|
+
fx_gm = _lift_dead_ops_to_outputs(fx_gm)
|
323
|
+
fx_gm = _erase_trivial_outputs(fx_gm)
|
324
|
+
fx_gm, fx_inputs = _erase_unused_inputs(fx_gm, fx_inputs)
|
325
|
+
|
326
|
+
fx_gm.graph.lint()
|
327
|
+
fx_gm.recompile()
|
328
|
+
return fx_gm, fx_inputs
|
329
|
+
|
330
|
+
|
331
|
+
def _normalize_minified_fx_gm(
|
332
|
+
fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
|
333
|
+
):
|
334
|
+
fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
|
335
|
+
fx_gm = _lift_dead_ops_to_outputs(fx_gm)
|
336
|
+
fx_gm, _ = aot_autograd.aot_export_module(fx_gm, inputs, trace_joint=False)
|
337
|
+
fx_gm.__class__.__name__ = _CULPRIT_GRAPH_MODULE_NAME
|
338
|
+
return fx_gm, inputs
|
339
|
+
|
340
|
+
|
341
|
+
def _fx_minifier_checker(fx_gm, inputs, runtime_errors=False):
|
342
|
+
fx_gm, inputs = _normalize_minified_fx_gm(fx_gm, inputs)
|
343
|
+
|
344
|
+
trivial_aten_ops = {
|
345
|
+
torch.ops.aten.view,
|
346
|
+
torch.ops.aten.view.default,
|
347
|
+
}
|
348
|
+
if all(
|
349
|
+
node.op in ("placeholder", "output") or node.target in trivial_aten_ops
|
350
|
+
for node in fx_gm.graph.nodes
|
351
|
+
):
|
352
|
+
return False
|
353
|
+
|
354
|
+
try:
|
355
|
+
edge_model = ai_edge_torch.convert(fx_gm.eval(), inputs)
|
356
|
+
if runtime_errors:
|
357
|
+
edge_model(*inputs)
|
358
|
+
except Exception as err:
|
359
|
+
return True
|
360
|
+
return False
|
361
|
+
|
362
|
+
|
363
|
+
def _search_model(
|
364
|
+
predicate_f: Callable[[torch.fx.GraphModule, List[Any]], bool],
|
365
|
+
model: Union[torch.export.ExportedProgram, torch.nn.Module],
|
366
|
+
export_args: Tuple[Any] = None,
|
367
|
+
*,
|
368
|
+
max_granularity: Optional[int] = None,
|
369
|
+
enable_fx_minifier_logging: bool = False,
|
370
|
+
) -> Generator[SearchResult, None, None]:
|
371
|
+
"""Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
predicate_f: a predicate function the users specify. It takes a FX
|
375
|
+
(sub)graph and the inputs to this graph, return True if the graph
|
376
|
+
satisfies the predicate, return False otherwise.
|
377
|
+
model: model in which to search subgraph.
|
378
|
+
export_args: A set of args to trace the model with, i.e. model(*args) must
|
379
|
+
run. max_granularity - FX minifier arg. The maximum granularity (number of
|
380
|
+
nodes) in the returned ATen FX subgraph of the culprit.
|
381
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to
|
382
|
+
log the progress.
|
383
|
+
"""
|
384
|
+
|
385
|
+
if isinstance(model, torch.nn.Module):
|
386
|
+
try:
|
387
|
+
ep = torch.export.export(model, export_args)
|
388
|
+
except Exception as err:
|
389
|
+
raise ValueError(
|
390
|
+
"Your model is not exportable by torch.export.export. Please modify"
|
391
|
+
" your model to be torch-exportable first."
|
392
|
+
) from err
|
393
|
+
else:
|
394
|
+
ep = model
|
395
|
+
|
396
|
+
fx_gm, fx_inputs = utils.exported_program_to_fx_graph_module_and_inputs(ep)
|
397
|
+
fx_gm = _normalize_getitem_nodes(fx_gm)
|
398
|
+
|
399
|
+
# HACK: temporarily disable XLA_HLO_DEBUG and create_minified_hlo_graph so that
|
400
|
+
# fx_minifier won't dump intermediate stablehlo files to storage.
|
401
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/_functorch/fx_minifier.py#L440
|
402
|
+
@contextlib.contextmanager
|
403
|
+
def disable_minifier_xla_debug():
|
404
|
+
xla_hlo_debug_value = None
|
405
|
+
if "XLA_HLO_DEBUG" in os.environ:
|
406
|
+
xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
|
407
|
+
del os.environ["XLA_HLO_DEBUG"]
|
408
|
+
|
409
|
+
create_minified_hlo_graph = (
|
410
|
+
torch._functorch.fx_minifier.create_minified_hlo_graph
|
411
|
+
)
|
412
|
+
torch._functorch.fx_minifier.create_minified_hlo_graph = (
|
413
|
+
lambda *args, **kwargs: None
|
414
|
+
)
|
415
|
+
|
416
|
+
try:
|
417
|
+
yield
|
418
|
+
finally:
|
419
|
+
if xla_hlo_debug_value is not None:
|
420
|
+
os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
|
421
|
+
|
422
|
+
torch._functorch.fx_minifier.create_minified_hlo_graph = (
|
423
|
+
create_minified_hlo_graph
|
424
|
+
)
|
425
|
+
|
426
|
+
found_culprits_num = 0
|
427
|
+
while True:
|
428
|
+
try:
|
429
|
+
with disable_minifier_xla_debug(), open(os.devnull, "w") as devnull:
|
430
|
+
with contextlib.nullcontext() if enable_fx_minifier_logging else utils.redirect_stdio(
|
431
|
+
stdout=devnull,
|
432
|
+
stderr=devnull,
|
433
|
+
):
|
434
|
+
raw_min_fx_gm, raw_min_inputs = fx_minifier(
|
435
|
+
fx_gm,
|
436
|
+
fx_inputs,
|
437
|
+
predicate_f,
|
438
|
+
max_granularity=max_granularity,
|
439
|
+
)
|
440
|
+
|
441
|
+
min_fx_gm, min_inputs = _normalize_minified_fx_gm(
|
442
|
+
raw_min_fx_gm, raw_min_inputs
|
443
|
+
)
|
444
|
+
found_culprits_num += 1
|
445
|
+
yield SearchResult(min_fx_gm, min_inputs)
|
446
|
+
|
447
|
+
fx_gm, fx_inputs = _erase_sub_gm_from_gm(
|
448
|
+
fx_gm, fx_inputs, raw_min_fx_gm, raw_min_inputs
|
449
|
+
)
|
450
|
+
|
451
|
+
except RuntimeError as e:
|
452
|
+
if (
|
453
|
+
str(e) == "Input graph did not fail the tester"
|
454
|
+
and found_culprits_num > 0
|
455
|
+
):
|
456
|
+
break
|
457
|
+
raise e
|
458
|
+
|
459
|
+
|
460
|
+
def find_culprits(
|
461
|
+
torch_model: torch.nn.Module,
|
462
|
+
args: Tuple[Any],
|
463
|
+
max_granularity: Optional[int] = None,
|
464
|
+
runtime_errors: bool = False,
|
465
|
+
*,
|
466
|
+
enable_fx_minifier_logging: bool = False,
|
467
|
+
) -> Generator[Culprit, None, None]:
|
468
|
+
"""Finds culprits in the AI Edge Torch model conversion.
|
469
|
+
|
470
|
+
Args:
|
471
|
+
torch_model: model to export and save
|
472
|
+
args: A set of args to trace the model with, i.e. torch_model(*args) must
|
473
|
+
run max_granularity - FX minifier arg. The maximum granularity (number of
|
474
|
+
nodes) in the returned ATen FX subgraph of the culprit.
|
475
|
+
runtime_errors: If true, find culprits for Python runtime errors with
|
476
|
+
converted model.
|
477
|
+
enable_fx_minifier_logging: If true, allows the underlying FX minifier to
|
478
|
+
log the progress.
|
479
|
+
"""
|
480
|
+
|
481
|
+
fx_minifier_checker = functools.partial(
|
482
|
+
_fx_minifier_checker, runtime_errors=runtime_errors
|
483
|
+
)
|
484
|
+
|
485
|
+
for search_result in _search_model(
|
486
|
+
fx_minifier_checker,
|
487
|
+
torch_model,
|
488
|
+
args,
|
489
|
+
max_granularity=max_granularity,
|
490
|
+
enable_fx_minifier_logging=enable_fx_minifier_logging,
|
491
|
+
):
|
492
|
+
yield Culprit(
|
493
|
+
search_result.graph_module,
|
494
|
+
search_result.inputs,
|
495
|
+
_runtime_errors=runtime_errors,
|
496
|
+
)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import ast
|
18
|
+
|
19
|
+
import ai_edge_torch.debug
|
20
|
+
import torch
|
21
|
+
|
22
|
+
from absl.testing import absltest as googletest
|
23
|
+
|
24
|
+
find_culprits = ai_edge_torch.debug.find_culprits
|
25
|
+
|
26
|
+
_test_culprit_lib = torch.library.Library("test_culprit", "DEF")
|
27
|
+
|
28
|
+
_test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
|
29
|
+
|
30
|
+
|
31
|
+
@torch.library.impl(
|
32
|
+
_test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd"
|
33
|
+
)
|
34
|
+
def non_lowerable_op(x):
|
35
|
+
if x.max() > 10.0:
|
36
|
+
return x + 1.0
|
37
|
+
return x
|
38
|
+
|
39
|
+
|
40
|
+
@torch.library.impl(_test_culprit_lib, "non_lowerable_op", "Meta")
|
41
|
+
def non_lowerable_op_meta(x):
|
42
|
+
return torch.empty_like(x)
|
43
|
+
|
44
|
+
|
45
|
+
class BadModel(torch.nn.Module):
|
46
|
+
|
47
|
+
def forward(self, x):
|
48
|
+
x = x + 1
|
49
|
+
x = torch.ops.test_culprit.non_lowerable_op.default(x)
|
50
|
+
return x
|
51
|
+
|
52
|
+
|
53
|
+
class TestCulprit(googletest.TestCase):
|
54
|
+
|
55
|
+
def setUp(self):
|
56
|
+
super().setUp()
|
57
|
+
torch.manual_seed(0)
|
58
|
+
torch._dynamo.reset()
|
59
|
+
|
60
|
+
def test_find_culprits(self):
|
61
|
+
model = BadModel().eval()
|
62
|
+
args = (torch.rand(10),)
|
63
|
+
|
64
|
+
culprits = list(find_culprits(model, args))
|
65
|
+
self.assertEqual(len(culprits), 1)
|
66
|
+
self.assertIn(
|
67
|
+
torch.ops.test_culprit.non_lowerable_op.default,
|
68
|
+
[n.target for n in culprits[0].graph.nodes],
|
69
|
+
)
|
70
|
+
|
71
|
+
def test_valid_culprit_readable(self):
|
72
|
+
model = BadModel().eval()
|
73
|
+
args = (torch.rand(10),)
|
74
|
+
|
75
|
+
culprits = list(find_culprits(model, args))
|
76
|
+
self.assertEqual(len(culprits), 1)
|
77
|
+
|
78
|
+
code = culprits[0].print_readable(print_output=False)
|
79
|
+
|
80
|
+
# The code should be a valid Python code
|
81
|
+
ast.parse(code)
|
82
|
+
|
83
|
+
def test_valid_culprit_code(self):
|
84
|
+
model = BadModel().eval()
|
85
|
+
args = (torch.rand(10),)
|
86
|
+
|
87
|
+
culprits = list(find_culprits(model, args))
|
88
|
+
self.assertEqual(len(culprits), 1)
|
89
|
+
|
90
|
+
code = culprits[0].print_code(print_output=False)
|
91
|
+
|
92
|
+
# The code should be a valid Python code
|
93
|
+
ast.parse(code)
|
94
|
+
|
95
|
+
def test_find_multiple_culprits(self):
|
96
|
+
class MultiBadOpsModel(torch.nn.Module):
|
97
|
+
|
98
|
+
def forward(self, x):
|
99
|
+
x = x + 1
|
100
|
+
a = torch.ops.test_culprit.non_lowerable_op.default(x)
|
101
|
+
b = torch.ops.test_culprit.non_lowerable_op.default(x)
|
102
|
+
c = a + b
|
103
|
+
d = torch.ops.test_culprit.non_lowerable_op.default(c)
|
104
|
+
return d
|
105
|
+
|
106
|
+
model = MultiBadOpsModel().eval()
|
107
|
+
args = (torch.rand(10),)
|
108
|
+
|
109
|
+
culprits = list(find_culprits(model, args))
|
110
|
+
self.assertEqual(len(culprits), 3)
|
111
|
+
for culprit in culprits:
|
112
|
+
self.assertIn(
|
113
|
+
torch.ops.test_culprit.non_lowerable_op.default,
|
114
|
+
[n.target for n in culprit.graph.nodes],
|
115
|
+
)
|
116
|
+
|
117
|
+
def test_find_culprits_with_trivial_inputs_outputs(self):
|
118
|
+
|
119
|
+
class MultiBadOpsModel(torch.nn.Module):
|
120
|
+
|
121
|
+
def forward(self, x, y, z):
|
122
|
+
x = x + 1
|
123
|
+
a = torch.ops.test_culprit.non_lowerable_op.default(x)
|
124
|
+
b = torch.ops.test_culprit.non_lowerable_op.default(y)
|
125
|
+
return a, b, x, y, a, b
|
126
|
+
|
127
|
+
model = MultiBadOpsModel().eval()
|
128
|
+
args = (torch.rand(10), torch.rand(10), torch.rand(10))
|
129
|
+
|
130
|
+
culprits = list(find_culprits(model, args))
|
131
|
+
self.assertEqual(len(culprits), 2)
|
132
|
+
for culprit in culprits:
|
133
|
+
self.assertIn(
|
134
|
+
torch.ops.test_culprit.non_lowerable_op.default,
|
135
|
+
[n.target for n in culprit.graph.nodes],
|
136
|
+
)
|
137
|
+
|
138
|
+
|
139
|
+
if __name__ == "__main__":
|
140
|
+
googletest.main()
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Tests for search_model."""
|
16
|
+
|
17
|
+
from ai_edge_torch.debug import _search_model
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from absl.testing import absltest as googletest
|
21
|
+
|
22
|
+
|
23
|
+
class TestSearchModel(googletest.TestCase):
|
24
|
+
|
25
|
+
def test_search_model_with_ops(self):
|
26
|
+
class MultipleOpsModel(torch.nn.Module):
|
27
|
+
|
28
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
29
|
+
sub_0 = x - 1
|
30
|
+
add_0 = y + 1
|
31
|
+
mul_0 = x * y
|
32
|
+
add_1 = sub_0 + add_0
|
33
|
+
mul_1 = add_0 * mul_0
|
34
|
+
sub_1 = add_1 - mul_1
|
35
|
+
return sub_1
|
36
|
+
|
37
|
+
model = MultipleOpsModel().eval()
|
38
|
+
args = (torch.rand(10), torch.rand(10))
|
39
|
+
|
40
|
+
def find_subgraph_with_sub(fx_gm, inputs):
|
41
|
+
return torch.ops.aten.sub.Tensor in [n.target for n in fx_gm.graph.nodes]
|
42
|
+
|
43
|
+
results = list(_search_model(find_subgraph_with_sub, model, args))
|
44
|
+
self.assertEqual(len(results), 2)
|
45
|
+
self.assertIn(
|
46
|
+
torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes]
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
if __name__ == "__main__":
|
51
|
+
googletest.main()
|