ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1032 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ from typing import Dict, List, Optional, Tuple
18
+
19
+ import ai_edge_torch.generative.layers.model_config as layers_config
20
+ import ai_edge_torch.generative.layers.unet.model_config as unet_config
21
+ import ai_edge_torch.generative.utilities.loader as loader
22
+ import torch
23
+
24
+
25
+ @dataclass
26
+ class ResidualBlockTensorNames:
27
+ norm_1: str = None
28
+ conv_1: str = None
29
+ norm_2: str = None
30
+ conv_2: str = None
31
+ residual_layer: str = None
32
+ time_embedding: str = None
33
+
34
+
35
+ @dataclass
36
+ class AttentionBlockTensorNames:
37
+ norm: str = None
38
+ fused_qkv_proj: str = None
39
+ q_proj: str = None
40
+ k_proj: str = None
41
+ v_proj: str = None
42
+ output_proj: str = None
43
+
44
+
45
+ @dataclass
46
+ class CrossAttentionBlockTensorNames:
47
+ norm: str = None
48
+ q_proj: str = None
49
+ k_proj: str = None
50
+ v_proj: str = None
51
+ output_proj: str = None
52
+
53
+
54
+ @dataclass
55
+ class TimeEmbeddingTensorNames:
56
+ w1: str = None
57
+ w2: str = None
58
+
59
+
60
+ @dataclass
61
+ class FeedForwardBlockTensorNames:
62
+ w1: str = None
63
+ w2: str = None
64
+ norm: str = None
65
+ ge_glu: str = None
66
+
67
+
68
+ @dataclass
69
+ class TransformerBlockTensorNames:
70
+ pre_conv_norm: str
71
+ conv_in: str
72
+ self_attention: AttentionBlockTensorNames
73
+ cross_attention: CrossAttentionBlockTensorNames
74
+ feed_forward: FeedForwardBlockTensorNames
75
+ conv_out: str
76
+
77
+
78
+ @dataclass
79
+ class MidBlockTensorNames:
80
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
81
+ attention_block_tensor_names: Optional[List[AttentionBlockTensorNames]] = None
82
+ transformer_block_tensor_names: Optional[
83
+ List[TransformerBlockTensorNames]
84
+ ] = None
85
+
86
+
87
+ @dataclass
88
+ class DownEncoderBlockTensorNames:
89
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
90
+ transformer_block_tensor_names: Optional[
91
+ List[TransformerBlockTensorNames]
92
+ ] = None
93
+ downsample_conv: str = None
94
+
95
+
96
+ @dataclass
97
+ class UpDecoderBlockTensorNames:
98
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
99
+ transformer_block_tensor_names: Optional[
100
+ List[TransformerBlockTensorNames]
101
+ ] = None
102
+ upsample_conv: str = None
103
+
104
+
105
+ @dataclass
106
+ class SkipUpDecoderBlockTensorNames:
107
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
108
+ transformer_block_tensor_names: Optional[
109
+ List[TransformerBlockTensorNames]
110
+ ] = None
111
+ upsample_conv: str = None
112
+
113
+
114
+ def _map_to_converted_state(
115
+ state: Dict[str, torch.Tensor],
116
+ state_param: str,
117
+ converted_state: Dict[str, torch.Tensor],
118
+ converted_state_param: str,
119
+ squeeze_dims: bool = False,
120
+ ):
121
+ converted_state[f"{converted_state_param}.weight"] = state.pop(
122
+ f"{state_param}.weight"
123
+ )
124
+ if squeeze_dims:
125
+ converted_state[f"{converted_state_param}.weight"] = torch.squeeze(
126
+ converted_state[f"{converted_state_param}.weight"]
127
+ )
128
+ if f"{state_param}.bias" in state:
129
+ converted_state[f"{converted_state_param}.bias"] = state.pop(
130
+ f"{state_param}.bias"
131
+ )
132
+ if squeeze_dims:
133
+ converted_state[f"{converted_state_param}.bias"] = torch.squeeze(
134
+ converted_state[f"{converted_state_param}.bias"]
135
+ )
136
+
137
+
138
+ class BaseLoader(loader.ModelLoader):
139
+
140
+ def _map_residual_block(
141
+ self,
142
+ state: Dict[str, torch.Tensor],
143
+ converted_state: Dict[str, torch.Tensor],
144
+ tensor_names: ResidualBlockTensorNames,
145
+ converted_state_param_prefix: str,
146
+ config: unet_config.ResidualBlock2DConfig,
147
+ ):
148
+ _map_to_converted_state(
149
+ state,
150
+ tensor_names.norm_1,
151
+ converted_state,
152
+ f"{converted_state_param_prefix}.norm_1",
153
+ )
154
+ _map_to_converted_state(
155
+ state,
156
+ tensor_names.conv_1,
157
+ converted_state,
158
+ f"{converted_state_param_prefix}.conv_1",
159
+ )
160
+ _map_to_converted_state(
161
+ state,
162
+ tensor_names.norm_2,
163
+ converted_state,
164
+ f"{converted_state_param_prefix}.norm_2",
165
+ )
166
+ _map_to_converted_state(
167
+ state,
168
+ tensor_names.conv_2,
169
+ converted_state,
170
+ f"{converted_state_param_prefix}.conv_2",
171
+ )
172
+ if config.in_channels != config.out_channels:
173
+ _map_to_converted_state(
174
+ state,
175
+ tensor_names.residual_layer,
176
+ converted_state,
177
+ f"{converted_state_param_prefix}.residual_layer",
178
+ )
179
+ if config.time_embedding_channels is not None:
180
+ _map_to_converted_state(
181
+ state,
182
+ tensor_names.time_embedding,
183
+ converted_state,
184
+ f"{converted_state_param_prefix}.time_emb_proj",
185
+ )
186
+
187
+ def _map_attention_block(
188
+ self,
189
+ state: Dict[str, torch.Tensor],
190
+ converted_state: Dict[str, torch.Tensor],
191
+ tensor_names: AttentionBlockTensorNames,
192
+ converted_state_param_prefix: str,
193
+ config: unet_config.AttentionBlock2DConfig,
194
+ ):
195
+ if config.normalization_config.type != layers_config.NormalizationType.NONE:
196
+ _map_to_converted_state(
197
+ state,
198
+ tensor_names.norm,
199
+ converted_state,
200
+ f"{converted_state_param_prefix}.norm",
201
+ )
202
+ attention_layer_prefix = f"{converted_state_param_prefix}.attention"
203
+ if tensor_names.fused_qkv_proj is not None:
204
+ _map_to_converted_state(
205
+ state,
206
+ tensor_names.fused_qkv_proj,
207
+ converted_state,
208
+ f"{attention_layer_prefix}.qkv_projection",
209
+ )
210
+ else:
211
+ _map_to_converted_state(
212
+ state,
213
+ tensor_names.q_proj,
214
+ converted_state,
215
+ f"{attention_layer_prefix}.q_projection",
216
+ squeeze_dims=True,
217
+ )
218
+ _map_to_converted_state(
219
+ state,
220
+ tensor_names.k_proj,
221
+ converted_state,
222
+ f"{attention_layer_prefix}.k_projection",
223
+ squeeze_dims=True,
224
+ )
225
+ _map_to_converted_state(
226
+ state,
227
+ tensor_names.v_proj,
228
+ converted_state,
229
+ f"{attention_layer_prefix}.v_projection",
230
+ squeeze_dims=True,
231
+ )
232
+ converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = (
233
+ torch.concat(
234
+ [
235
+ converted_state[
236
+ f"{attention_layer_prefix}.q_projection.weight"
237
+ ],
238
+ converted_state[
239
+ f"{attention_layer_prefix}.k_projection.weight"
240
+ ],
241
+ converted_state[
242
+ f"{attention_layer_prefix}.v_projection.weight"
243
+ ],
244
+ ],
245
+ axis=0,
246
+ )
247
+ )
248
+ del converted_state[f"{attention_layer_prefix}.q_projection.weight"]
249
+ del converted_state[f"{attention_layer_prefix}.k_projection.weight"]
250
+ del converted_state[f"{attention_layer_prefix}.v_projection.weight"]
251
+ if config.attention_config.qkv_use_bias:
252
+ converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = (
253
+ torch.concat(
254
+ [
255
+ converted_state[
256
+ f"{attention_layer_prefix}.q_projection.bias"
257
+ ],
258
+ converted_state[
259
+ f"{attention_layer_prefix}.k_projection.bias"
260
+ ],
261
+ converted_state[
262
+ f"{attention_layer_prefix}.v_projection.bias"
263
+ ],
264
+ ],
265
+ axis=0,
266
+ )
267
+ )
268
+ del converted_state[f"{attention_layer_prefix}.q_projection.bias"]
269
+ del converted_state[f"{attention_layer_prefix}.k_projection.bias"]
270
+ del converted_state[f"{attention_layer_prefix}.v_projection.bias"]
271
+
272
+ _map_to_converted_state(
273
+ state,
274
+ tensor_names.output_proj,
275
+ converted_state,
276
+ f"{attention_layer_prefix}.output_projection",
277
+ squeeze_dims=True,
278
+ )
279
+
280
+ def _map_cross_attention_block(
281
+ self,
282
+ state: Dict[str, torch.Tensor],
283
+ converted_state: Dict[str, torch.Tensor],
284
+ tensor_names: CrossAttentionBlockTensorNames,
285
+ converted_state_param_prefix: str,
286
+ config: unet_config.CrossAttentionBlock2DConfig,
287
+ ):
288
+ if config.normalization_config.type != layers_config.NormalizationType.NONE:
289
+ _map_to_converted_state(
290
+ state,
291
+ tensor_names.norm,
292
+ converted_state,
293
+ f"{converted_state_param_prefix}.norm",
294
+ )
295
+ attention_layer_prefix = f"{converted_state_param_prefix}.attention"
296
+ _map_to_converted_state(
297
+ state,
298
+ tensor_names.q_proj,
299
+ converted_state,
300
+ f"{attention_layer_prefix}.q_projection",
301
+ )
302
+ _map_to_converted_state(
303
+ state,
304
+ tensor_names.k_proj,
305
+ converted_state,
306
+ f"{attention_layer_prefix}.k_projection",
307
+ )
308
+ _map_to_converted_state(
309
+ state,
310
+ tensor_names.v_proj,
311
+ converted_state,
312
+ f"{attention_layer_prefix}.v_projection",
313
+ )
314
+ _map_to_converted_state(
315
+ state,
316
+ tensor_names.output_proj,
317
+ converted_state,
318
+ f"{attention_layer_prefix}.output_projection",
319
+ )
320
+
321
+ def _map_feedforward_block(
322
+ self,
323
+ state: Dict[str, torch.Tensor],
324
+ converted_state: Dict[str, torch.Tensor],
325
+ tensor_names: FeedForwardBlockTensorNames,
326
+ converted_state_param_prefix: str,
327
+ config: unet_config.FeedForwardBlock2DConfig,
328
+ ):
329
+ _map_to_converted_state(
330
+ state,
331
+ tensor_names.norm,
332
+ converted_state,
333
+ f"{converted_state_param_prefix}.norm",
334
+ )
335
+ if config.activation_config.type == layers_config.ActivationType.GE_GLU:
336
+ _map_to_converted_state(
337
+ state,
338
+ tensor_names.ge_glu,
339
+ converted_state,
340
+ f"{converted_state_param_prefix}.act.proj",
341
+ )
342
+ else:
343
+ _map_to_converted_state(
344
+ state,
345
+ tensor_names.w1,
346
+ converted_state,
347
+ f"{converted_state_param_prefix}.w1",
348
+ )
349
+
350
+ _map_to_converted_state(
351
+ state,
352
+ tensor_names.w2,
353
+ converted_state,
354
+ f"{converted_state_param_prefix}.w2",
355
+ )
356
+
357
+ def _map_transformer_block(
358
+ self,
359
+ state: Dict[str, torch.Tensor],
360
+ converted_state: Dict[str, torch.Tensor],
361
+ tensor_names: TransformerBlockTensorNames,
362
+ converted_state_param_prefix: str,
363
+ config: unet_config.TransformerBlock2DConfig,
364
+ ):
365
+ _map_to_converted_state(
366
+ state,
367
+ tensor_names.pre_conv_norm,
368
+ converted_state,
369
+ f"{converted_state_param_prefix}.pre_conv_norm",
370
+ )
371
+ _map_to_converted_state(
372
+ state,
373
+ tensor_names.conv_in,
374
+ converted_state,
375
+ f"{converted_state_param_prefix}.conv_in",
376
+ )
377
+ self._map_attention_block(
378
+ state,
379
+ converted_state,
380
+ tensor_names.self_attention,
381
+ f"{converted_state_param_prefix}.self_attention",
382
+ config.attention_block_config,
383
+ )
384
+ self._map_cross_attention_block(
385
+ state,
386
+ converted_state,
387
+ tensor_names.cross_attention,
388
+ f"{converted_state_param_prefix}.cross_attention",
389
+ config.cross_attention_block_config,
390
+ )
391
+ self._map_feedforward_block(
392
+ state,
393
+ converted_state,
394
+ tensor_names.feed_forward,
395
+ f"{converted_state_param_prefix}.feed_forward",
396
+ config.feed_forward_block_config,
397
+ )
398
+ _map_to_converted_state(
399
+ state,
400
+ tensor_names.conv_out,
401
+ converted_state,
402
+ f"{converted_state_param_prefix}.conv_out",
403
+ )
404
+
405
+ def _map_mid_block(
406
+ self,
407
+ state: Dict[str, torch.Tensor],
408
+ converted_state: Dict[str, torch.Tensor],
409
+ tensor_names: MidBlockTensorNames,
410
+ converted_state_param_prefix: str,
411
+ config: unet_config.MidBlock2DConfig,
412
+ ):
413
+ residual_block_config = unet_config.ResidualBlock2DConfig(
414
+ in_channels=config.in_channels,
415
+ hidden_channels=config.in_channels,
416
+ out_channels=config.in_channels,
417
+ time_embedding_channels=config.time_embedding_channels,
418
+ normalization_config=config.normalization_config,
419
+ activation_config=config.activation_config,
420
+ )
421
+ self._map_residual_block(
422
+ state,
423
+ converted_state,
424
+ tensor_names.residual_block_tensor_names[0],
425
+ f"{converted_state_param_prefix}.resnets.0",
426
+ residual_block_config,
427
+ )
428
+ for i in range(config.num_layers):
429
+ if config.attention_block_config:
430
+ self._map_attention_block(
431
+ state,
432
+ converted_state,
433
+ tensor_names.attention_block_tensor_names[i],
434
+ f"{converted_state_param_prefix}.attentions.{i}",
435
+ config.attention_block_config,
436
+ )
437
+ if config.transformer_block_config:
438
+ self._map_transformer_block(
439
+ state,
440
+ converted_state,
441
+ tensor_names.transformer_block_tensor_names[i],
442
+ f"{converted_state_param_prefix}.transformers.{i}",
443
+ config.transformer_block_config,
444
+ )
445
+ self._map_residual_block(
446
+ state,
447
+ converted_state,
448
+ tensor_names.residual_block_tensor_names[i + 1],
449
+ f"{converted_state_param_prefix}.resnets.{i+1}",
450
+ residual_block_config,
451
+ )
452
+
453
+ def _map_down_encoder_block(
454
+ self,
455
+ state: Dict[str, torch.Tensor],
456
+ converted_state: Dict[str, torch.Tensor],
457
+ converted_state_param_prefix: str,
458
+ config: unet_config.DownEncoderBlock2DConfig,
459
+ tensor_names: DownEncoderBlockTensorNames,
460
+ ):
461
+ for i in range(config.num_layers):
462
+ input_channels = config.in_channels if i == 0 else config.out_channels
463
+ self._map_residual_block(
464
+ state,
465
+ converted_state,
466
+ tensor_names.residual_block_tensor_names[i],
467
+ f"{converted_state_param_prefix}.resnets.{i}",
468
+ unet_config.ResidualBlock2DConfig(
469
+ in_channels=input_channels,
470
+ hidden_channels=config.out_channels,
471
+ out_channels=config.out_channels,
472
+ time_embedding_channels=config.time_embedding_channels,
473
+ normalization_config=config.normalization_config,
474
+ activation_config=config.activation_config,
475
+ ),
476
+ )
477
+ if config.transformer_block_config:
478
+ self._map_transformer_block(
479
+ state,
480
+ converted_state,
481
+ tensor_names.transformer_block_tensor_names[i],
482
+ f"{converted_state_param_prefix}.transformers.{i}",
483
+ config.transformer_block_config,
484
+ )
485
+ if (
486
+ config.add_downsample
487
+ and config.sampling_config.mode == unet_config.SamplingType.CONVOLUTION
488
+ ):
489
+ _map_to_converted_state(
490
+ state,
491
+ tensor_names.downsample_conv,
492
+ converted_state,
493
+ f"{converted_state_param_prefix}.downsampler",
494
+ )
495
+
496
+ def _map_up_decoder_block(
497
+ self,
498
+ state: Dict[str, torch.Tensor],
499
+ converted_state: Dict[str, torch.Tensor],
500
+ converted_state_param_prefix: str,
501
+ config: unet_config.UpDecoderBlock2DConfig,
502
+ tensor_names: UpDecoderBlockTensorNames,
503
+ ):
504
+ for i in range(config.num_layers):
505
+ input_channels = config.in_channels if i == 0 else config.out_channels
506
+ self._map_residual_block(
507
+ state,
508
+ converted_state,
509
+ tensor_names.residual_block_tensor_names[i],
510
+ f"{converted_state_param_prefix}.resnets.{i}",
511
+ unet_config.ResidualBlock2DConfig(
512
+ in_channels=input_channels,
513
+ hidden_channels=config.out_channels,
514
+ out_channels=config.out_channels,
515
+ time_embedding_channels=config.time_embedding_channels,
516
+ normalization_config=config.normalization_config,
517
+ activation_config=config.activation_config,
518
+ ),
519
+ )
520
+ if config.transformer_block_config:
521
+ self._map_transformer_block(
522
+ state,
523
+ converted_state,
524
+ tensor_names.transformer_block_tensor_names[i],
525
+ f"{converted_state_param_prefix}.transformers.{i}",
526
+ config.transformer_block_config,
527
+ )
528
+ if config.add_upsample and config.upsample_conv:
529
+ _map_to_converted_state(
530
+ state,
531
+ tensor_names.upsample_conv,
532
+ converted_state,
533
+ f"{converted_state_param_prefix}.upsample_conv",
534
+ )
535
+
536
+ def _map_skip_up_decoder_block(
537
+ self,
538
+ state: Dict[str, torch.Tensor],
539
+ converted_state: Dict[str, torch.Tensor],
540
+ converted_state_param_prefix: str,
541
+ config: unet_config.SkipUpDecoderBlock2DConfig,
542
+ tensor_names: UpDecoderBlockTensorNames,
543
+ ):
544
+ for i in range(config.num_layers):
545
+ res_skip_channels = (
546
+ config.in_channels
547
+ if (i == config.num_layers - 1)
548
+ else config.out_channels
549
+ )
550
+ resnet_in_channels = (
551
+ config.prev_out_channels if i == 0 else config.out_channels
552
+ )
553
+ self._map_residual_block(
554
+ state,
555
+ converted_state,
556
+ tensor_names.residual_block_tensor_names[i],
557
+ f"{converted_state_param_prefix}.resnets.{i}",
558
+ unet_config.ResidualBlock2DConfig(
559
+ in_channels=resnet_in_channels + res_skip_channels,
560
+ hidden_channels=config.out_channels,
561
+ out_channels=config.out_channels,
562
+ time_embedding_channels=config.time_embedding_channels,
563
+ normalization_config=config.normalization_config,
564
+ activation_config=config.activation_config,
565
+ ),
566
+ )
567
+ if config.transformer_block_config:
568
+ self._map_transformer_block(
569
+ state,
570
+ converted_state,
571
+ tensor_names.transformer_block_tensor_names[i],
572
+ f"{converted_state_param_prefix}.transformers.{i}",
573
+ config.transformer_block_config,
574
+ )
575
+ if config.add_upsample and config.upsample_conv:
576
+ _map_to_converted_state(
577
+ state,
578
+ tensor_names.upsample_conv,
579
+ converted_state,
580
+ f"{converted_state_param_prefix}.upsample_conv",
581
+ )
582
+
583
+
584
+ # Alias class name for better code reading.
585
+ ClipModelLoader = BaseLoader
586
+
587
+
588
+ class AutoEncoderModelLoader(BaseLoader):
589
+
590
+ @dataclass
591
+ class TensorNames:
592
+ quant_conv: str = None
593
+ post_quant_conv: str = None
594
+ conv_in: str = None
595
+ conv_out: str = None
596
+ final_norm: str = None
597
+ mid_block_tensor_names: MidBlockTensorNames = None
598
+ up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
599
+
600
+ def __init__(self, file_name: str, names: TensorNames):
601
+ """AutoEncoderModelLoader constructor.
602
+
603
+ Can be used to load encoder and decoder models.
604
+
605
+ Args:
606
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
607
+ file.
608
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
609
+ """
610
+ self._file_name = file_name
611
+ self._names = names
612
+ self._loader = self._get_loader()
613
+
614
+ def load(
615
+ self, model: torch.nn.Module, strict: bool = True
616
+ ) -> Tuple[List[str], List[str]]:
617
+ """Load the model from the checkpoint.
618
+
619
+ Args:
620
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
621
+ strict (bool, optional): Whether the converted keys are strictly
622
+ matched. Defaults to True.
623
+
624
+ Returns:
625
+ missing_keys (List[str]): a list of str containing the missing keys.
626
+ unexpected_keys (List[str]): a list of str containing the unexpected
627
+ keys.
628
+
629
+ Raises:
630
+ ValueError: If conversion results in unmapped tensors and strict mode is
631
+ enabled.
632
+ """
633
+ state = self._loader(self._file_name)
634
+ converted_state = dict()
635
+ if self._names.quant_conv is not None:
636
+ _map_to_converted_state(
637
+ state, self._names.quant_conv, converted_state, "quant_conv"
638
+ )
639
+ if self._names.post_quant_conv is not None:
640
+ _map_to_converted_state(
641
+ state, self._names.post_quant_conv, converted_state, "post_quant_conv"
642
+ )
643
+ if self._names.conv_in is not None:
644
+ _map_to_converted_state(
645
+ state, self._names.conv_in, converted_state, "conv_in"
646
+ )
647
+ if self._names.conv_out is not None:
648
+ _map_to_converted_state(
649
+ state, self._names.conv_out, converted_state, "conv_out"
650
+ )
651
+ if self._names.final_norm is not None:
652
+ _map_to_converted_state(
653
+ state, self._names.final_norm, converted_state, "final_norm"
654
+ )
655
+ self._map_mid_block(
656
+ state,
657
+ converted_state,
658
+ self._names.mid_block_tensor_names,
659
+ "mid_block",
660
+ model.config.mid_block_config,
661
+ )
662
+
663
+ reversed_block_out_channels = list(
664
+ reversed(model.config.block_out_channels)
665
+ )
666
+ block_out_channels = reversed_block_out_channels[0]
667
+ for i, out_channels in enumerate(reversed_block_out_channels):
668
+ prev_output_channel = block_out_channels
669
+ block_out_channels = out_channels
670
+ not_final_block = i < len(reversed_block_out_channels) - 1
671
+ self._map_up_decoder_block(
672
+ state,
673
+ converted_state,
674
+ f"up_decoder_blocks.{i}",
675
+ unet_config.UpDecoderBlock2DConfig(
676
+ in_channels=prev_output_channel,
677
+ out_channels=block_out_channels,
678
+ normalization_config=model.config.normalization_config,
679
+ activation_config=model.config.activation_config,
680
+ num_layers=model.config.layers_per_block,
681
+ add_upsample=not_final_block,
682
+ upsample_conv=True,
683
+ ),
684
+ self._names.up_decoder_blocks_tensor_names[i],
685
+ )
686
+ if strict and state:
687
+ raise ValueError(
688
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
689
+ )
690
+ return model.load_state_dict(converted_state, strict=strict)
691
+
692
+
693
+ def build_attention_config(
694
+ num_heads,
695
+ dim,
696
+ num_query_groups,
697
+ rotary_percentage=0.0,
698
+ qkv_transpose_before_split=True,
699
+ qkv_use_bias=False,
700
+ output_proj_use_bias=True,
701
+ enable_kv_cache=False,
702
+ qkv_fused_interleaved=False,
703
+ ):
704
+
705
+ return layers_config.AttentionConfig(
706
+ num_heads=num_heads,
707
+ head_dim=dim // num_heads,
708
+ num_query_groups=num_query_groups,
709
+ rotary_percentage=rotary_percentage,
710
+ qkv_transpose_before_split=qkv_transpose_before_split,
711
+ qkv_use_bias=qkv_use_bias,
712
+ output_proj_use_bias=output_proj_use_bias,
713
+ enable_kv_cache=enable_kv_cache,
714
+ qkv_fused_interleaved=qkv_fused_interleaved,
715
+ )
716
+
717
+
718
+ class DiffusionModelLoader(BaseLoader):
719
+
720
+ @dataclass
721
+ class TensorNames:
722
+ time_embedding: TimeEmbeddingTensorNames = None
723
+ conv_in: str = None
724
+ conv_out: str = None
725
+ final_norm: str = None
726
+ down_encoder_blocks_tensor_names: List[DownEncoderBlockTensorNames] = None
727
+ mid_block_tensor_names: MidBlockTensorNames = None
728
+ up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
729
+
730
+ def __init__(self, file_name: str, names: TensorNames):
731
+ """DiffusionModelLoader constructor.
732
+
733
+ Can be used to load diffusion models of Stable Diffusion.
734
+
735
+ Args:
736
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
737
+ file.
738
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
739
+ """
740
+ self._file_name = file_name
741
+ self._names = names
742
+ self._loader = self._get_loader()
743
+
744
+ def load(
745
+ self, model: torch.nn.Module, strict: bool = True
746
+ ) -> Tuple[List[str], List[str]]:
747
+ """Load the model from the checkpoint.
748
+
749
+ Args:
750
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
751
+ strict (bool, optional): Whether the converted keys are strictly
752
+ matched. Defaults to True.
753
+
754
+ Returns:
755
+ missing_keys (List[str]): a list of str containing the missing keys.
756
+ unexpected_keys (List[str]): a list of str containing the unexpected
757
+ keys.
758
+
759
+ Raises:
760
+ ValueError: If conversion results in unmapped tensors and strict mode is
761
+ enabled.
762
+ """
763
+ state = self._loader(self._file_name)
764
+ converted_state = dict()
765
+ config: unet_config.DiffusionModelConfig = model.config
766
+ self._map_time_embedding(
767
+ state, converted_state, "time_embedding", self._names.time_embedding
768
+ )
769
+ _map_to_converted_state(
770
+ state, self._names.conv_in, converted_state, "conv_in"
771
+ )
772
+ _map_to_converted_state(
773
+ state, self._names.conv_out, converted_state, "conv_out"
774
+ )
775
+ _map_to_converted_state(
776
+ state, self._names.final_norm, converted_state, "final_norm"
777
+ )
778
+
779
+ # Map down_encoders.
780
+ output_channel = config.block_out_channels[0]
781
+ for i, block_out_channel in enumerate(config.block_out_channels):
782
+ input_channel = output_channel
783
+ output_channel = block_out_channel
784
+ not_final_block = i < len(config.block_out_channels) - 1
785
+ if not_final_block:
786
+ down_encoder_block_config = unet_config.DownEncoderBlock2DConfig(
787
+ in_channels=input_channel,
788
+ out_channels=output_channel,
789
+ normalization_config=config.residual_norm_config,
790
+ activation_config=layers_config.ActivationConfig(
791
+ config.residual_activation_type
792
+ ),
793
+ num_layers=config.layers_per_block,
794
+ padding=config.downsample_padding,
795
+ time_embedding_channels=config.time_embedding_blocks_dim,
796
+ add_downsample=True,
797
+ sampling_config=unet_config.DownSamplingConfig(
798
+ mode=unet_config.SamplingType.CONVOLUTION,
799
+ in_channels=output_channel,
800
+ out_channels=output_channel,
801
+ kernel_size=3,
802
+ stride=2,
803
+ padding=config.downsample_padding,
804
+ ),
805
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
806
+ attention_block_config=unet_config.AttentionBlock2DConfig(
807
+ dim=output_channel,
808
+ normalization_config=config.transformer_norm_config,
809
+ attention_config=build_attention_config(
810
+ num_heads=config.transformer_num_attention_heads,
811
+ dim=output_channel,
812
+ num_query_groups=config.transformer_num_attention_heads,
813
+ ),
814
+ ),
815
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
816
+ query_dim=output_channel,
817
+ cross_dim=config.transformer_cross_attention_dim,
818
+ hidden_dim=output_channel,
819
+ output_dim=output_channel,
820
+ normalization_config=config.transformer_norm_config,
821
+ attention_config=build_attention_config(
822
+ num_heads=config.transformer_num_attention_heads,
823
+ dim=output_channel,
824
+ num_query_groups=config.transformer_num_attention_heads,
825
+ ),
826
+ ),
827
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
828
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
829
+ dim=output_channel,
830
+ hidden_dim=output_channel * 4,
831
+ normalization_config=config.transformer_norm_config,
832
+ activation_config=layers_config.ActivationConfig(
833
+ type=config.transformer_ff_activation_type,
834
+ dim_in=output_channel,
835
+ dim_out=output_channel * 4,
836
+ ),
837
+ use_bias=True,
838
+ ),
839
+ ),
840
+ )
841
+ else:
842
+ down_encoder_block_config = unet_config.DownEncoderBlock2DConfig(
843
+ in_channels=input_channel,
844
+ out_channels=output_channel,
845
+ normalization_config=config.residual_norm_config,
846
+ activation_config=layers_config.ActivationConfig(
847
+ config.residual_activation_type
848
+ ),
849
+ num_layers=config.layers_per_block,
850
+ padding=config.downsample_padding,
851
+ time_embedding_channels=config.time_embedding_blocks_dim,
852
+ add_downsample=False,
853
+ )
854
+
855
+ self._map_down_encoder_block(
856
+ state,
857
+ converted_state,
858
+ f"down_encoders.{i}",
859
+ down_encoder_block_config,
860
+ self._names.down_encoder_blocks_tensor_names[i],
861
+ )
862
+
863
+ # Map mid block.
864
+ mid_block_channels = config.block_out_channels[-1]
865
+ mid_block_config = unet_config.MidBlock2DConfig(
866
+ in_channels=mid_block_channels,
867
+ normalization_config=config.residual_norm_config,
868
+ activation_config=layers_config.ActivationConfig(
869
+ config.residual_activation_type
870
+ ),
871
+ num_layers=config.mid_block_layers,
872
+ time_embedding_channels=config.time_embedding_blocks_dim,
873
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
874
+ attention_block_config=unet_config.AttentionBlock2DConfig(
875
+ dim=mid_block_channels,
876
+ normalization_config=config.transformer_norm_config,
877
+ attention_config=build_attention_config(
878
+ num_heads=config.transformer_num_attention_heads,
879
+ dim=mid_block_channels,
880
+ num_query_groups=config.transformer_num_attention_heads,
881
+ ),
882
+ ),
883
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
884
+ query_dim=mid_block_channels,
885
+ cross_dim=config.transformer_cross_attention_dim,
886
+ hidden_dim=mid_block_channels,
887
+ output_dim=mid_block_channels,
888
+ normalization_config=config.transformer_norm_config,
889
+ attention_config=build_attention_config(
890
+ num_heads=config.transformer_num_attention_heads,
891
+ dim=mid_block_channels,
892
+ num_query_groups=config.transformer_num_attention_heads,
893
+ ),
894
+ ),
895
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
896
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
897
+ dim=mid_block_channels,
898
+ hidden_dim=mid_block_channels * 4,
899
+ normalization_config=config.transformer_norm_config,
900
+ activation_config=layers_config.ActivationConfig(
901
+ type=config.transformer_ff_activation_type,
902
+ dim_in=mid_block_channels,
903
+ dim_out=mid_block_channels * 4,
904
+ ),
905
+ use_bias=True,
906
+ ),
907
+ ),
908
+ )
909
+ self._map_mid_block(
910
+ state,
911
+ converted_state,
912
+ self._names.mid_block_tensor_names,
913
+ "mid_block",
914
+ mid_block_config,
915
+ )
916
+
917
+ # Map up_decoders.
918
+ reversed_block_out_channels = list(
919
+ reversed(model.config.block_out_channels)
920
+ )
921
+ up_decoder_layers_per_block = config.layers_per_block + 1
922
+ output_channel = reversed_block_out_channels[0]
923
+ for i, block_out_channel in enumerate(reversed_block_out_channels):
924
+ prev_out_channel = output_channel
925
+ output_channel = block_out_channel
926
+ input_channel = reversed_block_out_channels[
927
+ min(i + 1, len(reversed_block_out_channels) - 1)
928
+ ]
929
+ not_final_block = i < len(reversed_block_out_channels) - 1
930
+ not_first_block = i != 0
931
+ if not_first_block:
932
+ up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig(
933
+ in_channels=input_channel,
934
+ out_channels=output_channel,
935
+ prev_out_channels=prev_out_channel,
936
+ normalization_config=config.residual_norm_config,
937
+ activation_config=layers_config.ActivationConfig(
938
+ config.residual_activation_type
939
+ ),
940
+ num_layers=up_decoder_layers_per_block,
941
+ time_embedding_channels=config.time_embedding_blocks_dim,
942
+ add_upsample=not_final_block,
943
+ upsample_conv=True,
944
+ sampling_config=unet_config.UpSamplingConfig(
945
+ mode=unet_config.SamplingType.NEAREST,
946
+ scale_factor=2,
947
+ ),
948
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
949
+ attention_block_config=unet_config.AttentionBlock2DConfig(
950
+ dim=output_channel,
951
+ normalization_config=config.transformer_norm_config,
952
+ attention_config=build_attention_config(
953
+ num_heads=config.transformer_num_attention_heads,
954
+ dim=output_channel,
955
+ num_query_groups=config.transformer_num_attention_heads,
956
+ ),
957
+ ),
958
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
959
+ query_dim=output_channel,
960
+ cross_dim=config.transformer_cross_attention_dim,
961
+ hidden_dim=output_channel,
962
+ output_dim=output_channel,
963
+ normalization_config=config.transformer_norm_config,
964
+ attention_config=build_attention_config(
965
+ num_heads=config.transformer_num_attention_heads,
966
+ dim=output_channel,
967
+ num_query_groups=config.transformer_num_attention_heads,
968
+ ),
969
+ ),
970
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
971
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
972
+ dim=output_channel,
973
+ hidden_dim=output_channel * 4,
974
+ normalization_config=config.transformer_norm_config,
975
+ activation_config=layers_config.ActivationConfig(
976
+ type=config.transformer_ff_activation_type,
977
+ dim_in=output_channel,
978
+ dim_out=output_channel * 4,
979
+ ),
980
+ use_bias=True,
981
+ ),
982
+ ),
983
+ )
984
+ else:
985
+ up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig(
986
+ in_channels=input_channel,
987
+ out_channels=output_channel,
988
+ prev_out_channels=prev_out_channel,
989
+ normalization_config=config.residual_norm_config,
990
+ activation_config=layers_config.ActivationConfig(
991
+ config.residual_activation_type
992
+ ),
993
+ num_layers=up_decoder_layers_per_block,
994
+ time_embedding_channels=config.time_embedding_blocks_dim,
995
+ add_upsample=not_final_block,
996
+ upsample_conv=True,
997
+ sampling_config=unet_config.UpSamplingConfig(
998
+ mode=unet_config.SamplingType.NEAREST, scale_factor=2
999
+ ),
1000
+ )
1001
+ self._map_skip_up_decoder_block(
1002
+ state,
1003
+ converted_state,
1004
+ f"up_decoders.{i}",
1005
+ up_encoder_block_config,
1006
+ self._names.up_decoder_blocks_tensor_names[i],
1007
+ )
1008
+ if strict and state:
1009
+ raise ValueError(
1010
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
1011
+ )
1012
+ return model.load_state_dict(converted_state, strict=strict)
1013
+
1014
+ def _map_time_embedding(
1015
+ self,
1016
+ state: Dict[str, torch.Tensor],
1017
+ converted_state: Dict[str, torch.Tensor],
1018
+ converted_state_param_prefix: str,
1019
+ tensor_names: TimeEmbeddingTensorNames,
1020
+ ):
1021
+ _map_to_converted_state(
1022
+ state,
1023
+ tensor_names.w1,
1024
+ converted_state,
1025
+ f"{converted_state_param_prefix}.w1",
1026
+ )
1027
+ _map_to_converted_state(
1028
+ state,
1029
+ tensor_names.w2,
1030
+ converted_state,
1031
+ f"{converted_state_param_prefix}.w2",
1032
+ )