diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@
14
14
 
15
15
  import re
16
16
 
17
+ import torch
18
+
17
19
  from ..utils import is_peft_version, logging
18
20
 
19
21
 
@@ -326,3 +328,648 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
326
328
  prefix = "text_encoder_2."
327
329
  new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
328
330
  return {new_name: alpha}
331
+
332
+
333
+ # The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
+ # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
+ # All credits go to `kohya-ss`.
336
+ def _convert_kohya_flux_lora_to_diffusers(state_dict):
337
+ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338
+ if sds_key + ".lora_down.weight" not in sds_sd:
339
+ return
340
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
341
+
342
+ # scale weight by alpha and dim
343
+ rank = down_weight.shape[0]
344
+ alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
345
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
+
347
+ # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
348
+ scale_down = scale
349
+ scale_up = 1.0
350
+ while scale_down * 2 < scale_up:
351
+ scale_down *= 2
352
+ scale_up /= 2
353
+
354
+ ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
355
+ ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
356
+
357
+ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
358
+ if sds_key + ".lora_down.weight" not in sds_sd:
359
+ return
360
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
361
+ up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
362
+ sd_lora_rank = down_weight.shape[0]
363
+
364
+ # scale weight by alpha and dim
365
+ alpha = sds_sd.pop(sds_key + ".alpha")
366
+ scale = alpha / sd_lora_rank
367
+
368
+ # calculate scale_down and scale_up
369
+ scale_down = scale
370
+ scale_up = 1.0
371
+ while scale_down * 2 < scale_up:
372
+ scale_down *= 2
373
+ scale_up /= 2
374
+
375
+ down_weight = down_weight * scale_down
376
+ up_weight = up_weight * scale_up
377
+
378
+ # calculate dims if not provided
379
+ num_splits = len(ait_keys)
380
+ if dims is None:
381
+ dims = [up_weight.shape[0] // num_splits] * num_splits
382
+ else:
383
+ assert sum(dims) == up_weight.shape[0]
384
+
385
+ # check upweight is sparse or not
386
+ is_sparse = False
387
+ if sd_lora_rank % num_splits == 0:
388
+ ait_rank = sd_lora_rank // num_splits
389
+ is_sparse = True
390
+ i = 0
391
+ for j in range(len(dims)):
392
+ for k in range(len(dims)):
393
+ if j == k:
394
+ continue
395
+ is_sparse = is_sparse and torch.all(
396
+ up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
397
+ )
398
+ i += dims[j]
399
+ if is_sparse:
400
+ logger.info(f"weight is sparse: {sds_key}")
401
+
402
+ # make ai-toolkit weight
403
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
404
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
405
+ if not is_sparse:
406
+ # down_weight is copied to each split
407
+ ait_sd.update({k: down_weight for k in ait_down_keys})
408
+
409
+ # up_weight is split to each split
410
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
411
+ else:
412
+ # down_weight is chunked to each split
413
+ ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
414
+
415
+ # up_weight is sparse: only non-zero values are copied to each split
416
+ i = 0
417
+ for j in range(len(dims)):
418
+ ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
419
+ i += dims[j]
420
+
421
+ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
422
+ ait_sd = {}
423
+ for i in range(19):
424
+ _convert_to_ai_toolkit(
425
+ sds_sd,
426
+ ait_sd,
427
+ f"lora_unet_double_blocks_{i}_img_attn_proj",
428
+ f"transformer.transformer_blocks.{i}.attn.to_out.0",
429
+ )
430
+ _convert_to_ai_toolkit_cat(
431
+ sds_sd,
432
+ ait_sd,
433
+ f"lora_unet_double_blocks_{i}_img_attn_qkv",
434
+ [
435
+ f"transformer.transformer_blocks.{i}.attn.to_q",
436
+ f"transformer.transformer_blocks.{i}.attn.to_k",
437
+ f"transformer.transformer_blocks.{i}.attn.to_v",
438
+ ],
439
+ )
440
+ _convert_to_ai_toolkit(
441
+ sds_sd,
442
+ ait_sd,
443
+ f"lora_unet_double_blocks_{i}_img_mlp_0",
444
+ f"transformer.transformer_blocks.{i}.ff.net.0.proj",
445
+ )
446
+ _convert_to_ai_toolkit(
447
+ sds_sd,
448
+ ait_sd,
449
+ f"lora_unet_double_blocks_{i}_img_mlp_2",
450
+ f"transformer.transformer_blocks.{i}.ff.net.2",
451
+ )
452
+ _convert_to_ai_toolkit(
453
+ sds_sd,
454
+ ait_sd,
455
+ f"lora_unet_double_blocks_{i}_img_mod_lin",
456
+ f"transformer.transformer_blocks.{i}.norm1.linear",
457
+ )
458
+ _convert_to_ai_toolkit(
459
+ sds_sd,
460
+ ait_sd,
461
+ f"lora_unet_double_blocks_{i}_txt_attn_proj",
462
+ f"transformer.transformer_blocks.{i}.attn.to_add_out",
463
+ )
464
+ _convert_to_ai_toolkit_cat(
465
+ sds_sd,
466
+ ait_sd,
467
+ f"lora_unet_double_blocks_{i}_txt_attn_qkv",
468
+ [
469
+ f"transformer.transformer_blocks.{i}.attn.add_q_proj",
470
+ f"transformer.transformer_blocks.{i}.attn.add_k_proj",
471
+ f"transformer.transformer_blocks.{i}.attn.add_v_proj",
472
+ ],
473
+ )
474
+ _convert_to_ai_toolkit(
475
+ sds_sd,
476
+ ait_sd,
477
+ f"lora_unet_double_blocks_{i}_txt_mlp_0",
478
+ f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
479
+ )
480
+ _convert_to_ai_toolkit(
481
+ sds_sd,
482
+ ait_sd,
483
+ f"lora_unet_double_blocks_{i}_txt_mlp_2",
484
+ f"transformer.transformer_blocks.{i}.ff_context.net.2",
485
+ )
486
+ _convert_to_ai_toolkit(
487
+ sds_sd,
488
+ ait_sd,
489
+ f"lora_unet_double_blocks_{i}_txt_mod_lin",
490
+ f"transformer.transformer_blocks.{i}.norm1_context.linear",
491
+ )
492
+
493
+ for i in range(38):
494
+ _convert_to_ai_toolkit_cat(
495
+ sds_sd,
496
+ ait_sd,
497
+ f"lora_unet_single_blocks_{i}_linear1",
498
+ [
499
+ f"transformer.single_transformer_blocks.{i}.attn.to_q",
500
+ f"transformer.single_transformer_blocks.{i}.attn.to_k",
501
+ f"transformer.single_transformer_blocks.{i}.attn.to_v",
502
+ f"transformer.single_transformer_blocks.{i}.proj_mlp",
503
+ ],
504
+ dims=[3072, 3072, 3072, 12288],
505
+ )
506
+ _convert_to_ai_toolkit(
507
+ sds_sd,
508
+ ait_sd,
509
+ f"lora_unet_single_blocks_{i}_linear2",
510
+ f"transformer.single_transformer_blocks.{i}.proj_out",
511
+ )
512
+ _convert_to_ai_toolkit(
513
+ sds_sd,
514
+ ait_sd,
515
+ f"lora_unet_single_blocks_{i}_modulation_lin",
516
+ f"transformer.single_transformer_blocks.{i}.norm.linear",
517
+ )
518
+
519
+ remaining_keys = list(sds_sd.keys())
520
+ te_state_dict = {}
521
+ if remaining_keys:
522
+ if not all(k.startswith("lora_te1") for k in remaining_keys):
523
+ raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524
+ for key in remaining_keys:
525
+ if not key.endswith("lora_down.weight"):
526
+ continue
527
+
528
+ lora_name = key.split(".")[0]
529
+ lora_name_up = f"{lora_name}.lora_up.weight"
530
+ lora_name_alpha = f"{lora_name}.alpha"
531
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
532
+
533
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
534
+ down_weight = sds_sd.pop(key)
535
+ sd_lora_rank = down_weight.shape[0]
536
+ te_state_dict[diffusers_name] = down_weight
537
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
538
+
539
+ if lora_name_alpha in sds_sd:
540
+ alpha = sds_sd.pop(lora_name_alpha).item()
541
+ scale = alpha / sd_lora_rank
542
+
543
+ scale_down = scale
544
+ scale_up = 1.0
545
+ while scale_down * 2 < scale_up:
546
+ scale_down *= 2
547
+ scale_up /= 2
548
+
549
+ te_state_dict[diffusers_name] *= scale_down
550
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
551
+
552
+ if len(sds_sd) > 0:
553
+ logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
554
+
555
+ if te_state_dict:
556
+ te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
557
+
558
+ new_state_dict = {**ait_sd, **te_state_dict}
559
+ return new_state_dict
560
+
561
+ return _convert_sd_scripts_to_ai_toolkit(state_dict)
562
+
563
+
564
+ # Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
565
+ # Some utilities were reused from
566
+ # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
567
+ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
568
+ new_state_dict = {}
569
+ orig_keys = list(old_state_dict.keys())
570
+
571
+ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
572
+ down_weight = sds_sd.pop(sds_key)
573
+ up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
574
+
575
+ # calculate dims if not provided
576
+ num_splits = len(ait_keys)
577
+ if dims is None:
578
+ dims = [up_weight.shape[0] // num_splits] * num_splits
579
+ else:
580
+ assert sum(dims) == up_weight.shape[0]
581
+
582
+ # make ai-toolkit weight
583
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
584
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
585
+
586
+ # down_weight is copied to each split
587
+ ait_sd.update({k: down_weight for k in ait_down_keys})
588
+
589
+ # up_weight is split to each split
590
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
591
+
592
+ for old_key in orig_keys:
593
+ # Handle double_blocks
594
+ if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
595
+ block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
596
+ new_key = f"transformer.transformer_blocks.{block_num}"
597
+
598
+ if "processor.proj_lora1" in old_key:
599
+ new_key += ".attn.to_out.0"
600
+ elif "processor.proj_lora2" in old_key:
601
+ new_key += ".attn.to_add_out"
602
+ # Handle text latents.
603
+ elif "processor.qkv_lora2" in old_key and "up" not in old_key:
604
+ handle_qkv(
605
+ old_state_dict,
606
+ new_state_dict,
607
+ old_key,
608
+ [
609
+ f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
610
+ f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
611
+ f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
612
+ ],
613
+ )
614
+ # continue
615
+ # Handle image latents.
616
+ elif "processor.qkv_lora1" in old_key and "up" not in old_key:
617
+ handle_qkv(
618
+ old_state_dict,
619
+ new_state_dict,
620
+ old_key,
621
+ [
622
+ f"transformer.transformer_blocks.{block_num}.attn.to_q",
623
+ f"transformer.transformer_blocks.{block_num}.attn.to_k",
624
+ f"transformer.transformer_blocks.{block_num}.attn.to_v",
625
+ ],
626
+ )
627
+ # continue
628
+
629
+ if "down" in old_key:
630
+ new_key += ".lora_A.weight"
631
+ elif "up" in old_key:
632
+ new_key += ".lora_B.weight"
633
+
634
+ # Handle single_blocks
635
+ elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
636
+ block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637
+ new_key = f"transformer.single_transformer_blocks.{block_num}"
638
+
639
+ if "proj_lora" in old_key:
640
+ new_key += ".proj_out"
641
+ elif "qkv_lora" in old_key and "up" not in old_key:
642
+ handle_qkv(
643
+ old_state_dict,
644
+ new_state_dict,
645
+ old_key,
646
+ [
647
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
648
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
649
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
650
+ ],
651
+ )
652
+
653
+ if "down" in old_key:
654
+ new_key += ".lora_A.weight"
655
+ elif "up" in old_key:
656
+ new_key += ".lora_B.weight"
657
+
658
+ else:
659
+ # Handle other potential key patterns here
660
+ new_key = old_key
661
+
662
+ # Since we already handle qkv above.
663
+ if "qkv" not in old_key:
664
+ new_state_dict[new_key] = old_state_dict.pop(old_key)
665
+
666
+ if len(old_state_dict) > 0:
667
+ raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
668
+
669
+ return new_state_dict
670
+
671
+
672
+ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673
+ converted_state_dict = {}
674
+ original_state_dict_keys = list(original_state_dict.keys())
675
+ num_layers = 19
676
+ num_single_layers = 38
677
+ inner_dim = 3072
678
+ mlp_ratio = 4.0
679
+
680
+ def swap_scale_shift(weight):
681
+ shift, scale = weight.chunk(2, dim=0)
682
+ new_weight = torch.cat([scale, shift], dim=0)
683
+ return new_weight
684
+
685
+ for lora_key in ["lora_A", "lora_B"]:
686
+ ## time_text_embed.timestep_embedder <- time_in
687
+ converted_state_dict[
688
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
689
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
690
+ if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
691
+ converted_state_dict[
692
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
693
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
694
+
695
+ converted_state_dict[
696
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
697
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
698
+ if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
699
+ converted_state_dict[
700
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
701
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
702
+
703
+ ## time_text_embed.text_embedder <- vector_in
704
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
705
+ f"vector_in.in_layer.{lora_key}.weight"
706
+ )
707
+ if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
708
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
709
+ f"vector_in.in_layer.{lora_key}.bias"
710
+ )
711
+
712
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
713
+ f"vector_in.out_layer.{lora_key}.weight"
714
+ )
715
+ if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
716
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
717
+ f"vector_in.out_layer.{lora_key}.bias"
718
+ )
719
+
720
+ # guidance
721
+ has_guidance = any("guidance" in k for k in original_state_dict)
722
+ if has_guidance:
723
+ converted_state_dict[
724
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
725
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
726
+ if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
727
+ converted_state_dict[
728
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
729
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
730
+
731
+ converted_state_dict[
732
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
733
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
734
+ if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
735
+ converted_state_dict[
736
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
737
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
738
+
739
+ # context_embedder
740
+ converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
741
+ f"txt_in.{lora_key}.weight"
742
+ )
743
+ if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
744
+ converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
745
+ f"txt_in.{lora_key}.bias"
746
+ )
747
+
748
+ # x_embedder
749
+ converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
750
+ if f"img_in.{lora_key}.bias" in original_state_dict_keys:
751
+ converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
752
+
753
+ # double transformer blocks
754
+ for i in range(num_layers):
755
+ block_prefix = f"transformer_blocks.{i}."
756
+
757
+ for lora_key in ["lora_A", "lora_B"]:
758
+ # norms
759
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
760
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
761
+ )
762
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
763
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
764
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
765
+ )
766
+
767
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
768
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
769
+ )
770
+ if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
771
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
772
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
773
+ )
774
+
775
+ # Q, K, V
776
+ if lora_key == "lora_A":
777
+ sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
778
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
779
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
780
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
781
+
782
+ context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
783
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
784
+ [context_lora_weight]
785
+ )
786
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
787
+ [context_lora_weight]
788
+ )
789
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
790
+ [context_lora_weight]
791
+ )
792
+ else:
793
+ sample_q, sample_k, sample_v = torch.chunk(
794
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
795
+ )
796
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
797
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
798
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
799
+
800
+ context_q, context_k, context_v = torch.chunk(
801
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
802
+ )
803
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
804
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
805
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
806
+
807
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
808
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
809
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
810
+ )
811
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
812
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
813
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
814
+
815
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
816
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
817
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
818
+ )
819
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
820
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
821
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
822
+
823
+ # ff img_mlp
824
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
825
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
826
+ )
827
+ if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
828
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
829
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
830
+ )
831
+
832
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
833
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
834
+ )
835
+ if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
836
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
837
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
838
+ )
839
+
840
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
841
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
842
+ )
843
+ if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
844
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
845
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
846
+ )
847
+
848
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
849
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
850
+ )
851
+ if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
852
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
853
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
854
+ )
855
+
856
+ # output projections.
857
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
858
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
859
+ )
860
+ if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
861
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
862
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
863
+ )
864
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
865
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
866
+ )
867
+ if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
868
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
869
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
870
+ )
871
+
872
+ # qk_norm
873
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
874
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
875
+ )
876
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
877
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
878
+ )
879
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
880
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
881
+ )
882
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
883
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
884
+ )
885
+
886
+ # single transfomer blocks
887
+ for i in range(num_single_layers):
888
+ block_prefix = f"single_transformer_blocks.{i}."
889
+
890
+ for lora_key in ["lora_A", "lora_B"]:
891
+ # norm.linear <- single_blocks.0.modulation.lin
892
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
893
+ f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
894
+ )
895
+ if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
896
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
897
+ f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
898
+ )
899
+
900
+ # Q, K, V, mlp
901
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
902
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
903
+
904
+ if lora_key == "lora_A":
905
+ lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
906
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
907
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
908
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
909
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
910
+
911
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
912
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
913
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
914
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
915
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
916
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
917
+ else:
918
+ q, k, v, mlp = torch.split(
919
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
920
+ )
921
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
922
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
923
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
924
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
925
+
926
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
927
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
928
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
929
+ )
930
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
931
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
932
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
933
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
934
+
935
+ # output projections.
936
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
937
+ f"single_blocks.{i}.linear2.{lora_key}.weight"
938
+ )
939
+ if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
940
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
941
+ f"single_blocks.{i}.linear2.{lora_key}.bias"
942
+ )
943
+
944
+ # qk norm
945
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
946
+ f"single_blocks.{i}.norm.query_norm.scale"
947
+ )
948
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
949
+ f"single_blocks.{i}.norm.key_norm.scale"
950
+ )
951
+
952
+ for lora_key in ["lora_A", "lora_B"]:
953
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
954
+ f"final_layer.linear.{lora_key}.weight"
955
+ )
956
+ if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
957
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
958
+ f"final_layer.linear.{lora_key}.bias"
959
+ )
960
+
961
+ converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
962
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
963
+ )
964
+ if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
965
+ converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
966
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
967
+ )
968
+
969
+ if len(original_state_dict) > 0:
970
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
971
+
972
+ for key in list(converted_state_dict.keys()):
973
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
974
+
975
+ return converted_state_dict