ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,50 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Builder utils for individual components.
16
+
17
+ import ai_edge_torch.generative.layers.unet.model_config as unet_config
18
+ from torch import nn
19
+
20
+
21
+ def build_upsampling(config: unet_config.UpSamplingConfig):
22
+ if config.mode == unet_config.SamplingType.NEAREST:
23
+ return nn.UpsamplingNearest2d(scale_factor=config.scale_factor)
24
+ elif config.mode == unet_config.SamplingType.BILINEAR:
25
+ return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor)
26
+ else:
27
+ raise ValueError("Unsupported upsampling type.")
28
+
29
+
30
+ def build_downsampling(config: unet_config.DownSamplingConfig):
31
+ if config.mode == unet_config.SamplingType.AVERAGE:
32
+ return nn.AvgPool2d(
33
+ config.kernel_size, config.stride, padding=config.padding
34
+ )
35
+ elif config.mode == unet_config.SamplingType.CONVOLUTION:
36
+ out_channels = (
37
+ config.in_channels
38
+ if config.out_channels is None
39
+ else config.out_channels
40
+ )
41
+ padding = (0, 1, 0, 1) if config.padding == 0 else config.padding
42
+ return nn.Conv2d(
43
+ config.in_channels,
44
+ out_channels=out_channels,
45
+ kernel_size=config.kernel_size,
46
+ stride=config.stride,
47
+ padding=padding,
48
+ )
49
+ else:
50
+ raise ValueError("Unsupported downsampling type.")
@@ -0,0 +1,282 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # UNet configuration class.
17
+ import dataclasses
18
+ import enum
19
+ from typing import List, Optional
20
+
21
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
22
+
23
+
24
+ @enum.unique
25
+ class SamplingType(enum.Enum):
26
+ NEAREST = enum.auto()
27
+ BILINEAR = enum.auto()
28
+ AVERAGE = enum.auto()
29
+ CONVOLUTION = enum.auto()
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class UpSamplingConfig:
34
+ mode: SamplingType
35
+ scale_factor: float
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class DownSamplingConfig:
40
+ mode: SamplingType
41
+ in_channels: int
42
+ kernel_size: int
43
+ stride: int
44
+ padding: int
45
+ out_channels: Optional[int] = None
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class ResidualBlock2DConfig:
50
+ in_channels: int
51
+ hidden_channels: int
52
+ out_channels: int
53
+ hidden_channels: int
54
+ normalization_config: layers_cfg.NormalizationConfig
55
+ activation_config: layers_cfg.ActivationConfig
56
+ # Optional time embedding channels if the residual block takes a time embedding context as input
57
+ time_embedding_channels: Optional[int] = None
58
+ residual_out_channels: Optional[int] = None
59
+
60
+
61
+ @dataclasses.dataclass
62
+ class AttentionBlock2DConfig:
63
+ dim: int
64
+ normalization_config: layers_cfg.NormalizationConfig
65
+ attention_config: layers_cfg.AttentionConfig
66
+ enable_hlfb: bool = True
67
+ attention_batch_size: int = 1
68
+ hidden_dim: Optional[int] = None
69
+
70
+
71
+ @dataclasses.dataclass
72
+ class CrossAttentionBlock2DConfig:
73
+ query_dim: int
74
+ cross_dim: int
75
+ hidden_dim: int
76
+ output_dim: int
77
+ normalization_config: layers_cfg.NormalizationConfig
78
+ attention_config: layers_cfg.AttentionConfig
79
+ enable_hlfb: bool = True
80
+ attention_batch_size: int = 1
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class FeedForwardBlock2DConfig:
85
+ dim: int
86
+ hidden_dim: int
87
+ normalization_config: layers_cfg.NormalizationConfig
88
+ activation_config: layers_cfg.ActivationConfig
89
+ use_bias: bool
90
+
91
+
92
+ @dataclasses.dataclass
93
+ class TransformerBlock2DConfig:
94
+ pre_conv_normalization_config: layers_cfg.NormalizationConfig
95
+ attention_block_config: AttentionBlock2DConfig
96
+ cross_attention_block_config: CrossAttentionBlock2DConfig
97
+ feed_forward_block_config: FeedForwardBlock2DConfig
98
+
99
+
100
+ @dataclasses.dataclass
101
+ class UpDecoderBlock2DConfig:
102
+ in_channels: int
103
+ out_channels: int
104
+ normalization_config: layers_cfg.NormalizationConfig
105
+ activation_config: layers_cfg.ActivationConfig
106
+ num_layers: int
107
+ # The dimension of output channels of previous connected block
108
+ prev_out_channels: Optional[int] = None
109
+ # Optional time embedding channels if the residual blocks take a time embedding as input
110
+ time_embedding_channels: Optional[int] = None
111
+ # Whether to add upsample operation after residual blocks
112
+ add_upsample: bool = True
113
+ # Whether to add a conv2d layer after upsample
114
+ upsample_conv: bool = True
115
+ # Optional sampling config if add_upsample is True.
116
+ sampling_config: Optional[UpSamplingConfig] = None
117
+ # Optional config of transformer blocks interleaved with residual blocks
118
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
119
+ # Optional dimension of context tensor if context tensor is given as input.
120
+ context_dim: Optional[int] = None
121
+
122
+
123
+ @dataclasses.dataclass
124
+ class SkipUpDecoderBlock2DConfig:
125
+ in_channels: int
126
+ out_channels: int
127
+ # The dimension of output channels of previous connected block
128
+ prev_out_channels: int
129
+ normalization_config: layers_cfg.NormalizationConfig
130
+ activation_config: layers_cfg.ActivationConfig
131
+ num_layers: int
132
+ # Optional time embedding channels if the residual blocks take a time embedding as input
133
+ time_embedding_channels: Optional[int] = None
134
+ # Whether to add upsample operation after residual blocks
135
+ add_upsample: bool = True
136
+ # Whether to add a conv2d layer after upsample
137
+ upsample_conv: bool = True
138
+ # Optional sampling config if add_upsample is True.
139
+ sampling_config: Optional[UpSamplingConfig] = None
140
+ # Optional config of transformer blocks interleaved with residual blocks
141
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
142
+ # Optional dimension of context tensor if context tensor is given as input.
143
+ context_dim: Optional[int] = None
144
+ sub_block_channels: Optional[tuple] = None
145
+ hidden_channels: Optional[int] = None
146
+
147
+
148
+ @dataclasses.dataclass
149
+ class DownEncoderBlock2DConfig:
150
+ in_channels: int
151
+ out_channels: int
152
+ normalization_config: layers_cfg.NormalizationConfig
153
+ activation_config: layers_cfg.ActivationConfig
154
+ num_layers: int
155
+ # Padding for the downsampling convolution.
156
+ padding: int = 1
157
+ # Optional time embedding channels if the residual blocks take a time embedding as input
158
+ time_embedding_channels: Optional[int] = None
159
+ # Whether to add downsample operation after residual blocks
160
+ add_downsample: bool = True
161
+ # Optional sampling config if add_upsample is True.
162
+ sampling_config: Optional[DownSamplingConfig] = None
163
+ # Optional config of transformer blocks interleaved with residual blocks
164
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
165
+ # Optional dimension of context tensor if context tensor is given as input.
166
+ context_dim: Optional[int] = None
167
+ hidden_channels: Optional[int] = None
168
+
169
+
170
+ @dataclasses.dataclass
171
+ class MidBlock2DConfig:
172
+ in_channels: int
173
+ normalization_config: layers_cfg.NormalizationConfig
174
+ activation_config: layers_cfg.ActivationConfig
175
+ num_layers: int
176
+ # Optional time embedding channels if the residual blocks take a time embedding context as input
177
+ time_embedding_channels: Optional[int] = None
178
+ # Optional config of attention blocks interleaved with residual blocks
179
+ attention_block_config: Optional[AttentionBlock2DConfig] = None
180
+ # Optional config of transformer blocks interleaved with residual blocks
181
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
182
+ # Optional dimension of context tensor if context tensor is given as input.
183
+ context_dim: Optional[int] = None
184
+
185
+
186
+ @dataclasses.dataclass
187
+ class AutoEncoderConfig:
188
+ """Configurations of encoder/decoder in the autoencoder model."""
189
+
190
+ # The activation type of encoder/decoder blocks.
191
+ activation_config: layers_cfg.ActivationConfig
192
+
193
+ # The output channels of each block.
194
+ block_out_channels: List[int]
195
+
196
+ # Number of channels in the input image.
197
+ in_channels: int
198
+
199
+ # Number of channels in the output.
200
+ out_channels: int
201
+
202
+ # Number of channels in the latent space.
203
+ latent_channels: int
204
+
205
+ # The component-wise standard deviation of the trained latent space computed using the first batch of the
206
+ # training set. This is used to scale the latent space to have unit variance when training the diffusion
207
+ # model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
208
+ # diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
209
+ # / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
210
+ # Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
211
+ scaling_factor: float
212
+
213
+ # The layesr number of each encoder/decoder block.
214
+ layers_per_block: int
215
+
216
+ # The normalization config.
217
+ normalization_config: layers_cfg.NormalizationConfig
218
+
219
+ # The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
220
+ mid_block_config: MidBlock2DConfig
221
+
222
+
223
+ @dataclasses.dataclass
224
+ class DiffusionModelConfig:
225
+ """Configurations of Diffusion model."""
226
+
227
+ # Number of channels in the input tensor.
228
+ in_channels: int
229
+
230
+ # Number of channels in the output tensor.
231
+ out_channels: int
232
+
233
+ # The output channels of each block.
234
+ block_out_channels: List[int]
235
+
236
+ # The layesr number of each block.
237
+ layers_per_block: int
238
+
239
+ # The padding to use for the downsampling.
240
+ downsample_padding: int
241
+
242
+ # Normalization config used in residual blocks.
243
+ residual_norm_config: layers_cfg.NormalizationConfig
244
+
245
+ # Activation config used in residual blocks
246
+ residual_activation_type: layers_cfg.ActivationType
247
+
248
+ # The batch size used in transformer blocks, for attention layers.
249
+ transformer_batch_size: int
250
+
251
+ # The number of attention heads used in transformer blocks.
252
+ transformer_num_attention_heads: int
253
+
254
+ # The dimension of cross attention used in transformer blocks.
255
+ transformer_cross_attention_dim: int
256
+
257
+ # Normalization config used in prev conv layer of transformer blocks.
258
+ transformer_pre_conv_norm_config: layers_cfg.NormalizationConfig
259
+
260
+ # Normalization config used in transformer blocks.
261
+ transformer_norm_config: layers_cfg.NormalizationConfig
262
+
263
+ # Activation type of feed forward used in transformer blocks.
264
+ transformer_ff_activation_type: layers_cfg.ActivationType
265
+
266
+ # Number of layers in mid block.
267
+ mid_block_layers: int
268
+
269
+ # Dimension of time embedding.
270
+ time_embedding_dim: int
271
+
272
+ # Time embedding dimensions for blocks.
273
+ time_embedding_blocks_dim: int
274
+
275
+ # Normalization config used for final layer
276
+ final_norm_config: layers_cfg.NormalizationConfig
277
+
278
+ # Activation type used in final layer
279
+ final_activation_type: layers_cfg.ActivationType
280
+
281
+ # Whether to enable StableHLO composite ops in the model.
282
+ enable_hlfb: bool = False
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import ai_edge_torch
17
+ from ai_edge_torch.generative.examples.gemma import gemma1
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
+ from ai_edge_torch.generative.quantize import quant_recipes
20
+ from ai_edge_torch.generative.utilities import model_builder
21
+ import numpy as np
22
+ import torch
23
+
24
+
25
+ def main():
26
+ # Build a PyTorch model as usual
27
+ config = gemma1.get_fake_model_config()
28
+ model = model_builder.DecoderOnlyModel(config).eval()
29
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
30
+ tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
31
+ tokens[0, :4] = idx
32
+ input_pos = torch.arange(0, 10, dtype=torch.int)
33
+ kv = kv_utils.KVCache.from_model_config(config)
34
+
35
+ # Create a quantization recipe to be applied to the model
36
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
37
+ print(quant_config)
38
+
39
+ # Convert with quantization
40
+ edge_model = ai_edge_torch.convert(
41
+ model, (tokens, input_pos, kv), quant_config=quant_config
42
+ )
43
+ edge_model.export("/tmp/gemma_2b_quantized.tflite")
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import enum
17
+
18
+
19
+ @enum.unique
20
+ class Dtype(enum.Enum):
21
+ """Data types and precision of tensors."""
22
+
23
+ FP32 = enum.auto()
24
+ FP16 = enum.auto()
25
+ INT8 = enum.auto()
26
+
27
+
28
+ @enum.unique
29
+ class Algorithm(enum.Enum):
30
+ """Algorithm used to calculate quantization parameters.
31
+
32
+ Attributes:
33
+ MIN_MAX: Maps the min/max of floating point space to the min/max of
34
+ quantized space and quantize uniformly.
35
+ FLOAT_CAST: Casts a float to another float of a different type.
36
+ """
37
+
38
+ MIN_MAX = enum.auto()
39
+ FLOAT_CAST = enum.auto()
40
+
41
+
42
+ @enum.unique
43
+ class Mode(enum.Enum):
44
+ """Mode of quantization.
45
+
46
+ Attributes:
47
+ DYNAMIC_RANGE: Quantize activations during runtime and weights statically to
48
+ perform computation in integers.
49
+ WEIGHT_ONLY: Quantize weights statically and dequantize during runtime to
50
+ perform computation in floating points.
51
+ """
52
+
53
+ DYNAMIC_RANGE = enum.auto()
54
+ WEIGHT_ONLY = enum.auto()
55
+
56
+
57
+ @enum.unique
58
+ class Granularity(enum.Enum):
59
+ """Granularity of quantization parameters.
60
+
61
+ Attributes:
62
+ NONE: Granularity not applicable to this quantization scheme.
63
+ CHANNELWISE: Or per-channel quantization. Each channel of relevant tensors
64
+ is quantized independently of one another.
65
+ """
66
+
67
+ NONE = enum.auto()
68
+ CHANNELWISE = enum.auto()
@@ -0,0 +1,154 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Union
18
+
19
+ from ai_edge_torch.generative.quantize import quant_attrs
20
+ from ai_edge_torch.generative.quantize import supported_schemes
21
+
22
+
23
+ @dataclass
24
+ class LayerQuantRecipe:
25
+ """Quantization recipe for a single Edge Generative API layer (e.g. Attention).
26
+
27
+ Generic layer-scoped quantization recipe that specifies how this layer should
28
+ be quantized by the Edge Generative API. This is applicable to layers
29
+ implemented
30
+ in ai_edge_torch/generative/layers/. Combinations of attributes that are not
31
+ supported during runtime will be detected when .verify() is called.
32
+
33
+ Attributes:
34
+ activation_dtype: Desired data type of activation tensors.
35
+ weight_dtype: Desired data type of weight tensors.
36
+ mode: Type of quantization.
37
+ algorithm: Algorithm for calculating quantization parameters.
38
+ granularity: Granularity of quantization.
39
+ """
40
+
41
+ activation_dtype: quant_attrs.Dtype
42
+ weight_dtype: quant_attrs.Dtype
43
+ mode: quant_attrs.Mode
44
+ algorithm: quant_attrs.Algorithm
45
+ granularity: quant_attrs.Granularity
46
+
47
+ def __str__(self):
48
+ return (
49
+ f'(a:{self.activation_dtype.name}, '
50
+ f'w:{self.weight_dtype.name}, '
51
+ f'{self.mode.name}, '
52
+ f'{self.algorithm.name}, '
53
+ f'{self.granularity.name})'
54
+ )
55
+
56
+ __repr__ = __str__
57
+
58
+ def verify(self):
59
+ """Checks if all attributes configured are supported in runtime.
60
+
61
+ Raises:
62
+ ValueError: If any attributes are incompatible.
63
+ """
64
+ is_valid = False
65
+ for supported in supported_schemes.get_supported_layer_schemes():
66
+ if (
67
+ self.activation_dtype == supported[0]
68
+ and self.weight_dtype == supported[1]
69
+ and self.mode == supported[2]
70
+ and self.algorithm == supported[3]
71
+ and self.granularity == supported[4]
72
+ ):
73
+ is_valid = True
74
+ break
75
+
76
+ if not is_valid:
77
+ raise ValueError(
78
+ 'Unsupported LayerQuantRecipe configuration. See'
79
+ ' get_supported_recipe_matrix()'
80
+ )
81
+
82
+
83
+ @dataclass
84
+ class GenerativeQuantRecipe:
85
+ """Quantization recipe for a model composed of the Edge Generative API layers.
86
+
87
+ Some layers can be specified with different `LayerQuantRecipe` for each block
88
+ by
89
+ providing a dictionary keyed by the TransformerBlock index, e.g. attention
90
+ and feedforward. For example,
91
+
92
+ ```
93
+ default = LayerQuantRecipeA
94
+ attention = { 2: LayerQuantRecipeB }
95
+ feedforward = { 3: LayerQuantRecipeC }
96
+ ```
97
+
98
+ will apply LayerQuantRecipeA to the entire model, overriden by
99
+ LayerQuantRecipeB for the TransformerBlock[2].attention layer and
100
+ LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
101
+ with invalid indices will be ignored.
102
+
103
+ Attributes:
104
+ default: The quantization recipe for global scope of the model.
105
+ embedding: Recipe for the embedding table.
106
+ attention: Recipe for the attention blocks. This could be specified with
107
+ different LayerQuantRecipe for each block by providing a dictionary keyed
108
+ by the TransformerBlock index.
109
+ feedforward: Recipe for the feedforward layers. This could be specified with
110
+ different LayerQuantRecipe for each block by providing a dictionary keyed
111
+ by the TransformerBlock index.
112
+ """
113
+
114
+ default: Optional[LayerQuantRecipe] = None
115
+ embedding: Optional[LayerQuantRecipe] = None
116
+ attention: Union[
117
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
118
+ ] = None
119
+ feedforward: Union[
120
+ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
121
+ ] = None
122
+
123
+ def __str__(self):
124
+ return f"""GenerativeQuantRecipe(
125
+ Default: {self.default}
126
+ Embedding: {self.embedding}
127
+ Attention: {self.attention}
128
+ Feedforward: {self.feedforward}
129
+ )"""
130
+
131
+ __repr__ = __str__
132
+
133
+ def verify(self):
134
+ """Checks if the recipe configured can be supported in runtime.
135
+
136
+ Raises:
137
+ ValueError: If the recipe configured is invalid or unsupported.
138
+ """
139
+ if self.default is not None:
140
+ self.default.verify()
141
+ if self.embedding is not None:
142
+ self.embedding.verify()
143
+ if self.attention is not None:
144
+ if isinstance(self.attention, dict):
145
+ for recipe in self.attention.values():
146
+ recipe.verify()
147
+ else:
148
+ self.attention.verify()
149
+ if self.feedforward is not None:
150
+ if isinstance(self.feedforward, dict):
151
+ for recipe in self.feedforward.values():
152
+ recipe.verify()
153
+ else:
154
+ self.feedforward.verify()
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Helper functions to construct custom quantization recipes.
17
+
18
+ These are intended for more advanced users who want to configure their own
19
+ quantization recipes. For pre-constructed recipes, use `quant_recipes.py`
20
+ instead.
21
+
22
+ Typical usage example:
23
+
24
+ 1. Applying a single layer recipe to the entire model
25
+
26
+ quant_recipe.GenerativeQuantRecipe(
27
+ default=quant_recipe_utils.create_layer_quant_int8_dynamic()
28
+ )
29
+ """
30
+
31
+ from ai_edge_torch.generative.quantize import quant_attrs
32
+ from ai_edge_torch.generative.quantize import quant_recipe
33
+
34
+
35
+ def create_layer_quant_int8_dynamic() -> quant_recipe.LayerQuantRecipe:
36
+ return quant_recipe.LayerQuantRecipe(
37
+ activation_dtype=quant_attrs.Dtype.FP32,
38
+ weight_dtype=quant_attrs.Dtype.INT8,
39
+ mode=quant_attrs.Mode.DYNAMIC_RANGE,
40
+ algorithm=quant_attrs.Algorithm.MIN_MAX,
41
+ granularity=quant_attrs.Granularity.CHANNELWISE,
42
+ )
43
+
44
+
45
+ def create_layer_quant_int8_weight_only() -> quant_recipe.LayerQuantRecipe:
46
+ return quant_recipe.LayerQuantRecipe(
47
+ activation_dtype=quant_attrs.Dtype.FP32,
48
+ weight_dtype=quant_attrs.Dtype.INT8,
49
+ mode=quant_attrs.Mode.WEIGHT_ONLY,
50
+ algorithm=quant_attrs.Algorithm.MIN_MAX,
51
+ granularity=quant_attrs.Granularity.CHANNELWISE,
52
+ )
53
+
54
+
55
+ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
56
+ return quant_recipe.LayerQuantRecipe(
57
+ activation_dtype=quant_attrs.Dtype.FP32,
58
+ weight_dtype=quant_attrs.Dtype.FP16,
59
+ mode=quant_attrs.Mode.WEIGHT_ONLY,
60
+ algorithm=quant_attrs.Algorithm.FLOAT_CAST,
61
+ granularity=quant_attrs.Granularity.NONE,
62
+ )