ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,293 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import os
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.ao.quantization.quantize_pt2e
20
+ from torch.export import ExportedProgram
21
+ from torch.fx import GraphModule
22
+ from torch.fx import Node
23
+ import torch.utils._pytree as pytree
24
+
25
+ from ai_edge_torch.convert.fx_passes import ExportedProgramPassBase
26
+ from ai_edge_torch.convert.fx_passes import ExportedProgramPassResult
27
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
28
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
29
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
30
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
31
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
32
+
33
+ TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
34
+
35
+
36
+ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
37
+
38
+ def get_source_meta(self, node: torch.fx.Node):
39
+ keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
40
+ meta = {}
41
+ for key in keys:
42
+ if key in node.meta:
43
+ meta[key] = node.meta[key]
44
+ return meta
45
+
46
+ def insert_t_q_dq(
47
+ self,
48
+ graph: torch.fx.Graph,
49
+ input_dq: torch.fx.Node,
50
+ target: torch.fx.Node,
51
+ transpose_func: TransposeFunc,
52
+ transpose_node_meta: dict,
53
+ ) -> list[torch.fx.Node]:
54
+ """
55
+ original:
56
+ input_dq -> target
57
+ insert the node as:
58
+ input_dq -> (T q dq) -> target
59
+ """
60
+ assert utils.is_dq_node(input_dq)
61
+
62
+ q_args = input_dq.args[1:]
63
+ q_kwargs = input_dq.kwargs
64
+ q_op, dq_op = utils.get_paired_q_dq_ops(input_dq.target)
65
+ with graph.inserting_before(target):
66
+ t = graph.call_function(transpose_func, (input_dq,))
67
+ # Q and DQ inserted here may required updating the `axis` arg when they
68
+ # are per_channel ops. However, instead of updating here, the nodes would
69
+ # be marked as NHWC/NCHW and applied rewriters after partitioning.
70
+ q = graph.call_function(q_op, (t,) + q_args, q_kwargs)
71
+ dq = graph.call_function(dq_op, (q,) + q_args, q_kwargs)
72
+
73
+ input_dq.meta = transpose_node_meta
74
+ t.meta = transpose_node_meta
75
+ q.meta = transpose_node_meta
76
+ dq.meta = self.get_source_meta(target)
77
+
78
+ target.replace_input_with(input_dq, dq)
79
+ return [t, q, dq]
80
+
81
+ def insert_dq_t_q(
82
+ self,
83
+ graph: torch.fx.Graph,
84
+ input_q: torch.fx.Node,
85
+ target: torch.fx.Node,
86
+ transpose_func: TransposeFunc,
87
+ transpose_node_meta: dict,
88
+ ) -> list[torch.fx.Node]:
89
+ """
90
+ original:
91
+ input_q -> target
92
+ insert the node as:
93
+ input_q -> (dq T q) -> target
94
+ """
95
+ assert utils.is_q_node(input_q)
96
+
97
+ q_args = input_q.args[1:]
98
+ q_kwargs = input_q.kwargs
99
+ q_op, dq_op = self.get_paired_q_dq_ops(input_q.target)
100
+ with graph.inserting_before(target):
101
+ # Q and DQ inserted here may required updating the `axis` arg when they
102
+ # are per_channel ops. However, instead of updating here, the nodes would
103
+ # be marked as NHWC/NCHW and applied rewriters after partitioning.
104
+ dq = graph.call_function(dq_op, (input_q,) + q_args, q_kwargs)
105
+ t = graph.call_function(transpose_func, (dq,))
106
+ q = graph.call_function(q_op, (t,) + q_args, q_kwargs)
107
+
108
+ dq.meta = transpose_node_meta
109
+ t.meta = transpose_node_meta
110
+ q.meta = transpose_node_meta
111
+
112
+ target.replace_input_with(input_q, q)
113
+ return [dq, t, q]
114
+
115
+ def insert_layout_transpose(
116
+ self,
117
+ graph: torch.fx.Graph,
118
+ input_node: torch.fx.Node,
119
+ target_node: torch.fx.Node,
120
+ transpose_func: TransposeFunc,
121
+ transpose_node_meta: dict,
122
+ ) -> None:
123
+ assert transpose_func in (utils.tensor_to_nchw, utils.tensor_to_nhwc)
124
+
125
+ # new_nodes only contains Q/DQ/Transpose nodes, which are all SISO.
126
+ # Insertion order input nodes -> output nodes
127
+ new_nodes = []
128
+
129
+ # Constraint Q2: the NHWC partition's entry and exit must not be output
130
+ # edges of Q/DQ ops that are connected to a constant/weight tensor.
131
+ while layout_mark.is_const_node(input_node) and (
132
+ utils.is_dq_node(input_node) or utils.is_q_node(input_node)
133
+ ):
134
+ with graph.inserting_before(target_node):
135
+ new_input_node = graph.node_copy(input_node)
136
+
137
+ target_node.replace_input_with(input_node, new_input_node)
138
+
139
+ new_nodes = [new_input_node] + new_nodes
140
+ input_node, target_node = new_input_node.args[0], new_input_node
141
+
142
+ if utils.is_q_node(input_node):
143
+ # Constraint Q3: when the entry and exit is right after a q op (occur after a (dq-op-q)
144
+ # triplet), the transpose must be added as a quantized transpose in (dq-T-q)
145
+ # input_q -> (dq T q) -> target
146
+ new_nodes = (
147
+ self.insert_dq_t_q(
148
+ graph,
149
+ input_node,
150
+ target_node,
151
+ transpose_func,
152
+ transpose_node_meta,
153
+ )
154
+ + new_nodes
155
+ )
156
+ elif utils.is_dq_node(input_node):
157
+ # Constraint Q1: the NHWC partition's entry and exit cannot be edges
158
+ # within (dq-op-q) triplet.
159
+ # input_dq -> (T q dq) -> target
160
+ new_nodes = (
161
+ self.insert_t_q_dq(
162
+ graph,
163
+ input_node,
164
+ target_node,
165
+ transpose_func,
166
+ transpose_node_meta,
167
+ )
168
+ + new_nodes
169
+ )
170
+ else:
171
+ # input -> target
172
+ with graph.inserting_before(target_node):
173
+ t = graph.call_function(transpose_func, (input_node,))
174
+ t.meta = transpose_node_meta
175
+ target_node.replace_input_with(input_node, t)
176
+ new_nodes = [t] + new_nodes
177
+
178
+ # Mark new nodes as NCHW or NHWC
179
+ # For all nodes before the transpose, mark it as input_marker
180
+ # For all nodes after the transpose (incl. transpose), mark it as output_marker
181
+ if transpose_func == utils.tensor_to_nchw:
182
+ input_marker, target_marker = (
183
+ layout_mark.mark_as_nhwc_node,
184
+ layout_mark.mark_as_nchw_node,
185
+ )
186
+ else:
187
+ input_marker, target_marker = (
188
+ layout_mark.mark_as_nchw_node,
189
+ layout_mark.mark_as_nhwc_node,
190
+ )
191
+
192
+ marker = input_marker
193
+ for node in new_nodes:
194
+ if node.target == transpose_func:
195
+ marker = target_marker
196
+ marker(node)
197
+ assert marker == target_marker
198
+
199
+ def input_to_nhwc(
200
+ self,
201
+ graph: torch.fx.Graph,
202
+ input_node: torch.fx.Node,
203
+ target_node: torch.fx.Node,
204
+ ) -> None:
205
+ if layout_mark.is_nhwc_node(input_node):
206
+ return
207
+
208
+ if not layout_check.is_4d(input_node):
209
+ raise AssertionError(
210
+ f"Attempting to convert non-NHWC compatible node to NHWC: {input_node}"
211
+ )
212
+
213
+ # Assign target node's source meta to the to_NHWC node, because the transpose
214
+ # is added for the existence of target node.
215
+ self.insert_layout_transpose(
216
+ graph,
217
+ input_node,
218
+ target_node,
219
+ utils.tensor_to_nhwc,
220
+ self.get_source_meta(target_node),
221
+ )
222
+
223
+ def input_to_nchw(
224
+ self,
225
+ graph: torch.fx.Graph,
226
+ input_node: torch.fx.Node,
227
+ target_node: torch.fx.Node,
228
+ ) -> None:
229
+ if layout_mark.is_nchw_node(input_node):
230
+ return
231
+
232
+ self.insert_layout_transpose(
233
+ graph,
234
+ input_node,
235
+ target_node,
236
+ utils.tensor_to_nchw,
237
+ self.get_source_meta(input_node),
238
+ )
239
+
240
+ def mark_const_nodes(self, exported_program: torch.export.ExportedProgram):
241
+ graph_module = exported_program.graph_module
242
+ graph = graph_module.graph
243
+
244
+ input_specs = exported_program.graph_signature.input_specs
245
+ non_user_input_names = set()
246
+ for spec in input_specs:
247
+ if spec.kind != torch.export.graph_signature.InputKind.USER_INPUT:
248
+ non_user_input_names.add(spec.arg.name)
249
+
250
+ for node in graph.nodes:
251
+ has_input_nodes = len(node.all_input_nodes) > 0
252
+ all_inputs_are_const = all(map(layout_mark.is_const_node, node.all_input_nodes))
253
+ if (
254
+ node.name in non_user_input_names
255
+ or (has_input_nodes and all_inputs_are_const)
256
+ or (node.op != "placeholder" and not has_input_nodes)
257
+ ):
258
+ layout_mark.mark_as_const_node(node)
259
+
260
+ def call(self, exported_program: torch.export.ExportedProgram):
261
+ self.mark_const_nodes(exported_program)
262
+
263
+ graph_module = exported_program.graph_module
264
+ partitioner = os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None)
265
+ if partitioner == "MINCUT":
266
+ graph_module = layout_partitioners.min_cut.partition(graph_module)
267
+ elif partitioner == "GREEDY":
268
+ graph_module = layout_partitioners.greedy.partition(graph_module)
269
+ else:
270
+ # By default use min cut partitioner if possible
271
+ if layout_partitioners.min_cut.can_partition(graph_module):
272
+ graph_module = layout_partitioners.min_cut.partition(graph_module)
273
+ else:
274
+ graph_module = layout_partitioners.greedy.partition(graph_module)
275
+
276
+ graph = graph_module.graph
277
+ for node in list(graph.nodes):
278
+ if layout_mark.is_nhwc_node(node):
279
+ for input_node in layout_check.get_layout_sensitive_inputs(node):
280
+ self.input_to_nhwc(graph, input_node, node)
281
+ layout_rewrite.rewrite_nhwc_node(node)
282
+ else:
283
+ for input_node in layout_check.get_layout_sensitive_inputs(node):
284
+ # Note: for non-4D tensors input_to_nchw is always noop.
285
+ self.input_to_nchw(graph, input_node, node)
286
+
287
+ graph_module.graph.eliminate_dead_code()
288
+ graph_module.recompile()
289
+ graph_module.graph.lint()
290
+ # Mark const node again for debugging
291
+ self.mark_const_nodes(exported_program)
292
+
293
+ return ExportedProgramPassResult(exported_program, True)
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Callable
16
+
17
+ import torch
18
+ import torch.ao.quantization.quantize_pt2e
19
+
20
+
21
+ def tensor_to_nhwc(t: torch.Tensor):
22
+ return torch.ops.aten.permute(t.contiguous(), [0, 2, 3, 1]).contiguous()
23
+
24
+
25
+ def tensor_to_nchw(t: torch.Tensor):
26
+ return torch.ops.aten.permute(t.contiguous(), [0, 3, 1, 2]).contiguous()
27
+
28
+
29
+ def flatten_torch_op_overloads(op):
30
+ if isinstance(op, torch._ops.OpOverloadPacket):
31
+ return [getattr(op, overload) for overload in op.overloads()]
32
+ return [op]
33
+
34
+
35
+ _TORCH_Q_OPS = [
36
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
37
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
38
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor2,
39
+ torch.ops.quantized_decomposed.quantize_per_channel.default,
40
+ ]
41
+
42
+ _TORCH_DQ_OPS = [
43
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
44
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
45
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2,
46
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
47
+ ]
48
+
49
+
50
+ def is_q_node(node: torch.fx.Node):
51
+ return node.target in _TORCH_Q_OPS
52
+
53
+
54
+ def is_dq_node(node: torch.fx.Node):
55
+ return node.target in _TORCH_DQ_OPS
56
+
57
+
58
+ def get_paired_q_dq_ops(op: Callable) -> tuple[Callable, Callable]:
59
+ for q, dq in zip(_TORCH_Q_OPS, _TORCH_DQ_OPS):
60
+ if op in (q, dq):
61
+ return q, dq
62
+ raise AssertionError(f"{op} is not a Q/DQ op.")
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,311 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ import torch
22
+ import torchvision
23
+
24
+ import ai_edge_torch
25
+ from ai_edge_torch.convert import conversion_utils as cutils
26
+ from ai_edge_torch.testing import model_coverage
27
+
28
+
29
+ class TestConvert(unittest.TestCase):
30
+ """Tests conversion of various modules."""
31
+
32
+ def setUp(self):
33
+ torch.manual_seed(0)
34
+
35
+ def test_convert_add(self):
36
+ """Tests conversion of a simple Add module."""
37
+
38
+ class Add(torch.nn.Module):
39
+
40
+ def forward(self, a, b):
41
+ return a + b
42
+
43
+ args = (
44
+ torch.randn((5, 10)),
45
+ torch.randn((5, 10)),
46
+ )
47
+ torch_module = Add().eval()
48
+ edge_model = ai_edge_torch.convert(torch_module, args)
49
+
50
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
51
+
52
+ def test_convert_dot_add(self):
53
+ class DotAdd(torch.nn.Module):
54
+ """Tests conversion of a matrix multiplication followed by an add."""
55
+
56
+ def forward(self, a, b, c):
57
+ return a @ b + c
58
+
59
+ args = (
60
+ torch.randn((5, 10)),
61
+ torch.randn((10, 5)),
62
+ torch.randn((5, 5)),
63
+ )
64
+ torch_module = DotAdd().eval()
65
+ edge_model = ai_edge_torch.convert(torch_module, args)
66
+
67
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
68
+
69
+ def test_convert_resnet18(self):
70
+ args = (torch.randn(4, 3, 224, 224),)
71
+ torch_module = torchvision.models.resnet18().eval()
72
+ edge_model = ai_edge_torch.convert(torch_module, args)
73
+
74
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
75
+
76
+ def test_signature_args_ordering(self):
77
+ """Tests conversion of a model with more than 10 arguments."""
78
+
79
+ class AddChainWith11Args(torch.nn.Module):
80
+
81
+ def forward(
82
+ self,
83
+ arg0: "f32[64]",
84
+ arg1: "f32[64]",
85
+ arg2: "f32[64]",
86
+ arg3: "f32[64]",
87
+ arg4: "f32[64]",
88
+ arg5: "f32[64]",
89
+ arg6: "f32[64]",
90
+ arg7: "f32[64]",
91
+ arg8: "f32[64]",
92
+ arg9: "f32[64]",
93
+ arg10: "f32[64]",
94
+ ):
95
+ add0 = torch.add(arg0, arg1)
96
+ add1 = torch.add(add0, arg2)
97
+ add2 = torch.add(add1, arg3)
98
+ add3 = torch.add(add2, arg4)
99
+ add4 = torch.add(add3, arg5)
100
+ add5 = torch.add(add4, arg6)
101
+ add6 = torch.add(add5, arg7)
102
+ add7 = torch.add(add6, arg8)
103
+ add8 = torch.add(add7, arg9)
104
+ add9 = torch.add(add8, arg10)
105
+ return add9
106
+
107
+ sample_input = lambda: (
108
+ torch.rand((64,), dtype=torch.float32),
109
+ torch.rand((64,), dtype=torch.float32),
110
+ torch.rand((64,), dtype=torch.float32),
111
+ torch.rand((64,), dtype=torch.float32),
112
+ torch.rand((64,), dtype=torch.float32),
113
+ torch.rand((64,), dtype=torch.float32),
114
+ torch.rand((64,), dtype=torch.float32),
115
+ torch.rand((64,), dtype=torch.float32),
116
+ torch.rand((64,), dtype=torch.float32),
117
+ torch.rand((64,), dtype=torch.float32),
118
+ torch.rand((64,), dtype=torch.float32),
119
+ )
120
+ torch_model = AddChainWith11Args().eval()
121
+ edge_model = ai_edge_torch.convert(torch_model, sample_input())
122
+
123
+ result = model_coverage.compare_tflite_torch(
124
+ edge_model, torch_model, sample_input, num_valid_inputs=10
125
+ )
126
+ self.assertTrue(result)
127
+
128
+ def test_multi_output_model(self):
129
+ """Tests conversion of a model that returns multiple outputs."""
130
+
131
+ class BasicAddModelWithMultipleOutputs(torch.nn.Module):
132
+
133
+ def forward(self, arg0, arg1):
134
+ add0 = arg0 + arg1
135
+ mul0 = arg0 * arg1
136
+ return add0, mul0
137
+
138
+ sample_input = (
139
+ torch.rand((64,), dtype=torch.float32),
140
+ torch.rand((64,), dtype=torch.float32),
141
+ )
142
+
143
+ torch_model = BasicAddModelWithMultipleOutputs().eval()
144
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
145
+
146
+ result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
147
+ self.assertTrue(result)
148
+
149
+ def test_12_outputs_model(self):
150
+ """Tests conversion of a model that returns multiple outputs."""
151
+
152
+ class BasicAddModelWithMultipleOutputs(torch.nn.Module):
153
+
154
+ def forward(self, arg0, arg1):
155
+ add0 = arg0 + arg1
156
+ mul0 = arg0 * arg1
157
+ add1 = add0 + mul0
158
+ mul1 = add0 * mul0
159
+ add2 = add1 + mul1
160
+ mul2 = add1 * mul1
161
+ add3 = add2 + mul2
162
+ mul3 = add2 * mul2
163
+ add4 = add3 + mul3
164
+ mul4 = add3 * mul3
165
+ add5 = add4 + mul4
166
+ mul5 = add4 * mul4
167
+
168
+ return (
169
+ add0,
170
+ mul0,
171
+ add1,
172
+ mul1,
173
+ add2,
174
+ mul2,
175
+ add3,
176
+ mul3,
177
+ add4,
178
+ mul4,
179
+ add5,
180
+ mul5,
181
+ )
182
+
183
+ sample_input = (
184
+ torch.rand((64,), dtype=torch.float32),
185
+ torch.rand((64,), dtype=torch.float32),
186
+ )
187
+
188
+ torch_model = BasicAddModelWithMultipleOutputs().eval()
189
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
190
+
191
+ result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
192
+ self.assertTrue(result)
193
+
194
+ def test_apply_tfl_backdoor_flags(self):
195
+ """Tests if _apply_tfl_backdoor_flags correctly sets the values in a Converter object."""
196
+
197
+ class MockConverterInternalObject:
198
+
199
+ def __init__(self):
200
+ self.subkey2 = "original_subvalue2"
201
+
202
+ class MockConverter:
203
+
204
+ def __init__(self):
205
+ self.key1 = "original_value1"
206
+ self.key2 = MockConverterInternalObject()
207
+
208
+ mock_converter = MockConverter()
209
+ flags = {"key1": "new_value1", "key2": {"subkey2": "new_subvalue2"}}
210
+ cutils._apply_tfl_backdoor_flags(mock_converter, flags)
211
+
212
+ self.assertTrue(flags["key1"], "new_value1")
213
+ self.assertTrue(flags["key2"]["subkey2"], "new_subvalue2")
214
+
215
+ def test_convert_add_backdoor_flags(self):
216
+ """Tests conversion of an add module setting a tflite converter flag."""
217
+
218
+ class Add(torch.nn.Module):
219
+
220
+ def forward(self, a, b):
221
+ return a + b
222
+
223
+ args = (
224
+ torch.randn((5, 10)),
225
+ torch.randn((5, 10)),
226
+ )
227
+ torch_module = Add().eval()
228
+
229
+ with tempfile.TemporaryDirectory() as tmp_dir_path:
230
+ ir_dump_path = os.path.join(
231
+ tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
232
+ )
233
+ ai_edge_torch.convert(
234
+ torch_module, args, _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path}
235
+ )
236
+ self.assertTrue(os.path.isdir(ir_dump_path))
237
+
238
+ def test_convert_model_with_dynamic_batch(self):
239
+ """
240
+ Test converting a simple model with dynamic batch size.
241
+ """
242
+
243
+ class SampleModel(torch.nn.Module):
244
+
245
+ def __init__(self):
246
+ super().__init__()
247
+ self.w = torch.ones((10, 10)) * 2.7
248
+
249
+ def forward(self, x, y):
250
+ return x + y + self.w
251
+
252
+ sample_input = (torch.randn(4, 3, 10, 10), torch.randn(4, 3, 10, 10))
253
+ batch = torch.export.Dim("batch")
254
+ dynamic_shapes = ({0: batch}, {0: batch})
255
+
256
+ model = SampleModel().eval()
257
+ edge_model = ai_edge_torch.convert(
258
+ model, sample_input, dynamic_shapes=dynamic_shapes
259
+ )
260
+
261
+ for batch_size in [2, 4, 10]:
262
+ validate_input = (
263
+ torch.randn(batch_size, 3, 10, 10),
264
+ torch.randn(batch_size, 3, 10, 10),
265
+ )
266
+ self.assertTrue(
267
+ model_coverage.compare_tflite_torch(edge_model, model, validate_input)
268
+ )
269
+
270
+ def test_convert_model_with_kwargs(self):
271
+ """
272
+ Test converting a simple model with sample_kwargs.
273
+ """
274
+
275
+ class SampleModel(torch.nn.Module):
276
+
277
+ def forward(self, x, y):
278
+ return x + y
279
+
280
+ kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))
281
+
282
+ model = SampleModel().eval()
283
+ edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
284
+
285
+ self.assertTrue(
286
+ model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
287
+ )
288
+
289
+ def test_convert_model_with_args_kwargs(self):
290
+ """
291
+ Test converting a simple model with both sample_args and sample_kwargs.
292
+ """
293
+
294
+ class SampleModel(torch.nn.Module):
295
+
296
+ def forward(self, x, y):
297
+ return x + y
298
+
299
+ args_gen = lambda: (torch.randn(10, 10),)
300
+ kwargs_gen = lambda: dict(y=torch.randn(10, 10))
301
+
302
+ model = SampleModel().eval()
303
+ edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
304
+
305
+ self.assertTrue(
306
+ model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
307
+ )
308
+
309
+
310
+ if __name__ == "__main__":
311
+ unittest.main()