ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import os
|
|
16
|
+
from typing import Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.ao.quantization.quantize_pt2e
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
from torch.fx import GraphModule
|
|
22
|
+
from torch.fx import Node
|
|
23
|
+
import torch.utils._pytree as pytree
|
|
24
|
+
|
|
25
|
+
from ai_edge_torch.convert.fx_passes import ExportedProgramPassBase
|
|
26
|
+
from ai_edge_torch.convert.fx_passes import ExportedProgramPassResult
|
|
27
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
28
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
29
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
|
|
30
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
|
|
31
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
|
|
32
|
+
|
|
33
|
+
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
|
|
37
|
+
|
|
38
|
+
def get_source_meta(self, node: torch.fx.Node):
|
|
39
|
+
keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
|
|
40
|
+
meta = {}
|
|
41
|
+
for key in keys:
|
|
42
|
+
if key in node.meta:
|
|
43
|
+
meta[key] = node.meta[key]
|
|
44
|
+
return meta
|
|
45
|
+
|
|
46
|
+
def insert_t_q_dq(
|
|
47
|
+
self,
|
|
48
|
+
graph: torch.fx.Graph,
|
|
49
|
+
input_dq: torch.fx.Node,
|
|
50
|
+
target: torch.fx.Node,
|
|
51
|
+
transpose_func: TransposeFunc,
|
|
52
|
+
transpose_node_meta: dict,
|
|
53
|
+
) -> list[torch.fx.Node]:
|
|
54
|
+
"""
|
|
55
|
+
original:
|
|
56
|
+
input_dq -> target
|
|
57
|
+
insert the node as:
|
|
58
|
+
input_dq -> (T q dq) -> target
|
|
59
|
+
"""
|
|
60
|
+
assert utils.is_dq_node(input_dq)
|
|
61
|
+
|
|
62
|
+
q_args = input_dq.args[1:]
|
|
63
|
+
q_kwargs = input_dq.kwargs
|
|
64
|
+
q_op, dq_op = utils.get_paired_q_dq_ops(input_dq.target)
|
|
65
|
+
with graph.inserting_before(target):
|
|
66
|
+
t = graph.call_function(transpose_func, (input_dq,))
|
|
67
|
+
# Q and DQ inserted here may required updating the `axis` arg when they
|
|
68
|
+
# are per_channel ops. However, instead of updating here, the nodes would
|
|
69
|
+
# be marked as NHWC/NCHW and applied rewriters after partitioning.
|
|
70
|
+
q = graph.call_function(q_op, (t,) + q_args, q_kwargs)
|
|
71
|
+
dq = graph.call_function(dq_op, (q,) + q_args, q_kwargs)
|
|
72
|
+
|
|
73
|
+
input_dq.meta = transpose_node_meta
|
|
74
|
+
t.meta = transpose_node_meta
|
|
75
|
+
q.meta = transpose_node_meta
|
|
76
|
+
dq.meta = self.get_source_meta(target)
|
|
77
|
+
|
|
78
|
+
target.replace_input_with(input_dq, dq)
|
|
79
|
+
return [t, q, dq]
|
|
80
|
+
|
|
81
|
+
def insert_dq_t_q(
|
|
82
|
+
self,
|
|
83
|
+
graph: torch.fx.Graph,
|
|
84
|
+
input_q: torch.fx.Node,
|
|
85
|
+
target: torch.fx.Node,
|
|
86
|
+
transpose_func: TransposeFunc,
|
|
87
|
+
transpose_node_meta: dict,
|
|
88
|
+
) -> list[torch.fx.Node]:
|
|
89
|
+
"""
|
|
90
|
+
original:
|
|
91
|
+
input_q -> target
|
|
92
|
+
insert the node as:
|
|
93
|
+
input_q -> (dq T q) -> target
|
|
94
|
+
"""
|
|
95
|
+
assert utils.is_q_node(input_q)
|
|
96
|
+
|
|
97
|
+
q_args = input_q.args[1:]
|
|
98
|
+
q_kwargs = input_q.kwargs
|
|
99
|
+
q_op, dq_op = self.get_paired_q_dq_ops(input_q.target)
|
|
100
|
+
with graph.inserting_before(target):
|
|
101
|
+
# Q and DQ inserted here may required updating the `axis` arg when they
|
|
102
|
+
# are per_channel ops. However, instead of updating here, the nodes would
|
|
103
|
+
# be marked as NHWC/NCHW and applied rewriters after partitioning.
|
|
104
|
+
dq = graph.call_function(dq_op, (input_q,) + q_args, q_kwargs)
|
|
105
|
+
t = graph.call_function(transpose_func, (dq,))
|
|
106
|
+
q = graph.call_function(q_op, (t,) + q_args, q_kwargs)
|
|
107
|
+
|
|
108
|
+
dq.meta = transpose_node_meta
|
|
109
|
+
t.meta = transpose_node_meta
|
|
110
|
+
q.meta = transpose_node_meta
|
|
111
|
+
|
|
112
|
+
target.replace_input_with(input_q, q)
|
|
113
|
+
return [dq, t, q]
|
|
114
|
+
|
|
115
|
+
def insert_layout_transpose(
|
|
116
|
+
self,
|
|
117
|
+
graph: torch.fx.Graph,
|
|
118
|
+
input_node: torch.fx.Node,
|
|
119
|
+
target_node: torch.fx.Node,
|
|
120
|
+
transpose_func: TransposeFunc,
|
|
121
|
+
transpose_node_meta: dict,
|
|
122
|
+
) -> None:
|
|
123
|
+
assert transpose_func in (utils.tensor_to_nchw, utils.tensor_to_nhwc)
|
|
124
|
+
|
|
125
|
+
# new_nodes only contains Q/DQ/Transpose nodes, which are all SISO.
|
|
126
|
+
# Insertion order input nodes -> output nodes
|
|
127
|
+
new_nodes = []
|
|
128
|
+
|
|
129
|
+
# Constraint Q2: the NHWC partition's entry and exit must not be output
|
|
130
|
+
# edges of Q/DQ ops that are connected to a constant/weight tensor.
|
|
131
|
+
while layout_mark.is_const_node(input_node) and (
|
|
132
|
+
utils.is_dq_node(input_node) or utils.is_q_node(input_node)
|
|
133
|
+
):
|
|
134
|
+
with graph.inserting_before(target_node):
|
|
135
|
+
new_input_node = graph.node_copy(input_node)
|
|
136
|
+
|
|
137
|
+
target_node.replace_input_with(input_node, new_input_node)
|
|
138
|
+
|
|
139
|
+
new_nodes = [new_input_node] + new_nodes
|
|
140
|
+
input_node, target_node = new_input_node.args[0], new_input_node
|
|
141
|
+
|
|
142
|
+
if utils.is_q_node(input_node):
|
|
143
|
+
# Constraint Q3: when the entry and exit is right after a q op (occur after a (dq-op-q)
|
|
144
|
+
# triplet), the transpose must be added as a quantized transpose in (dq-T-q)
|
|
145
|
+
# input_q -> (dq T q) -> target
|
|
146
|
+
new_nodes = (
|
|
147
|
+
self.insert_dq_t_q(
|
|
148
|
+
graph,
|
|
149
|
+
input_node,
|
|
150
|
+
target_node,
|
|
151
|
+
transpose_func,
|
|
152
|
+
transpose_node_meta,
|
|
153
|
+
)
|
|
154
|
+
+ new_nodes
|
|
155
|
+
)
|
|
156
|
+
elif utils.is_dq_node(input_node):
|
|
157
|
+
# Constraint Q1: the NHWC partition's entry and exit cannot be edges
|
|
158
|
+
# within (dq-op-q) triplet.
|
|
159
|
+
# input_dq -> (T q dq) -> target
|
|
160
|
+
new_nodes = (
|
|
161
|
+
self.insert_t_q_dq(
|
|
162
|
+
graph,
|
|
163
|
+
input_node,
|
|
164
|
+
target_node,
|
|
165
|
+
transpose_func,
|
|
166
|
+
transpose_node_meta,
|
|
167
|
+
)
|
|
168
|
+
+ new_nodes
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
# input -> target
|
|
172
|
+
with graph.inserting_before(target_node):
|
|
173
|
+
t = graph.call_function(transpose_func, (input_node,))
|
|
174
|
+
t.meta = transpose_node_meta
|
|
175
|
+
target_node.replace_input_with(input_node, t)
|
|
176
|
+
new_nodes = [t] + new_nodes
|
|
177
|
+
|
|
178
|
+
# Mark new nodes as NCHW or NHWC
|
|
179
|
+
# For all nodes before the transpose, mark it as input_marker
|
|
180
|
+
# For all nodes after the transpose (incl. transpose), mark it as output_marker
|
|
181
|
+
if transpose_func == utils.tensor_to_nchw:
|
|
182
|
+
input_marker, target_marker = (
|
|
183
|
+
layout_mark.mark_as_nhwc_node,
|
|
184
|
+
layout_mark.mark_as_nchw_node,
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
input_marker, target_marker = (
|
|
188
|
+
layout_mark.mark_as_nchw_node,
|
|
189
|
+
layout_mark.mark_as_nhwc_node,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
marker = input_marker
|
|
193
|
+
for node in new_nodes:
|
|
194
|
+
if node.target == transpose_func:
|
|
195
|
+
marker = target_marker
|
|
196
|
+
marker(node)
|
|
197
|
+
assert marker == target_marker
|
|
198
|
+
|
|
199
|
+
def input_to_nhwc(
|
|
200
|
+
self,
|
|
201
|
+
graph: torch.fx.Graph,
|
|
202
|
+
input_node: torch.fx.Node,
|
|
203
|
+
target_node: torch.fx.Node,
|
|
204
|
+
) -> None:
|
|
205
|
+
if layout_mark.is_nhwc_node(input_node):
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
if not layout_check.is_4d(input_node):
|
|
209
|
+
raise AssertionError(
|
|
210
|
+
f"Attempting to convert non-NHWC compatible node to NHWC: {input_node}"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Assign target node's source meta to the to_NHWC node, because the transpose
|
|
214
|
+
# is added for the existence of target node.
|
|
215
|
+
self.insert_layout_transpose(
|
|
216
|
+
graph,
|
|
217
|
+
input_node,
|
|
218
|
+
target_node,
|
|
219
|
+
utils.tensor_to_nhwc,
|
|
220
|
+
self.get_source_meta(target_node),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def input_to_nchw(
|
|
224
|
+
self,
|
|
225
|
+
graph: torch.fx.Graph,
|
|
226
|
+
input_node: torch.fx.Node,
|
|
227
|
+
target_node: torch.fx.Node,
|
|
228
|
+
) -> None:
|
|
229
|
+
if layout_mark.is_nchw_node(input_node):
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
self.insert_layout_transpose(
|
|
233
|
+
graph,
|
|
234
|
+
input_node,
|
|
235
|
+
target_node,
|
|
236
|
+
utils.tensor_to_nchw,
|
|
237
|
+
self.get_source_meta(input_node),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def mark_const_nodes(self, exported_program: torch.export.ExportedProgram):
|
|
241
|
+
graph_module = exported_program.graph_module
|
|
242
|
+
graph = graph_module.graph
|
|
243
|
+
|
|
244
|
+
input_specs = exported_program.graph_signature.input_specs
|
|
245
|
+
non_user_input_names = set()
|
|
246
|
+
for spec in input_specs:
|
|
247
|
+
if spec.kind != torch.export.graph_signature.InputKind.USER_INPUT:
|
|
248
|
+
non_user_input_names.add(spec.arg.name)
|
|
249
|
+
|
|
250
|
+
for node in graph.nodes:
|
|
251
|
+
has_input_nodes = len(node.all_input_nodes) > 0
|
|
252
|
+
all_inputs_are_const = all(map(layout_mark.is_const_node, node.all_input_nodes))
|
|
253
|
+
if (
|
|
254
|
+
node.name in non_user_input_names
|
|
255
|
+
or (has_input_nodes and all_inputs_are_const)
|
|
256
|
+
or (node.op != "placeholder" and not has_input_nodes)
|
|
257
|
+
):
|
|
258
|
+
layout_mark.mark_as_const_node(node)
|
|
259
|
+
|
|
260
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
|
261
|
+
self.mark_const_nodes(exported_program)
|
|
262
|
+
|
|
263
|
+
graph_module = exported_program.graph_module
|
|
264
|
+
partitioner = os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None)
|
|
265
|
+
if partitioner == "MINCUT":
|
|
266
|
+
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
|
267
|
+
elif partitioner == "GREEDY":
|
|
268
|
+
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
269
|
+
else:
|
|
270
|
+
# By default use min cut partitioner if possible
|
|
271
|
+
if layout_partitioners.min_cut.can_partition(graph_module):
|
|
272
|
+
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
|
273
|
+
else:
|
|
274
|
+
graph_module = layout_partitioners.greedy.partition(graph_module)
|
|
275
|
+
|
|
276
|
+
graph = graph_module.graph
|
|
277
|
+
for node in list(graph.nodes):
|
|
278
|
+
if layout_mark.is_nhwc_node(node):
|
|
279
|
+
for input_node in layout_check.get_layout_sensitive_inputs(node):
|
|
280
|
+
self.input_to_nhwc(graph, input_node, node)
|
|
281
|
+
layout_rewrite.rewrite_nhwc_node(node)
|
|
282
|
+
else:
|
|
283
|
+
for input_node in layout_check.get_layout_sensitive_inputs(node):
|
|
284
|
+
# Note: for non-4D tensors input_to_nchw is always noop.
|
|
285
|
+
self.input_to_nchw(graph, input_node, node)
|
|
286
|
+
|
|
287
|
+
graph_module.graph.eliminate_dead_code()
|
|
288
|
+
graph_module.recompile()
|
|
289
|
+
graph_module.graph.lint()
|
|
290
|
+
# Mark const node again for debugging
|
|
291
|
+
self.mark_const_nodes(exported_program)
|
|
292
|
+
|
|
293
|
+
return ExportedProgramPassResult(exported_program, True)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from typing import Callable
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.ao.quantization.quantize_pt2e
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def tensor_to_nhwc(t: torch.Tensor):
|
|
22
|
+
return torch.ops.aten.permute(t.contiguous(), [0, 2, 3, 1]).contiguous()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def tensor_to_nchw(t: torch.Tensor):
|
|
26
|
+
return torch.ops.aten.permute(t.contiguous(), [0, 3, 1, 2]).contiguous()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def flatten_torch_op_overloads(op):
|
|
30
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
31
|
+
return [getattr(op, overload) for overload in op.overloads()]
|
|
32
|
+
return [op]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_TORCH_Q_OPS = [
|
|
36
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
37
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
|
38
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor2,
|
|
39
|
+
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
_TORCH_DQ_OPS = [
|
|
43
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
44
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
|
45
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2,
|
|
46
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def is_q_node(node: torch.fx.Node):
|
|
51
|
+
return node.target in _TORCH_Q_OPS
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def is_dq_node(node: torch.fx.Node):
|
|
55
|
+
return node.target in _TORCH_DQ_OPS
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_paired_q_dq_ops(op: Callable) -> tuple[Callable, Callable]:
|
|
59
|
+
for q, dq in zip(_TORCH_Q_OPS, _TORCH_DQ_OPS):
|
|
60
|
+
if op in (q, dq):
|
|
61
|
+
return q, dq
|
|
62
|
+
raise AssertionError(f"{op} is not a Q/DQ op.")
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torchvision
|
|
23
|
+
|
|
24
|
+
import ai_edge_torch
|
|
25
|
+
from ai_edge_torch.convert import conversion_utils as cutils
|
|
26
|
+
from ai_edge_torch.testing import model_coverage
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestConvert(unittest.TestCase):
|
|
30
|
+
"""Tests conversion of various modules."""
|
|
31
|
+
|
|
32
|
+
def setUp(self):
|
|
33
|
+
torch.manual_seed(0)
|
|
34
|
+
|
|
35
|
+
def test_convert_add(self):
|
|
36
|
+
"""Tests conversion of a simple Add module."""
|
|
37
|
+
|
|
38
|
+
class Add(torch.nn.Module):
|
|
39
|
+
|
|
40
|
+
def forward(self, a, b):
|
|
41
|
+
return a + b
|
|
42
|
+
|
|
43
|
+
args = (
|
|
44
|
+
torch.randn((5, 10)),
|
|
45
|
+
torch.randn((5, 10)),
|
|
46
|
+
)
|
|
47
|
+
torch_module = Add().eval()
|
|
48
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
49
|
+
|
|
50
|
+
self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
|
|
51
|
+
|
|
52
|
+
def test_convert_dot_add(self):
|
|
53
|
+
class DotAdd(torch.nn.Module):
|
|
54
|
+
"""Tests conversion of a matrix multiplication followed by an add."""
|
|
55
|
+
|
|
56
|
+
def forward(self, a, b, c):
|
|
57
|
+
return a @ b + c
|
|
58
|
+
|
|
59
|
+
args = (
|
|
60
|
+
torch.randn((5, 10)),
|
|
61
|
+
torch.randn((10, 5)),
|
|
62
|
+
torch.randn((5, 5)),
|
|
63
|
+
)
|
|
64
|
+
torch_module = DotAdd().eval()
|
|
65
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
66
|
+
|
|
67
|
+
self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
|
|
68
|
+
|
|
69
|
+
def test_convert_resnet18(self):
|
|
70
|
+
args = (torch.randn(4, 3, 224, 224),)
|
|
71
|
+
torch_module = torchvision.models.resnet18().eval()
|
|
72
|
+
edge_model = ai_edge_torch.convert(torch_module, args)
|
|
73
|
+
|
|
74
|
+
self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
|
|
75
|
+
|
|
76
|
+
def test_signature_args_ordering(self):
|
|
77
|
+
"""Tests conversion of a model with more than 10 arguments."""
|
|
78
|
+
|
|
79
|
+
class AddChainWith11Args(torch.nn.Module):
|
|
80
|
+
|
|
81
|
+
def forward(
|
|
82
|
+
self,
|
|
83
|
+
arg0: "f32[64]",
|
|
84
|
+
arg1: "f32[64]",
|
|
85
|
+
arg2: "f32[64]",
|
|
86
|
+
arg3: "f32[64]",
|
|
87
|
+
arg4: "f32[64]",
|
|
88
|
+
arg5: "f32[64]",
|
|
89
|
+
arg6: "f32[64]",
|
|
90
|
+
arg7: "f32[64]",
|
|
91
|
+
arg8: "f32[64]",
|
|
92
|
+
arg9: "f32[64]",
|
|
93
|
+
arg10: "f32[64]",
|
|
94
|
+
):
|
|
95
|
+
add0 = torch.add(arg0, arg1)
|
|
96
|
+
add1 = torch.add(add0, arg2)
|
|
97
|
+
add2 = torch.add(add1, arg3)
|
|
98
|
+
add3 = torch.add(add2, arg4)
|
|
99
|
+
add4 = torch.add(add3, arg5)
|
|
100
|
+
add5 = torch.add(add4, arg6)
|
|
101
|
+
add6 = torch.add(add5, arg7)
|
|
102
|
+
add7 = torch.add(add6, arg8)
|
|
103
|
+
add8 = torch.add(add7, arg9)
|
|
104
|
+
add9 = torch.add(add8, arg10)
|
|
105
|
+
return add9
|
|
106
|
+
|
|
107
|
+
sample_input = lambda: (
|
|
108
|
+
torch.rand((64,), dtype=torch.float32),
|
|
109
|
+
torch.rand((64,), dtype=torch.float32),
|
|
110
|
+
torch.rand((64,), dtype=torch.float32),
|
|
111
|
+
torch.rand((64,), dtype=torch.float32),
|
|
112
|
+
torch.rand((64,), dtype=torch.float32),
|
|
113
|
+
torch.rand((64,), dtype=torch.float32),
|
|
114
|
+
torch.rand((64,), dtype=torch.float32),
|
|
115
|
+
torch.rand((64,), dtype=torch.float32),
|
|
116
|
+
torch.rand((64,), dtype=torch.float32),
|
|
117
|
+
torch.rand((64,), dtype=torch.float32),
|
|
118
|
+
torch.rand((64,), dtype=torch.float32),
|
|
119
|
+
)
|
|
120
|
+
torch_model = AddChainWith11Args().eval()
|
|
121
|
+
edge_model = ai_edge_torch.convert(torch_model, sample_input())
|
|
122
|
+
|
|
123
|
+
result = model_coverage.compare_tflite_torch(
|
|
124
|
+
edge_model, torch_model, sample_input, num_valid_inputs=10
|
|
125
|
+
)
|
|
126
|
+
self.assertTrue(result)
|
|
127
|
+
|
|
128
|
+
def test_multi_output_model(self):
|
|
129
|
+
"""Tests conversion of a model that returns multiple outputs."""
|
|
130
|
+
|
|
131
|
+
class BasicAddModelWithMultipleOutputs(torch.nn.Module):
|
|
132
|
+
|
|
133
|
+
def forward(self, arg0, arg1):
|
|
134
|
+
add0 = arg0 + arg1
|
|
135
|
+
mul0 = arg0 * arg1
|
|
136
|
+
return add0, mul0
|
|
137
|
+
|
|
138
|
+
sample_input = (
|
|
139
|
+
torch.rand((64,), dtype=torch.float32),
|
|
140
|
+
torch.rand((64,), dtype=torch.float32),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
torch_model = BasicAddModelWithMultipleOutputs().eval()
|
|
144
|
+
edge_model = ai_edge_torch.convert(torch_model, sample_input)
|
|
145
|
+
|
|
146
|
+
result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
|
|
147
|
+
self.assertTrue(result)
|
|
148
|
+
|
|
149
|
+
def test_12_outputs_model(self):
|
|
150
|
+
"""Tests conversion of a model that returns multiple outputs."""
|
|
151
|
+
|
|
152
|
+
class BasicAddModelWithMultipleOutputs(torch.nn.Module):
|
|
153
|
+
|
|
154
|
+
def forward(self, arg0, arg1):
|
|
155
|
+
add0 = arg0 + arg1
|
|
156
|
+
mul0 = arg0 * arg1
|
|
157
|
+
add1 = add0 + mul0
|
|
158
|
+
mul1 = add0 * mul0
|
|
159
|
+
add2 = add1 + mul1
|
|
160
|
+
mul2 = add1 * mul1
|
|
161
|
+
add3 = add2 + mul2
|
|
162
|
+
mul3 = add2 * mul2
|
|
163
|
+
add4 = add3 + mul3
|
|
164
|
+
mul4 = add3 * mul3
|
|
165
|
+
add5 = add4 + mul4
|
|
166
|
+
mul5 = add4 * mul4
|
|
167
|
+
|
|
168
|
+
return (
|
|
169
|
+
add0,
|
|
170
|
+
mul0,
|
|
171
|
+
add1,
|
|
172
|
+
mul1,
|
|
173
|
+
add2,
|
|
174
|
+
mul2,
|
|
175
|
+
add3,
|
|
176
|
+
mul3,
|
|
177
|
+
add4,
|
|
178
|
+
mul4,
|
|
179
|
+
add5,
|
|
180
|
+
mul5,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
sample_input = (
|
|
184
|
+
torch.rand((64,), dtype=torch.float32),
|
|
185
|
+
torch.rand((64,), dtype=torch.float32),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
torch_model = BasicAddModelWithMultipleOutputs().eval()
|
|
189
|
+
edge_model = ai_edge_torch.convert(torch_model, sample_input)
|
|
190
|
+
|
|
191
|
+
result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
|
|
192
|
+
self.assertTrue(result)
|
|
193
|
+
|
|
194
|
+
def test_apply_tfl_backdoor_flags(self):
|
|
195
|
+
"""Tests if _apply_tfl_backdoor_flags correctly sets the values in a Converter object."""
|
|
196
|
+
|
|
197
|
+
class MockConverterInternalObject:
|
|
198
|
+
|
|
199
|
+
def __init__(self):
|
|
200
|
+
self.subkey2 = "original_subvalue2"
|
|
201
|
+
|
|
202
|
+
class MockConverter:
|
|
203
|
+
|
|
204
|
+
def __init__(self):
|
|
205
|
+
self.key1 = "original_value1"
|
|
206
|
+
self.key2 = MockConverterInternalObject()
|
|
207
|
+
|
|
208
|
+
mock_converter = MockConverter()
|
|
209
|
+
flags = {"key1": "new_value1", "key2": {"subkey2": "new_subvalue2"}}
|
|
210
|
+
cutils._apply_tfl_backdoor_flags(mock_converter, flags)
|
|
211
|
+
|
|
212
|
+
self.assertTrue(flags["key1"], "new_value1")
|
|
213
|
+
self.assertTrue(flags["key2"]["subkey2"], "new_subvalue2")
|
|
214
|
+
|
|
215
|
+
def test_convert_add_backdoor_flags(self):
|
|
216
|
+
"""Tests conversion of an add module setting a tflite converter flag."""
|
|
217
|
+
|
|
218
|
+
class Add(torch.nn.Module):
|
|
219
|
+
|
|
220
|
+
def forward(self, a, b):
|
|
221
|
+
return a + b
|
|
222
|
+
|
|
223
|
+
args = (
|
|
224
|
+
torch.randn((5, 10)),
|
|
225
|
+
torch.randn((5, 10)),
|
|
226
|
+
)
|
|
227
|
+
torch_module = Add().eval()
|
|
228
|
+
|
|
229
|
+
with tempfile.TemporaryDirectory() as tmp_dir_path:
|
|
230
|
+
ir_dump_path = os.path.join(
|
|
231
|
+
tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
|
|
232
|
+
)
|
|
233
|
+
ai_edge_torch.convert(
|
|
234
|
+
torch_module, args, _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path}
|
|
235
|
+
)
|
|
236
|
+
self.assertTrue(os.path.isdir(ir_dump_path))
|
|
237
|
+
|
|
238
|
+
def test_convert_model_with_dynamic_batch(self):
|
|
239
|
+
"""
|
|
240
|
+
Test converting a simple model with dynamic batch size.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
class SampleModel(torch.nn.Module):
|
|
244
|
+
|
|
245
|
+
def __init__(self):
|
|
246
|
+
super().__init__()
|
|
247
|
+
self.w = torch.ones((10, 10)) * 2.7
|
|
248
|
+
|
|
249
|
+
def forward(self, x, y):
|
|
250
|
+
return x + y + self.w
|
|
251
|
+
|
|
252
|
+
sample_input = (torch.randn(4, 3, 10, 10), torch.randn(4, 3, 10, 10))
|
|
253
|
+
batch = torch.export.Dim("batch")
|
|
254
|
+
dynamic_shapes = ({0: batch}, {0: batch})
|
|
255
|
+
|
|
256
|
+
model = SampleModel().eval()
|
|
257
|
+
edge_model = ai_edge_torch.convert(
|
|
258
|
+
model, sample_input, dynamic_shapes=dynamic_shapes
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
for batch_size in [2, 4, 10]:
|
|
262
|
+
validate_input = (
|
|
263
|
+
torch.randn(batch_size, 3, 10, 10),
|
|
264
|
+
torch.randn(batch_size, 3, 10, 10),
|
|
265
|
+
)
|
|
266
|
+
self.assertTrue(
|
|
267
|
+
model_coverage.compare_tflite_torch(edge_model, model, validate_input)
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def test_convert_model_with_kwargs(self):
|
|
271
|
+
"""
|
|
272
|
+
Test converting a simple model with sample_kwargs.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
class SampleModel(torch.nn.Module):
|
|
276
|
+
|
|
277
|
+
def forward(self, x, y):
|
|
278
|
+
return x + y
|
|
279
|
+
|
|
280
|
+
kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))
|
|
281
|
+
|
|
282
|
+
model = SampleModel().eval()
|
|
283
|
+
edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
|
|
284
|
+
|
|
285
|
+
self.assertTrue(
|
|
286
|
+
model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def test_convert_model_with_args_kwargs(self):
|
|
290
|
+
"""
|
|
291
|
+
Test converting a simple model with both sample_args and sample_kwargs.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
class SampleModel(torch.nn.Module):
|
|
295
|
+
|
|
296
|
+
def forward(self, x, y):
|
|
297
|
+
return x + y
|
|
298
|
+
|
|
299
|
+
args_gen = lambda: (torch.randn(10, 10),)
|
|
300
|
+
kwargs_gen = lambda: dict(y=torch.randn(10, 10))
|
|
301
|
+
|
|
302
|
+
model = SampleModel().eval()
|
|
303
|
+
edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
|
|
304
|
+
|
|
305
|
+
self.assertTrue(
|
|
306
|
+
model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
if __name__ == "__main__":
|
|
311
|
+
unittest.main()
|