ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch._config import config
17
+ from ai_edge_torch._convert.converter import convert
18
+ from ai_edge_torch._convert.converter import signature
19
+ from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
20
+ from ai_edge_torch.model import Model
21
+ from ai_edge_torch.version import __version__
22
+
23
+ def load(path: str) -> Model:
24
+ """Imports an ai_edge_torch model from disk.
25
+
26
+ Args:
27
+ path: The path to the serialized ai_edge_torch model.
28
+
29
+ Returns:
30
+ An ai_edge_torch.model.Model object.
31
+ """
32
+ return Model.load(path)
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Provides a configuration for the ai-edge-torch."""
17
+
18
+ import functools
19
+ import logging
20
+ import os
21
+
22
+ __all__ = ["config"]
23
+
24
+
25
+ def _get_bool_env_var(name: str, default: bool) -> bool:
26
+ var = os.environ.get(name, "false")
27
+ var = var.lower().strip()
28
+ if var in ("y", "yes", "t", "true", "on", "1"):
29
+ return True
30
+ elif var in ("n", "no", "f", "false", "off", "0"):
31
+ return False
32
+ else:
33
+ logging.warning("Invalid %s value is ignored: %s.", name, var)
34
+ return default
35
+
36
+
37
+ class _Config:
38
+ """ai-edge-torch global configs."""
39
+
40
+ @property
41
+ @functools.cache # pylint: disable=method-cache-max-size-none
42
+ def use_torch_xla(self) -> bool:
43
+ """True if using torch_xla to lower torch ops to StableHLO.
44
+
45
+ To use torch_xla as the lowering backend, set environment variable
46
+ `USE_TORCH_XLA` to "true".
47
+ """
48
+ return _get_bool_env_var("USE_TORCH_XLA", default=False)
49
+
50
+ @property
51
+ def in_oss(self) -> bool:
52
+ """True if the code is not running in google internal environment."""
53
+ return True
54
+
55
+ @property
56
+ def enable_group_norm_composite(self) -> bool:
57
+ """True if lowering group norm in StableHLO composite.
58
+
59
+ Currently only supports NHWC group norm generated by
60
+ OptimizeLayoutTransposesPass.
61
+ """
62
+ return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
63
+
64
+ @enable_group_norm_composite.setter
65
+ def enable_group_norm_composite(self, value: bool):
66
+ os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67
+
68
+
69
+ config = _Config()
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,153 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import logging
17
+ from typing import Any, Literal, Optional, Union
18
+
19
+ import ai_edge_torch
20
+ from ai_edge_torch import fx_pass_base
21
+ from ai_edge_torch import lowertools
22
+ from ai_edge_torch import model
23
+ from ai_edge_torch._convert import fx_passes
24
+ from ai_edge_torch._convert import signature
25
+ from ai_edge_torch.generative import fx_passes as generative_fx_passes
26
+ from ai_edge_torch.quantize import quant_config as qcfg
27
+ import torch
28
+
29
+
30
+ def _run_convert_passes(
31
+ exported_program: torch.export.ExportedProgram,
32
+ ) -> torch.export.ExportedProgram:
33
+ exported_program = generative_fx_passes.run_generative_passes(
34
+ exported_program
35
+ )
36
+
37
+ passes = [
38
+ fx_passes.BuildInterpolateCompositePass(),
39
+ fx_passes.CanonicalizePass(),
40
+ fx_passes.OptimizeLayoutTransposesPass(),
41
+ fx_passes.CanonicalizePass(),
42
+ fx_passes.BuildAtenCompositePass(),
43
+ fx_passes.CanonicalizePass(),
44
+ fx_passes.RemoveNonUserOutputsPass(),
45
+ fx_passes.CanonicalizePass(),
46
+ ]
47
+
48
+ # Debuginfo is not injected automatically by odml_torch. Only inject
49
+ # debuginfo via fx pass when using torch_xla.
50
+ if ai_edge_torch.config.use_torch_xla:
51
+ passes += [
52
+ fx_passes.InjectMlirDebuginfoPass(),
53
+ fx_passes.CanonicalizePass(),
54
+ ]
55
+
56
+ exported_program = fx_pass_base.run_passes(exported_program, passes)
57
+ return exported_program
58
+
59
+
60
+ def _warn_training_modules(signatures: list[signature.Signature]):
61
+ """Warns the user if the module is in training mode (.eval not called)."""
62
+ for sig in signatures:
63
+ if not sig.module.training:
64
+ continue
65
+
66
+ message = (
67
+ "Your model {sig_name}is converted in training mode. Please set the"
68
+ " module in evaluation mode with `module.eval()` for better on-device"
69
+ " performance and compatibility."
70
+ )
71
+ if len(signatures) == 1 and sig.name == model.DEFAULT_SIGNATURE_NAME:
72
+ # User does not specify any signature names explicitly.
73
+ message = message.format(sig_name="")
74
+ else:
75
+ message = message.format(sig_name=f'"{sig.name}" ')
76
+
77
+ logging.warning(message)
78
+
79
+
80
+ def convert_signatures(
81
+ signatures: list[signature.Signature],
82
+ *,
83
+ strict_export: Union[Literal["auto"], bool] = True,
84
+ quant_config: Optional[qcfg.QuantConfig] = None,
85
+ _tfl_converter_flags: Optional[dict[str, Any]] = None,
86
+ _saved_model_dir: Optional[str] = None,
87
+ ) -> model.TfLiteModel:
88
+ """Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
89
+
90
+ Args:
91
+ signatures: The list of 'signature.Signature' objects containing PyTorch
92
+ modules to be converted.
93
+ strict_export: Experimental `strict` arg for torch.export.export. When
94
+ enabled, the export function will trace the program through TorchDynamo
95
+ and ensure the soundness of the exported graph. When
96
+ strict_export="auto", the function will try to export module in both
97
+ modes and use the first one succeeds for downstream conversion.
98
+ quant_config: User-defined quantization method and scheme of the model.
99
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the
100
+ underlying tflite converter.
101
+ _saved_model_dir: Directory for the intermediate saved model. If not
102
+ specified, a random temporary directory would be used.
103
+
104
+ Returns:
105
+ The converted `model.TfLiteModel` object.
106
+ """
107
+ if _tfl_converter_flags is None:
108
+ _tfl_converter_flags = {}
109
+
110
+ _warn_training_modules(signatures)
111
+
112
+ def export(*args, **kwargs):
113
+ nonlocal strict_export
114
+ if strict_export == "auto":
115
+ try:
116
+ exported_program = torch.export.export(*args, **kwargs, strict=True)
117
+ except Exception:
118
+ logging.warning(
119
+ "torch.export.export(..., strict=True) failed. Retrying with"
120
+ " strict=False"
121
+ )
122
+ exported_program = torch.export.export(*args, **kwargs, strict=False)
123
+ elif not strict_export:
124
+ exported_program = torch.export.export(*args, **kwargs, strict=False)
125
+ else:
126
+ exported_program = torch.export.export(*args, **kwargs, strict=True)
127
+
128
+ if hasattr(torch._decomp, "_decomp_table_to_post_autograd_aten"):
129
+ # Available after torch 2.5.0: `_decomp_table_to_post_autograd_aten` is a
130
+ # stop-gap table which replicates the old behaviour of post-dispatch IR.
131
+ # This could help ensure the collection of aten ops remaining still as the
132
+ # implementation of torch.export changes.
133
+ exported_program = exported_program.run_decompositions(
134
+ torch._decomp._decomp_table_to_post_autograd_aten()
135
+ )
136
+ return exported_program
137
+
138
+ exported_programs: torch.export.ExportedProgram = [
139
+ export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
140
+ for sig in signatures
141
+ ]
142
+
143
+ # Apply default fx passes
144
+ exported_programs = list(map(_run_convert_passes, exported_programs))
145
+ tflite_model = lowertools.exported_programs_to_tflite(
146
+ exported_programs,
147
+ signatures,
148
+ quant_config=quant_config,
149
+ _tfl_converter_flags=_tfl_converter_flags,
150
+ _saved_model_dir=_saved_model_dir,
151
+ )
152
+
153
+ return model.TfLiteModel(tflite_model)
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Any
17
+
18
+ from ai_edge_torch.quantize import quant_config as qcfg
19
+ import tensorflow as tf
20
+
21
+
22
+ def apply_tfl_converter_flags(
23
+ converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict[str, Any]
24
+ ):
25
+ """Applies TFLite converter flags to the converter.
26
+
27
+ Args:
28
+ converter: TFLite converter.
29
+ tfl_converter_flags: TFLite converter flags.
30
+ """
31
+
32
+ def _set_converter_flag(path: list[Any]):
33
+ if len(path) < 2:
34
+ raise ValueError("Expecting at least two values in the path.")
35
+
36
+ target_obj = converter
37
+ for idx in range(len(path) - 2):
38
+ target_obj = getattr(target_obj, path[idx])
39
+
40
+ setattr(target_obj, path[-2], path[-1])
41
+
42
+ def _iterate_dict_tree(flags_dict: dict[str, Any], path: list[Any]):
43
+ for key, value in flags_dict.items():
44
+ path.append(key)
45
+ if isinstance(value, dict):
46
+ _iterate_dict_tree(value, path)
47
+ else:
48
+ path.append(value)
49
+ _set_converter_flag(path)
50
+ path.pop()
51
+ path.pop()
52
+
53
+ _iterate_dict_tree(tfl_converter_flags, [])
54
+
55
+
56
+ def set_tfl_converter_quant_flags(
57
+ converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
58
+ ):
59
+ if quant_config is not None:
60
+ quantizer_mode = quant_config._quantizer_mode
61
+ if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
62
+ converter._experimental_qdq_conversion_mode = "DYNAMIC"
63
+ elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
64
+ converter._experimental_qdq_conversion_mode = "STATIC"
@@ -0,0 +1,270 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Literal, Optional, Tuple, Union
19
+
20
+ from ai_edge_torch import model
21
+ from ai_edge_torch._convert import conversion
22
+ from ai_edge_torch._convert import signature as signature_module
23
+ from ai_edge_torch.quantize import quant_config as qcfg
24
+ import torch
25
+
26
+
27
+ class Converter:
28
+ """A converter for converting PyTorch models to edge models.
29
+
30
+ This class allows adding multiple signatures to the converted edge model.
31
+ """
32
+
33
+ def __init__(self):
34
+ self._signatures: list[signature_module.Signature] = []
35
+
36
+ def signature(
37
+ self,
38
+ name: str,
39
+ module: torch.nn.Module,
40
+ sample_args=None,
41
+ sample_kwargs=None,
42
+ *,
43
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
44
+ ) -> Converter:
45
+ """Functions as an alias to `add_signature`."""
46
+ return self.add_signature(
47
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
48
+ )
49
+
50
+ def add_signature(
51
+ self,
52
+ name: str,
53
+ module: torch.nn.Module,
54
+ sample_args=None,
55
+ sample_kwargs=None,
56
+ *,
57
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
58
+ ) -> Converter:
59
+ """Allows adding a new named torch model along with sample args to the conversion.
60
+
61
+ Args:
62
+ name: The name of the signature included in the converted edge model.
63
+ module: The torch module to be converted.
64
+ sample_args: Tuple of tensors by which the torch module will be traced
65
+ with prior to conversion.
66
+ sample_kwargs: Dict of str to tensor by which the torch module will be
67
+ traced with prior to conversion.
68
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
69
+ specifications for each input in original order. See
70
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
71
+ details.
72
+
73
+ Returns:
74
+ The converter object itself.
75
+
76
+ Raises:
77
+ ValueError: If a signature with the provided name already exists.
78
+ """
79
+
80
+ if name in [sig.name for sig in self._signatures]:
81
+ raise ValueError(
82
+ f"A signature with the provided name ({name}) is already added."
83
+ )
84
+
85
+ if sample_args is None and sample_kwargs is None:
86
+ raise ValueError("sample_args or sample_kwargs must be provided.")
87
+
88
+ self._signatures.append(
89
+ signature_module.Signature(
90
+ name,
91
+ module,
92
+ sample_args,
93
+ sample_kwargs,
94
+ dynamic_shapes=dynamic_shapes,
95
+ )
96
+ )
97
+ return self
98
+
99
+ def convert(
100
+ self,
101
+ module: torch.nn.Module = None,
102
+ sample_args=None,
103
+ sample_kwargs=None,
104
+ *,
105
+ strict_export: Union[Literal["auto"], bool] = True,
106
+ quant_config: Optional[qcfg.QuantConfig] = None,
107
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
108
+ _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
109
+ _saved_model_dir: Optional[str] = None,
110
+ ) -> model.TfLiteModel:
111
+ """Finalizes the conversion and produces an edge model.
112
+
113
+ This could be called with no arguments as follows:
114
+
115
+ edge_model = Converter().signature(name, module, args).convert()
116
+
117
+ Or it could be used to set the default signature for the converted edge
118
+ model:
119
+
120
+ edge_model = Converter().convert(module, args)
121
+
122
+ Args:
123
+ module: The torch module to be converted.
124
+ sample_args: Tuple of tensors by which the torch module will be traced
125
+ with prior to conversion.
126
+ sample_kwargs: Dict of str to tensor by which the torch module will be
127
+ traced with prior to conversion.
128
+ strict_export: Experimental `strict` arg for torch.export.export. When
129
+ enabled, the export function will trace the program through TorchDynamo
130
+ and ensure the soundness of the exported graph. When
131
+ strict_export="auto", the function will try to export module in both
132
+ modes and use the first one succeeds for downstream conversion.
133
+ quant_config: User-defined quantization method and scheme of the model.
134
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
135
+ specifications for each input in original order. See
136
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
137
+ details.
138
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for
139
+ the underlying converter. This gives access to an implementation detail
140
+ of this function and so needs to be treated as such. Please do not rely
141
+ on this parameter except for local debugging as this can be removed in a
142
+ future release.
143
+ _saved_model_dir: Directory for the intermediate saved model. If not
144
+ specified, a random temporary directory would be used.
145
+
146
+ Returns:
147
+ The converted edge model.
148
+
149
+ Raises:
150
+ ValueError: If the arguments are not provided as expected. See the example
151
+ in this functions's comment.
152
+ """
153
+ if _ai_edge_converter_flags is None:
154
+ _ai_edge_converter_flags = {}
155
+
156
+ if module is not None:
157
+ if (
158
+ sample_args is not None or sample_kwargs is not None
159
+ ): # both module and args provided
160
+ self.add_signature(
161
+ model.DEFAULT_SIGNATURE_NAME,
162
+ module,
163
+ sample_args,
164
+ sample_kwargs,
165
+ dynamic_shapes=dynamic_shapes,
166
+ )
167
+ else: # module is provided but not args
168
+ raise ValueError(
169
+ "sample_args or sample_kwargs must be provided if a module is"
170
+ " specified."
171
+ )
172
+ return conversion.convert_signatures(
173
+ self._signatures,
174
+ strict_export=strict_export,
175
+ quant_config=quant_config,
176
+ _tfl_converter_flags=_ai_edge_converter_flags,
177
+ _saved_model_dir=_saved_model_dir,
178
+ )
179
+
180
+
181
+ def signature(
182
+ name: str,
183
+ module: torch.nn.Module,
184
+ sample_args=None,
185
+ sample_kwargs=None,
186
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
187
+ ) -> Converter:
188
+ """Initiates a Converter object with the provided signature.
189
+
190
+ Args:
191
+ name: The name of the signature included in the converted edge model.
192
+ module: The torch module to be converted.
193
+ sample_args: Tuple of tensors by which the torch module will be traced with
194
+ prior to conversion.
195
+ sample_kwargs: Dict of str to tensor by which the torch module will be
196
+ traced with prior to conversion.
197
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
198
+ specifications for each input in original order. See
199
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
200
+ details.
201
+
202
+ Returns:
203
+ A Converter object with the provided signature.
204
+
205
+ Example:
206
+ converter = ai_edge_torch.signature(name, module, args)
207
+ edge_model = converter.convert()
208
+ """
209
+ return Converter().signature(
210
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
211
+ )
212
+
213
+
214
+ def convert(
215
+ module: torch.nn.Module = None,
216
+ sample_args=None,
217
+ sample_kwargs=None,
218
+ *,
219
+ strict_export: Union[Literal["auto"], bool] = True,
220
+ quant_config: Optional[qcfg.QuantConfig] = None,
221
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
222
+ _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
223
+ _saved_model_dir: Optional[str] = None,
224
+ ) -> model.TfLiteModel:
225
+ """Converts a PyTorch model to an edge model with a default signature.
226
+
227
+ Args:
228
+ module: The torch module to be converted.
229
+ sample_args: Tuple of tensors by which the torch module will be traced with
230
+ prior to conversion.
231
+ sample_kwargs: Dict of str to tensor by which the torch module will be
232
+ traced with prior to conversion.
233
+ strict_export: Experimental `strict` arg for torch.export.export. When
234
+ enabled, the export function will trace the program through TorchDynamo
235
+ and ensure the soundness of the exported graph. When strict_export="auto",
236
+ the function will try to export module in both modes and use the first one
237
+ succeeds for downstream conversion.
238
+ quant_config: User-defined quantization method and scheme of the model.
239
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
240
+ specifications for each input in original order. See
241
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
242
+ details.
243
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for the
244
+ underlying converter. This gives access to an implementation detail of
245
+ this function and so needs to be treated as such. Please do not rely on
246
+ this parameter except for local debugging as this can be removed in a
247
+ future release.
248
+ _saved_model_dir: Directory for the intermediate saved model. If not
249
+ specified, a random temporary directory would be used.
250
+
251
+ Returns:
252
+ The converted edge model.
253
+
254
+ Example:
255
+ edge_model = ai_edge_torch.convert(module, args)
256
+ """
257
+
258
+ if _ai_edge_converter_flags is None:
259
+ _ai_edge_converter_flags = {}
260
+
261
+ return Converter().convert(
262
+ module,
263
+ sample_args,
264
+ sample_kwargs,
265
+ strict_export=strict_export,
266
+ quant_config=quant_config,
267
+ dynamic_shapes=dynamic_shapes,
268
+ _ai_edge_converter_flags=_ai_edge_converter_flags,
269
+ _saved_model_dir=_saved_model_dir,
270
+ )
@@ -0,0 +1,23 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Sequence, Union
17
+
18
+ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
+ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
+ from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
+ from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
23
+ from ai_edge_torch.fx_pass_base import CanonicalizePass