ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,398 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import ai_edge_torch.generative.layers.builder as layers_builder
17
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
18
+ from ai_edge_torch.generative.layers.unet import blocks_2d
19
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
20
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
21
+ import torch
22
+ from torch import nn
23
+
24
+ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
25
+ post_quant_conv="first_stage_model.post_quant_conv",
26
+ conv_in="first_stage_model.decoder.conv_in",
27
+ mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
28
+ residual_block_tensor_names=[
29
+ stable_diffusion_loader.ResidualBlockTensorNames(
30
+ norm_1="first_stage_model.decoder.mid.block_1.norm1",
31
+ norm_2="first_stage_model.decoder.mid.block_1.norm2",
32
+ conv_1="first_stage_model.decoder.mid.block_1.conv1",
33
+ conv_2="first_stage_model.decoder.mid.block_1.conv2",
34
+ ),
35
+ stable_diffusion_loader.ResidualBlockTensorNames(
36
+ norm_1="first_stage_model.decoder.mid.block_2.norm1",
37
+ norm_2="first_stage_model.decoder.mid.block_2.norm2",
38
+ conv_1="first_stage_model.decoder.mid.block_2.conv1",
39
+ conv_2="first_stage_model.decoder.mid.block_2.conv2",
40
+ ),
41
+ ],
42
+ attention_block_tensor_names=[
43
+ stable_diffusion_loader.AttentionBlockTensorNames(
44
+ norm="first_stage_model.decoder.mid.attn_1.norm",
45
+ q_proj="first_stage_model.decoder.mid.attn_1.q",
46
+ k_proj="first_stage_model.decoder.mid.attn_1.k",
47
+ v_proj="first_stage_model.decoder.mid.attn_1.v",
48
+ output_proj="first_stage_model.decoder.mid.attn_1.proj_out",
49
+ )
50
+ ],
51
+ ),
52
+ up_decoder_blocks_tensor_names=[
53
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
54
+ residual_block_tensor_names=[
55
+ stable_diffusion_loader.ResidualBlockTensorNames(
56
+ norm_1="first_stage_model.decoder.up.3.block.0.norm1",
57
+ norm_2="first_stage_model.decoder.up.3.block.0.norm2",
58
+ conv_1="first_stage_model.decoder.up.3.block.0.conv1",
59
+ conv_2="first_stage_model.decoder.up.3.block.0.conv2",
60
+ ),
61
+ stable_diffusion_loader.ResidualBlockTensorNames(
62
+ norm_1="first_stage_model.decoder.up.3.block.1.norm1",
63
+ norm_2="first_stage_model.decoder.up.3.block.1.norm2",
64
+ conv_1="first_stage_model.decoder.up.3.block.1.conv1",
65
+ conv_2="first_stage_model.decoder.up.3.block.1.conv2",
66
+ ),
67
+ stable_diffusion_loader.ResidualBlockTensorNames(
68
+ norm_1="first_stage_model.decoder.up.3.block.2.norm1",
69
+ norm_2="first_stage_model.decoder.up.3.block.2.norm2",
70
+ conv_1="first_stage_model.decoder.up.3.block.2.conv1",
71
+ conv_2="first_stage_model.decoder.up.3.block.2.conv2",
72
+ ),
73
+ ],
74
+ upsample_conv="first_stage_model.decoder.up.3.upsample.conv",
75
+ ),
76
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
77
+ residual_block_tensor_names=[
78
+ stable_diffusion_loader.ResidualBlockTensorNames(
79
+ norm_1="first_stage_model.decoder.up.2.block.0.norm1",
80
+ norm_2="first_stage_model.decoder.up.2.block.0.norm2",
81
+ conv_1="first_stage_model.decoder.up.2.block.0.conv1",
82
+ conv_2="first_stage_model.decoder.up.2.block.0.conv2",
83
+ ),
84
+ stable_diffusion_loader.ResidualBlockTensorNames(
85
+ norm_1="first_stage_model.decoder.up.2.block.1.norm1",
86
+ norm_2="first_stage_model.decoder.up.2.block.1.norm2",
87
+ conv_1="first_stage_model.decoder.up.2.block.1.conv1",
88
+ conv_2="first_stage_model.decoder.up.2.block.1.conv2",
89
+ ),
90
+ stable_diffusion_loader.ResidualBlockTensorNames(
91
+ norm_1="first_stage_model.decoder.up.2.block.2.norm1",
92
+ norm_2="first_stage_model.decoder.up.2.block.2.norm2",
93
+ conv_1="first_stage_model.decoder.up.2.block.2.conv1",
94
+ conv_2="first_stage_model.decoder.up.2.block.2.conv2",
95
+ ),
96
+ ],
97
+ upsample_conv="first_stage_model.decoder.up.2.upsample.conv",
98
+ ),
99
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
100
+ residual_block_tensor_names=[
101
+ stable_diffusion_loader.ResidualBlockTensorNames(
102
+ norm_1="first_stage_model.decoder.up.1.block.0.norm1",
103
+ norm_2="first_stage_model.decoder.up.1.block.0.norm2",
104
+ conv_1="first_stage_model.decoder.up.1.block.0.conv1",
105
+ conv_2="first_stage_model.decoder.up.1.block.0.conv2",
106
+ residual_layer=(
107
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut"
108
+ ),
109
+ ),
110
+ stable_diffusion_loader.ResidualBlockTensorNames(
111
+ norm_1="first_stage_model.decoder.up.1.block.1.norm1",
112
+ norm_2="first_stage_model.decoder.up.1.block.1.norm2",
113
+ conv_1="first_stage_model.decoder.up.1.block.1.conv1",
114
+ conv_2="first_stage_model.decoder.up.1.block.1.conv2",
115
+ ),
116
+ stable_diffusion_loader.ResidualBlockTensorNames(
117
+ norm_1="first_stage_model.decoder.up.1.block.2.norm1",
118
+ norm_2="first_stage_model.decoder.up.1.block.2.norm2",
119
+ conv_1="first_stage_model.decoder.up.1.block.2.conv1",
120
+ conv_2="first_stage_model.decoder.up.1.block.2.conv2",
121
+ ),
122
+ ],
123
+ upsample_conv="first_stage_model.decoder.up.1.upsample.conv",
124
+ ),
125
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
126
+ residual_block_tensor_names=[
127
+ stable_diffusion_loader.ResidualBlockTensorNames(
128
+ norm_1="first_stage_model.decoder.up.0.block.0.norm1",
129
+ norm_2="first_stage_model.decoder.up.0.block.0.norm2",
130
+ conv_1="first_stage_model.decoder.up.0.block.0.conv1",
131
+ conv_2="first_stage_model.decoder.up.0.block.0.conv2",
132
+ residual_layer=(
133
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut"
134
+ ),
135
+ ),
136
+ stable_diffusion_loader.ResidualBlockTensorNames(
137
+ norm_1="first_stage_model.decoder.up.0.block.1.norm1",
138
+ norm_2="first_stage_model.decoder.up.0.block.1.norm2",
139
+ conv_1="first_stage_model.decoder.up.0.block.1.conv1",
140
+ conv_2="first_stage_model.decoder.up.0.block.1.conv2",
141
+ ),
142
+ stable_diffusion_loader.ResidualBlockTensorNames(
143
+ norm_1="first_stage_model.decoder.up.0.block.2.norm1",
144
+ norm_2="first_stage_model.decoder.up.0.block.2.norm2",
145
+ conv_1="first_stage_model.decoder.up.0.block.2.conv1",
146
+ conv_2="first_stage_model.decoder.up.0.block.2.conv2",
147
+ ),
148
+ ],
149
+ ),
150
+ ],
151
+ final_norm="first_stage_model.decoder.norm_out",
152
+ conv_out="first_stage_model.decoder.conv_out",
153
+ )
154
+
155
+
156
+ class Decoder(nn.Module):
157
+ """The Decoder model used in Stable Diffusion.
158
+
159
+ For details, see https://arxiv.org/abs/2103.00020
160
+
161
+ Sturcture of the Decoder:
162
+
163
+ latents tensor
164
+ |
165
+
166
+ ┌───────────────────┐
167
+ │ Post Quant Conv │
168
+ └─────────┬─────────┘
169
+
170
+ ┌─────────▼─────────┐
171
+ │ ConvIn │
172
+ └─────────┬─────────┘
173
+
174
+ ┌─────────▼─────────┐
175
+ │ MidBlock2D │
176
+ └─────────┬─────────┘
177
+
178
+ ┌─────────▼─────────┐
179
+ │ UpDecoder2D │ x 4
180
+ └─────────┬─────────┘
181
+
182
+ ┌─────────▼─────────┐
183
+ │ FinalNorm │
184
+ └─────────┬─────────┘
185
+ |
186
+ ┌─────────▼─────────┐
187
+ │ Activation │
188
+ └─────────┬─────────┘
189
+ |
190
+ ┌─────────▼─────────┐
191
+ │ ConvOut │
192
+ └─────────┬─────────┘
193
+ |
194
+
195
+ Output Image
196
+ """
197
+
198
+ def __init__(self, config: unet_cfg.AutoEncoderConfig):
199
+ super().__init__()
200
+ self.config = config
201
+ self.post_quant_conv = nn.Conv2d(
202
+ config.latent_channels,
203
+ config.latent_channels,
204
+ kernel_size=1,
205
+ stride=1,
206
+ padding=0,
207
+ )
208
+ reversed_block_out_channels = list(reversed(config.block_out_channels))
209
+ self.conv_in = nn.Conv2d(
210
+ config.latent_channels,
211
+ reversed_block_out_channels[0],
212
+ kernel_size=3,
213
+ stride=1,
214
+ padding=1,
215
+ )
216
+ self.mid_block = blocks_2d.MidBlock2D(config.mid_block_config)
217
+ up_decoder_blocks = []
218
+ block_out_channels = reversed_block_out_channels[0]
219
+ for i, out_channels in enumerate(reversed_block_out_channels):
220
+ prev_output_channel = block_out_channels
221
+ block_out_channels = out_channels
222
+ not_final_block = i < len(reversed_block_out_channels) - 1
223
+ up_decoder_blocks.append(
224
+ blocks_2d.UpDecoderBlock2D(
225
+ unet_cfg.UpDecoderBlock2DConfig(
226
+ in_channels=prev_output_channel,
227
+ out_channels=block_out_channels,
228
+ normalization_config=config.normalization_config,
229
+ activation_config=config.activation_config,
230
+ num_layers=config.layers_per_block,
231
+ add_upsample=not_final_block,
232
+ upsample_conv=True,
233
+ sampling_config=unet_cfg.UpSamplingConfig(
234
+ mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
235
+ ),
236
+ )
237
+ )
238
+ )
239
+ self.up_decoder_blocks = nn.ModuleList(up_decoder_blocks)
240
+ self.final_norm = layers_builder.build_norm(
241
+ block_out_channels, config.normalization_config
242
+ )
243
+ self.act_fn = layers_builder.get_activation(config.activation_config)
244
+ self.conv_out = nn.Conv2d(
245
+ block_out_channels,
246
+ config.out_channels,
247
+ kernel_size=3,
248
+ stride=1,
249
+ padding=1,
250
+ )
251
+
252
+ def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
253
+ """Forward function of decoder model.
254
+
255
+ Args:
256
+ latents (torch.Tensor): latents space tensor.
257
+
258
+ Returns:
259
+ output decoded image tensor from decoder model.
260
+ """
261
+ x = latents_tensor / self.config.scaling_factor
262
+ x = self.post_quant_conv(x)
263
+ x = self.conv_in(x)
264
+ x = self.mid_block(x)
265
+ for up_decoder_block in self.up_decoder_blocks:
266
+ x = up_decoder_block(x)
267
+ x = self.final_norm(x)
268
+ x = self.act_fn(x)
269
+ x = self.conv_out(x)
270
+ return x
271
+
272
+
273
+ def get_model_config(device_type: str = "cpu") -> unet_cfg.AutoEncoderConfig:
274
+ """Get configs for the Decoder of Stable Diffusion v1.5."""
275
+ in_channels = 3
276
+ latent_channels = 4
277
+ out_channels = 3
278
+ block_out_channels = [128, 256, 512, 512]
279
+ scaling_factor = 0.18215
280
+ layers_per_block = 3
281
+
282
+ # For now, only turns on StableHLO composite ops on GPU backend for better
283
+ # performance. CPU should also switch to it once the support is done.
284
+ enable_hlfb = True if device_type == "gpu" else False
285
+
286
+ norm_config = layers_cfg.NormalizationConfig(
287
+ layers_cfg.NormalizationType.GROUP_NORM,
288
+ group_num=32,
289
+ enable_hlfb=enable_hlfb,
290
+ )
291
+
292
+ att_config = unet_cfg.AttentionBlock2DConfig(
293
+ dim=block_out_channels[-1],
294
+ normalization_config=norm_config,
295
+ attention_config=layers_cfg.AttentionConfig(
296
+ num_heads=1,
297
+ head_dim=block_out_channels[-1],
298
+ num_query_groups=1,
299
+ qkv_use_bias=True,
300
+ output_proj_use_bias=True,
301
+ enable_kv_cache=False,
302
+ qkv_transpose_before_split=True,
303
+ qkv_fused_interleaved=False,
304
+ rotary_base=0,
305
+ rotary_percentage=0.0,
306
+ ),
307
+ enable_hlfb=enable_hlfb,
308
+ )
309
+
310
+ mid_block_config = unet_cfg.MidBlock2DConfig(
311
+ in_channels=block_out_channels[-1],
312
+ normalization_config=norm_config,
313
+ activation_config=layers_cfg.ActivationConfig(
314
+ layers_cfg.ActivationType.SILU
315
+ ),
316
+ num_layers=1,
317
+ attention_block_config=att_config,
318
+ )
319
+
320
+ config = unet_cfg.AutoEncoderConfig(
321
+ in_channels=in_channels,
322
+ latent_channels=latent_channels,
323
+ out_channels=out_channels,
324
+ activation_config=layers_cfg.ActivationConfig(
325
+ layers_cfg.ActivationType.SILU
326
+ ),
327
+ block_out_channels=block_out_channels,
328
+ scaling_factor=scaling_factor,
329
+ layers_per_block=layers_per_block,
330
+ normalization_config=norm_config,
331
+ mid_block_config=mid_block_config,
332
+ )
333
+ return config
334
+
335
+
336
+ def get_fake_model_config(
337
+ device_type: str = "cpu",
338
+ ) -> unet_cfg.AutoEncoderConfig:
339
+ """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
340
+ in_channels = 3
341
+ latent_channels = 4
342
+ out_channels = 3
343
+ block_out_channels = [2, 4]
344
+ scaling_factor = 0.18215
345
+ layers_per_block = 2
346
+
347
+ # For now, only turns on StableHLO composite ops on GPU backend for better
348
+ # performance. CPU should also switch to it once the support is done.
349
+ enable_hlfb = True if device_type == "gpu" else False
350
+
351
+ norm_config = layers_cfg.NormalizationConfig(
352
+ layers_cfg.NormalizationType.GROUP_NORM,
353
+ group_num=2,
354
+ enable_hlfb=enable_hlfb,
355
+ )
356
+
357
+ att_config = unet_cfg.AttentionBlock2DConfig(
358
+ dim=block_out_channels[-1],
359
+ normalization_config=norm_config,
360
+ attention_config=layers_cfg.AttentionConfig(
361
+ num_heads=1,
362
+ head_dim=block_out_channels[-1],
363
+ num_query_groups=1,
364
+ qkv_use_bias=True,
365
+ output_proj_use_bias=True,
366
+ enable_kv_cache=False,
367
+ qkv_transpose_before_split=True,
368
+ qkv_fused_interleaved=False,
369
+ rotary_base=0,
370
+ rotary_percentage=0.0,
371
+ ),
372
+ enable_hlfb=enable_hlfb,
373
+ )
374
+
375
+ mid_block_config = unet_cfg.MidBlock2DConfig(
376
+ in_channels=block_out_channels[-1],
377
+ normalization_config=norm_config,
378
+ activation_config=layers_cfg.ActivationConfig(
379
+ layers_cfg.ActivationType.SILU
380
+ ),
381
+ num_layers=1,
382
+ attention_block_config=att_config,
383
+ )
384
+
385
+ config = unet_cfg.AutoEncoderConfig(
386
+ in_channels=in_channels,
387
+ latent_channels=latent_channels,
388
+ out_channels=out_channels,
389
+ activation_config=layers_cfg.ActivationConfig(
390
+ layers_cfg.ActivationType.SILU
391
+ ),
392
+ block_out_channels=block_out_channels,
393
+ scaling_factor=scaling_factor,
394
+ layers_per_block=layers_per_block,
395
+ normalization_config=norm_config,
396
+ mid_block_config=mid_block_config,
397
+ )
398
+ return config