ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,573 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ import ai_edge_torch.generative.layers.builder as layers_builder
20
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
21
+ import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
+ import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
24
+
25
+ _down_encoder_blocks_tensor_names = [
26
+ stable_diffusion_loader.DownEncoderBlockTensorNames(
27
+ residual_block_tensor_names=[
28
+ stable_diffusion_loader.ResidualBlockTensorNames(
29
+ norm_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.0",
30
+ conv_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.2",
31
+ norm_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.0",
32
+ conv_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.3",
33
+ time_embedding=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.emb_layers.1",
34
+ residual_layer=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.skip_connection"
35
+ if (i * 3 + j + 1) in [4, 7]
36
+ else None,
37
+ )
38
+ for j in range(2)
39
+ ],
40
+ transformer_block_tensor_names=[
41
+ stable_diffusion_loader.TransformerBlockTensorNames(
42
+ pre_conv_norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm",
43
+ conv_in=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in",
44
+ conv_out=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out",
45
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
46
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
47
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
48
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_k",
49
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_v",
50
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_out.0",
51
+ ),
52
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
53
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm2",
54
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_q",
55
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_k",
56
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_v",
57
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_out.0",
58
+ ),
59
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
60
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm3",
61
+ ge_glu=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.0.proj",
62
+ w2=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.2",
63
+ ),
64
+ )
65
+ for j in range(2)
66
+ ]
67
+ if i < 3
68
+ else None,
69
+ downsample_conv=f"model.diffusion_model.input_blocks.{i*3+3}.0.op"
70
+ if i < 3
71
+ else None,
72
+ )
73
+ for i in range(4)
74
+ ]
75
+
76
+ _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
77
+ residual_block_tensor_names=[
78
+ stable_diffusion_loader.ResidualBlockTensorNames(
79
+ norm_1=f"model.diffusion_model.middle_block.{i}.in_layers.0",
80
+ conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
81
+ norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
82
+ conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
83
+ time_embedding=f"model.diffusion_model.middle_block.{i}.emb_layers.1",
84
+ )
85
+ for i in [0, 2]
86
+ ],
87
+ transformer_block_tensor_names=[
88
+ stable_diffusion_loader.TransformerBlockTensorNames(
89
+ pre_conv_norm=f"model.diffusion_model.middle_block.{i}.norm",
90
+ conv_in=f"model.diffusion_model.middle_block.{i}.proj_in",
91
+ conv_out=f"model.diffusion_model.middle_block.{i}.proj_out",
92
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
93
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm1",
94
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_q",
95
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_k",
96
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_v",
97
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_out.0",
98
+ ),
99
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
100
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm2",
101
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_q",
102
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_k",
103
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_v",
104
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_out.0",
105
+ ),
106
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
107
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm3",
108
+ ge_glu=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.0.proj",
109
+ w2=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.2",
110
+ ),
111
+ )
112
+ for i in [1]
113
+ ],
114
+ )
115
+
116
+ _up_decoder_blocks_tensor_names = [
117
+ stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
118
+ residual_block_tensor_names=[
119
+ stable_diffusion_loader.ResidualBlockTensorNames(
120
+ norm_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0",
121
+ conv_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2",
122
+ norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
123
+ conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
124
+ time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
125
+ residual_layer=f"model.diffusion_model.output_blocks.{i*3+j}.0.skip_connection",
126
+ )
127
+ for j in range(3)
128
+ ],
129
+ transformer_block_tensor_names=[
130
+ stable_diffusion_loader.TransformerBlockTensorNames(
131
+ pre_conv_norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.norm",
132
+ conv_in=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in",
133
+ conv_out=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out",
134
+ self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
135
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
136
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
137
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_k",
138
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_v",
139
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_out.0",
140
+ ),
141
+ cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
142
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm2",
143
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_q",
144
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_k",
145
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_v",
146
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_out.0",
147
+ ),
148
+ feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
149
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm3",
150
+ ge_glu=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.0.proj",
151
+ w2=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.2",
152
+ ),
153
+ )
154
+ for j in range(3)
155
+ ]
156
+ if i > 0
157
+ else None,
158
+ upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
159
+ if 0 < i < 3
160
+ else (f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None),
161
+ )
162
+ for i in range(4)
163
+ ]
164
+
165
+ TENSOR_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
166
+ time_embedding=stable_diffusion_loader.TimeEmbeddingTensorNames(
167
+ w1="model.diffusion_model.time_embed.0",
168
+ w2="model.diffusion_model.time_embed.2",
169
+ ),
170
+ conv_in="model.diffusion_model.input_blocks.0.0",
171
+ conv_out="model.diffusion_model.out.2",
172
+ final_norm="model.diffusion_model.out.0",
173
+ down_encoder_blocks_tensor_names=_down_encoder_blocks_tensor_names,
174
+ mid_block_tensor_names=_mid_block_tensor_names,
175
+ up_decoder_blocks_tensor_names=_up_decoder_blocks_tensor_names,
176
+ )
177
+
178
+
179
+ class TimeEmbedding(nn.Module):
180
+
181
+ def __init__(self, in_dim, out_dim):
182
+ super().__init__()
183
+ self.w1 = nn.Linear(in_dim, out_dim)
184
+ self.w2 = nn.Linear(out_dim, out_dim)
185
+ self.act = layers_builder.get_activation(
186
+ layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU)
187
+ )
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ return self.w2(self.act(self.w1(x)))
191
+
192
+
193
+ class Diffusion(nn.Module):
194
+ """The Diffusion model used in Stable Diffusion.
195
+
196
+ For details, see https://arxiv.org/abs/2103.00020
197
+
198
+ Sturcture of the Diffusion model:
199
+
200
+ latents text context time embed
201
+ │ │ │
202
+ │ │ │
203
+ ┌─────────▼─────────┐ │ ┌─────────▼─────────┐
204
+ │ ConvIn │ │ │ Time Embedding │
205
+ └─────────┬─────────┘ │ └─────────┬─────────┘
206
+ │ │ │
207
+ ┌─────────▼─────────┐ │ │
208
+ ┌──────┤ DownEncoder2D │ ◄─────┼────────────┤
209
+ │ └─────────┬─────────┘ x 4 │ │
210
+ │ │ │ │
211
+ │ ┌─────────▼─────────┐ │ │
212
+ skip connection │ MidBlock2D │ ◄─────┼────────────┤
213
+ │ └─────────┬─────────┘ │ │
214
+ │ │ │ │
215
+ │ ┌─────────▼─────────┐ │ │
216
+ └──────► SkipUpDecoder2D │ ◄─────┴────────────┘
217
+ └─────────┬─────────┘ x 4
218
+
219
+ ┌─────────▼─────────┐
220
+ │ FinalNorm │
221
+ └─────────┬─────────┘
222
+
223
+ ┌─────────▼─────────┐
224
+ │ Activation │
225
+ └─────────┬─────────┘
226
+
227
+ ┌─────────▼─────────┐
228
+ │ ConvOut │
229
+ └─────────┬─────────┘
230
+
231
+
232
+ output image
233
+ """
234
+
235
+ def __init__(self, config: unet_cfg.DiffusionModelConfig):
236
+ super().__init__()
237
+
238
+ self.config = config
239
+ block_out_channels = config.block_out_channels
240
+ reversed_block_out_channels = list(reversed(block_out_channels))
241
+
242
+ time_embedding_blocks_dim = config.time_embedding_blocks_dim
243
+ self.time_embedding = TimeEmbedding(
244
+ config.time_embedding_dim, config.time_embedding_blocks_dim
245
+ )
246
+
247
+ self.conv_in = nn.Conv2d(
248
+ config.in_channels, block_out_channels[0], kernel_size=3, padding=1
249
+ )
250
+
251
+ attention_config = layers_cfg.AttentionConfig(
252
+ num_heads=config.transformer_num_attention_heads,
253
+ num_query_groups=config.transformer_num_attention_heads,
254
+ rotary_percentage=0.0,
255
+ qkv_transpose_before_split=True,
256
+ qkv_use_bias=False,
257
+ output_proj_use_bias=True,
258
+ enable_kv_cache=False,
259
+ qkv_fused_interleaved=False,
260
+ )
261
+
262
+ # Down encoders.
263
+ down_encoders = []
264
+ output_channel = block_out_channels[0]
265
+ for i, block_out_channel in enumerate(block_out_channels):
266
+ input_channel = output_channel
267
+ output_channel = block_out_channel
268
+ not_final_block = i < len(block_out_channels) - 1
269
+ if not_final_block:
270
+ down_encoders.append(
271
+ blocks_2d.DownEncoderBlock2D(
272
+ unet_cfg.DownEncoderBlock2DConfig(
273
+ in_channels=input_channel,
274
+ out_channels=output_channel,
275
+ normalization_config=config.residual_norm_config,
276
+ activation_config=layers_cfg.ActivationConfig(
277
+ config.residual_activation_type
278
+ ),
279
+ num_layers=config.layers_per_block,
280
+ padding=config.downsample_padding,
281
+ time_embedding_channels=time_embedding_blocks_dim,
282
+ add_downsample=True,
283
+ sampling_config=unet_cfg.DownSamplingConfig(
284
+ mode=unet_cfg.SamplingType.CONVOLUTION,
285
+ in_channels=output_channel,
286
+ out_channels=output_channel,
287
+ kernel_size=3,
288
+ stride=2,
289
+ padding=config.downsample_padding,
290
+ ),
291
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
292
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
293
+ dim=output_channel,
294
+ attention_batch_size=config.transformer_batch_size,
295
+ normalization_config=config.transformer_norm_config,
296
+ attention_config=attention_config,
297
+ ),
298
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
299
+ query_dim=output_channel,
300
+ cross_dim=config.transformer_cross_attention_dim,
301
+ attention_batch_size=config.transformer_batch_size,
302
+ normalization_config=config.transformer_norm_config,
303
+ attention_config=attention_config,
304
+ ),
305
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
306
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
307
+ dim=output_channel,
308
+ hidden_dim=output_channel * 4,
309
+ normalization_config=config.transformer_norm_config,
310
+ activation_config=layers_cfg.ActivationConfig(
311
+ type=config.transformer_ff_activation_type,
312
+ dim_in=output_channel,
313
+ dim_out=output_channel * 4,
314
+ ),
315
+ use_bias=True,
316
+ ),
317
+ ),
318
+ )
319
+ )
320
+ )
321
+ else:
322
+ down_encoders.append(
323
+ blocks_2d.DownEncoderBlock2D(
324
+ unet_cfg.DownEncoderBlock2DConfig(
325
+ in_channels=input_channel,
326
+ out_channels=output_channel,
327
+ normalization_config=config.residual_norm_config,
328
+ activation_config=layers_cfg.ActivationConfig(
329
+ config.residual_activation_type
330
+ ),
331
+ num_layers=config.layers_per_block,
332
+ padding=config.downsample_padding,
333
+ time_embedding_channels=time_embedding_blocks_dim,
334
+ add_downsample=False,
335
+ )
336
+ )
337
+ )
338
+ self.down_encoders = nn.ModuleList(down_encoders)
339
+
340
+ # Mid block.
341
+ mid_block_channels = block_out_channels[-1]
342
+ self.mid_block = blocks_2d.MidBlock2D(
343
+ unet_cfg.MidBlock2DConfig(
344
+ in_channels=block_out_channels[-1],
345
+ normalization_config=config.residual_norm_config,
346
+ activation_config=layers_cfg.ActivationConfig(
347
+ config.residual_activation_type
348
+ ),
349
+ num_layers=config.mid_block_layers,
350
+ time_embedding_channels=config.time_embedding_blocks_dim,
351
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
352
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
353
+ dim=mid_block_channels,
354
+ attention_batch_size=config.transformer_batch_size,
355
+ normalization_config=config.transformer_norm_config,
356
+ attention_config=attention_config,
357
+ ),
358
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
359
+ query_dim=mid_block_channels,
360
+ cross_dim=config.transformer_cross_attention_dim,
361
+ attention_batch_size=config.transformer_batch_size,
362
+ normalization_config=config.transformer_norm_config,
363
+ attention_config=attention_config,
364
+ ),
365
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
366
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
367
+ dim=mid_block_channels,
368
+ hidden_dim=mid_block_channels * 4,
369
+ normalization_config=config.transformer_norm_config,
370
+ activation_config=layers_cfg.ActivationConfig(
371
+ type=config.transformer_ff_activation_type,
372
+ dim_in=mid_block_channels,
373
+ dim_out=mid_block_channels * 4,
374
+ ),
375
+ use_bias=True,
376
+ ),
377
+ ),
378
+ )
379
+ )
380
+
381
+ # Up decoders.
382
+ up_decoders = []
383
+ up_decoder_layers_per_block = config.layers_per_block + 1
384
+ output_channel = reversed_block_out_channels[0]
385
+ for i, block_out_channel in enumerate(reversed_block_out_channels):
386
+ prev_out_channel = output_channel
387
+ output_channel = block_out_channel
388
+ input_channel = reversed_block_out_channels[
389
+ min(i + 1, len(reversed_block_out_channels) - 1)
390
+ ]
391
+ not_final_block = i < len(reversed_block_out_channels) - 1
392
+ not_first_block = i != 0
393
+ if not_first_block:
394
+ up_decoders.append(
395
+ blocks_2d.SkipUpDecoderBlock2D(
396
+ unet_cfg.SkipUpDecoderBlock2DConfig(
397
+ in_channels=input_channel,
398
+ out_channels=output_channel,
399
+ prev_out_channels=prev_out_channel,
400
+ normalization_config=config.residual_norm_config,
401
+ activation_config=layers_cfg.ActivationConfig(
402
+ config.residual_activation_type
403
+ ),
404
+ num_layers=up_decoder_layers_per_block,
405
+ time_embedding_channels=time_embedding_blocks_dim,
406
+ add_upsample=not_final_block,
407
+ upsample_conv=True,
408
+ sampling_config=unet_cfg.UpSamplingConfig(
409
+ mode=unet_cfg.SamplingType.NEAREST,
410
+ scale_factor=2,
411
+ ),
412
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
413
+ attention_block_config=unet_cfg.AttentionBlock2DConfig(
414
+ dim=output_channel,
415
+ attention_batch_size=config.transformer_batch_size,
416
+ normalization_config=config.transformer_norm_config,
417
+ attention_config=attention_config,
418
+ ),
419
+ cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
420
+ query_dim=output_channel,
421
+ cross_dim=config.transformer_cross_attention_dim,
422
+ attention_batch_size=config.transformer_batch_size,
423
+ normalization_config=config.transformer_norm_config,
424
+ attention_config=attention_config,
425
+ ),
426
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
427
+ feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
428
+ dim=output_channel,
429
+ hidden_dim=output_channel * 4,
430
+ normalization_config=config.transformer_norm_config,
431
+ activation_config=layers_cfg.ActivationConfig(
432
+ type=config.transformer_ff_activation_type,
433
+ dim_in=output_channel,
434
+ dim_out=output_channel * 4,
435
+ ),
436
+ use_bias=True,
437
+ ),
438
+ ),
439
+ )
440
+ )
441
+ )
442
+ else:
443
+ up_decoders.append(
444
+ blocks_2d.SkipUpDecoderBlock2D(
445
+ unet_cfg.SkipUpDecoderBlock2DConfig(
446
+ in_channels=input_channel,
447
+ out_channels=output_channel,
448
+ prev_out_channels=prev_out_channel,
449
+ normalization_config=config.residual_norm_config,
450
+ activation_config=layers_cfg.ActivationConfig(
451
+ config.residual_activation_type
452
+ ),
453
+ num_layers=up_decoder_layers_per_block,
454
+ time_embedding_channels=time_embedding_blocks_dim,
455
+ add_upsample=not_final_block,
456
+ upsample_conv=True,
457
+ sampling_config=unet_cfg.UpSamplingConfig(
458
+ mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
459
+ ),
460
+ )
461
+ )
462
+ )
463
+ self.up_decoders = nn.ModuleList(up_decoders)
464
+
465
+ self.final_norm = layers_builder.build_norm(
466
+ reversed_block_out_channels[-1], config.final_norm_config
467
+ )
468
+ self.final_act = layers_builder.get_activation(
469
+ layers_cfg.ActivationConfig(config.final_activation_type)
470
+ )
471
+ self.conv_out = nn.Conv2d(
472
+ reversed_block_out_channels[-1], config.out_channels, kernel_size=3, padding=1
473
+ )
474
+
475
+ @torch.inference_mode
476
+ def forward(
477
+ self, latents: torch.Tensor, context: torch.Tensor, time_emb: torch.Tensor
478
+ ) -> torch.Tensor:
479
+ """Forward function of diffusion model.
480
+
481
+ Args:
482
+ latents (torch.Tensor): latents space tensor.
483
+ context (torch.Tensor): context tensor from CLIP text encoder.
484
+ time_emb (torch.Tensor): the time embedding tensor.
485
+
486
+ Returns:
487
+ output latents from diffusion model.
488
+ """
489
+ time_emb = self.time_embedding(time_emb)
490
+ x = self.conv_in(latents)
491
+ skip_connection_tensors = [x]
492
+ for encoder in self.down_encoders:
493
+ x, hidden_states = encoder(x, time_emb, context, output_hidden_states=True)
494
+ skip_connection_tensors.extend(hidden_states)
495
+ x = self.mid_block(x, time_emb, context)
496
+ for decoder in self.up_decoders:
497
+ encoder_tensors = [
498
+ skip_connection_tensors.pop() for i in range(self.config.layers_per_block + 1)
499
+ ]
500
+ x = decoder(x, encoder_tensors, time_emb, context)
501
+ x = self.final_norm(x)
502
+ x = self.final_act(x)
503
+ x = self.conv_out(x)
504
+ return x
505
+
506
+
507
+ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
508
+ """Get configs for the Diffusion model of Stable Diffusion v1.5
509
+
510
+ Args:
511
+ batch_size (int): the batch size of input.
512
+
513
+ Retruns:
514
+ The configuration of diffusion model of Stable Diffusion v1.5.
515
+
516
+ """
517
+ in_channels = 4
518
+ out_channels = 4
519
+ block_out_channels = [320, 640, 1280, 1280]
520
+ layers_per_block = 2
521
+ downsample_padding = 1
522
+
523
+ # Residual configs.
524
+ residual_norm_config = layers_cfg.NormalizationConfig(
525
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=32
526
+ )
527
+ residual_activation_type = layers_cfg.ActivationType.SILU
528
+
529
+ # Transformer configs.
530
+ transformer_num_attention_heads = 8
531
+ transformer_batch_size = batch_size
532
+ transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
533
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
534
+ layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
535
+ )
536
+ transformer_norm_config = layers_cfg.NormalizationConfig(
537
+ layers_cfg.NormalizationType.LAYER_NORM
538
+ )
539
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
540
+
541
+ # Time embedding configs.
542
+ time_embedding_dim = 320
543
+ time_embedding_blocks_dim = 1280
544
+
545
+ # Mid block configs.
546
+ mid_block_layers = 1
547
+
548
+ # Finaly layer configs.
549
+ final_norm_config = layers_cfg.NormalizationConfig(
550
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=32
551
+ )
552
+ final_activation_type = layers_cfg.ActivationType.SILU
553
+
554
+ return unet_cfg.DiffusionModelConfig(
555
+ in_channels=in_channels,
556
+ out_channels=out_channels,
557
+ block_out_channels=block_out_channels,
558
+ layers_per_block=layers_per_block,
559
+ downsample_padding=downsample_padding,
560
+ residual_norm_config=residual_norm_config,
561
+ residual_activation_type=residual_activation_type,
562
+ transformer_batch_size=transformer_batch_size,
563
+ transformer_num_attention_heads=transformer_num_attention_heads,
564
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
565
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
566
+ transformer_norm_config=transformer_norm_config,
567
+ transformer_ff_activation_type=transformer_ff_activation_type,
568
+ mid_block_layers=mid_block_layers,
569
+ time_embedding_dim=time_embedding_dim,
570
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
571
+ final_norm_config=final_norm_config,
572
+ final_activation_type=final_activation_type,
573
+ )
@@ -0,0 +1,118 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+
20
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
21
+
22
+
23
+ class AttentionBlock(nn.Module):
24
+
25
+ def __init__(self, channels):
26
+ super().__init__()
27
+ self.groupnorm = nn.GroupNorm(32, channels)
28
+ self.attention = SelfAttention(1, channels)
29
+
30
+ def forward(self, x):
31
+ residue = x
32
+ x = self.groupnorm(x)
33
+
34
+ n, c, h, w = x.shape
35
+ x = x.view((n, c, h * w))
36
+ x = x.transpose(-1, -2)
37
+ x = self.attention(x)
38
+ x = x.transpose(-1, -2)
39
+ x = x.view((n, c, h, w))
40
+
41
+ x += residue
42
+ return x
43
+
44
+
45
+ class ResidualBlock(nn.Module):
46
+
47
+ def __init__(self, in_channels, out_channels):
48
+ super().__init__()
49
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
50
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
51
+
52
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
53
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
54
+
55
+ if in_channels == out_channels:
56
+ self.residual_layer = nn.Identity()
57
+ else:
58
+ self.residual_layer = nn.Conv2d(
59
+ in_channels, out_channels, kernel_size=1, padding=0
60
+ )
61
+
62
+ def forward(self, x):
63
+ residue = x
64
+
65
+ x = self.groupnorm_1(x)
66
+ x = F.silu(x)
67
+ x = self.conv_1(x)
68
+
69
+ x = self.groupnorm_2(x)
70
+ x = F.silu(x)
71
+ x = self.conv_2(x)
72
+
73
+ return x + self.residual_layer(residue)
74
+
75
+
76
+ class Encoder(nn.Sequential):
77
+
78
+ def __init__(self):
79
+ super().__init__(
80
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
81
+ ResidualBlock(128, 128),
82
+ ResidualBlock(128, 128),
83
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
84
+ ResidualBlock(128, 256),
85
+ ResidualBlock(256, 256),
86
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
87
+ ResidualBlock(256, 512),
88
+ ResidualBlock(512, 512),
89
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
90
+ ResidualBlock(512, 512),
91
+ ResidualBlock(512, 512),
92
+ ResidualBlock(512, 512),
93
+ AttentionBlock(512),
94
+ ResidualBlock(512, 512),
95
+ nn.GroupNorm(32, 512),
96
+ nn.SiLU(),
97
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
98
+ nn.Conv2d(8, 8, kernel_size=1, padding=0),
99
+ )
100
+
101
+ @torch.inference_mode
102
+ def forward(self, x, noise):
103
+ for module in self:
104
+ if getattr(module, 'stride', None) == (
105
+ 2,
106
+ 2,
107
+ ): # Padding at downsampling should be asymmetric (see #8)
108
+ x = F.pad(x, (0, 1, 0, 1))
109
+ x = module(x)
110
+
111
+ mean, log_variance = torch.chunk(x, 2, dim=1)
112
+ log_variance = torch.clamp(log_variance, -30, 20)
113
+ variance = log_variance.exp()
114
+ stdev = variance.sqrt()
115
+ x = mean + stdev * noise
116
+
117
+ x *= 0.18215
118
+ return x