ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,115 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
20
+ import ai_edge_torch.generative.layers.attention_utils as attention_utils
21
+ import ai_edge_torch.generative.layers.builder as builder
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.utilities.loader as loading_utils
24
+
25
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
27
+ ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
28
+ attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
29
+ attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
30
+ attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
31
+ attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
32
+ pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
33
+ pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
34
+ embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
35
+ embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
36
+ final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
37
+ lm_head=None,
38
+ )
39
+
40
+
41
+ class CLIP(nn.Module):
42
+ """CLIP text encoder
43
+ For details, see https://arxiv.org/abs/2103.00020
44
+ """
45
+
46
+ def __init__(self, config: cfg.ModelConfig):
47
+ super().__init__()
48
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
49
+ self.tok_embedding_position = nn.Parameter(
50
+ torch.zeros((config.max_seq_len, config.embedding_dim))
51
+ )
52
+
53
+ self.config = config
54
+ self.transformer_blocks = nn.ModuleList(
55
+ TransformerBlock(config) for _ in range(config.num_layers)
56
+ )
57
+ self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
58
+
59
+ self.mask_cache = attention_utils.build_causal_mask_cache(
60
+ size=config.max_seq_len, dtype=torch.float32
61
+ )
62
+
63
+ @torch.inference_mode
64
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
65
+ tokens = tokens.type(torch.long)
66
+
67
+ state = self.tok_embedding(tokens) + self.tok_embedding_position
68
+ for layer in self.transformer_blocks:
69
+ state = layer(state, mask=self.mask_cache)
70
+ output = self.final_norm(state)
71
+ return output
72
+
73
+
74
+ def get_model_config() -> cfg.ModelConfig:
75
+ max_seq_len = 77
76
+ vocab_size = 49408
77
+ num_layers = 12
78
+ num_heads = 12
79
+ num_query_groups = 12
80
+ embedding_dim = 768
81
+
82
+ attn_config = cfg.AttentionConfig(
83
+ num_heads=num_heads,
84
+ num_query_groups=num_query_groups,
85
+ rotary_percentage=0.0,
86
+ qkv_use_bias=True,
87
+ qkv_transpose_before_split=True,
88
+ qkv_fused_interleaved=False,
89
+ output_proj_use_bias=True,
90
+ enable_kv_cache=False,
91
+ )
92
+
93
+ ff_config = cfg.FeedForwardConfig(
94
+ type=cfg.FeedForwardType.SEQUENTIAL,
95
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
96
+ intermediate_size=embedding_dim * 4,
97
+ use_bias=True,
98
+ )
99
+
100
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
101
+
102
+ config = cfg.ModelConfig(
103
+ vocab_size=vocab_size,
104
+ num_layers=num_layers,
105
+ max_seq_len=max_seq_len,
106
+ embedding_dim=embedding_dim,
107
+ attn_config=attn_config,
108
+ ff_config=ff_config,
109
+ pre_attention_norm_config=norm_config,
110
+ pre_ff_norm_config=norm_config,
111
+ final_norm_config=norm_config,
112
+ enable_hlfb=True,
113
+ )
114
+
115
+ return config
@@ -0,0 +1,142 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import argparse
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Optional
20
+
21
+ import torch
22
+
23
+ import ai_edge_torch
24
+ import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
25
+ import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
26
+ import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
27
+ from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
28
+ import ai_edge_torch.generative.examples.stable_diffusion.util as util
29
+ import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
30
+
31
+ arg_parser = argparse.ArgumentParser()
32
+ arg_parser.add_argument(
33
+ '--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
34
+ )
35
+ arg_parser.add_argument(
36
+ '--diffusion_ckpt',
37
+ type=str,
38
+ help='Path to source diffusion model checkpoint',
39
+ required=True,
40
+ )
41
+ arg_parser.add_argument(
42
+ '--decoder_ckpt',
43
+ type=str,
44
+ help='Path to source image decoder model checkpoint',
45
+ required=True,
46
+ )
47
+ arg_parser.add_argument(
48
+ '--output_dir',
49
+ type=str,
50
+ help='Path to the converted TF Lite directory.',
51
+ required=True,
52
+ )
53
+
54
+
55
+ @torch.inference_mode
56
+ def convert_stable_diffusion_to_tflite(
57
+ output_dir: str,
58
+ clip_ckpt_path: str,
59
+ diffusion_ckpt_path: str,
60
+ decoder_ckpt_path: str,
61
+ image_height: int = 512,
62
+ image_width: int = 512,
63
+ ):
64
+
65
+ clip_model = clip.CLIP(clip.get_model_config())
66
+ loader = stable_diffusion_loader.ClipModelLoader(
67
+ clip_ckpt_path,
68
+ clip.TENSOR_NAMES,
69
+ )
70
+ loader.load(clip_model, strict=False)
71
+
72
+ diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
73
+ diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
74
+ diffusion_ckpt_path, diffusion.TENSOR_NAMES
75
+ )
76
+ diffusion_loader.load(diffusion_model, strict=False)
77
+
78
+ decoder_model = decoder.Decoder(decoder.get_model_config())
79
+ decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
80
+ decoder_ckpt_path, decoder.TENSOR_NAMES
81
+ )
82
+ decoder_loader.load(decoder_model, strict=False)
83
+
84
+ # TODO(yichunk): enable image encoder conversion
85
+ # if encoder_ckpt_path is not None:
86
+ # encoder = Encoder()
87
+ # encoder.load_state_dict(torch.load(encoder_ckpt_path))
88
+
89
+ # Tensors used to trace the model graph during conversion.
90
+ n_tokens = 77
91
+ timestamp = 0
92
+ len_prompt = 1
93
+ prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
94
+ input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
95
+ noise = torch.full(
96
+ (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
97
+ )
98
+
99
+ input_latents = torch.zeros_like(noise)
100
+ context_cond = clip_model(prompt_tokens)
101
+ context_uncond = torch.zeros_like(context_cond)
102
+ context = torch.cat([context_cond, context_uncond], axis=0)
103
+ time_embedding = util.get_time_embedding(timestamp)
104
+
105
+ if not os.path.exists(output_dir):
106
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
107
+
108
+ # TODO(yichunk): convert to multi signature tflite model.
109
+ # CLIP text encoder
110
+ ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
111
+ f'{output_dir}/clip.tflite'
112
+ )
113
+
114
+ # TODO(yichunk): enable image encoder conversion
115
+ # Image encoder
116
+ # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
117
+ # f'{output_dir}/encoder.tflite'
118
+ # )
119
+
120
+ # Diffusion
121
+ ai_edge_torch.signature(
122
+ 'diffusion',
123
+ diffusion_model,
124
+ (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
125
+ ).convert().export(f'{output_dir}/diffusion.tflite')
126
+
127
+ # Image decoder
128
+ ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
129
+ f'{output_dir}/decoder.tflite'
130
+ )
131
+
132
+
133
+ if __name__ == '__main__':
134
+ args = arg_parser.parse_args()
135
+ convert_stable_diffusion_to_tflite(
136
+ output_dir=args.output_dir,
137
+ clip_ckpt_path=args.clip_ckpt,
138
+ diffusion_ckpt_path=args.diffusion_ckpt,
139
+ decoder_ckpt_path=args.decoder_ckpt,
140
+ image_height=512,
141
+ image_width=512,
142
+ )
@@ -0,0 +1,317 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ import ai_edge_torch.generative.layers.builder as layers_builder
20
+ import ai_edge_torch.generative.layers.model_config as layers_cfg
21
+ import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
+ import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
+ import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
24
+
25
+ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
+ post_quant_conv="first_stage_model.post_quant_conv",
27
+ conv_in="first_stage_model.decoder.conv_in",
28
+ mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
29
+ residual_block_tensor_names=[
30
+ stable_diffusion_loader.ResidualBlockTensorNames(
31
+ norm_1="first_stage_model.decoder.mid.block_1.norm1",
32
+ norm_2="first_stage_model.decoder.mid.block_1.norm2",
33
+ conv_1="first_stage_model.decoder.mid.block_1.conv1",
34
+ conv_2="first_stage_model.decoder.mid.block_1.conv2",
35
+ ),
36
+ stable_diffusion_loader.ResidualBlockTensorNames(
37
+ norm_1="first_stage_model.decoder.mid.block_2.norm1",
38
+ norm_2="first_stage_model.decoder.mid.block_2.norm2",
39
+ conv_1="first_stage_model.decoder.mid.block_2.conv1",
40
+ conv_2="first_stage_model.decoder.mid.block_2.conv2",
41
+ ),
42
+ ],
43
+ attention_block_tensor_names=[
44
+ stable_diffusion_loader.AttentionBlockTensorNames(
45
+ norm="first_stage_model.decoder.mid.attn_1.norm",
46
+ q_proj="first_stage_model.decoder.mid.attn_1.q",
47
+ k_proj="first_stage_model.decoder.mid.attn_1.k",
48
+ v_proj="first_stage_model.decoder.mid.attn_1.v",
49
+ output_proj="first_stage_model.decoder.mid.attn_1.proj_out",
50
+ )
51
+ ],
52
+ ),
53
+ up_decoder_blocks_tensor_names=[
54
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
55
+ residual_block_tensor_names=[
56
+ stable_diffusion_loader.ResidualBlockTensorNames(
57
+ norm_1="first_stage_model.decoder.up.3.block.0.norm1",
58
+ norm_2="first_stage_model.decoder.up.3.block.0.norm2",
59
+ conv_1="first_stage_model.decoder.up.3.block.0.conv1",
60
+ conv_2="first_stage_model.decoder.up.3.block.0.conv2",
61
+ ),
62
+ stable_diffusion_loader.ResidualBlockTensorNames(
63
+ norm_1="first_stage_model.decoder.up.3.block.1.norm1",
64
+ norm_2="first_stage_model.decoder.up.3.block.1.norm2",
65
+ conv_1="first_stage_model.decoder.up.3.block.1.conv1",
66
+ conv_2="first_stage_model.decoder.up.3.block.1.conv2",
67
+ ),
68
+ stable_diffusion_loader.ResidualBlockTensorNames(
69
+ norm_1="first_stage_model.decoder.up.3.block.2.norm1",
70
+ norm_2="first_stage_model.decoder.up.3.block.2.norm2",
71
+ conv_1="first_stage_model.decoder.up.3.block.2.conv1",
72
+ conv_2="first_stage_model.decoder.up.3.block.2.conv2",
73
+ ),
74
+ ],
75
+ upsample_conv="first_stage_model.decoder.up.3.upsample.conv",
76
+ ),
77
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
78
+ residual_block_tensor_names=[
79
+ stable_diffusion_loader.ResidualBlockTensorNames(
80
+ norm_1="first_stage_model.decoder.up.2.block.0.norm1",
81
+ norm_2="first_stage_model.decoder.up.2.block.0.norm2",
82
+ conv_1="first_stage_model.decoder.up.2.block.0.conv1",
83
+ conv_2="first_stage_model.decoder.up.2.block.0.conv2",
84
+ ),
85
+ stable_diffusion_loader.ResidualBlockTensorNames(
86
+ norm_1="first_stage_model.decoder.up.2.block.1.norm1",
87
+ norm_2="first_stage_model.decoder.up.2.block.1.norm2",
88
+ conv_1="first_stage_model.decoder.up.2.block.1.conv1",
89
+ conv_2="first_stage_model.decoder.up.2.block.1.conv2",
90
+ ),
91
+ stable_diffusion_loader.ResidualBlockTensorNames(
92
+ norm_1="first_stage_model.decoder.up.2.block.2.norm1",
93
+ norm_2="first_stage_model.decoder.up.2.block.2.norm2",
94
+ conv_1="first_stage_model.decoder.up.2.block.2.conv1",
95
+ conv_2="first_stage_model.decoder.up.2.block.2.conv2",
96
+ ),
97
+ ],
98
+ upsample_conv="first_stage_model.decoder.up.2.upsample.conv",
99
+ ),
100
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
101
+ residual_block_tensor_names=[
102
+ stable_diffusion_loader.ResidualBlockTensorNames(
103
+ norm_1="first_stage_model.decoder.up.1.block.0.norm1",
104
+ norm_2="first_stage_model.decoder.up.1.block.0.norm2",
105
+ conv_1="first_stage_model.decoder.up.1.block.0.conv1",
106
+ conv_2="first_stage_model.decoder.up.1.block.0.conv2",
107
+ residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
108
+ ),
109
+ stable_diffusion_loader.ResidualBlockTensorNames(
110
+ norm_1="first_stage_model.decoder.up.1.block.1.norm1",
111
+ norm_2="first_stage_model.decoder.up.1.block.1.norm2",
112
+ conv_1="first_stage_model.decoder.up.1.block.1.conv1",
113
+ conv_2="first_stage_model.decoder.up.1.block.1.conv2",
114
+ ),
115
+ stable_diffusion_loader.ResidualBlockTensorNames(
116
+ norm_1="first_stage_model.decoder.up.1.block.2.norm1",
117
+ norm_2="first_stage_model.decoder.up.1.block.2.norm2",
118
+ conv_1="first_stage_model.decoder.up.1.block.2.conv1",
119
+ conv_2="first_stage_model.decoder.up.1.block.2.conv2",
120
+ ),
121
+ ],
122
+ upsample_conv="first_stage_model.decoder.up.1.upsample.conv",
123
+ ),
124
+ stable_diffusion_loader.UpDecoderBlockTensorNames(
125
+ residual_block_tensor_names=[
126
+ stable_diffusion_loader.ResidualBlockTensorNames(
127
+ norm_1="first_stage_model.decoder.up.0.block.0.norm1",
128
+ norm_2="first_stage_model.decoder.up.0.block.0.norm2",
129
+ conv_1="first_stage_model.decoder.up.0.block.0.conv1",
130
+ conv_2="first_stage_model.decoder.up.0.block.0.conv2",
131
+ residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
132
+ ),
133
+ stable_diffusion_loader.ResidualBlockTensorNames(
134
+ norm_1="first_stage_model.decoder.up.0.block.1.norm1",
135
+ norm_2="first_stage_model.decoder.up.0.block.1.norm2",
136
+ conv_1="first_stage_model.decoder.up.0.block.1.conv1",
137
+ conv_2="first_stage_model.decoder.up.0.block.1.conv2",
138
+ ),
139
+ stable_diffusion_loader.ResidualBlockTensorNames(
140
+ norm_1="first_stage_model.decoder.up.0.block.2.norm1",
141
+ norm_2="first_stage_model.decoder.up.0.block.2.norm2",
142
+ conv_1="first_stage_model.decoder.up.0.block.2.conv1",
143
+ conv_2="first_stage_model.decoder.up.0.block.2.conv2",
144
+ ),
145
+ ],
146
+ ),
147
+ ],
148
+ final_norm="first_stage_model.decoder.norm_out",
149
+ conv_out="first_stage_model.decoder.conv_out",
150
+ )
151
+
152
+
153
+ class Decoder(nn.Module):
154
+ """The Decoder model used in Stable Diffusion.
155
+
156
+ For details, see https://arxiv.org/abs/2103.00020
157
+
158
+ Sturcture of the Decoder:
159
+
160
+ latents tensor
161
+ |
162
+
163
+ ┌───────────────────┐
164
+ │ Post Quant Conv │
165
+ └─────────┬─────────┘
166
+
167
+ ┌─────────▼─────────┐
168
+ │ ConvIn │
169
+ └─────────┬─────────┘
170
+
171
+ ┌─────────▼─────────┐
172
+ │ MidBlock2D │
173
+ └─────────┬─────────┘
174
+
175
+ ┌─────────▼─────────┐
176
+ │ UpDecoder2D │ x 4
177
+ └─────────┬─────────┘
178
+
179
+ ┌─────────▼─────────┐
180
+ │ FinalNorm │
181
+ └─────────┬─────────┘
182
+ |
183
+ ┌─────────▼─────────┐
184
+ │ Activation │
185
+ └─────────┬─────────┘
186
+ |
187
+ ┌─────────▼─────────┐
188
+ │ ConvOut │
189
+ └─────────┬─────────┘
190
+ |
191
+
192
+ Output Image
193
+ """
194
+
195
+ def __init__(self, config: unet_cfg.AutoEncoderConfig):
196
+ super().__init__()
197
+ self.config = config
198
+ self.post_quant_conv = nn.Conv2d(
199
+ config.latent_channels,
200
+ config.latent_channels,
201
+ kernel_size=1,
202
+ stride=1,
203
+ padding=0,
204
+ )
205
+ reversed_block_out_channels = list(reversed(config.block_out_channels))
206
+ self.conv_in = nn.Conv2d(
207
+ config.latent_channels,
208
+ reversed_block_out_channels[0],
209
+ kernel_size=3,
210
+ stride=1,
211
+ padding=1,
212
+ )
213
+ self.mid_block = blocks_2d.MidBlock2D(config.mid_block_config)
214
+ up_decoder_blocks = []
215
+ block_out_channels = reversed_block_out_channels[0]
216
+ for i, out_channels in enumerate(reversed_block_out_channels):
217
+ prev_output_channel = block_out_channels
218
+ block_out_channels = out_channels
219
+ not_final_block = i < len(reversed_block_out_channels) - 1
220
+ up_decoder_blocks.append(
221
+ blocks_2d.UpDecoderBlock2D(
222
+ unet_cfg.UpDecoderBlock2DConfig(
223
+ in_channels=prev_output_channel,
224
+ out_channels=block_out_channels,
225
+ normalization_config=config.normalization_config,
226
+ activation_config=config.activation_config,
227
+ num_layers=config.layers_per_block,
228
+ add_upsample=not_final_block,
229
+ upsample_conv=True,
230
+ sampling_config=unet_cfg.UpSamplingConfig(
231
+ mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
232
+ ),
233
+ )
234
+ )
235
+ )
236
+ self.up_decoder_blocks = nn.ModuleList(up_decoder_blocks)
237
+ self.final_norm = layers_builder.build_norm(
238
+ block_out_channels, config.normalization_config
239
+ )
240
+ self.act_fn = layers_builder.get_activation(config.activation_config)
241
+ self.conv_out = nn.Conv2d(
242
+ block_out_channels,
243
+ config.out_channels,
244
+ kernel_size=3,
245
+ stride=1,
246
+ padding=1,
247
+ )
248
+
249
+ def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
250
+ """Forward function of decoder model.
251
+
252
+ Args:
253
+ latents (torch.Tensor): latents space tensor.
254
+
255
+ Returns:
256
+ output decoded image tensor from decoder model.
257
+ """
258
+ x = latents_tensor / self.config.scaling_factor
259
+ x = self.post_quant_conv(x)
260
+ x = self.conv_in(x)
261
+ x = self.mid_block(x)
262
+ for up_decoder_block in self.up_decoder_blocks:
263
+ x = up_decoder_block(x)
264
+ x = self.final_norm(x)
265
+ x = self.act_fn(x)
266
+ x = self.conv_out(x)
267
+ return x
268
+
269
+
270
+ def get_model_config() -> unet_cfg.AutoEncoderConfig:
271
+ """Get configs for the Decoder of Stable Diffusion v1.5"""
272
+ in_channels = 3
273
+ latent_channels = 4
274
+ out_channels = 3
275
+ block_out_channels = [128, 256, 512, 512]
276
+ scaling_factor = 0.18215
277
+ layers_per_block = 3
278
+
279
+ norm_config = layers_cfg.NormalizationConfig(
280
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=32
281
+ )
282
+
283
+ att_config = unet_cfg.AttentionBlock2DConfig(
284
+ dim=block_out_channels[-1],
285
+ normalization_config=norm_config,
286
+ attention_config=layers_cfg.AttentionConfig(
287
+ num_heads=1,
288
+ num_query_groups=1,
289
+ qkv_use_bias=True,
290
+ output_proj_use_bias=True,
291
+ enable_kv_cache=False,
292
+ qkv_transpose_before_split=True,
293
+ qkv_fused_interleaved=False,
294
+ rotary_percentage=0.0,
295
+ ),
296
+ )
297
+
298
+ mid_block_config = unet_cfg.MidBlock2DConfig(
299
+ in_channels=block_out_channels[-1],
300
+ normalization_config=norm_config,
301
+ activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
302
+ num_layers=1,
303
+ attention_block_config=att_config,
304
+ )
305
+
306
+ config = unet_cfg.AutoEncoderConfig(
307
+ in_channels=in_channels,
308
+ latent_channels=latent_channels,
309
+ out_channels=out_channels,
310
+ activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
311
+ block_out_channels=block_out_channels,
312
+ scaling_factor=scaling_factor,
313
+ layers_per_block=layers_per_block,
314
+ normalization_config=norm_config,
315
+ mid_block_config=mid_block_config,
316
+ )
317
+ return config