ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,258 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Layout check for the optimized layout transposes pass."""
16
+
17
+ import dataclasses
18
+ import operator
19
+
20
+ import ai_edge_torch
21
+ from ai_edge_torch import lowertools
22
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
23
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
24
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry
25
+ import torch
26
+ from torch.fx import Node
27
+
28
+ aten = torch.ops.aten
29
+
30
+ __all__ = [
31
+ "is_4d",
32
+ "can_be_nhwc",
33
+ "must_be_nhwc",
34
+ "get_layout_sensitive_inputs",
35
+ "get_no_rewriter_nhwc_ops",
36
+ ]
37
+
38
+
39
+ class LayoutSensitiveInputsGettersRegistry(OpFuncRegistry):
40
+
41
+ def __missing__(self, op):
42
+
43
+ def _default_getter(node: Node):
44
+ """Default layout sensitive inputs are all input nodes."""
45
+ return node.all_input_nodes
46
+
47
+ return _default_getter
48
+
49
+
50
+ @dataclasses.dataclass
51
+ class NHWCable:
52
+ can_be: bool
53
+ must_be: bool
54
+
55
+ def __bool__(self):
56
+ raise RuntimeError(
57
+ "Boolean value on NHWCable is disabled. Please call .can_be or .must_be"
58
+ )
59
+
60
+
61
+ class NHWCableNodeCheckersRegistry(OpFuncRegistry):
62
+
63
+ def __init__(self):
64
+ self.no_rewriter_nhwc_ops = set()
65
+
66
+ def __missing__(self, op):
67
+
68
+ def _default_checker(node: Node):
69
+ """Default checker for most of the layout insensitive ops.
70
+
71
+ The node should be marked and rewritten to NHWC if:
72
+ 1. The node output is a single 4-D tensor.
73
+ 2. All layout sensitive input nodes (default all inputs) of this
74
+ node are all marked as NHWC.
75
+ 3. All layout sensitive input nodes return 4-D tensors.
76
+ 4. There exists a rewrite rule for this node (explicit registry
77
+ required for noop.)
78
+ """
79
+ nonlocal self
80
+ layout_sensitive_inputs = get_layout_sensitive_inputs(node)
81
+
82
+ can_be_nhwc = is_4d(node) and all_layout_sensitive_inputs_are_4d(node)
83
+ has_rewriter = layout_rewrite.has_nhwc_rewriter(node)
84
+
85
+ if can_be_nhwc and not has_rewriter:
86
+ self.no_rewriter_nhwc_ops.add(node.target)
87
+
88
+ return NHWCable(can_be_nhwc and has_rewriter, must_be=False)
89
+
90
+ return _default_checker
91
+
92
+
93
+ nhwcable_node_checkers = NHWCableNodeCheckersRegistry()
94
+ layout_sensitive_inputs_getters = LayoutSensitiveInputsGettersRegistry()
95
+
96
+
97
+ def can_be_nhwc(node: Node):
98
+ return nhwcable_node_checkers[node.target](node).can_be
99
+
100
+
101
+ def must_be_nhwc(node: Node):
102
+ return nhwcable_node_checkers[node.target](node).must_be
103
+
104
+
105
+ def get_layout_sensitive_inputs(node: Node):
106
+ return layout_sensitive_inputs_getters[node.target](node)
107
+
108
+
109
+ def get_no_rewriter_nhwc_ops():
110
+ """Debug only: get the ops that may be NHWC but not due to no rewriter registered."""
111
+ return nhwcable_node_checkers.no_rewriter_nhwc_ops
112
+
113
+
114
+ def is_4d(node: Node):
115
+ val = node.meta.get("val")
116
+ if val is None:
117
+ return False
118
+
119
+ if isinstance(val, (list, tuple)) and val:
120
+ val = val[0]
121
+
122
+ if not hasattr(val, "shape"):
123
+ return False
124
+
125
+ return len(val.shape) == 4
126
+
127
+
128
+ def all_layout_sensitive_inputs_are_4d(node: Node):
129
+ return all(is_4d(m) for m in get_layout_sensitive_inputs(node))
130
+
131
+
132
+ # ==== Quantize ops (use default NHWC checker)
133
+
134
+
135
+ @layout_sensitive_inputs_getters.register(
136
+ torch.ops.quantized_decomposed.dequantize_per_tensor
137
+ )
138
+ @layout_sensitive_inputs_getters.register(
139
+ torch.ops.quantized_decomposed.quantize_per_tensor
140
+ )
141
+ @layout_sensitive_inputs_getters.register(
142
+ torch.ops.quantized_decomposed.dequantize_per_channel
143
+ )
144
+ @layout_sensitive_inputs_getters.register(
145
+ torch.ops.quantized_decomposed.quantize_per_channel
146
+ )
147
+ def _qdq_layout_sensitive_inputs_getter(node: Node):
148
+ return [node.args[0]]
149
+
150
+
151
+ # ==== Ops must be NHWC if possible
152
+
153
+
154
+ @layout_sensitive_inputs_getters.register(aten.conv2d)
155
+ @layout_sensitive_inputs_getters.register(aten.convolution)
156
+ @layout_sensitive_inputs_getters.register(
157
+ aten._native_batch_norm_legit_no_training
158
+ )
159
+ @layout_sensitive_inputs_getters.register(aten.group_norm)
160
+ @layout_sensitive_inputs_getters.register(aten.native_group_norm)
161
+ def _first_arg_getter(node):
162
+ return [node.args[0]]
163
+
164
+
165
+ # Note: default layout sensitive inputs are all inputs when not specified.
166
+ @nhwcable_node_checkers.register(aten.max_pool2d)
167
+ @nhwcable_node_checkers.register(aten.max_pool2d_with_indices)
168
+ @nhwcable_node_checkers.register(aten.amax)
169
+ @nhwcable_node_checkers.register(aten.avg_pool2d)
170
+ @nhwcable_node_checkers.register(aten._prelu_kernel)
171
+ @nhwcable_node_checkers.register(aten.upsample_bilinear2d)
172
+ @nhwcable_node_checkers.register(aten.upsample_nearest2d)
173
+ @nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
174
+ @nhwcable_node_checkers.register(aten.conv2d)
175
+ @nhwcable_node_checkers.register(aten.convolution)
176
+ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
177
+ can_be = all_layout_sensitive_inputs_are_4d(node)
178
+ return NHWCable(can_be, must_be=can_be)
179
+
180
+
181
+ @nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
182
+ def _aten_norm_checker(node):
183
+ val = node.meta.get("val")
184
+ if (
185
+ not isinstance(val, (list, tuple))
186
+ or not val
187
+ or not hasattr(val[0], "shape")
188
+ ):
189
+ return NHWCable(can_be=False, must_be=False)
190
+ return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
191
+
192
+
193
+ @nhwcable_node_checkers.register(aten.group_norm)
194
+ def _aten_group_norm_checker(node):
195
+ val = node.meta.get("val")
196
+ if not hasattr(val, "shape"):
197
+ return NHWCable(can_be=False, must_be=False)
198
+
199
+ can_be = len(val.shape) == 4
200
+ must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
201
+ return NHWCable(can_be=can_be, must_be=must_be)
202
+
203
+
204
+ @nhwcable_node_checkers.register(aten.native_group_norm)
205
+ def _aten_native_group_norm_checker(node):
206
+ val = node.meta.get("val")
207
+ if (
208
+ not isinstance(val, (list, tuple))
209
+ or not val
210
+ or not hasattr(val[0], "shape")
211
+ ):
212
+ return NHWCable(can_be=False, must_be=False)
213
+ if len(node.args) >= 3 and (
214
+ node.args[1] is not None or node.args[2] is not None
215
+ ):
216
+ # Disable NHWC rewriter due to precision issue with weight and bias.
217
+ # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
218
+ return NHWCable(can_be=False, must_be=False)
219
+ return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
220
+
221
+
222
+ # ==== Ops must be NCHW
223
+
224
+
225
+ @nhwcable_node_checkers.register(lowertools.mark_tensor_op)
226
+ @nhwcable_node_checkers.register(utils.tensor_to_nchw)
227
+ @nhwcable_node_checkers.register(utils.tensor_to_nhwc)
228
+ @nhwcable_node_checkers.register("output")
229
+ @nhwcable_node_checkers.register(aten.view)
230
+ @nhwcable_node_checkers.register(aten.unsqueeze_copy)
231
+ @nhwcable_node_checkers.register(aten.expand)
232
+ @nhwcable_node_checkers.register(aten.permute)
233
+ @nhwcable_node_checkers.register(aten.as_strided)
234
+ def _not_nhwc(node: Node):
235
+ return NHWCable(can_be=False, must_be=False)
236
+
237
+
238
+ # ==== Others
239
+
240
+
241
+ @layout_sensitive_inputs_getters.register(aten.index)
242
+ @layout_sensitive_inputs_getters.register(aten._unsafe_index)
243
+ def _aten_index_layout_sensitive_inputs_getter(node):
244
+ return [node.args[0]]
245
+
246
+
247
+ @nhwcable_node_checkers.register(aten.index)
248
+ @nhwcable_node_checkers.register(aten._unsafe_index)
249
+ def _aten_index_checker(node):
250
+ layout_sensitive_inputs = get_layout_sensitive_inputs(node)
251
+ can_be = is_4d(node) and all_layout_sensitive_inputs_are_4d(node)
252
+ return NHWCable(can_be, must_be=False)
253
+
254
+
255
+ @nhwcable_node_checkers.register(operator.getitem)
256
+ def _getitem_checker(node):
257
+ src = node.args[0]
258
+ return nhwcable_node_checkers[src.target](src)
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Layout mark for the optimized layout transposes pass."""
16
+
17
+ import torch
18
+
19
+ # Tag which is added to a node's meta to indicate that is is part of the NHWC
20
+ # partition.
21
+ IS_NHWC_NODE = "OPTIMIZE_LAYOUT_TRANSPOSES_PASS__IS_NHWC_NODE"
22
+
23
+
24
+ # Tag which is added to a node's meta to indicate that it is derived completely
25
+ # from constant and/or weight tensor(s).
26
+ IS_CONST_NODE = "OPTIMIZE_LAYOUT_TRANSPOSES_PASS__IS_CONST_NODE"
27
+
28
+
29
+ def mark_as_nhwc_node(node: torch.fx.Node) -> None:
30
+ node.meta[IS_NHWC_NODE] = True
31
+
32
+
33
+ def mark_as_nchw_node(node: torch.fx.Node) -> None:
34
+ node.meta[IS_NHWC_NODE] = False
35
+
36
+
37
+ def is_nhwc_node(node: torch.fx.Node) -> bool:
38
+ return node.meta.get(IS_NHWC_NODE, False)
39
+
40
+
41
+ def is_nchw_node(node: torch.fx.Node) -> bool:
42
+ return not is_nhwc_node(node)
43
+
44
+
45
+ def mark_as_const_node(node: torch.fx.Node) -> None:
46
+ node.meta[IS_CONST_NODE] = True
47
+
48
+
49
+ def is_const_node(node: torch.fx.Node) -> bool:
50
+ return node.meta.get(IS_CONST_NODE, False)
@@ -0,0 +1,18 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Layout partitioners."""
16
+
17
+ from . import greedy
18
+ from . import min_cut
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Greedy partitioning algorithm."""
16
+
17
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check
18
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
19
+ import torch
20
+
21
+
22
+ def partition(graph_module: torch.fx.GraphModule):
23
+ """Partition the graph module into NHWC and non-NHWC subgraphs.
24
+
25
+ Partition the graph module into NHWC and non-NHWC subgraphs and mark nodes in
26
+ the NHWC partitions.
27
+
28
+ Implements O(|V|) greedy partitioning algorithm.
29
+
30
+ Args:
31
+ graph_module: The graph module to be partitioned.
32
+
33
+ Returns:
34
+ The partitioned graph module.
35
+ """
36
+ graph = graph_module.graph
37
+
38
+ for node in list(graph.nodes):
39
+ if not node.all_input_nodes:
40
+ # This node has no inputs so we don't need to change anything
41
+ continue
42
+
43
+ if layout_check.must_be_nhwc(node):
44
+ # If the node has must_be_nhwc equals true, mark this node as NHWC
45
+
46
+ layout_mark.mark_as_nhwc_node(node)
47
+ elif layout_check.can_be_nhwc(node):
48
+ # If the following conditions are all true, mark this node as NHWC
49
+ # - The node has can_be_nhwc equals true
50
+ # - Any of the node's layout sensitive inputs is marked as NHWC
51
+ # - All the node's layout sensitive inputs are 4D tensors
52
+
53
+ layout_sensitive_inputs = layout_check.get_layout_sensitive_inputs(node)
54
+
55
+ should_be_nhwc = any(
56
+ map(layout_mark.is_nhwc_node, layout_sensitive_inputs)
57
+ )
58
+ for input_node in layout_sensitive_inputs:
59
+ if not layout_mark.is_nhwc_node(input_node) and not layout_check.is_4d(
60
+ input_node
61
+ ):
62
+ should_be_nhwc = False
63
+
64
+ if should_be_nhwc:
65
+ layout_mark.mark_as_nhwc_node(node)
66
+
67
+ graph_module.recompile()
68
+ return graph_module
@@ -0,0 +1,216 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Min cut solver for partitioning the graph module into NHWC and non-NHWC subgraphs."""
16
+
17
+ import collections
18
+ import dataclasses
19
+
20
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
22
+ import numpy as np
23
+ import scipy
24
+ import torch
25
+
26
+
27
+ def can_partition(graph_module: torch.fx.GraphModule):
28
+ """Returns true if the input graph_module can be partitioned by min cut solver
29
+
30
+ in a reasonable time.
31
+
32
+ The min cut solver implements O(|V|^2|E|) Dinic's algorithm, which may
33
+ take a long time to complete for large graph module. This function determines
34
+ whether the graph module can be partitioned by the graph module size.
35
+ """
36
+ graph = graph_module.graph
37
+ n_nodes = len(graph.nodes)
38
+ n_edges = sum(len(n.users) for n in graph.nodes)
39
+
40
+ # According to the experiments our model set, |V| < 2000 can
41
+ # be partitioned generally in a reasonable time.
42
+ return n_nodes**2 * n_edges < 2000**3
43
+
44
+
45
+ class MinCutSolver:
46
+ # A number that is large enough but can fit into int32 with all computations
47
+ # in the maximum flow.
48
+ INF_COST = 1 << 28
49
+
50
+ def __init__(self):
51
+ self._edges_map = collections.defaultdict(dict)
52
+ self._obj_to_node = {}
53
+ self._node_to_obj = {}
54
+ self._nodes_cnt = 0
55
+
56
+ self.source = self._next_nid()
57
+ self.sink = self._next_nid()
58
+
59
+ def _next_nid(self):
60
+ nid = self._nodes_cnt
61
+ self._nodes_cnt += 1
62
+ return nid
63
+
64
+ @property
65
+ def nodes(self):
66
+ return list(range(self._nodes_cnt))
67
+
68
+ @property
69
+ def edges_map(self):
70
+ return self._edges_map
71
+
72
+ @property
73
+ def edges(self):
74
+ return [
75
+ [n, m, cost]
76
+ for n, next_nodes in self._edges_map.items()
77
+ for m, cost in next_nodes.items()
78
+ ]
79
+
80
+ @property
81
+ def graph(self):
82
+ edges = np.array(self.edges)
83
+ return scipy.sparse.csr_matrix(
84
+ (
85
+ np.minimum(edges[:, 2], MinCutSolver.INF_COST),
86
+ (edges[:, 0], edges[:, 1]),
87
+ ),
88
+ shape=(self._nodes_cnt, self._nodes_cnt),
89
+ dtype=np.int32,
90
+ )
91
+
92
+ def get_nid(self, obj=None):
93
+ if obj is None:
94
+ return self._next_nid()
95
+
96
+ nid = self._obj_to_node.get(obj)
97
+ if nid is None:
98
+ nid = self._next_nid()
99
+
100
+ self._obj_to_node[obj] = nid
101
+ self._node_to_obj[nid] = obj
102
+ return nid
103
+
104
+ def get_obj(self, nid: int):
105
+ return self._node_to_obj.get(nid, None)
106
+
107
+ def add_edge(self, a_id: int, b_id: int, cost: int):
108
+ assert isinstance(cost, int)
109
+ self._edges_map[a_id][b_id] = cost
110
+
111
+ def solve(self):
112
+ flow = scipy.sparse.csgraph.maximum_flow(
113
+ self.graph, self.source, self.sink, method="dinic"
114
+ ).flow
115
+
116
+ # Max-flow min-cut theorem: find min-cuts in the residual network.
117
+ ds = scipy.cluster.hierarchy.DisjointSet(self.nodes)
118
+ for n, m, cost in self.edges:
119
+ if abs(flow[n, m]) < cost:
120
+ ds.merge(n, m)
121
+
122
+ residual_reachable_nodes = ds.subset(self.source)
123
+
124
+ cuts = set()
125
+ for n, m, cost in self.edges:
126
+ if n in residual_reachable_nodes and m not in residual_reachable_nodes:
127
+ cuts.add((n, m))
128
+
129
+ return cuts
130
+
131
+
132
+ @dataclasses.dataclass(frozen=True)
133
+ class MultiUsersDummyNode:
134
+ src: torch.fx.Node
135
+
136
+
137
+ def partition(graph_module: torch.fx.GraphModule):
138
+ """Partition the graph module into NHWC and non-NHWC subgraphs, and mark
139
+
140
+ nodes in the NHWC partitions.
141
+
142
+ Implements O(|V|^2|E|) min-cut (optimal) partitioning algorithm.
143
+ """
144
+ graph = graph_module.graph
145
+
146
+ mc_solver = MinCutSolver()
147
+ for fx_node in graph.nodes:
148
+ if layout_mark.is_const_node(fx_node):
149
+ continue
150
+
151
+ nid = mc_solver.get_nid(fx_node)
152
+ if fx_node.op in ("placeholder", "output"):
153
+ # All inputs and outputs are not NHWCable nodes in the graph,
154
+ # connected to source S directly with inf cost to cut
155
+ mc_solver.add_edge(mc_solver.source, nid, cost=MinCutSolver.INF_COST)
156
+ elif not layout_check.can_be_nhwc(fx_node):
157
+ # All not NHWCable nodes are connected to source S directly,
158
+ # with inf cost to cut.
159
+ mc_solver.add_edge(mc_solver.source, nid, cost=MinCutSolver.INF_COST)
160
+ elif layout_check.must_be_nhwc(fx_node):
161
+ # All must be NHWC nodes are connected to sink T directly,
162
+ # with inf cost to cut
163
+ mc_solver.add_edge(nid, mc_solver.sink, cost=MinCutSolver.INF_COST)
164
+
165
+ cut_cost = 10 # set 10 to be a unit of cut cost
166
+ if fx_node.target in (torch.ops.aten.mean.default, torch.ops.aten.mean.dim):
167
+ # TFLite converter cannot fuse the lowering of (tpos-mean) but (mean-tpos)
168
+ # when it applies on the feature dimensions. Therefore decreasing the cut
169
+ # cost for aten.mean's out-going edges to favor having a cut (transpose)
170
+ # after the node than before when the number of transposes are equal.
171
+ # TODO: Remove this rule when converter has fuse rule for tpos-mean.
172
+ cut_cost = 9
173
+
174
+ if len(fx_node.users) > 1:
175
+ # If a node's (A1) output is used by multiple nodes (B1, B2, B3, ...),
176
+ # the cost to split A1 and Bs into different partitions would just be 1
177
+ # transpose. So we need to introduce a dummy node between A1 and Bs in the
178
+ # min-cut graph to reflect the fact that disconnecting them doesn't
179
+ # introduce multiple transposes.
180
+ dummy_nid = mc_solver.get_nid(MultiUsersDummyNode(fx_node))
181
+ mc_solver.add_edge(nid, dummy_nid, cost=cut_cost)
182
+ mc_solver.add_edge(dummy_nid, nid, cost=cut_cost)
183
+ nid = dummy_nid
184
+
185
+ for user in fx_node.users:
186
+ # All the other nodes and edges in the model graph are scattered
187
+ # and connected as is in the new graph, with 1 cost to cut an edge.
188
+ user_id = mc_solver.get_nid(user)
189
+ mc_solver.add_edge(nid, user_id, cost=cut_cost)
190
+ mc_solver.add_edge(user_id, nid, cost=cut_cost)
191
+
192
+ cuts = mc_solver.solve()
193
+
194
+ # Find nodes that is connected to sink after the min-cut and mark as NHWC.
195
+ ds = scipy.cluster.hierarchy.DisjointSet(mc_solver.nodes)
196
+ for n, m, cost in mc_solver.edges:
197
+ if (n, m) in cuts or (m, n) in cuts:
198
+ continue
199
+ ds.merge(n, m)
200
+ assert not ds.connected(mc_solver.source, mc_solver.sink)
201
+
202
+ for nid in mc_solver.nodes:
203
+ if ds.connected(nid, mc_solver.source):
204
+ continue
205
+
206
+ obj = mc_solver.get_obj(nid)
207
+ if obj is None:
208
+ continue
209
+ if isinstance(obj, MultiUsersDummyNode):
210
+ continue
211
+
212
+ assert isinstance(obj, torch.fx.Node)
213
+ layout_mark.mark_as_nhwc_node(obj)
214
+
215
+ graph_module.recompile()
216
+ return graph_module