ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,749 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import ai_edge_torch.generative.layers.builder as layers_builder
17
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
18
+ from ai_edge_torch.generative.layers.unet import blocks_2d
19
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
20
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
21
+ import torch
22
+ from torch import nn
23
+
24
+ _down_encoder_blocks_tensor_names = [
25
+ stable_diffusion_loader.DownEncoderBlockTensorNames(
26
+ residual_block_tensor_names=[
27
+ stable_diffusion_loader.ResidualBlockTensorNames(
28
+ norm_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.0",
29
+ conv_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.2",
30
+ norm_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.0",
31
+ conv_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.3",
32
+ time_embedding=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.emb_layers.1",
33
+ residual_layer=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.skip_connection"
34
+ if (i * 3 + j + 1) in [4, 7]
35
+ else None,
36
+ )
37
+ for j in range(2)
38
+ ],
39
+ transformer_block_tensor_names=[
40
+ stable_diffusion_loader.TransformerBlockTensorNames(
41
+ pre_conv_norm=(
42
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm"
43
+ ),
44
+ conv_in=(
45
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in"
46
+ ),
47
+ conv_out=(
48
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out"
49
+ ),
50
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
51
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
52
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
53
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_k",
54
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_v",
55
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_out.0",
56
+ ),
57
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
58
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm2",
59
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_q",
60
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_k",
61
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_v",
62
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_out.0",
63
+ ),
64
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
65
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm3",
66
+ ge_glu=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.0.proj",
67
+ w2=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.2",
68
+ ),
69
+ )
70
+ for j in range(2)
71
+ ]
72
+ if i < 3
73
+ else None,
74
+ downsample_conv=f"model.diffusion_model.input_blocks.{i*3+3}.0.op"
75
+ if i < 3
76
+ else None,
77
+ )
78
+ for i in range(4)
79
+ ]
80
+
81
+ _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
82
+ residual_block_tensor_names=[
83
+ stable_diffusion_loader.ResidualBlockTensorNames(
84
+ norm_1=f"model.diffusion_model.middle_block.{i}.in_layers.0",
85
+ conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
86
+ norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
87
+ conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
88
+ time_embedding=(
89
+ f"model.diffusion_model.middle_block.{i}.emb_layers.1"
90
+ ),
91
+ )
92
+ for i in [0, 2]
93
+ ],
94
+ transformer_block_tensor_names=[
95
+ stable_diffusion_loader.TransformerBlockTensorNames(
96
+ pre_conv_norm=f"model.diffusion_model.middle_block.{i}.norm",
97
+ conv_in=f"model.diffusion_model.middle_block.{i}.proj_in",
98
+ conv_out=f"model.diffusion_model.middle_block.{i}.proj_out",
99
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
100
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm1",
101
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_q",
102
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_k",
103
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_v",
104
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_out.0",
105
+ ),
106
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
107
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm2",
108
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_q",
109
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_k",
110
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_v",
111
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_out.0",
112
+ ),
113
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
114
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm3",
115
+ ge_glu=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.0.proj",
116
+ w2=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.2",
117
+ ),
118
+ )
119
+ for i in [1]
120
+ ],
121
+ )
122
+
123
+ _up_decoder_blocks_tensor_names = [
124
+ stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
125
+ residual_block_tensor_names=[
126
+ stable_diffusion_loader.ResidualBlockTensorNames(
127
+ norm_1=(
128
+ f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0"
129
+ ),
130
+ conv_1=(
131
+ f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2"
132
+ ),
133
+ norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
134
+ conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
135
+ time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
136
+ residual_layer=f"model.diffusion_model.output_blocks.{i*3+j}.0.skip_connection",
137
+ )
138
+ for j in range(3)
139
+ ],
140
+ transformer_block_tensor_names=[
141
+ stable_diffusion_loader.TransformerBlockTensorNames(
142
+ pre_conv_norm=(
143
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.norm"
144
+ ),
145
+ conv_in=(
146
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in"
147
+ ),
148
+ conv_out=(
149
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out"
150
+ ),
151
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
152
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
153
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
154
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_k",
155
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_v",
156
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_out.0",
157
+ ),
158
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
159
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm2",
160
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_q",
161
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_k",
162
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_v",
163
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_out.0",
164
+ ),
165
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
166
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm3",
167
+ ge_glu=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.0.proj",
168
+ w2=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.2",
169
+ ),
170
+ )
171
+ for j in range(3)
172
+ ]
173
+ if i > 0
174
+ else None,
175
+ upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
176
+ if 0 < i < 3
177
+ else (
178
+ f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None
179
+ ),
180
+ )
181
+ for i in range(4)
182
+ ]
183
+
184
+ TENSOR_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
185
+ time_embedding=stable_diffusion_loader.TimeEmbeddingTensorNames(
186
+ w1="model.diffusion_model.time_embed.0",
187
+ w2="model.diffusion_model.time_embed.2",
188
+ ),
189
+ conv_in="model.diffusion_model.input_blocks.0.0",
190
+ conv_out="model.diffusion_model.out.2",
191
+ final_norm="model.diffusion_model.out.0",
192
+ down_encoder_blocks_tensor_names=_down_encoder_blocks_tensor_names,
193
+ mid_block_tensor_names=_mid_block_tensor_names,
194
+ up_decoder_blocks_tensor_names=_up_decoder_blocks_tensor_names,
195
+ )
196
+
197
+
198
+ def build_attention_config(
199
+ num_heads,
200
+ dim,
201
+ num_query_groups,
202
+ rotary_base=0,
203
+ rotary_percentage=0.0,
204
+ qkv_transpose_before_split=True,
205
+ qkv_use_bias=False,
206
+ output_proj_use_bias=True,
207
+ enable_kv_cache=False,
208
+ qkv_fused_interleaved=False,
209
+ ):
210
+
211
+ return layers_cfg.AttentionConfig(
212
+ num_heads=num_heads,
213
+ head_dim=dim // num_heads,
214
+ num_query_groups=num_query_groups,
215
+ rotary_base=rotary_base,
216
+ rotary_percentage=rotary_percentage,
217
+ qkv_transpose_before_split=qkv_transpose_before_split,
218
+ qkv_use_bias=qkv_use_bias,
219
+ output_proj_use_bias=output_proj_use_bias,
220
+ enable_kv_cache=enable_kv_cache,
221
+ qkv_fused_interleaved=qkv_fused_interleaved,
222
+ )
223
+
224
+
225
+ class TimeEmbedding(nn.Module):
226
+
227
+ def __init__(self, in_dim, out_dim):
228
+ super().__init__()
229
+ self.w1 = nn.Linear(in_dim, out_dim)
230
+ self.w2 = nn.Linear(out_dim, out_dim)
231
+ self.act = layers_builder.get_activation(
232
+ layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU)
233
+ )
234
+
235
+ def forward(self, x: torch.Tensor):
236
+ return self.w2(self.act(self.w1(x)))
237
+
238
+
239
+ class Diffusion(nn.Module):
240
+ """The Diffusion model used in Stable Diffusion.
241
+
242
+ For details, see https://arxiv.org/abs/2103.00020
243
+
244
+ Sturcture of the Diffusion model:
245
+
246
+ latents text context time embed
247
+ │ │ │
248
+ │ │ │
249
+ ┌─────────▼─────────┐ │ ┌─────────▼─────────┐
250
+ │ ConvIn │ │ │ Time Embedding │
251
+ └─────────┬─────────┘ │ └─────────┬─────────┘
252
+ │ │ │
253
+ ┌─────────▼─────────┐ │ │
254
+ ┌──────┤ DownEncoder2D │ ◄─────┼────────────┤
255
+ │ └─────────┬─────────┘ x 4 │ │
256
+ │ │ │ │
257
+ │ ┌─────────▼─────────┐ │ │
258
+ skip connection │ MidBlock2D │ ◄─────┼────────────┤
259
+ │ └─────────┬─────────┘ │ │
260
+ │ │ │ │
261
+ │ ┌─────────▼─────────┐ │ │
262
+ └──────► SkipUpDecoder2D │ ◄─────┴────────────┘
263
+ └─────────┬─────────┘ x 4
264
+
265
+ ┌─────────▼─────────┐
266
+ │ FinalNorm │
267
+ └─────────┬─────────┘
268
+
269
+ ┌─────────▼─────────┐
270
+ │ Activation │
271
+ └─────────┬─────────┘
272
+
273
+ ┌─────────▼─────────┐
274
+ │ ConvOut │
275
+ └─────────┬─────────┘
276
+
277
+
278
+ output image
279
+ """
280
+
281
+ def __init__(self, config: unet_cfg.DiffusionModelConfig):
282
+ super().__init__()
283
+
284
+ self.config = config
285
+ block_out_channels = config.block_out_channels
286
+ reversed_block_out_channels = list(reversed(block_out_channels))
287
+
288
+ time_embedding_blocks_dim = config.time_embedding_blocks_dim
289
+ self.time_embedding = TimeEmbedding(
290
+ config.time_embedding_dim, config.time_embedding_blocks_dim
291
+ )
292
+
293
+ self.conv_in = nn.Conv2d(
294
+ config.in_channels, block_out_channels[0], kernel_size=3, padding=1
295
+ )
296
+
297
+ # Down encoders.
298
+ down_encoders = []
299
+ output_channel = block_out_channels[0]
300
+ for i, block_out_channel in enumerate(block_out_channels):
301
+ input_channel = output_channel
302
+ output_channel = block_out_channel
303
+ not_final_block = i < len(block_out_channels) - 1
304
+ if not_final_block:
305
+ down_encoders.append(
306
+ blocks_2d.DownEncoderBlock2D(
307
+ unet_cfg.DownEncoderBlock2DConfig(
308
+ in_channels=input_channel,
309
+ out_channels=output_channel,
310
+ normalization_config=config.residual_norm_config,
311
+ activation_config=layers_cfg.ActivationConfig(
312
+ config.residual_activation_type
313
+ ),
314
+ num_layers=config.layers_per_block,
315
+ padding=config.downsample_padding,
316
+ time_embedding_channels=time_embedding_blocks_dim,
317
+ add_downsample=True,
318
+ sampling_config=unet_cfg.DownSamplingConfig(
319
+ mode=unet_cfg.SamplingType.CONVOLUTION,
320
+ in_channels=output_channel,
321
+ out_channels=output_channel,
322
+ kernel_size=3,
323
+ stride=2,
324
+ padding=config.downsample_padding,
325
+ ),
326
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
327
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
328
+ dim=output_channel,
329
+ attention_batch_size=config.transformer_batch_size,
330
+ normalization_config=config.transformer_norm_config,
331
+ attention_config=build_attention_config(
332
+ num_heads=config.transformer_num_attention_heads,
333
+ dim=output_channel,
334
+ num_query_groups=config.transformer_num_attention_heads,
335
+ ),
336
+ enable_hlfb=config.enable_hlfb,
337
+ ),
338
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
339
+ query_dim=output_channel,
340
+ cross_dim=config.transformer_cross_attention_dim,
341
+ hidden_dim=output_channel,
342
+ output_dim=output_channel,
343
+ attention_batch_size=config.transformer_batch_size,
344
+ normalization_config=config.transformer_norm_config,
345
+ attention_config=build_attention_config(
346
+ num_heads=config.transformer_num_attention_heads,
347
+ dim=output_channel,
348
+ num_query_groups=config.transformer_num_attention_heads,
349
+ ),
350
+ enable_hlfb=config.enable_hlfb,
351
+ ),
352
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
353
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
354
+ dim=output_channel,
355
+ hidden_dim=output_channel * 4,
356
+ normalization_config=config.transformer_norm_config,
357
+ activation_config=layers_cfg.ActivationConfig(
358
+ type=config.transformer_ff_activation_type,
359
+ dim_in=output_channel,
360
+ dim_out=output_channel * 4,
361
+ ),
362
+ use_bias=True,
363
+ ),
364
+ ),
365
+ )
366
+ )
367
+ )
368
+ else:
369
+ down_encoders.append(
370
+ blocks_2d.DownEncoderBlock2D(
371
+ unet_cfg.DownEncoderBlock2DConfig(
372
+ in_channels=input_channel,
373
+ out_channels=output_channel,
374
+ normalization_config=config.residual_norm_config,
375
+ activation_config=layers_cfg.ActivationConfig(
376
+ config.residual_activation_type
377
+ ),
378
+ num_layers=config.layers_per_block,
379
+ padding=config.downsample_padding,
380
+ time_embedding_channels=time_embedding_blocks_dim,
381
+ add_downsample=False,
382
+ )
383
+ )
384
+ )
385
+ self.down_encoders = nn.ModuleList(down_encoders)
386
+
387
+ # Mid block.
388
+ mid_block_channels = block_out_channels[-1]
389
+ self.mid_block = blocks_2d.MidBlock2D(
390
+ unet_cfg.MidBlock2DConfig(
391
+ in_channels=block_out_channels[-1],
392
+ normalization_config=config.residual_norm_config,
393
+ activation_config=layers_cfg.ActivationConfig(
394
+ config.residual_activation_type
395
+ ),
396
+ num_layers=config.mid_block_layers,
397
+ time_embedding_channels=config.time_embedding_blocks_dim,
398
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
399
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
400
+ dim=mid_block_channels,
401
+ attention_batch_size=config.transformer_batch_size,
402
+ normalization_config=config.transformer_norm_config,
403
+ attention_config=build_attention_config(
404
+ num_heads=config.transformer_num_attention_heads,
405
+ dim=mid_block_channels,
406
+ num_query_groups=config.transformer_num_attention_heads,
407
+ ),
408
+ enable_hlfb=config.enable_hlfb,
409
+ ),
410
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
411
+ query_dim=mid_block_channels,
412
+ cross_dim=config.transformer_cross_attention_dim,
413
+ hidden_dim=mid_block_channels,
414
+ output_dim=mid_block_channels,
415
+ attention_batch_size=config.transformer_batch_size,
416
+ normalization_config=config.transformer_norm_config,
417
+ attention_config=build_attention_config(
418
+ num_heads=config.transformer_num_attention_heads,
419
+ dim=mid_block_channels,
420
+ num_query_groups=config.transformer_num_attention_heads,
421
+ ),
422
+ enable_hlfb=config.enable_hlfb,
423
+ ),
424
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
425
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
426
+ dim=mid_block_channels,
427
+ hidden_dim=mid_block_channels * 4,
428
+ normalization_config=config.transformer_norm_config,
429
+ activation_config=layers_cfg.ActivationConfig(
430
+ type=config.transformer_ff_activation_type,
431
+ dim_in=mid_block_channels,
432
+ dim_out=mid_block_channels * 4,
433
+ ),
434
+ use_bias=True,
435
+ ),
436
+ ),
437
+ )
438
+ )
439
+
440
+ # Up decoders.
441
+ up_decoders = []
442
+ up_decoder_layers_per_block = config.layers_per_block + 1
443
+ output_channel = reversed_block_out_channels[0]
444
+ for i, block_out_channel in enumerate(reversed_block_out_channels):
445
+ prev_out_channel = output_channel
446
+ output_channel = block_out_channel
447
+ input_channel = reversed_block_out_channels[
448
+ min(i + 1, len(reversed_block_out_channels) - 1)
449
+ ]
450
+ not_final_block = i < len(reversed_block_out_channels) - 1
451
+ not_first_block = i != 0
452
+ if not_first_block:
453
+ up_decoders.append(
454
+ blocks_2d.SkipUpDecoderBlock2D(
455
+ unet_cfg.SkipUpDecoderBlock2DConfig(
456
+ in_channels=input_channel,
457
+ out_channels=output_channel,
458
+ prev_out_channels=prev_out_channel,
459
+ normalization_config=config.residual_norm_config,
460
+ activation_config=layers_cfg.ActivationConfig(
461
+ config.residual_activation_type
462
+ ),
463
+ num_layers=up_decoder_layers_per_block,
464
+ time_embedding_channels=time_embedding_blocks_dim,
465
+ add_upsample=not_final_block,
466
+ upsample_conv=True,
467
+ sampling_config=unet_cfg.UpSamplingConfig(
468
+ mode=unet_cfg.SamplingType.NEAREST,
469
+ scale_factor=2,
470
+ ),
471
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
472
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
473
+ dim=output_channel,
474
+ attention_batch_size=config.transformer_batch_size,
475
+ normalization_config=config.transformer_norm_config,
476
+ attention_config=build_attention_config(
477
+ num_heads=config.transformer_num_attention_heads,
478
+ dim=output_channel,
479
+ num_query_groups=config.transformer_num_attention_heads,
480
+ ),
481
+ enable_hlfb=config.enable_hlfb,
482
+ ),
483
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
484
+ query_dim=output_channel,
485
+ cross_dim=config.transformer_cross_attention_dim,
486
+ hidden_dim=output_channel,
487
+ output_dim=output_channel,
488
+ attention_batch_size=config.transformer_batch_size,
489
+ normalization_config=config.transformer_norm_config,
490
+ attention_config=build_attention_config(
491
+ num_heads=config.transformer_num_attention_heads,
492
+ dim=output_channel,
493
+ num_query_groups=config.transformer_num_attention_heads,
494
+ ),
495
+ enable_hlfb=config.enable_hlfb,
496
+ ),
497
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
498
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
499
+ dim=output_channel,
500
+ hidden_dim=output_channel * 4,
501
+ normalization_config=config.transformer_norm_config,
502
+ activation_config=layers_cfg.ActivationConfig(
503
+ type=config.transformer_ff_activation_type,
504
+ dim_in=output_channel,
505
+ dim_out=output_channel * 4,
506
+ ),
507
+ use_bias=True,
508
+ ),
509
+ ),
510
+ )
511
+ )
512
+ )
513
+ else:
514
+ up_decoders.append(
515
+ blocks_2d.SkipUpDecoderBlock2D(
516
+ unet_cfg.SkipUpDecoderBlock2DConfig(
517
+ in_channels=input_channel,
518
+ out_channels=output_channel,
519
+ prev_out_channels=prev_out_channel,
520
+ normalization_config=config.residual_norm_config,
521
+ activation_config=layers_cfg.ActivationConfig(
522
+ config.residual_activation_type
523
+ ),
524
+ num_layers=up_decoder_layers_per_block,
525
+ time_embedding_channels=time_embedding_blocks_dim,
526
+ add_upsample=not_final_block,
527
+ upsample_conv=True,
528
+ sampling_config=unet_cfg.UpSamplingConfig(
529
+ mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
530
+ ),
531
+ )
532
+ )
533
+ )
534
+ self.up_decoders = nn.ModuleList(up_decoders)
535
+
536
+ self.final_norm = layers_builder.build_norm(
537
+ reversed_block_out_channels[-1], config.final_norm_config
538
+ )
539
+ self.final_act = layers_builder.get_activation(
540
+ layers_cfg.ActivationConfig(config.final_activation_type)
541
+ )
542
+ self.conv_out = nn.Conv2d(
543
+ reversed_block_out_channels[-1],
544
+ config.out_channels,
545
+ kernel_size=3,
546
+ padding=1,
547
+ )
548
+
549
+ @torch.inference_mode
550
+ def forward(
551
+ self, latents: torch.Tensor, context: torch.Tensor, time_emb: torch.Tensor
552
+ ) -> torch.Tensor:
553
+ """Forward function of diffusion model.
554
+
555
+ Args:
556
+ latents (torch.Tensor): latents space tensor.
557
+ context (torch.Tensor): context tensor from CLIP text encoder.
558
+ time_emb (torch.Tensor): the time embedding tensor.
559
+
560
+ Returns:
561
+ output latents from diffusion model.
562
+ """
563
+ time_emb = self.time_embedding(time_emb)
564
+ x = self.conv_in(latents)
565
+ skip_connection_tensors = [x]
566
+ for encoder in self.down_encoders:
567
+ x, hidden_states = encoder(
568
+ x, time_emb, context, output_hidden_states=True
569
+ )
570
+ skip_connection_tensors.extend(hidden_states)
571
+ x = self.mid_block(x, time_emb, context)
572
+ for decoder in self.up_decoders:
573
+ encoder_tensors = [
574
+ skip_connection_tensors.pop()
575
+ for i in range(self.config.layers_per_block + 1)
576
+ ]
577
+ x = decoder(x, encoder_tensors, time_emb, context)
578
+ x = self.final_norm(x)
579
+ x = self.final_act(x)
580
+ x = self.conv_out(x)
581
+ return x
582
+
583
+
584
+ def get_model_config(
585
+ batch_size: int, device_type: str = "cpu"
586
+ ) -> unet_cfg.DiffusionModelConfig:
587
+ """Get configs for the Diffusion model of Stable Diffusion v1.5.
588
+
589
+ Args:
590
+ batch_size (int): the batch size of input.
591
+ device_type (str): the device type of the model. Default to "cpu".
592
+
593
+ Returns:
594
+ The configuration of diffusion model of Stable Diffusion v1.5.
595
+ """
596
+ in_channels = 4
597
+ out_channels = 4
598
+ block_out_channels = [320, 640, 1280, 1280]
599
+ layers_per_block = 2
600
+ downsample_padding = 1
601
+
602
+ # For now, only turns on StableHLO composite ops on GPU backend for better
603
+ # performance. CPU should also switch to it once the support is done.
604
+ enable_hlfb = True if device_type == "gpu" else False
605
+
606
+ # Residual configs.
607
+ residual_norm_config = layers_cfg.NormalizationConfig(
608
+ layers_cfg.NormalizationType.GROUP_NORM,
609
+ group_num=32,
610
+ enable_hlfb=enable_hlfb,
611
+ )
612
+ residual_activation_type = layers_cfg.ActivationType.SILU
613
+
614
+ # Transformer configs.
615
+ transformer_num_attention_heads = 8
616
+ transformer_batch_size = batch_size
617
+ transformer_cross_attention_dim = 768 # Embedding from CLIP model
618
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
619
+ layers_cfg.NormalizationType.GROUP_NORM,
620
+ epsilon=1e-6,
621
+ group_num=32,
622
+ enable_hlfb=enable_hlfb,
623
+ )
624
+ transformer_norm_config = layers_cfg.NormalizationConfig(
625
+ layers_cfg.NormalizationType.LAYER_NORM,
626
+ enable_hlfb=enable_hlfb,
627
+ )
628
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
629
+
630
+ # Time embedding configs.
631
+ time_embedding_dim = 320
632
+ time_embedding_blocks_dim = 1280
633
+
634
+ # Mid block configs.
635
+ mid_block_layers = 1
636
+
637
+ # Finaly layer configs.
638
+ final_norm_config = layers_cfg.NormalizationConfig(
639
+ layers_cfg.NormalizationType.GROUP_NORM,
640
+ group_num=32,
641
+ enable_hlfb=enable_hlfb,
642
+ )
643
+ final_activation_type = layers_cfg.ActivationType.SILU
644
+
645
+ return unet_cfg.DiffusionModelConfig(
646
+ in_channels=in_channels,
647
+ out_channels=out_channels,
648
+ block_out_channels=block_out_channels,
649
+ layers_per_block=layers_per_block,
650
+ downsample_padding=downsample_padding,
651
+ residual_norm_config=residual_norm_config,
652
+ residual_activation_type=residual_activation_type,
653
+ transformer_batch_size=transformer_batch_size,
654
+ transformer_num_attention_heads=transformer_num_attention_heads,
655
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
656
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
657
+ transformer_norm_config=transformer_norm_config,
658
+ transformer_ff_activation_type=transformer_ff_activation_type,
659
+ mid_block_layers=mid_block_layers,
660
+ time_embedding_dim=time_embedding_dim,
661
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
662
+ final_norm_config=final_norm_config,
663
+ final_activation_type=final_activation_type,
664
+ enable_hlfb=enable_hlfb,
665
+ )
666
+
667
+
668
+ def get_fake_model_config(
669
+ batch_size: int, device_type: str = "cpu"
670
+ ) -> unet_cfg.DiffusionModelConfig:
671
+ """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
672
+
673
+ Args:
674
+ batch_size (int): the batch size of input.
675
+ device_type (str): the device type of the model. Default to "cpu".
676
+
677
+ Returns:
678
+ The configuration of diffusion model of Stable Diffusion v1.5.
679
+ """
680
+ in_channels = 4
681
+ out_channels = 4
682
+ block_out_channels = [2, 4, 8, 8]
683
+ layers_per_block = 1
684
+ downsample_padding = 1
685
+
686
+ # For now, only turns on StableHLO composite ops on GPU backend for better
687
+ # performance. CPU should also switch to it once the support is done.
688
+ enable_hlfb = True if device_type == "gpu" else False
689
+
690
+ # Residual configs.
691
+ residual_norm_config = layers_cfg.NormalizationConfig(
692
+ layers_cfg.NormalizationType.GROUP_NORM,
693
+ group_num=2,
694
+ enable_hlfb=enable_hlfb,
695
+ )
696
+ residual_activation_type = layers_cfg.ActivationType.SILU
697
+
698
+ # Transformer configs.
699
+ transformer_num_attention_heads = 1
700
+ transformer_batch_size = batch_size
701
+ transformer_cross_attention_dim = 4 # Embedding from CLIP model
702
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
703
+ layers_cfg.NormalizationType.GROUP_NORM,
704
+ epsilon=1e-6,
705
+ group_num=2,
706
+ enable_hlfb=enable_hlfb,
707
+ )
708
+ transformer_norm_config = layers_cfg.NormalizationConfig(
709
+ layers_cfg.NormalizationType.LAYER_NORM,
710
+ enable_hlfb=enable_hlfb,
711
+ )
712
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
713
+
714
+ # Time embedding configs.
715
+ time_embedding_dim = 2
716
+ time_embedding_blocks_dim = 4
717
+
718
+ # Mid block configs.
719
+ mid_block_layers = 1
720
+
721
+ # Finaly layer configs.
722
+ final_norm_config = layers_cfg.NormalizationConfig(
723
+ layers_cfg.NormalizationType.GROUP_NORM,
724
+ group_num=2,
725
+ enable_hlfb=enable_hlfb,
726
+ )
727
+ final_activation_type = layers_cfg.ActivationType.SILU
728
+
729
+ return unet_cfg.DiffusionModelConfig(
730
+ in_channels=in_channels,
731
+ out_channels=out_channels,
732
+ block_out_channels=block_out_channels,
733
+ layers_per_block=layers_per_block,
734
+ downsample_padding=downsample_padding,
735
+ residual_norm_config=residual_norm_config,
736
+ residual_activation_type=residual_activation_type,
737
+ transformer_batch_size=transformer_batch_size,
738
+ transformer_num_attention_heads=transformer_num_attention_heads,
739
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
740
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
741
+ transformer_norm_config=transformer_norm_config,
742
+ transformer_ff_activation_type=transformer_ff_activation_type,
743
+ mid_block_layers=mid_block_layers,
744
+ time_embedding_dim=time_embedding_dim,
745
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
746
+ final_norm_config=final_norm_config,
747
+ final_activation_type=final_activation_type,
748
+ enable_hlfb=enable_hlfb,
749
+ )